Skip to content

Commit ca3d36a

Browse files
committed
fix: plot.scatter c arguments functionalities
1 parent ad0e99e commit ca3d36a

File tree

2 files changed

+127
-4
lines changed

2 files changed

+127
-4
lines changed

bigframes/operations/_matplotlib/core.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@
1414

1515
import abc
1616
import typing
17+
import uuid
18+
19+
import matplotlib.pyplot as plt
20+
import pandas as pd
21+
22+
import bigframes.dtypes as dtypes
1723

1824
DEFAULT_SAMPLING_N = 1000
1925
DEFAULT_SAMPLING_STATE = 0
2026

21-
2227
class MPLPlot(abc.ABC):
2328
@abc.abstractmethod
2429
def generate(self):
@@ -44,12 +49,13 @@ def _kind(self):
4449

4550
def __init__(self, data, **kwargs) -> None:
4651
self.kwargs = kwargs
47-
self.data = self._compute_plot_data(data)
52+
self.data = data
4853

4954
def generate(self) -> None:
50-
self.axes = self.data.plot(kind=self._kind, **self.kwargs)
55+
plot_data = self._compute_plot_data()
56+
self.axes = plot_data.plot(kind=self._kind, **self.kwargs)
5157

52-
def _compute_plot_data(self, data):
58+
def _compute_sample_data(self, data):
5359
# TODO: Cache the sampling data in the PlotAccessor.
5460
sampling_n = self.kwargs.pop("sampling_n", DEFAULT_SAMPLING_N)
5561
sampling_random_state = self.kwargs.pop(
@@ -61,6 +67,9 @@ def _compute_plot_data(self, data):
6167
sort=False,
6268
).to_pandas()
6369

70+
def _compute_plot_data(self):
71+
return self._compute_sample_data(self.data)
72+
6473

6574
class LinePlot(SamplingPlot):
6675
@property
@@ -78,3 +87,56 @@ class ScatterPlot(SamplingPlot):
7887
@property
7988
def _kind(self) -> typing.Literal["scatter"]:
8089
return "scatter"
90+
91+
def __init__(self, data, **kwargs) -> None:
92+
super().__init__(data, **kwargs)
93+
94+
c = self.kwargs.get("c", None)
95+
if self._is_sequence_arg(c) and len(c) != self.data.shape[0]:
96+
raise ValueError(
97+
f"'c' argument has {len(c)} elements, which is "
98+
+ f"inconsistent with 'x' and 'y' with size {self.data.shape[0]}"
99+
)
100+
101+
def _compute_plot_data(self):
102+
data = self.data.copy()
103+
104+
c = self.kwargs.get("c", None)
105+
c_id = None
106+
if self._is_sequence_arg(c):
107+
c_id = self._generate_new_column_name(data)
108+
print(c_id)
109+
data[c_id] = c
110+
111+
sample = self._compute_sample_data(data)
112+
113+
# Works around a pandas bug:
114+
# https://github.com/pandas-dev/pandas/commit/45b937d64f6b7b6971856a47e379c7c87af7e00a
115+
if self._is_column_name(c, sample) and sample[c].dtype == dtypes.STRING_DTYPE:
116+
sample[c] = sample[c].astype("object")
117+
118+
if c_id is not None:
119+
self.kwargs["c"] = sample[c_id]
120+
sample = sample.drop(columns=[c_id])
121+
122+
return sample
123+
124+
def _is_sequence_arg(self, arg):
125+
return (
126+
arg is not None
127+
and not isinstance(arg, str)
128+
and isinstance(arg, typing.Iterable)
129+
)
130+
131+
def _is_column_name(self, arg, data):
132+
return (
133+
arg is not None
134+
and pd.core.dtypes.common.is_hashable(arg)
135+
and arg in data.columns
136+
)
137+
138+
def _generate_new_column_name(self, data):
139+
col_name = None
140+
while col_name is None or col_name in data.columns:
141+
col_name = f"plot_temp_{str(uuid.uuid4())[:8]}"
142+
return col_name

tests/system/small/operations/test_plotting.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,67 @@ def test_scatter(scalars_dfs):
209209
)
210210

211211

212+
@pytest.mark.parametrize(
213+
("c"),
214+
[
215+
pytest.param("red", id="red"),
216+
pytest.param("c", id="int_column"),
217+
pytest.param("species", id="color_column"),
218+
pytest.param(["red", "green", "blue"], id="color_sequence"),
219+
pytest.param([3.4, 5.3, 2.0], id="number_sequence"),
220+
pytest.param(
221+
[3.4, 5.3],
222+
id="length_mismatches_sequence",
223+
marks=pytest.mark.xfail(
224+
raises=ValueError,
225+
),
226+
),
227+
],
228+
)
229+
def test_scatter_args_c(c):
230+
data = {
231+
"a": [1, 2, 3],
232+
"b": [1, 2, 3],
233+
"c": [1, 2, 3],
234+
"species": ["r", "g", "b"],
235+
}
236+
df = bpd.DataFrame(data)
237+
pd_df = pd.DataFrame(data)
238+
239+
ax = df.plot.scatter(x="a", y="b", c=c)
240+
pd_ax = pd_df.plot.scatter(x="a", y="b", c=c)
241+
assert len(ax.collections[0].get_facecolor()) == len(
242+
pd_ax.collections[0].get_facecolor()
243+
)
244+
for idx in range(len(ax.collections[0].get_facecolor())):
245+
tm.assert_numpy_array_equal(
246+
ax.collections[0].get_facecolor()[idx],
247+
pd_ax.collections[0].get_facecolor()[idx],
248+
)
249+
250+
251+
def test_scatter_args_c_sampling():
252+
data = {
253+
"plot_temp_0": [1, 2, 3, 4, 5],
254+
"plot_temp_1": [5, 4, 3, 2, 1],
255+
}
256+
c = ["red", "green", "blue", "orange", "black"]
257+
258+
df = bpd.DataFrame(data)
259+
pd_df = pd.DataFrame(data)
260+
261+
ax = df.plot.scatter(x="plot_temp_0", y="plot_temp_1", c=c, sampling_n=3)
262+
pd_ax = pd_df.plot.scatter(x="plot_temp_0", y="plot_temp_1", c=c)
263+
assert len(ax.collections[0].get_facecolor()) == len(
264+
pd_ax.collections[0].get_facecolor()
265+
)
266+
for idx in range(len(ax.collections[0].get_facecolor())):
267+
tm.assert_numpy_array_equal(
268+
ax.collections[0].get_facecolor()[idx],
269+
pd_ax.collections[0].get_facecolor()[idx],
270+
)
271+
272+
212273
def test_sampling_plot_args_n():
213274
df = bpd.DataFrame(np.arange(bf_mpl.DEFAULT_SAMPLING_N * 10), columns=["one"])
214275
ax = df.plot.line()

0 commit comments

Comments
 (0)