Skip to content

Spark 3.4: Write supports using V2 functions for distribution and ordering #269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,8 @@ class WriteDistributionAndOrderingSuite extends SparkClickHouseSingleTest {
WRITE_REPARTITION_BY_PARTITION.key -> repartitionByPartition.toString,
WRITE_LOCAL_SORT_BY_KEY.key -> localSortByKey.toString
) {
if (!ignoreUnsupportedTransform && repartitionByPartition) {
intercept[AnalysisException](write())
} else {
write()
check()
}
write()
check()
}

Seq(true, false).foreach { ignoreUnsupportedTransform =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,105 +15,204 @@
package org.apache.spark.sql.clickhouse

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression}
import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, NoSuchFunctionException, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, ListQuery, Literal}
import org.apache.spark.sql.catalyst.expressions.{TimeZoneAwareExpression, TransformExpression, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION}
import org.apache.spark.sql.catalyst.{expressions, SQLConfHelper}
import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.IGNORE_UNSUPPORTED_TRANSFORM
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.connector.expressions.Expressions._
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, _}
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, SortOrder => V2SortOrder, _}
import org.apache.spark.sql.types.{StructField, StructType}
import xenon.clickhouse.exception.CHClientException
import xenon.clickhouse.expr._
import xenon.clickhouse.func.FunctionRegistry
import xenon.clickhouse.spec.ClusterSpec

import scala.annotation.tailrec
import scala.util.{Failure, Success, Try}

object ExprUtils extends SQLConfHelper {
object ExprUtils extends SQLConfHelper with Serializable {

def toSparkPartitions(partitionKey: Option[List[Expr]]): Array[Transform] =
partitionKey.seq.flatten.flatten(toSparkTransformOpt).toArray
def toSparkPartitions(
partitionKey: Option[List[Expr]],
functionRegistry: FunctionRegistry
): Array[Transform] =
partitionKey.seq.flatten.flatten(toSparkTransformOpt(_, functionRegistry)).toArray

def toSparkSplits(shardingKey: Option[Expr], partitionKey: Option[List[Expr]]): Array[Transform] =
(shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt).toArray
def toSparkSplits(
shardingKey: Option[Expr],
partitionKey: Option[List[Expr]],
functionRegistry: FunctionRegistry
): Array[Transform] =
(shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt(_, functionRegistry)).toArray

def toSparkSortOrders(
shardingKeyIgnoreRand: Option[Expr],
partitionKey: Option[List[Expr]],
sortingKey: Option[List[OrderExpr]]
): Array[SortOrder] =
toSparkSplits(shardingKeyIgnoreRand, partitionKey).map(Expressions.sort(_, SortDirection.ASCENDING)) ++:
sortingKey: Option[List[OrderExpr]],
cluster: Option[ClusterSpec],
functionRegistry: FunctionRegistry
): Array[V2SortOrder] =
toSparkSplits(
shardingKeyIgnoreRand,
partitionKey,
functionRegistry
).map(Expressions.sort(_, SortDirection.ASCENDING)) ++:
sortingKey.seq.flatten.flatten { case OrderExpr(expr, asc, nullFirst) =>
val direction = if (asc) SortDirection.ASCENDING else SortDirection.DESCENDING
val nullOrder = if (nullFirst) NullOrdering.NULLS_FIRST else NullOrdering.NULLS_LAST
toSparkTransformOpt(expr).map(trans => Expressions.sort(trans, direction, nullOrder))
toSparkTransformOpt(expr, functionRegistry).map(trans =>
Expressions.sort(trans, direction, nullOrder)
)
}.toArray

@tailrec
def toCatalyst(v2Expr: V2Expression, fields: Array[StructField]): Expression =
private def loadV2FunctionOpt(
name: String,
args: Seq[Expression],
functionRegistry: FunctionRegistry
): Option[BoundFunction] = {
def loadFunction(ident: Identifier): UnboundFunction =
functionRegistry.load(ident.name).getOrElse(throw new NoSuchFunctionException(ident))
val inputType = StructType(args.zipWithIndex.map {
case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
})
try {
val unbound = loadFunction(Identifier.of(Array.empty, name))
Some(unbound.bind(inputType))
} catch {
case e: NoSuchFunctionException =>
throw e
case _: UnsupportedOperationException if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) =>
None
case e: UnsupportedOperationException =>
throw new AnalysisException(e.getMessage, cause = Some(e))
}
}

