File tree Expand file tree Collapse file tree 1 file changed +8
-8
lines changed
applications/ColossalChat/coati/distributed Expand file tree Collapse file tree 1 file changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -128,6 +128,14 @@ def __init__(
128128 drop_last = True ,
129129 collate_fn = collate_fn_grpo ,
130130 )
131+ if grpo_config ["reward_fn_type" ] == "think_answer_tags" :
132+ self .evaluation_function = math_reward_fn
133+ elif grpo_config ["reward_fn_type" ] == "boxed" :
134+ self .evaluation_function = boxed_math_reward_fn
135+ elif grpo_config ["reward_fn_type" ] == "code" :
136+ self .evaluation_function = code_reward_fn
137+ else :
138+ raise ValueError (f"Unknown evaluation function type { grpo_config ['reward_fn_type' ]} " )
131139
132140 self .eval_dataset_config = eval_dataset_config
133141 if self .eval_dataset_config is not None :
@@ -151,14 +159,6 @@ def __init__(
151159 ),
152160 collate_fn = collate_fn_grpo ,
153161 )
154- if grpo_config ["reward_fn_type" ] == "think_answer_tags" :
155- self .evaluation_function = math_reward_fn
156- elif grpo_config ["reward_fn_type" ] == "boxed" :
157- self .evaluation_function = boxed_math_reward_fn
158- elif grpo_config ["reward_fn_type" ] == "code" :
159- self .evaluation_function = code_reward_fn
160- else :
161- raise ValueError (f"Unknown evaluation function type { grpo_config ['reward_fn_type' ]} " )
162162 else :
163163 print ("No eval dataset provided, skip eval" )
164164 self .device = get_current_device ()
You can’t perform that action at this time.
0 commit comments