Todo

This documentation needs to be updated.

How to write your own registration model

This note explains how to write your own registration model. For simplicity we assume that the stationary velocity field registration (SVF) does not exist. Here we explain, step-by-step, how to recreate it. We also assume that a new similarity measure should be added. We will start with the similarity measure as it is the easiest to implement.

Writing a similarity measure

Similarity measures are derived from SimilarityMeasure. To create a new similarity measure simply derive your own class from SimilarityMeasure and implement the method compute_similarity(). Assume you want to rewrite the sum of squared difference SSD similarity measure then this class will look as follows:

class MySSD(SimilarityMeasure):
    def compute_similarity(self,I0,I1,I0Source=None,phi=None):
        sigma = 0.1
        return ((I0 - I1) ** 2).sum() / (sigma**2) * self.volumeElement

Here, self.volumeElement is defined in the base class SimilarityMeasure and indicates the volume occupied by a pixel or voxel.

As the machinery to include the similarity measure into all available registration methods is rather heavy, there is a convenience method which can be accessed through the optimizer interface.

Assuming the parameter stucture being used is called params (a ParameterDict object), we can first tell that we want to use our own similarity measure via

params['registration_model']['similarity_measure']['type'] = 'mySSD'

Now, once we have a multi-scale optimizer

import mermaid.multiscale_optimizer as MO
mo = MO.MultiScaleRegistrationOptimizer(modelName,sz,spacing,useMap,mapLowResFactor,params)

we can simply instruct it to use our new similarity measure

mo.add_similarity_measure('mySSD', MySSD)

This will propagate through all the registration models. Hence, all of them will instantly be able to use the new similarity measure.

Writing a new registration model

The goal of this package is to make writing new models as easy as possible, while still providing an as simple to use package as possible. These are obviously somewhat contradictory goals. As a compromise, there is also a relatively easy interface which allows definitions of new models without integrating them into the overall machinery.

Let’s first import a few packages that are needed to write the new network module

import registration_networks as RN
import utils
import image_sampling as IS
import rungekutta_integrators as RK
import forward_models as FM
import regularizer_factory as RF

A new network is derived from the abstract class RegistrationNet. To create a working new class, it is required to define the following methods:

  • create_registration_parameters(): To set up the registration parameters required by the model. Needs to be torch Parameter type as defined in torch.autograd
  • get_registration_parameter(): simply return the registration parameter
  • set_registration_paramters(): to set the parameters, will be needed by the multi-scale optimizer to propagate parameters from one level to the next.
  • create_integrator(): since we are dealing with time-dependent problems here, this is to set up (and return!) an integrator for the system that is to be solved.
  • forward(): this is the method where all the magic happens. I.e., where we solve the forward problem by integrating the model forward in time.
  • upsample_registration_parameters(): method to spatially upsample the registration parameters. Needs to be defined if the multi-scale solver should be used. Does not need to be defined when solving on a single scale.

Let’s start with the simplest possible class first

class MySVFNet(RN.RegistrationNet):
    def __init__(self,sz,spacing,params):
        super(MySVFNet, self).__init__(sz,spacing,params)
        self.v = self.create_registration_parameters()
        self.integrator = self.create_integrator()

    def create_registration_parameters(self):
        return utils.create_ND_vector_field_parameter_multiN(self.sz[2::], self.nrOfImages)

    def get_registration_parameters(self):
        return self.v

    def set_registration_parameters(self, p, sz, spacing):
        self.v.data = p.data
        self.sz = sz
        self.spacing = spacing

    def create_integrator(self):
        cparams = self.params[('forward_model',{},'settings for the forward model')]
        advection = FM.AdvectImage(self.sz, self.spacing)
        return RK.RK4(advection.f, advection.u, self.v, cparams)

    def forward(self, I):
        I1 = self.integrator.solve([I], self.tFrom, self.tTo)
        return I1[0]

If desired (for the multi-scale optimizer), also define

def upsample_registration_parameters(self, desiredSz):
    sampler = IS.ResampleImage()
    vUpsampled,upsampled_spacing=sampler.upsample_image_to_size(self.v,self.spacing,desiredSz)
    return vUpsampled,upsampled_spacing

Lastly, we also need to define our own loss function. Loss functions are derived from RegistrationImageLoss or RegistrationMapLoss depending on if the source image is warped directly or via a coordinate map. The only method that needs to be defined is compute_regularization_energy(). For the SVF model we just created this could for example look like this

class MySVFImageLoss(RN.RegistrationImageLoss):
def __init__(self,v,sz,spacing,params):
    super(MySVFImageLoss, self).__init__(sz,spacing,params)
    self.v = v
    cparams = params[('loss',{},'settings for the loss function')]
    self.regularizer = (RF.RegularizerFactory(self.spacing).
                        create_regularizer(cparams))

def compute_regularization_energy(self, I0_source):
    return self.regularizer.compute_regularizer_multiN(self.v)

Now that the models are defined, we need to use them. Just as for the custom similarity measure above, we can do this by adding it to the multi-scale solver and then setting it (to be used for the solution).

myModelName = 'mySVF'
mo.add_model(myModelName,MySVFNet,MySVFImageLoss)
mo.set_model(myModelName)

If desired, it is possible to choose a custom optimizer (the default is LBFGS, with some default settings). The following selects adam as an optimizer and sets one of its optimization parameters. Any optimizer supported by pyTorch works in principle. However, be advised that especially the shooting formulations for registration may require reasonably sophisticated optimizers for convergence.

mo.set_optimizer(torch.optim.Adam)
mo.set_optimizer_params(dict(lr=0.01))

By default visualization output is turned on. But this can be set manually by

mo.set_visualization(True)
mo.set_visualize_step(10)

And again as before the model can then be solved

mo.set_source_image(ISource)
mo.set_target_image(ITarget)

mo.set_scale_factors([1.0, 0.5, 0.25])
mo.set_number_of_iterations_per_scale([5, 10, 10])

mo.optimize()