Skip to content

Commit de2571f

Browse files
authored
Update autocast in dispatcher tutorial (#1128)
* draft * fixes * dont overrun the line
1 parent fdbe99c commit de2571f

File tree

1 file changed

+61
-24
lines changed

1 file changed

+61
-24
lines changed

advanced_source/dispatcher.rst

+61-24
Original file line numberDiff line numberDiff line change
@@ -229,38 +229,75 @@ Autocast
229229
^^^^^^^^
230230

231231
The Autocast dispatch key implements support for
232-
`automatic mixed precision <https://developer.nvidia.com/automatic-mixed-precision>`_
233-
(AMP). An autocast kernel typically modifies the operation of an operator by casting the
234-
input arguments to some precision before carrying out the operation. For some
235-
operations, it is numerically safe to cast to lower precision, which is how AMP
236-
can achieve speed ups and reduced memory usage without sacrificing much
237-
accuracy. A nontrivial autocast kernel looks something like this:
232+
`automatic mixed precision (AMP) <https://pytorch.org/docs/stable/amp.html>`_.
233+
An autocast wrapper kernel typically casts incoming ``float16`` or ``float32`` CUDA tensors
234+
to some preferred precision before running the op.
235+
For example, matmuls and convolutions on floating-point CUDA tensors usually run faster
236+
and use less memory in ``float16`` without impairing convergence.
237+
Autocast wrappers only have an effect in
238+
`autocast-enabled contexts <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast>`_.
239+
240+
Here's an autocast wrapper for a hypothetical custom matmul, along with its registration:
238241

239242
.. code-block:: cpp
240243
244+
// Autocast-specific helper functions
245+
#include <ATen/autocast_mode.h>
246+
241247
Tensor mymatmul_autocast(const Tensor& self, const Tensor& other) {
242248
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
243-
return mymatmul(autocast::_cast(at::kHalf, self), autocast::_cast(at::kHalf, other));
249+
return mymatmul(at::autocast::cached_cast(at::kHalf, self),
250+
at::autocast::cached_cast(at::kHalf, other));
251+
}
252+
253+
TORCH_LIBRARY_IMPL(myops, Autocast, m) {
254+
m.impl("mymatmul", mymatmul_autocast);
244255
}
245256
257+
``cached_cast(kHalf, tensor)`` casts ``tensor`` to ``float16`` if ``tensor`` is CUDA and ``float32``,
258+
otherwise, it leaves ``tensor`` unchanged (c.f. the
259+
`eligibility policy <https://pytorch.org/docs/stable/amp.html#op-eligibility>`_ for natively autocasted ops).
260+
This ensures if the network calls ``mymatmul`` on any mixture of ``float16`` and ``float32`` CUDA tensors,
261+
``mymatmul`` runs in ``float16``. Meanwhile, calls to ``mymatmul`` with non-CUDA, integer-type, or ``float64``
262+
inputs are unaffected. Using ``cached_cast`` to follow the native eligibility policy in your own autocast wrapper
263+
is recommended, but not required. For example, if you wanted to force ``float16`` execution for all input types,
264+
you could ``return mymatmul(self.half(), other.half());`` instead of using ``cached_cast``.
265+
246266
Notice that, like our autograd kernels, we exclude the ``Autocast`` key from
247-
dispatch before redispatching. By default, if no autocast kernel is provided,
248-
we simply fallthrough directly to the regular operator implementation (no
249-
autocasting occurs.) (We didn't use ``myadd`` for this example, since pointwise
250-
addition doesn't do autocasting and should just fall through).
251-
252-
When should an autocast kernel be registered? Unfortunately, there aren't
253-
cut-and-dry rules for when you should cast to a lower precision. You can
254-
get a sense for what operators have autocasting behavior by looking at
255-
the `AMP documentation
256-
<https://pytorch.org/docs/master/amp.html#op-specific-behavior>`_. Some other
257-
general rules:
258-
259-
* Operations that do reductions should be carried out in float32,
260-
* Any operation with multiple float tensor inputs has to standardize them
261-
to a common precision, and
262-
* Any operation that does a convolution or gemm under the hood should
263-
probably be float16
267+
dispatch before redispatching.
268+
269+
By default, if no autocast wrapper is provided,
270+
we fallthrough directly to the regular operator implementation (no
271+
autocasting occurs). (We didn't use ``myadd`` for this example, since pointwise
272+
addition doesn't need autocasting and should just fall through.)
273+
274+
When should an autocast wrapper be registered? Unfortunately, there aren't
275+
cut-and-dried rules for an op's preferred precision. You can
276+
get a sense for some native ops' preferred precisions by looking at the
277+
`cast lists <https://pytorch.org/docs/master/amp.html#op-specific-behavior>`_.
278+
General guidance:
279+
280+
* Ops that do reductions should probably execute in ``float32``,
281+
* Any op that does a convolution or gemm under the hood should
282+
probably execute in ``float16``, and
283+
* Other ops with multiple floating-point tensor inputs should standardize
284+
them to a common precision (unless the implementation supports inputs with different precisions).
285+
286+
If your custom op falls into the third category, the ``promote_type`` template
287+
helps figure out the widest floating-point type present among input tensors, which is
288+
the safest choice for the execution type:
289+
290+
.. code-block:: cpp
291+
292+
#include <ATen/autocast_mode.h>
293+
294+
Tensor my_multiple_input_op_autocast(const Tensor& t0, const Tensor& t1) {
295+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
296+
// The required at::kHalf argument is an optimistic initial guess.
297+
auto exec_type = at::autocast::promote_type(at::kHalf, t0, t1);
298+
return my_multiple_input_op(at::autocast::cached_cast(exec_type, t0),
299+
at::autocast::cached_cast(exec_type, t1));
300+
}
264301
265302
Batched
266303
^^^^^^^

0 commit comments

Comments
 (0)