33
44
55import textwrap
6+ from typing import List , Any , Iterable
67
78from textworld .utils import uniquify
89
@@ -18,42 +19,51 @@ class DependencyTreeElement:
1819 `__str__` accordingly.
1920 """
2021
21- def __init__ (self , value ):
22+ def __init__ (self , value : Any ):
2223 self .value = value
24+ self .parent = None
2325
24- def depends_on (self , other ) :
26+ def depends_on (self , other : "DependencyTreeElement" ) -> bool :
2527 """
2628 Check whether this element depends on the `other`.
2729 """
2830 return self .value > other .value
2931
30- def is_distinct_from (self , others ) :
32+ def is_distinct_from (self , others : Iterable [ "DependencyTreeElement" ]) -> bool :
3133 """
3234 Check whether this element is distinct from `others`.
3335 """
3436 return self .value not in [other .value for other in others ]
3537
36- def __str__ (self ):
38+ def __str__ (self ) -> str :
3739 return str (self .value )
3840
3941
4042class DependencyTree :
4143 class _Node :
42- def __init__ (self , element ):
44+ def __init__ (self , element : DependencyTreeElement ):
4345 self .element = element
4446 self .children = []
47+ self .parent = None
4548
46- def push (self , node ) :
49+ def push (self , node : "DependencyTree._Node" ) -> bool :
4750 if node == self :
48- return
51+ return True
4952
53+ added = False
5054 for child in self .children :
51- child .push (node )
55+ added |= child .push (node )
5256
5357 if self .element .depends_on (node .element ) and not self .already_added (node ):
58+ node = node .copy ()
5459 self .children .append (node )
60+ node .element .parent = self .element
61+ node .parent = self
62+ return True
5563
56- def already_added (self , node ):
64+ return added
65+
66+ def already_added (self , node : "DependencyTree._Node" ) -> bool :
5767 # We want to avoid duplicate information about dependencies.
5868 if node in self .children :
5969 return True
@@ -63,14 +73,15 @@ def already_added(self, node):
6373 if not node .element .is_distinct_from ((child .element for child in self .children )):
6474 return True
6575
66- # for child in self.children:
67- # # if node.element.value == child.element.value:
68- # if not node.element.is_distinct_from((child.element):
69- # return True
70-
7176 return False
7277
73- def __str__ (self ):
78+ def __iter__ (self ) -> Iterable ["DependencyTree._Node" ]:
79+ for child in self .children :
80+ yield from list (child )
81+
82+ yield self
83+
84+ def __str__ (self ) -> str :
7485 node_text = str (self .element )
7586
7687 txt = [node_text ]
@@ -79,85 +90,112 @@ def __str__(self):
7990
8091 return "\n " .join (txt )
8192
82- def copy (self ):
93+ def copy (self ) -> "DependencyTree._Node" :
8394 node = DependencyTree ._Node (self .element )
84- node .children = [child .copy () for child in self .children ]
95+ for child in self .children :
96+ child_ = child .copy ()
97+ child_ .parent = node
98+ node .children .append (child_ )
99+
85100 return node
86101
87- def __init__ (self , element_type = DependencyTreeElement ):
88- self .root = None
102+ def __init__ (self , element_type : type = DependencyTreeElement , trees : Iterable [ "DependencyTree" ] = [] ):
103+ self .roots = []
89104 self .element_type = element_type
105+ for tree in trees :
106+ self .roots += [root .copy () for root in tree .roots ]
107+
90108 self ._update ()
91109
92- def push (self , value ):
110+ def push (self , value : Any , allow_multi_root : bool = False ) -> bool :
111+ """ Add a value to this dependency tree.
112+
113+ Adding a value already present in the tree does not modify the tree.
114+
115+ Args:
116+ value: value to add.
117+ allow_multi_root: if `True`, allow the value to spawn an
118+ additional root if needed.
119+
120+ """
93121 element = self .element_type (value )
94122 node = DependencyTree ._Node (element )
95- if self .root is None :
96- self .root = node
97- else :
98- self .root .push (node )
99123
100- # Recompute leaves.
101- self ._update ()
102- if element in self .leaves_elements :
103- return node
124+ added = False
125+ for root in self .roots :
126+ added |= root .push (node )
104127
105- return None
128+ if len (self .roots ) == 0 or (not added and allow_multi_root ):
129+ self .roots .append (node )
130+ added = True
106131
107- def pop (self , value ):
108- if value not in self .leaves_values :
109- raise ValueError ("That element is not a leaf: {!r}." .format (value ))
132+ self ._update () # Recompute leaves.
133+ return added
110134
111- def _visit (node ):
112- for child in list (node .children ):
113- if child .element .value == value :
114- node .children .remove (child )
135+ def remove (self , value : Any ) -> None :
136+ """ Remove all leaves having the given value.
115137
116- self ._postorder (self .root , _visit )
117- if self .root .element .value == value :
118- self .root = None
138+ The value to remove needs to belong to at least one leaf in this tree.
139+ Otherwise, the tree remains unchanged.
119140
120- # Recompute leaves.
121- self ._update ()
141+ Args:
142+ value: value to remove from the tree.
143+
144+ Returns:
145+ Whether the tree has changed or not.
146+ """
147+ if value not in self .leaves_values :
148+ return False
149+
150+ root_to_remove = []
151+ for node in self :
152+ if node .element .value == value :
153+ if node .parent is not None :
154+ node .parent .children .remove (node )
155+ else :
156+ root_to_remove .append (node )
122157
123- def _postorder (self , node , visit ):
124- for child in node .children :
125- self ._postorder (child , visit )
158+ for node in root_to_remove :
159+ self .roots .remove (node )
126160
127- visit (node )
161+ self ._update () # Recompute leaves.
162+ return True
128163
129- def _update (self ):
164+ def _update (self ) -> None :
130165 self ._leaves_values = []
131- self ._leaves_elements = set ()
166+ self ._leaves_elements = []
132167
133- def _visit ( node ) :
168+ for node in self :
134169 if len (node .children ) == 0 :
135- self ._leaves_elements .add (node .element )
170+ self ._leaves_elements .append (node .element )
136171 self ._leaves_values .append (node .element .value )
137172
138- if self .root is not None :
139- self ._postorder (self .root , _visit )
140-
141173 self ._leaves_values = uniquify (self ._leaves_values )
174+ self ._leaves_elements = uniquify (self ._leaves_elements )
142175
143- def copy (self ):
144- tree = DependencyTree (self .element_type )
145- if self .root is not None :
146- tree .root = self .root .copy ()
147- tree ._update ()
176+ def copy (self ) -> "DependencyTree" :
177+ tree = type (self )(element_type = self .element_type )
178+ for root in self .roots :
179+ tree .roots .append (root .copy ())
148180
181+ tree ._update ()
149182 return tree
150183
184+ def __iter__ (self ) -> Iterable ["DependencyTree._Node" ]:
185+ for root in self .roots :
186+ yield from list (root )
187+
151188 @property
152- def leaves_elements (self ):
189+ def values (self ) -> List [Any ]:
190+ return [node .element .value for node in self ]
191+
192+ @property
193+ def leaves_elements (self ) -> List [DependencyTreeElement ]:
153194 return self ._leaves_elements
154195
155196 @property
156- def leaves_values (self ):
197+ def leaves_values (self ) -> List [ Any ] :
157198 return self ._leaves_values
158199
159- def __str__ (self ):
160- if self .root is None :
161- return ""
162-
163- return str (self .root )
200+ def __str__ (self ) -> str :
201+ return "\n " .join (map (str , self .roots ))
0 commit comments