@@ -229,38 +229,75 @@ Autocast
229
229
^^^^^^^^
230
230
231
231
The Autocast dispatch key implements support for
232
- `automatic mixed precision <https://developer.nvidia.com/automatic-mixed-precision >`_
233
- (AMP). An autocast kernel typically modifies the operation of an operator by casting the
234
- input arguments to some precision before carrying out the operation. For some
235
- operations, it is numerically safe to cast to lower precision, which is how AMP
236
- can achieve speed ups and reduced memory usage without sacrificing much
237
- accuracy. A nontrivial autocast kernel looks something like this:
232
+ `automatic mixed precision (AMP) <https://pytorch.org/docs/stable/amp.html >`_.
233
+ An autocast wrapper kernel typically casts incoming ``float16 `` or ``float32 `` CUDA tensors
234
+ to some preferred precision before running the op.
235
+ For example, matmuls and convolutions on floating-point CUDA tensors usually run faster
236
+ and use less memory in ``float16 `` without impairing convergence.
237
+ Autocast wrappers only have an effect in
238
+ `autocast-enabled contexts <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast >`_.
239
+
240
+ Here's an autocast wrapper for a hypothetical custom matmul, along with its registration:
238
241
239
242
.. code-block :: cpp
240
243
244
+ // Autocast-specific helper functions
245
+ #include <ATen/autocast_mode.h>
246
+
241
247
Tensor mymatmul_autocast(const Tensor& self, const Tensor& other) {
242
248
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
243
- return mymatmul(autocast::_cast(at::kHalf, self), autocast::_cast(at::kHalf, other));
249
+ return mymatmul(at::autocast::cached_cast(at::kHalf, self),
250
+ at::autocast::cached_cast(at::kHalf, other));
251
+ }
252
+
253
+ TORCH_LIBRARY_IMPL(myops, Autocast, m) {
254
+ m.impl("mymatmul", mymatmul_autocast);
244
255
}
245
256
257
+ ``cached_cast(kHalf, tensor) `` casts ``tensor `` to ``float16 `` if ``tensor `` is CUDA and ``float32 ``,
258
+ otherwise, it leaves ``tensor `` unchanged (c.f. the
259
+ `eligibility policy <https://pytorch.org/docs/stable/amp.html#op-eligibility >`_ for natively autocasted ops).
260
+ This ensures if the network calls ``mymatmul `` on any mixture of ``float16 `` and ``float32 `` CUDA tensors,
261
+ ``mymatmul `` runs in ``float16 ``. Meanwhile, calls to ``mymatmul `` with non-CUDA, integer-type, or ``float64 ``
262
+ inputs are unaffected. Using ``cached_cast `` to follow the native eligibility policy in your own autocast wrapper
263
+ is recommended, but not required. For example, if you wanted to force ``float16 `` execution for all input types,
264
+ you could ``return mymatmul(self.half(), other.half()); `` instead of using ``cached_cast ``.
265
+
246
266
Notice that, like our autograd kernels, we exclude the ``Autocast `` key from
247
- dispatch before redispatching. By default, if no autocast kernel is provided,
248
- we simply fallthrough directly to the regular operator implementation (no
249
- autocasting occurs.) (We didn't use ``myadd `` for this example, since pointwise
250
- addition doesn't do autocasting and should just fall through).
251
-
252
- When should an autocast kernel be registered? Unfortunately, there aren't
253
- cut-and-dry rules for when you should cast to a lower precision. You can
254
- get a sense for what operators have autocasting behavior by looking at
255
- the `AMP documentation
256
- <https://pytorch.org/docs/master/amp.html#op-specific-behavior> `_. Some other
257
- general rules:
258
-
259
- * Operations that do reductions should be carried out in float32,
260
- * Any operation with multiple float tensor inputs has to standardize them
261
- to a common precision, and
262
- * Any operation that does a convolution or gemm under the hood should
263
- probably be float16
267
+ dispatch before redispatching.
268
+
269
+ By default, if no autocast wrapper is provided,
270
+ we fallthrough directly to the regular operator implementation (no
271
+ autocasting occurs). (We didn't use ``myadd `` for this example, since pointwise
272
+ addition doesn't need autocasting and should just fall through.)
273
+
274
+ When should an autocast wrapper be registered? Unfortunately, there aren't
275
+ cut-and-dried rules for an op's preferred precision. You can
276
+ get a sense for some native ops' preferred precisions by looking at the
277
+ `cast lists <https://pytorch.org/docs/master/amp.html#op-specific-behavior >`_.
278
+ General guidance:
279
+
280
+ * Ops that do reductions should probably execute in ``float32 ``,
281
+ * Any op that does a convolution or gemm under the hood should
282
+ probably execute in ``float16 ``, and
283
+ * Other ops with multiple floating-point tensor inputs should standardize
284
+ them to a common precision (unless the implementation supports inputs with different precisions).
285
+
286
+ If your custom op falls into the third category, the ``promote_type `` template
287
+ helps figure out the widest floating-point type present among input tensors, which is
288
+ the safest choice for the execution type:
289
+
290
+ .. code-block :: cpp
291
+
292
+ #include <ATen/autocast_mode.h>
293
+
294
+ Tensor my_multiple_input_op_autocast(const Tensor& t0, const Tensor& t1) {
295
+ c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
296
+ // The required at::kHalf argument is an optimistic initial guess.
297
+ auto exec_type = at::autocast::promote_type(at::kHalf, t0, t1);
298
+ return my_multiple_input_op(at::autocast::cached_cast(exec_type, t0),
299
+ at::autocast::cached_cast(exec_type, t1));
300
+ }
264
301
265
302
Batched
266
303
^^^^^^^
0 commit comments