Skip to content

Commit

Permalink
fix predicates push down error when query app_logs with line_num
Browse files Browse the repository at this point in the history
  • Loading branch information
naive-zhang committed Dec 21, 2024
1 parent 5f6597a commit f9f1fb2
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.connector.read.PartitionReader
import org.apache.spark.sql.sources.{EqualTo, In}
import org.apache.spark.sql.sources.{EqualTo, Filter, In}
import org.apache.spark.unsafe.types.UTF8String

class YarnAppPartitionReader(yarnAppPartition: YarnAppPartition)
Expand Down Expand Up @@ -98,26 +98,41 @@ class YarnAppPartitionReader(yarnAppPartition: YarnAppPartition)
case _ => yarnClient.getApplications()
}.get
}

val appSeq = applicationReports.asScala.map(app => {
YarnApplication(
id = app.getApplicationId.toString,
appType = app.getApplicationType,
user = app.getUser,
name = app.getName,
state = app.getYarnApplicationState.name,
queue = app.getQueue,
attemptId = app.getCurrentApplicationAttemptId.toString,
submitTime = app.getSubmitTime,
launchTime = app.getLaunchTime,
startTime = app.getStartTime,
finishTime = app.getFinishTime,
trackingUrl = app.getTrackingUrl,
originalTrackingUrl = app.getOriginalTrackingUrl)
})
val appSeq = applicationReports.asScala.filter(app =>
yarnAppPartition.filters
.forall(filter => maybeFilter(app, filter)))
.map(app => {
YarnApplication(
id = app.getApplicationId.toString,
appType = app.getApplicationType,
user = app.getUser,
name = app.getName,
state = app.getYarnApplicationState.name,
queue = app.getQueue,
attemptId = app.getCurrentApplicationAttemptId.toString,
submitTime = app.getSubmitTime,
launchTime = app.getLaunchTime,
startTime = app.getStartTime,
finishTime = app.getFinishTime,
trackingUrl = app.getTrackingUrl,
originalTrackingUrl = app.getOriginalTrackingUrl)
})
yarnClient.close()
appSeq
}

private def maybeFilter(app: ApplicationReport, filter: Filter): Boolean = {
filter match {
case EqualTo("id", appId: String) => app.getApplicationId.toString eq appId
case EqualTo("state", appState: String) => app.getYarnApplicationState.name() eq appState
case EqualTo("type", appType: String) => app.getApplicationType eq appType
case In("state", states) => states.map(x => x.toString)
.contains(app.getYarnApplicationState.name())
case In("type", types) => types.map(x => x.toString)
.contains(app.getApplicationType)
case _ => false
}
}
}

// Helper class to represent app
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.kyuubi.spark.connector.yarn

