diff --git a/bigframes/core/indexes/base.py b/bigframes/core/indexes/base.py index 2bb58da330..9ad201c73d 100644 --- a/bigframes/core/indexes/base.py +++ b/bigframes/core/indexes/base.py @@ -27,16 +27,12 @@ import pandas from bigframes import dtypes -from bigframes.core.array_value import ArrayValue import bigframes.core.block_transforms as block_ops import bigframes.core.blocks as blocks import bigframes.core.expression as ex -import bigframes.core.identifiers as ids -import bigframes.core.nodes as nodes import bigframes.core.ordering as order import bigframes.core.utils as utils import bigframes.core.validations as validations -import bigframes.core.window_spec as window_spec import bigframes.dtypes import bigframes.formatting_helpers as formatter import bigframes.operations as ops @@ -272,37 +268,20 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]: # Get the index column from the block index_column = self._block.index_columns[0] - # Apply row numbering to the original data - row_number_column_id = ids.ColumnId.unique() - window_node = nodes.WindowOpNode( - child=self._block._expr.node, - expression=ex.NullaryAggregation(agg_ops.RowNumberOp()), - window_spec=window_spec.unbound(), - output_name=row_number_column_id, - never_skip_nulls=True, - ) - - windowed_array = ArrayValue(window_node) - windowed_block = blocks.Block( - windowed_array, - index_columns=self._block.index_columns, - column_labels=self._block.column_labels.insert( - len(self._block.column_labels), None - ), - index_labels=self._block._index_labels, + # Use promote_offsets to get row numbers (similar to argmax/argmin implementation) + block_with_offsets, offsets_id = self._block.promote_offsets( + "temp_get_loc_offsets_" ) # Create expression to find matching positions match_expr = ops.eq_op.as_expr(ex.deref(index_column), ex.const(key)) - windowed_block, match_col_id = windowed_block.project_expr(match_expr) + block_with_offsets, match_col_id = block_with_offsets.project_expr(match_expr) # Filter to only rows where the key matches - filtered_block = windowed_block.filter_by_id(match_col_id) + filtered_block = block_with_offsets.filter_by_id(match_col_id) - # Check if key exists at all by counting on the filtered block - count_agg = ex.UnaryAggregation( - agg_ops.count_op, ex.deref(row_number_column_id.name) - ) + # Check if key exists at all by counting + count_agg = ex.UnaryAggregation(agg_ops.count_op, ex.deref(offsets_id)) count_result = filtered_block._expr.aggregate([(count_agg, "count")]) count_scalar = self._block.session._executor.execute( count_result @@ -313,9 +292,7 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]: # If only one match, return integer position if count_scalar == 1: - min_agg = ex.UnaryAggregation( - agg_ops.min_op, ex.deref(row_number_column_id.name) - ) + min_agg = ex.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id)) position_result = filtered_block._expr.aggregate([(min_agg, "position")]) position_scalar = self._block.session._executor.execute( position_result @@ -325,32 +302,24 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]: # Handle multiple matches based on index monotonicity is_monotonic = self.is_monotonic_increasing or self.is_monotonic_decreasing if is_monotonic: - return self._get_monotonic_slice(filtered_block, row_number_column_id) + return self._get_monotonic_slice(filtered_block, offsets_id) else: # Return boolean mask for non-monotonic duplicates - mask_block = windowed_block.select_columns([match_col_id]) - # Reset the index to use positional integers instead of original index values + mask_block = block_with_offsets.select_columns([match_col_id]) mask_block = mask_block.reset_index(drop=True) - # Ensure correct dtype and name to match pandas behavior result_series = bigframes.series.Series(mask_block) return result_series.astype("boolean") - def _get_monotonic_slice( - self, filtered_block, row_number_column_id: "ids.ColumnId" - ) -> slice: + def _get_monotonic_slice(self, filtered_block, offsets_id: str) -> slice: """Helper method to get a slice for monotonic duplicates with an optimized query.""" # Combine min and max aggregations into a single query for efficiency min_max_aggs = [ ( - ex.UnaryAggregation( - agg_ops.min_op, ex.deref(row_number_column_id.name) - ), + ex.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id)), "min_pos", ), ( - ex.UnaryAggregation( - agg_ops.max_op, ex.deref(row_number_column_id.name) - ), + ex.UnaryAggregation(agg_ops.max_op, ex.deref(offsets_id)), "max_pos", ), ] diff --git a/third_party/bigframes_vendored/pandas/core/indexes/base.py b/third_party/bigframes_vendored/pandas/core/indexes/base.py index 035eba74fd..eba47fc1f9 100644 --- a/third_party/bigframes_vendored/pandas/core/indexes/base.py +++ b/third_party/bigframes_vendored/pandas/core/indexes/base.py @@ -767,7 +767,7 @@ def get_loc( 1 True 2 False 3 True - Name: nan, dtype: boolean + dtype: boolean Args: key: Label to get the location for.