from parallels.common import messages
import logging
from collections import namedtuple
import threading

from parallels.common.utils import ssh_utils
from parallels.common.utils import yaml_utils
from parallels.utils.threading_utils import synchronized_by_args

logger = logging.getLogger(__name__)

KeyInfo = yaml_utils.pretty_yaml(
	namedtuple('KeyInfo', ('key_pathname', 'public_key'))
)


class SSHKeyPool(object):
	def __init__(self, filename):
		self.keys = dict()
		self.lock = threading.Lock()
		self.filename = filename
		self.readonly = False

	@synchronized_by_args
	def get(self, source_server, target_server):
		"""
		:return: parallels.common.utils.ssh_key_pool.KeyInfo
		"""
		self.lock.acquire()
		if self.readonly:
			raise Exception(messages.SSH_KEY_POOL_IS_READ_ONLY)

		if (source_server, target_server,) not in self.keys:
			logger.debug(
				"Configure SSH key auth from %s to %s",
				target_server.description(),
				source_server.description()
			)
			with source_server.runner() as runner_source:
				with target_server.runner() as runner_target:
					key_pathname, public_key = ssh_utils.set_up_keys(
						dst_exec=ssh_utils.runner_exec_adapter(runner_source),
						src_exec=ssh_utils.runner_exec_adapter(runner_target)
					)
					self.keys[(source_server, target_server)] = KeyInfo(
						key_pathname, public_key
					)
					self._save_keys()

		self.lock.release()
		return self.keys[(source_server, target_server)]

	def remove_all(self):
		for (source_server, target_server), key_info in self.keys.iteritems():
			logger.debug(
				"Remove SSH key auth from %s to %s",
				target_server.description(),
				source_server.description()
			)
			with source_server.runner() as runner_source:
				with target_server.runner() as runner_target:
					try:
						ssh_utils.remove_keys(
							dst_exec=ssh_utils.runner_exec_adapter(runner_source),
							src_exec=ssh_utils.runner_exec_adapter(runner_target),
							key_pathname=key_info.key_pathname,
							public_key=key_info.public_key,
						)
					except Exception as e:
						logger.debug("Exception", exc_info=True)
						logger.error(
							"Failed to remove SSH key from %s to %s: %s",
							target_server.description(),
							source_server.description(),
							e
						)
		self.keys = {}
		self._save_keys()
		self.readonly = True

	def _save_keys(self):
		keys = {}
		for (source_server, target_server), key_info in self.keys.iteritems():
			keys[source_server.ip(), target_server.ip()] = key_info
		yaml_utils.write_yaml(self.filename, keys)
