Skip to content

Commit 41fecd8

Browse files
crusaderkyshoyer
authored andcommitted
__slots__ (#3250)
* Add __slots__ to most classes * Enforced __slots__for all classes; remove _initialized * Speed up __setattr__ * Fix accessors * DeprecationWarning -> FutureWarning * IndexingSupport enum * What's New * Unit tests * Trivial docstrings and comments tweak * Don't expose accessors in Dataset._replace()
1 parent f864718 commit 41fecd8

25 files changed

+373
-76
lines changed

doc/whats-new.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,24 @@ Breaking changes
3131
- The ``inplace`` kwarg for public methods now raises an error, having been deprecated
3232
since v0.11.0.
3333
By `Maximilian Roos <https://github.com/max-sixty>`_
34+
- Most xarray objects now define ``__slots__``. This reduces overall RAM usage by ~22%
35+
(not counting the underlying numpy buffers); on CPython 3.7/x64, a trivial DataArray
36+
has gone down from 1.9kB to 1.5kB.
37+
38+
Caveats:
39+
40+
- Pickle streams produced by older versions of xarray can't be loaded using this
41+
release, and vice versa.
42+
- Any user code that was accessing the ``__dict__`` attribute of
43+
xarray objects will break. The best practice to attach custom metadata to xarray
44+
objects is to use the ``attrs`` dictionary.
45+
- Any user code that defines custom subclasses of xarray classes must now explicitly
46+
define ``__slots__`` itself. Subclasses that don't add any attributes must state so
47+
by defining ``__slots__ = ()`` right after the class header.
48+
Omitting ``__slots__`` will now cause a ``FutureWarning`` to be logged, and a hard
49+
crash in a later release.
50+
51+
(:issue:`3250`) by `Guido Imperiale <https://github.com/crusaderky>`_.
3452

3553
New functions/methods
3654
~~~~~~~~~~~~~~~~~~~~~

xarray/backends/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,8 @@ def open_dataarray(
694694

695695

696696
class _MultiFileCloser:
697+
__slots__ = ("file_objs",)
698+
697699
def __init__(self, file_objs):
698700
self.file_objs = file_objs
699701

xarray/backends/common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,16 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500
6868

6969

7070
class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed):
71+
__slots__ = ()
72+
7173
def __array__(self, dtype=None):
7274
key = indexing.BasicIndexer((slice(None),) * self.ndim)
7375
return np.asarray(self[key], dtype=dtype)
7476

7577

7678
class AbstractDataStore(Mapping):
79+
__slots__ = ()
80+
7781
def __iter__(self):
7882
return iter(self.variables)
7983

@@ -165,6 +169,8 @@ def __exit__(self, exception_type, exception_value, traceback):
165169

166170

167171
class ArrayWriter:
172+
__slots__ = ("sources", "targets", "regions", "lock")
173+
168174
def __init__(self, lock=None):
169175
self.sources = []
170176
self.targets = []
@@ -205,6 +211,8 @@ def sync(self, compute=True):
205211

206212

207213
class AbstractWritableDataStore(AbstractDataStore):
214+
__slots__ = ()
215+
208216
def encode(self, variables, attributes):
209217
"""
210218
Encode the variables and attributes in this store
@@ -371,6 +379,8 @@ def set_dimensions(self, variables, unlimited_dims=None):
371379

372380

373381
class WritableCFDataStore(AbstractWritableDataStore):
382+
__slots__ = ()
383+
374384
def encode(self, variables, attributes):
375385
# All NetCDF files get CF encoded by default, without this attempting
376386
# to write times, for example, would fail.

xarray/backends/netCDF4_.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131

3232
class BaseNetCDF4Array(BackendArray):
33+
__slots__ = ("datastore", "dtype", "shape", "variable_name")
34+
3335
def __init__(self, variable_name, datastore):
3436
self.datastore = datastore
3537
self.variable_name = variable_name
@@ -52,8 +54,13 @@ def __setitem__(self, key, value):
5254
if self.datastore.autoclose:
5355
self.datastore.close(needs_lock=False)
5456

57+
def get_array(self, needs_lock=True):
58+
raise NotImplementedError("Virtual Method")
59+
5560

5661
class NetCDF4ArrayWrapper(BaseNetCDF4Array):
62+
__slots__ = ()
63+
5764
def get_array(self, needs_lock=True):
5865
ds = self.datastore._acquire(needs_lock)
5966
variable = ds.variables[self.variable_name]
@@ -294,6 +301,17 @@ class NetCDF4DataStore(WritableCFDataStore):
294301
This store supports NetCDF3, NetCDF4 and OpenDAP datasets.
295302
"""
296303

304+
__slots__ = (
305+
"autoclose",
306+
"format",
307+
"is_remote",
308+
"lock",
309+
"_filename",
310+
"_group",
311+
"_manager",
312+
"_mode",
313+
)
314+
297315
def __init__(
298316
self, manager, group=None, mode=None, lock=NETCDF4_PYTHON_LOCK, autoclose=False
299317
):

xarray/backends/zarr.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def _encode_zarr_attr_value(value):
2929

3030

3131
class ZarrArrayWrapper(BackendArray):
32+
__slots__ = ("datastore", "dtype", "shape", "variable_name")
33+
3234
def __init__(self, variable_name, datastore):
3335
self.datastore = datastore
3436
self.variable_name = variable_name
@@ -231,6 +233,15 @@ class ZarrStore(AbstractWritableDataStore):
231233
"""Store for reading and writing data via zarr
232234
"""
233235

236+
__slots__ = (
237+
"append_dim",
238+
"ds",
239+
"_consolidate_on_close",
240+
"_group",
241+
"_read_only",
242+
"_synchronizer",
243+
)
244+
234245
@classmethod
235246
def open_group(
236247
cls,

xarray/conventions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin):
3131
dtype('int16')
3232
"""
3333

34+
__slots__ = ("array",)
35+
3436
def __init__(self, array):
3537
self.array = indexing.as_indexable(array)
3638

@@ -60,6 +62,8 @@ class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin):
6062
dtype('bool')
6163
"""
6264

65+
__slots__ = ("array",)
66+
6367
def __init__(self, array):
6468
self.array = indexing.as_indexable(array)
6569

xarray/core/accessor_str.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ class StringAccessor:
7575
7676
"""
7777

78+
__slots__ = ("_obj",)
79+
7880
def __init__(self, obj):
7981
self._obj = obj
8082

xarray/core/arithmetic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class SupportsArithmetic:
1414
Used by Dataset, DataArray, Variable and GroupBy.
1515
"""
1616

17+
__slots__ = ()
18+
1719
# TODO: implement special methods for arithmetic here rather than injecting
1820
# them in xarray/core/ops.py. Ideally, do so by inheriting from
1921
# numpy.lib.mixins.NDArrayOperatorsMixin.

xarray/core/common.py

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections import OrderedDict
23
from contextlib import suppress
34
from textwrap import dedent
@@ -35,6 +36,8 @@
3536

3637

3738
class ImplementsArrayReduce:
39+
__slots__ = ()
40+
3841
@classmethod
3942
def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool):
4043
if include_skipna:
@@ -72,6 +75,8 @@ def wrapped_func(self, dim=None, axis=None, **kwargs): # type: ignore
7275

