Skip to content

Commit d41c22b

Browse files
janselpytorchmergebot
authored andcommitted
Revert "[fx] Move Node._prepend/Node._remove_from_list to C++ (pytorch#148261)" (pytorch#150542)
Reverts pytorch#148261 due to possible memory leak This reverts commit 5d4e7d5. Pull Request resolved: pytorch#150542 Approved by: https://github.com/clee2000
1 parent 277369a commit d41c22b

File tree

4 files changed

+66
-226
lines changed

4 files changed

+66
-226
lines changed
Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,77 @@
1-
add_loop_eager,compile_time_instruction_count,2866000000,0.015
1+
add_loop_eager,compile_time_instruction_count,2926000000,0.015
22

33

44

5-
add_loop_eager_dynamic,compile_time_instruction_count,5460000000,0.025
5+
add_loop_eager_dynamic,compile_time_instruction_count,5637000000,0.025
66

77

88

9-
add_loop_inductor,compile_time_instruction_count,27660000000,0.015
9+
add_loop_inductor,compile_time_instruction_count,28680000000,0.015
1010

1111

1212

13-
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40640000000,0.025
13+
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42170000000,0.025
1414

1515

1616

17-
add_loop_inductor_gpu,compile_time_instruction_count,23970000000,0.015
17+
add_loop_inductor_gpu,compile_time_instruction_count,24980000000,0.015
1818

1919

2020

21-
basic_modules_ListOfLinears_eager,compile_time_instruction_count,953800000,0.015
21+
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969300000,0.015
2222

2323

2424

25-
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17190000000,0.015
25+
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17840000000,0.015
2626

2727

2828

29-
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15410000000,0.015
29+
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15990000000,0.015
3030

3131

3232

3333
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,9714000000,0.2
3434

3535

3636

37-
update_hint_regression,compile_time_instruction_count,1523000000,0.02
37+
update_hint_regression,compile_time_instruction_count,1593000000,0.02
3838

3939

4040

41-
float_args,compile_time_instruction_count,413700000,0.015
41+
float_args,compile_time_instruction_count,416400000,0.015
4242

4343

44-
sum_floordiv_regression,compile_time_instruction_count,970100000,0.015
4544

45+
sum_floordiv_regression,compile_time_instruction_count,989900000,0.015
4646

4747

48-
symint_sum,compile_time_instruction_count,3080000000,0.015
4948

49+
symint_sum,compile_time_instruction_count,3164000000,0.015
5050

5151

52-
symint_sum_loop,compile_time_instruction_count,3988000000,0.015
5352

53+
symint_sum_loop,compile_time_instruction_count,4142000000,0.015
5454

5555

56-
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1989000000,0.015
5756

57+
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2034000000,0.015
5858

5959

60-
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5759000000,0.015
6160

61+
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5880000000,0.015
6262

6363

64-
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7873000000,0.015
6564

65+
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8419000000,0.015
6666

6767

68-
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1746000000,0.015
6968

69+
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1838000000,0.015
7070

7171

72-
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3579000000,0.015
7372

73+
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3742000000,0.015
7474

7575

76-
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9830000000,0.015
76+
77+
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10190000000,0.015

torch/_C/__init__.pyi.in

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2539,12 +2539,6 @@ class _NodeBase:
25392539
return_type: Any,
25402540
) -> None: ...
25412541
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
2542-
def _prepend(self, n: FxNode) -> None: ...
2543-
def _remove_from_list(self) -> None: ...
2544-
def __lt__(self, n: Self) -> _bool: ...
2545-
def __gt__(self, n: Self) -> _bool: ...
2546-
def __le__(self, n: Self) -> _bool: ...
2547-
def __ge__(self, n: Self) -> _bool: ...
25482542

25492543
class _NodeIter(Iterator):
25502544
def __init__(self, root: FxNode, reversed: _bool) -> None: ...

torch/csrc/fx/node.cpp

Lines changed: 6 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
#include <torch/csrc/fx/node.h>
22

3-
#include <c10/util/SmallVector.h>
43
#include <structmember.h>
54
#include <torch/csrc/utils/object_ptr.h>
65
#include <torch/csrc/utils/pythoncapi_compat.h>
7-
#include <algorithm>
86

