14
14
15
15
import abc
16
16
import typing
17
+ import uuid
17
18
18
19
import matplotlib .pyplot as plt
20
+ import pandas as pd
21
+
22
+ import bigframes .dtypes as dtypes
19
23
20
24
21
25
class MPLPlot (abc .ABC ):
@@ -38,12 +42,13 @@ def _kind(self):
38
42
39
43
def __init__ (self , data , ** kwargs ) -> None :
40
44
self .kwargs = kwargs
41
- self .data = self . _compute_plot_data ( data )
45
+ self .data = data
42
46
43
47
def generate (self ) -> None :
44
- self .axes = self .data .plot (kind = self ._kind , ** self .kwargs )
48
+ plot_data = self ._compute_plot_data ()
49
+ self .axes = plot_data .plot (kind = self ._kind , ** self .kwargs )
45
50
46
- def _compute_plot_data (self , data ):
51
+ def _compute_sample_data (self , data ):
47
52
# TODO: Cache the sampling data in the PlotAccessor.
48
53
sampling_n = self .kwargs .pop ("sampling_n" , 100 )
49
54
sampling_random_state = self .kwargs .pop ("sampling_random_state" , 0 )
@@ -53,6 +58,9 @@ def _compute_plot_data(self, data):
53
58
sort = False ,
54
59
).to_pandas ()
55
60
61
+ def _compute_plot_data (self ):
62
+ return self ._compute_sample_data (self .data )
63
+
56
64
57
65
class LinePlot (SamplingPlot ):
58
66
@property
@@ -70,3 +78,56 @@ class ScatterPlot(SamplingPlot):
70
78
@property
71
79
def _kind (self ) -> typing .Literal ["scatter" ]:
72
80
return "scatter"
81
+
82
+ def __init__ (self , data , ** kwargs ) -> None :
83
+ super ().__init__ (data , ** kwargs )
84
+
85
+ c = self .kwargs .get ("c" , None )
86
+ if self ._is_sequence_arg (c ) and len (c ) != self .data .shape [0 ]:
87
+ raise ValueError (
88
+ f"'c' argument has { len (c )} elements, which is "
89
+ + f"inconsistent with 'x' and 'y' with size { self .data .shape [0 ]} "
90
+ )
91
+
92
+ def _compute_plot_data (self ):
93
+ data = self .data .copy ()
94
+
95
+ c = self .kwargs .get ("c" , None )
96
+ c_id = None
97
+ if self ._is_sequence_arg (c ):
98
+ c_id = self ._generate_new_column_name (data )
99
+ print (c_id )
100
+ data [c_id ] = c
101
+
102
+ sample = self ._compute_sample_data (data )
103
+
104
+ # Works around a pandas bug:
105
+ # https://github.com/pandas-dev/pandas/commit/45b937d64f6b7b6971856a47e379c7c87af7e00a
106
+ if self ._is_column_name (c , sample ) and sample [c ].dtype == dtypes .STRING_DTYPE :
107
+ sample [c ] = sample [c ].astype ("object" )
108
+
109
+ if c_id is not None :
110
+ self .kwargs ["c" ] = sample [c_id ]
111
+ sample = sample .drop (columns = [c_id ])
112
+
113
+ return sample
114
+
115
+ def _is_sequence_arg (self , arg ):
116
+ return (
117
+ arg is not None
118
+ and not isinstance (arg , str )
119
+ and isinstance (arg , typing .Iterable )
120
+ )
121
+
122
+ def _is_column_name (self , arg , data ):
123
+ return (
124
+ arg is not None
125
+ and pd .core .dtypes .common .is_hashable (arg )
126
+ and arg in data .columns
127
+ )
128
+
129
+ def _generate_new_column_name (self , data ):
130
+ col_name = None
131
+ while col_name is None or col_name in data .columns :
132
+ col_name = f"plot_temp_{ str (uuid .uuid4 ())[:8 ]} "
133
+ return col_name
0 commit comments