Skip to content

Commit a74f50d

Browse files
kshitij12345pytorchmergebot
authored andcommitted
torch.compile-functorch interaction: update docs (pytorch#108130)
Doc Preview: https://docs-preview.pytorch.org/pytorch/pytorch/108130/torch.compiler_faq.html#torch-func-works-with-torch-compile-for-grad-and-vmap-transforms Will also cherry-pick this for release branch. Pull Request resolved: pytorch#108130 Approved by: https://github.com/zou3519
1 parent 42f94d7 commit a74f50d

File tree

1 file changed

+120
-30
lines changed

1 file changed

+120
-30
lines changed

docs/source/torch.compiler_faq.rst

Lines changed: 120 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,8 @@ them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2.
317317
CUDA graphs with Triton are enabled by default in inductor but removing
318318
them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``.
319319

320-
``torch.func`` does not work with ``torch.compile``
321-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
320+
``torch.func`` works with ``torch.compile`` (for `grad` and `vmap` transforms)
321+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
322322

323323
Applying a ``torch.func`` transform to a function that uses ``torch.compile``
324324
does not work:
@@ -337,12 +337,20 @@ does not work:
337337
x = torch.randn(2, 3)
338338
g(x)
339339
340+
This code will not work. There is an `issue <https://github.com/pytorch/pytorch/issues/100320>`__
341+
that you can track for this.
342+
340343
As a workaround, use ``torch.compile`` outside of the ``torch.func`` function:
341344

345+
.. note::
346+
This is an experimental feature and can be used by setting `torch._dynamo.config.capture_func_transforms=True`
347+
342348
.. code-block:: python
343349
344350
import torch
345351
352+
torch._dynamo.config.capture_func_transforms=True
353+
346354
def f(x):
347355
return torch.sin(x)
348356
@@ -353,58 +361,140 @@ As a workaround, use ``torch.compile`` outside of the ``torch.func`` function:
353361
x = torch.randn(2, 3)
354362
g(x)
355363
356-
Applying a ``torch.func`` transform to a function handled with ``torch.compile``
357-
--------------------------------------------------------------------------------
364+
Calling ``torch.func`` transform inside of a function handled with ``torch.compile``
365+
------------------------------------------------------------------------------------
366+
358367

359-
For example, you have the following code:
368+
Compiling ``torch.func.grad`` with ``torch.compile``
369+
----------------------------------------------------
360370

361371
.. code-block:: python
362372
363373
import torch
364374
365-
@torch.compile
366-
def f(x):
367-
return torch.sin(x)
375+
torch._dynamo.config.capture_func_transforms=True
368376
369-
def g(x):
370-
return torch.grad(f)(x)
377+
def wrapper_fn(x):
378+
return torch.func.grad(lambda x: x.sin().sum())(x)
371379
372-
x = torch.randn(2, 3)
373-
g(x)
380+
x = torch.randn(3, 3, 3)
381+
grad_x = torch.compile(wrapper_fn)(x)
374382
375-
This code will not work. There is an `issue <https://github.com/pytorch/pytorch/issues/100320>`__
376-
that you can track for this.
377-
As a workaround, please put the ``torch.compile`` outside of ``torch.func`` transform:
383+
Compiling ``torch.vmap`` with ``torch.compile``
384+
-----------------------------------------------
378385

379386
.. code-block:: python
380387
381388
import torch
382389
383-
def f(x):
384-
return torch.sin(x)
390+
torch._dynamo.config.capture_func_transforms=True
385391
386-
@torch.compile
387-
def g(x):
388-
return torch.vmap(f)(x)
392+
def my_fn(x):
393+
return torch.vmap(lambda x: x.sum(1))(x)
389394
390-
x = torch.randn(2, 3)
391-
g(x)
395+
x = torch.randn(3, 3, 3)
396+
output = torch.compile(my_fn)(x)
392397
393-
Calling ``torch.func`` transform inside of a function handled with ``torch.compile``
394-
------------------------------------------------------------------------------------
398+
Limitations
399+
-----------
400+
401+
There are currently a few cases which are not supported and lead to graph breaks
402+
(that is, torch.compile falls back to eager-mode PyTorch on these). We are working
403+
on improving the situation for the next release (PyTorch 2.2)
404+
405+
1. The inputs and outputs of the function being transformed over must be tensors.
406+
We do not yet support things like tuple of Tensors.
395407

396408
.. code-block:: python
397409
398410
import torch
399411
400-
@torch.compile
401-
def f(x):
402-
return torch.vmap(torch.sum)(x)
412+
torch._dynamo.config.capture_func_transforms=True
403413
404-
x = torch.randn(2, 3)
405-
f(x)
414+
def fn(x):
415+
x1, x2 = x
416+
return x1 + x2
417+
418+
def my_fn(x):
419+
return torch.func.vmap(fn)(x)
420+
421+
x1 = torch.randn(3, 3, 3)
422+
x2 = torch.randn(3, 3, 3)
423+
# Unsupported, falls back to eager-mode PyTorch
424+
output = torch.compile(my_fn)((x1, x2))
425+
426+
2. Keyword arguments are not supported.
427+
428+
.. code-block:: python
429+
430+
import torch
431+
432+
torch._dynamo.config.capture_func_transforms=True
433+
434+
def fn(x, y):
435+
return (x + y).sum()
436+
437+
def my_fn(x, y):
438+
return torch.func.grad(fn)(x, y=y)
439+
440+
x = torch.randn(3, 3)
441+
y = torch.randn(3, 3)
442+
# Unsupported, falls back to eager-mode PyTorch
443+
output = torch.compile(my_fn)(x, y)
444+
445+
3. Functions with observable side effects. For example, it is OK to mutate a list created in the function,
446+
but not OK to mutate a list created outside of the function.
447+
448+
.. code-block:: python
449+
450+
import torch
451+
452+
torch._dynamo.config.capture_func_transforms=True
453+
454+
some_list = []
455+
456+
def f(x, y):
457+
some_list.append(1)
458+
return x + y
459+
460+
def my_fn(x, y):
461+
return torch.func.vmap(f)(x, y)
462+
463+
x = torch.ones(2, 3)
464+
y = torch.randn(2, 3)
465+
# Unsupported, falls back to eager-mode PyTorch
466+
output = torch.compile(my_fn)(x, y)
467+
468+
4. ``torch.vmap`` over a function that calls one or more operators in the following list.
469+
470+
.. note::
471+
'stride', 'requires_grad', 'storage_offset', 'layout', 'data', 'is_coalesced', 'is_complex',
472+
'is_conj', 'is_contiguous', 'is_cpu', 'is_cuda', 'is_distributed', 'is_floating_point',
473+
'is_inference', 'is_ipu', 'is_leaf', 'is_meta', 'is_mkldnn', 'is_mps', 'is_neg', 'is_nested',
474+
'is_nonzero', 'is_ort', 'is_pinned', 'is_quantized', 'is_same_size', 'is_set_to', 'is_shared',
475+
'is_signed', 'is_sparse', 'is_sparse_csr', 'is_vulkan', 'is_xla', 'is_xpu'
476+
477+
.. code-block:: python
478+
479+
import torch
480+
481+
torch._dynamo.config.capture_func_transforms=True
482+
483+
def bad_fn(x):
484+
x.stride()
485+
return x
486+
487+
def my_fn(x):
488+
return torch.func.vmap(bad_fn)(x)
489+
490+
x = torch.randn(3, 3, 3)
491+
# Unsupported, falls back to eager-mode PyTorch
492+
output = torch.compile(my_fn)(x)
493+
494+
Compiling functions besides the ones which are supported (escape hatch)
495+
-----------------------------------------------------------------------
406496

407-
This doesn't work yet. As a workaround, use ``torch._dynamo.allow_in_graph``
497+
For other transforms, as a workaround, use ``torch._dynamo.allow_in_graph``
408498

409499
``allow_in_graph`` is an escape hatch. If your code does not work with
410500
``torch.compile``, which introspects Python bytecode, but you believe it

0 commit comments

Comments
 (0)