from parallels.common import messages
import logging
import parallels
import textwrap
from collections import namedtuple
from parallels.common.utils.database_server_type import DatabaseServerType

from parallels.common.registry import Registry
from parallels.common import run_command
from parallels.common.utils import windows_utils
from parallels.common import MigrationError
from parallels.common.target_panels import TargetPanels
from parallels.common.utils.migrator_utils import get_package_extras_file_path, \
	get_package_scripts_file_path
from parallels.hosting_check.utils.powershell import Powershell
from parallels.utils import split_nonempty_strs, is_empty

logger = logging.getLogger(__name__)

DatabaseCredentials = namedtuple(
	'DatabaseCredentials', 'host port login password db_name')

def copy_db_content_linux(database_info, key_pathname):
	source = database_info.source_database_server
	target = database_info.target_database_server
	database_name = database_info.database_name
	logger.info(u"Copy database '{db_name}' content from {source} to {target}".format(
		db_name=database_name,
		source=source.description(),
		target=target.description()
	))

	dump_filename = 'db_backup_%s_%s' % (
			database_info.subscription_name, database_name)

	source_credentials = _get_db_credentials(source, database_name)
	target_credentials = _get_db_credentials(target, database_name)

	if source.type() == DatabaseServerType.MYSQL:
		source_passwd_file = None
		backup_command = get_mysql_backup_command_template(
			source.host, source.port, source.password())

		passwd_file = create_mysql_options_file(target, target_credentials)
		restore_command = get_mysql_restore_command_template(
			target.host(), target.port(), passwd_file)
	elif source.type() == DatabaseServerType.POSTGRESQL:
		source_passwd_file = create_pgpass_file(source, source_credentials)
		backup_command = get_pgdump_command_template(source.host())
		with source.runner() as source_runner:
			locate_pg_dump(source_runner)
		passwd_file = create_pgpass_file(target, target_credentials)
		restore_command = get_pg_restore_command_template(target.host(), passwd_file)
	else:
		raise Exception(
			messages.CANNOT_TRANSFER_DATABASE_UNSUPPORTED_TYPE_S % source.type())

	with source.runner() as source_runner:
		filename_stem = source.panel_server.get_session_file_path(dump_filename)
		source_dump_tmpname = _get_safe_filename(source_runner, filename_stem)
		source_runner.sh(backup_command + u" > {source_dump_tmpname}", dict(
			src_host=source.host(),
			src_port=source.port(),
			src_admin=source.user(),
			src_password=source.password(),
			db_name=database_name,
			source_dump_tmpname=source_dump_tmpname
		))

	with target.runner() as runner:
		filename_stem = target.panel_server.get_session_file_path(dump_filename)
		target_dump_tmpname = _get_safe_filename(runner, filename_stem)
		runner.sh(
			u"scp -i {key_pathname} -o StrictHostKeyChecking=no -o GSSAPIAuthentication=no "
			u"{src_server_ip}:{source_dump_tmpname} {target_dump_tmpname}",
			dict(
				key_pathname=key_pathname,
				src_server_ip=source.ip(),
				source_dump_tmpname=source_dump_tmpname,
				target_dump_tmpname=target_dump_tmpname
			)
		)

	with source.runner() as source_runner:
 		if source_passwd_file:
 			source_runner.remove_file(source_passwd_file)
 		source_runner.remove_file(source_dump_tmpname)

	with target.runner() as target_runner:
		destination_credentials = get_target_db_credentials(target_credentials)
		command_parameters = dict(target_dump_tmpname=target_dump_tmpname)
		command_parameters.update(destination_credentials)
		target_runner.sh(restore_command + u" < {target_dump_tmpname}", command_parameters)
 		if passwd_file:
 			target_runner.remove_file(passwd_file)
 		target_runner.remove_file(target_dump_tmpname)


def get_target_db_credentials(target_credentials):
	return _add_dict_key_prefix('dst_', target_credentials._asdict())


def _get_db_credentials(server, db_name):
	"""Return an object that defines database connection."""
	return DatabaseCredentials(
		host=server.host(), port=server.port(), login=server.user(),
		password=server.password(), db_name=db_name)


