Skip to content

Migrate from tic/toc timers and %time/%timeit to new Timer context manager #391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
53 changes: 30 additions & 23 deletions lectures/jax_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ In addition to what's in Anaconda, this lecture will need the following librarie
```{code-cell} ipython3
:tags: [hide-output]

!pip install jax
!pip install jax quantecon
```

This lecture provides a short introduction to [Google JAX](https://github.com/google/jax).
Expand Down Expand Up @@ -52,6 +52,7 @@ The following import is standard, replacing `import numpy as np`:
```{code-cell} ipython3
import jax
import jax.numpy as jnp
import quantecon as qe
```

Now we can use `jnp` in place of `np` for the usual array operations:
Expand Down Expand Up @@ -304,7 +305,8 @@ x = jnp.ones(n)
How long does the function take to execute?

```{code-cell} ipython3
%time f(x).block_until_ready()
with qe.Timer():
f(x).block_until_ready()
```

```{note}
Expand All @@ -318,7 +320,8 @@ allows the Python interpreter to run ahead of numerical computations.
If we run it a second time it becomes faster again:

```{code-cell} ipython3
%time f(x).block_until_ready()
with qe.Timer():
f(x).block_until_ready()
```

This is because the built in functions like `jnp.cos` are JIT compiled and the
Expand All @@ -341,7 +344,8 @@ y = jnp.ones(m)
```

```{code-cell} ipython3
%time f(y).block_until_ready()
with qe.Timer():
f(y).block_until_ready()
```

Notice that the execution time increases, because now new versions of
Expand All @@ -352,14 +356,16 @@ If we run again, the code is dispatched to the correct compiled version and we
get faster execution.

```{code-cell} ipython3
%time f(y).block_until_ready()
with qe.Timer():
f(y).block_until_ready()
```

The compiled versions for the previous array size are still available in memory
too, and the following call is dispatched to the correct compiled code.

```{code-cell} ipython3
%time f(x).block_until_ready()
with qe.Timer():
f(x).block_until_ready()
```

### Compiling the outer function
Expand All @@ -379,7 +385,8 @@ f_jit(x)
And now let's time it.

```{code-cell} ipython3
%time f_jit(x).block_until_ready()
with qe.Timer():
f_jit(x).block_until_ready()
```

Note the speed gain.
Expand Down Expand Up @@ -534,10 +541,10 @@ z_loops = np.empty((n, n))
```

```{code-cell} ipython3
%%time
for i in range(n):
for j in range(n):
z_loops[i, j] = f(x[i], y[j])
with qe.Timer():
for i in range(n):
for j in range(n):
z_loops[i, j] = f(x[i], y[j])
```

Even for this very small grid, the run time is extremely slow.
Expand Down Expand Up @@ -575,15 +582,15 @@ x_mesh, y_mesh = jnp.meshgrid(x, y)
Now we get what we want and the execution time is very fast.

```{code-cell} ipython3
%%time
z_mesh = f(x_mesh, y_mesh).block_until_ready()
with qe.Timer():
z_mesh = f(x_mesh, y_mesh).block_until_ready()
```

Let's run again to eliminate compile time.

```{code-cell} ipython3
%%time
z_mesh = f(x_mesh, y_mesh).block_until_ready()
with qe.Timer():
z_mesh = f(x_mesh, y_mesh).block_until_ready()
```

Let's confirm that we got the right answer.
Expand All @@ -602,8 +609,8 @@ x_mesh, y_mesh = jnp.meshgrid(x, y)
```

```{code-cell} ipython3
%%time
z_mesh = f(x_mesh, y_mesh).block_until_ready()
with qe.Timer():
z_mesh = f(x_mesh, y_mesh).block_until_ready()
```

But there is one problem here: the mesh grids use a lot of memory.
Expand Down Expand Up @@ -641,8 +648,8 @@ f_vec = jax.vmap(f_vec_y, in_axes=(0, None))
With this construction, we can now call the function $f$ on flat (low memory) arrays.

```{code-cell} ipython3
%%time
z_vmap = f_vec(x, y).block_until_ready()
with qe.Timer():
z_vmap = f_vec(x, y).block_until_ready()
```

The execution time is essentially the same as the mesh operation but we are using much less memory.
Expand Down Expand Up @@ -711,15 +718,15 @@ def compute_call_price_jax(β=β,
Let's run it once to compile it:

```{code-cell} ipython3
%%time
compute_call_price_jax().block_until_ready()
with qe.Timer():
compute_call_price_jax().block_until_ready()
```

And now let's time it:

```{code-cell} ipython3
%%time
compute_call_price_jax().block_until_ready()
with qe.Timer():
compute_call_price_jax().block_until_ready()
```

```{solution-end}
Expand Down
49 changes: 25 additions & 24 deletions lectures/numba.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ import quantecon as qe
import matplotlib.pyplot as plt
```



## Overview

In an {doc}`earlier lecture <need_for_speed>` we learned about vectorization, which is one method to improve speed and efficiency in numerical work.
Expand Down Expand Up @@ -133,17 +135,17 @@ Let's time and compare identical function calls across these two versions, start
```{code-cell} ipython3
n = 10_000_000

qe.tic()
qm(0.1, int(n))
time1 = qe.toc()
with qe.Timer() as timer1:
qm(0.1, int(n))
time1 = timer1.elapsed
```

Now let's try qm_numba

```{code-cell} ipython3
qe.tic()
qm_numba(0.1, int(n))
time2 = qe.toc()
with qe.Timer() as timer2:
qm_numba(0.1, int(n))
time2 = timer2.elapsed
```

This is already a very large speed gain.
Expand All @@ -153,9 +155,9 @@ In fact, the next time and all subsequent times it runs even faster as the funct
(qm_numba_result)=

```{code-cell} ipython3
qe.tic()
qm_numba(0.1, int(n))
time3 = qe.toc()
with qe.Timer() as timer3:
qm_numba(0.1, int(n))
time3 = timer3.elapsed
```

```{code-cell} ipython3
Expand Down Expand Up @@ -225,15 +227,13 @@ This is equivalent to adding `qm = jit(qm)` after the function definition.
The following now uses the jitted version:

```{code-cell} ipython3
%%time

qm(0.1, 100_000)
with qe.Timer():
qm(0.1, 100_000)
```

```{code-cell} ipython3
%%time

qm(0.1, 100_000)
with qe.Timer():
qm(0.1, 100_000)
```

Numba also provides several arguments for decorators to accelerate computation and cache functions -- see [here](https://numba.readthedocs.io/en/stable/user/performance-tips.html).
Expand Down Expand Up @@ -289,7 +289,8 @@ We can fix this error easily in this case by compiling `mean`.
def mean(data):
return np.mean(data)

%time bootstrap(data, mean, n_resamples)
with qe.Timer():
bootstrap(data, mean, n_resamples)
```

## Compiling Classes
Expand Down Expand Up @@ -534,11 +535,13 @@ def calculate_pi(n=1_000_000):
Now let's see how fast it runs:

```{code-cell} ipython3
%time calculate_pi()
with qe.Timer():
calculate_pi()
```

```{code-cell} ipython3
%time calculate_pi()
with qe.Timer():
calculate_pi()
```

If we switch off JIT compilation by removing `@njit`, the code takes around
Expand Down Expand Up @@ -639,9 +642,8 @@ This is (approximately) the right output.
Now let's time it:

```{code-cell} ipython3
qe.tic()
compute_series(n)
qe.toc()
with qe.Timer():
compute_series(n)
```

Next let's implement a Numba version, which is easy
Expand All @@ -660,9 +662,8 @@ print(np.mean(x == 0))
Let's see the time

```{code-cell} ipython3
qe.tic()
compute_series_numba(n)
qe.toc()
with qe.Timer():
compute_series_numba(n)
```

This is a nice speed improvement for one line of code!
Expand Down
62 changes: 29 additions & 33 deletions lectures/numpy.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ from mpl_toolkits.mplot3d.axes3d import Axes3D
from matplotlib import cm
```



(numpy_array)=
## NumPy Arrays

Expand Down Expand Up @@ -1190,21 +1192,19 @@ n = 1_000_000
```

```{code-cell} python3
%%time

y = 0 # Will accumulate and store sum
for i in range(n):
x = random.uniform(0, 1)
y += x**2
with qe.Timer():
y = 0 # Will accumulate and store sum
for i in range(n):
x = random.uniform(0, 1)
y += x**2
```

The following vectorized code achieves the same thing.

```{code-cell} ipython
%%time

x = np.random.uniform(0, 1, n)
y = np.sum(x**2)
with qe.Timer():
x = np.random.uniform(0, 1, n)
y = np.sum(x**2)
```

As you can see, the second code block runs much faster. Why?
Expand Down Expand Up @@ -1285,24 +1285,22 @@ grid = np.linspace(-3, 3, 1000)
Here's a non-vectorized version that uses Python loops.

```{code-cell} python3
%%time

m = -np.inf
with qe.Timer():
m = -np.inf

for x in grid:
for y in grid:
z = f(x, y)
if z > m:
m = z
for x in grid:
for y in grid:
z = f(x, y)
if z > m:
m = z
```

And here's a vectorized version

```{code-cell} python3
%%time

x, y = np.meshgrid(grid, grid)
np.max(f(x, y))
with qe.Timer():
x, y = np.meshgrid(grid, grid)
np.max(f(x, y))
```

In the vectorized version, all the looping takes place in compiled code.
Expand Down Expand Up @@ -1636,9 +1634,8 @@ np.random.seed(123)
x = np.random.randn(1000, 100, 100)
y = np.random.randn(100)

qe.tic()
B = x / y
qe.toc()
with qe.Timer("Broadcasting operation"):
B = x / y
```

Here is the output
Expand Down Expand Up @@ -1696,14 +1693,13 @@ np.random.seed(123)
x = np.random.randn(1000, 100, 100)
y = np.random.randn(100)

qe.tic()
D = np.empty_like(x)
d1, d2, d3 = x.shape
for i in range(d1):
for j in range(d2):
for k in range(d3):
D[i, j, k] = x[i, j, k] / y[k]
qe.toc()
with qe.Timer("For loop operation"):
D = np.empty_like(x)
d1, d2, d3 = x.shape
for i in range(d1):
for j in range(d2):
for k in range(d3):
D[i, j, k] = x[i, j, k] / y[k]
```

Note that the `for` loop takes much longer than the broadcasting operation.
Expand Down
Loading
Loading