import logging

from parallels.core.registry import Registry
from parallels.core.runners.unix.base import UnixRunner
from parallels.core.utils.common.threading_utils import synchronized
from parallels.core.utils.migrator_utils import secure_write_open
from parallels.core.utils.steps_profiler import get_default_steps_profiler
from parallels.core.utils import unix_utils, ssh_utils
from parallels.core.utils.common import default
from parallels.core.utils.unix_utils import format_command

logger = logging.getLogger(__name__)
profiler = get_default_steps_profiler()


class SSHRunner(UnixRunner):
    def __init__(self, ssh, server):
        super(SSHRunner, self).__init__(host_description=server.description(), hostname=server.ip())
        self.ssh = ssh
        self._server = server

    @property
    def codepage(self):
        return 'utf-8'

    def _run_unchecked_no_logging(
        self, cmd, args=None, stdin_content=None, output_codepage=None, error_policy='strict', env=None
    ):
        """Execute a command with a list of positional 'args'."""
        command_str = self._format_run_command(cmd, args)
        env = self._get_env(env)
        codepage = output_codepage if output_codepage else self.codepage
        return ssh_utils.run_unchecked(
            self.ssh, command_str, stdin_content, codepage, error_policy, env
        )

    def _sh_unchecked_no_logging(
        self, cmd, args=None, stdin_content=None, output_codepage=None, error_policy='strict',
        env=None, log_output=True, working_dir=None, redirect_output_file=None
    ):
        if working_dir is not None:
            raise NotImplementedError()
        if redirect_output_file is not None:
            raise NotImplementedError()

        command_str = self._format_sh_command(cmd, args)
        env = self._get_env(env)
        codepage = output_codepage if output_codepage else self.codepage
        return ssh_utils.run_unchecked(
            self.ssh, command_str, stdin_content, codepage, error_policy, env
        )

    @staticmethod
    def _get_env(env):
        env = default(env, {})
        # Execute all commands under 'en_US.utf-8' locale, if it is not specified,
        # so migrator works in the same way regardless of server's locale
        if 'LANG' not in env:
            env['LANG'] = 'en_US.utf-8'
        return env

    def upload_file(self, local_filename, remote_filename):
        self._transfer_file_with_rsync(remote_filename, local_filename, direction='from_local_to_remote')

    @synchronized
    def upload_file_content(self, filename, content):
        global_context = Registry.get_instance().get_context()
        local_temp_file = global_context.migrator_server.get_session_file_path(
            global_context.local_temp_filename.get('ssh_upload')
        )
        with secure_write_open(local_temp_file) as fp:
            fp.write(content)

        try:
            self._transfer_file_with_rsync(filename, local_temp_file, direction='from_local_to_remote')
        finally:
            with global_context.migrator_server.runner() as runner:
                runner.remove_file(local_temp_file)

    def get_file(self, remote_filename, local_filename):
        self._transfer_file_with_rsync(remote_filename, local_filename, direction='from_remote_to_local')

    def _transfer_file_with_rsync(self, remote_filename, local_filename, direction):
        global_context = Registry.get_instance().get_context()
        key_pathname = Registry.get_instance().get_context().ssh_key_pool.get(
            self._server, global_context.migrator_server
        ).key_pathname

        files = ["{ssh_user}@{remote_ip}:{remote_filename}", "{local_filename}"]
        if direction == 'from_local_to_remote':
            files.reverse()
        elif direction == 'from_remote_to_local':
            pass
        else:
            raise Exception("Invalid direction passed to _transfer_file_with_rsync function")

        with global_context.migrator_server.runner() as runner:
            ssh_command = format_command(
                'ssh -i {key} -p {port} '
                '-o PasswordAuthentication=no -o StrictHostKeyChecking=no -o GSSAPIAuthentication=no',
                key=key_pathname,
                port=self._server.settings().ssh_auth.port,
            )
            runner.sh(
                "rsync -l --chmod=Fu=r,Du=rwx,go= -e {ssh_command} " + " ".join(files),
                dict(
                    local_filename=local_filename,
                    ssh_user=self._server.settings().ssh_auth.username,
                    remote_ip=self._server.ip(),
                    remote_filename=remote_filename,
                    ssh_command=ssh_command,
                )
            )

    def get_file_contents(self, remote_filename):
        return self.run('/bin/cat', [remote_filename], log_output=False)

    def move(self, src_path, dst_path):
        return self.run('/bin/mv', [src_path, dst_path])

    def get_files_list(self, path):
        return unix_utils.get_files_list(self, path)

    def remove_file(self, filename):
        return self.run('/bin/rm', ['-f', filename])

    def remove_directory(self, directory, is_remove_root=True):
        if not is_remove_root:
            raise NotImplementedError()

        return self.run('/bin/rm', ['-rf', directory])

    def mkdir(self, dirname):
        return self.run('/bin/mkdir', ['-p', dirname])

    def file_exists(self, filename):
        return unix_utils.file_exists(self, filename)

    def is_dir(self, path):
        return unix_utils.is_directory(self, path)