def resolveTransformCatalyst(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow SPARK-44180

catalystExpr: Expression,
timeZoneId: Option[String] = None
): Expression =
new TypeCoercionExecutor(timeZoneId)
.execute(DummyLeafNode(resolveTransformExpression(catalystExpr)))
.asInstanceOf[DummyLeafNode].expr

private case class DummyLeafNode(expr: Expression) extends LeafNode {
override def output: Seq[Attribute] = Nil
}

private class CustomResolveTimeZone(timeZoneId: Option[String]) extends Rule[LogicalPlan] {
private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = {
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
e.withTimeZone(timeZoneId.getOrElse(conf.sessionLocalTimeZone))
// Casts could be added in the subquery plan through the rule TypeCoercion while coercing
// the types between the value expression and list query expression of IN expression.
// We need to subject the subquery plan through ResolveTimeZone again to setup timezone
// information for time zone aware expressions.
case e: ListQuery => e.withNewPlan(apply(e.plan))
}

override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveExpressionsWithPruning(
_.containsAnyPattern(LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION),
ruleId
)(transformTimeZoneExprs)
}

private class TypeCoercionExecutor(timeZoneId: Option[String]) extends RuleExecutor[LogicalPlan] {
override val batches =
Batch("Resolve TypeCoercion", FixedPoint(1), typeCoercionRules: _*) ::
Batch("Resolve TimeZone", FixedPoint(1), new CustomResolveTimeZone(timeZoneId)) :: Nil
}

private def resolveTransformExpression(expr: Expression): Expression = expr.transform {
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) =>
V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments)
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)
}

private def typeCoercionRules: List[Rule[LogicalPlan]] = if (conf.ansiEnabled) {
AnsiTypeCoercion.typeCoercionRules
} else {
TypeCoercion.typeCoercionRules
}

def toCatalyst(
v2Expr: V2Expression,
fields: Array[StructField],
functionRegistry: FunctionRegistry
): Expression =
v2Expr match {
case IdentityTransform(ref) => toCatalyst(ref, fields)
case IdentityTransform(ref) => toCatalyst(ref, fields, functionRegistry)
case ref: NamedReference if ref.fieldNames.length == 1 =>
val (field, ordinal) = fields
.zipWithIndex
.find { case (field, _) => field.name == ref.fieldNames.head }
.getOrElse(throw CHClientException(s"Invalid field reference: $ref"))
BoundReference(ordinal, field.dataType, field.nullable)
case t: Transform =>
val catalystArgs = t.arguments().map(toCatalyst(_, fields, functionRegistry))
loadV2FunctionOpt(t.name(), catalystArgs, functionRegistry)
.map(bound => TransformExpression(bound, catalystArgs)).getOrElse {
throw CHClientException(s"Unsupported expression: $v2Expr")
}
case literal: LiteralValue[Any] => expressions.Literal(literal.value)
case _ => throw CHClientException(
s"Unsupported V2 expression: $v2Expr, SPARK-33779: Spark 3.3 only support IdentityTransform"
s"Unsupported expression: $v2Expr"
)
}

def toSparkTransformOpt(expr: Expr): Option[Transform] = Try(toSparkTransform(expr)) match {
case Success(t) => Some(t)
case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None
case Failure(rethrow) => throw new AnalysisException(rethrow.getMessage, cause = Some(rethrow))
}
def toSparkTransformOpt(expr: Expr, functionRegistry: FunctionRegistry): Option[Transform] =
Try(toSparkExpression(expr, functionRegistry)) match {
// need this function because spark `Table`'s `partitioning` field should be `Transform`
case Success(t: Transform) => Some(t)
case Success(_) => None
case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None
case Failure(rethrow) => throw new AnalysisException(rethrow.getMessage, cause = Some(rethrow))
}

// Some functions of ClickHouse which match Spark pre-defined Transforms
//
// toYear, YEAR - Converts a date or date with time to a UInt16 (AD)
// toYYYYMM - Converts a date or date with time to a UInt32 (YYYY*100 + MM)
// toYYYYMMDD - Converts a date or date with time to a UInt32 (YYYY*10000 + MM*100 + DD)
// toHour, HOUR - Converts a date with time to a UInt8 (0-23)

