Open in Colab

Image Alignment by Homography Optimization#

In this tutorial we are going to learn how to perform the task of image alignment by optimising the homography transformation between two images.

%%capture
!pip install kornia
%%capture
!wget https://github.com/kornia/data/raw/main/homography/H1to2p
!wget https://github.com/kornia/data/raw/main/homography/img1.ppm
!wget https://github.com/kornia/data/raw/main/homography/img2.ppm

Import needed libraries

from typing import List
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# computer vision libs :D

import cv2
import kornia as K
from kornia.geometry import resize
/home/docs/checkouts/readthedocs.org/user_builds/kornia-tutorials/envs/latest/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Define the hyper parameters to perform the online optimisation

learning_rate: float = 1e-3  # the gradient optimisation update step
num_iterations: int = 100  # the number of iterations until convergence
num_levels: int = 6  # the total number of image pyramid levels
error_tol: float = 1e-8  # the optimisation error tolerance

log_interval: int = 100  # print log every N iterations
use_cuda: bool = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
print('Using ', device)
Using  cpu
def load_image(file_name: str) -> torch.Tensor:
    """Loads the image with OpenCV and converts to torch.Tensor                                      
    """
    assert os.path.isfile(file_name), "Invalid file {}".format(file_name)

    # load image with OpenCV                                                                         
    img: np.ndarray = cv2.imread(file_name, cv2.IMREAD_COLOR)

    # convert image to torch tensor                                                                  
    tensor: torch.Tensor = K.utils.image_to_tensor(img).float() / 255. # CxHxW
    return tensor[None] # 1xCxHxW

Define a container to hold the homography as a nn.Parameter so that cen be used by the autograd within the torch.optim framework.

We initialize the homography with the identity transformation.

class MyHomography(nn.Module):
    def __init__(self) -> None:
        super(MyHomography, self).__init__()
        self.homography = nn.Parameter(torch.Tensor(3, 3))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.eye_(self.homography)

    def forward(self) -> torch.Tensor:
        return torch.unsqueeze(self.homography, dim=0)  # 1x3x3  

Read the images and the ground truth homograpy to convert to tensor. In addition, we normalize the homography in order to smooth the gradiens during the optimisation process.

img_src: torch.Tensor = load_image('img1.ppm').to(device)
img_dst: torch.Tensor = load_image('img2.ppm').to(device)
print(img_src.shape)
print(img_dst.shape)

dst_homo_src_gt = np.loadtxt('H1to2p')
dst_homo_src_gt = torch.from_numpy(dst_homo_src_gt)[None].float().to(device)
print(dst_homo_src_gt.shape)
print(dst_homo_src_gt)

height, width = img_src.shape[-2:]

# warp image in normalized coordinates
normal_transform_pixel: torch.Tensor = (
    K.geometry.normal_transform_pixel(height, width, device=device)
)

dst_homo_src_gt_norm: torch.Tensor  = (
    normal_transform_pixel @ dst_homo_src_gt @ torch.inverse(normal_transform_pixel)
)

img_src_to_dst_gt: torch.Tensor = K.geometry.homography_warp(
    img_src, torch.inverse(dst_homo_src_gt_norm), (height, width))

img_src_vis: np.ndarray = K.utils.tensor_to_image(K.color.bgr_to_rgb(img_src))
img_dst_vis: np.ndarray = K.utils.tensor_to_image(K.color.bgr_to_rgb(img_dst))
img_src_to_dst_gt_vis: np.ndarray = K.utils.tensor_to_image(
    K.color.bgr_to_rgb(img_src_to_dst_gt))
torch.Size([1, 3, 640, 800])
torch.Size([1, 3, 640, 800])
torch.Size([1, 3, 3])
tensor([[[ 8.7977e-01,  3.1245e-01, -3.9431e+01],
         [-1.8389e-01,  9.3847e-01,  1.5316e+02],
         [ 1.9641e-04, -1.6015e-05,  1.0000e+00]]])

Show the source image, the target and the source image warped to the target using the ground truth homography transformation.

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True)
fig.set_figheight(15)
fig.set_figwidth(15)

ax1.imshow(img_src_vis)
ax1.set_title('Source image')

ax2.imshow(img_dst_vis)
ax2.set_title('Destination image')

ax3.imshow(img_src_to_dst_gt_vis)
ax3.set_title('Source to Destination image')
plt.show()
_images/ea07a72fd9a6286216d6d4cb3ee87b959db2f2c5caefedbe96dd9ac8e0f9dec3.png

Initialize the homography warper and pass the parameters to the torch.optim.Adam optimizer to perform an online gradient descent optimisation to approximate the mapping transformation between the two images.

# create homography parameters
dst_homo_src = MyHomography().to(device)

# create optimizer
optimizer = optim.Adam(dst_homo_src.parameters(), lr=learning_rate)

# send data to device
img_src, img_dst = img_src.to(device), img_dst.to(device)

In order to perform the online optimisation, we will apply a know fine-to-coarse strategy. For this reason, we precompute a gaussian pyramid from each image with a certain number of levels.

### compute Gaussian Pyramid

def get_gaussian_pyramid(img: torch.Tensor, num_levels: int) -> List[torch.Tensor]:
    r"""Utility function to compute a gaussian pyramid."""
    pyramid = []
    pyramid.append(img)
    for _ in range(num_levels - 1):
        img_curr = pyramid[-1]
        img_down = K.geometry.pyrdown(img_curr)
        pyramid.append(img_down)
    return pyramid

