@@ -21,24 +21,33 @@ import com.johnsnowlabs.ml.ai.util.Generation.{Generate, GenerationConfig}
2121import com .johnsnowlabs .ml .onnx .OnnxSession
2222import com .johnsnowlabs .ml .onnx .OnnxWrapper .DecoderWrappers
2323import com .johnsnowlabs .ml .onnx .TensorResources .implicits ._
24+ import com .johnsnowlabs .ml .openvino .OpenvinoWrapper
2425import com .johnsnowlabs .ml .tensorflow .sentencepiece .SentencePieceWrapper
26+ import com .johnsnowlabs .ml .util .{ONNX , Openvino , TensorFlow }
2527import com .johnsnowlabs .nlp .Annotation
2628import com .johnsnowlabs .nlp .AnnotatorType .DOCUMENT
2729import com .johnsnowlabs .nlp .annotators .common .SentenceSplit
2830import com .johnsnowlabs .nlp .annotators .tokenizer .bpe .{BpeTokenizer , QwenTokenizer }
31+ import org .intel .openvino .InferRequest
2932import org .tensorflow .{Session , Tensor }
3033
3134import scala .collection .JavaConverters ._
3235
3336private [johnsnowlabs] class Qwen (
34- val onnxWrappers : DecoderWrappers ,
37+ val onnxWrappers : Option [DecoderWrappers ],
38+ val openvinoWrapper : Option [OpenvinoWrapper ],
3539 merges : Map [(String , String ), Int ],
3640 vocabulary : Map [String , Int ],
3741 generationConfig : GenerationConfig )
3842 extends Serializable
3943 with Generate {
4044
4145 private val onnxSessionOptions : Map [String , String ] = new OnnxSession ().getSessionOptions
46+ val detectedEngine : String =
47+ if (onnxWrappers.isDefined) ONNX .name
48+ else if (openvinoWrapper.isDefined) Openvino .name
49+ else ONNX .name
50+ private var nextPositionId : Option [Array [Long ]] = None
4251 val bpeTokenizer : QwenTokenizer = BpeTokenizer
4352 .forModel(" qwen" , merges = merges, vocab = vocabulary, padWithSequenceTokens = false )
4453 .asInstanceOf [QwenTokenizer ]
@@ -93,8 +102,8 @@ private[johnsnowlabs] class Qwen(
93102 randomSeed : Option [Long ],
94103 ignoreTokenIds : Array [Int ] = Array (),
95104 beamSize : Int ,
96- maxInputLength : Int ) : Array [ Array [ Int ]] = {
97- val (encoderSession, env) = onnxWrappers.decoder.getSession(onnxSessionOptions)
105+ maxInputLength : Int ,
106+ stopTokenIds : Array [ Int ]) : Array [ Array [ Int ]] = {
98107 val ignoreTokenIdsInt = ignoreTokenIds
99108 val expandedDecoderInputsVals = batch
100109 val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
@@ -121,10 +130,23 @@ private[johnsnowlabs] class Qwen(
121130// (encoderSession, env),
122131// maxOutputLength)
123132
124- // dummy tensors for decoder encode state and attention mask
125- val decoderEncoderStateTensors = Right (OnnxTensor .createTensor(env, Array (0 )))
126- val encoderAttentionMaskTensors = Right (OnnxTensor .createTensor(env, Array (1 )))
127-
133+ val (decoderEncoderStateTensors, encoderAttentionMaskTensors, session) =
134+ detectedEngine match {
135+ case ONNX .name =>
136+ // dummy tensors for decoder encode state and attention mask
137+ val (encoderSession, env) = onnxWrappers.get.decoder.getSession(onnxSessionOptions)
138+ (
139+ Right (OnnxTensor .createTensor(env, Array (0 ))),
140+ Right (OnnxTensor .createTensor(env, Array (1 ))),
141+ Right ((env, encoderSession)))
142+ case Openvino .name =>
143+ // not needed
144+ (null , null , null )
145+ }
146+ val ovInferRequest : Option [InferRequest ] = detectedEngine match {
147+ case ONNX .name => None
148+ case Openvino .name => Some (openvinoWrapper.get.getCompiledModel().create_infer_request())
149+ }
128150 // output with beam search
129151 val modelOutputs = generate(
130152 batch,
@@ -146,8 +168,10 @@ private[johnsnowlabs] class Qwen(
146168 this .paddingTokenId,
147169 randomSeed,
148170 ignoreTokenIdsInt,
149- Right ((env, encoderSession)),
150- applySoftmax = false )
171+ session,
172+ applySoftmax = false ,
173+ ovInferRequest = ovInferRequest,
174+ stopTokenIds = stopTokenIds)
151175
152176// decoderOutputs
153177 modelOutputs
@@ -167,7 +191,8 @@ private[johnsnowlabs] class Qwen(
167191 randomSeed : Option [Long ] = None ,
168192 ignoreTokenIds : Array [Int ] = Array (),
169193 beamSize : Int ,
170- maxInputLength : Int ): Seq [Annotation ] = {
194+ maxInputLength : Int ,
195+ stopTokenIds : Array [Int ]): Seq [Annotation ] = {
171196
172197 val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
173198 val batchSP = encode(batch)
@@ -184,7 +209,8 @@ private[johnsnowlabs] class Qwen(
184209 randomSeed,
185210 ignoreTokenIds,
186211 beamSize,
187- maxInputLength)
212+ maxInputLength,
213+ stopTokenIds)
188214
189215 decode(spIds)
190216
@@ -239,20 +265,76 @@ private[johnsnowlabs] class Qwen(
239265 decoderEncoderStateTensors : Either [Tensor , OnnxTensor ],
240266 encoderAttentionMaskTensors : Either [Tensor , OnnxTensor ],
241267 maxLength : Int ,
242- session : Either [Session , (OrtEnvironment , OrtSession )]): Array [Array [Float ]] = {
268+ session : Either [Session , (OrtEnvironment , OrtSession )],
269+ ovInferRequest : Option [InferRequest ]): Array [Array [Float ]] = {
243270
244- session.fold(
245- tfSession => {
271+ detectedEngine match {
272+ case TensorFlow .name =>
246273 // not implemented yet
247274 Array ()
248- },
249- onnxSession => {
250- val (env, decoderSession) = onnxSession
275+ case ONNX .name =>
276+ val (env, decoderSession) = session.right.get
251277 val decoderOutputs =
252278 getDecoderOutputs(decoderInputIds.toArray, onnxSession = (decoderSession, env))
253279 decoderOutputs
254- })
280+ case Openvino .name =>
281+ val decoderOutputs =
282+ getDecoderOutputsOv(decoderInputIds.toArray, ovInferRequest.get)
283+ decoderOutputs
284+ }
285+ }
255286
287+ private def getDecoderOutputsOv (
288+ inputIds : Array [Array [Int ]],
289+ inferRequest : InferRequest ): (Array [Array [Float ]]) = {
290+ val (inputIdsLong, inputPositionIDsLong): (Array [Long ], Array [Long ]) =
291+ if (nextPositionId.isDefined) {
292+ val inpIdsLong = inputIds.map { tokenIds => tokenIds.last.toLong }
293+ (inpIdsLong, nextPositionId.get)
294+ } else {
295+ val inpIdsLong = inputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }
296+ val posIdsLong = inputIds.flatMap { tokenIds =>
297+ tokenIds.zipWithIndex.map { case (_, i) =>
298+ i.toLong
299+ }
300+ }
301+ (inpIdsLong, posIdsLong)
302+ }
303+ val attentionMask : Array [Long ] =
304+ inputIds.flatMap { tokenIds => tokenIds.map(_ => 1L ) }
305+
306+ val batchSize : Int = inputIds.length
307+ val beamIdx : Array [Int ] = new Array [Int ](batchSize)
308+ val shape : Array [Int ] = Array (batchSize, inputIdsLong.length / batchSize)
309+
310+ val inputIdsLongTensor : org.intel.openvino.Tensor =
311+ new org.intel.openvino.Tensor (shape, inputIdsLong)
312+ val decoderAttentionMask : org.intel.openvino.Tensor =
313+ new org.intel.openvino.Tensor (Array (batchSize, inputIds.head.length), attentionMask)
314+ val decoderPositionIDs : org.intel.openvino.Tensor =
315+ new org.intel.openvino.Tensor (shape, inputPositionIDsLong)
316+ val beamIdxTensor : org.intel.openvino.Tensor =
317+ new org.intel.openvino.Tensor (Array (batchSize), beamIdx)
318+
319+ inferRequest.set_tensor(OpenVinoSignatures .decoderInputIDs, inputIdsLongTensor)
320+ inferRequest.set_tensor(OpenVinoSignatures .decoderAttentionMask, decoderAttentionMask)
321+ inferRequest.set_tensor(OpenVinoSignatures .decoderPositionIDs, decoderPositionIDs)
322+ inferRequest.set_tensor(OpenVinoSignatures .decoderBeamIdx, beamIdxTensor)
323+
324+ inferRequest.infer()
325+
326+ val result = inferRequest.get_tensor(OpenVinoSignatures .decoderOutput)
327+ val logitsRaw = result.data()
328+ nextPositionId = Some (inputIds.map(tokenIds => tokenIds.length.toLong))
329+
330+ val sequenceLength = inputIdsLong.length / batchSize
331+ val decoderOutputs = (0 until batchSize).map(i => {
332+ logitsRaw
333+ .slice(
334+ i * sequenceLength * vocabSize + (sequenceLength - 1 ) * vocabSize,
335+ i * sequenceLength * vocabSize + sequenceLength * vocabSize)
336+ })
337+ decoderOutputs.toArray
256338 }
257339 private def getDecoderOutputs (
258340 inputIds : Array [Array [Int ]],
@@ -285,12 +367,12 @@ private[johnsnowlabs] class Qwen(
285367 val sequenceLength = inputIds.head.length
286368 val batchSize = inputIds.length
287369
288- // val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput)
289- // inputIdsLongTensor.close()
290- // decoderPositionIDs.close()
291- // decoderAttentionMask.close()
292- // val batchLogits = logits.grouped(vocabSize).toArray
293- // batchLogits
370+ // val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput)
371+ // inputIdsLongTensor.close()
372+ // decoderPositionIDs.close()
373+ // decoderAttentionMask.close()
374+ // val batchLogits = logits.grouped(vocabSize).toArray
375+ // batchLogits
294376
295377 val logitsRaw = sessionOutput.getFloatArray(OnnxSignatures .decoderOutput)
296378 val decoderOutputs = (0 until batchSize).map(i => {
@@ -358,4 +440,19 @@ private[johnsnowlabs] class Qwen(
358440 (0 until 32 ).flatMap(i => Seq (s " present. $i.key " , s " present. $i.value " )).toArray
359441 }
360442
443+ private object OpenVinoSignatures {
444+ val encoderInputIDs : String = " input_ids"
445+ val encoderAttentionMask : String = " attention_mask"
446+
447+ val encoderOutput : String = " last_hidden_state"
448+
449+ val decoderInputIDs : String = " input_ids"
450+ val decoderEncoderAttentionMask : String = " encoder_attention_mask"
451+ val decoderAttentionMask : String = " attention_mask"
452+ val decoderPositionIDs : String = " position_ids"
453+ val decoderBeamIdx : String = " beam_idx"
454+ val decoderEncoderState : String = " encoder_hidden_states"
455+
456+ val decoderOutput : String = " logits"
457+ }
361458}
0 commit comments