from parallels.core import messages
import logging
import paramiko
from paramiko.ssh_exception import SSHException
from contextlib import closing, contextmanager
from pipes import quote
from parallels.core.utils.steps_profiler import sleep

from parallels.core import MigrationError
from parallels.core.utils.common import poll_data

logger = logging.getLogger(__name__)

class ExecutionError(Exception):
	def __init__(self, command, exit_status, stdout, stderr):
		msg = (
			u"""Remote command "%s" terminated with code %d\nstdout: %s\nstderr: %s"""
			% (command, exit_status, stdout, stderr)
		)
		Exception.__init__(self, msg)
		self.command = command
		self.exit_status = exit_status
		self.stdout = stdout
		self.stderr = stderr

# Use this to avoid missing keys warning at WARNING logging level (which makes migrator output a bit messy).
# Please note that "how should we handle missing keys" is a disputable question, and probably ignoring them is not the best way.
class IgnoreMissingKeyPolicy(paramiko.MissingHostKeyPolicy):
	def missing_host_key(self, client, hostname, key):
		pass

def make_command(cmd, args):
	qparts = [quote(str(arg)) for arg in [cmd]+args]
	return ' '.join(qparts)


def run(ssh_client, command, stdin_content=None, output_codepage='utf-8', error_policy='strict'):
	exit_status, stdout_content, stderr_content = run_unchecked(ssh_client, command, stdin_content, output_codepage=output_codepage, error_policy=error_policy)
	if exit_status != 0:
		raise ExecutionError(command, exit_status, stdout_content, stderr_content)
	else:
		return stdout_content, stderr_content


def run_unchecked(ssh_client, command, stdin_content=None, output_codepage='utf-8', error_policy='strict', env=None):
	"""
	command - is either an unicode string (u''), or byte string that consists of only ascii symbols
	"""

	if env is not None:
		raise NotImplementedError(messages.PARAMETER_ENV_IS_NOT_SUPPORTED_2)

	def exec_command():
		try:
			return ssh_client.exec_command(command.encode('utf-8'))
		except SSHException as e:
			if e.message == u'Unable to open channel.':
				logger.debug(u"Exception:", exc_info=e)
				return None
			else:
				raise

	result = poll_data(exec_command, [5] * 3)
	if result is None:
		assert(ssh_client.hostname is not None)
		raise MigrationError(messages.UNABLE_OPEN_CHANNEL_SSH_CONNECTION_S % ssh_client.hostname)
	else:
		stdin, stdout, stderr = result

	channel = stdout.channel

	if stdin_content is not None:
		stdin.write(stdin_content)

	# Close stdin so program waiting for input will get EOF.
	# stdin.close() does nothing beside flushing and setting internal 'closed' flag,
	# so it is needed to explicitly close writing side of channel.
	stdin.close()
	channel.shutdown_write()

	stdout_content = stdout.read().decode(output_codepage, error_policy) # TODO: This will hang if tool produces a lot of stderr output.
	stderr_content = stderr.read().decode(output_codepage, error_policy)

	stdout.close()
	stderr.close()

	exit_status = channel.recv_exit_status()

	logger.debug(u"Exit status: %d, stdout: %s, stderr: %s", exit_status, stdout_content, stderr_content)

	return exit_status, stdout_content, stderr_content


