Open in Colab

Image Prompter: Segment Anything#

This tutorials shows how to use our high-level API Image Prompter. This API allow to set an image, and run multiple queries multiple times on this image. These query can be done with three types of prompt:

  1. Points: Keypoints with (x, y) and a respective label. Where 0 indicates a background point; 1 indicates a foreground point;

  2. Boxes: Boxes of different regions.

  3. Masks: Logits generated by the model in a previous run.

Read more on our docs: https://kornia.readthedocs.io/en/latest/models/segment_anything.html

This tutorials steps:

  1. Setup the desired SAM model and import the necessary packages

    1. Utilities function to read and plot the data

  2. How to instantiate the Image Prompter

    1. How to set an image

    2. The supported prompts type

    3. Example how to query on the image

    4. Prediction structure

  3. Using the Image Prompter

    1. Soccer player segmentation

    2. Car parts segmentation

    3. Satellite image - Sentinel-2

Setup#

1%%capture
2!pip install kornia
3!pip install kornia-rs

First let’s choose the SAM type to be used on our Image Prompter.

The options are (smaller to bigger):

model_type

checkpoint official

vit_b

https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

vit_l

https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth

vit_h

https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

1model_type = "vit_h"
2checkpoint = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"

Then let’s import all necessary packages and modules

 1from __future__ import annotations
 2
 3import os
 4
 5import matplotlib.pyplot as plt
 6import torch
 7from kornia.contrib.image_prompter import ImagePrompter
 8from kornia.contrib.models.sam import SamConfig
 9from kornia.geometry.boxes import Boxes
10from kornia.geometry.keypoints import Keypoints
11from kornia.io import ImageLoadType, load_image
12from kornia.utils import get_cuda_device_if_available, tensor_to_image
1device = get_cuda_device_if_available()
2print(device)
cuda:0

Utilities functions#

 1import io
 2
 3import requests
 4
 5
 6def download_image(url: str, filename: str = "") -> str:
 7    filename = url.split("/")[-1] if len(filename) == 0 else filename
 8    # Download
 9    bytesio = io.BytesIO(requests.get(url).content)
