Source code for mermaid.libraries.modules.stn_nd

"""
This package implements spatial transformations in 1D, 2D, and 3D.
This is needed for the map-based registrations for example.

.. todo::
  Implement CUDA version. There is already a 2D CUDA version available (in the source directory here).
  But it needs to be extended to 1D and 3D. We also make use of a different convention for images which needs
  to be accounted for, as we use the BxCxXxYxZ image format and BxdimxXxYxZ for the maps.
"""
#TODO

from torch.nn.modules.module import Module
###########TODO temporal comment for torch1 compatability
# from mermaid.libraries.functions.stn_nd import  STNFunction_ND_BCXYZ, STNFunction_ND_BCXYZ_Compile
# from mermaid.libraries.functions.nn_interpolation import get_nn_interpolationf
################################################################3
from ..functions.stn_nd import  STNFunction_ND_BCXYZ
from functools import partial
# class STN_ND(Module):
#     """
#     Legacy code for nD spatial transforms. Ignore for now. Implements spatial transforms, but in BXYZC format.
#     """
#     def __init__(self, dim):
#         super(STN_ND, self).__init__()
#         self.dim = dim
#         """spatial dimension"""
#         self.f = STNFunction_ND( self.dim )
#         """spatial transform function"""
#     def forward(self, input1, input2):
#         """
#         Simply returns the transformed input
#
#         :param input1: image in BCXYZ format
#         :param input2: map in BdimXYZ format
#         :return: returns the transformed image
#         """
#         return self.f(input1, input2)

[docs]class STN_ND_BCXYZ(Module): """ Spatial transform code for nD spatial transoforms. Uses the BCXYZ image format. """ def __init__(self, spacing, zero_boundary=False,use_bilinear=True,use_01_input=True,use_compile_version=False): super(STN_ND_BCXYZ, self).__init__() self.spacing = spacing """spatial dimension""" if use_compile_version: if use_bilinear: self.f = STNFunction_ND_BCXYZ_Compile(self.spacing,zero_boundary) else: self.f = partial(get_nn_interpolation,spacing = self.spacing) else: self.f = STNFunction_ND_BCXYZ( self.spacing,zero_boundary= zero_boundary,using_bilinear= use_bilinear,using_01_input = use_01_input) """spatial transform function"""
[docs] def forward(self, input1, input2): """ Simply returns the transformed input :param input1: image in BCXYZ format :param input2: map in BdimXYZ format :return: returns the transformed image """ return self.f(input1, input2)