@@ -28,6 +28,8 @@ use itertools::Itertools;
28
28
use serde:: { Deserialize , Serialize } ;
29
29
30
30
use crate :: error:: Result ;
31
+ use crate :: expr:: visitors:: predicate_visitor:: visit;
32
+ use crate :: expr:: visitors:: rewrite_not:: RewriteNotVisitor ;
31
33
use crate :: expr:: { Bind , BoundReference , PredicateOperator , Reference } ;
32
34
use crate :: spec:: { Datum , PrimitiveLiteral , SchemaRef } ;
33
35
use crate :: { Error , ErrorKind } ;
@@ -652,29 +654,8 @@ impl Predicate {
652
654
/// assert_eq!(&format!("{result}"), "a >= 5");
653
655
/// ```
654
656
pub fn rewrite_not ( self ) -> Predicate {
655
- match self {
656
- Predicate :: And ( expr) => {
657
- let [ left, right] = expr. inputs ;
658
- let new_left = Box :: new ( left. rewrite_not ( ) ) ;
659
- let new_right = Box :: new ( right. rewrite_not ( ) ) ;
660
- Predicate :: And ( LogicalExpression :: new ( [ new_left, new_right] ) )
661
- }
662
- Predicate :: Or ( expr) => {
663
- let [ left, right] = expr. inputs ;
664
- let new_left = Box :: new ( left. rewrite_not ( ) ) ;
665
- let new_right = Box :: new ( right. rewrite_not ( ) ) ;
666
- Predicate :: Or ( LogicalExpression :: new ( [ new_left, new_right] ) )
667
- }
668
- Predicate :: Not ( expr) => {
669
- let [ inner] = expr. inputs ;
670
- inner. negate ( )
671
- }
672
- Predicate :: Unary ( expr) => Predicate :: Unary ( expr) ,
673
- Predicate :: Binary ( expr) => Predicate :: Binary ( expr) ,
674
- Predicate :: Set ( expr) => Predicate :: Set ( expr) ,
675
- Predicate :: AlwaysTrue => Predicate :: AlwaysTrue ,
676
- Predicate :: AlwaysFalse => Predicate :: AlwaysFalse ,
677
- }
657
+ visit ( & mut RewriteNotVisitor :: new ( ) , & self )
658
+ . expect ( "RewriteNotVisitor guarantees always success" )
678
659
}
679
660
}
680
661
@@ -1466,4 +1447,29 @@ mod tests {
1466
1447
assert_eq ! ( & format!( "{bound_expr}" ) , r#"True"# ) ;
1467
1448
test_bound_predicate_serialize_diserialize ( bound_expr) ;
1468
1449
}
1450
+
1451
+ #[ test]
1452
+ fn test_rewrite_not_deeply_nested ( ) {
1453
+ // Test nested expression: not((not((not(ref(name="bar") < 40) and ref(name="bar") < 40)) and ref(name="bar") < 40))
1454
+ // Expected rewrite not result: ((bar >= 40) AND (bar < 40)) OR (bar >= 40)
1455
+ let complex_expression = Reference :: new ( "bar" )
1456
+ . less_than ( Datum :: int ( 40 ) )
1457
+ . not ( )
1458
+ . and ( Reference :: new ( "bar" ) . less_than ( Datum :: int ( 40 ) ) )
1459
+ . not ( )
1460
+ . and ( Reference :: new ( "bar" ) . less_than ( Datum :: int ( 40 ) ) )
1461
+ . not ( ) ;
1462
+
1463
+ let expected = Reference :: new ( "bar" )
1464
+ . greater_than_or_equal_to ( Datum :: int ( 40 ) )
1465
+ . and ( Reference :: new ( "bar" ) . less_than ( Datum :: int ( 40 ) ) )
1466
+ . or ( Reference :: new ( "bar" ) . greater_than_or_equal_to ( Datum :: int ( 40 ) ) ) ;
1467
+
1468
+ let result = complex_expression. rewrite_not ( ) ;
1469
+
1470
+ assert_eq ! ( result, expected) ;
1471
+
1472
+ let result_str = format ! ( "{result}" ) ;
1473
+ assert_eq ! ( & result_str, "((bar >= 40) AND (bar < 40)) OR (bar >= 40)" ) ;
1474
+ }
1469
1475
}
0 commit comments