10    # Save file
11    with open(filename, "wb") as outfile:
12        outfile.write(bytesio.getbuffer())
13
14    return filename
15
16
17soccer_image_path = download_image("https://raw.githubusercontent.com/kornia/data/main/soccer.jpg")
18car_image_path = download_image("https://raw.githubusercontent.com/kornia/data/main/simple_car.jpg")
19satellite_image_path = download_image("https://raw.githubusercontent.com/kornia/data/main/satellite_sentinel2_example.tif")
20soccer_image_path, car_image_path, satellite_image_path
('soccer.jpg', 'simple_car.jpg', 'satellite_sentinel2_example.tif')
 1def colorize_masks(binary_masks: torch.Tensor, merge: bool = True, alpha: None | float = None) -> list[torch.Tensor]:
 2    """Convert binary masks (B, C, H, W), boolean tensors, into masks with colors (B, (3, 4) , H, W) - RGB or RGBA. Where C refers to the number of masks.
 3    Args:
 4        binary_masks: a batched boolean tensor (B, C, H, W)
 5        merge: If true, will join the batch dimension into a unique mask.
 6        alpha: alpha channel value. If None, will generate RGB images
 7
 8    Returns:
 9        A list of `C` colored masks.
10    """
11    B, C, H, W = binary_masks.shape
12    OUT_C = 4 if alpha else 3
13
14    output_masks = []
15
16    for idx in range(C):
17        _out = torch.zeros(B, OUT_C, H, W, device=binary_masks.device, dtype=torch.float32)
18        for b in range(B):
19            color = torch.rand(1, 3, 1, 1, device=binary_masks.device, dtype=torch.float32)
20            if alpha:
21                color = torch.cat([color, torch.tensor([[[[alpha]]]], device=binary_masks.device, dtype=torch.float32)], dim=1)
22
23            to_colorize = binary_masks[b, idx, ...].view(1, 1, H, W).repeat(1, OUT_C, 1, 1)
24            _out[b, ...] = torch.where(to_colorize, color, _out[b, ...])
25        output_masks.append(_out)
26
27    if merge:
28        output_masks = [c.max(dim=0)[0] for c in output_masks]
29
30    return output_masks
31
32
33def show_binary_masks(binary_masks: torch.Tensor, axes) -> None:
34    """plot binary masks, with shape (B, C, H, W), where C refers to the number of masks.
35
36    will merge the `B` channel into a unique mask.
37    Args:
38        binary_masks: a batched boolean tensor (B, C, H, W)
39        ax: a list of matplotlib axes with lenght of C
40    """
41    colored_masks = colorize_masks(binary_masks, True, 0.6)
42
43    for ax, mask in zip(axes, colored_masks):
44        ax.imshow(tensor_to_image(mask))
45
46
47def show_boxes(boxes: Boxes, ax) -> None:
48    boxes_tensor = boxes.to_tensor(mode="xywh").detach().cpu().numpy()
49    for batched_boxes in boxes_tensor:
50        for box in batched_boxes:
51            x0, y0, w, h = box
52            ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="orange", facecolor=(0, 0, 0, 0), lw=2))
53
54
55def show_points(points: tuple[Keypoints, torch.Tensor], ax, marker_size=200):
56    coords, labels = points
57    pos_points = coords[labels == 1].to_tensor().detach().cpu().numpy()
58    neg_points = coords[labels == 0].to_tensor().detach().cpu().numpy()
59
60    ax.scatter(pos_points[:, 0], pos_points[:, 1], color="green", marker="+", s=marker_size, linewidth=2)
61    ax.scatter(neg_points[:, 0], neg_points[:, 1], color="red", marker="x", s=marker_size, linewidth=2)
 1from kornia.contrib.models import SegmentationResults
 2
 3
 4def show_image(image: torch.Tensor):
 5    plt.imshow(tensor_to_image(image))
 6    plt.axis("off")
 7    plt.show()
 8
 9
10def show_predictions(
11    image: torch.Tensor,
12    predictions: SegmentationResults,
13    points: tuple[Keypoints, torch.Tensor] | None = None,
14    boxes: Boxes | None = None,
15) -> None:
16    n_masks = predictions.logits.shape[1]
17
18    fig, axes = plt.subplots(1, n_masks, figsize=(21, 16))
19    axes = [axes] if n_masks == 1 else axes
20
21    for idx, ax in enumerate(axes):
22        score = predictions.scores[:, idx, ...].mean()
23        ax.imshow(tensor_to_image(image))
24        ax.set_title(f"Mask {idx+1}, Score: {score:.3f}", fontsize=18)
25
26        if points:
27            show_points(points, ax)
28
29        if boxes:
30            show_boxes(boxes, ax)
31
32        ax.axis("off")
33
34    show_binary_masks(predictions.binary_masks, axes)
35    plt.show()

Exploring the Image Prompter#

The ImagePrompter can be initialized from a ModelConfig structure, where now we just have support for the SAM model through the SamConfig. Through this config the ImagePrompter will initialize the SAM model and load the weights (from a path or a URL).

What the ImagePrompter can do?

  1. Based on the ModelConfig, besides the model initialization, we will setup the required transformations for the images and prompts using the kornia.augmentation API within the Augmentation sequential container.

  2. You can benefit from using the torch.compile(...) API (dynamo) for torch >= 2.0.0 versions. To compile with dynamo we provide the method ImagePrompter.compile(...) which will optimize the right parts of the backend model and the prompter itself.

  3. Caching the image features and transformations. With the ImagePrompter.set_image(...) method, we transform the image and already encode it using the model, caching it’s embeddings to query later.

  4. Query multiple times with multiple prompts. Using the ImagePrompter.predict(...), where we will query on our cached embeddings using Keypoints, Boxes and Masks as prompt.