def toSparkTransform(expr: Expr): Transform = expr match {
case FieldRef(col) => identity(col)
case FuncExpr("toYear", List(FieldRef(col))) => years(col)
case FuncExpr("YEAR", List(FieldRef(col))) => years(col)
case FuncExpr("toYYYYMM", List(FieldRef(col))) => months(col)
case FuncExpr("toYYYYMMDD", List(FieldRef(col))) => days(col)
case FuncExpr("toHour", List(FieldRef(col))) => hours(col)
case FuncExpr("HOUR", List(FieldRef(col))) => hours(col)
// TODO support arbitrary functions
// case FuncExpr("xxHash64", List(FieldRef(col))) => apply("ck_xx_hash64", column(col))
case FuncExpr("rand", Nil) => apply("rand")
case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col)
case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported")
}
def toSparkExpression(expr: Expr, functionRegistry: FunctionRegistry): V2Expression =
expr match {
case FieldRef(col) => identity(col)
case StringLiteral(value) => literal(value) // TODO LiteralTransform
case FuncExpr("rand", Nil) => apply("rand")
case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col)
case FuncExpr(funName, args) if functionRegistry.clickHouseToSparkFunc.contains(funName) =>
apply(functionRegistry.clickHouseToSparkFunc(funName), args.map(toSparkExpression(_, functionRegistry)): _*)
case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported")
}

def toClickHouse(transform: Transform): Expr = transform match {
case YearsTransform(FieldReference(Seq(col))) => FuncExpr("toYear", List(FieldRef(col)))
case MonthsTransform(FieldReference(Seq(col))) => FuncExpr("toYYYYMM", List(FieldRef(col)))
case DaysTransform(FieldReference(Seq(col))) => FuncExpr("toYYYYMMDD", List(FieldRef(col)))
case HoursTransform(FieldReference(Seq(col))) => FuncExpr("toHour", List(FieldRef(col)))
def toClickHouse(
transform: Transform,
functionRegistry: FunctionRegistry
): Expr = transform match {
case IdentityTransform(fieldRefs) => FieldRef(fieldRefs.describe)
case ApplyTransform(name, args) => FuncExpr(name, args.map(arg => SQLExpr(arg.describe())).toList)
case ApplyTransform(name, args) if functionRegistry.sparkToClickHouseFunc.contains(name) =>
FuncExpr(functionRegistry.sparkToClickHouseFunc(name), args.map(arg => SQLExpr(arg.describe)).toList)
case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket")
case other: Transform => throw CHClientException(s"Unsupported transform: $other")
}

def inferTransformSchema(
primarySchema: StructType,
secondarySchema: StructType,
transform: Transform
transform: Transform,
functionRegistry: FunctionRegistry
): StructField = transform match {
case years: YearsTransform => StructField(years.toString, IntegerType)
case months: MonthsTransform => StructField(months.toString, IntegerType)
case days: DaysTransform => StructField(days.toString, IntegerType)
case hours: HoursTransform => StructField(hours.toString, IntegerType)
case IdentityTransform(FieldReference(Seq(col))) => primarySchema.find(_.name == col)
.orElse(secondarySchema.find(_.name == col))
.getOrElse(throw CHClientException(s"Invalid partition column: $col"))
case ckXxhHash64 @ ApplyTransform("ck_xx_hash64", _) => StructField(ckXxhHash64.toString, LongType)
case t @ ApplyTransform(transformName, _) if functionRegistry.load(transformName).isDefined =>
val resType = functionRegistry.load(transformName) match {
case Some(f: ScalarFunction[_]) => f.resultType
case other => throw CHClientException(s"Unsupported function: $other")
}
StructField(t.toString, resType)
case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket")
case other: Transform => throw CHClientException(s"Unsupported transform: $other")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import xenon.clickhouse.Constants._
import xenon.clickhouse.client.NodeClient
import xenon.clickhouse.exception.CHClientException
import xenon.clickhouse.exception.ClickHouseErrCode._
import xenon.clickhouse.func.{FunctionRegistry, _}
import xenon.clickhouse.func.{ClickHouseXxHash64Shard, FunctionRegistry, _}
import xenon.clickhouse.spec._

import java.time.ZoneId
Expand Down Expand Up @@ -91,6 +91,7 @@ class ClickHouseCatalog extends TableCatalog

log.info(s"Detect ${clusterSpecs.size} ClickHouse clusters: ${clusterSpecs.map(_.name).mkString(",")}")
log.info(s"ClickHouse clusters' detail: $clusterSpecs")
log.info(s"Registered functions: ${this.functionRegistry.list.mkString(",")}")
}

override def name(): String = catalogName
Expand Down Expand Up @@ -141,7 +142,8 @@ class ClickHouseCatalog extends TableCatalog
tableClusterSpec,
_tz,
tableSpec,
tableEngineSpec
tableEngineSpec,
functionRegistry
)
}

