diff --git a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi.py b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi.py index 3cb80d322a17f..4a60d6f5798df 100644 --- a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi.py +++ b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi.py @@ -386,6 +386,9 @@ def _eq(self, other): isinstance(other, type(self)) and _shape(self) == _shape(other) and np.array_equal(_values_default(self), _values_default(other)) + and self.kind == other.kind + and np.array_equal(_get_sum_of_weights_squared(self), _get_sum_of_weights_squared(other)) + and all(a == b for a, b in zip(self.axes, other.axes)) ) @@ -559,11 +562,12 @@ def _get_sum_of_weights_squared(self) -> np.typing.NDArray[Any]: # noqa: F821 import numpy as np shape = _shape(self, include_flow_bins=False) - return np.frombuffer( + sumw2_arr = np.frombuffer( self.GetSumw2().GetArray(), dtype=self.GetSumw2().GetArray().typecode, count=self.GetSumw2().GetSize(), - ).reshape(shape, order="F")[tuple([slice(1, -1)] * len(shape))] + ) + return sumw2_arr[tuple([slice(1, -1)] * len(shape))].reshape(shape, order="F") if sumw2_arr.size > 0 else sumw2_arr values_func_dict: dict[str, Callable] = { diff --git a/bindings/pyroot/pythonizations/test/uhi_indexing.py b/bindings/pyroot/pythonizations/test/uhi_indexing.py index 2a98d3c7f5c77..a4ae44b92e6df 100644 --- a/bindings/pyroot/pythonizations/test/uhi_indexing.py +++ b/bindings/pyroot/pythonizations/test/uhi_indexing.py @@ -315,6 +315,36 @@ def test_statistics_slice(self, hist_setup): assert hist_setup.GetStdDev() == pytest.approx(sliced_hist.GetStdDev(), rel=10e-5) assert hist_setup.GetMean() == pytest.approx(sliced_hist.GetMean(), rel=10e-5) + def test_equality_operator(self, hist_setup): + if _special_setting(hist_setup) or isinstance(hist_setup, (ROOT.TH1C, ROOT.TH2C, ROOT.TH3C)): + pytest.skip("This feature cannot be tested here") + + assert hist_setup == hist_setup[...] + + cloned_histograms = { + "content": hist_setup.Clone(), + "error": hist_setup.Clone(), + "weight": hist_setup.Clone(), + "axis": hist_setup.Clone(), + } + assert all(hist_setup == other for other in cloned_histograms.values()) + + # Change the content of a bin + cloned_histograms["content"].SetBinContent(1, 100) + assert hist_setup != cloned_histograms["content"] + + # Change the error of a bin + cloned_histograms["error"].SetBinError(1, 10) + assert hist_setup != cloned_histograms["error"] + + # Change the weight of a bin + cloned_histograms["weight"].AddBinContent(1, 100) + assert hist_setup != cloned_histograms["weight"] + + # Change the x axis + cloned_histograms["axis"].GetXaxis().Set(10, 1, 5) + assert hist_setup != cloned_histograms["axis"] + if __name__ == "__main__": pytest.main(args=[__file__])