File size: 7,210 Bytes
2a5693e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright 2020 - 2022 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from utils.data_utils import get_loader
from utils.textswin_unetr import TextSwinUNETR
import os
import time
import torch
import torch.nn.parallel
import torch.utils.data.distributed
from utils.utils import AverageMeter
from monai.utils.enums import MetricReduction
from monai.metrics import DiceMetric, HausdorffDistanceMetric


parser = argparse.ArgumentParser(description="TextBraTS segmentation pipeline")
parser.add_argument("--data_dir", default="./data/TextBraTSData", type=str, help="dataset directory")
parser.add_argument("--exp_name", default="TextBraTS", type=str, help="experiment name")
parser.add_argument("--json_list", default="Test.json", type=str, help="dataset json file")
parser.add_argument("--fold", default=0, type=int, help="data fold")
parser.add_argument("--pretrained_model_name", default="model.pt", type=str, help="pretrained model name")
parser.add_argument("--feature_size", default=48, type=int, help="feature size")
parser.add_argument("--infer_overlap", default=0.6, type=float, help="sliding window inference overlap")
parser.add_argument("--in_channels", default=4, type=int, help="number of input channels")
parser.add_argument("--out_channels", default=3, type=int, help="number of output channels")
parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged")
parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged")
parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged")
parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged")
parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction")
parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction")
parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction")
parser.add_argument("--roi_x", default=128, type=int, help="roi size in x direction")
parser.add_argument("--roi_y", default=128, type=int, help="roi size in y direction")
parser.add_argument("--roi_z", default=128, type=int, help="roi size in z direction")
parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate")
parser.add_argument("--distributed", action="store_true", help="start distributed training")
parser.add_argument("--workers", default=8, type=int, help="number of workers")
parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability")
parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability")
parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data")
parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory")
parser.add_argument(
    "--pretrained_dir",
    default="./runs/TextBraTS/",
    type=str,
    help="pretrained checkpoint directory",
)


def main():
    args = parser.parse_args()
    args.test_mode = True
    output_directory = "./outputs/" + args.exp_name
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    test_loader = get_loader(args)
    pretrained_dir = args.pretrained_dir
    model_name = args.pretrained_model_name
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pretrained_pth = os.path.join(pretrained_dir, model_name)
    model = TextSwinUNETR(
        img_size=128,
        in_channels=args.in_channels,
        out_channels=args.out_channels,
        feature_size=args.feature_size,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        dropout_path_rate=0.0,
        use_checkpoint=args.use_checkpoint,
        text_dim=768,
    )
    model_dict = torch.load(pretrained_pth)["state_dict"]
    model.load_state_dict(model_dict, strict=False)
    model.eval()
    model.to(device)

    def val_epoch(model, loader, acc_func,  hd95_func):
        model.eval()
        start_time = time.time()
        run_acc = AverageMeter()
        run_hd95 = AverageMeter()

        with torch.no_grad():
            for idx, batch_data in enumerate(loader):
                data, target, text = batch_data["image"], batch_data["label"], batch_data["text_feature"]
                data, target, text = data.cuda(), target.cuda(), text.cuda()
                logits = model(data,text)
                prob = torch.sigmoid(logits)
                prob = (prob > 0.5).int()

                acc_func(y_pred=prob, y=target)
                acc, not_nans = acc_func.aggregate()
                acc = acc.cuda()

                run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy())

                # HD95 Metric
                hd95_func(y_pred=prob, y=target)
                hd95 = hd95_func.aggregate()  # Assuming it returns a single value
                run_hd95.update(hd95.cpu().numpy())


                Dice_TC = run_acc.avg[0]
                Dice_WT = run_acc.avg[1]
                Dice_ET = run_acc.avg[2]
                HD95_TC = run_hd95.avg[0]
                HD95_WT = run_hd95.avg[1]
                HD95_ET = run_hd95.avg[2]
                print(
                    "Val  {}/{}".format(idx, len(loader)),
                    ", Dice_TC:", Dice_TC,
                    ", Dice_WT:", Dice_WT,
                    ", Dice_ET:", Dice_ET,
                    ", Avg Dice:", (Dice_ET + Dice_TC + Dice_WT) / 3,
                    ", HD95_TC:", HD95_TC,
                    ", HD95_WT:", HD95_WT,
                    ", HD95_ET:", HD95_ET,
                    ", Avg HD95:", (HD95_ET + HD95_TC + HD95_WT) / 3,
                    ", time {:.2f}s".format(time.time() - start_time),
                )
                start_time = time.time()
            with open(output_directory+'/log.txt', "a") as log_file:
                log_file.write(f"Experiment name:{args.pretrained_dir.split('/')[-2]}, "
            f"Final Validation Results - Dice_TC: {Dice_TC}, Dice_WT: {Dice_WT}, Dice_ET: {Dice_ET}, "
            f"Avg Dice: {(Dice_ET + Dice_TC + Dice_WT) / 3}, "
            f"HD95_TC: {HD95_TC}, HD95_WT: {HD95_WT}, HD95_ET: {HD95_ET}, "
            f"Avg HD95: {(HD95_ET + HD95_TC + HD95_WT) / 3}\n")
        return run_acc.avg

    dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, get_not_nans=True)
    hd95_acc = HausdorffDistanceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, percentile=95.0)
    val_epoch(model, test_loader, acc_func=dice_acc,hd95_func=hd95_acc)

if __name__ == "__main__":
    main()