def _add_dict_key_prefix(prefix, dictionary):
	"""Return a new dictionary, with all key names prefixed by a given string."""
	return dict((prefix+k, v) for k, v in dictionary.iteritems())


def get_mysql_backup_command_template(host, port, password):
	backup_command = (
		u"mysqldump -h {src_host} -P {src_port} -u{src_admin} --quick --quote-names "
		u"--add-drop-table --default-character-set=utf8 --set-charset {db_name}")
	if host == 'localhost' and port == 3306:
		# a workaround for Plesk
		backup_command += u" -p\"`cat /etc/psa/.psa.shadow`\""
	elif password != '':
		backup_command += u" -p{src_password}"
	return backup_command


def get_mysql_restore_command_template(host, port, option_file=None):
	restore_command = (
		u"mysql %s -h {dst_host} -P {dst_port} -u{dst_login} {dst_db_name}")
	if option_file:
		credentials = '--defaults-file=' + option_file
	elif host == 'localhost' and port == 3306:
		credentials = u" --no-defaults -p\"`cat /etc/psa/.psa.shadow`\""
	else:
		credentials = u" --no-defaults -p{dst_password}"
	return restore_command % (credentials,)


def create_mysql_options_file(remote_server, credentials):
	"""Create a mysql options file and upload it to the remote server."""
	filename = 'my_{host}_{db_name}.cnf'.format(**credentials._asdict())
	connection_str = u"[client]\nuser={login}\npassword={password}".format(
			**credentials._asdict())
	remote_filename = _upload_file_content_linux(remote_server, filename, connection_str)
	return remote_filename


def get_pgdump_command_template(hostname, passwd_file=None):
	if passwd_file:
		credentials = u"PGPASSFILE=%s" % passwd_file
	else:
		credentials = u"PGUSER={src_admin} PGPASSWORD={src_password} PGDATABASE={db_name}"
	backup_command = "pg_dump --format=custom --blobs --no-owner --ignore-version"
	if hostname != 'localhost':
		backup_command += " -h {src_host} -p {src_port}"
	return u"%s %s" % (credentials, backup_command)


def get_pg_restore_command_template(hostname, passwd_file):
	credentials = u"PGPASSFILE=%s" % passwd_file
	restore_command = (u"pg_restore -v -U {dst_login} -d {dst_db_name} -c")
	if hostname != 'localhost':
		restore_command += " -h {dst_host} -p {dst_port}"
	return u"%s %s" % (credentials, restore_command)


def create_pgpass_file(server, credentials):
	"""Create a PostgreSQL password file and upload it to the remote server."""
	filename = 'pgpass_{host}_{db_name}'.format(**credentials._asdict())
	connection_str = '*:*:{db_name}:{login}:{password}'.format(**credentials._asdict())
	remote_filename = _upload_file_content_linux(server, filename, connection_str)
	return remote_filename


def _upload_file_content_linux(server, filename, content):
	"""Upload file contents.

	File is created in server's remote session dir, if one exists, or in '/tmp'
	otherwise.
	"""
	context = Registry.get_instance().get_context()
	full_filename = context.migrator_server.get_session_file_path(filename)
	with context.migrator_server.runner() as runner:
		runner.upload_file_content(full_filename, content)
	with server.runner() as runner:
		if hasattr(server, 'get_session_file_path'):
			remote_filepath = server.get_session_file_path(filename)
			remote_filename = _get_safe_filename(runner, remote_filepath)
		else:
			remote_filename = _get_safe_filename(runner, "/tmp/%s_XXXXXX" % filename)
		runner.upload_file(full_filename, remote_filename.strip())
		runner.sh(u'chmod 600 "%s"' % remote_filename)
	return remote_filename


def _get_safe_filename(runner, filename_stem):
	"""Create a temporary file with 600 access rights."""
	return runner.sh(u'mktemp "%s_XXXXXX"' % filename_stem).strip()


def locate_pg_dump(runner):
	"""Make sure 'pg_dump' utility is installed."""
	if runner.run_unchecked('which', ['pg_dump'])[0] != 0:
		raise MigrationError(
			textwrap.dedent(messages.PG_DUMP_UTILITY_IS_NOT_INSTALLED % runner.ssh.hostname))


