Skip to content

Commit 3a858ba

Browse files
committed
Update on "port spmm_sum to pytorch and optimize it on CPU"
### Motivation of this PR This patch is to migrate `spmm_reduce` from `torch-sparse` (a 3rd party dependency for PyG) to `torch`, which is a response to the initial proposal for fusion of **Gather, Apply Scatter** in Message Passing of GNN inference/training. pytorch#71300 **GAS** is the major step for Message Passing, the behavior of **GAS** can be classified into 2 kinds depending on the storage type of `EdgeIndex` which records the connections of nodes: * COO: the hotspot is `scatter_reduce` * CSR: the hotspot is `spmm_reduce` The reduce type can be choose from: "max", "mean", "max", "min". `spmm_reduce` is registered under the TensorTypeId of `SparseCsrCPU`, and this operator requires an internal interface `_spmm_reduce` which has dual outputs: * `out` - the actual output * `arg_out` - records output indices in the non zero elements if the reduce type is "max" or "min", this is only useful for training. So for inference, it will not be calculated. ### Performance Benchmark on GCN for obgn-products on Xeon single socket, the workload is improved by `4.3x` with this patch. Performance benefit for training will be bigger, the original backward impl for `sum|mean` is sequential; the original backward impl for `max|min` is not fused. #### before: ``` ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ torch_sparse::spmm_sum 97.09% 56.086s 97.09% 56.088s 6.232s 9 aten::linear 0.00% 85.000us 1.38% 795.485ms 88.387ms 9 aten::matmul 0.00% 57.000us 1.38% 795.260ms 88.362ms 9 aten::mm 1.38% 795.201ms 1.38% 795.203ms 88.356ms 9 aten::relu 0.00% 50.000us 0.76% 440.434ms 73.406ms 6 aten::clamp_min 0.76% 440.384ms 0.76% 440.384ms 73.397ms 6 aten::add_ 0.57% 327.801ms 0.57% 327.801ms 36.422ms 9 aten::log_softmax 0.00% 23.000us 0.10% 55.503ms 18.501ms 3 ``` #### after ``` ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::spmm_sum 87.35% 11.826s 87.36% 11.827s 1.314s 9 aten::linear 0.00% 92.000us 5.87% 794.451ms 88.272ms 9 aten::matmul 0.00% 62.000us 5.87% 794.208ms 88.245ms 9 aten::mm 5.87% 794.143ms 5.87% 794.146ms 88.238ms 9 aten::relu 0.00% 53.000us 3.35% 452.977ms 75.496ms 6 aten::clamp_min 3.35% 452.924ms 3.35% 452.924ms 75.487ms 6 aten::add_ 2.58% 348.663ms 2.58% 348.663ms 38.740ms 9 aten::argmax 0.42% 57.473ms 0.42% 57.475ms 14.369ms 4 aten::log_softmax 0.00% 22.000us 0.39% 52.605ms 17.535ms 3 ``` [ghstack-poisoned]
2 parents 685d432 + 724c74d commit 3a858ba

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

aten/src/ATen/native/sparse/SparseCsrTensorMath.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ inline void _check_dim(const Tensor& self, int64_t target_dim, c10::string_view
6161
" instead.");
6262
}
6363

64-
SPMM_REDUCE_OP get_operator_enum(const c10::string_view reduce) {
64+
inline SPMM_REDUCE_OP get_operator_enum(const c10::string_view reduce) {
6565
if (reduce == "sum") {
6666
return SPMM_SUM;
6767
} else if (reduce == "mean") {
@@ -76,7 +76,7 @@ SPMM_REDUCE_OP get_operator_enum(const c10::string_view reduce) {
7676
}
7777

7878
template <bool train>
79-
void check_spmm_reduce_inputs(
79+
inline void check_spmm_reduce_inputs(
8080
const Tensor& input,
8181
const Tensor& grad_output,
8282
const Tensor& weight,

0 commit comments

Comments
 (0)