Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd

import flox
from flox.factorize import _factorize_multiple

from .helpers import codes_for_resampling

Expand Down Expand Up @@ -89,7 +90,7 @@ def setup(self, *args, **kwargs):
y = np.repeat(np.arange(30), 60)
by = x[np.newaxis, :] * y[:, np.newaxis]

self.by = flox.core._factorize_multiple((by,), expected_groups=(None,), any_by_dask=False)[0][0]
self.by = _factorize_multiple((by,), expected_groups=(None,), any_by_dask=False)[0][0]

self.array = dask.array.ones(self.by.shape, chunks=(350, 350))
self.axis = (-2, -1)
Expand Down Expand Up @@ -149,7 +150,7 @@ class ERA5MonthHour(ERA5Dataset, Cohorts):
def setup(self, *args, **kwargs):
super().__init__()
by = (self.time.dt.month.values, self.time.dt.hour.values)
ret = flox.core._factorize_multiple(
ret = _factorize_multiple(
by,
(pd.Index(np.arange(1, 13)), pd.Index(np.arange(1, 25))),
any_by_dask=False,
Expand Down
2 changes: 1 addition & 1 deletion flox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from .aggregations import Aggregation, Scan # noqa
from .core import (
groupby_reduce,
groupby_scan,
rechunk_for_blockwise,
rechunk_for_cohorts,
ReindexStrategy,
ReindexArrayType,
) # noqa
from .scan import groupby_scan # noqa


def _get_version():
Expand Down
17 changes: 17 additions & 0 deletions flox/aggregate_numbagg.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,23 @@ def nanlen(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None)
any = partial(_numbagg_wrapper, func="nanany")
all = partial(_numbagg_wrapper, func="nanall")


def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
"""Account for numbagg not providing a fill_value kwarg."""
if not isinstance(func, str) or func not in DEFAULT_FILL_VALUE:
return result
# The condition needs to be
# len(found_groups) < size; if so we mask with fill_value (?)
default_fv = DEFAULT_FILL_VALUE[func]
needs_masking = fill_value is not None and not np.array_equal(fill_value, default_fv, equal_nan=True)
groups = np.arange(size)
if needs_masking:
mask = np.isin(groups, seen_groups, assume_unique=True, invert=True)
if mask.any():
result[..., groups[mask]] = fill_value
return result


# sum = nansum
# mean = nanmean
# sum_of_squares = nansum_of_squares
Loading
Loading