Skip to content

Commit eca477b

Browse files
authored
Flag misnomer: GeneratorArgs.is_torchtune_model (pytorch#1274)
1 parent b217158 commit eca477b

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

torchchat/generate.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ class GeneratorArgs:
140140
speculate_k: int = 5
141141
sequential_prefill: bool = False
142142
max_autotune: bool = False
143+
# (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
143144
is_torchtune_model: bool = False
144145

145146
def __post_init__(self):
@@ -958,16 +959,19 @@ def chat(
958959
if get_system_prompt == "y" or get_system_prompt == "Y":
959960
self.system_prompt = input("What is your system prompt? \n")
960961

961-
# elif not generator_args.is_torchtune_model:
962-
# max_seq_length = min(
963-
# encoded.size(0) + generator_args.max_new_tokens,
964-
# (
965-
# text_transformer_args.block_size
966-
# if text_transformer_args is not None
967-
# else 2048
968-
# ),
969-
# max_seq_length,
970-
# )
962+
# `is_torchtune_model` is a misnomer since it doesn't capture all
963+
# torchtune models (i.e. Flamingo)
964+
# See Issue: https://github.com/pytorch/torchchat/issues/1273
965+
elif not generator_args.is_torchtune_model and self.model.config.model_type != ModelType.Flamingo:
966+
max_seq_length = min(
967+
encoded.size(0) + generator_args.max_new_tokens,
968+
(
969+
text_transformer_args.block_size
970+
if text_transformer_args is not None
971+
else 2048
972+
),
973+
max_seq_length,
974+
)
971975

972976
max_seq_length = (
973977
max_seq_length + self.speculative_builder_args.speculate_k + 1

0 commit comments

Comments
 (0)