Skip to content

Commit 09ed835

Browse files
authored
Merge pull request #174 from structuredllm/tmp
Fix temperature logits processor
2 parents 657b451 + 6ef8767 commit 09ed835

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

syncode/language_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import syncode.common as common
55
from syncode.grammar_mask.logits_processor import SyncodeLogitsProcessor
6-
from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria
6+
from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel
77
from syncode.parsers.grammars import Grammar
88
from syncode.utils.generation import filter_code, fix_indents
99
from typing import Callable, Iterable, Union
@@ -48,7 +48,7 @@ def __init__(
4848
super().__init__()
4949

5050
self.prompt_template = prompt_template
51-
self.model = model
51+
self.model: PreTrainedModel = model
5252
self.tokenizer = tokenizer
5353
self.device = device
5454
self.best_of = best_of
@@ -193,7 +193,9 @@ def _generate(
193193

194194
# This does not include grammar decoder
195195
self.model._prepare_special_tokens(gen_config, False, device=self.device)
196-
logits_processor = self.model._get_logits_processor(gen_config, token_ids.size(1), token_ids, prefix_allowed_tokens_fn=None, logits_processor=[])
196+
197+
# Add logits processor for generation parameters such as top_k, top_p, temperature, etc.
198+
logits_processor = self.model._get_logits_warper(gen_config, self.device)
197199

198200
max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1)
199201
self.model.config.pad_token_id = pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id

0 commit comments

Comments
 (0)