From d443cb0c57cf6824811435af92a89c6919af1dd4 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 4 Jan 2021 14:40:47 -0800 Subject: [PATCH 1/3] REF: tm._network --- pandas/_testing/__init__.py | 248 +----------------------------------- pandas/_testing/_network.py | 248 ++++++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+), 247 deletions(-) create mode 100644 pandas/_testing/_network.py diff --git a/pandas/_testing/__init__.py b/pandas/_testing/__init__.py index 0591fc6afd633..c57b466f83170 100644 --- a/pandas/_testing/__init__.py +++ b/pandas/_testing/__init__.py @@ -55,6 +55,7 @@ Series, bdate_range, ) +from pandas._testing._network import network, with_connectivity_check # noqa:F401 from pandas._testing.asserters import ( # noqa:F401 assert_almost_equal, assert_attr_equal, @@ -89,13 +90,10 @@ ) from pandas.core.arrays import DatetimeArray, PeriodArray, TimedeltaArray, period_array -from pandas.io.common import urlopen - lzma = import_lzma() _N = 30 _K = 4 -_RAISE_NETWORK_ERROR_DEFAULT = False UNSIGNED_INT_DTYPES: List[Dtype] = ["uint8", "uint16", "uint32", "uint64"] UNSIGNED_EA_INT_DTYPES: List[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"] @@ -935,250 +933,6 @@ def makeMissingDataframe(density=0.9, random_state=None): return df -def optional_args(decorator): - """ - allows a decorator to take optional positional and keyword arguments. - Assumes that taking a single, callable, positional argument means that - it is decorating a function, i.e. something like this:: - - @my_decorator - def function(): pass - - Calls decorator with decorator(f, *args, **kwargs) - """ - - @wraps(decorator) - def wrapper(*args, **kwargs): - def dec(f): - return decorator(f, *args, **kwargs) - - is_decorating = not kwargs and len(args) == 1 and callable(args[0]) - if is_decorating: - f = args[0] - # pandas\_testing.py:2331: error: Incompatible types in assignment - # (expression has type "List[]", variable has type - # "Tuple[Any, ...]") - args = [] # type: ignore[assignment] - return dec(f) - else: - return dec - - return wrapper - - -# skip tests on exceptions with this message -_network_error_messages = ( - # 'urlopen error timed out', - # 'timeout: timed out', - # 'socket.timeout: timed out', - "timed out", - "Server Hangup", - "HTTP Error 503: Service Unavailable", - "502: Proxy Error", - "HTTP Error 502: internal error", - "HTTP Error 502", - "HTTP Error 503", - "HTTP Error 403", - "HTTP Error 400", - "Temporary failure in name resolution", - "Name or service not known", - "Connection refused", - "certificate verify", -) - -# or this e.errno/e.reason.errno -_network_errno_vals = ( - 101, # Network is unreachable - 111, # Connection refused - 110, # Connection timed out - 104, # Connection reset Error - 54, # Connection reset by peer - 60, # urllib.error.URLError: [Errno 60] Connection timed out -) - -# Both of the above shouldn't mask real issues such as 404's -# or refused connections (changed DNS). -# But some tests (test_data yahoo) contact incredibly flakey -# servers. - -# and conditionally raise on exception types in _get_default_network_errors - - -def _get_default_network_errors(): - # Lazy import for http.client because it imports many things from the stdlib - import http.client - - return (IOError, http.client.HTTPException, TimeoutError) - - -def can_connect(url, error_classes=None): - """ - Try to connect to the given url. True if succeeds, False if IOError - raised - - Parameters - ---------- - url : basestring - The URL to try to connect to - - Returns - ------- - connectable : bool - Return True if no IOError (unable to connect) or URLError (bad url) was - raised - """ - if error_classes is None: - error_classes = _get_default_network_errors() - - try: - with urlopen(url): - pass - except error_classes: - return False - else: - return True - - -@optional_args -def network( - t, - url="https://www.google.com", - raise_on_error=_RAISE_NETWORK_ERROR_DEFAULT, - check_before_test=False, - error_classes=None, - skip_errnos=_network_errno_vals, - _skip_on_messages=_network_error_messages, -): - """ - Label a test as requiring network connection and, if an error is - encountered, only raise if it does not find a network connection. - - In comparison to ``network``, this assumes an added contract to your test: - you must assert that, under normal conditions, your test will ONLY fail if - it does not have network connectivity. - - You can call this in 3 ways: as a standard decorator, with keyword - arguments, or with a positional argument that is the url to check. - - Parameters - ---------- - t : callable - The test requiring network connectivity. - url : path - The url to test via ``pandas.io.common.urlopen`` to check - for connectivity. Defaults to 'https://www.google.com'. - raise_on_error : bool - If True, never catches errors. - check_before_test : bool - If True, checks connectivity before running the test case. - error_classes : tuple or Exception - error classes to ignore. If not in ``error_classes``, raises the error. - defaults to IOError. Be careful about changing the error classes here. - skip_errnos : iterable of int - Any exception that has .errno or .reason.erno set to one - of these values will be skipped with an appropriate - message. - _skip_on_messages: iterable of string - any exception e for which one of the strings is - a substring of str(e) will be skipped with an appropriate - message. Intended to suppress errors where an errno isn't available. - - Notes - ----- - * ``raise_on_error`` supersedes ``check_before_test`` - - Returns - ------- - t : callable - The decorated test ``t``, with checks for connectivity errors. - - Example - ------- - - Tests decorated with @network will fail if it's possible to make a network - connection to another URL (defaults to google.com):: - - >>> from pandas._testing import network - >>> from pandas.io.common import urlopen - >>> @network - ... def test_network(): - ... with urlopen("rabbit://bonanza.com"): - ... pass - Traceback - ... - URLError: - - You can specify alternative URLs:: - - >>> @network("https://www.yahoo.com") - ... def test_something_with_yahoo(): - ... raise IOError("Failure Message") - >>> test_something_with_yahoo() - Traceback (most recent call last): - ... - IOError: Failure Message - - If you set check_before_test, it will check the url first and not run the - test on failure:: - - >>> @network("failing://url.blaher", check_before_test=True) - ... def test_something(): - ... print("I ran!") - ... raise ValueError("Failure") - >>> test_something() - Traceback (most recent call last): - ... - - Errors not related to networking will always be raised. - """ - from pytest import skip - - if error_classes is None: - error_classes = _get_default_network_errors() - - t.network = True - - @wraps(t) - def wrapper(*args, **kwargs): - if ( - check_before_test - and not raise_on_error - and not can_connect(url, error_classes) - ): - skip() - try: - return t(*args, **kwargs) - except Exception as err: - errno = getattr(err, "errno", None) - if not errno and hasattr(errno, "reason"): - # pandas\_testing.py:2521: error: "Exception" has no attribute - # "reason" - errno = getattr(err.reason, "errno", None) # type: ignore[attr-defined] - - if errno in skip_errnos: - skip(f"Skipping test due to known errno and error {err}") - - e_str = str(err) - - if any(m.lower() in e_str.lower() for m in _skip_on_messages): - skip( - f"Skipping test because exception message is known and error {err}" - ) - - if not isinstance(err, error_classes): - raise - - if raise_on_error or can_connect(url, error_classes): - raise - else: - skip(f"Skipping test due to lack of connectivity and error {err}") - - return wrapper - - -with_connectivity_check = network - - @contextmanager def assert_produces_warning( expected_warning: Optional[Union[Type[Warning], bool]] = Warning, diff --git a/pandas/_testing/_network.py b/pandas/_testing/_network.py new file mode 100644 index 0000000000000..c60e4426b695c --- /dev/null +++ b/pandas/_testing/_network.py @@ -0,0 +1,248 @@ +from functools import wraps + +from pandas.io.common import urlopen + +_RAISE_NETWORK_ERROR_DEFAULT = False + +# skip tests on exceptions with these messages +_network_error_messages = ( + # 'urlopen error timed out', + # 'timeout: timed out', + # 'socket.timeout: timed out', + "timed out", + "Server Hangup", + "HTTP Error 503: Service Unavailable", + "502: Proxy Error", + "HTTP Error 502: internal error", + "HTTP Error 502", + "HTTP Error 503", + "HTTP Error 403", + "HTTP Error 400", + "Temporary failure in name resolution", + "Name or service not known", + "Connection refused", + "certificate verify", +) + +# or this e.errno/e.reason.errno +_network_errno_vals = ( + 101, # Network is unreachable + 111, # Connection refused + 110, # Connection timed out + 104, # Connection reset Error + 54, # Connection reset by peer + 60, # urllib.error.URLError: [Errno 60] Connection timed out +) + +# Both of the above shouldn't mask real issues such as 404's +# or refused connections (changed DNS). +# But some tests (test_data yahoo) contact incredibly flakey +# servers. + +# and conditionally raise on exception types in _get_default_network_errors + + +def _get_default_network_errors(): + # Lazy import for http.client because it imports many things from the stdlib + import http.client + + return (IOError, http.client.HTTPException, TimeoutError) + + +def optional_args(decorator): + """ + allows a decorator to take optional positional and keyword arguments. + Assumes that taking a single, callable, positional argument means that + it is decorating a function, i.e. something like this:: + + @my_decorator + def function(): pass + + Calls decorator with decorator(f, *args, **kwargs) + """ + + @wraps(decorator) + def wrapper(*args, **kwargs): + def dec(f): + return decorator(f, *args, **kwargs) + + is_decorating = not kwargs and len(args) == 1 and callable(args[0]) + if is_decorating: + f = args[0] + # pandas\_testing.py:2331: error: Incompatible types in assignment + # (expression has type "List[]", variable has type + # "Tuple[Any, ...]") + args = [] # type: ignore[assignment] + return dec(f) + else: + return dec + + return wrapper + + +@optional_args +def network( + t, + url="https://www.google.com", + raise_on_error=_RAISE_NETWORK_ERROR_DEFAULT, + check_before_test=False, + error_classes=None, + skip_errnos=_network_errno_vals, + _skip_on_messages=_network_error_messages, +): + """ + Label a test as requiring network connection and, if an error is + encountered, only raise if it does not find a network connection. + + In comparison to ``network``, this assumes an added contract to your test: + you must assert that, under normal conditions, your test will ONLY fail if + it does not have network connectivity. + + You can call this in 3 ways: as a standard decorator, with keyword + arguments, or with a positional argument that is the url to check. + + Parameters + ---------- + t : callable + The test requiring network connectivity. + url : path + The url to test via ``pandas.io.common.urlopen`` to check + for connectivity. Defaults to 'https://www.google.com'. + raise_on_error : bool + If True, never catches errors. + check_before_test : bool + If True, checks connectivity before running the test case. + error_classes : tuple or Exception + error classes to ignore. If not in ``error_classes``, raises the error. + defaults to IOError. Be careful about changing the error classes here. + skip_errnos : iterable of int + Any exception that has .errno or .reason.erno set to one + of these values will be skipped with an appropriate + message. + _skip_on_messages: iterable of string + any exception e for which one of the strings is + a substring of str(e) will be skipped with an appropriate + message. Intended to suppress errors where an errno isn't available. + + Notes + ----- + * ``raise_on_error`` supersedes ``check_before_test`` + + Returns + ------- + t : callable + The decorated test ``t``, with checks for connectivity errors. + + Example + ------- + + Tests decorated with @network will fail if it's possible to make a network + connection to another URL (defaults to google.com):: + + >>> from pandas._testing import network + >>> from pandas.io.common import urlopen + >>> @network + ... def test_network(): + ... with urlopen("rabbit://bonanza.com"): + ... pass + Traceback + ... + URLError: + + You can specify alternative URLs:: + + >>> @network("https://www.yahoo.com") + ... def test_something_with_yahoo(): + ... raise IOError("Failure Message") + >>> test_something_with_yahoo() + Traceback (most recent call last): + ... + IOError: Failure Message + + If you set check_before_test, it will check the url first and not run the + test on failure:: + + >>> @network("failing://url.blaher", check_before_test=True) + ... def test_something(): + ... print("I ran!") + ... raise ValueError("Failure") + >>> test_something() + Traceback (most recent call last): + ... + + Errors not related to networking will always be raised. + """ + from pytest import skip + + if error_classes is None: + error_classes = _get_default_network_errors() + + t.network = True + + @wraps(t) + def wrapper(*args, **kwargs): + if ( + check_before_test + and not raise_on_error + and not can_connect(url, error_classes) + ): + skip() + try: + return t(*args, **kwargs) + except Exception as err: + errno = getattr(err, "errno", None) + if not errno and hasattr(errno, "reason"): + # pandas\_testing.py:2521: error: "Exception" has no attribute + # "reason" + errno = getattr(err.reason, "errno", None) # type: ignore[attr-defined] + + if errno in skip_errnos: + skip(f"Skipping test due to known errno and error {err}") + + e_str = str(err) + + if any(m.lower() in e_str.lower() for m in _skip_on_messages): + skip( + f"Skipping test because exception message is known and error {err}" + ) + + if not isinstance(err, error_classes): + raise + + if raise_on_error or can_connect(url, error_classes): + raise + else: + skip(f"Skipping test due to lack of connectivity and error {err}") + + return wrapper + + +with_connectivity_check = network + + +def can_connect(url, error_classes=None): + """ + Try to connect to the given url. True if succeeds, False if IOError + raised + + Parameters + ---------- + url : basestring + The URL to try to connect to + + Returns + ------- + connectable : bool + Return True if no IOError (unable to connect) or URLError (bad url) was + raised + """ + if error_classes is None: + error_classes = _get_default_network_errors() + + try: + with urlopen(url): + pass + except error_classes: + return False + else: + return True From d04c56dd6632666cbd51094e265f6e7d7d59ffd5 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 4 Jan 2021 15:46:20 -0800 Subject: [PATCH 2/3] REF: implement pandas/_testing/_io, _random --- pandas/_testing/__init__.py | 248 ++---------------------- pandas/_testing/{_network.py => _io.py} | 157 +++++++++++++++ pandas/_testing/_random.py | 48 +++++ pandas/_testing/contexts.py | 31 +++ 4 files changed, 254 insertions(+), 230 deletions(-) rename pandas/_testing/{_network.py => _io.py} (64%) create mode 100644 pandas/_testing/_random.py diff --git a/pandas/_testing/__init__.py b/pandas/_testing/__init__.py index c57b466f83170..07d34218aa6e1 100644 --- a/pandas/_testing/__init__.py +++ b/pandas/_testing/__init__.py @@ -1,27 +1,13 @@ -import bz2 from collections import Counter from contextlib import contextmanager from datetime import datetime from functools import wraps -import gzip import operator import os import re import string -from typing import ( - Any, - Callable, - ContextManager, - List, - Optional, - Sequence, - Tuple, - Type, - Union, - cast, -) +from typing import Callable, ContextManager, List, Optional, Sequence, Type, Union, cast import warnings -import zipfile import numpy as np @@ -31,8 +17,7 @@ set_locale, ) -from pandas._typing import Dtype, FilePathOrBuffer, FrameOrSeries -from pandas.compat import get_lzma_file, import_lzma +from pandas._typing import Dtype from pandas.core.dtypes.common import ( is_datetime64_dtype, @@ -55,7 +40,21 @@ Series, bdate_range, ) -from pandas._testing._network import network, with_connectivity_check # noqa:F401 +from pandas._testing._io import ( # noqa:F401 + close, + network, + round_trip_localpath, + round_trip_pathlib, + round_trip_pickle, + with_connectivity_check, + write_to_compressed, +) +from pandas._testing._random import ( # noqa:F401 + randbool, + rands, + rands_array, + randu_array, +) from pandas._testing.asserters import ( # noqa:F401 assert_almost_equal, assert_attr_equal, @@ -80,6 +79,7 @@ raise_assert_detail, ) from pandas._testing.contexts import ( # noqa:F401 + RNGContext, decompress_file, ensure_clean, ensure_clean_dir, @@ -90,8 +90,6 @@ ) from pandas.core.arrays import DatetimeArray, PeriodArray, TimedeltaArray, period_array -lzma = import_lzma() - _N = 30 _K = 4 @@ -168,187 +166,6 @@ def reset_display_options(): pd.reset_option("^display.", silent=True) -def round_trip_pickle( - obj: Any, path: Optional[FilePathOrBuffer] = None -) -> FrameOrSeries: - """ - Pickle an object and then read it again. - - Parameters - ---------- - obj : any object - The object to pickle and then re-read. - path : str, path object or file-like object, default None - The path where the pickled object is written and then read. - - Returns - ------- - pandas object - The original object that was pickled and then re-read. - """ - _path = path - if _path is None: - _path = f"__{rands(10)}__.pickle" - with ensure_clean(_path) as temp_path: - pd.to_pickle(obj, temp_path) - return pd.read_pickle(temp_path) - - -def round_trip_pathlib(writer, reader, path: Optional[str] = None): - """ - Write an object to file specified by a pathlib.Path and read it back - - Parameters - ---------- - writer : callable bound to pandas object - IO writing function (e.g. DataFrame.to_csv ) - reader : callable - IO reading function (e.g. pd.read_csv ) - path : str, default None - The path where the object is written and then read. - - Returns - ------- - pandas object - The original object that was serialized and then re-read. - """ - import pytest - - Path = pytest.importorskip("pathlib").Path - if path is None: - path = "___pathlib___" - with ensure_clean(path) as path: - writer(Path(path)) - obj = reader(Path(path)) - return obj - - -def round_trip_localpath(writer, reader, path: Optional[str] = None): - """ - Write an object to file specified by a py.path LocalPath and read it back. - - Parameters - ---------- - writer : callable bound to pandas object - IO writing function (e.g. DataFrame.to_csv ) - reader : callable - IO reading function (e.g. pd.read_csv ) - path : str, default None - The path where the object is written and then read. - - Returns - ------- - pandas object - The original object that was serialized and then re-read. - """ - import pytest - - LocalPath = pytest.importorskip("py.path").local - if path is None: - path = "___localpath___" - with ensure_clean(path) as path: - writer(LocalPath(path)) - obj = reader(LocalPath(path)) - return obj - - -def write_to_compressed(compression, path, data, dest="test"): - """ - Write data to a compressed file. - - Parameters - ---------- - compression : {'gzip', 'bz2', 'zip', 'xz'} - The compression type to use. - path : str - The file path to write the data. - data : str - The data to write. - dest : str, default "test" - The destination file (for ZIP only) - - Raises - ------ - ValueError : An invalid compression value was passed in. - """ - args: Tuple[Any, ...] = (data,) - mode = "wb" - method = "write" - compress_method: Callable - - if compression == "zip": - compress_method = zipfile.ZipFile - mode = "w" - args = (dest, data) - method = "writestr" - elif compression == "gzip": - compress_method = gzip.GzipFile - elif compression == "bz2": - compress_method = bz2.BZ2File - elif compression == "xz": - compress_method = get_lzma_file(lzma) - else: - raise ValueError(f"Unrecognized compression type: {compression}") - - with compress_method(path, mode=mode) as f: - getattr(f, method)(*args) - - -def randbool(size=(), p: float = 0.5): - return np.random.rand(*size) <= p - - -RANDS_CHARS = np.array(list(string.ascii_letters + string.digits), dtype=(np.str_, 1)) -RANDU_CHARS = np.array( - list("".join(map(chr, range(1488, 1488 + 26))) + string.digits), - dtype=(np.unicode_, 1), -) - - -def rands_array(nchars, size, dtype="O"): - """ - Generate an array of byte strings. - """ - retval = ( - np.random.choice(RANDS_CHARS, size=nchars * np.prod(size)) - .view((np.str_, nchars)) - .reshape(size) - ) - return retval.astype(dtype) - - -def randu_array(nchars, size, dtype="O"): - """ - Generate an array of unicode strings. - """ - retval = ( - np.random.choice(RANDU_CHARS, size=nchars * np.prod(size)) - .view((np.unicode_, nchars)) - .reshape(size) - ) - return retval.astype(dtype) - - -def rands(nchars): - """ - Generate one random byte string. - - See `rands_array` if you want to create an array of random strings. - - """ - return "".join(np.random.choice(RANDS_CHARS, nchars)) - - -def close(fignum=None): - from matplotlib.pyplot import close as _close, get_fignums - - if fignum is None: - for fignum in get_fignums(): - _close(fignum) - else: - _close(fignum) - - # ----------------------------------------------------------------------------- # Comparators @@ -1103,35 +920,6 @@ def _assert_raised_with_correct_stacklevel( assert actual_warning.filename == caller.filename, msg -class RNGContext: - """ - Context manager to set the numpy random number generator speed. Returns - to the original value upon exiting the context manager. - - Parameters - ---------- - seed : int - Seed for numpy.random.seed - - Examples - -------- - with RNGContext(42): - np.random.randn() - """ - - def __init__(self, seed): - self.seed = seed - - def __enter__(self): - - self.start_state = np.random.get_state() - np.random.seed(self.seed) - - def __exit__(self, exc_type, exc_value, traceback): - - np.random.set_state(self.start_state) - - def test_parallel(num_threads=2, kwargs_list=None): """ Decorator to run the same function multiple times in parallel. diff --git a/pandas/_testing/_network.py b/pandas/_testing/_io.py similarity index 64% rename from pandas/_testing/_network.py rename to pandas/_testing/_io.py index c60e4426b695c..5f27b016b68a2 100644 --- a/pandas/_testing/_network.py +++ b/pandas/_testing/_io.py @@ -1,9 +1,22 @@ +import bz2 from functools import wraps +import gzip +from typing import Any, Callable, Optional, Tuple +import zipfile + +from pandas._typing import FilePathOrBuffer, FrameOrSeries +from pandas.compat import get_lzma_file, import_lzma + +import pandas as pd +from pandas._testing._random import rands +from pandas._testing.contexts import ensure_clean from pandas.io.common import urlopen _RAISE_NETWORK_ERROR_DEFAULT = False +lzma = import_lzma() + # skip tests on exceptions with these messages _network_error_messages = ( # 'urlopen error timed out', @@ -246,3 +259,147 @@ def can_connect(url, error_classes=None): return False else: return True + + +# ------------------------------------------------------------------ +# File-IO + + +def round_trip_pickle( + obj: Any, path: Optional[FilePathOrBuffer] = None +) -> FrameOrSeries: + """ + Pickle an object and then read it again. + + Parameters + ---------- + obj : any object + The object to pickle and then re-read. + path : str, path object or file-like object, default None + The path where the pickled object is written and then read. + + Returns + ------- + pandas object + The original object that was pickled and then re-read. + """ + _path = path + if _path is None: + _path = f"__{rands(10)}__.pickle" + with ensure_clean(_path) as temp_path: + pd.to_pickle(obj, temp_path) + return pd.read_pickle(temp_path) + + +def round_trip_pathlib(writer, reader, path: Optional[str] = None): + """ + Write an object to file specified by a pathlib.Path and read it back + + Parameters + ---------- + writer : callable bound to pandas object + IO writing function (e.g. DataFrame.to_csv ) + reader : callable + IO reading function (e.g. pd.read_csv ) + path : str, default None + The path where the object is written and then read. + + Returns + ------- + pandas object + The original object that was serialized and then re-read. + """ + import pytest + + Path = pytest.importorskip("pathlib").Path + if path is None: + path = "___pathlib___" + with ensure_clean(path) as path: + writer(Path(path)) + obj = reader(Path(path)) + return obj + + +def round_trip_localpath(writer, reader, path: Optional[str] = None): + """ + Write an object to file specified by a py.path LocalPath and read it back. + + Parameters + ---------- + writer : callable bound to pandas object + IO writing function (e.g. DataFrame.to_csv ) + reader : callable + IO reading function (e.g. pd.read_csv ) + path : str, default None + The path where the object is written and then read. + + Returns + ------- + pandas object + The original object that was serialized and then re-read. + """ + import pytest + + LocalPath = pytest.importorskip("py.path").local + if path is None: + path = "___localpath___" + with ensure_clean(path) as path: + writer(LocalPath(path)) + obj = reader(LocalPath(path)) + return obj + + +def write_to_compressed(compression, path, data, dest="test"): + """ + Write data to a compressed file. + + Parameters + ---------- + compression : {'gzip', 'bz2', 'zip', 'xz'} + The compression type to use. + path : str + The file path to write the data. + data : str + The data to write. + dest : str, default "test" + The destination file (for ZIP only) + + Raises + ------ + ValueError : An invalid compression value was passed in. + """ + args: Tuple[Any, ...] = (data,) + mode = "wb" + method = "write" + compress_method: Callable + + if compression == "zip": + compress_method = zipfile.ZipFile + mode = "w" + args = (dest, data) + method = "writestr" + elif compression == "gzip": + compress_method = gzip.GzipFile + elif compression == "bz2": + compress_method = bz2.BZ2File + elif compression == "xz": + compress_method = get_lzma_file(lzma) + else: + raise ValueError(f"Unrecognized compression type: {compression}") + + with compress_method(path, mode=mode) as f: + getattr(f, method)(*args) + + +# ------------------------------------------------------------------ +# Plotting + + +def close(fignum=None): + from matplotlib.pyplot import close as _close, get_fignums + + if fignum is None: + for fignum in get_fignums(): + _close(fignum) + else: + _close(fignum) diff --git a/pandas/_testing/_random.py b/pandas/_testing/_random.py new file mode 100644 index 0000000000000..a646d7639a4e6 --- /dev/null +++ b/pandas/_testing/_random.py @@ -0,0 +1,48 @@ +import string + +import numpy as np + + +def randbool(size=(), p: float = 0.5): + return np.random.rand(*size) <= p + + +RANDS_CHARS = np.array(list(string.ascii_letters + string.digits), dtype=(np.str_, 1)) +RANDU_CHARS = np.array( + list("".join(map(chr, range(1488, 1488 + 26))) + string.digits), + dtype=(np.unicode_, 1), +) + + +def rands_array(nchars, size, dtype="O"): + """ + Generate an array of byte strings. + """ + retval = ( + np.random.choice(RANDS_CHARS, size=nchars * np.prod(size)) + .view((np.str_, nchars)) + .reshape(size) + ) + return retval.astype(dtype) + + +def randu_array(nchars, size, dtype="O"): + """ + Generate an array of unicode strings. + """ + retval = ( + np.random.choice(RANDU_CHARS, size=nchars * np.prod(size)) + .view((np.unicode_, nchars)) + .reshape(size) + ) + return retval.astype(dtype) + + +def rands(nchars): + """ + Generate one random byte string. + + See `rands_array` if you want to create an array of random strings. + + """ + return "".join(np.random.choice(RANDS_CHARS, nchars)) diff --git a/pandas/_testing/contexts.py b/pandas/_testing/contexts.py index d6a4b47571653..d72dc8c3af104 100644 --- a/pandas/_testing/contexts.py +++ b/pandas/_testing/contexts.py @@ -3,6 +3,8 @@ from shutil import rmtree import tempfile +import numpy as np + from pandas.io.common import get_handle @@ -214,3 +216,32 @@ def use_numexpr(use, min_elements=None): yield expr._MIN_ELEMENTS = oldmin expr.set_use_numexpr(olduse) + + +class RNGContext: + """ + Context manager to set the numpy random number generator speed. Returns + to the original value upon exiting the context manager. + + Parameters + ---------- + seed : int + Seed for numpy.random.seed + + Examples + -------- + with RNGContext(42): + np.random.randn() + """ + + def __init__(self, seed): + self.seed = seed + + def __enter__(self): + + self.start_state = np.random.get_state() + np.random.seed(self.seed) + + def __exit__(self, exc_type, exc_value, traceback): + + np.random.set_state(self.start_state) From 17f86c19d7c071f3e7dd12eedf7612402d412be7 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 4 Jan 2021 15:50:51 -0800 Subject: [PATCH 3/3] REF: implement tm._warnings --- pandas/_testing/__init__.py | 174 +---------------------------------- pandas/_testing/_warnings.py | 174 +++++++++++++++++++++++++++++++++++ 2 files changed, 176 insertions(+), 172 deletions(-) create mode 100644 pandas/_testing/_warnings.py diff --git a/pandas/_testing/__init__.py b/pandas/_testing/__init__.py index 07d34218aa6e1..c51ceb750c338 100644 --- a/pandas/_testing/__init__.py +++ b/pandas/_testing/__init__.py @@ -1,12 +1,11 @@ from collections import Counter -from contextlib import contextmanager from datetime import datetime from functools import wraps import operator import os import re import string -from typing import Callable, ContextManager, List, Optional, Sequence, Type, Union, cast +from typing import Callable, ContextManager, List, Type import warnings import numpy as np @@ -55,6 +54,7 @@ rands_array, randu_array, ) +from pandas._testing._warnings import assert_produces_warning # noqa:F401 from pandas._testing.asserters import ( # noqa:F401 assert_almost_equal, assert_attr_equal, @@ -750,176 +750,6 @@ def makeMissingDataframe(density=0.9, random_state=None): return df -@contextmanager -def assert_produces_warning( - expected_warning: Optional[Union[Type[Warning], bool]] = Warning, - filter_level="always", - check_stacklevel: bool = True, - raise_on_extra_warnings: bool = True, - match: Optional[str] = None, -): - """ - Context manager for running code expected to either raise a specific - warning, or not raise any warnings. Verifies that the code raises the - expected warning, and that it does not raise any other unexpected - warnings. It is basically a wrapper around ``warnings.catch_warnings``. - - Parameters - ---------- - expected_warning : {Warning, False, None}, default Warning - The type of Exception raised. ``exception.Warning`` is the base - class for all warnings. To check that no warning is returned, - specify ``False`` or ``None``. - filter_level : str or None, default "always" - Specifies whether warnings are ignored, displayed, or turned - into errors. - Valid values are: - - * "error" - turns matching warnings into exceptions - * "ignore" - discard the warning - * "always" - always emit a warning - * "default" - print the warning the first time it is generated - from each location - * "module" - print the warning the first time it is generated - from each module - * "once" - print the warning the first time it is generated - - check_stacklevel : bool, default True - If True, displays the line that called the function containing - the warning to show were the function is called. Otherwise, the - line that implements the function is displayed. - raise_on_extra_warnings : bool, default True - Whether extra warnings not of the type `expected_warning` should - cause the test to fail. - match : str, optional - Match warning message. - - Examples - -------- - >>> import warnings - >>> with assert_produces_warning(): - ... warnings.warn(UserWarning()) - ... - >>> with assert_produces_warning(False): - ... warnings.warn(RuntimeWarning()) - ... - Traceback (most recent call last): - ... - AssertionError: Caused unexpected warning(s): ['RuntimeWarning']. - >>> with assert_produces_warning(UserWarning): - ... warnings.warn(RuntimeWarning()) - Traceback (most recent call last): - ... - AssertionError: Did not see expected warning of class 'UserWarning'. - - ..warn:: This is *not* thread-safe. - """ - __tracebackhide__ = True - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter(filter_level) - yield w - - if expected_warning: - expected_warning = cast(Type[Warning], expected_warning) - _assert_caught_expected_warning( - caught_warnings=w, - expected_warning=expected_warning, - match=match, - check_stacklevel=check_stacklevel, - ) - - if raise_on_extra_warnings: - _assert_caught_no_extra_warnings( - caught_warnings=w, - expected_warning=expected_warning, - ) - - -def _assert_caught_expected_warning( - *, - caught_warnings: Sequence[warnings.WarningMessage], - expected_warning: Type[Warning], - match: Optional[str], - check_stacklevel: bool, -) -> None: - """Assert that there was the expected warning among the caught warnings.""" - saw_warning = False - matched_message = False - - for actual_warning in caught_warnings: - if issubclass(actual_warning.category, expected_warning): - saw_warning = True - - if check_stacklevel and issubclass( - actual_warning.category, (FutureWarning, DeprecationWarning) - ): - _assert_raised_with_correct_stacklevel(actual_warning) - - if match is not None and re.search(match, str(actual_warning.message)): - matched_message = True - - if not saw_warning: - raise AssertionError( - f"Did not see expected warning of class " - f"{repr(expected_warning.__name__)}" - ) - - if match and not matched_message: - raise AssertionError( - f"Did not see warning {repr(expected_warning.__name__)} " - f"matching {match}" - ) - - -def _assert_caught_no_extra_warnings( - *, - caught_warnings: Sequence[warnings.WarningMessage], - expected_warning: Optional[Union[Type[Warning], bool]], -) -> None: - """Assert that no extra warnings apart from the expected ones are caught.""" - extra_warnings = [] - - for actual_warning in caught_warnings: - if _is_unexpected_warning(actual_warning, expected_warning): - extra_warnings.append( - ( - actual_warning.category.__name__, - actual_warning.message, - actual_warning.filename, - actual_warning.lineno, - ) - ) - - if extra_warnings: - raise AssertionError(f"Caused unexpected warning(s): {repr(extra_warnings)}") - - -def _is_unexpected_warning( - actual_warning: warnings.WarningMessage, - expected_warning: Optional[Union[Type[Warning], bool]], -) -> bool: - """Check if the actual warning issued is unexpected.""" - if actual_warning and not expected_warning: - return True - expected_warning = cast(Type[Warning], expected_warning) - return bool(not issubclass(actual_warning.category, expected_warning)) - - -def _assert_raised_with_correct_stacklevel( - actual_warning: warnings.WarningMessage, -) -> None: - from inspect import getframeinfo, stack - - caller = getframeinfo(stack()[4][0]) - msg = ( - "Warning not set with correct stacklevel. " - f"File where warning is raised: {actual_warning.filename} != " - f"{caller.filename}. Warning message: {actual_warning.message}" - ) - assert actual_warning.filename == caller.filename, msg - - def test_parallel(num_threads=2, kwargs_list=None): """ Decorator to run the same function multiple times in parallel. diff --git a/pandas/_testing/_warnings.py b/pandas/_testing/_warnings.py new file mode 100644 index 0000000000000..6429f74637f01 --- /dev/null +++ b/pandas/_testing/_warnings.py @@ -0,0 +1,174 @@ +from contextlib import contextmanager +import re +from typing import Optional, Sequence, Type, Union, cast +import warnings + + +@contextmanager +def assert_produces_warning( + expected_warning: Optional[Union[Type[Warning], bool]] = Warning, + filter_level="always", + check_stacklevel: bool = True, + raise_on_extra_warnings: bool = True, + match: Optional[str] = None, +): + """ + Context manager for running code expected to either raise a specific + warning, or not raise any warnings. Verifies that the code raises the + expected warning, and that it does not raise any other unexpected + warnings. It is basically a wrapper around ``warnings.catch_warnings``. + + Parameters + ---------- + expected_warning : {Warning, False, None}, default Warning + The type of Exception raised. ``exception.Warning`` is the base + class for all warnings. To check that no warning is returned, + specify ``False`` or ``None``. + filter_level : str or None, default "always" + Specifies whether warnings are ignored, displayed, or turned + into errors. + Valid values are: + + * "error" - turns matching warnings into exceptions + * "ignore" - discard the warning + * "always" - always emit a warning + * "default" - print the warning the first time it is generated + from each location + * "module" - print the warning the first time it is generated + from each module + * "once" - print the warning the first time it is generated + + check_stacklevel : bool, default True + If True, displays the line that called the function containing + the warning to show were the function is called. Otherwise, the + line that implements the function is displayed. + raise_on_extra_warnings : bool, default True + Whether extra warnings not of the type `expected_warning` should + cause the test to fail. + match : str, optional + Match warning message. + + Examples + -------- + >>> import warnings + >>> with assert_produces_warning(): + ... warnings.warn(UserWarning()) + ... + >>> with assert_produces_warning(False): + ... warnings.warn(RuntimeWarning()) + ... + Traceback (most recent call last): + ... + AssertionError: Caused unexpected warning(s): ['RuntimeWarning']. + >>> with assert_produces_warning(UserWarning): + ... warnings.warn(RuntimeWarning()) + Traceback (most recent call last): + ... + AssertionError: Did not see expected warning of class 'UserWarning'. + + ..warn:: This is *not* thread-safe. + """ + __tracebackhide__ = True + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter(filter_level) + yield w + + if expected_warning: + expected_warning = cast(Type[Warning], expected_warning) + _assert_caught_expected_warning( + caught_warnings=w, + expected_warning=expected_warning, + match=match, + check_stacklevel=check_stacklevel, + ) + + if raise_on_extra_warnings: + _assert_caught_no_extra_warnings( + caught_warnings=w, + expected_warning=expected_warning, + ) + + +def _assert_caught_expected_warning( + *, + caught_warnings: Sequence[warnings.WarningMessage], + expected_warning: Type[Warning], + match: Optional[str], + check_stacklevel: bool, +) -> None: + """Assert that there was the expected warning among the caught warnings.""" + saw_warning = False + matched_message = False + + for actual_warning in caught_warnings: + if issubclass(actual_warning.category, expected_warning): + saw_warning = True + + if check_stacklevel and issubclass( + actual_warning.category, (FutureWarning, DeprecationWarning) + ): + _assert_raised_with_correct_stacklevel(actual_warning) + + if match is not None and re.search(match, str(actual_warning.message)): + matched_message = True + + if not saw_warning: + raise AssertionError( + f"Did not see expected warning of class " + f"{repr(expected_warning.__name__)}" + ) + + if match and not matched_message: + raise AssertionError( + f"Did not see warning {repr(expected_warning.__name__)} " + f"matching {match}" + ) + + +def _assert_caught_no_extra_warnings( + *, + caught_warnings: Sequence[warnings.WarningMessage], + expected_warning: Optional[Union[Type[Warning], bool]], +) -> None: + """Assert that no extra warnings apart from the expected ones are caught.""" + extra_warnings = [] + + for actual_warning in caught_warnings: + if _is_unexpected_warning(actual_warning, expected_warning): + extra_warnings.append( + ( + actual_warning.category.__name__, + actual_warning.message, + actual_warning.filename, + actual_warning.lineno, + ) + ) + + if extra_warnings: + raise AssertionError(f"Caused unexpected warning(s): {repr(extra_warnings)}") + + +def _is_unexpected_warning( + actual_warning: warnings.WarningMessage, + expected_warning: Optional[Union[Type[Warning], bool]], +) -> bool: + """Check if the actual warning issued is unexpected.""" + if actual_warning and not expected_warning: + return True + expected_warning = cast(Type[Warning], expected_warning) + return bool(not issubclass(actual_warning.category, expected_warning)) + + +def _assert_raised_with_correct_stacklevel( + actual_warning: warnings.WarningMessage, +) -> None: + from inspect import getframeinfo, stack + + caller = getframeinfo(stack()[4][0]) + msg = ( + "Warning not set with correct stacklevel. " + f"File where warning is raised: {actual_warning.filename} != " + f"{caller.filename}. Warning message: {actual_warning.message}" + ) + assert actual_warning.filename == caller.filename, msg