Spark Planner for Converting Logical Plan to Spark Plan

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: Spark Planner for Converting Logical Plan to Spark Plan

The Spark Planner is the bridge from Logic Plan to Spark Plan which will convert Logic Plan to Spark Plan.
The SparkPlanner extends from SparkStrategies(Code in SparkStrategies.scala).
It have several strategies:

  def strategies: Seq[Strategy] =
      extraStrategies ++ (
      FileSourceStrategy ::
      DataSourceStrategy ::
      DDLStrategy ::
      SpecialLimits ::
      Aggregation ::
      JoinSelection ::
      InMemoryScans ::
      BasicOperators :: Nil)

Let me introduce several here:

SpecialLimits

The Logical Plan Limit will be converted to Spark Plan TakeOrderedAndProjectExec here.

 /**
   * Plans special cases of limit operators.
   */
  object SpecialLimits extends Strategy {
    override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case logical.ReturnAnswer(rootPlan) => rootPlan match {
        case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
          execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
        case logical.Limit(
            IntegerLiteral(limit),
            logical.Project(projectList, logical.Sort(order, true, child))) =>
          execution.TakeOrderedAndProjectExec(
            limit, order, projectList, planLater(child)) :: Nil
        case logical.Limit(IntegerLiteral(limit), child) =>
          execution.CollectLimitExec(limit, planLater(child)) :: Nil
        case other => planLater(other) :: Nil
      }
      case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
        execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
      case logical.Limit(
          IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) =>
        execution.TakeOrderedAndProjectExec(
          limit, order, projectList, planLater(child)) :: Nil
      case _ => Nil
    }
  }
Aggregation

The Logical Plan Aggregate will be convert to bellow Spark Plan based on conditions:
SortAggregateExec
HashAggregateExec
ObjectHashAggregateExec

 /**
   * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
   */
  object Aggregation extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case PhysicalAggregation(
          groupingExpressions, aggregateExpressions, resultExpressions, child) =>
 
        val (functionsWithDistinct, functionsWithoutDistinct) =
          aggregateExpressions.partition(_.isDistinct)
......

Let me use SQL “SELECT x.str, COUNT(*) FROM df x JOIN df y ON x.str = y.str GROUP BY x.str” for example.
And at here, “COUNT(*)” is an aggregate expression.

Match the Logical Plan Aggregate

Code in patterns.scala.

object PhysicalAggregation {
  // groupingExpressions, aggregateExpressions, resultExpressions, child
  type ReturnType =
    (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan)
 
  def unapply(a: Any): Option[ReturnType] = a match { //Mark A
    case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
      // A single aggregate expression might appear multiple times in resultExpressions.
      // In order to avoid evaluating an individual aggregate function multiple times, we'll
      // build a set of the distinct aggregate expressions and build a function which can
      // be used to re-write expressions so that they reference the single copy of the
      // aggregate function which actually gets computed.
      val aggregateExpressions = resultExpressions.flatMap { expr =>
        expr.collect {
          case agg: AggregateExpression => agg
        }
      }.distinct
 
      val namedGroupingExpressions = groupingExpressions.map {
        case ne: NamedExpression => ne -> ne
        // If the expression is not a NamedExpressions, we add an alias.
        // So, when we generate the result of the operator, the Aggregate Operator
        // can directly get the Seq of attributes representing the grouping expressions.
        case other =>
          val withAlias = Alias(other, other.toString)()
          other -> withAlias
      }
      val groupExpressionMap = namedGroupingExpressions.toMap
 
      // The original `resultExpressions` are a set of expressions which may reference
      // aggregate expressions, grouping column values, and constants. When aggregate operator
      // emits output rows, we will use `resultExpressions` to generate an output projection
      // which takes the grouping columns and final aggregate result buffer as input.
      // Thus, we must re-write the result expressions so that their attributes match up with
      // the attributes of the final result projection's input row:
      val rewrittenResultExpressions = resultExpressions.map { expr =>
        expr.transformDown {
          case ae: AggregateExpression =>
            // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
            // so replace each aggregate expression by its corresponding attribute in the set:
            ae.resultAttribute
          case expression =>
            // Since we're using `namedGroupingAttributes` to extract the grouping key
            // columns, we need to replace grouping key expressions with their corresponding
            // attributes. We do not rely on the equality check at here since attributes may
            // differ cosmetically. Instead, we use semanticEquals.
            groupExpressionMap.collectFirst {
              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
            }.getOrElse(expression)
        }.asInstanceOf[NamedExpression]
      }
 
      Some((
        namedGroupingExpressions.map(_._2),
        aggregateExpressions,
        rewrittenResultExpressions,
        child))
 
    case _ => None
  }
}

