diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 285044803d73..0e8e19b69294 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -37,6 +37,7 @@ use datafusion_common::pruning::{ }; use datafusion_common::{exec_err, Result}; use datafusion_datasource::PartitionedFile; +use datafusion_physical_expr::PhysicalExprSchemaRewriter; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_optimizer::pruning::PruningPredicate; use datafusion_physical_plan::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder}; @@ -117,7 +118,6 @@ impl FileOpener for ParquetOpener { let projected_schema = SchemaRef::from(self.logical_file_schema.project(&self.projection)?); - let schema_adapter_factory = Arc::clone(&self.schema_adapter_factory); let schema_adapter = self .schema_adapter_factory .create(projected_schema, Arc::clone(&self.logical_file_schema)); @@ -159,7 +159,7 @@ impl FileOpener for ParquetOpener { if let Some(pruning_predicate) = pruning_predicate { // The partition column schema is the schema of the table - the schema of the file let mut pruning = Box::new(PartitionPruningStatistics::try_new( - vec![file.partition_values], + vec![file.partition_values.clone()], partition_fields.clone(), )?) as Box; @@ -248,10 +248,25 @@ impl FileOpener for ParquetOpener { } } + let predicate = predicate + .map(|p| { + PhysicalExprSchemaRewriter::new( + &physical_file_schema, + &logical_file_schema, + ) + .with_partition_columns( + partition_fields.to_vec(), + file.partition_values, + ) + .rewrite(p) + .map_err(ArrowError::from) + }) + .transpose()?; + // Build predicates for this specific file let (pruning_predicate, page_pruning_predicate) = build_pruning_predicates( predicate.as_ref(), - &logical_file_schema, + &physical_file_schema, &predicate_creation_errors, ); @@ -288,11 +303,9 @@ impl FileOpener for ParquetOpener { let row_filter = row_filter::build_row_filter( &predicate, &physical_file_schema, - &logical_file_schema, builder.metadata(), reorder_predicates, &file_metrics, - &schema_adapter_factory, ); match row_filter { @@ -879,4 +892,115 @@ mod test { assert_eq!(num_batches, 0); assert_eq!(num_rows, 0); } + + #[tokio::test] + async fn test_prune_on_partition_value_and_data_value() { + let store = Arc::new(InMemory::new()) as Arc; + + // Note: number 3 is missing! + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(4)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + + let make_opener = |predicate| { + ParquetOpener { + partition_index: 0, + projection: Arc::new([0]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: file_schema.clone(), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(Arc::clone(&store)), + ), + partition_fields: vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))], + pushdown_filters: true, // note that this is true! + reorder_filters: true, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + enable_row_group_stats_pruning: false, // note that this is false! + coerce_int96: None, + } + }; + + let make_meta = || FileMeta { + object_meta: ObjectMeta { + location: Path::from("part=1/file.parquet"), + last_modified: Utc::now(), + size: u64::try_from(data_size).unwrap(), + e_tag: None, + version: None, + }, + range: None, + extensions: None, + metadata_size_hint: None, + }; + + // Filter should match the partition value and data value + let expr = col("part").eq(lit(1)).or(col("a").eq(lit(1))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Filter should match the partition value but not the data value + let expr = col("part").eq(lit(1)).or(col("a").eq(lit(3))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Filter should not match the partition value but match the data value + let expr = col("part").eq(lit(2)).or(col("a").eq(lit(1))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 1); + + // Filter should not match the partition value or the data value + let expr = col("part").eq(lit(2)).or(col("a").eq(lit(3))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(make_meta(), file).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } } diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index db455fed6160..5626f83186e3 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -67,6 +67,7 @@ use arrow::array::BooleanArray; use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; +use itertools::Itertools; use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; use parquet::arrow::ProjectionMask; use parquet::file::metadata::ParquetMetaData; @@ -74,9 +75,8 @@ use parquet::file::metadata::ParquetMetaData; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::Result; -use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaMapper}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::reassign_predicate_columns; +use datafusion_physical_expr::utils::{collect_columns, reassign_predicate_columns}; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use datafusion_physical_plan::metrics; @@ -106,8 +106,6 @@ pub(crate) struct DatafusionArrowPredicate { rows_matched: metrics::Count, /// how long was spent evaluating this predicate time: metrics::Time, - /// used to perform type coercion while filtering rows - schema_mapper: Arc, } impl DatafusionArrowPredicate { @@ -132,7 +130,6 @@ impl DatafusionArrowPredicate { rows_pruned, rows_matched, time, - schema_mapper: candidate.schema_mapper, }) } } @@ -143,8 +140,6 @@ impl ArrowPredicate for DatafusionArrowPredicate { } fn evaluate(&mut self, batch: RecordBatch) -> ArrowResult { - let batch = self.schema_mapper.map_batch(batch)?; - // scoped timer updates on drop let mut timer = self.time.timer(); @@ -187,9 +182,6 @@ pub(crate) struct FilterCandidate { /// required to pass thorugh a `SchemaMapper` to the table schema /// upon which we then evaluate the filter expression. projection: Vec, - /// A `SchemaMapper` used to map batches read from the file schema to - /// the filter's projection of the table schema. - schema_mapper: Arc, /// The projected table schema that this filter references filter_schema: SchemaRef, } @@ -230,26 +222,11 @@ struct FilterCandidateBuilder { /// columns in the file schema that are not in the table schema or columns that /// are in the table schema that are not in the file schema. file_schema: SchemaRef, - /// The schema of the table (merged schema) -- columns may be in different - /// order than in the file and have columns that are not in the file schema - table_schema: SchemaRef, - /// A `SchemaAdapterFactory` used to map the file schema to the table schema. - schema_adapter_factory: Arc, } impl FilterCandidateBuilder { - pub fn new( - expr: Arc, - file_schema: Arc, - table_schema: Arc, - schema_adapter_factory: Arc, - ) -> Self { - Self { - expr, - file_schema, - table_schema, - schema_adapter_factory, - } + pub fn new(expr: Arc, file_schema: Arc) -> Self { + Self { expr, file_schema } } /// Attempt to build a `FilterCandidate` from the expression @@ -261,20 +238,21 @@ impl FilterCandidateBuilder { /// * `Err(e)` if an error occurs while building the candidate pub fn build(self, metadata: &ParquetMetaData) -> Result> { let Some(required_indices_into_table_schema) = - pushdown_columns(&self.expr, &self.table_schema)? + pushdown_columns(&self.expr, &self.file_schema)? else { return Ok(None); }; let projected_table_schema = Arc::new( - self.table_schema + self.file_schema .project(&required_indices_into_table_schema)?, ); - let (schema_mapper, projection_into_file_schema) = self - .schema_adapter_factory - .create(Arc::clone(&projected_table_schema), self.table_schema) - .map_schema(&self.file_schema)?; + let projection_into_file_schema = collect_columns(&self.expr) + .iter() + .map(|c| c.index()) + .sorted_unstable() + .collect_vec(); let required_bytes = size_of_columns(&projection_into_file_schema, metadata)?; let can_use_index = columns_sorted(&projection_into_file_schema, metadata)?; @@ -284,7 +262,6 @@ impl FilterCandidateBuilder { required_bytes, can_use_index, projection: projection_into_file_schema, - schema_mapper: Arc::clone(&schema_mapper), filter_schema: Arc::clone(&projected_table_schema), })) } @@ -426,11 +403,9 @@ fn columns_sorted(_columns: &[usize], _metadata: &ParquetMetaData) -> Result, physical_file_schema: &SchemaRef, - logical_file_schema: &SchemaRef, metadata: &ParquetMetaData, reorder_predicates: bool, file_metrics: &ParquetFileMetrics, - schema_adapter_factory: &Arc, ) -> Result> { let rows_pruned = &file_metrics.pushdown_rows_pruned; let rows_matched = &file_metrics.pushdown_rows_matched; @@ -447,8 +422,6 @@ pub fn build_row_filter( FilterCandidateBuilder::new( Arc::clone(expr), Arc::clone(physical_file_schema), - Arc::clone(logical_file_schema), - Arc::clone(schema_adapter_factory), ) .build(metadata) }) @@ -492,13 +465,9 @@ mod test { use super::*; use datafusion_common::ScalarValue; - use arrow::datatypes::{Field, TimeUnit::Nanosecond}; - use datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory; use datafusion_expr::{col, Expr}; use datafusion_physical_expr::planner::logical2physical; - use datafusion_physical_plan::metrics::{Count, Time}; - use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; use parquet::arrow::parquet_to_arrow_schema; use parquet::file::reader::{FileReader, SerializedFileReader}; @@ -520,111 +489,15 @@ mod test { let expr = col("int64_list").is_not_null(); let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); let table_schema = Arc::new(table_schema.clone()); - let candidate = FilterCandidateBuilder::new( - expr, - table_schema.clone(), - table_schema, - schema_adapter_factory, - ) - .build(metadata) - .expect("building candidate"); + let candidate = FilterCandidateBuilder::new(expr, table_schema.clone()) + .build(metadata) + .expect("building candidate"); assert!(candidate.is_none()); } - #[test] - fn test_filter_type_coercion() { - let testdata = datafusion_common::test_util::parquet_test_data(); - let file = std::fs::File::open(format!("{testdata}/alltypes_plain.parquet")) - .expect("opening file"); - - let parquet_reader_builder = - ParquetRecordBatchReaderBuilder::try_new(file).expect("creating reader"); - let metadata = parquet_reader_builder.metadata().clone(); - let file_schema = parquet_reader_builder.schema().clone(); - - // This is the schema we would like to coerce to, - // which is different from the physical schema of the file. - let table_schema = Schema::new(vec![Field::new( - "timestamp_col", - DataType::Timestamp(Nanosecond, Some(Arc::from("UTC"))), - false, - )]); - - // Test all should fail - let expr = col("timestamp_col").lt(Expr::Literal( - ScalarValue::TimestampNanosecond(Some(1), Some(Arc::from("UTC"))), - None, - )); - let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); - let table_schema = Arc::new(table_schema.clone()); - let candidate = FilterCandidateBuilder::new( - expr, - file_schema.clone(), - table_schema.clone(), - schema_adapter_factory, - ) - .build(&metadata) - .expect("building candidate") - .expect("candidate expected"); - - let mut row_filter = DatafusionArrowPredicate::try_new( - candidate, - &metadata, - Count::new(), - Count::new(), - Time::new(), - ) - .expect("creating filter predicate"); - - let mut parquet_reader = parquet_reader_builder - .with_projection(row_filter.projection().clone()) - .build() - .expect("building reader"); - - // Parquet file is small, we only need 1 record batch - let first_rb = parquet_reader - .next() - .expect("expected record batch") - .expect("expected error free record batch"); - - let filtered = row_filter.evaluate(first_rb.clone()); - assert!(matches!(filtered, Ok(a) if a == BooleanArray::from(vec![false; 8]))); - - // Test all should pass - let expr = col("timestamp_col").gt(Expr::Literal( - ScalarValue::TimestampNanosecond(Some(0), Some(Arc::from("UTC"))), - None, - )); - let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); - let candidate = FilterCandidateBuilder::new( - expr, - file_schema, - table_schema, - schema_adapter_factory, - ) - .build(&metadata) - .expect("building candidate") - .expect("candidate expected"); - - let mut row_filter = DatafusionArrowPredicate::try_new( - candidate, - &metadata, - Count::new(), - Count::new(), - Time::new(), - ) - .expect("creating filter predicate"); - - let filtered = row_filter.evaluate(first_rb); - assert!(matches!(filtered, Ok(a) if a == BooleanArray::from(vec![true; 8]))); - } - #[test] fn nested_data_structures_prevent_pushdown() { let table_schema = Arc::new(get_lists_table_schema()); diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 6741f94c9545..f74b739d15a4 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -37,6 +37,7 @@ mod partitioning; mod physical_expr; pub mod planner; mod scalar_function; +pub mod schema_rewriter; pub mod statistics; pub mod utils; pub mod window; @@ -67,6 +68,7 @@ pub use datafusion_physical_expr_common::sort_expr::{ pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; +pub use schema_rewriter::PhysicalExprSchemaRewriter; pub use utils::{conjunction, conjunction_opt, split_conjunction}; // For backwards compatibility diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs new file mode 100644 index 000000000000..53af90862435 --- /dev/null +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -0,0 +1,318 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical expression schema rewriting utilities + +use std::sync::Arc; + +use arrow::compute::can_cast_types; +use arrow::datatypes::{FieldRef, Schema}; +use datafusion_common::{ + exec_err, + tree_node::{Transformed, TransformedResult, TreeNode}, + Result, ScalarValue, +}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +use crate::expressions::{self, CastExpr, Column}; + +/// Builder for rewriting physical expressions to match different schemas. +/// +/// # Example +/// +/// ```rust +/// use datafusion_physical_expr::schema_rewriter::PhysicalExprSchemaRewriter; +/// use arrow::datatypes::Schema; +/// +/// # fn example( +/// # predicate: std::sync::Arc, +/// # physical_file_schema: &Schema, +/// # logical_file_schema: &Schema, +/// # ) -> datafusion_common::Result<()> { +/// let rewriter = PhysicalExprSchemaRewriter::new(physical_file_schema, logical_file_schema); +/// let adapted_predicate = rewriter.rewrite(predicate)?; +/// # Ok(()) +/// # } +/// ``` +pub struct PhysicalExprSchemaRewriter<'a> { + physical_file_schema: &'a Schema, + logical_file_schema: &'a Schema, + partition_fields: Vec, + partition_values: Vec, +} + +impl<'a> PhysicalExprSchemaRewriter<'a> { + /// Create a new schema rewriter with the given schemas + pub fn new( + physical_file_schema: &'a Schema, + logical_file_schema: &'a Schema, + ) -> Self { + Self { + physical_file_schema, + logical_file_schema, + partition_fields: Vec::new(), + partition_values: Vec::new(), + } + } + + /// Add partition columns and their corresponding values + /// + /// When a column reference matches a partition field, it will be replaced + /// with the corresponding literal value from partition_values. + pub fn with_partition_columns( + mut self, + partition_fields: Vec, + partition_values: Vec, + ) -> Self { + self.partition_fields = partition_fields; + self.partition_values = partition_values; + self + } + + /// Rewrite the given physical expression to match the target schema + /// + /// This method applies the following transformations: + /// 1. Replaces partition column references with literal values + /// 2. Handles missing columns by inserting null literals + /// 3. Casts columns when logical and physical schemas have different types + pub fn rewrite(&self, expr: Arc) -> Result> { + expr.transform(|expr| self.rewrite_expr(expr)).data() + } + + fn rewrite_expr( + &self, + expr: Arc, + ) -> Result>> { + if let Some(column) = expr.as_any().downcast_ref::() { + return self.rewrite_column(Arc::clone(&expr), column); + } + + Ok(Transformed::no(expr)) + } + + fn rewrite_column( + &self, + expr: Arc, + column: &Column, + ) -> Result>> { + // Get the logical field for this column + let logical_field = match self.logical_file_schema.field_with_name(column.name()) + { + Ok(field) => field, + Err(e) => { + // If the column is a partition field, we can use the partition value + if let Some(partition_value) = self.get_partition_value(column.name()) { + return Ok(Transformed::yes(expressions::lit(partition_value))); + } + // If the column is not found in the logical schema and is not a partition value, return an error + // This should probably never be hit unless something upstream broke, but nontheless it's better + // for us to return a handleable error than to panic / do something unexpected. + return Err(e.into()); + } + }; + + // Check if the column exists in the physical schema + let physical_column_index = + match self.physical_file_schema.index_of(column.name()) { + Ok(index) => index, + Err(_) => { + if !logical_field.is_nullable() { + return exec_err!( + "Non-nullable column '{}' is missing from the physical schema", + column.name() + ); + } + // If the column is missing from the physical schema fill it in with nulls as `SchemaAdapter` would do. + // TODO: do we need to sync this with what the `SchemaAdapter` actually does? + // While the default implementation fills in nulls in theory a custom `SchemaAdapter` could do something else! + let null_value = + ScalarValue::Null.cast_to(logical_field.data_type())?; + return Ok(Transformed::yes(expressions::lit(null_value))); + } + }; + let physical_field = self.physical_file_schema.field(physical_column_index); + + let column = match ( + column.index() == physical_column_index, + logical_field.data_type() == physical_field.data_type(), + ) { + // If the column index matches and the data types match, we can use the column as is + (true, true) => return Ok(Transformed::no(expr)), + // If the indexes or data types do not match, we need to create a new column expression + (true, _) => column.clone(), + (false, _) => { + Column::new_with_schema(logical_field.name(), self.physical_file_schema)? + } + }; + + if logical_field.data_type() == physical_field.data_type() { + // If the data types match, we can use the column as is + return Ok(Transformed::yes(Arc::new(column))); + } + + // We need to cast the column to the logical data type + // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123` + // since that's much cheaper to evalaute. + // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928 + if !can_cast_types(physical_field.data_type(), logical_field.data_type()) { + return exec_err!( + "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", + column.name(), + physical_field.data_type(), + logical_field.data_type() + ); + } + + let cast_expr = Arc::new(CastExpr::new( + Arc::new(column), + logical_field.data_type().clone(), + None, + )); + + Ok(Transformed::yes(cast_expr)) + } + + fn get_partition_value(&self, column_name: &str) -> Option { + self.partition_fields + .iter() + .zip(self.partition_values.iter()) + .find(|(field, _)| field.name() == column_name) + .map(|(_, value)| value.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ScalarValue; + use std::sync::Arc; + + fn create_test_schema() -> (Schema, Schema) { + let physical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ]); + + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), // Different type + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), // Missing from physical + ]); + + (physical_schema, logical_schema) + } + + #[test] + fn test_rewrite_column_with_type_cast() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("a", 0)); + + let result = rewriter.rewrite(column_expr)?; + + // Should be wrapped in a cast expression + assert!(result.as_any().downcast_ref::().is_some()); + + Ok(()) + } + + #[test] + fn test_rewrite_missing_column() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("c", 2)); + + let result = rewriter.rewrite(column_expr)?; + + // Should be replaced with a literal null + if let Some(literal) = result.as_any().downcast_ref::() { + assert_eq!(*literal.value(), ScalarValue::Float64(None)); + } else { + panic!("Expected literal expression"); + } + + Ok(()) + } + + #[test] + fn test_rewrite_partition_column() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let partition_fields = + vec![Arc::new(Field::new("partition_col", DataType::Utf8, false))]; + let partition_values = vec![ScalarValue::Utf8(Some("test_value".to_string()))]; + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema) + .with_partition_columns(partition_fields, partition_values); + + let column_expr = Arc::new(Column::new("partition_col", 0)); + let result = rewriter.rewrite(column_expr)?; + + // Should be replaced with the partition value + if let Some(literal) = result.as_any().downcast_ref::() { + assert_eq!( + *literal.value(), + ScalarValue::Utf8(Some("test_value".to_string())) + ); + } else { + panic!("Expected literal expression"); + } + + Ok(()) + } + + #[test] + fn test_rewrite_no_change_needed() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("b", 1)) as Arc; + + let result = rewriter.rewrite(Arc::clone(&column_expr))?; + + // Should be the same expression (no transformation needed) + // We compare the underlying pointer through the trait object + assert!(std::ptr::eq( + column_expr.as_ref() as *const dyn PhysicalExpr, + result.as_ref() as *const dyn PhysicalExpr + )); + + Ok(()) + } + + #[test] + fn test_non_nullable_missing_column_error() { + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), // Non-nullable missing column + ]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("b", 1)); + + let result = rewriter.rewrite(column_expr); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Non-nullable column 'b' is missing")); + } +}