# compute the gaussian pyramids
img_src_pyr: List[torch.Tensor] = get_gaussian_pyramid(img_src, num_levels)
img_dst_pyr: List[torch.Tensor] = get_gaussian_pyramid(img_dst, num_levels)

Main optimization loop#

Define the loss function to minimize the photometric error at each pyramid level:

$ L = \sum |I_{ref} - \omega(I_{dst}, H_{ref}^{dst}))|$

def compute_scale_loss(img_src: torch.Tensor,
                       img_dst: torch.Tensor,
                       dst_homo_src: nn.Module,
                       optimizer: torch.optim,
                       num_iterations: int,
                       error_tol: float) -> torch.Tensor:
    assert len(img_src.shape) == len(img_dst.shape), (img_src.shape, img_dst.shape)
    
    # init loop parameters
    loss_tol = torch.tensor(error_tol)
    loss_prev = torch.finfo(img_src.dtype).max
    
    for i in range(num_iterations):
        # create homography warper
        src_homo_dst: torch.Tensor = torch.inverse(dst_homo_src)

        _height, _width = img_src.shape[-2:]
        warper = K.geometry.HomographyWarper(_height, _width)
        img_src_to_dst = warper(img_src, src_homo_dst)

        # compute and mask loss
        loss = F.l1_loss(img_src_to_dst, img_dst, reduction='none') # 1x3xHxW

        ones = warper(torch.ones_like(img_src), src_homo_dst)
        loss = loss.masked_select((ones > 0.9)).mean()

        # compute gradient and update optimizer parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Run the main body loop to warp the images from each pyramid level and evaluate the loss to perform gradient update.

# pyramid loop

for iter_idx in range(num_levels):
    # get current pyramid data
    scale: int = (num_levels - 1) - iter_idx
    img_src = img_src_pyr[scale]
    img_dst = img_dst_pyr[scale]

    # compute scale loss
    compute_scale_loss(img_src, img_dst, dst_homo_src(),
                       optimizer, num_iterations, error_tol)

    print('Optimization iteration: {}/{}'.format(iter_idx, num_levels))
       
    # merge warped and target image for visualization
    h, w = img_src.shape[-2:]
    warper = K.geometry.HomographyWarper(h, w)
    img_src_to_dst = warper(img_src, torch.inverse(dst_homo_src()))
    img_src_to_dst_gt = warper(img_src, torch.inverse(dst_homo_src_gt_norm))

    # compute the reprojection error
    error = F.l1_loss(img_src_to_dst, img_src_to_dst_gt, reduction='none')
    print('Reprojection error: {}'.format(error.mean()))
    
    # show data
    img_src_vis = K.utils.tensor_to_image(K.color.bgr_to_rgb(img_src))
    img_dst_vis = K.utils.tensor_to_image(K.color.bgr_to_rgb(img_dst))
    img_src_to_dst_merge = 0.65 * img_src_to_dst + 0.35 * img_dst
    img_src_to_dst_vis = K.utils.tensor_to_image(K.color.bgr_to_rgb(img_src_to_dst_merge))
    img_src_to_dst_gt_vis = K.utils.tensor_to_image(K.color.bgr_to_rgb(img_src_to_dst_gt))
    
    error_sum = error.mean(dim=1, keepdim=True)
    error_vis = K.utils.tensor_to_image(error_sum)

    # show the original images at each scale level, the result of warping using
    # the homography at moment, and the estimated error against the GT homography.

    %matplotlib inline
    fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5, sharey=False)
    fig.set_figheight(15)
    fig.set_figwidth(15)

    ax1.imshow(img_src_vis)
    ax1.set_title('Source image')

    ax2.imshow(img_dst_vis)
    ax2.set_title('Destination image')

    ax3.imshow(img_src_to_dst_vis)
    ax3.set_title('Source to Destination image')
    
    ax4.imshow(img_src_to_dst_gt_vis)
    ax4.set_title('Source to Destination image GT')
    
    ax5.imshow(error_vis, cmap='gray', vmin=0, vmax=1)
    ax5.set_title('Error')
    plt.show()
Optimization iteration: 0/6
Reprojection error: 0.1722082793712616
_images/1abce50f6a5923a3484df54e8a7b2e4448600ba450798ffad14371f10b5e2a09.png
Optimization iteration: 1/6
Reprojection error: 0.11617553979158401
_images/6fd7ff2396130cfa784a813917ee5fb045750bb5ffa84e516ca830ec87510d05.png
Optimization iteration: 2/6
Reprojection error: 0.01905934512615204
_images/ab0a3a1ed5a93581eaa5301a82ddec5b199e4048b42f0ac18ba95ee2bc12c320.png
Optimization iteration: 3/6
Reprojection error: 0.013178573921322823
_images/1e70128e453e2dd02984bbfa637758c09fcc6b00ba3200a5287d51d2cc252dcb.png
Optimization iteration: 4/6
Reprojection error: 0.008066886104643345
_images/30fadc7e7ed46963e03063e2443442160ce54b5b66224232e5029e45da8d155c.png
Optimization iteration: 5/6
Reprojection error: 0.0053191822953522205
_images/e16aa3ee1a08b4929911578471fabcaff57a4ed84987fccf80a466a49f27a387.png