File tree Expand file tree Collapse file tree 4 files changed +17
-8
lines changed Expand file tree Collapse file tree 4 files changed +17
-8
lines changed Original file line number Diff line number Diff line change @@ -416,7 +416,7 @@ void ConvertBlockToNetDef(
416
416
EvaluateConditionalBlock (ctx, n);
417
417
} else if (to_eval) {
418
418
auto eval = EvaluateNode (ctx, n);
419
- if (eval && n-> outputs (). size () > 0 ) {
419
+ if (eval) {
420
420
if (n->outputs ().size () > 1 ) { // For ListUnpack scenario
421
421
if (eval.value ().isTuple ()) {
422
422
auto eval_list = eval.value ().toTuple ();
Original file line number Diff line number Diff line change @@ -288,14 +288,18 @@ auto aten_registrations TORCHTRT_UNUSED =
288
288
.evaluator({c10::Symbol::fromQualString (" aten::extend" ),
289
289
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
290
290
if (args.at (n->input (0 )).IValue ()->isList () && args.at (n->input (1 )).IValue ()->isList ()) {
291
- auto self = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
291
+ c10::IValue* self_ptr = args.at (n->input (0 )).IValueMut ();
292
+ auto self = self_ptr->to <c10::List<c10::IValue>>();
292
293
auto other = args.at (n->input (1 )).IValue ()->to <c10::List<c10::IValue>>();
293
294
const int64_t other_size = other.size ();
294
295
296
+ // Modify value in place
295
297
for (int64_t i = 0 ; i < other_size; i++) {
296
298
self.push_back (other.get (i));
297
299
}
298
- return self;
300
+
301
+ *self_ptr = c10::IValue (self);
302
+ return {};
299
303
} else {
300
304
TORCHTRT_THROW_ERROR (
301
305
" Unimplemented data type for aten::extend.t evaluator: "
Original file line number Diff line number Diff line change @@ -13,7 +13,7 @@ Var::Var() {
13
13
type_ = Type::kNone ;
14
14
}
15
15
16
- Var::Var (const torch::jit::IValue* p) : type_(Type::kIValue ) {
16
+ Var::Var (torch::jit::IValue* p) : type_(Type::kIValue ) {
17
17
ptr_.ivalue = p;
18
18
}
19
19
@@ -56,7 +56,7 @@ Var& Var::operator=(const Var& a) {
56
56
return (*this );
57
57
}
58
58
59
- Var& Var::operator =(const torch::jit::IValue* in) {
59
+ Var& Var::operator =(torch::jit::IValue* in) {
60
60
ptr_.ivalue = in;
61
61
type_ = Type::kIValue ;
62
62
return (*this );
@@ -116,6 +116,10 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
116
116
}
117
117
118
118
const torch::jit::IValue* Var::IValue () const {
119
+ return IValueMut ();
120
+ }
121
+
122
+ torch::jit::IValue* Var::IValueMut () const {
119
123
TORCHTRT_CHECK (isIValue (), " Requested IValue from Var, however Var type is " << type_name ());
120
124
if (type_ == Type::kIValue ) {
121
125
return ptr_.ivalue ;
Original file line number Diff line number Diff line change @@ -17,13 +17,14 @@ class Var : torch::CustomClassHolder {
17
17
enum Type { kITensor , kIValue , kNone };
18
18
19
19
Var ();
20
- Var (const torch::jit::IValue* p);
20
+ Var (torch::jit::IValue* p);
21
21
Var (nvinfer1::ITensor* p);
22
22
Var (const Var& a);
23
23
Var& operator =(const Var& a);
24
- Var& operator =(const torch::jit::IValue* in);
24
+ Var& operator =(torch::jit::IValue* in);
25
25
Var& operator =(nvinfer1::ITensor* in);
26
26
const torch::jit::IValue* IValue () const ;
27
+ torch::jit::IValue* IValueMut () const ;
27
28
nvinfer1::ITensor* ITensor () const ;
28
29
29
30
// TODO: Can we consolidate this in a way that prevents requesting invalid
@@ -63,7 +64,7 @@ class Var : torch::CustomClassHolder {
63
64
64
65
private:
65
66
union VarContainer {
66
- const torch::jit::IValue* ivalue;
67
+ torch::jit::IValue* ivalue;
67
68
nvinfer1::ITensor* tensor;
68
69
void * none;
69
70
};
You can’t perform that action at this time.
0 commit comments