Skip to content

Commit 50e8abb

Browse files
ezyangpytorchmergebot
authored andcommitted
Change SymIntNode into an intrusive pointer (#82548)
This will make the pointer type a single word, which is important for packing it into an int64_t This time, this diff doesn't segfault when you build with DEBUG mode; more details at pybind/pybind11#4099 Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: #82548 Approved by: https://github.com/albanD
1 parent 36120ce commit 50e8abb

File tree

15 files changed

+41
-25
lines changed

15 files changed

+41
-25
lines changed

.github/ci_commit_pins/xla.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
b3342319e96a0becd139019620d8665605b78475
1+
1693c1c728c05a3afda946415acc9dfd6adf421d

c10/core/SymInt.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
namespace c10 {
77

88
std::array<SymIntNode, 2> normalize_symints(SymInt a_, SymInt b_) {
9-
SymIntNode a = a_.is_symbolic() ? a_.toSymIntNodeImpl() : nullptr;
10-
SymIntNode b = b_.is_symbolic() ? b_.toSymIntNodeImpl() : nullptr;
9+
SymIntNode a, b;
10+
if (a_.is_symbolic())
11+
a = a_.toSymIntNodeImpl();
12+
if (b_.is_symbolic())
13+
b = b_.toSymIntNodeImpl();
1114

1215
SymIntNodeImpl* common = a ? a.get() : b.get();
1316
// TODO: technically we need to check that the classes match

c10/core/SymInt.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
#include <c10/macros/Macros.h>
44
#include <c10/util/Exception.h>
5+
#include <c10/util/intrusive_ptr.h>
56

67
#include <memory>
78

89
namespace c10 {
910

1011
class SymIntNodeImpl;
11-
using SymIntNode = std::shared_ptr<SymIntNodeImpl>;
12+
using SymIntNode = c10::intrusive_ptr<SymIntNodeImpl>;
1213

1314
// `SymInt` is a C++ wrapper class around int64_t data_ which and is used to
1415
// represent concrete dimension values.

c10/core/SymIntNodeImpl.h

+9-3
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,26 @@
33
#include <c10/core/SymInt.h>
44
#include <c10/macros/Macros.h>
55
#include <c10/util/Exception.h>
6+
#include <c10/util/intrusive_ptr.h>
67
#include <memory>
78
#include <mutex>
89
#include <vector>
910

1011
namespace c10 {
1112

1213
class SymIntNodeImpl;
13-
using SymIntNode = std::shared_ptr<SymIntNodeImpl>;
14+
using SymIntNode = c10::intrusive_ptr<SymIntNodeImpl>;
1415

15-
class C10_API SymIntNodeImpl
16-
: public std::enable_shared_from_this<SymIntNodeImpl> {
16+
class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target {
1717
public:
1818
c10::SymInt toSymInt();
1919
virtual ~SymIntNodeImpl(){};
20+
21+
template <typename T>
22+
c10::intrusive_ptr<T> dyn_cast() const {
23+
return c10::intrusive_ptr<T>::reclaim_copy(dynamic_cast<T*>(this));
24+
}
25+
2026
// these could be pure virtual when we implement LTC versions
2127
virtual SymIntNode add(const SymIntNode& other) {
2228
TORCH_CHECK(false, "NYI");

c10/core/SymIntTable.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace c10 {
55
uint64_t SymIntTable::addNode(SymIntNode sin) {
66
std::lock_guard<std::mutex> lock(mutex_);
77
auto index = nodes_.size();
8-
nodes_.push_back(sin);
8+
nodes_.push_back(std::move(sin));
99
return index;
1010
}
1111
SymIntNode SymIntTable::getNode(size_t index) {
@@ -17,7 +17,7 @@ SymIntNode SymIntTable::getNode(size_t index) {
1717
c10::SymInt SymIntNodeImpl::toSymInt() {
1818
// We will need to figure out a way
1919
// to dedup nodes
20-
auto sit_sp = this->shared_from_this();
20+
auto sit_sp = SymIntNode::reclaim_copy(this);
2121
return SymInt::toSymInt(sit_sp);
2222
}
2323

c10/test/core/SymInt_test.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ TEST(SymIntTest, ConcreteInts) {
2121
}
2222

2323
TEST(SymIntTest, AddNode) {
24-
auto n = std::make_shared<SymIntNodeImpl>();
24+
auto n = c10::make_intrusive<SymIntNodeImpl>();
2525
auto i = n->toSymInt();
2626
EXPECT_TRUE(i.is_symbolic());
2727
}

test/cpp/lazy/test_lazy_ops.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ TEST(LazyDynamicOpsTest, NarrowCopy) {
9191
auto y = torch::rand({Y_DIM}).to(kLazy);
9292
auto ly = torch::lazy::TryGetLtcTensor(y);
9393
auto dim_node = MakeNode<SizeNode>(ly->GetIrValue(), 0);
94-
auto lmn = std::make_shared<torch::lazy::SymIntNodeImpl>(dim_node);
94+
auto lmn = c10::make_intrusive<torch::lazy::SymIntNodeImpl>(dim_node);
9595
auto z = x.narrow_copy_symint(X_DIM_INDEX, 0, lmn->toSymInt());
9696
AllClose(z.cpu(), x.cpu().narrow_copy(X_DIM_INDEX, 0, Y_DIM));
9797
}

tools/autograd/templates/python_functions.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <torch/csrc/autograd/python_variable.h>
1212
#include <torch/csrc/autograd/saved_variable.h>
1313
#include <pybind11/pybind11.h>
14+
#include <torch/csrc/utils/pybind.h>
1415

1516
// NOTE: See [Sharded File] comment in VariableType
1617

torch/csrc/Size.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <torch/csrc/autograd/python_variable.h>
1313
#include <torch/csrc/jit/frontend/tracer.h>
14+
#include <torch/csrc/utils/pybind.h>
1415

1516
struct THPSize {
1617
PyTupleObject tuple;
@@ -58,6 +59,8 @@ PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
5859
!torch::jit::tracer::isTracing(),
5960
"JIT Tracing of SymInts isn't supported");
6061
auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr();
62+
if (!py_symint)
63+
throw python_error();
6164
PyTuple_SET_ITEM(ret.get(), i, py_symint);
6265
} else {
6366
if (torch::jit::tracer::isTracing()) {

torch/csrc/jit/python/init.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
145145
virtual SymIntNode wrap(int64_t num) override {
146146
py::gil_scoped_acquire acquire;
147147
auto r = getPyObj().attr("wrap")(num);
148-
return std::make_shared<PythonSymIntNodeImpl>(r);
148+
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
149149
}
150150

151151
virtual bool bool_() override {
@@ -166,11 +166,11 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
166166
virtual SymIntNode dispatch_common_(
167167
const char* fname,
168168
const SymIntNode& other) {
169-
auto pother = std::dynamic_pointer_cast<PythonSymIntNodeImpl>(other);
169+
auto pother = dynamic_cast<PythonSymIntNodeImpl*>(other.get());
170170
TORCH_CHECK(pother);
171171
py::gil_scoped_acquire acquire;
172172
auto r = getPyObj().attr(fname)(pother->getPyObj());
173-
return std::make_shared<PythonSymIntNodeImpl>(r);
173+
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
174174
}
175175

176176
virtual SymIntNode add(const SymIntNode& other) override {
@@ -1182,12 +1182,12 @@ void initJITBindings(PyObject* module) {
11821182
.def_static(
11831183
"new_symint",
11841184
[](py::object obj) -> c10::SymIntNode {
1185-
return std::make_shared<PythonSymIntNodeImpl>(obj);
1185+
return c10::make_intrusive<PythonSymIntNodeImpl>(obj);
11861186
})
11871187
.def(
11881188
"get_pyobj",
11891189
[](c10::SymIntNode a) -> py::object {
1190-
if (auto psn = std::dynamic_pointer_cast<PythonSymIntNodeImpl>(a)) {
1190+
if (auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get())) {
11911191
return py::reinterpret_borrow<py::object>(psn->getPyObj());
11921192
}
11931193
return py::none();

torch/csrc/lazy/core/shape_inference.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -452,9 +452,9 @@ std::vector<Shape> compute_shape_expand(
452452
for (const auto idx : c10::irange(_sizes.size())) {
453453
if (_sizes[idx].is_symbolic()) {
454454
c10::SymIntNode symbolicIntNode = _sizes[idx].toSymIntNodeImpl();
455-
auto lazySymIntNode =
456-
std::dynamic_pointer_cast<torch::lazy::SymIntNodeImpl>(
457-
symbolicIntNode);
455+
auto* lazySymIntNode =
456+
dynamic_cast<torch::lazy::SymIntNodeImpl*>(symbolicIntNode.get());
457+
TORCH_INTERNAL_ASSERT(lazySymIntNode);
458458
auto size_node = lazySymIntNode->node_;
459459
auto static_value =
460460
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node)

torch/csrc/lazy/core/tensor_impl.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ LTCTensorImpl::LTCTensorImpl(LazyTensor&& tensor)
9595
for (auto i : c10::irange(rank)) {
9696
auto dim_node = getBackend()->GetIrBuilder()->MakeSizeNode(
9797
this->tensor_->GetIrValue(), i);
98-
auto sn = std::make_shared<torch::lazy::SymIntNodeImpl>(dim_node);
98+
auto sn = c10::make_intrusive<torch::lazy::SymIntNodeImpl>(dim_node);
9999
sym_sizes_.push_back(sn->toSymInt());
100100
}
101101
}

torch/csrc/utils/python_arg_parser.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,9 @@ auto handle_torch_function_no_python_arg_parser(
335335
// NOLINTNEXTLINE(clang-diagnostic-writable-strings)
336336
py::object torch_function =
337337
PyObject_FastGetAttrString(arg.ptr(), torch_function_name_str);
338+
if (!torch_function) {
339+
TORCH_INTERNAL_ASSERT(0);
340+
}
338341

339342
// See https://github.com/pytorch/pytorch/issues/63767
340343
if (PyObject_FastGetAttrString(torch_function.ptr(), "__self__")

torch/csrc/utils/python_arg_parser.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,9 @@ inline bool is_symint_node(py::handle obj) {
487487

488488
inline PyObject* toPyObject(c10::SymInt symint) {
489489
if (symint.is_symbolic()) {
490-
return py::cast(symint.toSymIntNodeImpl()).release().ptr();
490+
auto r = py::cast(symint.toSymIntNodeImpl()).release().ptr();
491+
TORCH_INTERNAL_ASSERT(r);
492+
return r;
491493
} else {
492494
return THPUtils_packInt64(symint.data());
493495
}

torchgen/dest/lazy_ir.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
4848
return f"lazy_{arg.name}_tensorlist"
4949
elif arg.is_symint_or_list:
5050
cpp_type = arg.lazy_type.cpp_type()
51-
return (
52-
f"{cpp_type}(std::dynamic_pointer_cast<torch::lazy::SymIntNodeImpl>"
53-
f"({arg.name}.toSymIntNodeImpl())->node_, 0)"
54-
)
51+
return f"{cpp_type}(dynamic_cast<torch::lazy::SymIntNodeImpl*>({arg.name}.toSymIntNodeImpl().get())->node_, 0)"
5552
return f"lazy_{arg.name}->GetIrValue()"
5653
elif isinstance(arg.lazy_type, OptionalCType):
5754
if arg.is_wrapped_scalar:

0 commit comments

Comments
 (0)