File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
backends/candle/src/models Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -451,16 +451,17 @@ impl Qwen3Model {
451451 . flat_map ( |i| ( 0 ..seq_len) . map ( move |j| ( j > i) as u8 ) )
452452 . collect ( ) ;
453453
454- let causal_mask = Tensor :: from_slice ( & mask, ( seq_len, seq_len) , & Device :: Cpu ) ?;
454+ let device = attention_bias. device ( ) ;
455+ let causal_mask = Tensor :: from_slice ( & mask, ( seq_len, seq_len) , device) ?;
455456 let causal_mask = causal_mask. expand ( & [ bs, dim, seq_len, seq_len] ) ?;
456457
457458 let negatives =
458- Tensor :: full ( f32:: MIN , attention_bias. shape ( ) , & Device :: Cpu ) ?. to_dtype ( self . dtype ) ?;
459+ Tensor :: full ( f32:: MIN , attention_bias. shape ( ) , device ) ?. to_dtype ( self . dtype ) ?;
459460 let zeros = Tensor :: zeros_like ( & attention_bias) ?. to_dtype ( self . dtype ) ?;
460461
461462 let causal_mask = causal_mask
462463 . where_cond ( & negatives, & zeros) ?
463- . to_device ( & self . device ) ?;
464+ . to_device ( device) ?;
464465
465466 attention_bias. broadcast_add ( & causal_mask)
466467 }
You can’t perform that action at this time.
0 commit comments