@@ -317,8 +317,8 @@ them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2.
317317CUDA graphs with Triton are enabled by default in inductor but removing
318318them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False ``.
319319
320- ``torch.func `` does not work with ``torch.compile ``
321- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
320+ ``torch.func `` works with ``torch.compile `` (for ` grad ` and ` vmap ` transforms)
321+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
322322
323323Applying a ``torch.func `` transform to a function that uses ``torch.compile ``
324324does not work:
@@ -337,12 +337,20 @@ does not work:
337337 x = torch.randn(2 , 3 )
338338 g(x)
339339
340+ This code will not work. There is an `issue <https://github.com/pytorch/pytorch/issues/100320 >`__
341+ that you can track for this.
342+
340343As a workaround, use ``torch.compile `` outside of the ``torch.func `` function:
341344
345+ .. note ::
346+ This is an experimental feature and can be used by setting `torch._dynamo.config.capture_func_transforms=True `
347+
342348.. code-block :: python
343349
344350 import torch
345351
352+ torch._dynamo.config.capture_func_transforms= True
353+
346354 def f (x ):
347355 return torch.sin(x)
348356
@@ -353,58 +361,140 @@ As a workaround, use ``torch.compile`` outside of the ``torch.func`` function:
353361 x = torch.randn(2 , 3 )
354362 g(x)
355363
356- Applying a ``torch.func `` transform to a function handled with ``torch.compile ``
357- --------------------------------------------------------------------------------
364+ Calling ``torch.func `` transform inside of a function handled with ``torch.compile ``
365+ ------------------------------------------------------------------------------------
366+
358367
359- For example, you have the following code:
368+ Compiling ``torch.func.grad `` with ``torch.compile ``
369+ ----------------------------------------------------
360370
361371.. code-block :: python
362372
363373 import torch
364374
365- @torch.compile
366- def f (x ):
367- return torch.sin(x)
375+ torch._dynamo.config.capture_func_transforms= True
368376
369- def g (x ):
370- return torch.grad(f )(x)
377+ def wrapper_fn (x ):
378+ return torch.func. grad(lambda x : x.sin().sum() )(x)
371379
372- x = torch.randn(2 , 3 )
373- g (x)
380+ x = torch.randn(3 , 3 , 3 )
381+ grad_x = torch.compile(wrapper_fn) (x)
374382
375- This code will not work. There is an `issue <https://github.com/pytorch/pytorch/issues/100320 >`__
376- that you can track for this.
377- As a workaround, please put the ``torch.compile `` outside of ``torch.func `` transform:
383+ Compiling ``torch.vmap `` with ``torch.compile ``
384+ -----------------------------------------------
378385
379386.. code-block :: python
380387
381388 import torch
382389
383- def f (x ):
384- return torch.sin(x)
390+ torch._dynamo.config.capture_func_transforms= True
385391
386- @torch.compile
387- def g (x ):
388- return torch.vmap(f)(x)
392+ def my_fn (x ):
393+ return torch.vmap(lambda x : x.sum(1 ))(x)
389394
390- x = torch.randn(2 , 3 )
391- g (x)
395+ x = torch.randn(3 , 3 , 3 )
396+ output = torch.compile(my_fn) (x)
392397
393- Calling ``torch.func `` transform inside of a function handled with ``torch.compile ``
394- ------------------------------------------------------------------------------------
398+ Limitations
399+ -----------
400+
401+ There are currently a few cases which are not supported and lead to graph breaks
402+ (that is, torch.compile falls back to eager-mode PyTorch on these). We are working
403+ on improving the situation for the next release (PyTorch 2.2)
404+
405+ 1. The inputs and outputs of the function being transformed over must be tensors.
406+ We do not yet support things like tuple of Tensors.
395407
396408.. code-block :: python
397409
398410 import torch
399411
400- @torch.compile
401- def f (x ):
402- return torch.vmap(torch.sum)(x)
412+ torch._dynamo.config.capture_func_transforms= True
403413
404- x = torch.randn(2 , 3 )
405- f(x)
414+ def fn (x ):
415+ x1, x2 = x
416+ return x1 + x2
417+
418+ def my_fn (x ):
419+ return torch.func.vmap(fn)(x)
420+
421+ x1 = torch.randn(3 , 3 , 3 )
422+ x2 = torch.randn(3 , 3 , 3 )
423+ # Unsupported, falls back to eager-mode PyTorch
424+ output = torch.compile(my_fn)((x1, x2))
425+
426+ 2. Keyword arguments are not supported.
427+
428+ .. code-block :: python
429+
430+ import torch
431+
432+ torch._dynamo.config.capture_func_transforms= True
433+
434+ def fn (x , y ):
435+ return (x + y).sum()
436+
437+ def my_fn (x , y ):
438+ return torch.func.grad(fn)(x, y = y)
439+
440+ x = torch.randn(3 , 3 )
441+ y = torch.randn(3 , 3 )
442+ # Unsupported, falls back to eager-mode PyTorch
443+ output = torch.compile(my_fn)(x, y)
444+
445+ 3. Functions with observable side effects. For example, it is OK to mutate a list created in the function,
446+ but not OK to mutate a list created outside of the function.
447+
448+ .. code-block :: python
449+
450+ import torch
451+
452+ torch._dynamo.config.capture_func_transforms= True
453+
454+ some_list = []
455+
456+ def f (x , y ):
457+ some_list.append(1 )
458+ return x + y
459+
460+ def my_fn (x , y ):
461+ return torch.func.vmap(f)(x, y)
462+
463+ x = torch.ones(2 , 3 )
464+ y = torch.randn(2 , 3 )
465+ # Unsupported, falls back to eager-mode PyTorch
466+ output = torch.compile(my_fn)(x, y)
467+
468+ 4. ``torch.vmap `` over a function that calls one or more operators in the following list.
469+
470+ .. note ::
471+ 'stride', 'requires_grad', 'storage_offset', 'layout', 'data', 'is_coalesced', 'is_complex',
472+ 'is_conj', 'is_contiguous', 'is_cpu', 'is_cuda', 'is_distributed', 'is_floating_point',
473+ 'is_inference', 'is_ipu', 'is_leaf', 'is_meta', 'is_mkldnn', 'is_mps', 'is_neg', 'is_nested',
474+ 'is_nonzero', 'is_ort', 'is_pinned', 'is_quantized', 'is_same_size', 'is_set_to', 'is_shared',
475+ 'is_signed', 'is_sparse', 'is_sparse_csr', 'is_vulkan', 'is_xla', 'is_xpu'
476+
477+ .. code-block :: python
478+
479+ import torch
480+
481+ torch._dynamo.config.capture_func_transforms= True
482+
483+ def bad_fn (x ):
484+ x.stride()
485+ return x
486+
487+ def my_fn (x ):
488+ return torch.func.vmap(bad_fn)(x)
489+
490+ x = torch.randn(3 , 3 , 3 )
491+ # Unsupported, falls back to eager-mode PyTorch
492+ output = torch.compile(my_fn)(x)
493+
494+ Compiling functions besides the ones which are supported (escape hatch)
495+ -----------------------------------------------------------------------
406496
407- This doesn't work yet. As a workaround, use ``torch._dynamo.allow_in_graph ``
497+ For other transforms, as a workaround, use ``torch._dynamo.allow_in_graph ``
408498
409499``allow_in_graph `` is an escape hatch. If your code does not work with
410500``torch.compile ``, which introspects Python bytecode, but you believe it
0 commit comments