97
namespace {
108

11-
using NodeSortKey = c10::SmallVector<int64_t, 4>;
129
struct NodeBase;
1310

1411
// Thrown to exit out of a C++ function and return an error to Python.
@@ -166,22 +163,7 @@ struct NodeBase {
166163
PyObject* users;
167164
PyObject* _repr_fn;
168165
PyObject* meta;
169-
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
170-
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
171-
172-
inline NodeSortKey& sort_key() {
173-
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
174-
}
175-
176-
// Equivalent to:
177-
// p, n = self._prev, self._next
178-
// p._next, n._prev = n, p
179-
inline void remove_from_list() {
180-
NodeBase* p = this->_prev;
181-
NodeBase* n = this->_next;
182-
p->_next = n;
183-
n->_prev = p;
184-
}
166+
PyObject* _sort_key;
185167
};
186168

187169
static PyObject* NodeBase_new(
@@ -191,8 +173,6 @@ static PyObject* NodeBase_new(
191173
PyObject* self = type->tp_alloc(type, 0);
192174
if (!self)
193175
return nullptr;
194-
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
195-
NodeSortKey(); // placement new does not allocate
196176
return self;
197177
}
198178

@@ -221,6 +201,7 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
221201
self->users = PyDict_New();
222202
self->_repr_fn = Py_NewRef(Py_None);
223203
self->meta = PyDict_New();
204+
self->_sort_key = PyTuple_New(0);
224205
return 0;
225206
}
226207

@@ -240,6 +221,7 @@ static struct PyMemberDef NodeBase_members[] = {
240221
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
241222
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
242223
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
224+
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
243225
{nullptr} /* Sentinel */
244226
};
245227

@@ -257,6 +239,7 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
257239
Py_VISIT(self->users);
258240
Py_VISIT(self->_repr_fn);
259241
Py_VISIT(self->meta);
242+
Py_VISIT(self->_sort_key);
260243
return 0;
261244
}
262245

@@ -274,12 +257,12 @@ static int NodeBase_clear(NodeBase* self) {
274257
Py_CLEAR(self->users);
275258
Py_CLEAR(self->_repr_fn);
276259
Py_CLEAR(self->meta);
260+
Py_CLEAR(self->_sort_key);
277261
return 0;
278262
}
279263

280264
static void NodeBase_dealloc(PyObject* self) {
281265
PyObject_GC_UnTrack(self);
282-
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
283266
(void)NodeBase_clear((NodeBase*)self);
284267
Py_TYPE(self)->tp_free(self);
285268
}
@@ -338,191 +321,15 @@ static PyObject* NodeBase__update_args_kwargs(
338321
}
339322
}
340323

