Skip to content

Commit 97eb4eb

Browse files
committed
refactor(//core/conversion/var): Modify var contract
This modifies the variable contract to allow requesting of mutable handles to `IValue`s which previously were provided only as `const`. Now every IValue pointer is not `const` within `Var`, but in order to get a `const` handle you use IValueMut. This supports the evaluation of inplace operations and makes state changes explicit rather than implicit. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent e1b68fc commit 97eb4eb

File tree

4 files changed

+17
-8
lines changed

4 files changed

+17
-8
lines changed

core/conversion/conversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ void ConvertBlockToNetDef(
416416
EvaluateConditionalBlock(ctx, n);
417417
} else if (to_eval) {
418418
auto eval = EvaluateNode(ctx, n);
419-
if (eval && n->outputs().size() > 0) {
419+
if (eval) {
420420
if (n->outputs().size() > 1) { // For ListUnpack scenario
421421
if (eval.value().isTuple()) {
422422
auto eval_list = eval.value().toTuple();

core/conversion/evaluators/aten.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,14 +288,18 @@ auto aten_registrations TORCHTRT_UNUSED =
288288
.evaluator({c10::Symbol::fromQualString("aten::extend"),
289289
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
290290
if (args.at(n->input(0)).IValue()->isList() && args.at(n->input(1)).IValue()->isList()) {
291-
auto self = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
291+
c10::IValue* self_ptr = args.at(n->input(0)).IValueMut();
292+
auto self = self_ptr->to<c10::List<c10::IValue>>();
292293
auto other = args.at(n->input(1)).IValue()->to<c10::List<c10::IValue>>();
293294
const int64_t other_size = other.size();
294295

296+
// Modify value in place
295297
for (int64_t i = 0; i < other_size; i++) {
296298
self.push_back(other.get(i));
297299
}
298-
return self;
300+
301+
*self_ptr = c10::IValue(self);
302+
return {};
299303
} else {
300304
TORCHTRT_THROW_ERROR(
301305
"Unimplemented data type for aten::extend.t evaluator: "

core/conversion/var/Var.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Var::Var() {
1313
type_ = Type::kNone;
1414
}
1515

16-
Var::Var(const torch::jit::IValue* p) : type_(Type::kIValue) {
16+
Var::Var(torch::jit::IValue* p) : type_(Type::kIValue) {
1717
ptr_.ivalue = p;
1818
}
1919

@@ -56,7 +56,7 @@ Var& Var::operator=(const Var& a) {
5656
return (*this);
5757
}
5858

59-
Var& Var::operator=(const torch::jit::IValue* in) {
59+
Var& Var::operator=(torch::jit::IValue* in) {
6060
ptr_.ivalue = in;
6161
type_ = Type::kIValue;
6262
return (*this);
@@ -116,6 +116,10 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
116116
}
117117

118118
const torch::jit::IValue* Var::IValue() const {
119+
return IValueMut();
120+
}
121+
122+
torch::jit::IValue* Var::IValueMut() const {
119123
TORCHTRT_CHECK(isIValue(), "Requested IValue from Var, however Var type is " << type_name());
120124
if (type_ == Type::kIValue) {
121125
return ptr_.ivalue;

core/conversion/var/Var.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@ class Var : torch::CustomClassHolder {
1717
enum Type { kITensor, kIValue, kNone };
1818

1919
Var();
20-
Var(const torch::jit::IValue* p);
20+
Var(torch::jit::IValue* p);
2121
Var(nvinfer1::ITensor* p);
2222
Var(const Var& a);
2323
Var& operator=(const Var& a);
24-
Var& operator=(const torch::jit::IValue* in);
24+
Var& operator=(torch::jit::IValue* in);
2525
Var& operator=(nvinfer1::ITensor* in);
2626
const torch::jit::IValue* IValue() const;
27+
torch::jit::IValue* IValueMut() const;
2728
nvinfer1::ITensor* ITensor() const;
2829

2930
// TODO: Can we consolidate this in a way that prevents requesting invalid
@@ -63,7 +64,7 @@ class Var : torch::CustomClassHolder {
6364

6465
private:
6566
union VarContainer {
66-
const torch::jit::IValue* ivalue;
67+
torch::jit::IValue* ivalue;
6768
nvinfer1::ITensor* tensor;
6869
void* none;
6970
};

0 commit comments

Comments
 (0)