|
| 1 | +from tatsu.model import NodeWalker |
| 2 | +from typing import Iterable, Optional, Tuple, List |
| 3 | + |
| 4 | +from textworld.textgen.model import TextGrammarModelBuilderSemantics |
| 5 | +from textworld.textgen.parser import TextGrammarParser |
| 6 | + |
| 7 | + |
| 8 | +class NewAlternative: |
| 9 | + """ |
| 10 | + A single alternative in a production rule. |
| 11 | + """ |
| 12 | + |
| 13 | + def full_form(self, include_adj=True) -> str: |
| 14 | + adj, noun = self.split_form(include_adj) |
| 15 | + if adj is None: |
| 16 | + return noun |
| 17 | + else: |
| 18 | + return adj + "|" + noun |
| 19 | + |
| 20 | + |
| 21 | +class LiteralChunk: |
| 22 | + """ |
| 23 | + It creates an object with a [str] value for every single literal. |
| 24 | + literal is defined as any string which is not a symbol, i.e. it is not bounded by hashtags. |
| 25 | + """ |
| 26 | + def __init__(self, value: str): |
| 27 | + self._value = value |
| 28 | + |
| 29 | + |
| 30 | +class SymbolChunk: |
| 31 | + """ |
| 32 | + It creates an object with a [str] value for every single symbol. |
| 33 | + symbol is defined as any string in between two consecutive hashtags, e.g. #it_is_a_symbol#. |
| 34 | + """ |
| 35 | + def __init__(self, value: str): |
| 36 | + self._value = value |
| 37 | + |
| 38 | + |
| 39 | +class NewLiteralAlternative(NewAlternative): |
| 40 | + """ |
| 41 | + An alternative from a literal string and represents it as a chunk of literal and symbol objects. |
| 42 | + """ |
| 43 | + def __init__(self, node: str): |
| 44 | + self._node = node |
| 45 | + # self._val_chunk contains the objects which make the string. |
| 46 | + # It is equivalent to self._value in LiteralAlternative. |
| 47 | + self._val_chunk = self._symbol_finder(self._node) |
| 48 | + |
| 49 | + def _symbol_finder(self, node): |
| 50 | + self.chunks = [] |
| 51 | + while node: |
| 52 | + is_has_tag = [i for i, ltr in enumerate(node) if ltr == '#'] |
| 53 | + if is_has_tag: |
| 54 | + if node[:is_has_tag[0]]: |
| 55 | + self.chunks.append(LiteralChunk(node[:is_has_tag[0]])) |
| 56 | + self.chunks.append(SymbolChunk(node[is_has_tag[0]:is_has_tag[1] + 1])) |
| 57 | + else: |
| 58 | + self.chunks.append(SymbolChunk(node[is_has_tag[0]:is_has_tag[1] + 1])) |
| 59 | + |
| 60 | + node = node[is_has_tag[1] + 1:] |
| 61 | + else: |
| 62 | + if node: |
| 63 | + self.chunks.append(LiteralChunk(node)) |
| 64 | + break |
| 65 | + return self.chunks |
| 66 | + |
| 67 | + def split_form(self, include_adj=True) -> Tuple[Optional[str], str]: |
| 68 | + return None, self._node |
| 69 | + |
| 70 | + |
| 71 | +class NewAdjectiveNounAlternative(NewLiteralAlternative): |
| 72 | + """ |
| 73 | + An alternative that specifies an adjective and a noun as chunk of objects. |
| 74 | + """ |
| 75 | + |
| 76 | + def __init__(self, adj_node: str, n_node: str): |
| 77 | + self._adj_node = adj_node |
| 78 | + self._n_node = n_node |
| 79 | + # self._adj_chunk contains the objects which make the adjective string. |
| 80 | + # self._noun_chunk contains the objects which make the noun string. |
| 81 | + # These are equivalent to self._adjective and self._noun in AdjectiveNounAlternative. |
| 82 | + self._adj_chunk = self._symbol_finder(self._adj_node) |
| 83 | + self._noun_chunk = self._symbol_finder(self._n_node) |
| 84 | + |
| 85 | + def split_form(self, include_adj=True) -> Tuple[Optional[str], str]: |
| 86 | + if include_adj: |
| 87 | + return self._adj_node, self._n_node |
| 88 | + else: |
| 89 | + return None, self._n_node |
| 90 | + |
| 91 | + |
| 92 | +class MatchAlternative(NewAlternative): |
| 93 | + """ |
| 94 | + An alternative that specifies matching names for two objects. |
| 95 | + """ |
| 96 | + |
| 97 | + def __init__(self, lhs: NewAlternative, rhs: NewAlternative): |
| 98 | + self.lhs = lhs |
| 99 | + self.rhs = rhs |
| 100 | + |
| 101 | + def full_form(self, include_adj=True) -> str: |
| 102 | + return self.lhs.full_form(include_adj) + " <-> " + self.rhs.full_form(include_adj) |
| 103 | + |
| 104 | + |
| 105 | +class ProductionRule: |
| 106 | + """ |
| 107 | + A production rule in a text grammar. |
| 108 | + """ |
| 109 | + |
| 110 | + def __init__(self, symbol: str, alternatives: Iterable[NewAlternative]): |
| 111 | + self.symbol = symbol |
| 112 | + self.alternatives = tuple(alternatives) |
| 113 | + |
| 114 | + |
| 115 | +class _Converter(NodeWalker): |
| 116 | + |
| 117 | + def walk_list(self, node): |
| 118 | + return [self.walk(child) for child in node] |
| 119 | + |
| 120 | + def walk_str(self, node): |
| 121 | + return node.replace("\\n", "\n") |
| 122 | + |
| 123 | + def walk_Literal(self, node): |
| 124 | + value = self.walk(node.value) |
| 125 | + if value: |
| 126 | + return NewLiteralAlternative(value) |
| 127 | + else: |
| 128 | + # Skip empty literals |
| 129 | + return None |
| 130 | + |
| 131 | + def walk_AdjectiveNoun(self, node): |
| 132 | + return NewAdjectiveNounAlternative(self.walk(node.adjective), self.walk(node.noun)) |
| 133 | + |
| 134 | + def walk_Match(self, node): |
| 135 | + return MatchAlternative(self.walk(node.lhs), self.walk(node.rhs)) |
| 136 | + |
| 137 | + def walk_ProductionRule(self, node): |
| 138 | + alts = [alt for alt in self.walk(node.alternatives) if alt is not None] |
| 139 | + return ProductionRule(node.symbol, alts) |
| 140 | + |
| 141 | + def walk_TextGrammar(self, node): |
| 142 | + return TextGrammar(self.walk(node.rules)) |
| 143 | + |
| 144 | + |
| 145 | +class TextGrammar: |
| 146 | + _PARSER = TextGrammarParser(semantics=TextGrammarModelBuilderSemantics(), parseinfo=True) |
| 147 | + _CONVERTER = _Converter() |
| 148 | + |
| 149 | + def __init__(self, rules): |
| 150 | + self.rules = {rule.symbol: rule for rule in rules} |
| 151 | + |
| 152 | + @classmethod |
| 153 | + def parse(cls, grammar: str, filename: Optional[str] = None): |
| 154 | + model = cls._PARSER.parse(grammar, filename=filename) |
| 155 | + return cls._CONVERTER.walk(model) |
0 commit comments