def copy_db_content_windows(database_info, rsync):
	source = database_info.source_database_server
	target = database_info.target_database_server
	logger.info(u"Copy database '{database_name}' content from {source} to {target}".format(
		database_name=database_info.database_name,
		source=source.description(),
		target=target.description()
	))

	dump_filename = 'db_backup_%s_%s.sql' % (database_info.subscription_name, database_info.database_name)
	source_dump_filename = source.panel_server.get_session_file_path(dump_filename)
	target_dump_filename = target.panel_server.get_session_file_path(dump_filename)

	with target.panel_server.runner() as runner_target, source.panel_server.runner() as runner_source:

		if source.type() == DatabaseServerType.MYSQL:
			if not is_windows_mysql_client_configured(target.panel_server):
				raise MigrationError((
					messages.MYSQL_CLIENT_BINARY_WAS_NOT_FOUND) % (target.description(), database_info.database_name))

			runner_source.sh(
				ur'cmd.exe /C "{path_to_mysqldump} -h {host} -P {port} -u{user} -p{password} {database_name} '
				ur'--result-file={source_dump_filename}"' + (
					' --skip-secure-auth' if source.panel_server.mysql_use_skip_secure_auth() else ''
				),
				dict(
					path_to_mysqldump=source.panel_server.get_path_to_mysqldump(),
					host=source.host(),
					port=source.port(),
					user=source.user(),
					password=source.password(),
					database_name=database_info.database_name,
					source_dump_filename=source_dump_filename
				)
			)

			try:
				logger.debug(messages.COPY_DATABASE_DUMP_FROM_SOURCE_TARGET)

				rsync.sync(
					source_path='migrator/%s' % (
						source_dump_filename[source_dump_filename.rfind('\\') + 1:]
					),
					target_path=windows_utils.convert_path_to_cygwin(
						target_dump_filename
					),
				)
			except run_command.HclRunnerException as e:
				logger.debug(u"Exception: ", exc_info=e)
				raise MigrationError(
					messages.RSYNC_FAILED_COPY_DATABASE_DUMP_FROM % (
						source.panel_server.ip(),
						target.panel_server.ip(),
						e.stderr
					)
				)
			except Exception as e:
				logger.debug(u"Exception: ", exc_info=e)
				raise MigrationError(
					messages.RSYNC_FAILED_COPY_DATABASE_DUMP_FROM_1 % (
						source.panel_server.ip(),
						target.panel_server.ip(),
						str(e)
					)
				)

			logger.debug(messages.RESTORE_DATABASE_DUMP_TARGET_SERVER)
			runner_target.sh(
				ur'{mysql_client} --no-defaults -h {host} -P {port} -u{login} -p{password} {database_name} '
				ur'-e "source {target_dump_filename}"',
				dict(
					mysql_client=get_windows_mysql_client(target.panel_server),
					host=target.host(),
					port=target.port(),
					login=target.user(),
					password=target.password(),
					database_name=database_info.database_name,
					target_dump_filename=target_dump_filename
				)
			)

			logger.debug(u"Remove database dump files; source: %s, target: %s" % (
				source_dump_filename,
				target_dump_filename
			))
			runner_source.remove_file(source_dump_filename)
			runner_target.remove_file(target_dump_filename)

		elif source.type() == DatabaseServerType.MSSQL:
			runner_target.sh(
				ur'cmd.exe /C "{dbbackup_path} --copy -copy-if-logins-exist -with-data -src-server={src_host} '
				ur'-server-type={db_type} -src-server-login={src_admin} -src-server-pwd={src_password} '
				ur'-src-database={database_name} -dst-server={dst_host} -dst-server-login={dst_admin} '
				ur'-dst-server-pwd={dst_password} -dst-database={database_name}"',
				dict(
					dbbackup_path=windows_utils.path_join(target.panel_server.plesk_dir, r'admin\bin\dbbackup.exe'),
					db_type=source.type(),
					src_host=windows_utils.get_dbbackup_mssql_host(source.host(), source.ip()),
					src_admin=source.user(),
					src_password=source.password(),
					dst_host=target.host(),
					dst_admin=target.user(),
					dst_password=target.password(),
					database_name=database_info.database_name
				)
			)
		else:
			logger.error(messages.DATABASE_UNSUPPORTED_TYPE_AND_HENCE_NOT)


