Source code for mermaid.model_evaluation

import os
import torch
from . import model_factory as MF
from . import visualize_registration_results as vizReg
from . import image_sampling as IS
from .data_wrapper import AdaptVal
from .data_wrapper import USE_CUDA
from . import utils
from . import module_parameters as pars


[docs]def evaluate_model(ISource_in, ITarget_in, sz, spacing, model_name=None, use_map=None, compute_inverse_map=False, map_low_res_factor=None, compute_similarity_measure_at_low_res=None, spline_order=None, individual_parameters=None,shared_parameters=None,params=None,extra_info=None,visualize=True,visual_param=None,given_weight=False): """ #todo: Support initial maps which are not identity :param ISource_in: source image (BxCxXxYxZ format) :param ITarget_in: target image (BxCxXxYxZ format) :param sz: size of the images (BxCxXxYxZ format) :param spacing: spacing for the images :param model_name: name of the desired model (string) :param use_map: if set to True then map-based mode is used :param compute_inverse_map: if set to True the inverse map will be computed :param map_low_res_factor: if set to None then computations will be at full resolution, otherwise at a fraction of the resolution :param compute_similarity_measure_at_low_res: :param spline_order: desired spline order for the sampler :param individual_parameters: individual registration parameters :param shared_parameters: shared registration parameters :param params: parameter dictionary (of model_dictionary type) which configures the model :param visualize: if set to True results will be visualized :return: returns a tuple (I_warped,phi,phi_inverse,model_dictionary), here I_warped = I_source\circ\phi, and phi_inverse is the inverse of phi; model_dictionary contains various intermediate results """ ISource = AdaptVal(ISource_in) ITarget = AdaptVal(ITarget_in) if params is None: print('INFO: WARNING: no params specified, creating empty parameter dictionary') params = pars.ParameterDict() if model_name is None: model_name = params['model']['registration_model']['type'] if use_map is None: use_map = params['model']['deformation']['use_map'] if map_low_res_factor is None: map_low_res_factor = params['model']['deformation'][('map_low_res_factor', None, 'low_res_factor')] if compute_similarity_measure_at_low_res is None: compute_similarity_measure_at_low_res = params['model']['deformation'][('compute_similarity_measure_at_low_res', False, 'to compute Sim at lower resolution')] if spline_order is None: spline_order = params['model']['registration_model'][('spline_order', 1, 'Spline interpolation order; 1 is linear interpolation (default); 3 is cubic spline')] # set some initial default values lowResSize = None lowResSpacing = None lowResISource = None lowResIdentityMap = None # compute low-res versions if needed if map_low_res_factor is not None: lowResSize = utils._get_low_res_size_from_size(sz, map_low_res_factor) lowResSpacing = utils._get_low_res_spacing_from_spacing(spacing, sz, lowResSize) lowResISource = utils._compute_low_res_image(ISource, spacing, lowResSize,spline_order) # todo: can be removed to save memory; is more experimental at this point lowResITarget = utils._compute_low_res_image(ITarget, spacing, lowResSize,spline_order) # computes model at a lower resolution than the image similarity if compute_similarity_measure_at_low_res: mf = MF.ModelFactory(lowResSize, lowResSpacing, lowResSize, lowResSpacing) else: mf = MF.ModelFactory(sz, spacing, lowResSize, lowResSpacing) else: # computes model and similarity at the same resolution mf = MF.ModelFactory(sz, spacing, sz, spacing) model, criterion = mf.create_registration_model(model_name, params['model'], compute_inverse_map=compute_inverse_map) # set it to evaluation mode model.eval() print(model) if use_map: # create the identity map [-1,1]^d, since we will use a map-based implementation id = utils.identity_map_multiN(sz, spacing) identityMap = AdaptVal(torch.from_numpy(id)) if map_low_res_factor is not None: # create a lower resolution map for the computations lowres_id = utils.identity_map_multiN(lowResSize, lowResSpacing) lowResIdentityMap = AdaptVal(torch.from_numpy(lowres_id)) sampler = IS.ResampleImage() if USE_CUDA: model = model.cuda() dictionary_to_pass_to_integrator = dict() if map_low_res_factor is not None: dictionary_to_pass_to_integrator['I0'] = lowResISource dictionary_to_pass_to_integrator['I1'] = lowResITarget else: dictionary_to_pass_to_integrator['I0'] = ISource dictionary_to_pass_to_integrator['I1'] = ITarget model.set_dictionary_to_pass_to_integrator(dictionary_to_pass_to_integrator) if shared_parameters is not None: model.load_shared_state_dict(shared_parameters) model_pars = utils.individual_parameters_to_model_parameters(individual_parameters) model.set_individual_registration_parameters(model_pars) if given_weight: model.m.data = AdaptVal(individual_parameters['m']) if 'local_weights'in individual_parameters and individual_parameters['local_weights'] is not None: model.local_weights.data= AdaptVal(individual_parameters['local_weights']) opt_variables = {'iter': 0, 'epoch': 0,'extra_info':extra_info,'over_scale_iter_count':0} # now let's run the model rec_IWarped, rec_phiWarped, rec_phiInverseWarped = evaluate_model_low_level_interface( model=model, I_source=ISource, opt_variables=opt_variables, use_map=use_map, initial_map=identityMap, compute_inverse_map=compute_inverse_map, initial_inverse_map=identityMap, map_low_res_factor=map_low_res_factor, sampler=sampler, low_res_spacing=lowResSpacing, spline_order=spline_order, low_res_I_source=lowResISource, low_res_initial_map=lowResIdentityMap, low_res_initial_inverse_map=lowResIdentityMap, compute_similarity_measure_at_low_res=compute_similarity_measure_at_low_res) if use_map: rec_IWarped = utils.compute_warped_image_multiNC(ISource, rec_phiWarped, spacing, spline_order, zero_boundary=True) if use_map and map_low_res_factor is not None: vizImage, vizName = model.get_parameter_image_and_name_to_visualize(lowResISource) else: vizImage, vizName = model.get_parameter_image_and_name_to_visualize(ISource) if use_map: phi_or_warped_image = rec_phiWarped else: phi_or_warped_image = rec_IWarped if visual_param is None: visual_param = {} visual_param['visualize'] = visualize visual_param['save_fig'] = False visual_param['save_fig_num'] = -1 else: visual_param['visualize'] = visualize save_fig_path = visual_param['save_fig_path'] visual_param['save_fig_path_byname'] = os.path.join(save_fig_path, 'byname') visual_param['save_fig_path_byiter'] = os.path.join(save_fig_path, 'byiter') visual_param['iter'] = 'scale_' + str(0) + '_iter_' + str(0) if use_map: if compute_similarity_measure_at_low_res: I1Warped = utils.compute_warped_image_multiNC(lowResISource, phi_or_warped_image, lowResSpacing, spline_order) if visualize or visual_param['save_fig']: vizReg.show_current_images(iter=iter, iS=lowResISource, iT=lowResITarget, iW=I1Warped, vizImages=vizImage, vizName=vizName, phiWarped=phi_or_warped_image, visual_param=visual_param) else: I1Warped = utils.compute_warped_image_multiNC(ISource, phi_or_warped_image, spacing, spline_order) if visualize or visual_param['save_fig']: vizReg.show_current_images(iter=iter, iS=ISource, iT=ITarget, iW=I1Warped, vizImages=vizImage, vizName=vizName, phiWarped=phi_or_warped_image, visual_param=visual_param) else: if visualize or visual_param['save_fig']: vizReg.show_current_images(iter=iter, iS=ISource, iT=ITarget, iW=phi_or_warped_image, vizImages=vizImage, vizName=vizName, phiWarped=None, visual_param=visual_param) dictionary_to_pass_to_smoother = dict() if map_low_res_factor is not None: dictionary_to_pass_to_smoother['I'] = lowResISource dictionary_to_pass_to_smoother['I0'] = lowResISource dictionary_to_pass_to_smoother['I1'] = lowResITarget else: dictionary_to_pass_to_smoother['I'] = ISource dictionary_to_pass_to_smoother['I0'] = ISource dictionary_to_pass_to_smoother['I1'] = ITarget variables_from_forward_model = model.get_variables_to_transfer_to_loss_function() smoother = variables_from_forward_model['smoother'] try: smoother.set_debug_retain_computed_local_weights(True) except: pass model_pars = model.get_registration_parameters() if 'lam' in model_pars: m = utils.compute_vector_momentum_from_scalar_momentum_multiNC(model_pars['lam'], lowResISource, lowResSize,lowResSpacing) elif 'm' in model_pars: m = model_pars['m'] else: raise ValueError('Expected a scalar or a vector momentum in model (use SVF for example)') if not given_weight: v = smoother.smooth(m, None, dictionary_to_pass_to_smoother) local_weights = smoother.get_debug_computed_local_weights() local_pre_weights = smoother.get_debug_computed_local_pre_weights() else: v = None local_weights=None local_pre_weights=None try: default_multi_gaussian_weights = smoother.get_default_multi_gaussian_weights() gaussian_stds=smoother.get_gaussian_stds() except: default_multi_gaussian_weights=None gaussian_stds=None model_dict = dict() model_dict['use_map'] = use_map model_dict['lowResISource'] = lowResISource model_dict['lowResITarget'] = lowResITarget model_dict['lowResSpacing'] = lowResSpacing model_dict['lowResSize'] = lowResSize model_dict['local_weights'] = local_weights model_dict['local_pre_weights'] = local_pre_weights model_dict['default_multi_gaussian_weights'] = default_multi_gaussian_weights model_dict['stds'] =gaussian_stds model_dict['model'] = model if 'lam' in model_pars: model_dict['lam'] = model_pars['lam'] model_dict['m'] = m model_dict['v'] = v if use_map: model_dict['id'] = identityMap if map_low_res_factor is not None: model_dict['map_low_res_factor'] = map_low_res_factor model_dict['low_res_id'] = lowResIdentityMap return rec_IWarped,rec_phiWarped, rec_phiInverseWarped, model_dict
[docs]def evaluate_model_low_level_interface(model,I_source,opt_variables=None,use_map=False,initial_map=None,compute_inverse_map=False,initial_inverse_map=None, map_low_res_factor=None, sampler=None,low_res_spacing=None,spline_order=1, low_res_I_source=None,low_res_initial_map=None,low_res_initial_inverse_map=None,compute_similarity_measure_at_low_res=False): """ Evaluates a registration model. Core functionality for optimizer. Use evaluate_model for a convenience implementation which recomputes settings on the fly :param model: registration model :param I_source: source image (may not be used for map-based approaches) :param opt_variables: dictionary to be passed to an optimizer or here the evaluation routine (e.g., {'iter': self.iter_count,'epoch': self.current_epoch}) :param use_map: if set to True then map-based mode is used :param initial_map: initial full-resolution map (will in most cases be the identity) :param compute_inverse_map: if set to True the inverse map will be computed :param initial_inverse_map: initial inverse map (will in most cases be the identity) :param map_low_res_factor: if set to None then computations will be at full resolution, otherwise at a fraction of the resolution :param sampler: sampler which takes care of upsampling maps from their low-resolution variants (when map_low_res_factor<1.) :param low_res_spacing: spacing of the low res map/image :param spline_order: desired spline order for the sampler :param low_res_I_source: low resolution source image :param low_res_initial_map: low resolution version of the initial map :param low_res_initial_inverse_map: low resolution version of the initial inverse map :param compute_similarity_measure_at_low_res: if set to True the similarity measure is also evaluated at low resolution (otherwise at full resolution) :return: returns a tuple (I_warped,phi,phi_inverse), here I_warped = I_source\circ\phi, and phi_inverse is the inverse of phi IMPORTANT: note that \phi is the map that maps from source to target image; I_warped is None for map_based solution, phi is None for image-based solution; phi_inverse is None if inverse is not computed """ # do some checks to make sure everything is properly defined if use_map: if initial_map is None: raise ValueError('initial_map needs to be defined when using map mode') if compute_inverse_map: if initial_inverse_map is None: raise ValueError('inital_inverse_map needs to be defined when using map mode and requesting inverse map') if map_low_res_factor is not None: if sampler is None: raise ValueError('sampler needs to be specified when using map mode and computing at lower resolution') if low_res_spacing is None: raise ValueError('low_res_spacing needs to be speficied when using map mode and computing at lower resolution') if low_res_I_source is None: raise ValueError('low_res_I_source needs to be defined when using map mode and computing at lower resolution') if low_res_initial_map is None: raise ValueError('low_res_initial_map needs to be defined when using map mode and computing at lower resolution') if compute_inverse_map: if low_res_initial_inverse_map is None: raise ValueError('low_res_initial_inverse_map needs to be defined when using map mode at lower resolution and requesting inverse map') else: # not using map if compute_inverse_map: raise ValueError('Cannot compute inverse map in non-map mode') if compute_similarity_measure_at_low_res: raise ValueError('Low res similarity measure computations are only supported in map mode') # actual evaluation code starts here rec_phiWarped = None rec_phiInverseWarped = None rec_IWarped = None if use_map: if map_low_res_factor is not None: if compute_similarity_measure_at_low_res: ret = model(low_res_initial_map, low_res_I_source, phi_inv=low_res_initial_inverse_map, variables_from_optimizer=opt_variables) if compute_inverse_map: if type(ret) == tuple: # if it is a tuple it is returning the inverse rec_phiWarped = ret[0] rec_phiInverseWarped = ret[1] else: rec_phiWarped = ret else: rec_phiWarped = ret else: ret = model(low_res_initial_map, low_res_I_source, phi_inv=low_res_initial_inverse_map, variables_from_optimizer=opt_variables) if compute_inverse_map: if type(ret) == tuple: rec_tmp = ret[0] rec_inv_tmp = ret[1] else: rec_tmp = ret rec_inv_tmp = None else: rec_tmp = ret # now upsample to correct resolution desiredSz = initial_map.size()[2::] rec_phiWarped, _ = sampler.upsample_image_to_size(rec_tmp, low_res_spacing, desiredSz, spline_order, zero_boundary=False) if compute_inverse_map and rec_inv_tmp is not None: rec_phiInverseWarped, _ = sampler.upsample_image_to_size(rec_inv_tmp, low_res_spacing, desiredSz, spline_order, zero_boundary=False) else: rec_phiInverseWarped = None else: # computing at full resolution ret = model(initial_map, I_source, phi_inv=initial_inverse_map, variables_from_optimizer=opt_variables) if compute_inverse_map: if type(ret) == tuple: rec_phiWarped = ret[0] rec_phiInverseWarped = ret[1] else: rec_phiWarped = ret else: rec_phiWarped = ret else: rec_IWarped = model(I_source, opt_variables) return rec_IWarped,rec_phiWarped,rec_phiInverseWarped