TextBraTS / main.py
Jupitern52's picture
Upload 16 files
2a5693e verified
# Copyright 2020 - 2022 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.generic")
import argparse
import os
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.parallel
import torch.utils.data.distributed
from optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from trainer import run_training
from utils.data_utils import get_loader
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from utils.textswin_unetr import TextSwinUNETR
from monai.transforms import Activations, AsDiscrete, Compose
from monai.utils.enums import MetricReduction
import random
parser = argparse.ArgumentParser(description="TextBraTS segmentation pipeline for TextBRATS image-text dataset")
parser.add_argument("--checkpoint", default=None, help="start training from saved checkpoint")
parser.add_argument("--logdir", default="TextBraTS", type=str, help="directory to save the tensorboard logs")
parser.add_argument("--fold", default=0, type=int, help="data fold, 0 for validation and 1 for training")
parser.add_argument("--pretrained_model_name", default="model.pt", type=str, help="pretrained model name")
parser.add_argument("--data_dir", default="./data/TextBraTSData", type=str, help="dataset directory")
parser.add_argument("--json_list", default="./Train.json", type=str, help="dataset json file")
parser.add_argument("--save_checkpoint", action="store_true", help="save checkpoint during training")
parser.add_argument("--max_epochs", default=200, type=int, help="max number of training epochs")
parser.add_argument("--batch_size", default=2, type=int, help="number of batch size")
parser.add_argument("--sw_batch_size", default=4, type=int, help="number of sliding window batch size")
parser.add_argument("--optim_lr", default=1e-4, type=float, help="optimization learning rate")
parser.add_argument("--optim_name", default="adamw", type=str, help="optimization algorithm")
parser.add_argument("--reg_weight", default=1e-5, type=float, help="regularization weight")
parser.add_argument("--momentum", default=0.99, type=float, help="momentum")
parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training")
parser.add_argument("--val_every", default=1, type=int, help="validation frequency")
parser.add_argument("--distributed", action="store_true", help="start distributed training")
parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training")
parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
parser.add_argument("--dist-url", default="tcp://127.0.0.1:23456", type=str, help="distributed url")
parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
parser.add_argument("--norm_name", default="instance", type=str, help="normalization name")
parser.add_argument("--workers", default=8, type=int, help="number of workers")
parser.add_argument("--feature_size", default=48, type=int, help="feature size")
parser.add_argument("--in_channels", default=4, type=int, help="number of input channels")
parser.add_argument("--out_channels", default=3, type=int, help="number of output channels")
parser.add_argument("--cache_dataset", action="store_true", help="use monai Dataset class")
parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged")
parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged")
parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged")
parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged")
parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction")
parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction")
parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction")
parser.add_argument("--roi_x", default=128, type=int, help="roi size in x direction")
parser.add_argument("--roi_y", default=128, type=int, help="roi size in y direction")
parser.add_argument("--roi_z", default=128, type=int, help="roi size in z direction")
parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate")
parser.add_argument("--dropout_path_rate", default=0.0, type=float, help="drop path rate")
parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability")
parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability")
parser.add_argument("--infer_overlap", default=0.5, type=float, help="sliding window inference overlap")
parser.add_argument("--lrschedule", default="warmup_cosine", type=str, help="type of learning rate scheduler")
parser.add_argument("--warmup_epochs", default=50, type=int, help="number of warmup epochs")
parser.add_argument("--resume_ckpt", action="store_true", help="resume training from pretrained checkpoint")
parser.add_argument("--smooth_dr", default=1e-6, type=float, help="constant added to dice denominator to avoid nan")
parser.add_argument("--smooth_nr", default=0.0, type=float, help="constant added to dice numerator to avoid zero")
parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory")
parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data")
parser.add_argument("--use_ssl_pretrained", action="store_true", help="use SSL pretrained ckpt")
parser.add_argument(
"--pretrained_dir",
default="./runs/TextBraTS/",
type=str,
help="pretrained checkpoint directory",
)
parser.add_argument("--squared_dice", action="store_true", help="use squared Dice")
parser.add_argument("--seed", type=int, default=23,help="use random seed")
def main():
args = parser.parse_args()
args.amp = not args.noamp
args.logdir = "./runs/" + args.logdir
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.distributed:
torch.cuda.manual_seed_all(args.seed)
args.ngpus_per_node = torch.cuda.device_count()
print("Found total gpus", args.ngpus_per_node)
args.world_size = args.ngpus_per_node * args.world_size
mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,))
else:
torch.cuda.manual_seed(args.seed)
main_worker(gpu=0, args=args)
def main_worker(gpu, args):
if args.distributed:
torch.multiprocessing.set_start_method("fork", force=True)
np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True)
args.gpu = gpu
if args.distributed:
args.rank = args.rank * args.ngpus_per_node + gpu
dist.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
torch.cuda.set_device(args.gpu)
torch.backends.cudnn.benchmark = True
args.test_mode = False
loader = get_loader(args)
print(args.rank, " gpu", args.gpu)
if args.rank == 0:
print("Batch size is:", args.batch_size, "epochs", args.max_epochs)
pretrained_dir = args.pretrained_dir
model_name = args.pretrained_model_name
pretrained_pth = os.path.join(pretrained_dir, model_name)
model = TextSwinUNETR(
img_size=(args.roi_x, args.roi_y, args.roi_z),
in_channels=args.in_channels,
out_channels=args.out_channels,
feature_size=args.feature_size,
use_checkpoint=args.use_checkpoint,
text_dim=768,
)
if args.resume_ckpt:
model_dict = torch.load(pretrained_pth)["state_dict"]
for key in list(model_dict.keys()):
model_dict[key.replace("module.", "")] = model_dict.pop(key)
model.load_state_dict(model_dict,strict=True)
print("Using pretrained weights")
if args.use_ssl_pretrained:
try:
model_dict = torch.load("/media/iipl/disk1/swinunetr/model_swinvit.pt",weights_only=True)
state_dict = model_dict["state_dict"]
# fix potential differences in state dict keys from pre-training to
# fine-tuning
for key in list(state_dict.keys()):
state_dict[key.replace("module.", "swinViT.")] = state_dict.pop(key)
for key in list(state_dict.keys()):
if "fc" in key:
state_dict[key.replace("fc","linear")] = state_dict.pop(key)
if "patch_embed" in key:
state_dict[key.replace("patch_embed","")] = state_dict.pop(key)
model.load_state_dict(state_dict, strict=False)
except ValueError:
raise ValueError("Self-supervised pre-trained weights not available for" + str(args.model_name))
if args.squared_dice:
dice_loss = DiceLoss(
to_onehot_y=False, sigmoid=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr
)
else:
dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True)
post_sigmoid = Activations(sigmoid=True)
post_pred = AsDiscrete(argmax=False, logit_thresh=0.5)
dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, get_not_nans=True)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total parameters count", pytorch_total_params)
best_acc = 0
start_epoch = 0
if args.checkpoint is not None:
checkpoint = torch.load(args.checkpoint, map_location="cpu")
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint["state_dict"].items():
new_state_dict[k.replace("backbone.", "")] = v
model.load_state_dict(new_state_dict, strict=False)
if "epoch" in checkpoint:
start_epoch = checkpoint["epoch"]
if "best_acc" in checkpoint:
best_acc = checkpoint["best_acc"]
print("=> loaded checkpoint '{}' (epoch {}) (bestacc {})".format(args.checkpoint, start_epoch, best_acc))
model.cuda(args.gpu)
if args.distributed:
torch.cuda.set_device(args.gpu)
if args.norm_name == "batch":
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda(args.gpu)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], output_device=args.gpu, find_unused_parameters = False,)
if args.optim_name == "adam":
optimizer = torch.optim.Adam(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight)
elif args.optim_name == "adamw":
optimizer = torch.optim.AdamW(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight)
elif args.optim_name == "sgd":
optimizer = torch.optim.SGD(
model.parameters(), lr=args.optim_lr, momentum=args.momentum, nesterov=True, weight_decay=args.reg_weight
)
else:
raise ValueError("Unsupported Optimization Procedure: " + str(args.optim_name))
if args.lrschedule == "warmup_cosine":
scheduler = LinearWarmupCosineAnnealingLR(
optimizer, warmup_epochs=args.warmup_epochs, max_epochs=args.max_epochs
)
elif args.lrschedule == "cosine_anneal":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epochs)
if args.checkpoint is not None:
scheduler.step(epoch=start_epoch)
else:
scheduler = None
semantic_classes = ["Dice_Val_TC", "Dice_Val_WT", "Dice_Val_ET"]
accuracy = run_training(
model=model,
train_loader=loader[0],
val_loader=loader[1],
optimizer=optimizer,
loss_func=dice_loss,
acc_func=dice_acc,
args=args,
scheduler=scheduler,
start_epoch=start_epoch,
post_sigmoid=post_sigmoid,
post_pred=post_pred,
semantic_classes=semantic_classes,
)
return accuracy
if __name__ == "__main__":
main()