14
14
15
15
import abc
16
16
import typing
17
+ import uuid
18
+
19
+ import matplotlib .pyplot as plt
20
+ import pandas as pd
21
+
22
+ import bigframes .dtypes as dtypes
17
23
18
24
DEFAULT_SAMPLING_N = 1000
19
25
DEFAULT_SAMPLING_STATE = 0
20
26
21
-
22
27
class MPLPlot (abc .ABC ):
23
28
@abc .abstractmethod
24
29
def generate (self ):
@@ -44,12 +49,13 @@ def _kind(self):
44
49
45
50
def __init__ (self , data , ** kwargs ) -> None :
46
51
self .kwargs = kwargs
47
- self .data = self . _compute_plot_data ( data )
52
+ self .data = data
48
53
49
54
def generate (self ) -> None :
50
- self .axes = self .data .plot (kind = self ._kind , ** self .kwargs )
55
+ plot_data = self ._compute_plot_data ()
56
+ self .axes = plot_data .plot (kind = self ._kind , ** self .kwargs )
51
57
52
- def _compute_plot_data (self , data ):
58
+ def _compute_sample_data (self , data ):
53
59
# TODO: Cache the sampling data in the PlotAccessor.
54
60
sampling_n = self .kwargs .pop ("sampling_n" , DEFAULT_SAMPLING_N )
55
61
sampling_random_state = self .kwargs .pop (
@@ -61,6 +67,9 @@ def _compute_plot_data(self, data):
61
67
sort = False ,
62
68
).to_pandas ()
63
69
70
+ def _compute_plot_data (self ):
71
+ return self ._compute_sample_data (self .data )
72
+
64
73
65
74
class LinePlot (SamplingPlot ):
66
75
@property
@@ -78,3 +87,56 @@ class ScatterPlot(SamplingPlot):
78
87
@property
79
88
def _kind (self ) -> typing .Literal ["scatter" ]:
80
89
return "scatter"
90
+
91
+ def __init__ (self , data , ** kwargs ) -> None :
92
+ super ().__init__ (data , ** kwargs )
93
+
94
+ c = self .kwargs .get ("c" , None )
95
+ if self ._is_sequence_arg (c ) and len (c ) != self .data .shape [0 ]:
96
+ raise ValueError (
97
+ f"'c' argument has { len (c )} elements, which is "
98
+ + f"inconsistent with 'x' and 'y' with size { self .data .shape [0 ]} "
99
+ )
100
+
101
+ def _compute_plot_data (self ):
102
+ data = self .data .copy ()
103
+
104
+ c = self .kwargs .get ("c" , None )
105
+ c_id = None
106
+ if self ._is_sequence_arg (c ):
107
+ c_id = self ._generate_new_column_name (data )
108
+ print (c_id )
109
+ data [c_id ] = c
110
+
111
+ sample = self ._compute_sample_data (data )
112
+
113
+ # Works around a pandas bug:
114
+ # https://github.com/pandas-dev/pandas/commit/45b937d64f6b7b6971856a47e379c7c87af7e00a
115
+ if self ._is_column_name (c , sample ) and sample [c ].dtype == dtypes .STRING_DTYPE :
116
+ sample [c ] = sample [c ].astype ("object" )
117
+
118
+ if c_id is not None :
119
+ self .kwargs ["c" ] = sample [c_id ]
120
+ sample = sample .drop (columns = [c_id ])
121
+
122
+ return sample
123
+
124
+ def _is_sequence_arg (self , arg ):
125
+ return (
126
+ arg is not None
127
+ and not isinstance (arg , str )
128
+ and isinstance (arg , typing .Iterable )
129
+ )
130
+
131
+ def _is_column_name (self , arg , data ):
132
+ return (
133
+ arg is not None
134
+ and pd .core .dtypes .common .is_hashable (arg )
135
+ and arg in data .columns
136
+ )
137
+
138
+ def _generate_new_column_name (self , data ):
139
+ col_name = None
140
+ while col_name is None or col_name in data .columns :
141
+ col_name = f"plot_temp_{ str (uuid .uuid4 ())[:8 ]} "
142
+ return col_name
0 commit comments