From f398f17c3b06cdd567ae7295896d0d7f97785452 Mon Sep 17 00:00:00 2001 From: Maziyar Panahi Date: Mon, 1 Jan 2024 15:36:27 +0000 Subject: [PATCH] Fix the bad dimensions in BGE output [skip test] --- .../scala/com/johnsnowlabs/ml/ai/BGE.scala | 11 ++--- .../embeddings/BGEEmbeddingsTestSpec.scala | 44 +++++++++++++++++++ 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/BGE.scala b/src/main/scala/com/johnsnowlabs/ml/ai/BGE.scala index fb421b1fd58ebf..54d401e7c5b321 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/BGE.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/BGE.scala @@ -183,7 +183,7 @@ private[johnsnowlabs] class BGE( val info = lastHiddenState.getInfo.asInstanceOf[TensorInfo] val shape = info.getShape try { - val embeddings = lastHiddenState + val flattenEmbeddings = lastHiddenState .asInstanceOf[OnnxTensor] .getFloatBuffer .array() @@ -191,12 +191,9 @@ private[johnsnowlabs] class BGE( maskTensors.close() segmentTensors.close() - val dim = shape.last.toInt - // Perfom CLS pooling (the first element of each sequence) - val clsPooling = embeddings.grouped(dim).map(_.head).toArray - val normalizedSentenceEmbeddings = LinAlg.lpNormalizeArray(clsPooling, 2) - - Array(normalizedSentenceEmbeddings) + val embeddings = LinAlg.avgPooling(flattenEmbeddings, attentionMask, shape) + val normalizedEmbeddings = LinAlg.l2Normalize(embeddings) + LinAlg.denseMatrixToArray(normalizedEmbeddings) } finally if (results != null) results.close() } } diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddingsTestSpec.scala index 77e92795b8bacd..567e78fb0e4cc9 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddingsTestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddingsTestSpec.scala @@ -113,4 +113,48 @@ class BGEEmbeddingsTestSpec extends AnyFlatSpec { pipelineDF.select("bge.embeddings").show(false) } + it should "not return empty embeddings" taggedAs SlowTest in { + import ResourceHelper.spark.implicits._ + val interests = Seq( + "I like music", + "I like movies", + "I like books", + "I like sports", + "I like travel", + "I like food", + "I like games", + "I like art", + "I like nature", + "I like science", + "I like technology", + "I like history", + "I like fashion", + "I like cars", + "I like animals", + "I like gardening") + val testDf = interests.toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val embeddings = BGEEmbeddings + .pretrained() + .setInputCols(Array("document")) + .setOutputCol("bge") + + val pipeline = new Pipeline().setStages(Array(document, embeddings)) + + val pipelineDF = pipeline.fit(testDf).transform(testDf) + + val embeddingsDF = pipelineDF.withColumn("embeddings", col("bge.embeddings").getItem(0)) + + val sizesArray: Array[Int] = embeddingsDF + .select(size(col("embeddings")).as("size")) + .collect() + .map(row => row.getAs[Int]("size")) + + assert(sizesArray.forall(_ > 0)) + } + }