from parallels.source.custom import messages
import json
import random
import re
import string
import uuid
import os
import logging
from parallels.core import MigrationError
from parallels.core.utils.common_constants import DATABASE_NO_SOURCE_HOST

from parallels.core.utils.yaml_utils import read_yaml
from xml.etree import ElementTree
from parallels.core.plesk_backup import save_backup_tar
from parallels.core.utils.common import xml, if_not_none, group_by_id, generate_random_password, is_empty, default, find_only
from parallels.core.utils.common.xml import elem, text_elem, seq, seq_iter, xml_to_string_pretty


logger = logging.getLogger(__name__)


class BackupCreator(object):
	def write_full_backup(self, light_backup_config, target_backup_filename, database_servers=None):
		"""
		:type light_backup_config: parallels.custom_panel_migrator.connections.LightBackupConfig
		:type target_backup_filename: basestring
		:type database_servers: list[parallels.custom_panel_migrator.connections.DatabaseServerConfig] | None
		"""
		database_servers = default(database_servers, [])
		data = read_light_backup(light_backup_config)
		self._write_full_backup_for_data(data, target_backup_filename, database_servers)

	def _write_full_backup_for_data(self, data, target_backup_filename, database_servers=None):
		"""
		:type data: list
		:type target_backup_filename: basestring
		:type database_servers: list[parallels.custom_panel_migrator.connections.DatabaseServerConfig] | None
		"""
		database_servers = default(database_servers, [])
		backup_tree = self._create_backup_tree(data, database_servers)

		self._save(backup_tree, target_backup_filename)

	def _create_backup_tree(self, data, database_servers=None):
		"""
		:type data: list
		:type database_servers: list[parallels.custom_panel_migrator.connections.DatabaseServerConfig] | None
		"""
		database_servers = default(database_servers, [])
		backup_tree = ElementTree.ElementTree(
			xml.elem(
				'migration-dump', [], {
					'content-included': 'false',
					'agent-name': 'PleskX',
					'dump-format': 'panel',
					'dump-version': '11.0.9'
				}
			)
		)
		root = backup_tree.getroot()
		root.append(xml.elem('admin', [], {'guid': str(uuid.uuid1())}))
		root.append(xml.elem('server'))
		root.find('admin').append(xml.elem('clients', [
			self._client_node(client, database_servers) for client in data
		]))
		if len(database_servers) > 0:
			root.find('server').append(xml.elem('db-servers', [
				elem(
					'db-server',
					[
						text_elem('host', db_server.host),
						text_elem('port', db_server.port),
						elem(
							'db-admin',
							[
								self._password_node(db_server.password)
							],
							{
								'name': db_server.login
							}
						)
					],
					{
						'type': db_server.db_type
					}
				)
				for db_server in database_servers
			]))

		return backup_tree

	def _client_node(self, client, database_servers):
		return elem(
			'client',
			[
				self._client_preferences_node(client),
				self._client_properties_node(client),
				self._client_domains_node(client, database_servers)
			],
			{
				'name': client['login']
			}
		)

	def _client_preferences_node(self, client):
		return elem('preferences', [
			self._pinfo_node('name', client.get('name'))
		])

	@staticmethod
	def _pinfo_node(name, value):
		if value is None:
			return None
		else:
			return text_elem('pinfo', value, {'name': name})

	def _client_properties_node(self, client):
		return elem('properties', [
			self._password_node(client.get('password')),
			self._enabled_status_node()
		])

	@staticmethod
	def _enabled_status_node():
		return elem('status', [elem('enabled')])

	def _client_domains_node(self, client, database_servers):
		subscriptions = client.get('subscriptions', [])
		return elem(
			'domains', [
				self._domain_node(subscription, database_servers) for subscription in subscriptions
			]
		)

	def _domain_node(self, domain, database_servers):
		return elem(
			'domain', seq(
				elem('preferences'),
				self._domain_properties_node(),
				self._mailsystem_node(domain, domain['name']),
				self._databases_node(domain, database_servers),
				self._phosting_node(domain)
			),
			{'www': 'true', 'name': domain['name']}
		)

	def _domain_properties_node(self):
		return elem(
			'properties',
			[
				elem('ip', [
					text_elem('ip-type', 'shared'),
					text_elem('ip-address', '127.0.0.1')
				]),
				self._enabled_status_node()
			]
		)

	def _mailsystem_node(self, domain, domain_name_filter):
		return elem(
			'mailsystem', seq(
				elem('properties', [self._enabled_status_node()]),
				self._mailusers_node(domain, domain_name_filter),
				elem('preferences'),
			)
		)

	def _mailusers_node(self, domain, domain_name_filter):
		mailboxes = domain.get('mailboxes')
		if mailboxes is None:
			return None
		else:
			return elem('mailusers', [
				self._mailuser_node(mailbox) for mailbox in mailboxes
				if self._mailbox_domain_name(domain, mailbox) == domain_name_filter
			])

	def _mailbox_domain_name(self, domain, mailbox):
		if '@' in mailbox['name']:
			return mailbox['name'][mailbox['name'].find('@')+1:]
		else:
			return domain['name']

	def _mailuser_node(self, mailbox):
		return elem(
			'mailuser',
			seq(
				elem('properties', [self._password_node(mailbox.get('password'))]),
				elem('preferences', seq(
					elem('mailbox', [], {'enabled': 'true', 'type': 'mdir'}),
					self._spamassassin_node(mailbox),
					self._virusfilter_node(mailbox)
				))
			),
			{
				'name': self._short_mailbox_name(mailbox['name']),
				'forwarding-enabled': 'false',
				'mailbox-quota': str(self._convert_to_bytes(mailbox.get('limit')))
			}
		)

	@staticmethod
	def _convert_to_bytes(value):
		if is_empty(value):
			return -1

		value = value.strip()
		m = re.match('^(\d+)\s*(M|K|)$', value)
		if m is None:
			raise Exception("Invalid format: %s" % value)
		val = m.group(1)
		multiplier = m.group(2)

		if is_empty(multiplier):
			return int(val)
		elif multiplier == 'K':
			return int(val) * 1024
		elif multiplier == 'M':
			return int(val) * 1024 * 1024
		else:
			assert False, "Invalid multiplier %s" % multiplier

	def _spamassassin_node(self, mailbox):
		spamassassin = mailbox.get('spamfilter')
		if spamassassin is None:
			return None
		else:
			return elem(
				'spamassassin', [],
				{
					'status': 'on' if self._parse_bool_value(spamassassin) else 'off'
				}
			)

	def _virusfilter_node(self, mailbox):
		antivirus = mailbox.get('virusfilter')
		if antivirus is None:
			return None
		else:
			return elem(
				'virusfilter', [],
				{
					'state': 'inout' if self._parse_bool_value(antivirus) else 'none'
				}
			)

	@staticmethod
	def _parse_bool_value(value):
		if value in (1, True, '1', 'on', 'true', 'enabled'):
			return True
		elif value in (0, False, '0', 'off', 'false', 'disabled'):
			return False
		else:
			raise Exception(
				"Invalid boolean value '%s'. Expected one of: 1/0/on/off/true/false/enabled/disabled" % value
			)

	@staticmethod
	def _password_node(password):
		if password is None:
			password = generate_random_password()
		return text_elem('password', password, {'type': 'plain'})

	def _phosting_node(self, domain):
		return elem(
			'phosting',
			seq(
				elem('preferences', [
					self._sysuser_node(domain.get('sys_user', {})),
				]),
				elem('limits-and-permissions', [elem('scripting')]),
				self._sites_node(domain)
			),
			{
				'www-root': domain.get('document_root', 'httpdocs')
			}
		)

	def _sysuser_node(self, sysuser):
		return elem(
			'sysuser',
			[
				self._password_node(sysuser.get('password'))
			],
			{
				'name': sysuser.get('login', self._random_login('sub'))
			}
		)

	@staticmethod
	def _random_login(prefix):
		random_digits = "".join(random.choice(string.digits) for _ in range(10))
		return "%s_%s" % (prefix, random_digits,)

	def _sites_node(self, domain):
		addon_domains = domain.get('addon_domains', [])
		subdomains = domain.get('subdomains', [])
		if len(addon_domains) + len(subdomains) == 0:
			return None
		else:
			return elem(
				'sites',
				(
					seq_iter(self._addon_domain_node(domain, addon_domain) for addon_domain in addon_domains) +
					seq_iter(self._subdomain_node(domain, subdomain) for subdomain in subdomains)
				)
			)

	def _addon_domain_node(self, domain, addon_domain):
		return elem(
			'site',
			[
				elem('preferences'),
				self._mailsystem_node(domain, addon_domain['name']),
				self._site_phosting_node(addon_domain)
			],
			{
				'name': addon_domain['name']
			}
		)

	def _subdomain_node(self, domain, subdomain):
		parent_name = subdomain.get('parent-domain', self._get_parent(domain, subdomain))
		return elem(
			'site',
			[
				elem('preferences'),
				self._site_phosting_node(subdomain)
			],
			{
				'name': subdomain['name'],
				'parent-domain-name': parent_name
			}
		)

	@staticmethod
	def _get_parent(domain, subdomain):
		all_domains = [domain['name']] + [addon['name'] for addon in domain.get('addon_domains', [])]
		for domain_name in all_domains:
			if subdomain['name'].endswith('.%s' % domain_name):
				return domain_name

	@staticmethod
	def _site_phosting_node(site):
		return elem(
			'phosting',
			seq(elem('preferences')),
			{'www-root': site.get('document_root', site['name'])}
		)

	@staticmethod
	def _short_mailbox_name(mailbox_name):
		if '@' in mailbox_name:
			return mailbox_name[:mailbox_name.find('@')]
		else:
			return mailbox_name

	def _databases_node(self, domain, database_servers):
		databases = domain.get('databases', [])
		if len(databases) == 0:
			return None
		else:
			return elem('databases', seq_iter(
				self._database_node(database, database_servers) for database in databases
			))

	def _database_node(self, database, database_servers):
		if database.get('host') is not None and database.get('port') is not None and database.get('type') is not None:
			host = database.get('host')
			port = database.get('port')
			db_type = database.get('type')
		elif database.get('server') is not None:
			db_server_id = database.get('server')
			db_server = find_only(
				database_servers, lambda d: d.db_server_id == db_server_id,
				"Failed to find database server with ID '{db_server_id}' in configuration file".format(
					db_server_id=db_server_id
				)
			)
			host = db_server.host
			port = db_server.port
			db_type = db_server.db_type
		else:
			# We don't have information about source database server, so we should not try
			# to copy database content in a regular way by connecting to the source database server.
			# So we mark that in backup dump by special constant 'DATABASE_NO_SOURCE_HOST'.
			host = database.get('host', DATABASE_NO_SOURCE_HOST)
			port = database.get('port', 0)
			db_type = database.get('type', 'mysql')

		return elem(
			'database',
			seq(
				elem(
					'db-server',
					seq(
						text_elem('host', host),
						text_elem('port', port),
					),
					{
						'type': db_type
					}
				),
				if_not_none(database.get('user'), self._database_user)
			),
			{
				'name': database['name'],
				'type': database.get('type', 'mysql')
			}
		)

	def _database_user(self, user):
		return elem(
			'dbuser',
			seq(
				self._password_node(user.get('password'))
			),
			{
				'name': user.get('login', self._random_login('db'))
			}
		)

	@staticmethod
	def _save(backup_tree, target_backup_filename):
		if os.path.exists(target_backup_filename):
			logger.debug(messages.REMOVING_EXISTING_DUMP_FILE)
			os.remove(target_backup_filename)

		backup_file_content = xml_to_string_pretty(backup_tree)
		save_backup_tar(backup_file_content, target_backup_filename)


