@@ -95,6 +95,70 @@ def test_load_and_set(dummy_config, use_discrete):
9595 np .testing .assert_array_equal (w , lw )
9696
9797
98+ def test_resume (dummy_config , tmp_path ):
99+ brain_params_team0 = BrainParameters (
100+ brain_name = "test_brain?team=0" ,
101+ vector_observation_space_size = 1 ,
102+ camera_resolutions = [],
103+ vector_action_space_size = [2 ],
104+ vector_action_descriptions = [],
105+ vector_action_space_type = 0 ,
106+ )
107+
108+ brain_name = BehaviorIdentifiers .from_name_behavior_id (
109+ brain_params_team0 .brain_name
110+ ).brain_name
111+
112+ brain_params_team1 = BrainParameters (
113+ brain_name = "test_brain?team=1" ,
114+ vector_observation_space_size = 1 ,
115+ camera_resolutions = [],
116+ vector_action_space_size = [2 ],
117+ vector_action_descriptions = [],
118+ vector_action_space_type = 0 ,
119+ )
120+
121+ tmp_path = tmp_path .as_posix ()
122+ ppo_trainer = PPOTrainer (brain_name , 0 , dummy_config , True , False , 0 , tmp_path )
123+ controller = GhostController (100 )
124+ trainer = GhostTrainer (
125+ ppo_trainer , brain_name , controller , 0 , dummy_config , True , tmp_path
126+ )
127+
128+ parsed_behavior_id0 = BehaviorIdentifiers .from_name_behavior_id (
129+ brain_params_team0 .brain_name
130+ )
131+ policy = trainer .create_policy (parsed_behavior_id0 , brain_params_team0 )
132+ trainer .add_policy (parsed_behavior_id0 , policy )
133+
134+ parsed_behavior_id1 = BehaviorIdentifiers .from_name_behavior_id (
135+ brain_params_team1 .brain_name
136+ )
137+ policy = trainer .create_policy (parsed_behavior_id1 , brain_params_team1 )
138+ trainer .add_policy (parsed_behavior_id1 , policy )
139+
140+ trainer .save_model (parsed_behavior_id0 .behavior_id )
141+
142+ # Make a new trainer, check that the policies are the same
143+ ppo_trainer2 = PPOTrainer (brain_name , 0 , dummy_config , True , True , 0 , tmp_path )
144+ trainer2 = GhostTrainer (
145+ ppo_trainer2 , brain_name , controller , 0 , dummy_config , True , tmp_path
146+ )
147+ policy = trainer2 .create_policy (parsed_behavior_id0 , brain_params_team0 )
148+ trainer2 .add_policy (parsed_behavior_id0 , policy )
149+
150+ policy = trainer2 .create_policy (parsed_behavior_id1 , brain_params_team1 )
151+ trainer2 .add_policy (parsed_behavior_id1 , policy )
152+
153+ trainer1_policy = trainer .get_policy (parsed_behavior_id1 .behavior_id )
154+ trainer2_policy = trainer2 .get_policy (parsed_behavior_id1 .behavior_id )
155+ weights = trainer1_policy .get_weights ()
156+ weights2 = trainer2_policy .get_weights ()
157+
158+ for w , lw in zip (weights , weights2 ):
159+ np .testing .assert_array_equal (w , lw )
160+
161+
98162def test_process_trajectory (dummy_config ):
99163 brain_params_team0 = BrainParameters (
100164 brain_name = "test_brain?team=0" ,
0 commit comments