File tree 1 file changed +2
-2
lines changed
1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -195,7 +195,8 @@ def __init__(self, params: ModelArgs):
195
195
self .tok_embeddings .weight = self .output .weight # https://paperswithcode.com/method/weight-tying
196
196
197
197
# some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse
198
- self .freqs_cis = precompute_freqs_cis (self .params .dim // self .params .n_heads , self .params .max_seq_len * 2 )
198
+ freqs_cis = precompute_freqs_cis (self .params .dim // self .params .n_heads , self .params .max_seq_len * 2 )
199
+ self .register_buffer ("freqs_cis" , freqs_cis , persistent = False )
199
200
200
201
# init all weights
201
202
self .apply (self ._init_weights )
@@ -215,7 +216,6 @@ def _init_weights(self, module):
215
216
def forward (self , tokens , targets = None ):
216
217
_bsz , seqlen = tokens .shape
217
218
h = self .tok_embeddings (tokens )
218
- self .freqs_cis = self .freqs_cis .to (h .device )
219
219
freqs_cis = self .freqs_cis [:seqlen ]
220
220
221
221
for layer in self .layers :
You can’t perform that action at this time.
0 commit comments