Skip to content

Commit 4c83df7

Browse files
committed
Added Openvino support
1 parent 4e6b443 commit 4c83df7

File tree

4 files changed

+175
-39
lines changed

4 files changed

+175
-39
lines changed

python/sparknlp/annotator/seq2seq/qwen_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.QwenTransf
297297
repetitionPenalty=1.0, noRepeatNgramSize=0, ignoreTokenIds=[], batchSize=1)
298298

299299
@staticmethod
300-
def loadSavedModel(folder, spark_session):
300+
def loadSavedModel(folder, spark_session, use_openvino=False):
301301
"""Loads a locally saved model.
302302
303303
Parameters
@@ -313,7 +313,7 @@ def loadSavedModel(folder, spark_session):
313313
The restored model
314314
"""
315315
from sparknlp.internal import _QwenLoader
316-
jModel = _QwenLoader(folder, spark_session._jsparkSession)._java_obj
316+
jModel = _QwenLoader(folder, spark_session._jsparkSession, use_openvino)._java_obj
317317
return QwenTransformer(java_model=jModel)
318318

319319
@staticmethod

python/sparknlp/internal/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,9 @@ def __init__(self, path, jspark, useCache):
379379
)
380380

381381
class _QwenLoader(ExtendedJavaWrapper):
382-
def __init__(self, path, jspark):
382+
def __init__(self, path, jspark, use_openvino=False):
383383
super(_QwenLoader, self).__init__(
384-
"com.johnsnowlabs.nlp.annotators.seq2seq.QwenTransformer.loadSavedModel", path, jspark)
384+
"com.johnsnowlabs.nlp.annotators.seq2seq.QwenTransformer.loadSavedModel", path, jspark, use_openvino)
385385

386386

387387
class _USELoader(ExtendedJavaWrapper):

src/main/scala/com/johnsnowlabs/ml/ai/Qwen.scala

Lines changed: 121 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,33 @@ import com.johnsnowlabs.ml.ai.util.Generation.{Generate, GenerationConfig}
2121
import com.johnsnowlabs.ml.onnx.OnnxSession
2222
import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers
2323
import com.johnsnowlabs.ml.onnx.TensorResources.implicits._
24+
import com.johnsnowlabs.ml.openvino.OpenvinoWrapper
2425
import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper
26+
import com.johnsnowlabs.ml.util.{ONNX, Openvino, TensorFlow}
2527
import com.johnsnowlabs.nlp.Annotation
2628
import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
2729
import com.johnsnowlabs.nlp.annotators.common.SentenceSplit
2830
import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.{BpeTokenizer, QwenTokenizer}
31+
import org.intel.openvino.InferRequest
2932
import org.tensorflow.{Session, Tensor}
3033

3134
import scala.collection.JavaConverters._
3235

