import dns.message, dns.query, dns.rdtypes.ANY.TXT, dns.rdtypes.ANY.CNAME, dns.rdtypes.ANY.MX, dns.rdtypes.ANY.PTR, dns.rdtypes.IN.A, dns.rdtypes.IN.AAAA
from parallels.hosting_check import DNSRecord

import logging

logger = logging.getLogger(__name__)


def get_difference(records, dns_server_ip):
    difference = []

    records_grouped = {}
    for record in records:
        if (record.rec_type, record.src) not in records_grouped:
            records_grouped[(record.rec_type, record.src)] = []
        records_grouped[(record.rec_type, record.src)].append(record)

    for (rec_type, rec_src), expected in sorted(
        records_grouped.iteritems(),
        cmp=lambda x, y: cmp_records_src_type(x[0][0], x[0][1], y[0][0], y[0][1])
    ):
        if is_allowed_record_type(rec_type):
            logger.debug(
                u"Checking the records %s on the DNS server at %s",
                _format_list([_pretty_record_str(r) for r in expected]), dns_server_ip
            )
            actual = resolve_record(rec_type, rec_src, dns_server_ip)
            comparison_result = compare_records(expected, actual)
            if comparison_result is not None:
                difference.append(comparison_result)
        else:
            pass  # skip all records we don't know how to handle

    return difference


def cmp_records_src_type(rec_type1, rec_src1, rec_type2, rec_src2):
    if len(rec_src1) > len(rec_src2):  # first go records with minimum length of src
        return 1
    elif len(rec_src1) < len(rec_src2):
        return -1
    else:
        if rec_src1 != rec_src2:
            return cmp(rec_src1, rec_src2)  # then records are compared src
        else:
            return cmp(rec_type1, rec_type2)  # and if src is the same - compare by type


def compare_records(expected, actual):
    def sort_records(records):
        return sorted(records, key=lambda r: (r.rec_type, r.src, r.dst.lower(), r.opt))

    expected = sort_records([ComparableRec.from_rec(r) for r in expected])
    actual = sort_records([ComparableRec.from_rec(r) for r in actual])
    if expected != actual:
        if len(expected) != len(actual):
            if len(actual) != 0:
                return (
                    "expected %s record(s) %s, got %s record(s) %s" % ( 
                        len(expected), ", ".join(["'%s'" % (s,) for s in expected]),
                        len(actual), ", ".join(["'%s'" % (s,) for s in actual])
                    )
                )
            else:
                return (
                    "expected record(s) %s, got no records" % ( 
                        ", ".join(["'%s'" % (s,) for s in expected]),
                    )
                )
        else:
            return (
                "expected %s, got %s" % (
                    ", ".join(["'%s'" % (s,) for s in expected]),
                    ", ".join(["'%s'" % (s,) for s in actual])
                )
            )

    return None  # no difference


def is_allowed_record_type(rec_type):
    return rec_type in rec_type_to_rdatatype

rec_type_to_rdatatype = {
    'A': dns.rdatatype.A,
    'AAAA': dns.rdatatype.AAAA,
    'CNAME': dns.rdatatype.CNAME,
    'MX': dns.rdatatype.MX,
    'TXT': dns.rdatatype.TXT,
    'PTR': dns.rdatatype.PTR,
    'NS': dns.rdatatype.NS,
    # SRV records are not checked, because they are backed up incorrecly,
    # restored incorrectly and there are issues when creating/editing them in PPA
    # and they are not so important/frequent when migrating from other panels 
}