What the ImagePrompter and Kornia provides? Easy high-levels structures to be used as prompt, also as the result of the prediction. Using the kornia geometry module you can easily encapsulate the Keypoints and Boxes, which allow the API to be more flexible about the desired mode (mainly for boxes, where we had multiple modes of represent it).

The Kornia ImagePrompter and model config for SAM can be imported as follow:

from kornia.contrib.image_prompter import ImagePrompter
from kornia.contrib.models import SamConfig
1# Setting up a SamConfig with the model type and checkpoint desired
2config = SamConfig(model_type, checkpoint)
3
4# Initialize the ImagePrompter
5prompter = ImagePrompter(config, device=device)

Set image#

First, before set the image into the prompter, we need to read the image. For it, we can use kornia.io, which internally uses kornia-rs. So for, it ensure to have kornia-rs installed, you can install it with pip install kornia_rs. This API implement the DLPack protocol natively in Rust to reduce the memory footprint during the decoding and types conversion. Allowing us to read the image from the disk directly to a tensor.

1# Load the image
2image = load_image(soccer_image_path, ImageLoadType.RGB32, device)  # 3 x H x W
3
4# Display the loaded image
5show_image(image)
../_images/c5bc6570829fbea2232a4e36171c71791681103973197b22486f4a2d00f69301.png

With the image loaded into the same device than the model, and with the right shape 3xHxW let’s set the image into our image prompter. Attention, when doing this the model will already compute the embeddings of this image. This means, we will pass this image through the encoder, which will uses a lot of memory. It is possible to use the largest model (vit-h) with a graphic card (GPU) that has at least 8Gb of VRAM.

1prompter.set_image(image)

If no error occurred, the features needed to run queries are already cached. If you want to check this, you can see the status of the prompter.is_image_set property.

1prompter.is_image_set
True

Examples of prompts#

The ImagePrompter output will have the same Batch Size that its prompts. Where the output shape will be (B, C, H, W). Where B is the number of input prompts, C is determined by multimask output parameter. If multimask_output is True than C=3, otherwise C=1

Keypoints#

Keypoints prompts is a tensor or a Keypoint structure within coordinates into (x, y). With shape BxNx2.

For each coordinate pair, should have a corresponding label, where 0 indicates a background point; 1 indicates a foreground point; These labels should be in a tensor with shape BxN

The model will try to find a object within all the foreground points, and without the background points. In other words, the foreground points can be used to select the desired type of data, and the background point to exclude the type of data.

1keypoints_tensor = torch.tensor([[[960, 540]]], device=device, dtype=torch.float32)
2keypoints = Keypoints(keypoints_tensor)
3
4labels = torch.tensor([[1]], device=device, dtype=torch.float32)

Boxes#

Boxes prompts is a tensor a with boxes on “xyxy” format/mode, or a Boxes structure. Tensor should have a shape of BxNx4.

1boxes_tensor = torch.tensor([[[1841.7, 739.0, 1906.5, 890.6]]], device=device, dtype=torch.float32)
2boxes = Boxes.from_tensor(boxes_tensor, mode="xyxy")

Masks#

Masks prompts should be provide from a previous model output, with shape Bx1x256x256

# first run
predictions = prompter.prediction(...)

# use previous results as prompt
predictions = prompter.prediction(..., mask=predictions.logits)

Example of prediction#

1# using keypoints
2prediction_by_keypoint = prompter.predict(keypoints, labels, multimask_output=False)
3
4show_image(prediction_by_keypoint.binary_masks)
../_images/25af69026109977d9454cbca915126c1c89f10dbde5017417c061c509d560b40.png
1# Using boxes
2prediction_by_boxes = prompter.predict(boxes=boxes, multimask_output=False)
3
4show_image(prediction_by_boxes.binary_masks)
../_images/b275c038eaac800f57077c8a1461c41fe97e76b0904ddaa7c3ab970a5023337c.png

Exploring the prediction result structure#

