from parallels.source.helm3 import messages
import logging
import os

from contextlib import closing

from parallels.core.actions.base.common_action import CommonAction
from parallels.core.plesk_backup import plesk_backup_xml

logger = logging.getLogger(__name__)


class ReplaceDbServerIp(CommonAction):
	def get_description(self):
		return messages.REPLACE_DATABASE_SERVER_IPS

	def get_failure_message(self, global_context):
		"""
		:type global_context: parallels.source.helm3.global_context.Helm3GlobalMigrationContext
		"""
		return messages.FAILED_TO_REPLACE_DATABASE_SERVER_IPS

	def run(self, global_context):
		"""
		:type global_context: parallels.source.helm3.global_context.Helm3GlobalMigrationContext
		"""
		raw_backup_filename = global_context.session_files.get_path_to_raw_plesk_backup('helm3')
		if not global_context.options.reload_source_data and os.path.exists(raw_backup_filename):
			logger.debug(messages.DATABASE_SERVER_NAMES_ALREADY_REPLACED)
		else:
			pre_backup_filename = global_context.session_files.get_path_to_pre_plesk_backup('helm3')
			with closing(plesk_backup_xml.load_backup(pre_backup_filename)) as backup:
				for dbs in backup.iter_db_servers():
					dbs.set_host(
						self._replace_host_name(global_context, dbs.host)
					)

				for subscription in backup.iter_all_subscriptions():
					for db in subscription.all_databases:
						db.set_host(
							self._replace_host_name(global_context, db.host)
						)

				with open(raw_backup_filename, 'wb') as raw_backup_file:
					backup.save(raw_backup_file)

	def _replace_host_name(self, global_context, host):
		"""
		:type global_context: parallels.source.helm3.global_context.Helm3GlobalMigrationContext
		"""
		server_name = host.split("\\")[0]
		server_ips = global_context.helm3_agent.get_server_ips(server_name)
		for server in global_context.conn.helm3.source_servers:
			if server.ip() in server_ips:
				return server.ip() + host[len(server_name):]

		# Changed in the previous run
		return host
