|
3 | 3 | import torch |
4 | 4 | import syncode.common as common |
5 | 5 | from syncode.grammar_mask.logits_processor import SyncodeLogitsProcessor |
6 | | -from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria |
| 6 | +from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel |
7 | 7 | from syncode.parsers.grammars import Grammar |
8 | 8 | from syncode.utils.generation import filter_code, fix_indents |
9 | 9 | from typing import Callable, Iterable, Union |
@@ -48,7 +48,7 @@ def __init__( |
48 | 48 | super().__init__() |
49 | 49 |
|
50 | 50 | self.prompt_template = prompt_template |
51 | | - self.model = model |
| 51 | + self.model: PreTrainedModel = model |
52 | 52 | self.tokenizer = tokenizer |
53 | 53 | self.device = device |
54 | 54 | self.best_of = best_of |
@@ -193,7 +193,9 @@ def _generate( |
193 | 193 |
|
194 | 194 | # This does not include grammar decoder |
195 | 195 | self.model._prepare_special_tokens(gen_config, False, device=self.device) |
196 | | - logits_processor = self.model._get_logits_processor(gen_config, token_ids.size(1), token_ids, prefix_allowed_tokens_fn=None, logits_processor=[]) |
| 196 | + |
| 197 | + # Add logits processor for generation parameters such as top_k, top_p, temperature, etc. |
| 198 | + logits_processor = self.model._get_logits_warper(gen_config, self.device) |
197 | 199 |
|
198 | 200 | max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1) |
199 | 201 | self.model.config.pad_token_id = pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id |
|
0 commit comments