Skip to content

Commit 4a10673

Browse files
committed
feat: Add QAT patch which modifies scale factor dtype to INT32
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent b3101c6 commit 4a10673

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

docker/Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ COPY . /workspace/trtorch/src
2626
WORKDIR /workspace/trtorch/src
2727
RUN cp ./docker/WORKSPACE.cu.docker WORKSPACE
2828

29-
# This script builds both libtrtorch bin/lib/include tarball and the Pythin wheel, in dist/
29+
# This script builds both libtrtorch bin/lib/include tarball and the Pythin wheel, in dist/
3030
RUN ./docker/dist-build.sh
3131

3232
FROM base as trtorch
3333

3434
# copy source repo
3535
COPY . /workspace/trtorch
3636
COPY --from=trtorch-builder /workspace/trtorch/src/dist/ .
37-
37+
RUN patch -u /opt/conda/lib/python3.8/site-packages/pytorch_quantization/nn/modules/tensor_quantizer.py -i /workspace/trtorch/docker/qat.patch
3838
RUN conda init bash
3939

4040
RUN pip3 install ipywidgets --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host=files.pythonhosted.org

docker/qat.patch

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
--- /opt/conda/lib/python3.8/site-packages/pytorch_quantization/nn/modules/tensor_quantizer.py 2021-08-16 22:50:37.000000000 +0000
2+
+++ tensor_quantizer.py 2021-10-19 20:41:54.288077426 +0000
3+
@@ -291,7 +291,7 @@
4+
quant_dim = list(amax.shape).index(list(amax_sequeeze.shape)[0])
5+
scale = amax_sequeeze / bound
6+
outputs = torch.fake_quantize_per_channel_affine(
7+
- inputs, scale.data, torch.zeros_like(scale, dtype=torch.long).data, quant_dim,
8+
+ inputs, scale.data, torch.zeros_like(scale, dtype=torch.int32).data, quant_dim,
9+
-bound - 1 if not self._unsigned else 0, bound)
10+
11+
return outputs

0 commit comments

Comments
 (0)