class SSHClientWithHostname(paramiko.SSHClient):
	def __init__(self):
		paramiko.SSHClient.__init__(self)
		self.hostname = None

	def connect(self, hostname, *args, **kw):
		self.hostname = hostname
		self.args = args
		self.kw = kw
		self._connect_multiple_attemtps()

	def _connect_multiple_attemtps(self):
		max_attempts = 5
		interval_between_attempts = 10

		for attempt in range(0, max_attempts):
			try:
				paramiko.SSHClient.connect(self, self.hostname, *self.args, **self.kw)
				if attempt > 0:
					logger.info(messages.SUCCESSFULLY_CONNECTED_S_BY_SSH, self.hostname)
				return
			except IOError as e:
				logger.debug("Exception: ", exc_info=True)
				if attempt >= max_attempts - 1:
					raise MigrationError(
						messages.UNABLE_CONNECT_HOST_BY_SSH_EXCEPTION.format(
							host=self.hostname, exception=str(e)
						)
					)
				else:
					logger.error(
						u"Unable to connect to '{host}' by SSH: {exception}. Retrying in {interval_between_attempts} seconds.".format(
							host=self.hostname,
							interval_between_attempts=interval_between_attempts,
							exception=str(e)))
					sleep(interval_between_attempts, 'Reconnect by SSH')

	def exec_command(self, *args, **kwargs):
		if not self.get_transport().is_active():
			logger.warning(messages.SSH_CONNECTION_S_WAS_UNEXPECTEDLY_CLOSED % self.hostname)
			self.reconnect()
		return super(SSHClientWithHostname, self).exec_command(*args, **kwargs)

	def open_sftp(self, *args, **kwargs):
		if not self.get_transport().is_active():
			logger.warning(messages.SSH_CONNECTION_S_WAS_UNEXPECTEDLY_CLOSED_1 % self.hostname)
			self.reconnect()
		return super(SSHClientWithHostname, self).open_sftp(*args, **kwargs)

	def reconnect(self):
		super(SSHClientWithHostname, self).close()
		self._connect_multiple_attemtps()

def connect(settings):
	client = SSHClientWithHostname()
	client.set_missing_host_key_policy(IgnoreMissingKeyPolicy())
	settings.ssh_auth.connect(settings.ip, client)
	return closing(client)


def run_shell_script(ssh, content, args=''):
	return run(ssh, 'sh -s %s' % args, content)


@contextmanager
def public_key_ssh_access(src_exec, dst_exec):
	"""Setup temporary key-based SSH access from one host to another.

	`src_exec` and `dst_exec` are functions for executing remote commands
	on source and destination hosts respectively. They must accept
	shell commands, return (stdout, stderr) and fail in case of error.

	Usage:
		with public_key_ssh_access(exec_on_src, exec_on_dst) as key_path:
			# `key_path` is path to private key on source server
			pass
	"""
	key_pathname, public_key = set_up_keys(dst_exec, src_exec)
	yield key_pathname
	remove_keys(dst_exec, src_exec, key_pathname, public_key)



def set_up_keys(dst_exec, src_exec, authorized_keys_path='~/.ssh/authorized_keys'):
	(key_pathname, public_key) = src_exec("""
		set -e;
		ssh_config_dir=~/.ssh
		if [ ! -d "$ssh_config_dir" ]; then
			mkdir -p "$ssh_config_dir"
		fi
		key_filename=`mktemp $ssh_config_dir/id_dsa.XXXXXXXX`;
		echo y | ssh-keygen -q -t dsa -P "" -f "$key_filename" > /dev/null;
		echo "$key_filename";
		cat "$key_filename.pub" """
	)[0].strip().split("\n")
	dst_exec(
		"""
		authorized_keys_dir="`dirname %(authorized_keys_path)s`"
		if [ ! -d "$authorized_keys_dir" ]; then
			mkdir -p "$authorized_keys_dir"
		fi

		grep '%(public_key)s' %(authorized_keys_path)s || {
		umask 0077;
		echo '%(public_key)s' >> %(authorized_keys_path)s;
		}
		"""
		% {'public_key': public_key,
		   'authorized_keys_path': authorized_keys_path}
	)
	return key_pathname, public_key



def remove_keys(dst_exec, src_exec, key_pathname, public_key, authorized_keys_path='~/.ssh/authorized_keys'):
	# undo key-based ssh access from target to source
	dst_exec("sed -i -e \"\\|%s|d\" %s" % (public_key, authorized_keys_path))
	src_exec("rm -f %s{,.pub}" % key_pathname)



def public_key_ssh_access_runner(runner_src, runner_dst):
	"""Adapter that takes run_command transports and calls public_key_ssh_access"""
	src_exec = runner_exec_adapter(runner_src)
	dst_exec = runner_exec_adapter(runner_dst)
	return public_key_ssh_access(src_exec, dst_exec)



def runner_exec_adapter(runner):
	 return lambda script: (runner.sh(script), None)

