import logging

from parallels.core.utils.common import format_list
from parallels.core.utils.json_utils import write_json
from parallels.plesk.source.helm3 import messages
from parallels.core.actions.base.common_action import CommonAction

logger = logging.getLogger(__name__)


class CheckHelmNodes(CommonAction):
    """Check connections to Helm nodes, write results to JSON file

    Here for each Helm server involved into migration:
    - we deploy RPC agent, if it is not deployed
    - try to execute simple command on the server
    - we print results to console

    Also, migrator writes JSON file with information about failed servers, for GUI integration.
    """

    def get_description(self):
        return messages.CHECK_HELM_NODES_ACTION_DESCRIPTION

    def get_failure_message(self, global_context):
        """
        :type global_context: parallels.plesk.source.helm3.global_context.Helm3GlobalMigrationContext
        """
        return messages.CHECK_HELM_NODES_ACTION_FAILURE

    def run(self, global_context):
        """
        :type global_context: parallels.plesk.source.helm3.global_context.Helm3GlobalMigrationContext
        """
        servers = set()
        for subscription in global_context.iter_all_subscriptions():
            # servers involved into copy mail content
            if subscription.mail_source_server is not None:
                servers.add(subscription.mail_source_server)

            # servers involved into copy web content
            for site in subscription.raw_dump.iter_addon_domains():
                if site.hosting_type == 'phosting':
                    servers.add(global_context.migrator.get_domain_source_web_server(subscription, site.name))
            if not subscription.is_fake and subscription.raw_dump.hosting_type == 'phosting':
                servers.add(global_context.migrator.get_domain_source_web_server(subscription, subscription.name))

        failed_servers_info = {}
        logger.info(messages.LOG_CHECK_CONNECTIONS_TO_SERVERS_LIST.format(
            count=len(servers), servers=format_list([server.ip() for server in servers])
        ))

        for num, server in enumerate(servers, start=1):
            servers_info = global_context.helm3_agent.get_servers_info()
            server_name = servers_info.get_server_name_by_ip(server.ip())
            all_server_ips = servers_info.get_all_server_ips_by_ip(server.ip())
            logger.info(messages.LOG_CHECK_CONNECTIONS_TO_SERVER.format(
                server_name=server_name, server_ip=server.ip(), num=num, count=len(servers))
            )
            try:
                with server.runner() as runner:
                    runner.check(server.ip())
            except Exception:
                logger.debug(messages.LOG_EXCEPTION, exc_info=True)
                failed_servers_info[server_name] = {
                    'ips': all_server_ips,
                    'current_ip': server.ip()
                }
                logger.error(messages.FAILED_TO_CONNECT_TO_HELM3_SERVER.format(
                    server_name=server_name, server_ip=server.ip()
                ))
            else:
                logger.info(messages.SUCCESS_CONNECTED_TO_HELM3_SERVER.format(
                    server_name=server_name, server_ip=server.ip()
                ))

        write_json(
            global_context.session_files.get_path_to_helm_communication_failure_servers(),
            failed_servers_info, pretty_print=True
        )
