Skip to content

Update autocast in dispatcher tutorial #1128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 20, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 61 additions & 24 deletions advanced_source/dispatcher.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,38 +229,75 @@ Autocast
^^^^^^^^

The Autocast dispatch key implements support for
`automatic mixed precision <https://developer.nvidia.com/automatic-mixed-precision>`_
(AMP). An autocast kernel typically modifies the operation of an operator by casting the
input arguments to some precision before carrying out the operation. For some
operations, it is numerically safe to cast to lower precision, which is how AMP
can achieve speed ups and reduced memory usage without sacrificing much
accuracy. A nontrivial autocast kernel looks something like this:
`automatic mixed precision (AMP) <https://pytorch.org/docs/stable/amp.html>`_.
An autocast wrapper kernel typically casts incoming ``float16`` or ``float32`` CUDA tensors
to some preferred precision before running the op.
For example, matmuls and convolutions on floating-point CUDA tensors usually run faster
and use less memory in ``float16`` without impairing convergence.
Autocast wrappers only have an effect in
`autocast-enabled contexts <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast>`_.

Here's an autocast wrapper for a hypothetical custom matmul, along with its registration:

.. code-block:: cpp

// Autocast-specific helper functions
#include <ATen/autocast_mode.h>

Tensor mymatmul_autocast(const Tensor& self, const Tensor& other) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return mymatmul(autocast::_cast(at::kHalf, self), autocast::_cast(at::kHalf, other));
return mymatmul(at::autocast::cached_cast(at::kHalf, self),
at::autocast::cached_cast(at::kHalf, other));
}

TORCH_LIBRARY_IMPL(myops, Autocast, m) {
m.impl("mymatmul", mymatmul_autocast);
}

``cached_cast(kHalf, tensor)`` casts ``tensor`` to ``float16`` if ``tensor`` is CUDA and ``float32``,
otherwise, it leaves ``tensor`` unchanged (c.f. the
`eligibility policy <https://pytorch.org/docs/stable/amp.html#op-eligibility>`_ for natively autocasted ops).
This ensures if the network calls ``mymatmul`` on any mixture of ``float16`` and ``float32`` CUDA tensors,
``mymatmul`` runs in ``float16``. Meanwhile, calls to ``mymatmul`` with non-CUDA, integer-type, or ``float64``
inputs are unaffected. Using ``cached_cast`` to follow the native eligibility policy in your own autocast wrapper
is recommended, but not required. For example, if you wanted to force ``float16`` execution for all input types,
you could ``return mymatmul(self.half(), other.half());`` instead of using ``cached_cast``.

Notice that, like our autograd kernels, we exclude the ``Autocast`` key from
dispatch before redispatching. By default, if no autocast kernel is provided,
we simply fallthrough directly to the regular operator implementation (no
autocasting occurs.) (We didn't use ``myadd`` for this example, since pointwise
addition doesn't do autocasting and should just fall through).

When should an autocast kernel be registered? Unfortunately, there aren't
cut-and-dry rules for when you should cast to a lower precision. You can
get a sense for what operators have autocasting behavior by looking at
the `AMP documentation
<https://pytorch.org/docs/master/amp.html#op-specific-behavior>`_. Some other
general rules:

* Operations that do reductions should be carried out in float32,
* Any operation with multiple float tensor inputs has to standardize them
to a common precision, and
* Any operation that does a convolution or gemm under the hood should
probably be float16
dispatch before redispatching.

By default, if no autocast wrapper is provided,
we fallthrough directly to the regular operator implementation (no
autocasting occurs). (We didn't use ``myadd`` for this example, since pointwise
addition doesn't need autocasting and should just fall through.)

When should an autocast wrapper be registered? Unfortunately, there aren't
cut-and-dried rules for an op's preferred precision. You can
get a sense for some native ops' preferred precisions by looking at the
`cast lists <https://pytorch.org/docs/master/amp.html#op-specific-behavior>`_.
General guidance:

* Ops that do reductions should probably execute in ``float32``,
* Any op that does a convolution or gemm under the hood should
probably execute in ``float16``, and
* Other ops with multiple floating-point tensor inputs should standardize
them to a common precision (unless the implementation supports inputs with different precisions).

If your custom op falls into the third category, the ``promote_type`` template
helps figure out the widest floating-point type present among input tensors, which is
the safest choice for the execution type:

.. code-block:: cpp

#include <ATen/autocast_mode.h>

Tensor my_multiple_input_op_autocast(const Tensor& t0, const Tensor& t1) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
// The required at::kHalf argument is an optimistic initial guess.
auto exec_type = at::autocast::promote_type(at::kHalf, t0, t1);
return my_multiple_input_op(at::autocast::cached_cast(exec_type, t0),
at::autocast::cached_cast(exec_type, t1));
}

Batched
^^^^^^^
Expand Down