import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.sources.{EqualTo, Filter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand All @@ -27,4 +28,18 @@ case class YarnAppScanBuilder(options: CaseInsensitiveStringMap, schema: StructT
override def build(): Scan = {
YarnAppScan(options, schema, pushed)
}

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
val (supportedFilter, unsupportedFilter) = filters.partition {
case filter: EqualTo =>
filter match {
case EqualTo("app_id", _) => true
case EqualTo("user", _) => true
case _ => false
}
case _ => false
}
pushed = supportedFilter
unsupportedFilter
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
package org.apache.kyuubi.spark.connector.yarn

import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.sources.Filter

case class YarnLogPartition(
hadoopConfMap: Map[String, String],
logPath: String,
remoteAppLogDir: String)
remoteAppLogDir: String,
filters: Array[Filter])
extends InputPartition
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,13 @@ class YarnLogPartitionReader(yarnLogPartition: YarnLogPartition)
s"${containerHost}_${containerSuffix}",
containerHost,
lineNumber,
path.getName,
path.toUri.getPath,
line)
}
logEntries
} finally {
IOUtils.closeStream(inputStream)
reader.close()
fs.close()
}
case _ => Seq.empty
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,13 @@ case class YarnLogScan(
case pushed if pushed.isEmpty => listFiles(remoteAppLogDir)
case pushed => pushed.collectFirst {
case EqualTo("app_id", appId: String) =>
listFiles(s"${remoteAppLogDir}/*/*/*/${appId}")
listFiles(s"${remoteAppLogDir}/*/*/*/${appId}") ++
// compatible for hadoop2
listFiles(s"${remoteAppLogDir}/*/*/${appId}")
case EqualTo("container_id", containerId: String) =>
listFiles(s"${remoteAppLogDir}/*/*/*/*/${containerId}")
listFiles(s"${remoteAppLogDir}/*/*/*/*/${containerId}") ++
// compatible for hadoop2
listFiles(s"${remoteAppLogDir}/*/*/*/${containerId}")
case EqualTo("user", user: String) => listFiles(s"${remoteAppLogDir}/${user}")
case _ => listFiles(remoteAppLogDir)
}.get
Expand All @@ -101,7 +105,7 @@ case class YarnLogScan(
override def planInputPartitions(): Array[InputPartition] = {
// get file nums and construct nums inputPartition
tryPushDownPredicates().map(fileStatus => {
YarnLogPartition(hadoopConfMap, fileStatus.getPath.toString, remoteAppLogDir)
YarnLogPartition(hadoopConfMap, fileStatus.getPath.toString, remoteAppLogDir, filters)
}).toArray
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.kyuubi.spark.connector.yarn

import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.sources.{EqualTo, Filter, In}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand All @@ -31,7 +31,17 @@ case class YarnLogScanBuilder(options: CaseInsensitiveStringMap, schema: StructT

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
val (supportedFilter, unsupportedFilter) = filters.partition {
case _: org.apache.spark.sql.sources.EqualTo => true
case filter: EqualTo =>
filter match {
case EqualTo("id", _) => true
case EqualTo("state", _) => true
case EqualTo("type", _) => true
}
case filter: In =>
filter match {
case In("state", _) => true
case In("type", _) => true
}
case _ => false
}
pushed = supportedFilter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.kyuubi.spark.connector.yarn
import java.io.{File, FileWriter}
import java.util.Collections

import scala.util.Random

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.yarn.api.records.{ApplicationSubmissionContext, ContainerLaunchContext, Resource, YarnApplicationState}
import org.apache.hadoop.yarn.client.api.YarnClient
Expand All @@ -34,6 +36,8 @@ import org.apache.kyuubi.util.JavaUtils

trait WithKyuubiServerAndYarnMiniCluster extends KyuubiFunSuite with WithKyuubiServer {

private val taskTypeSet: Set[String] = Set("TYPE_1", "TYPE_2", "TYPE_3")

override protected val conf: KyuubiConf = new KyuubiConf(false)

val kyuubiHome: String = JavaUtils.getCodeSourceLocation(getClass).split("extensions").head
Expand Down Expand Up @@ -154,6 +158,8 @@ trait WithKyuubiServerAndYarnMiniCluster extends KyuubiFunSuite with WithKyuubiS
.getApplicationSubmissionContext.getApplicationId
appContext.setApplicationId(applicationId)
appContext.setApplicationName("TestApp")
// use random pickup
appContext.setApplicationType(taskTypeSet.toSeq(Random.nextInt(taskTypeSet.size)))

// Set up container launch context (e.g., commands to execute)
val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,5 +190,4 @@ class YarnAppQuerySuite extends SparkYarnConnectorWithYarn {
yarnClient.close()
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class YarnLogQuerySuite extends SparkYarnConnectorWithYarn {
withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
spark.sql("USE yarn")
val cnt = spark.sql(
"select count(1) from yarn.default.app_logs where host='localhost'").collect().head.getLong(
"select count(1) from yarn.default.app_logs " +
"where (host='localhost' or host like '%host') and " +
"app_id like '%application%'").collect().head.getLong(
0)
assert(cnt > 0)
}
Expand All @@ -66,8 +68,10 @@ class YarnLogQuerySuite extends SparkYarnConnectorWithYarn {
})
withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
spark.sql("USE yarn")
val host = spark.sql(
"select * from yarn.default.app_logs limit 10").collect().head.getString(2)
val rows = spark.sql(
"select * from yarn.default.app_logs where line_num = 10" +
" limit 10").collect()
val host = rows.head.getString(2)
assert(host == "localhost")
}
}
Expand Down

0 comments on commit f9f1fb2

Please sign in to comment.