from mdgru.helper import define_arguments
__author__ = "Simon Andermatt"
__copyright__ = "Copyright (C) 2017 Simon Andermatt"
import logging
logging.basicConfig(level=logging.INFO)
import os
import numpy as np
import sys
from mdgru.data.grid_collection import GridDataCollection, ThreadedGridDataCollection
# from options.parser import clean_datacollection_args
from mdgru.runner import Runner
from mdgru.helper import compile_arguments
import argparse
[docs]def run_mdgru(args=None):
"""Executes a training/ testing or training and testing run for the mdgru network"""
# Parse arguments
fullparameters = " ".join(args if args is not None else sys.argv)
parser = argparse.ArgumentParser(description="evaluate any data with given parameters", add_help=False)
pre_parameter = parser.add_argument_group('Options changing parameter. Use together with --help')
pre_parameter.add_argument('--use_pytorch', action='store_true', help='use experimental pytorch version. Only core functionality is provided')
pre_parameter.add_argument('--gpu', type=int, nargs='+', default=[0], help='set gpu ids')
pre_parameter.add_argument('--nonthreaded', action="store_true",
help="disallow threading during training to preload data before the processing")
pre_parameter.add_argument('--dice_loss_label', default=None, type=int, nargs="+", help='labels for which the dice losses shall be calculated')
pre_parameter.add_argument('--dice_loss_weight', default=None, type=float, nargs="+", help='weights for the dice losses of the individual classes. same size as dice_loss_label or scalar if dice_autoweighted. final loss: sum(dice_loss_weight)*diceloss + (1-sum(dice_loss_weight))*crossentropy')
pre_parameter.add_argument('--dice_autoweighted', action="store_true", help='weights the label Dices with the squared inverse gold standard area/volume; specify which labels with dice_loss_label; sum(dice_loss_weight) is used as a weighting between crossentropy and diceloss')
pre_parameter.add_argument('--dice_generalized', action="store_true", help='total intersections of all labels over total sums of all labels, instead of linearly combined class Dices')
pre_parameter.add_argument('--dice_cc', action='store_true', help='dice loss for binary segmentation per true component')
pre_args, _ = parser.parse_known_args(args=args)
parser.add_argument('-h','--help', action='store_true', help='print this help message')
# Set environment flag(s) and finally import the classes that depend upon them
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(g) for g in pre_args.gpu])
if pre_args.use_pytorch:
if pre_args.dice_cc:
from mdgru.model_pytorch.mdgru_classification import MDGRUClassificationCC as modelcls
else:
from mdgru.model_pytorch.mdgru_classification import MDGRUClassification as modelcls
from mdgru.eval.torch import SupervisedEvaluationTorch as evalcls
else:
if pre_args.dice_generalized:
from mdgru.model.mdgru_classification import MDGRUClassificationWithGeneralizedDiceLoss as modelcls
elif pre_args.dice_loss_label != None or pre_args.dice_autoweighted:
from mdgru.model.mdgru_classification import MDGRUClassificationWithDiceLoss as modelcls
else:
from mdgru.model.mdgru_classification import MDGRUClassification as modelcls
from mdgru.eval.tf import SupervisedEvaluationTensorflow as evalcls
# Set the necessary classes
# dc = GridDataCollection
tdc = GridDataCollection if pre_args.nonthreaded else ThreadedGridDataCollection
define_arguments(modelcls, parser.add_argument_group('Model Parameters'))
define_arguments(evalcls, parser.add_argument_group('Evaluation Parameters'))
define_arguments(Runner, parser.add_argument_group('Runner Parameters'))
define_arguments(tdc, parser.add_argument_group('Data Parameters'))
args = parser.parse_args(args=args)
# print(args)
if args.help:
parser.print_help()
return
if not args.use_pytorch:
if args.gpubound != 1:
modelcls.set_allowed_gpu_memory_fraction(args.gpubound)
# Set up datacollections
# args_tr, args_val, args_te = clean_datacollection_args(args)
# Set up model and evaluation
kw = vars(args)
args_eval, _ = compile_arguments(evalcls, kw, True, keep_entries=True)
args_model, _ = compile_arguments(modelcls, kw, True, keep_entries=True)
args_data, _ = compile_arguments(tdc, kw, True, keep_entries=True)
args_eval.update(args_model)
args_eval.update(args_data)
if not args.use_pytorch:
if args.checkpointfiles is not None:
args_eval['namespace'] = modelcls.get_model_name_from_ckpt(args.checkpointfiles[0])
args_eval['channels_first'] = args.use_pytorch
# if args_tr is not None:
# traindc = tdc(**args_tr)
# if args_val is not None:
# valdc = tdc(**args_val)
# if args_te is not None:
# testdc = dc(**args_te)
# if args.only_test: #FIXME: this is not the smartest way of doing it, make sure that you can allow missing entries in this dict!
# datadict = {"train": testdc, "validation": testdc, "test": testdc}
# elif args.only_train:
# datadict = {"train": traindc, "validation": valdc, "test": valdc}
# else:
# datadict = {"train": traindc, "validation": valdc, "test": testdc}
ev = evalcls(modelcls, tdc, args_eval)
# Set up runner
args_runner, _ = compile_arguments(Runner, kw, True, keep_entries=True)
args_runner.update({
"experimentloc": os.path.join(args.datapath, 'experiments'),
"fullparameters": fullparameters,
# "estimatefilenames": optionname
})
runner = Runner(ev, **args_runner)
# Run computation
return runner.run()
if __name__ == "__main__":
run_mdgru()