Skip to content

Commit 87b5cda

Browse files
author
Tavian Barnes
authored
Merge pull request #167 from MarcCote/fix_loading_saved_kb
Fix loading saved kb
2 parents 826f5e6 + c721242 commit 87b5cda

File tree

5 files changed

+95
-54
lines changed

5 files changed

+95
-54
lines changed

textworld/generator/game.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,7 @@ class Game:
339339
_SERIAL_VERSION = 1
340340

341341
def __init__(self, world: World, grammar: Optional[Grammar] = None,
342-
quests: Iterable[Quest] = (),
343-
kb: Optional[KnowledgeBase] = None) -> None:
342+
quests: Iterable[Quest] = ()) -> None:
344343
"""
345344
Args:
346345
world: The world to use for the game.
@@ -352,7 +351,7 @@ def __init__(self, world: World, grammar: Optional[Grammar] = None,
352351
self.metadata = {}
353352
self._objective = None
354353
self._infos = self._build_infos()
355-
self.kb = kb or KnowledgeBase.default()
354+
self.kb = world.kb
356355
self.extras = {}
357356

358357
# Check if we can derive a global winning policy from the quests.
@@ -379,7 +378,7 @@ def _build_infos(self) -> Dict[str, EntityInfo]:
379378

380379
def copy(self) -> "Game":
381380
""" Make a shallow copy of this game. """
382-
game = Game(self.world, self.grammar, self.quests, self.kb)
381+
game = Game(self.world, self.grammar, self.quests)
383382
game._infos = dict(self.infos)
384383
game._objective = self._objective
385384
game.metadata = dict(self.metadata)
@@ -435,12 +434,12 @@ def deserialize(cls, data: Mapping) -> "Game":
435434
if version != cls._SERIAL_VERSION:
436435
raise ValueError("Cannot deserialize a TextWorld version {} game, expected version {}".format(version, cls._SERIAL_VERSION))
437436

438-
world = World.deserialize(data["world"])
437+
kb = KnowledgeBase.deserialize(data["KB"])
438+
world = World.deserialize(data["world"], kb=kb)
439439
game = cls(world)
440440
game.grammar = Grammar(data["grammar"])
441441
game.quests = tuple([Quest.deserialize(d) for d in data["quests"]])
442442
game._infos = {k: EntityInfo.deserialize(v) for k, v in data["infos"]}
443-
game.kb = KnowledgeBase.deserialize(data["KB"])
444443
game.metadata = data.get("metadata", {})
445444
game._objective = data.get("objective", None)
446445
game.extras = data.get("extras", {})

textworld/generator/maker.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@
2424
from textworld.envs.wrappers import Recorder
2525

2626

27-
def get_failing_constraints(state):
27+
def get_failing_constraints(state, kb: Optional[KnowledgeBase] = None):
28+
kb = kb or KnowledgeBase.default()
2829
fail = Proposition("fail", [])
2930

3031
failed_constraints = []
31-
constraints = state.all_applicable_actions(KnowledgeBase.default().constraints.values())
32+
constraints = state.all_applicable_actions(kb.constraints.values())
3233
for constraint in constraints:
3334
if state.is_applicable(constraint):
3435
# Optimistically delay copying the state
@@ -53,6 +54,10 @@ class PlayerAlreadySetError(ValueError):
5354
pass
5455

5556

57+
class QuestError(ValueError):
58+
pass
59+
60+
5661
class FailedConstraintsError(ValueError):
5762
"""
5863
Thrown when a constraint has failed during generation.
@@ -76,7 +81,8 @@ class WorldEntity:
7681
"""
7782

7883
def __init__(self, var: Variable, name: Optional[str] = None,
79-
desc: Optional[str] = None) -> None:
84+
desc: Optional[str] = None,
85+
kb: Optional[KnowledgeBase] = None) -> None:
8086
"""
8187
Args:
8288
var: The underlying variable for the entity which is used
@@ -93,6 +99,7 @@ def __init__(self, var: Variable, name: Optional[str] = None,
9399
self.infos.desc = desc
94100
self.content = []
95101
self.parent = None
102+
self._kb = kb or KnowledgeBase.default()
96103

97104
@property
98105
def id(self) -> str:
@@ -158,11 +165,11 @@ def remove_property(self, name: str) -> None:
158165

159166
def add(self, *entities: List["WorldEntity"]) -> None:
160167
""" Add children to this entity. """
161-
if KnowledgeBase.default().types.is_descendant_of(self.type, "r"):
168+
if self._kb.types.is_descendant_of(self.type, "r"):
162169
name = "at"
163-
elif KnowledgeBase.default().types.is_descendant_of(self.type, ["c", "I"]):
170+
elif self._kb.types.is_descendant_of(self.type, ["c", "I"]):
164171
name = "in"
165-
elif KnowledgeBase.default().types.is_descendant_of(self.type, "s"):
172+
elif self._kb.types.is_descendant_of(self.type, "s"):
166173
name = "on"
167174
else:
168175
raise ValueError("Unexpected type {}".format(self.type))
@@ -173,11 +180,11 @@ def add(self, *entities: List["WorldEntity"]) -> None:
173180
entity.parent = self
174181

175182
def remove(self, *entities):
176-
if KnowledgeBase.default().types.is_descendant_of(self.type, "r"):
183+
if self._kb.types.is_descendant_of(self.type, "r"):
177184
name = "at"
178-
elif KnowledgeBase.default().types.is_descendant_of(self.type, ["c", "I"]):
185+
elif self._kb.types.is_descendant_of(self.type, ["c", "I"]):
179186
name = "in"
180-
elif KnowledgeBase.default().types.is_descendant_of(self.type, "s"):
187+
elif self._kb.types.is_descendant_of(self.type, "s"):
181188
name = "on"
182189
else:
183190
raise ValueError("Unexpected type {}".format(self.type))
@@ -283,7 +290,8 @@ class WorldPath:
283290

284291
def __init__(self, src: WorldRoom, src_exit: WorldRoomExit,
285292
dest: WorldRoom, dest_exit: WorldRoomExit,
286-
door: Optional[WorldEntity] = None) -> None:
293+
door: Optional[WorldEntity] = None,
294+
kb: Optional[KnowledgeBase] = None) -> None:
287295
"""
288296
Args:
289297
src: The source room.
@@ -297,6 +305,7 @@ def __init__(self, src: WorldRoom, src_exit: WorldRoomExit,
297305
self.dest = dest
298306
self.dest_exit = dest_exit
299307
self.door = door
308+
self._kb = kb or KnowledgeBase.default()
300309
self.src.exits[self.src_exit].dest = self.dest.exits[self.dest_exit]
301310
self.dest.exits[self.dest_exit].dest = self.src.exits[self.src_exit]
302311

@@ -307,7 +316,7 @@ def door(self) -> Optional[WorldEntity]:
307316

308317
@door.setter
309318
def door(self, door: WorldEntity) -> None:
310-
if door is not None and not KnowledgeBase.default().types.is_descendant_of(door.type, "d"):
319+
if door is not None and not self._kb.types.is_descendant_of(door.type, "d"):
311320
msg = "Expecting a WorldEntity of 'door' type."
312321
raise TypeError(msg)
313322

@@ -348,7 +357,7 @@ class GameMaker:
348357
paths (List[WorldPath]): The connections between the rooms.
349358
"""
350359

