from functools import partial
import collections
import datetime
import asyncio
import time
import lxml.etree
import io

common_replacements = {
    'irc_server_one': 'irc.localhost@biboumi.localhost',
    'irc_server_two': 'localhost@biboumi.localhost',
    'irc_host_one': 'irc.localhost',
    'irc_host_two': 'localhost',
    'biboumi_host': 'biboumi.localhost',
    'resource_one': 'resource1',
    'resource_two': 'resource2',
    'nick_one': 'Nick',
    'jid_one': 'first@example.com',
    'jid_two': 'second@example.com',
    'jid_admin': 'admin@example.com',
    'nick_two': 'Nick2',
    'nick_three': 'Nick3',
    'lower_nick_one': 'nick',
    'lower_nick_two': 'nick2',
}

class SkipStepError(Exception):
    """
    Raised by a step when it needs to be skiped, by running
    the next available step immediately.
    """
    pass

class StanzaError(Exception):
    """
    Raised when a step fails.
    """
    pass

def match(stanza, xpath):
    tree = lxml.etree.parse(io.StringIO(str(stanza)))
    matched = tree.xpath(xpath, namespaces={'re': 'http://exslt.org/regular-expressions',
                                            'muc_user': 'http://jabber.org/protocol/muc#user',
                                            'muc_owner': 'http://jabber.org/protocol/muc#owner',
                                            'muc': 'http://jabber.org/protocol/muc',
                                            'disco_info': 'http://jabber.org/protocol/disco#info',
                                            'muc_traffic': 'http://jabber.org/protocol/muc#traffic',
                                            'disco_items': 'http://jabber.org/protocol/disco#items',
                                            'commands': 'http://jabber.org/protocol/commands',
                                            'dataform': 'jabber:x:data',
                                            'version': 'jabber:iq:version',
                                            'mam': 'urn:xmpp:mam:2',
                                            'rms': 'http://jabber.org/protocol/rsm',
                                            'delay': 'urn:xmpp:delay',
                                            'forward': 'urn:xmpp:forward:0',
                                            'client': 'jabber:client',
                                            'rsm': 'http://jabber.org/protocol/rsm',
                                            'carbon': 'urn:xmpp:carbons:2',
                                            'hints': 'urn:xmpp:hints',
                                            'stanza': 'urn:ietf:params:xml:ns:xmpp-stanzas',
                                            'stable_id': 'urn:xmpp:sid:0'})
    return matched

def check_xpath(xpaths, xmpp, after, stanza):
    for xpath in xpaths:
        expected = True
        real_xpath = xpath
        # We can check that a stanza DOESN’T match, by adding a ! before it.
        if xpath.startswith('!'):
            expected = False
            xpath = xpath[1:]
        matched = match(stanza, xpath)
        if (expected and not matched) or (not expected and matched):
            raise StanzaError("Received stanza\n%s\ndid not match expected xpath\n%s" % (stanza, real_xpath))
    if after:
        if isinstance(after, collections.Iterable):
            for af in after:
                af(stanza, xmpp)
        else:
            after(stanza, xmpp)

def check_xpath_optional(xpaths, xmpp, after, stanza):
    try:
        check_xpath(xpaths, xmpp, after, stanza)
    except StanzaError:
        raise SkipStepError()

def all_xpaths_match(stanza, xpaths):
    try:
        check_xpath(xpaths, None, None, stanza)
    except StanzaError:
        return False
    return True

def check_list_of_xpath(list_of_xpaths, xmpp, stanza):
    found = False
    for i, xpaths in enumerate(list_of_xpaths):
        if all_xpaths_match(stanza, xpaths):
            found = True
            list_of_xpaths.pop(i)
            break

    if not found:
        raise StanzaError("Received stanza “%s” did not match any of the expected xpaths:\n%s" % (stanza, list_of_xpaths))

    if list_of_xpaths:
        step = partial(expect_unordered_already_formatted, list_of_xpaths)
        xmpp.scenario.steps.insert(0, step)

def extract_attribute(xpath, name):
    def f(xpath, name, stanza):
        matched = match(stanza, xpath)
        return matched[0].get(name)
    return partial(f, xpath, name)

def extract_text(xpath, stanza):
    matched = match(stanza, xpath)
    return matched[0].text

def save_value(name, func):
    def f(name, func, stanza, xmpp):
        xmpp.saved_values[name] = func(stanza)
    return partial(f, name, func)

def expect_stanza(*args, optional=False, after=None):
    def f(*xpaths, xmpp, biboumi, optional, after):
        replacements = common_replacements
        replacements.update(xmpp.saved_values)
        check_func = check_xpath if not optional else check_xpath_optional
        formatted_xpaths = [xpath.format_map(replacements) for xpath in xpaths]
        xmpp.stanza_checker = partial(check_func, formatted_xpaths, xmpp, after)
        xmpp.timeout_handler = asyncio.get_event_loop().call_later(10, partial(xmpp.on_timeout, formatted_xpaths))
    return partial(f, *args, optional=optional, after=after)

def send_stanza(stanza):
    def internal(stanza, xmpp, biboumi):
        replacements = common_replacements
        replacements.update(xmpp.saved_values)
        xmpp.send_raw(stanza.format_map(replacements))
        asyncio.get_event_loop().call_soon(xmpp.run_scenario)
    return partial(internal, stanza)

def expect_unordered(*args):
    def f(*lists_of_xpaths, xmpp, biboumi):
        formatted_list_of_xpaths = []
        for list_of_xpaths in lists_of_xpaths:
            formatted_xpaths = []
            for xpath in list_of_xpaths:
                formatted_xpath = xpath.format_map(common_replacements)
                formatted_xpaths.append(formatted_xpath)
            formatted_list_of_xpaths.append(tuple(formatted_xpaths))
        expect_unordered_already_formatted(formatted_list_of_xpaths, xmpp, biboumi)
        xmpp.timeout_handler = asyncio.get_event_loop().call_later(10, partial(xmpp.on_timeout, formatted_list_of_xpaths))
    return partial(f, *args)

def expect_unordered_already_formatted(formatted_list_of_xpaths, xmpp, biboumi):
    xmpp.stanza_checker = partial(check_list_of_xpath, formatted_list_of_xpaths, xmpp)

def sleep_for(duration):
    def f(duration, xmpp, biboumi):
        time.sleep(duration)
        asyncio.get_event_loop().call_soon(xmpp.run_scenario)
    return partial(f, duration)

def save_current_timestamp_plus_delta(key, delta):
    def f(key, delta, message, xmpp):
        now_plus_delta = datetime.datetime.utcnow() + delta
        xmpp.saved_values[key] = now_plus_delta.strftime("%FT%T.967Z")
    return partial(f, key, delta)
