|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.generic") |
|
|
|
import argparse |
|
import os |
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
import torch.nn.parallel |
|
import torch.utils.data.distributed |
|
from optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR |
|
from trainer import run_training |
|
from utils.data_utils import get_loader |
|
from monai.losses import DiceLoss |
|
from monai.metrics import DiceMetric |
|
from utils.textswin_unetr import TextSwinUNETR |
|
from monai.transforms import Activations, AsDiscrete, Compose |
|
from monai.utils.enums import MetricReduction |
|
import random |
|
|
|
|
|
parser = argparse.ArgumentParser(description="TextBraTS segmentation pipeline for TextBRATS image-text dataset") |
|
parser.add_argument("--checkpoint", default=None, help="start training from saved checkpoint") |
|
parser.add_argument("--logdir", default="TextBraTS", type=str, help="directory to save the tensorboard logs") |
|
parser.add_argument("--fold", default=0, type=int, help="data fold, 0 for validation and 1 for training") |
|
parser.add_argument("--pretrained_model_name", default="model.pt", type=str, help="pretrained model name") |
|
parser.add_argument("--data_dir", default="./data/TextBraTSData", type=str, help="dataset directory") |
|
parser.add_argument("--json_list", default="./Train.json", type=str, help="dataset json file") |
|
parser.add_argument("--save_checkpoint", action="store_true", help="save checkpoint during training") |
|
parser.add_argument("--max_epochs", default=200, type=int, help="max number of training epochs") |
|
parser.add_argument("--batch_size", default=2, type=int, help="number of batch size") |
|
parser.add_argument("--sw_batch_size", default=4, type=int, help="number of sliding window batch size") |
|
parser.add_argument("--optim_lr", default=1e-4, type=float, help="optimization learning rate") |
|
parser.add_argument("--optim_name", default="adamw", type=str, help="optimization algorithm") |
|
parser.add_argument("--reg_weight", default=1e-5, type=float, help="regularization weight") |
|
parser.add_argument("--momentum", default=0.99, type=float, help="momentum") |
|
parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training") |
|
parser.add_argument("--val_every", default=1, type=int, help="validation frequency") |
|
parser.add_argument("--distributed", action="store_true", help="start distributed training") |
|
parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training") |
|
parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training") |
|
parser.add_argument("--dist-url", default="tcp://127.0.0.1:23456", type=str, help="distributed url") |
|
parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") |
|
parser.add_argument("--norm_name", default="instance", type=str, help="normalization name") |
|
parser.add_argument("--workers", default=8, type=int, help="number of workers") |
|
parser.add_argument("--feature_size", default=48, type=int, help="feature size") |
|
parser.add_argument("--in_channels", default=4, type=int, help="number of input channels") |
|
parser.add_argument("--out_channels", default=3, type=int, help="number of output channels") |
|
parser.add_argument("--cache_dataset", action="store_true", help="use monai Dataset class") |
|
parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged") |
|
parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged") |
|
parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged") |
|
parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged") |
|
parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction") |
|
parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction") |
|
parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction") |
|
parser.add_argument("--roi_x", default=128, type=int, help="roi size in x direction") |
|
parser.add_argument("--roi_y", default=128, type=int, help="roi size in y direction") |
|
parser.add_argument("--roi_z", default=128, type=int, help="roi size in z direction") |
|
parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate") |
|
parser.add_argument("--dropout_path_rate", default=0.0, type=float, help="drop path rate") |
|
parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability") |
|
parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability") |
|
parser.add_argument("--infer_overlap", default=0.5, type=float, help="sliding window inference overlap") |
|
parser.add_argument("--lrschedule", default="warmup_cosine", type=str, help="type of learning rate scheduler") |
|
parser.add_argument("--warmup_epochs", default=50, type=int, help="number of warmup epochs") |
|
parser.add_argument("--resume_ckpt", action="store_true", help="resume training from pretrained checkpoint") |
|
parser.add_argument("--smooth_dr", default=1e-6, type=float, help="constant added to dice denominator to avoid nan") |
|
parser.add_argument("--smooth_nr", default=0.0, type=float, help="constant added to dice numerator to avoid zero") |
|
parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory") |
|
parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data") |
|
parser.add_argument("--use_ssl_pretrained", action="store_true", help="use SSL pretrained ckpt") |
|
parser.add_argument( |
|
"--pretrained_dir", |
|
default="./runs/TextBraTS/", |
|
type=str, |
|
help="pretrained checkpoint directory", |
|
) |
|
parser.add_argument("--squared_dice", action="store_true", help="use squared Dice") |
|
parser.add_argument("--seed", type=int, default=23,help="use random seed") |
|
|
|
|
|
def main(): |
|
args = parser.parse_args() |
|
args.amp = not args.noamp |
|
args.logdir = "./runs/" + args.logdir |
|
random.seed(args.seed) |
|
np.random.seed(args.seed) |
|
torch.manual_seed(args.seed) |
|
if args.distributed: |
|
torch.cuda.manual_seed_all(args.seed) |
|
args.ngpus_per_node = torch.cuda.device_count() |
|
print("Found total gpus", args.ngpus_per_node) |
|
args.world_size = args.ngpus_per_node * args.world_size |
|
mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,)) |
|
else: |
|
torch.cuda.manual_seed(args.seed) |
|
main_worker(gpu=0, args=args) |
|
|
|
|
|
def main_worker(gpu, args): |
|
if args.distributed: |
|
torch.multiprocessing.set_start_method("fork", force=True) |
|
np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True) |
|
args.gpu = gpu |
|
if args.distributed: |
|
args.rank = args.rank * args.ngpus_per_node + gpu |
|
dist.init_process_group( |
|
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank |
|
) |
|
torch.cuda.set_device(args.gpu) |
|
torch.backends.cudnn.benchmark = True |
|
args.test_mode = False |
|
loader = get_loader(args) |
|
print(args.rank, " gpu", args.gpu) |
|
if args.rank == 0: |
|
print("Batch size is:", args.batch_size, "epochs", args.max_epochs) |
|
pretrained_dir = args.pretrained_dir |
|
model_name = args.pretrained_model_name |
|
pretrained_pth = os.path.join(pretrained_dir, model_name) |
|
|
|
model = TextSwinUNETR( |
|
img_size=(args.roi_x, args.roi_y, args.roi_z), |
|
in_channels=args.in_channels, |
|
out_channels=args.out_channels, |
|
feature_size=args.feature_size, |
|
use_checkpoint=args.use_checkpoint, |
|
text_dim=768, |
|
) |
|
|
|
if args.resume_ckpt: |
|
model_dict = torch.load(pretrained_pth)["state_dict"] |
|
for key in list(model_dict.keys()): |
|
model_dict[key.replace("module.", "")] = model_dict.pop(key) |
|
model.load_state_dict(model_dict,strict=True) |
|
print("Using pretrained weights") |
|
|
|
if args.use_ssl_pretrained: |
|
try: |
|
model_dict = torch.load("/media/iipl/disk1/swinunetr/model_swinvit.pt",weights_only=True) |
|
state_dict = model_dict["state_dict"] |
|
|
|
|
|
for key in list(state_dict.keys()): |
|
state_dict[key.replace("module.", "swinViT.")] = state_dict.pop(key) |
|
for key in list(state_dict.keys()): |
|
if "fc" in key: |
|
state_dict[key.replace("fc","linear")] = state_dict.pop(key) |
|
if "patch_embed" in key: |
|
state_dict[key.replace("patch_embed","")] = state_dict.pop(key) |
|
model.load_state_dict(state_dict, strict=False) |
|
except ValueError: |
|
raise ValueError("Self-supervised pre-trained weights not available for" + str(args.model_name)) |
|
|
|
if args.squared_dice: |
|
dice_loss = DiceLoss( |
|
to_onehot_y=False, sigmoid=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr |
|
) |
|
else: |
|
dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True) |
|
post_sigmoid = Activations(sigmoid=True) |
|
post_pred = AsDiscrete(argmax=False, logit_thresh=0.5) |
|
dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, get_not_nans=True) |
|
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
print("Total parameters count", pytorch_total_params) |
|
|
|
best_acc = 0 |
|
start_epoch = 0 |
|
|
|
if args.checkpoint is not None: |
|
checkpoint = torch.load(args.checkpoint, map_location="cpu") |
|
from collections import OrderedDict |
|
|
|
new_state_dict = OrderedDict() |
|
for k, v in checkpoint["state_dict"].items(): |
|
new_state_dict[k.replace("backbone.", "")] = v |
|
model.load_state_dict(new_state_dict, strict=False) |
|
if "epoch" in checkpoint: |
|
start_epoch = checkpoint["epoch"] |
|
if "best_acc" in checkpoint: |
|
best_acc = checkpoint["best_acc"] |
|
print("=> loaded checkpoint '{}' (epoch {}) (bestacc {})".format(args.checkpoint, start_epoch, best_acc)) |
|
|
|
model.cuda(args.gpu) |
|
|
|
if args.distributed: |
|
torch.cuda.set_device(args.gpu) |
|
if args.norm_name == "batch": |
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
model.cuda(args.gpu) |
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], output_device=args.gpu, find_unused_parameters = False,) |
|
if args.optim_name == "adam": |
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight) |
|
elif args.optim_name == "adamw": |
|
optimizer = torch.optim.AdamW(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight) |
|
elif args.optim_name == "sgd": |
|
optimizer = torch.optim.SGD( |
|
model.parameters(), lr=args.optim_lr, momentum=args.momentum, nesterov=True, weight_decay=args.reg_weight |
|
) |
|
else: |
|
raise ValueError("Unsupported Optimization Procedure: " + str(args.optim_name)) |
|
|
|
if args.lrschedule == "warmup_cosine": |
|
scheduler = LinearWarmupCosineAnnealingLR( |
|
optimizer, warmup_epochs=args.warmup_epochs, max_epochs=args.max_epochs |
|
) |
|
elif args.lrschedule == "cosine_anneal": |
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epochs) |
|
if args.checkpoint is not None: |
|
scheduler.step(epoch=start_epoch) |
|
else: |
|
scheduler = None |
|
|
|
semantic_classes = ["Dice_Val_TC", "Dice_Val_WT", "Dice_Val_ET"] |
|
|
|
accuracy = run_training( |
|
model=model, |
|
train_loader=loader[0], |
|
val_loader=loader[1], |
|
optimizer=optimizer, |
|
loss_func=dice_loss, |
|
acc_func=dice_acc, |
|
args=args, |
|
scheduler=scheduler, |
|
start_epoch=start_epoch, |
|
post_sigmoid=post_sigmoid, |
|
post_pred=post_pred, |
|
semantic_classes=semantic_classes, |
|
) |
|
return accuracy |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|