def get_windows_mysql_client(target_server):
	context = Registry.get_instance().get_context()
	if context.target_panel in [TargetPanels.PLESK, TargetPanels.PVPS]:
		return windows_utils.path_join(target_server.plesk_dir, 'MySQL\\bin\\mysql')
	else:
		return 'mysql'


def is_windows_mysql_client_configured(target_server):
		mysql = get_windows_mysql_client(target_server)
		with target_server.runner() as runner:
			return runner.sh_unchecked(u'cmd /c "%s" --version' % mysql)[0] == 0


def check_connection(database_server):
	"""Check connection to specified database server

	:type database_server: parallels.common.connections.database_server.DatabaseServer
	:raises parallels.common.utils.database_utils.DatabaseServerConnectionException:
	"""
	if is_empty(database_server.user()) or is_empty(database_server.password()):
		raise DatabaseServerConnectionException(
			messages.SERVER_IS_NOT_PROPERLY_CONFIGURED_TARGET.format(
				server=database_server.description()
			)
		)

	if database_server.type() == DatabaseServerType.MYSQL:
		check_mysql_connection(database_server)
	elif database_server.type() == DatabaseServerType.POSTGRESQL:
		check_postgresql_connection(database_server)
	elif database_server.type() == DatabaseServerType.MSSQL:
		check_mssql_connection(database_server)
	else:
		return


def check_mysql_connection(database_server):
	"""Check connection to MySQL database server

	:type database_server: parallels.common.connections.database_server.DatabaseServer:
	:raises parallels.common.utils.database_utils.DatabaseServerConnectionException:
	"""

	with database_server.runner() as runner:
		command = (
			'{mysql} --silent --skip-column-names -h {host} -P {port} -u {user} -p{password} -e {query}')
		args = dict(
			mysql=database_server.mysql_bin,
			user=database_server.user(), password=database_server.password(),
			host=database_server.host(), port=database_server.port(),
			query='SELECT 1'
		)
		exit_code, stdout, stderr = runner.sh_unchecked(command, args)

		if exit_code != 0 or stdout.strip() != '1':
			raise DatabaseServerConnectionException(
				'Connection to {server} failed.\nCommand was: {command}\nStdout: {stdout}\nStderr: {stderr}\nExit code: {exit_code}'.format(
					server=database_server.description(),
					command=command.format(**args),
					stdout=stdout,
					stderr=stderr,
					exit_code=exit_code
				)
			)


def check_postgresql_connection(database_server):
	"""Check connection to PostgreSQL database server

	:type database_server: parallels.common.connections.database_server.DatabaseServer:
	:raises parallels.common.utils.database_utils.DatabaseServerConnectionException:
	"""
	command = "PGUSER={user} PGPASSWORD={password} psql"
	if database_server.host() != 'localhost':
		command += " -h {host} -p {port}"
	command += " -dtemplate1 -A -t -q -c {query}"
	args = dict(
		user=database_server.user(), password=database_server.password(),
		host=database_server.host(), port=database_server.port(),
		query='SELECT 1'
	)
	with database_server.runner() as runner:
		exit_code, stdout, stderr = runner.sh_unchecked(command, args)

	if exit_code != 0 or stdout.strip() != '1':
		raise DatabaseServerConnectionException(
			'Connection to {server} failed.\nCommand was: {command}\nStdout: {stdout}\nStderr: {stderr}\nExit code: {exit_code}'.format(
				server=database_server.description(),
				command=command.format(**args),
				stdout=stdout,
				stderr=stderr,
				exit_code=exit_code
			)
		)


