TextBraTS / trainer.py
Jupitern52's picture
Upload 16 files
2a5693e verified
# Copyright 2020 - 2022 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import time
import numpy as np
import torch
import torch.nn.parallel
import torch.utils.data.distributed
from tensorboardX import SummaryWriter
from torch.amp import GradScaler, autocast
from utils.utils import AverageMeter, distributed_all_gather
from monai.data import decollate_batch
def train_epoch(model, loader, optimizer, scaler, epoch, loss_func, args):
model.train()
start_time = time.time()
run_loss = AverageMeter()
for idx, batch_data in enumerate(loader):
if isinstance(batch_data, list):
data, target, text = batch_data
else:
data, target, text = batch_data["image"], batch_data["label"], batch_data["text_feature"]
data, target, text = data.cuda(args.rank), target.cuda(args.rank), text.cuda(args.rank)
optimizer.zero_grad(set_to_none=True)
with autocast('cuda',enabled=args.amp):
logits = model(data,text)
loss = loss_func(logits, target)
if args.amp:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
if args.distributed:
loss_list = distributed_all_gather([loss], out_numpy=True, is_valid=idx < loader.sampler.valid_length)
run_loss.update(
np.mean(np.mean(np.stack(loss_list, axis=0), axis=0), axis=0), n=args.batch_size * args.world_size
)
else:
run_loss.update(loss.item(), n=args.batch_size)
if args.rank == 0:
print(
"Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)),
"loss: {:.4f}".format(run_loss.avg),
"time {:.2f}s".format(time.time() - start_time),
)
start_time = time.time()
'''for param in model.parameters():
param.grad = None'''
optimizer.zero_grad(set_to_none=True)
return run_loss.avg
def val_epoch(model, loader, epoch, acc_func, args, post_sigmoid=None, post_pred=None):
model.eval()
start_time = time.time()
run_acc = AverageMeter()
with torch.no_grad():
for idx, batch_data in enumerate(loader):
data, target, text = batch_data["image"], batch_data["label"], batch_data["text_feature"]
data, target, text = data.cuda(args.rank), target.cuda(args.rank), text.cuda(args.rank)
with autocast('cuda',enabled=args.amp):
logits = model(data,text)
val_labels_list = decollate_batch(target)
val_outputs_list = decollate_batch(logits)
val_output_convert = [post_pred(post_sigmoid(val_pred_tensor)) for val_pred_tensor in val_outputs_list]
acc_func.reset()
acc_func(y_pred=val_output_convert, y=val_labels_list)
acc, not_nans = acc_func.aggregate()
acc = acc.cuda(args.rank)
if args.distributed:
acc_list, not_nans_list = distributed_all_gather(
[acc, not_nans], out_numpy=True, is_valid=idx < loader.sampler.valid_length
)
for al, nl in zip(acc_list, not_nans_list):
run_acc.update(al, n=nl)
else:
run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy())
if args.rank == 0:
Dice_TC = run_acc.avg[0]
Dice_WT = run_acc.avg[1]
Dice_ET = run_acc.avg[2]
print(
"Val {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)),
", Dice_TC:",
Dice_TC,
", Dice_WT:",
Dice_WT,
", Dice_ET:",
Dice_ET,
", time {:.2f}s".format(time.time() - start_time),
)
start_time = time.time()
return run_acc.avg
def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0, optimizer=None, scheduler=None):
state_dict = model.state_dict() if not args.distributed else model.module.state_dict()
save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict}
if optimizer is not None:
save_dict["optimizer"] = optimizer.state_dict()
if scheduler is not None:
save_dict["scheduler"] = scheduler.state_dict()
filename = os.path.join(args.logdir, filename)
torch.save(save_dict, filename)
print("Saving checkpoint", filename)
def run_training(
model,
train_loader,
val_loader,
optimizer,
loss_func,
acc_func,
args,
scheduler=None,
start_epoch=0,
post_sigmoid=None,
post_pred=None,
semantic_classes=None,
):
writer = None
if args.logdir is not None and args.rank == 0:
writer = SummaryWriter(log_dir=args.logdir)
if args.rank == 0:
print("Writing Tensorboard logs to ", args.logdir)
scaler = None
if args.amp:
scaler = GradScaler()
val_acc_max = 0.0
for epoch in range(start_epoch, args.max_epochs):
if args.distributed:
train_loader.sampler.set_epoch(epoch)
torch.distributed.barrier()
print(args.rank, time.ctime(), "Epoch:", epoch)
epoch_time = time.time()
train_loss = train_epoch(
model, train_loader, optimizer, scaler=scaler, epoch=epoch, loss_func=loss_func, args=args
)
if args.rank == 0:
print(
"Final training {}/{}".format(epoch, args.max_epochs - 1),
"loss: {:.4f}".format(train_loss),
"time {:.2f}s".format(time.time() - epoch_time),
)
if args.rank == 0 and writer is not None:
writer.add_scalar("train_loss", train_loss, epoch)
b_new_best = False
if (epoch + 1) % args.val_every == 0:
if args.distributed:
torch.distributed.barrier()
epoch_time = time.time()
val_acc = val_epoch(
model,
val_loader,
epoch=epoch,
acc_func=acc_func,
args=args,
post_sigmoid=post_sigmoid,
post_pred=post_pred,
)
if args.rank == 0:
Dice_TC = val_acc[0]
Dice_WT = val_acc[1]
Dice_ET = val_acc[2]
print(
"Final validation stats {}/{}".format(epoch, args.max_epochs - 1),
", Dice_TC:",
Dice_TC,
", Dice_WT:",
Dice_WT,
", Dice_ET:",
Dice_ET,
", time {:.2f}s".format(time.time() - epoch_time),
)
if writer is not None:
writer.add_scalar("Mean_Val_Dice", np.mean(val_acc), epoch)
if semantic_classes is not None:
for val_channel_ind in range(len(semantic_classes)):
if val_channel_ind < val_acc.size:
writer.add_scalar(semantic_classes[val_channel_ind], val_acc[val_channel_ind], epoch)
val_avg_acc = np.mean(val_acc)
if val_avg_acc > val_acc_max:
print("new best ({:.6f} --> {:.6f}). ".format(val_acc_max, val_avg_acc))
val_acc_max = val_avg_acc
b_new_best = True
if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
save_checkpoint(
model, epoch, args, best_acc=val_acc_max, optimizer=optimizer, scheduler=scheduler
)
if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
print("Saving")
save_checkpoint(model, epoch, args, best_acc=val_acc_max, filename="model_final.pt")
if b_new_best:
print("Copying to model.pt new best model!!!!")
shutil.copyfile(os.path.join(args.logdir, "model_final.pt"), os.path.join(args.logdir, "model.pt"))
if scheduler is not None:
scheduler.step()
print("Training Finished !, Best Accuracy: ", val_acc_max)
return val_acc_max