@@ -1778,6 +1778,24 @@ def typeerror():
17781778
17791779 return result
17801780
1781+ def _has_unfinished_sequences (self , this_peer_finished : bool , synced_gpus : bool , device : torch .device ) -> bool :
1782+ """
1783+ Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
1784+ fed through `this_peer_finished`. ZeRO stage 3-friendly.
1785+ """
1786+ if synced_gpus :
1787+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
1788+ # The following logic allows an early break if all peers finished generating their sequence
1789+ this_peer_finished_flag = torch .tensor (0.0 if this_peer_finished else 1.0 ).to (device )
1790+ # send 0.0 if we finished, 1.0 otherwise
1791+ dist .all_reduce (this_peer_finished_flag , op = dist .ReduceOp .SUM )
1792+ # did all peers finish? the reduced sum will be 0.0 then
1793+ if this_peer_finished_flag .item () == 0.0 :
1794+ return False
1795+ elif this_peer_finished :
1796+ return False
1797+ return True
1798+
17811799 def contrastive_search (self , * args , ** kwargs ):
17821800 logger .warning_once (
17831801 "Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
@@ -1939,19 +1957,9 @@ def _contrastive_search(
19391957 unfinished_sequences = torch .ones (batch_size , dtype = torch .long , device = input_ids .device )
19401958 model_kwargs ["cache_position" ] = torch .arange (cur_len , device = input_ids .device )
19411959
1942- this_peer_finished = False # used by synced_gpus only
1943-
1944- while True :
1945- if synced_gpus :
1946- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
1947- # The following logic allows an early break if all peers finished generating their sequence
1948- this_peer_finished_flag = torch .tensor (0.0 if this_peer_finished else 1.0 ).to (input_ids .device )
1949- # send 0.0 if we finished, 1.0 otherwise
1950- dist .all_reduce (this_peer_finished_flag , op = dist .ReduceOp .SUM )
1951- # did all peers finish? the reduced sum will be 0.0 then
1952- if this_peer_finished_flag .item () == 0.0 :
1953- break
1960+ this_peer_finished = False
19541961
1962+ while self ._has_unfinished_sequences (this_peer_finished , synced_gpus , device = input_ids .device ):
19551963 # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
19561964 # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
19571965 if model_kwargs .get ("past_key_values" ) is None :
@@ -2187,12 +2195,7 @@ def _contrastive_search(
21872195
21882196 # stop when each sentence is finished
21892197 unfinished_sequences = unfinished_sequences & ~ stopping_criteria (input_ids , scores )
2190-
2191- if unfinished_sequences .max () == 0 :
2192- this_peer_finished = True
2193-
2194- if this_peer_finished and not synced_gpus :
2195- break
2198+ this_peer_finished = unfinished_sequences .max () == 0
21962199
21972200 if streamer is not None :
21982201 streamer .end ()
@@ -2395,6 +2398,7 @@ def _greedy_search(
23952398 )
23962399
23972400 # keep track of which sequences are already finished
2401+ this_peer_finished = False
23982402 batch_size , cur_len = (
23992403 model_kwargs ["attention_mask" ].shape
24002404 if model_kwargs .get ("attention_mask" , None ) is not None
@@ -2403,18 +2407,7 @@ def _greedy_search(
24032407 unfinished_sequences = torch .ones (batch_size , dtype = torch .long , device = input_ids .device )
24042408 model_kwargs ["cache_position" ] = torch .arange (cur_len , device = input_ids .device )
24052409
2406- this_peer_finished = False # used by synced_gpus only
2407- while True :
2408- if synced_gpus :
2409- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
2410- # The following logic allows an early break if all peers finished generating their sequence
2411- this_peer_finished_flag = torch .tensor (0.0 if this_peer_finished else 1.0 ).to (input_ids .device )
2412- # send 0.0 if we finished, 1.0 otherwise
2413- dist .all_reduce (this_peer_finished_flag , op = dist .ReduceOp .SUM )
2414- # did all peers finish? the reduced sum will be 0.0 then
2415- if this_peer_finished_flag .item () == 0.0 :
2416- break
2417-
2410+ while self ._has_unfinished_sequences (this_peer_finished , synced_gpus , device = input_ids .device ):
24182411 # prepare model inputs
24192412 model_inputs = self .prepare_inputs_for_generation (input_ids , ** model_kwargs )
24202413
@@ -2480,13 +2473,7 @@ def _greedy_search(
24802473 )
24812474
24822475 unfinished_sequences = unfinished_sequences & ~ stopping_criteria (input_ids , scores )
2483-
2484- # stop when each sentence is finished
2485- if unfinished_sequences .max () == 0 :
2486- this_peer_finished = True
2487-
2488- if this_peer_finished and not synced_gpus :
2489- break
2476+ this_peer_finished = unfinished_sequences .max () == 0
24902477
24912478 if streamer is not None :
24922479 streamer .end ()
@@ -2699,6 +2686,7 @@ def _sample(
26992686 )
27002687
27012688 # keep track of which sequences are already finished
2689+ this_peer_finished = False
27022690 batch_size , cur_len = (
27032691 model_kwargs ["attention_mask" ].shape
27042692 if model_kwargs .get ("attention_mask" , None ) is not None
@@ -2707,19 +2695,7 @@ def _sample(
27072695 unfinished_sequences = torch .ones (batch_size , dtype = torch .long , device = input_ids .device )
27082696 model_kwargs ["cache_position" ] = torch .arange (cur_len , device = input_ids .device )
27092697
2710- this_peer_finished = False # used by synced_gpus only
2711- # auto-regressive generation
2712- while True :
2713- if synced_gpus :
2714- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
2715- # The following logic allows an early break if all peers finished generating their sequence
2716- this_peer_finished_flag = torch .tensor (0.0 if this_peer_finished else 1.0 ).to (input_ids .device )
2717- # send 0.0 if we finished, 1.0 otherwise
2718- dist .all_reduce (this_peer_finished_flag , op = dist .ReduceOp .SUM )
2719- # did all peers finish? the reduced sum will be 0.0 then
2720- if this_peer_finished_flag .item () == 0.0 :
2721- break
2722-
2698+ while self ._has_unfinished_sequences (this_peer_finished , synced_gpus , device = input_ids .device ):
27232699 # prepare model inputs
27242700 model_inputs = self .prepare_inputs_for_generation (input_ids , ** model_kwargs )
27252701
@@ -2787,13 +2763,7 @@ def _sample(
27872763 )
27882764
27892765 unfinished_sequences = unfinished_sequences & ~ stopping_criteria (input_ids , scores )
2790-
2791- # stop when each sentence is finished
2792- if unfinished_sequences .max () == 0 :
2793- this_peer_finished = True
2794-
2795- if this_peer_finished and not synced_gpus :
2796- break
2766+ this_peer_finished = unfinished_sequences .max () == 0
27972767
27982768 if streamer is not None :
27992769 streamer .end ()
@@ -3052,20 +3022,11 @@ def _beam_search(
30523022 beam_scores [:, 1 :] = - 1e9
30533023 beam_scores = beam_scores .view ((batch_size * num_beams ,))
30543024
3055- this_peer_finished = False # used by synced_gpus only
3025+ this_peer_finished = False
30563026
30573027 decoder_prompt_len = input_ids .shape [- 1 ] # record the prompt length of decoder
3058- while True :
3059- if synced_gpus :
3060- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
3061- # The following logic allows an early break if all peers finished generating their sequence
3062- this_peer_finished_flag = torch .tensor (0.0 if this_peer_finished else 1.0 ).to (input_ids .device )
3063- # send 0.0 if we finished, 1.0 otherwise
3064- dist .all_reduce (this_peer_finished_flag , op = dist .ReduceOp .SUM )
3065- # did all peers finish? the reduced sum will be 0.0 then
3066- if this_peer_finished_flag .item () == 0.0 :
3067- break
30683028
3029+ while self ._has_unfinished_sequences (this_peer_finished , synced_gpus , device = input_ids .device ):
30693030 model_inputs = self .prepare_inputs_for_generation (input_ids , ** model_kwargs )
30703031
30713032 # if sequential is True, split the input to batches of batch_size and run sequentially
@@ -3192,10 +3153,7 @@ def _beam_search(
31923153 cur_len = cur_len + 1
31933154
31943155 if beam_scorer .is_done or all (stopping_criteria (input_ids , scores )):
3195- if not synced_gpus :
3196- break
3197- else :
3198- this_peer_finished = True
3156+ this_peer_finished = True
31993157
32003158 sequence_outputs = beam_scorer .finalize (
32013159 input_ids ,
@@ -3441,20 +3399,10 @@ def _beam_sample(
34413399 beam_scores = torch .zeros ((batch_size , num_beams ), dtype = torch .float , device = input_ids .device )
34423400 beam_scores = beam_scores .view ((batch_size * num_beams ,))
34433401
3444- this_peer_finished = False # used by synced_gpus only
3402+ this_peer_finished = False
34453403
34463404 decoder_prompt_len = input_ids .shape [- 1 ] # record the prompt length of decoder
3447- while True :
3448- if synced_gpus :
3449- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
3450- # The following logic allows an early break if all peers finished generating their sequence
3451- this_peer_finished_flag = torch .tensor (0.0 if this_peer_finished else 1.0 ).to (input_ids .device )
3452- # send 0.0 if we finished, 1.0 otherwise
3453- dist .all_reduce (this_peer_finished_flag , op = dist .ReduceOp .SUM )
3454- # did all peers finish? the reduced sum will be 0.0 then
3455- if this_peer_finished_flag .item () == 0.0 :
3456- break
3457-
3405+ while self ._has_unfinished_sequences (this_peer_finished , synced_gpus , device = input_ids .device ):
34583406 model_inputs = self .prepare_inputs_for_generation (input_ids , ** model_kwargs )
34593407
34603408 outputs = self (
@@ -3549,10 +3497,7 @@ def _beam_sample(
35493497 cur_len = cur_len + 1
35503498
35513499 if beam_scorer .is_done or all (stopping_criteria (input_ids , scores )):
3552- if not synced_gpus :
3553- break
3554- else :
3555- this_peer_finished = True
3500+ this_peer_finished = True
35563501
35573502 sequence_outputs = beam_scorer .finalize (
35583503 input_ids ,
@@ -3804,20 +3749,10 @@ def _group_beam_search(
38043749 beam_scores [:, ::num_sub_beams ] = 0
38053750 beam_scores = beam_scores .view ((batch_size * num_beams ,))
38063751
3807- this_peer_finished = False # used by synced_gpus only
3752+ this_peer_finished = False
38083753
38093754 decoder_prompt_len = input_ids .shape [- 1 ] # record the prompt length of decoder
3810- while True :
3811- if synced_gpus :
3812- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
3813- # The following logic allows an early break if all peers finished generating their sequence
3814- this_peer_finished_flag = torch .tensor (0.0 if this_peer_finished else 1.0 ).to (input_ids .device )
3815- # send 0.0 if we finished, 1.0 otherwise
3816- dist .all_reduce (this_peer_finished_flag , op = dist .ReduceOp .SUM )
3817- # did all peers finish? the reduced sum will be 0.0 then
3818- if this_peer_finished_flag .item () == 0.0 :
3819- break
3820-
3755+ while self ._has_unfinished_sequences (this_peer_finished , synced_gpus , device = input_ids .device ):
38213756 # predicted tokens in cur_len step
38223757 current_tokens = torch .zeros (batch_size * num_beams , dtype = input_ids .dtype , device = device )
38233758
@@ -3955,10 +3890,7 @@ def _group_beam_search(
39553890 cur_len = cur_len + 1
39563891
39573892 if beam_scorer .is_done or all (stopping_criteria (input_ids , scores )):
3958- if not synced_gpus :
3959- break
3960- else :
3961- this_peer_finished = True
3893+ this_peer_finished = True
39623894
39633895 final_beam_indices = sum (beam_indices , ()) if beam_indices is not None else None
39643896 sequence_outputs = beam_scorer .finalize (
@@ -4213,20 +4145,10 @@ def _constrained_beam_search(
42134145 beam_scores [:, 1 :] = - 1e9
42144146 beam_scores = beam_scores .view ((batch_size * num_beams ,))
42154147
4216- this_peer_finished = False # used by synced_gpus only
4148+ this_peer_finished = False
42174149
42184150 decoder_prompt_len = input_ids .shape [- 1 ] # record the prompt length of decoder
4219- while True :
4220- if synced_gpus :
4221- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
4222- # The following logic allows an early break if all peers finished generating their sequence
4223- this_peer_finished_flag = torch .tensor (0.0 if this_peer_finished else 1.0 ).to (input_ids .device )
4224- # send 0.0 if we finished, 1.0 otherwise
4225- dist .all_reduce (this_peer_finished_flag , op = dist .ReduceOp .SUM )
4226- # did all peers finish? the reduced sum will be 0.0 then
4227- if this_peer_finished_flag .item () == 0.0 :
4228- break
4229-
4151+ while self ._has_unfinished_sequences (this_peer_finished , synced_gpus , device = input_ids .device ):
42304152 model_inputs = self .prepare_inputs_for_generation (input_ids , ** model_kwargs )
42314153
42324154 outputs = self (
@@ -4320,10 +4242,7 @@ def _constrained_beam_search(
43204242 cur_len = cur_len + 1
43214243
43224244 if constrained_beam_scorer .is_done or all (stopping_criteria (input_ids , scores )):
4323- if not synced_gpus :
4324- break
4325- else :
4326- this_peer_finished = True
4245+ this_peer_finished = True
43274246
43284247 sequence_outputs = constrained_beam_scorer .finalize (
43294248 input_ids ,
@@ -4553,18 +4472,8 @@ def _assisted_decoding(
45534472 # other auxiliary variables
45544473 max_len = stopping_criteria [0 ].max_length
45554474
4556- this_peer_finished = False # used by synced_gpus only
4557- while True :
4558- if synced_gpus :
4559- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
4560- # The following logic allows an early break if all peers finished generating their sequence
4561- this_peer_finished_flag = torch .tensor (0.0 if this_peer_finished else 1.0 ).to (input_ids .device )
4562- # send 0.0 if we finished, 1.0 otherwise
4563- dist .all_reduce (this_peer_finished_flag , op = dist .ReduceOp .SUM )
4564- # did all peers finish? the reduced sum will be 0.0 then
4565- if this_peer_finished_flag .item () == 0.0 :
4566- break
4567-
4475+ this_peer_finished = False
4476+ while self ._has_unfinished_sequences (this_peer_finished , synced_gpus , device = input_ids .device ):
45684477 cur_len = input_ids .shape [- 1 ]
45694478
45704479 # 1. Fetch candidate sequences from a `CandidateGenerator`
@@ -4733,13 +4642,7 @@ def _assisted_decoding(
47334642 )
47344643
47354644 unfinished_sequences = unfinished_sequences & ~ stopping_criteria (input_ids , scores )
4736-
4737- # stop when each sentence is finished
4738- if unfinished_sequences .max () == 0 :
4739- this_peer_finished = True
4740-
4741- if this_peer_finished and not synced_gpus :
4742- break
4645+ this_peer_finished = unfinished_sequences .max () == 0
47434646
47444647 if streamer is not None :
47454648 streamer .end ()
0 commit comments