First, try using PhysicalAggregation to match the Logical Plan Aggregate.
At area Mark A, the variables as bellow:
groupingExpressions: [ 0 : AttributeReference “str#227” ]
resultExpressions: [ 0 : AttributeReference “str#227” , 1 : Alias count(1) AS count(1)#235L ]
child: Logic Plan

Project [str#227]
+- Join Inner, (str#227 = str#233)
   :- Project [_2#224 AS str#227]
   :  +- Filter isnotnull(_2#224)
   :     +- LocalRelation [_1#223, _2#224]
   +- Project [_2#224 AS str#233]
      +- Filter isnotnull(_2#224)
         +- LocalRelation [_1#223, _2#224]

After match result, the result will be:
aggregateExpressions: [ 0 : AggregateExpress “count(1)” ]
groupingExpressions: [ 0 : AttributeReference “str#227” ]
resultExpressions: [ 0 : AttributeReference “str#227” , 1 : Alias count(1) AS count(1)#235L ]

Create Spark Plan base on PhysicalAggregation

Code in SparkStrategies.scala.

  object Aggregation extends Strategy {
    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
      case PhysicalAggregation(
          groupingExpressions, aggregateExpressions, resultExpressions, child) =>
 
        val (functionsWithDistinct, functionsWithoutDistinct) = //Mark B
          aggregateExpressions.partition(_.isDistinct)
        if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
          // This is a sanity check. We should not reach here when we have multiple distinct
          // column sets. Our MultipleDistinctRewriter should take care this case.
          sys.error("You hit a query analyzer bug. Please report your query to " +
              "Spark user mailing list.")
        }
 
        val aggregateOperator =
          if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
            if (functionsWithDistinct.nonEmpty) {
              sys.error("Distinct columns cannot exist in Aggregate operator containing " +
                "aggregate functions which don't support partial aggregation.")
            } else {
              aggregate.AggUtils.planAggregateWithoutPartial(
                groupingExpressions,
                aggregateExpressions,
                resultExpressions,
                planLater(child))
            }
          } else if (functionsWithDistinct.isEmpty) {
            aggregate.AggUtils.planAggregateWithoutDistinct(//Mark C
              groupingExpressions,
              aggregateExpressions,
              resultExpressions,
              planLater(child))
          } else {
            aggregate.AggUtils.planAggregateWithOneDistinct(
              groupingExpressions,
              functionsWithDistinct,
              functionsWithoutDistinct,
              resultExpressions,
              planLater(child))
          }
 
        aggregateOperator
 
      case _ => Nil
    }
  }

At area Mark B, based on the example SQL, functionsWithDistinct is nothing,
functionsWithoutDistinct will have one which is AggregateExpress “count(1)”.
So the branch will be planAggregateWithoutDistinct at last at Mark C.

At last, the Spark Plan HashAggregateExec will be:

HashAggregate(keys=[str#227], functions=[count(1)], output=[str#227, count(1)#235L])
+- HashAggregate(keys=[str#227], functions=[partial_count(1)], output=[str#227, count#240L])
   +- PlanLater Project [str#227]
JoinSelection

JoinSelection will convert Logical Plan Join to bellow Spark Plan base on condition:
BroadcastHashJoinExec
ShuffledHashJoinExec
SortMergeJoinExec
BroadcastNestedLoopJoinExec
CartesianProductExec

Convert Logical Plan to Spark Plan

First, if one side can build and can broadcast(The serialize size should not bigger than configuration “spark.sql.autoBroadcastJoinThreshold”), using BroadcastHashJoinExec, see Mark A and Mark B.
Second, if configuration “spark.sql.join.preferSortMergeJoin” is false, and one side can build, and can build at local( The serialize size should not bigger than configuration “spark.sql.autoBroadcastJoinThreshold” * “spark.sql.shuffle.partitions”), and build side is smaller than other side. Or leftKeys is not orderable. Using ShuffledHashJoinExec, see Mark C and Mark D.
Third, if leftKeys is orderable, using SortMergeJoinExec, see Mark E.
Then, if not join keys, and one side can broadcast, using BroadcastNestedLoopJoinExec, see Mark F and Mark G.
Then, if join type is inner join, using CartesianProduct at local, see Mark H.
At last, can only using BroadcastNestedLoopJoinExec, as no join keys, the large memory will be used, see Mark I.

 object JoinSelection extends Strategy with PredicateHelper {
 
......
 
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
 
      // --- BroadcastHashJoin --------------------------------------------------------------------
 
      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)  //Mark A
        if canBuildRight(joinType) && canBroadcast(right) =>
        Seq(joins.BroadcastHashJoinExec(
          leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
 
      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)  //Mark B
        if canBuildLeft(joinType) && canBroadcast(left) =>
        Seq(joins.BroadcastHashJoinExec(
          leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
 
      // --- ShuffledHashJoin ---------------------------------------------------------------------
 
      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)  //Mark C
         if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right)
           && muchSmaller(right, left) ||
           !RowOrdering.isOrderable(leftKeys) =>
        Seq(joins.ShuffledHashJoinExec(
          leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
 
      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)  //Mark D
         if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left)
           && muchSmaller(left, right) ||
           !RowOrdering.isOrderable(leftKeys) =>
        Seq(joins.ShuffledHashJoinExec(
          leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
 
      // --- SortMergeJoin ------------------------------------------------------------
 
      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)  //Mark E
        if RowOrdering.isOrderable(leftKeys) =>
        joins.SortMergeJoinExec(
          leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
 
      // --- Without joining keys ------------------------------------------------------------
 
      // Pick BroadcastNestedLoopJoin if one side could be broadcasted
      case j @ logical.Join(left, right, joinType, condition)   //Mark F
          if canBuildRight(joinType) && canBroadcast(right) =>
        joins.BroadcastNestedLoopJoinExec(
          planLater(left), planLater(right), BuildRight, joinType, condition) :: Nil
      case j @ logical.Join(left, right, joinType, condition)   //Mark G
          if canBuildLeft(joinType) && canBroadcast(left) =>
        joins.BroadcastNestedLoopJoinExec(
          planLater(left), planLater(right), BuildLeft, joinType, condition) :: Nil
 
      // Pick CartesianProduct for InnerJoin
      case logical.Join(left, right, _: InnerLike, condition) =>   //Mark H
        joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil
 
      case logical.Join(left, right, joinType, condition) =>      //Mark I
        val buildSide =
          if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
            BuildRight
          } else {
            BuildLeft
          }
        // This join could be very slow or OOM
        joins.BroadcastNestedLoopJoinExec(
          planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
 
      // --- Cases where this strategy does not apply ---------------------------------------------
 
      case _ => Nil
    }
  }
How BroadcastHashJoinExec

When getting the RDDs, BroadcastHashJoinExec will be invoked with method doProduce(see Mark A). The streamedPlan will be responsible to produce the code.
After the InputAdapter call back to parent, the doConsume will be called. For this case, the join type is inner join, so codegenInner branch will do(see Mark B).
When preparing the broadcast, the build plan will be executed and be broadcast first.(see Mark C, Mark D and Mark E).

case class BroadcastHashJoinExec(
    leftKeys: Seq[Expression],
    rightKeys: Seq[Expression],
    joinType: JoinType,
    buildSide: BuildSide,
    condition: Option[Expression],
    left: SparkPlan,
    right: SparkPlan)
  extends BinaryExecNode with HashJoin with CodegenSupport {
......
 
  override def doProduce(ctx: CodegenContext): String = {   //Mark A
    streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
  }
 
  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
    joinType match {
      case _: InnerLike => codegenInner(ctx, input) //Mark B
      case LeftOuter | RightOuter => codegenOuter(ctx, input)
      case LeftSemi => codegenSemi(ctx, input)
      case LeftAnti => codegenAnti(ctx, input)
      case j: ExistenceJoin => codegenExistence(ctx, input)
      case x =>
        throw new IllegalArgumentException(
          s"BroadcastHashJoin should not take $x as the JoinType")
    }
  }
 
  /**
   * Returns a tuple of Broadcast of HashedRelation and the variable name for it.
   */
  private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { //Mark D
    // create a name for HashedRelation
    val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()  //Mark E
    val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
    val relationTerm = ctx.freshName("relation")
    val clsName = broadcastRelation.value.getClass.getName
    ctx.addMutableState(clsName, relationTerm,
      s"""
         | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
         | incPeakExecutionMemory($relationTerm.estimatedSize());
       """.stripMargin)
    (broadcastRelation, relationTerm)
  }
 
  /**
   * Generates the code for Inner join.
   */
  private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
    val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)  //Mark C
    val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
    val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
    val numOutput = metricTerm(ctx, "numOutputRows")
 
    val resultVars = buildSide match {
      case BuildLeft => buildVars ++ input
      case BuildRight => input ++ buildVars
    }
    if (broadcastRelation.value.keyIsUnique) {  //Mark F
      s"""
         |// generate join key for stream side
         |${keyEv.code}
         |// find matches from HashedRelation
         |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
         |if ($matched == null) continue;
         |$checkCondition
         |$numOutput.add(1);
         |${consume(ctx, resultVars)}
       """.stripMargin
 
    } else {
      ctx.copyResult = true
      val matches = ctx.freshName("matches")
      val iteratorCls = classOf[Iterator[UnsafeRow]].getName
      s"""
         |// generate join key for stream side
         |${keyEv.code}
         |// find matches from HashRelation
         |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
         |if ($matches == null) continue;
         |while ($matches.hasNext()) {
         |  UnsafeRow $matched = (UnsafeRow) $matches.next();
         |  $checkCondition
         |  $numOutput.add(1);
         |  ${consume(ctx, resultVars)}
         |}
       """.stripMargin
    }
  }

At last, the Java code generated will be:

/* 193 */     // PRODUCE: BroadcastHashJoin [str#227], [str#233], Inner, BuildRight
/* 194 */     // PRODUCE: Project [_2#224 AS str#227]
/* 195 */     // PRODUCE: Filter isnotnull(_2#224)
/* 196 */     // PRODUCE: InputAdapter
/* 197 */     while (inputadapter_input.hasNext()) {
/* 198 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 199 */       // CONSUME: Filter isnotnull(_2#224)
/* 200 */       // input[1, string, true]
/* 201 */       boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 202 */       UTF8String inputadapter_value1 = inputadapter_isNull1 ? null : (inputadapter_row.getUTF8String(1));
/* 203 */
/* 204 */       if (!(!(inputadapter_isNull1))) continue;
/* 205 */
/* 206 */       filter_numOutputRows.add(1);
/* 207 */
/* 208 */       // CONSUME: Project [_2#224 AS str#227]
/* 209 */       // CONSUME: BroadcastHashJoin [str#227], [str#233], Inner, BuildRight
/* 210 */       // generate join key for stream side
/* 211 */
/* 212 */       bhj_holder.reset();
/* 213 */
/* 214 */       bhj_rowWriter.write(0, inputadapter_value1);
/* 215 */       bhj_result.setTotalSize(bhj_holder.totalSize());
/* 216 */
/* 217 */       // find matches from HashedRelation
/* 218 */       UnsafeRow bhj_matched = bhj_result.anyNull() ? null: (UnsafeRow)bhj_relation.getValue(bhj_result);
/* 219 */       if (bhj_matched == null) continue;
/* 220 */
/* 221 */       bhj_numOutputRows.add(1);

本作品采用知识共享署名 4.0 国际许可协议进行许可。

发表回复