def read_light_backup(light_backup_config):
	"""
	:type light_backup_config: parallels.source.custom.connections.LightBackupConfig
	"""
	input_formats_by_name = group_by_id(INPUT_FORMATS, lambda f: f.name)
	if light_backup_config.file_format not in input_formats_by_name:
		raise MigrationError(
			messages.INVALID_FILE_FORMAT % light_backup_config.path
		)
	return input_formats_by_name[light_backup_config.file_format].read(light_backup_config.path)



class InputFormat(object):
	@property
	def name(self):
		raise NotImplementedError()

	def read(self, filename):
		raise NotImplementedError()


class InputFormatJSON(InputFormat):
	@property
	def name(self):
		return 'json'

	def read(self, filename):
		with open(filename) as fp:
			return json.loads(fp.read())


class InputFormatYAML(InputFormat):
	@property
	def name(self):
		return 'yaml'

	def read(self, filename):
		return read_yaml(filename)


class InputFormatXML(InputFormat):
	@property
	def name(self):
		return 'xml'

	def read(self, filename):
		root = ElementTree.parse(filename).getroot()

		def parse(node):
			list_nodes = {
				'clients': 'client',
				'subscriptions': 'subscription',
				'addon_domains': 'addon_domain',
				'subdomains': 'subdomain',
				'databases': 'database',
				'mailboxes': 'mailbox'
			}

			if node.tag in list_nodes:
				result = []

				for child in node:
					if child.tag != list_nodes[node.tag]:
						raise MigrationError("Invalid tag: %s. Expected: %s." % (child.tag, list_nodes[node.tag]))
					result.append(parse(child))

				return result
			else:
				if len(node) > 0:
					result = {}
					for child in node:
						result[child.tag] = parse(child)
					return result
				else:
					return node.text

		return parse(root)

INPUT_FORMATS = [InputFormatYAML(), InputFormatJSON(), InputFormatXML()]
