Jupitern52 commited on
Commit
2a5693e
·
verified ·
1 Parent(s): 8f3c22b

Upload 16 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/datasample.PNG filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,77 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TextBraTS
2
+
3
+ A volume-level text-image public dataset with novel text-guided 3D brain tumor segmentation from BraTS challenge.
4
+
5
+ ---
6
+
7
+ ## Introduction
8
+
9
+ **TextBraTS** is an open-access dataset designed to advance research in text-guided 3D brain tumor segmentation. It includes paired multi-modal brain MRI scans and expertly annotated radiology reports, enabling the development and evaluation of multi-modal deep learning models that bridge vision and language in neuro-oncology. Our work has been accepted by MICCAI 2025. The paper is also available on arXiv: [2506.16784](https://arxiv.org/abs/2506.16784).
10
+
11
+ ![TextBraTS datasample](assets/datasample.PNG)
12
+
13
+ ## Features
14
+
15
+ - Multi-modal 3D brain MRI scans with expert-annotated segmentation (T1, T1ce, T2, FLAIR) from BraTS20 challenge training set
16
+ - Structured radiology reports for each case
17
+ - Text-image alignment method for research on multi-modal fusion
18
+
19
+ ![TextBraTS Overview](assets/overview.PNG)
20
+
21
+ ## Usage
22
+
23
+ You can use this dataset for:
24
+ - Developing and benchmarking text-guided segmentation models
25
+ - Evaluating multi-modal fusion algorithms in medical imaging
26
+ - Research in language-driven medical AI
27
+
28
+ ## Installing Dependencies
29
+ Run the following commands to set up the environment:
30
+ <pre>conda env create -f environment.yml
31
+ pip install git+https://github.com/Project-MONAI/MONAI.git@07de215c </pre>
32
+ If you need to activate the environment, use:
33
+ <pre>conda activate TextBraTS </pre>
34
+
35
+ ## Dataset
36
+
37
+ Due to BraTS official guidelines, MRI images must be downloaded directly from the [BraTS 2020 challenge website](https://www.med.upenn.edu/cbica/brats2020/data.html) (training set).
38
+
39
+ **Download our text, feature, and prompt files:**
40
+ You can download our dataset from [TextBraTSData](https://drive.google.com/file/d/1i1R6_bVY4VbNtxEIQVsiXUSWuVAtgJhg/view?usp=sharing).
41
+ Our provided text reports, feature files, and prompt files are named to match the original BraTS folder IDs exactly. You can set the path and simply merge them with the downloaded MRI data by `merge.py`.
42
+ <pre>python merge.py</pre>
43
+
44
+ If you would like to change the dataset split, please modify the `Train.json` and `Test.json` files accordingly.
45
+
46
+ ## Inference
47
+
48
+ We provide our pre-trained weights for direct inference and evaluation.
49
+ Download the weights from [checkpoint](https://drive.google.com/file/d/147283LL2fRDcTYR_vQA-95vbZysjjD1v/view?usp=sharing).
50
+
51
+ After downloading, place the weights in your desired directory, then run the `test.py` with following command for inference:
52
+
53
+ <pre>python test.py --pretrained_dir=/path/to/your/weights/ --exp_name=TextBraTS</pre>
54
+
55
+ ## Training
56
+
57
+ If you would like to train the model from scratch, you can modify the training code `main.py` and please use the following command:
58
+
59
+ <pre>python main.py --distributed --use_ssl_pretrained --save_checkpoint --logdir=TextBraTS</pre>
60
+
61
+ - The `--use_ssl_pretrained` option utilizes the pre-trained weights from NVIDIA's Swin UNETR model.
62
+ - Download the Swin UNETR pre-trained weights from [Pre-trained weights](https://drive.google.com/file/d/1FJ0N_Xo3olzAV-oojEkAsbsUgiFsoPdl/view?usp=sharing).
63
+ - Please place the downloaded weights in the appropriate directory as specified in your configuration or script.
64
+
65
+
66
+ ## Citation
67
+
68
+ If you use TextBraTS in your research, please cite:
69
+
70
+ ```bibtex
71
+ @inproceedings{shi2025textbrats,
72
+ title = {TextBraTS: Text-Guided Volumetric Brain Tumor Segmentation with Innovative Dataset Development and Fusion Module Exploration},
73
+ author = {Shi, Xiaoyu and Jain, Rahul Kumar and Li, Yinhao and Hou, Ruibo and Cheng, Jingliang and Bai, Jie and Zhao, Guohua and Lin, Lanfen and Xu, Rui and Chen, Yen-wei},
74
+ booktitle = {Proceedings of the International Conference on Medical Image Computing and Computer Assisted Intervention (MICCAI)},
75
+ year = {2025},
76
+ note = {to appear}
77
+ }
Test.json ADDED
@@ -0,0 +1,1027 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "training": [
3
+ {
4
+ "fold": 0,
5
+ "image": [
6
+ "BraTS20_Training_101/BraTS20_Training_101_flair.nii.gz",
7
+ "BraTS20_Training_101/BraTS20_Training_101_t1.nii.gz",
8
+ "BraTS20_Training_101/BraTS20_Training_101_t1ce.nii.gz",
9
+ "BraTS20_Training_101/BraTS20_Training_101_t2.nii.gz"
10
+ ],
11
+ "label": "BraTS20_Training_101/BraTS20_Training_101_seg.nii.gz",
12
+ "text_feature": "BraTS20_Training_101/BraTS20_Training_101_flair_text.npy"
13
+ },
14
+ {
15
+ "fold": 0,
16
+ "image": [
17
+ "BraTS20_Training_005/BraTS20_Training_005_flair.nii.gz",
18
+ "BraTS20_Training_005/BraTS20_Training_005_t1.nii.gz",
19
+ "BraTS20_Training_005/BraTS20_Training_005_t1ce.nii.gz",
20
+ "BraTS20_Training_005/BraTS20_Training_005_t2.nii.gz"
21
+ ],
22
+ "label": "BraTS20_Training_005/BraTS20_Training_005_seg.nii.gz",
23
+ "text_feature": "BraTS20_Training_005/BraTS20_Training_005_flair_text.npy"
24
+ },
25
+ {
26
+ "fold": 0,
27
+ "image": [
28
+ "BraTS20_Training_173/BraTS20_Training_173_flair.nii.gz",
29
+ "BraTS20_Training_173/BraTS20_Training_173_t1.nii.gz",
30
+ "BraTS20_Training_173/BraTS20_Training_173_t1ce.nii.gz",
31
+ "BraTS20_Training_173/BraTS20_Training_173_t2.nii.gz"
32
+ ],
33
+ "label": "BraTS20_Training_173/BraTS20_Training_173_seg.nii.gz",
34
+ "text_feature": "BraTS20_Training_173/BraTS20_Training_173_flair_text.npy"
35
+ },
36
+ {
37
+ "fold": 0,
38
+ "image": [
39
+ "BraTS20_Training_241/BraTS20_Training_241_flair.nii.gz",
40
+ "BraTS20_Training_241/BraTS20_Training_241_t1.nii.gz",
41
+ "BraTS20_Training_241/BraTS20_Training_241_t1ce.nii.gz",
42
+ "BraTS20_Training_241/BraTS20_Training_241_t2.nii.gz"
43
+ ],
44
+ "label": "BraTS20_Training_241/BraTS20_Training_241_seg.nii.gz",
45
+ "text_feature": "BraTS20_Training_241/BraTS20_Training_241_flair_text.npy"
46
+ },
47
+ {
48
+ "fold": 0,
49
+ "image": [
50
+ "BraTS20_Training_309/BraTS20_Training_309_flair.nii.gz",
51
+ "BraTS20_Training_309/BraTS20_Training_309_t1.nii.gz",
52
+ "BraTS20_Training_309/BraTS20_Training_309_t1ce.nii.gz",
53
+ "BraTS20_Training_309/BraTS20_Training_309_t2.nii.gz"
54
+ ],
55
+ "label": "BraTS20_Training_309/BraTS20_Training_309_seg.nii.gz",
56
+ "text_feature": "BraTS20_Training_309/BraTS20_Training_309_flair_text.npy"
57
+ },
58
+ {
59
+ "fold": 0,
60
+ "image": [
61
+ "BraTS20_Training_258/BraTS20_Training_258_flair.nii.gz",
62
+ "BraTS20_Training_258/BraTS20_Training_258_t1.nii.gz",
63
+ "BraTS20_Training_258/BraTS20_Training_258_t1ce.nii.gz",
64
+ "BraTS20_Training_258/BraTS20_Training_258_t2.nii.gz"
65
+ ],
66
+ "label": "BraTS20_Training_258/BraTS20_Training_258_seg.nii.gz",
67
+ "text_feature": "BraTS20_Training_258/BraTS20_Training_258_flair_text.npy"
68
+ },
69
+ {
70
+ "fold": 0,
71
+ "image": [
72
+ "BraTS20_Training_194/BraTS20_Training_194_flair.nii.gz",
73
+ "BraTS20_Training_194/BraTS20_Training_194_t1.nii.gz",
74
+ "BraTS20_Training_194/BraTS20_Training_194_t1ce.nii.gz",
75
+ "BraTS20_Training_194/BraTS20_Training_194_t2.nii.gz"
76
+ ],
77
+ "label": "BraTS20_Training_194/BraTS20_Training_194_seg.nii.gz",
78
+ "text_feature": "BraTS20_Training_194/BraTS20_Training_194_flair_text.npy"
79
+ },
80
+ {
81
+ "fold": 0,
82
+ "image": [
83
+ "BraTS20_Training_103/BraTS20_Training_103_flair.nii.gz",
84
+ "BraTS20_Training_103/BraTS20_Training_103_t1.nii.gz",
85
+ "BraTS20_Training_103/BraTS20_Training_103_t1ce.nii.gz",
86
+ "BraTS20_Training_103/BraTS20_Training_103_t2.nii.gz"
87
+ ],
88
+ "label": "BraTS20_Training_103/BraTS20_Training_103_seg.nii.gz",
89
+ "text_feature": "BraTS20_Training_103/BraTS20_Training_103_flair_text.npy"
90
+ },
91
+ {
92
+ "fold": 0,
93
+ "image": [
94
+ "BraTS20_Training_170/BraTS20_Training_170_flair.nii.gz",
95
+ "BraTS20_Training_170/BraTS20_Training_170_t1.nii.gz",
96
+ "BraTS20_Training_170/BraTS20_Training_170_t1ce.nii.gz",
97
+ "BraTS20_Training_170/BraTS20_Training_170_t2.nii.gz"
98
+ ],
99
+ "label": "BraTS20_Training_170/BraTS20_Training_170_seg.nii.gz",
100
+ "text_feature": "BraTS20_Training_170/BraTS20_Training_170_flair_text.npy"
101
+ },
102
+ {
103
+ "fold": 0,
104
+ "image": [
105
+ "BraTS20_Training_268/BraTS20_Training_268_flair.nii.gz",
106
+ "BraTS20_Training_268/BraTS20_Training_268_t1.nii.gz",
107
+ "BraTS20_Training_268/BraTS20_Training_268_t1ce.nii.gz",
108
+ "BraTS20_Training_268/BraTS20_Training_268_t2.nii.gz"
109
+ ],
110
+ "label": "BraTS20_Training_268/BraTS20_Training_268_seg.nii.gz",
111
+ "text_feature": "BraTS20_Training_268/BraTS20_Training_268_flair_text.npy"
112
+ },
113
+ {
114
+ "fold": 0,
115
+ "image": [
116
+ "BraTS20_Training_346/BraTS20_Training_346_flair.nii.gz",
117
+ "BraTS20_Training_346/BraTS20_Training_346_t1.nii.gz",
118
+ "BraTS20_Training_346/BraTS20_Training_346_t1ce.nii.gz",
119
+ "BraTS20_Training_346/BraTS20_Training_346_t2.nii.gz"
120
+ ],
121
+ "label": "BraTS20_Training_346/BraTS20_Training_346_seg.nii.gz",
122
+ "text_feature": "BraTS20_Training_346/BraTS20_Training_346_flair_text.npy"
123
+ },
124
+ {
125
+ "fold": 0,
126
+ "image": [
127
+ "BraTS20_Training_149/BraTS20_Training_149_flair.nii.gz",
128
+ "BraTS20_Training_149/BraTS20_Training_149_t1.nii.gz",
129
+ "BraTS20_Training_149/BraTS20_Training_149_t1ce.nii.gz",
130
+ "BraTS20_Training_149/BraTS20_Training_149_t2.nii.gz"
131
+ ],
132
+ "label": "BraTS20_Training_149/BraTS20_Training_149_seg.nii.gz",
133
+ "text_feature": "BraTS20_Training_149/BraTS20_Training_149_flair_text.npy"
134
+ },
135
+ {
136
+ "fold": 0,
137
+ "image": [
138
+ "BraTS20_Training_367/BraTS20_Training_367_flair.nii.gz",
139
+ "BraTS20_Training_367/BraTS20_Training_367_t1.nii.gz",
140
+ "BraTS20_Training_367/BraTS20_Training_367_t1ce.nii.gz",
141
+ "BraTS20_Training_367/BraTS20_Training_367_t2.nii.gz"
142
+ ],
143
+ "label": "BraTS20_Training_367/BraTS20_Training_367_seg.nii.gz",
144
+ "text_feature": "BraTS20_Training_367/BraTS20_Training_367_flair_text.npy"
145
+ },
146
+ {
147
+ "fold": 0,
148
+ "image": [
149
+ "BraTS20_Training_220/BraTS20_Training_220_flair.nii.gz",
150
+ "BraTS20_Training_220/BraTS20_Training_220_t1.nii.gz",
151
+ "BraTS20_Training_220/BraTS20_Training_220_t1ce.nii.gz",
152
+ "BraTS20_Training_220/BraTS20_Training_220_t2.nii.gz"
153
+ ],
154
+ "label": "BraTS20_Training_220/BraTS20_Training_220_seg.nii.gz",
155
+ "text_feature": "BraTS20_Training_220/BraTS20_Training_220_flair_text.npy"
156
+ },
157
+ {
158
+ "fold": 0,
159
+ "image": [
160
+ "BraTS20_Training_368/BraTS20_Training_368_flair.nii.gz",
161
+ "BraTS20_Training_368/BraTS20_Training_368_t1.nii.gz",
162
+ "BraTS20_Training_368/BraTS20_Training_368_t1ce.nii.gz",
163
+ "BraTS20_Training_368/BraTS20_Training_368_t2.nii.gz"
164
+ ],
165
+ "label": "BraTS20_Training_368/BraTS20_Training_368_seg.nii.gz",
166
+ "text_feature": "BraTS20_Training_368/BraTS20_Training_368_flair_text.npy"
167
+ },
168
+ {
169
+ "fold": 0,
170
+ "image": [
171
+ "BraTS20_Training_289/BraTS20_Training_289_flair.nii.gz",
172
+ "BraTS20_Training_289/BraTS20_Training_289_t1.nii.gz",
173
+ "BraTS20_Training_289/BraTS20_Training_289_t1ce.nii.gz",
174
+ "BraTS20_Training_289/BraTS20_Training_289_t2.nii.gz"
175
+ ],
176
+ "label": "BraTS20_Training_289/BraTS20_Training_289_seg.nii.gz",
177
+ "text_feature": "BraTS20_Training_289/BraTS20_Training_289_flair_text.npy"
178
+ },
179
+ {
180
+ "fold": 0,
181
+ "image": [
182
+ "BraTS20_Training_084/BraTS20_Training_084_flair.nii.gz",
183
+ "BraTS20_Training_084/BraTS20_Training_084_t1.nii.gz",
184
+ "BraTS20_Training_084/BraTS20_Training_084_t1ce.nii.gz",
185
+ "BraTS20_Training_084/BraTS20_Training_084_t2.nii.gz"
186
+ ],
187
+ "label": "BraTS20_Training_084/BraTS20_Training_084_seg.nii.gz",
188
+ "text_feature": "BraTS20_Training_084/BraTS20_Training_084_flair_text.npy"
189
+ },
190
+ {
191
+ "fold": 0,
192
+ "image": [
193
+ "BraTS20_Training_277/BraTS20_Training_277_flair.nii.gz",
194
+ "BraTS20_Training_277/BraTS20_Training_277_t1.nii.gz",
195
+ "BraTS20_Training_277/BraTS20_Training_277_t1ce.nii.gz",
196
+ "BraTS20_Training_277/BraTS20_Training_277_t2.nii.gz"
197
+ ],
198
+ "label": "BraTS20_Training_277/BraTS20_Training_277_seg.nii.gz",
199
+ "text_feature": "BraTS20_Training_277/BraTS20_Training_277_flair_text.npy"
200
+ },
201
+ {
202
+ "fold": 0,
203
+ "image": [
204
+ "BraTS20_Training_202/BraTS20_Training_202_flair.nii.gz",
205
+ "BraTS20_Training_202/BraTS20_Training_202_t1.nii.gz",
206
+ "BraTS20_Training_202/BraTS20_Training_202_t1ce.nii.gz",
207
+ "BraTS20_Training_202/BraTS20_Training_202_t2.nii.gz"
208
+ ],
209
+ "label": "BraTS20_Training_202/BraTS20_Training_202_seg.nii.gz",
210
+ "text_feature": "BraTS20_Training_202/BraTS20_Training_202_flair_text.npy"
211
+ },
212
+ {
213
+ "fold": 0,
214
+ "image": [
215
+ "BraTS20_Training_151/BraTS20_Training_151_flair.nii.gz",
216
+ "BraTS20_Training_151/BraTS20_Training_151_t1.nii.gz",
217
+ "BraTS20_Training_151/BraTS20_Training_151_t1ce.nii.gz",
218
+ "BraTS20_Training_151/BraTS20_Training_151_t2.nii.gz"
219
+ ],
220
+ "label": "BraTS20_Training_151/BraTS20_Training_151_seg.nii.gz",
221
+ "text_feature": "BraTS20_Training_151/BraTS20_Training_151_flair_text.npy"
222
+ },
223
+ {
224
+ "fold": 0,
225
+ "image": [
226
+ "BraTS20_Training_142/BraTS20_Training_142_flair.nii.gz",
227
+ "BraTS20_Training_142/BraTS20_Training_142_t1.nii.gz",
228
+ "BraTS20_Training_142/BraTS20_Training_142_t1ce.nii.gz",
229
+ "BraTS20_Training_142/BraTS20_Training_142_t2.nii.gz"
230
+ ],
231
+ "label": "BraTS20_Training_142/BraTS20_Training_142_seg.nii.gz",
232
+ "text_feature": "BraTS20_Training_142/BraTS20_Training_142_flair_text.npy"
233
+ },
234
+ {
235
+ "fold": 0,
236
+ "image": [
237
+ "BraTS20_Training_229/BraTS20_Training_229_flair.nii.gz",
238
+ "BraTS20_Training_229/BraTS20_Training_229_t1.nii.gz",
239
+ "BraTS20_Training_229/BraTS20_Training_229_t1ce.nii.gz",
240
+ "BraTS20_Training_229/BraTS20_Training_229_t2.nii.gz"
241
+ ],
242
+ "label": "BraTS20_Training_229/BraTS20_Training_229_seg.nii.gz",
243
+ "text_feature": "BraTS20_Training_229/BraTS20_Training_229_flair_text.npy"
244
+ },
245
+ {
246
+ "fold": 0,
247
+ "image": [
248
+ "BraTS20_Training_322/BraTS20_Training_322_flair.nii.gz",
249
+ "BraTS20_Training_322/BraTS20_Training_322_t1.nii.gz",
250
+ "BraTS20_Training_322/BraTS20_Training_322_t1ce.nii.gz",
251
+ "BraTS20_Training_322/BraTS20_Training_322_t2.nii.gz"
252
+ ],
253
+ "label": "BraTS20_Training_322/BraTS20_Training_322_seg.nii.gz",
254
+ "text_feature": "BraTS20_Training_322/BraTS20_Training_322_flair_text.npy"
255
+ },
256
+ {
257
+ "fold": 0,
258
+ "image": [
259
+ "BraTS20_Training_278/BraTS20_Training_278_flair.nii.gz",
260
+ "BraTS20_Training_278/BraTS20_Training_278_t1.nii.gz",
261
+ "BraTS20_Training_278/BraTS20_Training_278_t1ce.nii.gz",
262
+ "BraTS20_Training_278/BraTS20_Training_278_t2.nii.gz"
263
+ ],
264
+ "label": "BraTS20_Training_278/BraTS20_Training_278_seg.nii.gz",
265
+ "text_feature": "BraTS20_Training_278/BraTS20_Training_278_flair_text.npy"
266
+ },
267
+ {
268
+ "fold": 0,
269
+ "image": [
270
+ "BraTS20_Training_206/BraTS20_Training_206_flair.nii.gz",
271
+ "BraTS20_Training_206/BraTS20_Training_206_t1.nii.gz",
272
+ "BraTS20_Training_206/BraTS20_Training_206_t1ce.nii.gz",
273
+ "BraTS20_Training_206/BraTS20_Training_206_t2.nii.gz"
274
+ ],
275
+ "label": "BraTS20_Training_206/BraTS20_Training_206_seg.nii.gz",
276
+ "text_feature": "BraTS20_Training_206/BraTS20_Training_206_flair_text.npy"
277
+ },
278
+ {
279
+ "fold": 0,
280
+ "image": [
281
+ "BraTS20_Training_049/BraTS20_Training_049_flair.nii.gz",
282
+ "BraTS20_Training_049/BraTS20_Training_049_t1.nii.gz",
283
+ "BraTS20_Training_049/BraTS20_Training_049_t1ce.nii.gz",
284
+ "BraTS20_Training_049/BraTS20_Training_049_t2.nii.gz"
285
+ ],
286
+ "label": "BraTS20_Training_049/BraTS20_Training_049_seg.nii.gz",
287
+ "text_feature": "BraTS20_Training_049/BraTS20_Training_049_flair_text.npy"
288
+ },
289
+ {
290
+ "fold": 0,
291
+ "image": [
292
+ "BraTS20_Training_115/BraTS20_Training_115_flair.nii.gz",
293
+ "BraTS20_Training_115/BraTS20_Training_115_t1.nii.gz",
294
+ "BraTS20_Training_115/BraTS20_Training_115_t1ce.nii.gz",
295
+ "BraTS20_Training_115/BraTS20_Training_115_t2.nii.gz"
296
+ ],
297
+ "label": "BraTS20_Training_115/BraTS20_Training_115_seg.nii.gz",
298
+ "text_feature": "BraTS20_Training_115/BraTS20_Training_115_flair_text.npy"
299
+ },
300
+ {
301
+ "fold": 0,
302
+ "image": [
303
+ "BraTS20_Training_147/BraTS20_Training_147_flair.nii.gz",
304
+ "BraTS20_Training_147/BraTS20_Training_147_t1.nii.gz",
305
+ "BraTS20_Training_147/BraTS20_Training_147_t1ce.nii.gz",
306
+ "BraTS20_Training_147/BraTS20_Training_147_t2.nii.gz"
307
+ ],
308
+ "label": "BraTS20_Training_147/BraTS20_Training_147_seg.nii.gz",
309
+ "text_feature": "BraTS20_Training_147/BraTS20_Training_147_flair_text.npy"
310
+ },
311
+ {
312
+ "fold": 0,
313
+ "image": [
314
+ "BraTS20_Training_226/BraTS20_Training_226_flair.nii.gz",
315
+ "BraTS20_Training_226/BraTS20_Training_226_t1.nii.gz",
316
+ "BraTS20_Training_226/BraTS20_Training_226_t1ce.nii.gz",
317
+ "BraTS20_Training_226/BraTS20_Training_226_t2.nii.gz"
318
+ ],
319
+ "label": "BraTS20_Training_226/BraTS20_Training_226_seg.nii.gz",
320
+ "text_feature": "BraTS20_Training_226/BraTS20_Training_226_flair_text.npy"
321
+ },
322
+ {
323
+ "fold": 0,
324
+ "image": [
325
+ "BraTS20_Training_066/BraTS20_Training_066_flair.nii.gz",
326
+ "BraTS20_Training_066/BraTS20_Training_066_t1.nii.gz",
327
+ "BraTS20_Training_066/BraTS20_Training_066_t1ce.nii.gz",
328
+ "BraTS20_Training_066/BraTS20_Training_066_t2.nii.gz"
329
+ ],
330
+ "label": "BraTS20_Training_066/BraTS20_Training_066_seg.nii.gz",
331
+ "text_feature": "BraTS20_Training_066/BraTS20_Training_066_flair_text.npy"
332
+ },
333
+ {
334
+ "fold": 0,
335
+ "image": [
336
+ "BraTS20_Training_124/BraTS20_Training_124_flair.nii.gz",
337
+ "BraTS20_Training_124/BraTS20_Training_124_t1.nii.gz",
338
+ "BraTS20_Training_124/BraTS20_Training_124_t1ce.nii.gz",
339
+ "BraTS20_Training_124/BraTS20_Training_124_t2.nii.gz"
340
+ ],
341
+ "label": "BraTS20_Training_124/BraTS20_Training_124_seg.nii.gz",
342
+ "text_feature": "BraTS20_Training_124/BraTS20_Training_124_flair_text.npy"
343
+ },
344
+ {
345
+ "fold": 0,
346
+ "image": [
347
+ "BraTS20_Training_274/BraTS20_Training_274_flair.nii.gz",
348
+ "BraTS20_Training_274/BraTS20_Training_274_t1.nii.gz",
349
+ "BraTS20_Training_274/BraTS20_Training_274_t1ce.nii.gz",
350
+ "BraTS20_Training_274/BraTS20_Training_274_t2.nii.gz"
351
+ ],
352
+ "label": "BraTS20_Training_274/BraTS20_Training_274_seg.nii.gz",
353
+ "text_feature": "BraTS20_Training_274/BraTS20_Training_274_flair_text.npy"
354
+ },
355
+ {
356
+ "fold": 0,
357
+ "image": [
358
+ "BraTS20_Training_290/BraTS20_Training_290_flair.nii.gz",
359
+ "BraTS20_Training_290/BraTS20_Training_290_t1.nii.gz",
360
+ "BraTS20_Training_290/BraTS20_Training_290_t1ce.nii.gz",
361
+ "BraTS20_Training_290/BraTS20_Training_290_t2.nii.gz"
362
+ ],
363
+ "label": "BraTS20_Training_290/BraTS20_Training_290_seg.nii.gz",
364
+ "text_feature": "BraTS20_Training_290/BraTS20_Training_290_flair_text.npy"
365
+ },
366
+ {
367
+ "fold": 0,
368
+ "image": [
369
+ "BraTS20_Training_200/BraTS20_Training_200_flair.nii.gz",
370
+ "BraTS20_Training_200/BraTS20_Training_200_t1.nii.gz",
371
+ "BraTS20_Training_200/BraTS20_Training_200_t1ce.nii.gz",
372
+ "BraTS20_Training_200/BraTS20_Training_200_t2.nii.gz"
373
+ ],
374
+ "label": "BraTS20_Training_200/BraTS20_Training_200_seg.nii.gz",
375
+ "text_feature": "BraTS20_Training_200/BraTS20_Training_200_flair_text.npy"
376
+ },
377
+ {
378
+ "fold": 0,
379
+ "image": [
380
+ "BraTS20_Training_121/BraTS20_Training_121_flair.nii.gz",
381
+ "BraTS20_Training_121/BraTS20_Training_121_t1.nii.gz",
382
+ "BraTS20_Training_121/BraTS20_Training_121_t1ce.nii.gz",
383
+ "BraTS20_Training_121/BraTS20_Training_121_t2.nii.gz"
384
+ ],
385
+ "label": "BraTS20_Training_121/BraTS20_Training_121_seg.nii.gz",
386
+ "text_feature": "BraTS20_Training_121/BraTS20_Training_121_flair_text.npy"
387
+ },
388
+ {
389
+ "fold": 0,
390
+ "image": [
391
+ "BraTS20_Training_082/BraTS20_Training_082_flair.nii.gz",
392
+ "BraTS20_Training_082/BraTS20_Training_082_t1.nii.gz",
393
+ "BraTS20_Training_082/BraTS20_Training_082_t1ce.nii.gz",
394
+ "BraTS20_Training_082/BraTS20_Training_082_t2.nii.gz"
395
+ ],
396
+ "label": "BraTS20_Training_082/BraTS20_Training_082_seg.nii.gz",
397
+ "text_feature": "BraTS20_Training_082/BraTS20_Training_082_flair_text.npy"
398
+ },
399
+ {
400
+ "fold": 0,
401
+ "image": [
402
+ "BraTS20_Training_052/BraTS20_Training_052_flair.nii.gz",
403
+ "BraTS20_Training_052/BraTS20_Training_052_t1.nii.gz",
404
+ "BraTS20_Training_052/BraTS20_Training_052_t1ce.nii.gz",
405
+ "BraTS20_Training_052/BraTS20_Training_052_t2.nii.gz"
406
+ ],
407
+ "label": "BraTS20_Training_052/BraTS20_Training_052_seg.nii.gz",
408
+ "text_feature": "BraTS20_Training_052/BraTS20_Training_052_flair_text.npy"
409
+ },
410
+ {
411
+ "fold": 0,
412
+ "image": [
413
+ "BraTS20_Training_104/BraTS20_Training_104_flair.nii.gz",
414
+ "BraTS20_Training_104/BraTS20_Training_104_t1.nii.gz",
415
+ "BraTS20_Training_104/BraTS20_Training_104_t1ce.nii.gz",
416
+ "BraTS20_Training_104/BraTS20_Training_104_t2.nii.gz"
417
+ ],
418
+ "label": "BraTS20_Training_104/BraTS20_Training_104_seg.nii.gz",
419
+ "text_feature": "BraTS20_Training_104/BraTS20_Training_104_flair_text.npy"
420
+ },
421
+ {
422
+ "fold": 0,
423
+ "image": [
424
+ "BraTS20_Training_062/BraTS20_Training_062_flair.nii.gz",
425
+ "BraTS20_Training_062/BraTS20_Training_062_t1.nii.gz",
426
+ "BraTS20_Training_062/BraTS20_Training_062_t1ce.nii.gz",
427
+ "BraTS20_Training_062/BraTS20_Training_062_t2.nii.gz"
428
+ ],
429
+ "label": "BraTS20_Training_062/BraTS20_Training_062_seg.nii.gz",
430
+ "text_feature": "BraTS20_Training_062/BraTS20_Training_062_flair_text.npy"
431
+ },
432
+ {
433
+ "fold": 0,
434
+ "image": [
435
+ "BraTS20_Training_214/BraTS20_Training_214_flair.nii.gz",
436
+ "BraTS20_Training_214/BraTS20_Training_214_t1.nii.gz",
437
+ "BraTS20_Training_214/BraTS20_Training_214_t1ce.nii.gz",
438
+ "BraTS20_Training_214/BraTS20_Training_214_t2.nii.gz"
439
+ ],
440
+ "label": "BraTS20_Training_214/BraTS20_Training_214_seg.nii.gz",
441
+ "text_feature": "BraTS20_Training_214/BraTS20_Training_214_flair_text.npy"
442
+ },
443
+ {
444
+ "fold": 0,
445
+ "image": [
446
+ "BraTS20_Training_360/BraTS20_Training_360_flair.nii.gz",
447
+ "BraTS20_Training_360/BraTS20_Training_360_t1.nii.gz",
448
+ "BraTS20_Training_360/BraTS20_Training_360_t1ce.nii.gz",
449
+ "BraTS20_Training_360/BraTS20_Training_360_t2.nii.gz"
450
+ ],
451
+ "label": "BraTS20_Training_360/BraTS20_Training_360_seg.nii.gz",
452
+ "text_feature": "BraTS20_Training_360/BraTS20_Training_360_flair_text.npy"
453
+ },
454
+ {
455
+ "fold": 0,
456
+ "image": [
457
+ "BraTS20_Training_041/BraTS20_Training_041_flair.nii.gz",
458
+ "BraTS20_Training_041/BraTS20_Training_041_t1.nii.gz",
459
+ "BraTS20_Training_041/BraTS20_Training_041_t1ce.nii.gz",
460
+ "BraTS20_Training_041/BraTS20_Training_041_t2.nii.gz"
461
+ ],
462
+ "label": "BraTS20_Training_041/BraTS20_Training_041_seg.nii.gz",
463
+ "text_feature": "BraTS20_Training_041/BraTS20_Training_041_flair_text.npy"
464
+ },
465
+ {
466
+ "fold": 0,
467
+ "image": [
468
+ "BraTS20_Training_009/BraTS20_Training_009_flair.nii.gz",
469
+ "BraTS20_Training_009/BraTS20_Training_009_t1.nii.gz",
470
+ "BraTS20_Training_009/BraTS20_Training_009_t1ce.nii.gz",
471
+ "BraTS20_Training_009/BraTS20_Training_009_t2.nii.gz"
472
+ ],
473
+ "label": "BraTS20_Training_009/BraTS20_Training_009_seg.nii.gz",
474
+ "text_feature": "BraTS20_Training_009/BraTS20_Training_009_flair_text.npy"
475
+ },
476
+ {
477
+ "fold": 0,
478
+ "image": [
479
+ "BraTS20_Training_347/BraTS20_Training_347_flair.nii.gz",
480
+ "BraTS20_Training_347/BraTS20_Training_347_t1.nii.gz",
481
+ "BraTS20_Training_347/BraTS20_Training_347_t1ce.nii.gz",
482
+ "BraTS20_Training_347/BraTS20_Training_347_t2.nii.gz"
483
+ ],
484
+ "label": "BraTS20_Training_347/BraTS20_Training_347_seg.nii.gz",
485
+ "text_feature": "BraTS20_Training_347/BraTS20_Training_347_flair_text.npy"
486
+ },
487
+ {
488
+ "fold": 0,
489
+ "image": [
490
+ "BraTS20_Training_330/BraTS20_Training_330_flair.nii.gz",
491
+ "BraTS20_Training_330/BraTS20_Training_330_t1.nii.gz",
492
+ "BraTS20_Training_330/BraTS20_Training_330_t1ce.nii.gz",
493
+ "BraTS20_Training_330/BraTS20_Training_330_t2.nii.gz"
494
+ ],
495
+ "label": "BraTS20_Training_330/BraTS20_Training_330_seg.nii.gz",
496
+ "text_feature": "BraTS20_Training_330/BraTS20_Training_330_flair_text.npy"
497
+ },
498
+ {
499
+ "fold": 0,
500
+ "image": [
501
+ "BraTS20_Training_122/BraTS20_Training_122_flair.nii.gz",
502
+ "BraTS20_Training_122/BraTS20_Training_122_t1.nii.gz",
503
+ "BraTS20_Training_122/BraTS20_Training_122_t1ce.nii.gz",
504
+ "BraTS20_Training_122/BraTS20_Training_122_t2.nii.gz"
505
+ ],
506
+ "label": "BraTS20_Training_122/BraTS20_Training_122_seg.nii.gz",
507
+ "text_feature": "BraTS20_Training_122/BraTS20_Training_122_flair_text.npy"
508
+ },
509
+ {
510
+ "fold": 0,
511
+ "image": [
512
+ "BraTS20_Training_340/BraTS20_Training_340_flair.nii.gz",
513
+ "BraTS20_Training_340/BraTS20_Training_340_t1.nii.gz",
514
+ "BraTS20_Training_340/BraTS20_Training_340_t1ce.nii.gz",
515
+ "BraTS20_Training_340/BraTS20_Training_340_t2.nii.gz"
516
+ ],
517
+ "label": "BraTS20_Training_340/BraTS20_Training_340_seg.nii.gz",
518
+ "text_feature": "BraTS20_Training_340/BraTS20_Training_340_flair_text.npy"
519
+ },
520
+ {
521
+ "fold": 0,
522
+ "image": [
523
+ "BraTS20_Training_028/BraTS20_Training_028_flair.nii.gz",
524
+ "BraTS20_Training_028/BraTS20_Training_028_t1.nii.gz",
525
+ "BraTS20_Training_028/BraTS20_Training_028_t1ce.nii.gz",
526
+ "BraTS20_Training_028/BraTS20_Training_028_t2.nii.gz"
527
+ ],
528
+ "label": "BraTS20_Training_028/BraTS20_Training_028_seg.nii.gz",
529
+ "text_feature": "BraTS20_Training_028/BraTS20_Training_028_flair_text.npy"
530
+ },
531
+ {
532
+ "fold": 0,
533
+ "image": [
534
+ "BraTS20_Training_265/BraTS20_Training_265_flair.nii.gz",
535
+ "BraTS20_Training_265/BraTS20_Training_265_t1.nii.gz",
536
+ "BraTS20_Training_265/BraTS20_Training_265_t1ce.nii.gz",
537
+ "BraTS20_Training_265/BraTS20_Training_265_t2.nii.gz"
538
+ ],
539
+ "label": "BraTS20_Training_265/BraTS20_Training_265_seg.nii.gz",
540
+ "text_feature": "BraTS20_Training_265/BraTS20_Training_265_flair_text.npy"
541
+ },
542
+ {
543
+ "fold": 0,
544
+ "image": [
545
+ "BraTS20_Training_192/BraTS20_Training_192_flair.nii.gz",
546
+ "BraTS20_Training_192/BraTS20_Training_192_t1.nii.gz",
547
+ "BraTS20_Training_192/BraTS20_Training_192_t1ce.nii.gz",
548
+ "BraTS20_Training_192/BraTS20_Training_192_t2.nii.gz"
549
+ ],
550
+ "label": "BraTS20_Training_192/BraTS20_Training_192_seg.nii.gz",
551
+ "text_feature": "BraTS20_Training_192/BraTS20_Training_192_flair_text.npy"
552
+ },
553
+ {
554
+ "fold": 0,
555
+ "image": [
556
+ "BraTS20_Training_255/BraTS20_Training_255_flair.nii.gz",
557
+ "BraTS20_Training_255/BraTS20_Training_255_t1.nii.gz",
558
+ "BraTS20_Training_255/BraTS20_Training_255_t1ce.nii.gz",
559
+ "BraTS20_Training_255/BraTS20_Training_255_t2.nii.gz"
560
+ ],
561
+ "label": "BraTS20_Training_255/BraTS20_Training_255_seg.nii.gz",
562
+ "text_feature": "BraTS20_Training_255/BraTS20_Training_255_flair_text.npy"
563
+ },
564
+ {
565
+ "fold": 0,
566
+ "image": [
567
+ "BraTS20_Training_137/BraTS20_Training_137_flair.nii.gz",
568
+ "BraTS20_Training_137/BraTS20_Training_137_t1.nii.gz",
569
+ "BraTS20_Training_137/BraTS20_Training_137_t1ce.nii.gz",
570
+ "BraTS20_Training_137/BraTS20_Training_137_t2.nii.gz"
571
+ ],
572
+ "label": "BraTS20_Training_137/BraTS20_Training_137_seg.nii.gz",
573
+ "text_feature": "BraTS20_Training_137/BraTS20_Training_137_flair_text.npy"
574
+ },
575
+ {
576
+ "fold": 0,
577
+ "image": [
578
+ "BraTS20_Training_001/BraTS20_Training_001_flair.nii.gz",
579
+ "BraTS20_Training_001/BraTS20_Training_001_t1.nii.gz",
580
+ "BraTS20_Training_001/BraTS20_Training_001_t1ce.nii.gz",
581
+ "BraTS20_Training_001/BraTS20_Training_001_t2.nii.gz"
582
+ ],
583
+ "label": "BraTS20_Training_001/BraTS20_Training_001_seg.nii.gz",
584
+ "text_feature": "BraTS20_Training_001/BraTS20_Training_001_flair_text.npy"
585
+ },
586
+ {
587
+ "fold": 0,
588
+ "image": [
589
+ "BraTS20_Training_182/BraTS20_Training_182_flair.nii.gz",
590
+ "BraTS20_Training_182/BraTS20_Training_182_t1.nii.gz",
591
+ "BraTS20_Training_182/BraTS20_Training_182_t1ce.nii.gz",
592
+ "BraTS20_Training_182/BraTS20_Training_182_t2.nii.gz"
593
+ ],
594
+ "label": "BraTS20_Training_182/BraTS20_Training_182_seg.nii.gz",
595
+ "text_feature": "BraTS20_Training_182/BraTS20_Training_182_flair_text.npy"
596
+ },
597
+ {
598
+ "fold": 0,
599
+ "image": [
600
+ "BraTS20_Training_235/BraTS20_Training_235_flair.nii.gz",
601
+ "BraTS20_Training_235/BraTS20_Training_235_t1.nii.gz",
602
+ "BraTS20_Training_235/BraTS20_Training_235_t1ce.nii.gz",
603
+ "BraTS20_Training_235/BraTS20_Training_235_t2.nii.gz"
604
+ ],
605
+ "label": "BraTS20_Training_235/BraTS20_Training_235_seg.nii.gz",
606
+ "text_feature": "BraTS20_Training_235/BraTS20_Training_235_flair_text.npy"
607
+ },
608
+ {
609
+ "fold": 0,
610
+ "image": [
611
+ "BraTS20_Training_299/BraTS20_Training_299_flair.nii.gz",
612
+ "BraTS20_Training_299/BraTS20_Training_299_t1.nii.gz",
613
+ "BraTS20_Training_299/BraTS20_Training_299_t1ce.nii.gz",
614
+ "BraTS20_Training_299/BraTS20_Training_299_t2.nii.gz"
615
+ ],
616
+ "label": "BraTS20_Training_299/BraTS20_Training_299_seg.nii.gz",
617
+ "text_feature": "BraTS20_Training_299/BraTS20_Training_299_flair_text.npy"
618
+ },
619
+ {
620
+ "fold": 0,
621
+ "image": [
622
+ "BraTS20_Training_019/BraTS20_Training_019_flair.nii.gz",
623
+ "BraTS20_Training_019/BraTS20_Training_019_t1.nii.gz",
624
+ "BraTS20_Training_019/BraTS20_Training_019_t1ce.nii.gz",
625
+ "BraTS20_Training_019/BraTS20_Training_019_t2.nii.gz"
626
+ ],
627
+ "label": "BraTS20_Training_019/BraTS20_Training_019_seg.nii.gz",
628
+ "text_feature": "BraTS20_Training_019/BraTS20_Training_019_flair_text.npy"
629
+ },
630
+ {
631
+ "fold": 0,
632
+ "image": [
633
+ "BraTS20_Training_061/BraTS20_Training_061_flair.nii.gz",
634
+ "BraTS20_Training_061/BraTS20_Training_061_t1.nii.gz",
635
+ "BraTS20_Training_061/BraTS20_Training_061_t1ce.nii.gz",
636
+ "BraTS20_Training_061/BraTS20_Training_061_t2.nii.gz"
637
+ ],
638
+ "label": "BraTS20_Training_061/BraTS20_Training_061_seg.nii.gz",
639
+ "text_feature": "BraTS20_Training_061/BraTS20_Training_061_flair_text.npy"
640
+ },
641
+ {
642
+ "fold": 0,
643
+ "image": [
644
+ "BraTS20_Training_250/BraTS20_Training_250_flair.nii.gz",
645
+ "BraTS20_Training_250/BraTS20_Training_250_t1.nii.gz",
646
+ "BraTS20_Training_250/BraTS20_Training_250_t1ce.nii.gz",
647
+ "BraTS20_Training_250/BraTS20_Training_250_t2.nii.gz"
648
+ ],
649
+ "label": "BraTS20_Training_250/BraTS20_Training_250_seg.nii.gz",
650
+ "text_feature": "BraTS20_Training_250/BraTS20_Training_250_flair_text.npy"
651
+ },
652
+ {
653
+ "fold": 0,
654
+ "image": [
655
+ "BraTS20_Training_249/BraTS20_Training_249_flair.nii.gz",
656
+ "BraTS20_Training_249/BraTS20_Training_249_t1.nii.gz",
657
+ "BraTS20_Training_249/BraTS20_Training_249_t1ce.nii.gz",
658
+ "BraTS20_Training_249/BraTS20_Training_249_t2.nii.gz"
659
+ ],
660
+ "label": "BraTS20_Training_249/BraTS20_Training_249_seg.nii.gz",
661
+ "text_feature": "BraTS20_Training_249/BraTS20_Training_249_flair_text.npy"
662
+ },
663
+ {
664
+ "fold": 0,
665
+ "image": [
666
+ "BraTS20_Training_168/BraTS20_Training_168_flair.nii.gz",
667
+ "BraTS20_Training_168/BraTS20_Training_168_t1.nii.gz",
668
+ "BraTS20_Training_168/BraTS20_Training_168_t1ce.nii.gz",
669
+ "BraTS20_Training_168/BraTS20_Training_168_t2.nii.gz"
670
+ ],
671
+ "label": "BraTS20_Training_168/BraTS20_Training_168_seg.nii.gz",
672
+ "text_feature": "BraTS20_Training_168/BraTS20_Training_168_flair_text.npy"
673
+ },
674
+ {
675
+ "fold": 0,
676
+ "image": [
677
+ "BraTS20_Training_313/BraTS20_Training_313_flair.nii.gz",
678
+ "BraTS20_Training_313/BraTS20_Training_313_t1.nii.gz",
679
+ "BraTS20_Training_313/BraTS20_Training_313_t1ce.nii.gz",
680
+ "BraTS20_Training_313/BraTS20_Training_313_t2.nii.gz"
681
+ ],
682
+ "label": "BraTS20_Training_313/BraTS20_Training_313_seg.nii.gz",
683
+ "text_feature": "BraTS20_Training_313/BraTS20_Training_313_flair_text.npy"
684
+ },
685
+ {
686
+ "fold": 0,
687
+ "image": [
688
+ "BraTS20_Training_248/BraTS20_Training_248_flair.nii.gz",
689
+ "BraTS20_Training_248/BraTS20_Training_248_t1.nii.gz",
690
+ "BraTS20_Training_248/BraTS20_Training_248_t1ce.nii.gz",
691
+ "BraTS20_Training_248/BraTS20_Training_248_t2.nii.gz"
692
+ ],
693
+ "label": "BraTS20_Training_248/BraTS20_Training_248_seg.nii.gz",
694
+ "text_feature": "BraTS20_Training_248/BraTS20_Training_248_flair_text.npy"
695
+ },
696
+ {
697
+ "fold": 0,
698
+ "image": [
699
+ "BraTS20_Training_280/BraTS20_Training_280_flair.nii.gz",
700
+ "BraTS20_Training_280/BraTS20_Training_280_t1.nii.gz",
701
+ "BraTS20_Training_280/BraTS20_Training_280_t1ce.nii.gz",
702
+ "BraTS20_Training_280/BraTS20_Training_280_t2.nii.gz"
703
+ ],
704
+ "label": "BraTS20_Training_280/BraTS20_Training_280_seg.nii.gz",
705
+ "text_feature": "BraTS20_Training_280/BraTS20_Training_280_flair_text.npy"
706
+ },
707
+ {
708
+ "fold": 0,
709
+ "image": [
710
+ "BraTS20_Training_156/BraTS20_Training_156_flair.nii.gz",
711
+ "BraTS20_Training_156/BraTS20_Training_156_t1.nii.gz",
712
+ "BraTS20_Training_156/BraTS20_Training_156_t1ce.nii.gz",
713
+ "BraTS20_Training_156/BraTS20_Training_156_t2.nii.gz"
714
+ ],
715
+ "label": "BraTS20_Training_156/BraTS20_Training_156_seg.nii.gz",
716
+ "text_feature": "BraTS20_Training_156/BraTS20_Training_156_flair_text.npy"
717
+ },
718
+ {
719
+ "fold": 0,
720
+ "image": [
721
+ "BraTS20_Training_275/BraTS20_Training_275_flair.nii.gz",
722
+ "BraTS20_Training_275/BraTS20_Training_275_t1.nii.gz",
723
+ "BraTS20_Training_275/BraTS20_Training_275_t1ce.nii.gz",
724
+ "BraTS20_Training_275/BraTS20_Training_275_t2.nii.gz"
725
+ ],
726
+ "label": "BraTS20_Training_275/BraTS20_Training_275_seg.nii.gz",
727
+ "text_feature": "BraTS20_Training_275/BraTS20_Training_275_flair_text.npy"
728
+ },
729
+ {
730
+ "fold": 0,
731
+ "image": [
732
+ "BraTS20_Training_076/BraTS20_Training_076_flair.nii.gz",
733
+ "BraTS20_Training_076/BraTS20_Training_076_t1.nii.gz",
734
+ "BraTS20_Training_076/BraTS20_Training_076_t1ce.nii.gz",
735
+ "BraTS20_Training_076/BraTS20_Training_076_t2.nii.gz"
736
+ ],
737
+ "label": "BraTS20_Training_076/BraTS20_Training_076_seg.nii.gz",
738
+ "text_feature": "BraTS20_Training_076/BraTS20_Training_076_flair_text.npy"
739
+ },
740
+ {
741
+ "fold": 0,
742
+ "image": [
743
+ "BraTS20_Training_327/BraTS20_Training_327_flair.nii.gz",
744
+ "BraTS20_Training_327/BraTS20_Training_327_t1.nii.gz",
745
+ "BraTS20_Training_327/BraTS20_Training_327_t1ce.nii.gz",
746
+ "BraTS20_Training_327/BraTS20_Training_327_t2.nii.gz"
747
+ ],
748
+ "label": "BraTS20_Training_327/BraTS20_Training_327_seg.nii.gz",
749
+ "text_feature": "BraTS20_Training_327/BraTS20_Training_327_flair_text.npy"
750
+ },
751
+ {
752
+ "fold": 0,
753
+ "image": [
754
+ "BraTS20_Training_059/BraTS20_Training_059_flair.nii.gz",
755
+ "BraTS20_Training_059/BraTS20_Training_059_t1.nii.gz",
756
+ "BraTS20_Training_059/BraTS20_Training_059_t1ce.nii.gz",
757
+ "BraTS20_Training_059/BraTS20_Training_059_t2.nii.gz"
758
+ ],
759
+ "label": "BraTS20_Training_059/BraTS20_Training_059_seg.nii.gz",
760
+ "text_feature": "BraTS20_Training_059/BraTS20_Training_059_flair_text.npy"
761
+ },
762
+ {
763
+ "fold": 0,
764
+ "image": [
765
+ "BraTS20_Training_199/BraTS20_Training_199_flair.nii.gz",
766
+ "BraTS20_Training_199/BraTS20_Training_199_t1.nii.gz",
767
+ "BraTS20_Training_199/BraTS20_Training_199_t1ce.nii.gz",
768
+ "BraTS20_Training_199/BraTS20_Training_199_t2.nii.gz"
769
+ ],
770
+ "label": "BraTS20_Training_199/BraTS20_Training_199_seg.nii.gz",
771
+ "text_feature": "BraTS20_Training_199/BraTS20_Training_199_flair_text.npy"
772
+ },
773
+ {
774
+ "fold": 0,
775
+ "image": [
776
+ "BraTS20_Training_044/BraTS20_Training_044_flair.nii.gz",
777
+ "BraTS20_Training_044/BraTS20_Training_044_t1.nii.gz",
778
+ "BraTS20_Training_044/BraTS20_Training_044_t1ce.nii.gz",
779
+ "BraTS20_Training_044/BraTS20_Training_044_t2.nii.gz"
780
+ ],
781
+ "label": "BraTS20_Training_044/BraTS20_Training_044_seg.nii.gz",
782
+ "text_feature": "BraTS20_Training_044/BraTS20_Training_044_flair_text.npy"
783
+ },
784
+ {
785
+ "fold": 0,
786
+ "image": [
787
+ "BraTS20_Training_320/BraTS20_Training_320_flair.nii.gz",
788
+ "BraTS20_Training_320/BraTS20_Training_320_t1.nii.gz",
789
+ "BraTS20_Training_320/BraTS20_Training_320_t1ce.nii.gz",
790
+ "BraTS20_Training_320/BraTS20_Training_320_t2.nii.gz"
791
+ ],
792
+ "label": "BraTS20_Training_320/BraTS20_Training_320_seg.nii.gz",
793
+ "text_feature": "BraTS20_Training_320/BraTS20_Training_320_flair_text.npy"
794
+ },
795
+ {
796
+ "fold": 0,
797
+ "image": [
798
+ "BraTS20_Training_093/BraTS20_Training_093_flair.nii.gz",
799
+ "BraTS20_Training_093/BraTS20_Training_093_t1.nii.gz",
800
+ "BraTS20_Training_093/BraTS20_Training_093_t1ce.nii.gz",
801
+ "BraTS20_Training_093/BraTS20_Training_093_t2.nii.gz"
802
+ ],
803
+ "label": "BraTS20_Training_093/BraTS20_Training_093_seg.nii.gz",
804
+ "text_feature": "BraTS20_Training_093/BraTS20_Training_093_flair_text.npy"
805
+ },
806
+ {
807
+ "fold": 0,
808
+ "image": [
809
+ "BraTS20_Training_224/BraTS20_Training_224_flair.nii.gz",
810
+ "BraTS20_Training_224/BraTS20_Training_224_t1.nii.gz",
811
+ "BraTS20_Training_224/BraTS20_Training_224_t1ce.nii.gz",
812
+ "BraTS20_Training_224/BraTS20_Training_224_t2.nii.gz"
813
+ ],
814
+ "label": "BraTS20_Training_224/BraTS20_Training_224_seg.nii.gz",
815
+ "text_feature": "BraTS20_Training_224/BraTS20_Training_224_flair_text.npy"
816
+ },
817
+ {
818
+ "fold": 0,
819
+ "image": [
820
+ "BraTS20_Training_225/BraTS20_Training_225_flair.nii.gz",
821
+ "BraTS20_Training_225/BraTS20_Training_225_t1.nii.gz",
822
+ "BraTS20_Training_225/BraTS20_Training_225_t1ce.nii.gz",
823
+ "BraTS20_Training_225/BraTS20_Training_225_t2.nii.gz"
824
+ ],
825
+ "label": "BraTS20_Training_225/BraTS20_Training_225_seg.nii.gz",
826
+ "text_feature": "BraTS20_Training_225/BraTS20_Training_225_flair_text.npy"
827
+ },
828
+ {
829
+ "fold": 0,
830
+ "image": [
831
+ "BraTS20_Training_218/BraTS20_Training_218_flair.nii.gz",
832
+ "BraTS20_Training_218/BraTS20_Training_218_t1.nii.gz",
833
+ "BraTS20_Training_218/BraTS20_Training_218_t1ce.nii.gz",
834
+ "BraTS20_Training_218/BraTS20_Training_218_t2.nii.gz"
835
+ ],
836
+ "label": "BraTS20_Training_218/BraTS20_Training_218_seg.nii.gz",
837
+ "text_feature": "BraTS20_Training_218/BraTS20_Training_218_flair_text.npy"
838
+ },
839
+ {
840
+ "fold": 0,
841
+ "image": [
842
+ "BraTS20_Training_014/BraTS20_Training_014_flair.nii.gz",
843
+ "BraTS20_Training_014/BraTS20_Training_014_t1.nii.gz",
844
+ "BraTS20_Training_014/BraTS20_Training_014_t1ce.nii.gz",
845
+ "BraTS20_Training_014/BraTS20_Training_014_t2.nii.gz"
846
+ ],
847
+ "label": "BraTS20_Training_014/BraTS20_Training_014_seg.nii.gz",
848
+ "text_feature": "BraTS20_Training_014/BraTS20_Training_014_flair_text.npy"
849
+ },
850
+ {
851
+ "fold": 0,
852
+ "image": [
853
+ "BraTS20_Training_264/BraTS20_Training_264_flair.nii.gz",
854
+ "BraTS20_Training_264/BraTS20_Training_264_t1.nii.gz",
855
+ "BraTS20_Training_264/BraTS20_Training_264_t1ce.nii.gz",
856
+ "BraTS20_Training_264/BraTS20_Training_264_t2.nii.gz"
857
+ ],
858
+ "label": "BraTS20_Training_264/BraTS20_Training_264_seg.nii.gz",
859
+ "text_feature": "BraTS20_Training_264/BraTS20_Training_264_flair_text.npy"
860
+ },
861
+ {
862
+ "fold": 0,
863
+ "image": [
864
+ "BraTS20_Training_071/BraTS20_Training_071_flair.nii.gz",
865
+ "BraTS20_Training_071/BraTS20_Training_071_t1.nii.gz",
866
+ "BraTS20_Training_071/BraTS20_Training_071_t1ce.nii.gz",
867
+ "BraTS20_Training_071/BraTS20_Training_071_t2.nii.gz"
868
+ ],
869
+ "label": "BraTS20_Training_071/BraTS20_Training_071_seg.nii.gz",
870
+ "text_feature": "BraTS20_Training_071/BraTS20_Training_071_flair_text.npy"
871
+ },
872
+ {
873
+ "fold": 0,
874
+ "image": [
875
+ "BraTS20_Training_167/BraTS20_Training_167_flair.nii.gz",
876
+ "BraTS20_Training_167/BraTS20_Training_167_t1.nii.gz",
877
+ "BraTS20_Training_167/BraTS20_Training_167_t1ce.nii.gz",
878
+ "BraTS20_Training_167/BraTS20_Training_167_t2.nii.gz"
879
+ ],
880
+ "label": "BraTS20_Training_167/BraTS20_Training_167_seg.nii.gz",
881
+ "text_feature": "BraTS20_Training_167/BraTS20_Training_167_flair_text.npy"
882
+ },
883
+ {
884
+ "fold": 0,
885
+ "image": [
886
+ "BraTS20_Training_087/BraTS20_Training_087_flair.nii.gz",
887
+ "BraTS20_Training_087/BraTS20_Training_087_t1.nii.gz",
888
+ "BraTS20_Training_087/BraTS20_Training_087_t1ce.nii.gz",
889
+ "BraTS20_Training_087/BraTS20_Training_087_t2.nii.gz"
890
+ ],
891
+ "label": "BraTS20_Training_087/BraTS20_Training_087_seg.nii.gz",
892
+ "text_feature": "BraTS20_Training_087/BraTS20_Training_087_flair_text.npy"
893
+ },
894
+ {
895
+ "fold": 0,
896
+ "image": [
897
+ "BraTS20_Training_004/BraTS20_Training_004_flair.nii.gz",
898
+ "BraTS20_Training_004/BraTS20_Training_004_t1.nii.gz",
899
+ "BraTS20_Training_004/BraTS20_Training_004_t1ce.nii.gz",
900
+ "BraTS20_Training_004/BraTS20_Training_004_t2.nii.gz"
901
+ ],
902
+ "label": "BraTS20_Training_004/BraTS20_Training_004_seg.nii.gz",
903
+ "text_feature": "BraTS20_Training_004/BraTS20_Training_004_flair_text.npy"
904
+ },
905
+ {
906
+ "fold": 0,
907
+ "image": [
908
+ "BraTS20_Training_133/BraTS20_Training_133_flair.nii.gz",
909
+ "BraTS20_Training_133/BraTS20_Training_133_t1.nii.gz",
910
+ "BraTS20_Training_133/BraTS20_Training_133_t1ce.nii.gz",
911
+ "BraTS20_Training_133/BraTS20_Training_133_t2.nii.gz"
912
+ ],
913
+ "label": "BraTS20_Training_133/BraTS20_Training_133_seg.nii.gz",
914
+ "text_feature": "BraTS20_Training_133/BraTS20_Training_133_flair_text.npy"
915
+ },
916
+ {
917
+ "fold": 0,
918
+ "image": [
919
+ "BraTS20_Training_072/BraTS20_Training_072_flair.nii.gz",
920
+ "BraTS20_Training_072/BraTS20_Training_072_t1.nii.gz",
921
+ "BraTS20_Training_072/BraTS20_Training_072_t1ce.nii.gz",
922
+ "BraTS20_Training_072/BraTS20_Training_072_t2.nii.gz"
923
+ ],
924
+ "label": "BraTS20_Training_072/BraTS20_Training_072_seg.nii.gz",
925
+ "text_feature": "BraTS20_Training_072/BraTS20_Training_072_flair_text.npy"
926
+ },
927
+ {
928
+ "fold": 0,
929
+ "image": [
930
+ "BraTS20_Training_078/BraTS20_Training_078_flair.nii.gz",
931
+ "BraTS20_Training_078/BraTS20_Training_078_t1.nii.gz",
932
+ "BraTS20_Training_078/BraTS20_Training_078_t1ce.nii.gz",
933
+ "BraTS20_Training_078/BraTS20_Training_078_t2.nii.gz"
934
+ ],
935
+ "label": "BraTS20_Training_078/BraTS20_Training_078_seg.nii.gz",
936
+ "text_feature": "BraTS20_Training_078/BraTS20_Training_078_flair_text.npy"
937
+ },
938
+ {
939
+ "fold": 0,
940
+ "image": [
941
+ "BraTS20_Training_119/BraTS20_Training_119_flair.nii.gz",
942
+ "BraTS20_Training_119/BraTS20_Training_119_t1.nii.gz",
943
+ "BraTS20_Training_119/BraTS20_Training_119_t1ce.nii.gz",
944
+ "BraTS20_Training_119/BraTS20_Training_119_t2.nii.gz"
945
+ ],
946
+ "label": "BraTS20_Training_119/BraTS20_Training_119_seg.nii.gz",
947
+ "text_feature": "BraTS20_Training_119/BraTS20_Training_119_flair_text.npy"
948
+ },
949
+ {
950
+ "fold": 0,
951
+ "image": [
952
+ "BraTS20_Training_344/BraTS20_Training_344_flair.nii.gz",
953
+ "BraTS20_Training_344/BraTS20_Training_344_t1.nii.gz",
954
+ "BraTS20_Training_344/BraTS20_Training_344_t1ce.nii.gz",
955
+ "BraTS20_Training_344/BraTS20_Training_344_t2.nii.gz"
956
+ ],
957
+ "label": "BraTS20_Training_344/BraTS20_Training_344_seg.nii.gz",
958
+ "text_feature": "BraTS20_Training_344/BraTS20_Training_344_flair_text.npy"
959
+ },
960
+ {
961
+ "fold": 0,
962
+ "image": [
963
+ "BraTS20_Training_171/BraTS20_Training_171_flair.nii.gz",
964
+ "BraTS20_Training_171/BraTS20_Training_171_t1.nii.gz",
965
+ "BraTS20_Training_171/BraTS20_Training_171_t1ce.nii.gz",
966
+ "BraTS20_Training_171/BraTS20_Training_171_t2.nii.gz"
967
+ ],
968
+ "label": "BraTS20_Training_171/BraTS20_Training_171_seg.nii.gz",
969
+ "text_feature": "BraTS20_Training_171/BraTS20_Training_171_flair_text.npy"
970
+ },
971
+ {
972
+ "fold": 0,
973
+ "image": [
974
+ "BraTS20_Training_297/BraTS20_Training_297_flair.nii.gz",
975
+ "BraTS20_Training_297/BraTS20_Training_297_t1.nii.gz",
976
+ "BraTS20_Training_297/BraTS20_Training_297_t1ce.nii.gz",
977
+ "BraTS20_Training_297/BraTS20_Training_297_t2.nii.gz"
978
+ ],
979
+ "label": "BraTS20_Training_297/BraTS20_Training_297_seg.nii.gz",
980
+ "text_feature": "BraTS20_Training_297/BraTS20_Training_297_flair_text.npy"
981
+ },
982
+ {
983
+ "fold": 0,
984
+ "image": [
985
+ "BraTS20_Training_021/BraTS20_Training_021_flair.nii.gz",
986
+ "BraTS20_Training_021/BraTS20_Training_021_t1.nii.gz",
987
+ "BraTS20_Training_021/BraTS20_Training_021_t1ce.nii.gz",
988
+ "BraTS20_Training_021/BraTS20_Training_021_t2.nii.gz"
989
+ ],
990
+ "label": "BraTS20_Training_021/BraTS20_Training_021_seg.nii.gz",
991
+ "text_feature": "BraTS20_Training_021/BraTS20_Training_021_flair_text.npy"
992
+ },
993
+ {
994
+ "fold": 0,
995
+ "image": [
996
+ "BraTS20_Training_359/BraTS20_Training_359_flair.nii.gz",
997
+ "BraTS20_Training_359/BraTS20_Training_359_t1.nii.gz",
998
+ "BraTS20_Training_359/BraTS20_Training_359_t1ce.nii.gz",
999
+ "BraTS20_Training_359/BraTS20_Training_359_t2.nii.gz"
1000
+ ],
1001
+ "label": "BraTS20_Training_359/BraTS20_Training_359_seg.nii.gz",
1002
+ "text_feature": "BraTS20_Training_359/BraTS20_Training_359_flair_text.npy"
1003
+ },
1004
+ {
1005
+ "fold": 0,
1006
+ "image": [
1007
+ "BraTS20_Training_328/BraTS20_Training_328_flair.nii.gz",
1008
+ "BraTS20_Training_328/BraTS20_Training_328_t1.nii.gz",
1009
+ "BraTS20_Training_328/BraTS20_Training_328_t1ce.nii.gz",
1010
+ "BraTS20_Training_328/BraTS20_Training_328_t2.nii.gz"
1011
+ ],
1012
+ "label": "BraTS20_Training_328/BraTS20_Training_328_seg.nii.gz",
1013
+ "text_feature": "BraTS20_Training_328/BraTS20_Training_328_flair_text.npy"
1014
+ },
1015
+ {
1016
+ "fold": 0,
1017
+ "image": [
1018
+ "BraTS20_Training_233/BraTS20_Training_233_flair.nii.gz",
1019
+ "BraTS20_Training_233/BraTS20_Training_233_t1.nii.gz",
1020
+ "BraTS20_Training_233/BraTS20_Training_233_t1ce.nii.gz",
1021
+ "BraTS20_Training_233/BraTS20_Training_233_t2.nii.gz"
1022
+ ],
1023
+ "label": "BraTS20_Training_233/BraTS20_Training_233_seg.nii.gz",
1024
+ "text_feature": "BraTS20_Training_233/BraTS20_Training_233_flair_text.npy"
1025
+ }
1026
+ ]
1027
+ }
Train.json ADDED
The diff for this file is too large to render. See raw diff
 
assets/datasample.PNG ADDED

Git LFS Details

  • SHA256: d4596574c72d3c7113f1b66cb98afd0493e9c0890f84b5ba4bb6c1a677099e49
  • Pointer size: 131 Bytes
  • Size of remote file: 181 kB
assets/overview.PNG ADDED
environment.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: TextBraTS
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - python=3.11
9
+ - pytorch=2.5.0
10
+ - torchvision
11
+ - torchaudio
12
+ - pytorch-cuda=12.1
13
+ - cudnn=8.9.7
14
+ - numpy=1.26.4
15
+ - sympy=1.13.1
16
+ - fsspec=2025.2
17
+ - tensorboardX=2.6.2.2
18
+ - pip
19
+ - pip:
20
+ - monai
21
+ - nibabel
22
+ - einops
23
+ - scipy
main.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 - 2022 MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ import warnings
12
+ warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.generic")
13
+
14
+ import argparse
15
+ import os
16
+ import numpy as np
17
+ import torch
18
+ import torch.distributed as dist
19
+ import torch.multiprocessing as mp
20
+ import torch.nn.parallel
21
+ import torch.utils.data.distributed
22
+ from optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
23
+ from trainer import run_training
24
+ from utils.data_utils import get_loader
25
+ from monai.losses import DiceLoss
26
+ from monai.metrics import DiceMetric
27
+ from utils.textswin_unetr import TextSwinUNETR
28
+ from monai.transforms import Activations, AsDiscrete, Compose
29
+ from monai.utils.enums import MetricReduction
30
+ import random
31
+
32
+
33
+ parser = argparse.ArgumentParser(description="TextBraTS segmentation pipeline for TextBRATS image-text dataset")
34
+ parser.add_argument("--checkpoint", default=None, help="start training from saved checkpoint")
35
+ parser.add_argument("--logdir", default="TextBraTS", type=str, help="directory to save the tensorboard logs")
36
+ parser.add_argument("--fold", default=0, type=int, help="data fold, 0 for validation and 1 for training")
37
+ parser.add_argument("--pretrained_model_name", default="model.pt", type=str, help="pretrained model name")
38
+ parser.add_argument("--data_dir", default="./data/TextBraTSData", type=str, help="dataset directory")
39
+ parser.add_argument("--json_list", default="./Train.json", type=str, help="dataset json file")
40
+ parser.add_argument("--save_checkpoint", action="store_true", help="save checkpoint during training")
41
+ parser.add_argument("--max_epochs", default=200, type=int, help="max number of training epochs")
42
+ parser.add_argument("--batch_size", default=2, type=int, help="number of batch size")
43
+ parser.add_argument("--sw_batch_size", default=4, type=int, help="number of sliding window batch size")
44
+ parser.add_argument("--optim_lr", default=1e-4, type=float, help="optimization learning rate")
45
+ parser.add_argument("--optim_name", default="adamw", type=str, help="optimization algorithm")
46
+ parser.add_argument("--reg_weight", default=1e-5, type=float, help="regularization weight")
47
+ parser.add_argument("--momentum", default=0.99, type=float, help="momentum")
48
+ parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training")
49
+ parser.add_argument("--val_every", default=1, type=int, help="validation frequency")
50
+ parser.add_argument("--distributed", action="store_true", help="start distributed training")
51
+ parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training")
52
+ parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
53
+ parser.add_argument("--dist-url", default="tcp://127.0.0.1:23456", type=str, help="distributed url")
54
+ parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
55
+ parser.add_argument("--norm_name", default="instance", type=str, help="normalization name")
56
+ parser.add_argument("--workers", default=8, type=int, help="number of workers")
57
+ parser.add_argument("--feature_size", default=48, type=int, help="feature size")
58
+ parser.add_argument("--in_channels", default=4, type=int, help="number of input channels")
59
+ parser.add_argument("--out_channels", default=3, type=int, help="number of output channels")
60
+ parser.add_argument("--cache_dataset", action="store_true", help="use monai Dataset class")
61
+ parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged")
62
+ parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged")
63
+ parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged")
64
+ parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged")
65
+ parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction")
66
+ parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction")
67
+ parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction")
68
+ parser.add_argument("--roi_x", default=128, type=int, help="roi size in x direction")
69
+ parser.add_argument("--roi_y", default=128, type=int, help="roi size in y direction")
70
+ parser.add_argument("--roi_z", default=128, type=int, help="roi size in z direction")
71
+ parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate")
72
+ parser.add_argument("--dropout_path_rate", default=0.0, type=float, help="drop path rate")
73
+ parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability")
74
+ parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability")
75
+ parser.add_argument("--infer_overlap", default=0.5, type=float, help="sliding window inference overlap")
76
+ parser.add_argument("--lrschedule", default="warmup_cosine", type=str, help="type of learning rate scheduler")
77
+ parser.add_argument("--warmup_epochs", default=50, type=int, help="number of warmup epochs")
78
+ parser.add_argument("--resume_ckpt", action="store_true", help="resume training from pretrained checkpoint")
79
+ parser.add_argument("--smooth_dr", default=1e-6, type=float, help="constant added to dice denominator to avoid nan")
80
+ parser.add_argument("--smooth_nr", default=0.0, type=float, help="constant added to dice numerator to avoid zero")
81
+ parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory")
82
+ parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data")
83
+ parser.add_argument("--use_ssl_pretrained", action="store_true", help="use SSL pretrained ckpt")
84
+ parser.add_argument(
85
+ "--pretrained_dir",
86
+ default="./runs/TextBraTS/",
87
+ type=str,
88
+ help="pretrained checkpoint directory",
89
+ )
90
+ parser.add_argument("--squared_dice", action="store_true", help="use squared Dice")
91
+ parser.add_argument("--seed", type=int, default=23,help="use random seed")
92
+
93
+
94
+ def main():
95
+ args = parser.parse_args()
96
+ args.amp = not args.noamp
97
+ args.logdir = "./runs/" + args.logdir
98
+ random.seed(args.seed)
99
+ np.random.seed(args.seed)
100
+ torch.manual_seed(args.seed)
101
+ if args.distributed:
102
+ torch.cuda.manual_seed_all(args.seed)
103
+ args.ngpus_per_node = torch.cuda.device_count()
104
+ print("Found total gpus", args.ngpus_per_node)
105
+ args.world_size = args.ngpus_per_node * args.world_size
106
+ mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,))
107
+ else:
108
+ torch.cuda.manual_seed(args.seed)
109
+ main_worker(gpu=0, args=args)
110
+
111
+
112
+ def main_worker(gpu, args):
113
+ if args.distributed:
114
+ torch.multiprocessing.set_start_method("fork", force=True)
115
+ np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True)
116
+ args.gpu = gpu
117
+ if args.distributed:
118
+ args.rank = args.rank * args.ngpus_per_node + gpu
119
+ dist.init_process_group(
120
+ backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
121
+ )
122
+ torch.cuda.set_device(args.gpu)
123
+ torch.backends.cudnn.benchmark = True
124
+ args.test_mode = False
125
+ loader = get_loader(args)
126
+ print(args.rank, " gpu", args.gpu)
127
+ if args.rank == 0:
128
+ print("Batch size is:", args.batch_size, "epochs", args.max_epochs)
129
+ pretrained_dir = args.pretrained_dir
130
+ model_name = args.pretrained_model_name
131
+ pretrained_pth = os.path.join(pretrained_dir, model_name)
132
+
133
+ model = TextSwinUNETR(
134
+ img_size=(args.roi_x, args.roi_y, args.roi_z),
135
+ in_channels=args.in_channels,
136
+ out_channels=args.out_channels,
137
+ feature_size=args.feature_size,
138
+ use_checkpoint=args.use_checkpoint,
139
+ text_dim=768,
140
+ )
141
+
142
+ if args.resume_ckpt:
143
+ model_dict = torch.load(pretrained_pth)["state_dict"]
144
+ for key in list(model_dict.keys()):
145
+ model_dict[key.replace("module.", "")] = model_dict.pop(key)
146
+ model.load_state_dict(model_dict,strict=True)
147
+ print("Using pretrained weights")
148
+
149
+ if args.use_ssl_pretrained:
150
+ try:
151
+ model_dict = torch.load("/media/iipl/disk1/swinunetr/model_swinvit.pt",weights_only=True)
152
+ state_dict = model_dict["state_dict"]
153
+ # fix potential differences in state dict keys from pre-training to
154
+ # fine-tuning
155
+ for key in list(state_dict.keys()):
156
+ state_dict[key.replace("module.", "swinViT.")] = state_dict.pop(key)
157
+ for key in list(state_dict.keys()):
158
+ if "fc" in key:
159
+ state_dict[key.replace("fc","linear")] = state_dict.pop(key)
160
+ if "patch_embed" in key:
161
+ state_dict[key.replace("patch_embed","")] = state_dict.pop(key)
162
+ model.load_state_dict(state_dict, strict=False)
163
+ except ValueError:
164
+ raise ValueError("Self-supervised pre-trained weights not available for" + str(args.model_name))
165
+
166
+ if args.squared_dice:
167
+ dice_loss = DiceLoss(
168
+ to_onehot_y=False, sigmoid=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr
169
+ )
170
+ else:
171
+ dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True)
172
+ post_sigmoid = Activations(sigmoid=True)
173
+ post_pred = AsDiscrete(argmax=False, logit_thresh=0.5)
174
+ dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, get_not_nans=True)
175
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
176
+ print("Total parameters count", pytorch_total_params)
177
+
178
+ best_acc = 0
179
+ start_epoch = 0
180
+
181
+ if args.checkpoint is not None:
182
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
183
+ from collections import OrderedDict
184
+
185
+ new_state_dict = OrderedDict()
186
+ for k, v in checkpoint["state_dict"].items():
187
+ new_state_dict[k.replace("backbone.", "")] = v
188
+ model.load_state_dict(new_state_dict, strict=False)
189
+ if "epoch" in checkpoint:
190
+ start_epoch = checkpoint["epoch"]
191
+ if "best_acc" in checkpoint:
192
+ best_acc = checkpoint["best_acc"]
193
+ print("=> loaded checkpoint '{}' (epoch {}) (bestacc {})".format(args.checkpoint, start_epoch, best_acc))
194
+
195
+ model.cuda(args.gpu)
196
+
197
+ if args.distributed:
198
+ torch.cuda.set_device(args.gpu)
199
+ if args.norm_name == "batch":
200
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
201
+ model.cuda(args.gpu)
202
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], output_device=args.gpu, find_unused_parameters = False,)
203
+ if args.optim_name == "adam":
204
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight)
205
+ elif args.optim_name == "adamw":
206
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight)
207
+ elif args.optim_name == "sgd":
208
+ optimizer = torch.optim.SGD(
209
+ model.parameters(), lr=args.optim_lr, momentum=args.momentum, nesterov=True, weight_decay=args.reg_weight
210
+ )
211
+ else:
212
+ raise ValueError("Unsupported Optimization Procedure: " + str(args.optim_name))
213
+
214
+ if args.lrschedule == "warmup_cosine":
215
+ scheduler = LinearWarmupCosineAnnealingLR(
216
+ optimizer, warmup_epochs=args.warmup_epochs, max_epochs=args.max_epochs
217
+ )
218
+ elif args.lrschedule == "cosine_anneal":
219
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epochs)
220
+ if args.checkpoint is not None:
221
+ scheduler.step(epoch=start_epoch)
222
+ else:
223
+ scheduler = None
224
+
225
+ semantic_classes = ["Dice_Val_TC", "Dice_Val_WT", "Dice_Val_ET"]
226
+
227
+ accuracy = run_training(
228
+ model=model,
229
+ train_loader=loader[0],
230
+ val_loader=loader[1],
231
+ optimizer=optimizer,
232
+ loss_func=dice_loss,
233
+ acc_func=dice_acc,
234
+ args=args,
235
+ scheduler=scheduler,
236
+ start_epoch=start_epoch,
237
+ post_sigmoid=post_sigmoid,
238
+ post_pred=post_pred,
239
+ semantic_classes=semantic_classes,
240
+ )
241
+ return accuracy
242
+
243
+
244
+ if __name__ == "__main__":
245
+ main()
merge.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ # === Please set your own paths below! ===
5
+ img_root = "/path/to/MICCAI_BraTS2020_TrainingData"
6
+ txt_root = "/path/to/Download/TextBraTSData"
7
+ out_root = "/path/to/TextBraTS/TextBraTSData"
8
+
9
+ # Loop over all cases in the image folder
10
+ for case in os.listdir(img_root):
11
+ img_case_dir = os.path.join(img_root, case)
12
+ txt_case_dir = os.path.join(txt_root, case)
13
+ out_case_dir = os.path.join(out_root, case)
14
+
15
+ if not os.path.isdir(img_case_dir):
16
+ continue # Skip non-directory files
17
+
18
+ # Create output folder for each case
19
+ os.makedirs(out_case_dir, exist_ok=True)
20
+
21
+ # Copy all imaging files and segmentation labels
22
+ for file in os.listdir(img_case_dir):
23
+ shutil.copy2(os.path.join(img_case_dir, file), os.path.join(out_case_dir, file))
24
+
25
+ # Copy text reports and feature files if available
26
+ if os.path.exists(txt_case_dir):
27
+ for file in os.listdir(txt_case_dir):
28
+ shutil.copy2(os.path.join(txt_case_dir, file), os.path.join(out_case_dir, file))
29
+ else:
30
+ print(f"Warning: {txt_case_dir} does not exist, skipping.")
31
+
32
+ print("Merge done! All cases are in:", out_root)
optimizers/__init__.py ADDED
File without changes
optimizers/lr_scheduler.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 - 2021 MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import math
13
+ import warnings
14
+ from typing import List
15
+
16
+ from torch.optim import Adam, Optimizer
17
+ from torch.optim.lr_scheduler import _LRScheduler
18
+
19
+
20
+ class LinearWarmupCosineAnnealingLR(_LRScheduler):
21
+ def __init__(
22
+ self,
23
+ optimizer: Optimizer,
24
+ warmup_epochs: int,
25
+ max_epochs: int,
26
+ warmup_start_lr: float = 0.0,
27
+ eta_min: float = 0.0,
28
+ last_epoch: int = -1,
29
+ ) -> None:
30
+ """
31
+ Args:
32
+ optimizer (Optimizer): Wrapped optimizer.
33
+ warmup_epochs (int): Maximum number of iterations for linear warmup
34
+ max_epochs (int): Maximum number of iterations
35
+ warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
36
+ eta_min (float): Minimum learning rate. Default: 0.
37
+ last_epoch (int): The index of last epoch. Default: -1.
38
+ """
39
+ self.warmup_epochs = warmup_epochs
40
+ self.max_epochs = max_epochs
41
+ self.warmup_start_lr = warmup_start_lr
42
+ self.eta_min = eta_min
43
+
44
+ super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
45
+
46
+ def get_lr(self) -> List[float]:
47
+ """
48
+ Compute learning rate using chainable form of the scheduler
49
+ """
50
+ if not self._get_lr_called_within_step:
51
+ warnings.warn(
52
+ "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning
53
+ )
54
+
55
+ if self.last_epoch == 0:
56
+ return [self.warmup_start_lr] * len(self.base_lrs)
57
+ elif self.last_epoch < self.warmup_epochs:
58
+ return [
59
+ group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
60
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
61
+ ]
62
+ elif self.last_epoch == self.warmup_epochs:
63
+ return self.base_lrs
64
+ elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
65
+ return [
66
+ group["lr"]
67
+ + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
68
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
69
+ ]
70
+
71
+ return [
72
+ (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
73
+ / (
74
+ 1
75
+ + math.cos(
76
+ math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
77
+ )
78
+ )
79
+ * (group["lr"] - self.eta_min)
80
+ + self.eta_min
81
+ for group in self.optimizer.param_groups
82
+ ]
83
+
84
+ def _get_closed_form_lr(self) -> List[float]:
85
+ """
86
+ Called when epoch is passed as a param to the `step` function of the scheduler.
87
+ """
88
+ if self.last_epoch < self.warmup_epochs:
89
+ return [
90
+ self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
91
+ for base_lr in self.base_lrs
92
+ ]
93
+
94
+ return [
95
+ self.eta_min
96
+ + 0.5
97
+ * (base_lr - self.eta_min)
98
+ * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
99
+ for base_lr in self.base_lrs
100
+ ]
test.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 - 2022 MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import argparse
13
+ from utils.data_utils import get_loader
14
+ from utils.textswin_unetr import TextSwinUNETR
15
+ import os
16
+ import time
17
+ import torch
18
+ import torch.nn.parallel
19
+ import torch.utils.data.distributed
20
+ from utils.utils import AverageMeter
21
+ from monai.utils.enums import MetricReduction
22
+ from monai.metrics import DiceMetric, HausdorffDistanceMetric
23
+
24
+
25
+ parser = argparse.ArgumentParser(description="TextBraTS segmentation pipeline")
26
+ parser.add_argument("--data_dir", default="./data/TextBraTSData", type=str, help="dataset directory")
27
+ parser.add_argument("--exp_name", default="TextBraTS", type=str, help="experiment name")
28
+ parser.add_argument("--json_list", default="Test.json", type=str, help="dataset json file")
29
+ parser.add_argument("--fold", default=0, type=int, help="data fold")
30
+ parser.add_argument("--pretrained_model_name", default="model.pt", type=str, help="pretrained model name")
31
+ parser.add_argument("--feature_size", default=48, type=int, help="feature size")
32
+ parser.add_argument("--infer_overlap", default=0.6, type=float, help="sliding window inference overlap")
33
+ parser.add_argument("--in_channels", default=4, type=int, help="number of input channels")
34
+ parser.add_argument("--out_channels", default=3, type=int, help="number of output channels")
35
+ parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged")
36
+ parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged")
37
+ parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged")
38
+ parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged")
39
+ parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction")
40
+ parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction")
41
+ parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction")
42
+ parser.add_argument("--roi_x", default=128, type=int, help="roi size in x direction")
43
+ parser.add_argument("--roi_y", default=128, type=int, help="roi size in y direction")
44
+ parser.add_argument("--roi_z", default=128, type=int, help="roi size in z direction")
45
+ parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate")
46
+ parser.add_argument("--distributed", action="store_true", help="start distributed training")
47
+ parser.add_argument("--workers", default=8, type=int, help="number of workers")
48
+ parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability")
49
+ parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability")
50
+ parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data")
51
+ parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory")
52
+ parser.add_argument(
53
+ "--pretrained_dir",
54
+ default="./runs/TextBraTS/",
55
+ type=str,
56
+ help="pretrained checkpoint directory",
57
+ )
58
+
59
+
60
+ def main():
61
+ args = parser.parse_args()
62
+ args.test_mode = True
63
+ output_directory = "./outputs/" + args.exp_name
64
+ if not os.path.exists(output_directory):
65
+ os.makedirs(output_directory)
66
+ test_loader = get_loader(args)
67
+ pretrained_dir = args.pretrained_dir
68
+ model_name = args.pretrained_model_name
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ pretrained_pth = os.path.join(pretrained_dir, model_name)
71
+ model = TextSwinUNETR(
72
+ img_size=128,
73
+ in_channels=args.in_channels,
74
+ out_channels=args.out_channels,
75
+ feature_size=args.feature_size,
76
+ drop_rate=0.0,
77
+ attn_drop_rate=0.0,
78
+ dropout_path_rate=0.0,
79
+ use_checkpoint=args.use_checkpoint,
80
+ text_dim=768,
81
+ )
82
+ model_dict = torch.load(pretrained_pth)["state_dict"]
83
+ model.load_state_dict(model_dict, strict=False)
84
+ model.eval()
85
+ model.to(device)
86
+
87
+ def val_epoch(model, loader, acc_func, hd95_func):
88
+ model.eval()
89
+ start_time = time.time()
90
+ run_acc = AverageMeter()
91
+ run_hd95 = AverageMeter()
92
+
93
+ with torch.no_grad():
94
+ for idx, batch_data in enumerate(loader):
95
+ data, target, text = batch_data["image"], batch_data["label"], batch_data["text_feature"]
96
+ data, target, text = data.cuda(), target.cuda(), text.cuda()
97
+ logits = model(data,text)
98
+ prob = torch.sigmoid(logits)
99
+ prob = (prob > 0.5).int()
100
+
101
+ acc_func(y_pred=prob, y=target)
102
+ acc, not_nans = acc_func.aggregate()
103
+ acc = acc.cuda()
104
+
105
+ run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy())
106
+
107
+ # HD95 Metric
108
+ hd95_func(y_pred=prob, y=target)
109
+ hd95 = hd95_func.aggregate() # Assuming it returns a single value
110
+ run_hd95.update(hd95.cpu().numpy())
111
+
112
+
113
+ Dice_TC = run_acc.avg[0]
114
+ Dice_WT = run_acc.avg[1]
115
+ Dice_ET = run_acc.avg[2]
116
+ HD95_TC = run_hd95.avg[0]
117
+ HD95_WT = run_hd95.avg[1]
118
+ HD95_ET = run_hd95.avg[2]
119
+ print(
120
+ "Val {}/{}".format(idx, len(loader)),
121
+ ", Dice_TC:", Dice_TC,
122
+ ", Dice_WT:", Dice_WT,
123
+ ", Dice_ET:", Dice_ET,
124
+ ", Avg Dice:", (Dice_ET + Dice_TC + Dice_WT) / 3,
125
+ ", HD95_TC:", HD95_TC,
126
+ ", HD95_WT:", HD95_WT,
127
+ ", HD95_ET:", HD95_ET,
128
+ ", Avg HD95:", (HD95_ET + HD95_TC + HD95_WT) / 3,
129
+ ", time {:.2f}s".format(time.time() - start_time),
130
+ )
131
+ start_time = time.time()
132
+ with open(output_directory+'/log.txt', "a") as log_file:
133
+ log_file.write(f"Experiment name:{args.pretrained_dir.split('/')[-2]}, "
134
+ f"Final Validation Results - Dice_TC: {Dice_TC}, Dice_WT: {Dice_WT}, Dice_ET: {Dice_ET}, "
135
+ f"Avg Dice: {(Dice_ET + Dice_TC + Dice_WT) / 3}, "
136
+ f"HD95_TC: {HD95_TC}, HD95_WT: {HD95_WT}, HD95_ET: {HD95_ET}, "
137
+ f"Avg HD95: {(HD95_ET + HD95_TC + HD95_WT) / 3}\n")
138
+ return run_acc.avg
139
+
140
+ dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, get_not_nans=True)
141
+ hd95_acc = HausdorffDistanceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, percentile=95.0)
142
+ val_epoch(model, test_loader, acc_func=dice_acc,hd95_func=hd95_acc)
143
+
144
+ if __name__ == "__main__":
145
+ main()
trainer.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 - 2022 MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import os
13
+ import shutil
14
+ import time
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.parallel
19
+ import torch.utils.data.distributed
20
+ from tensorboardX import SummaryWriter
21
+ from torch.amp import GradScaler, autocast
22
+ from utils.utils import AverageMeter, distributed_all_gather
23
+
24
+ from monai.data import decollate_batch
25
+
26
+
27
+ def train_epoch(model, loader, optimizer, scaler, epoch, loss_func, args):
28
+ model.train()
29
+ start_time = time.time()
30
+ run_loss = AverageMeter()
31
+ for idx, batch_data in enumerate(loader):
32
+ if isinstance(batch_data, list):
33
+ data, target, text = batch_data
34
+ else:
35
+ data, target, text = batch_data["image"], batch_data["label"], batch_data["text_feature"]
36
+ data, target, text = data.cuda(args.rank), target.cuda(args.rank), text.cuda(args.rank)
37
+ optimizer.zero_grad(set_to_none=True)
38
+ with autocast('cuda',enabled=args.amp):
39
+ logits = model(data,text)
40
+ loss = loss_func(logits, target)
41
+ if args.amp:
42
+ scaler.scale(loss).backward()
43
+ scaler.step(optimizer)
44
+ scaler.update()
45
+ else:
46
+ loss.backward()
47
+ optimizer.step()
48
+ if args.distributed:
49
+ loss_list = distributed_all_gather([loss], out_numpy=True, is_valid=idx < loader.sampler.valid_length)
50
+ run_loss.update(
51
+ np.mean(np.mean(np.stack(loss_list, axis=0), axis=0), axis=0), n=args.batch_size * args.world_size
52
+ )
53
+ else:
54
+ run_loss.update(loss.item(), n=args.batch_size)
55
+ if args.rank == 0:
56
+ print(
57
+ "Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)),
58
+ "loss: {:.4f}".format(run_loss.avg),
59
+ "time {:.2f}s".format(time.time() - start_time),
60
+ )
61
+ start_time = time.time()
62
+ '''for param in model.parameters():
63
+ param.grad = None'''
64
+ optimizer.zero_grad(set_to_none=True)
65
+ return run_loss.avg
66
+
67
+
68
+ def val_epoch(model, loader, epoch, acc_func, args, post_sigmoid=None, post_pred=None):
69
+ model.eval()
70
+ start_time = time.time()
71
+ run_acc = AverageMeter()
72
+
73
+ with torch.no_grad():
74
+ for idx, batch_data in enumerate(loader):
75
+ data, target, text = batch_data["image"], batch_data["label"], batch_data["text_feature"]
76
+ data, target, text = data.cuda(args.rank), target.cuda(args.rank), text.cuda(args.rank)
77
+ with autocast('cuda',enabled=args.amp):
78
+ logits = model(data,text)
79
+ val_labels_list = decollate_batch(target)
80
+ val_outputs_list = decollate_batch(logits)
81
+ val_output_convert = [post_pred(post_sigmoid(val_pred_tensor)) for val_pred_tensor in val_outputs_list]
82
+ acc_func.reset()
83
+ acc_func(y_pred=val_output_convert, y=val_labels_list)
84
+ acc, not_nans = acc_func.aggregate()
85
+ acc = acc.cuda(args.rank)
86
+ if args.distributed:
87
+ acc_list, not_nans_list = distributed_all_gather(
88
+ [acc, not_nans], out_numpy=True, is_valid=idx < loader.sampler.valid_length
89
+ )
90
+ for al, nl in zip(acc_list, not_nans_list):
91
+ run_acc.update(al, n=nl)
92
+ else:
93
+ run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy())
94
+
95
+ if args.rank == 0:
96
+ Dice_TC = run_acc.avg[0]
97
+ Dice_WT = run_acc.avg[1]
98
+ Dice_ET = run_acc.avg[2]
99
+ print(
100
+ "Val {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)),
101
+ ", Dice_TC:",
102
+ Dice_TC,
103
+ ", Dice_WT:",
104
+ Dice_WT,
105
+ ", Dice_ET:",
106
+ Dice_ET,
107
+ ", time {:.2f}s".format(time.time() - start_time),
108
+ )
109
+ start_time = time.time()
110
+
111
+ return run_acc.avg
112
+
113
+
114
+ def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0, optimizer=None, scheduler=None):
115
+ state_dict = model.state_dict() if not args.distributed else model.module.state_dict()
116
+ save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict}
117
+ if optimizer is not None:
118
+ save_dict["optimizer"] = optimizer.state_dict()
119
+ if scheduler is not None:
120
+ save_dict["scheduler"] = scheduler.state_dict()
121
+ filename = os.path.join(args.logdir, filename)
122
+ torch.save(save_dict, filename)
123
+ print("Saving checkpoint", filename)
124
+
125
+
126
+ def run_training(
127
+ model,
128
+ train_loader,
129
+ val_loader,
130
+ optimizer,
131
+ loss_func,
132
+ acc_func,
133
+ args,
134
+ scheduler=None,
135
+ start_epoch=0,
136
+ post_sigmoid=None,
137
+ post_pred=None,
138
+ semantic_classes=None,
139
+ ):
140
+ writer = None
141
+ if args.logdir is not None and args.rank == 0:
142
+ writer = SummaryWriter(log_dir=args.logdir)
143
+ if args.rank == 0:
144
+ print("Writing Tensorboard logs to ", args.logdir)
145
+ scaler = None
146
+ if args.amp:
147
+ scaler = GradScaler()
148
+ val_acc_max = 0.0
149
+ for epoch in range(start_epoch, args.max_epochs):
150
+ if args.distributed:
151
+ train_loader.sampler.set_epoch(epoch)
152
+ torch.distributed.barrier()
153
+ print(args.rank, time.ctime(), "Epoch:", epoch)
154
+ epoch_time = time.time()
155
+ train_loss = train_epoch(
156
+ model, train_loader, optimizer, scaler=scaler, epoch=epoch, loss_func=loss_func, args=args
157
+ )
158
+ if args.rank == 0:
159
+ print(
160
+ "Final training {}/{}".format(epoch, args.max_epochs - 1),
161
+ "loss: {:.4f}".format(train_loss),
162
+ "time {:.2f}s".format(time.time() - epoch_time),
163
+ )
164
+ if args.rank == 0 and writer is not None:
165
+ writer.add_scalar("train_loss", train_loss, epoch)
166
+ b_new_best = False
167
+ if (epoch + 1) % args.val_every == 0:
168
+ if args.distributed:
169
+ torch.distributed.barrier()
170
+ epoch_time = time.time()
171
+ val_acc = val_epoch(
172
+ model,
173
+ val_loader,
174
+ epoch=epoch,
175
+ acc_func=acc_func,
176
+ args=args,
177
+ post_sigmoid=post_sigmoid,
178
+ post_pred=post_pred,
179
+ )
180
+
181
+ if args.rank == 0:
182
+ Dice_TC = val_acc[0]
183
+ Dice_WT = val_acc[1]
184
+ Dice_ET = val_acc[2]
185
+ print(
186
+ "Final validation stats {}/{}".format(epoch, args.max_epochs - 1),
187
+ ", Dice_TC:",
188
+ Dice_TC,
189
+ ", Dice_WT:",
190
+ Dice_WT,
191
+ ", Dice_ET:",
192
+ Dice_ET,
193
+ ", time {:.2f}s".format(time.time() - epoch_time),
194
+ )
195
+
196
+ if writer is not None:
197
+ writer.add_scalar("Mean_Val_Dice", np.mean(val_acc), epoch)
198
+ if semantic_classes is not None:
199
+ for val_channel_ind in range(len(semantic_classes)):
200
+ if val_channel_ind < val_acc.size:
201
+ writer.add_scalar(semantic_classes[val_channel_ind], val_acc[val_channel_ind], epoch)
202
+ val_avg_acc = np.mean(val_acc)
203
+ if val_avg_acc > val_acc_max:
204
+ print("new best ({:.6f} --> {:.6f}). ".format(val_acc_max, val_avg_acc))
205
+ val_acc_max = val_avg_acc
206
+ b_new_best = True
207
+ if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
208
+ save_checkpoint(
209
+ model, epoch, args, best_acc=val_acc_max, optimizer=optimizer, scheduler=scheduler
210
+ )
211
+ if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
212
+ print("Saving")
213
+ save_checkpoint(model, epoch, args, best_acc=val_acc_max, filename="model_final.pt")
214
+ if b_new_best:
215
+ print("Copying to model.pt new best model!!!!")
216
+ shutil.copyfile(os.path.join(args.logdir, "model_final.pt"), os.path.join(args.logdir, "model.pt"))
217
+
218
+ if scheduler is not None:
219
+ scheduler.step()
220
+
221
+ print("Training Finished !, Best Accuracy: ", val_acc_max)
222
+
223
+ return val_acc_max
utils/__init__.py ADDED
File without changes
utils/data_utils.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 - 2022 MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import json
13
+ import math
14
+ import os
15
+
16
+ import numpy as np
17
+ import torch
18
+
19
+ from monai import data, transforms
20
+ from monai.data import NibabelReader
21
+ from monai.transforms import MapTransform
22
+
23
+ #Load biobert features
24
+ class LoadNumpyd(MapTransform):
25
+ def __init__(self, keys):
26
+ super().__init__(keys)
27
+
28
+ def __call__(self, data):
29
+ d = dict(data)
30
+ for key in self.keys:
31
+ d[key] = np.load(d[key])
32
+ d[key] = np.squeeze(d[key],axis=0)
33
+ return d
34
+
35
+ class Sampler(torch.utils.data.Sampler):
36
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True):
37
+ if num_replicas is None:
38
+ if not torch.distributed.is_available():
39
+ raise RuntimeError("Requires distributed package to be available")
40
+ num_replicas = torch.distributed.get_world_size()
41
+ if rank is None:
42
+ if not torch.distributed.is_available():
43
+ raise RuntimeError("Requires distributed package to be available")
44
+ rank = torch.distributed.get_rank()
45
+ self.shuffle = shuffle
46
+ self.make_even = make_even
47
+ self.dataset = dataset
48
+ self.num_replicas = num_replicas
49
+ self.rank = rank
50
+ self.epoch = 0
51
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
52
+ self.total_size = self.num_samples * self.num_replicas
53
+ indices = list(range(len(self.dataset)))
54
+ self.valid_length = len(indices[self.rank : self.total_size : self.num_replicas])
55
+
56
+ def __iter__(self):
57
+ if self.shuffle:
58
+ g = torch.Generator()
59
+ g.manual_seed(self.epoch)
60
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
61
+ else:
62
+ indices = list(range(len(self.dataset)))
63
+ if self.make_even:
64
+ if len(indices) < self.total_size:
65
+ if self.total_size - len(indices) < len(indices):
66
+ indices += indices[: (self.total_size - len(indices))]
67
+ else:
68
+ extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices))
69
+ indices += [indices[ids] for ids in extra_ids]
70
+ assert len(indices) == self.total_size
71
+ indices = indices[self.rank : self.total_size : self.num_replicas]
72
+ self.num_samples = len(indices)
73
+ return iter(indices)
74
+
75
+ def __len__(self):
76
+ return self.num_samples
77
+
78
+ def set_epoch(self, epoch):
79
+ self.epoch = epoch
80
+
81
+
82
+ def datafold_read(datalist, basedir, fold=0, key="training"):
83
+ with open(datalist) as f:
84
+ json_data = json.load(f)
85
+
86
+ json_data = json_data[key]
87
+
88
+ for d in json_data:
89
+ for k, v in d.items():
90
+ if isinstance(d[k], list):
91
+ d[k] = [os.path.join(basedir, iv) for iv in d[k]]
92
+ elif isinstance(d[k], str):
93
+ d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]
94
+ tr = []
95
+ val = []
96
+ for d in json_data:
97
+ if "fold" in d and d["fold"] == fold:
98
+ val.append(d)
99
+ else:
100
+ tr.append(d)
101
+ return tr, val
102
+
103
+
104
+ def get_loader(args):
105
+ data_dir = args.data_dir
106
+ datalist_json = args.json_list
107
+ train_files, validation_files = datafold_read(datalist=datalist_json, basedir=data_dir, fold=args.fold)
108
+ train_transform = transforms.Compose(
109
+ [
110
+ transforms.LoadImaged(keys=["image", "label"],reader=NibabelReader()),
111
+ LoadNumpyd(keys=["text_feature"]),
112
+ transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
113
+ transforms.Resized(keys=["image","label"],spatial_size=[args.roi_x,args.roi_y,args.roi_z]),
114
+ transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
115
+ transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
116
+ transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
117
+ transforms.ToTensord(keys=["image", "label", "text_feature"]),
118
+ ]
119
+ )
120
+ val_transform = transforms.Compose(
121
+ [
122
+ transforms.LoadImaged(keys=["image", "label"],reader=NibabelReader()),
123
+ LoadNumpyd(keys=["text_feature"]),
124
+ transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
125
+ transforms.Resized(keys=["image", "label"], spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
126
+ transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
127
+ transforms.ToTensord(keys=["image", "label", "text_feature"]),
128
+ ]
129
+ )
130
+
131
+ test_transform = transforms.Compose(
132
+ [
133
+ transforms.LoadImaged(keys=["image", "label"],reader=NibabelReader()),
134
+ LoadNumpyd(keys=["text_feature"]),
135
+ transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
136
+ transforms.Resized(keys=["image", "label"], spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
137
+ transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
138
+ transforms.ToTensord(keys=["image", "label", "text_feature"]),
139
+ ]
140
+ )
141
+
142
+ if args.test_mode:
143
+ val_ds = data.Dataset(data=validation_files, transform=test_transform)
144
+ val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None
145
+ test_loader = data.DataLoader(
146
+ val_ds, batch_size=1, shuffle=False, num_workers=args.workers, sampler=val_sampler, pin_memory=True
147
+ )
148
+
149
+ loader = test_loader
150
+ else:
151
+ train_ds = data.Dataset(data=train_files, transform=train_transform)
152
+
153
+ train_sampler = Sampler(train_ds) if args.distributed else None
154
+ train_loader = data.DataLoader(
155
+ train_ds,
156
+ batch_size=args.batch_size,
157
+ shuffle=(train_sampler is None),
158
+ num_workers=args.workers,
159
+ sampler=train_sampler,
160
+ pin_memory=True,
161
+ )
162
+ val_ds = data.Dataset(data=validation_files, transform=val_transform)
163
+ val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None
164
+ val_loader = data.DataLoader(
165
+ val_ds, batch_size=1, shuffle=False, num_workers=args.workers, sampler=val_sampler, pin_memory=True
166
+ )
167
+ loader = [train_loader, val_loader]
168
+
169
+ return loader
utils/textswin_unetr.py ADDED
@@ -0,0 +1,1081 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 - 2022 -> (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from typing import Sequence, Tuple, Type, Union
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.utils.checkpoint as checkpoint
19
+ from torch.nn import LayerNorm
20
+
21
+ from monai.networks.blocks import MLPBlock as Mlp
22
+ from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
23
+ from monai.networks.layers import DropPath, trunc_normal_
24
+ from monai.utils import ensure_tuple_rep, optional_import
25
+ import math
26
+
27
+ rearrange, _ = optional_import("einops", name="rearrange")
28
+
29
+
30
+ class TextSwinUNETR(nn.Module):
31
+ """
32
+ Swin UNETR based on: "Hatamizadeh et al.,
33
+ Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
34
+ <https://arxiv.org/abs/2201.01266>"
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[Sequence[int], int],
40
+ in_channels: int,
41
+ out_channels: int,
42
+ text_dim: int,
43
+ depths: Sequence[int] = (2, 2, 2, 2),
44
+ num_heads: Sequence[int] = (3, 6, 12, 24),
45
+ feature_size: int = 24,
46
+ norm_name: Union[Tuple, str] = "instance",
47
+ drop_rate: float = 0.0,
48
+ attn_drop_rate: float = 0.0,
49
+ dropout_path_rate: float = 0.0,
50
+ normalize: bool = True,
51
+ use_checkpoint: bool = False,
52
+ spatial_dims: int = 3,
53
+ ) -> None:
54
+ """
55
+ Args:
56
+ img_size: dimension of input image.
57
+ in_channels: dimension of input channels.
58
+ out_channels: dimension of output channels.
59
+ feature_size: dimension of network feature size.
60
+ depths: number of layers in each stage.
61
+ num_heads: number of attention heads.
62
+ norm_name: feature normalization type and arguments.
63
+ drop_rate: dropout rate.
64
+ attn_drop_rate: attention dropout rate.
65
+ dropout_path_rate: drop path rate.
66
+ normalize: normalize output intermediate features in each stage.
67
+ use_checkpoint: use gradient checkpointing for reduced memory usage.
68
+ spatial_dims: number of spatial dims.
69
+
70
+ Examples::
71
+
72
+ # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
73
+ #>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
74
+
75
+ # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
76
+ #>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
77
+
78
+ # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
79
+ #>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
80
+
81
+ """
82
+
83
+ super().__init__()
84
+
85
+ img_size = ensure_tuple_rep(img_size, spatial_dims)
86
+ patch_size = ensure_tuple_rep(2, spatial_dims)
87
+ window_size = ensure_tuple_rep(7, spatial_dims)
88
+
89
+ if not (spatial_dims == 2 or spatial_dims == 3):
90
+ raise ValueError("spatial dimension should be 2 or 3.")
91
+
92
+ for m, p in zip(img_size, patch_size):
93
+ for i in range(5):
94
+ if m % np.power(p, i + 1) != 0:
95
+ raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.")
96
+
97
+ if not (0 <= drop_rate <= 1):
98
+ raise ValueError("dropout rate should be between 0 and 1.")
99
+
100
+ if not (0 <= attn_drop_rate <= 1):
101
+ raise ValueError("attention dropout rate should be between 0 and 1.")
102
+
103
+ if not (0 <= dropout_path_rate <= 1):
104
+ raise ValueError("drop path rate should be between 0 and 1.")
105
+
106
+ if feature_size % 12 != 0:
107
+ raise ValueError("feature_size should be divisible by 12.")
108
+
109
+ self.normalize = normalize
110
+
111
+ self.swinViT = SwinTransformer(
112
+ in_chans=in_channels,
113
+ embed_dim=feature_size,
114
+ window_size=window_size,
115
+ patch_size=patch_size,
116
+ depths=depths,
117
+ num_heads=num_heads,
118
+ mlp_ratio=4.0,
119
+ qkv_bias=True,
120
+ drop_rate=drop_rate,
121
+ attn_drop_rate=attn_drop_rate,
122
+ drop_path_rate=dropout_path_rate,
123
+ norm_layer=nn.LayerNorm,
124
+ use_checkpoint=use_checkpoint,
125
+ spatial_dims=spatial_dims,
126
+ text_dim=text_dim,
127
+ )
128
+
129
+ self.encoder1 = UnetrBasicBlock(
130
+ spatial_dims=spatial_dims,
131
+ in_channels=in_channels,
132
+ out_channels=feature_size,
133
+ kernel_size=3,
134
+ stride=1,
135
+ norm_name=norm_name,
136
+ res_block=True,
137
+ )
138
+
139
+ self.encoder2 = UnetrBasicBlock(
140
+ spatial_dims=spatial_dims,
141
+ in_channels=feature_size,
142
+ out_channels=feature_size,
143
+ kernel_size=3,
144
+ stride=1,
145
+ norm_name=norm_name,
146
+ res_block=True,
147
+ )
148
+
149
+ self.encoder3 = UnetrBasicBlock(
150
+ spatial_dims=spatial_dims,
151
+ in_channels=2 * feature_size,
152
+ out_channels=2 * feature_size,
153
+ kernel_size=3,
154
+ stride=1,
155
+ norm_name=norm_name,
156
+ res_block=True,
157
+ )
158
+
159
+ self.encoder4 = UnetrBasicBlock(
160
+ spatial_dims=spatial_dims,
161
+ in_channels=4 * feature_size,
162
+ out_channels=4 * feature_size,
163
+ kernel_size=3,
164
+ stride=1,
165
+ norm_name=norm_name,
166
+ res_block=True,
167
+ )
168
+
169
+ self.encoder10 = UnetrBasicBlock(
170
+ spatial_dims=spatial_dims,
171
+ in_channels=16 * feature_size,
172
+ out_channels=16 * feature_size,
173
+ kernel_size=3,
174
+ stride=1,
175
+ norm_name=norm_name,
176
+ res_block=True,
177
+ )
178
+
179
+ self.decoder5 = UnetrUpBlock(
180
+ spatial_dims=spatial_dims,
181
+ in_channels=16 * feature_size,
182
+ out_channels=8 * feature_size,
183
+ kernel_size=3,
184
+ upsample_kernel_size=2,
185
+ norm_name=norm_name,
186
+ res_block=True,
187
+ )
188
+
189
+ self.decoder4 = UnetrUpBlock(
190
+ spatial_dims=spatial_dims,
191
+ in_channels=feature_size * 8,
192
+ out_channels=feature_size * 4,
193
+ kernel_size=3,
194
+ upsample_kernel_size=2,
195
+ norm_name=norm_name,
196
+ res_block=True,
197
+ )
198
+
199
+ self.decoder3 = UnetrUpBlock(
200
+ spatial_dims=spatial_dims,
201
+ in_channels=feature_size * 4,
202
+ out_channels=feature_size * 2,
203
+ kernel_size=3,
204
+ upsample_kernel_size=2,
205
+ norm_name=norm_name,
206
+ res_block=True,
207
+ )
208
+ self.decoder2 = UnetrUpBlock(
209
+ spatial_dims=spatial_dims,
210
+ in_channels=feature_size * 2,
211
+ out_channels=feature_size,
212
+ kernel_size=3,
213
+ upsample_kernel_size=2,
214
+ norm_name=norm_name,
215
+ res_block=True,
216
+ )
217
+
218
+ self.decoder1 = UnetrUpBlock(
219
+ spatial_dims=spatial_dims,
220
+ in_channels=feature_size,
221
+ out_channels=feature_size,
222
+ kernel_size=3,
223
+ upsample_kernel_size=2,
224
+ norm_name=norm_name,
225
+ res_block=True,
226
+ )
227
+
228
+ self.out = UnetOutBlock(
229
+ spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels
230
+ ) # type: ignore
231
+
232
+ def load_from(self, weights):
233
+
234
+ with torch.no_grad():
235
+ self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"])
236
+ self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"])
237
+ for bname, block in self.swinViT.layers1[0].blocks.named_children():
238
+ block.load_from(weights, n_block=bname, layer="layers1")
239
+ self.swinViT.layers1[0].downsample.reduction.weight.copy_(
240
+ weights["state_dict"]["module.layers1.0.downsample.reduction.weight"]
241
+ )
242
+ self.swinViT.layers1[0].downsample.norm.weight.copy_(
243
+ weights["state_dict"]["module.layers1.0.downsample.norm.weight"]
244
+ )
245
+ self.swinViT.layers1[0].downsample.norm.bias.copy_(
246
+ weights["state_dict"]["module.layers1.0.downsample.norm.bias"]
247
+ )
248
+ for bname, block in self.swinViT.layers2[0].blocks.named_children():
249
+ block.load_from(weights, n_block=bname, layer="layers2")
250
+ self.swinViT.layers2[0].downsample.reduction.weight.copy_(
251
+ weights["state_dict"]["module.layers2.0.downsample.reduction.weight"]
252
+ )
253
+ self.swinViT.layers2[0].downsample.norm.weight.copy_(
254
+ weights["state_dict"]["module.layers2.0.downsample.norm.weight"]
255
+ )
256
+ self.swinViT.layers2[0].downsample.norm.bias.copy_(
257
+ weights["state_dict"]["module.layers2.0.downsample.norm.bias"]
258
+ )
259
+ for bname, block in self.swinViT.layers3[0].blocks.named_children():
260
+ block.load_from(weights, n_block=bname, layer="layers3")
261
+ self.swinViT.layers3[0].downsample.reduction.weight.copy_(
262
+ weights["state_dict"]["module.layers3.0.downsample.reduction.weight"]
263
+ )
264
+ self.swinViT.layers3[0].downsample.norm.weight.copy_(
265
+ weights["state_dict"]["module.layers3.0.downsample.norm.weight"]
266
+ )
267
+ self.swinViT.layers3[0].downsample.norm.bias.copy_(
268
+ weights["state_dict"]["module.layers3.0.downsample.norm.bias"]
269
+ )
270
+ for bname, block in self.swinViT.layers4[0].blocks.named_children():
271
+ block.load_from(weights, n_block=bname, layer="layers4")
272
+ self.swinViT.layers4[0].downsample.reduction.weight.copy_(
273
+ weights["state_dict"]["module.layers4.0.downsample.reduction.weight"]
274
+ )
275
+ self.swinViT.layers4[0].downsample.norm.weight.copy_(
276
+ weights["state_dict"]["module.layers4.0.downsample.norm.weight"]
277
+ )
278
+ self.swinViT.layers4[0].downsample.norm.bias.copy_(
279
+ weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
280
+ )
281
+
282
+ def forward(self, x_in, text_in):
283
+ hidden_states_out = self.swinViT(x_in, text_in, self.normalize)
284
+ enc0 = self.encoder1(x_in)
285
+ enc1 = self.encoder2(hidden_states_out[0])
286
+ enc2 = self.encoder3(hidden_states_out[1])
287
+ enc3 = self.encoder4(hidden_states_out[2])
288
+ dec4 = self.encoder10(hidden_states_out[4])
289
+ dec3 = self.decoder5(dec4, hidden_states_out[3])
290
+ dec2 = self.decoder4(dec3, enc3)
291
+ dec1 = self.decoder3(dec2, enc2)
292
+ dec0 = self.decoder2(dec1, enc1)
293
+ out = self.decoder1(dec0, enc0)
294
+ logits = self.out(out)
295
+ return logits
296
+
297
+
298
+ def window_partition(x, window_size):
299
+ """window partition operation based on: "Liu et al.,
300
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
301
+ <https://arxiv.org/abs/2103.14030>"
302
+ https://github.com/microsoft/Swin-Transformer
303
+
304
+ Args:
305
+ x: input tensor.
306
+ window_size: local window size.
307
+ """
308
+ x_shape = x.size()
309
+ if len(x_shape) == 5:
310
+ b, d, h, w, c = x_shape
311
+ x = x.view(
312
+ b,
313
+ d // window_size[0],
314
+ window_size[0],
315
+ h // window_size[1],
316
+ window_size[1],
317
+ w // window_size[2],
318
+ window_size[2],
319
+ c,
320
+ )
321
+ windows = (
322
+ x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
323
+ )
324
+ elif len(x_shape) == 4:
325
+ b, h, w, c = x.shape
326
+ x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
327
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
328
+ return windows
329
+
330
+
331
+ def window_reverse(windows, window_size, dims):
332
+ """window reverse operation based on: "Liu et al.,
333
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
334
+ <https://arxiv.org/abs/2103.14030>"
335
+ https://github.com/microsoft/Swin-Transformer
336
+
337
+ Args:
338
+ windows: windows tensor.
339
+ window_size: local window size.
340
+ dims: dimension values.
341
+ """
342
+ if len(dims) == 4:
343
+ b, d, h, w = dims
344
+ x = windows.view(
345
+ b,
346
+ d // window_size[0],
347
+ h // window_size[1],
348
+ w // window_size[2],
349
+ window_size[0],
350
+ window_size[1],
351
+ window_size[2],
352
+ -1,
353
+ )
354
+ x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)
355
+
356
+ elif len(dims) == 3:
357
+ b, h, w = dims
358
+ x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1)
359
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
360
+ return x
361
+
362
+
363
+ def get_window_size(x_size, window_size, shift_size=None):
364
+ """Computing window size based on: "Liu et al.,
365
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
366
+ <https://arxiv.org/abs/2103.14030>"
367
+ https://github.com/microsoft/Swin-Transformer
368
+
369
+ Args:
370
+ x_size: input size.
371
+ window_size: local window size.
372
+ shift_size: window shifting size.
373
+ """
374
+
375
+ use_window_size = list(window_size)
376
+ if shift_size is not None:
377
+ use_shift_size = list(shift_size)
378
+ for i in range(len(x_size)):
379
+ if x_size[i] <= window_size[i]:
380
+ use_window_size[i] = x_size[i]
381
+ if shift_size is not None:
382
+ use_shift_size[i] = 0
383
+
384
+ if shift_size is None:
385
+ return tuple(use_window_size)
386
+ else:
387
+ return tuple(use_window_size), tuple(use_shift_size)
388
+
389
+
390
+ class WindowAttention(nn.Module):
391
+ """
392
+ Window based multi-head self attention module with relative position bias based on: "Liu et al.,
393
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
394
+ <https://arxiv.org/abs/2103.14030>"
395
+ https://github.com/microsoft/Swin-Transformer
396
+ """
397
+
398
+ def __init__(
399
+ self,
400
+ dim: int,
401
+ num_heads: int,
402
+ window_size: Sequence[int],
403
+ qkv_bias: bool = False,
404
+ attn_drop: float = 0.0,
405
+ proj_drop: float = 0.0,
406
+ ) -> None:
407
+ """
408
+ Args:
409
+ dim: number of feature channels.
410
+ num_heads: number of attention heads.
411
+ window_size: local window size.
412
+ qkv_bias: add a learnable bias to query, key, value.
413
+ attn_drop: attention dropout rate.
414
+ proj_drop: dropout rate of output.
415
+ """
416
+
417
+ super().__init__()
418
+ self.dim = dim
419
+ self.window_size = window_size
420
+ self.num_heads = num_heads
421
+ head_dim = dim // num_heads
422
+ self.scale = head_dim**-0.5
423
+ mesh_args = torch.meshgrid.__kwdefaults__
424
+
425
+ if len(self.window_size) == 3:
426
+ self.relative_position_bias_table = nn.Parameter(
427
+ torch.zeros(
428
+ (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
429
+ num_heads,
430
+ )
431
+ )
432
+ coords_d = torch.arange(self.window_size[0])
433
+ coords_h = torch.arange(self.window_size[1])
434
+ coords_w = torch.arange(self.window_size[2])
435
+ if mesh_args is not None:
436
+ coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"))
437
+ else:
438
+ coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
439
+ coords_flatten = torch.flatten(coords, 1)
440
+
441
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
442
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
443
+ relative_coords[:, :, 0] += self.window_size[0] - 1
444
+ relative_coords[:, :, 1] += self.window_size[1] - 1
445
+ relative_coords[:, :, 2] += self.window_size[2] - 1
446
+ relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
447
+ relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
448
+ elif len(self.window_size) == 2:
449
+ self.relative_position_bias_table = nn.Parameter(
450
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
451
+ )
452
+ coords_h = torch.arange(self.window_size[0])
453
+ coords_w = torch.arange(self.window_size[1])
454
+ if mesh_args is not None:
455
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
456
+ else:
457
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w))
458
+ coords_flatten = torch.flatten(coords, 1)
459
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
460
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
461
+ relative_coords[:, :, 0] += self.window_size[0] - 1
462
+ relative_coords[:, :, 1] += self.window_size[1] - 1
463
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
464
+
465
+ relative_position_index = relative_coords.sum(-1)
466
+ self.register_buffer("relative_position_index", relative_position_index)
467
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
468
+ self.attn_drop = nn.Dropout(attn_drop)
469
+ self.proj = nn.Linear(dim, dim)
470
+ self.proj_drop = nn.Dropout(proj_drop)
471
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
472
+ self.softmax = nn.Softmax(dim=-1)
473
+
474
+ def forward(self, x, mask):
475
+ b, n, c = x.shape
476
+ qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
477
+ q, k, v = qkv[0], qkv[1], qkv[2]
478
+ q = q * self.scale
479
+ attn = q @ k.transpose(-2, -1)
480
+ relative_position_bias = self.relative_position_bias_table[
481
+ self.relative_position_index[:n, :n].reshape(-1)
482
+ ].reshape(n, n, -1)
483
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
484
+ attn = attn + relative_position_bias.unsqueeze(0)
485
+ if mask is not None:
486
+ nw = mask.shape[0]
487
+ attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
488
+ attn = attn.view(-1, self.num_heads, n, n)
489
+ attn = self.softmax(attn)
490
+ else:
491
+ attn = self.softmax(attn)
492
+
493
+ attn = self.attn_drop(attn)
494
+ x = (attn @ v).transpose(1, 2).reshape(b, n, c)
495
+ x = self.proj(x)
496
+ x = self.proj_drop(x)
497
+ return x
498
+
499
+
500
+ class SwinTransformerBlock(nn.Module):
501
+ """
502
+ Swin Transformer block based on: "Liu et al.,
503
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
504
+ <https://arxiv.org/abs/2103.14030>"
505
+ https://github.com/microsoft/Swin-Transformer
506
+ """
507
+
508
+ def __init__(
509
+ self,
510
+ dim: int,
511
+ num_heads: int,
512
+ window_size: Sequence[int],
513
+ shift_size: Sequence[int],
514
+ mlp_ratio: float = 4.0,
515
+ qkv_bias: bool = True,
516
+ drop: float = 0.0,
517
+ attn_drop: float = 0.0,
518
+ drop_path: float = 0.0,
519
+ act_layer: str = "GELU",
520
+ norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
521
+ use_checkpoint: bool = False,
522
+ ) -> None:
523
+ """
524
+ Args:
525
+ dim: number of feature channels.
526
+ num_heads: number of attention heads.
527
+ window_size: local window size.
528
+ shift_size: window shift size.
529
+ mlp_ratio: ratio of mlp hidden dim to embedding dim.
530
+ qkv_bias: add a learnable bias to query, key, value.
531
+ drop: dropout rate.
532
+ attn_drop: attention dropout rate.
533
+ drop_path: stochastic depth rate.
534
+ act_layer: activation layer.
535
+ norm_layer: normalization layer.
536
+ use_checkpoint: use gradient checkpointing for reduced memory usage.
537
+ """
538
+
539
+ super().__init__()
540
+ self.dim = dim
541
+ self.num_heads = num_heads
542
+ self.window_size = window_size
543
+ self.shift_size = shift_size
544
+ self.mlp_ratio = mlp_ratio
545
+ self.use_checkpoint = use_checkpoint
546
+ self.norm1 = norm_layer(dim)
547
+ self.attn = WindowAttention(
548
+ dim,
549
+ window_size=self.window_size,
550
+ num_heads=num_heads,
551
+ qkv_bias=qkv_bias,
552
+ attn_drop=attn_drop,
553
+ proj_drop=drop,
554
+ )
555
+
556
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
557
+ self.norm2 = norm_layer(dim)
558
+ mlp_hidden_dim = int(dim * mlp_ratio)
559
+ self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin")
560
+
561
+ def forward_part1(self, x, mask_matrix):
562
+ x_shape = x.size()
563
+ x = self.norm1(x)
564
+ if len(x_shape) == 5:
565
+ b, d, h, w, c = x.shape
566
+ window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
567
+ pad_l = pad_t = pad_d0 = 0
568
+ pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
569
+ pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
570
+ pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
571
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
572
+ _, dp, hp, wp, _ = x.shape
573
+ dims = [b, dp, hp, wp]
574
+
575
+ elif len(x_shape) == 4:
576
+ b, h, w, c = x.shape
577
+ window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
578
+ pad_l = pad_t = 0
579
+ pad_r = (window_size[0] - h % window_size[0]) % window_size[0]
580
+ pad_b = (window_size[1] - w % window_size[1]) % window_size[1]
581
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
582
+ _, hp, wp, _ = x.shape
583
+ dims = [b, hp, wp]
584
+
585
+ if any(i > 0 for i in shift_size):
586
+ if len(x_shape) == 5:
587
+ shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
588
+ elif len(x_shape) == 4:
589
+ shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
590
+ attn_mask = mask_matrix
591
+ else:
592
+ shifted_x = x
593
+ attn_mask = None
594
+ x_windows = window_partition(shifted_x, window_size)
595
+ attn_windows = self.attn(x_windows, mask=attn_mask)
596
+ attn_windows = attn_windows.view(-1, *(window_size + (c,)))
597
+ shifted_x = window_reverse(attn_windows, window_size, dims)
598
+ if any(i > 0 for i in shift_size):
599
+ if len(x_shape) == 5:
600
+ x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
601
+ elif len(x_shape) == 4:
602
+ x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
603
+ else:
604
+ x = shifted_x
605
+
606
+ if len(x_shape) == 5:
607
+ if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
608
+ x = x[:, :d, :h, :w, :].contiguous()
609
+ elif len(x_shape) == 4:
610
+ if pad_r > 0 or pad_b > 0:
611
+ x = x[:, :h, :w, :].contiguous()
612
+
613
+ return x
614
+
615
+ def forward_part2(self, x):
616
+ return self.drop_path(self.mlp(self.norm2(x)))
617
+
618
+ def load_from(self, weights, n_block, layer):
619
+ root = f"module.{layer}.0.blocks.{n_block}."
620
+ block_names = [
621
+ "norm1.weight",
622
+ "norm1.bias",
623
+ "attn.relative_position_bias_table",
624
+ "attn.relative_position_index",
625
+ "attn.qkv.weight",
626
+ "attn.qkv.bias",
627
+ "attn.proj.weight",
628
+ "attn.proj.bias",
629
+ "norm2.weight",
630
+ "norm2.bias",
631
+ "mlp.fc1.weight",
632
+ "mlp.fc1.bias",
633
+ "mlp.fc2.weight",
634
+ "mlp.fc2.bias",
635
+ ]
636
+ with torch.no_grad():
637
+ self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
638
+ self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
639
+ self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
640
+ self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
641
+ self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
642
+ self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
643
+ self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
644
+ self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]])
645
+ self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]])
646
+ self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]])
647
+ self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]])
648
+ self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]])
649
+ self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]])
650
+ self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]])
651
+
652
+ def forward(self, x, mask_matrix):
653
+ shortcut = x
654
+ if self.use_checkpoint:
655
+ x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
656
+ else:
657
+ x = self.forward_part1(x, mask_matrix)
658
+ x = shortcut + self.drop_path(x)
659
+ if self.use_checkpoint:
660
+ x = x + checkpoint.checkpoint(self.forward_part2, x)
661
+ else:
662
+ x = x + self.forward_part2(x)
663
+ return x
664
+
665
+
666
+ class PatchMerging(nn.Module):
667
+ """
668
+ Patch merging layer based on: "Liu et al.,
669
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
670
+ <https://arxiv.org/abs/2103.14030>"
671
+ https://github.com/microsoft/Swin-Transformer
672
+ """
673
+
674
+ def __init__(
675
+ self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3
676
+ ) -> None: # type: ignore
677
+ """
678
+ Args:
679
+ dim: number of feature channels.
680
+ norm_layer: normalization layer.
681
+ spatial_dims: number of spatial dims.
682
+ """
683
+
684
+ super().__init__()
685
+ self.dim = dim
686
+ if spatial_dims == 3:
687
+ self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
688
+ self.norm = norm_layer(8 * dim)
689
+ elif spatial_dims == 2:
690
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
691
+ self.norm = norm_layer(4 * dim)
692
+
693
+ def forward(self, x):
694
+
695
+ x_shape = x.size()
696
+ if len(x_shape) == 5:
697
+ b, d, h, w, c = x_shape
698
+ pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
699
+ if pad_input:
700
+ x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2))
701
+ x0 = x[:, 0::2, 0::2, 0::2, :]
702
+ x1 = x[:, 1::2, 0::2, 0::2, :]
703
+ x2 = x[:, 0::2, 1::2, 0::2, :]
704
+ x3 = x[:, 0::2, 0::2, 1::2, :]
705
+ x4 = x[:, 1::2, 0::2, 1::2, :]
706
+ x5 = x[:, 0::2, 1::2, 0::2, :]
707
+ x6 = x[:, 0::2, 0::2, 1::2, :]
708
+ x7 = x[:, 1::2, 1::2, 1::2, :]
709
+ x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
710
+
711
+ elif len(x_shape) == 4:
712
+ b, h, w, c = x_shape
713
+ pad_input = (h % 2 == 1) or (w % 2 == 1)
714
+ if pad_input:
715
+ x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
716
+ x0 = x[:, 0::2, 0::2, :]
717
+ x1 = x[:, 1::2, 0::2, :]
718
+ x2 = x[:, 0::2, 1::2, :]
719
+ x3 = x[:, 1::2, 1::2, :]
720
+ x = torch.cat([x0, x1, x2, x3], -1)
721
+
722
+ x = self.norm(x)
723
+ x = self.reduction(x)
724
+ return x
725
+
726
+
727
+ def compute_mask(dims, window_size, shift_size, device):
728
+ """Computing region masks based on: "Liu et al.,
729
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
730
+ <https://arxiv.org/abs/2103.14030>"
731
+ https://github.com/microsoft/Swin-Transformer
732
+
733
+ Args:
734
+ dims: dimension values.
735
+ window_size: local window size.
736
+ shift_size: shift size.
737
+ device: device.
738
+ """
739
+
740
+ cnt = 0
741
+
742
+ if len(dims) == 3:
743
+ d, h, w = dims
744
+ img_mask = torch.zeros((1, d, h, w, 1), device=device)
745
+ for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
746
+ for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
747
+ for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
748
+ img_mask[:, d, h, w, :] = cnt
749
+ cnt += 1
750
+
751
+ elif len(dims) == 2:
752
+ h, w = dims
753
+ img_mask = torch.zeros((1, h, w, 1), device=device)
754
+ for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
755
+ for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
756
+ img_mask[:, h, w, :] = cnt
757
+ cnt += 1
758
+
759
+ mask_windows = window_partition(img_mask, window_size)
760
+ mask_windows = mask_windows.squeeze(-1)
761
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
762
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
763
+
764
+ return attn_mask
765
+
766
+
767
+ class BasicLayer(nn.Module):
768
+ """
769
+ Basic Swin Transformer layer in one stage based on: "Liu et al.,
770
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
771
+ <https://arxiv.org/abs/2103.14030>"
772
+ https://github.com/microsoft/Swin-Transformer
773
+ """
774
+
775
+ def __init__(
776
+ self,
777
+ dim: int,
778
+ depth: int,
779
+ num_heads: int,
780
+ window_size: Sequence[int],
781
+ drop_path: list,
782
+ mlp_ratio: float = 4.0,
783
+ qkv_bias: bool = False,
784
+ drop: float = 0.0,
785
+ attn_drop: float = 0.0,
786
+ norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
787
+ downsample: isinstance = None, # type: ignore
788
+ use_checkpoint: bool = False,
789
+ ) -> None:
790
+ """
791
+ Args:
792
+ dim: number of feature channels.
793
+ depths: number of layers in each stage.
794
+ num_heads: number of attention heads.
795
+ window_size: local window size.
796
+ drop_path: stochastic depth rate.
797
+ mlp_ratio: ratio of mlp hidden dim to embedding dim.
798
+ qkv_bias: add a learnable bias to query, key, value.
799
+ drop: dropout rate.
800
+ attn_drop: attention dropout rate.
801
+ norm_layer: normalization layer.
802
+ downsample: downsample layer at the end of the layer.
803
+ use_checkpoint: use gradient checkpointing for reduced memory usage.
804
+ """
805
+
806
+ super().__init__()
807
+ self.window_size = window_size
808
+ self.shift_size = tuple(i // 2 for i in window_size)
809
+ self.no_shift = tuple(0 for i in window_size)
810
+ self.depth = depth
811
+ self.use_checkpoint = use_checkpoint
812
+ self.blocks = nn.ModuleList(
813
+ [
814
+ SwinTransformerBlock(
815
+ dim=dim,
816
+ num_heads=num_heads,
817
+ window_size=self.window_size,
818
+ shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
819
+ mlp_ratio=mlp_ratio,
820
+ qkv_bias=qkv_bias,
821
+ drop=drop,
822
+ attn_drop=attn_drop,
823
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
824
+ norm_layer=norm_layer,
825
+ use_checkpoint=use_checkpoint,
826
+ )
827
+ for i in range(depth)
828
+ ]
829
+ )
830
+ self.downsample = downsample
831
+ if self.downsample is not None:
832
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
833
+
834
+ def forward(self, x):
835
+ x_shape = x.size()
836
+ if len(x_shape) == 5:
837
+ b, c, d, h, w = x_shape
838
+ window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
839
+ x = rearrange(x, "b c d h w -> b d h w c")
840
+ dp = int(np.ceil(d / window_size[0])) * window_size[0]
841
+ hp = int(np.ceil(h / window_size[1])) * window_size[1]
842
+ wp = int(np.ceil(w / window_size[2])) * window_size[2]
843
+ attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
844
+ for blk in self.blocks:
845
+ x = blk(x, attn_mask)
846
+ x = x.view(b, d, h, w, -1)
847
+ if self.downsample is not None:
848
+ x = self.downsample(x)
849
+ x = rearrange(x, "b d h w c -> b c d h w")
850
+
851
+ elif len(x_shape) == 4:
852
+ b, c, h, w = x_shape
853
+ window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
854
+ x = rearrange(x, "b c h w -> b h w c")
855
+ hp = int(np.ceil(h / window_size[0])) * window_size[0]
856
+ wp = int(np.ceil(w / window_size[1])) * window_size[1]
857
+ attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
858
+ for blk in self.blocks:
859
+ x = blk(x, attn_mask)
860
+ x = x.view(b, h, w, -1)
861
+ if self.downsample is not None:
862
+ x = self.downsample(x)
863
+ x = rearrange(x, "b h w c -> b c h w")
864
+ return x
865
+
866
+
867
+ class SwinTransformer(nn.Module):
868
+ """
869
+ Swin Transformer based on: "Liu et al.,
870
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
871
+ <https://arxiv.org/abs/2103.14030>"
872
+ https://github.com/microsoft/Swin-Transformer
873
+ """
874
+
875
+ def __init__(
876
+ self,
877
+ in_chans: int,
878
+ embed_dim: int,
879
+ text_dim: int,
880
+ window_size: Sequence[int],
881
+ patch_size: Sequence[int],
882
+ depths: Sequence[int],
883
+ num_heads: Sequence[int],
884
+ mlp_ratio: float = 4.0,
885
+ qkv_bias: bool = True,
886
+ drop_rate: float = 0.0,
887
+ attn_drop_rate: float = 0.0,
888
+ drop_path_rate: float = 0.0,
889
+ norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
890
+ patch_norm: bool = False,
891
+ use_checkpoint: bool = False,
892
+ spatial_dims: int = 3,
893
+ ) -> None:
894
+ """
895
+ Args:
896
+ in_chans: dimension of input channels.
897
+ embed_dim: number of linear projection output channels.
898
+ window_size: local window size.
899
+ patch_size: patch size.
900
+ depths: number of layers in each stage.
901
+ num_heads: number of attention heads.
902
+ mlp_ratio: ratio of mlp hidden dim to embedding dim.
903
+ qkv_bias: add a learnable bias to query, key, value.
904
+ drop_rate: dropout rate.
905
+ attn_drop_rate: attention dropout rate.
906
+ drop_path_rate: stochastic depth rate.
907
+ norm_layer: normalization layer.
908
+ patch_norm: add normalization after patch embedding.
909
+ use_checkpoint: use gradient checkpointing for reduced memory usage.
910
+ spatial_dims: spatial dimension.
911
+ """
912
+
913
+ super().__init__()
914
+ self.num_layers = len(depths)
915
+ self.embed_dim = embed_dim
916
+ self.patch_norm = patch_norm
917
+ self.window_size = window_size
918
+ self.patch_size = patch_size
919
+ self.patch_embed = PatchEmbed(
920
+ patch_size=self.patch_size,
921
+ in_chans=in_chans,
922
+ embed_dim=embed_dim,
923
+ norm_layer=norm_layer if self.patch_norm else None, # type: ignore
924
+ spatial_dims=spatial_dims,
925
+ )
926
+ self.pos_drop = nn.Dropout(p=drop_rate)
927
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
928
+ self.layers1 = nn.ModuleList()
929
+ self.layers2 = nn.ModuleList()
930
+ self.layers3 = nn.ModuleList()
931
+ self.layers4 = nn.ModuleList()
932
+ for i_layer in range(self.num_layers):
933
+ layer = BasicLayer(
934
+ dim=int(embed_dim * 2**i_layer),
935
+ depth=depths[i_layer],
936
+ num_heads=num_heads[i_layer],
937
+ window_size=self.window_size,
938
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
939
+ mlp_ratio=mlp_ratio,
940
+ qkv_bias=qkv_bias,
941
+ drop=drop_rate,
942
+ attn_drop=attn_drop_rate,
943
+ norm_layer=norm_layer,
944
+ downsample=PatchMerging,
945
+ use_checkpoint=use_checkpoint,
946
+ )
947
+ if i_layer == 0:
948
+ self.layers1.append(layer)
949
+ elif i_layer == 1:
950
+ self.layers2.append(layer)
951
+ elif i_layer == 2:
952
+ self.layers3.append(layer)
953
+ elif i_layer == 3:
954
+ self.layers4.append(layer)
955
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
956
+ self.mlp_text = nn.Sequential(
957
+ nn.Conv1d(text_dim, 1024, kernel_size=1),
958
+ nn.ReLU(),
959
+ nn.Conv1d(1024, 768, kernel_size=1),
960
+ )
961
+
962
+ self.fw_mlp = nn.Sequential(
963
+ nn.Conv3d(768, 768, kernel_size=1),
964
+ nn.ReLU())
965
+
966
+ self.mlpk = nn.Conv1d(768, 768, kernel_size=1)
967
+ self.mlpv = nn.Conv1d(768, 768, kernel_size=1)
968
+
969
+ self.ly_norm = nn.LayerNorm(normalized_shape=(4,4,4))
970
+
971
+ self.mlp_text_q = nn.Conv1d(768, 768, kernel_size=1)
972
+ self.mlp_image_k = nn.Conv1d(768, 768, kernel_size=1)
973
+ self.mlp_image_v = nn.Conv1d(768, 768, kernel_size=1)
974
+ self.mlp_image_q = nn.Conv1d(768, 768, kernel_size=1)
975
+
976
+
977
+ def proj_out(self, x, normalize=False):
978
+ if normalize:
979
+ x_shape = x.size()
980
+ if len(x_shape) == 5:
981
+ n, ch, d, h, w = x_shape
982
+ x = rearrange(x, "n c d h w -> n d h w c")
983
+ x = F.layer_norm(x, [ch])
984
+ x = rearrange(x, "n d h w c -> n c d h w")
985
+ elif len(x_shape) == 4:
986
+ n, ch, h, w = x_shape
987
+ x = rearrange(x, "n c h w -> n h w c")
988
+ x = F.layer_norm(x, [ch])
989
+ x = rearrange(x, "n h w c -> n c h w")
990
+ return x
991
+
992
+ def sequential_cross_attention(self, image_features, text_features):
993
+
994
+ """
995
+ Cross attention between image and text features.
996
+ Args:
997
+ image_features: Tensor of shape (B, C, H, W, D)
998
+ text_features: Tensor of shape (B, T_dim, T_len)
999
+ Returns:
1000
+ Processed image features with the same shape as input (B, C, H, W, D)
1001
+ """
1002
+ B, C, H, W, D = image_features.shape
1003
+ _, T_dim, T_len = text_features.shape
1004
+
1005
+ # Step 1: Text-to-Image Cross Attention (Text as Q, Image as K/V)
1006
+ # Project text features to Query
1007
+ text_features = self.mlp_text(text_features.permute(0,2,1).contiguous())
1008
+ text_q = self.mlp_text_q(text_features).permute(0,2,1).contiguous() # Shape: (B, T_len, d_k)
1009
+
1010
+ # Flatten image features and project to Key and Value
1011
+ image_features_flat = image_features.view(B, C, -1).contiguous() # Shape: (B, N_img, C)
1012
+ image_k = self.mlp_image_k(image_features_flat).permute(0,2,1).contiguous() # Shape: (B, N_img, d_k)
1013
+ image_v = self.mlp_image_v(image_features_flat).permute(0,2,1).contiguous() # Shape: (B, N_img, d_v)
1014
+
1015
+ # Compute attention scores and weights
1016
+ attn_scores_t2i = torch.matmul(text_q, image_k.transpose(-2, -1)) / math.sqrt(
1017
+ text_q.size(-1)) # (B, T_len, N_img)
1018
+ attn_weights_t2i = F.softmax(attn_scores_t2i, dim=-1) # (B, T_len, N_img)
1019
+
1020
+ # Get attended image features
1021
+ attended_image_features = torch.matmul(attn_weights_t2i, image_v).permute(0,2,1).contiguous() # (B, T_len, d_v)
1022
+
1023
+ # Step 2: Image-to-AttendedImage(Text) Cross Attention (Image as Q, AttendedImage as K/V)
1024
+ # Project image features to Query
1025
+ image_q = self.mlp_image_q(image_features_flat).permute(0,2,1).contiguous() # (B, N_img, d_k)
1026
+
1027
+ # Project attended text features to Key and Value
1028
+ attended_image_k = self.mlpk(attended_image_features).permute(0,2,1).contiguous() # (B, T_len, d_k)
1029
+ attended_image_v = self.mlpv(attended_image_features).permute(0,2,1).contiguous() # (B, T_len, d_v)
1030
+
1031
+ # Compute attention scores and weights
1032
+ attn_scores_i2t = torch.matmul(image_q, attended_image_k.transpose(-2, -1)) / math.sqrt(
1033
+ image_q.size(-1)) # (B, N_img, T_len)
1034
+ attn_weights_i2t = F.softmax(attn_scores_i2t, dim=-1) # (B, N_img, T_len)
1035
+
1036
+ # Get attended image features
1037
+ attn_output_image = torch.matmul(attn_weights_i2t, attended_image_v) # (B, N_img, d_v)
1038
+
1039
+ # Reshape back to original image feature shape
1040
+ attn_output_image = attn_output_image.permute(0, 2, 1).contiguous() # (B, d_v, N_img)
1041
+ attn_output_image = attn_output_image.view(B, C, H, W, D)
1042
+
1043
+ # Apply layer normalization and final MLP processing
1044
+ processed_image_features = self.ly_norm(attn_output_image)
1045
+ processed_image_features = self.fw_mlp(processed_image_features.float())
1046
+ processed_image_features = self.ly_norm(processed_image_features)
1047
+
1048
+ return processed_image_features
1049
+
1050
+
1051
+ def forward(self, x, text, normalize=True):
1052
+ x0 = self.patch_embed(x)
1053
+ x0 = self.pos_drop(x0)
1054
+ x0_out = self.proj_out(x0, normalize)
1055
+ x1 = self.layers1[0](x0.contiguous())
1056
+ x1_out = self.proj_out(x1, normalize)
1057
+ x2 = self.layers2[0](x1.contiguous())
1058
+ x2_out = self.proj_out(x2, normalize)
1059
+ x3 = self.layers3[0](x2.contiguous())
1060
+ x3_out = self.proj_out(x3, normalize)
1061
+ x4 = self.layers4[0](x3.contiguous())
1062
+ # Sequential cross-attention fusion
1063
+ x4 = self.sequential_cross_attention(x4,text)
1064
+ x4_out = self.proj_out(x4, normalize)
1065
+ return [x0_out, x1_out, x2_out, x3_out, x4_out]
1066
+
1067
+
1068
+ if __name__ == "__main__":
1069
+ model = TextSwinUNETR(
1070
+ img_size=(128,128,128),
1071
+ in_channels=4,
1072
+ out_channels=3,
1073
+ feature_size=48,
1074
+ text_dim=768,
1075
+ use_checkpoint=False,
1076
+ ).cuda()
1077
+
1078
+ input = torch.randn(1,4,128,128,128).cuda()
1079
+ text = torch.randn(1,128,768).cuda()
1080
+ output = model(input,text)
1081
+ print(output[0].shape)
utils/utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 - 2022 MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ def dice(x, y):
17
+ intersect = np.sum(np.sum(np.sum(x * y)))
18
+ y_sum = np.sum(np.sum(np.sum(y)))
19
+ if y_sum == 0:
20
+ return 0.0
21
+ x_sum = np.sum(np.sum(np.sum(x)))
22
+ return 2 * intersect / (x_sum + y_sum)
23
+
24
+
25
+ class AverageMeter(object):
26
+ def __init__(self):
27
+ self.reset()
28
+
29
+ def reset(self):
30
+ self.val = 0
31
+ self.avg = 0
32
+ self.sum = 0
33
+ self.count = 0
34
+
35
+ def update(self, val, n=1):
36
+ self.val = val
37
+ self.sum += val * n
38
+ self.count += n
39
+ self.avg = np.where(self.count > 0, self.sum / self.count, self.sum)
40
+
41
+
42
+ def distributed_all_gather(
43
+ tensor_list, valid_batch_size=None, out_numpy=False, world_size=None, no_barrier=False, is_valid=None
44
+ ):
45
+ if world_size is None:
46
+ world_size = torch.distributed.get_world_size()
47
+ if valid_batch_size is not None:
48
+ valid_batch_size = min(valid_batch_size, world_size)
49
+ elif is_valid is not None:
50
+ is_valid = torch.tensor(bool(is_valid), dtype=torch.bool, device=tensor_list[0].device)
51
+ if not no_barrier:
52
+ torch.distributed.barrier()
53
+ tensor_list_out = []
54
+ with torch.no_grad():
55
+ if is_valid is not None:
56
+ is_valid_list = [torch.zeros_like(is_valid) for _ in range(world_size)]
57
+ torch.distributed.all_gather(is_valid_list, is_valid)
58
+ is_valid = [x.item() for x in is_valid_list]
59
+ for tensor in tensor_list:
60
+ gather_list = [torch.zeros_like(tensor) for _ in range(world_size)]
61
+ torch.distributed.all_gather(gather_list, tensor)
62
+ if valid_batch_size is not None:
63
+ gather_list = gather_list[:valid_batch_size]
64
+ elif is_valid is not None:
65
+ gather_list = [g for g, v in zip(gather_list, is_valid_list) if v]
66
+ if out_numpy:
67
+ gather_list = [t.cpu().numpy() for t in gather_list]
68
+ tensor_list_out.append(gather_list)
69
+ return tensor_list_out