From ed86cb89524a91eb3f30e9b5ca9e1978ed811cf0 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Wed, 21 Aug 2024 16:18:27 -0400 Subject: [PATCH 1/7] adding detector --- .gitignore | 3 +- directai_fastapi/.env | 1 + directai_fastapi/Dockerfile | 4 +- directai_fastapi/modeling/image_classifier.py | 241 +++++++ directai_fastapi/modeling/object_detector.py | 631 ++++++++++++++++++ directai_fastapi/modeling/tensor_utils.py | 122 ++++ directai_fastapi/requirements.txt | 8 +- directai_fastapi/unit_tests/test.py | 4 + .../unit_tests/test_modules/test_detector.py | 54 ++ docker-compose.yml | 25 +- redis_data/redis_entrypoint.sh | 2 +- 11 files changed, 1081 insertions(+), 14 deletions(-) create mode 100644 directai_fastapi/.env create mode 100644 directai_fastapi/modeling/image_classifier.py create mode 100644 directai_fastapi/modeling/object_detector.py create mode 100644 directai_fastapi/modeling/tensor_utils.py create mode 100644 directai_fastapi/unit_tests/test_modules/test_detector.py diff --git a/.gitignore b/.gitignore index 8c61b8b..c3f8a1b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ redis_data/appendonlydir/ -logs/** \ No newline at end of file +logs/** +.cache \ No newline at end of file diff --git a/directai_fastapi/.env b/directai_fastapi/.env new file mode 100644 index 0000000..54b8adb --- /dev/null +++ b/directai_fastapi/.env @@ -0,0 +1 @@ +LOGGING_LEVEL= diff --git a/directai_fastapi/Dockerfile b/directai_fastapi/Dockerfile index 9be78c2..6fe6666 100644 --- a/directai_fastapi/Dockerfile +++ b/directai_fastapi/Dockerfile @@ -1,9 +1,11 @@ -FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime +FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel WORKDIR /directai_fastapi RUN apt-get update RUN apt-get install libgl1 libglib2.0-0 libsm6 libxrender1 libxext6 -y +RUN apt-get install git -y + RUN apt-get install cmake build-essential -y COPY requirements.txt . RUN pip install -r requirements.txt diff --git a/directai_fastapi/modeling/image_classifier.py b/directai_fastapi/modeling/image_classifier.py new file mode 100644 index 0000000..074a876 --- /dev/null +++ b/directai_fastapi/modeling/image_classifier.py @@ -0,0 +1,241 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch_scatter import scatter_max # type: ignore +import open_clip # type: ignore +from functools import partial + +from modeling.tensor_utils import ( + batch_encode_cache_missed_list_elements, + image_bytes_to_tensor, + squish_labels, +) +from modeling.prompt_templates import noop_hypothesis_formats, many_hypothesis_formats +from lru import LRU + + +class ZeroShotImageClassifierWithFeedback(nn.Module): + def __init__( + self, + base_model_name: str = "ViT-H-14-quickgelu", + dataset_name: str = "dfn5b", + max_text_batch_size: int = 256, + max_image_batch_size: int = 256, + device: torch.device | str = "cuda", + lru_cache_size: int = 4096, # set to 0 to disable caching + jit: bool = True, + fp16: bool = True, + ): + super().__init__() + + self.device = torch.device(device) if type(device) is str else device + self.fp16 = fp16 + + # TODO: just do create_model, not create_model_and_transforms + self.model, _, _ = open_clip.create_model_and_transforms( + base_model_name, + pretrained=dataset_name, + jit=jit, + image_resize_mode="squash", + precision="fp16" if fp16 else "fp32", + ) + self.tokenizer = open_clip.get_tokenizer(base_model_name) + + self.model = self.model.to(self.device) + self.model.eval() + + self.max_text_batch_size = max_text_batch_size + self.max_image_batch_size = max_image_batch_size + + # we cache the text embeddings to avoid recomputing them + # we use an LRU cache to avoid running out of memory + # especially because likely the tensors will be large and stored in GPU memory + self.augmented_label_encoding_cache: LRU | None = ( + LRU(lru_cache_size) if lru_cache_size > 0 else None + ) + self.not_augmented_label_encoding_cache: LRU | None = ( + LRU(lru_cache_size) if lru_cache_size > 0 else None + ) + + preprocess_config = self.model.visual.preprocess_cfg + self.img_mean = ( + torch.tensor(preprocess_config["mean"]).view(1, 3, 1, 1).to(self.device) + ) + self.img_std = ( + torch.tensor(preprocess_config["std"]).view(1, 3, 1, 1).to(self.device) + ) + self.img_size = preprocess_config["size"] + if type(self.img_size) is int: + self.img_size = (self.img_size, self.img_size) + + def encode_image(self, image: torch.Tensor | bytes) -> torch.Tensor: + # enable to work with raw file instead of PIL image to save bandwidth during remote gRPC call + if isinstance(image, bytes): + image = image_bytes_to_tensor(image, self.img_size) + + if len(image.shape) == 3: + image = image.unsqueeze(0) + + image = image.to(self.device) + + # NOTE: we are doing the normalization here instead of the data loader + # to take advantage of easier access to the model's specific normalization values + image = image.float() / 255.0 + image -= self.img_mean + image /= self.img_std + + if self.fp16: + image = image.half() + + feature_list = [] + for i in range(0, image.size(0), self.max_image_batch_size): + features_subset = self.model.encode_image( + image[i : i + self.max_image_batch_size] + ) + feature_list.append(features_subset) + + features = torch.cat(feature_list, dim=0) + features /= torch.norm(features, dim=1, keepdim=True) + + return features + + def _encode_text(self, text: list[str], augment: bool = True) -> torch.Tensor: + # we apply the prompt templates commonly used with CLIP-based models unless otherwise specified + templates = many_hypothesis_formats if augment else noop_hypothesis_formats + augmented_text = [template.format(t) for t in text for template in templates] + + tokenized = self.tokenizer(augmented_text).to(self.device) + + features_list = [] + for i in range(0, len(tokenized), self.max_text_batch_size): + features_subset = self.model.encode_text( + tokenized[i : i + self.max_text_batch_size] + ) + features_list.append(features_subset) + + features = torch.cat(features_list, dim=0) + features /= torch.norm(features, dim=1, keepdim=True) + features = features.view(len(text), len(templates), features.shape[1]) + features = features.mean(dim=1) + features /= torch.norm(features, dim=1, keepdim=True) + + return features + + def encode_text(self, text: list[str], augment: bool = True) -> torch.Tensor: + if augment: + return batch_encode_cache_missed_list_elements( + partial(self._encode_text, augment=True), + text, + self.augmented_label_encoding_cache, + ) + else: + return batch_encode_cache_missed_list_elements( + partial(self._encode_text, augment=False), + text, + self.not_augmented_label_encoding_cache, + ) + + def forward( + self, + image: torch.Tensor | bytes, + labels: list[str], + inc_sub_labels_dict: dict[str, list[str]], + exc_sub_labels_dict: dict[str, list[str]] | None = None, + augment_examples: bool = True, + ) -> torch.Tensor: + # run an image classifier parameterized by explicit statements on what each label should include or exclude + # return a tensor of scores for each label + # each label must include at least one sub-label, and may exclude any number of sub-labels + + if len(labels) == 0: + raise ValueError("At least one label must be provided") + + if any([len(sub_labels) == 0 for sub_labels in inc_sub_labels_dict.values()]): + raise ValueError("Each label must include at least one sub-label") + + image_features = self.encode_image(image) + + exc_sub_labels_dict = {} if exc_sub_labels_dict is None else exc_sub_labels_dict + # filter out empty excs lists + exc_sub_labels_dict = { + label: excs for label, excs in exc_sub_labels_dict.items() if len(excs) > 0 + } + + all_labels, all_labels_to_inds = squish_labels( + labels, inc_sub_labels_dict, exc_sub_labels_dict + ) + text_features = self.encode_text(all_labels, augment=augment_examples) + + scores = (1 + image_features @ text_features.T) / 2 + + label_to_ind = {label: i for i, label in enumerate(labels)} + + pos_labels_to_master_inds, pos_labels_list = zip( + *[ + v + for label, incs in inc_sub_labels_dict.items() + for v in zip([label_to_ind[label]] * len(incs), incs) + ] + ) + pos_labels_inds = [all_labels_to_inds[label] for label in pos_labels_list] + + pos_scores = scores[:, pos_labels_inds] + + # pos_labels_to_master_inds indicates which indices we should be taking the max over for each label + # since our scatter_max will be batched, we need to offset this for each image + num_labels = len(labels) + num_images = image_features.shape[0] + num_incs = len(pos_labels_to_master_inds) + offsets = ( + torch.arange(num_images).unsqueeze(1).expand(-1, num_incs).flatten() + * num_labels + ) + offsets = offsets.to(self.device) + indices_for_max = ( + torch.tensor(pos_labels_to_master_inds).to(self.device).repeat(num_images) + + offsets + ) + + max_pos_scores_flat, _ = scatter_max( + pos_scores.view(-1), indices_for_max, dim_size=num_images * num_labels + ) + max_pos_scores = max_pos_scores_flat.view(num_images, num_labels) + + # compute the same for the negative labels, if any + if len(exc_sub_labels_dict) > 0: + neg_labels_to_master_inds, neg_labels_list = zip( + *[ + v + for label, excs in exc_sub_labels_dict.items() + for v in zip([label_to_ind[label]] * len(excs), excs) + ] + ) + neg_labels_inds = [all_labels_to_inds[label] for label in neg_labels_list] + + neg_scores = scores[:, neg_labels_inds] + + num_excs = len(neg_labels_to_master_inds) + offsets = ( + torch.arange(num_images).unsqueeze(1).expand(-1, num_excs).flatten() + * num_labels + ) + offsets = offsets.to(self.device) + indices_for_max = ( + torch.tensor(neg_labels_to_master_inds) + .to(self.device) + .repeat(num_images) + + offsets + ) + + max_neg_scores_flat, _ = scatter_max( + neg_scores.view(-1), indices_for_max, dim_size=num_images * num_labels + ) + max_neg_scores = max_neg_scores_flat.view(num_images, num_labels) + + raw_scores = torch.where( + max_pos_scores > max_neg_scores, max_pos_scores, 1 - max_neg_scores + ) + else: + raw_scores = max_pos_scores + + return raw_scores diff --git a/directai_fastapi/modeling/object_detector.py b/directai_fastapi/modeling/object_detector.py new file mode 100644 index 0000000..5b03d41 --- /dev/null +++ b/directai_fastapi/modeling/object_detector.py @@ -0,0 +1,631 @@ +from typing import List, Optional, Tuple +import torch +from PIL import Image +import torch +from torch import nn +import torchvision +import numpy as np +from transformers import Owlv2Processor, Owlv2ForObjectDetection, Owlv2VisionModel +from transformers.models.owlv2.modeling_owlv2 import Owlv2Attention +import time +from typing import Union +from torch_scatter import scatter_max +from flash_attn import flash_attn_func +import io +from lru import LRU +from functools import partial + +from modeling.prompt_templates import medium_hypothesis_formats, noop_hypothesis_formats +from modeling.tensor_utils import ( + batch_encode_cache_missed_list_elements, + resize_pil_image, + squish_labels, +) + + +def created_padded_tensor_from_bytes( + image_bytes: bytes, image_size: tuple[int, int] +) -> tuple[torch.Tensor, torch.Tensor]: + padded_image_tensor = torch.ones((1, 3, *image_size)) * 114.0 + + # TODO: add nonblocking streaming to GPU + + image_buffer = io.BytesIO(image_bytes) + pil_image = Image.open(image_buffer) + current_size = pil_image.size + + r = min(image_size[0] / current_size[0], image_size[1] / current_size[1]) + target_size = (int(r * current_size[0]), int(r * current_size[1])) + + pil_image = resize_pil_image(pil_image, target_size) + + np_image = np.asarray(pil_image) + torch_image = torch.tensor(np_image).permute(2, 0, 1).unsqueeze(0) + + padded_image_tensor[:, :, :torch_image.shape[2], :torch_image.shape[3]] = torch_image + + image_scale_ratios = torch.tensor([r,]) + + return padded_image_tensor, image_scale_ratios + + +def flash_attn_owl_vit_encoder_forward( + self: Owlv2Attention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, None]: + assert not output_attentions, "output_attentions not supported for flash attention implementation" + assert attention_mask is None, "attention_mask not supported for flash attention implementation" + # technically flash_attn DOES support causal attention + # but the OWL usage of causal attention mask does not limit it to true causal attention + # we don't support generalized attention, so we're just going to assert causal attention mask is ALSO None + assert causal_attention_mask is None, "causal_attention_mask not supported for flash attention implementation" + + bsz, tgt_len, embed_dim = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.contiguous().view(bsz, tgt_len, self.num_heads, self.head_dim) + key_states = key_states.contiguous().view(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = value_states.contiguous().view(bsz, tgt_len, self.num_heads, self.head_dim) + + # convert to appropriate dtype + # NOTE: bf16 may be more appropriate than fp16 + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=0, + softmax_scale=self.scale, + ) + + attn_output = attn_output.view(bsz, tgt_len, embed_dim) + + # convert back to appropriate dtype + attn_output = attn_output.to(hidden_states.dtype) + + attn_output = self.out_proj(attn_output) + + return attn_output, None + + +# we're copying the function signature from the original +# and just replacing the method with a faster one based on flash_attn +# we could subclass the original, but that would require us to subclass the entire model +# so we're just going to monkey patch it, as the output should be identical with the same inputs +# Owlv2Attention.forward = flash_attn_owl_vit_encoder_forward +# for owlv2_vision_model_encoder_layer in Owlv2VisionModel.vision_model.encoder.layers: +# owlv2_vision_model_encoder_layer.self_attn.forward = flash_attn_owl_vit_encoder_forward + + +class VisionModelWrapper(nn.Module): + def __init__(self, vision_model: Owlv2VisionModel) -> None: + super().__init__() + + self.vision_model = vision_model + + # we're going to monkey patch the forward method of the attention layers + # to replace it with a faster one based on flash_attn + # the alternative is to subclass the entire model, but that's a lot of work + # so we're just going to define a replacement with the same function signature + # and assert that the input is as expected + for owlv2_vision_model_encoder_layer in self.vision_model.encoder.layers: + owlv2_vision_model_encoder_layer.self_attn.forward = partial( + flash_attn_owl_vit_encoder_forward, owlv2_vision_model_encoder_layer.self_attn + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + vision_outputs = self.vision_model(pixel_values=image, return_dict=True) + + # Get image embedding + last_hidden_states = vision_outputs[0] + image_embeds = self.vision_model.post_layernorm(last_hidden_states) + + return image_embeds + + +class WrappedImageEmbedder(nn.Module): + def __init__(self, model: Owlv2ForObjectDetection) -> None: + super().__init__() + + self.model = model + self.wrapped_vision_model = VisionModelWrapper(self.model.owlv2.vision_model) + + def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + image_embeds = self.wrapped_vision_model(image) + + # Resize class token + new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size) + + # Merge image embedding with class tokens + image_embeds = image_embeds[:, 1:, :] * class_token_out + image_embeds = self.model.layer_norm(image_embeds) + + # Resize to [batch_size, num_patches, num_patches, hidden_size] + new_size = ( + image_embeds.shape[0], + int(np.sqrt(image_embeds.shape[1])), + int(np.sqrt(image_embeds.shape[1])), + image_embeds.shape[-1], + ) + feature_map = image_embeds.reshape(new_size) + + # Get class head features + # we do dot prod between image_class_embeds and query embeddings, then do (pred + shift) * scale + image_class_embeds = self.model.class_head.dense0(image_embeds) + image_class_embeds = image_class_embeds / image_class_embeds.norm(dim=-1, keepdim=True, p=2) + logit_shift = self.model.class_head.logit_shift(image_embeds) + logit_scale = self.model.class_head.logit_scale(image_embeds) + logit_scale = self.model.class_head.elu(logit_scale) + 1 + + # Get box head features + # NOTE: this is in a specific format, handle later + pred_boxes = self.model.box_predictor(image_embeds, feature_map) + + # filter out patches that are unlikely to map to objects + # the paper takes the top 10% during training, but we'll take the top 300 to be more in line with DETR + objectness_scores = self.model.objectness_predictor(image_embeds) + # compute the top 300 objectness indices + indices = torch.topk(objectness_scores, 300, dim=1).indices + # filter all the other stuff + image_class_embeds = image_class_embeds.gather(1, indices.unsqueeze(-1).expand(-1, -1, image_class_embeds.shape[-1])) + logit_shift = logit_shift.gather(1, indices.unsqueeze(-1).expand(-1, -1, logit_shift.shape[-1])) + logit_scale = logit_scale.gather(1, indices.unsqueeze(-1).expand(-1, -1, logit_scale.shape[-1])) + pred_boxes = pred_boxes.gather(1, indices.unsqueeze(-1).expand(-1, -1, pred_boxes.shape[-1])) + + assert image_class_embeds.shape[1] == 300 + + return image_class_embeds, logit_shift, logit_scale, pred_boxes + + +class ZeroShotObjectDetectorWithFeedback(nn.Module): + def __init__( + self, + model_name: str = "google/owlv2-large-patch14-ensemble", + image_size: tuple[int, int] = (1008, 1008), + device: torch.device | str = "cuda", + lru_cache_size: int = 4096, + jit: bool = True, + ): + super().__init__() + + self.device = device + self.model = Owlv2ForObjectDetection.from_pretrained(model_name).to(device).eval() + self.processor = Owlv2Processor.from_pretrained(model_name) + + if jit: + self.wrapped_image_embedder = torch.jit.trace_module( + WrappedImageEmbedder(self.model), + {"forward": (torch.randn(1, 3, *image_size, device=device),)}, + ) + else: + self.wrapped_image_embedder = WrappedImageEmbedder(self.model) + + # we cache the text embeddings to avoid recomputing them + # we use an LRU cache to avoid running out of memory + # especially because likely the tensors will be large and stored in GPU memory + self.augmented_label_encoding_cache: LRU | None = ( + LRU(lru_cache_size) if lru_cache_size > 0 else None + ) + self.not_augmented_label_encoding_cache: LRU | None = ( + LRU(lru_cache_size) if lru_cache_size > 0 else None + ) + + self.image_size = image_size + self.rgb_means = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1) + self.rgb_stds = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1) + + def _encode_text(self, text: list[str], augment: bool = True) -> torch.Tensor: + # NOTE: object detector liturature tends to use fewer templates than image classifiers + templates = medium_hypothesis_formats if augment else noop_hypothesis_formats + augmented_text = [template.format(t) for t in text for template in templates] + + processor_output = self.processor(text=augmented_text, return_tensors="pt", padding=True, truncation=True) + input_ids = processor_output.input_ids.to(self.device) + attn_mask = processor_output.attention_mask.to(self.device) + + # TODO: add appropriate batching to avoid OOM + text_output = self.model.owlv2.text_model( + input_ids=input_ids, attention_mask=attn_mask, return_dict=True + ) + + embeddings = text_output[1] + embeddings = self.model.owlv2.text_projection(embeddings) + embeddings = embeddings / embeddings.norm(dim=1, keepdim=True, p=2) + + embeddings = embeddings.reshape(len(text), len(templates), embeddings.shape[1]) + embeddings = embeddings.mean(dim=1) + embeddings = embeddings / embeddings.norm(dim=1, keepdim=True, p=2) + + return embeddings + + def encode_text(self, text: list[str], augment: bool = True) -> torch.Tensor: + if augment: + return batch_encode_cache_missed_list_elements( + partial(self._encode_text, augment=True), + text, + self.augmented_label_encoding_cache, + ) + else: + return batch_encode_cache_missed_list_elements( + partial(self._encode_text, augment=False), + text, + self.not_augmented_label_encoding_cache, + ) + + def get_image_data(self, image: torch.Tensor) -> dict[str, torch.Tensor]: + # we do the normalization here to make sure we have access to the right parameters + image = image / 255.0 + image = (image - self.rgb_means) / self.rgb_stds + + image_class_embeds, logit_shift, logit_scale, pred_boxes = self.wrapped_image_embedder(image) + + return { + "image_class_embeds": image_class_embeds, + "logit_shift": logit_shift, + "logit_scale": logit_scale, + "pred_boxes": pred_boxes, + } + + def forward( + self, + image: torch.Tensor | bytes, + labels: list[str], + inc_sub_labels_dict: dict[str, list[str]], + exc_sub_labels_dict: dict[str, list[str]] | None = None, + label_conf_thres: dict[str, float] | None = None, + augment_examples: bool = True, + nms_thre: float = 0.4, + run_class_agnostic_nms: bool = False, + image_scale_ratios: torch.Tensor | None = None, + ) -> list[list[torch.Tensor]]: + assert not run_class_agnostic_nms, "Class-agnostic NMS not yet implemented" + + if isinstance(image, bytes): + assert image_scale_ratios is None, "image_scale_ratios must be None if image is bytes as we define the scale internally" + image_tensor, image_scale_ratios = created_padded_tensor_from_bytes(image, self.image_size) + else: + assert image_scale_ratios is not None, "image_scale_ratios must be provided if image is a tensor as we cannot derive the scale internally" + image_tensor = image + + if label_conf_thres is None: + label_conf_thres = {} + + if len(labels) == 0: + raise ValueError("At least one label must be provided") + + if any([len(sub_labels) == 0 for sub_labels in inc_sub_labels_dict.values()]): + raise ValueError("Each label must include at least one sub-label") + + image_tensor = image_tensor.to(self.device) + + image_data = self.get_image_data(image_tensor) + + exc_sub_labels_dict = {} if exc_sub_labels_dict is None else exc_sub_labels_dict + # filter out empty excs lists + exc_sub_labels_dict = { + label: excs for label, excs in exc_sub_labels_dict.items() if len(excs) > 0 + } + + all_labels, all_labels_to_inds = squish_labels( + labels, inc_sub_labels_dict, exc_sub_labels_dict + ) + text_features = self.encode_text(all_labels, augment=augment_examples) + + scores_by_image_and_box = compute_query_fit( + text_features, + image_data["image_class_embeds"], + image_data["logit_shift"], + image_data["logit_scale"], + ) + # NOTE that scores_by_image_and_box is of shape [num_images, num_boxes, len(all_labels)] + # for the extracting of the per-box pro and con scores, we don't care about differentiating the first two dimensions + # so we flatten them to make the scatter_max operation easier + # and then we reshape them back to the original shape + scores = scores_by_image_and_box.view(-1, len(all_labels)) + # now we can proceed in the same way as the image classifier + + label_to_ind = {label: i for i, label in enumerate(labels)} + + pos_labels_to_master_inds, pos_labels_list = zip( + *[ + v + for label, incs in inc_sub_labels_dict.items() + for v in zip([label_to_ind[label]] * len(incs), incs) + ] + ) + pos_labels_inds = [all_labels_to_inds[label] for label in pos_labels_list] + + pos_scores = scores[:, pos_labels_inds] + + # pos_labels_to_master_inds indicates which indices we should be taking the max over for each label + # since our scatter_max will be batched, we need to offset this for each box + num_labels = len(labels) + num_boxes = scores.shape[0] + num_incs = len(pos_labels_to_master_inds) + offsets = ( + torch.arange(num_boxes).unsqueeze(1).expand(-1, num_incs).flatten() + * num_labels + ) + offsets = offsets.to(self.device) + indices_for_max = ( + torch.tensor(pos_labels_to_master_inds).to(self.device).repeat(num_boxes) + + offsets + ) + + max_pos_scores_flat, _ = scatter_max( + pos_scores.view(-1), indices_for_max, dim_size=num_boxes * num_labels + ) + max_pos_scores = max_pos_scores_flat.view(num_boxes, num_labels) + + # compute the same for the negative labels, if any + if len(exc_sub_labels_dict) > 0: + neg_labels_to_master_inds, neg_labels_list = zip( + *[ + v + for label, excs in exc_sub_labels_dict.items() + for v in zip([label_to_ind[label]] * len(excs), excs) + ] + ) + neg_labels_inds = [all_labels_to_inds[label] for label in neg_labels_list] + + neg_scores = scores[:, neg_labels_inds] + + num_excs = len(neg_labels_to_master_inds) + offsets = ( + torch.arange(num_boxes).unsqueeze(1).expand(-1, num_excs).flatten() + * num_labels + ) + offsets = offsets.to(self.device) + indices_for_max = ( + torch.tensor(neg_labels_to_master_inds) + .to(self.device) + .repeat(num_boxes) + + offsets + ) + + max_neg_scores_flat, _ = scatter_max( + neg_scores.view(-1), indices_for_max, dim_size=num_boxes * num_labels + ) + max_neg_scores = max_neg_scores_flat.view(num_boxes, num_labels) + else: + # if we have no negative labels, we just set the max neg scores to zero + # NOTE: possible to speed things up by skipping the ops conditional on having negative labels + max_neg_scores = torch.zeros_like(max_pos_scores) + + # now reshape the scores to [num_images, num_boxes, num_labels] + max_pos_scores = max_pos_scores.view(image_data["pred_boxes"].shape[0], image_data["pred_boxes"].shape[1], num_labels) + max_neg_scores = max_neg_scores.view(image_data["pred_boxes"].shape[0], image_data["pred_boxes"].shape[1], num_labels) + + # unlike the image classifier, we have to suppress boxes based on the scores of their neighbors + # we do this via a modified NMS algorithm + # because it operates over a variable-sized graph of boxes, it's hard to vectorize + # so we dump it into a script function that does fork-based async processing + # the output is a per-image list of per-object boxes in tlbr-score format + batched_predicted_boxes = batched_run_nms_based_box_suppression_for_all_objects( + max_pos_scores, + max_neg_scores, + image_data["pred_boxes"], + image_tensor.shape[2] / image_scale_ratios, + torch.tensor([label_conf_thres.get(label, 0.0) for label in labels], device=self.device), + nms_thre, + run_class_agnostic_nms, + ) + + return batched_predicted_boxes + + +@torch.jit.script +def compute_query_fit( + query_embeds: torch.Tensor, + image_class_embeds: torch.Tensor, + logit_shift: torch.Tensor, + logit_scale: torch.Tensor, +) -> torch.Tensor: + # Compute query fit + pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) + pred_logits = (pred_logits + logit_shift) * logit_scale + + return torch.sigmoid(pred_logits) + + +@torch.jit.script +def compute_iou_adjacency_list(boxes: torch.Tensor, nms_thre: float) -> list[torch.Tensor]: + boxes = boxes.clone() + boxes[:, 0] -= boxes[:, 2] / 2 + boxes[:, 1] -= boxes[:, 3] / 2 + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + ious = torchvision.ops.box_iou(boxes, boxes) + ious = ious >= nms_thre + # Set diagonal elements to zero .. no self loops! + ious.fill_diagonal_(0) + + edges = torch.nonzero(ious).unbind(-1) + + return edges + + +@torch.jit.script +def find_in_sorted_tensor(sorted_tensor: torch.Tensor, query: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + indices = torch.searchsorted(sorted_tensor, query) + indices_clamped = torch.clamp(indices, max=sorted_tensor.size(0) - 1) + present = sorted_tensor[indices_clamped] == query + in_bounds = indices < sorted_tensor.size(0) + found_mask = present & in_bounds + return found_mask, indices + + +@torch.jit.script +def compute_candidate_nms_via_adjacency_list(pro_max: torch.Tensor, con_max: torch.Tensor, adjacency_list: list[torch.Tensor], conf_thre: float) -> torch.Tensor: + # we use a scatter_max to efficiently compute, for each bounding box, the max con score of its adjacent boxes + expanded_con_max = con_max[adjacency_list[0]] + adjacent_con_max = scatter_max(expanded_con_max, adjacency_list[1], dim_size=con_max.shape[0])[0] + # we then filter down to boxes that both exceed the confidence threshold and are not suppressed by negative examples + # we do this by filtering on three expressions: + # 1. pro_max >= conf_thre: the box has a high enough confidence + # 2. pro_max >= adjacent_con_max: the box has a higher confidence than any adjacent boxes have negative confidence + # 3. pro_max >= con_max: the box has a higher confidence than its own negative confidence + # NOTE: could make this more efficient perhaps by filtering out the easy ones prior to the scatter_max + pro_valid = (pro_max >= conf_thre) * (pro_max >= adjacent_con_max) * (pro_max >= con_max) + pro_valid_inds = pro_valid.nonzero().squeeze(1) + + if pro_valid_inds.numel() == 0 or adjacency_list[0].numel() == 0: + # no boxes are valid or no boxes have any overlap with any other boxes + # either way, we can skip the NMS step + return pro_valid_inds + + # remove reported overlaps with boxes that are not valid + # this shrinks the graph that we need to do NMS over + first_node_valid, _ = find_in_sorted_tensor(pro_valid_inds, adjacency_list[0]) + second_node_valid, _ = find_in_sorted_tensor(pro_valid_inds, adjacency_list[1]) + + nms_inds = torch.nonzero(first_node_valid * second_node_valid).squeeze(1) + modified_adjacency_list = [adjacency_list[0][nms_inds], adjacency_list[1][nms_inds]] + + if nms_inds.numel() == 0: + # none of the remaining boxes have any overlap with any other remaining boxes + # so we can skip the NMS step + survive_inds = pro_valid.nonzero().squeeze(1) + return survive_inds + + # we compute the indices of the start and end of each box's adjacent boxes + # since our graph representation is just a list of edges, we would like to know which edges correspond to which nodes + # as the first node is sorted, we can just take the difference between adjacent nodes to get the start and end of each node's edges + first_node = modified_adjacency_list[0] + zero_tensor = torch.tensor([0], device=pro_max.device) + change_inds = (first_node[1:] != first_node[:-1]).nonzero()[:, 0] + len_tensor = torch.tensor([first_node.shape[0]], device=pro_max.device) + inds_of_adj_boxes = torch.cat([zero_tensor, change_inds + 1, len_tensor]) + + # we then run a NMS over the (remaining) boxes + sorted_pro_valid_inds = pro_valid_inds[pro_max[pro_valid].argsort(descending=True)] + + # check which boxes have graph connections + unique_nodes = first_node[inds_of_adj_boxes[:-1]] + has_connection, graph_node_indices = find_in_sorted_tensor(unique_nodes, sorted_pro_valid_inds) + connected_sorted_pro_valid_inds = sorted_pro_valid_inds[has_connection] + graph_indices = graph_node_indices[has_connection] + + for i, j in zip(connected_sorted_pro_valid_inds, graph_indices): + if pro_valid[i] == 0: + continue + + remapped_start_ind = inds_of_adj_boxes[j] + remapped_end_ind = inds_of_adj_boxes[j+1] + adj_boxes = modified_adjacency_list[1][remapped_start_ind:remapped_end_ind] + pro_valid[adj_boxes] = 0 + + survive_inds = pro_valid.nonzero().squeeze(1) + + return survive_inds + + +@torch.jit.script +def run_nms_based_box_suppression_for_one_object( + pro_max: torch.Tensor, + con_max: torch.Tensor, + pred_boxes: torch.Tensor, + adjacency_list: list[torch.Tensor], + image_scale: float, + conf_thre: float = 0.001, +) -> torch.Tensor: + survive_inds = compute_candidate_nms_via_adjacency_list(pro_max, con_max, adjacency_list, conf_thre) + + boxes = pred_boxes.squeeze(0)[survive_inds] + logits = pro_max[survive_inds] + + logits = logits.unsqueeze(-1) + boxes = boxes * image_scale + + # Convert boxes from center_x, center_y, width, height (cx_cy_w_h) to top_left_x, top_left_y, bottom_right_x, bottom_right_y (tlbr) + cx, cy, w, h = boxes.unbind(-1) + tl_x = cx - 0.5 * w + tl_y = cy - 0.5 * h + br_x = cx + 0.5 * w + br_y = cy + 0.5 * h + boxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=-1) + + boxes_with_scores = torch.cat([boxes, logits], dim=-1) + boxes_with_scores = boxes_with_scores[boxes_with_scores[:, 4].argsort(descending=True)] + + print(boxes_with_scores) + print(boxes_with_scores.shape) + + return boxes_with_scores + + +@torch.jit.script +def run_nms_based_box_suppression_for_all_objects( + pro_max: torch.Tensor, + con_max: torch.Tensor, + pred_boxes: torch.Tensor, + image_scale: float, + conf_thres: torch.Tensor, + nms_thre: float = 0.4, + run_class_agnostic_nms: bool = False, +) -> list[torch.Tensor]: + # pred_boxes is assumed to be [num_boxes, 4] + # pro_max and con_max are assumed to be [num_boxes, num_objects] + # conf_thres is assumed to be [num_objects] + adjacency_list = compute_iou_adjacency_list(pred_boxes, nms_thre) + + futures = [ + torch.jit.fork( + run_nms_based_box_suppression_for_one_object, + pro_max[:, i], + con_max[:, i], + pred_boxes, + adjacency_list, + image_scale, + conf_thres[i], + ) for i in range(pro_max.shape[1]) + ] + + predicted_boxes = [torch.jit.wait(fut) for fut in futures] + + # TODO: add class-agnostic NMS + assert not run_class_agnostic_nms, "Class-agnostic NMS not yet implemented" + + return predicted_boxes + + +@torch.jit.script +def batched_run_nms_based_box_suppression_for_all_objects( + pro_max: torch.Tensor, + con_max: torch.Tensor, + pred_boxes: torch.Tensor, + image_scales: torch.Tensor, + conf_thres: torch.Tensor, + nms_thre: float = 0.4, + run_class_agnostic_nms: bool = False, +) -> list[list[torch.Tensor]]: + # pred_boxes is assumed to be [num_images, num_boxes, 4] + # pro_max and con_max are assumed to be [num_images, num_boxes, num_objects] + # conf_thres is assumed to be [num_objects] + # image_scales is assumed to be [num_images] + futures = [ + torch.jit.fork( + run_nms_based_box_suppression_for_all_objects, + pro_max[i], + con_max[i], + pred_boxes[i], + image_scales[i].item(), + conf_thres, + nms_thre, + run_class_agnostic_nms, + ) for i in range(pro_max.shape[0]) + ] + + batched_predicted_boxes = [torch.jit.wait(fut) for fut in futures] + + return batched_predicted_boxes \ No newline at end of file diff --git a/directai_fastapi/modeling/tensor_utils.py b/directai_fastapi/modeling/tensor_utils.py new file mode 100644 index 0000000..cf160ae --- /dev/null +++ b/directai_fastapi/modeling/tensor_utils.py @@ -0,0 +1,122 @@ +import torch +from typing import Callable +import numpy as np +from PIL import Image +import io +from lru import LRU + + +def batch_encode_cache_missed_list_elements( + encode_fn: Callable[[list], torch.Tensor], args_list: list, cache: dict | LRU | None +) -> torch.Tensor: + if len(args_list) == 0: + raise ValueError("args_list should not be empty") + + # NOTE: the batch size may be larger than the cache size + # NOTE: by passing cache=None, we can disable caching, in which case this is just a straight-through operation + + if cache is not None: + # first we retrieve any cached values + cache_hit_inds = [] + cache_miss_inds = [] + cache_hits_tensor_list = [] + cache_misses = [] + for i, arg in enumerate(args_list): + if (tensor := cache.get(arg)) is not None: + cache_hit_inds.append(i) + cache_hits_tensor_list.append(tensor) + else: + cache_miss_inds.append(i) + cache_misses.append(arg) + + cache_hits_tensor = ( + torch.stack(cache_hits_tensor_list, dim=0) + if len(cache_hits_tensor_list) > 0 + else None + ) + else: + cache_hit_inds = [] + cache_miss_inds = list(range(len(args_list))) + cache_hits_tensor = None + cache_misses = args_list + + # then we batch encode any cache misses + if len(cache_misses) > 0: + cache_misses_tensor = encode_fn(cache_misses) + for i, arg in enumerate(cache_misses): + if cache is not None: + cache[arg] = cache_misses_tensor[i] + else: + cache_misses_tensor = None + + # now we merge the cache hits and cache misses + # such that the order of the output tensors matches the order of the input args + # NOTE: at least one of cache_hit_inds and cache_miss_inds will be non-empty + # NOTE: we assume device, dtype, and shape are unchanged between different calls to encode_fn + if cache_hits_tensor is None: + output_tensor = cache_misses_tensor + elif cache_misses_tensor is None: + output_tensor = cache_hits_tensor + else: + output_tensor = torch.empty( + len(args_list), + *cache_misses_tensor.shape[1:], + dtype=cache_misses_tensor.dtype, + device=cache_misses_tensor.device, + ) + output_tensor[cache_hit_inds] = cache_hits_tensor + output_tensor[cache_miss_inds] = cache_misses_tensor + + assert isinstance(output_tensor, torch.Tensor) # just to make mypy happy + + return output_tensor + + +def resize_pil_image(pil_image: Image.Image, image_size: tuple[int, int]) -> Image.Image: + if pil_image.format == "JPEG": + # try requesting a format-specific conversion + # this significantly speeds up the subsequent resize operation + # note that torchvision does NOT try this internally, and is therefore much slower + # (plus likely using draft here leads to a more accurate resize operation) + pil_image.draft("RGB", image_size) + pil_image = pil_image.convert("RGB") + pil_image = pil_image.resize(image_size, Image.BICUBIC) + return pil_image + + +def image_bytes_to_tensor(image: bytes, image_size: tuple[int, int]) -> torch.Tensor: + image_buffer = io.BytesIO(image) + pil_image = Image.open(image_buffer) + pil_image = resize_pil_image(pil_image, image_size) + np_image = np.asarray(pil_image) + tensor = torch.tensor(np_image).permute(2, 0, 1).unsqueeze(0) + return tensor + + +def squish_labels( + labels: list[str], + inc_sub_labels_dict: dict[str, list[str]], + exc_sub_labels_dict: dict[str, list[str]], + ) -> tuple[list[str], dict[str, int]]: + # build one list of labels to encode, without duplicates + # and lists / dicts containing the indices of each label + # and the indices of each label's sub-labels + all_labels_to_inds: dict[str, int] = {} + all_labels = [] + + for label in labels: + inc_subs = inc_sub_labels_dict.get(label) + if inc_subs is not None: + for inc_sub in inc_subs: + if inc_sub not in all_labels_to_inds: + all_labels_to_inds[inc_sub] = len(all_labels_to_inds) + all_labels.append(inc_sub) + + exc_subs = exc_sub_labels_dict.get(label) + if exc_subs is not None: + for exc_sub in exc_subs: + if exc_sub not in all_labels_to_inds: + all_labels_to_inds[exc_sub] = len(all_labels_to_inds) + all_labels.append(exc_sub) + + return all_labels, all_labels_to_inds diff --git a/directai_fastapi/requirements.txt b/directai_fastapi/requirements.txt index 342c26b..26d3ec1 100644 --- a/directai_fastapi/requirements.txt +++ b/directai_fastapi/requirements.txt @@ -5,4 +5,10 @@ pillow==9.2.0 opencv-python==4.10.0.84 uvicorn[standard]==0.20.0 gunicorn==22.0.0 -ray[serve]==2.34.0 \ No newline at end of file +ray[serve]==2.34.0 +mypy==1.11.1 +open_clip_torch==2.24.0 +https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_scatter-2.1.2%2Bpt22cu121-cp310-cp310-linux_x86_64.whl +lru-dict==1.3.0 +transformers==4.35 +flash-attn==2.6.3 \ No newline at end of file diff --git a/directai_fastapi/unit_tests/test.py b/directai_fastapi/unit_tests/test.py index 1be6c3f..32359cf 100644 --- a/directai_fastapi/unit_tests/test.py +++ b/directai_fastapi/unit_tests/test.py @@ -2,6 +2,10 @@ import unittest from unit_tests.test_modules.test_async_redis import * +from unit_tests.test_modules.test_utils import * +from unit_tests.test_modules.test_tensor_utils import * +from unit_tests.test_modules.test_classifier import * +from unit_tests.test_modules.test_detector import * if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/directai_fastapi/unit_tests/test_modules/test_detector.py b/directai_fastapi/unit_tests/test_modules/test_detector.py new file mode 100644 index 0000000..266b6cb --- /dev/null +++ b/directai_fastapi/unit_tests/test_modules/test_detector.py @@ -0,0 +1,54 @@ +import unittest +import torch +from typing_extensions import ClassVar + +from modeling.object_detector import ( + ZeroShotObjectDetectorWithFeedback, + created_padded_tensor_from_bytes, +) + + +class TestObjectDetector(unittest.TestCase): + # we have to define these here because mypy doesn't dive into the init hiding behind the classmethod + object_detector = NotImplemented # type: ClassVar[ZeroShotObjectDetectorWithFeedback] + coke_bottle_image_bytes = NotImplemented # type: ClassVar[bytes] + default_labels = NotImplemented # type: ClassVar[list[str]] + default_incs = NotImplemented # type: ClassVar[dict[str, list[str]]] + default_excs = NotImplemented # type: ClassVar[dict[str, list[str]]] + default_nms_thre = NotImplemented # type: ClassVar[float] + default_conf_thres = NotImplemented # type: ClassVar[dict[str, float]] + + @classmethod + def setUpClass(cls) -> None: + cls.object_detector = ZeroShotObjectDetectorWithFeedback() + + coke_bottle_filepath = "unit_tests/sample_data/coke_through_the_ages.jpeg" + with open(coke_bottle_filepath, "rb") as f: + cls.coke_bottle_image_bytes = f.read() + + cls.default_labels = ["bottle", "moose"] + cls.default_incs = { + "bottle": ["bottle", "glass bottle", "plastic bottle", "water bottle"], + "moose": ["moose", "elk", "deer"], + } + cls.default_excs = { + "bottle": ["can", "soda can", "aluminum can"], + } + cls.default_nms_thre = 0.4 + cls.default_conf_thres = { + "bottle": 0.1, + "moose": 0.1, + } + + def test_detect_objects_from_image_bytes(self) -> None: + with torch.no_grad(): + batched_predicted_boxes = self.object_detector( + self.coke_bottle_image_bytes, + labels=self.default_labels, + inc_sub_labels_dict=self.default_incs, + exc_sub_labels_dict=self.default_excs, + nms_thre=self.default_nms_thre, + label_conf_thres=self.default_conf_thres, + ) + + print(batched_predicted_boxes) \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index a4f20ef..7d2d1d2 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,30 +1,35 @@ version: '2.3' services: - local_redis: + local_redis_isaac_2: build: redis_data/ - container_name: local_redis + container_name: local_redis_isaac_2 ports: - - 6379:6379 + - 63790:63790 volumes: - ./redis_data/:/data networks: - deploy_network - local_fastapi: + local_fastapi_isaac_2: build: directai_fastapi/ ports: - - 8000:8000 + - 10000:10000 networks: - deploy_network - container_name: local_fastapi + container_name: local_fastapi_isaac_2 environment: - PYTHONUNBUFFERED=1 - - NVIDIA_VISIBLE_DEVICES=all - - HF_HOME=/.cache/huggingface + - NVIDIA_VISIBLE_DEVICES=1 + - HF_HOME=/directai_fastapi/.cache/huggingface + - CACHE_REDIS_PORT=63790 + env_file: + - directai_fastapi/.env runtime: nvidia volumes: - - ./logs:/logs + - ./logs:/directai_fastapi/logs + - ./.cache:/directai_fastapi/.cache + shm_size: 10.24g # because Ray complains if it's less depends_on: - - local_redis + - local_redis_isaac_2 extra_hosts: - "host.docker.internal:host-gateway" diff --git a/redis_data/redis_entrypoint.sh b/redis_data/redis_entrypoint.sh index 4381723..d7a3ecb 100755 --- a/redis_data/redis_entrypoint.sh +++ b/redis_data/redis_entrypoint.sh @@ -8,7 +8,7 @@ cleanup() { trap 'cleanup' SIGTERM #Execute a command in the background -redis-server --requirepass "default_password" --appendonly "yes" --appendfsync "always" & +redis-server --requirepass "default_password" --appendonly "yes" --appendfsync "always" --port 63790 & #Save the PID of the background process REDIS_PID=$! From a2e2a536b5342dd87a358b52b6c0ccaf5a9c44de Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Fri, 23 Aug 2024 18:01:33 -0400 Subject: [PATCH 2/7] integration tests passing --- .../modeling/distributed_backend.py | 55 +- directai_fastapi/modeling/object_detector.py | 452 ++++++--- directai_fastapi/pydantic_models.py | 47 +- directai_fastapi/server.py | 32 +- docker-compose.yml | 16 +- .../test_modules/test_detector.py | 868 ++++++++++-------- mypy.sh | 6 +- 7 files changed, 891 insertions(+), 585 deletions(-) diff --git a/directai_fastapi/modeling/distributed_backend.py b/directai_fastapi/modeling/distributed_backend.py index b83a0bc..45c5120 100644 --- a/directai_fastapi/modeling/distributed_backend.py +++ b/directai_fastapi/modeling/distributed_backend.py @@ -9,6 +9,7 @@ from typing import List from pydantic_models import ClassifierResponse, SingleDetectionResponse from modeling.image_classifier import ZeroShotImageClassifierWithFeedback +from modeling.object_detector import ZeroShotObjectDetectorWithFeedback serve.start(http_options={"port": 8100}) @@ -16,15 +17,48 @@ @serve.deployment class ObjectDetector: - async def __call__(self, image: Image.Image) -> List[List[SingleDetectionResponse]]: - # Placeholder implementation - single_detection = { - "tlbr": [0.0, 0.0, 1.0, 1.0], - "score": random.random(), - "class": "dog", - } - sdr = SingleDetectionResponse.parse_obj(single_detection) - return [[sdr]] + def __init__(self) -> None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = ZeroShotObjectDetectorWithFeedback(device=device) + + async def __call__( + self, + image: bytes, + labels: list[str], + inc_sub_labels_dict: dict[str, list[str]], + exc_sub_labels_dict: dict[str, list[str]] | None = None, + label_conf_thres: dict[str, float] | None = None, + augment_examples: bool = True, + nms_thre: float = 0.4, + run_class_agnostic_nms: bool = False, + ) -> list[SingleDetectionResponse]: + with torch.inference_mode(), torch.autocast(str(self.model.device)): + batched_predicted_boxes = self.model( + image, + labels=labels, + inc_sub_labels_dict=inc_sub_labels_dict, + exc_sub_labels_dict=exc_sub_labels_dict, + label_conf_thres=label_conf_thres, + augment_examples=augment_examples, + nms_thre=nms_thre, + run_class_agnostic_nms=run_class_agnostic_nms, + ) + + # since we are not batching, we can assume the output has batch size 1 + per_label_boxes = batched_predicted_boxes[0] + + # predicted_boxes is a list in order of labels, with each box of the form [x1, y1, x2, y2, confidence] + detection_responses = [] + for label, boxes in zip(labels, per_label_boxes): + for detection in boxes: + single_detection_response = SingleDetectionResponse( + tlbr=detection[:4].tolist(), + score=detection[4].item(), + class_=label, # type: ignore + ) + detection_responses.append(single_detection_response) + + return detection_responses @serve.deployment @@ -41,8 +75,7 @@ async def __call__( exc_sub_labels_dict: dict[str, list[str]] | None = None, augment_examples: bool = True, ) -> ClassifierResponse: - - with torch.no_grad(), torch.autocast(str(self.model.device)): + with torch.inference_mode(), torch.autocast(str(self.model.device)): raw_scores = self.model( image, labels=labels, diff --git a/directai_fastapi/modeling/object_detector.py b/directai_fastapi/modeling/object_detector.py index 5b03d41..6e2e9ca 100644 --- a/directai_fastapi/modeling/object_detector.py +++ b/directai_fastapi/modeling/object_detector.py @@ -3,14 +3,14 @@ from PIL import Image import torch from torch import nn -import torchvision +import torchvision # type: ignore import numpy as np -from transformers import Owlv2Processor, Owlv2ForObjectDetection, Owlv2VisionModel -from transformers.models.owlv2.modeling_owlv2 import Owlv2Attention +from transformers import Owlv2Processor, Owlv2ForObjectDetection, Owlv2VisionModel # type: ignore +from transformers.models.owlv2.modeling_owlv2 import Owlv2Attention # type: ignore import time from typing import Union -from torch_scatter import scatter_max -from flash_attn import flash_attn_func +from torch_scatter import scatter_max # type: ignore +from flash_attn import flash_attn_func # type: ignore import io from lru import LRU from functools import partial @@ -29,56 +29,74 @@ def created_padded_tensor_from_bytes( padded_image_tensor = torch.ones((1, 3, *image_size)) * 114.0 # TODO: add nonblocking streaming to GPU - + image_buffer = io.BytesIO(image_bytes) pil_image = Image.open(image_buffer) current_size = pil_image.size - + r = min(image_size[0] / current_size[0], image_size[1] / current_size[1]) target_size = (int(r * current_size[0]), int(r * current_size[1])) - + pil_image = resize_pil_image(pil_image, target_size) np_image = np.asarray(pil_image) torch_image = torch.tensor(np_image).permute(2, 0, 1).unsqueeze(0) - - padded_image_tensor[:, :, :torch_image.shape[2], :torch_image.shape[3]] = torch_image - - image_scale_ratios = torch.tensor([r,]) - + + padded_image_tensor[:, :, : torch_image.shape[2], : torch_image.shape[3]] = ( + torch_image + ) + + image_scale_ratios = torch.tensor( + [ + r, + ] + ) + return padded_image_tensor, image_scale_ratios - + def flash_attn_owl_vit_encoder_forward( - self: Owlv2Attention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - causal_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> tuple[torch.Tensor, None]: - assert not output_attentions, "output_attentions not supported for flash attention implementation" - assert attention_mask is None, "attention_mask not supported for flash attention implementation" + self: Owlv2Attention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, +) -> tuple[torch.Tensor, None]: + assert ( + not output_attentions + ), "output_attentions not supported for flash attention implementation" + assert ( + attention_mask is None + ), "attention_mask not supported for flash attention implementation" # technically flash_attn DOES support causal attention # but the OWL usage of causal attention mask does not limit it to true causal attention # we don't support generalized attention, so we're just going to assert causal attention mask is ALSO None - assert causal_attention_mask is None, "causal_attention_mask not supported for flash attention implementation" + assert ( + causal_attention_mask is None + ), "causal_attention_mask not supported for flash attention implementation" bsz, tgt_len, embed_dim = hidden_states.shape - + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - - query_states = query_states.contiguous().view(bsz, tgt_len, self.num_heads, self.head_dim) - key_states = key_states.contiguous().view(bsz, tgt_len, self.num_heads, self.head_dim) - value_states = value_states.contiguous().view(bsz, tgt_len, self.num_heads, self.head_dim) - + + query_states = query_states.contiguous().view( + bsz, tgt_len, self.num_heads, self.head_dim + ) + key_states = key_states.contiguous().view( + bsz, tgt_len, self.num_heads, self.head_dim + ) + value_states = value_states.contiguous().view( + bsz, tgt_len, self.num_heads, self.head_dim + ) + # convert to appropriate dtype # NOTE: bf16 may be more appropriate than fp16 query_states = query_states.to(torch.float16) key_states = key_states.to(torch.float16) value_states = value_states.to(torch.float16) - + attn_output = flash_attn_func( query_states, key_states, @@ -86,16 +104,16 @@ def flash_attn_owl_vit_encoder_forward( dropout_p=0, softmax_scale=self.scale, ) - + attn_output = attn_output.view(bsz, tgt_len, embed_dim) - + # convert back to appropriate dtype attn_output = attn_output.to(hidden_states.dtype) - + attn_output = self.out_proj(attn_output) - + return attn_output, None - + # we're copying the function signature from the original # and just replacing the method with a faster one based on flash_attn @@ -109,9 +127,9 @@ def flash_attn_owl_vit_encoder_forward( class VisionModelWrapper(nn.Module): def __init__(self, vision_model: Owlv2VisionModel) -> None: super().__init__() - + self.vision_model = vision_model - + # we're going to monkey patch the forward method of the attention layers # to replace it with a faster one based on flash_attn # the alternative is to subclass the entire model, but that's a lot of work @@ -119,29 +137,32 @@ def __init__(self, vision_model: Owlv2VisionModel) -> None: # and assert that the input is as expected for owlv2_vision_model_encoder_layer in self.vision_model.encoder.layers: owlv2_vision_model_encoder_layer.self_attn.forward = partial( - flash_attn_owl_vit_encoder_forward, owlv2_vision_model_encoder_layer.self_attn + flash_attn_owl_vit_encoder_forward, + owlv2_vision_model_encoder_layer.self_attn, ) def forward(self, image: torch.Tensor) -> torch.Tensor: vision_outputs = self.vision_model(pixel_values=image, return_dict=True) - + # Get image embedding last_hidden_states = vision_outputs[0] image_embeds = self.vision_model.post_layernorm(last_hidden_states) - + return image_embeds class WrappedImageEmbedder(nn.Module): def __init__(self, model: Owlv2ForObjectDetection) -> None: super().__init__() - + self.model = model self.wrapped_vision_model = VisionModelWrapper(self.model.owlv2.vision_model) - - def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + def forward( + self, image: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: image_embeds = self.wrapped_vision_model(image) - + # Resize class token new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size) @@ -158,32 +179,42 @@ def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torc image_embeds.shape[-1], ) feature_map = image_embeds.reshape(new_size) - + # Get class head features # we do dot prod between image_class_embeds and query embeddings, then do (pred + shift) * scale image_class_embeds = self.model.class_head.dense0(image_embeds) - image_class_embeds = image_class_embeds / image_class_embeds.norm(dim=-1, keepdim=True, p=2) + image_class_embeds = image_class_embeds / image_class_embeds.norm( + dim=-1, keepdim=True, p=2 + ) logit_shift = self.model.class_head.logit_shift(image_embeds) logit_scale = self.model.class_head.logit_scale(image_embeds) logit_scale = self.model.class_head.elu(logit_scale) + 1 - + # Get box head features # NOTE: this is in a specific format, handle later pred_boxes = self.model.box_predictor(image_embeds, feature_map) - + # filter out patches that are unlikely to map to objects # the paper takes the top 10% during training, but we'll take the top 300 to be more in line with DETR objectness_scores = self.model.objectness_predictor(image_embeds) # compute the top 300 objectness indices indices = torch.topk(objectness_scores, 300, dim=1).indices # filter all the other stuff - image_class_embeds = image_class_embeds.gather(1, indices.unsqueeze(-1).expand(-1, -1, image_class_embeds.shape[-1])) - logit_shift = logit_shift.gather(1, indices.unsqueeze(-1).expand(-1, -1, logit_shift.shape[-1])) - logit_scale = logit_scale.gather(1, indices.unsqueeze(-1).expand(-1, -1, logit_scale.shape[-1])) - pred_boxes = pred_boxes.gather(1, indices.unsqueeze(-1).expand(-1, -1, pred_boxes.shape[-1])) - + image_class_embeds = image_class_embeds.gather( + 1, indices.unsqueeze(-1).expand(-1, -1, image_class_embeds.shape[-1]) + ) + logit_shift = logit_shift.gather( + 1, indices.unsqueeze(-1).expand(-1, -1, logit_shift.shape[-1]) + ) + logit_scale = logit_scale.gather( + 1, indices.unsqueeze(-1).expand(-1, -1, logit_scale.shape[-1]) + ) + pred_boxes = pred_boxes.gather( + 1, indices.unsqueeze(-1).expand(-1, -1, pred_boxes.shape[-1]) + ) + assert image_class_embeds.shape[1] == 300 - + return image_class_embeds, logit_shift, logit_scale, pred_boxes @@ -197,11 +228,13 @@ def __init__( jit: bool = True, ): super().__init__() - + self.device = device - self.model = Owlv2ForObjectDetection.from_pretrained(model_name).to(device).eval() + self.model = ( + Owlv2ForObjectDetection.from_pretrained(model_name).to(device).eval() + ) self.processor = Owlv2Processor.from_pretrained(model_name) - + if jit: self.wrapped_image_embedder = torch.jit.trace_module( WrappedImageEmbedder(self.model), @@ -219,35 +252,41 @@ def __init__( self.not_augmented_label_encoding_cache: LRU | None = ( LRU(lru_cache_size) if lru_cache_size > 0 else None ) - + self.image_size = image_size - self.rgb_means = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1) - self.rgb_stds = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1) + self.rgb_means = torch.tensor([0.485, 0.456, 0.406], device=device).view( + 1, 3, 1, 1 + ) + self.rgb_stds = torch.tensor([0.229, 0.224, 0.225], device=device).view( + 1, 3, 1, 1 + ) def _encode_text(self, text: list[str], augment: bool = True) -> torch.Tensor: # NOTE: object detector liturature tends to use fewer templates than image classifiers templates = medium_hypothesis_formats if augment else noop_hypothesis_formats augmented_text = [template.format(t) for t in text for template in templates] - - processor_output = self.processor(text=augmented_text, return_tensors="pt", padding=True, truncation=True) + + processor_output = self.processor( + text=augmented_text, return_tensors="pt", padding=True, truncation=True + ) input_ids = processor_output.input_ids.to(self.device) attn_mask = processor_output.attention_mask.to(self.device) - + # TODO: add appropriate batching to avoid OOM text_output = self.model.owlv2.text_model( input_ids=input_ids, attention_mask=attn_mask, return_dict=True ) - + embeddings = text_output[1] embeddings = self.model.owlv2.text_projection(embeddings) embeddings = embeddings / embeddings.norm(dim=1, keepdim=True, p=2) - + embeddings = embeddings.reshape(len(text), len(templates), embeddings.shape[1]) embeddings = embeddings.mean(dim=1) embeddings = embeddings / embeddings.norm(dim=1, keepdim=True, p=2) - + return embeddings - + def encode_text(self, text: list[str], augment: bool = True) -> torch.Tensor: if augment: return batch_encode_cache_missed_list_elements( @@ -266,9 +305,11 @@ def get_image_data(self, image: torch.Tensor) -> dict[str, torch.Tensor]: # we do the normalization here to make sure we have access to the right parameters image = image / 255.0 image = (image - self.rgb_means) / self.rgb_stds - - image_class_embeds, logit_shift, logit_scale, pred_boxes = self.wrapped_image_embedder(image) - + + image_class_embeds, logit_shift, logit_scale, pred_boxes = ( + self.wrapped_image_embedder(image) + ) + return { "image_class_embeds": image_class_embeds, "logit_shift": logit_shift, @@ -285,18 +326,22 @@ def forward( label_conf_thres: dict[str, float] | None = None, augment_examples: bool = True, nms_thre: float = 0.4, - run_class_agnostic_nms: bool = False, + run_class_agnostic_nms: bool = True, image_scale_ratios: torch.Tensor | None = None, ) -> list[list[torch.Tensor]]: - assert not run_class_agnostic_nms, "Class-agnostic NMS not yet implemented" - if isinstance(image, bytes): - assert image_scale_ratios is None, "image_scale_ratios must be None if image is bytes as we define the scale internally" - image_tensor, image_scale_ratios = created_padded_tensor_from_bytes(image, self.image_size) + assert ( + image_scale_ratios is None + ), "image_scale_ratios must be None if image is bytes as we define the scale internally" + image_tensor, image_scale_ratios = created_padded_tensor_from_bytes( + image, self.image_size + ) else: - assert image_scale_ratios is not None, "image_scale_ratios must be provided if image is a tensor as we cannot derive the scale internally" + assert ( + image_scale_ratios is not None + ), "image_scale_ratios must be provided if image is a tensor as we cannot derive the scale internally" image_tensor = image - + if label_conf_thres is None: label_conf_thres = {} @@ -401,11 +446,19 @@ def forward( # if we have no negative labels, we just set the max neg scores to zero # NOTE: possible to speed things up by skipping the ops conditional on having negative labels max_neg_scores = torch.zeros_like(max_pos_scores) - + # now reshape the scores to [num_images, num_boxes, num_labels] - max_pos_scores = max_pos_scores.view(image_data["pred_boxes"].shape[0], image_data["pred_boxes"].shape[1], num_labels) - max_neg_scores = max_neg_scores.view(image_data["pred_boxes"].shape[0], image_data["pred_boxes"].shape[1], num_labels) - + max_pos_scores = max_pos_scores.view( + image_data["pred_boxes"].shape[0], + image_data["pred_boxes"].shape[1], + num_labels, + ) + max_neg_scores = max_neg_scores.view( + image_data["pred_boxes"].shape[0], + image_data["pred_boxes"].shape[1], + num_labels, + ) + # unlike the image classifier, we have to suppress boxes based on the scores of their neighbors # we do this via a modified NMS algorithm # because it operates over a variable-sized graph of boxes, it's hard to vectorize @@ -416,11 +469,14 @@ def forward( max_neg_scores, image_data["pred_boxes"], image_tensor.shape[2] / image_scale_ratios, - torch.tensor([label_conf_thres.get(label, 0.0) for label in labels], device=self.device), + torch.tensor( + [label_conf_thres.get(label, 0.0) for label in labels], + device=self.device, + ), nms_thre, run_class_agnostic_nms, ) - + return batched_predicted_boxes @@ -434,12 +490,14 @@ def compute_query_fit( # Compute query fit pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) pred_logits = (pred_logits + logit_shift) * logit_scale - + return torch.sigmoid(pred_logits) @torch.jit.script -def compute_iou_adjacency_list(boxes: torch.Tensor, nms_thre: float) -> list[torch.Tensor]: +def compute_iou_adjacency_list( + boxes: torch.Tensor, nms_thre: float +) -> list[torch.Tensor]: boxes = boxes.clone() boxes[:, 0] -= boxes[:, 2] / 2 boxes[:, 1] -= boxes[:, 3] / 2 @@ -449,14 +507,16 @@ def compute_iou_adjacency_list(boxes: torch.Tensor, nms_thre: float) -> list[tor ious = ious >= nms_thre # Set diagonal elements to zero .. no self loops! ious.fill_diagonal_(0) - + edges = torch.nonzero(ious).unbind(-1) - + return edges @torch.jit.script -def find_in_sorted_tensor(sorted_tensor: torch.Tensor, query: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def find_in_sorted_tensor( + sorted_tensor: torch.Tensor, query: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: indices = torch.searchsorted(sorted_tensor, query) indices_clamped = torch.clamp(indices, max=sorted_tensor.size(0) - 1) present = sorted_tensor[indices_clamped] == query @@ -466,19 +526,89 @@ def find_in_sorted_tensor(sorted_tensor: torch.Tensor, query: torch.Tensor) -> t @torch.jit.script -def compute_candidate_nms_via_adjacency_list(pro_max: torch.Tensor, con_max: torch.Tensor, adjacency_list: list[torch.Tensor], conf_thre: float) -> torch.Tensor: +def run_nms_via_adjacency_list( + valid_box_indices_by_descending_score: torch.Tensor, + adjacency_list: list[torch.Tensor], +) -> torch.Tensor: + # we compute the indices of the start and end of each box's adjacent boxes + # since our graph representation is just a list of edges, we would like to know which edges correspond to which nodes + # as the first node is sorted, we can just take the difference between adjacent nodes to get the start and end of each node's edges + # NOTE: we've already computed the graph so we don't need to supply an NMS threshold + first_node = adjacency_list[0] + zero_tensor = torch.tensor([0], device=first_node.device) + change_inds = (first_node[1:] != first_node[:-1]).nonzero()[:, 0] + len_tensor = torch.tensor([first_node.shape[0]], device=first_node.device) + inds_of_adj_boxes = torch.cat([zero_tensor, change_inds + 1, len_tensor]) + + # we then run a NMS over the boxes + # need to keep track of which boxes have survived, which means an index per possible box + # note that some boxes may have already been filtered out + # so valid_box_indices_by_descending_score may not contain all indices + # TODO: switch this to a sparse tensor for efficiency + highest_possible_ind = max( + valid_box_indices_by_descending_score.max().item(), + adjacency_list[0].max().item(), + adjacency_list[1].max().item(), + ) + assert isinstance(highest_possible_ind, int) # this is just to make mypy happy + has_survived = torch.zeros( + highest_possible_ind + 1, + device=valid_box_indices_by_descending_score.device, + dtype=torch.bool, + ) + has_survived[valid_box_indices_by_descending_score] = 1 + + # check which boxes have graph connections + unique_nodes = first_node[inds_of_adj_boxes[:-1]] + has_connection, graph_node_indices = find_in_sorted_tensor( + unique_nodes, valid_box_indices_by_descending_score + ) + connected_sorted_pro_valid_inds = valid_box_indices_by_descending_score[ + has_connection + ] + graph_indices = graph_node_indices[has_connection] + + # supress the boxes for which their (unsupressed) neighbors have higher scores + for i, j in zip(connected_sorted_pro_valid_inds, graph_indices): + if has_survived[i] == 0: + continue + + remapped_start_ind = inds_of_adj_boxes[j] + remapped_end_ind = inds_of_adj_boxes[j + 1] + adj_boxes = adjacency_list[1][remapped_start_ind:remapped_end_ind] + has_survived[adj_boxes] = 0 + + survive_inds = valid_box_indices_by_descending_score[ + has_survived[valid_box_indices_by_descending_score] + ] + + return survive_inds + + +@torch.jit.script +def compute_candidate_nms_via_adjacency_list( + pro_max: torch.Tensor, + con_max: torch.Tensor, + adjacency_list: list[torch.Tensor], + conf_thre: float, + run_nms: bool, +) -> torch.Tensor: # we use a scatter_max to efficiently compute, for each bounding box, the max con score of its adjacent boxes expanded_con_max = con_max[adjacency_list[0]] - adjacent_con_max = scatter_max(expanded_con_max, adjacency_list[1], dim_size=con_max.shape[0])[0] + adjacent_con_max = scatter_max( + expanded_con_max, adjacency_list[1], dim_size=con_max.shape[0] + )[0] # we then filter down to boxes that both exceed the confidence threshold and are not suppressed by negative examples # we do this by filtering on three expressions: # 1. pro_max >= conf_thre: the box has a high enough confidence # 2. pro_max >= adjacent_con_max: the box has a higher confidence than any adjacent boxes have negative confidence # 3. pro_max >= con_max: the box has a higher confidence than its own negative confidence # NOTE: could make this more efficient perhaps by filtering out the easy ones prior to the scatter_max - pro_valid = (pro_max >= conf_thre) * (pro_max >= adjacent_con_max) * (pro_max >= con_max) + pro_valid = ( + (pro_max >= conf_thre) * (pro_max >= adjacent_con_max) * (pro_max >= con_max) + ) pro_valid_inds = pro_valid.nonzero().squeeze(1) - + if pro_valid_inds.numel() == 0 or adjacency_list[0].numel() == 0: # no boxes are valid or no boxes have any overlap with any other boxes # either way, we can skip the NMS step @@ -491,42 +621,19 @@ def compute_candidate_nms_via_adjacency_list(pro_max: torch.Tensor, con_max: tor nms_inds = torch.nonzero(first_node_valid * second_node_valid).squeeze(1) modified_adjacency_list = [adjacency_list[0][nms_inds], adjacency_list[1][nms_inds]] - + if nms_inds.numel() == 0: # none of the remaining boxes have any overlap with any other remaining boxes # so we can skip the NMS step survive_inds = pro_valid.nonzero().squeeze(1) return survive_inds - # we compute the indices of the start and end of each box's adjacent boxes - # since our graph representation is just a list of edges, we would like to know which edges correspond to which nodes - # as the first node is sorted, we can just take the difference between adjacent nodes to get the start and end of each node's edges - first_node = modified_adjacency_list[0] - zero_tensor = torch.tensor([0], device=pro_max.device) - change_inds = (first_node[1:] != first_node[:-1]).nonzero()[:, 0] - len_tensor = torch.tensor([first_node.shape[0]], device=pro_max.device) - inds_of_adj_boxes = torch.cat([zero_tensor, change_inds + 1, len_tensor]) - # we then run a NMS over the (remaining) boxes - sorted_pro_valid_inds = pro_valid_inds[pro_max[pro_valid].argsort(descending=True)] - - # check which boxes have graph connections - unique_nodes = first_node[inds_of_adj_boxes[:-1]] - has_connection, graph_node_indices = find_in_sorted_tensor(unique_nodes, sorted_pro_valid_inds) - connected_sorted_pro_valid_inds = sorted_pro_valid_inds[has_connection] - graph_indices = graph_node_indices[has_connection] - - for i, j in zip(connected_sorted_pro_valid_inds, graph_indices): - if pro_valid[i] == 0: - continue - - remapped_start_ind = inds_of_adj_boxes[j] - remapped_end_ind = inds_of_adj_boxes[j+1] - adj_boxes = modified_adjacency_list[1][remapped_start_ind:remapped_end_ind] - pro_valid[adj_boxes] = 0 - - survive_inds = pro_valid.nonzero().squeeze(1) - + survive_inds = pro_valid_inds[pro_max[pro_valid].argsort(descending=True)] + + if run_nms: + survive_inds = run_nms_via_adjacency_list(survive_inds, modified_adjacency_list) + return survive_inds @@ -538,15 +645,18 @@ def run_nms_based_box_suppression_for_one_object( adjacency_list: list[torch.Tensor], image_scale: float, conf_thre: float = 0.001, -) -> torch.Tensor: - survive_inds = compute_candidate_nms_via_adjacency_list(pro_max, con_max, adjacency_list, conf_thre) - + run_nms: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + survive_inds = compute_candidate_nms_via_adjacency_list( + pro_max, con_max, adjacency_list, conf_thre, run_nms + ) + boxes = pred_boxes.squeeze(0)[survive_inds] logits = pro_max[survive_inds] - + logits = logits.unsqueeze(-1) boxes = boxes * image_scale - + # Convert boxes from center_x, center_y, width, height (cx_cy_w_h) to top_left_x, top_left_y, bottom_right_x, bottom_right_y (tlbr) cx, cy, w, h = boxes.unbind(-1) tl_x = cx - 0.5 * w @@ -554,14 +664,13 @@ def run_nms_based_box_suppression_for_one_object( br_x = cx + 0.5 * w br_y = cy + 0.5 * h boxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=-1) - + boxes_with_scores = torch.cat([boxes, logits], dim=-1) - boxes_with_scores = boxes_with_scores[boxes_with_scores[:, 4].argsort(descending=True)] - - print(boxes_with_scores) - print(boxes_with_scores.shape) - - return boxes_with_scores + boxes_with_scores = boxes_with_scores[ + boxes_with_scores[:, 4].argsort(descending=True) + ] + + return boxes_with_scores, survive_inds @torch.jit.script @@ -572,13 +681,13 @@ def run_nms_based_box_suppression_for_all_objects( image_scale: float, conf_thres: torch.Tensor, nms_thre: float = 0.4, - run_class_agnostic_nms: bool = False, + run_class_agnostic_nms: bool = True, ) -> list[torch.Tensor]: # pred_boxes is assumed to be [num_boxes, 4] # pro_max and con_max are assumed to be [num_boxes, num_objects] # conf_thres is assumed to be [num_objects] adjacency_list = compute_iou_adjacency_list(pred_boxes, nms_thre) - + futures = [ torch.jit.fork( run_nms_based_box_suppression_for_one_object, @@ -588,15 +697,63 @@ def run_nms_based_box_suppression_for_all_objects( adjacency_list, image_scale, conf_thres[i], - ) for i in range(pro_max.shape[1]) + run_nms=not run_class_agnostic_nms, + ) + for i in range(pro_max.shape[1]) ] - - predicted_boxes = [torch.jit.wait(fut) for fut in futures] - - # TODO: add class-agnostic NMS - assert not run_class_agnostic_nms, "Class-agnostic NMS not yet implemented" - - return predicted_boxes + + # there appears to be a bug in JIT related to star expansion that stops us from just using the following list comprehension: + # object_boxes, box_indices = zip(*[torch.jit.wait(fut) for fut in futures]) + object_boxes: list[torch.Tensor] = [] + box_indices: list[torch.Tensor] = [] + for fut in futures: + boxes_with_scores, survive_inds = torch.jit.wait(fut) + object_boxes.append(boxes_with_scores) + box_indices.append(survive_inds) + + if run_class_agnostic_nms: + # first take the top class prediction for each box + # we will of course use a scatter_max for this + all_object_boxes = torch.cat(object_boxes, dim=0) + all_box_indices = torch.cat(box_indices, dim=0) + all_box_class_assignments = torch.cat( + [ + torch.full((box.shape[0],), i, dtype=torch.long, device=box.device) + for i, box in enumerate(object_boxes) + ], + dim=0, + ) + all_object_confidences = all_object_boxes[:, 4] + + # we use an out argument here so we can control the default value + box_max_confidences = torch.zeros(pred_boxes.shape[0], device=pred_boxes.device) + _, max_confidence_indices = scatter_max( + all_object_confidences, all_box_indices, out=box_max_confidences + ) + + survived_inds = box_max_confidences.nonzero().squeeze(1) + + # now filter the survived inds via NMS + sorted_survived_inds = survived_inds[ + box_max_confidences[survived_inds].argsort(descending=True) + ] + post_nms_survived_inds = run_nms_via_adjacency_list( + sorted_survived_inds, adjacency_list + ) + + # and accumulate the boxes + object_survived_inds = max_confidence_indices[post_nms_survived_inds] + survived_class_assignments = all_box_class_assignments[object_survived_inds] + survived_object_boxes = all_object_boxes[object_survived_inds] + + # loop through based on class_assignments to create list of per-class boxes + # TODO: this could be done more efficiently + object_boxes = [ + survived_object_boxes[survived_class_assignments == i] + for i in range(pro_max.shape[1]) + ] + + return object_boxes @torch.jit.script @@ -623,9 +780,10 @@ def batched_run_nms_based_box_suppression_for_all_objects( conf_thres, nms_thre, run_class_agnostic_nms, - ) for i in range(pro_max.shape[0]) + ) + for i in range(pro_max.shape[0]) ] - + batched_predicted_boxes = [torch.jit.wait(fut) for fut in futures] - - return batched_predicted_boxes \ No newline at end of file + + return batched_predicted_boxes diff --git a/directai_fastapi/pydantic_models.py b/directai_fastapi/pydantic_models.py index d9d6f3b..51772cf 100644 --- a/directai_fastapi/pydantic_models.py +++ b/directai_fastapi/pydantic_models.py @@ -103,28 +103,33 @@ class Config: orm_mode = True async def save_configuration(self, config_cache: redis.Redis) -> dict: + logger.info(f"Detector Configs: {self.detector_configs}") for detector_config in self.detector_configs: + logger.info(detector_config.examples_to_include) if len(detector_config.examples_to_include) == 0: raise HTTPException( status_code=422, detail=f"Model lacks example_to_include for {detector_config.name} class.", ) - # Translating into Backend - config_dict = self.dict() - for i, single_config in enumerate(config_dict["detector_configs"]): - single_config["incs"] = single_config["examples_to_include"] - single_config["excs"] = single_config["examples_to_exclude"] - single_config["img_incs"] = [] - single_config["img_excs"] = [] - single_config["thresh"] = single_config["detection_threshold"] - del single_config["examples_to_include"] - del single_config["examples_to_exclude"] - del single_config["detection_threshold"] - config_dict["detector_configs"][i] = single_config - config_dict["nms_thresh"] = config_dict["nms_threshold"] - del config_dict["nms_threshold"] - config_dict["augment_examples"] = config_dict.get("augment_examples", True) - config_dict["class_agnostic_nms"] = config_dict.get("class_agnostic_nms", True) + labels = [c.name for c in self.detector_configs] + inc_sub_labels_dict: dict[str, List[str]] = { + c.name: c.examples_to_include for c in self.detector_configs + } + exc_sub_labels_dict: dict[str, List[str]] = { + c.name: c.examples_to_exclude for c in self.detector_configs + } + label_conf_thres: dict[str, float] = { + c.name: c.detection_threshold for c in self.detector_configs + } + config_dict = { + "labels": labels, + "inc_sub_labels_dict": inc_sub_labels_dict, + "exc_sub_labels_dict": exc_sub_labels_dict, + "augment_examples": self.augment_examples, + "nms_threshold": self.nms_threshold, + "class_agnostic_nms": self.class_agnostic_nms, + "label_conf_thres": label_conf_thres, + } if self.deployed_id is not None: key_exists = await config_cache.exists(self.deployed_id) @@ -139,6 +144,7 @@ async def save_configuration(self, config_cache: redis.Redis) -> dict: self.deployed_id = str(uuid.uuid4()) else: message = "Model updated." + assert ( self.deployed_id is not None ), "deployed_id should not be None at this point" @@ -154,12 +160,3 @@ class SingleDetectionResponse(BaseModel): class Config: allow_population_by_field_name = True - - -class VerboseDetectorConfig(BaseModel): - name: str - - incs: List[str] - excs: List[str] = [] - - thresh: Optional[float] = None diff --git a/directai_fastapi/server.py b/directai_fastapi/server.py index 8085000..0303be0 100644 --- a/directai_fastapi/server.py +++ b/directai_fastapi/server.py @@ -15,7 +15,6 @@ ClassifierResponse, DetectorDeploy, SingleDetectionResponse, - VerboseDetectorConfig, ) from utils import raise_if_cannot_open from modeling.distributed_backend import deploy_backend_models @@ -191,13 +190,28 @@ async def run_detector( print(f"Got request for {deployed_id}, which is a detector model") image = data.file.read() raise_if_cannot_open(image) + logger.info(f"Got request for {deployed_id}, which is a detector model") detector_configs = await grab_config(deployed_id) - ## NOTE: This might break if we have embedded BaseModel-inheriting objects inside the json object - verbose_detector_configs = [ - VerboseDetectorConfig(**json.loads(d) if isinstance(d, str) else d) - for d in detector_configs["detector_configs"] - ] - print(f"augment_examples: {detector_configs.get('augment_examples', None)}") + labels = detector_configs["labels"] + assert isinstance(labels, list), "Labels should be a list of strings" + inc_sub_labels_dict = detector_configs.get("inc_sub_labels_dict", None) + exc_sub_labels_dict = detector_configs.get("exc_sub_labels_dict", None) + label_conf_thres = detector_configs.get("label_conf_thres", None) + augment_examples = detector_configs.get("augment_examples", True) + nms_threshold = detector_configs.get("nms_threshold", 0.4) + class_agnostic_nms = detector_configs.get("class_agnostic_nms", True) + + bboxes = await app.state.detector_handle.remote( + image, + labels=labels, + inc_sub_labels_dict=inc_sub_labels_dict, + exc_sub_labels_dict=exc_sub_labels_dict, + label_conf_thres=label_conf_thres, + augment_examples=augment_examples, + nms_thre=nms_threshold, + run_class_agnostic_nms=class_agnostic_nms, + ) - bboxes = await app.state.detector_handle.remote(None) - return bboxes + return [ + bboxes, + ] diff --git a/docker-compose.yml b/docker-compose.yml index 879c675..3bebc46 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,26 +1,26 @@ version: '2.3' services: - local_redis_isaac_2: + local_redis: build: redis_data/ - container_name: local_redis_isaac_2 + container_name: local_redis ports: - - 63790:63790 + - 6379:6379 volumes: - ./redis_data/:/data networks: - deploy_network - local_fastapi_isaac_2: + local_fastapi: build: directai_fastapi/ ports: - - 10000:10000 + - 8000:8000 networks: - deploy_network - container_name: local_fastapi_isaac_2 + container_name: local_fastapi environment: - PYTHONUNBUFFERED=1 - NVIDIA_VISIBLE_DEVICES=1 - HF_HOME=/directai_fastapi/.cache/huggingface - - CACHE_REDIS_PORT=63790 + - CACHE_REDIS_PORT=6379 env_file: - directai_fastapi/.env runtime: nvidia @@ -29,7 +29,7 @@ services: - ./.cache:/directai_fastapi/.cache shm_size: 10.24g # because Ray complains if it's less depends_on: - - local_redis_isaac_2 + - local_redis extra_hosts: - "host.docker.internal:host-gateway" diff --git a/integration_tests/test_modules/test_detector.py b/integration_tests/test_modules/test_detector.py index 1e982a9..f52d8e5 100644 --- a/integration_tests/test_modules/test_detector.py +++ b/integration_tests/test_modules/test_detector.py @@ -4,16 +4,18 @@ import numpy as np from pydantic import BaseModel, Field, conlist from typing import Dict, List, Union -from scipy.optimize import linear_sum_assignment +from scipy.optimize import linear_sum_assignment # type: ignore FASTAPI_HOST = "host.docker.internal" + # copying over from directai_fastapi/pydantic_models.py class SingleDetectionResponse(BaseModel): # see discussion: https://github.com/pydantic/pydantic/issues/975 - tlbr: conlist(float, min_items=4, max_items=4) # type: ignore[valid-type] + tlbr: conlist(float, min_items=4, max_items=4) # type: ignore[valid-type] score: float - class_: str = Field(alias='class') + class_: str = Field(alias="class") + def bbox_iou(box1: List[float], box2: List[float]) -> float: """ @@ -43,8 +45,11 @@ def bbox_iou(box1: List[float], box2: List[float]) -> float: return iou - -def compute_naive_bipartite_detection_loss(detections_set_1: List[SingleDetectionResponse], detections_set_2: List[SingleDetectionResponse], unmatched_loss: float =1.0) -> float: +def compute_naive_bipartite_detection_loss( + detections_set_1: List[SingleDetectionResponse], + detections_set_2: List[SingleDetectionResponse], + unmatched_loss: float = 1.0, +) -> float: """ Computes the bipartite matching average loss between two sets of detections with potentially different sizes. If the classes are not equal, the loss is set to 1. Otherwise, the loss is the IoU of the bounding boxes. @@ -72,584 +77,683 @@ def compute_naive_bipartite_detection_loss(detections_set_1: List[SingleDetectio # Compute the total loss including unmatched detections total_loss = cost_matrix[row_ind, col_ind].sum() - total_loss += unmatched_loss * (max(num_detections_1, num_detections_2) - len(row_ind)) # Add loss for unmatched detections + total_loss += unmatched_loss * ( + max(num_detections_1, num_detections_2) - len(row_ind) + ) # Add loss for unmatched detections # Compute the average loss - average_loss = total_loss / max(num_detections_1, num_detections_2) # Normalizing by the larger set size + average_loss = total_loss / max( + num_detections_1, num_detections_2 + ) # Normalizing by the larger set size return average_loss + class TestDetect(unittest.TestCase): - def __init__(self, methodName: str ='runTest') -> None: + def __init__(self, methodName: str = "runTest") -> None: super().__init__(methodName=methodName) self.endpoint = f"http://{FASTAPI_HOST}:8000/" - + def test_deploy_detector_config_missing(self) -> None: body: Dict[str, str] = {} expected_result = { - 'data': None, - 'message': '1 validation error for Request body -> detector_configs field required (type=value_error.missing)', - 'status_code': 422 + "data": None, + "message": "1 validation error for Request body -> detector_configs field required (type=value_error.missing)", + "status_code": 422, } - response = requests.post( - self.endpoint+"deploy_detector", - json=body - ) + response = requests.post(self.endpoint + "deploy_detector", json=body) self.assertEqual(response.json(), expected_result) - + def test_deploy_detector_failure_include_missing(self) -> None: body = { "detector_configs": [ { - "name": "bottle", - "examples_to_exclude": [], - "detection_threshold": 0.1 + "name": "bottle", + "examples_to_exclude": [], + "detection_threshold": 0.1, } ], - "nms_threshold": 0.4 - } + "nms_threshold": 0.4, + } expected_result = { - 'status_code': 422, - 'message': '1 validation error for Request body -> detector_configs -> 0 -> examples_to_include field required (type=value_error.missing)', - 'data': None + "status_code": 422, + "message": "1 validation error for Request body -> detector_configs -> 0 -> examples_to_include field required (type=value_error.missing)", + "data": None, } - response = requests.post( - self.endpoint+"deploy_detector", - json=body - ) + response = requests.post(self.endpoint + "deploy_detector", json=body) response_json = response.json() self.assertEqual(response_json, expected_result) - + def test_deploy_detector_success(self) -> None: body = { "detector_configs": [ { - "name": "bottle", - "examples_to_include": [ - "bottle" - ], - "examples_to_exclude": [], - "detection_threshold": 0.1 + "name": "bottle", + "examples_to_include": ["bottle"], + "examples_to_exclude": [], + "detection_threshold": 0.1, } ], - "nms_threshold": 0.4 - } - response = requests.post( - self.endpoint+"deploy_detector", - json=body - ) + "nms_threshold": 0.4, + } + response = requests.post(self.endpoint + "deploy_detector", json=body) response_json = response.json() - self.assertTrue('deployed_id' in response_json) - self.assertEqual(response_json['message'], "New model deployed.") - + self.assertTrue("deployed_id" in response_json) + self.assertEqual(response_json["message"], "New model deployed.") + def test_deploy_detector_success_without_exclude(self) -> None: body = { "detector_configs": [ { - "name": "bottle", - "examples_to_include": [ - "bottle" - ], - "detection_threshold": 0.1 + "name": "bottle", + "examples_to_include": ["bottle"], + "detection_threshold": 0.1, } ], - "nms_threshold": 0.4 - } - response = requests.post( - self.endpoint+"deploy_detector", - json=body - ) + "nms_threshold": 0.4, + } + response = requests.post(self.endpoint + "deploy_detector", json=body) response_json = response.json() - self.assertTrue('deployed_id' in response_json) - self.assertEqual(response_json['message'], "New model deployed.") - + self.assertTrue("deployed_id" in response_json) + self.assertEqual(response_json["message"], "New model deployed.") + def test_deploy_detector_success_without_detect_threshold(self) -> None: body = { "detector_configs": [ { - "name": "bottle", - "examples_to_include": [ - "bottle" - ], - "examples_to_exclude": [], + "name": "bottle", + "examples_to_include": ["bottle"], + "examples_to_exclude": [], } ], - "nms_threshold": 0.4 - } - response = requests.post( - self.endpoint+"deploy_detector", - json=body - ) + "nms_threshold": 0.4, + } + response = requests.post(self.endpoint + "deploy_detector", json=body) response_json = response.json() - self.assertTrue('deployed_id' in response_json) - self.assertEqual(response_json['message'], "New model deployed.") - + self.assertTrue("deployed_id" in response_json) + self.assertEqual(response_json["message"], "New model deployed.") + def test_deploy_detector_success_without_nms_threshold(self) -> None: body = { "detector_configs": [ { - "name": "bottle", - "examples_to_include": [ - "bottle" - ], - "examples_to_exclude": [], - "detection_threshold": 0.1 + "name": "bottle", + "examples_to_include": ["bottle"], + "examples_to_exclude": [], + "detection_threshold": 0.1, } ], } - response = requests.post( - self.endpoint+"deploy_detector", - json=body - ) + response = requests.post(self.endpoint + "deploy_detector", json=body) response_json = response.json() - self.assertTrue('deployed_id' in response_json) - self.assertEqual(response_json['message'], "New model deployed.") - + self.assertTrue("deployed_id" in response_json) + self.assertEqual(response_json["message"], "New model deployed.") + def test_deploy_detector_success_without_augment_examples(self) -> None: body = { "detector_configs": [ { - "name": "bottle", - "examples_to_include": [ - "bottle" - ], - "examples_to_exclude": [], - "detection_threshold": 0.1 + "name": "bottle", + "examples_to_include": ["bottle"], + "examples_to_exclude": [], + "detection_threshold": 0.1, } ], "augment_examples": False, } - response = requests.post( - self.endpoint+"deploy_detector", - json=body - ) + response = requests.post(self.endpoint + "deploy_detector", json=body) response_json = response.json() - self.assertTrue('deployed_id' in response_json) - self.assertEqual(response_json['message'], "New model deployed.") - - @unittest.skip("detector isn't built yet") + self.assertTrue("deployed_id" in response_json) + self.assertEqual(response_json["message"], "New model deployed.") + + +class TestDetectorInference(unittest.TestCase): + def __init__(self, methodName: str = "runTest"): + super().__init__(methodName=methodName) + self.endpoint = f"http://{FASTAPI_HOST}:8000/" + # here we assume that deploy has been tested and works + # so we can generate a fixed deploy id for testing + body = { + "detector_configs": [ + { + "name": "bottle", + "examples_to_include": ["bottle"], + "examples_to_exclude": [], + "detection_threshold": 0.1, + } + ], + } + deploy_response = requests.post(self.endpoint + "deploy_detector", json=body) + deploy_response_json = deploy_response.json() + self.sample_deployed_id = deploy_response_json["deployed_id"] + def test_detect(self) -> None: - sample_deployed_id = "a554fdf3-cd07-45f5-a01a-c3b7cef75374" sample_fp = "sample_data/coke_through_the_ages.jpeg" - expected_detect_response_unaccelerated = [[{'tlbr': [167.0, 151.0, 267.0, 494.0], 'score': 0.473, 'class': 'bottle'}, {'tlbr': [407.0, 152.0, 510.0, 495.0], 'score': 0.466, 'class': 'bottle'}, {'tlbr': [534.0, 150.0, 632.0, 494.0], 'score': 0.457, 'class': 'bottle'}, {'tlbr': [39.3, 183.0, 141.0, 494.0], 'score': 0.45, 'class': 'bottle'}, {'tlbr': [653.0, 151.0, 756.0, 497.0], 'score': 0.45, 'class': 'bottle'}, {'tlbr': [290.0, 160.0, 388.0, 493.0], 'score': 0.443, 'class': 'bottle'}, {'tlbr': [780.0, 158.0, 879.0, 498.0], 'score': 0.418, 'class': 'bottle'}, {'tlbr': [910.0, 108.0, 1030.0, 501.0], 'score': 0.385, 'class': 'bottle'}, {'tlbr': [1070.0, 170.0, 1160.0, 498.0], 'score': 0.364, 'class': 'bottle'}, {'tlbr': [27.3, 112.0, 1170.0, 507.0], 'score': 0.134, 'class': 'bottle'}]] - with open(sample_fp, 'rb') as f: + expected_detect_response_unaccelerated = [ + [ + { + "tlbr": [167.0, 151.0, 267.0, 494.0], + "score": 0.473, + "class": "bottle", + }, + { + "tlbr": [407.0, 152.0, 510.0, 495.0], + "score": 0.466, + "class": "bottle", + }, + { + "tlbr": [534.0, 150.0, 632.0, 494.0], + "score": 0.457, + "class": "bottle", + }, + {"tlbr": [39.3, 183.0, 141.0, 494.0], "score": 0.45, "class": "bottle"}, + { + "tlbr": [653.0, 151.0, 756.0, 497.0], + "score": 0.45, + "class": "bottle", + }, + { + "tlbr": [290.0, 160.0, 388.0, 493.0], + "score": 0.443, + "class": "bottle", + }, + { + "tlbr": [780.0, 158.0, 879.0, 498.0], + "score": 0.418, + "class": "bottle", + }, + { + "tlbr": [910.0, 108.0, 1030.0, 501.0], + "score": 0.385, + "class": "bottle", + }, + { + "tlbr": [1070.0, 170.0, 1160.0, 498.0], + "score": 0.364, + "class": "bottle", + }, + { + "tlbr": [27.3, 112.0, 1170.0, 507.0], + "score": 0.134, + "class": "bottle", + }, + ] + ] + with open(sample_fp, "rb") as f: file_data = f.read() files = { - 'data': (sample_fp, file_data, "image/jpg"), + "data": (sample_fp, file_data, "image/jpg"), } - params = { - 'deployed_id': sample_deployed_id - } - response = requests.post( - self.endpoint+"detect", - params=params, - files=files - ) + params = {"deployed_id": self.sample_deployed_id} + response = requests.post(self.endpoint + "detect", params=params, files=files) detect_response_json = response.json() - - self.assertEqual( - len(detect_response_json), 1 - ) - expected_detect_response = [SingleDetectionResponse(**d) for d in expected_detect_response_unaccelerated[0]] - actual_detect_response = [SingleDetectionResponse(**d) for d in detect_response_json[0]] + + self.assertEqual(len(detect_response_json), 1) + expected_detect_response = [ + SingleDetectionResponse(**d) + for d in expected_detect_response_unaccelerated[0] + ] # type: ignore + actual_detect_response = [ + SingleDetectionResponse(**d) for d in detect_response_json[0] + ] # type: ignore self.assertLess( compute_naive_bipartite_detection_loss( - expected_detect_response, - actual_detect_response + expected_detect_response, actual_detect_response ), - 0.05 + 0.05, ) - + def test_detect_malformatted_image(self) -> None: - sample_deployed_id = "a554fdf3-cd07-45f5-a01a-c3b7cef75374" sample_fp = "bad_file_path.suffix" - file_data = b"This is not an image file" + file_data = b"This is not an image file" files = { - 'data': (sample_fp, file_data, "image/jpg"), - } - params = { - 'deployed_id': sample_deployed_id + "data": (sample_fp, file_data, "image/jpg"), } - response = requests.post( - self.endpoint+"detect", - params=params, - files=files - ) + params = {"deployed_id": self.sample_deployed_id} + response = requests.post(self.endpoint + "detect", params=params, files=files) response_json = response.json() - self.assertEqual(response_json['status_code'],422) + self.assertEqual(response_json["status_code"], 422) self.assertEqual( - response_json['message'], - "Invalid image received, unable to open." + response_json["message"], "Invalid image received, unable to open." ) - + def test_detect_empty_image(self) -> None: - sample_deployed_id = "a554fdf3-cd07-45f5-a01a-c3b7cef75374" sample_fp = "bad_file_path.jpg" files = { - 'data': (sample_fp, b'', "image/jpg"), - } - params = { - 'deployed_id': sample_deployed_id + "data": (sample_fp, b"", "image/jpg"), } - response = requests.post( - self.endpoint+"detect", - params=params, - files=files - ) + params = {"deployed_id": self.sample_deployed_id} + response = requests.post(self.endpoint + "detect", params=params, files=files) response_json = response.json() - self.assertEqual(response_json['status_code'],422) + self.assertEqual(response_json["status_code"], 422) self.assertEqual( - response_json['message'], - "Invalid image received, unable to open." + response_json["message"], "Invalid image received, unable to open." ) - + def test_detect_truncated_image(self) -> None: - sample_deployed_id = "a554fdf3-cd07-45f5-a01a-c3b7cef75374" sample_fp = "sample_data/coke_through_the_ages.jpeg" - with open(sample_fp, 'rb') as f: + with open(sample_fp, "rb") as f: file_data = f.read() - file_data = file_data[:int(len(file_data)/2)] + file_data = file_data[: int(len(file_data) / 2)] files = { - 'data': (sample_fp, file_data, "image/jpg"), - } - params = { - 'deployed_id': sample_deployed_id + "data": (sample_fp, file_data, "image/jpg"), } - response = requests.post( - self.endpoint+"detect", - params=params, - files=files - ) + params = {"deployed_id": self.sample_deployed_id} + response = requests.post(self.endpoint + "detect", params=params, files=files) response_json = response.json() - self.assertEqual(response_json['status_code'],422) + self.assertEqual(response_json["status_code"], 422) self.assertEqual( - response_json['message'], - "Invalid image received, unable to open." + response_json["message"], "Invalid image received, unable to open." ) - - @unittest.skip("detector isn't built yet") + def test_deploy_and_detect(self) -> None: # Starting Deploy Call body = { "detector_configs": [ { - "name": "bottle", - "examples_to_include": [ - "bottle" - ], - "examples_to_exclude": [], - "detection_threshold": 0.2 + "name": "bottle", + "examples_to_include": ["bottle"], + "examples_to_exclude": [], + "detection_threshold": 0.2, } ], - "nms_threshold": 0.4 + "nms_threshold": 0.4, } - deploy_response = requests.post( - self.endpoint+"deploy_detector", - json=body - ) + deploy_response = requests.post(self.endpoint + "deploy_detector", json=body) deploy_response_json = deploy_response.json() - self.assertTrue('deployed_id' in deploy_response_json) - self.assertEqual(deploy_response_json['message'], "New model deployed.") - deployed_id = deploy_response_json['deployed_id'] - + self.assertTrue("deployed_id" in deploy_response_json) + self.assertEqual(deploy_response_json["message"], "New model deployed.") + deployed_id = deploy_response_json["deployed_id"] + # Starting Detect Call sample_fp = "sample_data/coke_through_the_ages.jpeg" - expected_detect_response_unaccelerated = [[{'tlbr': [167.0, 151.0, 267.0, 494.0], 'score': 0.473, 'class': 'bottle'}, {'tlbr': [407.0, 152.0, 510.0, 495.0], 'score': 0.466, 'class': 'bottle'}, {'tlbr': [534.0, 150.0, 632.0, 494.0], 'score': 0.457, 'class': 'bottle'}, {'tlbr': [39.3, 183.0, 141.0, 494.0], 'score': 0.45, 'class': 'bottle'}, {'tlbr': [653.0, 151.0, 756.0, 497.0], 'score': 0.45, 'class': 'bottle'}, {'tlbr': [290.0, 160.0, 388.0, 493.0], 'score': 0.443, 'class': 'bottle'}, {'tlbr': [780.0, 158.0, 879.0, 498.0], 'score': 0.418, 'class': 'bottle'}, {'tlbr': [910.0, 108.0, 1030.0, 501.0], 'score': 0.385, 'class': 'bottle'}, {'tlbr': [1070.0, 170.0, 1160.0, 498.0], 'score': 0.364, 'class': 'bottle'}]] - with open(sample_fp, 'rb') as f: + expected_detect_response_unaccelerated = [ + [ + { + "tlbr": [167.0, 151.0, 267.0, 494.0], + "score": 0.473, + "class": "bottle", + }, # type: ignore + { + "tlbr": [407.0, 152.0, 510.0, 495.0], + "score": 0.466, + "class": "bottle", + }, + { + "tlbr": [534.0, 150.0, 632.0, 494.0], + "score": 0.457, + "class": "bottle", + }, + {"tlbr": [39.3, 183.0, 141.0, 494.0], "score": 0.45, "class": "bottle"}, + { + "tlbr": [653.0, 151.0, 756.0, 497.0], + "score": 0.45, + "class": "bottle", + }, + { + "tlbr": [290.0, 160.0, 388.0, 493.0], + "score": 0.443, + "class": "bottle", + }, + { + "tlbr": [780.0, 158.0, 879.0, 498.0], + "score": 0.418, + "class": "bottle", + }, + { + "tlbr": [910.0, 108.0, 1030.0, 501.0], + "score": 0.385, + "class": "bottle", + }, + { + "tlbr": [1070.0, 170.0, 1160.0, 498.0], + "score": 0.364, + "class": "bottle", + }, + ] + ] + with open(sample_fp, "rb") as f: file_data = f.read() files = { - 'data': (sample_fp, file_data, "image/jpg"), - } - params = { - 'deployed_id': deployed_id + "data": (sample_fp, file_data, "image/jpg"), } + params = {"deployed_id": deployed_id} detect_response = requests.post( - self.endpoint+"detect", - params=params, - files=files + self.endpoint + "detect", params=params, files=files ) detect_response_json = detect_response.json() - - self.assertEqual( - len(detect_response_json), 1 - ) - expected_detect_response = [SingleDetectionResponse(**d) for d in expected_detect_response_unaccelerated[0]] - actual_detect_response = [SingleDetectionResponse(**d) for d in detect_response_json[0]] + + self.assertEqual(len(detect_response_json), 1) + expected_detect_response = [ + SingleDetectionResponse(**d) # type: ignore + for d in expected_detect_response_unaccelerated[0] + ] + actual_detect_response = [ + SingleDetectionResponse(**d) for d in detect_response_json[0] # type: ignore + ] self.assertLess( compute_naive_bipartite_detection_loss( - expected_detect_response, - actual_detect_response + expected_detect_response, actual_detect_response ), - 0.05 + 0.05, ) - - @unittest.skip("detector isn't built yet") + def test_deploy_with_long_prompt_and_detect(self) -> None: # Starting Deploy Call + # NOTE: this is the only single-class detection test that runs the class-specific NMS algorithm very_long_prompt = "boat from birds-eye view maritime vessel from birds-eye view boat from top-down view maritime vessel from top-down view" body = { "detector_configs": [ { - "name": "sample_prompt", - "examples_to_include": [ - very_long_prompt - ], - "examples_to_exclude": [], - "detection_threshold": 0.01 + "name": "sample_prompt", + "examples_to_include": [very_long_prompt], + "examples_to_exclude": [], + "detection_threshold": 0.01, } ], - "nms_threshold": 0.4 + "nms_threshold": 0.4, } - deploy_response = requests.post( - self.endpoint+"deploy_detector", - json=body - ) + deploy_response = requests.post(self.endpoint + "deploy_detector", json=body) deploy_response_json = deploy_response.json() - self.assertTrue('deployed_id' in deploy_response_json) - self.assertEqual(deploy_response_json['message'], "New model deployed.") - deployed_id = deploy_response_json['deployed_id'] - + self.assertTrue("deployed_id" in deploy_response_json) + self.assertEqual(deploy_response_json["message"], "New model deployed.") + deployed_id = deploy_response_json["deployed_id"] + # Starting Detect Call sample_fp = "sample_data/coke_through_the_ages.jpeg" - expected_detect_response_unaccelerated = [[{'tlbr': [2.04, -1.39, 1190.0, 636.0], 'score': 0.0195, 'class': 'sample_prompt'}]] - with open(sample_fp, 'rb') as f: + expected_detect_response_unaccelerated = [ + [ + { + "tlbr": [2.04, -1.39, 1190.0, 636.0], + "score": 0.0195, + "class": "sample_prompt", + } + ] + ] + with open(sample_fp, "rb") as f: file_data = f.read() files = { - 'data': (sample_fp, file_data, "image/jpg"), - } - params = { - 'deployed_id': deployed_id + "data": (sample_fp, file_data, "image/jpg"), } + params = {"deployed_id": deployed_id} detect_response = requests.post( - self.endpoint+"detect", - params=params, - files=files + self.endpoint + "detect", params=params, files=files ) detect_response_json = detect_response.json() - - self.assertEqual( - len(detect_response_json), 1 - ) - expected_detect_response = [SingleDetectionResponse(**d) for d in expected_detect_response_unaccelerated[0]] - actual_detect_response = [SingleDetectionResponse(**d) for d in detect_response_json[0]] + + self.assertEqual(len(detect_response_json), 1) + expected_detect_response = [ + SingleDetectionResponse(**d) # type: ignore + for d in expected_detect_response_unaccelerated[0] + ] + actual_detect_response = [ + SingleDetectionResponse(**d) for d in detect_response_json[0] # type: ignore + ] self.assertLess( compute_naive_bipartite_detection_loss( - expected_detect_response, - actual_detect_response + expected_detect_response, actual_detect_response ), - 0.05 + 0.05, ) - - @unittest.skip("detector isn't built yet") + def test_deploy_and_detect_without_augmented_examples(self) -> None: # Starting Deploy Call body = { "detector_configs": [ { - "name": "bottle", - "examples_to_include": [ - "bottle" - ], - "examples_to_exclude": [], - "detection_threshold": 0.2 + "name": "bottle", + "examples_to_include": ["bottle"], + "examples_to_exclude": [], + "detection_threshold": 0.2, } ], "nms_threshold": 0.4, - "augment_examples": False + "augment_examples": False, } - deploy_response = requests.post( - self.endpoint+"deploy_detector", - json=body - ) + deploy_response = requests.post(self.endpoint + "deploy_detector", json=body) deploy_response_json = deploy_response.json() - self.assertTrue('deployed_id' in deploy_response_json) - self.assertEqual(deploy_response_json['message'], "New model deployed.") - deployed_id = deploy_response_json['deployed_id'] - + self.assertTrue("deployed_id" in deploy_response_json) + self.assertEqual(deploy_response_json["message"], "New model deployed.") + deployed_id = deploy_response_json["deployed_id"] + # Starting Detect Call sample_fp = "sample_data/coke_through_the_ages.jpeg" - expected_detect_response_accelerated = [[{'tlbr': [407.40325927734375, 152.23214721679688, 509.8586120605469, 494.7916564941406], 'score': 0.6420959234237671, 'class': 'bottle'}, {'tlbr': [166.59225463867188, 151.0416717529297, 266.7410583496094, 493.6011962890625], 'score': 0.6420406699180603, 'class': 'bottle'}, {'tlbr': [533.6681518554688, 150.29762268066406, 632.4032592773438, 494.3452453613281], 'score': 0.6393750905990601, 'class': 'bottle'}, {'tlbr': [652.455322265625, 151.0416717529297, 755.8779907226562, 497.172607421875], 'score': 0.6305534243583679, 'class': 'bottle'}, {'tlbr': [39.28571319580078, 182.5892791748047, 141.22023010253906, 494.1964111328125], 'score': 0.6197347640991211, 'class': 'bottle'}, {'tlbr': [289.91815185546875, 160.26785278320312, 388.05804443359375, 492.7083435058594], 'score': 0.5861819386482239, 'class': 'bottle'}, {'tlbr': [779.6502685546875, 157.5892791748047, 878.6830444335938, 497.7678527832031], 'score': 0.5840385556221008, 'class': 'bottle'}, {'tlbr': [909.97021484375, 107.58928680419922, 1031.6964111328125, 501.33929443359375], 'score': 0.5274810194969177, 'class': 'bottle'}, {'tlbr': [1065.4761962890625, 170.38690185546875, 1158.3333740234375, 497.4702453613281], 'score': 0.49106982350349426, 'class': 'bottle'}]] - with open(sample_fp, 'rb') as f: + expected_detect_response_accelerated = [ + [ + { + "tlbr": [ + 407.40325927734375, + 152.23214721679688, + 509.8586120605469, + 494.7916564941406, + ], + "score": 0.6420959234237671, + "class": "bottle", + }, + { + "tlbr": [ + 166.59225463867188, + 151.0416717529297, + 266.7410583496094, + 493.6011962890625, + ], + "score": 0.6420406699180603, + "class": "bottle", + }, + { + "tlbr": [ + 533.6681518554688, + 150.29762268066406, + 632.4032592773438, + 494.3452453613281, + ], + "score": 0.6393750905990601, + "class": "bottle", + }, + { + "tlbr": [ + 652.455322265625, + 151.0416717529297, + 755.8779907226562, + 497.172607421875, + ], + "score": 0.6305534243583679, + "class": "bottle", + }, + { + "tlbr": [ + 39.28571319580078, + 182.5892791748047, + 141.22023010253906, + 494.1964111328125, + ], + "score": 0.6197347640991211, + "class": "bottle", + }, + { + "tlbr": [ + 289.91815185546875, + 160.26785278320312, + 388.05804443359375, + 492.7083435058594, + ], + "score": 0.5861819386482239, + "class": "bottle", + }, + { + "tlbr": [ + 779.6502685546875, + 157.5892791748047, + 878.6830444335938, + 497.7678527832031, + ], + "score": 0.5840385556221008, + "class": "bottle", + }, + { + "tlbr": [ + 909.97021484375, + 107.58928680419922, + 1031.6964111328125, + 501.33929443359375, + ], + "score": 0.5274810194969177, + "class": "bottle", + }, + { + "tlbr": [ + 1065.4761962890625, + 170.38690185546875, + 1158.3333740234375, + 497.4702453613281, + ], + "score": 0.49106982350349426, + "class": "bottle", + }, + ] + ] + with open(sample_fp, "rb") as f: file_data = f.read() files = { - 'data': (sample_fp, file_data, "image/jpg"), - } - params = { - 'deployed_id': deployed_id + "data": (sample_fp, file_data, "image/jpg"), } + params = {"deployed_id": deployed_id} detect_response = requests.post( - self.endpoint+"detect", - params=params, - files=files + self.endpoint + "detect", params=params, files=files ) detect_response_json = detect_response.json() - - self.assertEqual( - len(detect_response_json), 1 - ) - expected_detect_response = [SingleDetectionResponse(**d) for d in expected_detect_response_accelerated[0]] - actual_detect_response = [SingleDetectionResponse(**d) for d in detect_response_json[0]] + + self.assertEqual(len(detect_response_json), 1) + expected_detect_response = [ + SingleDetectionResponse(**d) + for d in expected_detect_response_accelerated[0] + ] + actual_detect_response = [ + SingleDetectionResponse(**d) for d in detect_response_json[0] + ] self.assertLess( compute_naive_bipartite_detection_loss( - expected_detect_response, - actual_detect_response + expected_detect_response, actual_detect_response ), - 0.05 + 0.05, ) - + # now compare with the response from the same model with augmented examples body = { "detector_configs": [ { - "name": "bottle", - "examples_to_include": [ - "bottle" - ], - "examples_to_exclude": [], - "detection_threshold": 0.2 + "name": "bottle", + "examples_to_include": ["bottle"], + "examples_to_exclude": [], + "detection_threshold": 0.2, } ], "nms_threshold": 0.4, - "augment_examples": True + "augment_examples": True, } - deploy_response = requests.post( - self.endpoint+"deploy_detector", - json=body - ) + deploy_response = requests.post(self.endpoint + "deploy_detector", json=body) deploy_response_json = deploy_response.json() - self.assertTrue('deployed_id' in deploy_response_json) - self.assertEqual(deploy_response_json['message'], "New model deployed.") - deployed_id_augmented_examples = deploy_response_json['deployed_id'] - + self.assertTrue("deployed_id" in deploy_response_json) + self.assertEqual(deploy_response_json["message"], "New model deployed.") + deployed_id_augmented_examples = deploy_response_json["deployed_id"] + # Starting Detect Call - params = { - 'deployed_id': deployed_id_augmented_examples - } + params = {"deployed_id": deployed_id_augmented_examples} detect_response_augmented_examples = requests.post( - self.endpoint+"detect", - params=params, - files=files + self.endpoint + "detect", params=params, files=files + ) + detect_response_augmented_examples_json = ( + detect_response_augmented_examples.json() + ) + + self.assertNotEqual( + detect_response_json, detect_response_augmented_examples_json ) - detect_response_augmented_examples_json = detect_response_augmented_examples.json() - - self.assertNotEqual(detect_response_json, detect_response_augmented_examples_json) - - @unittest.skip("detector isn't built yet") + def test_deploy_with_and_without_class_agnostic_nms(self) -> None: # Starting Deploy Call body = { "detector_configs": [ { - "name": "face", - "examples_to_include": [ - "face" - ], - "examples_to_exclude": [], - "detection_threshold": 0.1 + "name": "face", + "examples_to_include": ["face"], + "examples_to_exclude": [], + "detection_threshold": 0.1, }, { - "name": "head", - "examples_to_include": [ - "head" - ], - "examples_to_exclude": [], - "detection_threshold": 0.1 - } + "name": "head", + "examples_to_include": ["head"], + "examples_to_exclude": [], + "detection_threshold": 0.1, + }, ], "nms_threshold": 0.1, - "class_agnostic_nms": True + "class_agnostic_nms": True, } - deploy_response = requests.post( - self.endpoint+"deploy_detector", - json=body - ) + deploy_response = requests.post(self.endpoint + "deploy_detector", json=body) deploy_response_json = deploy_response.json() - self.assertTrue('deployed_id' in deploy_response_json) - self.assertEqual(deploy_response_json['message'], "New model deployed.") - deployed_id = deploy_response_json['deployed_id'] - + self.assertTrue("deployed_id" in deploy_response_json) + self.assertEqual(deploy_response_json["message"], "New model deployed.") + deployed_id = deploy_response_json["deployed_id"] + # Starting Detect Call sample_fp = "sample_data/jumping_jack_up_isaac.jpg" - with open(sample_fp, 'rb') as f: + with open(sample_fp, "rb") as f: file_data = f.read() files = { - 'data': (sample_fp, file_data, "image/jpg"), - } - params = { - 'deployed_id': deployed_id + "data": (sample_fp, file_data, "image/jpg"), } + params = {"deployed_id": deployed_id} detect_response = requests.post( - self.endpoint+"detect", - params=params, - files=files + self.endpoint + "detect", params=params, files=files ) detect_response_json = detect_response.json() - - self.assertEqual( - len(detect_response_json), 1 - ) - self.assertEqual( - len(detect_response_json[0]), - 1 - ) - detected_classes = set([detection['class'] for detection in detect_response_json[0]]) - self.assertEqual( - detected_classes, - {'head'} + + self.assertEqual(len(detect_response_json), 1) + self.assertEqual(len(detect_response_json[0]), 1) + detected_classes = set( + [detection["class"] for detection in detect_response_json[0]] ) - + self.assertEqual(detected_classes, {"head"}) + # now compare with the response from the same model without class agnostic nms # it should detect both a head and a face body = { "detector_configs": [ { - "name": "face", - "examples_to_include": [ - "face" - ], - "examples_to_exclude": [], - "detection_threshold": 0.1 + "name": "face", + "examples_to_include": ["face"], + "examples_to_exclude": [], + "detection_threshold": 0.1, }, { - "name": "head", - "examples_to_include": [ - "head" - ], - "examples_to_exclude": [], - "detection_threshold": 0.1 - } + "name": "head", + "examples_to_include": ["head"], + "examples_to_exclude": [], + "detection_threshold": 0.1, + }, ], "nms_threshold": 0.1, - "class_agnostic_nms": False + "class_agnostic_nms": False, } - deploy_response = requests.post( - self.endpoint+"deploy_detector", - json=body - ) + deploy_response = requests.post(self.endpoint + "deploy_detector", json=body) deploy_response_json = deploy_response.json() - self.assertTrue('deployed_id' in deploy_response_json) - self.assertEqual(deploy_response_json['message'], "New model deployed.") - deployed_id_class_based_nms = deploy_response_json['deployed_id'] - + self.assertTrue("deployed_id" in deploy_response_json) + self.assertEqual(deploy_response_json["message"], "New model deployed.") + deployed_id_class_based_nms = deploy_response_json["deployed_id"] + # Starting Detect Call - params = { - 'deployed_id': deployed_id_class_based_nms - } + params = {"deployed_id": deployed_id_class_based_nms} detect_response_class_based_nms = requests.post( - self.endpoint+"detect", - params=params, - files=files + self.endpoint + "detect", params=params, files=files ) detect_response_class_based_nms_json = detect_response_class_based_nms.json() - - self.assertEqual( - len(detect_response_class_based_nms_json), 1 - ) - self.assertEqual( - len(detect_response_class_based_nms_json[0]), - 2 - ) - detected_classes = set([detection['class'] for detection in detect_response_class_based_nms_json[0]]) - self.assertEqual( - detected_classes, - {'head', 'face'} - ) + + self.assertEqual(len(detect_response_class_based_nms_json), 1) + self.assertEqual(len(detect_response_class_based_nms_json[0]), 2) + detected_classes = set( + [ + detection["class"] + for detection in detect_response_class_based_nms_json[0] + ] + ) + self.assertEqual(detected_classes, {"head", "face"}) diff --git a/mypy.sh b/mypy.sh index 25a7af7..c297b36 100755 --- a/mypy.sh +++ b/mypy.sh @@ -62,6 +62,6 @@ build_testing() { if $MYPY_FULL_APP; then build_app $BUILD fi -if $MYPY_TESTS; then - build_testing $BUILD -fi +# if $MYPY_TESTS; then +# build_testing $BUILD +# fi From ce234769ee0445c556874fbd64b8cbb50cf65a98 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Mon, 26 Aug 2024 14:02:37 -0400 Subject: [PATCH 3/7] added unit tests --- directai_fastapi/modeling/object_detector.py | 1 + .../unit_tests/test_modules/test_detector.py | 127 ++++++++++++++++-- docker-compose.yml | 2 +- redis_data/redis_entrypoint.sh | 2 +- 4 files changed, 122 insertions(+), 10 deletions(-) diff --git a/directai_fastapi/modeling/object_detector.py b/directai_fastapi/modeling/object_detector.py index 6e2e9ca..df796fd 100644 --- a/directai_fastapi/modeling/object_detector.py +++ b/directai_fastapi/modeling/object_detector.py @@ -499,6 +499,7 @@ def compute_iou_adjacency_list( boxes: torch.Tensor, nms_thre: float ) -> list[torch.Tensor]: boxes = boxes.clone() + # boxes are in cxcywh format, we need to convert them to tlbr format boxes[:, 0] -= boxes[:, 2] / 2 boxes[:, 1] -= boxes[:, 3] / 2 boxes[:, 2] = boxes[:, 0] + boxes[:, 2] diff --git a/directai_fastapi/unit_tests/test_modules/test_detector.py b/directai_fastapi/unit_tests/test_modules/test_detector.py index 266b6cb..4af735d 100644 --- a/directai_fastapi/unit_tests/test_modules/test_detector.py +++ b/directai_fastapi/unit_tests/test_modules/test_detector.py @@ -1,31 +1,136 @@ import unittest import torch +import torchvision # type: ignore from typing_extensions import ClassVar from modeling.object_detector import ( ZeroShotObjectDetectorWithFeedback, created_padded_tensor_from_bytes, + compute_iou_adjacency_list, + run_nms_via_adjacency_list, + run_nms_based_box_suppression_for_all_objects, ) +class TestHelperFunctions(unittest.TestCase): + def test_nms_via_adjacency_list(self) -> None: + nms_thre = 0.1 + n_boxes = 1024 + cxcywh_boxes = torch.rand(n_boxes, 4) + scores = torch.rand(n_boxes) + + tlbr_boxes = cxcywh_boxes.clone() + tlbr_boxes[:, :2] = cxcywh_boxes[:, :2] - cxcywh_boxes[:, 2:] / 2 + tlbr_boxes[:, 2:] = cxcywh_boxes[:, :2] + cxcywh_boxes[:, 2:] / 2 + + box_indices_by_descending_score = torch.argsort(scores, descending=True) + adjacency_list = compute_iou_adjacency_list(cxcywh_boxes, nms_thre=nms_thre) + adjacency_survived_inds = run_nms_via_adjacency_list( + box_indices_by_descending_score, adjacency_list + ) + + torchvision_survived_inds = torchvision.ops.nms( + tlbr_boxes, scores, iou_threshold=nms_thre + ) + + self.assertTrue(torch.all(adjacency_survived_inds == torchvision_survived_inds)) + + def test_class_agnostic_has_no_effect_on_single_class(self) -> None: + n_boxes = 512 + pro_max = torch.rand(n_boxes, 1) + con_max = torch.rand(n_boxes, 1) + cxcywh_boxes = torch.rand(n_boxes, 4) + conf_thres = torch.tensor([0.1]) + nms_thre = 0.1 + + class_believer_boxes = run_nms_based_box_suppression_for_all_objects( + pro_max, + con_max, + cxcywh_boxes, + 1.0, + conf_thres, + nms_thre, + run_class_agnostic_nms=True, + ) + self.assertEqual(len(class_believer_boxes), 1) + + class_agnostic_boxes = run_nms_based_box_suppression_for_all_objects( + pro_max, + con_max, + cxcywh_boxes, + 1.0, + conf_thres, + nms_thre, + run_class_agnostic_nms=False, + ) + + self.assertTrue( + torch.all(torch.eq(class_believer_boxes[0], class_agnostic_boxes[0])) + ) + + def test_class_agnostic_has_no_effect_on_multiclass_with_no_overlap(self) -> None: + n_classes = 4 + n_boxes = 512 * n_classes + pro_max = torch.zeros(n_boxes, n_classes) + con_max = torch.zeros(n_boxes, n_classes) + cxcywh_boxes = torch.rand(n_boxes, 4) + conf_thres = torch.tensor([0.1] * n_classes) + nms_thre = 0.1 + + # adjust the scores such that each class has nonzero scores in exactly 1/n_classes of the boxes + # and shift those boxes so that they don't overlap between classes + for i in range(n_classes): + pro_max[i::n_classes, i] = torch.rand(n_boxes // n_classes) + con_max[i::n_classes, i] = torch.rand(n_boxes // n_classes) + cxcywh_boxes[i::n_classes, 0] += i * 10 + + class_believer_boxes = run_nms_based_box_suppression_for_all_objects( + pro_max, + con_max, + cxcywh_boxes, + 1.0, + conf_thres, + nms_thre, + run_class_agnostic_nms=True, + ) + self.assertEqual(len(class_believer_boxes), n_classes) + + class_agnostic_boxes = run_nms_based_box_suppression_for_all_objects( + pro_max, + con_max, + cxcywh_boxes, + 1.0, + conf_thres, + nms_thre, + run_class_agnostic_nms=False, + ) + + for i in range(n_classes): + self.assertTrue( + torch.all(torch.eq(class_believer_boxes[i], class_agnostic_boxes[i])) + ) + + class TestObjectDetector(unittest.TestCase): # we have to define these here because mypy doesn't dive into the init hiding behind the classmethod - object_detector = NotImplemented # type: ClassVar[ZeroShotObjectDetectorWithFeedback] + object_detector = ( + NotImplemented + ) # type: ClassVar[ZeroShotObjectDetectorWithFeedback] coke_bottle_image_bytes = NotImplemented # type: ClassVar[bytes] default_labels = NotImplemented # type: ClassVar[list[str]] default_incs = NotImplemented # type: ClassVar[dict[str, list[str]]] default_excs = NotImplemented # type: ClassVar[dict[str, list[str]]] default_nms_thre = NotImplemented # type: ClassVar[float] default_conf_thres = NotImplemented # type: ClassVar[dict[str, float]] - + @classmethod def setUpClass(cls) -> None: cls.object_detector = ZeroShotObjectDetectorWithFeedback() - + coke_bottle_filepath = "unit_tests/sample_data/coke_through_the_ages.jpeg" with open(coke_bottle_filepath, "rb") as f: cls.coke_bottle_image_bytes = f.read() - + cls.default_labels = ["bottle", "moose"] cls.default_incs = { "bottle": ["bottle", "glass bottle", "plastic bottle", "water bottle"], @@ -34,12 +139,12 @@ def setUpClass(cls) -> None: cls.default_excs = { "bottle": ["can", "soda can", "aluminum can"], } - cls.default_nms_thre = 0.4 + cls.default_nms_thre = 0.1 cls.default_conf_thres = { "bottle": 0.1, "moose": 0.1, } - + def test_detect_objects_from_image_bytes(self) -> None: with torch.no_grad(): batched_predicted_boxes = self.object_detector( @@ -50,5 +155,11 @@ def test_detect_objects_from_image_bytes(self) -> None: nms_thre=self.default_nms_thre, label_conf_thres=self.default_conf_thres, ) - - print(batched_predicted_boxes) \ No newline at end of file + + self.assertEqual(len(batched_predicted_boxes), 1) + predicted_boxes = batched_predicted_boxes[0] + self.assertEqual(len(predicted_boxes), len(self.default_labels)) + bottle_boxes = predicted_boxes[0] + moose_boxes = predicted_boxes[1] + self.assertEqual(len(bottle_boxes), 9) + self.assertEqual(len(moose_boxes), 0) diff --git a/docker-compose.yml b/docker-compose.yml index 3bebc46..ec8e51e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,7 +18,7 @@ services: container_name: local_fastapi environment: - PYTHONUNBUFFERED=1 - - NVIDIA_VISIBLE_DEVICES=1 + - NVIDIA_VISIBLE_DEVICES=all - HF_HOME=/directai_fastapi/.cache/huggingface - CACHE_REDIS_PORT=6379 env_file: diff --git a/redis_data/redis_entrypoint.sh b/redis_data/redis_entrypoint.sh index d7a3ecb..9fc4b9c 100755 --- a/redis_data/redis_entrypoint.sh +++ b/redis_data/redis_entrypoint.sh @@ -8,7 +8,7 @@ cleanup() { trap 'cleanup' SIGTERM #Execute a command in the background -redis-server --requirepass "default_password" --appendonly "yes" --appendfsync "always" --port 63790 & +redis-server --requirepass "default_password" --appendonly "yes" --appendfsync "always" --port 6379 & #Save the PID of the background process REDIS_PID=$! From 58c589b85dcd06cea6bc8125dca96309a774fd49 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Mon, 26 Aug 2024 14:26:07 -0400 Subject: [PATCH 4/7] revoking mypy cripple --- mypy.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mypy.sh b/mypy.sh index 2abe81a..3d73a10 100755 --- a/mypy.sh +++ b/mypy.sh @@ -62,6 +62,6 @@ build_testing() { if $MYPY_FULL_APP; then build_app $BUILD fi -# if $MYPY_TESTS; then -# build_testing $BUILD -# fi +if $MYPY_TESTS; then + build_testing $BUILD +fi From f63f77ee472209482137a55b9465a1df82e86874 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Mon, 26 Aug 2024 15:05:47 -0400 Subject: [PATCH 5/7] found that batched detection does not work --- directai_fastapi/modeling/object_detector.py | 79 ++++++++++++------- .../unit_tests/test_modules/test_detector.py | 42 +++++++++- 2 files changed, 92 insertions(+), 29 deletions(-) diff --git a/directai_fastapi/modeling/object_detector.py b/directai_fastapi/modeling/object_detector.py index df796fd..b113633 100644 --- a/directai_fastapi/modeling/object_detector.py +++ b/directai_fastapi/modeling/object_detector.py @@ -115,15 +115,6 @@ def flash_attn_owl_vit_encoder_forward( return attn_output, None -# we're copying the function signature from the original -# and just replacing the method with a faster one based on flash_attn -# we could subclass the original, but that would require us to subclass the entire model -# so we're just going to monkey patch it, as the output should be identical with the same inputs -# Owlv2Attention.forward = flash_attn_owl_vit_encoder_forward -# for owlv2_vision_model_encoder_layer in Owlv2VisionModel.vision_model.encoder.layers: -# owlv2_vision_model_encoder_layer.self_attn.forward = flash_attn_owl_vit_encoder_forward - - class VisionModelWrapper(nn.Module): def __init__(self, vision_model: Owlv2VisionModel) -> None: super().__init__() @@ -134,7 +125,7 @@ def __init__(self, vision_model: Owlv2VisionModel) -> None: # to replace it with a faster one based on flash_attn # the alternative is to subclass the entire model, but that's a lot of work # so we're just going to define a replacement with the same function signature - # and assert that the input is as expected + # and assert that the input is as supported by flash_attn for owlv2_vision_model_encoder_layer in self.vision_model.encoder.layers: owlv2_vision_model_encoder_layer.self_attn.forward = partial( flash_attn_owl_vit_encoder_forward, @@ -164,8 +155,9 @@ def forward( image_embeds = self.wrapped_vision_model(image) # Resize class token - new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) - class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size) + class_token_out = torch.broadcast_to( + image_embeds[:, :1, :], image_embeds[:, 1:, :].shape + ) # Merge image embedding with class tokens image_embeds = image_embeds[:, 1:, :] * class_token_out @@ -223,6 +215,8 @@ def __init__( self, model_name: str = "google/owlv2-large-patch14-ensemble", image_size: tuple[int, int] = (1008, 1008), + max_text_batch_size: int = 256, + max_image_batch_size: int = 256, device: torch.device | str = "cuda", lru_cache_size: int = 4096, jit: bool = True, @@ -243,6 +237,9 @@ def __init__( else: self.wrapped_image_embedder = WrappedImageEmbedder(self.model) + self.max_text_batch_size = max_text_batch_size + self.max_image_batch_size = max_image_batch_size + # we cache the text embeddings to avoid recomputing them # we use an LRU cache to avoid running out of memory # especially because likely the tensors will be large and stored in GPU memory @@ -266,21 +263,28 @@ def _encode_text(self, text: list[str], augment: bool = True) -> torch.Tensor: templates = medium_hypothesis_formats if augment else noop_hypothesis_formats augmented_text = [template.format(t) for t in text for template in templates] - processor_output = self.processor( - text=augmented_text, return_tensors="pt", padding=True, truncation=True - ) - input_ids = processor_output.input_ids.to(self.device) - attn_mask = processor_output.attention_mask.to(self.device) + embeddings_list = [] + for i in range(0, len(augmented_text), self.max_text_batch_size): + text_subset = augmented_text[i : i + self.max_text_batch_size] - # TODO: add appropriate batching to avoid OOM - text_output = self.model.owlv2.text_model( - input_ids=input_ids, attention_mask=attn_mask, return_dict=True - ) + processor_output = self.processor( + text=text_subset, return_tensors="pt", padding=True, truncation=True + ) + input_ids = processor_output.input_ids.to(self.device) + attn_mask = processor_output.attention_mask.to(self.device) - embeddings = text_output[1] - embeddings = self.model.owlv2.text_projection(embeddings) - embeddings = embeddings / embeddings.norm(dim=1, keepdim=True, p=2) + # TODO: add appropriate batching to avoid OOM + text_output = self.model.owlv2.text_model( + input_ids=input_ids, attention_mask=attn_mask, return_dict=True + ) + + embeddings = text_output[1] + embeddings = self.model.owlv2.text_projection(embeddings) + embeddings = embeddings / embeddings.norm(dim=1, keepdim=True, p=2) + embeddings_list.append(embeddings) + + embeddings = torch.cat(embeddings_list, dim=0) embeddings = embeddings.reshape(len(text), len(templates), embeddings.shape[1]) embeddings = embeddings.mean(dim=1) embeddings = embeddings / embeddings.norm(dim=1, keepdim=True, p=2) @@ -306,9 +310,25 @@ def get_image_data(self, image: torch.Tensor) -> dict[str, torch.Tensor]: image = image / 255.0 image = (image - self.rgb_means) / self.rgb_stds - image_class_embeds, logit_shift, logit_scale, pred_boxes = ( - self.wrapped_image_embedder(image) - ) + image_class_embeds_list = [] + logit_shift_list = [] + logit_scale_list = [] + pred_boxes_list = [] + for i in range(0, image.size(0), self.max_image_batch_size): + image_subset = image[i : i + self.max_image_batch_size] + image_class_embeds, logit_shift, logit_scale, pred_boxes = ( + self.wrapped_image_embedder(image_subset) + ) + + image_class_embeds_list.append(image_class_embeds) + logit_shift_list.append(logit_shift) + logit_scale_list.append(logit_scale) + pred_boxes_list.append(pred_boxes) + + image_class_embeds = torch.cat(image_class_embeds_list) + logit_shift = torch.cat(logit_shift_list) + logit_scale = torch.cat(logit_scale_list) + pred_boxes = torch.cat(pred_boxes_list) return { "image_class_embeds": image_class_embeds, @@ -351,6 +371,9 @@ def forward( if any([len(sub_labels) == 0 for sub_labels in inc_sub_labels_dict.values()]): raise ValueError("Each label must include at least one sub-label") + if image_tensor.shape[0] > 1: + raise ValueError("Batched image inputs are not yet supported") + image_tensor = image_tensor.to(self.device) image_data = self.get_image_data(image_tensor) @@ -629,10 +652,10 @@ def compute_candidate_nms_via_adjacency_list( survive_inds = pro_valid.nonzero().squeeze(1) return survive_inds - # we then run a NMS over the (remaining) boxes survive_inds = pro_valid_inds[pro_max[pro_valid].argsort(descending=True)] if run_nms: + # we then run a NMS over the (remaining) boxes survive_inds = run_nms_via_adjacency_list(survive_inds, modified_adjacency_list) return survive_inds diff --git a/directai_fastapi/unit_tests/test_modules/test_detector.py b/directai_fastapi/unit_tests/test_modules/test_detector.py index 4af735d..d1e988b 100644 --- a/directai_fastapi/unit_tests/test_modules/test_detector.py +++ b/directai_fastapi/unit_tests/test_modules/test_detector.py @@ -125,7 +125,7 @@ class TestObjectDetector(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - cls.object_detector = ZeroShotObjectDetectorWithFeedback() + cls.object_detector = ZeroShotObjectDetectorWithFeedback(jit=True) coke_bottle_filepath = "unit_tests/sample_data/coke_through_the_ages.jpeg" with open(coke_bottle_filepath, "rb") as f: @@ -163,3 +163,43 @@ def test_detect_objects_from_image_bytes(self) -> None: moose_boxes = predicted_boxes[1] self.assertEqual(len(bottle_boxes), 9) self.assertEqual(len(moose_boxes), 0) + + @unittest.skip("We don't yet support batched object detection") + def test_batched_detect(self) -> None: + with torch.no_grad(): + random_images = torch.rand(16, 3, *self.object_detector.image_size) + image_scale_ratios = torch.ones(16) + + single_image_outputs_list = [] + for image in random_images: + single_image_outputs_list.append( + self.object_detector( + image.unsqueeze(0), + labels=self.default_labels, + inc_sub_labels_dict=self.default_incs, + exc_sub_labels_dict=None, + nms_thre=self.default_nms_thre, + label_conf_thres={"bottle": 0.0, "moose": 0.0}, + image_scale_ratios=image_scale_ratios, + )[0] + ) + + batched_outputs = self.object_detector( + random_images, + labels=self.default_labels, + inc_sub_labels_dict=self.default_incs, + exc_sub_labels_dict=None, + nms_thre=self.default_nms_thre, + label_conf_thres={"bottle": 0.0, "moose": 0.0}, + image_scale_ratios=image_scale_ratios, + ) + + for i in range(16): + for j in range(2): + from_batch = batched_outputs[i][j] + from_single = single_image_outputs_list[i][j] + self.assertEqual(from_batch.shape, from_single.shape) + if from_batch.shape[0] == 0: + continue + max_diff = (from_batch - from_single).abs().max().item() + self.assertTrue(max_diff < 1e-5) From 2694a4db50dd68606c74b17fe94077933e442d38 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Mon, 26 Aug 2024 16:14:58 -0400 Subject: [PATCH 6/7] added batch testing --- directai_fastapi/modeling/object_detector.py | 15 +- .../unit_tests/test_modules/test_detector.py | 145 ++++++++++++++++-- 2 files changed, 135 insertions(+), 25 deletions(-) diff --git a/directai_fastapi/modeling/object_detector.py b/directai_fastapi/modeling/object_detector.py index b113633..29be7d6 100644 --- a/directai_fastapi/modeling/object_detector.py +++ b/directai_fastapi/modeling/object_detector.py @@ -215,8 +215,8 @@ def __init__( self, model_name: str = "google/owlv2-large-patch14-ensemble", image_size: tuple[int, int] = (1008, 1008), - max_text_batch_size: int = 256, - max_image_batch_size: int = 256, + max_text_batch_size: int = 32, + max_image_batch_size: int = 32, device: torch.device | str = "cuda", lru_cache_size: int = 4096, jit: bool = True, @@ -371,9 +371,6 @@ def forward( if any([len(sub_labels) == 0 for sub_labels in inc_sub_labels_dict.values()]): raise ValueError("Each label must include at least one sub-label") - if image_tensor.shape[0] > 1: - raise ValueError("Batched image inputs are not yet supported") - image_tensor = image_tensor.to(self.device) image_data = self.get_image_data(image_tensor) @@ -675,7 +672,7 @@ def run_nms_based_box_suppression_for_one_object( pro_max, con_max, adjacency_list, conf_thre, run_nms ) - boxes = pred_boxes.squeeze(0)[survive_inds] + boxes = pred_boxes[survive_inds] logits = pro_max[survive_inds] logits = logits.unsqueeze(-1) @@ -690,9 +687,9 @@ def run_nms_based_box_suppression_for_one_object( boxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=-1) boxes_with_scores = torch.cat([boxes, logits], dim=-1) - boxes_with_scores = boxes_with_scores[ - boxes_with_scores[:, 4].argsort(descending=True) - ] + ordered_by_logit = boxes_with_scores[:, 4].argsort(descending=True) + boxes_with_scores = boxes_with_scores[ordered_by_logit] + survive_inds = survive_inds[ordered_by_logit] return boxes_with_scores, survive_inds diff --git a/directai_fastapi/unit_tests/test_modules/test_detector.py b/directai_fastapi/unit_tests/test_modules/test_detector.py index d1e988b..f26c327 100644 --- a/directai_fastapi/unit_tests/test_modules/test_detector.py +++ b/directai_fastapi/unit_tests/test_modules/test_detector.py @@ -117,6 +117,7 @@ class TestObjectDetector(unittest.TestCase): NotImplemented ) # type: ClassVar[ZeroShotObjectDetectorWithFeedback] coke_bottle_image_bytes = NotImplemented # type: ClassVar[bytes] + coke_can_image_bytes = NotImplemented # type: ClassVar[bytes] default_labels = NotImplemented # type: ClassVar[list[str]] default_incs = NotImplemented # type: ClassVar[dict[str, list[str]]] default_excs = NotImplemented # type: ClassVar[dict[str, list[str]]] @@ -125,15 +126,19 @@ class TestObjectDetector(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - cls.object_detector = ZeroShotObjectDetectorWithFeedback(jit=True) + cls.object_detector = ZeroShotObjectDetectorWithFeedback(jit=False) coke_bottle_filepath = "unit_tests/sample_data/coke_through_the_ages.jpeg" with open(coke_bottle_filepath, "rb") as f: cls.coke_bottle_image_bytes = f.read() + coke_can_filepath = "unit_tests/sample_data/coke_can.jpg" + with open(coke_can_filepath, "rb") as f: + cls.coke_can_image_bytes = f.read() - cls.default_labels = ["bottle", "moose"] + cls.default_labels = ["bottle", "can", "moose"] cls.default_incs = { "bottle": ["bottle", "glass bottle", "plastic bottle", "water bottle"], + "can": ["can", "soda can", "aluminum can"], "moose": ["moose", "elk", "deer"], } cls.default_excs = { @@ -142,6 +147,7 @@ def setUpClass(cls) -> None: cls.default_nms_thre = 0.1 cls.default_conf_thres = { "bottle": 0.1, + "can": 0.1, "moose": 0.1, } @@ -160,18 +166,54 @@ def test_detect_objects_from_image_bytes(self) -> None: predicted_boxes = batched_predicted_boxes[0] self.assertEqual(len(predicted_boxes), len(self.default_labels)) bottle_boxes = predicted_boxes[0] - moose_boxes = predicted_boxes[1] + can_boxes = predicted_boxes[1] + moose_boxes = predicted_boxes[2] self.assertEqual(len(bottle_boxes), 9) + self.assertEqual(len(can_boxes), 0) self.assertEqual(len(moose_boxes), 0) - @unittest.skip("We don't yet support batched object detection") def test_batched_detect(self) -> None: + # ideally we would test a set of random images + # but we use a sort, which has unstable ordering with floating point numbers + # which means it is nontrivial to compare the outputs of the batched and single-image versions + # instead we're going to limit to confident predictions from two images + # and hope that the confidences are well-enough separated that the sort is stable + with torch.no_grad(): - random_images = torch.rand(16, 3, *self.object_detector.image_size) - image_scale_ratios = torch.ones(16) + coke_bottle_image_tensor, coke_bottle_ratio = ( + created_padded_tensor_from_bytes( + self.coke_bottle_image_bytes, self.object_detector.image_size + ) + ) + coke_can_image_tensor, coke_can_ratio = created_padded_tensor_from_bytes( + self.coke_can_image_bytes, self.object_detector.image_size + ) + + batched_images = torch.cat( + [ + coke_bottle_image_tensor, + ] + * 8 + + [ + coke_can_image_tensor, + ] + * 8, + dim=0, + ) + batched_ratios = torch.cat( + [ + coke_bottle_ratio, + ] + * 8 + + [ + coke_can_ratio, + ] + * 8, + dim=0, + ) single_image_outputs_list = [] - for image in random_images: + for image, ratio in zip(batched_images, batched_ratios): single_image_outputs_list.append( self.object_detector( image.unsqueeze(0), @@ -179,27 +221,98 @@ def test_batched_detect(self) -> None: inc_sub_labels_dict=self.default_incs, exc_sub_labels_dict=None, nms_thre=self.default_nms_thre, - label_conf_thres={"bottle": 0.0, "moose": 0.0}, - image_scale_ratios=image_scale_ratios, + label_conf_thres=self.default_conf_thres, + image_scale_ratios=ratio.unsqueeze(0), )[0] ) batched_outputs = self.object_detector( - random_images, + batched_images, labels=self.default_labels, inc_sub_labels_dict=self.default_incs, exc_sub_labels_dict=None, nms_thre=self.default_nms_thre, - label_conf_thres={"bottle": 0.0, "moose": 0.0}, - image_scale_ratios=image_scale_ratios, + label_conf_thres=self.default_conf_thres, + image_scale_ratios=batched_ratios, ) - for i in range(16): - for j in range(2): + for i in range(len(batched_outputs)): + for j in range(len(batched_outputs[i])): from_batch = batched_outputs[i][j] from_single = single_image_outputs_list[i][j] self.assertEqual(from_batch.shape, from_single.shape) if from_batch.shape[0] == 0: continue - max_diff = (from_batch - from_single).abs().max().item() - self.assertTrue(max_diff < 1e-5) + + # these values have range on the order of 1e3, so we scale them to compare + scale = torch.maximum(from_batch.abs(), from_single.abs()) + diff = (from_batch - from_single).abs() + scaled_diff = diff / (scale + 1e-6) + max_diff = scaled_diff.max().item() + + # large range and machine precision issues mean the max diff has a lot of noise + # TODO: is this more than is sane? + self.assertTrue(max_diff < 1e-3) + + def test_batch_detect_random_append(self) -> None: + # we test batched detection by doing a pass for one image + # and then doing a batched pass for that image and many random images + # we use low confidence thresholds during the detection + # and then truncate at a high confidence level to ensure stability due to sorting + confidences = {label: 0.0 for label in self.default_labels} + with torch.no_grad(): + coke_bottle_image_tensor, coke_bottle_ratio = ( + created_padded_tensor_from_bytes( + self.coke_bottle_image_bytes, self.object_detector.image_size + ) + ) + baseline_output = self.object_detector( + coke_bottle_image_tensor, + labels=self.default_labels, + inc_sub_labels_dict=self.default_incs, + exc_sub_labels_dict=None, + nms_thre=self.default_nms_thre, + label_conf_thres=confidences, + image_scale_ratios=coke_bottle_ratio, + )[0] + + random_tensors = torch.rand(128, 3, *self.object_detector.image_size) + batched_tensor = torch.cat([random_tensors, coke_bottle_image_tensor]) + batched_ratios = torch.cat([torch.ones(128), coke_bottle_ratio]) + batched_output = self.object_detector( + batched_tensor, + labels=self.default_labels, + inc_sub_labels_dict=self.default_incs, + exc_sub_labels_dict=None, + nms_thre=self.default_nms_thre, + label_conf_thres=confidences, + image_scale_ratios=batched_ratios, + )[-1] + + for baseline_obj_detections, batched_obj_detections in zip( + baseline_output, batched_output + ): + # filter by confidence of 0.1 + baseline_obj_detections = baseline_obj_detections[ + baseline_obj_detections[:, 4] > 0.1 + ] + batched_obj_detections = batched_obj_detections[ + batched_obj_detections[:, 4] > 0.1 + ] + + self.assertEqual( + baseline_obj_detections.shape, batched_obj_detections.shape + ) + if baseline_obj_detections.shape[0] == 0: + continue + + # these values have range on the order of 1e3, so we scale them to compare + scale = torch.maximum( + baseline_obj_detections.abs(), batched_obj_detections.abs() + ) + diff = (baseline_obj_detections - batched_obj_detections).abs() + scaled_diff = diff / (scale + 1e-6) + max_diff = scaled_diff.max().item() + + # large range and machine precision issues mean the max diff has a lot of noise + self.assertTrue(max_diff < 1e-6) From 5853cc461a6d35ba4ece8951d8177533cc4164d0 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Mon, 26 Aug 2024 17:35:06 -0400 Subject: [PATCH 7/7] addressed pr comments --- directai_fastapi/modeling/distributed_backend.py | 13 ++++++++----- directai_fastapi/modeling/object_detector.py | 10 +++++----- directai_fastapi/server.py | 11 +++++------ 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/directai_fastapi/modeling/distributed_backend.py b/directai_fastapi/modeling/distributed_backend.py index 45c5120..fc3d7ac 100644 --- a/directai_fastapi/modeling/distributed_backend.py +++ b/directai_fastapi/modeling/distributed_backend.py @@ -44,17 +44,20 @@ async def __call__( run_class_agnostic_nms=run_class_agnostic_nms, ) - # since we are not batching, we can assume the output has batch size 1 + # since we are processing a single image, the output has batch size 1, so we can safely index into it per_label_boxes = batched_predicted_boxes[0] # predicted_boxes is a list in order of labels, with each box of the form [x1, y1, x2, y2, confidence] detection_responses = [] for label, boxes in zip(labels, per_label_boxes): for detection in boxes: - single_detection_response = SingleDetectionResponse( - tlbr=detection[:4].tolist(), - score=detection[4].item(), - class_=label, # type: ignore + det_dict = { + "tlbr": detection[:4].tolist(), + "score": detection[4].item(), + "class_": label, + } + single_detection_response = SingleDetectionResponse.parse_obj( + det_dict ) detection_responses.append(single_detection_response) diff --git a/directai_fastapi/modeling/object_detector.py b/directai_fastapi/modeling/object_detector.py index 29be7d6..b29274a 100644 --- a/directai_fastapi/modeling/object_detector.py +++ b/directai_fastapi/modeling/object_detector.py @@ -3,14 +3,14 @@ from PIL import Image import torch from torch import nn -import torchvision # type: ignore +import torchvision # type: ignore[import-untyped] import numpy as np -from transformers import Owlv2Processor, Owlv2ForObjectDetection, Owlv2VisionModel # type: ignore -from transformers.models.owlv2.modeling_owlv2 import Owlv2Attention # type: ignore +from transformers import Owlv2Processor, Owlv2ForObjectDetection, Owlv2VisionModel # type: ignore[import-untyped] +from transformers.models.owlv2.modeling_owlv2 import Owlv2Attention # type: ignore[import-untyped] import time from typing import Union -from torch_scatter import scatter_max # type: ignore -from flash_attn import flash_attn_func # type: ignore +from torch_scatter import scatter_max # type: ignore[import-untyped] +from flash_attn import flash_attn_func # type: ignore[import-untyped] import io from lru import LRU from functools import partial diff --git a/directai_fastapi/server.py b/directai_fastapi/server.py index 0303be0..172f5a6 100644 --- a/directai_fastapi/server.py +++ b/directai_fastapi/server.py @@ -58,7 +58,7 @@ async def startup_event() -> None: app.state.config_cache = await redis.from_url( f"{grab_redis_endpoint()}?decode_responses=True" ) - print(f"Ping successful: {await app.state.config_cache.ping()}") + logger.info(f"Ping successful: {await app.state.config_cache.ping()}") @app.on_event("shutdown") @@ -71,7 +71,7 @@ async def validation_exception_handler( request: Request, exc: RequestValidationError ) -> JSONResponse: exc_str = f"{exc}".replace("\n", " ").replace(" ", " ") - print(f"{request}: {exc_str}") + logger.info(f"{request}: {exc_str}") return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content={ @@ -85,7 +85,7 @@ async def validation_exception_handler( @app.exception_handler(HTTPException) async def exception_handler(request: Request, exc: HTTPException) -> JSONResponse: exc_str = f"{exc.detail}".replace("\n", " ").replace(" ", " ") - print(f"{request}: {exc_str}") + logger.info(f"{request}: {exc_str}") return JSONResponse( status_code=exc.status_code, content={"status_code": exc.status_code, "message": exc_str, "data": None}, @@ -110,7 +110,7 @@ async def deploy_classifier(request: Request, config: ClassifierDeploy) -> dict: deploy_response = await config.save_configuration( config_cache=app.state.config_cache ) - print(f"Deployed classifier w/ ID: {deploy_response['deployed_id']}") + logger.info(f"Deployed classifier w/ ID: {deploy_response['deployed_id']}") return deploy_response @@ -167,7 +167,7 @@ async def deploy_detector(request: Request, config: DetectorDeploy) -> dict: deploy_response = await config.save_configuration( config_cache=app.state.config_cache ) - print(f"Deployed detector w/ ID: {deploy_response['deployed_id']}") + logger.info(f"Deployed detector w/ ID: {deploy_response['deployed_id']}") return deploy_response @@ -187,7 +187,6 @@ async def run_detector( data: UploadFile = File(), ) -> List[List[SingleDetectionResponse]]: """Get detections from deployed model""" - print(f"Got request for {deployed_id}, which is a detector model") image = data.file.read() raise_if_cannot_open(image) logger.info(f"Got request for {deployed_id}, which is a detector model")