@@ -26,6 +26,7 @@ auto mm_registrations TRTORCH_UNUSED =
2626
2727 auto mm_layer = ctx->net ->addMatrixMultiply (
2828 *self, nvinfer1::MatrixOperation::kNONE , *other, nvinfer1::MatrixOperation::kNONE );
29+
2930 TRTORCH_CHECK (mm_layer, " Unable to create matrix multiplication node: " << *n);
3031 mm_layer->setName (util::node_info (n).c_str ());
3132 auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mm_layer->getOutput (0 ));
@@ -73,51 +74,6 @@ auto mm_registrations TRTORCH_UNUSED =
7374
7475 LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
7576 return true ;
76- }})
77- .pattern(
78- {" aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)" ,
79- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
80- auto self = args[0 ].ITensorOrFreeze (ctx);
81- auto mat1 = args[1 ].ITensorOrFreeze (ctx);
82- auto mat2 = args[2 ].ITensorOrFreeze (ctx);
83- auto beta = args[3 ].unwrapToScalar ().to <float >();
84- auto betaTensor = tensor_to_const (ctx, torch::tensor ({beta}));
85- auto alpha = args[4 ].unwrapToScalar ().to <float >();
86- auto alphaTensor = tensor_to_const (ctx, torch::tensor ({alpha}));
87-
88- // Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if
89- // necessary.
90- if (mat1->getDimensions ().nbDims < mat2->getDimensions ().nbDims ) {
91- mat1 = addPadding (ctx, n, mat1, mat2->getDimensions ().nbDims , false , false );
92- } else {
93- mat2 = addPadding (ctx, n, mat2, mat1->getDimensions ().nbDims , false , false );
94- }
95-
96- auto mm_layer = ctx->net ->addMatrixMultiply (
97- *mat1, nvinfer1::MatrixOperation::kNONE , *mat2, nvinfer1::MatrixOperation::kNONE );
98- TRTORCH_CHECK (mm_layer, " Unable to create matrix multiplication layer in node: " << *n);
99- auto mm_scale_layer = add_elementwise (
100- ctx,
101- nvinfer1::ElementWiseOperation::kPROD ,
102- mm_layer->getOutput (0 ),
103- alphaTensor,
104- util::node_info (n) + " _alphaScale" );
105- TRTORCH_CHECK (mm_scale_layer, " Unable to create alpha scaling layer in node: " << *n);
106- auto beta_scale_layer = add_elementwise (
107- ctx, nvinfer1::ElementWiseOperation::kPROD , self, betaTensor, util::node_info (n) + " _betaScale" );
108- TRTORCH_CHECK (beta_scale_layer, " Unable to create beta scaling layer in node: " << *n);
109- auto add_mm_layer = add_elementwise (
110- ctx,
111- nvinfer1::ElementWiseOperation::kSUM ,
112- beta_scale_layer->getOutput (0 ),
113- mm_scale_layer->getOutput (0 ),
114- util::node_info (n));
115- TRTORCH_CHECK (add_mm_layer, " Unable to create addmm layer in node: " << *n);
116-
117- auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mm_layer->getOutput (0 ));
118-
119- LOG_DEBUG (" [AddMM layer] Output tensor shape: " << out_tensor->getDimensions ());
120- return true ;
12177 }});
12278} // namespace
12379} // namespace impl
0 commit comments