Skip to content

Commit fc8eafb

Browse files
committed
feat: Add functionality for QAT workflow
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 54f08f9 commit fc8eafb

File tree

5 files changed

+130
-38
lines changed

5 files changed

+130
-38
lines changed

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
7070
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
7171
}
7272
input_type = nvinfer1::DataType::kFLOAT;
73-
TRTORCH_CHECK(
74-
settings.calibrator != nullptr,
75-
"Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
76-
cfg->setInt8Calibrator(settings.calibrator);
73+
// TRTORCH_CHECK(
74+
// settings.calibrator != nullptr,
75+
// "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
76+
// cfg->setInt8Calibrator(settings.calibrator);
7777
break;
7878
case nvinfer1::DataType::kFLOAT:
7979
default:

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ cc_library(
4747
"impl/matrix_multiply.cpp",
4848
"impl/normalize.cpp",
4949
"impl/pooling.cpp",
50+
"impl/quantization.cpp",
5051
"impl/reduce.cpp",
5152
"impl/replication_pad.cpp",
5253
"impl/select.cpp",

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 88 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,95 @@ namespace impl {
1111
namespace {
1212

1313
bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
14-
auto in = args[0].ITensor(); // assumes non-static input Tensor
15-
auto w = Weights(ctx, args[1].unwrapToTensor());
14+
// Input to conv/deconv
15+
auto in = args[0].ITensor();
16+
17+
// Conv /deconv parameters
1618
auto stride = util::toDims(args[3].unwrapToIntList());
1719
auto padding = util::toDims(args[4].unwrapToIntList());
1820
auto dilation = util::toDims(args[5].unwrapToIntList());
1921
bool transposed = args[6].unwrapToBool();
2022
auto out_padding = util::toDims(args[7].unwrapToIntList());
2123
int64_t groups = args[8].unwrapToInt();
2224

25+
// Reshape the parameters to 2D if needed
26+
if (stride.nbDims == 1) {
27+
stride = util::unsqueezeDims(stride, 1, 1);
28+
LOG_DEBUG("Reshaped stride: " << stride);
29+
}
30+
if (dilation.nbDims == 1) {
31+
dilation = util::unsqueezeDims(dilation, 1, 1);
32+
LOG_DEBUG("Reshaped dilation: " << dilation);
33+
}
34+
if (padding.nbDims == 1) {
35+
padding = util::unsqueezeDims(padding, 1, 0);
36+
LOG_DEBUG("Reshaped padding: " << padding);
37+
}
38+
if (out_padding.nbDims == 1) {
39+
out_padding = util::unsqueezeDims(out_padding, 1, 0);
40+
LOG_DEBUG("Reshaped out_padding: " << out_padding);
41+
}
42+
43+
// Get bias tensor or initialize it to zeros.
44+
Weights bias;
45+
if (args[2].IValue()->isTensor()) {
46+
bias = Weights(ctx, args[2].unwrapToTensor());
47+
} else {
48+
bias = Weights(); //nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
49+
}
50+
51+
// Handle case when weights of conv/deconv is an ITensor. This case happens for QAT networks where
52+
// conv_weights -> Quantize -> Dequantize -> new_conv_weights -> conv <- input
53+
// new_conv_weights will be an ITensor because it is an output of Dequantize layer defined in impl/quantization.cpp
54+
if (args[1].isITensor()){
55+
// Get the kernel tensor
56+
auto kernel = args[1].ITensor();
57+
auto kernel_dims = kernel->getDimensions();
58+
59+
// Make a new Dims with only the spatial dimensions.
60+
nvinfer1::Dims filter_dim;
61+
int64_t nbSpatialDims = in->getDimensions().nbDims - 2;
62+
TRTORCH_CHECK(nbSpatialDims = kernel_dims.nbDims - 2, "Number of input spatial dimensions should match the kernel spatial dimensions");
63+
filter_dim.nbDims = nbSpatialDims;
64+
filter_dim.d[0] = kernel_dims.d[2];
65+
filter_dim.d[1] = kernel_dims.d[3];
66+
67+
// Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API.
68+
auto kernel_weights = nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
69+
70+
nvinfer1::ILayer* layer = nullptr;
71+
if (transposed){
72+
nvinfer1::IDeconvolutionLayer* deconvLayer
73+
= ctx->net->addDeconvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
74+
deconvLayer->setStrideNd(stride);
75+
deconvLayer->setDilationNd(dilation);
76+
deconvLayer->setNbGroups(groups);
77+
deconvLayer->setPaddingNd(padding);
78+
// Set deconv kernel weights
79+
deconvLayer->setInput(1, *kernel);
80+
TRTORCH_CHECK(deconvLayer, "Unable to create deconv layer with non-const weights from node: " << *n);
81+
layer = deconvLayer;
82+
} else{
83+
nvinfer1::IConvolutionLayer* convLayer
84+
= ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
85+
convLayer->setStrideNd(stride);
86+
convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN);
87+
convLayer->setPaddingNd(padding);
88+
convLayer->setPostPadding(out_padding);
89+
convLayer->setDilationNd(dilation);
90+
convLayer->setNbGroups(groups);
91+
92+
// Set conv kernel weights
93+
convLayer->setInput(1, *kernel);
94+
layer = convLayer;
95+
}
96+
97+
ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
98+
LOG_DEBUG("Output tensor shape: " << layer->getOutput(0)->getDimensions());
99+
return true;
100+
}
101+
102+
auto w = Weights(ctx, args[1].unwrapToTensor());
23103
auto dims = in->getDimensions();
24104
auto orig_dims = dims;
25105
LOG_DEBUG("Input dims: " << orig_dims);
@@ -46,32 +126,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
46126
w.kernel_shape.d[1] = 1;
47127
LOG_DEBUG("Reshaped Weights: " << w);
48128
}
49-
if (stride.nbDims == 1) {
50-
stride = util::unsqueezeDims(stride, 1, 1);
51-
LOG_DEBUG("Reshaped stride: " << stride);
52-
}
53-
if (dilation.nbDims == 1) {
54-
dilation = util::unsqueezeDims(dilation, 1, 1);
55-
LOG_DEBUG("Reshaped dilation: " << dilation);
56-
}
57-
if (padding.nbDims == 1) {
58-
padding = util::unsqueezeDims(padding, 1, 0);
59-
LOG_DEBUG("Reshaped padding: " << padding);
60-
}
61-
if (out_padding.nbDims == 1) {
62-
out_padding = util::unsqueezeDims(out_padding, 1, 0);
63-
LOG_DEBUG("Reshaped out_padding: " << out_padding);
64-
}
65129

66130
nvinfer1::ILayer* new_layer;
67131
if (transposed) {
68-
Weights bias;
69-
if (args[2].IValue()->isTensor()) {
70-
bias = Weights(ctx, args[2].unwrapToTensor());
71-
} else {
72-
bias = Weights(ctx, torch::zeros(w.shape.d[1] * groups));
73-
}
74-
75132
// shape of deconvolution's weight: [in, out/groups, ...]
76133
auto deconv = ctx->net->addDeconvolutionNd(*in, w.shape.d[1] * groups, w.kernel_shape, w.data, bias.data);
77134
TRTORCH_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n);
@@ -89,12 +146,12 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
89146
#endif
90147
new_layer = deconv;
91148
} else {
92-
Weights bias;
93-
if (args[2].IValue()->isTensor()) {
94-
bias = Weights(ctx, args[2].unwrapToTensor());
95-
} else {
96-
bias = Weights(ctx, torch::zeros(w.shape.d[0]));
97-
}
149+
// Weights bias;
150+
// if (args[2].IValue()->isTensor()) {
151+
// bias = Weights(ctx, args[2].unwrapToTensor());
152+
// } else {
153+
// bias = Weights(ctx, torch::zeros(w.shape.d[0]));
154+
// }
98155

99156
// shape of convolution's weight: [out, in/groups, ...]
100157
auto conv = ctx->net->addConvolutionNd(*in, w.shape.d[0], w.kernel_shape, w.data, bias.data);

core/conversion/converters/impl/linear.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,30 @@ auto linear_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patt
4040
in = in_shuffle->getOutput(0);
4141
}
4242

43+
// Get the bias
44+
Weights bias;
45+
if(!args[2].IValue()->isNone()){
46+
bias = Weights(ctx, args[2].IValue()->toTensor());
47+
}else {
48+
bias = Weights();
49+
}
50+
51+
// Handle case when weights of conv/deconv is an ITensor. This case happens for QAT networks where
52+
// conv_weights -> Quantize -> Dequantize -> new_conv_weights -> conv <- input
53+
// new_conv_weights will be an ITensor because it is an output of Dequantize layer defined in impl/quantization.cpp
54+
if(args[1].isITensor()){
55+
auto kernel_tensor = args[1].ITensor();
56+
auto kernel_dims = args[1].ITensor()->getDimensions();
57+
// Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API.
58+
auto kernel_weights = nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
59+
auto fc_layer = ctx->net->addFullyConnected(*in, kernel_dims.d[0], kernel_weights, bias.data);
60+
fc_layer->setInput(1, *kernel_tensor);
61+
fc_layer->setName(util::node_info(n).c_str());
62+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], fc_layer->getOutput(0));
63+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
64+
return true;
65+
}
66+
4367
auto w_tensor = args[1].IValue()->toTensor();
4468
Weights w = Weights(ctx, w_tensor);
4569

