From c025ef39b584f3b4ce4afbec148ca5b10affbae4 Mon Sep 17 00:00:00 2001 From: Xinyuan Yang Date: Wed, 17 May 2023 14:54:38 +0800 Subject: [PATCH] Spark 3.4: Write supports using V2 functions for distribution and ordering --- .../WriteDistributionAndOrderingSuite.scala | 8 +- .../spark/sql/clickhouse/ExprUtils.scala | 211 +++++++++++++----- .../xenon/clickhouse/ClickHouseCatalog.scala | 10 +- .../xenon/clickhouse/ClickHouseTable.scala | 28 ++- .../clickhouse/write/ClickHouseWriter.scala | 48 +++- .../write/WriteJobDescription.scala | 21 +- 6 files changed, 232 insertions(+), 94 deletions(-) diff --git a/spark-3.4/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/WriteDistributionAndOrderingSuite.scala b/spark-3.4/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/WriteDistributionAndOrderingSuite.scala index fe9ba535..7fc0972d 100644 --- a/spark-3.4/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/WriteDistributionAndOrderingSuite.scala +++ b/spark-3.4/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/WriteDistributionAndOrderingSuite.scala @@ -78,12 +78,8 @@ class WriteDistributionAndOrderingSuite extends SparkClickHouseSingleTest { WRITE_REPARTITION_BY_PARTITION.key -> repartitionByPartition.toString, WRITE_LOCAL_SORT_BY_KEY.key -> localSortByKey.toString ) { - if (!ignoreUnsupportedTransform && repartitionByPartition) { - intercept[AnalysisException](write()) - } else { - write() - check() - } + write() + check() } Seq(true, false).foreach { ignoreUnsupportedTransform => diff --git a/spark-3.4/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ExprUtils.scala b/spark-3.4/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ExprUtils.scala index 314c65f3..b9502822 100644 --- a/spark-3.4/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ExprUtils.scala +++ b/spark-3.4/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ExprUtils.scala @@ -15,88 +15,185 @@ package org.apache.spark.sql.clickhouse import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression} +import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, NoSuchFunctionException, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, ListQuery, Literal} +import org.apache.spark.sql.catalyst.expressions.{TimeZoneAwareExpression, TransformExpression, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION} +import org.apache.spark.sql.catalyst.{expressions, SQLConfHelper} import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.IGNORE_UNSUPPORTED_TRANSFORM +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction} import org.apache.spark.sql.connector.expressions.Expressions._ -import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, _} -import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, SortOrder => V2SortOrder, _} +import org.apache.spark.sql.types.{StructField, StructType} import xenon.clickhouse.exception.CHClientException import xenon.clickhouse.expr._ +import xenon.clickhouse.func.FunctionRegistry +import xenon.clickhouse.spec.ClusterSpec -import scala.annotation.tailrec import scala.util.{Failure, Success, Try} -object ExprUtils extends SQLConfHelper { +object ExprUtils extends SQLConfHelper with Serializable { - def toSparkPartitions(partitionKey: Option[List[Expr]]): Array[Transform] = - partitionKey.seq.flatten.flatten(toSparkTransformOpt).toArray + def toSparkPartitions( + partitionKey: Option[List[Expr]], + functionRegistry: FunctionRegistry + ): Array[Transform] = + partitionKey.seq.flatten.flatten(toSparkTransformOpt(_, functionRegistry)).toArray - def toSparkSplits(shardingKey: Option[Expr], partitionKey: Option[List[Expr]]): Array[Transform] = - (shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt).toArray + def toSparkSplits( + shardingKey: Option[Expr], + partitionKey: Option[List[Expr]], + functionRegistry: FunctionRegistry + ): Array[Transform] = + (shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt(_, functionRegistry)).toArray def toSparkSortOrders( shardingKeyIgnoreRand: Option[Expr], partitionKey: Option[List[Expr]], - sortingKey: Option[List[OrderExpr]] - ): Array[SortOrder] = - toSparkSplits(shardingKeyIgnoreRand, partitionKey).map(Expressions.sort(_, SortDirection.ASCENDING)) ++: + sortingKey: Option[List[OrderExpr]], + cluster: Option[ClusterSpec], + functionRegistry: FunctionRegistry + ): Array[V2SortOrder] = + toSparkSplits( + shardingKeyIgnoreRand, + partitionKey, + functionRegistry + ).map(Expressions.sort(_, SortDirection.ASCENDING)) ++: sortingKey.seq.flatten.flatten { case OrderExpr(expr, asc, nullFirst) => val direction = if (asc) SortDirection.ASCENDING else SortDirection.DESCENDING val nullOrder = if (nullFirst) NullOrdering.NULLS_FIRST else NullOrdering.NULLS_LAST - toSparkTransformOpt(expr).map(trans => Expressions.sort(trans, direction, nullOrder)) + toSparkTransformOpt(expr, functionRegistry).map(trans => + Expressions.sort(trans, direction, nullOrder) + ) }.toArray - @tailrec - def toCatalyst(v2Expr: V2Expression, fields: Array[StructField]): Expression = + private def loadV2FunctionOpt( + name: String, + args: Seq[Expression], + functionRegistry: FunctionRegistry + ): Option[BoundFunction] = { + def loadFunction(ident: Identifier): UnboundFunction = + functionRegistry.load(ident.name).getOrElse(throw new NoSuchFunctionException(ident)) + val inputType = StructType(args.zipWithIndex.map { + case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) + }) + try { + val unbound = loadFunction(Identifier.of(Array.empty, name)) + Some(unbound.bind(inputType)) + } catch { + case e: NoSuchFunctionException => + throw e + case _: UnsupportedOperationException if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => + None + case e: UnsupportedOperationException => + throw new AnalysisException(e.getMessage, cause = Some(e)) + } + } + + def resolveTransformCatalyst( + catalystExpr: Expression, + timeZoneId: Option[String] = None + ): Expression = + new TypeCoercionExecutor(timeZoneId) + .execute(DummyLeafNode(resolveTransformExpression(catalystExpr))) + .asInstanceOf[DummyLeafNode].expr + + private case class DummyLeafNode(expr: Expression) extends LeafNode { + override def output: Seq[Attribute] = Nil + } + + private class CustomResolveTimeZone(timeZoneId: Option[String]) extends Rule[LogicalPlan] { + private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = { + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => + e.withTimeZone(timeZoneId.getOrElse(conf.sessionLocalTimeZone)) + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing + // the types between the value expression and list query expression of IN expression. + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone + // information for time zone aware expressions. + case e: ListQuery => e.withNewPlan(apply(e.plan)) + } + + override def apply(plan: LogicalPlan): LogicalPlan = + plan.resolveExpressionsWithPruning( + _.containsAnyPattern(LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION), + ruleId + )(transformTimeZoneExprs) + } + + private class TypeCoercionExecutor(timeZoneId: Option[String]) extends RuleExecutor[LogicalPlan] { + override val batches = + Batch("Resolve TypeCoercion", FixedPoint(1), typeCoercionRules: _*) :: + Batch("Resolve TimeZone", FixedPoint(1), new CustomResolveTimeZone(timeZoneId)) :: Nil + } + + private def resolveTransformExpression(expr: Expression): Expression = expr.transform { + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => + V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments) + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => + V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) + } + + private def typeCoercionRules: List[Rule[LogicalPlan]] = if (conf.ansiEnabled) { + AnsiTypeCoercion.typeCoercionRules + } else { + TypeCoercion.typeCoercionRules + } + + def toCatalyst( + v2Expr: V2Expression, + fields: Array[StructField], + functionRegistry: FunctionRegistry + ): Expression = v2Expr match { - case IdentityTransform(ref) => toCatalyst(ref, fields) + case IdentityTransform(ref) => toCatalyst(ref, fields, functionRegistry) case ref: NamedReference if ref.fieldNames.length == 1 => val (field, ordinal) = fields .zipWithIndex .find { case (field, _) => field.name == ref.fieldNames.head } .getOrElse(throw CHClientException(s"Invalid field reference: $ref")) BoundReference(ordinal, field.dataType, field.nullable) + case t: Transform => + val catalystArgs = t.arguments().map(toCatalyst(_, fields, functionRegistry)) + loadV2FunctionOpt(t.name(), catalystArgs, functionRegistry) + .map(bound => TransformExpression(bound, catalystArgs)).getOrElse { + throw CHClientException(s"Unsupported expression: $v2Expr") + } + case literal: LiteralValue[Any] => expressions.Literal(literal.value) case _ => throw CHClientException( - s"Unsupported V2 expression: $v2Expr, SPARK-33779: Spark 3.3 only support IdentityTransform" + s"Unsupported expression: $v2Expr" ) } - def toSparkTransformOpt(expr: Expr): Option[Transform] = Try(toSparkTransform(expr)) match { - case Success(t) => Some(t) - case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None - case Failure(rethrow) => throw new AnalysisException(rethrow.getMessage, cause = Some(rethrow)) - } + def toSparkTransformOpt(expr: Expr, functionRegistry: FunctionRegistry): Option[Transform] = + Try(toSparkExpression(expr, functionRegistry)) match { + // need this function because spark `Table`'s `partitioning` field should be `Transform` + case Success(t: Transform) => Some(t) + case Success(_) => None + case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None + case Failure(rethrow) => throw new AnalysisException(rethrow.getMessage, cause = Some(rethrow)) + } - // Some functions of ClickHouse which match Spark pre-defined Transforms - // - // toYear, YEAR - Converts a date or date with time to a UInt16 (AD) - // toYYYYMM - Converts a date or date with time to a UInt32 (YYYY*100 + MM) - // toYYYYMMDD - Converts a date or date with time to a UInt32 (YYYY*10000 + MM*100 + DD) - // toHour, HOUR - Converts a date with time to a UInt8 (0-23) - - def toSparkTransform(expr: Expr): Transform = expr match { - case FieldRef(col) => identity(col) - case FuncExpr("toYear", List(FieldRef(col))) => years(col) - case FuncExpr("YEAR", List(FieldRef(col))) => years(col) - case FuncExpr("toYYYYMM", List(FieldRef(col))) => months(col) - case FuncExpr("toYYYYMMDD", List(FieldRef(col))) => days(col) - case FuncExpr("toHour", List(FieldRef(col))) => hours(col) - case FuncExpr("HOUR", List(FieldRef(col))) => hours(col) - // TODO support arbitrary functions - // case FuncExpr("xxHash64", List(FieldRef(col))) => apply("ck_xx_hash64", column(col)) - case FuncExpr("rand", Nil) => apply("rand") - case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col) - case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported") - } + def toSparkExpression(expr: Expr, functionRegistry: FunctionRegistry): V2Expression = + expr match { + case FieldRef(col) => identity(col) + case StringLiteral(value) => literal(value) // TODO LiteralTransform + case FuncExpr("rand", Nil) => apply("rand") + case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col) + case FuncExpr(funName, args) if functionRegistry.clickHouseToSparkFunc.contains(funName) => + apply(functionRegistry.clickHouseToSparkFunc(funName), args.map(toSparkExpression(_, functionRegistry)): _*) + case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported") + } - def toClickHouse(transform: Transform): Expr = transform match { - case YearsTransform(FieldReference(Seq(col))) => FuncExpr("toYear", List(FieldRef(col))) - case MonthsTransform(FieldReference(Seq(col))) => FuncExpr("toYYYYMM", List(FieldRef(col))) - case DaysTransform(FieldReference(Seq(col))) => FuncExpr("toYYYYMMDD", List(FieldRef(col))) - case HoursTransform(FieldReference(Seq(col))) => FuncExpr("toHour", List(FieldRef(col))) + def toClickHouse( + transform: Transform, + functionRegistry: FunctionRegistry + ): Expr = transform match { case IdentityTransform(fieldRefs) => FieldRef(fieldRefs.describe) - case ApplyTransform(name, args) => FuncExpr(name, args.map(arg => SQLExpr(arg.describe())).toList) + case ApplyTransform(name, args) if functionRegistry.sparkToClickHouseFunc.contains(name) => + FuncExpr(functionRegistry.sparkToClickHouseFunc(name), args.map(arg => SQLExpr(arg.describe)).toList) case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket") case other: Transform => throw CHClientException(s"Unsupported transform: $other") } @@ -104,16 +201,18 @@ object ExprUtils extends SQLConfHelper { def inferTransformSchema( primarySchema: StructType, secondarySchema: StructType, - transform: Transform + transform: Transform, + functionRegistry: FunctionRegistry ): StructField = transform match { - case years: YearsTransform => StructField(years.toString, IntegerType) - case months: MonthsTransform => StructField(months.toString, IntegerType) - case days: DaysTransform => StructField(days.toString, IntegerType) - case hours: HoursTransform => StructField(hours.toString, IntegerType) case IdentityTransform(FieldReference(Seq(col))) => primarySchema.find(_.name == col) .orElse(secondarySchema.find(_.name == col)) .getOrElse(throw CHClientException(s"Invalid partition column: $col")) - case ckXxhHash64 @ ApplyTransform("ck_xx_hash64", _) => StructField(ckXxhHash64.toString, LongType) + case t @ ApplyTransform(transformName, _) if functionRegistry.load(transformName).isDefined => + val resType = functionRegistry.load(transformName) match { + case Some(f: ScalarFunction[_]) => f.resultType + case other => throw CHClientException(s"Unsupported function: $other") + } + StructField(t.toString, resType) case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket") case other: Transform => throw CHClientException(s"Unsupported transform: $other") } diff --git a/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/ClickHouseCatalog.scala b/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/ClickHouseCatalog.scala index 02862392..9698e823 100644 --- a/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/ClickHouseCatalog.scala +++ b/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/ClickHouseCatalog.scala @@ -26,7 +26,7 @@ import xenon.clickhouse.Constants._ import xenon.clickhouse.client.NodeClient import xenon.clickhouse.exception.CHClientException import xenon.clickhouse.exception.ClickHouseErrCode._ -import xenon.clickhouse.func.{FunctionRegistry, _} +import xenon.clickhouse.func.{ClickHouseXxHash64Shard, FunctionRegistry, _} import xenon.clickhouse.spec._ import java.time.ZoneId @@ -91,6 +91,7 @@ class ClickHouseCatalog extends TableCatalog log.info(s"Detect ${clusterSpecs.size} ClickHouse clusters: ${clusterSpecs.map(_.name).mkString(",")}") log.info(s"ClickHouse clusters' detail: $clusterSpecs") + log.info(s"Registered functions: ${this.functionRegistry.list.mkString(",")}") } override def name(): String = catalogName @@ -141,7 +142,8 @@ class ClickHouseCatalog extends TableCatalog tableClusterSpec, _tz, tableSpec, - tableEngineSpec + tableEngineSpec, + functionRegistry ) } @@ -206,7 +208,7 @@ class ClickHouseCatalog extends TableCatalog val partitionsClause = partitions match { case transforms if transforms.nonEmpty => - transforms.map(ExprUtils.toClickHouse(_).sql).mkString("PARTITION BY (", ", ", ")") + transforms.map(ExprUtils.toClickHouse(_, functionRegistry).sql).mkString("PARTITION BY (", ", ", ")") case _ => "" } @@ -297,7 +299,7 @@ class ClickHouseCatalog extends TableCatalog } tableOpt match { case None => false - case Some(ClickHouseTable(_, cluster, _, tableSpec, _)) => + case Some(ClickHouseTable(_, cluster, _, tableSpec, _, _)) => val (db, tbl) = (tableSpec.database, tableSpec.name) val isAtomic = loadNamespaceMetadata(Array(db)).get("engine").equalsIgnoreCase("atomic") val syncClause = if (isAtomic) "SYNC" else "" diff --git a/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/ClickHouseTable.scala b/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/ClickHouseTable.scala index 59b3ca9f..83846c34 100644 --- a/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/ClickHouseTable.scala +++ b/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/ClickHouseTable.scala @@ -14,16 +14,12 @@ package xenon.clickhouse -import java.lang.{Integer => JInt, Long => JLong} -import java.time.{LocalDate, ZoneId} -import java.util -import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.clickhouse.{ExprUtils, ReadOptions, WriteOptions} +import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.{READ_DISTRIBUTED_CONVERT_LOCAL, USE_NULLABLE_QUERY_SCHEMA} -import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.clickhouse.{ExprUtils, ReadOptions, WriteOptions} import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.LogicalWriteInfo @@ -34,16 +30,23 @@ import org.apache.spark.unsafe.types.UTF8String import xenon.clickhouse.Utils._ import xenon.clickhouse.client.NodeClient import xenon.clickhouse.expr.{Expr, OrderExpr} +import xenon.clickhouse.func.FunctionRegistry import xenon.clickhouse.read.{ClickHouseMetadataColumn, ClickHouseScanBuilder, ScanJobDescription} import xenon.clickhouse.spec._ import xenon.clickhouse.write.{ClickHouseWriteBuilder, WriteJobDescription} +import java.lang.{Integer => JInt, Long => JLong} +import java.time.{LocalDate, ZoneId} +import java.util +import scala.collection.JavaConverters._ + case class ClickHouseTable( node: NodeSpec, cluster: Option[ClusterSpec], implicit val tz: ZoneId, spec: TableSpec, - engineSpec: TableEngineSpec + engineSpec: TableEngineSpec, + functionRegistry: FunctionRegistry ) extends Table with SupportsRead with SupportsWrite @@ -130,10 +133,12 @@ case class ClickHouseTable( private lazy val metadataSchema: StructType = StructType(metadataColumns.map(_.asInstanceOf[ClickHouseMetadataColumn].toStructField)) - override lazy val partitioning: Array[Transform] = ExprUtils.toSparkPartitions(partitionKey) + override lazy val partitioning: Array[Transform] = ExprUtils.toSparkPartitions(partitionKey, functionRegistry) override lazy val partitionSchema: StructType = StructType( - partitioning.map(partTransform => ExprUtils.inferTransformSchema(schema, metadataSchema, partTransform)) + partitioning.map { partTransform => + ExprUtils.inferTransformSchema(schema, metadataSchema, partTransform, functionRegistry) + } ) override lazy val properties: util.Map[String, String] = spec.toJavaMap @@ -170,7 +175,8 @@ case class ClickHouseTable( shardingKey = shardingKey, partitionKey = partitionKey, sortingKey = sortingKey, - writeOptions = new WriteOptions(info.options.asCaseSensitiveMap()) + writeOptions = new WriteOptions(info.options.asCaseSensitiveMap()), + functionRegistry = functionRegistry ) new ClickHouseWriteBuilder(writeJob) diff --git a/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/ClickHouseWriter.scala b/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/ClickHouseWriter.scala index d18319e5..56e1b457 100644 --- a/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/ClickHouseWriter.scala +++ b/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/ClickHouseWriter.scala @@ -17,7 +17,8 @@ package xenon.clickhouse.write import com.clickhouse.client.ClickHouseProtocol import com.clickhouse.data.ClickHouseCompression import org.apache.commons.io.IOUtils -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, SafeProjection} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, TransformExpression} +import org.apache.spark.sql.catalyst.expressions.{Projection, SafeProjection} import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.clickhouse.ExprUtils import org.apache.spark.sql.connector.metric.CustomTaskMetric @@ -56,7 +57,7 @@ abstract class ClickHouseWriter(writeJob: WriteJobDescription) protected lazy val shardExpr: Option[Expression] = writeJob.sparkShardExpr match { case None => None case Some(v2Expr) => - val catalystExpr = ExprUtils.toCatalyst(v2Expr, writeJob.dataSetSchema.fields) + val catalystExpr = ExprUtils.toCatalyst(v2Expr, writeJob.dataSetSchema.fields, writeJob.functionRegistry) catalystExpr match { case BoundReference(_, dataType, _) if dataType.isInstanceOf[ByteType] // list all integral types here because we can not access `IntegralType` @@ -66,6 +67,11 @@ abstract class ClickHouseWriter(writeJob: WriteJobDescription) Some(catalystExpr) case BoundReference(_, dataType, _) => throw CHClientException(s"Invalid data type of sharding field: $dataType") + case TransformExpression(function, _, _) => + function.resultType match { + case ByteType | ShortType | IntegerType | LongType => Some(catalystExpr) + case _ => throw CHClientException(s"Invalid data type of sharding field: ${function.resultType}") + } case unsupported: Expression => log.warn(s"Unsupported expression of sharding field: $unsupported") None @@ -74,7 +80,21 @@ abstract class ClickHouseWriter(writeJob: WriteJobDescription) protected lazy val shardProjection: Option[expressions.Projection] = shardExpr .filter(_ => writeJob.writeOptions.convertDistributedToLocal) - .map(expr => SafeProjection.create(Seq(expr))) + .flatMap { + case expr: BoundReference => + Some(SafeProjection.create(Seq(expr))) + case expr @ TransformExpression(function, _, _) => + // result type must be integer class + function.resultType match { + case ByteType => classOf[Byte] + case ShortType => classOf[Short] + case IntegerType => classOf[Int] + case LongType => classOf[Long] + case _ => throw CHClientException(s"Invalid return data type for function ${function.name()}," + + s"sharding field: ${function.resultType}") + } + Some(SafeProjection.create(Seq(ExprUtils.resolveTransformCatalyst(expr, Some(writeJob.tz.getId))))) + } // put the node select strategy in executor side because we need to calculate shard and don't know the records // util DataWriter#write(InternalRow) invoked. @@ -99,17 +119,23 @@ abstract class ClickHouseWriter(writeJob: WriteJobDescription) def calcShard(record: InternalRow): Option[Int] = (shardExpr, shardProjection) match { case (Some(BoundReference(_, dataType, _)), Some(projection)) => - val shardValue = dataType match { - case ByteType => Some(projection(record).getByte(0).toLong) - case ShortType => Some(projection(record).getShort(0).toLong) - case IntegerType => Some(projection(record).getInt(0).toLong) - case LongType => Some(projection(record).getLong(0)) - case _ => None - } - shardValue.map(value => ShardUtils.calcShard(writeJob.cluster.get, value).num) + doCalcShard(record, dataType, projection) + case (Some(TransformExpression(function, _, _)), Some(projection)) => + doCalcShard(record, function.resultType, projection) case _ => None } + private def doCalcShard(record: InternalRow, dataType: DataType, projection: Projection): Option[Int] = { + val shardValue = dataType match { + case ByteType => Some(projection(record).getByte(0).toLong) + case ShortType => Some(projection(record).getShort(0).toLong) + case IntegerType => Some(projection(record).getInt(0).toLong) + case LongType => Some(projection(record).getLong(0)) + case _ => None + } + shardValue.map(value => ShardUtils.calcShard(writeJob.cluster.get, value).num) + } + val _currentBufferedRows = new LongAdder def currentBufferedRows: Long = _currentBufferedRows.longValue val _totalRecordsWritten = new LongAdder diff --git a/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/WriteJobDescription.scala b/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/WriteJobDescription.scala index 9cd8262f..411f08a4 100644 --- a/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/WriteJobDescription.scala +++ b/spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/WriteJobDescription.scala @@ -15,11 +15,11 @@ package xenon.clickhouse.write import java.time.ZoneId - import org.apache.spark.sql.clickhouse.{ExprUtils, WriteOptions} import org.apache.spark.sql.connector.expressions.{Expression, SortOrder, Transform} import org.apache.spark.sql.types.StructType import xenon.clickhouse.expr.{Expr, FuncExpr, OrderExpr} +import xenon.clickhouse.func.FunctionRegistry import xenon.clickhouse.spec._ case class WriteJobDescription( @@ -37,7 +37,8 @@ case class WriteJobDescription( shardingKey: Option[Expr], partitionKey: Option[List[Expr]], sortingKey: Option[List[OrderExpr]], - writeOptions: WriteOptions + writeOptions: WriteOptions, + functionRegistry: FunctionRegistry ) { def targetDatabase(convert2Local: Boolean): String = tableEngineSpec match { @@ -56,20 +57,28 @@ case class WriteJobDescription( } def sparkShardExpr: Option[Expression] = shardingKeyIgnoreRand match { - case Some(expr) => ExprUtils.toSparkTransformOpt(expr) + case Some(expr) => ExprUtils.toSparkTransformOpt(expr, functionRegistry) case _ => None } def sparkSplits: Array[Transform] = if (writeOptions.repartitionByPartition) { - ExprUtils.toSparkSplits(shardingKeyIgnoreRand, partitionKey) + ExprUtils.toSparkSplits( + shardingKeyIgnoreRand, + partitionKey, + functionRegistry + ) } else { - ExprUtils.toSparkSplits(shardingKeyIgnoreRand, None) + ExprUtils.toSparkSplits( + shardingKeyIgnoreRand, + None, + functionRegistry + ) } def sparkSortOrders: Array[SortOrder] = { val _partitionKey = if (writeOptions.localSortByPartition) partitionKey else None val _sortingKey = if (writeOptions.localSortByKey) sortingKey else None - ExprUtils.toSparkSortOrders(shardingKeyIgnoreRand, _partitionKey, _sortingKey) + ExprUtils.toSparkSortOrders(shardingKeyIgnoreRand, _partitionKey, _sortingKey, cluster, functionRegistry) } }