|
9 | 9 | import torch
|
10 | 10 |
|
11 | 11 | from torch._ops import OpOverload
|
| 12 | +from torch._subclasses import FakeTensor |
12 | 13 |
|
13 | 14 | from torch.ao.quantization.quantizer import (
|
14 | 15 | QuantizationAnnotation,
|
@@ -41,6 +42,18 @@ def decorator(annotator: Callable):
|
41 | 42 |
|
42 | 43 | return decorator
|
43 | 44 |
|
| 45 | +def _is_input_float_tensor(node: Node): |
| 46 | + """Check if the input is not a float tensor, so that we can skip quantization for the node |
| 47 | + since observers only works with float Tensors |
| 48 | + """ |
| 49 | + if ( |
| 50 | + not isinstance(node, Node) |
| 51 | + or "val" not in node.meta |
| 52 | + or not isinstance(node.meta["val"], FakeTensor) |
| 53 | + ): |
| 54 | + return False |
| 55 | + return node.meta["val"].dtype == torch.float32 |
| 56 | + |
44 | 57 |
|
45 | 58 | def _is_annotated(nodes: List[Node]):
|
46 | 59 | """
|
@@ -123,11 +136,11 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None
|
123 | 136 |
|
124 | 137 | input_qspec_map = {}
|
125 | 138 | input_act0 = node.args[0]
|
126 |
| - if isinstance(input_act0, Node): |
| 139 | + if _is_input_float_tensor(input_act0): |
127 | 140 | input_qspec_map[input_act0] = input_act_qspec
|
128 | 141 |
|
129 | 142 | input_act1 = node.args[1]
|
130 |
| - if isinstance(input_act1, Node): |
| 143 | + if _is_input_float_tensor(input_act1): |
131 | 144 | input_qspec_map[input_act1] = input_act_qspec
|
132 | 145 |
|
133 | 146 | node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
|
|
0 commit comments