7376

7477
class ImplementsDatasetReduce:
78+
__slots__ = ()
79+
7580
@classmethod
7681
def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool):
7782
if include_skipna:
@@ -110,6 +115,8 @@ class AbstractArray(ImplementsArrayReduce):
110115
"""Shared base class for DataArray and Variable.
111116
"""
112117

118+
__slots__ = ()
119+
113120
def __bool__(self: Any) -> bool:
114121
return bool(self.values)
115122

@@ -180,7 +187,25 @@ class AttrAccessMixin:
180187
"""Mixin class that allows getting keys with attribute access
181188
"""
182189

183-
_initialized = False
190+
__slots__ = ()
191+
192+
def __init_subclass__(cls):
193+
"""Verify that all subclasses explicitly define ``__slots__``. If they don't,
194+
raise error in the core xarray module and a FutureWarning in third-party
195+
extensions.
196+
This check is only triggered in Python 3.6+.
197+
"""
198+
if not hasattr(object.__new__(cls), "__dict__"):
199+
cls.__setattr__ = cls._setattr_slots
200+
elif cls.__module__.startswith("xarray."):
201+
raise AttributeError("%s must explicitly define __slots__" % cls.__name__)
202+
else:
203+
cls.__setattr__ = cls._setattr_dict
204+
warnings.warn(
205+
"xarray subclass %s should explicitly define __slots__" % cls.__name__,
206+
FutureWarning,
207+
stacklevel=2,
208+
)
184209

