diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 9768d9ea7d..bc6ae52564 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -62,6 +62,8 @@ get_print_options, print_options, set_print_options, + usm_ndarray_repr, + usm_ndarray_str, ) from dpctl.tensor._reshape import reshape from dpctl.tensor._usmarray import usm_ndarray @@ -137,4 +139,6 @@ "get_print_options", "set_print_options", "print_options", + "usm_ndarray_repr", + "usm_ndarray_str", ] diff --git a/dpctl/tensor/_print.py b/dpctl/tensor/_print.py index 914f555a77..15c06a875f 100644 --- a/dpctl/tensor/_print.py +++ b/dpctl/tensor/_print.py @@ -107,8 +107,8 @@ def set_print_options( ): """ set_print_options(linewidth=None, edgeitems=None, threshold=None, - precision=None, floatmode=None, suppress=None, nanstr=None, - infstr=None, sign=None, numpy=False) + precision=None, floatmode=None, suppress=None, + nanstr=None, infstr=None, sign=None, numpy=False) Set options for printing ``dpctl.tensor.usm_ndarray`` class. @@ -238,7 +238,7 @@ def _nd_corners(x, edge_items, slices=()): return _nd_corners(x, edge_items, slices + (slice(None, None, None),)) -def _usm_ndarray_str( +def usm_ndarray_str( x, line_width=None, edge_items=None, @@ -252,6 +252,72 @@ def _usm_ndarray_str( prefix="", suffix="", ): + """ + usm_ndarray_str(x, line_width=None, edgeitems=None, threshold=None, + precision=None, floatmode=None, suppress=None, + sign=None, numpy=False, separator=" ", prefix="", + suffix="") -> str + + Returns a string representing the elements of a + ``dpctl.tensor.usm_ndarray``. + + Args: + x (usm_ndarray): Input array. + line_width (int, optional): Number of characters printed per line. + Raises `TypeError` if line_width is not an integer. + Default: `75`. + edgeitems (int, optional): Number of elements at the beginning and end + when the printed array is abbreviated. + Raises `TypeError` if edgeitems is not an integer. + Default: `3`. + threshold (int, optional): Number of elements that triggers array + abbreviation. + Raises `TypeError` if threshold is not an integer. + Default: `1000`. + precision (int or None, optional): Number of digits printed for + floating point numbers. + Raises `TypeError` if precision is not an integer. + Default: `8`. + floatmode (str, optional): Controls how floating point + numbers are interpreted. + + `"fixed:`: Always prints exactly `precision` digits. + `"unique"`: Ignores precision, prints the number of + digits necessary to uniquely specify each number. + `"maxprec"`: Prints `precision` digits or fewer, + if fewer will uniquely represent a number. + `"maxprec_equal"`: Prints an equal number of digits + for each number. This number is `precision` digits or fewer, + if fewer will uniquely represent each number. + Raises `ValueError` if floatmode is not one of + `fixed`, `unique`, `maxprec`, or `maxprec_equal`. + Default: "maxprec_equal" + suppress (bool, optional): If `True,` numbers equal to zero + in the current precision will print as zero. + Default: `False`. + sign (str, optional): Controls the sign of floating point + numbers. + `"-"`: Omit the sign of positive numbers. + `"+"`: Always print the sign of positive numbers. + `" "`: Always print a whitespace in place of the + sign of positive numbers. + Raises `ValueError` if sign is not one of + `"-"`, `"+"`, or `" "`. + Default: `"-"`. + numpy (bool, optional): If `True,` then before other specified print + options are set, a dictionary of Numpy's print options + will be used to initialize dpctl's print options. + Default: "False" + separator (str, optional): String inserted between elements of + the array string. + Default: " " + prefix (str, optional): String used to determine spacing to the left + of the array string. + Default: "" + suffix (str, optional): String that determines length of the last line + of the array string. + Default: "" + """ if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") @@ -285,7 +351,33 @@ def _usm_ndarray_str( return s -def _usm_ndarray_repr(x, line_width=None, precision=None, suppress=None): +def usm_ndarray_repr( + x, line_width=None, precision=None, suppress=None, prefix="usm_ndarray" +): + """ + usm_ndarray_repr(x, line_width=None, precision=None, + suppress=None, prefix="") -> str + + Returns a formatted string representing the elements + of a ``dpctl.tensor.usm_ndarray`` and its data type, + if not a default type. + + Args: + x (usm_ndarray): Input array. + line_width (int, optional): Number of characters printed per line. + Raises `TypeError` if line_width is not an integer. + Default: `75`. + precision (int or None, optional): Number of digits printed for + floating point numbers. + Raises `TypeError` if precision is not an integer. + Default: `8`. + suppress (bool, optional): If `True,` numbers equal to zero + in the current precision will print as zero. + Default: `False`. + prefix (str, optional): String inserted at the start of the array + string. + Default: "" + """ if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") @@ -299,10 +391,10 @@ def _usm_ndarray_repr(x, line_width=None, precision=None, suppress=None): dpt.complex128, ] - prefix = "usm_ndarray(" + prefix = prefix + "(" suffix = ")" - s = _usm_ndarray_str( + s = usm_ndarray_str( x, line_width=line_width, precision=precision, diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 8c383b72c6..e2524b866d 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -26,7 +26,7 @@ import dpctl import dpctl.memory as dpmem from ._device import Device -from ._print import _usm_ndarray_repr, _usm_ndarray_str +from ._print import usm_ndarray_repr, usm_ndarray_str from cpython.mem cimport PyMem_Free from cpython.tuple cimport PyTuple_New, PyTuple_SetItem @@ -1145,10 +1145,10 @@ cdef class usm_ndarray: return self def __str__(self): - return _usm_ndarray_str(self) + return usm_ndarray_str(self) def __repr__(self): - return _usm_ndarray_repr(self) + return usm_ndarray_repr(self) cdef usm_ndarray _real_view(usm_ndarray ary): diff --git a/dpctl/tests/test_usm_ndarray_print.py b/dpctl/tests/test_usm_ndarray_print.py index 05a4a2b8a9..8d4e3d9b7f 100644 --- a/dpctl/tests/test_usm_ndarray_print.py +++ b/dpctl/tests/test_usm_ndarray_print.py @@ -48,6 +48,57 @@ def test_print_option_arg_validation(self, arg, err): with pytest.raises(err): dpt.set_print_options(**arg) + def test_usm_ndarray_repr_arg_validation(self): + X = dict() + with pytest.raises(TypeError): + dpt.usm_ndarray_repr(X) + + X = dpt.arange(4) + with pytest.raises(TypeError): + dpt.usm_ndarray_repr(X, line_width="I") + + with pytest.raises(TypeError): + dpt.usm_ndarray_repr(X, precision="I") + + with pytest.raises(TypeError): + dpt.usm_ndarray_repr(X, prefix=4) + + def test_usm_ndarray_str_arg_validation(self): + X = dict() + with pytest.raises(TypeError): + dpt.usm_ndarray_str(X) + + X = dpt.arange(4) + with pytest.raises(TypeError): + dpt.usm_ndarray_str(X, line_width="I") + + with pytest.raises(TypeError): + dpt.usm_ndarray_str(X, edge_items="I") + + with pytest.raises(TypeError): + dpt.usm_ndarray_str(X, threshold="I") + + with pytest.raises(TypeError): + dpt.usm_ndarray_str(X, precision="I") + + with pytest.raises(ValueError): + dpt.usm_ndarray_str(X, floatmode="I") + + with pytest.raises(TypeError): + dpt.usm_ndarray_str(X, edge_items="I") + + with pytest.raises(ValueError): + dpt.usm_ndarray_str(X, sign="I") + + with pytest.raises(TypeError): + dpt.usm_ndarray_str(X, prefix=4) + + with pytest.raises(TypeError): + dpt.usm_ndarray_str(X, prefix=4) + + with pytest.raises(TypeError): + dpt.usm_ndarray_str(X, suffix=4) + class TestSetPrintOptions(TestPrint): def test_set_linewidth(self): @@ -188,6 +239,16 @@ def test_print_str_abbreviated(self): x = dpt.reshape(x, (3, 3)) assert str(x) == "[[0 ... 2]\n ...\n [6 ... 8]]" + def test_usm_ndarray_str_separator(self): + q = get_queue_or_skip() + + x = dpt.reshape(dpt.arange(4, sycl_queue=q), (2, 2)) + + np.testing.assert_equal( + dpt.usm_ndarray_str(x, prefix="test", separator=" "), + "[[0 1]\n [2 3]]", + ) + def test_print_repr(self): q = get_queue_or_skip() @@ -282,6 +343,19 @@ def test_repr_appended_dtype(self, dtype): x = dpt.empty(4, dtype=dtype) assert repr(x).split("=")[-1][:-1] == x.dtype.name + def test_usm_ndarray_repr_prefix(self): + q = get_queue_or_skip() + + x = dpt.arange(4, dtype=np.intp, sycl_queue=q) + np.testing.assert_equal( + dpt.usm_ndarray_repr(x, prefix="test"), "test([0, 1, 2, 3])" + ) + x = dpt.reshape(x, (2, 2)) + np.testing.assert_equal( + dpt.usm_ndarray_repr(x, prefix="test"), + "test([[0, 1]," "\n [2, 3]])", + ) + class TestContextManager: def test_context_manager_basic(self):