from socket import socket, AF_INET, AF_INET6, SOCK_STREAM
import re
import logging

from parallels.core.utils.common.ip import is_ipv6
from parallels.core.utils.windows_utils import get_binary_full_path
from parallels.hosting_check import ServiceIssue
from parallels.hosting_check import Severity
from parallels.hosting_check import ServiceIssueType
from parallels.hosting_check.messages import MSG
from parallels.hosting_check import NonZeroExitCodeException

logger = logging.getLogger(__name__)


class ServiceChecker(object):
    def __init__(self):
        self._command_cache = {}
        self._netstat_tool_cache = {}

    def check(self, services_to_check):
        issues = []
        for service_to_check in services_to_check:
            self._check_single_service(service_to_check, issues)
        return issues

    def _check_single_service(self, service_to_check, issues):
        try:
            for process in service_to_check.service.processes:
                started_process = self._check_processes(service_to_check, process, issues)
                if started_process:
                    self._check_ports(service_to_check, process, issues)
        except KeyboardInterrupt:
            # for compatibility with python 2.4
            raise
        except Exception, e:
            logger.debug(u"Exception:", exc_info=e)
            issues.append(
                ServiceIssue(
                    severity=Severity.WARNING,
                    category=ServiceIssueType.SERVICE_INTERNAL_ERROR,
                    problem=MSG(
                        ServiceIssueType.SERVICE_INTERNAL_ERROR,
                        description=service_to_check.description,
                        type=service_to_check.service.type,
                        error_message=str(e)
                    )
                )
            )
            return

    def _check_processes(self, service_to_check, process, issues):
        for process_name in process.names:
            try:
                args = {}
                if service_to_check.service.is_windows:
                    cmd = "{tasklist_bin}"
                    args['tasklist_bin'] = get_binary_full_path(service_to_check.server, 'tasklist')
                else:
                    cmd = "ps ax" 
                result = self._cached_sh(service_to_check.server, 'tasklist', cmd, args)

                if result.find(process_name) == -1:
                    continue
                return True
            except KeyboardInterrupt:
                # for compatibility with python 2.4
                raise
            except NonZeroExitCodeException, e:
                logger.debug(u"Exception:", exc_info=e)
                issues.append(
                    ServiceIssue(
                        severity=Severity.WARNING,
                        category=ServiceIssueType.SERVICE_INTERNAL_ERROR,
                        problem=MSG(
                            ServiceIssueType.SERVICE_INTERNAL_ERROR,
                            description=service_to_check.description,
                            type=service_to_check.service.type,
                            error_message=str(e)
                        )
                    )
                )
        issues.append(
            ServiceIssue(
                severity=Severity.ERROR,
                category=ServiceIssueType.SERVICE_NOT_STARTED,
                problem=MSG(
                    ServiceIssueType.SERVICE_NOT_STARTED,
                    description=service_to_check.description,
                    type=service_to_check.service.type
                )
            )
        )
        return False

    def _check_ports(self, service_to_check, process, issues):
        for port in process.ports:
            try:
                args = {}
                if service_to_check.service.is_windows:
                    cmd = '{netstat_bin} -an'
                    args['netstat_bin'] = get_binary_full_path(service_to_check.server, 'netstat')
                else:
                    netstat_tool = self._get_netstat_tool(service_to_check.server)
                    if netstat_tool is not None:
                        cmd = "%s -tpln" % netstat_tool
                    else:
                        # If no netstat-like tool is installed, silently skip the checks
                        cmd = None

                if cmd is not None:
                    result = self._cached_sh(service_to_check.server, 'netstat', cmd, args)
                    if re.search(":%s\s" % port, result) is None:
                        issues.append(
                            ServiceIssue(
                                severity=Severity.ERROR,
                                category=ServiceIssueType.SERVICE_PORT_IS_CLOSED,
                                problem=MSG(
                                    ServiceIssueType.SERVICE_PORT_IS_CLOSED,
                                    description=service_to_check.description,
                                    port=port,
                                    service=service_to_check.service.type
                                )
                            )
                        )
                        continue

                if service_to_check.service.check_connection:
                    result = self._check_connection(service_to_check.host, port)

                    if not result:
                        issues.append(
                            ServiceIssue(
                                severity=Severity.ERROR,
                                category=ServiceIssueType.SERVICE_CONNECTION_ERROR,
                                problem=MSG(
                                    ServiceIssueType.SERVICE_CONNECTION_ERROR,
                                    description=service_to_check.description,
                                    port=port,
                                    service=service_to_check.service.type
                                )
                            )
                        )
            except KeyboardInterrupt:
                # for compatibility with python 2.4
                raise
            except NonZeroExitCodeException, e:
                logger.debug(u"Exception:", exc_info=e)
                issues.append(
                    ServiceIssue(
                        severity=Severity.WARNING,
                        category=ServiceIssueType.SERVICE_INTERNAL_ERROR,
                        problem=MSG(
                            ServiceIssueType.SERVICE_INTERNAL_ERROR,
                            description=service_to_check.description,
                            type=service_to_check.service.type,
                            error_message=str(e)
                        )
                    )
                )

    def _get_netstat_tool(self, server):
        if server not in self._netstat_tool_cache:
            self._netstat_tool_cache[server] = None

            for tool in ['netstat', 'ss']:
                if self._tool_exists(server, tool):
                    self._netstat_tool_cache[server] = tool

        return self._netstat_tool_cache[server]

    @staticmethod
    def _tool_exists(server, tool):
        with server.runner() as runner:
            exit_code, _, _ = runner.sh_unchecked('which {tool}', dict(tool=tool))
        return exit_code == 0

    def _cached_sh(self, server, command_key, command, args):
        if (server, command_key) not in self._command_cache:
            with server.runner() as runner:
                self._command_cache[(server, command_key)] = runner.sh(command, args)
        return self._command_cache[(server, command_key)]

    def _check_connection(self, host, port):
        s = socket(AF_INET6 if is_ipv6(host) else AF_INET, SOCK_STREAM)
        result = s.connect_ex((host, port))
        s.close()
        return result == 0
