11# Copyright (c) Microsoft Corporation. All rights reserved.
22# Licensed under the MIT license.
33
4-
54import time
65import argparse
6+ from os .path import join as pjoin
7+
8+ import numpy as np
79
810import textworld
911from textworld import g_rng
1012from textworld .generator import World
1113
14+ from textworld .generator .game import GameOptions
15+
1216
1317def generate_never_ending_game (args ):
1418 g_rng .set_seed (args .seed )
19+
1520 msg = "--max-steps {} --nb-objects {} --nb-rooms {} --quest-length {} --quest-breadth {} --seed {}"
1621 print (msg .format (args .max_steps , args .nb_objects , args .nb_rooms , args .quest_length , args .quest_breadth , g_rng .seed ))
1722 print ("Generating game..." )
1823
19- grammar_flags = {}
20- game = textworld .generator .make_game (args .nb_rooms , args .nb_objects , args .quest_length , args .quest_breadth , grammar_flags )
24+ options = GameOptions ()
25+ options .seeds = g_rng .seed
26+ options .nb_rooms = args .nb_rooms
27+ options .nb_objects = args .nb_objects
28+ options .quest_length = args .quest_length
29+ options .quest_breadth = args .quest_breadth
30+
31+ game = textworld .generator .make_game (options )
2132 if args .no_quest :
2233 game .quests = []
2334
2435 game_name = "neverending"
25- game_file = textworld .generator .compile_game (game , game_name , force_recompile = True , games_folder = args .output )
36+ path = pjoin (args .output , game_name + ".ulx" )
37+ game_file = textworld .generator .compile_game (game , path , force_recompile = True )
2638 return game_file
2739
2840
@@ -33,7 +45,7 @@ def benchmark(game_file, args):
3345 if args .mode == "random" :
3446 agent = textworld .agents .NaiveAgent ()
3547 elif args .mode == "random-cmd" :
36- agent = textworld .agents .RandomCommandAgent ()
48+ agent = textworld .agents .RandomCommandAgent (seed = args . agent_seed )
3749 elif args .mode == "walkthrough" :
3850 agent = textworld .agents .WalkthroughAgent ()
3951
@@ -52,14 +64,13 @@ def benchmark(game_file, args):
5264
5365 reward = 0
5466 done = False
55- print ("Benchmarking {} using ..." .format (game_file , env .__class__ .__name__ ))
5667 start_time = time .time ()
5768 for _ in range (args .max_steps ):
5869 command = agent .act (game_state , reward , done )
5970 game_state , reward , done = env .step (command )
6071
6172 if done :
62- print ("Win! Reset." )
73+ # print("Win! Reset.")
6374 env .reset ()
6475 done = False
6576
@@ -69,6 +80,7 @@ def benchmark(game_file, args):
6980 duration = time .time () - start_time
7081 speed = args .max_steps / duration
7182 print ("Done {:,} steps in {:.2f} secs ({:,.1f} steps/sec)" .format (args .max_steps , duration , speed ))
83+ return speed
7284
7385
7486def parse_args ():
@@ -90,11 +102,21 @@ def parse_args():
90102 parser .add_argument ("--compute_intermediate_reward" , action = "store_true" )
91103 parser .add_argument ("--activate_state_tracking" , action = "store_true" )
92104 parser .add_argument ("--seed" , type = int )
105+ parser .add_argument ("--agent-seed" , type = int , default = 2018 )
93106 parser .add_argument ("-v" , "--verbose" , action = "store_true" )
94107 return parser .parse_args ()
95108
96109
97110if __name__ == "__main__" :
98111 args = parse_args ()
99112 game_file = generate_never_ending_game (args )
100- benchmark (game_file , args )
113+
114+
115+ speeds = []
116+ for _ in range (10 ):
117+ speed = benchmark (game_file , args )
118+ speeds .append (speed )
119+ args .agent_seed = args .agent_seed + 1
120+
121+ print ("-----\n Average: {:,.1f} steps/sec" .format (np .mean (speeds )))
122+
0 commit comments