|
15 | 15 | package org.apache.spark.sql.clickhouse
|
16 | 16 |
|
17 | 17 | import org.apache.spark.sql.AnalysisException
|
18 |
| -import org.apache.spark.sql.catalyst.SQLConfHelper |
19 |
| -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression} |
| 18 | +import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, NoSuchFunctionException, TypeCoercion} |
| 19 | +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, ListQuery, Literal} |
| 20 | +import org.apache.spark.sql.catalyst.expressions.{TimeZoneAwareExpression, TransformExpression, V2ExpressionUtils} |
| 21 | +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} |
| 22 | +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} |
| 23 | +import org.apache.spark.sql.catalyst.trees.TreePattern.{LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION} |
| 24 | +import org.apache.spark.sql.catalyst.{expressions, SQLConfHelper} |
20 | 25 | import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.IGNORE_UNSUPPORTED_TRANSFORM
|
| 26 | +import org.apache.spark.sql.connector.catalog.Identifier |
| 27 | +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction} |
21 | 28 | import org.apache.spark.sql.connector.expressions.Expressions._
|
22 |
| -import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, _} |
23 |
| -import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} |
| 29 | +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, SortOrder => V2SortOrder, _} |
| 30 | +import org.apache.spark.sql.types.{StructField, StructType} |
24 | 31 | import xenon.clickhouse.exception.CHClientException
|
25 | 32 | import xenon.clickhouse.expr._
|
| 33 | +import xenon.clickhouse.func.FunctionRegistry |
| 34 | +import xenon.clickhouse.spec.ClusterSpec |
26 | 35 |
|
27 |
| -import scala.annotation.tailrec |
28 | 36 | import scala.util.{Failure, Success, Try}
|
29 | 37 |
|
30 |
| -object ExprUtils extends SQLConfHelper { |
| 38 | +object ExprUtils extends SQLConfHelper with Serializable { |
31 | 39 |
|
32 |
| - def toSparkPartitions(partitionKey: Option[List[Expr]]): Array[Transform] = |
33 |
| - partitionKey.seq.flatten.flatten(toSparkTransformOpt).toArray |
| 40 | + def toSparkPartitions( |
| 41 | + partitionKey: Option[List[Expr]], |
| 42 | + functionRegistry: FunctionRegistry |
| 43 | + ): Array[Transform] = |
| 44 | + partitionKey.seq.flatten.flatten(toSparkTransformOpt(_, functionRegistry)).toArray |
34 | 45 |
|
35 |
| - def toSparkSplits(shardingKey: Option[Expr], partitionKey: Option[List[Expr]]): Array[Transform] = |
36 |
| - (shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt).toArray |
| 46 | + def toSparkSplits( |
| 47 | + shardingKey: Option[Expr], |
| 48 | + partitionKey: Option[List[Expr]], |
| 49 | + functionRegistry: FunctionRegistry |
| 50 | + ): Array[Transform] = |
| 51 | + (shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt(_, functionRegistry)).toArray |
37 | 52 |
|
38 | 53 | def toSparkSortOrders(
|
39 | 54 | shardingKeyIgnoreRand: Option[Expr],
|
40 | 55 | partitionKey: Option[List[Expr]],
|
41 |
| - sortingKey: Option[List[OrderExpr]] |
42 |
| - ): Array[SortOrder] = |
43 |
| - toSparkSplits(shardingKeyIgnoreRand, partitionKey).map(Expressions.sort(_, SortDirection.ASCENDING)) ++: |
| 56 | + sortingKey: Option[List[OrderExpr]], |
| 57 | + cluster: Option[ClusterSpec], |
| 58 | + functionRegistry: FunctionRegistry |
| 59 | + ): Array[V2SortOrder] = |
| 60 | + toSparkSplits( |
| 61 | + shardingKeyIgnoreRand, |
| 62 | + partitionKey, |
| 63 | + functionRegistry |
| 64 | + ).map(Expressions.sort(_, SortDirection.ASCENDING)) ++: |
44 | 65 | sortingKey.seq.flatten.flatten { case OrderExpr(expr, asc, nullFirst) =>
|
45 | 66 | val direction = if (asc) SortDirection.ASCENDING else SortDirection.DESCENDING
|
46 | 67 | val nullOrder = if (nullFirst) NullOrdering.NULLS_FIRST else NullOrdering.NULLS_LAST
|
47 |
| - toSparkTransformOpt(expr).map(trans => Expressions.sort(trans, direction, nullOrder)) |
| 68 | + toSparkTransformOpt(expr, functionRegistry).map(trans => |
| 69 | + Expressions.sort(trans, direction, nullOrder) |
| 70 | + ) |
48 | 71 | }.toArray
|
49 | 72 |
|
50 |
| - @tailrec |
51 |
| - def toCatalyst(v2Expr: V2Expression, fields: Array[StructField]): Expression = |
| 73 | + private def loadV2FunctionOpt( |
| 74 | + name: String, |
| 75 | + args: Seq[Expression], |
| 76 | + functionRegistry: FunctionRegistry |
| 77 | + ): Option[BoundFunction] = { |
| 78 | + def loadFunction(ident: Identifier): UnboundFunction = |
| 79 | + functionRegistry.load(ident.name).getOrElse(throw new NoSuchFunctionException(ident)) |
| 80 | + val inputType = StructType(args.zipWithIndex.map { |
| 81 | + case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) |
| 82 | + }) |
| 83 | + try { |
| 84 | + val unbound = loadFunction(Identifier.of(Array.empty, name)) |
| 85 | + Some(unbound.bind(inputType)) |
| 86 | + } catch { |
| 87 | + case e: NoSuchFunctionException => |
| 88 | + throw e |
| 89 | + case _: UnsupportedOperationException if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => |
| 90 | + None |
| 91 | + case e: UnsupportedOperationException => |
| 92 | + throw new AnalysisException(e.getMessage, cause = Some(e)) |
| 93 | + } |
| 94 | + } |
| 95 | + |
| 96 | + def resolveTransformCatalyst( |
| 97 | + catalystExpr: Expression, |
| 98 | + timeZoneId: Option[String] = None |
| 99 | + ): Expression = |
| 100 | + new TypeCoercionExecutor(timeZoneId) |
| 101 | + .execute(DummyLeafNode(resolveTransformExpression(catalystExpr))) |
| 102 | + .asInstanceOf[DummyLeafNode].expr |
| 103 | + |
| 104 | + private case class DummyLeafNode(expr: Expression) extends LeafNode { |
| 105 | + override def output: Seq[Attribute] = Nil |
| 106 | + } |
| 107 | + |
| 108 | + private class CustomResolveTimeZone(timeZoneId: Option[String]) extends Rule[LogicalPlan] { |
| 109 | + private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = { |
| 110 | + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => |
| 111 | + e.withTimeZone(timeZoneId.getOrElse(conf.sessionLocalTimeZone)) |
| 112 | + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing |
| 113 | + // the types between the value expression and list query expression of IN expression. |
| 114 | + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone |
| 115 | + // information for time zone aware expressions. |
| 116 | + case e: ListQuery => e.withNewPlan(apply(e.plan)) |
| 117 | + } |
| 118 | + |
| 119 | + override def apply(plan: LogicalPlan): LogicalPlan = |
| 120 | + plan.resolveExpressionsWithPruning( |
| 121 | + _.containsAnyPattern(LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION), |
| 122 | + ruleId |
| 123 | + )(transformTimeZoneExprs) |
| 124 | + } |
| 125 | + |
| 126 | + private class TypeCoercionExecutor(timeZoneId: Option[String]) extends RuleExecutor[LogicalPlan] { |
| 127 | + override val batches = |
| 128 | + Batch("Resolve TypeCoercion", FixedPoint(1), typeCoercionRules: _*) :: |
| 129 | + Batch("Resolve TimeZone", FixedPoint(1), new CustomResolveTimeZone(timeZoneId)) :: Nil |
| 130 | + } |
| 131 | + |
| 132 | + private def resolveTransformExpression(expr: Expression): Expression = expr.transform { |
| 133 | + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => |
| 134 | + V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments) |
| 135 | + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => |
| 136 | + V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) |
| 137 | + } |
| 138 | + |
| 139 | + private def typeCoercionRules: List[Rule[LogicalPlan]] = if (conf.ansiEnabled) { |
| 140 | + AnsiTypeCoercion.typeCoercionRules |
| 141 | + } else { |
| 142 | + TypeCoercion.typeCoercionRules |
| 143 | + } |
| 144 | + |
| 145 | + def toCatalyst( |
| 146 | + v2Expr: V2Expression, |
| 147 | + fields: Array[StructField], |
| 148 | + functionRegistry: FunctionRegistry |
| 149 | + ): Expression = |
52 | 150 | v2Expr match {
|
53 |
| - case IdentityTransform(ref) => toCatalyst(ref, fields) |
| 151 | + case IdentityTransform(ref) => toCatalyst(ref, fields, functionRegistry) |
54 | 152 | case ref: NamedReference if ref.fieldNames.length == 1 =>
|
55 | 153 | val (field, ordinal) = fields
|
56 | 154 | .zipWithIndex
|
57 | 155 | .find { case (field, _) => field.name == ref.fieldNames.head }
|
58 | 156 | .getOrElse(throw CHClientException(s"Invalid field reference: $ref"))
|
59 | 157 | BoundReference(ordinal, field.dataType, field.nullable)
|
| 158 | + case t: Transform => |
| 159 | + val catalystArgs = t.arguments().map(toCatalyst(_, fields, functionRegistry)) |
| 160 | + loadV2FunctionOpt(t.name(), catalystArgs, functionRegistry) |
| 161 | + .map(bound => TransformExpression(bound, catalystArgs)).getOrElse { |
| 162 | + throw CHClientException(s"Unsupported expression: $v2Expr") |
| 163 | + } |
| 164 | + case literal: LiteralValue[Any] => expressions.Literal(literal.value) |
60 | 165 | case _ => throw CHClientException(
|
61 |
| - s"Unsupported V2 expression: $v2Expr, SPARK-33779: Spark 3.3 only support IdentityTransform" |
| 166 | + s"Unsupported expression: $v2Expr" |
62 | 167 | )
|
63 | 168 | }
|
64 | 169 |
|
65 |
| - def toSparkTransformOpt(expr: Expr): Option[Transform] = Try(toSparkTransform(expr)) match { |
66 |
| - case Success(t) => Some(t) |
67 |
| - case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None |
68 |
| - case Failure(rethrow) => throw new AnalysisException(rethrow.getMessage, cause = Some(rethrow)) |
69 |
| - } |
| 170 | + def toSparkTransformOpt(expr: Expr, functionRegistry: FunctionRegistry): Option[Transform] = |
| 171 | + Try(toSparkExpression(expr, functionRegistry)) match { |
| 172 | + // need this function because spark `Table`'s `partitioning` field should be `Transform` |
| 173 | + case Success(t: Transform) => Some(t) |
| 174 | + case Success(_) => None |
| 175 | + case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None |
| 176 | + case Failure(rethrow) => throw new AnalysisException(rethrow.getMessage, cause = Some(rethrow)) |
| 177 | + } |
70 | 178 |
|
71 |
| - // Some functions of ClickHouse which match Spark pre-defined Transforms |
72 |
| - // |
73 |
| - // toYear, YEAR - Converts a date or date with time to a UInt16 (AD) |
74 |
| - // toYYYYMM - Converts a date or date with time to a UInt32 (YYYY*100 + MM) |
75 |
| - // toYYYYMMDD - Converts a date or date with time to a UInt32 (YYYY*10000 + MM*100 + DD) |
76 |
| - // toHour, HOUR - Converts a date with time to a UInt8 (0-23) |
77 |
| - |
78 |
| - def toSparkTransform(expr: Expr): Transform = expr match { |
79 |
| - case FieldRef(col) => identity(col) |
80 |
| - case FuncExpr("toYear", List(FieldRef(col))) => years(col) |
81 |
| - case FuncExpr("YEAR", List(FieldRef(col))) => years(col) |
82 |
| - case FuncExpr("toYYYYMM", List(FieldRef(col))) => months(col) |
83 |
| - case FuncExpr("toYYYYMMDD", List(FieldRef(col))) => days(col) |
84 |
| - case FuncExpr("toHour", List(FieldRef(col))) => hours(col) |
85 |
| - case FuncExpr("HOUR", List(FieldRef(col))) => hours(col) |
86 |
| - // TODO support arbitrary functions |
87 |
| - // case FuncExpr("xxHash64", List(FieldRef(col))) => apply("ck_xx_hash64", column(col)) |
88 |
| - case FuncExpr("rand", Nil) => apply("rand") |
89 |
| - case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col) |
90 |
| - case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported") |
91 |
| - } |
| 179 | + def toSparkExpression(expr: Expr, functionRegistry: FunctionRegistry): V2Expression = |
| 180 | + expr match { |
| 181 | + case FieldRef(col) => identity(col) |
| 182 | + case StringLiteral(value) => literal(value) // TODO LiteralTransform |
| 183 | + case FuncExpr("rand", Nil) => apply("rand") |
| 184 | + case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col) |
| 185 | + case FuncExpr(funName, args) if functionRegistry.clickHouseToSparkFunc.contains(funName) => |
| 186 | + apply(functionRegistry.clickHouseToSparkFunc(funName), args.map(toSparkExpression(_, functionRegistry)): _*) |
| 187 | + case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported") |
| 188 | + } |
92 | 189 |
|
93 |
| - def toClickHouse(transform: Transform): Expr = transform match { |
94 |
| - case YearsTransform(FieldReference(Seq(col))) => FuncExpr("toYear", List(FieldRef(col))) |
95 |
| - case MonthsTransform(FieldReference(Seq(col))) => FuncExpr("toYYYYMM", List(FieldRef(col))) |
96 |
| - case DaysTransform(FieldReference(Seq(col))) => FuncExpr("toYYYYMMDD", List(FieldRef(col))) |
97 |
| - case HoursTransform(FieldReference(Seq(col))) => FuncExpr("toHour", List(FieldRef(col))) |
| 190 | + def toClickHouse( |
| 191 | + transform: Transform, |
| 192 | + functionRegistry: FunctionRegistry |
| 193 | + ): Expr = transform match { |
98 | 194 | case IdentityTransform(fieldRefs) => FieldRef(fieldRefs.describe)
|
99 |
| - case ApplyTransform(name, args) => FuncExpr(name, args.map(arg => SQLExpr(arg.describe())).toList) |
| 195 | + case ApplyTransform(name, args) if functionRegistry.sparkToClickHouseFunc.contains(name) => |
| 196 | + FuncExpr(functionRegistry.sparkToClickHouseFunc(name), args.map(arg => SQLExpr(arg.describe)).toList) |
100 | 197 | case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket")
|
101 | 198 | case other: Transform => throw CHClientException(s"Unsupported transform: $other")
|
102 | 199 | }
|
103 | 200 |
|
104 | 201 | def inferTransformSchema(
|
105 | 202 | primarySchema: StructType,
|
106 | 203 | secondarySchema: StructType,
|
107 |
| - transform: Transform |
| 204 | + transform: Transform, |
| 205 | + functionRegistry: FunctionRegistry |
108 | 206 | ): StructField = transform match {
|
109 |
| - case years: YearsTransform => StructField(years.toString, IntegerType) |
110 |
| - case months: MonthsTransform => StructField(months.toString, IntegerType) |
111 |
| - case days: DaysTransform => StructField(days.toString, IntegerType) |
112 |
| - case hours: HoursTransform => StructField(hours.toString, IntegerType) |
113 | 207 | case IdentityTransform(FieldReference(Seq(col))) => primarySchema.find(_.name == col)
|
114 | 208 | .orElse(secondarySchema.find(_.name == col))
|
115 | 209 | .getOrElse(throw CHClientException(s"Invalid partition column: $col"))
|
116 |
| - case ckXxhHash64 @ ApplyTransform("ck_xx_hash64", _) => StructField(ckXxhHash64.toString, LongType) |
| 210 | + case t @ ApplyTransform(transformName, _) if functionRegistry.load(transformName).isDefined => |
| 211 | + val resType = functionRegistry.load(transformName) match { |
| 212 | + case Some(f: ScalarFunction[_]) => f.resultType |
| 213 | + case other => throw CHClientException(s"Unsupported function: $other") |
| 214 | + } |
| 215 | + StructField(t.toString, resType) |
117 | 216 | case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket")
|
118 | 217 | case other: Transform => throw CHClientException(s"Unsupported transform: $other")
|
119 | 218 | }
|
|
0 commit comments