Make remix model more powerful.

Accept elements in the remix that are not present in the
subtrees.
This commit is contained in:
kenkeiras 2017-05-24 20:30:50 +02:00
parent d029ecd91d
commit bbba6b75e1
3 changed files with 25 additions and 12 deletions

View File

@ -127,7 +127,11 @@ def integrate_language(knowledge_base, example):
def apply_remix(tokens, remix): def apply_remix(tokens, remix):
rebuilt = [] rebuilt = []
for i in remix: for i in remix:
if isinstance(i, int):
rebuilt.append(tokens[i]) rebuilt.append(tokens[i])
else:
assert(isinstance(i, str))
rebuilt.append(i)
return rebuilt return rebuilt
@ -154,13 +158,14 @@ def get_possible_remixes(knowledge_base, matcher, similar_matcher):
for element in matcher: for element in matcher:
logging.debug("- {}".format(element)) logging.debug("- {}".format(element))
logging.debug("+ {}".format(similar_matcher)) logging.debug("+ {}".format(similar_matcher))
assert(element in similar_matcher or isinstance(element, dict)) if element in similar_matcher or isinstance(element, dict):
if isinstance(element, dict): if isinstance(element, dict):
indexes = all_matching_indexes(knowledge_base, similar_matcher, element) indexes = all_matching_indexes(knowledge_base, similar_matcher, element)
else: else:
indexes = all_indexes(similar_matcher, element) indexes = all_indexes(similar_matcher, element)
matrix.append(indexes) matrix.append(indexes)
else:
matrix.append([element])
# TODO: do some scoring to find the most "interesting combination" # TODO: do some scoring to find the most "interesting combination"
return [list(x) for x in list(zip(*matrix))] return [list(x) for x in list(zip(*matrix))]
@ -294,10 +299,16 @@ def reprocess_language_knowledge(knowledge_base, examples):
def reverse_remix(tree_section, remix): def reverse_remix(tree_section, remix):
result_section = [] result_section = []
offset = 0
for origin in remix: for origin in remix:
if isinstance(origin, int):
if origin >= len(tree_section): if origin >= len(tree_section):
return None return None
result_section.append(copy.deepcopy(tree_section[origin]))
result_section.append(copy.deepcopy(tree_section[origin + offset]))
else:
assert(isinstance(origin, str))
offset += 1
return result_section + tree_section[len(remix):] return result_section + tree_section[len(remix):]

View File

@ -17,7 +17,9 @@ def main():
test_module.main() test_module.main()
print(" \x1b[1;32m✓\x1b[0m {}".format(test_name)) print(" \x1b[1;32m✓\x1b[0m {}".format(test_name))
except AssertionError as ae: except AssertionError as ae:
print(" \x1b[1;31m✗\x1b[0m {}: {}".format(test_name, ae.args[0])) print(" \x1b[1;31m✗\x1b[0m {}: {}".format(test_name,
ae.args[0] if len(ae.args) > 0
else '\b\b \b'))
failed = True failed = True
except Exception as e: except Exception as e:

View File

@ -36,7 +36,7 @@ examples = [
('full_example', ('full_example',
{ {
"text": "Is it hot during the summer?", "text": "Is it hot during the summer?",
"affirmation": "it is hot during the summer", "affirmation": "it is hot during summer",
"parsed": ("question", "parsed": ("question",
("implies", 'summer', 'hot')), ("implies", 'summer', 'hot')),
"answer": True, "answer": True,