Source code for mermaid.data_wrapper

from __future__ import absolute_import
import torch
from mermaid.config_parser import CUDA_ON, USE_FLOAT16

# ----------------- global setting ----------------------------------------
USE_CUDA = CUDA_ON and torch.cuda.is_available()

# --------------------   My Tensor -------------------------
# a warped version of Tensor to adapt gpu, cpu and float16
if USE_CUDA:
    MyLongTensor = torch.cuda.LongTensor
    if not USE_FLOAT16:
        MyTensor = torch.cuda.FloatTensor
    else:
        MyTensor = torch.cuda.HalfTensor
else:
    MyTensor = torch.FloatTensor
    MyLongTensor = torch.LongTensor

# ------------------  ApdatVal --------------------------
# Adaptive Warper: used to adapt the data type, implemented on the existed Tensor/Variable
[docs]def AdaptVal(x): """ adapt float32/16, gpu/cpu, float 16 is not recommended to use for it is not stable""" if USE_CUDA: if not USE_FLOAT16: return x.cuda() else: return x.cuda().half() else: return x
# -------------------- STN ------------------------------ # specific to the STN Function if USE_CUDA: STNTensor = torch.cuda.FloatTensor else: STNTensor = torch.FloatTensor
[docs]def STNVal(x, ini): """ the cuda version of stn is writing in float32 so the input would first be converted into float32, the output would be converted to adaptive type """ if USE_CUDA: if USE_FLOAT16: if ini == 1: return x.float() elif ini == -1: return x.half() else: raise ValueError('ini should be 1 or -1') else: return x else: return x
# ------------------ FFT ---------------------------- # specific to FFT Function # do same thing as the STNVal FFTVal = STNVal