Skip to content

Commit 9bc60dd

Browse files
committed
fix: plot.scatter c arguments functionalities
1 parent 429a4a5 commit 9bc60dd

File tree

2 files changed

+125
-3
lines changed

2 files changed

+125
-3
lines changed

bigframes/operations/_matplotlib/core.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414

1515
import abc
1616
import typing
17+
import uuid
1718

1819
import matplotlib.pyplot as plt
20+
import pandas as pd
21+
22+
import bigframes.dtypes as dtypes
1923

2024

2125
class MPLPlot(abc.ABC):
@@ -38,12 +42,13 @@ def _kind(self):
3842

3943
def __init__(self, data, **kwargs) -> None:
4044
self.kwargs = kwargs
41-
self.data = self._compute_plot_data(data)
45+
self.data = data
4246

4347
def generate(self) -> None:
44-
self.axes = self.data.plot(kind=self._kind, **self.kwargs)
48+
plot_data = self._compute_plot_data()
49+
self.axes = plot_data.plot(kind=self._kind, **self.kwargs)
4550

46-
def _compute_plot_data(self, data):
51+
def _compute_sample_data(self, data):
4752
# TODO: Cache the sampling data in the PlotAccessor.
4853
sampling_n = self.kwargs.pop("sampling_n", 100)
4954
sampling_random_state = self.kwargs.pop("sampling_random_state", 0)
@@ -53,6 +58,9 @@ def _compute_plot_data(self, data):
5358
sort=False,
5459
).to_pandas()
5560

61+
def _compute_plot_data(self):
62+
return self._compute_sample_data(self.data)
63+
5664

5765
class LinePlot(SamplingPlot):
5866
@property
@@ -70,3 +78,56 @@ class ScatterPlot(SamplingPlot):
7078
@property
7179
def _kind(self) -> typing.Literal["scatter"]:
7280
return "scatter"
81+
82+
def __init__(self, data, **kwargs) -> None:
83+
super().__init__(data, **kwargs)
84+
85+
c = self.kwargs.get("c", None)
86+
if self._is_sequence_arg(c) and len(c) != self.data.shape[0]:
87+
raise ValueError(
88+
f"'c' argument has {len(c)} elements, which is "
89+
+ f"inconsistent with 'x' and 'y' with size {self.data.shape[0]}"
90+
)
91+
92+
def _compute_plot_data(self):
93+
data = self.data.copy()
94+
95+
c = self.kwargs.get("c", None)
96+
c_id = None
97+
if self._is_sequence_arg(c):
98+
c_id = self._generate_new_column_name(data)
99+
print(c_id)
100+
data[c_id] = c
101+
102+
sample = self._compute_sample_data(data)
103+
104+
# Works around a pandas bug:
105+
# https://github.com/pandas-dev/pandas/commit/45b937d64f6b7b6971856a47e379c7c87af7e00a
106+
if self._is_column_name(c, sample) and sample[c].dtype == dtypes.STRING_DTYPE:
107+
sample[c] = sample[c].astype("object")
108+
109+
if c_id is not None:
110+
self.kwargs["c"] = sample[c_id]
111+
sample = sample.drop(columns=[c_id])
112+
113+
return sample
114+
115+
def _is_sequence_arg(self, arg):
116+
return (
117+
arg is not None
118+
and not isinstance(arg, str)
119+
and isinstance(arg, typing.Iterable)
120+
)
121+
122+
def _is_column_name(self, arg, data):
123+
return (
124+
arg is not None
125+
and pd.core.dtypes.common.is_hashable(arg)
126+
and arg in data.columns
127+
)
128+
129+
def _generate_new_column_name(self, data):
130+
col_name = None
131+
while col_name is None or col_name in data.columns:
132+
col_name = f"plot_temp_{str(uuid.uuid4())[:8]}"
133+
return col_name

tests/system/small/operations/test_plotting.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,67 @@ def test_scatter(scalars_dfs):
208208
)
209209

210210

211+
@pytest.mark.parametrize(
212+
("c"),
213+
[
214+
pytest.param("red", id="red"),
215+
pytest.param("c", id="int_column"),
216+
pytest.param("species", id="color_column"),
217+
pytest.param(["red", "green", "blue"], id="color_sequence"),
218+
pytest.param([3.4, 5.3, 2.0], id="number_sequence"),
219+
pytest.param(
220+
[3.4, 5.3],
221+
id="length_mismatches_sequence",
222+
marks=pytest.mark.xfail(
223+
raises=ValueError,
224+
),
225+
),
226+
],
227+
)
228+
def test_scatter_args_c(c):
229+
data = {
230+
"a": [1, 2, 3],
231+
"b": [1, 2, 3],
232+
"c": [1, 2, 3],
233+
"species": ["r", "g", "b"],
234+
}
235+
df = bpd.DataFrame(data)
236+
pd_df = pd.DataFrame(data)
237+
238+
ax = df.plot.scatter(x="a", y="b", c=c)
239+
pd_ax = pd_df.plot.scatter(x="a", y="b", c=c)
240+
assert len(ax.collections[0].get_facecolor()) == len(
241+
pd_ax.collections[0].get_facecolor()
242+
)
243+
for idx in range(len(ax.collections[0].get_facecolor())):
244+
tm.assert_numpy_array_equal(
245+
ax.collections[0].get_facecolor()[idx],
246+
pd_ax.collections[0].get_facecolor()[idx],
247+
)
248+
249+
250+
def test_scatter_args_c_sampling():
251+
data = {
252+
"plot_temp_0": [1, 2, 3, 4, 5],
253+
"plot_temp_1": [5, 4, 3, 2, 1],
254+
}
255+
c = ["red", "green", "blue", "orange", "black"]
256+
257+
df = bpd.DataFrame(data)
258+
pd_df = pd.DataFrame(data)
259+
260+
ax = df.plot.scatter(x="plot_temp_0", y="plot_temp_1", c=c, sampling_n=3)
261+
pd_ax = pd_df.plot.scatter(x="plot_temp_0", y="plot_temp_1", c=c)
262+
assert len(ax.collections[0].get_facecolor()) == len(
263+
pd_ax.collections[0].get_facecolor()
264+
)
265+
for idx in range(len(ax.collections[0].get_facecolor())):
266+
tm.assert_numpy_array_equal(
267+
ax.collections[0].get_facecolor()[idx],
268+
pd_ax.collections[0].get_facecolor()[idx],
269+
)
270+
271+
211272
def test_sampling_plot_args_n():
212273
df = bpd.DataFrame(np.arange(1000), columns=["one"])
213274
ax = df.plot.line()

0 commit comments

Comments
 (0)