from paramiko import SSHClient, MissingHostKeyPolicy
from paramiko.ssh_exception import SSHException
import logging

from parallels.hosting_check import DomainIssue
from parallels.hosting_check import Severity
from parallels.hosting_check import SSHAuthIssueType
from parallels.hosting_check.messages import MSG

from parallels.hosting_check.utils.list_utils import none_list

logger = logging.getLogger(__name__)

class SSHAuthChecker(object):
	def check(self, domains_to_check):
		issues = []
		for domain_to_check in domains_to_check:
			self._check_single_domain(domain_to_check, issues)
		return issues

	def _check_single_domain(self, domain_to_check, issues):
		for user in domain_to_check.users:
			try:
				if not self._can_verify_user(user):
					issues.append(
						DomainIssue(
							domain_name=domain_to_check.domain_name, 
							severity=Severity.WARNING, 
							category=SSHAuthIssueType.ENCRYPTED_PASSWORD,
							problem=MSG(
								SSHAuthIssueType.ENCRYPTED_PASSWORD,
								server_ip=domain_to_check.web_server_ip, 
								login=user.login
							)
						)
					)
				else:
					auth_successful, message = self._check_ssh_auth(
						domain_to_check.web_server_ip, user
					)
					if not auth_successful:
						issues.append(DomainIssue(
							domain_name=domain_to_check.domain_name, 
							severity=Severity.ERROR, 
							category=SSHAuthIssueType.CHECK_FAILED, 
							problem=MSG(
								SSHAuthIssueType.CHECK_FAILED,
								server_ip=domain_to_check.web_server_ip,
								login=user.login, error_message=message
							)
						))
			except KeyboardInterrupt:
				# for compatibility with python 2.4
				raise
			except Exception, e:
				logger.debug(u"Exception:", exc_info=e)
				issues.append(DomainIssue(
					domain_name=domain_to_check.domain_name, 
					severity=Severity.WARNING, 
					category=SSHAuthIssueType.INTERNAL_ERROR,
					problem=MSG(
						SSHAuthIssueType.INTERNAL_ERROR,
						server_ip=domain_to_check.web_server_ip,
						login=user.login, error_message=str(e)
					)
				))
		for user in none_list(domain_to_check.inactive_users):
			try:
				auth_successful, _ = self._check_ssh_auth(
					domain_to_check.web_server_ip, user
				)
				if auth_successful:
					issues.append(DomainIssue(
						domain_name=domain_to_check.domain_name, 
						severity=Severity.ERROR, 
						category=SSHAuthIssueType.INACTIVE_USER_CAN_CONNECT, 
						problem=MSG(
							SSHAuthIssueType.INACTIVE_USER_CAN_CONNECT,
							server_ip=domain_to_check.web_server_ip,
							login=user.login
						)
					))
			except KeyboardInterrupt:
				# for compatibility with python 2.4
				raise
			except Exception, e:
				logger.debug(u"Exception:", exc_info=e)
				issues.append(DomainIssue(
					domain_name=domain_to_check.domain_name, 
					severity=Severity.WARNING, 
					category=SSHAuthIssueType.INTERNAL_ERROR,
					problem=MSG(
						SSHAuthIssueType.INTERNAL_ERROR,
						server_ip=domain_to_check.web_server_ip,
						login=user.login, error_message=str(e)
					)
				))

	def _can_verify_user(self, user):
		return user.password_type == 'plain'

	@staticmethod
	def _check_ssh_auth(server_ip, user):
		logger.debug(MSG(
			'ssh_log_checking', 
			user=user.login, server_ip=server_ip
		))
		client = SSHClient()
		client.set_missing_host_key_policy(IgnoreMissingKeyPolicy())

		try:
			try:
				client.connect(
					server_ip,
					username=user.login, password=user.password,
					look_for_keys=False
				)
				_, stdout, _ = client.exec_command('echo test_ok')
				stdout_content = stdout.read().strip()
				if stdout_content != 'test_ok':
					return False, 'failed to execute simple command with SSH'
			finally:
				client.close()
			return True, None
		except KeyboardInterrupt:
			# for compatibility with python 2.4
			raise
		except SSHException, e:
			logger.debug("Exception: ", exc_info=True)
			return False, str(e)
		
# Use this to avoid missing keys warning at WARNING logging level (which makes
# migrator output a bit messy).
class IgnoreMissingKeyPolicy(MissingHostKeyPolicy):
	def missing_host_key(self, client, hostname, key):
		pass
