import logging
import os
from collections import defaultdict

from parallels.core.dump.data_model import DatabaseServer
from parallels.core.utils.common import obj, cached, find_first
from parallels.core.utils.common.ip import ip_distance
from parallels.core.utils.database_utils import get_windows_mysql_client, split_mssql_host_parts_str,\
    join_mssql_hostname_and_instance
from parallels.core.migrator import Migrator as CommonMigrator
from parallels.core.utils.json_utils import read_json
from parallels.plesk.source.helm3.migrated_subscription import Helm3MigratedSubscription
from parallels.plesk.source.helm3.web_files import HelmWebFiles
from parallels.plesk.source.plesk.infrastructure_checks import checks as infrastructure_checks
from parallels.plesk.source.plesk.infrastructure_checks.checks import NodesPair
from parallels.plesk.source.helm3 import helm_constants
from parallels.plesk.source.helm3 import messages
from parallels.plesk.source.helm3.helm3_agent import NoServiceIpHelmDatabase
from parallels.plesk.source.helm3.global_context import Helm3GlobalMigrationContext
from parallels.plesk.source.helm3.session_files import Helm3SessionFiles
from parallels.plesk.source.helm3 import connections


class Migrator(CommonMigrator):
    logger = logging.getLogger(__name__)

    def _load_connections_configuration(self, global_context, target_panel_type):
        return connections.Helm3MigratorConnections(
            global_context, self._get_target_panel_by_name(target_panel_type), self._get_migrator_server()
        )

    def _create_global_context(self):
        context = Helm3GlobalMigrationContext()
        context.source_has_dns_forwarding = False
        return context

    def _create_session_files(self):
        return Helm3SessionFiles(self.global_context.conn, self._get_migrator_server())

    @property
    def web_files(self):
        """Object to list files to be transferred from source to target

        :rtype: parallels.core.utils.paths.web_files.BaseWebFiles
        """
        return HelmWebFiles()

    def _get_src_db_server(self, database_server_host, database_server_port, database_server_type, server_dump):
        """Get source database server by database object

        This function takes information about database server from backup dump and modifies it in the following way:
        1) If database server hostname is hostname of some Helm 3 server - replace it with IP address, used for
        migration.
        2) If database server hostname is some IP address of some Helm 3 server - replace it
        with IP address (of the same server) used for migration.
        3) Otherwise - consider that is some remote host, and leave it as is. In that case there could be
        issues if host could be resolved / is accessible from source but could not be resolved / is not accessible
        from target.

        This works both for MySQL and MSSQL. In case of MSSQL both named and default instances are handled.
        """
        source_db_servers = {(dbs.host, dbs.port): dbs for dbs in server_dump.iter_db_servers()}
        server_model = source_db_servers.get((database_server_host, database_server_port))
        if server_model is None:
            # If can not find information about database server - return None, which actually means error
            return None

        # Use the same function both for MSSQL and MySQL. In case of MySQL instance name will be always None.
        host, instance, port = split_mssql_host_parts_str(server_model.host)
        servers_info = self.global_context.helm3_agent.get_servers_info()
        all_ips = servers_info.get_server_ips_by_host_or_ip(host)
        if all_ips is not None and len(all_ips) > 0:
            db_server_params = server_model.as_dictionary()
            migration_ip = self.get_helm_server_migration_ip(all_ips[0])
            db_server_params['host'] = join_mssql_hostname_and_instance(migration_ip, instance, port)
            server_model = DatabaseServer(**db_server_params)

        return server_model

    def shallow_dump_supported(self, source_id):
        return True

    # ======================== copy web content ===============================

    def _get_subscription_content_ip(self, subscription):
        # servers involved into copy web content
        if subscription.is_fake:
            for site in subscription.raw_dump.iter_addon_domains():
                if site.hosting_type == 'phosting':
                    return self.get_domain_source_web_ip(subscription, site.name)
            return None
        elif subscription.raw_dump.hosting_type == 'phosting':
            return self.get_domain_source_web_ip(subscription, subscription.name)
        else:
            return None

    def get_domain_source_web_ip(self, subscription, domain_name):
        """Get IP address of a server from which we should copy web content

        Domain name is name of some domain within specified subscription.

        :type subscription: parallels.core.migrated_subscription.MigratedSubscription
        :type domain_name: str | unicode
        :rtype: str | unicode
        """
        try:
            return self.get_service_ip(
                helm_constants.WEB_SERVICE_TYPE, domain_name, 'web'
            )
        except NoServiceIpHelmDatabase as e:
            self.logger.debug(messages.LOG_EXCEPTION, exc_info=e)
            self.logger.debug(
                messages.CANNOT_GET_HELM_WEB_NODE_IP.format(
                    domain_name=domain_name
                )
            )
            try:
                return self.get_service_ip(
                    helm_constants.FTP_SERVICE_TYPE, domain_name, 'ftp'
                )
            except NoServiceIpHelmDatabase:
                raise e

    def get_domain_source_web_server(self, subscription, domain_name):
        """Get server object for source web server of specified domain

        Domain name is name of some domain within specified subscription.

        :type subscription: parallels.core.migrated_subscription.MigratedSubscription
        :type domain_name: str | unicode
        :rtype: parallels.core.connections.source_server.SourceServer | None
        """
        source_web_ip = self.get_domain_source_web_ip(subscription, domain_name)
        if source_web_ip is not None:
            return self.global_context.conn.helm3.get_source_server_by_ip(source_web_ip)
        else:
            return None

    def _get_source_web_node(self, subscription_name):
        subscription = self._create_migrated_subscription(subscription_name)
        source_ip = self._get_subscription_content_ip(subscription)
        if source_ip is not None:
            return self.global_context.conn.helm3.get_source_server_by_ip(source_ip)
        else:
            return None

    # ======================== infrastructure checks ==========================

    def _check_infrastructure_connections(self, report, safe):
        """Check infrastructure - connections and disk space requirements"""

        self.logger.info(messages.LOG_CHECK_CONNECTION_REQUIREMENTS)
        checks = infrastructure_checks.InfrastructureChecks()

        web_report = report.subtarget(messages.REPORT_CONNECTIONS_WEB_NODES, None)
        with safe(web_report, messages.FAILED_TO_CHECK_CONNECTIONS_WEB_NODES):
            self._check_windows_copy_web_content_rsync_connections(checks, web_report)

        db_report = report.subtarget(messages.REPORT_DB_SERVERS_CONNECTIONS, None)
        with safe(db_report, messages.FAILED_TO_CHECK_CONNECTIONS_DB_NODES):
            self._check_windows_mysql_copy_db_content_rsync_connections(checks, db_report)
            self._check_windows_copy_mssql_db_content(db_report)

    def _check_disk_space(self, report, safe):
        """Check disk space requirements for source and target servers"""

        self.logger.info(messages.LOG_CHECK_DISK_SPACE_REQUIREMENTS)
        disk_space_report = report.subtarget(messages.REPORT_TARGET_DISK_SPACE_REQUIREMENTS, None)
        self._check_disk_space_windows(disk_space_report)

    def _check_disk_space_windows(self, report):
        self.logger.info(messages.GET_DISK_USAGE_ON_WINDOWS_SERVER)

        web_diskspace_usages = []
        mssql_diskspace_usages = []
        mysql_diskspace_usages = []

        for subscription in self.iter_all_subscriptions():
            if subscription.is_fake:
                continue

            web_diskspace_usages.append(
                obj(
                    usage_source=self.global_context.helm3_agent.get_domain_diskspace_usage(
                        subscription.name,
                        helm_constants.WEB_SERVICE_TYPE
                    ),
                    subscription_name=subscription.name,
                    target_node=subscription.web_target_server
                )
            )

        for database_info in self._list_databases_to_copy():

            source = database_info.source_database_server
            target = database_info.target_database_server

            if source.type() == 'mysql' and target.is_windows():
                mysql_diskspace_usages.append(
                    obj(
                        usage_source=self.global_context.helm3_agent.get_domain_diskspace_usage(
                            database_info.subscription_name,
                            helm_constants.MYSQL_SERVICE_TYPE
                        ),
                        subscription_name=database_info.subscription_name,
                        target_node=target.panel_server,
                        db=obj(
                            db_name=database_info.database_name,
                            db_target_node=target
                        )
                    )
                )
            if source.type() == 'mssql':
                mssql_diskspace_usages.append(
                    obj(
                        usage_source=self.global_context.helm3_agent.get_domain_diskspace_usage(
                            database_info.subscription_name,
                            helm_constants.MSSQL_SERVICE_TYPE
                        ),
                        subscription_name=database_info.subscription_name,
                        target_node=target.panel_server,
                        db=obj(
                            db_name=database_info.database_name,
                            db_target_node=target
                        )
                    )
                )

        self.logger.info(messages.CHECKING_IF_THERE_IS_ENOUGH_DISK_SPACE)
        checker_target = infrastructure_checks.WindowsDiskSpaceChecker()
        diskspace_usages = web_diskspace_usages + mysql_diskspace_usages + mssql_diskspace_usages
        for target_node in set(
            subs_info.target_node for subs_info in diskspace_usages
        ):
            subscription_names = [
                subs_info.subscription_name for subs_info in diskspace_usages if subs_info.target_node == target_node
            ]
            usages_source_web = sum([
                subs_info.usage_source for subs_info in web_diskspace_usages if subs_info.target_node == target_node
            ])
            usages_source_mysql = defaultdict(list)  # { db_server: [ usage ] }
            for mysql_diskspace_usage in mysql_diskspace_usages:
                if mysql_diskspace_usage.target_node == target_node:
                    usages_source_mysql[(
                        mysql_diskspace_usage.db.db_target_node.host(),
                        mysql_diskspace_usage.db.db_target_node.port()
                    )].append(mysql_diskspace_usage.usage_source)
            usages_source_mssql = defaultdict(list)  # { db_server: [ usage ] }
            for mssql_diskspace_usage in mssql_diskspace_usages:
                if mssql_diskspace_usage.target_node == target_node:
                    usages_source_mssql[mssql_diskspace_usage.db.db_target_node.host()].append(
                        mssql_diskspace_usage.usage_source
                    )
            checker_target.check_with_source_usages(
                target_node,
                mysql_bin=get_windows_mysql_client(target_node),
                usage_source_web=usages_source_web,
                usages_source_mysql_db=usages_source_mysql,
                usages_source_mssql_db=usages_source_mssql,
                domains=subscription_names,
                mysql_databases=[
                    subs_info.db for subs_info in mysql_diskspace_usages if subs_info.target_node == target_node
                ],
                mssql_databases=[
                    subs_info.db for subs_info in mssql_diskspace_usages if subs_info.target_node == target_node
                ],
                report=report
            )

    def _check_windows_copy_web_content_rsync_connections(self, checks, report):
        domains_by_servers = defaultdict(list)

        for subscription in self.global_context.iter_all_subscriptions():
            target_server = subscription.web_target_server
            for site in subscription.raw_dump.iter_addon_domains():
                if site.hosting_type == 'phosting':
                    source_server = self.get_domain_source_web_server(subscription, site.name)
                    domains_by_servers[NodesPair(source_server, target_server)].append(site.name)
            if not subscription.is_fake and subscription.raw_dump.hosting_type == 'phosting':
                source_server = self.get_domain_source_web_server(subscription, subscription.name)
                domains_by_servers[NodesPair(source_server, target_server)].append(subscription.name)

        checks.check_copy_web_content_per_domain(
            domains=domains_by_servers.items(),
            report=report,
            check_function=lambda nodes_pair_info: self._check_windows_copy_content(
                checker=infrastructure_checks.WindowsFileCopyBetweenNodesRsyncChecker(),
                nodes_pair_info=nodes_pair_info
            ),
        )

    # ======================== utility functions ==============================

    def _create_migrated_subscription(self, name):
        return Helm3MigratedSubscription(self, name)

    @cached
    def _is_fake_domain(self, subscription_name):
        """Detect if subscription is fake (so it does not exist on source)

        In case when client has one domain with enabled web or ftp we don't
        create fake domain. In all other cases we create fake domain
        """
        return self.global_context.helm3_agent.is_fake_domain(subscription_name)

    def get_service_ip(self, service_type_id, domain_name, content_type):
        # Command returns all IP addresses of a server where specified service of domain is located
        service_ips = self.global_context.helm3_agent.get_service_ips(domain_name, service_type_id)
        if len(service_ips) > 0:
            return self.get_helm_server_migration_ip(service_ips[0])
        else:
            raise NoServiceIpHelmDatabase(
                messages.DOMAIN_HAS_NO_CONTENT_OF_THAT_TYPE.format(
                    content_type=content_type, domain_name=domain_name
                )
            )

    def _get_rsync(self, source_server, target_server, source_ip=None):
        if source_server is None:
            source_server = self.global_context.conn.helm3.get_source_server_by_ip(source_ip)

        source_vhosts_dir = self.global_context.helm3_agent.get_vhosts_dir_source(source_server)

        return self.global_context.rsync_pool.get(source_server, target_server, source_vhosts_dir)

    def get_helm_server_migration_ip(self, service_ip):
        """Get IP used for communication with Helm server, by any other of its IP addresses

        :type service_ip: str | unicode
        :rtype: str | unicode
        """
        servers_info = self.global_context.helm3_agent.get_servers_info()
        all_ips = servers_info.get_all_server_ips_by_ip(service_ip)
        if all_ips is None:
            return service_ip

        # First, if server is present in configuration file, use IP from configuration file
        server = find_first(self.global_context.conn.helm3.source_servers, lambda s: s.ip() in all_ips)
        if server is not None:
            return server.ip()

        # Then, try to check file with preferred IPs used for migration
        for ip in self._read_migration_ips_file():
            if ip in all_ips:
                return ip

        # Otherwise, look for an IP with minimal difference from main node IP
        main_node_ip = self.global_context.conn.helm3.get_main_source_server().ip()
        result = min(all_ips, key=lambda x: ip_distance(x, main_node_ip))

        return result

    @cached
    def _read_migration_ips_file(self):
        """Read file which contains IP addresses to use for migration

        File is JSON, with dictionary which has:
        - server hostname as a key
        - IP address to use for migration as value
        This function returns just a list of IP addresses.

        :rtype: list[str | unicode]
        """
        filename = self.global_context.session_files.get_path_to_migration_ips()
        if os.path.exists(filename):
            return read_json(filename).values()
        else:
            return []
