Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ def gen_float_tensor(self, seed, shape, requires_grad=False):

def test_constructor(self):
tensors = []
for i in range(16):
num_tensors = 16
for i in range(num_tensors):
tensors.append(self.gen_float_tensor(i, (i + 1, 128, 128)))
nested_tensor = torch.nestedtensor(tensors)
for i in range(16):
for i in range(num_tensors):
tensors[i].mul_(i + 2)
for i in range(16):
for i in range(num_tensors):
self.assertTrue((tensors[i] != nested_tensor._tensors[i]).all())
self.assertRaises(ValueError, lambda: torch.nestedtensor([]))
self.assertRaises(ValueError, lambda: torch.nestedtensor(torch.tensor([3.0])))
Expand Down
11 changes: 7 additions & 4 deletions torch/nested/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _nary(*args, **kwargs):
func = args[1]
inputs = args[2:]
out = kwargs.get('out', None)
# NOTE: We are disabling broadcasting for now.
# NOTE: We are disabling broadcasting for now. These checks introduce a lot of overhead.
for i in range(1, len(inputs)):
for j in range(len(inputs[i])):
assert inputs[0]._tensors[j].size() == inputs[i]._tensors[j].size()
Expand All @@ -32,13 +32,16 @@ def _nary(*args, **kwargs):
out_tensors.append(out_tensor)
return NestedTensor(out_tensors)
else:
# NOTE: We are disabling broadcasting for now.
# NOTE: We are disabling broadcasting for now. These checks introduce a lot of overhead.
for i in range(len(out)):
assert out._tensors[i].size() == inputs[0]._tensors[i].size()
if out_dtype is not None:
out = out.to(out_dtype)
for i in range(len(inputs[0])):
func(*list(map(lambda x: x._tensors[i], inputs)), out=out._tensors[i])
if all(nested_tensor.is_contiguous() for nested_tensor in inputs):
func(*list(map(lambda x: x.buffer_, inputs)), out=out.buffer_)
else:
for i in range(len(inputs[0])):
func(*list(map(lambda x: x._tensors[i], inputs)), out=out._tensors[i])
return out
return _nary

Expand Down
2 changes: 0 additions & 2 deletions torch/nested/codegen/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def get_unary_functions():
'sqrt',
'tan',
'tanh',
'tril',
'triu',
'trunc']


Expand Down
60 changes: 59 additions & 1 deletion torch/nested/nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def make_nested_tensor(data, dtype=None, device=None, requires_grad=False, pin_m
new_data = new_data.pin_memory()
tensors.append(new_data)

return NestedTensor(tensors)
return NestedTensor(tensors).contiguous()

def as_nestedtensor(data, dtype=None, device=None):
ret = NestedTensor(data)
Expand Down Expand Up @@ -119,6 +119,24 @@ def dim(self):
_verify_tensors(self._tensors)
return self._tensors[0].dim

@property
def dtype(self):
if DEBUG:
_verify_tensors(self._tensors)
return self._tensors[0].dtype

@property
def layout(self):
if DEBUG:
_verify_tensors(self._tensors)
return self._tensors[0].layout

@property
def device(self):
if DEBUG:
_verify_tensors(self._tensors)
return self._tensors[0].device

@property
def shape(self):
raise NotImplementedError()
Expand Down Expand Up @@ -196,3 +214,43 @@ def requires_grad_(self, *args, **kwargs):

def backward(self, *args, **kwargs):
self.__apply(lambda x: x.backward(*args, **kwargs))

# The overhead on this function is very heavy
def is_contiguous(self):
first_data_ptr = self._tensors[0].data_ptr()
current_offset = 0
is_cont = hasattr(self, 'buffer_')
for tensor in self._tensors:
if not is_cont:
return False
test_data_ptr = first_data_ptr + current_offset
is_cont = is_cont and tensor.data_ptr() == test_data_ptr
is_cont = is_cont and tensor.is_contiguous()
current_offset += tensor.numel() * tensor.element_size()
return is_cont

def contiguous(self):
flat_tensors = []
for tensor in self._tensors:
flat_tensors.append(tensor.view(-1))
self.buffer_ = torch.cat(flat_tensors)
current_offset = 0
for i in range(len(self._tensors)):
# This is an unnecessary allocation
new_tensor = torch.empty_like(self._tensors[i],
dtype=self.dtype, layout=self.layout, device=self.device)
with torch.no_grad():
new_tensor.set_(self.buffer_.storage(),
storage_offset=current_offset,
size=self._tensors[i].size(),
stride=self._tensors[i].stride())
new_tensor.requires_grad_(self.requires_grad)
self._tensors[i] = new_tensor
current_offset += self._tensors[i].numel()
return self

def numel(self):
all_numel = 0
for tensor in self._tensors:
all_numel += tensor.numel()
return all_numel