#!/usr/bin/env python

from . import knowledge_evaluation

from . import depth_meter
import logging
import re
import copy

from functools import reduce
from typing import List
from .modifiable_property import ModifiableProperty
from . import parameters

# TODO: more flexible tokenization
def to_tokens(text):
    return re.findall(r'(\w+|[^\s])', text)


def make_template(knowledge_base, tokens, parsed):
    matcher = list(tokens)
    template = list(parsed)
    for i in range(len(matcher)):
        word = matcher[i]
        if word in template:
            template[template.index(word)] = i
            matcher[i] = {
                'groups': set(knowledge_base.knowledge[word]['groups'])
            }
    return tokens, matcher, template


def is_bottom_level(tree):
    for element in tree:
        if isinstance(element, list) or isinstance(element, tuple):
            return False
    return True


def get_lower_levels(parsed):
    lower = []
    def aux(subtree, path):
        nonlocal lower
        deeper = len(path) == 0
        for i, element in enumerate(subtree):
            if isinstance(element, list) or isinstance(element, tuple):
                aux(element, path + (i,))
                deeper = True

        if not deeper:
            lower.append((path, subtree))

    aux(parsed, path=())
    return lower


# TODO: probably optimize this, it creates lots of unnecessary tuples
def replace_position(tree, position, new_element):

    def aux(current_tree, remaining_route):
        if len(remaining_route) == 0:
            return new_element

        else:
            step = remaining_route[0]
            return (
                tree[:step]
                + (aux(tree[step], remaining_route[1:]),)
                + tree[step + 2:]
            )

    return aux(tree, position)


def integrate_language(knowledge_base, example):
    text = example["text"].lower()
    parsed = example["parsed"]

    resolved_parsed = copy.deepcopy(parsed)
    tokens = to_tokens(text)

    while True:
        logging.debug("P: {}".format(resolved_parsed))
        lower_levels = get_lower_levels(resolved_parsed)
        logging.debug("Lower: {}".format(lower_levels))
        if len(lower_levels) == 0:
            break

        for position, atom in lower_levels:
            logging.debug("\x1b[1mSelecting\x1b[0m: {}".format(atom))
            similar = get_similar_tree(knowledge_base, atom)
            remix, (start_bounds, end_bounds) = build_remix_matrix(knowledge_base, tokens, atom, similar)
            _, matcher, result = make_template(knowledge_base, tokens, atom)
            logging.debug("Tx: {}".format(tokens))
            logging.debug("Mx: {}".format(matcher))
            logging.debug("Rx: {}".format(result))
            logging.debug("Remix: {}".format(remix))

            after_remix = apply_remix(tokens[len(start_bounds):-len(end_bounds)], remix)
            assert(len(after_remix) + len(start_bounds) + len(end_bounds) == len(tokens))
            logging.debug( "  +-> {}".format(after_remix))
            subquery_type = knowledge_evaluation.get_subquery_type(knowledge_base.knowledge, atom)
            logging.debug(r"  \-> <{}>".format(subquery_type))

            # Clean remaining tokens
            new_tokens = list(tokens)
            offset = len(start_bounds)
            for _ in range(len(remix)):
                new_tokens.pop(offset)

            # TODO: Get a specific types for... types
            new_tokens.insert(offset, (subquery_type, remix))
            tokens = new_tokens

            resolved_parsed = replace_position(resolved_parsed, position, offset)
            logging.debug("#########")


    tokens, matcher, result = make_template(knowledge_base, tokens, resolved_parsed)
    logging.debug("T: {}".format(tokens))
    logging.debug("M: {}".format(matcher))
    logging.debug("R: {}".format(result))
    logging.debug("---")
    return tokens, matcher, result


def apply_remix(tokens, remix):
    rebuilt = []
    for i in remix:
        rebuilt.append(tokens[i])
    return rebuilt


def build_remix_matrix(knowledge_base, tokens, atom, similar):
    tokens = list(tokens)
    tokens, matcher, result = make_template(knowledge_base, tokens, atom)
    similar_matcher, similar_result, similar_result_resolved, _ = similar

    start_bounds, end_bounds = find_bounds(matcher, similar_matcher)

    for i, element in (end_bounds + start_bounds[::-1]):
        matcher.pop(i)
        tokens.pop(i)

    possible_remixes = get_possible_remixes(knowledge_base, matcher, similar_matcher)
    chosen_remix = possible_remixes[0]

    return chosen_remix, (start_bounds, end_bounds)


def get_possible_remixes(knowledge_base, matcher, similar_matcher):

    matrix = []
    for element in matcher:
        logging.debug("- {}".format(element))
        logging.debug("+ {}".format(similar_matcher))
        assert(element in similar_matcher or isinstance(element, dict))

        if isinstance(element, dict):
            indexes = all_matching_indexes(knowledge_base, similar_matcher, element)
        else:
            indexes = all_indexes(similar_matcher, element)
        matrix.append(indexes)

    # TODO: do some scoring to find the most "interesting combination"
    return [list(x) for x in list(zip(*matrix))]


def all_indexes(collection, element):
    indexes = []
    base = 0

    for _ in range(collection.count(element)):
        i = collection.index(element, base)
        base = i + 1
        indexes.append(i)

    return indexes


def all_matching_indexes(knowledge_base, collection, element):
    indexes = []

    assert("groups" in element)
    element = element["groups"]
    for i, instance in enumerate(collection):
        if isinstance(instance, dict):
            instance = instance["groups"]
        elif instance in knowledge_base.knowledge:
            instance = knowledge_base.knowledge[instance]["groups"]

        intersection = set(instance) & set(element)
        if len(intersection) > 0:
            indexes.append((i, intersection))

    return [x[0] for x in sorted(indexes, key=lambda x: len(x[1]), reverse=True)]


