Open in Colab

Image patch generation#

In this tutorial we are going to learn how to generate image patches using kornia.geometry components.

%%capture
!pip install kornia matplotlib
%%capture
!wget https://github.com/kornia/data/raw/main/homography/img1.ppm

First load libraries and images

%matplotlib inline
import matplotlib.pyplot as plt

import numpy as np
import cv2

import torch
import kornia as K
/home/docs/checkouts/readthedocs.org/user_builds/kornia-tutorials/envs/latest/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
def imshow(image: np.ndarray, height: int = 10, width: int = 10):
    """Utility function to plot images."""
    plt.figure(figsize=(height, width))
    plt.imshow(image)
    plt.axis('off')
    plt.show()
def imread(data_path: str) -> torch.Tensor:
    """Utility function that load an image an convert to torch."""
    # open image using OpenCV (HxWxC)
    img: np.ndarray = cv2.imread(data_path, cv2.IMREAD_COLOR)

    # cast image to torch tensor and convert to RGB
    img_t: torch.Tensor = K.utils.image_to_tensor(img, keepdim=False)  # BxCxHxW
    img_t = K.color.bgr_to_rgb(img_t)

    return img_t.float() / 255.

Load and show the original image

torch.manual_seed(0)

timg: torch.Tensor = imread('img1.ppm')

imshow(K.tensor_to_image(timg), 10, 10)
_images/6085d38bbe86c58587a7c74593aa108d0508bc5269e0b18c286c4a0f910a7ce6.png

In the following section we are going to take the original image and generate random crops of a given size.

random_crop = K.augmentation.RandomCrop((64, 64))

patch = torch.cat([random_crop(timg) for _ in range(15)], dim=-1)

imshow(K.tensor_to_image(patch[0]), 22, 22)
_images/d5e81c7aaeb7335a6eef8c0ad53f048ed41739f2c2a9851cd2e3f460254dfcda.png

Next, we will show how to crop patches and apply forth and back random geometric transformations.

# transform a patch 

random_crop = K.augmentation.RandomCrop((64, 64))
random_affine = K.augmentation.RandomAffine(
    [-15, 15], [0., 0.25], return_transform=True)

# crop
patch = random_crop(timg)

# transform and retrieve transformation
patch_affine, transformation = random_affine(patch)

# invert patch
_, _, H, W = patch.shape
patch_inv = K.geometry.warp_perspective(
    patch_affine, torch.inverse(transformation), (H, W)
)

# visualise - (original, transformed, reconstructed)
patches_vis = torch.cat([patch, patch_affine, patch_inv], dim=-1)
imshow(K.tensor_to_image(patches_vis), 15, 15)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_885/1906952233.py in <module>
      3 random_crop = K.augmentation.RandomCrop((64, 64))
      4 random_affine = K.augmentation.RandomAffine(
----> 5     [-15, 15], [0., 0.25], return_transform=True)
      6 
      7 # crop

~/checkouts/readthedocs.org/user_builds/kornia-tutorials/envs/latest/lib/python3.7/site-packages/kornia/augmentation/_2d/geometric/affine.py in __init__(self, degrees, translate, scale, shear, resample, same_on_batch, align_corners, padding_mode, p, keepdim, return_transform)
     91         return_transform: Optional[bool] = None,
     92     ) -> None:
---> 93         super().__init__(p=p, return_transform=return_transform, same_on_batch=same_on_batch, keepdim=keepdim)
     94         self._param_generator = cast(rg.AffineGenerator, rg.AffineGenerator(degrees, translate, scale, shear))
     95         self.flags = dict(

~/checkouts/readthedocs.org/user_builds/kornia-tutorials/envs/latest/lib/python3.7/site-packages/kornia/augmentation/base.py in __init__(self, return_transform, same_on_batch, p, p_batch, keepdim)
    215         if return_transform is not None:
    216             raise ValueError(
--> 217                 "`return_transform` is deprecated. Please access the transformation matrix with "
    218                 "`.transform_matrix`. For chained matrices, please use `AugmentationSequential`.",
    219             )

ValueError: `return_transform` is deprecated. Please access the transformation matrix with `.transform_matrix`. For chained matrices, please use `AugmentationSequential`.