3336
private[johnsnowlabs] class Qwen(
34-
val onnxWrappers: DecoderWrappers,
37+
val onnxWrappers: Option[DecoderWrappers],
38+
val openvinoWrapper: Option[OpenvinoWrapper],
3539
merges: Map[(String, String), Int],
3640
vocabulary: Map[String, Int],
3741
generationConfig: GenerationConfig)
3842
extends Serializable
3943
with Generate {
4044

4145
private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions
46+
val detectedEngine: String =
47+
if (onnxWrappers.isDefined) ONNX.name
48+
else if (openvinoWrapper.isDefined) Openvino.name
49+
else ONNX.name
50+
private var nextPositionId: Option[Array[Long]] = None
4251
val bpeTokenizer: QwenTokenizer = BpeTokenizer
4352
.forModel("qwen", merges = merges, vocab = vocabulary, padWithSequenceTokens = false)
4453
.asInstanceOf[QwenTokenizer]
@@ -93,8 +102,8 @@ private[johnsnowlabs] class Qwen(
93102
randomSeed: Option[Long],
94103
ignoreTokenIds: Array[Int] = Array(),
95104
beamSize: Int,
96-
maxInputLength: Int): Array[Array[Int]] = {
97-
val (encoderSession, env) = onnxWrappers.decoder.getSession(onnxSessionOptions)
105+
maxInputLength: Int,
106+
stopTokenIds: Array[Int]): Array[Array[Int]] = {
98107
val ignoreTokenIdsInt = ignoreTokenIds
99108
val expandedDecoderInputsVals = batch
100109
val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
@@ -121,10 +130,23 @@ private[johnsnowlabs] class Qwen(
121130
// (encoderSession, env),
122131
// maxOutputLength)
123132

124-
// dummy tensors for decoder encode state and attention mask
125-
val decoderEncoderStateTensors = Right(OnnxTensor.createTensor(env, Array(0)))
126-
val encoderAttentionMaskTensors = Right(OnnxTensor.createTensor(env, Array(1)))
127-
133+
val (decoderEncoderStateTensors, encoderAttentionMaskTensors, session) =
134+
detectedEngine match {
135+
case ONNX.name =>
136+
// dummy tensors for decoder encode state and attention mask
137+
val (encoderSession, env) = onnxWrappers.get.decoder.getSession(onnxSessionOptions)
138+
(
139+
Right(OnnxTensor.createTensor(env, Array(0))),
140+
Right(OnnxTensor.createTensor(env, Array(1))),
141+
Right((env, encoderSession)))
142+
case Openvino.name =>
143+
// not needed
144+
(null, null, null)
145+
}
146+
val ovInferRequest: Option[InferRequest] = detectedEngine match {
147+
case ONNX.name => None
148+
case Openvino.name => Some(openvinoWrapper.get.getCompiledModel().create_infer_request())
149+
}
128150
// output with beam search
129151
val modelOutputs = generate(
130152
batch,
@@ -146,8 +168,10 @@ private[johnsnowlabs] class Qwen(
146168
this.paddingTokenId,
147169
randomSeed,
148170
ignoreTokenIdsInt,
149-
Right((env, encoderSession)),
150-
applySoftmax = false)
171+
session,
172+
applySoftmax = false,
173+
ovInferRequest = ovInferRequest,
174+
stopTokenIds = stopTokenIds)
151175

152176
// decoderOutputs
153177
modelOutputs
@@ -167,7 +191,8 @@ private[johnsnowlabs] class Qwen(
167191
randomSeed: Option[Long] = None,
168192
ignoreTokenIds: Array[Int] = Array(),
169193
beamSize: Int,
170-
maxInputLength: Int): Seq[Annotation] = {
194+
maxInputLength: Int,
195+
stopTokenIds: Array[Int]): Seq[Annotation] = {
171196

172197
val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
173198
val batchSP = encode(batch)
@@ -184,7 +209,8 @@ private[johnsnowlabs] class Qwen(
184209
randomSeed,
185210
ignoreTokenIds,
186211
beamSize,
187-
maxInputLength)
212+
maxInputLength,
213+
stopTokenIds)
188214

189215
decode(spIds)
190216

@@ -239,20 +265,76 @@ private[johnsnowlabs] class Qwen(
239265
decoderEncoderStateTensors: Either[Tensor, OnnxTensor],
240266
encoderAttentionMaskTensors: Either[Tensor, OnnxTensor],
241267
maxLength: Int,
242-
session: Either[Session, (OrtEnvironment, OrtSession)]): Array[Array[Float]] = {
268+
session: Either[Session, (OrtEnvironment, OrtSession)],
269+
ovInferRequest: Option[InferRequest]): Array[Array[Float]] = {
243270

244-
session.fold(
245-
tfSession => {
271+
detectedEngine match {
272+
case TensorFlow.name =>
246273
// not implemented yet
247274
Array()
248-
},
249-
onnxSession => {
250-
val (env, decoderSession) = onnxSession
275+
case ONNX.name =>
276+
val (env, decoderSession) = session.right.get
251277
val decoderOutputs =
252278
getDecoderOutputs(decoderInputIds.toArray, onnxSession = (decoderSession, env))
253279
decoderOutputs
254-
})
280+
case Openvino.name =>
281+
val decoderOutputs =
282+
getDecoderOutputsOv(decoderInputIds.toArray, ovInferRequest.get)
283+
decoderOutputs
284+
}
285+
}
255286

287+
private def getDecoderOutputsOv(
288+
inputIds: Array[Array[Int]],
289+
inferRequest: InferRequest): (Array[Array[Float]]) = {
290+
val (inputIdsLong, inputPositionIDsLong): (Array[Long], Array[Long]) =
291+
if (nextPositionId.isDefined) {
292+
val inpIdsLong = inputIds.map { tokenIds => tokenIds.last.toLong }
293+
(inpIdsLong, nextPositionId.get)
294+
} else {
295+
val inpIdsLong = inputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }
296+
val posIdsLong = inputIds.flatMap { tokenIds =>
297+
tokenIds.zipWithIndex.map { case (_, i) =>
298+
i.toLong
299+
}
300+
}
301+
(inpIdsLong, posIdsLong)
302+
}
303+
val attentionMask: Array[Long] =
304+
inputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) }
305+
306+
val batchSize: Int = inputIds.length
307+
val beamIdx: Array[Int] = new Array[Int](batchSize)
308+
val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize)
309+
310+
val inputIdsLongTensor: org.intel.openvino.Tensor =
311+
new org.intel.openvino.Tensor(shape, inputIdsLong)
312+
val decoderAttentionMask: org.intel.openvino.Tensor =
313+
new org.intel.openvino.Tensor(Array(batchSize, inputIds.head.length), attentionMask)
314+
val decoderPositionIDs: org.intel.openvino.Tensor =
315+
new org.intel.openvino.Tensor(shape, inputPositionIDsLong)
316+
val beamIdxTensor: org.intel.openvino.Tensor =
317+
new org.intel.openvino.Tensor(Array(batchSize), beamIdx)
318+
319+
inferRequest.set_tensor(OpenVinoSignatures.decoderInputIDs, inputIdsLongTensor)
320+
inferRequest.set_tensor(OpenVinoSignatures.decoderAttentionMask, decoderAttentionMask)
321+
inferRequest.set_tensor(OpenVinoSignatures.decoderPositionIDs, decoderPositionIDs)
322+
inferRequest.set_tensor(OpenVinoSignatures.decoderBeamIdx, beamIdxTensor)
323+
324+
inferRequest.infer()
325+
326+
val result = inferRequest.get_tensor(OpenVinoSignatures.decoderOutput)
327+
val logitsRaw = result.data()
328+
nextPositionId = Some(inputIds.map(tokenIds => tokenIds.length.toLong))
329+
330+
val sequenceLength = inputIdsLong.length / batchSize
331+
val decoderOutputs = (0 until batchSize).map(i => {
332+
logitsRaw
333+
.slice(
334+
i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize,
335+
i * sequenceLength * vocabSize + sequenceLength * vocabSize)
336+
})
337+
decoderOutputs.toArray
256338
}
257339
private def getDecoderOutputs(
258340
inputIds: Array[Array[Int]],
@@ -285,12 +367,12 @@ private[johnsnowlabs] class Qwen(
285367
val sequenceLength = inputIds.head.length
286368
val batchSize = inputIds.length
287369

288-
// val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput)
289-
// inputIdsLongTensor.close()
290-
// decoderPositionIDs.close()
291-
// decoderAttentionMask.close()
292-
// val batchLogits = logits.grouped(vocabSize).toArray
293-
// batchLogits
370+
// val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput)
371+
// inputIdsLongTensor.close()
372+
// decoderPositionIDs.close()
373+
// decoderAttentionMask.close()
374+
// val batchLogits = logits.grouped(vocabSize).toArray
375+
// batchLogits
294376

295377
val logitsRaw = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput)
296378
val decoderOutputs = (0 until batchSize).map(i => {
@@ -358,4 +440,19 @@ private[johnsnowlabs] class Qwen(
358440
(0 until 32).flatMap(i => Seq(s"present.$i.key", s"present.$i.value")).toArray
359441
}
360442

443+
private object OpenVinoSignatures {
444+
val encoderInputIDs: String = "input_ids"
445+
val encoderAttentionMask: String = "attention_mask"
446+
447+
val encoderOutput: String = "last_hidden_state"
448+
449+
val decoderInputIDs: String = "input_ids"
450+
val decoderEncoderAttentionMask: String = "encoder_attention_mask"
451+
val decoderAttentionMask: String = "attention_mask"
452+
val decoderPositionIDs: String = "position_ids"
453+
val decoderBeamIdx: String = "beam_idx"
454+
val decoderEncoderState: String = "encoder_hidden_states"
455+
456+
val decoderOutput: String = "logits"
457+
}
361458
}

0 commit comments

Comments
 (0)