3
3
4
4
import logging
5
5
from typing import Any , Callable
6
- from collections import deque
7
6
8
7
import numpy as np
9
8
from numpy .typing import DTypeLike
@@ -74,20 +73,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
74
73
_tensor_type : type
75
74
_meta : Any
76
75
_data : Any | None
77
- _lazy : deque [LazyBase ] # shared within a graph, to avoid deep recursion when making eager
78
76
_args : tuple
79
- _func : Callable [[tuple ], Any ] | None
77
+ _kwargs : dict [str , Any ]
78
+ _func : Callable [[Any ], Any ] | None
80
79
81
- def __init__ (self , * , meta : Any , data : Any | None = None , lazy : deque [ LazyBase ] | None = None , args : tuple = (), func : Callable [[tuple ], Any ] | None = None ):
80
+ def __init__ (self , * , meta : Any , data : Any | None = None , args : tuple = (), kwargs : dict [ str , Any ] | None = None , func : Callable [[Any ], Any ] | None = None ):
82
81
super ().__init__ ()
83
82
self ._meta = meta
84
83
self ._data = data
85
- self ._lazy = lazy if lazy is not None else deque ()
86
84
self ._args = args
85
+ self ._kwargs = kwargs if kwargs is not None else {}
87
86
self ._func = func
88
87
assert self ._func is not None or self ._data is not None
89
- if self ._data is None :
90
- self ._lazy .append (self )
91
88
92
89
def __init_subclass__ (cls ) -> None :
93
90
if "_tensor_type" not in cls .__dict__ :
@@ -117,6 +114,7 @@ def wrapped_fn(*args, **kwargs):
117
114
args = ((use_self ,) if use_self is not None else ()) + args
118
115
119
116
meta_args = LazyBase ._recurse_apply (args , lambda t : t ._meta )
117
+ # TODO: maybe handle tensors in kwargs too
120
118
121
119
if isinstance (meta_noop , bool ) and not meta_noop :
122
120
try :
@@ -140,23 +138,7 @@ def wrapped_fn(*args, **kwargs):
140
138
res = cls .meta_with_dtype_and_shape (meta_noop , res .shape )
141
139
142
140
if isinstance (res , cls ._tensor_type ):
143
- class CollectSharedLazy :
144
- # emulating a static variable
145
- shared_lazy : None | deque [LazyBase ] = None
146
-
147
- @staticmethod
148
- def collect_replace (t : LazyBase ):
149
- if CollectSharedLazy .shared_lazy is None :
150
- CollectSharedLazy .shared_lazy = t ._lazy
151
- else :
152
- CollectSharedLazy .shared_lazy .extend (t ._lazy )
153
- t ._lazy = CollectSharedLazy .shared_lazy
154
-
155
- LazyBase ._recurse_apply (args , CollectSharedLazy .collect_replace )
156
-
157
- shared_lazy = CollectSharedLazy .shared_lazy
158
-
159
- return cls (meta = cls .eager_to_meta (res ), lazy = shared_lazy , args = args , func = lambda a : fn (* a , ** kwargs ))
141
+ return cls (meta = cls .eager_to_meta (res ), args = args , kwargs = kwargs , func = fn )
160
142
else :
161
143
del res # not needed
162
144
# non-tensor return likely relies on the contents of the args
@@ -168,26 +150,18 @@ def collect_replace(t: LazyBase):
168
150
@classmethod
169
151
def to_eager (cls , t : Any ) -> Any :
170
152
def simple_to_eager (_t : LazyBase ) -> Any :
171
- def already_eager_to_eager (_t : LazyBase ) -> Any :
172
- assert _t ._data is not None
153
+ if _t ._data is not None :
173
154
return _t ._data
174
155
175
- while _t ._data is None :
176
- lt = _t ._lazy .popleft ()
177
- if lt ._data is not None :
178
- # Lazy tensor did not belong in the lazy queue.
179
- # Weirdly only happens with Bloom models...
180
- # likely because tensors aren't unique in the queue.
181
- # The final output is still the same as in eager mode,
182
- # so it's safe to ignore this.
183
- continue
184
- assert lt ._func is not None
185
- lt ._args = cls ._recurse_apply (lt ._args , already_eager_to_eager )
186
- lt ._data = lt ._func (lt ._args )
187
- # sanity check
188
- assert lt ._data is not None
189
- assert lt ._data .dtype == lt ._meta .dtype
190
- assert lt ._data .shape == lt ._meta .shape
156
+ # NOTE: there's a recursion limit in Python (usually 1000)
157
+
158
+ assert _t ._func is not None
159
+ _t ._args = cls ._recurse_apply (_t ._args , simple_to_eager )
160
+ _t ._data = _t ._func (* _t ._args , ** _t ._kwargs )
161
+ # sanity check
162
+ assert _t ._data is not None
163
+ assert _t ._data .dtype == _t ._meta .dtype
164
+ assert _t ._data .shape == _t ._meta .shape
191
165
192
166
return _t ._data
193
167
@@ -206,7 +180,7 @@ def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass
206
180
@classmethod
207
181
def from_eager (cls , t : Any ) -> Any :
208
182
if type (t ) is cls :
209
- # already eager
183
+ # already lazy
210
184
return t
211
185
elif isinstance (t , cls ._tensor_type ):
212
186
return cls (meta = cls .eager_to_meta (t ), data = t )
@@ -228,8 +202,7 @@ def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) ->
228
202
def astype (self , dtype , * args , ** kwargs ):
229
203
meta = type (self ).meta_with_dtype_and_shape (dtype , self ._meta .shape )
230
204
full_args = (self , dtype ,) + args
231
- # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
232
- return type (self )(meta = meta , args = full_args , lazy = self ._lazy , func = (lambda a : a [0 ].astype (* a [1 :], ** kwargs )))
205
+ return type (self )(meta = meta , args = full_args , kwargs = kwargs , func = (lambda a , * args , ** kwargs : a .astype (* args , ** kwargs )))
233
206
234
207
def tofile (self , * args , ** kwargs ):
235
208
eager = LazyNumpyTensor .to_eager (self )
0 commit comments