351-
def __init__(self) -> None:
360+
def __init__(self, kb: Optional[KnowledgeBase] = None) -> None:
352361
"""
353362
Creates an empty world, with a player and an empty inventory.
354363
"""
@@ -357,7 +366,7 @@ def __init__(self) -> None:
357366
self.quests = []
358367
self.rooms = []
359368
self.paths = []
360-
self._kb = KnowledgeBase.default()
369+
self._kb = kb or KnowledgeBase.default()
361370
self._types_counts = self._kb.types.count(State(self._kb.logic))
362371
self.player = self.new(type='P')
363372
self.inventory = self.new(type='I')
@@ -442,15 +451,15 @@ def new(self, type: str, name: Optional[str] = None,
442451
* Otherwise, a `WorldEntity` is returned.
443452
"""
444453
var_id = type
445-
if not KnowledgeBase.default().types.is_constant(type):
454+
if not self._kb.types.is_constant(type):
446455
var_id = get_new(type, self._types_counts)
447456

448457
var = Variable(var_id, type)
449458
if type == "r":
450459
entity = WorldRoom(var, name, desc)
451460
self.rooms.append(entity)
452461
else:
453-
entity = WorldEntity(var, name, desc)
462+
entity = WorldEntity(var, name, desc, kb=self._kb)
454463

455464
self._entities[var_id] = entity
456465
if entity.name:
@@ -546,7 +555,7 @@ def connect(self, exit1: WorldRoomExit, exit2: WorldRoomExit) -> WorldPath:
546555
exit2.dest.src, exit2.dest.direction)
547556
raise ExitAlreadyUsedError(msg)
548557

549-
path = WorldPath(exit1.src, exit1.direction, exit2.src, exit2.direction)
558+
path = WorldPath(exit1.src, exit1.direction, exit2.src, exit2.direction, kb=self._kb)
550559
self.paths.append(path)
551560
return path
552561

@@ -658,6 +667,9 @@ def set_quest_from_commands(self, commands: List[str], ask_for_state: bool = Fal
658667
winning_facts = [user_query.query_for_important_facts(actions=recorder.actions,
659668
facts=recorder.last_game_state.state.facts,
660669
varinfos=self._working_game.infos)]
670+
if len(commands) != len(actions):
671+
unrecognized_commands = [c for c, a in zip(commands, recorder.actions) if a is None]
672+
raise QuestError("Some of the actions were unrecognized: {}".format(unrecognized_commands))
661673

662674
event = Event(actions=actions, conditions=winning_facts)
663675
self.quests = [Quest(win_events=[event])]
@@ -726,7 +738,7 @@ def validate(self) -> bool:
726738
msg = "Player position has not been specified. Use 'M.set_player(room)'."
727739
raise MissingPlayerError(msg)
728740

729-
failed_constraints = get_failing_constraints(self.state)
741+
failed_constraints = get_failing_constraints(self.state, self._kb)
730742
if len(failed_constraints) > 0:
731743
raise FailedConstraintsError(failed_constraints)
732744

@@ -747,7 +759,7 @@ def build(self, validate: bool = True) -> Game:
747759
if validate:
748760
self.validate() # Validate the state of the world.
749761

750-
world = World.from_facts(self.facts)
762+
world = World.from_facts(self.facts, kb=self._kb)
751763
game = Game(world, quests=self.quests)
752764

753765
# Keep names and descriptions that were manually provided.

textworld/generator/tests/test_game.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from textworld.generator.game import ActionDependencyTree, ActionDependencyTreeElement
2828
from textworld.generator.inform7 import Inform7Game
2929

30-
from textworld.logic import Proposition
30+
from textworld.logic import Proposition, GameLogic
3131

3232

3333
def _find_action(command: str, actions: Iterable[Action], inform7: Inform7Game) -> None:
@@ -65,6 +65,36 @@ def test_game_comparison():
6565
assert game1 != game3
6666

6767

68+
def test_reloading_game_with_custom_kb():
69+
twl = KnowledgeBase.default().logic._document
70+
twl += """
71+
type customobj : o {
72+
inform7 {
73+
type {
74+
kind :: "custom-obj-like";
75+
}
76+
}
77+
}
78+
"""
79+
80+
logic = GameLogic.parse(twl)
81+
kb = KnowledgeBase(logic, "")
82+
83+
M = GameMaker(kb=kb)
84+
85+
room = M.new_room("room")
86+
M.set_player(room)
87+
88+
custom_obj = M.new(type='customobj', name='customized object')
89+
M.inventory.add(custom_obj)
90+
91+
commands = ["drop customized object"]
92+
quest = M.set_quest_from_commands(commands)
93+
assert quest.commands == commands
94+
game = M.build()
95+
assert game == Game.deserialize(game.serialize())
96+
97+
6898
def test_variable_infos(verbose=False):
6999
options = textworld.GameOptions()
70100
options.nb_rooms = 5

0 commit comments

Comments
 (0)