import logging
import threading
import Queue 
from contextlib import contextmanager
from parallels.common.actions.base.action_pointer import ActionPointer

from parallels.common.actions.base.common_action \
	import CommonAction
from parallels.common.actions.base.condition_action_pointer import ConditionActionPointer
from parallels.common.actions.base.subscription_action \
	import SubscriptionAction
from parallels.common.actions.base.compound_action \
	import CompoundAction

from parallels.common import migrator_config
from parallels.common.logging_context import log_context
from parallels.common.migrator_config import MultithreadingStatus
from parallels.common.utils.migrator_utils import trace_step
from parallels.common.utils.steps_profiler import get_default_steps_profiler

logger = logging.getLogger(__name__)


class ActionRunner(object):
	"""Executes workflow actions and entry points

	Considers their sequence, handles exceptions so non-critical failures do
	not break migration
	"""

	def __init__(self, global_context, multithreading=None):
		self._global_context = global_context
		self._executed_actions = set()
		if multithreading is not None:
			self._multithreading = multithreading
		else:
			self._multithreading = migrator_config.MultithreadingParams(
				status=MultithreadingStatus.DISABLED, num_workers=None, by_subscriptions=False
			)

	def run_entry_point(self, entry_point):
		try:
			self.run(entry_point)
		finally:
			shutdown_paths = set()

			# add shutdown actions for each set-up action
			all_shutdown_actions = entry_point.get_shutdown_actions()
			for path in self._executed_actions:
				if path in all_shutdown_actions:
					shutdown_paths.update(all_shutdown_actions[path])

			# add overall shutdown actions
			if None in all_shutdown_actions:
				shutdown_paths.update(all_shutdown_actions[None])

			for path in shutdown_paths - self._executed_actions:
				self.run(entry_point.get_path(path))

	def run(self, action, action_id=None, path=None):
		if path is None:
			path = []

		if isinstance(action, ActionPointer):
			action = action.resolve()
		elif isinstance(action, ConditionActionPointer):
			action = action.resolve(self._global_context)

		if isinstance(action, CompoundAction):
			if not action.run_by_subscription or not self._multithreading.by_subscriptions:
				self._run_compound_action_by_layers(action, action_id, path)
			else:
				self._run_compound_action_by_subscriptions(action, action_id, path)
		elif isinstance(action, SubscriptionAction):
			self._run_subscription_action(action_id, action)
		elif isinstance(action, CommonAction):
			self._run_action(action_id, action)
		else:
			raise Exception(
				"Invalid action '%s': %s is not a subclass of "
				"CompoundAction, SubscriptionAction or BaseAction" % (
					action_id,
					action.__class__.__name__
			))

		self._executed_actions.add('/'.join(path))

	def _run_compound_action_by_layers(self, action, action_id, path):
		with self._profile_action(action_id, action):
			with self._trace_action(action_id, action):
				for child_action_id, child_action in action.get_all_actions():
					self.run(child_action, child_action_id, path + [child_action_id])

	def _run_compound_action_by_subscriptions(self, action, action_id, path):
		action_queue = self._get_subscription_action_queue(action, action_id, path)
		locks = {}

		for child_action, _, _ in action_queue:
			if not child_action.get_multithreading_properties().can_use_threads:
				locks[child_action] = threading.Lock()

		subscription_queue = Queue.Queue()

		def worker():
			while True:
				try:
					subscription = subscription_queue.get(False)
				except Queue.Empty:
					return  # no more subscriptions in a queue

				logger.info("START Processing subscription '%s'", subscription.name)
				with log_context(subscription.name):
					for child_action, child_action_id, child_action_path in action_queue:
						logger.info("START: %s", child_action.get_description())
						if child_action in locks:
							locks[child_action].acquire()

						if child_action.filter_subscription(self._global_context, subscription):
							self._run_single_subscription(child_action, subscription)

						if child_action in locks:
							locks[child_action].release()
						logger.info("END: %s", child_action.get_description())
						self._executed_actions.add('/'.join(child_action_path))
				logger.info("END Processing subscription '%s'", subscription.name)
				subscription_queue.task_done()

		for subscription in self._global_context.iter_all_subscriptions():
			subscription_queue.put(subscription)

		for i in range(0, self._multithreading.num_workers):
			thread = threading.Thread(target=worker)
			thread.start()

		subscription_queue.join()

	def _run_action(self, action_id, action):
		with self._profile_action(action_id, action):
			with self._trace_action(action_id, action):
				action.run(self._global_context)

	def _run_subscription_action(self, action_id, action):
		"""Run specified action on all migrated subscriptions

		This function also considers isolated safe execution for each
		subscription and logging.

		Arguments:
		- action - subscription action, instance of
		  parallels.common.utils.subscription_action.SubscriptionAction
		"""
		with self._profile_action(action_id, action):
			logger.debug("Run subscription action class %s", action.__class__.__name__)
			logger.debug(action.get_description())
			subscriptions = filter(
				lambda s: action.filter_subscription(self._global_context, s),
				self._global_context.iter_all_subscriptions()
			)

			if len(subscriptions) > 0:
				with self._trace_action(action_id, action):
					if action.get_logging_properties().info_log:
						logging_function = logger.info
					else:
						logging_function = logger.debug

					action_threading_props = action.get_multithreading_properties()

					if self._multithreading.status == MultithreadingStatus.DISABLED:
						multithreading_enabled = False
					elif self._multithreading.status == MultithreadingStatus.DEFAULT:
						multithreading_enabled = action_threading_props.use_threads_by_default
					elif self._multithreading.status == MultithreadingStatus.FULL:
						multithreading_enabled = action_threading_props.can_use_threads
					else:
						assert False

					if not multithreading_enabled:
						for n, subscription in enumerate(subscriptions, start=1):
							logging_function(
								"Process subscription '%s' (#%s out of %s)", 
								subscription.name, n, len(subscriptions)
							)
							self._run_single_subscription(action, subscription)
					else:
						queue = Queue.Queue()

						def worker():
							while True:
								try:
									subscription = queue.get(False)
								except Queue.Empty:
									return # no more subscriptions in a queue

								logging_function(
									"Start processing subscription '%s'", 
									subscription.name
								)
								self._run_single_subscription(action, subscription)
								logging_function(
									"Finish processing subscription '%s'", 
									subscription.name
								)
								queue.task_done()

						for subscription in subscriptions:
							queue.put(subscription)

						for i in range(0, self._multithreading.num_workers):
							thread = threading.Thread(target=worker)
							thread.start()

						queue.join()

	def _run_single_subscription(self, action, subscription):
		safe = self._global_context.safe

		safe.try_subscription_with_rerun(
			lambda: action.run(self._global_context, subscription),
			subscription.name,
			error_message=action.get_failure_message(
				self._global_context, subscription
			),
			repeat_error="%s. Try to repeat operation once more." % (
				action.get_failure_message(
					self._global_context, subscription
				)
			)
		)

	@contextmanager
	def _trace_action(self, action_id, action):
		if action.get_description() is not None:
			if action.get_logging_properties().info_log:
				log_level = 'info'
			else:
				log_level = 'debug'

			with trace_step(
				action_id, action.get_description(),
				log_level=log_level,
				compound=action.get_logging_properties().compound
			):
				yield
		else:
			yield

	@staticmethod
	@contextmanager
	def _profile_action(action_id, action):
		if action.get_description() is not None:
			with get_default_steps_profiler().measure_time(
				action_id, action.get_description()
			):
				yield
		else:
			# 1) Compound actions that have no descriptions are profiled by all child actions.
			# 2) Other actions that have no descriptions are measured as "Other" action
			yield

	def _get_subscription_action_queue(self, action, action_id, path):
		"""Get plain action list out of CompoundAction or SubscriptionAction"""

		actions = []
		if isinstance(action, CompoundAction):
			for child_action_id, child_action in action.get_all_actions():
				actions.extend(
					self._get_subscription_action_queue(
						child_action, child_action_id, path + [child_action_id]
					)
				)
			return actions
		elif isinstance(action, SubscriptionAction):
			return [(action, action_id, path)]
		elif isinstance(action, CommonAction):
			raise Exception(
				"Can not run common action in subscription's actions queue"
			)