2424from textworld .envs .wrappers import Recorder
2525
2626
27- def get_failing_constraints (state ):
27+ def get_failing_constraints (state , kb : Optional [KnowledgeBase ] = None ):
28+ kb = kb or KnowledgeBase .default ()
2829 fail = Proposition ("fail" , [])
2930
3031 failed_constraints = []
31- constraints = state .all_applicable_actions (KnowledgeBase . default () .constraints .values ())
32+ constraints = state .all_applicable_actions (kb .constraints .values ())
3233 for constraint in constraints :
3334 if state .is_applicable (constraint ):
3435 # Optimistically delay copying the state
@@ -53,6 +54,10 @@ class PlayerAlreadySetError(ValueError):
5354 pass
5455
5556
57+ class QuestError (ValueError ):
58+ pass
59+
60+
5661class FailedConstraintsError (ValueError ):
5762 """
5863 Thrown when a constraint has failed during generation.
@@ -76,7 +81,8 @@ class WorldEntity:
7681 """
7782
7883 def __init__ (self , var : Variable , name : Optional [str ] = None ,
79- desc : Optional [str ] = None ) -> None :
84+ desc : Optional [str ] = None ,
85+ kb : Optional [KnowledgeBase ] = None ) -> None :
8086 """
8187 Args:
8288 var: The underlying variable for the entity which is used
@@ -93,6 +99,7 @@ def __init__(self, var: Variable, name: Optional[str] = None,
9399 self .infos .desc = desc
94100 self .content = []
95101 self .parent = None
102+ self ._kb = kb or KnowledgeBase .default ()
96103
97104 @property
98105 def id (self ) -> str :
@@ -158,11 +165,11 @@ def remove_property(self, name: str) -> None:
158165
159166 def add (self , * entities : List ["WorldEntity" ]) -> None :
160167 """ Add children to this entity. """
161- if KnowledgeBase . default () .types .is_descendant_of (self .type , "r" ):
168+ if self . _kb .types .is_descendant_of (self .type , "r" ):
162169 name = "at"
163- elif KnowledgeBase . default () .types .is_descendant_of (self .type , ["c" , "I" ]):
170+ elif self . _kb .types .is_descendant_of (self .type , ["c" , "I" ]):
164171 name = "in"
165- elif KnowledgeBase . default () .types .is_descendant_of (self .type , "s" ):
172+ elif self . _kb .types .is_descendant_of (self .type , "s" ):
166173 name = "on"
167174 else :
168175 raise ValueError ("Unexpected type {}" .format (self .type ))
@@ -173,11 +180,11 @@ def add(self, *entities: List["WorldEntity"]) -> None:
173180 entity .parent = self
174181
175182 def remove (self , * entities ):
176- if KnowledgeBase . default () .types .is_descendant_of (self .type , "r" ):
183+ if self . _kb .types .is_descendant_of (self .type , "r" ):
177184 name = "at"
178- elif KnowledgeBase . default () .types .is_descendant_of (self .type , ["c" , "I" ]):
185+ elif self . _kb .types .is_descendant_of (self .type , ["c" , "I" ]):
179186 name = "in"
180- elif KnowledgeBase . default () .types .is_descendant_of (self .type , "s" ):
187+ elif self . _kb .types .is_descendant_of (self .type , "s" ):
181188 name = "on"
182189 else :
183190 raise ValueError ("Unexpected type {}" .format (self .type ))
@@ -283,7 +290,8 @@ class WorldPath:
283290
284291 def __init__ (self , src : WorldRoom , src_exit : WorldRoomExit ,
285292 dest : WorldRoom , dest_exit : WorldRoomExit ,
286- door : Optional [WorldEntity ] = None ) -> None :
293+ door : Optional [WorldEntity ] = None ,
294+ kb : Optional [KnowledgeBase ] = None ) -> None :
287295 """
288296 Args:
289297 src: The source room.
@@ -297,6 +305,7 @@ def __init__(self, src: WorldRoom, src_exit: WorldRoomExit,
297305 self .dest = dest
298306 self .dest_exit = dest_exit
299307 self .door = door
308+ self ._kb = kb or KnowledgeBase .default ()
300309 self .src .exits [self .src_exit ].dest = self .dest .exits [self .dest_exit ]
301310 self .dest .exits [self .dest_exit ].dest = self .src .exits [self .src_exit ]
302311
@@ -307,7 +316,7 @@ def door(self) -> Optional[WorldEntity]:
307316
308317 @door .setter
309318 def door (self , door : WorldEntity ) -> None :
310- if door is not None and not KnowledgeBase . default () .types .is_descendant_of (door .type , "d" ):
319+ if door is not None and not self . _kb .types .is_descendant_of (door .type , "d" ):
311320 msg = "Expecting a WorldEntity of 'door' type."
312321 raise TypeError (msg )
313322
@@ -348,7 +357,7 @@ class GameMaker:
348357 paths (List[WorldPath]): The connections between the rooms.
349358 """
350359
351- def __init__ (self ) -> None :
360+ def __init__ (self , kb : Optional [ KnowledgeBase ] = None ) -> None :
352361 """
353362 Creates an empty world, with a player and an empty inventory.
354363 """
@@ -357,7 +366,7 @@ def __init__(self) -> None:
357366 self .quests = []
358367 self .rooms = []
359368 self .paths = []
360- self ._kb = KnowledgeBase .default ()
369+ self ._kb = kb or KnowledgeBase .default ()
361370 self ._types_counts = self ._kb .types .count (State (self ._kb .logic ))
362371 self .player = self .new (type = 'P' )
363372 self .inventory = self .new (type = 'I' )
@@ -442,15 +451,15 @@ def new(self, type: str, name: Optional[str] = None,
442451 * Otherwise, a `WorldEntity` is returned.
443452 """
444453 var_id = type
445- if not KnowledgeBase . default () .types .is_constant (type ):
454+ if not self . _kb .types .is_constant (type ):
446455 var_id = get_new (type , self ._types_counts )
447456
448457 var = Variable (var_id , type )
449458 if type == "r" :
450459 entity = WorldRoom (var , name , desc )
451460 self .rooms .append (entity )
452461 else :
453- entity = WorldEntity (var , name , desc )
462+ entity = WorldEntity (var , name , desc , kb = self . _kb )
454463
455464 self ._entities [var_id ] = entity
456465 if entity .name :
@@ -546,7 +555,7 @@ def connect(self, exit1: WorldRoomExit, exit2: WorldRoomExit) -> WorldPath:
546555 exit2 .dest .src , exit2 .dest .direction )
547556 raise ExitAlreadyUsedError (msg )
548557
549- path = WorldPath (exit1 .src , exit1 .direction , exit2 .src , exit2 .direction )
558+ path = WorldPath (exit1 .src , exit1 .direction , exit2 .src , exit2 .direction , kb = self . _kb )
550559 self .paths .append (path )
551560 return path
552561
@@ -658,6 +667,9 @@ def set_quest_from_commands(self, commands: List[str], ask_for_state: bool = Fal
658667 winning_facts = [user_query .query_for_important_facts (actions = recorder .actions ,
659668 facts = recorder .last_game_state .state .facts ,
660669 varinfos = self ._working_game .infos )]
670+ if len (commands ) != len (actions ):
671+ unrecognized_commands = [c for c , a in zip (commands , recorder .actions ) if a is None ]
672+ raise QuestError ("Some of the actions were unrecognized: {}" .format (unrecognized_commands ))
661673
662674 event = Event (actions = actions , conditions = winning_facts )
663675 self .quests = [Quest (win_events = [event ])]
@@ -726,7 +738,7 @@ def validate(self) -> bool:
726738 msg = "Player position has not been specified. Use 'M.set_player(room)'."
727739 raise MissingPlayerError (msg )
728740
729- failed_constraints = get_failing_constraints (self .state )
741+ failed_constraints = get_failing_constraints (self .state , self . _kb )
730742 if len (failed_constraints ) > 0 :
731743 raise FailedConstraintsError (failed_constraints )
732744
@@ -747,7 +759,7 @@ def build(self, validate: bool = True) -> Game:
747759 if validate :
748760 self .validate () # Validate the state of the world.
749761
750- world = World .from_facts (self .facts )
762+ world = World .from_facts (self .facts , kb = self . _kb )
751763 game = Game (world , quests = self .quests )
752764
753765 # Keep names and descriptions that were manually provided.
0 commit comments