This tutorial walks through how to utilize the models downloaded from Rapid for your own deployment use-cases.
Rapid generates the model in a mobile-optimized Torchscript format. Torchscript is an intermediate representation of a PyTorch model that can then be run on an edge device in a high-performance environment such as C++. This Torchscript model file then be loaded onto a mobile application (iOS or Android) via the PyTorch C++ libraries or PyTorch Android API. Pytorch provides documentation and examples of how to embed the Torchscript model into your mobile application.
The downloaded file will be in the format mobile_optimized.ptl and can be loaded using the
To load the trained Rapid model, use the
torch and the
import torch import torchvision model = torch.jit.load("PATH_TO_FILE/mobile_optimized.ptl") model.eval()
eval() mode ensures the model is being used for inference purposes and not training.
Now that you've loaded your trained torchscript model, let's walk through how to use it for instance segmentation.
Let's first read in an image and set it to the appropriate format for inputting into the trained model.
import cv2 im = cv2.imread("./test_image.jpeg") image = im[:, :, ::-1] image_height, image_width = image.shape[:2] image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
Note: the image can be any size. However, to make model inference faster, you may resize the image down. A brief example on how to do that is here:
import detectron2.data.transforms as T original_image = im[:, :, ::-1] image_height, image_width = original_image.shape[:2] image = aug.get_transform(original_image).apply_image(original_image) image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
Using the model loaded in the earlier step, you can now pass the formatted image through the model.
with torch.no_grad(): outputs = model(image)
The outputs of the instance segmentation model are a
N is the number of predicted objects.
predicted_bounding_boxes = outputs # (N, 4) predicted_classes = outputs # (N, 1) where each element is the class label predicted_instance_masks = outputs # (N, 1, 28, 28) - see post-proccessing below predicted_scores = outputs # (N, 1) where each element is the score (image_height, image_width) = outputs
Note: the masks are outputted in a fixed 28x28 resolution that needs to be pasted into an image.
For each object predicted, the model generates 28x28 instance segmentation masks. During training, the ground truth mask is downscaled to compute the loss with the predicted mask, and during inference the generated mask needs to be up-scaled to the bounding box size. To do so, you may implement your own re-sizing function natively, or there are the following provided post-processing functions. The following functions re-size and convert the instance segmentation masks to bitmasks (0 indicates no-mask, and 1 indicates mask) where the output value is
(N, image_height, image_width).
from typing import Tuple from torch.nn import functional as F BYTES_PER_FLOAT = 4 GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True): """ Args: masks: N, 1, H, W boxes: N, 4 img_h, img_w (int): skip_empty (bool): only paste masks within the region that tightly bound all boxes, and returns the results this region only. An important optimization for CPU. Returns: if skip_empty == False, a mask of shape (N, img_h, img_w) if skip_empty == True, a mask of shape (N, h', w'), and the slice object for the corresponding region. """ # On GPU, paste all masks together (up to chunk size) # by using the entire image to sample the masks # Compared to pasting them one by one, # this has more operations but is faster on COCO-scale dataset. device = masks.device if skip_empty and not torch.jit.is_scripting(): x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to( dtype=torch.int32 ) x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) else: x0_int, y0_int = 0, 0 x1_int, y1_int = img_w, img_h x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 N = masks.shape img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5 img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5 img_y = (img_y - y0) / (y1 - y0) * 2 - 1 img_x = (img_x - x0) / (x1 - x0) * 2 - 1 # img_x, img_y have shapes (N, w), (N, h) gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) grid = torch.stack([gx, gy], dim=3) if not torch.jit.is_scripting(): if not masks.dtype.is_floating_point: masks = masks.float() img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False) if skip_empty and not torch.jit.is_scripting(): return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) else: return img_masks[:, 0], () def paste_masks_in_image( masks: torch.Tensor, boxes: torch.Tensor, image_shape: Tuple[int, int], threshold: float = 0.5 ): """ Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image. The location, height, and width for pasting each mask is determined by their corresponding bounding boxes in boxes. Args: masks (tensor): Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of detected object instances in the image and Hmask, Wmask are the mask width and mask height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1]. boxes (Boxes or Tensor): A Boxes of length Bimg or Tensor of shape (Bimg, 4). boxes[i] and masks[i] correspond to the same object instance. image_shape (tuple): height, width threshold (float): A threshold in [0, 1] for converting the (soft) masks to binary masks. Returns: img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the number of detected object instances and Himage, Wimage are the image width and height. img_masks[i] is a binary mask for object instance i. """ assert masks.shape[-1] == masks.shape[-2], "Only square mask predictions are supported" N = len(masks) if N == 0: return masks.new_empty((0,) + image_shape, dtype=torch.uint8) if not isinstance(boxes, torch.Tensor): boxes = boxes.tensor device = boxes.device assert len(boxes) == N, boxes.shape img_h, img_w = image_shape # The actual implementation split the input into chunks, # and paste them chunk by chunk. if device.type == "cpu" or torch.jit.is_scripting(): # CPU is most efficient when they are pasted one by one with skip_empty=True # so that it performs minimal number of operations. num_chunks = N else: # GPU benefits from parallelism for larger chunks, but may have memory issue # int(img_h) because shape may be tensors in tracing num_chunks = int(np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT)) assert ( num_chunks <= N ), "Default GPU_MEM_LIMIT in mask_ops.py is too small; try increasing it" chunks = torch.chunk(torch.arange(N, device=device), num_chunks) img_masks = torch.zeros( N, img_h, img_w, device=device, dtype=torch.bool if threshold >= 0 else torch.uint8 ) for inds in chunks: masks_chunk, spatial_inds = _do_paste_mask( masks[inds, None, :, :], boxes[inds], img_h, img_w, skip_empty=device.type == "cpu" ) if threshold >= 0: masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) else: # for visualization and debugging masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) if torch.jit.is_scripting(): # Scripting does not use the optimized codepath img_masks[inds] = masks_chunk else: img_masks[(inds,) + spatial_inds] = masks_chunk return img_masks def post_process_instance_seg_masks(outputs): pred_boxes = outputs pred_masks = outputs[:, 0, :, :] height, width = outputs bitmasks = paste_masks_in_image(pred_masks, pred_boxes, (height, width), threshold=0.5) masks = bitmasks.to(torch.bool) return masks
Alternatively, you can also you a built-in detectron utility function to do this:
from detectron2.layers.mask_ops import paste_masks_in_image def post_process_instance_seg_masks(outputs, image_height, image_width): pred_boxes = outputs pred_masks = outputs[:, 0, :, :] height, width = outputs bitmasks = paste_masks_in_image(pred_masks, pred_boxes, (height, width), threshold=0.5) masks = bitmasks.to(torch.bool) return masks
Now you have a fully functioning instance segmentation model with parsable bounding box and instance segmentation detections!