Open in Colab

Image Alignment by Homography Optimization

In this tutorial we are going to learn how to perform the task of image alignment by optimising the homography transformation between two images.

%%capture
!pip install kornia
%%capture
!wget https://github.com/kornia/data/raw/main/homography/H1to2p
!wget https://github.com/kornia/data/raw/main/homography/img1.ppm
!wget https://github.com/kornia/data/raw/main/homography/img2.ppm

Import needed libraries

from typing import List
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# computer vision libs :D

import cv2
import kornia as K
from kornia.geometry import resize

Define the hyper parameters to perform the online optimisation

learning_rate: float = 1e-3  # the gradient optimisation update step
num_iterations: int = 100  # the number of iterations until convergence
num_levels: int = 6  # the total number of image pyramid levels
error_tol: float = 1e-8  # the optimisation error tolerance

log_interval: int = 100  # print log every N iterations
use_cuda: bool = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
print('Using ', device)
Using  cpu
def load_image(file_name: str) -> torch.Tensor:
    """Loads the image with OpenCV and converts to torch.Tensor                                      
    """
    assert os.path.isfile(file_name), "Invalid file {}".format(file_name)

    # load image with OpenCV                                                                         
    img: np.ndarray = cv2.imread(file_name, cv2.IMREAD_COLOR)

    # convert image to torch tensor                                                                  
    tensor: torch.Tensor = K.utils.image_to_tensor(img).float() / 255. # CxHxW
    return tensor[None] # 1xCxHxW

Define a container to hold the homography as a nn.Parameter so that cen be used by the autograd within the torch.optim framework.

We initialize the homography with the identity transformation.

class MyHomography(nn.Module):
    def __init__(self) -> None:
        super(MyHomography, self).__init__()
        self.homography = nn.Parameter(torch.Tensor(3, 3))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.eye_(self.homography)

    def forward(self) -> torch.Tensor:
        return torch.unsqueeze(self.homography, dim=0)  # 1x3x3  

Read the images and the ground truth homograpy to convert to tensor. In addition, we normalize the homography in order to smooth the gradiens during the optimisation process.

img_src: torch.Tensor = load_image('img1.ppm').to(device)
img_dst: torch.Tensor = load_image('img2.ppm').to(device)
print(img_src.shape)
print(img_dst.shape)

dst_homo_src_gt = np.loadtxt('H1to2p')
dst_homo_src_gt = torch.from_numpy(dst_homo_src_gt)[None].float().to(device)
print(dst_homo_src_gt.shape)
print(dst_homo_src_gt)

height, width = img_src.shape[-2:]

# warp image in normalized coordinates
normal_transform_pixel: torch.Tensor = (
    K.geometry.normal_transform_pixel(height, width, device=device)
)

dst_homo_src_gt_norm: torch.Tensor  = (
    normal_transform_pixel @ dst_homo_src_gt @ torch.inverse(normal_transform_pixel)
)

img_src_to_dst_gt: torch.Tensor = K.geometry.homography_warp(
    img_src, torch.inverse(dst_homo_src_gt_norm), (height, width))

img_src_vis: np.ndarray = K.utils.tensor_to_image(K.color.bgr_to_rgb(img_src))
img_dst_vis: np.ndarray = K.utils.tensor_to_image(K.color.bgr_to_rgb(img_dst))
img_src_to_dst_gt_vis: np.ndarray = K.utils.tensor_to_image(
    K.color.bgr_to_rgb(img_src_to_dst_gt))
torch.Size([1, 3, 640, 800])
torch.Size([1, 3, 640, 800])
torch.Size([1, 3, 3])
tensor([[[ 8.7977e-01,  3.1245e-01, -3.9431e+01],
         [-1.8389e-01,  9.3847e-01,  1.5316e+02],
         [ 1.9641e-04, -1.6015e-05,  1.0000e+00]]])
/home/docs/checkouts/readthedocs.org/user_builds/kornia-tutorials/envs/latest/lib/python3.7/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2157.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]

Show the source image, the target and the source image warped to the target using the ground truth homography transformation.

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True)
fig.set_figheight(15)
fig.set_figwidth(15)

ax1.imshow(img_src_vis)
ax1.set_title('Source image')

ax2.imshow(img_dst_vis)
ax2.set_title('Destination image')

ax3.imshow(img_src_to_dst_gt_vis)
ax3.set_title('Source to Destination image')
plt.show()
_images/homography_13_0.png

Initialize the homography warper and pass the parameters to the torch.optim.Adam optimizer to perform an online gradient descent optimisation to approximate the mapping transformation between the two images.

# create homography parameters
dst_homo_src = MyHomography().to(device)

# create optimizer
optimizer = optim.Adam(dst_homo_src.parameters(), lr=learning_rate)

# send data to device
img_src, img_dst = img_src.to(device), img_dst.to(device)

In order to perform the online optimisation, we will apply a know fine-to-coarse strategy. For this reason, we precompute a gaussian pyramid from each image with a certain number of levels.

### compute Gaussian Pyramid