def resolve_record(rec_type, rec_src, dns_server_ip):
    """
    Function simplifies iteraction with dnspython library, making DNS queries easy:
    you pass DNS record type and DNS record source as a DNS query, and get list of 'Rec' objects as a query result

    Returns list of items where each item is:
    - 'Rec', in case DNS server responsed with DNS data we know how to deal with and that we know how to
    convert to 'Rec' object
    - 'UnknownDnsResponse', in case DNS responsed with something that we don't know

    Raises DnsRequestTimeoutException in case of timeout
    """
    logger.debug(u"Sending the DNS query: type='%s', src='%s', server='%s'" % (rec_type, rec_src, dns_server_ip))

    rdatatype = rec_type_to_rdatatype[rec_type]
    if rec_type == 'PTR':
        host = dns.reversename.from_address(rec_src)
    else:
        host = rec_src

    dns_request = dns.message.make_query(host, rdatatype, dns.rdataclass.IN)

    def convert_rdata_to_record(rdata):
        if rdata.rdtype == dns.rdatatype.A:
            return DNSRecord(rec_type='A', src=rec_src, dst=str(rdata.address), opt='')
        if rdata.rdtype == dns.rdatatype.AAAA:
            return DNSRecord(rec_type='AAAA', src=rec_src, dst=str(rdata.address), opt='')
        elif rdata.rdtype == dns.rdatatype.CNAME:
            return DNSRecord(rec_type='CNAME', src=rec_src, dst=str(rdata.target), opt='')
        elif rdata.rdtype == dns.rdatatype.TXT:
            return DNSRecord(rec_type='TXT', src=rec_src, dst=' '.join(rdata.strings), opt='')
        elif rdata.rdtype == dns.rdatatype.MX:
            return DNSRecord(rec_type='MX', src=rec_src, dst=str(rdata.exchange), opt=str(rdata.preference))
        elif rdata.rdtype == dns.rdatatype.NS:
            return DNSRecord(rec_type='NS', src=rec_src, dst=str(rdata.target), opt='')
        else:
            return UnknownDnsResponse(raw_data=rdata) 

    try:
        dns_response = dns.query.udp(dns_request, dns_server_ip, timeout=DNS_REQUEST_TIMEOUT)
        result = [
            convert_rdata_to_record(rdata)
            for rdset in dns_response.answer
            for rdata in rdset
        ]
        logger.debug(u"Received the DNS response: %s", result)
        return result
    except dns.exception.Timeout, e:
        logger.debug(u'Exception:', exc_info=e)
        raise DnsRequestTimeoutException(server_ip=dns_server_ip)


def get_difference_between_dns_servers(records, expected_dns_server_ip, actual_dns_server_ip):
    difference = []

    records_grouped = dict()
    for record in records:
        key = (record.rec_type, record.src)
        if key not in records_grouped:
            records_grouped[key] = list()
        records_grouped[key].append(record)

    for (rec_type, rec_src), expected in sorted(
        records_grouped.iteritems(), cmp=lambda x, y: cmp_records_src_type(x[0][0], x[0][1], y[0][0], y[0][1])
    ):
        if is_allowed_record_type(rec_type):
            logger.debug(
                u"Checking the record %s on the DNS server at %s and "
                u"comparing it to the response of the DNS server at %s",
                _pretty_record_str(record), actual_dns_server_ip, expected_dns_server_ip
            )
            actual = resolve_record(rec_type, rec_src, actual_dns_server_ip)
            expected = resolve_record(rec_type, rec_src, expected_dns_server_ip)
            comparison_result = compare_records(expected, actual)
            if comparison_result is not None:
                difference.append(comparison_result)
        else:
            pass  # skip all records we don't know how to handle

    return difference


class DnsServerInfo(object):
    def __init__(self, hostname, ips):
        self.hostname = hostname
        self.ips = ips


def get_authoritative_dns_servers(domain, dns_server_ip):
    dns_servers = []
    for ns_rec in resolve_record('NS', '%s.' % (domain,), dns_server_ip):
        dns_server = DnsServerInfo(hostname=ns_rec.dst, ips=[])
        dns_servers.append(dns_server)
        for a_rec in resolve_record('A', dns_server.hostname, dns_server_ip):
            if a_rec.rec_type == 'A':
                dns_server.ips.append(a_rec.dst)
    return dns_servers


DNS_REQUEST_TIMEOUT = 60


class DnsRequestTimeoutException(Exception):
    def __init__(self, server_ip):
        self.server_ip = server_ip


class UnknownDnsResponse(object):
    def __init__(self, raw_data):
        self.raw_data = raw_data

    def __str__(self):
        return str(self.raw_data)


class ComparableRec(object):
    def __init__(self, rec_type, src, dst, opt):
        self.rec_type = rec_type
        self.src = src
        self.dst = dst
        self.opt = opt

    @classmethod
    def from_rec(cls, rec):
        return cls(rec_type=rec.rec_type, src=rec.src, dst=rec.dst, opt=rec.opt)

    def __str__(self):
        return _pretty_record_str(self)

    def __eq__(self, rec):
        return (
            self.rec_type == rec.rec_type and 
            self.src == rec.src and 
            (
                # Fix for domainkey TXT records, which may have an additional ";" symbol at the end of the record after migration
                self.dst.lower() == rec.dst.lower() or
                (self.rec_type == 'TXT' and self.dst.rstrip(';') == rec.dst.rstrip(';'))
            ) and 
            self.opt == rec.opt
        )

    def __ne__(self, rec):
        return not self.__eq__(rec)


def _format_list(lst):
    return ", ".join(["'%s'" % (elem,) for elem in lst])


def _pretty_record_str(rec):
    if rec.opt is not None and rec.opt != "":
        opt_str = u" %s" % (rec.opt,) 
    else:
        opt_str = ""

    return u"%s %s%s %s" % (rec.src, rec.rec_type, opt_str, rec.dst)