The ImagePrompter prediction structure, is a SegmentationResults which has the upscaled (default) logits when output_original_size=True is passed on the predict.

The segmentation results have:

  • logits: Results logits with shape (B, C, H, W), where C refers to the number of predicted masks

  • scores: The scores from the logits. Shape (B,)

  • Binary mask generated from logits considering the mask_threshold. The size depends on original_res_logits=True, if false, the binary mask will have the same shape of the logits Bx1x256x256

1prediction_by_boxes.scores
tensor([[0.9317]], device='cuda:0')
1prediction_by_boxes.binary_masks.shape
torch.Size([1, 1, 1080, 1920])
1prediction_by_boxes.logits.shape
torch.Size([1, 1, 256, 256])

Using the Image Prompter on examples#

Soccer players#

Using an example image from the dataset: https://www.kaggle.com/datasets/ihelon/football-player-segmentation

Lets segment the persons on the field using boxes

 1# Prompts
 2boxes = Boxes.from_tensor(
 3    torch.tensor(
 4        [
 5            [
 6                [1841.7000, 739.0000, 1906.5000, 890.6000],
 7                [879.3000, 545.9000, 948.2000, 669.2000],
 8                [55.7000, 595.0000, 127.4000, 745.9000],
 9                [1005.4000, 128.7000, 1031.5000, 212.0000],
10                [387.4000, 424.1000, 438.2000, 539.0000],
11                [921.0000, 377.7000, 963.3000, 483.0000],
12                [1213.2000, 885.8000, 1276.2000, 1060.1000],
13                [40.8900, 725.9600, 105.4100, 886.5800],
14                [848.9600, 283.6200, 896.0600, 368.6200],
15                [1109.6500, 499.0400, 1153.0400, 622.1700],
16                [576.3000, 860.8000, 671.7000, 1018.8000],
17                [1039.8000, 389.9000, 1072.5000, 493.2000],
18                [1647.1000, 315.1000, 1694.0000, 406.0000],
19                [1231.2000, 214.0000, 1294.1000, 297.3000],
20            ]
21        ],
22        device=device,
23    ),
24    mode="xyxy",
25)
1# Load the image
2image = load_image(soccer_image_path, ImageLoadType.RGB32, device)  # 3 x H x W
3
4# Set the image
5prompter.set_image(image)
1predictions = prompter.predict(boxes=boxes, multimask_output=True)

let’s see the results, since we used multimask_output=True, the model outputted 3 masks.

1show_predictions(image, predictions, boxes=boxes)
../_images/dd2191c6ee2477a2c080e916c30ced46003c5678d8f331d1ebc8056c52bcd2a2.png

Car parts#

Segmenting car parts of an example from the dataset: https://www.kaggle.com/datasets/jessicali9530/stanford-cars-dataset

 1# Prompts
 2boxes = Boxes.from_tensor(
 3    torch.tensor(
 4        [
 5            [
 6                [56.2800, 369.1100, 187.3000, 579.4300],
 7                [412.5600, 426.5800, 592.9900, 608.1600],
 8                [609.0800, 366.8200, 682.6400, 431.1700],
 9                [925.1300, 366.8200, 959.6100, 423.1300],
10                [756.1900, 416.2300, 904.4400, 473.7000],
11                [489.5600, 285.2200, 676.8900, 343.8300],
12            ]
13        ],
14        device=device,
15    ),
16    mode="xyxy",
17)
18
19keypoints = Keypoints(
20    torch.tensor(
21        [[[535.0, 227.0], [349.0, 215.0], [237.0, 219.0], [301.0, 373.0], [641.0, 397.0], [489.0, 513.0]]], device=device
22    )
23)
24labels = torch.ones(keypoints.shape[:2], device=device, dtype=torch.float32)
1# Image
2image = load_image(car_image_path, ImageLoadType.RGB32, device)
3
4# Set the image
5prompter.set_image(image)

Querying with boxes#

