diff --git a/test/test_utils.py b/test/test_utils.py index 5c0502d7bb5..1439a974368 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -162,6 +162,9 @@ def test_draw_invalid_boxes(): "colors", [ None, + "blue", + "#FF00FF", + (1, 34, 122), ["red", "blue"], ["#FF00FF", (1, 34, 122)], ], @@ -191,6 +194,8 @@ def test_draw_segmentation_masks(colors, alpha): if colors is None: colors = utils._generate_color_palette(num_masks) + elif isinstance(colors, str) or isinstance(colors, tuple): + colors = [colors] # Make sure each mask draws with its own color for mask, color in zip(masks, colors): diff --git a/torchvision/utils.py b/torchvision/utils.py index 6e8d63b1d7e..a71e0f234b4 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -160,9 +160,9 @@ def draw_bounding_boxes( the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and `0 <= ymin < ymax < H`. labels (List[str]): List containing the labels of bounding boxes. - colors (Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]): List containing the colors - or a single color for all of the bounding boxes. The colors can be represented as `str` or - `Tuple[int, int, int]`. + colors (color or list of colors, optional): List containing the colors + of the boxes or single color for all boxes. The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. fill (bool): If `True` fills the bounding box with specified color. width (int): Width of bounding box. font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may @@ -231,7 +231,7 @@ def draw_segmentation_masks( image: torch.Tensor, masks: torch.Tensor, alpha: float = 0.8, - colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, + colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, ) -> torch.Tensor: """ @@ -243,10 +243,10 @@ def draw_segmentation_masks( masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. alpha (float): Float number between 0 and 1 denoting the transparency of the masks. 0 means full transparency, 1 means no transparency. - colors (list or None): List containing the colors of the masks. The colors can - be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. - When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list - with one element. By default, random colors are generated for each mask. + colors (color or list of colors, optional): List containing the colors + of the masks or single color for all masks. The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + By default, random colors are generated for each mask. Returns: img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. @@ -289,8 +289,7 @@ def draw_segmentation_masks( for color in colors: if isinstance(color, str): color = ImageColor.getrgb(color) - color = torch.tensor(color, dtype=out_dtype) - colors_.append(color) + colors_.append(torch.tensor(color, dtype=out_dtype)) img_to_draw = image.detach().clone() # TODO: There might be a way to vectorize this @@ -301,6 +300,6 @@ def draw_segmentation_masks( return out.to(out_dtype) -def _generate_color_palette(num_masks): +def _generate_color_palette(num_masks: int): palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) return [tuple((i * palette) % 255) for i in range(num_masks)]