Source code for mdgru.model_pytorch.mdrnn

__author__ = "Simon Andermatt"
__copyright__ = "Copyright (C) 2017 Simon Andermatt"

import torch as th
from mdgru.helper import argget, generate_defaults_info
from mdgru.helper import compile_arguments
from .mdgru import MDRNN


[docs]class MDGRUBlock(th.nn.Module): """Convenience class combining attributes to be used for multiple MDRNN and voxel-wise fully connected layers. :param num_spatial_dims: Number of spatial dimensions to consider :param dropout: Dropout rate provided as "keep" rate :param num_input: Nuber of input units/channels :param num_hidden: number of hidden units, output units of the MDRNN :param num_output: number of output units of the voxel-wise fully connected layer (Can be None -> no voxel-wise fully connected layer) """ _defaults = { "resmdgru": {'value': False, 'help': 'Add a residual connection from each mdgru input to its output, possibly homogenizing dimensions using one 1 conv layer'}, "vwfc_activation": th.nn.Tanh, "noactivation": False, "name": None, } def __init__(self, num_spatial_dims, dropout, num_input, num_hidden, num_output, kw): super(MDGRUBlock, self).__init__() mdrnn_net_kw, kw = compile_arguments(MDGRUBlock, kw, transitive=False) for k, v in mdrnn_net_kw.items(): setattr(self, k, v) self.mdrnn_kw, kw = compile_arguments(MDRNN, kw, transitive=True) self.crnn_kw, kw = compile_arguments(self.mdrnn_kw['crnn_class'], kw, transitive=True) spatial_dimensions = argget(kw, "dimensions", None) if spatial_dimensions is None: spatial_dimensions = [i for i in range(num_spatial_dims)] mdrnn_kw = {} mdrnn_kw.update(self.mdrnn_kw) mdrnn_kw.update(self.crnn_kw) mdrnn_kw.update(kw) mdrnn_kw["num_hidden"] = num_hidden mdrnn_kw["num_input"] = num_input mdrnn_kw["name"] = "mdgru" model = [MDRNN(dropout, spatial_dimensions, mdrnn_kw)] if num_spatial_dims == 2: convop = th.nn.Conv2d kernel = [1, 1] elif num_spatial_dims == 3: convop = th.nn.Conv3d kernel = [1, 1, 1] else: raise Exception('pytorch cannot handle more than 3 dimensions for convolution') if num_output is not None: model += [convop(num_hidden, num_output, kernel)] if not self.noactivation: model += [self.vwfc_activation()] self.model = th.nn.Sequential(*model)
[docs] def forward(self, input): return self.model.forward(input)
generate_defaults_info(MDGRUBlock)