Pass tests using tokenization.

This commit is contained in:
kenkeiras 2018-04-15 21:10:49 +02:00
parent 45cc3a8a31
commit 1306306723
4 changed files with 30 additions and 14 deletions

View File

@ -406,22 +406,33 @@ def all_indexes(collection, element):
def all_matching_indexes(knowledge_base, collection, element): def all_matching_indexes(knowledge_base, collection, element):
indexes = [] indexes = []
assert("groups" in element) with session().log('Matching “{}'.format(element)):
element = element["groups"] assert("groups" in element)
for i, instance in enumerate(collection): element = element["groups"]
if isinstance(instance, dict): for i, instance in enumerate(collection):
instance = instance["groups"] session().log('Checking “{}'.format(instance))
elif instance in knowledge_base.knowledge:
instance = knowledge_base.knowledge[instance]["groups"]
intersection = set(instance) & set(element) if isinstance(instance, dict):
if (len(intersection) > 0 or (0 == len(instance) == len(element))): instance = instance["groups"]
indexes.append((i, intersection)) elif instance in knowledge_base.knowledge:
session().log('Knowledge about “{}”: ”{}'.format(instance, knowledge_base.knowledge[instance]))
return [x[0] for x in sorted(indexes, key=lambda x: len(x[1]), reverse=True)] if "groups" not in knowledge_base.knowledge[instance]:
# This means that is only known as token
# so we should try to avoid using it
continue
instance = knowledge_base.knowledge[instance]["groups"]
intersection = set(instance) & set(element)
if (len(intersection) > 0 or (0 == len(instance) == len(element))):
indexes.append((i, intersection))
return [x[0] for x in sorted(indexes, key=lambda x: len(x[1]), reverse=True)]
def element_matches_groups(knowledge, element: Dict, groups): def element_matches_groups(knowledge, element: Dict, groups):
with session().log("Checking if e “{}” matches groups “{}".format(element, groups)):
if isinstance(groups, str) and groups in knowledge: if isinstance(groups, str) and groups in knowledge:
return len(knowledge[groups].get("groups", set()) & element['groups']) > 0 return len(knowledge[groups].get("groups", set()) & element['groups']) > 0
elif isinstance(groups, dict): elif isinstance(groups, dict):

View File

@ -11,9 +11,9 @@ logging.getLogger().setLevel(logging.ERROR)
tests = ( tests = (
("tokenization", tokenization), ("tokenization", tokenization),
# ("basic", basic), ("basic", basic),
# ("gac 100", gac_100), ("gac 100", gac_100),
# ("gac+", gac_extension), ("gac+", gac_extension),
) )

View File

@ -668,6 +668,10 @@ base_knowledge = {
'electricity': { 'electricity': {
"groups": {'power'}, "groups": {'power'},
}, },
'airplanes': {},
'white': {
'groups': {'property'},
}
} }
def main(): def main():

View File

@ -22,4 +22,5 @@ def ask_then_learn_test(knowledge: KnowledgeBase):
def main(): def main():
knowledge = gac_100.main() knowledge = gac_100.main()
knowledge.knowledge['blue'] = {'groups': {'property'}}
knowledge = ask_then_learn_test(knowledge) knowledge = ask_then_learn_test(knowledge)