185210
@property
186211
def _attr_sources(self) -> List[Mapping[Hashable, Any]]:
@@ -195,7 +220,7 @@ def _item_sources(self) -> List[Mapping[Hashable, Any]]:
195220
return []
196221

197222
def __getattr__(self, name: str) -> Any:
198-
if name != "__setstate__":
223+
if name not in {"__dict__", "__setstate__"}:
199224
# this avoids an infinite loop when pickle looks for the
200225
# __setstate__ attribute before the xarray object is initialized
201226
for source in self._attr_sources:
@@ -205,20 +230,52 @@ def __getattr__(self, name: str) -> Any:
205230
"%r object has no attribute %r" % (type(self).__name__, name)
206231
)
207232

208-
def __setattr__(self, name: str, value: Any) -> None:
209-
if self._initialized:
210-
try:
211-
# Allow setting instance variables if they already exist
212-
# (e.g., _attrs). We use __getattribute__ instead of hasattr
213-
# to avoid key lookups with attribute-style access.
214-
self.__getattribute__(name)
215-
except AttributeError:
216-
raise AttributeError(
217-
"cannot set attribute %r on a %r object. Use __setitem__ "
218-
"style assignment (e.g., `ds['name'] = ...`) instead to "
219-
"assign variables." % (name, type(self).__name__)
220-
)
233+
# This complicated three-method design boosts overall performance of simple
234+
# operations - particularly DataArray methods that perform a _to_temp_dataset()
235+
# round-trip - by a whopping 8% compared to a single method that checks
236+
# hasattr(self, "__dict__") at runtime before every single assignment (like
237+
# _setattr_py35 does). All of this is just temporary until the FutureWarning can be
238+
# changed into a hard crash.
239+
def _setattr_dict(self, name: str, value: Any) -> None:
240+
"""Deprecated third party subclass (see ``__init_subclass__`` above)
241+
"""
221242
object.__setattr__(self, name, value)
243+
if name in self.__dict__:
244+
# Custom, non-slotted attr, or improperly assigned variable?
245+
warnings.warn(
246+
"Setting attribute %r on a %r object. Explicitly define __slots__ "
247+
"to suppress this warning for legitimate custom attributes and "
248+
"raise an error when attempting variables assignments."
249+
% (name, type(self).__name__),
250+
FutureWarning,
251+
stacklevel=2,
252+
)
253+
254+
def _setattr_slots(self, name: str, value: Any) -> None:
255+
"""Objects with ``__slots__`` raise AttributeError if you try setting an
256+
undeclared attribute. This is desirable, but the error message could use some
257+
improvement.
258+
"""
259+
try:
260+
object.__setattr__(self, name, value)
261+
except AttributeError as e:
262+
# Don't accidentally shadow custom AttributeErrors, e.g.
263+
# DataArray.dims.setter
264+
if str(e) != "%r object has no attribute %r" % (type(self).__name__, name):
265+
raise
266+
raise AttributeError(
267+
"cannot set attribute %r on a %r object. Use __setitem__ style"
268+
"assignment (e.g., `ds['name'] = ...`) instead of assigning variables."
269+
% (name, type(self).__name__)
270+
) from e
271+
272+
def _setattr_py35(self, name: str, value: Any) -> None:
273+
if hasattr(self, "__dict__"):
274+
return self._setattr_dict(name, value)
275+
return self._setattr_slots(name, value)
276+
277+
# Overridden in Python >=3.6 by __init_subclass__
278+
__setattr__ = _setattr_py35
222279

