Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 13 additions & 44 deletions bigframes/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,12 @@
import pandas

from bigframes import dtypes
from bigframes.core.array_value import ArrayValue
import bigframes.core.block_transforms as block_ops
import bigframes.core.blocks as blocks
import bigframes.core.expression as ex
import bigframes.core.identifiers as ids
import bigframes.core.nodes as nodes
import bigframes.core.ordering as order
import bigframes.core.utils as utils
import bigframes.core.validations as validations
import bigframes.core.window_spec as window_spec
import bigframes.dtypes
import bigframes.formatting_helpers as formatter
import bigframes.operations as ops
Expand Down Expand Up @@ -272,37 +268,20 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
# Get the index column from the block
index_column = self._block.index_columns[0]

# Apply row numbering to the original data
row_number_column_id = ids.ColumnId.unique()
window_node = nodes.WindowOpNode(
child=self._block._expr.node,
expression=ex.NullaryAggregation(agg_ops.RowNumberOp()),
window_spec=window_spec.unbound(),
output_name=row_number_column_id,
never_skip_nulls=True,
)

windowed_array = ArrayValue(window_node)
windowed_block = blocks.Block(
windowed_array,
index_columns=self._block.index_columns,
column_labels=self._block.column_labels.insert(
len(self._block.column_labels), None
),
index_labels=self._block._index_labels,
# Use promote_offsets to get row numbers (similar to argmax/argmin implementation)
block_with_offsets, offsets_id = self._block.promote_offsets(
"temp_get_loc_offsets_"
)

# Create expression to find matching positions
match_expr = ops.eq_op.as_expr(ex.deref(index_column), ex.const(key))
windowed_block, match_col_id = windowed_block.project_expr(match_expr)
block_with_offsets, match_col_id = block_with_offsets.project_expr(match_expr)

# Filter to only rows where the key matches
filtered_block = windowed_block.filter_by_id(match_col_id)
filtered_block = block_with_offsets.filter_by_id(match_col_id)

# Check if key exists at all by counting on the filtered block
count_agg = ex.UnaryAggregation(
agg_ops.count_op, ex.deref(row_number_column_id.name)
)
# Check if key exists at all by counting
count_agg = ex.UnaryAggregation(agg_ops.count_op, ex.deref(offsets_id))
count_result = filtered_block._expr.aggregate([(count_agg, "count")])
count_scalar = self._block.session._executor.execute(
count_result
Expand All @@ -313,9 +292,7 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:

# If only one match, return integer position
if count_scalar == 1:
min_agg = ex.UnaryAggregation(
agg_ops.min_op, ex.deref(row_number_column_id.name)
)
min_agg = ex.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id))
position_result = filtered_block._expr.aggregate([(min_agg, "position")])
position_scalar = self._block.session._executor.execute(
position_result
Expand All @@ -325,32 +302,24 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
# Handle multiple matches based on index monotonicity
is_monotonic = self.is_monotonic_increasing or self.is_monotonic_decreasing
if is_monotonic:
return self._get_monotonic_slice(filtered_block, row_number_column_id)
return self._get_monotonic_slice(filtered_block, offsets_id)
else:
# Return boolean mask for non-monotonic duplicates
mask_block = windowed_block.select_columns([match_col_id])
# Reset the index to use positional integers instead of original index values
mask_block = block_with_offsets.select_columns([match_col_id])
mask_block = mask_block.reset_index(drop=True)
# Ensure correct dtype and name to match pandas behavior
result_series = bigframes.series.Series(mask_block)
return result_series.astype("boolean")

def _get_monotonic_slice(
self, filtered_block, row_number_column_id: "ids.ColumnId"
) -> slice:
def _get_monotonic_slice(self, filtered_block, offsets_id: str) -> slice:
"""Helper method to get a slice for monotonic duplicates with an optimized query."""
# Combine min and max aggregations into a single query for efficiency
min_max_aggs = [
(
ex.UnaryAggregation(
agg_ops.min_op, ex.deref(row_number_column_id.name)
),
ex.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id)),
"min_pos",
),
(
ex.UnaryAggregation(
agg_ops.max_op, ex.deref(row_number_column_id.name)
),
ex.UnaryAggregation(agg_ops.max_op, ex.deref(offsets_id)),
"max_pos",
),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ def get_loc(
1 True
2 False
3 True
Name: nan, dtype: boolean
dtype: boolean

Args:
key: Label to get the location for.
Expand Down