Skip to content

Commit 933c386

Browse files
anntzermeeseeksmachine
authored andcommitted
Backport PR matplotlib#30114: Fix _is_tensorflow_array.
1 parent d72667f commit 933c386

File tree

2 files changed

+43
-38
lines changed

2 files changed

+43
-38
lines changed

lib/matplotlib/cbook.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2311,42 +2311,56 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
23112311

23122312

23132313
def _is_torch_array(x):
2314-
"""Check if 'x' is a PyTorch Tensor."""
2314+
"""Return whether *x* is a PyTorch Tensor."""
23152315
try:
2316-
# we're intentionally not attempting to import torch. If somebody
2317-
# has created a torch array, torch should already be in sys.modules
2318-
return isinstance(x, sys.modules['torch'].Tensor)
2319-
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2320-
# we're attempting to access attributes on imported modules which
2321-
# may have arbitrary user code, so we deliberately catch all exceptions
2322-
return False
2316+
# We're intentionally not attempting to import torch. If somebody
2317+
# has created a torch array, torch should already be in sys.modules.
2318+
tp = sys.modules.get("torch").Tensor
2319+
except AttributeError:
2320+
return False # Module not imported or a nonstandard module with no Tensor attr.
2321+
return (isinstance(tp, type) # Just in case it's a very nonstandard module.
2322+
and isinstance(x, tp))
23232323

23242324

23252325
def _is_jax_array(x):
2326-
"""Check if 'x' is a JAX Array."""
2326+
"""Return whether *x* is a JAX Array."""
23272327
try:
2328-
# we're intentionally not attempting to import jax. If somebody
2329-
# has created a jax array, jax should already be in sys.modules
2330-
return isinstance(x, sys.modules['jax'].Array)
2331-
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2332-
# we're attempting to access attributes on imported modules which
2333-
# may have arbitrary user code, so we deliberately catch all exceptions
2334-
return False
2328+
# We're intentionally not attempting to import jax. If somebody
2329+
# has created a jax array, jax should already be in sys.modules.
2330+
tp = sys.modules.get("jax").Array
2331+
except AttributeError:
2332+
return False # Module not imported or a nonstandard module with no Array attr.
2333+
return (isinstance(tp, type) # Just in case it's a very nonstandard module.
2334+
and isinstance(x, tp))
2335+
2336+
2337+
def _is_pandas_dataframe(x):
2338+
"""Check if *x* is a Pandas DataFrame."""
2339+
try:
2340+
# We're intentionally not attempting to import Pandas. If somebody
2341+
# has created a Pandas DataFrame, Pandas should already be in sys.modules.
2342+
tp = sys.modules.get("pandas").DataFrame
2343+
except AttributeError:
2344+
return False # Module not imported or a nonstandard module with no Array attr.
2345+
return (isinstance(tp, type) # Just in case it's a very nonstandard module.
2346+
and isinstance(x, tp))
23352347

23362348

23372349
def _is_tensorflow_array(x):
2338-
"""Check if 'x' is a TensorFlow Tensor or Variable."""
2350+
"""Return whether *x* is a TensorFlow Tensor or Variable."""
23392351
try:
2340-
# we're intentionally not attempting to import TensorFlow. If somebody
2341-
# has created a TensorFlow array, TensorFlow should already be in sys.modules
2342-
# we use `is_tensor` to not depend on the class structure of TensorFlow
2343-
# arrays, as `tf.Variables` are not instances of `tf.Tensor`
2344-
# (they both convert the same way)
2345-
return isinstance(x, sys.modules['tensorflow'].is_tensor(x))
2346-
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2347-
# we're attempting to access attributes on imported modules which
2348-
# may have arbitrary user code, so we deliberately catch all exceptions
2352+
# We're intentionally not attempting to import TensorFlow. If somebody
2353+
# has created a TensorFlow array, TensorFlow should already be in
2354+
# sys.modules we use `is_tensor` to not depend on the class structure
2355+
# of TensorFlow arrays, as `tf.Variables` are not instances of
2356+
# `tf.Tensor` (they both convert the same way).
2357+
is_tensor = sys.modules.get("tensorflow").is_tensor
2358+
except AttributeError:
23492359
return False
2360+
try:
2361+
return is_tensor(x)
2362+
except Exception:
2363+
return False # Just in case it's a very nonstandard module.
23502364

23512365

23522366
def _unpack_to_numpy(x):
@@ -2401,15 +2415,3 @@ def _auto_format_str(fmt, value):
24012415
return fmt % (value,)
24022416
except (TypeError, ValueError):
24032417
return fmt.format(value)
2404-
2405-
2406-
def _is_pandas_dataframe(x):
2407-
"""Check if 'x' is a Pandas DataFrame."""
2408-
try:
2409-
# we're intentionally not attempting to import Pandas. If somebody
2410-
# has created a Pandas DataFrame, Pandas should already be in sys.modules
2411-
return isinstance(x, sys.modules['pandas'].DataFrame)
2412-
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2413-
# we're attempting to access attributes on imported modules which
2414-
# may have arbitrary user code, so we deliberately catch all exceptions
2415-
return False

lib/matplotlib/tests/test_cbook.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,7 @@ def __array__(self):
983983
torch_tensor = torch.Tensor(data)
984984

985985
result = cbook._unpack_to_numpy(torch_tensor)
986+
assert isinstance(result, np.ndarray)
986987
# compare results, do not check for identity: the latter would fail
987988
# if not mocked, and the implementation does not guarantee it
988989
# is the same Python object, just the same values.
@@ -1011,6 +1012,7 @@ def __array__(self):
10111012
jax_array = jax.Array(data)
10121013

10131014
result = cbook._unpack_to_numpy(jax_array)
1015+
assert isinstance(result, np.ndarray)
10141016
# compare results, do not check for identity: the latter would fail
10151017
# if not mocked, and the implementation does not guarantee it
10161018
# is the same Python object, just the same values.
@@ -1040,6 +1042,7 @@ def __array__(self):
10401042
tf_tensor = tensorflow.Tensor(data)
10411043

10421044
result = cbook._unpack_to_numpy(tf_tensor)
1045+
assert isinstance(result, np.ndarray)
10431046
# compare results, do not check for identity: the latter would fail
10441047
# if not mocked, and the implementation does not guarantee it
10451048
# is the same Python object, just the same values.

0 commit comments

Comments
 (0)