223280
def __dir__(self) -> List[str]:
224281
"""Provide method name lookup and completion. Only provide 'public'
@@ -283,6 +340,8 @@ def get_squeeze_dims(
283340
class DataWithCoords(SupportsArithmetic, AttrAccessMixin):
284341
"""Shared base class for Dataset and DataArray."""
285342

343+
__slots__ = ()
344+
286345
_rolling_exp_cls = RollingExp
287346

288347
def squeeze(

xarray/core/computation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ class _UFuncSignature:
5151
Core dimension names on each output variable.
5252
"""
5353

54+
__slots__ = (
55+
"input_core_dims",
56+
"output_core_dims",
57+
"_all_input_core_dims",
58+
"_all_output_core_dims",
59+
"_all_core_dims",
60+
)
61+
5462
def __init__(self, input_core_dims, output_core_dims=((),)):
5563
self.input_core_dims = tuple(tuple(a) for a in input_core_dims)
5664
self.output_core_dims = tuple(tuple(a) for a in output_core_dims)

xarray/core/coordinates.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636

3737
class AbstractCoordinates(Mapping[Hashable, "DataArray"]):
38-
_data = None # type: Union["DataArray", "Dataset"]
38+
__slots__ = ()
3939

4040
def __getitem__(self, key: Hashable) -> "DataArray":
4141
raise NotImplementedError()
@@ -53,7 +53,7 @@ def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]:
5353

5454
@property
5555
def indexes(self) -> Indexes:
56-
return self._data.indexes
56+
return self._data.indexes # type: ignore
5757

5858
@property
5959
def variables(self):
@@ -108,9 +108,9 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
108108
raise ValueError("no valid index for a 0-dimensional object")
109109
elif len(ordered_dims) == 1:
110110
(dim,) = ordered_dims
111-
return self._data.get_index(dim)
111+
return self._data.get_index(dim) # type: ignore
112112
else:
113-
indexes = [self._data.get_index(k) for k in ordered_dims]
113+
indexes = [self._data.get_index(k) for k in ordered_dims] # type: ignore
114114
names = list(ordered_dims)
115115
return pd.MultiIndex.from_product(indexes, names=names)
116116

@@ -187,7 +187,7 @@ class DatasetCoordinates(AbstractCoordinates):
187187
objects.
188188
"""
189189

190-
_data = None # type: Dataset
190+
__slots__ = ("_data",)
191191

192192
def __init__(self, dataset: "Dataset"):
193193
self._data = dataset
@@ -258,7 +258,7 @@ class DataArrayCoordinates(AbstractCoordinates):
258258
dimensions and the values given by corresponding DataArray objects.
259259
"""
260260

261-
_data = None # type: DataArray
261+
__slots__ = ("_data",)
262262

263263
def __init__(self, dataarray: "DataArray"):
264264
self._data = dataarray
@@ -314,6 +314,8 @@ class LevelCoordinatesSource(Mapping[Hashable, Any]):
314314
by any public methods.
315315
"""
316316

317+
__slots__ = ("_data",)
318+
317319
def __init__(self, data_object: "Union[DataArray, Dataset]"):
318320
self._data = data_object
319321

0 commit comments

Comments
 (0)