diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index abac19bc19..bd9f30f2f0 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -917,6 +917,82 @@ def translate_column_names(expr: BooleanExpression, file_schema: Schema, case_se return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive)) +class _ProjectedColumnsEvaluator(BooleanExpressionVisitor[BooleanExpression]): + """Evaluated predicates which involve projected columns missing from the file. + + Args: + file_schema (Schema): The schema of the file. + projected_schema (Schema): The schema to project onto the data files. + case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True. + projected_missing_fields(dict[str, Any]): Map of fields missing in file_schema, but present as partition values. + + Raises: + TypeError: In the case of an UnboundPredicate. + """ + + file_schema: Schema + case_sensitive: bool + + def __init__( + self, file_schema: Schema, projected_schema: Schema, case_sensitive: bool, projected_missing_fields: dict[str, Any] + ) -> None: + self.file_schema = file_schema + self.projected_schema = projected_schema + self.case_sensitive = case_sensitive + self.projected_missing_fields = projected_missing_fields + + def visit_true(self) -> BooleanExpression: + return AlwaysTrue() + + def visit_false(self) -> BooleanExpression: + return AlwaysFalse() + + def visit_not(self, child_result: BooleanExpression) -> BooleanExpression: + return Not(child=child_result) + + def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: + return And(left=left_result, right=right_result) + + def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: + return Or(left=left_result, right=right_result) + + def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + raise TypeError(f"Expected Bound Predicate, got: {predicate.term}") + + def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression: + file_column_name = self.file_schema.find_column_name(predicate.term.ref().field.field_id) + + if file_column_name is None and (field_name := predicate.term.ref().field.name) in self.projected_missing_fields: + unbound_predicate: BooleanExpression + if isinstance(predicate, BoundUnaryPredicate): + unbound_predicate = predicate.as_unbound(field_name) + elif isinstance(predicate, BoundLiteralPredicate): + unbound_predicate = predicate.as_unbound(field_name, predicate.literal) + elif isinstance(predicate, BoundSetPredicate): + unbound_predicate = predicate.as_unbound(field_name, predicate.literals) + else: + raise ValueError(f"Unsupported predicate: {predicate}") + field = self.projected_schema.find_field(field_name) + schema = Schema(field) + evaluator = expression_evaluator(schema, unbound_predicate, self.case_sensitive) + if evaluator(Record(self.projected_missing_fields[field_name])): + return AlwaysTrue() + else: + return AlwaysFalse() + + return predicate + + +def evaluate_projected_columns( + expr: BooleanExpression, + file_schema: Schema, + projected_schema: Schema, + case_sensitive: bool, + projected_missing_fields: dict[str, Any], +) -> BooleanExpression: + return visit(expr, _ProjectedColumnsEvaluator(file_schema, projected_schema, case_sensitive, projected_missing_fields)) + + class _ExpressionFieldIDs(BooleanExpressionVisitor[Set[int]]): """Extracts the field IDs used in the BooleanExpression.""" diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 3e49885e58..17793b6330 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -78,6 +78,7 @@ from pyiceberg.expressions.visitors import ( BoundBooleanExpressionVisitor, bind, + evaluate_projected_columns, extract_field_ids, translate_column_names, ) @@ -1458,18 +1459,29 @@ def _task_to_record_batches( # the table format version. file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True) - pyarrow_filter = None - if bound_row_filter is not AlwaysTrue(): - translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive) - bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) - pyarrow_filter = expression_to_pyarrow(bound_file_filter) - # Apply column projection rules # https://iceberg.apache.org/spec/#column-projection should_project_columns, projected_missing_fields = _get_column_projection_values( task.file, projected_schema, partition_spec, file_schema.field_ids ) + pyarrow_filter = None + if bound_row_filter is not AlwaysTrue(): + evaluated_projected_columns_filter = evaluate_projected_columns( + bound_row_filter, + file_schema, + projected_schema, + case_sensitive=case_sensitive, + projected_missing_fields=projected_missing_fields, + ) + translated_row_filter = translate_column_names( + evaluated_projected_columns_filter, + file_schema, + case_sensitive=case_sensitive, + ) + bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) + pyarrow_filter = expression_to_pyarrow(bound_file_filter) + file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) fragment_scanner = ds.Scanner.from_fragment( diff --git a/tests/expressions/test_visitors.py b/tests/expressions/test_visitors.py index 273bd24c9b..7d39dc96d7 100644 --- a/tests/expressions/test_visitors.py +++ b/tests/expressions/test_visitors.py @@ -68,6 +68,8 @@ BooleanExpressionVisitor, BoundBooleanExpressionVisitor, _ManifestEvalVisitor, + bind, + evaluate_projected_columns, expression_evaluator, expression_to_plain_format, rewrite_not, @@ -1623,3 +1625,37 @@ def test_expression_evaluator_null() -> None: assert expression_evaluator(schema, LessThan("a", 1), case_sensitive=True)(struct) is False assert expression_evaluator(schema, StartsWith("a", 1), case_sensitive=True)(struct) is False assert expression_evaluator(schema, NotStartsWith("a", 1), case_sensitive=True)(struct) is True + + +@pytest.mark.parametrize( + "before_expression,after_expression", + [ + (In("id", {1, 2, 3}), AlwaysTrue()), + (EqualTo("id", 3), AlwaysFalse()), + ( + And(EqualTo("id", 1), EqualTo("all_same_value_or_null", "string")), + And(AlwaysTrue(), EqualTo("all_same_value_or_null", "string")), + ), + ( + And(EqualTo("all_same_value_or_null", "string"), GreaterThan("id", 2)), + And(EqualTo("all_same_value_or_null", "string"), AlwaysFalse()), + ), + ( + Or( + And(EqualTo("id", 1), EqualTo("all_same_value_or_null", "string")), + And(EqualTo("all_same_value_or_null", "string"), GreaterThan("id", 2)), + ), + Or( + And(AlwaysTrue(), EqualTo("all_same_value_or_null", "string")), + And(EqualTo("all_same_value_or_null", "string"), AlwaysFalse()), + ), + ), + ], +) +def test_eval_projected_fields(schema: Schema, before_expression: BooleanExpression, after_expression: BooleanExpression) -> None: + # exclude id from file_schema pretending that it's part of partition values + file_schema = Schema(*[field for field in schema.columns if field.name != "id"]) + projected_missing_fields = {"id": 1} + assert evaluate_projected_columns( + bind(schema, before_expression, True), file_schema, schema, True, projected_missing_fields + ) == bind(schema, after_expression, True)