Skip to content

Commit 9a784c9

Browse files
crcrparkulinseth
authored andcommitted
[ReduceOp] ameliorate custom __eq__ (pytorch#90088)
Improve the completeness of `ReduceOp.__eq__`. Should support the equal operator with the first argument of `RedOpType` and the second of `ReduceOp` in a follow-up. Fixes pytorch#90072 Pull Request resolved: pytorch#90088 Approved by: https://github.com/kwen2501
1 parent f65249b commit 9a784c9

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

test/distributed/test_c10d_common.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,28 @@ def test_reduceop_pickle(self):
16961696
reduce_op = dist._make_nccl_premul_sum(scale)
16971697
self.assertEqual(pickle.loads(pickle.dumps(reduce_op)), reduce_op)
16981698

1699+
# Ref: https://github.com/pytorch/pytorch/issues/90072
1700+
def test_reduceop_equal(self):
1701+
not_reduceop = "abc"
1702+
for reduce_op in (
1703+
c10d.ReduceOp.SUM, c10d.ReduceOp.AVG, c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX,
1704+
c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR,
1705+
):
1706+
reduce_op_obj = c10d.ReduceOp(reduce_op)
1707+
# this calls `ReduceOp.__eq__(self, other)`
1708+
self.assertEqual(reduce_op_obj, reduce_op_obj)
1709+
self.assertEqual(reduce_op_obj, reduce_op)
1710+
self.assertNotEqual(reduce_op_obj, not_reduceop)
1711+
self.assertNotEqual(reduce_op, not_reduceop)
1712+
# TODO(crcrpar): This needs to be `assertEqual` for the associativity even though
1713+
# the comparison of `RedOpType` and `ReduceOp` sounds less likely to happen compared
1714+
# to that of `ReduceOp` and `RedOptype`.
1715+
# this calls `RedOpType.__eq__(self, other)`
1716+
self.assertNotEqual(reduce_op, reduce_op_obj)
1717+
1718+
self.assertFalse(None in (reduce_op, reduce_op_obj))
1719+
self.assertFalse(not_reduceop in (reduce_op, reduce_op_obj))
1720+
16991721

17001722
if __name__ == "__main__":
17011723
assert (

torch/csrc/distributed/c10d/Types.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ struct TORCH_API ReduceOp : torch::CustomClassHolder {
8585
return *this == static_cast<std::uint8_t>(other);
8686
}
8787

88+
// todo(crcrpar): Handle `RedOpType::PREMUL_SUM` with its scaling factor.
8889
bool operator==(const ReduceOp& other) {
8990
return *this == other.op_;
9091
}

torch/csrc/distributed/c10d/init.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,17 +608,26 @@ This class does not support ``__members__`` property.)");
608608
// take hash of `::c10d::ReduceOp`. To avoid losing these functionality, here
609609
// I define some member methods.
610610
reduce_op
611+
// todo(crcrpar): Support `RedOpType == ReduceOp`.
611612
.def(
613+
// This calls `operator==(const ReduceOp::RedOpType)`
612614
"__eq__",
613615
[](const ::c10d::ReduceOp& self,
614616
const ::c10d::ReduceOp::RedOpType& other) {
615617
return self == other;
616618
})
617619
.def(
620+
// This calls `operator==(const ReduceOp)` for the future support of
621+
// `PREMUL_SUM` comparison
618622
"__eq__",
619623
[](const ::c10d::ReduceOp& self, const ::c10d::ReduceOp& other) {
620-
return self == other.op_;
624+
return self == other;
621625
})
626+
.def(
627+
// With the above custom `__eq__`'s, I have to manually support the
628+
// other types.
629+
"__eq__",
630+
[](const ::c10d::ReduceOp& self, py::object) { return false; })
622631
.def(
623632
"__hash__",
624633
[](const ::c10d::ReduceOp& self) {

0 commit comments

Comments
 (0)