@@ -498,7 +498,7 @@ def generate(
498498
499499 # 3. Make sure generation config is correctly set
500500 # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
501- self ._set_return_outputs (
501+ return_dict_in_generate = self ._set_return_outputs (
502502 return_dict_in_generate = return_dict_in_generate ,
503503 return_token_timestamps = return_token_timestamps ,
504504 logprob_threshold = logprob_threshold ,
@@ -732,7 +732,7 @@ def generate(
732732 else :
733733 outputs = sequences
734734
735- if generation_config .return_dict_in_generate :
735+ if return_dict_in_generate and generation_config .return_dict_in_generate :
736736 dict_outputs = self ._stack_split_outputs (seek_outputs , model_output_type , sequences .device , kwargs )
737737
738738 if num_return_sequences > 1 :
@@ -1109,18 +1109,20 @@ def _maybe_warn_unused_inputs(
11091109 def _set_return_outputs (return_dict_in_generate , return_token_timestamps , logprob_threshold , generation_config ):
11101110 if return_dict_in_generate is None :
11111111 return_dict_in_generate = generation_config .return_dict_in_generate
1112+ else :
1113+ generation_config .return_dict_in_generate = return_dict_in_generate
11121114
11131115 generation_config .return_token_timestamps = return_token_timestamps
11141116 if return_token_timestamps :
1115- return_dict_in_generate = True
1117+ generation_config . return_dict_in_generate = True
11161118 generation_config .output_attentions = True
11171119 generation_config .output_scores = True
11181120
11191121 if logprob_threshold is not None :
1120- return_dict_in_generate = True
1122+ generation_config . return_dict_in_generate = True
11211123 generation_config .output_scores = True
11221124
1123- generation_config . return_dict_in_generate = return_dict_in_generate
1125+ return return_dict_in_generate
11241126
11251127 def _set_return_timestamps (self , return_timestamps , is_shortform , generation_config ):
11261128 if not is_shortform :
0 commit comments