@@ -89,6 +89,7 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
8989 if (isIValue ()) {
9090 LOG_DEBUG (ctx->logger , " Found IValue containing object of type " << *(ptr_.ivalue ->type ()));
9191 }
92+
9293 TRTORCH_CHECK (
9394 isITensor () || (isIValue () && (ptr_.ivalue ->isTensor () || ptr_.ivalue ->isCustomClass ())),
9495 " Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name ());
@@ -97,11 +98,22 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
9798
9899 if (isIValue ()) {
99100 if (ptr_.ivalue ->isTensor ()) {
100- auto weights = converters::Weights (ctx, ptr_.ivalue ->toTensor ());
101-
101+ auto weights = converters::Weights ();
102+ auto tensor = ptr_.ivalue ->toTensor ();
103+ if ((tensor.scalar_type () == at::kLong || tensor.scalar_type () == at::kDouble ) && !ctx->settings .truncate_long_and_double ) {
104+ TRTORCH_THROW_ERROR (" Unable to freeze tensor of type Int64/Float64 into constant layer, try to compile model with truncate_long_and_double enabled" );
105+ } else if (tensor.scalar_type () == at::kLong && ctx->settings .truncate_long_and_double ) {
106+ weights = converters::Weights (ctx, tensor.toType (at::kInt ));
107+ LOG_WARNING (" Truncating weight (constant in the graph) from Int64 to Int32" );
108+ } else if (tensor.scalar_type () == at::kDouble && ctx->settings .truncate_long_and_double ) {
109+ weights = converters::Weights (ctx, tensor.toType (at::kFloat ));
110+ LOG_WARNING (" Truncating weight (constant in the graph) from Float64 to Float32" );
111+ } else {
112+ weights = converters::Weights (ctx, tensor);
113+ }
114+
102115 auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
103116 TRTORCH_CHECK (const_layer, " Unable to freeze tensor into constant layer" );
104-
105117 out = const_layer->getOutput (0 );
106118
107119 std::ostringstream tensor_id;
@@ -119,7 +131,6 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
119131 }
120132
121133 LOG_DEBUG (" Frozen tensor shape: " << out->getDimensions ());
122-
123134 return out;
124135}
125136
0 commit comments