Skip to content

Commit 93b329a

Browse files
jstacmmcky
andauthored
misc (#49)
Co-authored-by: mmcky <[email protected]>
1 parent 72925c4 commit 93b329a

File tree

3 files changed

+99
-79
lines changed

3 files changed

+99
-79
lines changed

lectures/about.md

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,58 @@
11

2-
# About these Lectures
2+
# About
33

4-
## About
4+
This lecture series provides an introduction to quantitative economics using [Google JAX](https://github.com/google/jax).
55

6-
This lecture series introduces quantitative economics using Google JAX.
76

8-
We assume that readers have covered most of the QuantEcon lecture
9-
series [on Python programming](https://python-programming.quantecon.org/intro.html).
7+
## What is JAX?
8+
9+
JAX is an open source Python library developed by Google Research to support
10+
in-house artificial intelligence and machine learning.
11+
12+
JAX provides data types, functions and a compiler for fast linear
13+
algebra operations and automatic differentiation.
14+
15+
Loosely speaking, JAX is like [NumPy](https://numpy.org/) with the addition of
16+
17+
* automatic differentiation
18+
* automated GPU/TPU support
19+
* a just-in-time compiler
20+
21+
One of the great benefits of JAX is that exactly the same code can be run either
22+
on the CPU or on a hardware accelerator, such as a GPU or TPU.
23+
24+
In short, JAX delivers
25+
26+
1. high execution speeds on CPUs due to efficient parallelization and JIT
27+
compilation,
28+
1. a powerful and convenient environment for GPU programming, and
29+
1. the ability to efficiently differentiate smooth functions for optimization
30+
and estimation.
31+
32+
These features make JAX ideal for almost all quantitative economic modeling
33+
problems that require heavy-duty computing.
34+
35+
## How to run these lectures
36+
37+
The easiest way to run these lectures is via [Google Colab](https://colab.research.google.com/).
1038

39+
JAX is pre-installed with GPU support on Colab and Colab provides GPU access
40+
even on the free tier.
41+
42+
Each lecture has a "play" button on the top right that you can use to launch the
43+
lecture on Colab.
44+
45+
You might also like to try using JAX locally.
46+
47+
If you do not own a GPU, you can still install JAX for the CPU by following the relevant [install instructions](https://github.com/google/jax).
48+
49+
(We recommend that you install [Anaconda
50+
Python](https://www.anaconda.com/download) first.)
51+
52+
If you do have a GPU, you can try installing JAX for the GPU by following the
53+
install instructions for GPUs.
54+
55+
(This is not always trivial but is starting to get easier.)
1156

1257
## Credits
1358

@@ -23,3 +68,9 @@ In particular, we thank and credit
2368
- [Hengcheng Zhang](https://github.com/HengchengZhang)
2469
- [Frank Wu](https://github.com/chappiewuzefan)
2570

71+
72+
## Prerequisites
73+
74+
We assume that readers have covered most of the QuantEcon lecture
75+
series [on Python programming](https://python-programming.quantecon.org/intro.html).
76+

lectures/jax_intro.md

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ kernelspec:
1111
name: python3
1212
---
1313

14-
# JAX
14+
# An Introduction to JAX
1515

1616

1717
```{admonition} GPU
@@ -26,63 +26,12 @@ Alternatively, if you have your own GPU, you can follow the [instructions](https
2626

2727
This lecture provides a short introduction to [Google JAX](https://github.com/google/jax).
2828

29-
## Overview
30-
31-
Let's start with an overview of JAX.
32-
33-
### Capabilities
34-
35-
[JAX](https://github.com/google/jax) is a Python library initially developed by
36-
Google to support in-house artificial intelligence and machine learning.
37-
38-
39-
JAX provides data types, functions and a compiler for fast linear
40-
algebra operations and automatic differentiation.
41-
42-
Loosely speaking, JAX is like NumPy with the addition of
43-
44-
* automatic differentiation
45-
* automated GPU/TPU support
46-
* a just-in-time compiler
47-
48-
One of the great benefits of JAX is that the same code can be run either on
49-
the CPU or on a hardware accelerator, such as a GPU or TPU.
50-
51-
For example, JAX automatically builds and deploys kernels on the GPU whenever
52-
an accessible device is detected.
53-
54-
### History
55-
56-
In 2015, Google open-sourced part of its AI infrastructure called TensorFlow.
57-
58-
Around two years later, Facebook open-sourced PyTorch beta, an alternative AI
59-
framework which is regarded as developer-friendly and more Pythonic than
60-
TensorFlow.
61-
62-
By 2019, PyTorch was surging in popularity, adopted by Uber, Airbnb, Tesla and
63-
many other companies.
64-
65-
In 2020, Google launched JAX as an open-source framework, simultaneously
66-
beginning to shift away from TPUs to Nvidia GPUs.
67-
68-
In the last few years, uptake of Google JAX has accelerated rapidly, bringing
69-
attention back to Google-based machine learning architectures.
70-
71-
72-
### Installation
73-
74-
JAX can be installed with or without GPU support by following [the install guide](https://github.com/google/jax).
75-
76-
Note that JAX is pre-installed with GPU support on [Google Colab](https://colab.research.google.com/).
77-
78-
If you do not have your own GPU, we recommend that you run this lecture on Colab.
79-
80-
+++
8129

8230
## JAX as a NumPy Replacement
8331

8432

85-
One way to use JAX is as a plug-in NumPy replacement. Let's look at the similarities and differences.
33+
One way to use JAX is as a plug-in NumPy replacement. Let's look at the
34+
similarities and differences.
8635

8736
### Similarities
8837

lectures/newtons_method.md

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@ Alternatively, if you have your own GPU, you can follow the [instructions](https
2626

2727
## Overview
2828

29-
Continuing from the [Newton's Method lecture](https://python.quantecon.org/newton_method.html), we are going to solve the multidimensional problem with `JAX`.
29+
In this lecture we highlight some of the capabilities of JAX, including JIT
30+
compilation and automatic differentiation.
3031

31-
More information about JAX can be found [here](https://python-programming.quantecon.org/jax_intro.html).
32+
The application is computing equilibria via Newton's method, which we discussed
33+
in [a more elementary QuantEcon lecture](https://python.quantecon.org/newton_method.html)
34+
35+
Here our focus is on how to apply JAX to this problem.
3236

3337
We use the following imports in this lecture
3438

@@ -38,11 +42,21 @@ import jax.numpy as jnp
3842
from scipy.optimize import root
3943
```
4044

41-
## The Two Goods Market Equilibrium
45+
## The Equilibrium Problem
46+
47+
In this section we describe the market equilibrium problem we will solve with
48+
JAX.
49+
50+
We begin with a two good case,
51+
which is borrowed from [an earlier lecture](https://python.quantecon.org/newton_method.html).
4252

43-
Let's have a quick recap of this problem -- a more detailed explanation and derivation can be found at [A Two Goods Market Equilibrium](https://python.quantecon.org/newton_method.html#two-goods-market).
53+
Then we shift to higher dimensions.
4454

45-
Assume we have a market for two complementary goods where demand depends on the price of both components.
55+
56+
### The Two Goods Market Equilibrium
57+
58+
Assume we have a market for two complementary goods where demand depends on the
59+
price of both components.
4660

4761
We label them good 0 and good 1, with price vector $p = (p_0, p_1)$.
4862

@@ -90,23 +104,24 @@ $$
90104
for this particular question.
91105

92106

93-
### The Multivariable Market Equilibrium
107+
### A High-Dimensional Version
94108

95-
We can now easily get the multivariable version of the problem above.
109+
Let's now shift to a linear algebra formulation, which alllows us to handle
110+
arbitrarily many goods.
96111

97112
The supply function remains unchanged,
98113

99114
$$
100-
q^s (p) =b \sqrt{p}
115+
q^s (p) =b \sqrt{p}
101116
$$
102117

103-
The demand function is,
118+
The demand function becomes
104119

105120
$$
106-
q^d (p) = \text{exp}(- A \cdot p) + c
121+
q^d (p) = \text{exp}(- A \cdot p) + c
107122
$$
108123

109-
Our new excess demand function is,
124+
Our new excess demand function is
110125

111126
$$
112127
e(p) = \text{exp}(- A \cdot p) + c - b \sqrt{p}
@@ -120,9 +135,17 @@ def e(p, A, b, c):
120135
```
121136

122137

123-
## Using Newton's Method
124138

125-
Now let's use the multivariate version of Newton's method to compute the equilibrium price
139+
## Computation
140+
141+
In this section we describe and then implement the solution method.
142+
143+
144+
### Newton's Method
145+
146+
We use a multivariate version of Newton's method to compute the equilibrium price.
147+
148+
The rule for updating a guess $p_n$ of the price vector is
126149

127150
```{math}
128151
:label: multi-newton
@@ -131,11 +154,9 @@ p_{n+1} = p_n - J_e(p_n)^{-1} e(p_n)
131154

132155
Here $J_e(p_n)$ is the Jacobian of $e$ evaluated at $p_n$.
133156

134-
The iteration starts from some initial guess of the price vector $p_0$.
135-
136-
Here, instead of coding Jacobian by hand, We use the `jax.jacobian()` function to auto-differentiate and calculate the Jacobian.
157+
Iteration starts from initial guess $p_0$.
137158

138-
With only slight modification, we can generalize [our previous attempt](https://python.quantecon.org/newton_method.html#first-newton-attempt) to multi-dimensional problems
159+
Instead of coding the Jacobian by hand, we use `jax.jacobian()`.
139160

140161
```{code-cell} ipython3
141162
def newton(f, x_0, tol=1e-5, max_iter=15):
@@ -159,9 +180,9 @@ def newton(f, x_0, tol=1e-5, max_iter=15):
159180
```
160181

161182

162-
### A High-Dimensional Problem
183+
### Application
163184

164-
We now apply the multivariate Newton's Method to investigate a large market with 5,000 goods.
185+
Let's now apply the method just described to investigate a large market with 5,000 goods.
165186

166187
We randomly generate the matrix $A$ and set the parameter vectors $b \text{ and } c$ to $1$.
167188

@@ -189,7 +210,6 @@ Here's our initial condition $p_0$
189210
init_p = jnp.ones(dim)
190211
```
191212

192-
193213
By leveraging the power of Newton's method, JAX accelerated linear algebra,
194214
automatic differentiation, and a GPU, we obtain a relatively small error for
195215
this very large problem in just a few seconds:

0 commit comments

Comments
 (0)