|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
from utils.data_utils import get_loader |
|
from utils.textswin_unetr import TextSwinUNETR |
|
import os |
|
import time |
|
import torch |
|
import torch.nn.parallel |
|
import torch.utils.data.distributed |
|
from utils.utils import AverageMeter |
|
from monai.utils.enums import MetricReduction |
|
from monai.metrics import DiceMetric, HausdorffDistanceMetric |
|
|
|
|
|
parser = argparse.ArgumentParser(description="TextBraTS segmentation pipeline") |
|
parser.add_argument("--data_dir", default="./data/TextBraTSData", type=str, help="dataset directory") |
|
parser.add_argument("--exp_name", default="TextBraTS", type=str, help="experiment name") |
|
parser.add_argument("--json_list", default="Test.json", type=str, help="dataset json file") |
|
parser.add_argument("--fold", default=0, type=int, help="data fold") |
|
parser.add_argument("--pretrained_model_name", default="model.pt", type=str, help="pretrained model name") |
|
parser.add_argument("--feature_size", default=48, type=int, help="feature size") |
|
parser.add_argument("--infer_overlap", default=0.6, type=float, help="sliding window inference overlap") |
|
parser.add_argument("--in_channels", default=4, type=int, help="number of input channels") |
|
parser.add_argument("--out_channels", default=3, type=int, help="number of output channels") |
|
parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged") |
|
parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged") |
|
parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged") |
|
parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged") |
|
parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction") |
|
parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction") |
|
parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction") |
|
parser.add_argument("--roi_x", default=128, type=int, help="roi size in x direction") |
|
parser.add_argument("--roi_y", default=128, type=int, help="roi size in y direction") |
|
parser.add_argument("--roi_z", default=128, type=int, help="roi size in z direction") |
|
parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate") |
|
parser.add_argument("--distributed", action="store_true", help="start distributed training") |
|
parser.add_argument("--workers", default=8, type=int, help="number of workers") |
|
parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability") |
|
parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability") |
|
parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data") |
|
parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory") |
|
parser.add_argument( |
|
"--pretrained_dir", |
|
default="./runs/TextBraTS/", |
|
type=str, |
|
help="pretrained checkpoint directory", |
|
) |
|
|
|
|
|
def main(): |
|
args = parser.parse_args() |
|
args.test_mode = True |
|
output_directory = "./outputs/" + args.exp_name |
|
if not os.path.exists(output_directory): |
|
os.makedirs(output_directory) |
|
test_loader = get_loader(args) |
|
pretrained_dir = args.pretrained_dir |
|
model_name = args.pretrained_model_name |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
pretrained_pth = os.path.join(pretrained_dir, model_name) |
|
model = TextSwinUNETR( |
|
img_size=128, |
|
in_channels=args.in_channels, |
|
out_channels=args.out_channels, |
|
feature_size=args.feature_size, |
|
drop_rate=0.0, |
|
attn_drop_rate=0.0, |
|
dropout_path_rate=0.0, |
|
use_checkpoint=args.use_checkpoint, |
|
text_dim=768, |
|
) |
|
model_dict = torch.load(pretrained_pth)["state_dict"] |
|
model.load_state_dict(model_dict, strict=False) |
|
model.eval() |
|
model.to(device) |
|
|
|
def val_epoch(model, loader, acc_func, hd95_func): |
|
model.eval() |
|
start_time = time.time() |
|
run_acc = AverageMeter() |
|
run_hd95 = AverageMeter() |
|
|
|
with torch.no_grad(): |
|
for idx, batch_data in enumerate(loader): |
|
data, target, text = batch_data["image"], batch_data["label"], batch_data["text_feature"] |
|
data, target, text = data.cuda(), target.cuda(), text.cuda() |
|
logits = model(data,text) |
|
prob = torch.sigmoid(logits) |
|
prob = (prob > 0.5).int() |
|
|
|
acc_func(y_pred=prob, y=target) |
|
acc, not_nans = acc_func.aggregate() |
|
acc = acc.cuda() |
|
|
|
run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy()) |
|
|
|
|
|
hd95_func(y_pred=prob, y=target) |
|
hd95 = hd95_func.aggregate() |
|
run_hd95.update(hd95.cpu().numpy()) |
|
|
|
|
|
Dice_TC = run_acc.avg[0] |
|
Dice_WT = run_acc.avg[1] |
|
Dice_ET = run_acc.avg[2] |
|
HD95_TC = run_hd95.avg[0] |
|
HD95_WT = run_hd95.avg[1] |
|
HD95_ET = run_hd95.avg[2] |
|
print( |
|
"Val {}/{}".format(idx, len(loader)), |
|
", Dice_TC:", Dice_TC, |
|
", Dice_WT:", Dice_WT, |
|
", Dice_ET:", Dice_ET, |
|
", Avg Dice:", (Dice_ET + Dice_TC + Dice_WT) / 3, |
|
", HD95_TC:", HD95_TC, |
|
", HD95_WT:", HD95_WT, |
|
", HD95_ET:", HD95_ET, |
|
", Avg HD95:", (HD95_ET + HD95_TC + HD95_WT) / 3, |
|
", time {:.2f}s".format(time.time() - start_time), |
|
) |
|
start_time = time.time() |
|
with open(output_directory+'/log.txt', "a") as log_file: |
|
log_file.write(f"Experiment name:{args.pretrained_dir.split('/')[-2]}, " |
|
f"Final Validation Results - Dice_TC: {Dice_TC}, Dice_WT: {Dice_WT}, Dice_ET: {Dice_ET}, " |
|
f"Avg Dice: {(Dice_ET + Dice_TC + Dice_WT) / 3}, " |
|
f"HD95_TC: {HD95_TC}, HD95_WT: {HD95_WT}, HD95_ET: {HD95_ET}, " |
|
f"Avg HD95: {(HD95_ET + HD95_TC + HD95_WT) / 3}\n") |
|
return run_acc.avg |
|
|
|
dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, get_not_nans=True) |
|
hd95_acc = HausdorffDistanceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, percentile=95.0) |
|
val_epoch(model, test_loader, acc_func=dice_acc,hd95_func=hd95_acc) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|