Open in Colab

Data Augmentation Semantic Segmentation#

In this tutorial we will show how we can quickly perform data augmentation for semantic segmenation using the kornia.augmentation API.

Install and get data#

We install Kornia and some dependencies, and download a simple data sample

!pip install kornia opencv-python matplotlib
# import the libraries
%matplotlib inline
import matplotlib.pyplot as plt
import cv2
import numpy as np

import torch
import torch.nn as nn
import kornia as K
/home/docs/checkouts/ TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See
  from .autonotebook import tqdm as notebook_tqdm

Define Augmentation pipeline#

We define a class to define our augmentation API using an nn.Module

class MyAugmentation(nn.Module):
  def __init__(self):
    super(MyAugmentation, self).__init__()
    # we define and cache our operators as class members
    self.k1 = K.augmentation.ColorJitter(0.15, 0.25, 0.25, 0.25)
    self.k2 = K.augmentation.RandomAffine([-45., 45.], [0., 0.15], [0.5, 1.5], [0., 0.15])
  def forward(self, img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    # 1. apply color only in image
    # 2. apply geometric tranform
    img_out = self.k2(self.k1(img))

    # 3. infer geometry params to mask
    # TODO: this will change in future so that no need to infer params
    mask_out = self.k2(mask, self.k2._params)

    return img_out, mask_out

Load the data and apply the transforms

def load_data(data_path: str) -> torch.Tensor:
  data: np.ndarray = cv2.imread(data_path, cv2.IMREAD_COLOR)
  data_t: torch.Tensor = K.image_to_tensor(data, keepdim=False)
  data_t = K.color.bgr_to_rgb(data_t)
  data_t = K.enhance.normalize(data_t, torch.tensor(0.), torch.tensor(255.))
  img, labels = data_t[..., :571], data_t[..., 572:]
  return img, labels

# load data (B, C, H, W)
img, labels = load_data("causevic16semseg3.png")

# create augmentation instance
aug = MyAugmentation()

# apply the augmenation pipelone to our batch of data
img_aug, labels_aug = aug(img, labels)

# visualize
img_out =[img, labels], dim=-1)

# generate several samples
num_samples: int = 10

for img_id in range(num_samples):
  # generate data
  img_aug, labels_aug = aug(img, labels)
  img_out =[img_aug, labels_aug], dim=-1)

  # save data
  plt.savefig(f"img_{img_id}.png", bbox_inches='tight')
_images/257c40c4f5ba1fa29f81423ec4b95d9901f12e0881ff57c7e9b519b4874b62e0.png _images/fefc00e113d2ff4725bb2a8dcff9787edf4ce4f1b020f54a6a18bdd67e66ffd1.png _images/339db63af69402f188aacce718b2ddcfd9fd86f3cf78e78bae3ecedbc9f0e083.png _images/d00d7702f81c0da521c8ab98abce211fb10d54296958398f103dd635c41b601c.png _images/db92178a52e6f809ef702e2a6bfc986a546cce86dae015aa7fd8225b6486795d.png _images/d1969e82eb6bf66ba25cb7caaf6b01188aadcb3190a773f12a9c21fe384ba119.png _images/5c87f996c8ead5e670829b09b757d641db305051bea461ac778002f49095be2f.png _images/4305b093b5b4c3f38f08e544db1efe1a128678aa7fcd9a69dab0776f7c433ec8.png _images/fb2ea6dd5e1f81a75a1fb9b5fadd3763996e02bfe66d615d15806f0b284e3c24.png _images/c2237c3af9bc9c47dac185ba4debd40243067687164220cf233e36125138b2d4.png _images/b20aee717be37f609475a2f7ed180150be177ecc7f2fa055f7ed91b3e2a40820.png