This tutorial walks through how to utilize the models downloaded from Rapid for your own deployment use-cases.
Model format
Rapid generates the model in a mobile-optimized Torchscript format. Torchscript is an intermediate representation of a PyTorch model that can then be run on an edge device in a high-performance environment such as C++. This Torchscript model file then be loaded onto a mobile application (iOS or Android) via the PyTorch C++ libraries or PyTorch Android API. Pytorch provides documentation and examples of how to embed the Torchscript model into your mobile application.
The downloaded file will be in the format mobile_optimized.ptl and can be loaded using the torch
library.
Loading the Torchscript model
To load the trained Rapid model, use the torch
and the torchvision
libraries:
import torch
import torchvision
model = torch.jit.load("PATH_TO_FILE/mobile_optimized.ptl")
model.eval()
The eval()
mode ensures the model is being used for inference purposes and not training.
Instance Segmentation Model
Now that you've loaded your trained torchscript model, let's walk through how to use it for instance segmentation.
Loading the image
Let's first read in an image and set it to the appropriate format for inputting into the trained model.
import cv2
im = cv2.imread("./test_image.jpeg")
image = im[:, :, ::-1]
image_height, image_width = image.shape[:2]
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
Note: the image can be any size. However, to make model inference faster, you may resize the image down. A brief example on how to do that is here:
import detectron2.data.transforms as T
original_image = im[:, :, ::-1]
image_height, image_width = original_image.shape[:2]
image = aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
Running the Model
Using the model loaded in the earlier step, you can now pass the formatted image through the model.
with torch.no_grad():
outputs = model(image)
The outputs of the instance segmentation model are a Tuple
where N
is the number of predicted objects.
predicted_bounding_boxes = outputs[0] # (N, 4)
predicted_classes = outputs[1] # (N, 1) where each element is the class label
predicted_instance_masks = outputs[2] # (N, 1, 28, 28) - see post-proccessing below
predicted_scores = outputs[3] # (N, 1) where each element is the score
(image_height, image_width) = outputs[4]
Note: the masks are outputted in a fixed 28x28 resolution that needs to be pasted into an image.
Post-processing Instance Segmentation Outputs
For each object predicted, the model generates 28x28 instance segmentation masks. During training, the ground truth mask is downscaled to compute the loss with the predicted mask, and during inference the generated mask needs to be up-scaled to the bounding box size. To do so, you may implement your own re-sizing function natively, or there are the following provided post-processing functions. The following functions re-size and convert the instance segmentation masks to bitmasks (0 indicates no-mask, and 1 indicates mask) where the output value is (N, image_height, image_width)
.
from typing import Tuple
from torch.nn import functional as F
BYTES_PER_FLOAT = 4
GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit
def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True):
"""
Args:
masks: N, 1, H, W
boxes: N, 4
img_h, img_w (int):
skip_empty (bool): only paste masks within the region that
tightly bound all boxes, and returns the results this region only.
An important optimization for CPU.
Returns:
if skip_empty == False, a mask of shape (N, img_h, img_w)
if skip_empty == True, a mask of shape (N, h', w'), and the slice
object for the corresponding region.
"""
# On GPU, paste all masks together (up to chunk size)
# by using the entire image to sample the masks
# Compared to pasting them one by one,
# this has more operations but is faster on COCO-scale dataset.
device = masks.device
if skip_empty and not torch.jit.is_scripting():
x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(
dtype=torch.int32
)
x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)
y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)
else:
x0_int, y0_int = 0, 0
x1_int, y1_int = img_w, img_h
x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1
N = masks.shape[0]
img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
img_x = (img_x - x0) / (x1 - x0) * 2 - 1
# img_x, img_y have shapes (N, w), (N, h)
gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
grid = torch.stack([gx, gy], dim=3)
if not torch.jit.is_scripting():
if not masks.dtype.is_floating_point:
masks = masks.float()
img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)
if skip_empty and not torch.jit.is_scripting():
return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
else:
return img_masks[:, 0], ()
def paste_masks_in_image(
masks: torch.Tensor, boxes: torch.Tensor, image_shape: Tuple[int, int], threshold: float = 0.5
):
"""
Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image.
The location, height, and width for pasting each mask is determined by their
corresponding bounding boxes in boxes.
Args:
masks (tensor): Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of
detected object instances in the image and Hmask, Wmask are the mask width and mask
height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1].
boxes (Boxes or Tensor): A Boxes of length Bimg or Tensor of shape (Bimg, 4).
boxes[i] and masks[i] correspond to the same object instance.
image_shape (tuple): height, width
threshold (float): A threshold in [0, 1] for converting the (soft) masks to
binary masks.
Returns:
img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the
number of detected object instances and Himage, Wimage are the image width
and height. img_masks[i] is a binary mask for object instance i.
"""
assert masks.shape[-1] == masks.shape[-2], "Only square mask predictions are supported"
N = len(masks)
if N == 0:
return masks.new_empty((0,) + image_shape, dtype=torch.uint8)
if not isinstance(boxes, torch.Tensor):
boxes = boxes.tensor
device = boxes.device
assert len(boxes) == N, boxes.shape
img_h, img_w = image_shape
# The actual implementation split the input into chunks,
# and paste them chunk by chunk.
if device.type == "cpu" or torch.jit.is_scripting():
# CPU is most efficient when they are pasted one by one with skip_empty=True
# so that it performs minimal number of operations.
num_chunks = N
else:
# GPU benefits from parallelism for larger chunks, but may have memory issue
# int(img_h) because shape may be tensors in tracing
num_chunks = int(np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
assert (
num_chunks <= N
), "Default GPU_MEM_LIMIT in mask_ops.py is too small; try increasing it"
chunks = torch.chunk(torch.arange(N, device=device), num_chunks)
img_masks = torch.zeros(
N, img_h, img_w, device=device, dtype=torch.bool if threshold >= 0 else torch.uint8
)
for inds in chunks:
masks_chunk, spatial_inds = _do_paste_mask(
masks[inds, None, :, :], boxes[inds], img_h, img_w, skip_empty=device.type == "cpu"
)
if threshold >= 0:
masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
else:
# for visualization and debugging
masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)
if torch.jit.is_scripting(): # Scripting does not use the optimized codepath
img_masks[inds] = masks_chunk
else:
img_masks[(inds,) + spatial_inds] = masks_chunk
return img_masks
def post_process_instance_seg_masks(outputs):
pred_boxes = outputs[0]
pred_masks = outputs[2][:, 0, :, :]
height, width = outputs[4]
bitmasks = paste_masks_in_image(pred_masks, pred_boxes, (height, width), threshold=0.5)
masks = bitmasks.to(torch.bool)
return masks
Alternatively, you can also you a built-in detectron utility function to do this:
from detectron2.layers.mask_ops import paste_masks_in_image
def post_process_instance_seg_masks(outputs, image_height, image_width):
pred_boxes = outputs[0]
pred_masks = outputs[2][:, 0, :, :]
height, width = outputs[4]
bitmasks = paste_masks_in_image(pred_masks, pred_boxes, (height, width), threshold=0.5)
masks = bitmasks.to(torch.bool)
return masks
Now you have a fully functioning instance segmentation model with parsable bounding box and instance segmentation detections!