Skip to content

Commit efe322b

Browse files
HumphreyYangmmcky
andauthored
FIX: Update Numba Lecture to Address Deprecation of @jit (#296)
* update a section on type inference. * update lecture to avoid literal box warning * check the type of the function * Update lectures/numba.md Co-authored-by: mmcky <[email protected]> * reduce redundancy * further simplifies descriptions * fix typos --------- Co-authored-by: mmcky <[email protected]>
1 parent 995c490 commit efe322b

File tree

1 file changed

+118
-60
lines changed

1 file changed

+118
-60
lines changed

lectures/numba.md

Lines changed: 118 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ jupytext:
33
text_representation:
44
extension: .md
55
format_name: myst
6+
format_version: 0.13
7+
jupytext_version: 1.14.4
68
kernelspec:
7-
display_name: Python 3
9+
display_name: Python 3 (ipykernel)
810
language: python
911
name: python3
1012
---
@@ -26,10 +28,9 @@ kernelspec:
2628

2729
In addition to what's in Anaconda, this lecture will need the following libraries:
2830

29-
```{code-cell} ipython
30-
---
31-
tags: [hide-output]
32-
---
31+
```{code-cell} ipython3
32+
:tags: [hide-output]
33+
3334
!pip install quantecon
3435
```
3536

@@ -38,7 +39,7 @@ versions are a {doc}`common source of errors <troubleshooting>`.
3839

3940
Let's start with some imports:
4041

41-
```{code-cell} ipython
42+
```{code-cell} ipython3
4243
%matplotlib inline
4344
import numpy as np
4445
import quantecon as qe
@@ -98,13 +99,13 @@ $$
9899

99100
In what follows we set
100101

101-
```{code-cell} python3
102+
```{code-cell} ipython3
102103
α = 4.0
103104
```
104105

105106
Here's the plot of a typical trajectory, starting from $x_0 = 0.1$, with $t$ on the x-axis
106107

107-
```{code-cell} python3
108+
```{code-cell} ipython3
108109
def qm(x0, n):
109110
x = np.empty(n+1)
110111
x[0] = x0
@@ -122,10 +123,10 @@ plt.show()
122123

123124
To speed the function `qm` up using Numba, our first step is
124125

125-
```{code-cell} python3
126-
from numba import jit
126+
```{code-cell} ipython3
127+
from numba import njit
127128
128-
qm_numba = jit(qm)
129+
qm_numba = njit(qm)
129130
```
130131

131132
The function `qm_numba` is a version of `qm` that is "targeted" for
@@ -135,7 +136,7 @@ We will explain what this means momentarily.
135136

136137
Let's time and compare identical function calls across these two versions, starting with the original function `qm`:
137138

138-
```{code-cell} python3
139+
```{code-cell} ipython3
139140
n = 10_000_000
140141
141142
qe.tic()
@@ -145,7 +146,7 @@ time1 = qe.toc()
145146

146147
Now let's try qm_numba
147148

148-
```{code-cell} python3
149+
```{code-cell} ipython3
149150
qe.tic()
150151
qm_numba(0.1, int(n))
151152
time2 = qe.toc()
@@ -156,13 +157,14 @@ This is already a massive speed gain.
156157
In fact, the next time and all subsequent times it runs even faster as the function has been compiled and is in memory:
157158

158159
(qm_numba_result)=
159-
```{code-cell} python3
160+
161+
```{code-cell} ipython3
160162
qe.tic()
161163
qm_numba(0.1, int(n))
162164
time3 = qe.toc()
163165
```
164166

165-
```{code-cell} python3
167+
```{code-cell} ipython3
166168
time1 / time3 # Calculate speed gain
167169
```
168170

@@ -194,12 +196,12 @@ Note that, if you make the call `qm(0.5, 10)` and then follow it with `qm(0.9, 2
194196

195197
The compiled code is then cached and recycled as required.
196198

197-
## Decorators and "nopython" Mode
199+
## Decorator Notation
198200

199201
In the code above we created a JIT compiled version of `qm` via the call
200202

201-
```{code-cell} python3
202-
qm_numba = jit(qm)
203+
```{code-cell} ipython3
204+
qm_numba = njit(qm)
203205
```
204206

205207
In practice this would typically be done using an alternative *decorator* syntax.
@@ -208,14 +210,12 @@ In practice this would typically be done using an alternative *decorator* syntax
208210

209211
Let's see how this is done.
210212

211-
### Decorator Notation
212-
213-
To target a function for JIT compilation we can put `@jit` before the function definition.
213+
To target a function for JIT compilation we can put `@njit` before the function definition.
214214

215215
Here's what this looks like for `qm`
216216

217-
```{code-cell} python3
218-
@jit
217+
```{code-cell} ipython3
218+
@njit
219219
def qm(x0, n):
220220
x = np.empty(n+1)
221221
x[0] = x0
@@ -224,15 +224,21 @@ def qm(x0, n):
224224
return x
225225
```
226226

227-
This is equivalent to `qm = jit(qm)`.
227+
This is equivalent to `qm = njit(qm)`.
228228

229229
The following now uses the jitted version:
230230

231-
```{code-cell} python3
232-
qm(0.1, 10)
231+
```{code-cell} ipython3
232+
%%time
233+
234+
qm(0.1, 100_000)
233235
```
234236

235-
### Type Inference and "nopython" Mode
237+
Numba provides several arguments for decorators to accelerate computation and cache functions [here](https://numba.readthedocs.io/en/stable/user/performance-tips.html).
238+
239+
In the [following lecture on parallelization](parallel), we will discuss how to use the `parallel` argument to achieve automatic parallelization.
240+
241+
## Type Inference
236242

237243
Clearly type inference is a key part of JIT compilation.
238244

@@ -246,29 +252,83 @@ This allows it to generate native machine code, without having to call the Pytho
246252

247253
In such a setting, Numba will be on par with machine code from low-level languages.
248254

249-
When Numba cannot infer all type information, some Python objects are given generic object status and execution falls back to the Python runtime.
255+
When Numba cannot infer all type information, it will raise an error.
250256

251-
When this happens, Numba provides only minor speed gains or none at all.
257+
For example, in the case below, Numba is unable to determine the type of function `mean` when compiling the function `bootstrap`
252258

253-
We generally prefer to force an error when this occurs, so we know effective
254-
compilation is failing.
259+
```{code-cell} ipython3
260+
@njit
261+
def bootstrap(data, statistics, n):
262+
bootstrap_stat = np.empty(n)
263+
n = len(data)
264+
for i in range(n_resamples):
265+
resample = np.random.choice(data, size=n, replace=True)
266+
bootstrap_stat[i] = statistics(resample)
267+
return bootstrap_stat
255268
256-
This is done by using either `@jit(nopython=True)` or, equivalently, `@njit` instead of `@jit`.
269+
def mean(data):
270+
return np.mean(data)
257271
258-
For example,
272+
data = np.array([2.3, 3.1, 4.3, 5.9, 2.1, 3.8, 2.2])
273+
n_resamples = 10
259274
260-
```{code-cell} python3
261-
from numba import njit
275+
print('Type of function:', type(mean))
276+
277+
#Error
278+
try:
279+
bootstrap(data, mean, n_resamples)
280+
except Exception as e:
281+
print(e)
282+
```
262283

284+
But Numba recognizes JIT-compiled functions
285+
286+
```{code-cell} ipython3
263287
@njit
264-
def qm(x0, n):
265-
x = np.empty(n+1)
266-
x[0] = x0
267-
for t in range(n):
268-
x[t+1] = 4 * x[t] * (1 - x[t])
269-
return x
288+
def mean(data):
289+
return np.mean(data)
290+
291+
print('Type of function:', type(mean))
292+
293+
%time bootstrap(data, mean, n_resamples)
294+
```
295+
296+
We can check the signature of the JIT-compiled function
297+
298+
```{code-cell} ipython3
299+
bootstrap.signatures
300+
```
301+
302+
The function `bootstrap` takes one `float64` floating point array, one function called `mean` and an `int64` integer.
303+
304+
Now let's see what happens when we change the inputs.
305+
306+
Running it again with a larger integer for `n` and a different set of data does not change the signature of the function.
307+
308+
```{code-cell} ipython3
309+
data = np.array([4.1, 1.1, 2.3, 1.9, 0.1, 2.8, 1.2])
310+
%time bootstrap(data, mean, 100)
311+
bootstrap.signatures
270312
```
271313

314+
As expected, the second run is much faster.
315+
316+
Let's try to change the data again and use an integer array as data
317+
318+
```{code-cell} ipython3
319+
data = np.array([1, 2, 3, 4, 5], dtype=np.int64)
320+
%time bootstrap(data, mean, 100)
321+
bootstrap.signatures
322+
```
323+
324+
Note that a second signature is added.
325+
326+
It also takes longer to run, suggesting that Numba recompiles this function as the type changes.
327+
328+
Overall, type inference helps Numba to achieve its performance, but it also limits what Numba supports and sometimes requires careful type checks.
329+
330+
You can refer to the list of supported Python and Numpy features [here](https://numba.pydata.org/numba-doc/dev/reference/pysupported.html).
331+
272332
## Compiling Classes
273333

274334
As mentioned above, at present Numba can only compile a subset of Python.
@@ -285,7 +345,7 @@ created in {doc}`this lecture <python_oop>`.
285345

286346
To compile this class we use the `@jitclass` decorator:
287347

288-
```{code-cell} python3
348+
```{code-cell} ipython3
289349
from numba import float64
290350
from numba.experimental import jitclass
291351
```
@@ -294,11 +354,11 @@ Notice that we also imported something called `float64`.
294354

295355
This is a data type representing standard floating point numbers.
296356

297-
We are importing it here because Numba needs a bit of extra help with types when it trys to deal with classes.
357+
We are importing it here because Numba needs a bit of extra help with types when it tries to deal with classes.
298358

299359
Here's our code:
300360

301-
```{code-cell} python3
361+
```{code-cell} ipython3
302362
solow_data = [
303363
('n', float64),
304364
('s', float64),
@@ -361,7 +421,7 @@ After that, targeting the class for JIT compilation only requires adding
361421

362422
When we call the methods in the class, the methods are compiled just like functions.
363423

364-
```{code-cell} python3
424+
```{code-cell} ipython3
365425
s1 = Solow()
366426
s2 = Solow(k=8.0)
367427
@@ -444,25 +504,25 @@ For larger ones, or for routines using external libraries, it can easily fail.
444504

445505
Hence, it's prudent when using Numba to focus on speeding up small, time-critical snippets of code.
446506

447-
This will give you much better performance than blanketing your Python programs with `@jit` statements.
507+
This will give you much better performance than blanketing your Python programs with `@njit` statements.
448508

449509
### A Gotcha: Global Variables
450510

451511
Here's another thing to be careful about when using Numba.
452512

453513
Consider the following example
454514

455-
```{code-cell} python3
515+
```{code-cell} ipython3
456516
a = 1
457517
458-
@jit
518+
@njit
459519
def add_a(x):
460520
return a + x
461521
462522
print(add_a(10))
463523
```
464524

465-
```{code-cell} python3
525+
```{code-cell} ipython3
466526
a = 2
467527
468528
print(add_a(10))
@@ -492,7 +552,7 @@ Compare speed with and without Numba when the sample size is large.
492552

493553
Here is one solution:
494554

495-
```{code-cell} python3
555+
```{code-cell} ipython3
496556
from random import uniform
497557
498558
@njit
@@ -581,13 +641,13 @@ We let
581641
- 0 represent "low"
582642
- 1 represent "high"
583643

584-
```{code-cell} python3
644+
```{code-cell} ipython3
585645
p, q = 0.1, 0.2 # Prob of leaving low and high state respectively
586646
```
587647

588648
Here's a pure Python version of the function
589649

590-
```{code-cell} python3
650+
```{code-cell} ipython3
591651
def compute_series(n):
592652
x = np.empty(n, dtype=np.int_)
593653
x[0] = 1 # Start in state 1
@@ -604,7 +664,7 @@ def compute_series(n):
604664
Let's run this code and check that the fraction of time spent in the low
605665
state is about 0.666
606666

607-
```{code-cell} python3
667+
```{code-cell} ipython3
608668
n = 1_000_000
609669
x = compute_series(n)
610670
print(np.mean(x == 0)) # Fraction of time x is in state 0
@@ -614,30 +674,28 @@ This is (approximately) the right output.
614674

615675
Now let's time it:
616676

617-
```{code-cell} python3
677+
```{code-cell} ipython3
618678
qe.tic()
619679
compute_series(n)
620680
qe.toc()
621681
```
622682

623683
Next let's implement a Numba version, which is easy
624684

625-
```{code-cell} python3
626-
from numba import jit
627-
628-
compute_series_numba = jit(compute_series)
685+
```{code-cell} ipython3
686+
compute_series_numba = njit(compute_series)
629687
```
630688

631689
Let's check we still get the right numbers
632690

633-
```{code-cell} python3
691+
```{code-cell} ipython3
634692
x = compute_series_numba(n)
635693
print(np.mean(x == 0))
636694
```
637695

638696
Let's see the time
639697

640-
```{code-cell} python3
698+
```{code-cell} ipython3
641699
qe.tic()
642700
compute_series_numba(n)
643701
qe.toc()

0 commit comments

Comments
 (0)