Skip to content

Commit bd9e837

Browse files
authored
Merge pull request ggml-org#23 from awgu/pt2
Register `freqs_cis` as non-persistent buffer
2 parents 3bfa566 + af3b5c0 commit bd9e837

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ def __init__(self, params: ModelArgs):
195195
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
196196

197197
# some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse
198-
self.freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
198+
freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
199+
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
199200

200201
# init all weights
201202
self.apply(self._init_weights)
@@ -215,7 +216,6 @@ def _init_weights(self, module):
215216
def forward(self, tokens, targets=None):
216217
_bsz, seqlen = tokens.shape
217218
h = self.tok_embeddings(tokens)
218-
self.freqs_cis = self.freqs_cis.to(h.device)
219219
freqs_cis = self.freqs_cis[:seqlen]
220220

221221
for layer in self.layers:

0 commit comments

Comments
 (0)