Skip to content

Commit

Permalink
Fix the bad dimensions in BGE output [skip test]
Browse files Browse the repository at this point in the history
  • Loading branch information
maziyarpanahi committed Jan 1, 2024
1 parent d1d2261 commit f398f17
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
11 changes: 4 additions & 7 deletions src/main/scala/com/johnsnowlabs/ml/ai/BGE.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,20 +183,17 @@ private[johnsnowlabs] class BGE(
val info = lastHiddenState.getInfo.asInstanceOf[TensorInfo]
val shape = info.getShape
try {
val embeddings = lastHiddenState
val flattenEmbeddings = lastHiddenState
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()

val dim = shape.last.toInt
// Perfom CLS pooling (the first element of each sequence)
val clsPooling = embeddings.grouped(dim).map(_.head).toArray
val normalizedSentenceEmbeddings = LinAlg.lpNormalizeArray(clsPooling, 2)

Array(normalizedSentenceEmbeddings)
val embeddings = LinAlg.avgPooling(flattenEmbeddings, attentionMask, shape)
val normalizedEmbeddings = LinAlg.l2Normalize(embeddings)
LinAlg.denseMatrixToArray(normalizedEmbeddings)
} finally if (results != null) results.close()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,48 @@ class BGEEmbeddingsTestSpec extends AnyFlatSpec {
pipelineDF.select("bge.embeddings").show(false)
}

it should "not return empty embeddings" taggedAs SlowTest in {
import ResourceHelper.spark.implicits._
val interests = Seq(
"I like music",
"I like movies",
"I like books",
"I like sports",
"I like travel",
"I like food",
"I like games",
"I like art",
"I like nature",
"I like science",
"I like technology",
"I like history",
"I like fashion",
"I like cars",
"I like animals",
"I like gardening")
val testDf = interests.toDF("text")

val document = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("document")

val embeddings = BGEEmbeddings
.pretrained()
.setInputCols(Array("document"))
.setOutputCol("bge")

val pipeline = new Pipeline().setStages(Array(document, embeddings))

val pipelineDF = pipeline.fit(testDf).transform(testDf)

val embeddingsDF = pipelineDF.withColumn("embeddings", col("bge.embeddings").getItem(0))

val sizesArray: Array[Int] = embeddingsDF
.select(size(col("embeddings")).as("size"))
.collect()
.map(row => row.getAs[Int]("size"))

assert(sizesArray.forall(_ > 0))
}

}

0 comments on commit f398f17

Please sign in to comment.