1predictions = prompter.predict(boxes=boxes, multimask_output=True)
1show_predictions(image, predictions, boxes=boxes)
../_images/b661405149257ab39e508101cee765c4d2d85f4054e2c2eb7fef01c49b0cdcaf.png

Querying with keypoints#

Considering N points into 1 Batch#

This way the model will kinda find the object within all the points

1predictions = prompter.predict(keypoints=keypoints, keypoints_labels=labels)
1show_predictions(image, predictions, points=(keypoints, labels))
../_images/8ebad37cb0fa0fce5803797ebfa0bbc10d04ea85f24e1ec27d9b7ec0f5231731.png
Considering 1 point into N Batch#

Prompter encoder not working for a batch of points :/

Considering 1 point for batch into N queries#

This way the model will find an object for each point

 1k = 2  # number of times/points to query
 2
 3for idx in range(min(keypoints.data.size(1), k)):
 4    print("-" * 79, f"\nQuery {idx}:")
 5    _kpts = keypoints[:, idx, ...][None, ...]
 6    _lbl = labels[:, idx, ...][None, ...]
 7
 8    predictions = prompter.predict(keypoints=_kpts, keypoints_labels=_lbl)
 9
10    show_predictions(image, predictions, points=(_kpts, _lbl))
------------------------------------------------------------------------------- 
Query 0:
../_images/ee2eecddf7017ad3b5d88b10802521493bb81294bef84db7048d3ed7efb81a5f.png
------------------------------------------------------------------------------- 
Query 1:
../_images/501795eb27f73dcf6df0b6a198dc1885c77edea293cd963f8d77b7bb0f2464d4.png

Satellite image#

Image from Sentinel-2

Product: A tile of the TCI (px of 10m). Product name: S2B_MSIL1C_20230324T130249_N0509_R095_T23KPQ_20230324T174312

 1# Prompts
 2keypoints = Keypoints(
 3    torch.tensor(
 4        [
 5            [
 6                # Urban
 7                [74.0, 104.5],
 8                [335, 110],
 9                [702, 65],
10                [636, 479],
11                [408, 820],
12                # Forest
13                [40, 425],
14                [680, 566],
15                [405, 439],
16                [73, 689],
17                [865, 460],
18                # Ocean/water
19                [981, 154],
20                [705, 714],
21                [357, 683],
22                [259, 908],
23                [1049, 510],
24            ]
25        ],
26        device=device,
27    )
28)
29labels = torch.ones(keypoints.shape[:2], device=device, dtype=torch.float32)
1# Image
2image = load_image(satellite_image_path, ImageLoadType.RGB32, device)
3
4# Set the image
5prompter.set_image(image)

Query urban points#

1# Query the prompts
2labels_to_query = labels.clone()
3labels_to_query[..., 5:] = 0
4
5predictions = prompter.predict(keypoints=keypoints, keypoints_labels=labels_to_query)
1show_predictions(image, predictions, points=(keypoints, labels_to_query))
../_images/c5ec618b63c6118f00e6f29377cc2b50a1e1df2d247657eec531dfc5454272c7.png

Query Forest points#

1# Query the prompts
2labels_to_query = labels.clone()
3labels_to_query[..., :5] = 0
4labels_to_query[..., 10:] = 0
5
6predictions = prompter.predict(keypoints=keypoints, keypoints_labels=labels_to_query)
1show_predictions(image, predictions, points=(keypoints, labels_to_query))
../_images/e2a9a44da4b6ccbaae5b57e17ebb36030931a61b55ce36c23aff51390bee1920.png

Query ocean/water points#

1# Query the prompts
2labels_to_query = labels.clone()
3labels_to_query[..., :10] = 0
4
5predictions = prompter.predict(keypoints=keypoints, keypoints_labels=labels_to_query)
1show_predictions(image, predictions, points=(keypoints, labels_to_query))
../_images/900a48259e715048956d98f7f4773b175753c009e21c26f6fb168956225a3498.png