Expand Down Expand Up @@ -206,7 +208,7 @@ class ClickHouseCatalog extends TableCatalog

val partitionsClause = partitions match {
case transforms if transforms.nonEmpty =>
transforms.map(ExprUtils.toClickHouse(_).sql).mkString("PARTITION BY (", ", ", ")")
transforms.map(ExprUtils.toClickHouse(_, functionRegistry).sql).mkString("PARTITION BY (", ", ", ")")
case _ => ""
}

Expand Down Expand Up @@ -297,7 +299,7 @@ class ClickHouseCatalog extends TableCatalog
}
tableOpt match {
case None => false
case Some(ClickHouseTable(_, cluster, _, tableSpec, _)) =>
case Some(ClickHouseTable(_, cluster, _, tableSpec, _, _)) =>
val (db, tbl) = (tableSpec.database, tableSpec.name)
val isAtomic = loadNamespaceMetadata(Array(db)).get("engine").equalsIgnoreCase("atomic")
val syncClause = if (isAtomic) "SYNC" else ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,12 @@

package xenon.clickhouse

import java.lang.{Integer => JInt, Long => JLong}
import java.time.{LocalDate, ZoneId}
import java.util
import scala.collection.JavaConverters._
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.clickhouse.{ExprUtils, ReadOptions, WriteOptions}
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.{READ_DISTRIBUTED_CONVERT_LOCAL, USE_NULLABLE_QUERY_SCHEMA}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.clickhouse.{ExprUtils, ReadOptions, WriteOptions}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.LogicalWriteInfo
Expand All @@ -34,16 +30,23 @@ import org.apache.spark.unsafe.types.UTF8String
import xenon.clickhouse.Utils._
import xenon.clickhouse.client.NodeClient
import xenon.clickhouse.expr.{Expr, OrderExpr}
import xenon.clickhouse.func.FunctionRegistry
import xenon.clickhouse.read.{ClickHouseMetadataColumn, ClickHouseScanBuilder, ScanJobDescription}
import xenon.clickhouse.spec._
import xenon.clickhouse.write.{ClickHouseWriteBuilder, WriteJobDescription}

import java.lang.{Integer => JInt, Long => JLong}
import java.time.{LocalDate, ZoneId}
import java.util
import scala.collection.JavaConverters._

case class ClickHouseTable(
node: NodeSpec,
cluster: Option[ClusterSpec],
implicit val tz: ZoneId,
spec: TableSpec,
engineSpec: TableEngineSpec
engineSpec: TableEngineSpec,
functionRegistry: FunctionRegistry
) extends Table
with SupportsRead
with SupportsWrite
Expand Down Expand Up @@ -130,10 +133,12 @@ case class ClickHouseTable(
private lazy val metadataSchema: StructType =
StructType(metadataColumns.map(_.asInstanceOf[ClickHouseMetadataColumn].toStructField))

override lazy val partitioning: Array[Transform] = ExprUtils.toSparkPartitions(partitionKey)
override lazy val partitioning: Array[Transform] = ExprUtils.toSparkPartitions(partitionKey, functionRegistry)

override lazy val partitionSchema: StructType = StructType(
partitioning.map(partTransform => ExprUtils.inferTransformSchema(schema, metadataSchema, partTransform))
partitioning.map { partTransform =>
ExprUtils.inferTransformSchema(schema, metadataSchema, partTransform, functionRegistry)
}
)

override lazy val properties: util.Map[String, String] = spec.toJavaMap
Expand Down Expand Up @@ -170,7 +175,8 @@ case class ClickHouseTable(
shardingKey = shardingKey,
partitionKey = partitionKey,
sortingKey = sortingKey,
writeOptions = new WriteOptions(info.options.asCaseSensitiveMap())
writeOptions = new WriteOptions(info.options.asCaseSensitiveMap()),
functionRegistry = functionRegistry
)

new ClickHouseWriteBuilder(writeJob)
Expand Down
Loading