Open in Colab

Warp image using perspective transform#

In this tutorial we are going to learn how to use the functions kornia.get_perspective_transform and kornia.warp_perspective.

Install libraries and get the data#

%%capture
!pip install kornia
%%capture
!wget https://github.com/kornia/data/raw/main/bruce.png

Import libraries and load the data#

import torch
import kornia as K

import cv2
import numpy as np
import matplotlib.pyplot as plt
/home/docs/checkouts/readthedocs.org/user_builds/kornia-tutorials/envs/latest/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
img_bgr: np.ndarray = cv2.imread('bruce.png')  # HxWxC / np.uint8

img: torch.tensor = K.image_to_tensor(img_bgr, keepdim=False)  # 1xCxHxW / torch.uint8
img = K.color.bgr_to_rgb(img)
print(img.shape)
torch.Size([1, 3, 372, 600])

Define the points to warp, compute the homography and warp#

# the source points are the region to crop corners
points_src = torch.tensor([[
    [125., 150.], [562., 40.], [562., 282.], [54., 328.],
]])

# the destination points are the image vertexes
h, w = 64, 128  # destination size
points_dst = torch.tensor([[
    [0., 0.], [w - 1., 0.], [w - 1., h - 1.], [0., h - 1.],
]])

# compute perspective transform
M: torch.tensor = K.geometry.get_perspective_transform(points_src, points_dst)

# warp the original image by the found transform
img_warp: torch.tensor = K.geometry.warp_perspective(img.float(), M, dsize=(h, w))
print(img_warp.shape)
torch.Size([1, 3, 64, 128])

Plot the warped data#

# convert back to numpy
img_np = K.tensor_to_image(img.byte())
img_warp_np: np.ndarray = K.tensor_to_image(img_warp.byte())

# draw points into original image
for i in range(4):
    center = tuple(points_src[0, i].long().numpy())
    img_np = cv2.circle(img_np.copy(), center, 5, (0, 255, 0), -1)

# create the plot
fig, axs = plt.subplots(1, 2, figsize=(16, 10))
axs = axs.ravel()

axs[0].axis('off')
axs[0].set_title('image source')
axs[0].imshow(img_np)

axs[1].axis('off')
axs[1].set_title('image destination')
axs[1].imshow(img_warp_np)
<matplotlib.image.AxesImage at 0x7f44f98d9ad0>
_images/b8d62fe01b06552c8289feee22e64abc6b4cae9579701408f78b9f4f15370e17.png