Download and Use Your Model

This tutorial walks through how to utilize the models downloaded from Rapid for your own deployment use-cases.

Model format

Rapid generates the model in a mobile-optimized Torchscript format. Torchscript is an intermediate representation of a PyTorch model that can then be run on an edge device in a high-performance environment such as C++. This Torchscript model file then be loaded onto a mobile application (iOS or Android) via the PyTorch C++ libraries or PyTorch Android API. Pytorch provides documentation and examples of how to embed the Torchscript model into your mobile application.

The downloaded file will be in the format mobile_optimized.ptl and can be loaded using the torch library.

Loading the Torchscript model

To load the trained Rapid model, use the torch and the torchvision libraries:

import torch
import torchvision

model = torch.jit.load("PATH_TO_FILE/mobile_optimized.ptl")
model.eval()

The eval() mode ensures the model is being used for inference purposes and not training.

Instance Segmentation Model

Now that you've loaded your trained torchscript model, let's walk through how to use it for instance segmentation.

Loading the image

Let's first read in an image and set it to the appropriate format for inputting into the trained model.

import cv2

im = cv2.imread("./test_image.jpeg")
image = im[:, :, ::-1]
image_height, image_width = image.shape[:2]
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

Note: the image can be any size. However, to make model inference faster, you may resize the image down. A brief example on how to do that is here:

import detectron2.data.transforms as T

original_image = im[:, :, ::-1]
image_height, image_width = original_image.shape[:2]
image = aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

Running the Model

Using the model loaded in the earlier step, you can now pass the formatted image through the model.

with torch.no_grad():
    outputs = model(image)

The outputs of the instance segmentation model are a Tuple where N is the number of predicted objects.

predicted_bounding_boxes = outputs[0] # (N, 4)
predicted_classes = outputs[1] # (N, 1) where each element is the class label
predicted_instance_masks = outputs[2] # (N, 1, 28, 28) - see post-proccessing below
predicted_scores = outputs[3] # (N, 1) where each element is the score
(image_height, image_width) = outputs[4]

Note: the masks are outputted in a fixed 28x28 resolution that needs to be pasted into an image.

Post-processing Instance Segmentation Outputs

For each object predicted, the model generates 28x28 instance segmentation masks. During training, the ground truth mask is downscaled to compute the loss with the predicted mask, and during inference the generated mask needs to be up-scaled to the bounding box size. To do so, you may implement your own re-sizing function natively, or there are the following provided post-processing functions. The following functions re-size and convert the instance segmentation masks to bitmasks (0 indicates no-mask, and 1 indicates mask) where the output value is (N, image_height, image_width).


from typing import Tuple
from torch.nn import functional as F

BYTES_PER_FLOAT = 4
GPU_MEM_LIMIT = 1024**3  # 1 GB memory limit
def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True):
    """
    Args:
        masks: N, 1, H, W
        boxes: N, 4
        img_h, img_w (int):
        skip_empty (bool): only paste masks within the region that
            tightly bound all boxes, and returns the results this region only.
            An important optimization for CPU.
    Returns:
        if skip_empty == False, a mask of shape (N, img_h, img_w)
        if skip_empty == True, a mask of shape (N, h', w'), and the slice
            object for the corresponding region.
    """
    # On GPU, paste all masks together (up to chunk size)
    # by using the entire image to sample the masks
    # Compared to pasting them one by one,
    # this has more operations but is faster on COCO-scale dataset.
    device = masks.device

    if skip_empty and not torch.jit.is_scripting():
        x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(
            dtype=torch.int32
        )
        x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)
        y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)
    else:
        x0_int, y0_int = 0, 0
        x1_int, y1_int = img_w, img_h
    x0, y0, x1, y1 = torch.split(boxes, 1, dim=1)  # each is Nx1

    N = masks.shape[0]

    img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
    img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
    img_y = (img_y - y0) / (y1 - y0) * 2 - 1
    img_x = (img_x - x0) / (x1 - x0) * 2 - 1
    # img_x, img_y have shapes (N, w), (N, h)

    gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
    gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
    grid = torch.stack([gx, gy], dim=3)

    if not torch.jit.is_scripting():
        if not masks.dtype.is_floating_point:
            masks = masks.float()
    img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)

    if skip_empty and not torch.jit.is_scripting():
        return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
    else:
        return img_masks[:, 0], ()

