# Copyright 2020 - 2022 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.generic") import argparse import os import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.parallel import torch.utils.data.distributed from optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from trainer import run_training from utils.data_utils import get_loader from monai.losses import DiceLoss from monai.metrics import DiceMetric from utils.textswin_unetr import TextSwinUNETR from monai.transforms import Activations, AsDiscrete, Compose from monai.utils.enums import MetricReduction import random parser = argparse.ArgumentParser(description="TextBraTS segmentation pipeline for TextBRATS image-text dataset") parser.add_argument("--checkpoint", default=None, help="start training from saved checkpoint") parser.add_argument("--logdir", default="TextBraTS", type=str, help="directory to save the tensorboard logs") parser.add_argument("--fold", default=0, type=int, help="data fold, 0 for validation and 1 for training") parser.add_argument("--pretrained_model_name", default="model.pt", type=str, help="pretrained model name") parser.add_argument("--data_dir", default="./data/TextBraTSData", type=str, help="dataset directory") parser.add_argument("--json_list", default="./Train.json", type=str, help="dataset json file") parser.add_argument("--save_checkpoint", action="store_true", help="save checkpoint during training") parser.add_argument("--max_epochs", default=200, type=int, help="max number of training epochs") parser.add_argument("--batch_size", default=2, type=int, help="number of batch size") parser.add_argument("--sw_batch_size", default=4, type=int, help="number of sliding window batch size") parser.add_argument("--optim_lr", default=1e-4, type=float, help="optimization learning rate") parser.add_argument("--optim_name", default="adamw", type=str, help="optimization algorithm") parser.add_argument("--reg_weight", default=1e-5, type=float, help="regularization weight") parser.add_argument("--momentum", default=0.99, type=float, help="momentum") parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training") parser.add_argument("--val_every", default=1, type=int, help="validation frequency") parser.add_argument("--distributed", action="store_true", help="start distributed training") parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training") parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training") parser.add_argument("--dist-url", default="tcp://127.0.0.1:23456", type=str, help="distributed url") parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") parser.add_argument("--norm_name", default="instance", type=str, help="normalization name") parser.add_argument("--workers", default=8, type=int, help="number of workers") parser.add_argument("--feature_size", default=48, type=int, help="feature size") parser.add_argument("--in_channels", default=4, type=int, help="number of input channels") parser.add_argument("--out_channels", default=3, type=int, help="number of output channels") parser.add_argument("--cache_dataset", action="store_true", help="use monai Dataset class") parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged") parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged") parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged") parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged") parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction") parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction") parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction") parser.add_argument("--roi_x", default=128, type=int, help="roi size in x direction") parser.add_argument("--roi_y", default=128, type=int, help="roi size in y direction") parser.add_argument("--roi_z", default=128, type=int, help="roi size in z direction") parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate") parser.add_argument("--dropout_path_rate", default=0.0, type=float, help="drop path rate") parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability") parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability") parser.add_argument("--infer_overlap", default=0.5, type=float, help="sliding window inference overlap") parser.add_argument("--lrschedule", default="warmup_cosine", type=str, help="type of learning rate scheduler") parser.add_argument("--warmup_epochs", default=50, type=int, help="number of warmup epochs") parser.add_argument("--resume_ckpt", action="store_true", help="resume training from pretrained checkpoint") parser.add_argument("--smooth_dr", default=1e-6, type=float, help="constant added to dice denominator to avoid nan") parser.add_argument("--smooth_nr", default=0.0, type=float, help="constant added to dice numerator to avoid zero") parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory") parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data") parser.add_argument("--use_ssl_pretrained", action="store_true", help="use SSL pretrained ckpt") parser.add_argument( "--pretrained_dir", default="./runs/TextBraTS/", type=str, help="pretrained checkpoint directory", ) parser.add_argument("--squared_dice", action="store_true", help="use squared Dice") parser.add_argument("--seed", type=int, default=23,help="use random seed") def main(): args = parser.parse_args() args.amp = not args.noamp args.logdir = "./runs/" + args.logdir random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.distributed: torch.cuda.manual_seed_all(args.seed) args.ngpus_per_node = torch.cuda.device_count() print("Found total gpus", args.ngpus_per_node) args.world_size = args.ngpus_per_node * args.world_size mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,)) else: torch.cuda.manual_seed(args.seed) main_worker(gpu=0, args=args) def main_worker(gpu, args): if args.distributed: torch.multiprocessing.set_start_method("fork", force=True) np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True) args.gpu = gpu if args.distributed: args.rank = args.rank * args.ngpus_per_node + gpu dist.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank ) torch.cuda.set_device(args.gpu) torch.backends.cudnn.benchmark = True args.test_mode = False loader = get_loader(args) print(args.rank, " gpu", args.gpu) if args.rank == 0: print("Batch size is:", args.batch_size, "epochs", args.max_epochs) pretrained_dir = args.pretrained_dir model_name = args.pretrained_model_name pretrained_pth = os.path.join(pretrained_dir, model_name) model = TextSwinUNETR( img_size=(args.roi_x, args.roi_y, args.roi_z), in_channels=args.in_channels, out_channels=args.out_channels, feature_size=args.feature_size, use_checkpoint=args.use_checkpoint, text_dim=768, ) if args.resume_ckpt: model_dict = torch.load(pretrained_pth)["state_dict"] for key in list(model_dict.keys()): model_dict[key.replace("module.", "")] = model_dict.pop(key) model.load_state_dict(model_dict,strict=True) print("Using pretrained weights") if args.use_ssl_pretrained: try: model_dict = torch.load("/media/iipl/disk1/swinunetr/model_swinvit.pt",weights_only=True) state_dict = model_dict["state_dict"] # fix potential differences in state dict keys from pre-training to # fine-tuning for key in list(state_dict.keys()): state_dict[key.replace("module.", "swinViT.")] = state_dict.pop(key) for key in list(state_dict.keys()): if "fc" in key: state_dict[key.replace("fc","linear")] = state_dict.pop(key) if "patch_embed" in key: state_dict[key.replace("patch_embed","")] = state_dict.pop(key) model.load_state_dict(state_dict, strict=False) except ValueError: raise ValueError("Self-supervised pre-trained weights not available for" + str(args.model_name)) if args.squared_dice: dice_loss = DiceLoss( to_onehot_y=False, sigmoid=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr ) else: dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True) post_sigmoid = Activations(sigmoid=True) post_pred = AsDiscrete(argmax=False, logit_thresh=0.5) dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, get_not_nans=True) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print("Total parameters count", pytorch_total_params) best_acc = 0 start_epoch = 0 if args.checkpoint is not None: checkpoint = torch.load(args.checkpoint, map_location="cpu") from collections import OrderedDict new_state_dict = OrderedDict() for k, v in checkpoint["state_dict"].items(): new_state_dict[k.replace("backbone.", "")] = v model.load_state_dict(new_state_dict, strict=False) if "epoch" in checkpoint: start_epoch = checkpoint["epoch"] if "best_acc" in checkpoint: best_acc = checkpoint["best_acc"] print("=> loaded checkpoint '{}' (epoch {}) (bestacc {})".format(args.checkpoint, start_epoch, best_acc)) model.cuda(args.gpu) if args.distributed: torch.cuda.set_device(args.gpu) if args.norm_name == "batch": model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model.cuda(args.gpu) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], output_device=args.gpu, find_unused_parameters = False,) if args.optim_name == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight) elif args.optim_name == "adamw": optimizer = torch.optim.AdamW(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight) elif args.optim_name == "sgd": optimizer = torch.optim.SGD( model.parameters(), lr=args.optim_lr, momentum=args.momentum, nesterov=True, weight_decay=args.reg_weight ) else: raise ValueError("Unsupported Optimization Procedure: " + str(args.optim_name)) if args.lrschedule == "warmup_cosine": scheduler = LinearWarmupCosineAnnealingLR( optimizer, warmup_epochs=args.warmup_epochs, max_epochs=args.max_epochs ) elif args.lrschedule == "cosine_anneal": scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epochs) if args.checkpoint is not None: scheduler.step(epoch=start_epoch) else: scheduler = None semantic_classes = ["Dice_Val_TC", "Dice_Val_WT", "Dice_Val_ET"] accuracy = run_training( model=model, train_loader=loader[0], val_loader=loader[1], optimizer=optimizer, loss_func=dice_loss, acc_func=dice_acc, args=args, scheduler=scheduler, start_epoch=start_epoch, post_sigmoid=post_sigmoid, post_pred=post_pred, semantic_classes=semantic_classes, ) return accuracy if __name__ == "__main__": main()