|
1 | | -#include <stdexcept> |
2 | 1 | #include "torch/csrc/jit/tensorexpr/mem_arena.h" |
3 | 2 |
|
4 | 3 | namespace torch { |
5 | 4 | namespace jit { |
6 | 5 | namespace tensorexpr { |
7 | 6 |
|
| 7 | +namespace { |
| 8 | +// Define in an anonymous namespace to hide this symbol from other compilation |
| 9 | +// units |
| 10 | +thread_local KernelArena* current_arena = nullptr; |
| 11 | +} |
| 12 | + |
8 | 13 | KernelArena::~KernelArena() { |
9 | 14 | for (KernelScopedObject* p : kernel_objects_) { |
10 | 15 | delete p; |
11 | 16 | } |
12 | 17 | } |
13 | 18 |
|
14 | 19 | KernelScopedObject::KernelScopedObject() { |
15 | | - KernelArena& kernel = KernelArena::GetCurrentKernelArena(); |
16 | | - kernel.kernel_objects_.push_back(this); |
| 20 | + KernelArena* kernel = KernelArena::GetCurrentKernelArena(); |
| 21 | + kernel->kernel_objects_.push_back(this); |
17 | 22 | } |
18 | 23 |
|
19 | 24 | static std::vector<KernelArena*>& GetKernelArenaStack() { |
20 | 25 | thread_local std::vector<KernelArena*> kernel_arena_stack; |
21 | 26 | return kernel_arena_stack; |
22 | 27 | } |
23 | 28 |
|
24 | | -KernelArena& KernelArena::GetCurrentKernelArena() { |
25 | | - std::vector<KernelArena*>& kernel_arena_stack = GetKernelArenaStack(); |
26 | | - if (kernel_arena_stack.empty()) { |
27 | | - throw std::runtime_error( |
28 | | - "A KernelScope must be bound before creating KernelScopedObject"); |
29 | | - } |
30 | | - return *kernel_arena_stack.back(); |
| 29 | +void KernelArena::SetCurrentKernelArena(KernelArena *new_kernel_arena) { |
| 30 | + current_arena = new_kernel_arena; |
31 | 31 | } |
32 | 32 |
|
33 | | -KernelScope::KernelScope() : owning_kernel_arena_(true) { |
34 | | - kernel_arena_ = new KernelArena; |
35 | | - GetKernelArenaStack().push_back(kernel_arena_); |
| 33 | +KernelArena* KernelArena::GetCurrentKernelArena() { |
| 34 | + return current_arena; |
36 | 35 | } |
37 | 36 |
|
38 | | -KernelScope::KernelScope(KernelArena& kernel_arena) |
39 | | - : owning_kernel_arena_(false) { |
40 | | - kernel_arena_ = &kernel_arena; |
41 | | - GetKernelArenaStack().push_back(&kernel_arena); |
| 37 | +KernelScope::KernelScope() : owning_(true) { |
| 38 | + old_kernel_arena_ = KernelArena::GetCurrentKernelArena(); |
| 39 | + KernelArena::SetCurrentKernelArena(new KernelArena); |
42 | 40 | } |
43 | 41 |
|
44 | | -KernelScope::~KernelScope() noexcept(false) { |
45 | | - std::vector<KernelArena*>& kernel_arena_stack = GetKernelArenaStack(); |
46 | | - if (kernel_arena_ != kernel_arena_stack.back()) { |
47 | | - throw std::runtime_error("Mismatch KernelScope and kernel"); |
48 | | - } |
49 | | - if (owning_kernel_arena_) { |
50 | | - delete kernel_arena_; |
| 42 | +KernelScope::KernelScope(KernelArena* arena_) : owning_(false) { |
| 43 | + old_kernel_arena_ = KernelArena::GetCurrentKernelArena(); |
| 44 | + KernelArena::SetCurrentKernelArena(arena_); |
| 45 | +} |
| 46 | + |
| 47 | +KernelScope::~KernelScope() { |
| 48 | + if (owning_) { |
| 49 | + delete KernelArena::GetCurrentKernelArena(); |
51 | 50 | } |
52 | | - kernel_arena_stack.pop_back(); |
| 51 | + KernelArena::SetCurrentKernelArena(old_kernel_arena_); |
53 | 52 | } |
54 | 53 |
|
55 | 54 | } // namespace tensorexpr |
|
0 commit comments