def paste_masks_in_image(
    masks: torch.Tensor, boxes: torch.Tensor, image_shape: Tuple[int, int], threshold: float = 0.5
):
    """
    Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image.
    The location, height, and width for pasting each mask is determined by their
    corresponding bounding boxes in boxes.
    Args:
        masks (tensor): Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of
            detected object instances in the image and Hmask, Wmask are the mask width and mask
            height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1].
        boxes (Boxes or Tensor): A Boxes of length Bimg or Tensor of shape (Bimg, 4).
            boxes[i] and masks[i] correspond to the same object instance.
        image_shape (tuple): height, width
        threshold (float): A threshold in [0, 1] for converting the (soft) masks to
            binary masks.
    Returns:
        img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the
        number of detected object instances and Himage, Wimage are the image width
        and height. img_masks[i] is a binary mask for object instance i.
    """

    assert masks.shape[-1] == masks.shape[-2], "Only square mask predictions are supported"
    N = len(masks)
    if N == 0:
        return masks.new_empty((0,) + image_shape, dtype=torch.uint8)
    if not isinstance(boxes, torch.Tensor):
        boxes = boxes.tensor
    device = boxes.device
    assert len(boxes) == N, boxes.shape

    img_h, img_w = image_shape

    # The actual implementation split the input into chunks,
    # and paste them chunk by chunk.
    if device.type == "cpu" or torch.jit.is_scripting():
        # CPU is most efficient when they are pasted one by one with skip_empty=True
        # so that it performs minimal number of operations.
        num_chunks = N
    else:
        # GPU benefits from parallelism for larger chunks, but may have memory issue
        # int(img_h) because shape may be tensors in tracing
        num_chunks = int(np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
        assert (
            num_chunks <= N
        ), "Default GPU_MEM_LIMIT in mask_ops.py is too small; try increasing it"
    chunks = torch.chunk(torch.arange(N, device=device), num_chunks)

    img_masks = torch.zeros(
        N, img_h, img_w, device=device, dtype=torch.bool if threshold >= 0 else torch.uint8
    )
    for inds in chunks:
        masks_chunk, spatial_inds = _do_paste_mask(
            masks[inds, None, :, :], boxes[inds], img_h, img_w, skip_empty=device.type == "cpu"
        )

        if threshold >= 0:
            masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
        else:
            # for visualization and debugging
            masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)

        if torch.jit.is_scripting():  # Scripting does not use the optimized codepath
            img_masks[inds] = masks_chunk
        else:
            img_masks[(inds,) + spatial_inds] = masks_chunk
    return img_masks

def post_process_instance_seg_masks(outputs):
    pred_boxes = outputs[0]
    pred_masks = outputs[2][:, 0, :, :]
    height, width = outputs[4]
    bitmasks = paste_masks_in_image(pred_masks, pred_boxes, (height, width), threshold=0.5)
    masks = bitmasks.to(torch.bool)
    return masks

Alternatively, you can also you a built-in detectron utility function to do this:

 from detectron2.layers.mask_ops import paste_masks_in_image

 def post_process_instance_seg_masks(outputs, image_height, image_width):
     pred_boxes = outputs[0]
     pred_masks = outputs[2][:, 0, :, :]
     height, width = outputs[4]
     bitmasks = paste_masks_in_image(pred_masks, pred_boxes, (height, width), threshold=0.5)
     masks = bitmasks.to(torch.bool)
     return masks

Now you have a fully functioning instance segmentation model with parsable bounding box and instance segmentation detections!

Updated a year ago