from OpenSSL import crypto
from socket import gethostname
import os


class SSLKeys(object):
	def __init__(self, migrator_server):
		"""
		:type migrator_server parallels.common.connections.migrator_server.MigratorServer
		"""
		self._migrator_server = migrator_server
		files = [
			self.migration_node_crt_filename,
			self.migration_node_key_filename,
			self.source_node_crt_filename,
			self.source_node_key_filename,
			self._keys_generated_filename
		]
		if not all([os.path.exists(filename) for filename in files]):
			self._generate_keys()

	def _generate_keys(self):
		self._generate_key_pair(
			cert_file=self.migration_node_crt_filename,
			key_file=self.migration_node_key_filename
		)
		self._generate_key_pair(
			cert_file=self.source_node_crt_filename,
			key_file=self.source_node_key_filename
		)
		with self._migrator_server.runner() as runner:
			""":type runner parallels.common.run_command.BaseRunner"""
			runner.upload_file_content(
				self._keys_generated_filename,
				"This is a file indicating that SSL certificates and keys "
				"used by migration tools were generated successfully. "
				"Remove that file if you want to regenerate them. "
			)

	@property
	def migration_node_crt_filename(self):
		return self._migrator_server.get_session_file_path('migration-node.crt')

	@property
	def migration_node_key_filename(self):
		return self._migrator_server.get_session_file_path('migration-node.key')

	@property
	def source_node_crt_filename(self):
		return self._migrator_server.get_session_file_path('source-node.crt')

	@property
	def source_node_key_filename(self):
		return self._migrator_server.get_session_file_path('source-node.key')

	@property
	def _keys_generated_filename(self):
		return self._migrator_server.get_session_file_path('ssl-keys-generated')

	def _generate_key_pair(self, cert_file, key_file):
		key = self._create_private_key()
		cert = self._create_certificate(key)

		self._save_private_key(key, key_file)
		self._save_certificate(cert, cert_file)

	@staticmethod
	def _create_private_key():
		key = crypto.PKey()
		key.generate_key(crypto.TYPE_RSA, 1024)
		return key

	@staticmethod
	def _create_certificate(key):
		cert = crypto.X509()
		cert.get_subject().C = "US"
		cert.get_subject().ST = "Virginia"
		cert.get_subject().L = "Herndon"
		cert.get_subject().O = "Parallels, Inc."
		cert.get_subject().OU = "Parallels Panel"
		cert.get_subject().CN = gethostname()
		cert.set_serial_number(1)
		cert.set_notBefore(
			# Fake date in the past to avoid any issues with date on server
			'20000101000000+0000'
		)
		year = 365 * 24 * 60 * 60
		cert.gmtime_adj_notAfter(
			# Valid for 10 years - should be enough for any migration and date fixes on server
			10 * year
		)
		cert.set_issuer(cert.get_subject())
		cert.set_pubkey(key)
		cert.sign(key, 'sha1')
		return cert

	@staticmethod
	def _save_private_key(key, key_file):
		with open(key_file, "wt") as fp:
			fp.write(
				crypto.dump_privatekey(crypto.FILETYPE_PEM, key)
			)

	@staticmethod
	def _save_certificate(cert, cert_file):
		with open(cert_file, "wt") as fp:
			fp.write(
				crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
			)