import logging

from parallels.expand_api.operator import PleskServerOperator, \
	PleskMailServerOperator, PleskCentralizedDbServerOperator, ClientOperator, \
	PleskClientOperator, PleskDomainOperator, DomainTemplateOperator, \
	PleskDomainAliasOperator
from parallels.expand_migrator.expand_data.model import CentralizedMailServer, ExpandClient, PleskClient, PleskDomain, \
	PleskDomainAlias
from parallels.utils import group_by_id, if_not_none
from parallels.utils.ip import resolve

logger = logging.getLogger(__name__)


class ExpandAPIData(object):
	"""Retrieve data from Expand with the help of Expand API"""

	def __init__(self, api):
		self.api = api

	def get_centralized_mail_servers(self, centralized_mail_servers):
		cmail_server_plesk_id_by_ip = dict(
			(srv.ip, source_id) for source_id, srv in centralized_mail_servers.iteritems()
		)

		cmail_servers = []
		for result in self.api.send(
			PleskMailServerOperator.Get(
				filter=PleskMailServerOperator.Get.FilterAll(),
			)
		):
			server_ip = result.data.ip
			if server_ip in cmail_server_plesk_id_by_ip:
				cmail_servers.append(
					CentralizedMailServer(
						id=result.data.id, ip=server_ip,
						plesk_id=cmail_server_plesk_id_by_ip[server_ip],
						assigned_server_ids=result.data.assigned_server_ids
					)
				)
			else:
				logger.debug(
					u"Expand centralized mail server #%s is skipped "
					u"as its IP '%s' is not among selected centralized mail servers IP addresses" % (
						result.data.id, server_ip,
					)
				)

		return cmail_servers

	def get_expand_clients(self):
		logger.debug(u"Get Expand clients")
		return [
			ExpandClient(
				id=result.data.id,
				info=result.data.info,
				personal_info=result.data.personal_info
			)
			for result in self.api.send(
				ClientOperator.Get(
					filter=ClientOperator.Get.FilterAll(),
					dataset=[ClientOperator.Get.Dataset.INFO, ClientOperator.Get.Dataset.PERSONAL_INFO]
				)
			)
		]

	def iter_plesk_clients(self, expand_plesk_servers):
		logger.debug(u"Get Plesk clients")
		plesk_client_id_to_expand_client_id = self._get_plesk_client_id_to_expand_client_id()
		for server in expand_plesk_servers:
			for plesk_client_info in self._get_plesk_clients_by_server_id(server.id):
				yield PleskClient(
					id=plesk_client_info.id,
					login=plesk_client_info.gen_info.login,
					pname=plesk_client_info.gen_info.pname,
					plesk_id=server.plesk_id,
					expand_client_id=plesk_client_id_to_expand_client_id.get(plesk_client_info.id)
				)

	def _get_plesk_client_id_to_expand_client_id(self):
		result = {}
		expand_client_ids = [client.id for client in self.get_expand_clients()]
		for expand_client_id, plesk_client_ids in self._get_plesk_client_ids_by_expand_client_ids(expand_client_ids):
			for plesk_client_id in plesk_client_ids:
				result[plesk_client_id] = expand_client_id
		return result

	def _get_plesk_client_ids_by_expand_client_ids(self, expand_client_ids):
		return [
			result.data for result in self.api.send(
				ClientOperator.InstanceGet(
					filter=ClientOperator.InstanceGet.FilterByClientId(ids=expand_client_ids)
				)
			)
		]

	def _get_plesk_clients_by_server_id(self, server_id):
		return [
			result.data for result in self.api.send(
				PleskClientOperator.Get(
					filter=PleskClientOperator.Get.FilterByServerId(server_id=server_id),
					dataset=[PleskClientOperator.Get.Dataset.GEN_INFO]
				)
			)
		]

	def get_plesk_domains(self, expand_plesk_servers):
		logger.debug(u"Get Plesk domains")
		domain_templates = {}
		for result in self.api.send(
			DomainTemplateOperator.Get(
				filter=DomainTemplateOperator.Get.FilterAll(),
				dataset=[DomainTemplateOperator.Get.Dataset.GEN_SETUP]
			)
		):
			domain_templates[result.data.id] = result.data.name

		expand_plesk_servers_by_id = group_by_id(expand_plesk_servers, lambda s: s.id)

		return [
			PleskDomain(
				id=result.data.id,
				name=result.data.name,
				client_id=result.data.client_id,
				plesk_id=expand_plesk_servers_by_id[result.data.server_id].plesk_id,
				tmpl_id=result.data.tmpl_id,
				tmpl_name=if_not_none(result.data.tmpl_id, lambda tmpl_id: domain_templates[tmpl_id])
			)
			for result in self.api.send(
				PleskDomainOperator.Get(
					filter=PleskDomainOperator.Get.FilterAll(),
					dataset=[]
				)
			)
			if result.data.server_id in expand_plesk_servers_by_id
		]

	def get_plesk_domain_aliases(self, plesk_domains):
		logger.debug(u"Get Plesk aliases")

		domain_aliases = []
		for domain in plesk_domains:
			for result in self.api.send(
				PleskDomainAliasOperator.Get(
					filter=PleskDomainAliasOperator.Get.FilterByDomainId(domain_id=domain.id)
				)
			):
				domain_aliases.append(
					PleskDomainAlias(
						id=result.data.id,
						name=result.data.name
					)
				)
		return domain_aliases

	def get_plesk_server_ips(self):
		logger.debug(u"Get IP addresses of Plesk servers")
		plesk_servers_result = self.api.send(
			PleskServerOperator.Get(
				filter=PleskServerOperator.Get.FilterAll(),
				dataset=[PleskServerOperator.Get.Dataset.IPPOOL]
			)
		)

		return {
			ip_info.ip_address
			for result in plesk_servers_result
			for ip_info in result.data.ips
		}

	def get_mail_server_ips(self):
		logger.debug(u"Get IP addresses of centralized mail servers")
		return {
			result.data.ip for result in self.api.send(
				PleskMailServerOperator.Get(
					filter=PleskServerOperator.Get.FilterAll(),
				)
			)
		}

	def get_db_server_ips(self):
		logger.debug(u"Get IP addresses of centralized database servers")

		plesk_server_ids = {result.data.id for result in self.api.send(
			PleskServerOperator.Get(
				filter=PleskServerOperator.Get.FilterAll(),
				# need to specify something here to get the IDs
				dataset=[PleskServerOperator.Get.Dataset.IPPOOL]
			)
		)}

		db_server_ips = set()
		db_hosts = {
			result.data.host for result in self.api.send(
				PleskCentralizedDbServerOperator.Get(
					filter=PleskCentralizedDbServerOperator.Get.FilterByPleskServerId(plesk_server_ids),
				)
			) if not result.data.is_local
		}

		for db_host in db_hosts:
			# may be it is better to resolve hostname on Expand cluster,
			# but it is much more complex to implement
			db_server_ip = resolve(db_host)
			if db_server_ip is None:
				logger.debug(
					u"Unable to resolve database server hostname '%s'. It will be silently skipped." % db_host
				)
			else:
				db_server_ips.add(db_server_ip)

		return db_server_ips