def check_mssql_connection(database_server):
	"""Check connection to MSSQL database server

	:type database_server: parallels.common.connections.database_server.DatabaseServer:
	:raises parallels.common.utils.database_utils.DatabaseServerConnectionException:
	"""
	with database_server.runner() as runner:
		script_name = 'check_mssql_connection.ps1'
		local_script_path = get_package_scripts_file_path(parallels.plesks_migrator, script_name)
		remote_script_path = database_server.panel_server.get_session_file_path(script_name)
		runner.upload_file(local_script_path, remote_script_path)
		powershell = Powershell(runner, input_format_none=True)
		args = {
			'dataSource': database_server.host(),
			'login': database_server.user(),
			'pwd': database_server.password(),
		}
		exit_code, stdout, stderr = powershell.execute_script_unchecked(remote_script_path, args)
		if exit_code != 0 or stderr != '':
			raise DatabaseServerConnectionException(
				'Connection to {server} failed.\nCommand that checks MSSQL connection with Powershell was: {command}\nStdout: {stdout}\nStderr: {stderr}\nExit code: {exit_code}'.format(
					server=database_server.description(),
					command=powershell.get_command_string(remote_script_path, args),
					stdout=stdout,
					stderr=stderr,
					exit_code=exit_code
				)
			)


class DatabaseServerConnectionException(Exception):
	pass


def list_databases(database_server):
	"""List databases on specified server.

	Returns list of databases names or None if that function is not supported for that database server type.

	:type database_server: parallels.common.connections.database_server.DatabaseServer
	:rtype: set[basestring] | None
	"""
	if database_server.type() == DatabaseServerType.MYSQL:
		return list_mysql_databases(database_server)
	elif database_server.type() == DatabaseServerType.POSTGRESQL:
		return list_postgresql_databases(database_server)
	elif database_server.type() == DatabaseServerType.MSSQL:
		return list_mssql_databases(database_server)
	else:
		return None


def list_mysql_databases(database_server):
	"""List databases on specified MySQL server.

	Returns list of databases names.

	:type database_server: parallels.common.connections.database_server.DatabaseServer
	:rtype: set[basestring]
	"""
	with database_server.runner() as runner:
		stdout = runner.sh(
			'{mysql} --silent --skip-column-names -h {host} -P {port} -u {user} -p{password} -e {query}',
			dict(
				mysql=database_server.mysql_bin,
				user=database_server.user(), password=database_server.password(),
				host=database_server.host(), port=database_server.port(),
				query='SHOW DATABASES'
			)
		)

		return set(split_nonempty_strs(stdout))


def list_postgresql_databases(database_server):
	"""List databases on specified PostgreSQL server.

	Returns list of databases names or None if that function is not supported for that database server type.

	:type database_server: parallels.common.connections.database_server.DatabaseServer
	:rtype: set[basestring]
	"""
	with database_server.runner() as runner:
		command = "PGUSER={user} PGPASSWORD={password} psql"
		if database_server.host() != 'localhost':
			command += " -h {host} -p {port}"
		command += " -dtemplate1 -A -t -q -c {query}"
		stdout = runner.sh(
			command,
			dict(
				user=database_server.user(), password=database_server.password(),
				host=database_server.host(), port=database_server.port(),
				query='SELECT datname FROM pg_database'
			)
		)

		return set(split_nonempty_strs(stdout))


def list_mssql_databases(database_server):
	"""List databases on specified MSSQL server.

	Returns list of databases names or None if that function is not supported for that database server type.

	:type database_server: parallels.common.connections.database_server.DatabaseServer
	:rtype: set[basestring]
	"""
	with database_server.runner() as runner:
		script_name = 'list_mssql_databases.ps1'
		local_script_path = get_package_extras_file_path(parallels.common, script_name)
		remote_script_path = database_server.panel_server.get_session_file_path(script_name)
		runner.upload_file(local_script_path, remote_script_path)
		powershell = Powershell(runner, input_format_none=True)
		exit_code, stdout, _ = powershell.execute_script_unchecked(
			remote_script_path, {
				'dataSource': database_server.host(),
				'login': database_server.user(),
				'pwd': database_server.password(),
			}
		)

		if exit_code != 0:
			raise MigrationError(
				messages.FAILED_GET_LIST_DATABASES_MSSQL_DATABASE % (
					database_server.host(), database_server.port()
				)
			)

		return set(split_nonempty_strs(stdout))