core/lowering/passes/linear_to_addmm.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@ namespace passes {
1919
void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph) {
2020
// Define the decomposition function for aten::linear for the case where bias (mat2) is None.
2121
static torch::jit::CompilationUnit decompose_funcs(R"SCRIPT(
22-
def linear(self: Tensor, mat1: Tensor, mat2: Tensor):
22+
def linear_bias_none(self: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor:
2323
return torch.matmul(self, mat1.t())
24+
25+
def linear(self: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor:
26+
return torch.matmul(self, torch.transpose(mat1, 0, 1)) + mat2
2427
)SCRIPT");
2528

2629
// Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case)
@@ -29,16 +32,23 @@ void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph)
2932
auto n = *it;
3033
if (n->kind().toQualString() == std::string("aten::linear")) {
3134
auto input_values = n->inputs();
35+
std::cout << "WEIGHT CONST ?: " << input_values[1]->type()->isSubtypeOf(c10::TensorType::get()) << std::endl;
3236
// input_values[2] is the bias. If none, replace it with the decomposed linear graph.
3337
if (input_values[2]->type()->isSubtypeOf(c10::TensorType::get())) {
34-
continue;
35-
} else {
38+
// continue;
3639
torch::jit::WithInsertPoint guard(*it);
3740
std::shared_ptr<torch::jit::Graph> d_graph = decompose_funcs.get_function("linear").graph();
3841
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
3942
new_output->setType(it->output()->type());
4043
it->output()->replaceAllUsesWith(new_output);
4144
it.destroyCurrent();
45+
} else {
46+
torch::jit::WithInsertPoint guard(*it);
47+
std::shared_ptr<torch::jit::Graph> d_graph = decompose_funcs.get_function("linear_bias_none").graph();
48+
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
49+
new_output->setType(it->output()->type());
50+
it->output()->replaceAllUsesWith(new_output);
51+
it.destroyCurrent();
4252
}
4353
}
4454
}

0 commit comments

Comments
 (0)