diff --git a/src/main/scala/org/apache/spark/sql/sources/druid/DruidQueryCostModel.scala b/src/main/scala/org/apache/spark/sql/sources/druid/DruidQueryCostModel.scala index 66ea140..cb6af9f 100644 --- a/src/main/scala/org/apache/spark/sql/sources/druid/DruidQueryCostModel.scala +++ b/src/main/scala/org/apache/spark/sql/sources/druid/DruidQueryCostModel.scala @@ -40,34 +40,47 @@ sealed trait DruidQueryCost extends Ordered[DruidQueryCost] { } } -case class CostDetails( - costInput : CostInput[_ <: QuerySpec], - histProcessingCostPerRow : Double, - queryOutputSizeEstimate : Long, - segmentOutputSizeEstimate : Long, - numSegmentsProcessed : Long, - numSparkCores : Long, - numHistoricalThreads : Long, - parallelismPerWave : Long, - minCost : DruidQueryCost, - allCosts : List[DruidQueryCost] - ) { +case class CostDetails(costInput : CostInput, + histMergeCostPerRow: Double, + brokerMergeCostPerRow: Double, + histInputProcessingCostPerRow: Double, + histOutputProcessingCostPerRow: Double, + queryInputSizeEstimate: Long, + segmentInputSizeEstimate : Long, + queryOutputSizeEstimate: Long, + segmentOutputSizeEstimate: Long, + numSparkCores: Long, + numHistoricalThreads: Long, + parallelismPerWave: Long, + minCost : DruidQueryCost, + allCosts : List[DruidQueryCost] + ) { override def toString : String = { val header = s"""Druid Query Cost Model:: - | |${costInput}histProcessingCost = $histProcessingCostPerRow - | |queryOutputSizeEstimate = $queryOutputSizeEstimate - | |segmentOutputSizeEstimate = $segmentOutputSizeEstimate - | |numSegmentsProcessed = $numSegmentsProcessed - | |numSparkCores = $numSparkCores - | |numHistoricalThreads = $numHistoricalThreads - | |parallelismPerWave = $parallelismPerWave + ${costInput} +Cost Per Row( + histMergeCost=$histMergeCostPerRow, + brokerMergeCost=$brokerMergeCostPerRow, + histInputProcessingCost=$histInputProcessingCostPerRow, + histOutputProcessingCost=$histOutputProcessingCostPerRow + ) +numSegmentsProcessed = ${costInput.numSegmentsProcessed} +Size Estimates( + queryInputSize=$queryInputSizeEstimate, + segmentInputSize=$segmentInputSizeEstimate, + queryOutputSize=$queryOutputSizeEstimate, + segmentOutputSize=$segmentOutputSizeEstimate + ) +Environment( + numSparkCores = $numSparkCores + numHistoricalThreads = $numHistoricalThreads + parallelismPerWave = $parallelismPerWave + ) + |minCost : $minCost | - | |minCost : $minCost - | | - | """.stripMargin - + |""".stripMargin val rows = allCosts.map(CostDetails.costRow(_)).mkString("", "\n", "") @@ -216,8 +229,9 @@ case class DruidQueryMethod( } } -case class CostInput[T <: QuerySpec]( - dimsNDVEstimate : Long, +case class CostInput( + inputEstimate : Long, + outputEstimate : Long, shuffleCostPerRow : Double, histMergeCostPerRowFactor : Double, histSegsPerQueryLimit : Int, @@ -235,13 +249,14 @@ case class CostInput[T <: QuerySpec]( numSparkExecutors : Int, numProcessingThreadsPerHistorical : Long, numHistoricals : Int, - querySpecClass : Class[T] + querySpec : Either[QuerySpec,Class[_ <: QuerySpec]] ) { // scalastyle:off line.size.limit override def toString : String = { s""" - |dimsNDVEstimate = $dimsNDVEstimate + |inputEstimate = $inputEstimate + |outputEstimate = $outputEstimate |shuffleCostPerRow = $shuffleCostPerRow, |histMergeCostPerRowFactor = $histMergeCostPerRowFactor, |histSegsPerQueryLimit = $histSegsPerQueryLimit, @@ -259,191 +274,131 @@ case class CostInput[T <: QuerySpec]( |numSparkExecutors = $numSparkExecutors, |numProcessingThreadsPerHistorical = $numProcessingThreadsPerHistorical, |numHistoricals = $numHistoricals, - |querySpecClass = $querySpecClass + |querySpec = $querySpec """.stripMargin } // scalastyle:on } -object DruidQueryCostModel extends Logging { +sealed trait QueryCost extends Logging { + val cI: CostInput - def intervalsMillis(intervals : List[Interval]) : Long = Utils.intervalsMillis(intervals) + import cI._ - def intervalNDVEstimate( - intervalMillis : Long, - totalIntervalMillis : Long, - ndvForIndexEstimate : Long, - queryIntervalRatioScaleFactor : Double - ) : Long = { - val intervalRatio : Double = Math.min(intervalMillis.toDouble/totalIntervalMillis.toDouble, 1.0) - var scaledRatio = Math.min(queryIntervalRatioScaleFactor * intervalRatio * 10.0, 10.0) - if ( scaledRatio < 1.0 ) { - scaledRatio = queryIntervalRatioScaleFactor * intervalRatio - } else { - scaledRatio = Math.log10(scaledRatio) - } - Math.round(ndvForIndexEstimate * scaledRatio) - } + def histMergeCostPerRow: Double - private[druid] def compute[T <: QuerySpec](cI : CostInput[T]) : DruidQueryMethod = { - import cI._ - - val histMergeCostPerRow = shuffleCostPerRow * histMergeCostPerRowFactor - val brokerMergeCostPerRow = shuffleCostPerRow * histMergeCostPerRowFactor - val histProcessingCostPerRow = { - if (classOf[TimeSeriesQuerySpec].isAssignableFrom(querySpecClass)) { - historicalTimeSeriesProcessingCostPerRowFactor - } else { - historicalGByProcessigCostPerRowFactor - } * shuffleCostPerRow - } + def brokerMergeCostPerRow: Double - val queryOutputSizeEstimate = intervalNDVEstimate(queryIntervalMillis, - indexIntervalMillis, - dimsNDVEstimate, - queryIntervalRatioScaleFactor - ) + def histInputProcessingCostPerRow: Double + def histOutputProcessingCostPerRow: Double - val segmentOutputSizeEstimate = querySpecClass match { - case qC if classOf[SelectSpec].isAssignableFrom(qC) => - queryOutputSizeEstimate/numSegmentsProcessed - case _ => intervalNDVEstimate(segIntervalMillis, - indexIntervalMillis, - dimsNDVEstimate, - queryIntervalRatioScaleFactor - ) - } + def queryInputSizeEstimate: Long - val numSparkCores : Long = { - numSparkExecutors * sparkCoresPerExecutor - } + def segmentInputSizeEstimate : Long - val numHistoricalThreads = numHistoricals * numProcessingThreadsPerHistorical + def queryOutputSizeEstimate: Long - val parallelismPerWave = Math.min(numHistoricalThreads, numSparkCores) + def segmentOutputSizeEstimate: Long - def estimateNumWaves(numSegsPerQuery : Long, - parallelism : Long = parallelismPerWave) : Long = { - var d = numSegmentsProcessed.toDouble / numSegsPerQuery.toDouble - d = d/ parallelism - Math.round(d + 0.5) - } + def numSparkCores: Long - def brokerQueryCost : DruidQueryCost = { - val numWaves: Long = estimateNumWaves(1, numHistoricalThreads) - val processingCostPerHist : Double = - segmentOutputSizeEstimate * histProcessingCostPerRow - - val numMerges = 2 * (numSegmentsProcessed - 1) - - val brokertMergeCost : Double = - (Math.max(numMerges.toDouble / numProcessingThreadsPerHistorical.toDouble,1.0)) * - segmentOutputSizeEstimate * brokerMergeCostPerRow - val segmentOutputTransportCost = queryOutputSizeEstimate * - (druidOutputTransportCostPerRowFactor * shuffleCostPerRow) - val queryCost: Double = { - /* - * SearchQuerySpecs cannot be run against broker. - */ - if ( classOf[SelectSpec].isAssignableFrom(cI.querySpecClass)) { - Double.MaxValue - } else { - numWaves * processingCostPerHist + segmentOutputTransportCost + brokertMergeCost - } - } + def numHistoricalThreads: Long - BrokerQueryCost( - numWaves, - processingCostPerHist, - brokertMergeCost, - segmentOutputTransportCost, - queryCost - ) - } + def parallelismPerWave: Long - def histQueryCost(numSegsPerQuery : Long) : DruidQueryCost = { + def estimateNumWaves(numSegsPerQuery: Long, + parallelism: Long = parallelismPerWave): Long = { + var d = numSegmentsProcessed.toDouble / numSegsPerQuery.toDouble + d = d / parallelism + Math.round(d + 0.5) + } - val numWaves = estimateNumWaves(numSegsPerQuery) - val estimateOutputSizePerHist = intervalNDVEstimate( - segIntervalMillis * numSegsPerQuery, - indexIntervalMillis, - dimsNDVEstimate, - queryIntervalRatioScaleFactor - ) + def shuffleCostPerWave: Double - val processingCostPerHist : Double = - numSegsPerQuery * segmentOutputSizeEstimate * histProcessingCostPerRow + def sparkSchedulingCostPerWave: Double - val histMergeCost : Double = - (numSegsPerQuery - 1) * segmentOutputSizeEstimate * histMergeCostPerRow + def sparkAggCostPerWave: Double - val segmentOutputTransportCost = estimateOutputSizePerHist * - (druidOutputTransportCostPerRowFactor * shuffleCostPerRow) + def intervalRatio(intervalMillis: Long): Double = + Math.min(intervalMillis.toDouble / indexIntervalMillis.toDouble, 1.0) - val shuffleCost = numWaves * segmentOutputSizeEstimate * shuffleCostPerRow + def scaledRatio(intervalMillis: Long): Double = { + val s = Math.min(queryIntervalRatioScaleFactor * intervalRatio(intervalMillis) * 10.0, 10.0) - val sparkSchedulingCost = - numWaves * Math.min(parallelismPerWave, numSegmentsProcessed) * sparkSchedulingCostPerTask + if (s < 1.0) { + queryIntervalRatioScaleFactor * intervalRatio(intervalMillis) + } else { + Math.log10(s) + } + } - val sparkAggCost = numWaves * segmentOutputSizeEstimate * - (sparkAggregationCostPerRowFactor * shuffleCostPerRow) + def estimateInput(intervalMillis : Long) : Long - val costPerHistoricalWave = processingCostPerHist + histMergeCost + segmentOutputTransportCost + def estimateOutput(intervalMillis : Long) : Long - val druidStageCost = numWaves * costPerHistoricalWave + def brokerQueryCost : DruidQueryCost - val queryCost = druidStageCost + shuffleCost + sparkSchedulingCost + sparkAggCost + def histQueryCost(numSegsPerQuery : Long) : DruidQueryCost - HistoricalQueryCost( - numWaves, - numSegsPerQuery, - estimateOutputSizePerHist, - processingCostPerHist, - histMergeCost, - segmentOutputTransportCost, - shuffleCost, - sparkSchedulingCost, - sparkAggCost, - costPerHistoricalWave, - druidStageCost, - queryCost - ) - } + def druidQueryMethod: DruidQueryMethod = { log.info( s"""Druid Query Cost Model Input: - |$cI - |histProcessingCost = $histProcessingCostPerRow - |queryOutputSizeEstimate = $queryOutputSizeEstimate - |segmentOutputSizeEstimate = $segmentOutputSizeEstimate - |numSegmentsProcessed = $numSegmentsProcessed - |numSparkCores = $numSparkCores - |numHistoricalThreads = $numHistoricalThreads - |parallelismPerWave = $parallelismPerWave - | + |$cI + |Cost Per Row( + | histMergeCost=$histMergeCostPerRow, + | brokerMergeCost=$brokerMergeCostPerRow, + | histInputProcessingCost=$histInputProcessingCostPerRow, + | histOutputProcessingCost=$histOutputProcessingCostPerRow + | ) + |numSegmentsProcessed = $numSegmentsProcessed + |Size Estimates( + | queryInputSize=$queryInputSizeEstimate, + | segmentInputSize=$segmentInputSizeEstimate, + | queryOutputSize=$queryOutputSizeEstimate, + | segmentOutputSize=$segmentOutputSizeEstimate + | ) + |Environment( + | numSparkCores = $numSparkCores + | numHistoricalThreads = $numHistoricalThreads + | parallelismPerWave = $parallelismPerWave + | ) + | """.stripMargin) - val allCosts : ArrayBuffer[DruidQueryCost] = ArrayBuffer() - var minCost : DruidQueryCost = brokerQueryCost + val allCosts: ArrayBuffer + [DruidQueryCost] = + ArrayBuffer() + var minCost: + + + DruidQueryCost = brokerQueryCost allCosts += minCost - var minNumSegmentsPerQuery = -1 + var + minNumSegmentsPerQuery = -1 - (1 to histSegsPerQueryLimit).foreach { (numSegsPerQuery : Int) => + (1 to + histSegsPerQueryLimit).foreach { (numSegsPerQuery: Int) => val c = histQueryCost(numSegsPerQuery) allCosts += c - if ( c < minCost ) { + if (c < minCost) { minCost = c - minNumSegmentsPerQuery = numSegsPerQuery + minNumSegmentsPerQuery = + numSegsPerQuery } } val costDetails = CostDetails( cI, - histProcessingCostPerRow, + histMergeCostPerRow, + brokerMergeCostPerRow, + histInputProcessingCostPerRow, + histOutputProcessingCostPerRow, + queryInputSizeEstimate, + segmentInputSizeEstimate, queryOutputSizeEstimate, segmentOutputSizeEstimate, - numSegmentsProcessed, numSparkCores, numHistoricalThreads, parallelismPerWave, @@ -452,12 +407,276 @@ object DruidQueryCostModel extends Logging { ) log.info(costDetails.toString) - DruidQueryMethod(minCost.isInstanceOf[HistoricalQueryCost], minNumSegmentsPerQuery, minCost, costDetails) } - + +} + +class AggQueryCost(val cI : CostInput) extends QueryCost { + + import cI._ + + val histMergeCostPerRow = shuffleCostPerRow * histMergeCostPerRowFactor + val brokerMergeCostPerRow = shuffleCostPerRow * histMergeCostPerRowFactor + + val histInputProcessingCostPerRow = + historicalTimeSeriesProcessingCostPerRowFactor * shuffleCostPerRow + + val histOutputProcessingCostPerRow = + historicalGByProcessigCostPerRowFactor * shuffleCostPerRow + + val queryInputSizeEstimate = estimateInput(queryIntervalMillis) + + val segmentInputSizeEstimate : Long = estimateInput(segIntervalMillis) + + val queryOutputSizeEstimate = estimateOutput(queryIntervalMillis) + + val segmentOutputSizeEstimate : Long = estimateOutput(segIntervalMillis) + + val numSparkCores : Long = { + numSparkExecutors * sparkCoresPerExecutor + } + + val numHistoricalThreads = numHistoricals * numProcessingThreadsPerHistorical + + val parallelismPerWave = Math.min(numHistoricalThreads, numSparkCores) + + def estimateOutput(intervalMillis : Long) : Long = + Math.round(outputEstimate * scaledRatio(intervalMillis)) + + def estimateInput(intervalMillis : Long) : Long = + Math.round(inputEstimate * intervalRatio(intervalMillis)) + + + def shuffleCostPerWave = segmentOutputSizeEstimate * shuffleCostPerRow + + def sparkSchedulingCostPerWave = + Math.min(parallelismPerWave, numSegmentsProcessed) * sparkSchedulingCostPerTask + + def sparkAggCostPerWave = segmentOutputSizeEstimate * + (sparkAggregationCostPerRowFactor * shuffleCostPerRow) + + def costPerHistorical(numSegsPerQuery : Int) : Double = { + val inputProcessingCostPerHist: Double = + numSegsPerQuery * segmentInputSizeEstimate * histInputProcessingCostPerRow + + val outputProcessingCostPerHist: Double = + numSegsPerQuery * segmentOutputSizeEstimate * histOutputProcessingCostPerRow + + inputProcessingCostPerHist + outputProcessingCostPerHist + } + + def brokerMergeCost : Double = { + val numMerges = 2 * (numSegmentsProcessed - 1) + + (Math.max(numMerges.toDouble / numProcessingThreadsPerHistorical.toDouble,1.0)) * + segmentOutputSizeEstimate * brokerMergeCostPerRow + } + + def brokerQueryCost : DruidQueryCost = { + val numWaves: Long = estimateNumWaves(1, numHistoricalThreads) + val processingCostPerHistPerWave : Double = costPerHistorical(1) + + val brokerTransportCostPerWave = segmentOutputSizeEstimate * + (druidOutputTransportCostPerRowFactor * shuffleCostPerRow) + + val histCostPerWave = processingCostPerHistPerWave + brokerTransportCostPerWave + + val mergeCost : Double = brokerMergeCost + + val segmentOutputTransportCost = queryOutputSizeEstimate * + (druidOutputTransportCostPerRowFactor * shuffleCostPerRow) + + + val queryCost: Double = + numWaves * histCostPerWave + segmentOutputTransportCost + mergeCost + + BrokerQueryCost( + numWaves, + numWaves * histCostPerWave, + mergeCost, + segmentOutputTransportCost, + queryCost + ) + } + + def histQueryCost(numSegsPerQuery : Long) : DruidQueryCost = { + + val numWaves = estimateNumWaves(numSegsPerQuery) + val estimateOutputSizePerHist = + estimateOutput(segIntervalMillis * numSegsPerQuery) + + val inputProcessingCostPerHist : Double = + numSegsPerQuery * segmentInputSizeEstimate * histInputProcessingCostPerRow + + val outputProcessingCostPerHist : Double = + numSegsPerQuery * segmentOutputSizeEstimate * histOutputProcessingCostPerRow + + val processingCostPerHist = inputProcessingCostPerHist + outputProcessingCostPerHist + + val histMergeCost : Double = + (numSegsPerQuery - 1) * segmentOutputSizeEstimate * histMergeCostPerRow + + val segmentOutputTransportCost = estimateOutputSizePerHist * + (druidOutputTransportCostPerRowFactor * shuffleCostPerRow) + + val costPerHistoricalWave = processingCostPerHist + histMergeCost + segmentOutputTransportCost + val costPerSparkWave = shuffleCostPerWave + sparkSchedulingCostPerWave + sparkAggCostPerWave + + val druidStageCost = numWaves * costPerHistoricalWave + val sparkStageCost = numWaves * costPerSparkWave + + val queryCost = druidStageCost + sparkStageCost + + HistoricalQueryCost( + numWaves, + numSegsPerQuery, + estimateOutputSizePerHist, + processingCostPerHist, + histMergeCost, + segmentOutputTransportCost, + shuffleCostPerWave, + sparkSchedulingCostPerWave, + sparkAggCostPerWave, + costPerHistoricalWave, + druidStageCost, + queryCost + ) + } + +} + +class GroupByQueryCost(cI : CostInput) extends AggQueryCost(cI) { + +} + +/* + * - OutputSize = InputSize + * - there is no cost for doing GroupBy. + * - Prevent Broker plans. TODO revisit this + * - No Spark-side shuffle. + */ +class SelectQueryCost(i : CostInput) extends GroupByQueryCost(i) { + + override val histMergeCostPerRow = 0.0 + override val brokerMergeCostPerRow = 0.0 + + override val histOutputProcessingCostPerRow = 0.0 + + override def estimateOutput(intervalMillis : Long) = estimateInput(intervalMillis) + + override def shuffleCostPerWave = 0.0 + + override def sparkAggCostPerWave = 0.0 + + override def brokerMergeCost : Double = Double.MaxValue +} + +/* + * - inputSize = 0, because rows are not scanned. + * - Prevent Broker plans. + */ +class SearchQueryCost(i : CostInput) extends GroupByQueryCost(i) { + override val queryInputSizeEstimate : Long = 0 + + override val segmentInputSizeEstimate : Long = 0 + + override def brokerMergeCost : Double = Double.MaxValue + +} + +/* + * - outputSize = 1 + */ +class TimeSeriesQueryCost(i : CostInput) extends GroupByQueryCost(i) { + + override val queryOutputSizeEstimate : Long = 1 + + override val segmentOutputSizeEstimate : Long = 1 +} + +/* + * - queryOutputSize = query TopN threshold + * - segmentOutputSize = queryContext maxTopNThreshold + */ +class TopNQueryCost(i : CostInput) extends GroupByQueryCost(i) { + import cI._ + override val segmentOutputSizeEstimate : Long = cI.querySpec match { + case Left(qS : TopNQuerySpec) if qS.context.flatMap(_.minTopNThreshold).isDefined => + qS.context.flatMap(_.minTopNThreshold).get + case _ => estimateOutput(segIntervalMillis) + } + + override val queryOutputSizeEstimate: Long = cI.querySpec match { + case Left(qS : TopNQuerySpec) => qS.threshold + case _ => estimateOutput(queryIntervalMillis) + } +} + +object DruidQueryCostModel extends Logging { + + def intervalsMillis(intervals : List[Interval]) : Long = Utils.intervalsMillis(intervals) + + def intervalNDVEstimate( + intervalMillis : Long, + totalIntervalMillis : Long, + ndvForIndexEstimate : Long, + queryIntervalRatioScaleFactor : Double + ) : Long = { + val intervalRatio : Double = Math.min(intervalMillis.toDouble/totalIntervalMillis.toDouble, 1.0) + var scaledRatio = Math.min(queryIntervalRatioScaleFactor * intervalRatio * 10.0, 10.0) + if ( scaledRatio < 1.0 ) { + scaledRatio = queryIntervalRatioScaleFactor * intervalRatio + } else { + scaledRatio = Math.log10(scaledRatio) + } + Math.round(ndvForIndexEstimate * scaledRatio) + } + + private[druid] def compute[T <: QuerySpec](cI : CostInput) : DruidQueryMethod = { + val queryCost = cI.querySpec match { + case Left(q : SelectSpec) => new SelectQueryCost(cI) + case Right(cls) if classOf[SelectSpec].isAssignableFrom(cls) => + new SelectQueryCost(cI) + case Left(q : SearchQuerySpec) => new SearchQueryCost(cI) + case Right(cls) if classOf[SearchQuerySpec].isAssignableFrom(cls) => + new SearchQueryCost(cI) + case Left(q : TimeSeriesQuerySpec) => new TimeSeriesQueryCost(cI) + case Right(cls) if classOf[TimeSeriesQuerySpec].isAssignableFrom(cls) => + new TimeSeriesQueryCost(cI) + case Left(q : TopNQuerySpec) => new TopNQueryCost(cI) + case Right(cls) if classOf[TopNQuerySpec].isAssignableFrom(cls) => + new TopNQueryCost(cI) + case Left(q : GroupByQuerySpec) => new GroupByQueryCost(cI) + case Right(cls) if classOf[GroupByQuerySpec].isAssignableFrom(cls) => + new GroupByQueryCost(cI) + case _ => new GroupByQueryCost(cI) + } + queryCost.druidQueryMethod + } + + def estimateInput(qs : QuerySpec, + drInfo : DruidRelationInfo) : Long = { + + def applyFilterSelectivity(f : FilterSpec) : Double = f match { + case SelectorFilterSpec(_, dm, _) => + (1.0/drInfo.druidDS.columns(dm).cardinality) + case ExtractionFilterSpec(_, dm, _, InExtractionFnSpec(_, LookUpMap(_, lm))) + => + (lm.size.toDouble/drInfo.druidDS.columns(dm).cardinality) + case LogicalFilterSpec("and", sfs) => sfs.map(applyFilterSelectivity).product + case LogicalFilterSpec("or", sfs) => sfs.map(applyFilterSelectivity).sum + case NotFilterSpec(_, fs) => 1.0 - applyFilterSelectivity(fs) + case _ => 1.0/3.0 + } + + val selectivity : Double = qs.filter.map(applyFilterSelectivity).getOrElse(1.0) + (drInfo.druidDS.numRows.toDouble * selectivity).toLong + } + + /** * Cardinality = product of dimension(ndv) * dimension(selectivity) * where selectivity: @@ -469,54 +688,48 @@ object DruidQueryCostModel extends Logging { * @param drInfo * @return */ - def estimateNDV(qs : QuerySpec, + private def estimateOutputCardinality(qs : QuerySpec, drInfo : DruidRelationInfo) : Long = { - val m : MMap[String, Long] = MHashMap() - def applyFilterSelectivity(f : FilterSpec) : Unit = f match { - case SelectorFilterSpec(_, dm, _) if m.contains(dm) => m(dm) = 1 - case ExtractionFilterSpec(_, dm, _, InExtractionFnSpec(_, LookUpMap(_, lm))) - if m.contains(dm) => m(dm) = Math.min(m(dm), lm.size) - case LogicalFilterSpec("and", sfs) => sfs.foreach(applyFilterSelectivity) - case _ => () - } + val gByColumns : Map[String, DruidColumn] = + qs.dimensionSpecs.map(ds => (ds.dimension, drInfo.druidDS.columns(ds.dimension))).toMap - val timeTicks = drInfo.options.queryGranularity.ndv(drInfo.druidDS.intervals) + var queryNumRows = gByColumns.values.map(_.cardinality).map(_.toDouble).product + queryNumRows = Math.min(queryNumRows, drInfo.druidDS.numRows) - def populateNDVMap(columns : Iterable[DruidColumn]) : Unit = { - columns.foreach { dC => - m(dC.name) = dC match { - case t: DruidTimeDimension => timeTicks - case dC => dC.cardinality - } - } - } - - val columns = qs match { - case s : SelectSpec => { - drInfo.druidDS.columns.values + def applyFilterSelectivity(f : FilterSpec) : Unit = f match { + case SelectorFilterSpec(_, dm, _) if gByColumns.contains(dm) => { + queryNumRows = queryNumRows * (1.0/gByColumns(dm).cardinality) } - case a => { - a.dimensionSpecs.map(ds => drInfo.druidDS.columns(ds.dimension)) + case ExtractionFilterSpec(_, dm, _, InExtractionFnSpec(_, LookUpMap(_, lm))) + if gByColumns.contains(dm) => { + queryNumRows = queryNumRows * (lm.size.toDouble/gByColumns(dm).cardinality) } + case LogicalFilterSpec("and", sfs) => sfs.foreach(applyFilterSelectivity) + case _ => () } - populateNDVMap(columns) - if ( qs.filter.isDefined) { applyFilterSelectivity(qs.filter.get) } - m.values.product + queryNumRows.toLong + } + + def estimateOutput(qs : QuerySpec, + drInfo : DruidRelationInfo) : Long = qs match { + case select : SelectSpec => estimateInput(select, drInfo) + case _ => estimateOutputCardinality(qs, drInfo) } - private[druid] def computeMethod[T <: QuerySpec]( + private[druid] def computeMethod( sqlContext : SQLContext, druidDSIntervals : List[Interval], druidDSFullName : DruidRelationName, druidDSOptions : DruidRelationOptions, - dimsNDVEstimate : Long, + inputEstimate : Long, + outputEstimate : Long, queryIntervalMillis : Long, - querySpecClass : Class[_ <: T] + querySpec : QuerySpec ) : DruidQueryMethod = { val shuffleCostPerRow: Double = 1.0 @@ -567,18 +780,6 @@ object DruidQueryCostModel extends Logging { (totalSegments / totalIntervals) } - val queryOutputSizeEstimate = intervalNDVEstimate(queryIntervalMillis, - indexIntervalMillis, - dimsNDVEstimate, - queryIntervalRatioScaleFactor - ) - - val segmentOutputSizeEstimate = intervalNDVEstimate(segIntervalMillis, - indexIntervalMillis, - dimsNDVEstimate, - queryIntervalRatioScaleFactor - ) - val numSegmentsProcessed: Long = Math.round( Math.max(Math.round(queryIntervalMillis / segIntervalMillis), 1L ) * avgNumSegmentsPerSegInterval @@ -603,7 +804,8 @@ object DruidQueryCostModel extends Logging { compute( CostInput( - dimsNDVEstimate, + inputEstimate, + outputEstimate, shuffleCostPerRow, histMergeFactor, histSegsPerQueryLimit, @@ -621,14 +823,14 @@ object DruidQueryCostModel extends Logging { numSparkExecutors, numProcessingThreadsPerHistorical, numHistoricals, - querySpecClass + Left(querySpec) ) ) } - def computeMethod[T <: QuerySpec](sqlContext : SQLContext, + def computeMethod(sqlContext : SQLContext, drInfo : DruidRelationInfo, - querySpec : T + querySpec : QuerySpec ) : DruidQueryMethod = { val queryIntervalMillis : Long = intervalsMillis(querySpec.intervalList.map(Interval.parse(_))) @@ -637,19 +839,21 @@ object DruidQueryCostModel extends Logging { drInfo.druidDS.intervals, drInfo.fullName, drInfo.options, - estimateNDV(querySpec, drInfo), + estimateInput(querySpec, drInfo), + estimateOutput(querySpec, drInfo), queryIntervalMillis, - querySpec.getClass + querySpec ) } - def computeMethod[T <: QuerySpec](sqlContext : SQLContext, + def computeMethod(sqlContext : SQLContext, druidDSIntervals : List[Interval], druidDSFullName : DruidRelationName, druidDSOptions : DruidRelationOptions, - ndvEstimate : Long, - querySpec : T + inputEstimate: Long, + outputEstimate : Long, + querySpec : QuerySpec ) : DruidQueryMethod = { val queryIntervalMillis : Long = intervalsMillis(querySpec.intervalList.map(Interval.parse(_))) @@ -658,9 +862,10 @@ object DruidQueryCostModel extends Logging { druidDSIntervals, druidDSFullName, druidDSOptions, - ndvEstimate, + inputEstimate, + outputEstimate, queryIntervalMillis, - querySpec.getClass + querySpec ) } diff --git a/src/main/scala/org/apache/spark/sql/sources/druid/DruidStrategy.scala b/src/main/scala/org/apache/spark/sql/sources/druid/DruidStrategy.scala index 525383c..f54018c 100644 --- a/src/main/scala/org/apache/spark/sql/sources/druid/DruidStrategy.scala +++ b/src/main/scala/org/apache/spark/sql/sources/druid/DruidStrategy.scala @@ -95,11 +95,11 @@ private[druid] class DruidStrategy(val planner: DruidPlanner) extends Strategy * and not an epoch. */ val replaceTimeReferencedDruidColumns = dqb1.referencedDruidColumns.mapValues { - case dtc@DruidTimeDimension(_, _, sz) => DruidDimension( + case dtc@DruidTimeDimension(_, _, sz, card) => DruidDimension( DruidDataSource.EVENT_TIMESTAMP_KEY_NAME, DruidDataType.String, sz, - dtc.cardinality) + card) case dc => dc } diff --git a/src/main/scala/org/apache/spark/sql/sparklinedata/commands/DruidMetadataCommands.scala b/src/main/scala/org/apache/spark/sql/sparklinedata/commands/DruidMetadataCommands.scala index 0184de9..56a8e0d 100644 --- a/src/main/scala/org/apache/spark/sql/sparklinedata/commands/DruidMetadataCommands.scala +++ b/src/main/scala/org/apache/spark/sql/sparklinedata/commands/DruidMetadataCommands.scala @@ -63,11 +63,13 @@ case class ExplainDruidRewrite(sql: String) extends RunnableCommand { val druidDSIntervals = dR.drDSIntervals val druidDSFullName= dR.drFullName val druidDSOptions = dR.drOptions - val ndvEstimate = dR.ndvEstimate + val inputEstimate = dR.inputEstimate + val outputEstimate = dR.outputEstimate s"""DruidQuery(${System.identityHashCode(dR.dQuery)}) details :: |${DruidQueryCostModel.computeMethod( - sqlContext, druidDSIntervals, druidDSFullName, druidDSOptions, ndvEstimate, dR.dQuery.q) + sqlContext, druidDSIntervals, druidDSFullName, druidDSOptions, + inputEstimate, outputEstimate, dR.dQuery.q) } """.stripMargin.split("\n").map(Row(_)) } diff --git a/src/main/scala/org/sparklinedata/druid/DruidQueryGranularity.scala b/src/main/scala/org/sparklinedata/druid/DruidQueryGranularity.scala index a766e0f..5c0a58f 100644 --- a/src/main/scala/org/sparklinedata/druid/DruidQueryGranularity.scala +++ b/src/main/scala/org/sparklinedata/druid/DruidQueryGranularity.scala @@ -20,7 +20,8 @@ package org.sparklinedata.druid import jodd.datetime.Period import org.joda.time.chrono.ISOChronology import org.joda.time.{DateTime, DateTimeZone, Interval} -import org.json4s.MappingException +import org.json4s.JsonAST.{JField, JObject, JString} +import org.json4s.{CustomSerializer, MappingException} import scala.util.Try import scala.language.postfixOps @@ -105,4 +106,23 @@ case class DurationGranularity( } } +class DruidQueryGranularitySerializer extends CustomSerializer[DruidQueryGranularity](format => { + implicit val f = format + ( + { + case jO@JObject(JField("type", JString("duration")) :: rest) => + jO.extract[DurationGranularity] + case jO@JObject(JField("type", JString("period")) :: rest) => + jO.extract[PeriodGranularity] + case jO@JObject(JField("type", JString("all")) :: rest) => AllGranularity + case jO@JObject(JField("type", JString("none")) :: rest) => NoneGranularity + } + , + { + case x: DruidQueryGranularity => + throw new RuntimeException("DruidQueryGranularity serialization not supported.") + } + ) +}) + diff --git a/src/main/scala/org/sparklinedata/druid/DruidRDD.scala b/src/main/scala/org/sparklinedata/druid/DruidRDD.scala index e2f003d..7f5d489 100644 --- a/src/main/scala/org/sparklinedata/druid/DruidRDD.scala +++ b/src/main/scala/org/sparklinedata/druid/DruidRDD.scala @@ -105,7 +105,8 @@ class DruidRDD(sqlContext: SQLContext, val drOptions = drInfo.options val drFullName = drInfo.fullName val drDSIntervals = drInfo.druidDS.intervals - val ndvEstimate = DruidQueryCostModel.estimateNDV(dQuery.q, drInfo) + val inputEstimate = DruidQueryCostModel.estimateInput(dQuery.q, drInfo) + val outputEstimate = DruidQueryCostModel.estimateOutput(dQuery.q, drInfo) val (httpMaxPerRoute, httpMaxTotal) = ( DruidPlanner.getConfValue(sqlContext, DruidPlanner.DRUID_CONN_POOL_MAX_CONNECTIONS_PER_ROUTE), diff --git a/src/main/scala/org/sparklinedata/druid/Utils.scala b/src/main/scala/org/sparklinedata/druid/Utils.scala index 2bbde7d..b8f082e 100644 --- a/src/main/scala/org/sparklinedata/druid/Utils.scala +++ b/src/main/scala/org/sparklinedata/druid/Utils.scala @@ -89,6 +89,7 @@ object Utils extends Logging { ) + new EnumNameSerializer(FunctionalDependencyType) + new EnumNameSerializer(DruidDataType) + + new DruidQueryGranularitySerializer + new QueryResultRowSerializer + new SelectResultRowSerializer + new TopNResultRowSerializer ++ diff --git a/src/main/scala/org/sparklinedata/druid/client/DruidClient.scala b/src/main/scala/org/sparklinedata/druid/client/DruidClient.scala index babf2e7..8622c9d 100644 --- a/src/main/scala/org/sparklinedata/druid/client/DruidClient.scala +++ b/src/main/scala/org/sparklinedata/druid/client/DruidClient.scala @@ -349,7 +349,7 @@ abstract class DruidClient(val host : String, val jR = render( ("queryType" -> "segmentMetadata") ~ ("dataSource" -> dataSource) ~ ("intervals" -> ins) ~ - ("analysisTypes" -> List[String]("cardinality")) ~ + ("analysisTypes" -> List[String]("cardinality", "interval", "minmax", "queryGranularity")) ~ ("merge" -> "true") ) diff --git a/src/main/scala/org/sparklinedata/druid/client/DruidMessages.scala b/src/main/scala/org/sparklinedata/druid/client/DruidMessages.scala index c546ec9..99675da 100644 --- a/src/main/scala/org/sparklinedata/druid/client/DruidMessages.scala +++ b/src/main/scala/org/sparklinedata/druid/client/DruidMessages.scala @@ -20,13 +20,37 @@ package org.sparklinedata.druid.client import org.joda.time.{DateTime, Interval} import org.json4s.CustomSerializer import org.json4s.JsonAST._ +import org.sparklinedata.druid.{DruidQueryGranularity, NoneGranularity} case class ColumnDetails(typ : String, size : Long, - cardinality : Option[Long], errorMessage : Option[String]) + cardinality : Option[Long], + minValue : Option[String], + maxValue : Option[String], + errorMessage : Option[String]) { + + def isDimension = cardinality.isDefined +} + case class MetadataResponse(id : String, intervals : List[String], columns : Map[String, ColumnDetails], - size : Long) + size : Long, + numRows : Option[Long], + queryGranularity : Option[DruidQueryGranularity] + ) { + + def getIntervals : List[Interval] = intervals.map(Interval.parse(_)) + + def timeTicks(ins : List[Interval] ) : Long = + queryGranularity.getOrElse(NoneGranularity).ndv(ins) + + def getNumRows : Long = numRows.getOrElse{ + val p = + columns.values.filter(c => c.isDimension).map(_.cardinality.get).map(_.toDouble).product + if (p > Long.MaxValue) Long.MaxValue else p.toLong + } + +} case class SegmentInfo(id : String, intervals : Interval, diff --git a/src/main/scala/org/sparklinedata/druid/metadata/DruidDataSource.scala b/src/main/scala/org/sparklinedata/druid/metadata/DruidDataSource.scala index 0317c84..31fab4d 100644 --- a/src/main/scala/org/sparklinedata/druid/metadata/DruidDataSource.scala +++ b/src/main/scala/org/sparklinedata/druid/metadata/DruidDataSource.scala @@ -46,20 +46,23 @@ sealed trait DruidColumn { * assume the worst for time dim and metrics, * this is only used during query costing */ - def cardinality : Long = Int.MaxValue.toLong + val cardinality : Long def isDimension(excludeTime : Boolean = false) : Boolean } object DruidColumn { - def apply(nm : String, c : ColumnDetails) : DruidColumn = { + def apply(nm : String, + c : ColumnDetails, + idxNumRows : Long, + idxTicks : Long) : DruidColumn = { if (nm == DruidDataSource.TIME_COLUMN_NAME) { - DruidTimeDimension(nm, DruidDataType.withName(c.typ), c.size) + DruidTimeDimension(nm, DruidDataType.withName(c.typ), c.size, Math.min(idxTicks,idxNumRows)) } else if ( c.cardinality.isDefined) { DruidDimension(nm, DruidDataType.withName(c.typ), c.size, c.cardinality.get) } else { - DruidMetric(nm, DruidDataType.withName(c.typ), c.size) + DruidMetric(nm, DruidDataType.withName(c.typ), c.size, idxNumRows) } } } @@ -67,19 +70,21 @@ object DruidColumn { case class DruidDimension(name : String, dataType : DruidDataType.Value, size : Long, - override val cardinality : Long) extends DruidColumn { + cardinality : Long) extends DruidColumn { def isDimension(excludeTime : Boolean = false) = true } case class DruidMetric(name : String, dataType : DruidDataType.Value, - size : Long) extends DruidColumn { + size : Long, + cardinality : Long) extends DruidColumn { def isDimension(excludeTime : Boolean = false) = false } case class DruidTimeDimension(name : String, dataType : DruidDataType.Value, - size : Long) extends DruidColumn { + size : Long, + cardinality : Long) extends DruidColumn { def isDimension(excludeTime : Boolean = false) = !excludeTime } @@ -93,6 +98,8 @@ case class DruidDataSource(name : String, intervals : List[Interval], columns : Map[String, DruidColumn], size : Long, + numRows : Long, + timeTicks : Long, druidVersion : String = null) extends DruidDataSourceCapability { import DruidDataSource._ @@ -132,9 +139,13 @@ object DruidDataSource { def apply(dataSource : String, mr : MetadataResponse, is : List[Interval]) : DruidDataSource = { + + val idxNumRows = mr.getNumRows + val idxTicks = mr.timeTicks(is) + val columns = mr.columns.map { - case (n, c) => (n -> DruidColumn(n,c)) + case (n, c) => (n -> DruidColumn(n,c, idxNumRows, idxTicks)) } - new DruidDataSource(dataSource, is, columns, mr.size) + new DruidDataSource(dataSource, is, columns, mr.size, idxNumRows, idxTicks) } } diff --git a/src/test/scala/org/apache/spark/sql/sources/druid/DruidQueryCostModelTest.scala b/src/test/scala/org/apache/spark/sql/sources/druid/DruidQueryCostModelTest.scala index accf32d..852d189 100644 --- a/src/test/scala/org/apache/spark/sql/sources/druid/DruidQueryCostModelTest.scala +++ b/src/test/scala/org/apache/spark/sql/sources/druid/DruidQueryCostModelTest.scala @@ -56,10 +56,13 @@ class DruidQueryCostModelTest extends fixture.FunSuite with val tpch_numProcessingThreadsPerHistorical = 7 val tpch_numHistoricals = 8 + val tpchInputSize : Long = 100L * 1000 * 1000 + test("tpch_one_year_query") { td => // distinct values estimate = 100 var costScenario = CostInput( + tpchInputSize, 100, 1.0, default_histMergeFactor, @@ -78,15 +81,15 @@ class DruidQueryCostModelTest extends fixture.FunSuite with tpch_numSparkExecutors, tpch_numProcessingThreadsPerHistorical, tpch_numHistoricals, - classOf[GroupByQuerySpec] + Right(classOf[GroupByQuerySpec]) ) DruidQueryCostModel.compute(costScenario) // distinct values estimate = 1000 - DruidQueryCostModel.compute(costScenario.copy(dimsNDVEstimate = 1000)) + DruidQueryCostModel.compute(costScenario.copy(outputEstimate = 1000)) // distinct values estimate = 10000 - DruidQueryCostModel.compute(costScenario.copy(dimsNDVEstimate = 10000)) + DruidQueryCostModel.compute(costScenario.copy(outputEstimate = 10000)) } @@ -94,6 +97,7 @@ class DruidQueryCostModelTest extends fixture.FunSuite with val costScenario = CostInput( + tpchInputSize, 100, 1.0, default_histMergeFactor, @@ -112,7 +116,7 @@ class DruidQueryCostModelTest extends fixture.FunSuite with tpch_numSparkExecutors, tpch_numProcessingThreadsPerHistorical, tpch_numHistoricals, - classOf[GroupByQuerySpec] + Right(classOf[GroupByQuerySpec]) ) DruidQueryCostModel.compute(costScenario) } @@ -122,6 +126,7 @@ class DruidQueryCostModelTest extends fixture.FunSuite with // distinct values estimate = 100 val costScenario = CostInput( + tpchInputSize, 10000, 1.0, default_histMergeFactor, @@ -140,22 +145,22 @@ class DruidQueryCostModelTest extends fixture.FunSuite with tpch_numSparkExecutors, tpch_numProcessingThreadsPerHistorical, tpch_numHistoricals, - classOf[GroupByQuerySpec] + Right(classOf[GroupByQuerySpec]) ) DruidQueryCostModel.compute(costScenario) // distinct values estimate = 1000 - DruidQueryCostModel.compute(costScenario.copy(dimsNDVEstimate = 1000)) + DruidQueryCostModel.compute(costScenario.copy(outputEstimate = 1000)) // distinct values estimate = 10000 - DruidQueryCostModel.compute(costScenario.copy(dimsNDVEstimate = 10000)) + DruidQueryCostModel.compute(costScenario.copy(outputEstimate = 10000)) // distinct values estimate = 4611686018427387904L - DruidQueryCostModel.compute(costScenario.copy(dimsNDVEstimate = 4611686018427387904L)) + DruidQueryCostModel.compute(costScenario.copy(outputEstimate = 50L * 1000 * 1000)) DruidQueryCostModel.compute( - costScenario.copy(dimsNDVEstimate = 4611686018427387904L). - copy(querySpecClass = classOf[SelectSpecWithIntervals])) + costScenario.copy(outputEstimate = 50L * 1000 * 1000). + copy(querySpec = Right(classOf[SelectSpecWithIntervals]))) }