-
Notifications
You must be signed in to change notification settings - Fork 736
Closed
Labels
Description
When I try E5Embeddings on a sentence basis, I get an error below. If the input has no more than one sentence, It works without issue on a sentence basis.
In addition, there is no problem when taking Document input.
sparknlp.version: 5.1.4
spark.version: 3.3.1
colab link: https://colab.research.google.com/drive/1NhmhGcuNiR5hrNn2POm3uj7W_zJZn-mz?usp=sharing
Code:
documentAssembler = DocumentAssembler() \
.setInputCol("text") \
.setOutputCol("document")
sentencerDL = SentenceDetectorDLModel\
.pretrained("sentence_detector_dl", "en") \
.setInputCols(["document"]) \
.setOutputCol("sentences")
e5embeddings = E5Embeddings.pretrained("e5_base","en") \
.setInputCols("sentences") \
.setOutputCol("e5_embeddings")
pipeline = Pipeline(stages=[documentAssembler,sentencerDL,e5embeddings])
data = spark.createDataFrame([
["""passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day."""],
]).toDF("text")
pipeline.fit(data).transform(data).show()- Error
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
[<ipython-input-4-507e232c7224>](https://localhost:8080/#) in <cell line: 6>()
4 ["""passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day."""],
5 ]).toDF("text")
----> 6 pipeline.fit(data).transform(data).show()
3 frames
[/usr/local/lib/python3.10/dist-packages/py4j/protocol.py](https://localhost:8080/#) in get_return_value(answer, gateway_client, target_id, name)
324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
325 if answer[1] == REFERENCE_TYPE:
--> 326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
328 format(target_id, ".", name), value)
Py4JJavaError: An error occurred while calling o229.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 4.0 failed 1 times, most recent failure: Lost task 0.0 in stage 4.0 (TID 5) (fdd11149b536 executor driver): ai.onnxruntime.OrtException: Supplied array is ragged, expected 4, found 28
at ai.onnxruntime.TensorInfo.extractShape(TensorInfo.java:391)
at ai.onnxruntime.TensorInfo.extractShape(TensorInfo.java:395)
at ai.onnxruntime.TensorInfo.constructFromJavaArray(TensorInfo.java:300)
at ai.onnxruntime.OnnxTensor.createTensor(OnnxTensor.java:313)
at ai.onnxruntime.OnnxTensor.createTensor(OnnxTensor.java:297)
at com.johnsnowlabs.ml.ai.E5.getSentenceEmbeddingFromOnnx(E5.scala:156)
at com.johnsnowlabs.ml.ai.E5.getSentenceEmbedding(E5.scala:68)
at com.johnsnowlabs.ml.ai.E5.$anonfun$predict$1(E5.scala:221)
at scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:293)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
at scala.collection.TraversableLike.flatMap(TraversableLike.scala:293)
at scala.collection.TraversableLike.flatMap$(TraversableLike.scala:290)
at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:198)
at com.johnsnowlabs.ml.ai.E5.predict(E5.scala:215)
at com.johnsnowlabs.nlp.embeddings.E5Embeddings.batchAnnotate(E5Embeddings.scala:314)
at com.johnsnowlabs.nlp.HasBatchedAnnotate.$anonfun$batchProcess$1(HasBatchedAnnotate.scala:59)
at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:364)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:890)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:890)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:136)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
at java.base/java.lang.Thread.run(Thread.java:829)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
at scala.Option.foreach(Option.scala:407)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2228)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2249)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2268)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:506)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:459)
at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:48)
at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:3868)
at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:2863)
at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:3858)
at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:510)
at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3856)
at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:109)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:169)
at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:95)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:779)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3856)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2863)
at org.apache.spark.sql.Dataset.take(Dataset.scala:3084)
at org.apache.spark.sql.Dataset.getRows(Dataset.scala:288)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:327)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: ai.onnxruntime.OrtException: Supplied array is ragged, expected 4, found 28
at ai.onnxruntime.TensorInfo.extractShape(TensorInfo.java:391)
at ai.onnxruntime.TensorInfo.extractShape(TensorInfo.java:395)
at ai.onnxruntime.TensorInfo.constructFromJavaArray(TensorInfo.java:300)
at ai.onnxruntime.OnnxTensor.createTensor(OnnxTensor.java:313)
at ai.onnxruntime.OnnxTensor.createTensor(OnnxTensor.java:297)
at com.johnsnowlabs.ml.ai.E5.getSentenceEmbeddingFromOnnx(E5.scala:156)
at com.johnsnowlabs.ml.ai.E5.getSentenceEmbedding(E5.scala:68)
at com.johnsnowlabs.ml.ai.E5.$anonfun$predict$1(E5.scala:221)
at scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:293)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
at scala.collection.TraversableLike.flatMap(TraversableLike.scala:293)
at scala.collection.TraversableLike.flatMap$(TraversableLike.scala:290)
at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:198)
at com.johnsnowlabs.ml.ai.E5.predict(E5.scala:215)
at com.johnsnowlabs.nlp.embeddings.E5Embeddings.batchAnnotate(E5Embeddings.scala:314)
at com.johnsnowlabs.nlp.HasBatchedAnnotate.$anonfun$batchProcess$1(HasBatchedAnnotate.scala:59)
at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:364)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:890)
at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:890)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:136)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
... 1 more