11package org .beehive .gpullama3 .model .loader ;
22
3- import org .beehive .gpullama3 .LlamaApp ;
43import org .beehive .gpullama3 .Options ;
5- import org .beehive .gpullama3 .auxiliary .Timer ;
64import org .beehive .gpullama3 .core .model .GGMLType ;
75import org .beehive .gpullama3 .core .model .GGUF ;
86import org .beehive .gpullama3 .core .model .tensor .ArrayFloatTensor ;
2119import org .beehive .gpullama3 .tokenizer .impl .Qwen3Tokenizer ;
2220import org .beehive .gpullama3 .tokenizer .impl .Tokenizer ;
2321import org .beehive .gpullama3 .tokenizer .vocabulary .Vocabulary ;
22+ import org .beehive .gpullama3 .tornadovm .TornadoVMMasterPlan ;
2423import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
2524
2625import java .io .IOException ;
@@ -40,11 +39,9 @@ public Model loadModel() {
4039 Map <String , Object > metadata = gguf .getMetadata ();
4140 String basename = (String ) metadata .get ("general.basename" );
4241
43- String modelName = "DeepSeek-R1-Distill-Qwen" .equals (basename )
44- ? "DeepSeek-R1-Distill-Qwen"
45- : "Qwen2.5" ;
42+ String modelName = "DeepSeek-R1-Distill-Qwen" .equals (basename ) ? "DeepSeek-R1-Distill-Qwen" : "Qwen2.5" ;
4643
47- try ( var ignored = Timer . log ( "Load " + modelName + " model" )) {
44+ try {
4845 // reuse method of Qwen3
4946 Vocabulary vocabulary = loadQwen3Vocabulary (metadata );
5047 boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen" .equals (metadata .get ("general.basename" ));
@@ -55,11 +52,8 @@ public Model loadModel() {
5552 contextLength = modelContextLength ;
5653 }
5754
58- int numberOfKeyValueHeads = metadata .containsKey ("qwen2.attention.head_count_kv" )
59- ? (int ) metadata .get ("qwen2.attention.head_count_kv" )
60- : (int ) metadata .get ("qwen2.attention.head_count" );
61- Qwen2Configuration config = new Qwen2Configuration (
62- (int ) metadata .get ("qwen2.embedding_length" ), // dim
55+ int numberOfKeyValueHeads = metadata .containsKey ("qwen2.attention.head_count_kv" ) ? (int ) metadata .get ("qwen2.attention.head_count_kv" ) : (int ) metadata .get ("qwen2.attention.head_count" );
56+ Qwen2Configuration config = new Qwen2Configuration ((int ) metadata .get ("qwen2.embedding_length" ), // dim
6357 (int ) metadata .get ("qwen2.feed_forward_length" ), // hiddendim
6458 (int ) metadata .get ("qwen2.block_count" ), // numberOfLayers
6559 (int ) metadata .get ("qwen2.attention.head_count" ), // numberOfHeads
@@ -68,22 +62,17 @@ public Model loadModel() {
6862 numberOfKeyValueHeads , // numberOfHeadsKey
6963 numberOfKeyValueHeads , // numberOfHeadsValue
7064
71- vocabulary .size (),
72- modelContextLength , contextLength ,
73- false ,
74- (float ) metadata .get ("qwen2.attention.layer_norm_rms_epsilon" ),
75- (float ) metadata .get ("qwen2.rope.freq_base" )
76- );
65+ vocabulary .size (), modelContextLength , contextLength , false , (float ) metadata .get ("qwen2.attention.layer_norm_rms_epsilon" ), (float ) metadata .get ("qwen2.rope.freq_base" ));
7766
7867 Weights weights = null ;
7968 if (loadWeights ) {
8069 Map <String , GGMLTensorEntry > tensorEntries = GGUF .loadTensors (fileChannel , gguf .getTensorDataOffset (), gguf .getTensorInfos ());
8170 weights = loadWeights (tensorEntries , config );
8271 }
8372 // Qwen2.5-Coder uses <|endoftext|> as stop-token.
84- ChatTokens chatTokens = isDeepSeekR1DistillQwen ?
85- new ChatTokens ( "<|begin▁of▁sentence|>" , "" , "" , "<|end▁of▁sentence|>" , "" ) :
86- new ChatTokens ( "<|im_start|>" , "<|im_end|>" , "" , "<|end_of_text|>" , "<|endoftext|>" );
73+ ChatTokens chatTokens = isDeepSeekR1DistillQwen
74+ ? new ChatTokens ("<|begin▁of▁sentence|>" , "" , "" , "<|end▁of▁sentence|>" , "" )
75+ : new ChatTokens ("<|im_start|>" , "<|im_end|>" , "" , "<|end_of_text|>" , "<|endoftext|>" );
8776 return new Qwen2 (config , tokenizer , weights , ChatFormat .create (tokenizer , chatTokens ));
8877 } catch (IOException e ) {
8978 throw new RuntimeException (e );
@@ -108,7 +97,9 @@ public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configura
10897 GGMLTensorEntry outputWeight = tensorEntries .getOrDefault ("output.weight" , tokenEmbeddings );
10998
11099 if (Options .getDefaultOptions ().useTornadovm ()) {
111- System .out .println ("Loading model weights in TornadoVM format (loading " + outputWeight .ggmlType () + " -> " + GGMLType .F16 + ")" );
100+ if (TornadoVMMasterPlan .ENABLE_TORNADOVM_INIT_TIME ) {
101+ System .out .println ("Loading model weights in TornadoVM format (loading " + outputWeight .ggmlType () + " -> " + GGMLType .F16 + ")" );
102+ }
112103 return createTornadoVMWeights (tensorEntries , config , ropeFreqs , tokenEmbeddings , outputWeight );
113104 } else {
114105 return createStandardWeights (tensorEntries , config , ropeFreqs , tokenEmbeddings , outputWeight );
0 commit comments