diff --git a/plotly/tests/test_optional/test_figure_factory.py b/plotly/tests/test_optional/test_figure_factory.py index 5eff6b2c16a..f75c0789ced 100644 --- a/plotly/tests/test_optional/test_figure_factory.py +++ b/plotly/tests/test_optional/test_figure_factory.py @@ -627,18 +627,6 @@ def test_valid_colormap(self): self.assertRaises(PlotlyError, tls.FigureFactory.create_trisurf, x, y, z, simplices, colormap='foo') - # check that colormap is a list, if not a string - - pattern1 = ( - "If 'colormap' is a list, then its items must be tripets of the " - "form a,b,c or 'rgbx,y,z' where a,b,c are between 0 and 1 " - "inclusive and x,y,z are between 0 and 255 inclusive." - ) - - self.assertRaisesRegexp(PlotlyError, pattern1, - tls.FigureFactory.create_trisurf, - x, y, z, simplices, colormap=3) - # check: if colormap is a list of rgb color strings, make sure the # entries of each color are no greater than 255.0 @@ -650,20 +638,31 @@ def test_valid_colormap(self): self.assertRaisesRegexp(PlotlyError, pattern2, tls.FigureFactory.create_trisurf, x, y, z, simplices, - colormap=['rgb(1, 2, 3)', 'rgb(4, 5, 600)']) + colormap=['rgb(4, 5, 600)']) # check: if colormap is a list of tuple colors, make sure the entries # of each tuple are no greater than 1.0 pattern3 = ( - "Whoops! The elements in your rgb colormap tuples " - "cannot exceed 1.0." + "Whoops! The elements in your colormap tuples cannot exceed 1.0." ) self.assertRaisesRegexp(PlotlyError, pattern3, tls.FigureFactory.create_trisurf, x, y, z, simplices, - colormap=[(0.2, 0.4, 0.6), (0.8, 1.0, 1.2)]) + colormap=[(0.8, 1.0, 1.2)]) + + # check: + + pattern4 = ( + "You must input a valid colormap. Valid types include a plotly " + "scale, rgb, hex or tuple color, or lastly a list of any color " + "types." + ) + + self.assertRaisesRegexp(PlotlyError, pattern4, + tls.FigureFactory.create_trisurf, + x, y, z, simplices, colormap=1) def test_trisurf_all_args(self): @@ -865,27 +864,55 @@ def test_same_data_in_index(self): tls.FigureFactory.create_scatterplotmatrix, df, index='apple') - def test_valid_palette(self): + def test_valid_colormap(self): - # check: the palette argument is in a acceptable form + # check: the colormap argument is in a valid form df = pd.DataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]], columns=['a', 'b', 'c']) - self.assertRaisesRegexp(PlotlyError, "You must pick a valid " - "plotly colorscale name.", - tls.FigureFactory.create_scatterplotmatrix, - df, use_theme=True, index='a', - palette='fake_scale') + # check: valid plotly scalename is entered + self.assertRaises(PlotlyError, + tls.FigureFactory.create_scatterplotmatrix, + df, index='a', colormap='fake_scale') pattern = ( - "The items of 'palette' must be tripets of the form a,b,c or " - "'rgbx,y,z' where a,b,c belong to the interval 0,1 and x,y,z " - "belong to 0,255." + "You must input a valid colormap. Valid types include a plotly " + "scale, rgb, hex or tuple color, a list of any color types, or a " + "dictionary with index names each assigned to a color." ) + # check: accepted data type for colormap self.assertRaisesRegexp(PlotlyError, pattern, tls.FigureFactory.create_scatterplotmatrix, - df, use_theme=True, palette=1, index='c') + df, colormap=1) + + pattern_rgb = ( + "Whoops! The elements in your rgb colormap tuples cannot " + "exceed 255.0." + ) + + # check: proper 'rgb' color + self.assertRaisesRegexp(PlotlyError, pattern_rgb, + tls.FigureFactory.create_scatterplotmatrix, + df, colormap='rgb(500, 1, 1)', index='c') + + self.assertRaisesRegexp(PlotlyError, pattern_rgb, + tls.FigureFactory.create_scatterplotmatrix, + df, colormap=['rgb(500, 1, 1)'], index='c') + + pattern_tuple = ( + "Whoops! The elements in your colormap tuples cannot " + "exceed 1.0." + ) + + # check: proper color tuple + self.assertRaisesRegexp(PlotlyError, pattern_tuple, + tls.FigureFactory.create_scatterplotmatrix, + df, colormap=(2, 1, 1), index='c') + + self.assertRaisesRegexp(PlotlyError, pattern_tuple, + tls.FigureFactory.create_scatterplotmatrix, + df, colormap=[(2, 1, 1)], index='c') def test_valid_endpts(self): @@ -900,20 +927,35 @@ def test_valid_endpts(self): self.assertRaisesRegexp(PlotlyError, pattern, tls.FigureFactory.create_scatterplotmatrix, - df, use_theme=True, index='a', - palette='Blues', endpts='foo') + df, index='a', colormap='Hot', endpts='foo') # check: the endpts are a list of numbers self.assertRaisesRegexp(PlotlyError, pattern, tls.FigureFactory.create_scatterplotmatrix, - df, use_theme=True, index='a', - palette='Blues', endpts=['a']) + df, index='a', colormap='Hot', endpts=['a']) # check: endpts is a list of INCREASING numbers self.assertRaisesRegexp(PlotlyError, pattern, tls.FigureFactory.create_scatterplotmatrix, - df, use_theme=True, index='a', - palette='Blues', endpts=[2, 1]) + df, index='a', colormap='Hot', endpts=[2, 1]) + + def test_dictionary_colormap(self): + + # if colormap is a dictionary, make sure it all the values in the + # index column are keys in colormap + df = pd.DataFrame([['apple', 'happy'], ['pear', 'sad']], + columns=['Fruit', 'Emotion']) + + colormap = {'happy': 'rgb(5, 5, 5)'} + + pattern = ( + "If colormap is a dictionary, all the names in the index " + "must be keys." + ) + + self.assertRaisesRegexp(PlotlyError, pattern, + tls.FigureFactory.create_scatterplotmatrix, + df, index='Emotion', colormap=colormap) def test_scatter_plot_matrix(self): @@ -926,7 +968,7 @@ def test_scatter_plot_matrix(self): test_scatter_plot_matrix = tls.FigureFactory.create_scatterplotmatrix( df=df, diag='scatter', height=1000, width=1000, size=13, - title='Scatterplot Matrix', use_theme=False + title='Scatterplot Matrix' ) exp_scatter_plot_matrix = { @@ -1020,17 +1062,17 @@ def test_scatter_plot_matrix_kwargs(self): test_scatter_plot_matrix = tls.FigureFactory.create_scatterplotmatrix( df, index='Fruit', endpts=[-10, -1], diag='histogram', height=1000, width=1000, size=13, title='Scatterplot Matrix', - use_theme=True, palette='YlOrRd', marker=dict(symbol=136) + colormap='YlOrRd', marker=dict(symbol=136) ) exp_scatter_plot_matrix = { - 'data': [{'marker': {'color': 'rgb(128.0, 0.0, 38.0)'}, + 'data': [{'marker': {'color': 'rgb(128,0,38)'}, 'showlegend': False, 'type': 'histogram', 'x': [2, -15, -2, 0], 'xaxis': 'x1', 'yaxis': 'y1'}, - {'marker': {'color': 'rgb(255.0, 255.0, 204.0)'}, + {'marker': {'color': 'rgb(255,255,204)'}, 'showlegend': False, 'type': 'histogram', 'x': [6, 5], diff --git a/plotly/tools.py b/plotly/tools.py index 0f71d950e85..42e0479caf9 100644 --- a/plotly/tools.py +++ b/plotly/tools.py @@ -1432,6 +1432,7 @@ def return_figure_from_figure_or_data(figure_or_data, validate_figure): _DEFAULT_DECREASING_COLOR = '#FF4136' DIAG_CHOICES = ['scatter', 'histogram', 'box'] +VALID_COLORMAP_TYPES = ['cat', 'seq'] class FigureFactory(object): @@ -1475,20 +1476,29 @@ def _unconvert_from_RGB_255(colors): """ Return a tuple where each element gets divided by 255 - Takes a list of color tuples where each element is between 0 and 255 - and returns the same list where each tuple element is normalized to be - between 0 and 1 + Takes a (list of) color tuple(s) where each element is between 0 and + 255. Returns the same tuples where each tuple element is normalized to + a value between 0 and 1 """ - un_rgb_colors = [] - for color in colors: - un_rgb_color = (color[0]/(255.0), - color[1]/(255.0), - color[2]/(255.0)) + if isinstance(colors, tuple): - un_rgb_colors.append(un_rgb_color) + un_rgb_color = (colors[0]/(255.0), + colors[1]/(255.0), + colors[2]/(255.0)) - return un_rgb_colors + return un_rgb_color + + if isinstance(colors, list): + un_rgb_colors = [] + for color in colors: + un_rgb_color = (color[0]/(255.0), + color[1]/(255.0), + color[2]/(255.0)) + + un_rgb_colors.append(un_rgb_color) + + return un_rgb_colors @staticmethod def _map_z2color(zval, colormap, vmin, vmax): @@ -1509,9 +1519,30 @@ def _map_z2color(zval, colormap, vmin, vmax): # find distance t of zval from vmin to vmax where the distance # is normalized to be between 0 and 1 t = (zval - vmin)/float((vmax - vmin)) - t_color = FigureFactory._find_intermediate_color(colormap[0], - colormap[1], - t) + + # for colormaps of more than 2 colors, find two closest colors based + # on relative position between vmin and vmax + if len(colormap) == 1: + t_color = colormap[0] + else: + num_steps = len(colormap) - 1 + step = 1./num_steps + + if t == 1.0: + t_color = FigureFactory._find_intermediate_color( + colormap[int(t/step) - 1], + colormap[int(t/step)], + t + ) + else: + new_t = (t - int(t/step)*step)/float(step) + + t_color = FigureFactory._find_intermediate_color( + colormap[int(t/step)], + colormap[int(t/step) + 1], + new_t + ) + t_color = (t_color[0]*255.0, t_color[1]*255.0, t_color[2]*255.0) labelled_color = 'rgb{}'.format(t_color) @@ -1525,7 +1556,7 @@ def _tri_indices(simplices): return ([triplet[c] for triplet in simplices] for c in range(3)) @staticmethod - def _trisurf(x, y, z, simplices, colormap=None, dist_func=None, + def _trisurf(x, y, z, simplices, colormap=None, color_func=None, plot_edges=None, x_edge=None, y_edge=None, z_edge=None): """ Refer to FigureFactory.create_trisurf() for docstring @@ -1541,7 +1572,7 @@ def _trisurf(x, y, z, simplices, colormap=None, dist_func=None, # vertices of the surface triangles tri_vertices = list(map(lambda index: points3D[index], simplices)) - if not dist_func: + if not color_func: # mean values of z-coordinates of triangle vertices mean_dists = [np.mean(tri[:, 2]) for tri in tri_vertices] else: @@ -1552,7 +1583,7 @@ def _trisurf(x, y, z, simplices, colormap=None, dist_func=None, for triangle in tri_vertices: dists = [] for vertex in triangle: - dist = dist_func(vertex[0], vertex[1], vertex[2]) + dist = color_func(vertex[0], vertex[1], vertex[2]) dists.append(dist) mean_dists.append(np.mean(dists)) @@ -1603,7 +1634,7 @@ def _trisurf(x, y, z, simplices, colormap=None, dist_func=None, @staticmethod def create_trisurf(x, y, z, simplices, colormap=None, - dist_func=None, title='Trisurf Plot', + color_func=None, title='Trisurf Plot', showbackground=True, backgroundcolor='rgb(230, 230, 230)', gridcolor='rgb(255, 255, 255)', @@ -1619,11 +1650,12 @@ def create_trisurf(x, y, z, simplices, colormap=None, :param (array) simplices: an array of shape (ntri, 3) where ntri is the number of triangles in the triangularization. Each row of the array contains the indicies of the verticies of each triangle. - :param (str|list) colormap: either a plotly scale name, or a list - containing 2 triplets. These triplets must be of the form (a,b,c) - or 'rgb(x,y,z)' where a,b,c belong to the interval [0,1] and x,y,z - belong to [0,255] - :param (function) dist_func: The function that determines how the + :param (str|tuple|list) colormap: takes a plotly scale, an rgb or hex + string, a tuple, or a list of colors of the aforementioned rgb, + hex or tuple types. An rgb/triplet color type is a triplet of the + form (a,b,c) or 'rgb(x,y,z)' respectively where a,b,c belong to + the interval [0,1] and x,y,z belong to [0,255] + :param (function) color_func: The function that determines how the coloring of the surface changes. It takes 3 arguments x, y, z and must return a formula of these variables which can include numpy functions (eg. np.sqrt). If set to None, color will only depend on @@ -1773,9 +1805,11 @@ def dist_origin(x, y, z): # Create a figure fig1 = FF.create_trisurf(x=x, y=y, z=z, - colormap="Blues", + colormap=['#604d9e', + 'rgb(50, 150, 255)', + (0.2, 0.2, 0.8)], simplices=simplices, - dist_func=dist_origin) + color_func=dist_origin) # Plot the data py.iplot(fig1, filename='Trisurf Plot - Custom Coloring') ``` @@ -1807,50 +1841,83 @@ def dist_origin(x, y, z): colormap = FigureFactory._unconvert_from_RGB_255(colormap) if isinstance(colormap, str): - if colormap not in plotly_scales: - scale_keys = list(plotly_scales.keys()) - raise exceptions.PlotlyError("You must pick a valid " - "plotly colorscale " - "name from " - "{}".format(scale_keys)) + if colormap in plotly_scales: + colormap = plotly_scales[colormap] + colormap = FigureFactory._unlabel_rgb(colormap) + colormap = FigureFactory._unconvert_from_RGB_255(colormap) - colormap = [plotly_scales[colormap][0], - plotly_scales[colormap][1]] - colormap = FigureFactory._unlabel_rgb(colormap) - colormap = FigureFactory._unconvert_from_RGB_255(colormap) + elif 'rgb' in colormap: + # put colormap in list + colors_list = [] + colors_list.append(colormap) + colormap = colors_list - else: - if not isinstance(colormap, list): - raise exceptions.PlotlyError("If 'colormap' is a list, then " - "its items must be tripets of " - "the form a,b,c or 'rgbx,y,z' " - "where a,b,c are between 0 and " - "1 inclusive and x,y,z are " - "between 0 and 255 inclusive.") - if 'rgb' in colormap[0]: colormap = FigureFactory._unlabel_rgb(colormap) - for color in colormap: - for index in range(3): - if color[index] > 255.0: + colormap = FigureFactory._unconvert_from_RGB_255(colormap) + + elif '#' in colormap: + colormap = FigureFactory._hex_to_rgb(colormap) + colormap = FigureFactory._unconvert_from_RGB_255(colormap) + + # put colormap in list + colors_list = [] + colors_list.append(colormap) + colormap = colors_list + + else: + scale_keys = list(plotly_scales.keys()) + raise exceptions.PlotlyError("If you input a string " + "for 'colormap', it must " + "either be a Plotly " + "colorscale, an 'rgb' " + "color or a hex color." + "Valid plotly colorscale " + "names are {}".format(scale_keys)) + elif isinstance(colormap, tuple): + colors_list = [] + colors_list.append(colormap) + colormap = colors_list + + elif isinstance(colormap, list): + new_colormap = [] + for color in colormap: + if 'rgb' in color: + color = FigureFactory._unlabel_rgb(color) + + for value in color: + if value > 255.0: raise exceptions.PlotlyError("Whoops! The " "elements in your " "rgb colormap " "tuples cannot " "exceed 255.0.") - colormap = FigureFactory._unconvert_from_RGB_255(colormap) - if isinstance(colormap[0], tuple): - for color in colormap: - for index in range(3): - if color[index] > 1.0: + color = FigureFactory._unconvert_from_RGB_255(color) + new_colormap.append(color) + elif '#' in color: + color = FigureFactory._hex_to_rgb(color) + color = FigureFactory._unconvert_from_RGB_255(color) + new_colormap.append(color) + elif isinstance(color, tuple): + + for value in color: + if value > 1.0: raise exceptions.PlotlyError("Whoops! The " - "elements in your " - "rgb colormap " + "elements in " + "your colormap " "tuples cannot " "exceed 1.0.") + new_colormap.append(color) + colormap = new_colormap + + else: + raise exceptions.PlotlyError("You must input a valid colormap. " + "Valid types include a plotly scale, " + "rgb, hex or tuple color, or lastly " + "a list of any color types.") data1 = FigureFactory._trisurf(x, y, z, simplices, - dist_func=dist_func, + color_func=color_func, colormap=colormap, plot_edges=True) axis = dict( @@ -1876,17 +1943,14 @@ def dist_origin(x, y, z): return graph_objs.Figure(data=data1, layout=layout) @staticmethod - def _scatterplot(dataframe, headers, - diag, size, - height, width, - title, **kwargs): + def _scatterplot(dataframe, headers, diag, size, + height, width, title, **kwargs): """ - Refer to FigureFactory.create_scatterplotmatrix() for docstring. + Refer to FigureFactory.create_scatterplotmatrix() for docstring - Returns fig for scatterplotmatrix without index or theme. + Returns fig for scatterplotmatrix without index """ - from plotly.graph_objs import graph_objs dim = len(dataframe) fig = make_subplots(rows=dim, cols=dim) @@ -1953,23 +2017,25 @@ def _scatterplot(dataframe, headers, return fig @staticmethod - def _scatterplot_index(dataframe, headers, - diag, size, - height, width, - title, - index, index_vals, - **kwargs): + def _scatterplot_dict(dataframe, headers, diag, size, + height, width, title, index, index_vals, + endpts, colormap, colormap_type, **kwargs): """ - Refer to FigureFactory.create_scatterplotmatrix() for docstring. + Refer to FigureFactory.create_scatterplotmatrix() for docstring - Returns fig for scatterplotmatrix with an index and no theme. + Returns fig for scatterplotmatrix with both index and colormap picked. + Used if colormap is a dictionary with index values as keys pointing to + colors. Forces colormap_type to behave categorically because it would + not make sense colors are assigned to each index value and thus + implies that a categorical approach should be taken """ from plotly.graph_objs import graph_objs + + theme = colormap dim = len(dataframe) fig = make_subplots(rows=dim, cols=dim) trace_list = [] - legend_param = 0 # Work over all permutations of list pairs for listy in dataframe: @@ -1980,24 +2046,21 @@ def _scatterplot_index(dataframe, headers, if name not in unique_index_vals: unique_index_vals[name] = [] - c_indx = 0 # color index # Fill all the rest of the names into the dictionary - for name in unique_index_vals: + for name in sorted(unique_index_vals.keys()): new_listx = [] new_listy = [] - for j in range(len(index_vals)): if index_vals[j] == name: new_listx.append(listx[j]) new_listy.append(listy[j]) - # Generate trace with VISIBLE icon if legend_param == 1: if (listx == listy) and (diag == 'histogram'): trace = graph_objs.Histogram( x=new_listx, marker=dict( - color=DEFAULT_PLOTLY_COLORS[c_indx]), + color=theme[name]), showlegend=True ) elif (listx == listy) and (diag == 'box'): @@ -2005,14 +2068,13 @@ def _scatterplot_index(dataframe, headers, y=new_listx, name=None, marker=dict( - color=DEFAULT_PLOTLY_COLORS[c_indx]), + color=theme[name]), showlegend=True ) else: if 'marker' in kwargs: kwargs['marker']['size'] = size - (kwargs['marker'] - ['color']) = DEFAULT_PLOTLY_COLORS[c_indx] + kwargs['marker']['color'] = theme[name] trace = graph_objs.Scatter( x=new_listx, y=new_listy, @@ -2029,7 +2091,7 @@ def _scatterplot_index(dataframe, headers, name=name, marker=dict( size=size, - color=DEFAULT_PLOTLY_COLORS[c_indx]), + color=theme[name]), showlegend=True, **kwargs ) @@ -2039,22 +2101,21 @@ def _scatterplot_index(dataframe, headers, trace = graph_objs.Histogram( x=new_listx, marker=dict( - color=DEFAULT_PLOTLY_COLORS[c_indx]), + color=theme[name]), showlegend=False - ) + ) elif (listx == listy) and (diag == 'box'): trace = graph_objs.Box( y=new_listx, name=None, marker=dict( - color=DEFAULT_PLOTLY_COLORS[c_indx]), + color=theme[name]), showlegend=False ) else: if 'marker' in kwargs: kwargs['marker']['size'] = size - (kwargs['marker'] - ['color']) = DEFAULT_PLOTLY_COLORS[c_indx] + kwargs['marker']['color'] = theme[name] trace = graph_objs.Scatter( x=new_listx, y=new_listy, @@ -2071,15 +2132,12 @@ def _scatterplot_index(dataframe, headers, name=name, marker=dict( size=size, - color=DEFAULT_PLOTLY_COLORS[c_indx]), + color=theme[name]), showlegend=False, **kwargs ) # Push the trace into dictionary unique_index_vals[name] = trace - if c_indx >= (len(DEFAULT_PLOTLY_COLORS) - 1): - c_indx = -1 - c_indx += 1 trace_list.append(unique_index_vals) legend_param += 1 @@ -2087,7 +2145,7 @@ def _scatterplot_index(dataframe, headers, indices = range(1, dim + 1) for y_index in indices: for x_index in indices: - for name in trace_list[trace_index]: + for name in sorted(trace_list[trace_index].keys()): fig.append_trace( trace_list[trace_index][name], y_index, @@ -2098,6 +2156,7 @@ def _scatterplot_index(dataframe, headers, for j in range(dim): xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) fig['layout'][xaxis_key].update(title=headers[j]) + for j in range(dim): yaxis_key = 'yaxis{}'.format(1 + (dim * j)) fig['layout'][yaxis_key].update(title=headers[j]) @@ -2107,7 +2166,7 @@ def _scatterplot_index(dataframe, headers, height=height, width=width, title=title, showlegend=True, - barmode="stack") + barmode='stack') return fig elif diag == 'box': @@ -2125,48 +2184,16 @@ def _scatterplot_index(dataframe, headers, return fig @staticmethod - def _scatterplot_theme(dataframe, headers, diag, size, height, width, - title, index, index_vals, endpts, - palette, **kwargs): + def _scatterplot_theme(dataframe, headers, diag, size, height, + width, title, index, index_vals, endpts, + colormap, colormap_type, **kwargs): """ - Refer to FigureFactory.create_scatterplotmatrix() for docstring. + Refer to FigureFactory.create_scatterplotmatrix() for docstring - Returns fig for scatterplotmatrix with both index and theme. + Returns fig for scatterplotmatrix with both index and colormap picked - :raises: (PlotlyError) If palette string is not a Plotly colorscale - :raises: (PlotlyError) If palette is not a string or list """ from plotly.graph_objs import graph_objs - plotly_scales = {'Greys': ['rgb(0,0,0)', 'rgb(255,255,255)'], - 'YlGnBu': ['rgb(8,29,88)', 'rgb(255,255,217)'], - 'Greens': ['rgb(0,68,27)', 'rgb(247,252,245)'], - 'YlOrRd': ['rgb(128,0,38)', 'rgb(255,255,204)'], - 'Bluered': ['rgb(0,0,255)', 'rgb(255,0,0)'], - 'RdBu': ['rgb(5,10,172)', 'rgb(178,10,28)'], - 'Reds': ['rgb(220,220,220)', 'rgb(178,10,28)'], - 'Blues': ['rgb(5,10,172)', 'rgb(220,220,220)'], - 'Picnic': ['rgb(0,0,255)', 'rgb(255,0,0)'], - 'Rainbow': ['rgb(150,0,90)', 'rgb(255,0,0)'], - 'Portland': ['rgb(12,51,131)', 'rgb(217,30,30)'], - 'Jet': ['rgb(0,0,131)', 'rgb(128,0,0)'], - 'Hot': ['rgb(0,0,0)', 'rgb(255,255,255)'], - 'Blackbody': ['rgb(0,0,0)', 'rgb(160,200,255)'], - 'Earth': ['rgb(0,0,130)', 'rgb(255,255,255)'], - 'Electric': ['rgb(0,0,0)', 'rgb(255,250,220)'], - 'Viridis': ['rgb(68,1,84)', 'rgb(253,231,37)']} - - # Validate choice of palette - if isinstance(palette, str): - if palette not in plotly_scales: - raise exceptions.PlotlyError("You must pick a valid " - "plotly colorscale name.") - else: - if not isinstance(palette, list): - raise exceptions.PlotlyError("The items of 'palette' must be " - "tripets of the form a,b,c or " - "'rgbx,y,z' where a,b,c belong " - "to the interval 0,1 and x,y,z " - "belong to 0,255.") # Check if index is made of string values if isinstance(index_vals[0], str): @@ -2176,28 +2203,16 @@ def _scatterplot_theme(dataframe, headers, diag, size, height, width, unique_index_vals.append(name) n_colors_len = len(unique_index_vals) - # Convert palette to list of n RGB tuples - if isinstance(palette, str): - if palette in plotly_scales: - foo = FigureFactory._unlabel_rgb(plotly_scales[palette]) - foo = FigureFactory._n_colors(foo[0], - foo[1], - n_colors_len) - theme = FigureFactory._label_rgb(foo) - - if isinstance(palette, list): - if 'rgb' in palette[0]: - foo = FigureFactory._unlabel_rgb(palette) - foo = FigureFactory._n_colors(foo[0], - foo[1], - n_colors_len) - theme = FigureFactory._label_rgb(foo) - else: - foo = FigureFactory._convert_to_RGB_255(palette) - foo = FigureFactory._n_colors(foo[0], - foo[1], - n_colors_len) - theme = FigureFactory._label_rgb(foo) + # Convert colormap to list of n RGB tuples + if colormap_type == 'seq': + foo = FigureFactory._unlabel_rgb(colormap) + foo = FigureFactory._n_colors(foo[0], + foo[1], + n_colors_len) + theme = FigureFactory._label_rgb(foo) + if colormap_type == 'cat': + # leave list of colors the same way + theme = colormap dim = len(dataframe) fig = make_subplots(rows=dim, cols=dim) @@ -2310,7 +2325,6 @@ def _scatterplot_theme(dataframe, headers, diag, size, height, width, c_indx += 1 trace_list.append(unique_index_vals) legend_param += 1 - #return trace_list trace_index = 0 indices = range(1, dim + 1) @@ -2358,30 +2372,16 @@ def _scatterplot_theme(dataframe, headers, diag, size, height, width, if endpts: intervals = FigureFactory._endpts_to_intervals(endpts) - # Convert palette to list of n RGB tuples - if isinstance(palette, str): - if palette in plotly_scales: - foo = FigureFactory._unlabel_rgb( - plotly_scales[palette] - ) - foo = FigureFactory._n_colors(foo[0], - foo[1], - len(intervals)) - theme = FigureFactory._label_rgb(foo) - - if isinstance(palette, list): - if 'rgb' in palette[0]: - foo = FigureFactory._unlabel_rgb(palette) - foo = FigureFactory._n_colors(foo[0], - foo[1], - len(intervals)) - theme = FigureFactory._label_rgb(foo) - else: - foo = FigureFactory._convert_to_RGB_255(palette) - foo = FigureFactory._n_colors(foo[0], - foo[1], - len(intervals)) - theme = FigureFactory._label_rgb(foo) + # Convert colormap to list of n RGB tuples + if colormap_type == 'seq': + foo = FigureFactory._unlabel_rgb(colormap) + foo = FigureFactory._n_colors(foo[0], + foo[1], + len(intervals)) + theme = FigureFactory._label_rgb(foo) + if colormap_type == 'cat': + # leave list of colors the same way + theme = colormap dim = len(dataframe) fig = make_subplots(rows=dim, cols=dim) @@ -2537,17 +2537,15 @@ def _scatterplot_theme(dataframe, headers, diag, size, height, width, return fig else: - # Convert palette to list of 2 RGB tuples - if isinstance(palette, str): - if palette in plotly_scales: - theme = plotly_scales[palette] - - if isinstance(palette, list): - if 'rgb' in palette[0]: - theme = palette - else: - foo = FigureFactory._convert_to_RGB_255(palette) - theme = FigureFactory._label_rgb(foo) + theme = colormap + + # add a copy of rgb color to theme if it contains one color + if len(theme) <= 1: + theme.append(theme[0]) + + color = [] + for incr in range(len(theme)): + color.append([1./(len(theme)-1)*incr, theme[incr]]) dim = len(dataframe) fig = make_subplots(rows=dim, cols=dim) @@ -2576,10 +2574,7 @@ def _scatterplot_theme(dataframe, headers, diag, size, height, width, if 'marker' in kwargs: kwargs['marker']['size'] = size kwargs['marker']['color'] = index_vals - kwargs['marker']['colorscale'] = [ - [0, theme[0]], - [1, theme[1]] - ] + kwargs['marker']['colorscale'] = color kwargs['marker']['showscale'] = True trace = graph_objs.Scatter( x=listx, @@ -2596,8 +2591,7 @@ def _scatterplot_theme(dataframe, headers, diag, size, height, width, marker=dict( size=size, color=index_vals, - colorscale=[[0, theme[0]], - [1, theme[1]]], + colorscale=color, showscale=True), showlegend=False, **kwargs @@ -2622,10 +2616,7 @@ def _scatterplot_theme(dataframe, headers, diag, size, height, width, if 'marker' in kwargs: kwargs['marker']['size'] = size kwargs['marker']['color'] = index_vals - kwargs['marker']['colorscale'] = [ - [0, theme[0]], - [1, theme[1]] - ] + kwargs['marker']['colorscale'] = color kwargs['marker']['showscale'] = False trace = graph_objs.Scatter( x=listx, @@ -2642,8 +2633,7 @@ def _scatterplot_theme(dataframe, headers, diag, size, height, width, marker=dict( size=size, color=index_vals, - colorscale=[[0, theme[0]], - [1, theme[1]]], + colorscale=color, showscale=False), showlegend=False, **kwargs @@ -2738,7 +2728,7 @@ def _validate_dataframe(array): "numbers or strings.") @staticmethod - def _validate_scatterplotmatrix(df, index, diag, **kwargs): + def _validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs): """ Validates basic inputs for FigureFactory.create_scatterplotmatrix() @@ -2747,6 +2737,7 @@ def _validate_scatterplotmatrix(df, index, diag, **kwargs): :raises: (PlotlyError) If pandas dataframe has <= 1 columns :raises: (PlotlyError) If diagonal plot choice (diag) is not one of the viable options + :raises: (PlotlyError) If colormap_type is not a valid choice :raises: (PlotlyError) If kwargs contains 'size', 'color' or 'colorscale' """ @@ -2766,11 +2757,18 @@ def _validate_scatterplotmatrix(df, index, diag, **kwargs): "use the scatterplot matrix, use at " "least 2 columns.") - # Check that diag parameter is selected properly + # Check that diag parameter is a valid selection if diag not in DIAG_CHOICES: raise exceptions.PlotlyError("Make sure diag is set to " "one of {}".format(DIAG_CHOICES)) + # Check that colormap_types is a valid selection + if colormap_type not in VALID_COLORMAP_TYPES: + raise exceptions.PlotlyError("Must choose a valid colormap type. " + "Either 'cat' or 'seq' for a cate" + "gorical and sequential colormap " + "respectively.") + # Check for not 'size' or 'color' in 'marker' of **kwargs if 'marker' in kwargs: FORBIDDEN_PARAMS = ['size', 'color', 'colorscale'] @@ -2791,7 +2789,7 @@ def _endpts_to_intervals(endpts): Accepts a list or tuple of sequentially increasing numbers and returns a list representation of the mathematical intervals with these numbers - as endpoints. For example, [1, 4, 6] returns [[1, 4], [4, 6]] + as endpoints. For example, [1, 6] returns [[-inf, 1], [1, 6], [6, inf]] :raises: (PlotlyError) If input is not a list or tuple :raises: (PlotlyError) If the input contains a string @@ -2834,26 +2832,31 @@ def _endpts_to_intervals(endpts): @staticmethod def _convert_to_RGB_255(colors): """ - Return a list of tuples where each element gets multiplied by 255 + Return a (list of) tuple(s) where each element is multiplied by 255 - Takes a list of color tuples where each element is between 0 and 1 - and returns the same list where each tuple element is normalized to be - between 0 and 255 + Takes a tuple or a list of tuples where each element of each tuple is + between 0 and 1. Returns the same tuple(s) where each tuple element is + multiplied by 255 """ - colors_255 = [] - for color in colors: - rgb_color = (color[0]*255.0, color[1]*255.0, color[2]*255.0) - colors_255.append(rgb_color) - return colors_255 + if isinstance(colors, tuple): + return (colors[0]*255.0, colors[1]*255.0, colors[2]*255.0) + + else: + colors_255 = [] + for color in colors: + rgb_color = (color[0]*255.0, color[1]*255.0, color[2]*255.0) + colors_255.append(rgb_color) + return colors_255 @staticmethod def _n_colors(lowcolor, highcolor, n_colors): """ - Splits a low and high color into a list of #n_colors colors + Splits a low and high color into a list of n_colors colors in it Accepts two color tuples and returns a list of n_colors colors which form the intermediate colors between lowcolor and highcolor + from linearly interpolating through RGB space """ diff_0 = float(highcolor[0] - lowcolor[0]) @@ -2875,40 +2878,41 @@ def _n_colors(lowcolor, highcolor, n_colors): @staticmethod def _label_rgb(colors): """ - Takes colors (a, b, c) and returns tuples 'rgb(a, b, c)' + Takes tuple(s) (a, b, c) and returns rgb color(s) 'rgb(a, b, c)' - Takes a list of two color tuples of the form (a, b, c) and returns the - same list with each tuple replaced by a string 'rgb(a, b, c)' + Takes either a list or a single color tuple of the form (a, b, c) and + returns the same color(s) with each tuple replaced by a string + 'rgb(a, b, c)' """ - colors_label = [] - for color in colors: - color_label = 'rgb{}'.format(color) - colors_label.append(color_label) + if isinstance(colors, tuple): + return 'rgb{}'.format(colors) + else: + colors_label = [] + for color in colors: + color_label = 'rgb{}'.format(color) + colors_label.append(color_label) - return colors_label + return colors_label @staticmethod def _unlabel_rgb(colors): """ - Takes rgb colors 'rgb(a, b, c)' and returns the tuples (a, b, c) + Takes rgb color(s) 'rgb(a, b, c)' and returns tuple(s) (a, b, c) - This function takes a list of two 'rgb(a, b, c)' color strings and - returns a list of the color tuples in tuple form without the 'rgb' - label. In particular, the output is a list of two tuples of the form - (a, b, c) + This function takes either an 'rgb(a, b, c)' color or a list of + such colors and returns the color tuples in tuple(s) (a, b, c) """ - unlabelled_colors = [] - for character in colors: + if isinstance(colors, str): str_vals = '' - for index in range(len(character)): + for index in range(len(colors)): try: - float(character[index]) - str_vals = str_vals + character[index] + float(colors[index]) + str_vals = str_vals + colors[index] except ValueError: - if (character[index] == ',') or (character[index] == '.'): - str_vals = str_vals + character[index] + if (colors[index] == ',') or (colors[index] == '.'): + str_vals = str_vals + colors[index] str_vals = str_vals + ',' numbers = [] @@ -2919,37 +2923,73 @@ def _unlabel_rgb(colors): else: numbers.append(float(str_num)) str_num = '' - unlabelled_tuple = (numbers[0], numbers[1], numbers[2]) - unlabelled_colors.append(unlabelled_tuple) + return (numbers[0], numbers[1], numbers[2]) + + if isinstance(colors, list): + unlabelled_colors = [] + for color in colors: + str_vals = '' + for index in range(len(color)): + try: + float(color[index]) + str_vals = str_vals + color[index] + except ValueError: + if (color[index] == ',') or (color[index] == '.'): + str_vals = str_vals + color[index] + + str_vals = str_vals + ',' + numbers = [] + str_num = '' + for char in str_vals: + if char != ',': + str_num = str_num + char + else: + numbers.append(float(str_num)) + str_num = '' + unlabelled_tuple = (numbers[0], numbers[1], numbers[2]) + unlabelled_colors.append(unlabelled_tuple) - return unlabelled_colors + return unlabelled_colors @staticmethod - def create_scatterplotmatrix(df, dataframe=None, headers=None, - index_vals=None, index=None, endpts=None, - diag='scatter', height=500, width=500, size=6, - title='Scatterplot Matrix', use_theme=False, - palette=None, **kwargs): + def create_scatterplotmatrix(df, index=None, endpts=None, diag='scatter', + height=500, width=500, size=6, + title='Scatterplot Matrix', colormap=None, + colormap_type='cat', dataframe=None, + headers=None, index_vals=None, **kwargs): """ Returns data for a scatterplot matrix. :param (array) df: array of the data with column headers :param (str) index: name of the index column in data array - :param (list|tuple) endpts: this param takes an increasing sequece - of numbers that form intervals on the real line. They are used - to make a numeric index categorical under 'theme = True' by - grouping the data into these intervals. It only affects the non- - diagonal plots - :param (str) diag: sets graph type for the main diagonal plots - :param (int|float) height: sets the height of the graph - :param (int|float) width: sets the width of the graph - :param (int or float >= 0) size: sets the marker size (in px) + :param (list|tuple) endpts: takes an increasing sequece of numbers + that defines intervals on the real line. They are used to group + the entries in an index of numbers into their corresponding + interval and therefore can be treated as categorical data + :param (str) diag: sets the chart type for the main diagonal plots + :param (int|float) height: sets the height of the chart + :param (int|float) width: sets the width of the chart + :param (float) size: sets the marker size (in px) :param (str) title: the title label of the scatterplot matrix - :param (bool) use_theme: determines if a theme is applied - :param (str|list) palette: either a plotly scale name, or a list - containing 2 triplets. These triplets must be of the form (a,b,c) - or 'rgb(x,y,z)' where a,b,c belong to the interval [0,1] and x,y,z - belong to [0,255] + :param (str|tuple|list|dict) colormap: either a plotly scale name, + an rgb or hex color, a color tuple, a list of colors or a + dictionary. An rgb color is of the form 'rgb(x, y, z)' where + x, y and z belong to the interval [0, 255] and a color tuple is a + tuple of the form (a, b, c) where a, b and c belong to [0, 1]. + If colormap is a list, it must contain valid color types as its + members. + If colormap is a dictionary, all the string entries in + the index column must be a key in colormap. In this case, the + colormap_type is forced to 'cat' or categorical + :param (str) colormap_type: determines how colormap is interpreted. + Valid choices are 'seq' (sequential) and 'cat' (categorical). If + 'seq' is selected, only the first two colors in colormap will be + considered (when colormap is a list) and the index values will be + linearly interpolated between those two colors. This option is + forced if all index values are numeric. + If 'cat' is selected, a color from colormap will be assigned to + each category from index, including the intervals if endpts is + being used :param (dict) **kwargs: a dictionary of scatterplot arguments The only forbidden parameters are 'size', 'color' and 'colorscale' in 'marker' @@ -2971,7 +3011,7 @@ def create_scatterplotmatrix(df, dataframe=None, headers=None, fig = FF.create_scatterplotmatrix(df) # Plot - py.iplot(fig, filename='Scatterplot Matrix') + py.iplot(fig, filename='Vanilla Scatterplot Matrix') ``` Example 2: Indexing a Column @@ -2992,13 +3032,13 @@ def create_scatterplotmatrix(df, dataframe=None, headers=None, 'grape', 'pear', 'pear', 'apple', 'pear']) # Create scatterplot matrix - fig = FF.create_scatterplotmatrix(df, index = 'Fruit', size = 10) + fig = FF.create_scatterplotmatrix(df, index='Fruit', size=10) # Plot - py.iplot(fig, filename = 'Scatterplot Matrix') + py.iplot(fig, filename = 'Scatterplot Matrix with Index') ``` - Example 3: Styling the diagonal subplots + Example 3: Styling the Diagonal Subplots ``` import plotly.plotly as py from plotly.graph_objs import graph_objs @@ -3016,14 +3056,14 @@ def create_scatterplotmatrix(df, dataframe=None, headers=None, 'grape', 'pear', 'pear', 'apple', 'pear']) # Create scatterplot matrix - fig = FF.create_scatterplotmatrix(df, diag = 'box', index = 'Fruit', - height = 1000, width = 1000) + fig = FF.create_scatterplotmatrix(df, diag='box', index='Fruit', + height=1000, width=1000) # Plot - py.iplot(fig, filename = 'Scatterplot Matrix') + py.iplot(fig, filename = 'Scatterplot Matrix - Diagonal Styling') ``` - Example 4: Use a theme to Styling the subplots + Example 4: Use a Theme to Style the Subplots ``` import plotly.plotly as py from plotly.graph_objs import graph_objs @@ -3038,15 +3078,15 @@ def create_scatterplotmatrix(df, dataframe=None, headers=None, # Create scatterplot matrix using a built-in # Plotly palette scale and indexing column 'A' - fig = FF.create_scatterplotmatrix(df, diag = 'histogram', index = 'A', - use_theme=True, palette = 'Blues', - height = 800, width = 800) + fig = FF.create_scatterplotmatrix(df, diag='histogram', + index='A', colormap='Blues', + height=800, width=800) # Plot - py.iplot(fig, filename = 'Scatterplot Matrix') + py.iplot(fig, filename = 'Scatterplot Matrix - Colormap Theme') ``` - Example 5: Example 4 with interval factoring + Example 5: Example 4 with Interval Factoring ``` import plotly.plotly as py from plotly.graph_objs import graph_objs @@ -3061,15 +3101,57 @@ def create_scatterplotmatrix(df, dataframe=None, headers=None, # Create scatterplot matrix using a list of 2 rgb tuples # and endpoints at -1, 0 and 1 - fig = FF.create_scatterplotmatrix(df, diag = 'histogram', index = 'A', - use_theme=True, - palette = ['rgb(140, 255, 50)', - 'rgb(170, 60, 115)'], - endpts = [-1, 0, 1], - height = 800, width = 800) + fig = FF.create_scatterplotmatrix(df, diag='histogram', index='A', + colormap=['rgb(140, 255, 50)', + 'rgb(170, 60, 115)', + '#6c4774', + (0.5, 0.1, 0.8)], + endpts=[-1, 0, 1], + height=800, width=800) + + # Plot + py.iplot(fig, filename = 'Scatterplot Matrix - Intervals') + ``` + + Example 6: Using the colormap as a Dictionary + ``` + import plotly.plotly as py + from plotly.graph_objs import graph_objs + from plotly.tools import FigureFactory as FF + + import numpy as np + import pandas as pd + import random + + # Create dataframe with random data + df = pd.DataFrame(np.random.randn(100, 3), + columns=['Column A', + 'Column B', + 'Column C']) + + # Add new color column to dataframe + new_column = [] + strange_colors = ['turquoise', 'limegreen', 'goldenrod'] + + for j in range(100): + new_column.append(random.choice(strange_colors)) + df['Colors'] = pd.Series(new_column, index=df.index) + + # Create scatterplot matrix using a dictionary of hex color values + # which correspond to actual color names in 'Colors' column + fig = FF.create_scatterplotmatrix( + df, diag='box', index='Colors', + colormap= dict( + turquoise = '#00F5FF', + limegreen = '#32CD32', + goldenrod = '#DAA520' + ), + colormap_type='cat', + height=800, width=800 + ) # Plot - py.iplot(fig, filename = 'Scatterplot Matrix') + py.iplot(fig, filename = 'Scatterplot Matrix - colormap dictionary ') ``` """ # TODO: protected until #282 @@ -3079,9 +3161,154 @@ def create_scatterplotmatrix(df, dataframe=None, headers=None, headers = [] if index_vals is None: index_vals = [] + plotly_scales = {'Greys': ['rgb(0,0,0)', 'rgb(255,255,255)'], + 'YlGnBu': ['rgb(8,29,88)', 'rgb(255,255,217)'], + 'Greens': ['rgb(0,68,27)', 'rgb(247,252,245)'], + 'YlOrRd': ['rgb(128,0,38)', 'rgb(255,255,204)'], + 'Bluered': ['rgb(0,0,255)', 'rgb(255,0,0)'], + 'RdBu': ['rgb(5,10,172)', 'rgb(178,10,28)'], + 'Reds': ['rgb(220,220,220)', 'rgb(178,10,28)'], + 'Blues': ['rgb(5,10,172)', 'rgb(220,220,220)'], + 'Picnic': ['rgb(0,0,255)', 'rgb(255,0,0)'], + 'Rainbow': ['rgb(150,0,90)', 'rgb(255,0,0)'], + 'Portland': ['rgb(12,51,131)', 'rgb(217,30,30)'], + 'Jet': ['rgb(0,0,131)', 'rgb(128,0,0)'], + 'Hot': ['rgb(0,0,0)', 'rgb(255,255,255)'], + 'Blackbody': ['rgb(0,0,0)', 'rgb(160,200,255)'], + 'Earth': ['rgb(0,0,130)', 'rgb(255,255,255)'], + 'Electric': ['rgb(0,0,0)', 'rgb(255,250,220)'], + 'Viridis': ['rgb(68,1,84)', 'rgb(253,231,37)']} FigureFactory._validate_scatterplotmatrix(df, index, diag, - **kwargs) + colormap_type, **kwargs) + + # Validate colormap + if colormap is None: + colormap = DEFAULT_PLOTLY_COLORS + + if isinstance(colormap, str): + if colormap in plotly_scales: + colormap = plotly_scales[colormap] + + elif 'rgb' in colormap: + colormap = FigureFactory._unlabel_rgb(colormap) + for value in colormap: + if value > 255.0: + raise exceptions.PlotlyError("Whoops! The " + "elements in your " + "rgb colormap " + "tuples cannot " + "exceed 255.0.") + colormap = FigureFactory._label_rgb(colormap) + + # put colormap in list + colors_list = [] + colors_list.append(colormap) + colormap = colors_list + + elif '#' in colormap: + colormap = FigureFactory._hex_to_rgb(colormap) + colormap = FigureFactory._label_rgb(colormap) + + # put colormap in list + colors_list = [] + colors_list.append(colormap) + colormap = colors_list + + else: + scale_keys = list(plotly_scales.keys()) + raise exceptions.PlotlyError("If you input a string " + "for 'colormap', it must " + "either be a Plotly " + "colorscale, an 'rgb' " + "color or a hex color." + "Valid plotly colorscale " + "names are {}".format(scale_keys)) + elif isinstance(colormap, tuple): + for value in colormap: + if value > 1.0: + raise exceptions.PlotlyError("Whoops! The " + "elements in " + "your colormap " + "tuples cannot " + "exceed 1.0.") + + colors_list = [] + colors_list.append(colormap) + colormap = colors_list + + colormap = FigureFactory._convert_to_RGB_255(colormap) + colormap = FigureFactory._label_rgb(colormap) + + elif isinstance(colormap, list): + new_colormap = [] + for color in colormap: + if 'rgb' in color: + color = FigureFactory._unlabel_rgb(color) + + for value in color: + if value > 255.0: + raise exceptions.PlotlyError("Whoops! The " + "elements in your " + "rgb colormap " + "tuples cannot " + "exceed 255.0.") + + color = FigureFactory._label_rgb(color) + new_colormap.append(color) + elif '#' in color: + color = FigureFactory._hex_to_rgb(color) + color = FigureFactory._label_rgb(color) + new_colormap.append(color) + elif isinstance(color, tuple): + for value in color: + if value > 1.0: + raise exceptions.PlotlyError("Whoops! The " + "elements in " + "your colormap " + "tuples cannot " + "exceed 1.0.") + color = FigureFactory._convert_to_RGB_255(color) + color = FigureFactory._label_rgb(color) + new_colormap.append(color) + colormap = new_colormap + + elif isinstance(colormap, dict): + for name in colormap: + if 'rgb' in colormap[name]: + color = FigureFactory._unlabel_rgb(colormap[name]) + for value in color: + if value > 255.0: + raise exceptions.PlotlyError("Whoops! The " + "elements in your " + "rgb colormap " + "tuples cannot " + "exceed 255.0.") + + elif '#' in colormap[name]: + color = FigureFactory._hex_to_rgb(colormap[name]) + color = FigureFactory._label_rgb(color) + colormap[name] = color + + elif isinstance(colormap[name], tuple): + for value in colormap[name]: + if value > 1.0: + raise exceptions.PlotlyError("Whoops! The " + "elements in " + "your colormap " + "tuples cannot " + "exceed 1.0.") + color = FigureFactory._convert_to_RGB_255(colormap[name]) + color = FigureFactory._label_rgb(color) + colormap[name] = color + + else: + raise exceptions.PlotlyError("You must input a valid colormap. " + "Valid types include a plotly scale, " + "rgb, hex or tuple color, a list of " + "any color types, or a dictionary " + "with index names each assigned " + "to a color.") if not index: for name in df: headers.append(name) @@ -3106,25 +3333,42 @@ def create_scatterplotmatrix(df, dataframe=None, headers=None, headers.append(name) for name in headers: dataframe.append(df[name].values.tolist()) - # Check for same data-type in df columns + + # check for same data-type in each df column FigureFactory._validate_dataframe(dataframe) FigureFactory._validate_index(index_vals) - if use_theme is False: - figure = FigureFactory._scatterplot_index(dataframe, headers, - diag, size, - height, width, - title, index, - index_vals, - **kwargs) + # check if all colormap keys are in the index + # if colormap is a dictionary + if isinstance(colormap, dict): + for key in colormap: + if not all(index in colormap for index in index_vals): + raise exceptions.PlotlyError("If colormap is a " + "dictionary, all the " + "names in the index " + "must be keys.") + + figure = FigureFactory._scatterplot_dict(dataframe, + headers, + diag, + size, height, + width, title, + index, + index_vals, + endpts, + colormap, + colormap_type, + **kwargs) return figure + else: figure = FigureFactory._scatterplot_theme(dataframe, headers, diag, size, height, width, title, index, index_vals, - endpts, palette, + endpts, colormap, + colormap_type, **kwargs) return figure