From dec604f52fe3f9a877dd2ff19eafe9c34f3b0062 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 21 Mar 2024 20:56:23 +0000 Subject: [PATCH 1/4] fix: plot.scatter c arguments functionalities --- bigframes/operations/_matplotlib/core.py | 70 +++++++++++++++++-- .../system/small/operations/test_plotting.py | 61 ++++++++++++++++ 2 files changed, 127 insertions(+), 4 deletions(-) diff --git a/bigframes/operations/_matplotlib/core.py b/bigframes/operations/_matplotlib/core.py index 663e7a789f..05e960d3c1 100644 --- a/bigframes/operations/_matplotlib/core.py +++ b/bigframes/operations/_matplotlib/core.py @@ -14,11 +14,16 @@ import abc import typing +import uuid + +import matplotlib.pyplot as plt +import pandas as pd + +import bigframes.dtypes as dtypes DEFAULT_SAMPLING_N = 1000 DEFAULT_SAMPLING_STATE = 0 - class MPLPlot(abc.ABC): @abc.abstractmethod def generate(self): @@ -44,12 +49,13 @@ def _kind(self): def __init__(self, data, **kwargs) -> None: self.kwargs = kwargs - self.data = self._compute_plot_data(data) + self.data = data def generate(self) -> None: - self.axes = self.data.plot(kind=self._kind, **self.kwargs) + plot_data = self._compute_plot_data() + self.axes = plot_data.plot(kind=self._kind, **self.kwargs) - def _compute_plot_data(self, data): + def _compute_sample_data(self, data): # TODO: Cache the sampling data in the PlotAccessor. sampling_n = self.kwargs.pop("sampling_n", DEFAULT_SAMPLING_N) sampling_random_state = self.kwargs.pop( @@ -61,6 +67,9 @@ def _compute_plot_data(self, data): sort=False, ).to_pandas() + def _compute_plot_data(self): + return self._compute_sample_data(self.data) + class LinePlot(SamplingPlot): @property @@ -78,3 +87,56 @@ class ScatterPlot(SamplingPlot): @property def _kind(self) -> typing.Literal["scatter"]: return "scatter" + + def __init__(self, data, **kwargs) -> None: + super().__init__(data, **kwargs) + + c = self.kwargs.get("c", None) + if self._is_sequence_arg(c) and len(c) != self.data.shape[0]: + raise ValueError( + f"'c' argument has {len(c)} elements, which is " + + f"inconsistent with 'x' and 'y' with size {self.data.shape[0]}" + ) + + def _compute_plot_data(self): + data = self.data.copy() + + c = self.kwargs.get("c", None) + c_id = None + if self._is_sequence_arg(c): + c_id = self._generate_new_column_name(data) + print(c_id) + data[c_id] = c + + sample = self._compute_sample_data(data) + + # Works around a pandas bug: + # https://github.com/pandas-dev/pandas/commit/45b937d64f6b7b6971856a47e379c7c87af7e00a + if self._is_column_name(c, sample) and sample[c].dtype == dtypes.STRING_DTYPE: + sample[c] = sample[c].astype("object") + + if c_id is not None: + self.kwargs["c"] = sample[c_id] + sample = sample.drop(columns=[c_id]) + + return sample + + def _is_sequence_arg(self, arg): + return ( + arg is not None + and not isinstance(arg, str) + and isinstance(arg, typing.Iterable) + ) + + def _is_column_name(self, arg, data): + return ( + arg is not None + and pd.core.dtypes.common.is_hashable(arg) + and arg in data.columns + ) + + def _generate_new_column_name(self, data): + col_name = None + while col_name is None or col_name in data.columns: + col_name = f"plot_temp_{str(uuid.uuid4())[:8]}" + return col_name diff --git a/tests/system/small/operations/test_plotting.py b/tests/system/small/operations/test_plotting.py index 5ca3382e2a..8edc8e12f6 100644 --- a/tests/system/small/operations/test_plotting.py +++ b/tests/system/small/operations/test_plotting.py @@ -209,6 +209,67 @@ def test_scatter(scalars_dfs): ) +@pytest.mark.parametrize( + ("c"), + [ + pytest.param("red", id="red"), + pytest.param("c", id="int_column"), + pytest.param("species", id="color_column"), + pytest.param(["red", "green", "blue"], id="color_sequence"), + pytest.param([3.4, 5.3, 2.0], id="number_sequence"), + pytest.param( + [3.4, 5.3], + id="length_mismatches_sequence", + marks=pytest.mark.xfail( + raises=ValueError, + ), + ), + ], +) +def test_scatter_args_c(c): + data = { + "a": [1, 2, 3], + "b": [1, 2, 3], + "c": [1, 2, 3], + "species": ["r", "g", "b"], + } + df = bpd.DataFrame(data) + pd_df = pd.DataFrame(data) + + ax = df.plot.scatter(x="a", y="b", c=c) + pd_ax = pd_df.plot.scatter(x="a", y="b", c=c) + assert len(ax.collections[0].get_facecolor()) == len( + pd_ax.collections[0].get_facecolor() + ) + for idx in range(len(ax.collections[0].get_facecolor())): + tm.assert_numpy_array_equal( + ax.collections[0].get_facecolor()[idx], + pd_ax.collections[0].get_facecolor()[idx], + ) + + +def test_scatter_args_c_sampling(): + data = { + "plot_temp_0": [1, 2, 3, 4, 5], + "plot_temp_1": [5, 4, 3, 2, 1], + } + c = ["red", "green", "blue", "orange", "black"] + + df = bpd.DataFrame(data) + pd_df = pd.DataFrame(data) + + ax = df.plot.scatter(x="plot_temp_0", y="plot_temp_1", c=c, sampling_n=3) + pd_ax = pd_df.plot.scatter(x="plot_temp_0", y="plot_temp_1", c=c) + assert len(ax.collections[0].get_facecolor()) == len( + pd_ax.collections[0].get_facecolor() + ) + for idx in range(len(ax.collections[0].get_facecolor())): + tm.assert_numpy_array_equal( + ax.collections[0].get_facecolor()[idx], + pd_ax.collections[0].get_facecolor()[idx], + ) + + def test_sampling_plot_args_n(): df = bpd.DataFrame(np.arange(bf_mpl.DEFAULT_SAMPLING_N * 10), columns=["one"]) ax = df.plot.line() From eefe402fd3eb8b182d1adb57e0228c669d5eca17 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 22 Mar 2024 19:17:22 +0000 Subject: [PATCH 2/4] fixing test errors --- bigframes/operations/_matplotlib/core.py | 3 +-- tests/system/small/operations/test_plotting.py | 6 +++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/bigframes/operations/_matplotlib/core.py b/bigframes/operations/_matplotlib/core.py index 05e960d3c1..306face4b4 100644 --- a/bigframes/operations/_matplotlib/core.py +++ b/bigframes/operations/_matplotlib/core.py @@ -16,7 +16,6 @@ import typing import uuid -import matplotlib.pyplot as plt import pandas as pd import bigframes.dtypes as dtypes @@ -24,6 +23,7 @@ DEFAULT_SAMPLING_N = 1000 DEFAULT_SAMPLING_STATE = 0 + class MPLPlot(abc.ABC): @abc.abstractmethod def generate(self): @@ -105,7 +105,6 @@ def _compute_plot_data(self): c_id = None if self._is_sequence_arg(c): c_id = self._generate_new_column_name(data) - print(c_id) data[c_id] = c sample = self._compute_sample_data(data) diff --git a/tests/system/small/operations/test_plotting.py b/tests/system/small/operations/test_plotting.py index 8edc8e12f6..6434ce4def 100644 --- a/tests/system/small/operations/test_plotting.py +++ b/tests/system/small/operations/test_plotting.py @@ -259,7 +259,11 @@ def test_scatter_args_c_sampling(): pd_df = pd.DataFrame(data) ax = df.plot.scatter(x="plot_temp_0", y="plot_temp_1", c=c, sampling_n=3) - pd_ax = pd_df.plot.scatter(x="plot_temp_0", y="plot_temp_1", c=c) + + sampling_index = [0, 1, 2] + pd_ax = pd_df.iloc[sampling_index].plot.scatter( + x="plot_temp_0", y="plot_temp_1", c=[c[i] for i in sampling_index] + ) assert len(ax.collections[0].get_facecolor()) == len( pd_ax.collections[0].get_facecolor() ) From ed89fd8d4491bff42d8f6139dae0fabc04cf0c6c Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 22 Mar 2024 23:04:58 +0000 Subject: [PATCH 3/4] do not support sequence, and support column position --- bigframes/operations/_matplotlib/core.py | 28 ++++++--------- .../system/small/operations/test_plotting.py | 36 +------------------ .../pandas/plotting/_core.py | 4 --- 3 files changed, 12 insertions(+), 56 deletions(-) diff --git a/bigframes/operations/_matplotlib/core.py b/bigframes/operations/_matplotlib/core.py index 306face4b4..1dfd5b34ca 100644 --- a/bigframes/operations/_matplotlib/core.py +++ b/bigframes/operations/_matplotlib/core.py @@ -18,6 +18,7 @@ import pandas as pd +import bigframes.constants as constants import bigframes.dtypes as dtypes DEFAULT_SAMPLING_N = 1000 @@ -92,32 +93,25 @@ def __init__(self, data, **kwargs) -> None: super().__init__(data, **kwargs) c = self.kwargs.get("c", None) - if self._is_sequence_arg(c) and len(c) != self.data.shape[0]: - raise ValueError( - f"'c' argument has {len(c)} elements, which is " - + f"inconsistent with 'x' and 'y' with size {self.data.shape[0]}" + if self._is_sequence_arg(c): + raise NotImplementedError( + f"Only support a single color string or a column name/posision. {constants.FEEDBACK_LINK}" ) def _compute_plot_data(self): - data = self.data.copy() - - c = self.kwargs.get("c", None) - c_id = None - if self._is_sequence_arg(c): - c_id = self._generate_new_column_name(data) - data[c_id] = c - - sample = self._compute_sample_data(data) + sample = self._compute_sample_data(self.data) # Works around a pandas bug: # https://github.com/pandas-dev/pandas/commit/45b937d64f6b7b6971856a47e379c7c87af7e00a + c = self.kwargs.get("c", None) + if ( + pd.core.dtypes.common.is_integer(c) + and not self.data.columns._holds_integer() + ): + c = self.data.columns[c] if self._is_column_name(c, sample) and sample[c].dtype == dtypes.STRING_DTYPE: sample[c] = sample[c].astype("object") - if c_id is not None: - self.kwargs["c"] = sample[c_id] - sample = sample.drop(columns=[c_id]) - return sample def _is_sequence_arg(self, arg): diff --git a/tests/system/small/operations/test_plotting.py b/tests/system/small/operations/test_plotting.py index 6434ce4def..41ea7d4ebb 100644 --- a/tests/system/small/operations/test_plotting.py +++ b/tests/system/small/operations/test_plotting.py @@ -215,15 +215,7 @@ def test_scatter(scalars_dfs): pytest.param("red", id="red"), pytest.param("c", id="int_column"), pytest.param("species", id="color_column"), - pytest.param(["red", "green", "blue"], id="color_sequence"), - pytest.param([3.4, 5.3, 2.0], id="number_sequence"), - pytest.param( - [3.4, 5.3], - id="length_mismatches_sequence", - marks=pytest.mark.xfail( - raises=ValueError, - ), - ), + pytest.param(3, id="column_index"), ], ) def test_scatter_args_c(c): @@ -248,32 +240,6 @@ def test_scatter_args_c(c): ) -def test_scatter_args_c_sampling(): - data = { - "plot_temp_0": [1, 2, 3, 4, 5], - "plot_temp_1": [5, 4, 3, 2, 1], - } - c = ["red", "green", "blue", "orange", "black"] - - df = bpd.DataFrame(data) - pd_df = pd.DataFrame(data) - - ax = df.plot.scatter(x="plot_temp_0", y="plot_temp_1", c=c, sampling_n=3) - - sampling_index = [0, 1, 2] - pd_ax = pd_df.iloc[sampling_index].plot.scatter( - x="plot_temp_0", y="plot_temp_1", c=[c[i] for i in sampling_index] - ) - assert len(ax.collections[0].get_facecolor()) == len( - pd_ax.collections[0].get_facecolor() - ) - for idx in range(len(ax.collections[0].get_facecolor())): - tm.assert_numpy_array_equal( - ax.collections[0].get_facecolor()[idx], - pd_ax.collections[0].get_facecolor()[idx], - ) - - def test_sampling_plot_args_n(): df = bpd.DataFrame(np.arange(bf_mpl.DEFAULT_SAMPLING_N * 10), columns=["one"]) ax = df.plot.line() diff --git a/third_party/bigframes_vendored/pandas/plotting/_core.py b/third_party/bigframes_vendored/pandas/plotting/_core.py index d901f41ef8..f8da9efdc0 100644 --- a/third_party/bigframes_vendored/pandas/plotting/_core.py +++ b/third_party/bigframes_vendored/pandas/plotting/_core.py @@ -266,10 +266,6 @@ def scatter( - A single color string referred to by name, RGB or RGBA code, for instance 'red' or '#a98d19'. - - A sequence of color strings referred to by name, RGB or RGBA - code, which will be used for each point's color recursively. For - instance ['green','yellow'] all points will be filled in green or - yellow, alternatively. - A column name or position whose values will be used to color the marker points according to a colormap. From 9c4893abf45db3c47300c6ab67bab88b63cc286a Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Sun, 24 Mar 2024 17:59:16 +0000 Subject: [PATCH 4/4] fix integer column --- bigframes/operations/_matplotlib/core.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/bigframes/operations/_matplotlib/core.py b/bigframes/operations/_matplotlib/core.py index 1dfd5b34ca..ad5abb4bca 100644 --- a/bigframes/operations/_matplotlib/core.py +++ b/bigframes/operations/_matplotlib/core.py @@ -104,10 +104,7 @@ def _compute_plot_data(self): # Works around a pandas bug: # https://github.com/pandas-dev/pandas/commit/45b937d64f6b7b6971856a47e379c7c87af7e00a c = self.kwargs.get("c", None) - if ( - pd.core.dtypes.common.is_integer(c) - and not self.data.columns._holds_integer() - ): + if pd.core.dtypes.common.is_integer(c): c = self.data.columns[c] if self._is_column_name(c, sample) and sample[c].dtype == dtypes.STRING_DTYPE: sample[c] = sample[c].astype("object")