"""
This package implements pytorch functions for Fourier-based convolutions.
While this may not be relevant for GPU-implementations, convolutions in the spatial domain are slow on CPUs. Hence, this function should be useful for memory-intensive models that need to be run on the CPU or CPU-based computations involving convolutions in general.
.. todo::
Create a CUDA version of these convolutions functions. There is already a CUDA based FFT implementation available which could be built upon. Alternatively, spatial smoothing may be sufficiently fast on the GPU.
"""
from __future__ import print_function
from __future__ import absolute_import
# TODO
from builtins import range
from builtins import object
import torch
from torch.autograd import Function
import numpy as np
from torch.autograd import gradcheck
from .data_wrapper import USE_CUDA, FFTVal,AdaptVal, MyTensor
# if USE_CUDA:
# import pytorch_fft.fft as fft
from . import utils
def _symmetrize_filter_center_at_zero_1D(filter):
sz = filter.shape
if sz[0] % 2 == 0:
# symmetrize if it is even
filter[1:sz[0] // 2] = filter[sz[0]:sz[0] // 2:-1]
else:
# symmetrize if it is odd
filter[1:sz[0] // 2 + 1] = filter[sz[0]:sz[0] // 2:-1]
def _symmetrize_filter_center_at_zero_2D(filter):
sz = filter.shape
if sz[0] % 2 == 0:
# symmetrize if it is even
filter[1:sz[0] // 2,:] = filter[sz[0]:sz[0] // 2:-1,:]
else:
# symmetrize if it is odd
filter[1:sz[0] // 2 + 1,:] = filter[sz[0]:sz[0] // 2:-1,:]
if sz[1] % 2 == 0:
# symmetrize if it is even
filter[:,1:sz[1] // 2] = filter[:,sz[1]:sz[1] // 2:-1]
else:
# symmetrize if it is odd
filter[:,1:sz[1] // 2 + 1] = filter[:,sz[1]:sz[1] // 2:-1]
def _symmetrize_filter_center_at_zero_3D(filter):
sz = filter.shape
if sz[0] % 2 == 0:
# symmetrize if it is even
filter[1:sz[0] // 2,:,:] = filter[sz[0]:sz[0] // 2:-1,:,:]
else:
# symmetrize if it is odd
filter[1:sz[0] // 2 + 1,:,:] = filter[sz[0]:sz[0] // 2:-1,:,:]
if sz[1] % 2 == 0:
# symmetrize if it is even
filter[:,1:sz[1] // 2,:] = filter[:,sz[1]:sz[1] // 2:-1,:]
else:
# symmetrize if it is odd
filter[:,1:sz[1] // 2 + 1,:] = filter[:,sz[1]:sz[1] // 2:-1,:]
if sz[2] % 2 == 0:
# symmetrize if it is even
filter[:,:,1:sz[2] // 2] = filter[:,:,sz[2]:sz[2] // 2:-1]
else:
# symmetrize if it is odd
filter[:,:,1:sz[2] // 2 + 1] = filter[:,:,sz[2]:sz[2] // 2:-1]
[docs]def symmetrize_filter_center_at_zero(filter,renormalize=False):
"""
Symmetrizes filter. The assumption is that the filter is already in the format for input to an FFT.
I.e., that it has been transformed so that the center of the pixel is at zero.
:param filter: Input filter (in spatial domain). Will be symmetrized (i.e., will change its value)
:param renormalize: (bool) if true will normalize so that the sum is one
:return: n/a (returns via call by reference)
"""
sz = filter.shape
dim = len(sz)
if dim==1:
_symmetrize_filter_center_at_zero_1D(filter)
elif dim==2:
_symmetrize_filter_center_at_zero_2D(filter)
elif dim==3:
_symmetrize_filter_center_at_zero_3D(filter)
else:
raise ValueError('Only implemented for dimensions 1,2, and 3 so far')
if renormalize:
filter = filter / filter.sum()
[docs]def are_indices_close(loc):
"""
This function takes a set of indices (as produced by np.where) and determines
if they are roughly closeby. If not it returns *False* otherwise *True*.
:param loc: Index locations as outputted by np.where
:return: Returns if the indices are roughly closeby or not
.. todo::
There should be a better check for closeness of points. The implemented one is very crude.
"""
# TODO: potentially do a better check here, this one is very crude
for cloc in loc:
cMaxDist = (abs(cloc - cloc.max())).max()
if cMaxDist > 2:
return False
return True
[docs]def create_complex_fourier_filter(spatial_filter, sz, enforceMaxSymmetry=True, maxIndex=None, renormalize=False):
"""
Creates a filter in the Fourier domain given a spatial array defining the filter
:param spatial_filter: Array defining the filter.
:param sz: Desired size of the filter in the Fourier domain.
:param enforceMaxSymmetry: If set to *True* (default) forces the filter to be real and hence forces the filter
in the spatial domain to be symmetric
:param maxIndex: specifies the index of the maximum which will be used to enforceMaxSymmetry. If it is not
defined, the maximum is simply computed
:param renormalize: (bool) if true, the filter is renormalized to sum to one (useful for Gaussians for example)
:return: Returns the complex coefficients for the filter in the Fourier domain and the maxIndex
"""
# we assume this is a spatial filter, F, hence conj(F(w))=F(-w)
sz = np.array(sz)
if enforceMaxSymmetry:
if maxIndex is None:
maxIndex = np.unravel_index(np.argmax(spatial_filter), spatial_filter.shape)
maxValue = spatial_filter[maxIndex]
loc = np.where(spatial_filter == maxValue)
nrOfMaxValues = len(loc[0])
if nrOfMaxValues > 1:
# now need to check if they are close to each other
if not are_indices_close(loc):
raise ValueError('Cannot enforce max symmetry as maximum is not unique')
spatial_filter_max_at_zero = np.roll(spatial_filter, -np.array(maxIndex),
list(range(len(spatial_filter.shape))))
symmetrize_filter_center_at_zero(spatial_filter_max_at_zero,renormalize=renormalize)
# we assume this is symmetric and hence take the absolute value
# as the FT of a symmetric kernel has to be real
if USE_CUDA:
f_filter = create_cuda_filter(spatial_filter_max_at_zero, sz)
ret_filter = f_filter[...,0] # only the real part
else:
f_filter = create_numpy_filter(spatial_filter_max_at_zero, sz)
ret_filter = f_filter.real
return ret_filter,maxIndex
else:
if USE_CUDA:
return create_cuda_filter(spatial_filter),maxIndex
else:
return create_numpy_filter(spatial_filter, sz),maxIndex
[docs]def create_cuda_filter(spatial_filter, sz):
"""
create cuda version filter, another one dimension is added to the output for computational convenient
besides the output will not be full complex result of shape (∗,2),
where ∗ is the shape of input, but instead the last dimension will be halfed as of size ⌊Nd/2⌋+1.
:param spatial_filter: N1 x...xNd, no batch dimension, no channel dimension
:param sz: [N1,..., Nd]
:return: filter, with size [1,N1,..Nd-1,⌊Nd/2⌋+1,2⌋
"""
fftn = torch.rfft
spatial_filter_th = torch.from_numpy(spatial_filter).float().cuda()
spatial_filter_th = spatial_filter_th[None, ...]
spatial_filter_th_fft = fftn(spatial_filter_th, len(sz))
return spatial_filter_th_fft
[docs]def create_numpy_filter(spatial_filter, sz):
return np.fft.fftn(spatial_filter, s=sz)
# todo: maybe check if we can use rfft's here for better performance
[docs]def sel_fftn(dim):
"""
sel the gpu and cpu version of the fft
:param dim:
:return: function pointer
"""
if USE_CUDA:
if dim in[1,2,3]:
f= torch.rfft
else:
print('Warning, fft more than 3d is supported but not tested')
return f
else:
if dim == 1:
f = np.fft.fft
elif dim == 2:
f = np.fft.fft2
elif dim == 3:
f = np.fft.fftn
else:
raise ValueError('Only 3D cpu ifft supported')
return f
[docs]def sel_ifftn(dim):
"""
select the cpu and gpu version of the ifft
:param dim:
:return: function pointer
"""
if USE_CUDA:
if dim in [1,2,3]:
f = torch.irfft
else:
print('Warning, fft more than 3d is supported but not tested')
else:
if dim == 1:
f = np.fft.ifft
elif dim == 2:
f = np.fft.ifft2
elif dim == 3:
f = np.fft.ifftn
else:
raise ValueError('Only 3D cpu ifft supported')
return f
[docs]class FourierConvolution(Function):
"""
pyTorch function to compute convolutions in the Fourier domain: f = g*h
"""
def __init__(self, complex_fourier_filter):
"""
Constructor for the Fouier-based convolution
:param complex_fourier_filter: Filter in the Fourier domain as created by *createComplexFourierFilter*
"""
# we assume this is a spatial filter, F, hence conj(F(w))=F(-w)
super(FourierConvolution, self).__init__()
self.complex_fourier_filter = complex_fourier_filter
if USE_CUDA:
self.dim = complex_fourier_filter.dim() -1
else:
self.dim = len(complex_fourier_filter.shape)
self.fftn = sel_fftn(self.dim)
self.ifftn = sel_ifftn(self.dim)
"""The filter in the Fourier domain"""
[docs] def forward(self, input):
"""
Performs the Fourier-based filtering
the 3d cpu fft is not implemented in fftn, to avoid fusing with batch and channel, here 3d is calcuated in loop
1d 2d cpu works well because fft and fft2 is inbuilt, similarly , 1d 2d 3d gpu fft also is inbuilt
in gpu implementation, the rfft is used for efficiency, which means the filter should be symmetric
(input_real+input_img)(filter_real+filter_img) = (input_real*filter_real-input_img*filter_img) + (input_img*filter_real+input_real*filter_img)i
filter_img =0, then get input_real*filter_real + (input_img*filter_real)i ac + bci
:param input: Image
:return: Filtered-image
"""
if USE_CUDA:
input = FFTVal(input,ini=1)
f_input = self.fftn(input,self.dim,onesided=True)
f_filter_real = self.complex_fourier_filter[0]
f_filter_real=f_filter_real.expand_as(f_input[...,0])
f_filter_real = torch.stack((f_filter_real,f_filter_real),-1)
f_conv = f_input * f_filter_real
dim_input = len(input.shape)
dim_input_batch = dim_input-self.dim
conv_ouput_real = self.ifftn(f_conv, self.dim,onesided=True,signal_sizes=input.shape[dim_input_batch::])
result = conv_ouput_real
return FFTVal(result, ini=-1)
else:
if self.dim <3:
conv_output = self.ifftn(self.fftn(input.detach().cpu().numpy()) * self.complex_fourier_filter)
result = conv_output.real # should in principle be real
elif self.dim==3:
result = np.zeros(input.shape)
for batch in range(input.size()[0]):
for ch in range(input.size()[1]):
conv_output = self.ifftn(self.fftn(input[batch,ch].detach().cpu().numpy()) * self.complex_fourier_filter)
result[batch,ch] = conv_output.real
else:
raise ValueError("cpu fft smooth should be 1d-3d")
return torch.FloatTensor(result)
# print( 'max(imag) = ' + str( (abs( conv_output.imag )).max() ) )
# print( 'max(real) = ' + str( (abs( conv_output.real )).max() ) )
# This function has only a single output, so it gets only one gradient
[docs] def backward(self, grad_output):
"""
Computes the gradient
the 3d cpu ifft is not implemented in ifftn, to avoid fusing with batch and channel, here 3d is calcuated in loop
1d 2d cpu works well because ifft and ifft2 is inbuilt, similarly , 1d 2d 3d gpu fft also is inbuilt
in gpu implementation, the irfft is used for efficiency, which means the filter should be symmetric
:param grad_output: Gradient output of previous layer
:return: Gradient including the Fourier-based convolution
"""
# Initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
grad_input = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
# (a+bi)(c+di) = (ac-bd) + (bc+ad)i
# input_imag =0, then get ac + bci
if USE_CUDA:
grad_output = FFTVal(grad_output, ini=1)
#print grad_output.view(-1,1).sum()
f_go = self.fftn(grad_output,self.dim,onesided=True)
f_filter_real = self.complex_fourier_filter[0]
f_filter_real = f_filter_real.expand_as(f_go[..., 0])
f_filter_real = torch.stack((f_filter_real, f_filter_real), -1)
f_conv = f_go * f_filter_real
dim_input = len(grad_output.shape)
dim_input_batch = dim_input - self.dim
grad_input = self.ifftn(f_conv,self.dim,onesided=True,signal_sizes=grad_output.shape[dim_input_batch::])
# print(grad_input)
# print((grad_input[0,0,12:15]))
return FFTVal(grad_input, ini=-1)
else:
# if self.needs_input_grad[0]:
numpy_go = grad_output.detach().cpu().numpy()
# we use the conjugate because the assumption was that the spatial filter is real
# THe following two lines should be correct
if self.dim < 3:
grad_input_c = (self.ifftn(np.conjugate(self.complex_fourier_filter) * self.fftn(numpy_go)))
grad_input = grad_input_c.real
elif self.dim == 3:
grad_input = np.zeros(numpy_go.shape)
assert grad_output.dim() == 5 # to ensure the behavior correct, we avoid more than 3 dimension fftn method
for batch in range(grad_output.size()[0]):
for ch in range(grad_output.size()[1]):
grad_input_c = (self.ifftn(np.conjugate(self.complex_fourier_filter) *self.fftn(numpy_go[batch,ch])))
grad_input[batch,ch] = grad_input_c.real
else:
raise ValueError("cpu fft smooth should be 1d-3d")
# print(grad_input)
# print((grad_input[0,0,12:15]))
return torch.FloatTensor(grad_input)
# print( 'grad max(imag) = ' + str( (abs( grad_input_c.imag )).max() ) )
# print( 'grad max(real) = ' + str( (abs( grad_input_c.real )).max() ) )
[docs]class InverseFourierConvolution(Function):
"""
pyTorch function to compute convolutions in the Fourier domain: f = g*h
But uses the inverse of the smoothing filter
"""
def __init__(self, complex_fourier_filter):
"""
Constructor for the Fouier-based convolution (WARNING: EXPERIMENTAL)
:param complex_fourier_filter: Filter in the Fourier domain as created by *createComplexFourierFilter*
"""
# we assume this is a spatial filter, F, hence conj(F(w))=F(-w)
super(InverseFourierConvolution, self).__init__()
self.complex_fourier_filter = complex_fourier_filter
if USE_CUDA:
self.dim = complex_fourier_filter.dim() - 1
else:
self.dim = len(complex_fourier_filter.shape)
self.fftn = sel_fftn(self.dim)
self.ifftn = sel_ifftn(self.dim)
"""Fourier filter"""
self.alpha = 0.1
"""Regularizing weight"""
[docs] def set_alpha(self, alpha):
"""
Sets the regularizing weight
:param alpha: regularizing weight
"""
self.alpha = alpha
[docs] def get_alpha(self):
"""
Returns the regularizing weight
:return: regularizing weight
"""
return self.alpha
[docs] def forward(self, input):
"""
Performs the Fourier-based filtering
:param input: Image
:return: Filtered-image
"""
# do the filtering in the Fourier domain
# (a+bi)/(c) = (a/c) + (b/c)i
if USE_CUDA:
input = FFTVal(input, ini=1)
f_input = self.fftn(input,self.dim,onesided=True)
f_filter_real = self.complex_fourier_filter[0]
f_filter_real += self.alpha
f_filter_real = f_filter_real.expand_as(f_input[..., 0])
f_filter_real = torch.stack((f_filter_real, f_filter_real), -1)
f_conv = f_input/f_filter_real
dim_input = len(input.shape)
dim_input_batch = dim_input - self.dim
conv_ouput_real = self.ifftn(f_conv,self.dim,onesided=True,signal_sizes=input.shape[dim_input_batch::])
result = conv_ouput_real
return FFTVal(result, ini=-1)
else:
result = np.zeros(input.shape)
if self.dim <3:
conv_output = self.ifftn(self.fftn(input.detach().cpu().numpy()) / (self.alpha + self.complex_fourier_filter))
# result = abs(conv_output) # should in principle be real
result = conv_output.real
elif self.dim == 3:
result = np.zeros(input.shape)
for batch in range(input.size()[0]):
for ch in range(input.size()[1]):
conv_output = self.ifftn(
self.fftn(input[batch,ch].detach().cpu().numpy()) / (self.alpha + self.complex_fourier_filter))
result[batch, ch] = conv_output.real
else:
raise ValueError("cpu fft smooth should be 1d-3d")
return torch.FloatTensor(result)
# This function has only a single output, so it gets only one gradient
[docs] def backward(self, grad_output):
"""
Computes the gradient
:param grad_output: Gradient output of previous layer
:return: Gradient including the Fourier-based convolution
"""
# Initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
grad_input = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
# if self.needs_input_grad[0]:
if USE_CUDA:
grad_output =FFTVal(grad_output, ini=1)
f_go = self.fftn(grad_output, self.dim, onesided=True)
f_filter_real = self.complex_fourier_filter[0]
f_filter_real += self.alpha
f_filter_real = f_filter_real.expand_as(f_go[..., 0])
f_filter_real = torch.stack((f_filter_real, f_filter_real), -1)
f_conv = f_go / f_filter_real
dim_input = len(grad_output.shape)
dim_input_batch = dim_input - self.dim
grad_input = self.ifftn(f_conv, self.dim, onesided=True, signal_sizes=grad_output.shape[dim_input::])
return FFTVal(grad_input, ini=-1)
else:
# if self.needs_input_grad[0]:
numpy_go = grad_output.detach().cpu().numpy()
# we use the conjugate because the assumption was that the spatial filter is real
# THe following two lines should be correct
if self.dim<3:
grad_input_c = (self.ifftn(self.fftn(numpy_go) / (self.alpha + np.conjugate(self.complex_fourier_filter))))
grad_input = grad_input_c.real
elif self.dim == 3:
grad_input = np.zeros(numpy_go.shape)
for batch in range(grad_output.size()[0]):
for ch in range(grad_output.size()[1]):
grad_input_c = (
self.ifftn(self.fftn(numpy_go[batch,ch]) / (self.alpha + np.conjugate(self.complex_fourier_filter))))
grad_input[batch, ch] = grad_input_c.real
else:
raise ValueError("cpu fft smooth should be 1d-3d")
return torch.FloatTensor(grad_input)
[docs]def fourier_convolution(input, complex_fourier_filter):
"""
Convenience function for Fourier-based convolutions. Make sure to use this one (instead of directly
using the class FourierConvolution). This will assure that each call generates its own instance
and hence autograd will work properly
:param input: Input image
:param complex_fourier_filter: Filter in Fourier domain as generated by *createComplexFourierFilter*
:return:
"""
# First braces create a Function object. Any arguments given here
# will be passed to __init__. Second braces will invoke the __call__
# operator, that will then use forward() to compute the result and
# return it.
return FourierConvolution(complex_fourier_filter)(input)
[docs]def inverse_fourier_convolution(input, complex_fourier_filter):
# just filtering with inverse filter
return InverseFourierConvolution(complex_fourier_filter)(input)
[docs]class GaussianFourierFilterGenerator(object):
def __init__(self, sz, spacing, nr_of_slots=1):
self.sz = sz
"""image size"""
self.spacing = spacing
"""image spacing"""
self.volumeElement = self.spacing.prod()
"""volume of pixel/voxel"""
self.dim = len(spacing)
"""dimension"""
self.nr_of_slots = nr_of_slots
"""number of slots to hold Gaussians (to be able to support multi-Gaussian); this is related to storage"""
"""typically should be set to the number of total desired Gaussians (so that none of them need to be recomputed)"""
self.mus = np.zeros(self.dim)
# TODO: storing the identity map may be a little wasteful
self.centered_id = utils.centered_identity_map(self.sz,self.spacing)
self.complex_gaussian_fourier_filters = [None] * self.nr_of_slots
self.max_indices = [None]*self.nr_of_slots
self.sigmas_complex_gaussian_fourier_filters = [None]*self.nr_of_slots
self.complex_gaussian_fourier_xsqr_filters = [None]*self.nr_of_slots
self.sigmas_complex_gaussian_fourier_xsqr_filters = [None]*self.nr_of_slots
self.sigmas_complex_gaussian_fourier_filters_np=[]
[docs] def get_number_of_slots(self):
return self.nr_of_slots
[docs] def get_number_of_currently_stored_gaussians(self):
nr_of_gaussians = 0
for s in self.sigmas_complex_gaussian_fourier_filters:
if s is not None:
nr_of_gaussians += 1
return nr_of_gaussians
[docs] def get_dimension(self):
return self.dim
def _compute_complex_gaussian_fourier_filter(self,sigma):
stds = sigma.detach().cpu().numpy() * np.ones(self.dim)
gaussian_spatial_filter = utils.compute_normalized_gaussian(self.centered_id, self.mus, stds)
complex_gaussian_fourier_filter,max_index = create_complex_fourier_filter(gaussian_spatial_filter,self.sz,True)
return complex_gaussian_fourier_filter,max_index
def _compute_complex_gaussian_fourier_xsqr_filter(self,sigma,max_index=None):
if max_index is None:
raise ValueError('A Gaussian filter needs to be generated / requested *before* any other filter')
# TODO: maybe compute this jointly with the gaussian filter itself to avoid computing the spatial filter twice
stds = sigma.detach().cpu().numpy() * np.ones(self.dim)
gaussian_spatial_filter = utils.compute_normalized_gaussian(self.centered_id, self.mus, stds)
gaussian_spatial_xsqr_filter = gaussian_spatial_filter*(self.centered_id**2).sum(axis=0)
complex_gaussian_fourier_xsqr_filter,max_index = create_complex_fourier_filter(gaussian_spatial_xsqr_filter,self.sz,True,max_index)
return complex_gaussian_fourier_xsqr_filter,max_index
def _find_closest_sigma_index(self, sigma, available_sigmas):
"""
For a given sigma, finds the closest one in a list of available sigmas
- If a sigma is already computed it finds its index
- If the sigma has not been computed (it finds the next empty slot (None)
- If no empty slots are available it replaces the closest
:param available_sigmas: a list of sigmas that have already been computed (or None if they have not)
:return: returns the index for the closest sigma among the available_sigmas
"""
closest_i = None
same_i = None
empty_slot_i = None
current_dist_sqr = None
for i,s in enumerate(available_sigmas):
if s is not None:
# keep track of the one with the closest distance
new_dist_sqr = (s-sigma)**2
if current_dist_sqr is None:
current_dist_sqr = new_dist_sqr
closest_i = i
else:
if new_dist_sqr<current_dist_sqr:
current_dist_sqr = new_dist_sqr
closest_i = i
# also check if this is the same
# if it is records the first occurrence
if torch.isclose(sigma,s):
if same_i is None:
same_i = i
else:
# found an empty slot, record it if it is the first one that was found
if empty_slot_i is None:
empty_slot_i = i
# if we found the same we return it
if same_i is not None:
# we found the same; i.e., already computed
return same_i
elif empty_slot_i is not None:
# it was not already computed, but we found an empty slot to put it in
return empty_slot_i
elif closest_i is not None:
# no empty slot, so just overwrite the closest one if there is one
return closest_i
else:
# nothing has been computed yet, so return the 0 index (this should never execute, as it should be taken care of by the empty slot
return 0
[docs] def get_gaussian_xsqr_filters(self,sigmas):
"""
Returns complex Gaussian Fourier filter multiplied with x**2 with standard deviation sigma.
Only recomputes the filter if sigma has changed.
:param sigmas: standard deviation of the filter as a list
:return: Returns the complex Gaussian Fourier filters as a list (in the same order as requested)
"""
current_complex_gaussian_fourier_xsqr_filters = []
# only recompute the ones that need to be recomputed
for sigma in sigmas:
# now find the index that corresponds to this
i = self._find_closest_sigma_index(sigma, self.sigmas_complex_gaussian_fourier_xsqr_filters)
if self.sigmas_complex_gaussian_fourier_xsqr_filters[i] is None:
need_to_recompute = True
elif self.complex_gaussian_fourier_xsqr_filters[i] is None:
need_to_recompute = True
elif torch.isclose(sigma,self.sigmas_complex_gaussian_fourier_xsqr_filters[i]):
need_to_recompute = False
else:
need_to_recompute = True
if need_to_recompute:
print('INFO: Recomputing gaussian xsqr filter for sigma={:.2f}'.format(sigma))
self.sigmas_complex_gaussian_fourier_xsqr_filters[i] = sigma #.clone()
self.complex_gaussian_fourier_xsqr_filters[i],_ = self._compute_complex_gaussian_fourier_xsqr_filter(sigma,self.max_indices[i])
current_complex_gaussian_fourier_xsqr_filters.append(self.complex_gaussian_fourier_xsqr_filters[i])
return current_complex_gaussian_fourier_xsqr_filters
[docs] def get_gaussian_filters(self,sigmas):
"""
Returns a complex Gaussian Fourier filter with standard deviation sigma.
Only recomputes the filter if sigma has changed.
:param sigma: standard deviation of filter.
:return: Returns the complex Gaussian Fourier filter
"""
current_complex_gaussian_fourier_filters = []
# only recompute the ones that need to be recomputed
for sigma in sigmas:
# now find the index that corresponds to this
sigma_value = sigma.item()
if sigma_value in self.sigmas_complex_gaussian_fourier_filters_np:
i = self.sigmas_complex_gaussian_fourier_filters_np.index(sigma_value)
else:
i = self._find_closest_sigma_index(sigma,self.sigmas_complex_gaussian_fourier_filters)
if self.sigmas_complex_gaussian_fourier_filters[i] is None:
need_to_recompute = True
elif self.complex_gaussian_fourier_filters[i] is None:
need_to_recompute = True
elif torch.isclose(sigma,self.sigmas_complex_gaussian_fourier_filters[i]):
need_to_recompute = False
else:
need_to_recompute = True
if need_to_recompute: # todo not comment this warning
print('INFO: Recomputing gaussian filter for sigma={:.2f}'.format(sigma))
self.sigmas_complex_gaussian_fourier_filters[i] = sigma #.clone()
self.sigmas_complex_gaussian_fourier_filters_np.append(sigma_value)
self.complex_gaussian_fourier_filters[i], self.max_indices[i] = self._compute_complex_gaussian_fourier_filter(sigma)
current_complex_gaussian_fourier_filters.append(self.complex_gaussian_fourier_filters[i])
return current_complex_gaussian_fourier_filters
[docs]class FourierGaussianConvolution(Function):
"""
pyTorch function to compute Gaussian convolutions in the Fourier domain: f = g*h.
Also allows to differentiate through the Gaussian standard deviation.
"""
def __init__(self, gaussian_fourier_filter_generator):
"""
Constructor for the Fouier-based convolution
:param sigma: standard deviation for the filter
"""
# we assume this is a spatial filter, F, hence conj(F(w))=F(-w)
super(FourierGaussianConvolution, self).__init__()
self.gaussian_fourier_filter_generator = gaussian_fourier_filter_generator
self.dim = self.gaussian_fourier_filter_generator.get_dimension()
self.fftn = sel_fftn(self.dim)
self.ifftn = sel_ifftn(self.dim)
def _compute_convolution_CUDA(self,input,complex_fourier_filter):
input = FFTVal(input, ini=1)
f_input = self.fftn(input, self.dim, onesided=True)
f_filter_real = complex_fourier_filter[0]
f_filter_real = f_filter_real.expand_as(f_input[..., 0])
f_filter_real = torch.stack((f_filter_real, f_filter_real), -1)
f_conv = f_input * f_filter_real
dim_input = len(input.shape)
dim_input_batch = dim_input - self.dim
conv_ouput_real = self.ifftn(f_conv, self.dim, onesided=True, signal_sizes=input.shape[dim_input_batch::])
result = conv_ouput_real
return FFTVal(result, ini=-1)
def _compute_convolution_CPU(self,input,complex_fourier_filter):
if self.dim < 3:
conv_output = self.ifftn(self.fftn(input.detach().cpu().numpy()) * complex_fourier_filter)
result = conv_output.real # should in principle be real
elif self.dim == 3:
result = np.zeros(input.shape)
for batch in range(input.size()[0]):
for ch in range(input.size()[1]):
conv_output = self.ifftn(self.fftn(input[batch, ch].detach().cpu().numpy()) * complex_fourier_filter)
result[batch, ch] = conv_output.real
else:
raise ValueError("cpu fft smooth should be 1d-3d")
return torch.FloatTensor(result)
# print( 'max(imag) = ' + str( (abs( conv_output.imag )).max() ) )
# print( 'max(real) = ' + str( (abs( conv_output.real )).max() ) )
def _compute_input_gradient_CUDA(self,grad_output,complex_fourier_filter):
grad_output = FFTVal(grad_output, ini=1)
# print grad_output.view(-1,1).sum()
f_go = self.fftn(grad_output, self.dim, onesided=True)
f_filter_real = complex_fourier_filter[0]
f_filter_real = f_filter_real.expand_as(f_go[..., 0])
f_filter_real = torch.stack((f_filter_real, f_filter_real), -1)
f_conv = f_go * f_filter_real
dim_input = len(grad_output.shape)
dim_input_batch = dim_input - self.dim
grad_input = self.ifftn(f_conv, self.dim, onesided=True, signal_sizes=grad_output.shape[dim_input_batch::])
return FFTVal(grad_input, ini=-1)
def _compute_input_gradient_CPU(self,grad_output,complex_fourier_filter):
numpy_go = grad_output.detach().cpu().numpy()
# we use the conjugate because the assumption was that the spatial filter is real
# THe following two lines should be correct
if self.dim < 3:
grad_input_c = (self.ifftn(np.conjugate(complex_fourier_filter) * self.fftn(numpy_go)))
grad_input = grad_input_c.real
elif self.dim == 3:
grad_input = np.zeros(numpy_go.shape)
assert grad_output.dim() == 5 # to ensure the behavior correct, we avoid more than 3 dimension fftn method
for batch in range(grad_output.size()[0]):
for ch in range(grad_output.size()[1]):
grad_input_c = (
self.ifftn(np.conjugate(complex_fourier_filter) * self.fftn(numpy_go[batch, ch])))
grad_input[batch, ch] = grad_input_c.real
else:
raise ValueError("cpu fft smooth should be 1d-3d")
return torch.FloatTensor(grad_input)
def _compute_sigma_gradient_CUDA(self,input,sigma,grad_output,complex_fourier_filter,complex_fourier_xsqr_filter):
convolved_input = self._compute_convolution_CUDA(input, complex_fourier_filter)
grad_sigma = -1. / sigma * self.dim * (grad_output.detach().cpu().numpy() * convolved_input).sum()
convolved_input_xsqr = self._compute_convolution_CUDA(input, complex_fourier_xsqr_filter)
grad_sigma += 1. / (sigma ** 3) * (grad_output.detach().cpu().numpy() * convolved_input_xsqr).sum()
return grad_sigma
# TODO: gradient appears to be incorrect
def _compute_sigma_gradient_CPU(self,input,sigma,grad_output,complex_fourier_filter,complex_fourier_xsqr_filter):
convolved_input = self._compute_convolution_CPU(input,complex_fourier_filter)
grad_sigma = -1./sigma*self.dim*(grad_output.detach().cpu().numpy()*convolved_input).sum()
convolved_input_xsqr = self._compute_convolution_CPU(input,complex_fourier_xsqr_filter)
grad_sigma += 1./(sigma**3)*(grad_output.detach().cpu().numpy()*convolved_input_xsqr).sum()
return grad_sigma
[docs]class FourierSingleGaussianConvolution(FourierGaussianConvolution):
"""
pyTorch function to compute Gaussian convolutions in the Fourier domain: f = g*h.
Also allows to differentiate through the Gaussian standard deviation.
"""
def __init__(self, gaussian_fourier_filter_generator, compute_std_gradient):
"""
Constructor for the Fouier-based convolution
:param sigma: standard deviation for the filter
:param compute_std_gradient: if True computes the gradient with respect to the std, otherwise set to 0
"""
# we assume this is a spatial filter, F, hence conj(F(w))=F(-w)
super(FourierSingleGaussianConvolution, self).__init__(gaussian_fourier_filter_generator)
self.gaussian_fourier_filter_generator = gaussian_fourier_filter_generator
self.complex_fourier_filter = None
self.complex_fourier_xsqr_filter = None
self.input = None
self.sigma = None
self.compute_std_gradient = compute_std_gradient
[docs] def forward(self, input, sigma):
"""
Performs the Fourier-based filtering
the 3d cpu fft is not implemented in fftn, to avoid fusing with batch and channel, here 3d is calcuated in loop
1d 2d cpu works well because fft and fft2 is inbuilt, similarly , 1d 2d 3d gpu fft also is inbuilt
in gpu implementation, the rfft is used for efficiency, which means the filter should be symmetric
:param input: Image
:return: Filtered-image
"""
self.input = input
self.sigma = sigma
self.complex_fourier_filter = self.gaussian_fourier_filter_generator.get_gaussian_filters(self.sigma)[0]
self.complex_fourier_xsqr_filter = self.gaussian_fourier_filter_generator.get_gaussian_xsqr_filters(self.sigma)[0]
# (a+bi)(c+di) = (ac-bd) + (bc+ad)i
# filter_imag =0, then get ac + bci
if USE_CUDA:
return self._compute_convolution_CUDA(input,self.complex_fourier_filter)
else:
return self._compute_convolution_CPU(input,self.complex_fourier_filter)
# This function has only a single output, so it gets only one gradient
[docs] def backward(self, grad_output):
"""
Computes the gradient
the 3d cpu ifft is not implemented in ifftn, to avoid fusing with batch and channel, here 3d is calcuated in loop
1d 2d cpu works well because ifft and ifft2 is inbuilt, similarly , 1d 2d 3d gpu fft also is inbuilt
in gpu implementation, the irfft is used for efficiency, which means the filter should be symmetric
:param grad_output: Gradient output of previous layer
:return: Gradient including the Fourier-based convolution
"""
# Initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
grad_input = grad_sigma = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
# first compute the gradient with respect to the input
if self.needs_input_grad[0]:
# (a+bi)(c+di) = (ac-bd) + (bc+ad)i
# input_imag =0, then get ac + bci
if USE_CUDA:
grad_input = self._compute_input_gradient_CUDA(grad_output,self.complex_fourier_filter)
else:
grad_input = self._compute_input_gradient_CPU(grad_output,self.complex_fourier_filter)
# now compute the gradient with respect to the standard deviation of the filter
if self.compute_std_gradient:
if self.needs_input_grad[1]:
if USE_CUDA:
grad_sigma = self._compute_sigma_gradient_CUDA(self.input,self.sigma,grad_output,self.complex_fourier_filter,self.complex_fourier_xsqr_filter)
else:
grad_sigma = self._compute_sigma_gradient_CPU(self.input,self.sigma,grad_output,self.complex_fourier_filter,self.complex_fourier_xsqr_filter)
else:
grad_sigma = torch.zeros_like(self.sigma)
# now return the computed gradients
return grad_input, grad_sigma
[docs]def fourier_single_gaussian_convolution(input, gaussian_fourier_filter_generator,sigma,compute_std_gradient):
"""
Convenience function for Fourier-based Gaussian convolutions. Make sure to use this one (instead of directly
using the class FourierGaussianConvolution). This will assure that each call generates its own instance
and hence autograd will work properly
:param input: Input image
:param gaussian_fourier_filter_generator: generator which will create Gaussian Fourier filter (and caches them)
:param sigma: standard deviation for the Gaussian filter
:param compute_std_gradient: if set to True computes the gradient otherwise sets it to 0
:return:
"""
# First braces create a Function object. Any arguments given here
# will be passed to __init__. Second braces will invoke the __call__
# operator, that will then use forward() to compute the result and
# return it.
return FourierSingleGaussianConvolution(gaussian_fourier_filter_generator,compute_std_gradient)(input,sigma)
[docs]class FourierMultiGaussianConvolution(FourierGaussianConvolution):
"""
pyTorch function to compute multi Gaussian convolutions in the Fourier domain: f = g*h.
Also allows to differentiate through the Gaussian standard deviation.
"""
def __init__(self, gaussian_fourier_filter_generator,compute_std_gradients,compute_weight_gradients):
"""
Constructor for the Fouier-based convolution
:param gaussian_fourier_filter_generator: class instance that creates and caches the Gaussian filters
:param compute_std_gradients: if set to True the gradients for std are computed, otherwise they are filled w/ zero
:param compute_weight_gradients: if set to True the gradients for weights are computed, otherwise they are filled w/ zero
"""
# we assume this is a spatial filter, F, hence conj(F(w))=F(-w)
super(FourierMultiGaussianConvolution, self).__init__(gaussian_fourier_filter_generator)
self.gaussian_fourier_filter_generator = gaussian_fourier_filter_generator
self.complex_fourier_filters = None
self.complex_fourier_xsqr_filters = None
self.input = None
self.weights = None
self.sigmas = None
self.nr_of_gaussians = None
self.compute_std_gradients = compute_std_gradients
self.compute_weight_gradients = compute_weight_gradients
[docs] def forward(self, input, sigmas, weights):
"""
Performs the Fourier-based filtering
the 3d cpu fft is not implemented in fftn, to avoid fusing with batch and channel, here 3d is calcuated in loop
1d 2d cpu works well because fft and fft2 is inbuilt, similarly , 1d 2d 3d gpu fft also is inbuilt
in gpu implementation, the rfft is used for efficiency, which means the filter should be symmetric
:param input: Image
:return: Filtered-image
"""
self.input = input
self.sigmas = sigmas
self.weights = weights
self.nr_of_gaussians = len(self.sigmas)
nr_of_weights = len(self.weights)
assert(self.nr_of_gaussians==nr_of_weights)
self.complex_fourier_filters = self.gaussian_fourier_filter_generator.get_gaussian_filters(self.sigmas)
self.complex_fourier_xsqr_filters = self.gaussian_fourier_filter_generator.get_gaussian_xsqr_filters(self.sigmas)
# (a+bi)(c+di) = (ac-bd) + (bc+ad)i
# filter_imag =0, then get ac + bci
ret = torch.zeros_like(input)
for i in range(self.nr_of_gaussians):
if USE_CUDA:
ret += self.weights[i]*self._compute_convolution_CUDA(input,self.complex_fourier_filters[i])
else:
ret+= self.weights[i]*self._compute_convolution_CPU(input,self.complex_fourier_filters[i])
return ret
def _compute_input_gradient_CUDA_multi_gaussian(self,grad_output,complex_fourier_filters):
grad_input = torch.zeros_like(self.input)
for i in range(self.nr_of_gaussians):
grad_input += self.weights[i]*self._compute_input_gradient_CUDA(grad_output, complex_fourier_filters[i])
return grad_input
def _compute_input_gradient_CPU_multi_gaussian(self,grad_output,complex_fourier_filters):
grad_input = torch.zeros_like(self.input)
for i in range(self.nr_of_gaussians):
grad_input += self.weights[i] * self._compute_input_gradient_CPU(grad_output,complex_fourier_filters[i])
return grad_input
def _compute_sigmas_gradient_CUDA_multi_gaussian(self,input,sigmas,grad_output,complex_fourier_filters,complex_fourier_xsqr_filters):
grad_sigmas = torch.zeros_like(sigmas)
for i in range(self.nr_of_gaussians):
grad_sigmas[i] = self.weights[i] * self._compute_sigma_gradient_CUDA(input,sigmas[i],grad_output,
complex_fourier_filters[i],
complex_fourier_xsqr_filters[i])
return grad_sigmas
def _compute_sigmas_gradient_CPU_multi_gaussian(self,input,sigmas,grad_output,complex_fourier_filters,complex_fourier_xsqr_filters):
grad_sigmas = torch.zeros_like(sigmas)
for i in range(self.nr_of_gaussians):
grad_sigmas[i] = self.weights[i] * self._compute_sigma_gradient_CPU(input,sigmas[i],grad_output,
complex_fourier_filters[i],
complex_fourier_xsqr_filters[i])
return grad_sigmas
def _compute_weights_gradient_CUDA_multi_gaussian(self,input,weights,grad_output,complex_fourier_filters):
grad_weights = torch.zeros_like(weights)
for i in range(self.nr_of_gaussians):
grad_weights[i] = (grad_output*self._compute_convolution_CUDA(input,complex_fourier_filters[i])).sum()
return grad_weights
def _compute_weights_gradient_CPU_multi_gaussian(self,input,weights,grad_output,complex_fourier_filters):
grad_weights = torch.zeros_like(weights)
for i in range(self.nr_of_gaussians):
grad_weights[i] = (grad_output * self._compute_convolution_CPU(input, complex_fourier_filters[i])).sum()
return grad_weights
# This function has only a single output, so it gets only one gradient
[docs] def backward(self, grad_output):
"""
Computes the gradient
the 3d cpu ifft is not implemented in ifftn, to avoid fusing with batch and channel, here 3d is calcuated in loop
1d 2d cpu works well because ifft and ifft2 is inbuilt, similarly , 1d 2d 3d gpu fft also is inbuilt
in gpu implementation, the irfft is used for efficiency, which means the filter should be symmetric
:param grad_output: Gradient output of previous layer
:return: Gradient including the Fourier-based convolution
"""
# Initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
grad_input = grad_sigmas = grad_weights = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
# first compute the gradient with respect to the input
if self.needs_input_grad[0]:
# (a+bi)(c+di) = (ac-bd) + (bc+ad)i
# input_imag =0, then get ac + bci
if USE_CUDA:
grad_input = self._compute_input_gradient_CUDA_multi_gaussian(grad_output,self.complex_fourier_filters)
else:
grad_input = self._compute_input_gradient_CPU_multi_gaussian(grad_output,self.complex_fourier_filters)
# now compute the gradient with respect to the standard deviation of the filter
if self.needs_input_grad[1]:
if self.compute_std_gradients:
if USE_CUDA:
grad_sigmas = self._compute_sigmas_gradient_CUDA_multi_gaussian(self.input,self.sigmas,grad_output,self.complex_fourier_filters,self.complex_fourier_xsqr_filters)
else:
grad_sigmas = self._compute_sigmas_gradient_CPU_multi_gaussian(self.input,self.sigmas,grad_output,self.complex_fourier_filters,self.complex_fourier_xsqr_filters)
else:
grad_sigmas = torch.zeros_like(self.sigmas)
if self.needs_input_grad[2]:
if self.compute_weight_gradients:
if USE_CUDA:
grad_weights = self._compute_weights_gradient_CUDA_multi_gaussian(self.input,self.weights,grad_output,self.complex_fourier_filters)
else:
grad_weights = self._compute_weights_gradient_CPU_multi_gaussian(self.input,self.weights,grad_output,self.complex_fourier_filters)
else:
grad_weights = torch.zeros_like(self.weights)
# now return the computed gradients
#print('gsigmas: min=' + str(grad_sigmas.min()) + '; max=' + str(grad_sigmas.max()))
#print('gweight: min=' + str(grad_weights.min()) + '; max=' + str(grad_weights.max()))
#print( 'gsigmas = ' + str( grad_sigmas))
#print( 'gweight = ' + str( grad_weights))
return grad_input, grad_sigmas, grad_weights
[docs]def fourier_multi_gaussian_convolution(input, gaussian_fourier_filter_generator,sigma,weights,compute_std_gradients=True,compute_weight_gradients=True):
"""
Convenience function for Fourier-based multi Gaussian convolutions. Make sure to use this one (instead of directly
using the class FourierGaussianConvolution). This will assure that each call generates its own instance
and hence autograd will work properly
:param input: Input image
:param gaussian_fourier_filter_generator: generator which will create Gaussian Fourier filter (and caches them)
:param sigma: standard deviations for the Gaussian filter (need to be positive)
:param weights: weights for the multi-Gaussian kernel (need to sum up to one and need to be positive)
:param compute_std_gradients: if set to True computes the gradients with respect to the standard deviation
:param compute_weight_gradients: if set to True then gradients for weight are computed, otherwise they are replaced w/ zero
:return:
"""
# First braces create a Function object. Any arguments given here
# will be passed to __init__. Second braces will invoke the __call__
# operator, that will then use forward() to compute the result and
# return it.
return FourierMultiGaussianConvolution(gaussian_fourier_filter_generator,compute_std_gradients,compute_weight_gradients)(input,sigma,weights)
[docs]class FourierSetOfGaussianConvolutions(FourierGaussianConvolution):
"""
pyTorch function to compute a set of Gaussian convolutions (as in the multi-Gaussian) in the Fourier domain: f = g*h.
Also allows to differentiate through the standard deviations. THe output is not a smoothed field, but the
set of all of them. This can then be fed into a subsequent neural network for further processing.
"""
def __init__(self, gaussian_fourier_filter_generator,compute_std_gradients):
"""
Constructor for the Fouier-based convolution
:param gaussian_fourier_filter_generator: class instance that creates and caches the Gaussian filters
:param compute_std_gradients: if set to True the gradients for the stds are computed, otherwise they are filled w/ zero
"""
# we assume this is a spatial filter, F, hence conj(F(w))=F(-w)
super(FourierSetOfGaussianConvolutions, self).__init__(gaussian_fourier_filter_generator)
self.gaussian_fourier_filter_generator = gaussian_fourier_filter_generator
self.complex_fourier_filters = None
self.complex_fourier_xsqr_filters = None
self.input = None
self.sigmas = None
self.nr_of_gaussians = None
self.compute_std_gradients = compute_std_gradients
[docs] def forward(self, input, sigmas):
"""
Performs the Fourier-based filtering
the 3d cpu fft is not implemented in fftn, to avoid fusing with batch and channel, here 3d is calculated in loop
1d 2d cpu works well because fft and fft2 is inbuilt, similarly , 1d 2d 3d gpu fft also is inbuilt
in gpu implementation, the rfft is used for efficiency, which means the filter should be symmetric
:param input: Image
:return: Filtered-image
"""
self.input = input
self.sigmas = sigmas
self.nr_of_gaussians = len(self.sigmas)
self.complex_fourier_filters = self.gaussian_fourier_filter_generator.get_gaussian_filters(self.sigmas)
if self.compute_std_gradients:
self.complex_fourier_xsqr_filters = self.gaussian_fourier_filter_generator.get_gaussian_xsqr_filters(self.sigmas)
# TODO check if the xsqr should be put into an if statement here
# (a+bi)(c+di) = (ac-bd) + (bc+ad)i
# filter_imag =0, then get ac + bci
sz = input.size()
new_sz = [self.nr_of_gaussians] + list(sz)
ret = AdaptVal(MyTensor(*new_sz))
for i in range(self.nr_of_gaussians):
if USE_CUDA:
ret[i,...] = self._compute_convolution_CUDA(input,self.complex_fourier_filters[i])
else:
ret[i,...] = self._compute_convolution_CPU(input,self.complex_fourier_filters[i])
return ret
def _compute_input_gradient_CUDA_multi_gaussian(self,grad_output,complex_fourier_filters):
grad_input = torch.zeros_like(self.input)
for i in range(self.nr_of_gaussians):
grad_input += self._compute_input_gradient_CUDA(grad_output[i,...], complex_fourier_filters[i])
return grad_input
def _compute_input_gradient_CPU_multi_gaussian(self,grad_output,complex_fourier_filters):
grad_input = torch.zeros_like(self.input)
for i in range(self.nr_of_gaussians):
grad_input += self._compute_input_gradient_CPU(grad_output[i,...],complex_fourier_filters[i])
return grad_input
def _compute_sigmas_gradient_CUDA_multi_gaussian(self,input,sigmas,grad_output,complex_fourier_filters,complex_fourier_xsqr_filters):
grad_sigmas = torch.zeros_like(sigmas)
for i in range(self.nr_of_gaussians):
grad_sigmas[i] = self._compute_sigma_gradient_CUDA(input,sigmas[i],grad_output[i,...],
complex_fourier_filters[i],
complex_fourier_xsqr_filters[i])
return grad_sigmas
def _compute_sigmas_gradient_CPU_multi_gaussian(self,input,sigmas,grad_output,complex_fourier_filters,complex_fourier_xsqr_filters):
grad_sigmas = torch.zeros_like(sigmas)
for i in range(self.nr_of_gaussians):
grad_sigmas[i] = self._compute_sigma_gradient_CPU(input,sigmas[i],grad_output[i,...],
complex_fourier_filters[i],
complex_fourier_xsqr_filters[i])
return grad_sigmas
# This function has only a single output, so it gets only one gradient
[docs] def backward(self, grad_output):
"""
Computes the gradient
the 3d cpu ifft is not implemented in ifftn, to avoid fusing with batch and channel, here 3d is calcuated in loop
1d 2d cpu works well because ifft and ifft2 is inbuilt, similarly , 1d 2d 3d gpu fft also is inbuilt
in gpu implementation, the irfft is used for efficiency, which means the filter should be symmetric
:param grad_output: Gradient output of previous layer
:return: Gradient including the Fourier-based convolution
"""
# Initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
grad_input = grad_sigmas = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
# first compute the gradient with respect to the input
if self.needs_input_grad[0]:
# (a+bi)(c+di) = (ac-bd) + (bc+ad)i
# input_imag =0, then get ac + bci
if USE_CUDA:
grad_input = self._compute_input_gradient_CUDA_multi_gaussian(grad_output,self.complex_fourier_filters)
else:
grad_input = self._compute_input_gradient_CPU_multi_gaussian(grad_output,self.complex_fourier_filters)
# now compute the gradient with respect to the standard deviation of the filter
if self.needs_input_grad[1]:
if self.compute_std_gradients:
if USE_CUDA:
grad_sigmas = self._compute_sigmas_gradient_CUDA_multi_gaussian(self.input,self.sigmas,grad_output,self.complex_fourier_filters,self.complex_fourier_xsqr_filters)
else:
grad_sigmas = self._compute_sigmas_gradient_CPU_multi_gaussian(self.input,self.sigmas,grad_output,self.complex_fourier_filters,self.complex_fourier_xsqr_filters)
else:
grad_sigmas = torch.zeros_like(self.sigmas)
# now return the computed gradients
return grad_input, grad_sigmas
[docs]def fourier_set_of_gaussian_convolutions(input, gaussian_fourier_filter_generator,sigma,compute_std_gradients=False):
"""
Convenience function for Fourier-based multi Gaussian convolutions. Make sure to use this one (instead of directly
using the class FourierGaussianConvolution). This will assure that each call generates its own instance
and hence autograd will work properly
:param input: Input image
:param gaussian_fourier_filter_generator: generator which will create Gaussian Fourier filter (and caches them)
:param sigma: standard deviations for the Gaussian filter (need to be positive)
:param compute_weight_std_gradients: if set to True then gradients for standard deviation are computed, otherwise they are replaced w/ zero
:return:
"""
# First braces create a Function object. Any arguments given here
# will be passed to __init__. Second braces will invoke the __call__
# operator, that will then use forward() to compute the result and
# return it.
return FourierSetOfGaussianConvolutions(gaussian_fourier_filter_generator,compute_std_gradients)(input,sigma)
[docs]def check_fourier_conv():
"""
Convenience function to check the gradient. Fails, as pytorch's check appears to have difficulty
:return: True if analytical and numerical gradient are the same
.. todo::
The current check seems to fail in pyTorch. However, the gradient appears to be correct. Potentially an issue with the numerical gradient approximiaton.
"""
# gradcheck takes a tuple of tensor as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
# TODO: Seems to fail at the moment, check why if there are issues with the gradient
sz = np.array([20, 20], dtype='int64')
# f = np.ones(sz)
f = 1 / 400. * np.ones(sz)
dim = len(sz)
mus = np.zeros(dim)
stds = np.ones(dim)
spacing = np.ones(dim)
centered_id = utils.centered_identity_map(sz,spacing)
g = 100 * utils.compute_normalized_gaussian(centered_id, mus, stds)
FFilter,_ = create_complex_fourier_filter(g, sz)
input = AdaptVal(torch.randn([1, 1] + list(sz)))
input.requires_grad = True
test = gradcheck(FourierConvolution(FFilter), input, eps=1e-6, atol=1e-4)
print(test)
[docs]def check_run_forward_and_backward():
"""
Convenience function to check running the function forward and backward
s
:return:
"""
sz = [20, 20]
f = 1 / 400. * np.ones(sz)
FFilter,_ = create_complex_fourier_filter(f, sz, False)
input = torch.randn(sz).float()
input.requires_grad = True
fc = FourierConvolution(FFilter)(input)
# print( fc )
fc.backward(torch.randn(sz).float())
print(input.grad)