Skip to content

Commit c77a105

Browse files
authored
Merge pull request #42 from MarcCote/enh_dependency_tree
Support subquests - part 1
2 parents fc31a4a + 9025cc8 commit c77a105

File tree

7 files changed

+533
-234
lines changed

7 files changed

+533
-234
lines changed

tests/test_play_generated_games.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def test_play_generated_games():
4747
game_state, reward, done = env.step(command)
4848

4949
if done:
50-
msg = "Finished before playing `max_steps` steps."
50+
msg = "Finished before playing `max_steps` steps because of command '{}'.".format(command)
5151
if game_state.has_won:
5252
msg += " (winning)"
53-
assert game_state._game_progression.winning_policy == []
53+
assert len(game_state._game_progression.winning_policy) == 0
5454

5555
if game_state.has_lost:
5656
msg += " (losing)"

textworld/envs/glulx/git_glulx_ml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def init(self, output: str, game=None,
146146
"""
147147
output = _strip_input_prompt_symbol(output)
148148
super().init(output)
149-
self._game_progression = GameProgression(game, track_quest=compute_intermediate_reward)
149+
self._game_progression = GameProgression(game, track_quests=compute_intermediate_reward)
150150
self._state_tracking = state_tracking
151151
self._compute_intermediate_reward = compute_intermediate_reward and len(game.quests) > 0
152152

textworld/generator/dependency_tree.py

Lines changed: 102 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44

55
import textwrap
6+
from typing import List, Any, Iterable
67

78
from textworld.utils import uniquify
89

@@ -18,42 +19,51 @@ class DependencyTreeElement:
1819
`__str__` accordingly.
1920
"""
2021

21-
def __init__(self, value):
22+
def __init__(self, value: Any):
2223
self.value = value
24+
self.parent = None
2325

24-
def depends_on(self, other):
26+
def depends_on(self, other: "DependencyTreeElement") -> bool:
2527
"""
2628
Check whether this element depends on the `other`.
2729
"""
2830
return self.value > other.value
2931

30-
def is_distinct_from(self, others):
32+
def is_distinct_from(self, others: Iterable["DependencyTreeElement"]) -> bool:
3133
"""
3234
Check whether this element is distinct from `others`.
3335
"""
3436
return self.value not in [other.value for other in others]
3537

36-
def __str__(self):
38+
def __str__(self) -> str:
3739
return str(self.value)
3840

3941

4042
class DependencyTree:
4143
class _Node:
42-
def __init__(self, element):
44+
def __init__(self, element: DependencyTreeElement):
4345
self.element = element
4446
self.children = []
47+
self.parent = None
4548

46-
def push(self, node):
49+
def push(self, node: "DependencyTree._Node") -> bool:
4750
if node == self:
48-
return
51+
return True
4952

53+
added = False
5054
for child in self.children:
51-
child.push(node)
55+
added |= child.push(node)
5256

5357
if self.element.depends_on(node.element) and not self.already_added(node):
58+
node = node.copy()
5459
self.children.append(node)
60+
node.element.parent = self.element
61+
node.parent = self
62+
return True
5563

56-
def already_added(self, node):
64+
return added
65+
66+
def already_added(self, node: "DependencyTree._Node") -> bool:
5767
# We want to avoid duplicate information about dependencies.
5868
if node in self.children:
5969
return True
@@ -63,14 +73,15 @@ def already_added(self, node):
6373
if not node.element.is_distinct_from((child.element for child in self.children)):
6474
return True
6575

66-
# for child in self.children:
67-
# # if node.element.value == child.element.value:
68-
# if not node.element.is_distinct_from((child.element):
69-
# return True
70-
7176
return False
7277

73-
def __str__(self):
78+
def __iter__(self) -> Iterable["DependencyTree._Node"]:
79+
for child in self.children:
80+
yield from list(child)
81+
82+
yield self
83+
84+
def __str__(self) -> str:
7485
node_text = str(self.element)
7586

7687
txt = [node_text]
@@ -79,85 +90,112 @@ def __str__(self):
7990

8091
return "\n".join(txt)
8192

82-
def copy(self):
93+
def copy(self) -> "DependencyTree._Node":
8394
node = DependencyTree._Node(self.element)
84-
node.children = [child.copy() for child in self.children]
95+
for child in self.children:
96+
child_ = child.copy()
97+
child_.parent = node
98+
node.children.append(child_)
99+
85100
return node
86101

87-
def __init__(self, element_type=DependencyTreeElement):
88-
self.root = None
102+
def __init__(self, element_type: type = DependencyTreeElement, trees: Iterable["DependencyTree"] = []):
103+
self.roots = []
89104
self.element_type = element_type
105+
for tree in trees:
106+
self.roots += [root.copy() for root in tree.roots]
107+
90108
self._update()
91109

92-
def push(self, value):
110+
def push(self, value: Any, allow_multi_root: bool = False) -> bool:
111+
""" Add a value to this dependency tree.
112+
113+
Adding a value already present in the tree does not modify the tree.
114+
115+
Args:
116+
value: value to add.
117+
allow_multi_root: if `True`, allow the value to spawn an
118+
additional root if needed.
119+
120+
"""
93121
element = self.element_type(value)
94122
node = DependencyTree._Node(element)
95-
if self.root is None:
96-
self.root = node
97-
else:
98-
self.root.push(node)
99123

100-
# Recompute leaves.
101-
self._update()
102-
if element in self.leaves_elements:
103-
return node
124+
added = False
125+
for root in self.roots:
126+
added |= root.push(node)
104127

105-
return None
128+
if len(self.roots) == 0 or (not added and allow_multi_root):
129+
self.roots.append(node)
130+
added = True
106131

107-
def pop(self, value):
108-
if value not in self.leaves_values:
109-
raise ValueError("That element is not a leaf: {!r}.".format(value))
132+
self._update() # Recompute leaves.
133+
return added
110134

111-
def _visit(node):
112-
for child in list(node.children):
113-
if child.element.value == value:
114-
node.children.remove(child)
135+
def remove(self, value: Any) -> None:
136+
""" Remove all leaves having the given value.
115137
116-
self._postorder(self.root, _visit)
117-
if self.root.element.value == value:
118-
self.root = None
138+
The value to remove needs to belong to at least one leaf in this tree.
139+
Otherwise, the tree remains unchanged.
119140
120-
# Recompute leaves.
121-
self._update()
141+
Args:
142+
value: value to remove from the tree.
143+
144+
Returns:
145+
Whether the tree has changed or not.
146+
"""
147+
if value not in self.leaves_values:
148+
return False
149+
150+
root_to_remove = []
151+
for node in self:
152+
if node.element.value == value:
153+
if node.parent is not None:
154+
node.parent.children.remove(node)
155+
else:
156+
root_to_remove.append(node)
122157

123-
def _postorder(self, node, visit):
124-
for child in node.children:
125-
self._postorder(child, visit)
158+
for node in root_to_remove:
159+
self.roots.remove(node)
126160

127-
visit(node)
161+
self._update() # Recompute leaves.
162+
return True
128163

129-
def _update(self):
164+
def _update(self) -> None:
130165
self._leaves_values = []
131-
self._leaves_elements = set()
166+
self._leaves_elements = []
132167

133-
def _visit(node):
168+
for node in self:
134169
if len(node.children) == 0:
135-
self._leaves_elements.add(node.element)
170+
self._leaves_elements.append(node.element)
136171
self._leaves_values.append(node.element.value)
137172

138-
if self.root is not None:
139-
self._postorder(self.root, _visit)
140-
141173
self._leaves_values = uniquify(self._leaves_values)
174+
self._leaves_elements = uniquify(self._leaves_elements)
142175

143-
def copy(self):
144-
tree = DependencyTree(self.element_type)
145-
if self.root is not None:
146-
tree.root = self.root.copy()
147-
tree._update()
176+
def copy(self) -> "DependencyTree":
177+
tree = type(self)(element_type=self.element_type)
178+
for root in self.roots:
179+
tree.roots.append(root.copy())
148180

181+
tree._update()
149182
return tree
150183

184+
def __iter__(self) -> Iterable["DependencyTree._Node"]:
185+
for root in self.roots:
186+
yield from list(root)
187+
151188
@property
152-
def leaves_elements(self):
189+
def values(self) -> List[Any]:
190+
return [node.element.value for node in self]
191+
192+
@property
193+
def leaves_elements(self) -> List[DependencyTreeElement]:
153194
return self._leaves_elements
154195

155196
@property
156-
def leaves_values(self):
197+
def leaves_values(self) -> List[Any]:
157198
return self._leaves_values
158199

159-
def __str__(self):
160-
if self.root is None:
161-
return ""
162-
163-
return str(self.root)
200+
def __str__(self) -> str:
201+
return "\n".join(map(str, self.roots))

0 commit comments

Comments
 (0)