341-
static PyObject* NodeBase__remove_from_list(
342-
PyObject* self,
343-
PyObject* _ignored) {
344-
reinterpret_cast<NodeBase*>(self)->remove_from_list();
345-
Py_RETURN_NONE;
346-
}
347-
348-
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
349-
if (self_ == arg) {
350-
Py_RETURN_NONE;
351-
}
352-
if (!is_node(arg)) {
353-
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
354-
return nullptr;
355-
}
356-
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
357-
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
358-
if (self->graph != x->graph) {
359-
PyErr_SetString(
360-
PyExc_AssertionError,
361-
"Attempting to move a Node into a different Graph");
362-
return nullptr;
363-
}
364-
365-
x->remove_from_list();
366-
NodeBase* p = self->_prev;
367-
p->_next = x;
368-
x->_prev = p;
369-
x->_next = self;
370-
self->_prev = x;
371-
372-
// Now compute x.sort_key()
373-
const NodeSortKey& psk = x->_prev->sort_key();
374-
const NodeSortKey& nsk = x->_next->sort_key();
375-
if (psk.size() > nsk.size()) {
376-
// prefix = psk[: len(nsk)+1]
377-
size_t slice_len = nsk.size() + 1;
378-
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
379-
// last element is idx => increment by 1
380-
prefix.back()++;
381-
x->sort_key() = std::move(prefix);
382-
} else if (psk.size() < nsk.size()) {
383-
// prefix = nsk[: len(psk)+1]
384-
size_t slice_len = psk.size() + 1;
385-
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
386-
// last element is idx => decrement by 1
387-
prefix.back()--;
388-
x->sort_key() = std::move(prefix);
389-
} else {
390-
// same length => add a 0
391-
x->sort_key() = psk;
392-
x->sort_key().emplace_back(0);
393-
}
394-
Py_RETURN_NONE;
395-
}
396-
397-
// __lt__(self, other): Return self.sort_key < other.sort_key
398-
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
399-
// METH_O => one argument: 'other'
400-
if (!is_node(other)) {
401-
Py_RETURN_NOTIMPLEMENTED;
402-
}
403-
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
404-
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
405-
bool less = std::lexicographical_compare(
406-
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
407-
if (less)
408-
Py_RETURN_TRUE;
409-
Py_RETURN_FALSE;
410-
}
411-
412-
// __gt__(self, other): Return self.sort_key() > other.sort_key
413-
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
414-
if (!is_node(other)) {
415-
Py_RETURN_NOTIMPLEMENTED;
416-
}
417-
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
418-
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
419-
// "a > b" is equivalent to "b < a"
420-
bool greater = std::lexicographical_compare(
421-
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
422-
if (greater)
423-
Py_RETURN_TRUE;
424-
Py_RETURN_FALSE;
425-
}
426-
427-
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
428-
if (self == other) {
429-
Py_RETURN_TRUE;
430-
}
431-
return NodeBase___gt__(self, other);
432-
}
433-
434-
// __le__(self, other): Return not (self > other)
435-
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
436-
if (self == other) {
437-
Py_RETURN_TRUE;
438-
}
439-
return NodeBase___lt__(self, other);
440-
}
441-
442-
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
443-
// Only used by pickle/__getstate__
444-
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
445-
NodeBase* node = reinterpret_cast<NodeBase*>(self);
446-
const NodeSortKey& vec = node->sort_key();
447-
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
448-
THPObjectPtr tuple(PyTuple_New(n));
449-
if (!tuple) {
450-
return nullptr; // Out of memory
451-
}
452-
for (Py_ssize_t i = 0; i < n; i++) {
453-
PyTuple_SET_ITEM(tuple.get(), i, PyLong_FromSsize_t(vec[i]));
454-
}
455-
return tuple.release();
456-
}
457-
458-
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
459-
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
460-
static int NodeBase_set_sort_key(
461-
PyObject* self,
462-
PyObject* value,
463-
void* /*closure*/) {
464-
NodeBase* node = reinterpret_cast<NodeBase*>(self);
465-
if (!PyTuple_Check(value)) {
466-
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
467-
return -1;
468-
}
469-
Py_ssize_t size = PyTuple_GET_SIZE(value);
470-
NodeSortKey new_vec;
471-
new_vec.reserve(size);
472-
for (Py_ssize_t i = 0; i < size; i++) {
473-
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
474-
if (val == -1 && PyErr_Occurred()) {
475-
return -1;
476-
}
477-
new_vec.emplace_back(val);
478-
}
479-
node->sort_key() = std::move(new_vec);
480-
return 0;
481-
}
482-
483324
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
484325
static PyMethodDef NodeBase_methods[] = {
485326
{"_update_args_kwargs",
486327
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
487328
METH_FASTCALL,
488329
"Internal method: do not call directly."},
489-
{"_remove_from_list",
490-
(PyCFunction)(void*)(NodeBase__remove_from_list),
491-
METH_NOARGS,
492-
"Internal method: do not call directly."},
493-
{"_prepend",
494-
(PyCFunction)(void*)(NodeBase__prepend),
495-
METH_O,
496-
"Internal method: do not call directly."},
497-
{"__lt__",
498-
(PyCFunction)(void*)NodeBase___lt__,
499-
METH_O,
500-
"Return True if self.sort_key < other.sort_key"},
501-
{"__gt__",
502-
(PyCFunction)(void*)NodeBase___gt__,
503-
METH_O,
504-
"Return True if self.sort_key > other.sort_key"},
505-
{"__ge__",
506-
(PyCFunction)(void*)NodeBase___ge__,
507-
METH_O,
508-
"Return True if self.sort_key >= other.sort_key"},
509-
{"__le__",
510-
(PyCFunction)(void*)NodeBase___le__,
511-
METH_O,
512-
"Return True if self.sort_key <= other.sort_key"},
513330
{nullptr, nullptr, 0, nullptr} // Sentinel
514331
};
515332

516-
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
517-
static PyGetSetDef NodeBase_getset[] = {
518-
{"_sort_key", // attribute name in Python
519-
(getter)NodeBase_get_sort_key, // C getter function
520-
(setter)NodeBase_set_sort_key, // C setter function
521-
(char*)"The sort key as a tuple of ints", // docstring
522-
nullptr},
523-
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
524-
};
525-
526333
PyTypeObject NodeBaseType = {
527334
PyVarObject_HEAD_INIT(nullptr, 0)
528335
"torch._C._NodeBase", /* tp_name */
@@ -554,7 +361,7 @@ PyTypeObject NodeBaseType = {
554361
nullptr, /* tp_iternext */
555362
NodeBase_methods, /* tp_methods */
556363
NodeBase_members, /* tp_members */
557-
NodeBase_getset, /* tp_getset */
364+
nullptr, /* tp_getset */
558365
nullptr, /* tp_base */
559366
nullptr, /* tp_dict */
560367
nullptr, /* tp_descr_get */

0 commit comments

Comments
 (0)