Make remix model more powerful.

Accept elements in the remix that are not present in the
subtrees.
This commit is contained in:
kenkeiras 2017-05-24 20:30:50 +02:00
parent d029ecd91d
commit bbba6b75e1
3 changed files with 25 additions and 12 deletions

View File

@ -127,7 +127,11 @@ def integrate_language(knowledge_base, example):
def apply_remix(tokens, remix):
rebuilt = []
for i in remix:
rebuilt.append(tokens[i])
if isinstance(i, int):
rebuilt.append(tokens[i])
else:
assert(isinstance(i, str))
rebuilt.append(i)
return rebuilt
@ -154,13 +158,14 @@ def get_possible_remixes(knowledge_base, matcher, similar_matcher):
for element in matcher:
logging.debug("- {}".format(element))
logging.debug("+ {}".format(similar_matcher))
assert(element in similar_matcher or isinstance(element, dict))
if isinstance(element, dict):
indexes = all_matching_indexes(knowledge_base, similar_matcher, element)
if element in similar_matcher or isinstance(element, dict):
if isinstance(element, dict):
indexes = all_matching_indexes(knowledge_base, similar_matcher, element)
else:
indexes = all_indexes(similar_matcher, element)
matrix.append(indexes)
else:
indexes = all_indexes(similar_matcher, element)
matrix.append(indexes)
matrix.append([element])
# TODO: do some scoring to find the most "interesting combination"
return [list(x) for x in list(zip(*matrix))]
@ -294,10 +299,16 @@ def reprocess_language_knowledge(knowledge_base, examples):
def reverse_remix(tree_section, remix):
result_section = []
offset = 0
for origin in remix:
if origin >= len(tree_section):
return None
result_section.append(copy.deepcopy(tree_section[origin]))
if isinstance(origin, int):
if origin >= len(tree_section):
return None
result_section.append(copy.deepcopy(tree_section[origin + offset]))
else:
assert(isinstance(origin, str))
offset += 1
return result_section + tree_section[len(remix):]

View File

@ -17,7 +17,9 @@ def main():
test_module.main()
print(" \x1b[1;32m✓\x1b[0m {}".format(test_name))
except AssertionError as ae:
print(" \x1b[1;31m✗\x1b[0m {}: {}".format(test_name, ae.args[0]))
print(" \x1b[1;31m✗\x1b[0m {}: {}".format(test_name,
ae.args[0] if len(ae.args) > 0
else '\b\b \b'))
failed = True
except Exception as e:

View File

@ -36,7 +36,7 @@ examples = [
('full_example',
{
"text": "Is it hot during the summer?",
"affirmation": "it is hot during the summer",
"affirmation": "it is hot during summer",
"parsed": ("question",
("implies", 'summer', 'hot')),
"answer": True,