From ded9d769a9f3f290d5b2ef80268487e562e64069 Mon Sep 17 00:00:00 2001
From: sychen <sychen@ctrip.com>
Date: Tue, 23 Jul 2024 20:56:18 -0700
Subject: [PATCH] ORC-1741: Respect decimal reader isRepeating flag

### What changes were proposed in this pull request?
Decimal type, when `isRepeating` itself is false, do not try to change it.

### Why are the changes needed?
https://github.com/apache/hive/pull/5218#discussion_r1647367003

[ORC-1266](https://issues.apache.org/jira/browse/ORC-1266): DecimalColumnVector resets the isRepeating flag in the nextVector method

### How was this patch tested?
Add UT

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #1960 from cxzl25/decimal_isRepeating.

Authored-by: sychen <sychen@ctrip.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
(cherry picked from commit e818d56f06610d8c6dd41304d7a327062f3a0cd8)
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
---
 .../apache/orc/impl/TreeReaderFactory.java    |  8 ---
 .../impl/TestConvertTreeReaderFactory.java    |  8 +--
 .../apache/orc/impl/TestRecordReaderImpl.java | 68 +++++++++++++++++++
 3 files changed, 72 insertions(+), 12 deletions(-)

diff --git a/java/core/src/java/org/apache/orc/impl/TreeReaderFactory.java b/java/core/src/java/org/apache/orc/impl/TreeReaderFactory.java
index 2a2adf50d7..418b9c9561 100644
--- a/java/core/src/java/org/apache/orc/impl/TreeReaderFactory.java
+++ b/java/core/src/java/org/apache/orc/impl/TreeReaderFactory.java
@@ -1551,7 +1551,6 @@ private void nextVector(DecimalColumnVector result,
       HiveDecimalWritable[] vector = result.vector;
       HiveDecimalWritable decWritable;
       if (result.noNulls) {
-        result.isRepeating = true;
         for (int r = 0; r < batchSize; ++r) {
           decWritable = vector[r];
           if (!decWritable.serializationUtilsRead(
@@ -1563,7 +1562,6 @@ private void nextVector(DecimalColumnVector result,
           setIsRepeatingIfNeeded(result, r);
         }
       } else if (!result.isRepeating || !result.isNull[0]) {
-        result.isRepeating = true;
         for (int r = 0; r < batchSize; ++r) {
           if (!result.isNull[r]) {
             decWritable = vector[r];
@@ -1595,7 +1593,6 @@ private void nextVector(DecimalColumnVector result,
       HiveDecimalWritable[] vector = result.vector;
       HiveDecimalWritable decWritable;
       if (result.noNulls) {
-        result.isRepeating = true;
         int previousIdx = 0;
         for (int r = 0; r != filterContext.getSelectedSize(); ++r) {
           int idx = filterContext.getSelected()[r];
@@ -1614,7 +1611,6 @@ private void nextVector(DecimalColumnVector result,
         }
         skipStreamRows(batchSize - previousIdx);
       } else if (!result.isRepeating || !result.isNull[0]) {
-        result.isRepeating = true;
         int previousIdx = 0;
         for (int r = 0; r != filterContext.getSelectedSize(); ++r) {
           int idx = filterContext.getSelected()[r];
@@ -1651,14 +1647,12 @@ private void nextVector(Decimal64ColumnVector result,
       // read the scales
       scaleReader.nextVector(result, scratchScaleVector, batchSize);
       if (result.noNulls) {
-        result.isRepeating = true;
         for (int r = 0; r < batchSize; ++r) {
           final long scaleFactor = powerOfTenTable[scale - scratchScaleVector[r]];
           result.vector[r] = SerializationUtils.readVslong(valueStream) * scaleFactor;
           setIsRepeatingIfNeeded(result, r);
         }
       } else if (!result.isRepeating || !result.isNull[0]) {
-        result.isRepeating = true;
         for (int r = 0; r < batchSize; ++r) {
           if (!result.isNull[r]) {
             final long scaleFactor = powerOfTenTable[scale - scratchScaleVector[r]];
@@ -1686,7 +1680,6 @@ private void nextVector(Decimal64ColumnVector result,
       // Read all the scales
       scaleReader.nextVector(result, scratchScaleVector, batchSize);
       if (result.noNulls) {
-        result.isRepeating = true;
         int previousIdx = 0;
         for (int r = 0; r != filterContext.getSelectedSize(); r++) {
           int idx = filterContext.getSelected()[r];
@@ -1702,7 +1695,6 @@ private void nextVector(Decimal64ColumnVector result,
         }
         skipStreamRows(batchSize - previousIdx);
       } else if (!result.isRepeating || !result.isNull[0]) {
-        result.isRepeating = true;
         int previousIdx = 0;
         for (int r = 0; r != filterContext.getSelectedSize(); r++) {
           int idx = filterContext.getSelected()[r];
diff --git a/java/core/src/test/org/apache/orc/impl/TestConvertTreeReaderFactory.java b/java/core/src/test/org/apache/orc/impl/TestConvertTreeReaderFactory.java
index a90a285a65..860b18aa7e 100644
--- a/java/core/src/test/org/apache/orc/impl/TestConvertTreeReaderFactory.java
+++ b/java/core/src/test/org/apache/orc/impl/TestConvertTreeReaderFactory.java
@@ -707,7 +707,7 @@ private void readDecimalInNullStripe(String typeString, Class<?> expectedColumnT
     assertTrue(batch.cols[0].isRepeating);
     StringBuilder sb = new StringBuilder();
     batch.cols[0].stringifyValue(sb, 1023);
-    assertEquals(sb.toString(), expectedResult[0]);
+    assertEquals(expectedResult[0], sb.toString());
 
     rows.nextBatch(batch);
     assertEquals(1024, batch.size);
@@ -717,17 +717,17 @@ private void readDecimalInNullStripe(String typeString, Class<?> expectedColumnT
     assertFalse(batch.cols[0].isRepeating);
     StringBuilder sb2 = new StringBuilder();
     batch.cols[0].stringifyValue(sb2, 1023);
-    assertEquals(sb2.toString(), expectedResult[1]);
+    assertEquals(expectedResult[1], sb2.toString());
 
     rows.nextBatch(batch);
     assertEquals(1024, batch.size);
     assertEquals(expected, options.toString());
     assertEquals(batch.cols.length, 1);
     assertEquals(batch.cols[0].getClass(), expectedColumnType);
-    assertTrue(batch.cols[0].isRepeating);
+    assertFalse(batch.cols[0].isRepeating);
     StringBuilder sb3 = new StringBuilder();
     batch.cols[0].stringifyValue(sb3, 1023);
-    assertEquals(sb3.toString(), expectedResult[2]);
+    assertEquals(expectedResult[2], sb3.toString());
   }
 
   private void testDecimalConvertToLongInNullStripe() throws Exception {
diff --git a/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java b/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java
index f0124715b8..378f0fcdad 100644
--- a/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java
+++ b/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java
@@ -28,6 +28,7 @@
 import org.apache.hadoop.hive.common.io.DiskRangeList;
 import org.apache.hadoop.hive.common.type.HiveDecimal;
 import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector;
+import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector;
 import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector;
 import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector;
 import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
@@ -2732,4 +2733,71 @@ public void testHadoopVectoredIO() throws Exception {
 
     verify(spyFSDataInputStream, atLeastOnce()).readVectored(any(), any());
   }
+
+  @Test
+  public  void testDecimalIsRepeatingFlag() throws IOException {
+    Configuration conf = new Configuration();
+    FileSystem fs = FileSystem.get(conf);
+    Path testFilePath = new Path(workDir, "testDecimalIsRepeatingFlag.orc");
+    fs.delete(testFilePath, true);
+
+    Configuration decimalConf = new Configuration(conf);
+    decimalConf.set(OrcConf.STRIPE_ROW_COUNT.getAttribute(), "1024");
+    decimalConf.set(OrcConf.ROWS_BETWEEN_CHECKS.getAttribute(), "1");
+    String typeStr = "decimal(20,10)";
+    TypeDescription schema = TypeDescription.fromString("struct<col1:" + typeStr + ">");
+    Writer w = OrcFile.createWriter(testFilePath, OrcFile.writerOptions(decimalConf).setSchema(schema));
+
+    VectorizedRowBatch b = schema.createRowBatch();
+    DecimalColumnVector f1 = (DecimalColumnVector) b.cols[0];
+    for (int i = 0; i < 1024; i++) {
+      f1.set(i, HiveDecimal.create("-119.4594594595"));
+    }
+    b.size = 1024;
+    w.addRowBatch(b);
+
+    b.reset();
+    for (int i = 0; i < 1024; i++) {
+      f1.set(i, HiveDecimal.create("9318.4351351351"));
+    }
+    b.size = 1024;
+    w.addRowBatch(b);
+
+    b.reset();
+    for (int i = 0; i < 1024; i++) {
+      f1.set(i, HiveDecimal.create("-4298.1513513514"));
+    }
+    b.size = 1024;
+    w.addRowBatch(b);
+
+    b.reset();
+    w.close();
+
+    Reader.Options options = new Reader.Options();
+    try (Reader reader = OrcFile.createReader(testFilePath, OrcFile.readerOptions(conf));
+         RecordReader rows = reader.rows(options)) {
+      VectorizedRowBatch batch = schema.createRowBatch();
+
+      rows.nextBatch(batch);
+      assertEquals(1024, batch.size);
+      assertFalse(batch.cols[0].isRepeating);
+      for (HiveDecimalWritable hiveDecimalWritable : ((DecimalColumnVector) batch.cols[0]).vector) {
+        assertEquals(HiveDecimal.create("-119.4594594595"), hiveDecimalWritable.getHiveDecimal());
+      }
+
+      rows.nextBatch(batch);
+      assertEquals(1024, batch.size);
+      assertFalse(batch.cols[0].isRepeating);
+      for (HiveDecimalWritable hiveDecimalWritable : ((DecimalColumnVector) batch.cols[0]).vector) {
+        assertEquals(HiveDecimal.create("9318.4351351351"), hiveDecimalWritable.getHiveDecimal());
+      }
+
+      rows.nextBatch(batch);
+      assertEquals(1024, batch.size);
+      assertFalse(batch.cols[0].isRepeating);
+      for (HiveDecimalWritable hiveDecimalWritable : ((DecimalColumnVector) batch.cols[0]).vector) {
+        assertEquals(HiveDecimal.create("-4298.1513513514"), hiveDecimalWritable.getHiveDecimal());
+      }
+    }
+  }
 }