|
|
from .resnet import * |
|
|
import logging |
|
|
logger = logging.getLogger('base') |
|
|
|
|
|
def create_CD_model(opt): |
|
|
|
|
|
from models.STNR import STNR as stnr |
|
|
|
|
|
if opt['model']['name'] == 'STNR': |
|
|
cd_model = stnr(spatial_dims=opt['model']['spatial_dims'], in_channels=opt['model']['in_channels'], init_filters=opt['model']['init_filters'], out_channels=opt['model']['n_classes'], |
|
|
mode=opt['model']['mode'], conv_mode=opt['model']['conv_mode'], up_mode=opt['model']['up_mode'], up_conv_mode=opt['model']['up_conv_mode'], norm=opt['model']['norm'], |
|
|
blocks_down=opt['model']['blocks_down'], blocks_up=opt['model']['blocks_up'], resdiual=opt['model']['resdiual'], diff_abs=opt['model']['diff_abs'], stage=opt['model']['stage'], |
|
|
mamba_act=opt['model']['mamba_act'], local_query_model=opt['model']['local_query_model']) |
|
|
|