diff --git a/naive-nlu/tree_nlu/parsing.py b/naive-nlu/tree_nlu/parsing.py index b06e18b..1705286 100644 --- a/naive-nlu/tree_nlu/parsing.py +++ b/naive-nlu/tree_nlu/parsing.py @@ -406,22 +406,33 @@ def all_indexes(collection, element): def all_matching_indexes(knowledge_base, collection, element): indexes = [] - assert("groups" in element) - element = element["groups"] - for i, instance in enumerate(collection): - if isinstance(instance, dict): - instance = instance["groups"] - elif instance in knowledge_base.knowledge: - instance = knowledge_base.knowledge[instance]["groups"] + with session().log('Matching “{}”'.format(element)): + assert("groups" in element) + element = element["groups"] + for i, instance in enumerate(collection): + session().log('Checking “{}”'.format(instance)) - intersection = set(instance) & set(element) - if (len(intersection) > 0 or (0 == len(instance) == len(element))): - indexes.append((i, intersection)) + if isinstance(instance, dict): + instance = instance["groups"] + elif instance in knowledge_base.knowledge: + session().log('Knowledge about “{}”: ”{}”'.format(instance, knowledge_base.knowledge[instance])) - return [x[0] for x in sorted(indexes, key=lambda x: len(x[1]), reverse=True)] + if "groups" not in knowledge_base.knowledge[instance]: + # This means that is only known as token + # so we should try to avoid using it + continue + + instance = knowledge_base.knowledge[instance]["groups"] + + intersection = set(instance) & set(element) + if (len(intersection) > 0 or (0 == len(instance) == len(element))): + indexes.append((i, intersection)) + + return [x[0] for x in sorted(indexes, key=lambda x: len(x[1]), reverse=True)] def element_matches_groups(knowledge, element: Dict, groups): + with session().log("Checking if e “{}” matches groups “{}”".format(element, groups)): if isinstance(groups, str) and groups in knowledge: return len(knowledge[groups].get("groups", set()) & element['groups']) > 0 elif isinstance(groups, dict): diff --git a/naive-nlu/tree_nlu/test.py b/naive-nlu/tree_nlu/test.py index 11cd561..683f85e 100644 --- a/naive-nlu/tree_nlu/test.py +++ b/naive-nlu/tree_nlu/test.py @@ -11,9 +11,9 @@ logging.getLogger().setLevel(logging.ERROR) tests = ( ("tokenization", tokenization), - # ("basic", basic), - # ("gac 100", gac_100), - # ("gac+", gac_extension), + ("basic", basic), + ("gac 100", gac_100), + ("gac+", gac_extension), ) diff --git a/naive-nlu/tree_nlu/tests/gac_100.py b/naive-nlu/tree_nlu/tests/gac_100.py index 2e6bcf4..f4656fb 100644 --- a/naive-nlu/tree_nlu/tests/gac_100.py +++ b/naive-nlu/tree_nlu/tests/gac_100.py @@ -668,6 +668,10 @@ base_knowledge = { 'electricity': { "groups": {'power'}, }, + 'airplanes': {}, + 'white': { + 'groups': {'property'}, + } } def main(): diff --git a/naive-nlu/tree_nlu/tests/gac_extension.py b/naive-nlu/tree_nlu/tests/gac_extension.py index 5aae0a2..abb87ba 100644 --- a/naive-nlu/tree_nlu/tests/gac_extension.py +++ b/naive-nlu/tree_nlu/tests/gac_extension.py @@ -22,4 +22,5 @@ def ask_then_learn_test(knowledge: KnowledgeBase): def main(): knowledge = gac_100.main() + knowledge.knowledge['blue'] = {'groups': {'property'}} knowledge = ask_then_learn_test(knowledge)