def get_gaussian_pyramid(img: torch.Tensor, num_levels: int) -> List[torch.Tensor]:
    r"""Utility function to compute a gaussian pyramid."""
    pyramid = []
    pyramid.append(img)
    for _ in range(num_levels - 1):
        img_curr = pyramid[-1]
        img_down = K.geometry.pyrdown(img_curr)
        pyramid.append(img_down)
    return pyramid

# compute the gaussian pyramids
img_src_pyr: List[torch.Tensor] = get_gaussian_pyramid(img_src, num_levels)
img_dst_pyr: List[torch.Tensor] = get_gaussian_pyramid(img_dst, num_levels)

Main optimization loop

Define the loss function to minimize the photometric error at each pyramid level:

\( L = \sum |I_{ref} - \omega(I_{dst}, H_{ref}^{dst}))|\)

def compute_scale_loss(img_src: torch.Tensor,
                       img_dst: torch.Tensor,
                       dst_homo_src: nn.Module,
                       optimizer: torch.optim,
                       num_iterations: int,
                       error_tol: float) -> torch.Tensor:
    assert len(img_src.shape) == len(img_dst.shape), (img_src.shape, img_dst.shape)
    
    # init loop parameters
    loss_tol = torch.tensor(error_tol)
    loss_prev = torch.finfo(img_src.dtype).max
    
    for i in range(num_iterations):
        # create homography warper
        src_homo_dst: torch.Tensor = torch.inverse(dst_homo_src)

        _height, _width = img_src.shape[-2:]
        warper = K.geometry.HomographyWarper(_height, _width)
        img_src_to_dst = warper(img_src, src_homo_dst)

        # compute and mask loss
        loss = F.l1_loss(img_src_to_dst, img_dst, reduction='none') # 1x3xHxW

        ones = warper(torch.ones_like(img_src), src_homo_dst)
        loss = loss.masked_select((ones > 0.9)).mean()

        # compute gradient and update optimizer parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Run the main body loop to warp the images from each pyramid level and evaluate the loss to perform gradient update.

# pyramid loop

for iter_idx in range(num_levels):
    # get current pyramid data
    scale: int = (num_levels - 1) - iter_idx
    img_src = img_src_pyr[scale]
    img_dst = img_dst_pyr[scale]

    # compute scale loss
    compute_scale_loss(img_src, img_dst, dst_homo_src(),
                       optimizer, num_iterations, error_tol)

    print('Optimization iteration: {}/{}'.format(iter_idx, num_levels))
       
    # merge warped and target image for visualization
    h, w = img_src.shape[-2:]
    warper = K.geometry.HomographyWarper(h, w)
    img_src_to_dst = warper(img_src, torch.inverse(dst_homo_src()))
    img_src_to_dst_gt = warper(img_src, torch.inverse(dst_homo_src_gt_norm))

    # compute the reprojection error
    error = F.l1_loss(img_src_to_dst, img_src_to_dst_gt, reduction='none')
    print('Reprojection error: {}'.format(error.mean()))
    
    # show data
    img_src_vis = K.utils.tensor_to_image(K.color.bgr_to_rgb(img_src))
    img_dst_vis = K.utils.tensor_to_image(K.color.bgr_to_rgb(img_dst))
    img_src_to_dst_merge = 0.65 * img_src_to_dst + 0.35 * img_dst
    img_src_to_dst_vis = K.utils.tensor_to_image(K.color.bgr_to_rgb(img_src_to_dst_merge))
    img_src_to_dst_gt_vis = K.utils.tensor_to_image(K.color.bgr_to_rgb(img_src_to_dst_gt))
    
    error_sum = error.mean(dim=1, keepdim=True)
    error_vis = K.utils.tensor_to_image(error_sum)

    # show the original images at each scale level, the result of warping using
    # the homography at moment, and the estimated error against the GT homography.

    %matplotlib inline
    fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5, sharey=False)
    fig.set_figheight(15)
    fig.set_figwidth(15)

    ax1.imshow(img_src_vis)
    ax1.set_title('Source image')

    ax2.imshow(img_dst_vis)
    ax2.set_title('Destination image')

    ax3.imshow(img_src_to_dst_vis)
    ax3.set_title('Source to Destination image')
    
    ax4.imshow(img_src_to_dst_gt_vis)
    ax4.set_title('Source to Destination image GT')
    
    ax5.imshow(error_vis, cmap='gray', vmin=0, vmax=1)
    ax5.set_title('Error')
    plt.show()
Optimization iteration: 0/6
Reprojection error: 0.1722082793712616
_images/homography_21_1.png
Optimization iteration: 1/6
Reprojection error: 0.11617553979158401
_images/homography_21_3.png
Optimization iteration: 2/6
Reprojection error: 0.01905934512615204
_images/homography_21_5.png
Optimization iteration: 3/6
Reprojection error: 0.013178573921322823
_images/homography_21_7.png
Optimization iteration: 4/6
Reprojection error: 0.008066886104643345
_images/homography_21_9.png
Optimization iteration: 5/6
Reprojection error: 0.0053191822953522205
_images/homography_21_11.png