Skip to content

Commit 01e9cf9

Browse files
Merge pull request #79 from IntelPython/hotfix/scipy-fft-rfft
Hotfix/scipy fft rfft
2 parents 69bfdbb + baff159 commit 01e9cf9

File tree

4 files changed

+26
-13
lines changed

4 files changed

+26
-13
lines changed

conda-recipe/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{% set version = "1.3.4" %}
1+
{% set version = "1.3.5" %}
22
{% set buildnumber = 0 %}
33

44
package:

mkl_fft/_scipy_fft_backend.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def _workers_to_num_threads(w):
151151
return _workers_global_settings.get().workers
152152
_w = operator.index(w)
153153
if (_w == 0):
154-
raise ValueError("Number of workers must be nonzero")
154+
raise ValueError("Number of workers must not be zero")
155155
if (_w < 0):
156156
ub = _cpu_max_threads_count().get_cpu_count()
157157
_w += ub + 1
@@ -338,21 +338,22 @@ def irfft(a, n=None, axis=-1, norm=None, workers=None, plan=None):
338338
return NotImplemented
339339
if x is NotImplemented:
340340
return x
341-
fsc = _compute_1d_forward_scale(norm, n, x.shape[axis])
341+
nn = n if n else 2*(x.shape[axis]-1)
342+
fsc = _compute_1d_forward_scale(norm, nn, x.shape[axis])
342343
_check_plan(plan)
343344
with Workers(workers):
344345
output = _pydfti.irfft_numpy(x, n=n, axis=axis, forward_scale=fsc)
345346
return output
346347

347348

348-
def _compute_nd_forward_scale_for_rfft(norm, s, axes, x):
349+
def _compute_nd_forward_scale_for_rfft(norm, s, axes, x, invreal=False):
349350
if norm in (None, "backward"):
350351
fsc = 1.0
351352
elif norm == "forward":
352-
s, axes = _cook_nd_args(x, s, axes)
353+
s, axes = _cook_nd_args(x, s, axes, invreal=invreal)
353354
fsc = _frwd_sc_nd(s, axes, x.shape)
354355
elif norm == "ortho":
355-
s, axes = _cook_nd_args(x, s, axes)
356+
s, axes = _cook_nd_args(x, s, axes, invreal=invreal)
356357
fsc = sqrt(_frwd_sc_nd(s, axes, x.shape))
357358
else:
358359
_check_norm(norm)
@@ -380,7 +381,7 @@ def irfft2(a, s=None, axes=(-2, -1), norm=None, workers=None, plan=None):
380381
return NotImplemented
381382
if x is NotImplemented:
382383
return x
383-
s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x)
384+
s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x, invreal=True)
384385
_check_plan(plan)
385386
with Workers(workers):
386387
output = _pydfti.irfftn_numpy(x, s, axes, forward_scale=fsc)
@@ -408,7 +409,7 @@ def irfftn(a, s=None, axes=None, norm=None, workers=None, plan=None):
408409
return NotImplemented
409410
if x is NotImplemented:
410411
return x
411-
s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x)
412+
s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x, invreal=True)
412413
_check_plan(plan)
413414
with Workers(workers):
414415
output = _pydfti.irfftn_numpy(x, s, axes, forward_scale=fsc)

mkl_fft/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.3.4'
1+
__version__ = '1.3.5'

mkl_fft/tests/test_interfaces.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def test_scipy_rfft(norm, dtype):
6565
xx = mfi.scipy_fft.irfft(w, n=x.shape[0], norm=norm, workers=None, plan=None)
6666
tol = 64 * np.finfo(np.dtype(dtype)).eps
6767
assert np.allclose(x, xx, atol=tol, rtol=tol)
68+
69+
x = np.ones(510, dtype=dtype)
70+
w = mfi.scipy_fft.rfft(x, norm=norm, workers=None, plan=None)
71+
xx = mfi.scipy_fft.irfft(w, norm=norm, workers=None, plan=None)
72+
tol = 64 * np.finfo(np.dtype(dtype)).eps
73+
assert np.allclose(x, xx, atol=tol, rtol=tol)
6874

6975

7076
@pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"])
@@ -99,20 +105,26 @@ def test_numpy_fftn(norm, dtype):
99105

100106
@pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"])
101107
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
102-
def test_scipy_rftn(norm, dtype):
108+
def test_scipy_rfftn(norm, dtype):
103109
x = np.ones((37, 83), dtype=dtype)
104110
w = mfi.scipy_fft.rfftn(x, norm=norm, workers=None, plan=None)
105-
xx = mfi.scipy_fft.ifftn(w, s=x.shape, norm=norm, workers=None, plan=None)
111+
xx = mfi.scipy_fft.irfftn(w, s=x.shape, norm=norm, workers=None, plan=None)
112+
tol = 64 * np.finfo(np.dtype(dtype)).eps
113+
assert np.allclose(x, xx, atol=tol, rtol=tol)
114+
115+
x = np.ones((36, 82), dtype=dtype)
116+
w = mfi.scipy_fft.rfftn(x, norm=norm, workers=None, plan=None)
117+
xx = mfi.scipy_fft.irfftn(w, norm=norm, workers=None, plan=None)
106118
tol = 64 * np.finfo(np.dtype(dtype)).eps
107119
assert np.allclose(x, xx, atol=tol, rtol=tol)
108120

109121

110122
@pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"])
111123
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
112-
def test_numpy_rftn(norm, dtype):
124+
def test_numpy_rfftn(norm, dtype):
113125
x = np.ones((37, 83), dtype=dtype)
114126
w = mfi.numpy_fft.rfftn(x, norm=norm)
115-
xx = mfi.numpy_fft.ifftn(w, s=x.shape, norm=norm)
127+
xx = mfi.numpy_fft.irfftn(w, s=x.shape, norm=norm)
116128
tol = 64 * np.finfo(np.dtype(dtype)).eps
117129
assert np.allclose(x, xx, atol=tol, rtol=tol)
118130

0 commit comments

Comments
 (0)