File tree Expand file tree Collapse file tree 1 file changed +14
-10
lines changed Expand file tree Collapse file tree 1 file changed +14
-10
lines changed Original file line number Diff line number Diff line change @@ -140,6 +140,7 @@ class GeneratorArgs:
140
140
speculate_k : int = 5
141
141
sequential_prefill : bool = False
142
142
max_autotune : bool = False
143
+ # (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
143
144
is_torchtune_model : bool = False
144
145
145
146
def __post_init__ (self ):
@@ -958,16 +959,19 @@ def chat(
958
959
if get_system_prompt == "y" or get_system_prompt == "Y" :
959
960
self .system_prompt = input ("What is your system prompt? \n " )
960
961
961
- # elif not generator_args.is_torchtune_model:
962
- # max_seq_length = min(
963
- # encoded.size(0) + generator_args.max_new_tokens,
964
- # (
965
- # text_transformer_args.block_size
966
- # if text_transformer_args is not None
967
- # else 2048
968
- # ),
969
- # max_seq_length,
970
- # )
962
+ # `is_torchtune_model` is a misnomer since it doesn't capture all
963
+ # torchtune models (i.e. Flamingo)
964
+ # See Issue: https://github.com/pytorch/torchchat/issues/1273
965
+ elif not generator_args .is_torchtune_model and self .model .config .model_type != ModelType .Flamingo :
966
+ max_seq_length = min (
967
+ encoded .size (0 ) + generator_args .max_new_tokens ,
968
+ (
969
+ text_transformer_args .block_size
970
+ if text_transformer_args is not None
971
+ else 2048
972
+ ),
973
+ max_seq_length ,
974
+ )
971
975
972
976
max_seq_length = (
973
977
max_seq_length + self .speculative_builder_args .speculate_k + 1
You can’t perform that action at this time.
0 commit comments