Skip to content

Commit 9e4df7c

Browse files
ganteamyeroberts
andauthored
Generate: replace breaks by a loop condition (#29662)
* replace breaks by a loop condition * Update src/transformers/generation/utils.py Co-authored-by: amyeroberts <[email protected]> --------- Co-authored-by: amyeroberts <[email protected]>
1 parent 28de2f4 commit 9e4df7c

File tree

1 file changed

+42
-139
lines changed

1 file changed

+42
-139
lines changed

src/transformers/generation/utils.py

Lines changed: 42 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,6 +1778,24 @@ def typeerror():
17781778

17791779
return result
17801780

1781+
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
1782+
"""
1783+
Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
1784+
fed through `this_peer_finished`. ZeRO stage 3-friendly.
1785+
"""
1786+
if synced_gpus:
1787+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
1788+
# The following logic allows an early break if all peers finished generating their sequence
1789+
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
1790+
# send 0.0 if we finished, 1.0 otherwise
1791+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
1792+
# did all peers finish? the reduced sum will be 0.0 then
1793+
if this_peer_finished_flag.item() == 0.0:
1794+
return False
1795+
elif this_peer_finished:
1796+
return False
1797+
return True
1798+
17811799
def contrastive_search(self, *args, **kwargs):
17821800
logger.warning_once(
17831801
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
@@ -1939,19 +1957,9 @@ def _contrastive_search(
19391957
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
19401958
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
19411959

1942-
this_peer_finished = False # used by synced_gpus only
1943-
1944-
while True:
1945-
if synced_gpus:
1946-
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
1947-
# The following logic allows an early break if all peers finished generating their sequence
1948-
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
1949-
# send 0.0 if we finished, 1.0 otherwise
1950-
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
1951-
# did all peers finish? the reduced sum will be 0.0 then
1952-
if this_peer_finished_flag.item() == 0.0:
1953-
break
1960+
this_peer_finished = False
19541961

1962+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
19551963
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
19561964
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
19571965
if model_kwargs.get("past_key_values") is None:
@@ -2187,12 +2195,7 @@ def _contrastive_search(
21872195

21882196
# stop when each sentence is finished
21892197
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
2190-
2191-
if unfinished_sequences.max() == 0:
2192-
this_peer_finished = True
2193-
2194-
if this_peer_finished and not synced_gpus:
2195-
break
2198+
this_peer_finished = unfinished_sequences.max() == 0
21962199

21972200
if streamer is not None:
21982201
streamer.end()
@@ -2395,6 +2398,7 @@ def _greedy_search(
23952398
)
23962399

23972400
# keep track of which sequences are already finished
2401+
this_peer_finished = False
23982402
batch_size, cur_len = (
23992403
model_kwargs["attention_mask"].shape
24002404
if model_kwargs.get("attention_mask", None) is not None
@@ -2403,18 +2407,7 @@ def _greedy_search(
24032407
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
24042408
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
24052409

2406-
this_peer_finished = False # used by synced_gpus only
2407-
while True:
2408-
if synced_gpus:
2409-
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
2410-
# The following logic allows an early break if all peers finished generating their sequence
2411-
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
2412-
# send 0.0 if we finished, 1.0 otherwise
2413-
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
2414-
# did all peers finish? the reduced sum will be 0.0 then
2415-
if this_peer_finished_flag.item() == 0.0:
2416-
break
2417-
2410+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
24182411
# prepare model inputs
24192412
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
24202413

@@ -2480,13 +2473,7 @@ def _greedy_search(
24802473
)
24812474

24822475
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
2483-
2484-
# stop when each sentence is finished
2485-
if unfinished_sequences.max() == 0:
2486-
this_peer_finished = True
2487-
2488-
if this_peer_finished and not synced_gpus:
2489-
break
2476+
this_peer_finished = unfinished_sequences.max() == 0
24902477

24912478
if streamer is not None:
24922479
streamer.end()
@@ -2699,6 +2686,7 @@ def _sample(
26992686
)
27002687

27012688
# keep track of which sequences are already finished
2689+
this_peer_finished = False
27022690
batch_size, cur_len = (
27032691
model_kwargs["attention_mask"].shape
27042692
if model_kwargs.get("attention_mask", None) is not None
@@ -2707,19 +2695,7 @@ def _sample(
27072695
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
27082696
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
27092697

2710-
this_peer_finished = False # used by synced_gpus only
2711-
# auto-regressive generation
2712-
while True:
2713-
if synced_gpus:
2714-
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
2715-
# The following logic allows an early break if all peers finished generating their sequence
2716-
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
2717-
# send 0.0 if we finished, 1.0 otherwise
2718-
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
2719-
# did all peers finish? the reduced sum will be 0.0 then
2720-
if this_peer_finished_flag.item() == 0.0:
2721-
break
2722-
2698+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
27232699
# prepare model inputs
27242700
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
27252701

@@ -2787,13 +2763,7 @@ def _sample(
27872763
)
27882764

27892765
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
2790-
2791-
# stop when each sentence is finished
2792-
if unfinished_sequences.max() == 0:
2793-
this_peer_finished = True
2794-
2795-
if this_peer_finished and not synced_gpus:
2796-
break
2766+
this_peer_finished = unfinished_sequences.max() == 0
27972767

27982768
if streamer is not None:
27992769
streamer.end()
@@ -3052,20 +3022,11 @@ def _beam_search(
30523022
beam_scores[:, 1:] = -1e9
30533023
beam_scores = beam_scores.view((batch_size * num_beams,))
30543024

3055-
this_peer_finished = False # used by synced_gpus only
3025+
this_peer_finished = False
30563026

30573027
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
3058-
while True:
3059-
if synced_gpus:
3060-
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
3061-
# The following logic allows an early break if all peers finished generating their sequence
3062-
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
3063-
# send 0.0 if we finished, 1.0 otherwise
3064-
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
3065-
# did all peers finish? the reduced sum will be 0.0 then
3066-
if this_peer_finished_flag.item() == 0.0:
3067-
break
30683028

3029+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
30693030
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
30703031

30713032
# if sequential is True, split the input to batches of batch_size and run sequentially
@@ -3192,10 +3153,7 @@ def _beam_search(
31923153
cur_len = cur_len + 1
31933154

31943155
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
3195-
if not synced_gpus:
3196-
break
3197-
else:
3198-
this_peer_finished = True
3156+
this_peer_finished = True
31993157

32003158
sequence_outputs = beam_scorer.finalize(
32013159
input_ids,
@@ -3441,20 +3399,10 @@ def _beam_sample(
34413399
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
34423400
beam_scores = beam_scores.view((batch_size * num_beams,))
34433401

3444-
this_peer_finished = False # used by synced_gpus only
3402+
this_peer_finished = False
34453403

34463404
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
3447-
while True:
3448-
if synced_gpus:
3449-
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
3450-
# The following logic allows an early break if all peers finished generating their sequence
3451-
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
3452-
# send 0.0 if we finished, 1.0 otherwise
3453-
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
3454-
# did all peers finish? the reduced sum will be 0.0 then
3455-
if this_peer_finished_flag.item() == 0.0:
3456-
break
3457-
3405+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
34583406
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
34593407

34603408
outputs = self(
@@ -3549,10 +3497,7 @@ def _beam_sample(
35493497
cur_len = cur_len + 1
35503498

35513499
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
3552-
if not synced_gpus:
3553-
break
3554-
else:
3555-
this_peer_finished = True
3500+
this_peer_finished = True
35563501

35573502
sequence_outputs = beam_scorer.finalize(
35583503
input_ids,
@@ -3804,20 +3749,10 @@ def _group_beam_search(
38043749
beam_scores[:, ::num_sub_beams] = 0
38053750
beam_scores = beam_scores.view((batch_size * num_beams,))
38063751

3807-
this_peer_finished = False # used by synced_gpus only
3752+
this_peer_finished = False
38083753

38093754
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
3810-
while True:
3811-
if synced_gpus:
3812-
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
3813-
# The following logic allows an early break if all peers finished generating their sequence
3814-
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
3815-
# send 0.0 if we finished, 1.0 otherwise
3816-
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
3817-
# did all peers finish? the reduced sum will be 0.0 then
3818-
if this_peer_finished_flag.item() == 0.0:
3819-
break
3820-
3755+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
38213756
# predicted tokens in cur_len step
38223757
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
38233758

@@ -3955,10 +3890,7 @@ def _group_beam_search(
39553890
cur_len = cur_len + 1
39563891

39573892
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
3958-
if not synced_gpus:
3959-
break
3960-
else:
3961-
this_peer_finished = True
3893+
this_peer_finished = True
39623894

39633895
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
39643896
sequence_outputs = beam_scorer.finalize(
@@ -4213,20 +4145,10 @@ def _constrained_beam_search(
42134145
beam_scores[:, 1:] = -1e9
42144146
beam_scores = beam_scores.view((batch_size * num_beams,))
42154147

4216-
this_peer_finished = False # used by synced_gpus only
4148+
this_peer_finished = False
42174149

42184150
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
4219-
while True:
4220-
if synced_gpus:
4221-
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
4222-
# The following logic allows an early break if all peers finished generating their sequence
4223-
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
4224-
# send 0.0 if we finished, 1.0 otherwise
4225-
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
4226-
# did all peers finish? the reduced sum will be 0.0 then
4227-
if this_peer_finished_flag.item() == 0.0:
4228-
break
4229-
4151+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
42304152
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
42314153

42324154
outputs = self(
@@ -4320,10 +4242,7 @@ def _constrained_beam_search(
43204242
cur_len = cur_len + 1
43214243

43224244
if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
4323-
if not synced_gpus:
4324-
break
4325-
else:
4326-
this_peer_finished = True
4245+
this_peer_finished = True
43274246

43284247
sequence_outputs = constrained_beam_scorer.finalize(
43294248
input_ids,
@@ -4553,18 +4472,8 @@ def _assisted_decoding(
45534472
# other auxiliary variables
45544473
max_len = stopping_criteria[0].max_length
45554474

4556-
this_peer_finished = False # used by synced_gpus only
4557-
while True:
4558-
if synced_gpus:
4559-
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
4560-
# The following logic allows an early break if all peers finished generating their sequence
4561-
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
4562-
# send 0.0 if we finished, 1.0 otherwise
4563-
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
4564-
# did all peers finish? the reduced sum will be 0.0 then
4565-
if this_peer_finished_flag.item() == 0.0:
4566-
break
4567-
4475+
this_peer_finished = False
4476+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
45684477
cur_len = input_ids.shape[-1]
45694478

45704479
# 1. Fetch candidate sequences from a `CandidateGenerator`
@@ -4733,13 +4642,7 @@ def _assisted_decoding(
47334642
)
47344643

47354644
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
4736-
4737-
# stop when each sentence is finished
4738-
if unfinished_sequences.max() == 0:
4739-
this_peer_finished = True
4740-
4741-
if this_peer_finished and not synced_gpus:
4742-
break
4645+
this_peer_finished = unfinished_sequences.max() == 0
47434646

47444647
if streamer is not None:
47454648
streamer.end()

0 commit comments

Comments
 (0)