Skip to content

Commit 64b2fae

Browse files
authored
remove kept flag after tensor is replaced or popped (#6759)
1 parent b02de70 commit 64b2fae

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

tfjs-converter/src/executor/tensor_list.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ export class TensorList {
143143
const outputElementShape =
144144
inferElementShape(this.elementShape, this.tensors, elementShape);
145145
const tensor = this.tensors.pop();
146+
tensor.kept = false;
146147

147148
assertShapesMatchAllowUndefinedSize(
148149
tensor.shape, elementShape, 'TensorList shape mismatch: ');
@@ -243,6 +244,12 @@ export class TensorList {
243244
assertShapesMatchAllowUndefinedSize(
244245
this.elementShape, tensor.shape, 'TensorList shape mismatch: ');
245246
keep(tensor);
247+
248+
// dispose the previous value if it is replacing.
249+
if (this.tensors[elementIndex] != null) {
250+
this.tensors[elementIndex].kept = false;
251+
}
252+
246253
this.tensors[elementIndex] = tensor;
247254
}
248255

tfjs-converter/src/executor/tensor_list_test.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ describe('TensorList', () => {
122122
it('should create no new tensors', () => {
123123
tensorList.pushBack(tensor);
124124
const numTensors = memory().numTensors;
125-
tensorList.popBack(SHAPE, DTYPE);
125+
const tensorPoped = tensorList.popBack(SHAPE, DTYPE);
126+
expect(tensorPoped.kept).toBeFalsy();
126127
// a new reshaped tensor
127128
expect(memory().numTensors).toEqual(numTensors + 1);
128129
});
@@ -164,6 +165,14 @@ describe('TensorList', () => {
164165
tensorList.setItem(0, tensor);
165166
expect(memory().numTensors).toEqual(numTensors);
166167
});
168+
it('should remove kept flag for replaced tensor', () => {
169+
tensorList = new TensorList([], [-1, 1], DTYPE, SIZE);
170+
tensorList.setItem(0, tensor);
171+
expect(tensor.kept).toBeTruthy();
172+
tensorList.setItem(0, tensor2);
173+
expect(tensor.kept).toBeFalsy();
174+
expect(tensor2.kept).toBeTruthy();
175+
});
167176
});
168177

169178
describe('getItem', () => {

0 commit comments

Comments
 (0)