1
1
#include < torch/csrc/fx/node.h>
2
2
3
- #include < c10/util/SmallVector.h>
4
3
#include < structmember.h>
5
4
#include < torch/csrc/utils/object_ptr.h>
6
5
#include < torch/csrc/utils/pythoncapi_compat.h>
7
- #include < algorithm>
8
6
9
7
namespace {
10
8
11
- using NodeSortKey = c10::SmallVector<int64_t , 4 >;
12
9
struct NodeBase ;
13
10
14
11
// Thrown to exit out of a C++ function and return an error to Python.
@@ -166,22 +163,7 @@ struct NodeBase {
166
163
PyObject* users;
167
164
PyObject* _repr_fn;
168
165
PyObject* meta;
169
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
170
- alignas (NodeSortKey) char sort_key_buf[sizeof (NodeSortKey)];
171
-
172
- inline NodeSortKey& sort_key () {
173
- return *reinterpret_cast <NodeSortKey*>(sort_key_buf);
174
- }
175
-
176
- // Equivalent to:
177
- // p, n = self._prev, self._next
178
- // p._next, n._prev = n, p
179
- inline void remove_from_list () {
180
- NodeBase* p = this ->_prev ;
181
- NodeBase* n = this ->_next ;
182
- p->_next = n;
183
- n->_prev = p;
184
- }
166
+ PyObject* _sort_key;
185
167
};
186
168
187
169
static PyObject* NodeBase_new (
@@ -191,8 +173,6 @@ static PyObject* NodeBase_new(
191
173
PyObject* self = type->tp_alloc (type, 0 );
192
174
if (!self)
193
175
return nullptr ;
194
- new (reinterpret_cast <NodeBase*>(self)->sort_key_buf )
195
- NodeSortKey (); // placement new does not allocate
196
176
return self;
197
177
}
198
178
@@ -221,6 +201,7 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
221
201
self->users = PyDict_New ();
222
202
self->_repr_fn = Py_NewRef (Py_None);
223
203
self->meta = PyDict_New ();
204
+ self->_sort_key = PyTuple_New (0 );
224
205
return 0 ;
225
206
}
226
207
@@ -240,6 +221,7 @@ static struct PyMemberDef NodeBase_members[] = {
240
221
{" users" , T_OBJECT_EX, offsetof (NodeBase, users), 0 , nullptr },
241
222
{" _repr_fn" , T_OBJECT_EX, offsetof (NodeBase, _repr_fn), 0 , nullptr },
242
223
{" meta" , T_OBJECT_EX, offsetof (NodeBase, meta), 0 , nullptr },
224
+ {" _sort_key" , T_OBJECT_EX, offsetof (NodeBase, _sort_key), 0 , nullptr },
243
225
{nullptr } /* Sentinel */
244
226
};
245
227
@@ -257,6 +239,7 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
257
239
Py_VISIT (self->users );
258
240
Py_VISIT (self->_repr_fn );
259
241
Py_VISIT (self->meta );
242
+ Py_VISIT (self->_sort_key );
260
243
return 0 ;
261
244
}
262
245
@@ -274,12 +257,12 @@ static int NodeBase_clear(NodeBase* self) {
274
257
Py_CLEAR (self->users );
275
258
Py_CLEAR (self->_repr_fn );
276
259
Py_CLEAR (self->meta );
260
+ Py_CLEAR (self->_sort_key );
277
261
return 0 ;
278
262
}
279
263
280
264
static void NodeBase_dealloc (PyObject* self) {
281
265
PyObject_GC_UnTrack (self);
282
- reinterpret_cast <NodeBase*>(self)->sort_key ().~NodeSortKey ();
283
266
(void )NodeBase_clear ((NodeBase*)self);
284
267
Py_TYPE (self)->tp_free (self);
285
268
}
@@ -338,191 +321,15 @@ static PyObject* NodeBase__update_args_kwargs(
338
321
}
339
322
}
340
323
341
- static PyObject* NodeBase__remove_from_list (
342
- PyObject* self,
343
- PyObject* _ignored) {
344
- reinterpret_cast <NodeBase*>(self)->remove_from_list ();
345
- Py_RETURN_NONE;
346
- }
347
-
348
- static PyObject* NodeBase__prepend (PyObject* self_, PyObject* arg) {
349
- if (self_ == arg) {
350
- Py_RETURN_NONE;
351
- }
352
- if (!is_node (arg)) {
353
- PyErr_SetString (PyExc_TypeError, " _prepend() argument must be a Node" );
354
- return nullptr ;
355
- }
356
- NodeBase* self = reinterpret_cast <NodeBase*>(self_);
357
- NodeBase* x = reinterpret_cast <NodeBase*>(arg);
358
- if (self->graph != x->graph ) {
359
- PyErr_SetString (
360
- PyExc_AssertionError,
361
- " Attempting to move a Node into a different Graph" );
362
- return nullptr ;
363
- }
364
-
365
- x->remove_from_list ();
366
- NodeBase* p = self->_prev ;
367
- p->_next = x;
368
- x->_prev = p;
369
- x->_next = self;
370
- self->_prev = x;
371
-
372
- // Now compute x.sort_key()
373
- const NodeSortKey& psk = x->_prev ->sort_key ();
374
- const NodeSortKey& nsk = x->_next ->sort_key ();
375
- if (psk.size () > nsk.size ()) {
376
- // prefix = psk[: len(nsk)+1]
377
- size_t slice_len = nsk.size () + 1 ;
378
- NodeSortKey prefix (psk.begin (), psk.begin () + slice_len);
379
- // last element is idx => increment by 1
380
- prefix.back ()++;
381
- x->sort_key () = std::move (prefix);
382
- } else if (psk.size () < nsk.size ()) {
383
- // prefix = nsk[: len(psk)+1]
384
- size_t slice_len = psk.size () + 1 ;
385
- NodeSortKey prefix (nsk.begin (), nsk.begin () + slice_len);
386
- // last element is idx => decrement by 1
387
- prefix.back ()--;
388
- x->sort_key () = std::move (prefix);
389
- } else {
390
- // same length => add a 0
391
- x->sort_key () = psk;
392
- x->sort_key ().emplace_back (0 );
393
- }
394
- Py_RETURN_NONE;
395
- }
396
-
397
- // __lt__(self, other): Return self.sort_key < other.sort_key
398
- static PyObject* NodeBase___lt__ (PyObject* self, PyObject* other) {
399
- // METH_O => one argument: 'other'
400
- if (!is_node (other)) {
401
- Py_RETURN_NOTIMPLEMENTED;
402
- }
403
- const NodeSortKey& lhs = reinterpret_cast <NodeBase*>(self)->sort_key ();
404
- const NodeSortKey& rhs = reinterpret_cast <NodeBase*>(other)->sort_key ();
405
- bool less = std::lexicographical_compare (
406
- lhs.begin (), lhs.end (), rhs.begin (), rhs.end ());
407
- if (less)
408
- Py_RETURN_TRUE;
409
- Py_RETURN_FALSE;
410
- }
411
-
412
- // __gt__(self, other): Return self.sort_key() > other.sort_key
413
- static PyObject* NodeBase___gt__ (PyObject* self, PyObject* other) {
414
- if (!is_node (other)) {
415
- Py_RETURN_NOTIMPLEMENTED;
416
- }
417
- const NodeSortKey& lhs = reinterpret_cast <NodeBase*>(self)->sort_key ();
418
- const NodeSortKey& rhs = reinterpret_cast <NodeBase*>(other)->sort_key ();
419
- // "a > b" is equivalent to "b < a"
420
- bool greater = std::lexicographical_compare (
421
- rhs.begin (), rhs.end (), lhs.begin (), lhs.end ());
422
- if (greater)
423
- Py_RETURN_TRUE;
424
- Py_RETURN_FALSE;
425
- }
426
-
427
- static PyObject* NodeBase___ge__ (PyObject* self, PyObject* other) {
428
- if (self == other) {
429
- Py_RETURN_TRUE;
430
- }
431
- return NodeBase___gt__ (self, other);
432
- }
433
-
434
- // __le__(self, other): Return not (self > other)
435
- static PyObject* NodeBase___le__ (PyObject* self, PyObject* other) {
436
- if (self == other) {
437
- Py_RETURN_TRUE;
438
- }
439
- return NodeBase___lt__ (self, other);
440
- }
441
-
442
- // Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
443
- // Only used by pickle/__getstate__
444
- static PyObject* NodeBase_get_sort_key (PyObject* self, void * /* closure*/ ) {
445
- NodeBase* node = reinterpret_cast <NodeBase*>(self);
446
- const NodeSortKey& vec = node->sort_key ();
447
- Py_ssize_t n = static_cast <Py_ssize_t>(vec.size ());
448
- THPObjectPtr tuple (PyTuple_New (n));
449
- if (!tuple) {
450
- return nullptr ; // Out of memory
451
- }
452
- for (Py_ssize_t i = 0 ; i < n; i++) {
453
- PyTuple_SET_ITEM (tuple.get (), i, PyLong_FromSsize_t (vec[i]));
454
- }
455
- return tuple.release ();
456
- }
457
-
458
- // Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
459
- // node._sort_key = (1,2,3) Only used by pickle/__setstate__
460
- static int NodeBase_set_sort_key (
461
- PyObject* self,
462
- PyObject* value,
463
- void * /* closure*/ ) {
464
- NodeBase* node = reinterpret_cast <NodeBase*>(self);
465
- if (!PyTuple_Check (value)) {
466
- PyErr_SetString (PyExc_TypeError, " _sort_key must be an tuple of ints" );
467
- return -1 ;
468
- }
469
- Py_ssize_t size = PyTuple_GET_SIZE (value);
470
- NodeSortKey new_vec;
471
- new_vec.reserve (size);
472
- for (Py_ssize_t i = 0 ; i < size; i++) {
473
- int64_t val = PyLong_AsSsize_t (PyTuple_GET_ITEM (value, i));
474
- if (val == -1 && PyErr_Occurred ()) {
475
- return -1 ;
476
- }
477
- new_vec.emplace_back (val);
478
- }
479
- node->sort_key () = std::move (new_vec);
480
- return 0 ;
481
- }
482
-
483
324
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
484
325
static PyMethodDef NodeBase_methods[] = {
485
326
{" _update_args_kwargs" ,
486
327
(PyCFunction)(void *)(NodeBase__update_args_kwargs),
487
328
METH_FASTCALL,
488
329
" Internal method: do not call directly." },
489
- {" _remove_from_list" ,
490
- (PyCFunction)(void *)(NodeBase__remove_from_list),
491
- METH_NOARGS,
492
- " Internal method: do not call directly." },
493
- {" _prepend" ,
494
- (PyCFunction)(void *)(NodeBase__prepend),
495
- METH_O,
496
- " Internal method: do not call directly." },
497
- {" __lt__" ,
498
- (PyCFunction)(void *)NodeBase___lt__,
499
- METH_O,
500
- " Return True if self.sort_key < other.sort_key" },
501
- {" __gt__" ,
502
- (PyCFunction)(void *)NodeBase___gt__,
503
- METH_O,
504
- " Return True if self.sort_key > other.sort_key" },
505
- {" __ge__" ,
506
- (PyCFunction)(void *)NodeBase___ge__,
507
- METH_O,
508
- " Return True if self.sort_key >= other.sort_key" },
509
- {" __le__" ,
510
- (PyCFunction)(void *)NodeBase___le__,
511
- METH_O,
512
- " Return True if self.sort_key <= other.sort_key" },
513
330
{nullptr , nullptr , 0 , nullptr } // Sentinel
514
331
};
515
332
516
- // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
517
- static PyGetSetDef NodeBase_getset[] = {
518
- {" _sort_key" , // attribute name in Python
519
- (getter)NodeBase_get_sort_key, // C getter function
520
- (setter)NodeBase_set_sort_key, // C setter function
521
- (char *)" The sort key as a tuple of ints" , // docstring
522
- nullptr },
523
- {nullptr , nullptr , nullptr , nullptr , nullptr } // Sentinel
524
- };
525
-
526
333
PyTypeObject NodeBaseType = {
527
334
PyVarObject_HEAD_INIT (nullptr , 0 )
528
335
" torch._C._NodeBase" , /* tp_name */
@@ -554,7 +361,7 @@ PyTypeObject NodeBaseType = {
554
361
nullptr , /* tp_iternext */
555
362
NodeBase_methods, /* tp_methods */
556
363
NodeBase_members, /* tp_members */
557
- NodeBase_getset , /* tp_getset */
364
+ nullptr , /* tp_getset */
558
365
nullptr , /* tp_base */
559
366
nullptr , /* tp_dict */
560
367
nullptr , /* tp_descr_get */
0 commit comments