Skip to content

Commit b145701

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
add aten.sum.default (#2807)
Summary: Pull Request resolved: #2807 The operator `aten.sum.dim_IntList` could take an empty list as the parameter for `dims`. We modify `vulkan_graph_builder.py` to accommodate the empty list. Moreover, the op `aten.sum.default` is implemented as a [decomposition](https://www.internalfb.com/code/fbsource/[96e496f9db8f92967b4394bd4f60e39ab916740b]/xplat/caffe2/torch/_decomp/decompositions.py?lines=4676) into `aten.sum.dim_IntList` with empty `dims`. So we will support `aten.sum.default` with the changes. Context: `torch.sum(x, ())` and `torch.sum(x)` are two ways to compute the sum of all elements in tensor `x`. Reviewed By: SS-JIA, jorgep31415 Differential Revision: D55630993 fbshipit-source-id: 923d276118e893ff6885b92eb7b4c7cb7a95b374
1 parent f0bfc3c commit b145701

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

backends/vulkan/runtime/graph/ops/impl/Sum.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,17 @@ void add_sum_dim_IntList(
120120
const auto& dims_to_sum = graph.get_val(opt_dim).toIntList();
121121
int64_t in_dim = in_tensor.sizes().size();
122122

123-
for (const auto& dim : dims_to_sum) {
124-
// Normalize (negative) dim into range [0, self.dim() - 1]
125-
int64_t dim_normalized = normalize(dim, in_dim);
126-
dims_set.insert(dim_normalized);
123+
if (dims_to_sum.empty()) {
124+
// If dim is not specified, reduce over all dims
125+
for (int64_t i = 0; i < in_dim; ++i) {
126+
dims_set.insert(i);
127+
}
128+
} else {
129+
for (const auto& dim : dims_to_sum) {
130+
// Normalize (negative) dim into range [0, self.dim() - 1]
131+
int64_t dim_normalized = normalize(dim, in_dim);
132+
dims_set.insert(dim_normalized);
133+
}
127134
}
128135

129136
// Reduce the higher dimensionalities first, otherwise when keepdim is

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,11 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
178178

179179
def create_scalar_list_value(self, arg: List[_ScalarType]) -> int:
180180
new_id = len(self.values)
181-
if isinstance(arg[0], bool):
181+
if len(arg) == 0:
182+
self.values.append(
183+
vk_graph_schema.VkValue(vk_graph_schema.IntList(items=[]))
184+
)
185+
elif isinstance(arg[0], bool):
182186
self.values.append(
183187
vk_graph_schema.VkValue(
184188
vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg])
@@ -227,7 +231,9 @@ def get_or_create_value_for(self, arg: _Argument):
227231
return self.create_scalar_value(arg)
228232
elif isinstance(arg, TensorSpec):
229233
return self.create_tensor_value(arg)
230-
elif isinstance(arg, list) and isinstance(arg[0], _ScalarType):
234+
elif isinstance(arg, list) and (
235+
len(arg) == 0 or isinstance(arg[0], _ScalarType)
236+
):
231237
# pyre-ignore[6]
232238
return self.create_scalar_list_value(arg)
233239
elif isinstance(arg, list) and isinstance(arg[0], Node):

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,25 @@ def forward(self, x):
497497
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
498498
)
499499

500+
def test_vulkan_backend_sum(self):
501+
class SumModule(torch.nn.Module):
502+
def __init__(self):
503+
super().__init__()
504+
505+
def forward(self, x):
506+
x = torch.sum(x, (), keepdim=True)
507+
x = torch.sum(x)
508+
return x
509+
510+
module = SumModule()
511+
sample_inputs = (torch.rand(size=(3, 2, 7, 5), dtype=torch.float32),)
512+
513+
self.lower_module_and_test_output(
514+
module,
515+
sample_inputs,
516+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
517+
)
518+
500519
def test_vulkan_backend_conv2d(self):
501520
class Conv2dModule(torch.nn.Module):
502521
def __init__(self):

0 commit comments

Comments
 (0)