def find_bounds(matcher, similar_matcher):
    start_bounds = []
    for i, element in enumerate(matcher):
        if element in similar_matcher:
            break
        else:
            start_bounds.append((i, element))

    end_bounds = []
    for i, element in enumerate(matcher[::-1]):
        if element in similar_matcher:
            break
        else:
            end_bounds.append((len(matcher) - (i + 1), element))

    return start_bounds, end_bounds


def get_similar_tree(knowledge_base, atom):
    possibilities = []

    # Find matching possibilities
    for entry, tree in knowledge_base.trained:
        if not is_bottom_level(tree):
            continue
        if tree[0] == atom[0]:
            possibilities.append((entry, tree))

    # Sort by more matching elements
    sorted_possibilities = []
    for (raw, possibility) in possibilities:
        resolved = []
        for element in atom:
            if isinstance(element, str):
                resolved.append(element)
            else:
                resolved.append(knowledge_evaluation.resolve(
                    knowledge_base.knowledge,
                    element,
                    raw))

        # TODO: Probably should take into account the categories of the elements in the "intake" ([0]) element
        score = sum([resolved[i] == atom[i]
                     for i
                     in range(min(len(resolved),
                                  len(atom)))])
        sorted_possibilities.append((raw, possibility, resolved, score))
    sorted_possibilities = sorted(sorted_possibilities, key=lambda p: p[3], reverse=True)
    if len(sorted_possibilities) < 1:
        return None

    return sorted_possibilities[0]


# TODO: unroll this mess
def get_matching(sample, other):
    l = len(sample[0])
    other = list(filter(lambda x: len(x[0]) == l, other))
    for i in range(l):
        if len(other) == 0:
            return []

        if isinstance(sample[0][i], dict): # Dictionaries are compared by groups
            other = list(filter(lambda x: isinstance(x[0][i], dict) and
                                len(x[0][i]['groups'] & sample[0][i]['groups']) > 0,
                                other))

        elif isinstance(sample[0][i], tuple): # Tuples are compared by types [0]
            other = list(filter(lambda x: isinstance(x[0][i], tuple) and
                                x[0][i][0] == sample[0][i][0],
                                other))

    return [sample[0][x] if isinstance(sample[0][x], str)
            else
            sample[0][x] if isinstance(sample[0][x], tuple)
            else {'groups': sample[0][x]['groups'] & reduce(lambda a, b: a & b,
                                                            map(lambda y: y[0][x]['groups'],
                                                                other))}
            for x
            in range(l)]


def reprocess_language_knowledge(knowledge_base, examples):
    examples = knowledge_base.examples + examples

    pattern_examples = []
    for i, sample in enumerate(examples):
        other = examples[:i] + examples[i + 1:]
        match = get_matching(sample, other)
        if len(match) > 0:
            sample = (match, sample[1],)
        pattern_examples.append(sample)

    return pattern_examples


def reverse_remix(tree_section, remix):
    result_section = []
    for origin in remix:
        if origin >= len(tree_section):
            return None
        result_section.append(copy.deepcopy(tree_section[origin]))
    return result_section + tree_section[len(remix):]


def get_fit(knowledge, tokens, remaining_recursions=parameters.MAX_RECURSIONS):
    for matcher, ast in knowledge.trained:
        result = match_fit(knowledge, tokens, matcher, ast,
                           remaining_recursions)
        if result is not None:
            return result

    return None


def is_definite_minisegment(minisegment):
    return isinstance(minisegment, str) or isinstance(minisegment, dict)


def match_token(knowledge, next_token, minisegment):
    if isinstance(minisegment, dict):
        # TODO: check if the dictionary matches the values
        return True
    elif isinstance(minisegment, str):
        # TODO: check if the two elements can be used in each other place
        return next_token == minisegment

    return False


def resolve_fit(knowledge, fit, remaining_recursions):
    fitted = []
    for element in fit:
        if is_definite_minisegment(element):
            fitted.append(element)
        else:
            ((result_type, remixer), tokens) = element
            remixed_tokens = reverse_remix(tokens, remixer)
            if remixed_tokens is None:
                return None

            minifit = get_fit(knowledge, remixed_tokens, remaining_recursions - 1)
            if minifit is None:
                return None

            minitokens, miniast = minifit
            subproperty = knowledge_evaluation.resolve(knowledge.knowledge, minitokens, miniast)
            fitted.append(subproperty)

    return fitted


def match_fit(knowledge, tokens, matcher, ast, remaining_recursions):
    segment_possibilities = [([], tokens)]  # Matched tokens, remaining tokens
    for minisegment in matcher:
        possibilities_after_round = []
        for matched_tokens, remaining_tokens in segment_possibilities:
            if len(remaining_tokens) < 1:
                continue

            if is_definite_minisegment(minisegment):
                if match_token(knowledge, remaining_tokens[0], minisegment):
                    possibilities_after_round.append((
                        matched_tokens + [remaining_tokens[0]],
                        remaining_tokens[1:]
                    ))
            else:
                # TODO: optimize this with a look ahead
                for i in range(1, len(tokens)):
                    possibilities_after_round.append((
                        matched_tokens + [(minisegment, remaining_tokens[:i])],
                        remaining_tokens[i:]
                    ))
        else:
            segment_possibilities = possibilities_after_round

    fully_matched_segments = [(matched, remaining)
                              for (matched, remaining)
                              in segment_possibilities
                              if len(remaining) == 0]

    resolved_fits = []
    for fit, _ in fully_matched_segments:
        resolved_fit = resolve_fit(knowledge, fit, remaining_recursions)
        if resolved_fit is not None:
            resolved_fits.append(resolved_fit)

    if len(resolved_fits) == 0:
        return None

    return resolved_fits[0], ast