Skip to content

Commit 6fb2ef6

Browse files
committed
add e2e example
1 parent bfeb392 commit 6fb2ef6

File tree

2 files changed

+213
-6
lines changed

2 files changed

+213
-6
lines changed
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
``torch.compile`` End-to-End Tutorial
5+
=================================
6+
**Author:** William Wen
7+
"""
8+
9+
import warnings
10+
11+
######################################################################
12+
# ``torch.compile`` is the new way to speed up your PyTorch code!
13+
# ``torch.compile`` makes PyTorch code run faster by
14+
# JIT-compiling PyTorch code into optimized kernels,
15+
# while requiring minimal code changes.
16+
#
17+
# This tutorial covers an end-to-end example of training and evaluating a
18+
# real model with ``torch.compile``. For a gentler introduction to ``torch.compile``,
19+
# please check out our ```torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__.
20+
#
21+
# **Required pip Dependencies**
22+
#
23+
# - ``torch >= 2.0``
24+
# - ``torchvision``
25+
26+
# NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in
27+
# order to reproduce the speedup numbers shown below and documented elsewhere.
28+
29+
import torch
30+
31+
gpu_ok = False
32+
if torch.cuda.is_available():
33+
device_cap = torch.cuda.get_device_capability()
34+
if device_cap in ((7, 0), (8, 0), (9, 0)):
35+
gpu_ok = True
36+
37+
if not gpu_ok:
38+
warnings.warn(
39+
"GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
40+
"than expected."
41+
)
42+
43+
44+
######################################################################
45+
# Let's demonstrate how using ``torch.compile`` can speed up a real model.
46+
# We will compare standard eager mode and
47+
# ``torch.compile`` by evaluating and training a ``torchvision`` model on random data.
48+
#
49+
# Before we start, we need to define some utility functions.
50+
51+
52+
# Returns the result of running `fn()` and the time it took for `fn()` to run,
53+
# in seconds. We use CUDA events and synchronization for the most accurate
54+
# measurements.
55+
def timed(fn):
56+
start = torch.cuda.Event(enable_timing=True)
57+
end = torch.cuda.Event(enable_timing=True)
58+
start.record()
59+
result = fn()
60+
end.record()
61+
torch.cuda.synchronize()
62+
return result, start.elapsed_time(end) / 1000
63+
64+
65+
# Generates random input and targets data for the model, where `b` is
66+
# batch size.
67+
def generate_data(b):
68+
return (
69+
torch.randn(b, 3, 128, 128).to().cuda(),
70+
torch.randint(1000, (b,)).cuda(),
71+
)
72+
73+
74+
N_ITERS = 10
75+
76+
from torchvision.models import densenet121
77+
78+
79+
def init_model():
80+
return densenet121().cuda()
81+
82+
83+
######################################################################
84+
# First, let's compare inference.
85+
#
86+
# Note that in the call to ``torch.compile``, we have the additional
87+
# ``mode`` argument, which we will discuss below.
88+
89+
model = init_model()
90+
91+
model_opt = torch.compile(model, mode="reduce-overhead")
92+
93+
inp = generate_data(16)[0]
94+
with torch.no_grad():
95+
print("eager:", timed(lambda: model(inp))[1])
96+
print("compile:", timed(lambda: model_opt(inp))[1])
97+
98+
######################################################################
99+
# Notice that ``torch.compile`` takes a lot longer to complete
100+
# compared to eager. This is because ``torch.compile`` compiles
101+
# the model into optimized kernels as it executes. In our example, the
102+
# structure of the model doesn't change, and so recompilation is not
103+
# needed. So if we run our optimized model several more times, we should
104+
# see a significant improvement compared to eager.
105+
106+
eager_times = []
107+
for i in range(N_ITERS):
108+
inp = generate_data(16)[0]
109+
with torch.no_grad():
110+
_, eager_time = timed(lambda: model(inp))
111+
eager_times.append(eager_time)
112+
print(f"eager eval time {i}: {eager_time}")
113+
114+
print("~" * 10)
115+
116+
compile_times = []
117+
for i in range(N_ITERS):
118+
inp = generate_data(16)[0]
119+
with torch.no_grad():
120+
_, compile_time = timed(lambda: model_opt(inp))
121+
compile_times.append(compile_time)
122+
print(f"compile eval time {i}: {compile_time}")
123+
print("~" * 10)
124+
125+
import numpy as np
126+
127+
eager_med = np.median(eager_times)
128+
compile_med = np.median(compile_times)
129+
speedup = eager_med / compile_med
130+
assert speedup > 1
131+
print(
132+
f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
133+
)
134+
print("~" * 10)
135+
136+
######################################################################
137+
# And indeed, we can see that running our model with ``torch.compile``
138+
# results in a significant speedup. Speedup mainly comes from reducing Python overhead and
139+
# GPU read/writes, and so the observed speedup may vary on factors such as model
140+
# architecture and batch size. For example, if a model's architecture is simple
141+
# and the amount of data is large, then the bottleneck would be
142+
# GPU compute and the observed speedup may be less significant.
143+
#
144+
# You may also see different speedup results depending on the chosen ``mode``
145+
# argument. The ``"reduce-overhead"`` mode uses CUDA graphs to further reduce
146+
# the overhead of Python. For your own models,
147+
# you may need to experiment with different modes to maximize speedup. You can
148+
# read more about modes `here <https://pytorch.org/get-started/pytorch-2.0/#user-experience>`__.
149+
#
150+
# You may might also notice that the second time we run our model with ``torch.compile`` is significantly
151+
# slower than the other runs, although it is much faster than the first run. This is because the ``"reduce-overhead"``
152+
# mode runs a few warm-up iterations for CUDA graphs.
153+
#
154+
# Now, let's consider comparing training.
155+
156+
model = init_model()
157+
opt = torch.optim.Adam(model.parameters())
158+
159+
160+
def train(mod, data):
161+
opt.zero_grad(True)
162+
pred = mod(data[0])
163+
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
164+
loss.backward()
165+
opt.step()
166+
167+
168+
eager_times = []
169+
for i in range(N_ITERS):
170+
inp = generate_data(16)
171+
_, eager_time = timed(lambda: train(model, inp))
172+
eager_times.append(eager_time)
173+
print(f"eager train time {i}: {eager_time}")
174+
print("~" * 10)
175+
176+
model = init_model()
177+
opt = torch.optim.Adam(model.parameters())
178+
train_opt = torch.compile(train, mode="reduce-overhead")
179+
180+
compile_times = []
181+
for i in range(N_ITERS):
182+
inp = generate_data(16)
183+
_, compile_time = timed(lambda: train_opt(model, inp))
184+
compile_times.append(compile_time)
185+
print(f"compile train time {i}: {compile_time}")
186+
print("~" * 10)
187+
188+
eager_med = np.median(eager_times)
189+
compile_med = np.median(compile_times)
190+
speedup = eager_med / compile_med
191+
assert speedup > 1
192+
print(
193+
f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
194+
)
195+
print("~" * 10)
196+
197+
######################################################################
198+
# Again, we can see that ``torch.compile`` takes longer in the first
199+
# iteration, as it must compile the model, but in subsequent iterations, we see
200+
# significant speedups compared to eager.
201+
#
202+
# We remark that the speedup numbers presented in this tutorial are for
203+
# demonstration purposes only. Official speedup values can be seen at the
204+
# `TorchInductor performance dashboard <https://hud.pytorch.org/benchmark/compilers>`__.

intermediate_source/torch_compile_tutorial.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
# our previous PyTorch compiler solution,
2626
# `TorchScript <https://pytorch.org/docs/stable/jit.html>`__.
2727
#
28+
# For an end-to-end example on a real model, check out our `end-to-end ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_full_example.html>`__.
29+
#
2830
# **Contents**
2931
#
3032
# .. contents::
@@ -167,10 +169,6 @@ def timed(fn):
167169
# turn off logging for now to prevent spam
168170
torch._logging.set_logs(graph_code=False)
169171

170-
# from torch._inductor.runtime.benchmarking import benchmarker
171-
# eager_latency = benchmarker.benchmark_gpu(lambda: mod1(torch.randn)) * 1e3
172-
# compiled_latency = benchmarker.benchmark_gpu(lambda: compiled_f(x)) * 1e3
173-
174172
eager_times = []
175173
for i in range(10):
176174
_, eager_time = timed(lambda: foo3(inp))
@@ -203,6 +201,8 @@ def timed(fn):
203201
# architecture and batch size. For example, if a model's architecture is simple
204202
# and the amount of data is large, then the bottleneck would be
205203
# GPU compute and the observed speedup may be less significant.
204+
#
205+
# To see speedups on a real model, check out our `end-to-end ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_full_example.html>`__.
206206

207207
######################################################################
208208
# Benefits over TorchScript
@@ -424,9 +424,9 @@ def false_branch(y):
424424
# ---------------
425425
# Is ``torch.compile`` failing to speed up your model? Is compile time unreasonably long?
426426
# Is your code recompiling excessively? Are you having difficulties dealing with graph breaks?
427-
# Or maybe you want to simply learn more about the inner workings of ``torch.compile``?
427+
# Or maybe you simply want to learn more about the inner workings of ``torch.compile``?
428428
#
429-
# Check out `the ``torch.compile`` troubleshooting guide <https://pytorch.org/docs/stable/torch.compiler_troubleshooting.html>`_!
429+
# Check out `the ``torch.compile`` troubleshooting guide <https://pytorch.org/docs/stable/torch.compiler_troubleshooting.html>`__!
430430

431431
######################################################################
432432
# Conclusion
@@ -435,4 +435,7 @@ def false_branch(y):
435435
# In this tutorial, we introduced ``torch.compile`` by covering
436436
# basic usage, demonstrating speedups over eager mode, comparing to TorchScript,
437437
# and briefly describing graph breaks.
438+
#
439+
# For an end-to-end example on a real model, check out our `end-to-end ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_full_example.html>`__.
440+
#
438441
# We hope that you will give ``torch.compile`` a try!

0 commit comments

Comments
 (0)