Upload 16 files
Browse files- .gitattributes +1 -0
- README.md +77 -3
- Test.json +1027 -0
- Train.json +0 -0
- assets/datasample.PNG +3 -0
- assets/overview.PNG +0 -0
- environment.yml +23 -0
- main.py +245 -0
- merge.py +32 -0
- optimizers/__init__.py +0 -0
- optimizers/lr_scheduler.py +100 -0
- test.py +145 -0
- trainer.py +223 -0
- utils/__init__.py +0 -0
- utils/data_utils.py +169 -0
- utils/textswin_unetr.py +1081 -0
- utils/utils.py +69 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/datasample.PNG filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,3 +1,77 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TextBraTS
|
2 |
+
|
3 |
+
A volume-level text-image public dataset with novel text-guided 3D brain tumor segmentation from BraTS challenge.
|
4 |
+
|
5 |
+
---
|
6 |
+
|
7 |
+
## Introduction
|
8 |
+
|
9 |
+
**TextBraTS** is an open-access dataset designed to advance research in text-guided 3D brain tumor segmentation. It includes paired multi-modal brain MRI scans and expertly annotated radiology reports, enabling the development and evaluation of multi-modal deep learning models that bridge vision and language in neuro-oncology. Our work has been accepted by MICCAI 2025. The paper is also available on arXiv: [2506.16784](https://arxiv.org/abs/2506.16784).
|
10 |
+
|
11 |
+

|
12 |
+
|
13 |
+
## Features
|
14 |
+
|
15 |
+
- Multi-modal 3D brain MRI scans with expert-annotated segmentation (T1, T1ce, T2, FLAIR) from BraTS20 challenge training set
|
16 |
+
- Structured radiology reports for each case
|
17 |
+
- Text-image alignment method for research on multi-modal fusion
|
18 |
+
|
19 |
+

|
20 |
+
|
21 |
+
## Usage
|
22 |
+
|
23 |
+
You can use this dataset for:
|
24 |
+
- Developing and benchmarking text-guided segmentation models
|
25 |
+
- Evaluating multi-modal fusion algorithms in medical imaging
|
26 |
+
- Research in language-driven medical AI
|
27 |
+
|
28 |
+
## Installing Dependencies
|
29 |
+
Run the following commands to set up the environment:
|
30 |
+
<pre>conda env create -f environment.yml
|
31 |
+
pip install git+https://github.com/Project-MONAI/MONAI.git@07de215c </pre>
|
32 |
+
If you need to activate the environment, use:
|
33 |
+
<pre>conda activate TextBraTS </pre>
|
34 |
+
|
35 |
+
## Dataset
|
36 |
+
|
37 |
+
Due to BraTS official guidelines, MRI images must be downloaded directly from the [BraTS 2020 challenge website](https://www.med.upenn.edu/cbica/brats2020/data.html) (training set).
|
38 |
+
|
39 |
+
**Download our text, feature, and prompt files:**
|
40 |
+
You can download our dataset from [TextBraTSData](https://drive.google.com/file/d/1i1R6_bVY4VbNtxEIQVsiXUSWuVAtgJhg/view?usp=sharing).
|
41 |
+
Our provided text reports, feature files, and prompt files are named to match the original BraTS folder IDs exactly. You can set the path and simply merge them with the downloaded MRI data by `merge.py`.
|
42 |
+
<pre>python merge.py</pre>
|
43 |
+
|
44 |
+
If you would like to change the dataset split, please modify the `Train.json` and `Test.json` files accordingly.
|
45 |
+
|
46 |
+
## Inference
|
47 |
+
|
48 |
+
We provide our pre-trained weights for direct inference and evaluation.
|
49 |
+
Download the weights from [checkpoint](https://drive.google.com/file/d/147283LL2fRDcTYR_vQA-95vbZysjjD1v/view?usp=sharing).
|
50 |
+
|
51 |
+
After downloading, place the weights in your desired directory, then run the `test.py` with following command for inference:
|
52 |
+
|
53 |
+
<pre>python test.py --pretrained_dir=/path/to/your/weights/ --exp_name=TextBraTS</pre>
|
54 |
+
|
55 |
+
## Training
|
56 |
+
|
57 |
+
If you would like to train the model from scratch, you can modify the training code `main.py` and please use the following command:
|
58 |
+
|
59 |
+
<pre>python main.py --distributed --use_ssl_pretrained --save_checkpoint --logdir=TextBraTS</pre>
|
60 |
+
|
61 |
+
- The `--use_ssl_pretrained` option utilizes the pre-trained weights from NVIDIA's Swin UNETR model.
|
62 |
+
- Download the Swin UNETR pre-trained weights from [Pre-trained weights](https://drive.google.com/file/d/1FJ0N_Xo3olzAV-oojEkAsbsUgiFsoPdl/view?usp=sharing).
|
63 |
+
- Please place the downloaded weights in the appropriate directory as specified in your configuration or script.
|
64 |
+
|
65 |
+
|
66 |
+
## Citation
|
67 |
+
|
68 |
+
If you use TextBraTS in your research, please cite:
|
69 |
+
|
70 |
+
```bibtex
|
71 |
+
@inproceedings{shi2025textbrats,
|
72 |
+
title = {TextBraTS: Text-Guided Volumetric Brain Tumor Segmentation with Innovative Dataset Development and Fusion Module Exploration},
|
73 |
+
author = {Shi, Xiaoyu and Jain, Rahul Kumar and Li, Yinhao and Hou, Ruibo and Cheng, Jingliang and Bai, Jie and Zhao, Guohua and Lin, Lanfen and Xu, Rui and Chen, Yen-wei},
|
74 |
+
booktitle = {Proceedings of the International Conference on Medical Image Computing and Computer Assisted Intervention (MICCAI)},
|
75 |
+
year = {2025},
|
76 |
+
note = {to appear}
|
77 |
+
}
|
Test.json
ADDED
@@ -0,0 +1,1027 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"training": [
|
3 |
+
{
|
4 |
+
"fold": 0,
|
5 |
+
"image": [
|
6 |
+
"BraTS20_Training_101/BraTS20_Training_101_flair.nii.gz",
|
7 |
+
"BraTS20_Training_101/BraTS20_Training_101_t1.nii.gz",
|
8 |
+
"BraTS20_Training_101/BraTS20_Training_101_t1ce.nii.gz",
|
9 |
+
"BraTS20_Training_101/BraTS20_Training_101_t2.nii.gz"
|
10 |
+
],
|
11 |
+
"label": "BraTS20_Training_101/BraTS20_Training_101_seg.nii.gz",
|
12 |
+
"text_feature": "BraTS20_Training_101/BraTS20_Training_101_flair_text.npy"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"fold": 0,
|
16 |
+
"image": [
|
17 |
+
"BraTS20_Training_005/BraTS20_Training_005_flair.nii.gz",
|
18 |
+
"BraTS20_Training_005/BraTS20_Training_005_t1.nii.gz",
|
19 |
+
"BraTS20_Training_005/BraTS20_Training_005_t1ce.nii.gz",
|
20 |
+
"BraTS20_Training_005/BraTS20_Training_005_t2.nii.gz"
|
21 |
+
],
|
22 |
+
"label": "BraTS20_Training_005/BraTS20_Training_005_seg.nii.gz",
|
23 |
+
"text_feature": "BraTS20_Training_005/BraTS20_Training_005_flair_text.npy"
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"fold": 0,
|
27 |
+
"image": [
|
28 |
+
"BraTS20_Training_173/BraTS20_Training_173_flair.nii.gz",
|
29 |
+
"BraTS20_Training_173/BraTS20_Training_173_t1.nii.gz",
|
30 |
+
"BraTS20_Training_173/BraTS20_Training_173_t1ce.nii.gz",
|
31 |
+
"BraTS20_Training_173/BraTS20_Training_173_t2.nii.gz"
|
32 |
+
],
|
33 |
+
"label": "BraTS20_Training_173/BraTS20_Training_173_seg.nii.gz",
|
34 |
+
"text_feature": "BraTS20_Training_173/BraTS20_Training_173_flair_text.npy"
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"fold": 0,
|
38 |
+
"image": [
|
39 |
+
"BraTS20_Training_241/BraTS20_Training_241_flair.nii.gz",
|
40 |
+
"BraTS20_Training_241/BraTS20_Training_241_t1.nii.gz",
|
41 |
+
"BraTS20_Training_241/BraTS20_Training_241_t1ce.nii.gz",
|
42 |
+
"BraTS20_Training_241/BraTS20_Training_241_t2.nii.gz"
|
43 |
+
],
|
44 |
+
"label": "BraTS20_Training_241/BraTS20_Training_241_seg.nii.gz",
|
45 |
+
"text_feature": "BraTS20_Training_241/BraTS20_Training_241_flair_text.npy"
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"fold": 0,
|
49 |
+
"image": [
|
50 |
+
"BraTS20_Training_309/BraTS20_Training_309_flair.nii.gz",
|
51 |
+
"BraTS20_Training_309/BraTS20_Training_309_t1.nii.gz",
|
52 |
+
"BraTS20_Training_309/BraTS20_Training_309_t1ce.nii.gz",
|
53 |
+
"BraTS20_Training_309/BraTS20_Training_309_t2.nii.gz"
|
54 |
+
],
|
55 |
+
"label": "BraTS20_Training_309/BraTS20_Training_309_seg.nii.gz",
|
56 |
+
"text_feature": "BraTS20_Training_309/BraTS20_Training_309_flair_text.npy"
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"fold": 0,
|
60 |
+
"image": [
|
61 |
+
"BraTS20_Training_258/BraTS20_Training_258_flair.nii.gz",
|
62 |
+
"BraTS20_Training_258/BraTS20_Training_258_t1.nii.gz",
|
63 |
+
"BraTS20_Training_258/BraTS20_Training_258_t1ce.nii.gz",
|
64 |
+
"BraTS20_Training_258/BraTS20_Training_258_t2.nii.gz"
|
65 |
+
],
|
66 |
+
"label": "BraTS20_Training_258/BraTS20_Training_258_seg.nii.gz",
|
67 |
+
"text_feature": "BraTS20_Training_258/BraTS20_Training_258_flair_text.npy"
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"fold": 0,
|
71 |
+
"image": [
|
72 |
+
"BraTS20_Training_194/BraTS20_Training_194_flair.nii.gz",
|
73 |
+
"BraTS20_Training_194/BraTS20_Training_194_t1.nii.gz",
|
74 |
+
"BraTS20_Training_194/BraTS20_Training_194_t1ce.nii.gz",
|
75 |
+
"BraTS20_Training_194/BraTS20_Training_194_t2.nii.gz"
|
76 |
+
],
|
77 |
+
"label": "BraTS20_Training_194/BraTS20_Training_194_seg.nii.gz",
|
78 |
+
"text_feature": "BraTS20_Training_194/BraTS20_Training_194_flair_text.npy"
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"fold": 0,
|
82 |
+
"image": [
|
83 |
+
"BraTS20_Training_103/BraTS20_Training_103_flair.nii.gz",
|
84 |
+
"BraTS20_Training_103/BraTS20_Training_103_t1.nii.gz",
|
85 |
+
"BraTS20_Training_103/BraTS20_Training_103_t1ce.nii.gz",
|
86 |
+
"BraTS20_Training_103/BraTS20_Training_103_t2.nii.gz"
|
87 |
+
],
|
88 |
+
"label": "BraTS20_Training_103/BraTS20_Training_103_seg.nii.gz",
|
89 |
+
"text_feature": "BraTS20_Training_103/BraTS20_Training_103_flair_text.npy"
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"fold": 0,
|
93 |
+
"image": [
|
94 |
+
"BraTS20_Training_170/BraTS20_Training_170_flair.nii.gz",
|
95 |
+
"BraTS20_Training_170/BraTS20_Training_170_t1.nii.gz",
|
96 |
+
"BraTS20_Training_170/BraTS20_Training_170_t1ce.nii.gz",
|
97 |
+
"BraTS20_Training_170/BraTS20_Training_170_t2.nii.gz"
|
98 |
+
],
|
99 |
+
"label": "BraTS20_Training_170/BraTS20_Training_170_seg.nii.gz",
|
100 |
+
"text_feature": "BraTS20_Training_170/BraTS20_Training_170_flair_text.npy"
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"fold": 0,
|
104 |
+
"image": [
|
105 |
+
"BraTS20_Training_268/BraTS20_Training_268_flair.nii.gz",
|
106 |
+
"BraTS20_Training_268/BraTS20_Training_268_t1.nii.gz",
|
107 |
+
"BraTS20_Training_268/BraTS20_Training_268_t1ce.nii.gz",
|
108 |
+
"BraTS20_Training_268/BraTS20_Training_268_t2.nii.gz"
|
109 |
+
],
|
110 |
+
"label": "BraTS20_Training_268/BraTS20_Training_268_seg.nii.gz",
|
111 |
+
"text_feature": "BraTS20_Training_268/BraTS20_Training_268_flair_text.npy"
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"fold": 0,
|
115 |
+
"image": [
|
116 |
+
"BraTS20_Training_346/BraTS20_Training_346_flair.nii.gz",
|
117 |
+
"BraTS20_Training_346/BraTS20_Training_346_t1.nii.gz",
|
118 |
+
"BraTS20_Training_346/BraTS20_Training_346_t1ce.nii.gz",
|
119 |
+
"BraTS20_Training_346/BraTS20_Training_346_t2.nii.gz"
|
120 |
+
],
|
121 |
+
"label": "BraTS20_Training_346/BraTS20_Training_346_seg.nii.gz",
|
122 |
+
"text_feature": "BraTS20_Training_346/BraTS20_Training_346_flair_text.npy"
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"fold": 0,
|
126 |
+
"image": [
|
127 |
+
"BraTS20_Training_149/BraTS20_Training_149_flair.nii.gz",
|
128 |
+
"BraTS20_Training_149/BraTS20_Training_149_t1.nii.gz",
|
129 |
+
"BraTS20_Training_149/BraTS20_Training_149_t1ce.nii.gz",
|
130 |
+
"BraTS20_Training_149/BraTS20_Training_149_t2.nii.gz"
|
131 |
+
],
|
132 |
+
"label": "BraTS20_Training_149/BraTS20_Training_149_seg.nii.gz",
|
133 |
+
"text_feature": "BraTS20_Training_149/BraTS20_Training_149_flair_text.npy"
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"fold": 0,
|
137 |
+
"image": [
|
138 |
+
"BraTS20_Training_367/BraTS20_Training_367_flair.nii.gz",
|
139 |
+
"BraTS20_Training_367/BraTS20_Training_367_t1.nii.gz",
|
140 |
+
"BraTS20_Training_367/BraTS20_Training_367_t1ce.nii.gz",
|
141 |
+
"BraTS20_Training_367/BraTS20_Training_367_t2.nii.gz"
|
142 |
+
],
|
143 |
+
"label": "BraTS20_Training_367/BraTS20_Training_367_seg.nii.gz",
|
144 |
+
"text_feature": "BraTS20_Training_367/BraTS20_Training_367_flair_text.npy"
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"fold": 0,
|
148 |
+
"image": [
|
149 |
+
"BraTS20_Training_220/BraTS20_Training_220_flair.nii.gz",
|
150 |
+
"BraTS20_Training_220/BraTS20_Training_220_t1.nii.gz",
|
151 |
+
"BraTS20_Training_220/BraTS20_Training_220_t1ce.nii.gz",
|
152 |
+
"BraTS20_Training_220/BraTS20_Training_220_t2.nii.gz"
|
153 |
+
],
|
154 |
+
"label": "BraTS20_Training_220/BraTS20_Training_220_seg.nii.gz",
|
155 |
+
"text_feature": "BraTS20_Training_220/BraTS20_Training_220_flair_text.npy"
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"fold": 0,
|
159 |
+
"image": [
|
160 |
+
"BraTS20_Training_368/BraTS20_Training_368_flair.nii.gz",
|
161 |
+
"BraTS20_Training_368/BraTS20_Training_368_t1.nii.gz",
|
162 |
+
"BraTS20_Training_368/BraTS20_Training_368_t1ce.nii.gz",
|
163 |
+
"BraTS20_Training_368/BraTS20_Training_368_t2.nii.gz"
|
164 |
+
],
|
165 |
+
"label": "BraTS20_Training_368/BraTS20_Training_368_seg.nii.gz",
|
166 |
+
"text_feature": "BraTS20_Training_368/BraTS20_Training_368_flair_text.npy"
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"fold": 0,
|
170 |
+
"image": [
|
171 |
+
"BraTS20_Training_289/BraTS20_Training_289_flair.nii.gz",
|
172 |
+
"BraTS20_Training_289/BraTS20_Training_289_t1.nii.gz",
|
173 |
+
"BraTS20_Training_289/BraTS20_Training_289_t1ce.nii.gz",
|
174 |
+
"BraTS20_Training_289/BraTS20_Training_289_t2.nii.gz"
|
175 |
+
],
|
176 |
+
"label": "BraTS20_Training_289/BraTS20_Training_289_seg.nii.gz",
|
177 |
+
"text_feature": "BraTS20_Training_289/BraTS20_Training_289_flair_text.npy"
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"fold": 0,
|
181 |
+
"image": [
|
182 |
+
"BraTS20_Training_084/BraTS20_Training_084_flair.nii.gz",
|
183 |
+
"BraTS20_Training_084/BraTS20_Training_084_t1.nii.gz",
|
184 |
+
"BraTS20_Training_084/BraTS20_Training_084_t1ce.nii.gz",
|
185 |
+
"BraTS20_Training_084/BraTS20_Training_084_t2.nii.gz"
|
186 |
+
],
|
187 |
+
"label": "BraTS20_Training_084/BraTS20_Training_084_seg.nii.gz",
|
188 |
+
"text_feature": "BraTS20_Training_084/BraTS20_Training_084_flair_text.npy"
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"fold": 0,
|
192 |
+
"image": [
|
193 |
+
"BraTS20_Training_277/BraTS20_Training_277_flair.nii.gz",
|
194 |
+
"BraTS20_Training_277/BraTS20_Training_277_t1.nii.gz",
|
195 |
+
"BraTS20_Training_277/BraTS20_Training_277_t1ce.nii.gz",
|
196 |
+
"BraTS20_Training_277/BraTS20_Training_277_t2.nii.gz"
|
197 |
+
],
|
198 |
+
"label": "BraTS20_Training_277/BraTS20_Training_277_seg.nii.gz",
|
199 |
+
"text_feature": "BraTS20_Training_277/BraTS20_Training_277_flair_text.npy"
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"fold": 0,
|
203 |
+
"image": [
|
204 |
+
"BraTS20_Training_202/BraTS20_Training_202_flair.nii.gz",
|
205 |
+
"BraTS20_Training_202/BraTS20_Training_202_t1.nii.gz",
|
206 |
+
"BraTS20_Training_202/BraTS20_Training_202_t1ce.nii.gz",
|
207 |
+
"BraTS20_Training_202/BraTS20_Training_202_t2.nii.gz"
|
208 |
+
],
|
209 |
+
"label": "BraTS20_Training_202/BraTS20_Training_202_seg.nii.gz",
|
210 |
+
"text_feature": "BraTS20_Training_202/BraTS20_Training_202_flair_text.npy"
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"fold": 0,
|
214 |
+
"image": [
|
215 |
+
"BraTS20_Training_151/BraTS20_Training_151_flair.nii.gz",
|
216 |
+
"BraTS20_Training_151/BraTS20_Training_151_t1.nii.gz",
|
217 |
+
"BraTS20_Training_151/BraTS20_Training_151_t1ce.nii.gz",
|
218 |
+
"BraTS20_Training_151/BraTS20_Training_151_t2.nii.gz"
|
219 |
+
],
|
220 |
+
"label": "BraTS20_Training_151/BraTS20_Training_151_seg.nii.gz",
|
221 |
+
"text_feature": "BraTS20_Training_151/BraTS20_Training_151_flair_text.npy"
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"fold": 0,
|
225 |
+
"image": [
|
226 |
+
"BraTS20_Training_142/BraTS20_Training_142_flair.nii.gz",
|
227 |
+
"BraTS20_Training_142/BraTS20_Training_142_t1.nii.gz",
|
228 |
+
"BraTS20_Training_142/BraTS20_Training_142_t1ce.nii.gz",
|
229 |
+
"BraTS20_Training_142/BraTS20_Training_142_t2.nii.gz"
|
230 |
+
],
|
231 |
+
"label": "BraTS20_Training_142/BraTS20_Training_142_seg.nii.gz",
|
232 |
+
"text_feature": "BraTS20_Training_142/BraTS20_Training_142_flair_text.npy"
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"fold": 0,
|
236 |
+
"image": [
|
237 |
+
"BraTS20_Training_229/BraTS20_Training_229_flair.nii.gz",
|
238 |
+
"BraTS20_Training_229/BraTS20_Training_229_t1.nii.gz",
|
239 |
+
"BraTS20_Training_229/BraTS20_Training_229_t1ce.nii.gz",
|
240 |
+
"BraTS20_Training_229/BraTS20_Training_229_t2.nii.gz"
|
241 |
+
],
|
242 |
+
"label": "BraTS20_Training_229/BraTS20_Training_229_seg.nii.gz",
|
243 |
+
"text_feature": "BraTS20_Training_229/BraTS20_Training_229_flair_text.npy"
|
244 |
+
},
|
245 |
+
{
|
246 |
+
"fold": 0,
|
247 |
+
"image": [
|
248 |
+
"BraTS20_Training_322/BraTS20_Training_322_flair.nii.gz",
|
249 |
+
"BraTS20_Training_322/BraTS20_Training_322_t1.nii.gz",
|
250 |
+
"BraTS20_Training_322/BraTS20_Training_322_t1ce.nii.gz",
|
251 |
+
"BraTS20_Training_322/BraTS20_Training_322_t2.nii.gz"
|
252 |
+
],
|
253 |
+
"label": "BraTS20_Training_322/BraTS20_Training_322_seg.nii.gz",
|
254 |
+
"text_feature": "BraTS20_Training_322/BraTS20_Training_322_flair_text.npy"
|
255 |
+
},
|
256 |
+
{
|
257 |
+
"fold": 0,
|
258 |
+
"image": [
|
259 |
+
"BraTS20_Training_278/BraTS20_Training_278_flair.nii.gz",
|
260 |
+
"BraTS20_Training_278/BraTS20_Training_278_t1.nii.gz",
|
261 |
+
"BraTS20_Training_278/BraTS20_Training_278_t1ce.nii.gz",
|
262 |
+
"BraTS20_Training_278/BraTS20_Training_278_t2.nii.gz"
|
263 |
+
],
|
264 |
+
"label": "BraTS20_Training_278/BraTS20_Training_278_seg.nii.gz",
|
265 |
+
"text_feature": "BraTS20_Training_278/BraTS20_Training_278_flair_text.npy"
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"fold": 0,
|
269 |
+
"image": [
|
270 |
+
"BraTS20_Training_206/BraTS20_Training_206_flair.nii.gz",
|
271 |
+
"BraTS20_Training_206/BraTS20_Training_206_t1.nii.gz",
|
272 |
+
"BraTS20_Training_206/BraTS20_Training_206_t1ce.nii.gz",
|
273 |
+
"BraTS20_Training_206/BraTS20_Training_206_t2.nii.gz"
|
274 |
+
],
|
275 |
+
"label": "BraTS20_Training_206/BraTS20_Training_206_seg.nii.gz",
|
276 |
+
"text_feature": "BraTS20_Training_206/BraTS20_Training_206_flair_text.npy"
|
277 |
+
},
|
278 |
+
{
|
279 |
+
"fold": 0,
|
280 |
+
"image": [
|
281 |
+
"BraTS20_Training_049/BraTS20_Training_049_flair.nii.gz",
|
282 |
+
"BraTS20_Training_049/BraTS20_Training_049_t1.nii.gz",
|
283 |
+
"BraTS20_Training_049/BraTS20_Training_049_t1ce.nii.gz",
|
284 |
+
"BraTS20_Training_049/BraTS20_Training_049_t2.nii.gz"
|
285 |
+
],
|
286 |
+
"label": "BraTS20_Training_049/BraTS20_Training_049_seg.nii.gz",
|
287 |
+
"text_feature": "BraTS20_Training_049/BraTS20_Training_049_flair_text.npy"
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"fold": 0,
|
291 |
+
"image": [
|
292 |
+
"BraTS20_Training_115/BraTS20_Training_115_flair.nii.gz",
|
293 |
+
"BraTS20_Training_115/BraTS20_Training_115_t1.nii.gz",
|
294 |
+
"BraTS20_Training_115/BraTS20_Training_115_t1ce.nii.gz",
|
295 |
+
"BraTS20_Training_115/BraTS20_Training_115_t2.nii.gz"
|
296 |
+
],
|
297 |
+
"label": "BraTS20_Training_115/BraTS20_Training_115_seg.nii.gz",
|
298 |
+
"text_feature": "BraTS20_Training_115/BraTS20_Training_115_flair_text.npy"
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"fold": 0,
|
302 |
+
"image": [
|
303 |
+
"BraTS20_Training_147/BraTS20_Training_147_flair.nii.gz",
|
304 |
+
"BraTS20_Training_147/BraTS20_Training_147_t1.nii.gz",
|
305 |
+
"BraTS20_Training_147/BraTS20_Training_147_t1ce.nii.gz",
|
306 |
+
"BraTS20_Training_147/BraTS20_Training_147_t2.nii.gz"
|
307 |
+
],
|
308 |
+
"label": "BraTS20_Training_147/BraTS20_Training_147_seg.nii.gz",
|
309 |
+
"text_feature": "BraTS20_Training_147/BraTS20_Training_147_flair_text.npy"
|
310 |
+
},
|
311 |
+
{
|
312 |
+
"fold": 0,
|
313 |
+
"image": [
|
314 |
+
"BraTS20_Training_226/BraTS20_Training_226_flair.nii.gz",
|
315 |
+
"BraTS20_Training_226/BraTS20_Training_226_t1.nii.gz",
|
316 |
+
"BraTS20_Training_226/BraTS20_Training_226_t1ce.nii.gz",
|
317 |
+
"BraTS20_Training_226/BraTS20_Training_226_t2.nii.gz"
|
318 |
+
],
|
319 |
+
"label": "BraTS20_Training_226/BraTS20_Training_226_seg.nii.gz",
|
320 |
+
"text_feature": "BraTS20_Training_226/BraTS20_Training_226_flair_text.npy"
|
321 |
+
},
|
322 |
+
{
|
323 |
+
"fold": 0,
|
324 |
+
"image": [
|
325 |
+
"BraTS20_Training_066/BraTS20_Training_066_flair.nii.gz",
|
326 |
+
"BraTS20_Training_066/BraTS20_Training_066_t1.nii.gz",
|
327 |
+
"BraTS20_Training_066/BraTS20_Training_066_t1ce.nii.gz",
|
328 |
+
"BraTS20_Training_066/BraTS20_Training_066_t2.nii.gz"
|
329 |
+
],
|
330 |
+
"label": "BraTS20_Training_066/BraTS20_Training_066_seg.nii.gz",
|
331 |
+
"text_feature": "BraTS20_Training_066/BraTS20_Training_066_flair_text.npy"
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"fold": 0,
|
335 |
+
"image": [
|
336 |
+
"BraTS20_Training_124/BraTS20_Training_124_flair.nii.gz",
|
337 |
+
"BraTS20_Training_124/BraTS20_Training_124_t1.nii.gz",
|
338 |
+
"BraTS20_Training_124/BraTS20_Training_124_t1ce.nii.gz",
|
339 |
+
"BraTS20_Training_124/BraTS20_Training_124_t2.nii.gz"
|
340 |
+
],
|
341 |
+
"label": "BraTS20_Training_124/BraTS20_Training_124_seg.nii.gz",
|
342 |
+
"text_feature": "BraTS20_Training_124/BraTS20_Training_124_flair_text.npy"
|
343 |
+
},
|
344 |
+
{
|
345 |
+
"fold": 0,
|
346 |
+
"image": [
|
347 |
+
"BraTS20_Training_274/BraTS20_Training_274_flair.nii.gz",
|
348 |
+
"BraTS20_Training_274/BraTS20_Training_274_t1.nii.gz",
|
349 |
+
"BraTS20_Training_274/BraTS20_Training_274_t1ce.nii.gz",
|
350 |
+
"BraTS20_Training_274/BraTS20_Training_274_t2.nii.gz"
|
351 |
+
],
|
352 |
+
"label": "BraTS20_Training_274/BraTS20_Training_274_seg.nii.gz",
|
353 |
+
"text_feature": "BraTS20_Training_274/BraTS20_Training_274_flair_text.npy"
|
354 |
+
},
|
355 |
+
{
|
356 |
+
"fold": 0,
|
357 |
+
"image": [
|
358 |
+
"BraTS20_Training_290/BraTS20_Training_290_flair.nii.gz",
|
359 |
+
"BraTS20_Training_290/BraTS20_Training_290_t1.nii.gz",
|
360 |
+
"BraTS20_Training_290/BraTS20_Training_290_t1ce.nii.gz",
|
361 |
+
"BraTS20_Training_290/BraTS20_Training_290_t2.nii.gz"
|
362 |
+
],
|
363 |
+
"label": "BraTS20_Training_290/BraTS20_Training_290_seg.nii.gz",
|
364 |
+
"text_feature": "BraTS20_Training_290/BraTS20_Training_290_flair_text.npy"
|
365 |
+
},
|
366 |
+
{
|
367 |
+
"fold": 0,
|
368 |
+
"image": [
|
369 |
+
"BraTS20_Training_200/BraTS20_Training_200_flair.nii.gz",
|
370 |
+
"BraTS20_Training_200/BraTS20_Training_200_t1.nii.gz",
|
371 |
+
"BraTS20_Training_200/BraTS20_Training_200_t1ce.nii.gz",
|
372 |
+
"BraTS20_Training_200/BraTS20_Training_200_t2.nii.gz"
|
373 |
+
],
|
374 |
+
"label": "BraTS20_Training_200/BraTS20_Training_200_seg.nii.gz",
|
375 |
+
"text_feature": "BraTS20_Training_200/BraTS20_Training_200_flair_text.npy"
|
376 |
+
},
|
377 |
+
{
|
378 |
+
"fold": 0,
|
379 |
+
"image": [
|
380 |
+
"BraTS20_Training_121/BraTS20_Training_121_flair.nii.gz",
|
381 |
+
"BraTS20_Training_121/BraTS20_Training_121_t1.nii.gz",
|
382 |
+
"BraTS20_Training_121/BraTS20_Training_121_t1ce.nii.gz",
|
383 |
+
"BraTS20_Training_121/BraTS20_Training_121_t2.nii.gz"
|
384 |
+
],
|
385 |
+
"label": "BraTS20_Training_121/BraTS20_Training_121_seg.nii.gz",
|
386 |
+
"text_feature": "BraTS20_Training_121/BraTS20_Training_121_flair_text.npy"
|
387 |
+
},
|
388 |
+
{
|
389 |
+
"fold": 0,
|
390 |
+
"image": [
|
391 |
+
"BraTS20_Training_082/BraTS20_Training_082_flair.nii.gz",
|
392 |
+
"BraTS20_Training_082/BraTS20_Training_082_t1.nii.gz",
|
393 |
+
"BraTS20_Training_082/BraTS20_Training_082_t1ce.nii.gz",
|
394 |
+
"BraTS20_Training_082/BraTS20_Training_082_t2.nii.gz"
|
395 |
+
],
|
396 |
+
"label": "BraTS20_Training_082/BraTS20_Training_082_seg.nii.gz",
|
397 |
+
"text_feature": "BraTS20_Training_082/BraTS20_Training_082_flair_text.npy"
|
398 |
+
},
|
399 |
+
{
|
400 |
+
"fold": 0,
|
401 |
+
"image": [
|
402 |
+
"BraTS20_Training_052/BraTS20_Training_052_flair.nii.gz",
|
403 |
+
"BraTS20_Training_052/BraTS20_Training_052_t1.nii.gz",
|
404 |
+
"BraTS20_Training_052/BraTS20_Training_052_t1ce.nii.gz",
|
405 |
+
"BraTS20_Training_052/BraTS20_Training_052_t2.nii.gz"
|
406 |
+
],
|
407 |
+
"label": "BraTS20_Training_052/BraTS20_Training_052_seg.nii.gz",
|
408 |
+
"text_feature": "BraTS20_Training_052/BraTS20_Training_052_flair_text.npy"
|
409 |
+
},
|
410 |
+
{
|
411 |
+
"fold": 0,
|
412 |
+
"image": [
|
413 |
+
"BraTS20_Training_104/BraTS20_Training_104_flair.nii.gz",
|
414 |
+
"BraTS20_Training_104/BraTS20_Training_104_t1.nii.gz",
|
415 |
+
"BraTS20_Training_104/BraTS20_Training_104_t1ce.nii.gz",
|
416 |
+
"BraTS20_Training_104/BraTS20_Training_104_t2.nii.gz"
|
417 |
+
],
|
418 |
+
"label": "BraTS20_Training_104/BraTS20_Training_104_seg.nii.gz",
|
419 |
+
"text_feature": "BraTS20_Training_104/BraTS20_Training_104_flair_text.npy"
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"fold": 0,
|
423 |
+
"image": [
|
424 |
+
"BraTS20_Training_062/BraTS20_Training_062_flair.nii.gz",
|
425 |
+
"BraTS20_Training_062/BraTS20_Training_062_t1.nii.gz",
|
426 |
+
"BraTS20_Training_062/BraTS20_Training_062_t1ce.nii.gz",
|
427 |
+
"BraTS20_Training_062/BraTS20_Training_062_t2.nii.gz"
|
428 |
+
],
|
429 |
+
"label": "BraTS20_Training_062/BraTS20_Training_062_seg.nii.gz",
|
430 |
+
"text_feature": "BraTS20_Training_062/BraTS20_Training_062_flair_text.npy"
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"fold": 0,
|
434 |
+
"image": [
|
435 |
+
"BraTS20_Training_214/BraTS20_Training_214_flair.nii.gz",
|
436 |
+
"BraTS20_Training_214/BraTS20_Training_214_t1.nii.gz",
|
437 |
+
"BraTS20_Training_214/BraTS20_Training_214_t1ce.nii.gz",
|
438 |
+
"BraTS20_Training_214/BraTS20_Training_214_t2.nii.gz"
|
439 |
+
],
|
440 |
+
"label": "BraTS20_Training_214/BraTS20_Training_214_seg.nii.gz",
|
441 |
+
"text_feature": "BraTS20_Training_214/BraTS20_Training_214_flair_text.npy"
|
442 |
+
},
|
443 |
+
{
|
444 |
+
"fold": 0,
|
445 |
+
"image": [
|
446 |
+
"BraTS20_Training_360/BraTS20_Training_360_flair.nii.gz",
|
447 |
+
"BraTS20_Training_360/BraTS20_Training_360_t1.nii.gz",
|
448 |
+
"BraTS20_Training_360/BraTS20_Training_360_t1ce.nii.gz",
|
449 |
+
"BraTS20_Training_360/BraTS20_Training_360_t2.nii.gz"
|
450 |
+
],
|
451 |
+
"label": "BraTS20_Training_360/BraTS20_Training_360_seg.nii.gz",
|
452 |
+
"text_feature": "BraTS20_Training_360/BraTS20_Training_360_flair_text.npy"
|
453 |
+
},
|
454 |
+
{
|
455 |
+
"fold": 0,
|
456 |
+
"image": [
|
457 |
+
"BraTS20_Training_041/BraTS20_Training_041_flair.nii.gz",
|
458 |
+
"BraTS20_Training_041/BraTS20_Training_041_t1.nii.gz",
|
459 |
+
"BraTS20_Training_041/BraTS20_Training_041_t1ce.nii.gz",
|
460 |
+
"BraTS20_Training_041/BraTS20_Training_041_t2.nii.gz"
|
461 |
+
],
|
462 |
+
"label": "BraTS20_Training_041/BraTS20_Training_041_seg.nii.gz",
|
463 |
+
"text_feature": "BraTS20_Training_041/BraTS20_Training_041_flair_text.npy"
|
464 |
+
},
|
465 |
+
{
|
466 |
+
"fold": 0,
|
467 |
+
"image": [
|
468 |
+
"BraTS20_Training_009/BraTS20_Training_009_flair.nii.gz",
|
469 |
+
"BraTS20_Training_009/BraTS20_Training_009_t1.nii.gz",
|
470 |
+
"BraTS20_Training_009/BraTS20_Training_009_t1ce.nii.gz",
|
471 |
+
"BraTS20_Training_009/BraTS20_Training_009_t2.nii.gz"
|
472 |
+
],
|
473 |
+
"label": "BraTS20_Training_009/BraTS20_Training_009_seg.nii.gz",
|
474 |
+
"text_feature": "BraTS20_Training_009/BraTS20_Training_009_flair_text.npy"
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"fold": 0,
|
478 |
+
"image": [
|
479 |
+
"BraTS20_Training_347/BraTS20_Training_347_flair.nii.gz",
|
480 |
+
"BraTS20_Training_347/BraTS20_Training_347_t1.nii.gz",
|
481 |
+
"BraTS20_Training_347/BraTS20_Training_347_t1ce.nii.gz",
|
482 |
+
"BraTS20_Training_347/BraTS20_Training_347_t2.nii.gz"
|
483 |
+
],
|
484 |
+
"label": "BraTS20_Training_347/BraTS20_Training_347_seg.nii.gz",
|
485 |
+
"text_feature": "BraTS20_Training_347/BraTS20_Training_347_flair_text.npy"
|
486 |
+
},
|
487 |
+
{
|
488 |
+
"fold": 0,
|
489 |
+
"image": [
|
490 |
+
"BraTS20_Training_330/BraTS20_Training_330_flair.nii.gz",
|
491 |
+
"BraTS20_Training_330/BraTS20_Training_330_t1.nii.gz",
|
492 |
+
"BraTS20_Training_330/BraTS20_Training_330_t1ce.nii.gz",
|
493 |
+
"BraTS20_Training_330/BraTS20_Training_330_t2.nii.gz"
|
494 |
+
],
|
495 |
+
"label": "BraTS20_Training_330/BraTS20_Training_330_seg.nii.gz",
|
496 |
+
"text_feature": "BraTS20_Training_330/BraTS20_Training_330_flair_text.npy"
|
497 |
+
},
|
498 |
+
{
|
499 |
+
"fold": 0,
|
500 |
+
"image": [
|
501 |
+
"BraTS20_Training_122/BraTS20_Training_122_flair.nii.gz",
|
502 |
+
"BraTS20_Training_122/BraTS20_Training_122_t1.nii.gz",
|
503 |
+
"BraTS20_Training_122/BraTS20_Training_122_t1ce.nii.gz",
|
504 |
+
"BraTS20_Training_122/BraTS20_Training_122_t2.nii.gz"
|
505 |
+
],
|
506 |
+
"label": "BraTS20_Training_122/BraTS20_Training_122_seg.nii.gz",
|
507 |
+
"text_feature": "BraTS20_Training_122/BraTS20_Training_122_flair_text.npy"
|
508 |
+
},
|
509 |
+
{
|
510 |
+
"fold": 0,
|
511 |
+
"image": [
|
512 |
+
"BraTS20_Training_340/BraTS20_Training_340_flair.nii.gz",
|
513 |
+
"BraTS20_Training_340/BraTS20_Training_340_t1.nii.gz",
|
514 |
+
"BraTS20_Training_340/BraTS20_Training_340_t1ce.nii.gz",
|
515 |
+
"BraTS20_Training_340/BraTS20_Training_340_t2.nii.gz"
|
516 |
+
],
|
517 |
+
"label": "BraTS20_Training_340/BraTS20_Training_340_seg.nii.gz",
|
518 |
+
"text_feature": "BraTS20_Training_340/BraTS20_Training_340_flair_text.npy"
|
519 |
+
},
|
520 |
+
{
|
521 |
+
"fold": 0,
|
522 |
+
"image": [
|
523 |
+
"BraTS20_Training_028/BraTS20_Training_028_flair.nii.gz",
|
524 |
+
"BraTS20_Training_028/BraTS20_Training_028_t1.nii.gz",
|
525 |
+
"BraTS20_Training_028/BraTS20_Training_028_t1ce.nii.gz",
|
526 |
+
"BraTS20_Training_028/BraTS20_Training_028_t2.nii.gz"
|
527 |
+
],
|
528 |
+
"label": "BraTS20_Training_028/BraTS20_Training_028_seg.nii.gz",
|
529 |
+
"text_feature": "BraTS20_Training_028/BraTS20_Training_028_flair_text.npy"
|
530 |
+
},
|
531 |
+
{
|
532 |
+
"fold": 0,
|
533 |
+
"image": [
|
534 |
+
"BraTS20_Training_265/BraTS20_Training_265_flair.nii.gz",
|
535 |
+
"BraTS20_Training_265/BraTS20_Training_265_t1.nii.gz",
|
536 |
+
"BraTS20_Training_265/BraTS20_Training_265_t1ce.nii.gz",
|
537 |
+
"BraTS20_Training_265/BraTS20_Training_265_t2.nii.gz"
|
538 |
+
],
|
539 |
+
"label": "BraTS20_Training_265/BraTS20_Training_265_seg.nii.gz",
|
540 |
+
"text_feature": "BraTS20_Training_265/BraTS20_Training_265_flair_text.npy"
|
541 |
+
},
|
542 |
+
{
|
543 |
+
"fold": 0,
|
544 |
+
"image": [
|
545 |
+
"BraTS20_Training_192/BraTS20_Training_192_flair.nii.gz",
|
546 |
+
"BraTS20_Training_192/BraTS20_Training_192_t1.nii.gz",
|
547 |
+
"BraTS20_Training_192/BraTS20_Training_192_t1ce.nii.gz",
|
548 |
+
"BraTS20_Training_192/BraTS20_Training_192_t2.nii.gz"
|
549 |
+
],
|
550 |
+
"label": "BraTS20_Training_192/BraTS20_Training_192_seg.nii.gz",
|
551 |
+
"text_feature": "BraTS20_Training_192/BraTS20_Training_192_flair_text.npy"
|
552 |
+
},
|
553 |
+
{
|
554 |
+
"fold": 0,
|
555 |
+
"image": [
|
556 |
+
"BraTS20_Training_255/BraTS20_Training_255_flair.nii.gz",
|
557 |
+
"BraTS20_Training_255/BraTS20_Training_255_t1.nii.gz",
|
558 |
+
"BraTS20_Training_255/BraTS20_Training_255_t1ce.nii.gz",
|
559 |
+
"BraTS20_Training_255/BraTS20_Training_255_t2.nii.gz"
|
560 |
+
],
|
561 |
+
"label": "BraTS20_Training_255/BraTS20_Training_255_seg.nii.gz",
|
562 |
+
"text_feature": "BraTS20_Training_255/BraTS20_Training_255_flair_text.npy"
|
563 |
+
},
|
564 |
+
{
|
565 |
+
"fold": 0,
|
566 |
+
"image": [
|
567 |
+
"BraTS20_Training_137/BraTS20_Training_137_flair.nii.gz",
|
568 |
+
"BraTS20_Training_137/BraTS20_Training_137_t1.nii.gz",
|
569 |
+
"BraTS20_Training_137/BraTS20_Training_137_t1ce.nii.gz",
|
570 |
+
"BraTS20_Training_137/BraTS20_Training_137_t2.nii.gz"
|
571 |
+
],
|
572 |
+
"label": "BraTS20_Training_137/BraTS20_Training_137_seg.nii.gz",
|
573 |
+
"text_feature": "BraTS20_Training_137/BraTS20_Training_137_flair_text.npy"
|
574 |
+
},
|
575 |
+
{
|
576 |
+
"fold": 0,
|
577 |
+
"image": [
|
578 |
+
"BraTS20_Training_001/BraTS20_Training_001_flair.nii.gz",
|
579 |
+
"BraTS20_Training_001/BraTS20_Training_001_t1.nii.gz",
|
580 |
+
"BraTS20_Training_001/BraTS20_Training_001_t1ce.nii.gz",
|
581 |
+
"BraTS20_Training_001/BraTS20_Training_001_t2.nii.gz"
|
582 |
+
],
|
583 |
+
"label": "BraTS20_Training_001/BraTS20_Training_001_seg.nii.gz",
|
584 |
+
"text_feature": "BraTS20_Training_001/BraTS20_Training_001_flair_text.npy"
|
585 |
+
},
|
586 |
+
{
|
587 |
+
"fold": 0,
|
588 |
+
"image": [
|
589 |
+
"BraTS20_Training_182/BraTS20_Training_182_flair.nii.gz",
|
590 |
+
"BraTS20_Training_182/BraTS20_Training_182_t1.nii.gz",
|
591 |
+
"BraTS20_Training_182/BraTS20_Training_182_t1ce.nii.gz",
|
592 |
+
"BraTS20_Training_182/BraTS20_Training_182_t2.nii.gz"
|
593 |
+
],
|
594 |
+
"label": "BraTS20_Training_182/BraTS20_Training_182_seg.nii.gz",
|
595 |
+
"text_feature": "BraTS20_Training_182/BraTS20_Training_182_flair_text.npy"
|
596 |
+
},
|
597 |
+
{
|
598 |
+
"fold": 0,
|
599 |
+
"image": [
|
600 |
+
"BraTS20_Training_235/BraTS20_Training_235_flair.nii.gz",
|
601 |
+
"BraTS20_Training_235/BraTS20_Training_235_t1.nii.gz",
|
602 |
+
"BraTS20_Training_235/BraTS20_Training_235_t1ce.nii.gz",
|
603 |
+
"BraTS20_Training_235/BraTS20_Training_235_t2.nii.gz"
|
604 |
+
],
|
605 |
+
"label": "BraTS20_Training_235/BraTS20_Training_235_seg.nii.gz",
|
606 |
+
"text_feature": "BraTS20_Training_235/BraTS20_Training_235_flair_text.npy"
|
607 |
+
},
|
608 |
+
{
|
609 |
+
"fold": 0,
|
610 |
+
"image": [
|
611 |
+
"BraTS20_Training_299/BraTS20_Training_299_flair.nii.gz",
|
612 |
+
"BraTS20_Training_299/BraTS20_Training_299_t1.nii.gz",
|
613 |
+
"BraTS20_Training_299/BraTS20_Training_299_t1ce.nii.gz",
|
614 |
+
"BraTS20_Training_299/BraTS20_Training_299_t2.nii.gz"
|
615 |
+
],
|
616 |
+
"label": "BraTS20_Training_299/BraTS20_Training_299_seg.nii.gz",
|
617 |
+
"text_feature": "BraTS20_Training_299/BraTS20_Training_299_flair_text.npy"
|
618 |
+
},
|
619 |
+
{
|
620 |
+
"fold": 0,
|
621 |
+
"image": [
|
622 |
+
"BraTS20_Training_019/BraTS20_Training_019_flair.nii.gz",
|
623 |
+
"BraTS20_Training_019/BraTS20_Training_019_t1.nii.gz",
|
624 |
+
"BraTS20_Training_019/BraTS20_Training_019_t1ce.nii.gz",
|
625 |
+
"BraTS20_Training_019/BraTS20_Training_019_t2.nii.gz"
|
626 |
+
],
|
627 |
+
"label": "BraTS20_Training_019/BraTS20_Training_019_seg.nii.gz",
|
628 |
+
"text_feature": "BraTS20_Training_019/BraTS20_Training_019_flair_text.npy"
|
629 |
+
},
|
630 |
+
{
|
631 |
+
"fold": 0,
|
632 |
+
"image": [
|
633 |
+
"BraTS20_Training_061/BraTS20_Training_061_flair.nii.gz",
|
634 |
+
"BraTS20_Training_061/BraTS20_Training_061_t1.nii.gz",
|
635 |
+
"BraTS20_Training_061/BraTS20_Training_061_t1ce.nii.gz",
|
636 |
+
"BraTS20_Training_061/BraTS20_Training_061_t2.nii.gz"
|
637 |
+
],
|
638 |
+
"label": "BraTS20_Training_061/BraTS20_Training_061_seg.nii.gz",
|
639 |
+
"text_feature": "BraTS20_Training_061/BraTS20_Training_061_flair_text.npy"
|
640 |
+
},
|
641 |
+
{
|
642 |
+
"fold": 0,
|
643 |
+
"image": [
|
644 |
+
"BraTS20_Training_250/BraTS20_Training_250_flair.nii.gz",
|
645 |
+
"BraTS20_Training_250/BraTS20_Training_250_t1.nii.gz",
|
646 |
+
"BraTS20_Training_250/BraTS20_Training_250_t1ce.nii.gz",
|
647 |
+
"BraTS20_Training_250/BraTS20_Training_250_t2.nii.gz"
|
648 |
+
],
|
649 |
+
"label": "BraTS20_Training_250/BraTS20_Training_250_seg.nii.gz",
|
650 |
+
"text_feature": "BraTS20_Training_250/BraTS20_Training_250_flair_text.npy"
|
651 |
+
},
|
652 |
+
{
|
653 |
+
"fold": 0,
|
654 |
+
"image": [
|
655 |
+
"BraTS20_Training_249/BraTS20_Training_249_flair.nii.gz",
|
656 |
+
"BraTS20_Training_249/BraTS20_Training_249_t1.nii.gz",
|
657 |
+
"BraTS20_Training_249/BraTS20_Training_249_t1ce.nii.gz",
|
658 |
+
"BraTS20_Training_249/BraTS20_Training_249_t2.nii.gz"
|
659 |
+
],
|
660 |
+
"label": "BraTS20_Training_249/BraTS20_Training_249_seg.nii.gz",
|
661 |
+
"text_feature": "BraTS20_Training_249/BraTS20_Training_249_flair_text.npy"
|
662 |
+
},
|
663 |
+
{
|
664 |
+
"fold": 0,
|
665 |
+
"image": [
|
666 |
+
"BraTS20_Training_168/BraTS20_Training_168_flair.nii.gz",
|
667 |
+
"BraTS20_Training_168/BraTS20_Training_168_t1.nii.gz",
|
668 |
+
"BraTS20_Training_168/BraTS20_Training_168_t1ce.nii.gz",
|
669 |
+
"BraTS20_Training_168/BraTS20_Training_168_t2.nii.gz"
|
670 |
+
],
|
671 |
+
"label": "BraTS20_Training_168/BraTS20_Training_168_seg.nii.gz",
|
672 |
+
"text_feature": "BraTS20_Training_168/BraTS20_Training_168_flair_text.npy"
|
673 |
+
},
|
674 |
+
{
|
675 |
+
"fold": 0,
|
676 |
+
"image": [
|
677 |
+
"BraTS20_Training_313/BraTS20_Training_313_flair.nii.gz",
|
678 |
+
"BraTS20_Training_313/BraTS20_Training_313_t1.nii.gz",
|
679 |
+
"BraTS20_Training_313/BraTS20_Training_313_t1ce.nii.gz",
|
680 |
+
"BraTS20_Training_313/BraTS20_Training_313_t2.nii.gz"
|
681 |
+
],
|
682 |
+
"label": "BraTS20_Training_313/BraTS20_Training_313_seg.nii.gz",
|
683 |
+
"text_feature": "BraTS20_Training_313/BraTS20_Training_313_flair_text.npy"
|
684 |
+
},
|
685 |
+
{
|
686 |
+
"fold": 0,
|
687 |
+
"image": [
|
688 |
+
"BraTS20_Training_248/BraTS20_Training_248_flair.nii.gz",
|
689 |
+
"BraTS20_Training_248/BraTS20_Training_248_t1.nii.gz",
|
690 |
+
"BraTS20_Training_248/BraTS20_Training_248_t1ce.nii.gz",
|
691 |
+
"BraTS20_Training_248/BraTS20_Training_248_t2.nii.gz"
|
692 |
+
],
|
693 |
+
"label": "BraTS20_Training_248/BraTS20_Training_248_seg.nii.gz",
|
694 |
+
"text_feature": "BraTS20_Training_248/BraTS20_Training_248_flair_text.npy"
|
695 |
+
},
|
696 |
+
{
|
697 |
+
"fold": 0,
|
698 |
+
"image": [
|
699 |
+
"BraTS20_Training_280/BraTS20_Training_280_flair.nii.gz",
|
700 |
+
"BraTS20_Training_280/BraTS20_Training_280_t1.nii.gz",
|
701 |
+
"BraTS20_Training_280/BraTS20_Training_280_t1ce.nii.gz",
|
702 |
+
"BraTS20_Training_280/BraTS20_Training_280_t2.nii.gz"
|
703 |
+
],
|
704 |
+
"label": "BraTS20_Training_280/BraTS20_Training_280_seg.nii.gz",
|
705 |
+
"text_feature": "BraTS20_Training_280/BraTS20_Training_280_flair_text.npy"
|
706 |
+
},
|
707 |
+
{
|
708 |
+
"fold": 0,
|
709 |
+
"image": [
|
710 |
+
"BraTS20_Training_156/BraTS20_Training_156_flair.nii.gz",
|
711 |
+
"BraTS20_Training_156/BraTS20_Training_156_t1.nii.gz",
|
712 |
+
"BraTS20_Training_156/BraTS20_Training_156_t1ce.nii.gz",
|
713 |
+
"BraTS20_Training_156/BraTS20_Training_156_t2.nii.gz"
|
714 |
+
],
|
715 |
+
"label": "BraTS20_Training_156/BraTS20_Training_156_seg.nii.gz",
|
716 |
+
"text_feature": "BraTS20_Training_156/BraTS20_Training_156_flair_text.npy"
|
717 |
+
},
|
718 |
+
{
|
719 |
+
"fold": 0,
|
720 |
+
"image": [
|
721 |
+
"BraTS20_Training_275/BraTS20_Training_275_flair.nii.gz",
|
722 |
+
"BraTS20_Training_275/BraTS20_Training_275_t1.nii.gz",
|
723 |
+
"BraTS20_Training_275/BraTS20_Training_275_t1ce.nii.gz",
|
724 |
+
"BraTS20_Training_275/BraTS20_Training_275_t2.nii.gz"
|
725 |
+
],
|
726 |
+
"label": "BraTS20_Training_275/BraTS20_Training_275_seg.nii.gz",
|
727 |
+
"text_feature": "BraTS20_Training_275/BraTS20_Training_275_flair_text.npy"
|
728 |
+
},
|
729 |
+
{
|
730 |
+
"fold": 0,
|
731 |
+
"image": [
|
732 |
+
"BraTS20_Training_076/BraTS20_Training_076_flair.nii.gz",
|
733 |
+
"BraTS20_Training_076/BraTS20_Training_076_t1.nii.gz",
|
734 |
+
"BraTS20_Training_076/BraTS20_Training_076_t1ce.nii.gz",
|
735 |
+
"BraTS20_Training_076/BraTS20_Training_076_t2.nii.gz"
|
736 |
+
],
|
737 |
+
"label": "BraTS20_Training_076/BraTS20_Training_076_seg.nii.gz",
|
738 |
+
"text_feature": "BraTS20_Training_076/BraTS20_Training_076_flair_text.npy"
|
739 |
+
},
|
740 |
+
{
|
741 |
+
"fold": 0,
|
742 |
+
"image": [
|
743 |
+
"BraTS20_Training_327/BraTS20_Training_327_flair.nii.gz",
|
744 |
+
"BraTS20_Training_327/BraTS20_Training_327_t1.nii.gz",
|
745 |
+
"BraTS20_Training_327/BraTS20_Training_327_t1ce.nii.gz",
|
746 |
+
"BraTS20_Training_327/BraTS20_Training_327_t2.nii.gz"
|
747 |
+
],
|
748 |
+
"label": "BraTS20_Training_327/BraTS20_Training_327_seg.nii.gz",
|
749 |
+
"text_feature": "BraTS20_Training_327/BraTS20_Training_327_flair_text.npy"
|
750 |
+
},
|
751 |
+
{
|
752 |
+
"fold": 0,
|
753 |
+
"image": [
|
754 |
+
"BraTS20_Training_059/BraTS20_Training_059_flair.nii.gz",
|
755 |
+
"BraTS20_Training_059/BraTS20_Training_059_t1.nii.gz",
|
756 |
+
"BraTS20_Training_059/BraTS20_Training_059_t1ce.nii.gz",
|
757 |
+
"BraTS20_Training_059/BraTS20_Training_059_t2.nii.gz"
|
758 |
+
],
|
759 |
+
"label": "BraTS20_Training_059/BraTS20_Training_059_seg.nii.gz",
|
760 |
+
"text_feature": "BraTS20_Training_059/BraTS20_Training_059_flair_text.npy"
|
761 |
+
},
|
762 |
+
{
|
763 |
+
"fold": 0,
|
764 |
+
"image": [
|
765 |
+
"BraTS20_Training_199/BraTS20_Training_199_flair.nii.gz",
|
766 |
+
"BraTS20_Training_199/BraTS20_Training_199_t1.nii.gz",
|
767 |
+
"BraTS20_Training_199/BraTS20_Training_199_t1ce.nii.gz",
|
768 |
+
"BraTS20_Training_199/BraTS20_Training_199_t2.nii.gz"
|
769 |
+
],
|
770 |
+
"label": "BraTS20_Training_199/BraTS20_Training_199_seg.nii.gz",
|
771 |
+
"text_feature": "BraTS20_Training_199/BraTS20_Training_199_flair_text.npy"
|
772 |
+
},
|
773 |
+
{
|
774 |
+
"fold": 0,
|
775 |
+
"image": [
|
776 |
+
"BraTS20_Training_044/BraTS20_Training_044_flair.nii.gz",
|
777 |
+
"BraTS20_Training_044/BraTS20_Training_044_t1.nii.gz",
|
778 |
+
"BraTS20_Training_044/BraTS20_Training_044_t1ce.nii.gz",
|
779 |
+
"BraTS20_Training_044/BraTS20_Training_044_t2.nii.gz"
|
780 |
+
],
|
781 |
+
"label": "BraTS20_Training_044/BraTS20_Training_044_seg.nii.gz",
|
782 |
+
"text_feature": "BraTS20_Training_044/BraTS20_Training_044_flair_text.npy"
|
783 |
+
},
|
784 |
+
{
|
785 |
+
"fold": 0,
|
786 |
+
"image": [
|
787 |
+
"BraTS20_Training_320/BraTS20_Training_320_flair.nii.gz",
|
788 |
+
"BraTS20_Training_320/BraTS20_Training_320_t1.nii.gz",
|
789 |
+
"BraTS20_Training_320/BraTS20_Training_320_t1ce.nii.gz",
|
790 |
+
"BraTS20_Training_320/BraTS20_Training_320_t2.nii.gz"
|
791 |
+
],
|
792 |
+
"label": "BraTS20_Training_320/BraTS20_Training_320_seg.nii.gz",
|
793 |
+
"text_feature": "BraTS20_Training_320/BraTS20_Training_320_flair_text.npy"
|
794 |
+
},
|
795 |
+
{
|
796 |
+
"fold": 0,
|
797 |
+
"image": [
|
798 |
+
"BraTS20_Training_093/BraTS20_Training_093_flair.nii.gz",
|
799 |
+
"BraTS20_Training_093/BraTS20_Training_093_t1.nii.gz",
|
800 |
+
"BraTS20_Training_093/BraTS20_Training_093_t1ce.nii.gz",
|
801 |
+
"BraTS20_Training_093/BraTS20_Training_093_t2.nii.gz"
|
802 |
+
],
|
803 |
+
"label": "BraTS20_Training_093/BraTS20_Training_093_seg.nii.gz",
|
804 |
+
"text_feature": "BraTS20_Training_093/BraTS20_Training_093_flair_text.npy"
|
805 |
+
},
|
806 |
+
{
|
807 |
+
"fold": 0,
|
808 |
+
"image": [
|
809 |
+
"BraTS20_Training_224/BraTS20_Training_224_flair.nii.gz",
|
810 |
+
"BraTS20_Training_224/BraTS20_Training_224_t1.nii.gz",
|
811 |
+
"BraTS20_Training_224/BraTS20_Training_224_t1ce.nii.gz",
|
812 |
+
"BraTS20_Training_224/BraTS20_Training_224_t2.nii.gz"
|
813 |
+
],
|
814 |
+
"label": "BraTS20_Training_224/BraTS20_Training_224_seg.nii.gz",
|
815 |
+
"text_feature": "BraTS20_Training_224/BraTS20_Training_224_flair_text.npy"
|
816 |
+
},
|
817 |
+
{
|
818 |
+
"fold": 0,
|
819 |
+
"image": [
|
820 |
+
"BraTS20_Training_225/BraTS20_Training_225_flair.nii.gz",
|
821 |
+
"BraTS20_Training_225/BraTS20_Training_225_t1.nii.gz",
|
822 |
+
"BraTS20_Training_225/BraTS20_Training_225_t1ce.nii.gz",
|
823 |
+
"BraTS20_Training_225/BraTS20_Training_225_t2.nii.gz"
|
824 |
+
],
|
825 |
+
"label": "BraTS20_Training_225/BraTS20_Training_225_seg.nii.gz",
|
826 |
+
"text_feature": "BraTS20_Training_225/BraTS20_Training_225_flair_text.npy"
|
827 |
+
},
|
828 |
+
{
|
829 |
+
"fold": 0,
|
830 |
+
"image": [
|
831 |
+
"BraTS20_Training_218/BraTS20_Training_218_flair.nii.gz",
|
832 |
+
"BraTS20_Training_218/BraTS20_Training_218_t1.nii.gz",
|
833 |
+
"BraTS20_Training_218/BraTS20_Training_218_t1ce.nii.gz",
|
834 |
+
"BraTS20_Training_218/BraTS20_Training_218_t2.nii.gz"
|
835 |
+
],
|
836 |
+
"label": "BraTS20_Training_218/BraTS20_Training_218_seg.nii.gz",
|
837 |
+
"text_feature": "BraTS20_Training_218/BraTS20_Training_218_flair_text.npy"
|
838 |
+
},
|
839 |
+
{
|
840 |
+
"fold": 0,
|
841 |
+
"image": [
|
842 |
+
"BraTS20_Training_014/BraTS20_Training_014_flair.nii.gz",
|
843 |
+
"BraTS20_Training_014/BraTS20_Training_014_t1.nii.gz",
|
844 |
+
"BraTS20_Training_014/BraTS20_Training_014_t1ce.nii.gz",
|
845 |
+
"BraTS20_Training_014/BraTS20_Training_014_t2.nii.gz"
|
846 |
+
],
|
847 |
+
"label": "BraTS20_Training_014/BraTS20_Training_014_seg.nii.gz",
|
848 |
+
"text_feature": "BraTS20_Training_014/BraTS20_Training_014_flair_text.npy"
|
849 |
+
},
|
850 |
+
{
|
851 |
+
"fold": 0,
|
852 |
+
"image": [
|
853 |
+
"BraTS20_Training_264/BraTS20_Training_264_flair.nii.gz",
|
854 |
+
"BraTS20_Training_264/BraTS20_Training_264_t1.nii.gz",
|
855 |
+
"BraTS20_Training_264/BraTS20_Training_264_t1ce.nii.gz",
|
856 |
+
"BraTS20_Training_264/BraTS20_Training_264_t2.nii.gz"
|
857 |
+
],
|
858 |
+
"label": "BraTS20_Training_264/BraTS20_Training_264_seg.nii.gz",
|
859 |
+
"text_feature": "BraTS20_Training_264/BraTS20_Training_264_flair_text.npy"
|
860 |
+
},
|
861 |
+
{
|
862 |
+
"fold": 0,
|
863 |
+
"image": [
|
864 |
+
"BraTS20_Training_071/BraTS20_Training_071_flair.nii.gz",
|
865 |
+
"BraTS20_Training_071/BraTS20_Training_071_t1.nii.gz",
|
866 |
+
"BraTS20_Training_071/BraTS20_Training_071_t1ce.nii.gz",
|
867 |
+
"BraTS20_Training_071/BraTS20_Training_071_t2.nii.gz"
|
868 |
+
],
|
869 |
+
"label": "BraTS20_Training_071/BraTS20_Training_071_seg.nii.gz",
|
870 |
+
"text_feature": "BraTS20_Training_071/BraTS20_Training_071_flair_text.npy"
|
871 |
+
},
|
872 |
+
{
|
873 |
+
"fold": 0,
|
874 |
+
"image": [
|
875 |
+
"BraTS20_Training_167/BraTS20_Training_167_flair.nii.gz",
|
876 |
+
"BraTS20_Training_167/BraTS20_Training_167_t1.nii.gz",
|
877 |
+
"BraTS20_Training_167/BraTS20_Training_167_t1ce.nii.gz",
|
878 |
+
"BraTS20_Training_167/BraTS20_Training_167_t2.nii.gz"
|
879 |
+
],
|
880 |
+
"label": "BraTS20_Training_167/BraTS20_Training_167_seg.nii.gz",
|
881 |
+
"text_feature": "BraTS20_Training_167/BraTS20_Training_167_flair_text.npy"
|
882 |
+
},
|
883 |
+
{
|
884 |
+
"fold": 0,
|
885 |
+
"image": [
|
886 |
+
"BraTS20_Training_087/BraTS20_Training_087_flair.nii.gz",
|
887 |
+
"BraTS20_Training_087/BraTS20_Training_087_t1.nii.gz",
|
888 |
+
"BraTS20_Training_087/BraTS20_Training_087_t1ce.nii.gz",
|
889 |
+
"BraTS20_Training_087/BraTS20_Training_087_t2.nii.gz"
|
890 |
+
],
|
891 |
+
"label": "BraTS20_Training_087/BraTS20_Training_087_seg.nii.gz",
|
892 |
+
"text_feature": "BraTS20_Training_087/BraTS20_Training_087_flair_text.npy"
|
893 |
+
},
|
894 |
+
{
|
895 |
+
"fold": 0,
|
896 |
+
"image": [
|
897 |
+
"BraTS20_Training_004/BraTS20_Training_004_flair.nii.gz",
|
898 |
+
"BraTS20_Training_004/BraTS20_Training_004_t1.nii.gz",
|
899 |
+
"BraTS20_Training_004/BraTS20_Training_004_t1ce.nii.gz",
|
900 |
+
"BraTS20_Training_004/BraTS20_Training_004_t2.nii.gz"
|
901 |
+
],
|
902 |
+
"label": "BraTS20_Training_004/BraTS20_Training_004_seg.nii.gz",
|
903 |
+
"text_feature": "BraTS20_Training_004/BraTS20_Training_004_flair_text.npy"
|
904 |
+
},
|
905 |
+
{
|
906 |
+
"fold": 0,
|
907 |
+
"image": [
|
908 |
+
"BraTS20_Training_133/BraTS20_Training_133_flair.nii.gz",
|
909 |
+
"BraTS20_Training_133/BraTS20_Training_133_t1.nii.gz",
|
910 |
+
"BraTS20_Training_133/BraTS20_Training_133_t1ce.nii.gz",
|
911 |
+
"BraTS20_Training_133/BraTS20_Training_133_t2.nii.gz"
|
912 |
+
],
|
913 |
+
"label": "BraTS20_Training_133/BraTS20_Training_133_seg.nii.gz",
|
914 |
+
"text_feature": "BraTS20_Training_133/BraTS20_Training_133_flair_text.npy"
|
915 |
+
},
|
916 |
+
{
|
917 |
+
"fold": 0,
|
918 |
+
"image": [
|
919 |
+
"BraTS20_Training_072/BraTS20_Training_072_flair.nii.gz",
|
920 |
+
"BraTS20_Training_072/BraTS20_Training_072_t1.nii.gz",
|
921 |
+
"BraTS20_Training_072/BraTS20_Training_072_t1ce.nii.gz",
|
922 |
+
"BraTS20_Training_072/BraTS20_Training_072_t2.nii.gz"
|
923 |
+
],
|
924 |
+
"label": "BraTS20_Training_072/BraTS20_Training_072_seg.nii.gz",
|
925 |
+
"text_feature": "BraTS20_Training_072/BraTS20_Training_072_flair_text.npy"
|
926 |
+
},
|
927 |
+
{
|
928 |
+
"fold": 0,
|
929 |
+
"image": [
|
930 |
+
"BraTS20_Training_078/BraTS20_Training_078_flair.nii.gz",
|
931 |
+
"BraTS20_Training_078/BraTS20_Training_078_t1.nii.gz",
|
932 |
+
"BraTS20_Training_078/BraTS20_Training_078_t1ce.nii.gz",
|
933 |
+
"BraTS20_Training_078/BraTS20_Training_078_t2.nii.gz"
|
934 |
+
],
|
935 |
+
"label": "BraTS20_Training_078/BraTS20_Training_078_seg.nii.gz",
|
936 |
+
"text_feature": "BraTS20_Training_078/BraTS20_Training_078_flair_text.npy"
|
937 |
+
},
|
938 |
+
{
|
939 |
+
"fold": 0,
|
940 |
+
"image": [
|
941 |
+
"BraTS20_Training_119/BraTS20_Training_119_flair.nii.gz",
|
942 |
+
"BraTS20_Training_119/BraTS20_Training_119_t1.nii.gz",
|
943 |
+
"BraTS20_Training_119/BraTS20_Training_119_t1ce.nii.gz",
|
944 |
+
"BraTS20_Training_119/BraTS20_Training_119_t2.nii.gz"
|
945 |
+
],
|
946 |
+
"label": "BraTS20_Training_119/BraTS20_Training_119_seg.nii.gz",
|
947 |
+
"text_feature": "BraTS20_Training_119/BraTS20_Training_119_flair_text.npy"
|
948 |
+
},
|
949 |
+
{
|
950 |
+
"fold": 0,
|
951 |
+
"image": [
|
952 |
+
"BraTS20_Training_344/BraTS20_Training_344_flair.nii.gz",
|
953 |
+
"BraTS20_Training_344/BraTS20_Training_344_t1.nii.gz",
|
954 |
+
"BraTS20_Training_344/BraTS20_Training_344_t1ce.nii.gz",
|
955 |
+
"BraTS20_Training_344/BraTS20_Training_344_t2.nii.gz"
|
956 |
+
],
|
957 |
+
"label": "BraTS20_Training_344/BraTS20_Training_344_seg.nii.gz",
|
958 |
+
"text_feature": "BraTS20_Training_344/BraTS20_Training_344_flair_text.npy"
|
959 |
+
},
|
960 |
+
{
|
961 |
+
"fold": 0,
|
962 |
+
"image": [
|
963 |
+
"BraTS20_Training_171/BraTS20_Training_171_flair.nii.gz",
|
964 |
+
"BraTS20_Training_171/BraTS20_Training_171_t1.nii.gz",
|
965 |
+
"BraTS20_Training_171/BraTS20_Training_171_t1ce.nii.gz",
|
966 |
+
"BraTS20_Training_171/BraTS20_Training_171_t2.nii.gz"
|
967 |
+
],
|
968 |
+
"label": "BraTS20_Training_171/BraTS20_Training_171_seg.nii.gz",
|
969 |
+
"text_feature": "BraTS20_Training_171/BraTS20_Training_171_flair_text.npy"
|
970 |
+
},
|
971 |
+
{
|
972 |
+
"fold": 0,
|
973 |
+
"image": [
|
974 |
+
"BraTS20_Training_297/BraTS20_Training_297_flair.nii.gz",
|
975 |
+
"BraTS20_Training_297/BraTS20_Training_297_t1.nii.gz",
|
976 |
+
"BraTS20_Training_297/BraTS20_Training_297_t1ce.nii.gz",
|
977 |
+
"BraTS20_Training_297/BraTS20_Training_297_t2.nii.gz"
|
978 |
+
],
|
979 |
+
"label": "BraTS20_Training_297/BraTS20_Training_297_seg.nii.gz",
|
980 |
+
"text_feature": "BraTS20_Training_297/BraTS20_Training_297_flair_text.npy"
|
981 |
+
},
|
982 |
+
{
|
983 |
+
"fold": 0,
|
984 |
+
"image": [
|
985 |
+
"BraTS20_Training_021/BraTS20_Training_021_flair.nii.gz",
|
986 |
+
"BraTS20_Training_021/BraTS20_Training_021_t1.nii.gz",
|
987 |
+
"BraTS20_Training_021/BraTS20_Training_021_t1ce.nii.gz",
|
988 |
+
"BraTS20_Training_021/BraTS20_Training_021_t2.nii.gz"
|
989 |
+
],
|
990 |
+
"label": "BraTS20_Training_021/BraTS20_Training_021_seg.nii.gz",
|
991 |
+
"text_feature": "BraTS20_Training_021/BraTS20_Training_021_flair_text.npy"
|
992 |
+
},
|
993 |
+
{
|
994 |
+
"fold": 0,
|
995 |
+
"image": [
|
996 |
+
"BraTS20_Training_359/BraTS20_Training_359_flair.nii.gz",
|
997 |
+
"BraTS20_Training_359/BraTS20_Training_359_t1.nii.gz",
|
998 |
+
"BraTS20_Training_359/BraTS20_Training_359_t1ce.nii.gz",
|
999 |
+
"BraTS20_Training_359/BraTS20_Training_359_t2.nii.gz"
|
1000 |
+
],
|
1001 |
+
"label": "BraTS20_Training_359/BraTS20_Training_359_seg.nii.gz",
|
1002 |
+
"text_feature": "BraTS20_Training_359/BraTS20_Training_359_flair_text.npy"
|
1003 |
+
},
|
1004 |
+
{
|
1005 |
+
"fold": 0,
|
1006 |
+
"image": [
|
1007 |
+
"BraTS20_Training_328/BraTS20_Training_328_flair.nii.gz",
|
1008 |
+
"BraTS20_Training_328/BraTS20_Training_328_t1.nii.gz",
|
1009 |
+
"BraTS20_Training_328/BraTS20_Training_328_t1ce.nii.gz",
|
1010 |
+
"BraTS20_Training_328/BraTS20_Training_328_t2.nii.gz"
|
1011 |
+
],
|
1012 |
+
"label": "BraTS20_Training_328/BraTS20_Training_328_seg.nii.gz",
|
1013 |
+
"text_feature": "BraTS20_Training_328/BraTS20_Training_328_flair_text.npy"
|
1014 |
+
},
|
1015 |
+
{
|
1016 |
+
"fold": 0,
|
1017 |
+
"image": [
|
1018 |
+
"BraTS20_Training_233/BraTS20_Training_233_flair.nii.gz",
|
1019 |
+
"BraTS20_Training_233/BraTS20_Training_233_t1.nii.gz",
|
1020 |
+
"BraTS20_Training_233/BraTS20_Training_233_t1ce.nii.gz",
|
1021 |
+
"BraTS20_Training_233/BraTS20_Training_233_t2.nii.gz"
|
1022 |
+
],
|
1023 |
+
"label": "BraTS20_Training_233/BraTS20_Training_233_seg.nii.gz",
|
1024 |
+
"text_feature": "BraTS20_Training_233/BraTS20_Training_233_flair_text.npy"
|
1025 |
+
}
|
1026 |
+
]
|
1027 |
+
}
|
Train.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
assets/datasample.PNG
ADDED
|
Git LFS Details
|
assets/overview.PNG
ADDED
|
environment.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: TextBraTS
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
- defaults
|
7 |
+
dependencies:
|
8 |
+
- python=3.11
|
9 |
+
- pytorch=2.5.0
|
10 |
+
- torchvision
|
11 |
+
- torchaudio
|
12 |
+
- pytorch-cuda=12.1
|
13 |
+
- cudnn=8.9.7
|
14 |
+
- numpy=1.26.4
|
15 |
+
- sympy=1.13.1
|
16 |
+
- fsspec=2025.2
|
17 |
+
- tensorboardX=2.6.2.2
|
18 |
+
- pip
|
19 |
+
- pip:
|
20 |
+
- monai
|
21 |
+
- nibabel
|
22 |
+
- einops
|
23 |
+
- scipy
|
main.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 - 2022 MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
import warnings
|
12 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.generic")
|
13 |
+
|
14 |
+
import argparse
|
15 |
+
import os
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.distributed as dist
|
19 |
+
import torch.multiprocessing as mp
|
20 |
+
import torch.nn.parallel
|
21 |
+
import torch.utils.data.distributed
|
22 |
+
from optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
|
23 |
+
from trainer import run_training
|
24 |
+
from utils.data_utils import get_loader
|
25 |
+
from monai.losses import DiceLoss
|
26 |
+
from monai.metrics import DiceMetric
|
27 |
+
from utils.textswin_unetr import TextSwinUNETR
|
28 |
+
from monai.transforms import Activations, AsDiscrete, Compose
|
29 |
+
from monai.utils.enums import MetricReduction
|
30 |
+
import random
|
31 |
+
|
32 |
+
|
33 |
+
parser = argparse.ArgumentParser(description="TextBraTS segmentation pipeline for TextBRATS image-text dataset")
|
34 |
+
parser.add_argument("--checkpoint", default=None, help="start training from saved checkpoint")
|
35 |
+
parser.add_argument("--logdir", default="TextBraTS", type=str, help="directory to save the tensorboard logs")
|
36 |
+
parser.add_argument("--fold", default=0, type=int, help="data fold, 0 for validation and 1 for training")
|
37 |
+
parser.add_argument("--pretrained_model_name", default="model.pt", type=str, help="pretrained model name")
|
38 |
+
parser.add_argument("--data_dir", default="./data/TextBraTSData", type=str, help="dataset directory")
|
39 |
+
parser.add_argument("--json_list", default="./Train.json", type=str, help="dataset json file")
|
40 |
+
parser.add_argument("--save_checkpoint", action="store_true", help="save checkpoint during training")
|
41 |
+
parser.add_argument("--max_epochs", default=200, type=int, help="max number of training epochs")
|
42 |
+
parser.add_argument("--batch_size", default=2, type=int, help="number of batch size")
|
43 |
+
parser.add_argument("--sw_batch_size", default=4, type=int, help="number of sliding window batch size")
|
44 |
+
parser.add_argument("--optim_lr", default=1e-4, type=float, help="optimization learning rate")
|
45 |
+
parser.add_argument("--optim_name", default="adamw", type=str, help="optimization algorithm")
|
46 |
+
parser.add_argument("--reg_weight", default=1e-5, type=float, help="regularization weight")
|
47 |
+
parser.add_argument("--momentum", default=0.99, type=float, help="momentum")
|
48 |
+
parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training")
|
49 |
+
parser.add_argument("--val_every", default=1, type=int, help="validation frequency")
|
50 |
+
parser.add_argument("--distributed", action="store_true", help="start distributed training")
|
51 |
+
parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training")
|
52 |
+
parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
|
53 |
+
parser.add_argument("--dist-url", default="tcp://127.0.0.1:23456", type=str, help="distributed url")
|
54 |
+
parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
|
55 |
+
parser.add_argument("--norm_name", default="instance", type=str, help="normalization name")
|
56 |
+
parser.add_argument("--workers", default=8, type=int, help="number of workers")
|
57 |
+
parser.add_argument("--feature_size", default=48, type=int, help="feature size")
|
58 |
+
parser.add_argument("--in_channels", default=4, type=int, help="number of input channels")
|
59 |
+
parser.add_argument("--out_channels", default=3, type=int, help="number of output channels")
|
60 |
+
parser.add_argument("--cache_dataset", action="store_true", help="use monai Dataset class")
|
61 |
+
parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged")
|
62 |
+
parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged")
|
63 |
+
parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged")
|
64 |
+
parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged")
|
65 |
+
parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction")
|
66 |
+
parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction")
|
67 |
+
parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction")
|
68 |
+
parser.add_argument("--roi_x", default=128, type=int, help="roi size in x direction")
|
69 |
+
parser.add_argument("--roi_y", default=128, type=int, help="roi size in y direction")
|
70 |
+
parser.add_argument("--roi_z", default=128, type=int, help="roi size in z direction")
|
71 |
+
parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate")
|
72 |
+
parser.add_argument("--dropout_path_rate", default=0.0, type=float, help="drop path rate")
|
73 |
+
parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability")
|
74 |
+
parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability")
|
75 |
+
parser.add_argument("--infer_overlap", default=0.5, type=float, help="sliding window inference overlap")
|
76 |
+
parser.add_argument("--lrschedule", default="warmup_cosine", type=str, help="type of learning rate scheduler")
|
77 |
+
parser.add_argument("--warmup_epochs", default=50, type=int, help="number of warmup epochs")
|
78 |
+
parser.add_argument("--resume_ckpt", action="store_true", help="resume training from pretrained checkpoint")
|
79 |
+
parser.add_argument("--smooth_dr", default=1e-6, type=float, help="constant added to dice denominator to avoid nan")
|
80 |
+
parser.add_argument("--smooth_nr", default=0.0, type=float, help="constant added to dice numerator to avoid zero")
|
81 |
+
parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory")
|
82 |
+
parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data")
|
83 |
+
parser.add_argument("--use_ssl_pretrained", action="store_true", help="use SSL pretrained ckpt")
|
84 |
+
parser.add_argument(
|
85 |
+
"--pretrained_dir",
|
86 |
+
default="./runs/TextBraTS/",
|
87 |
+
type=str,
|
88 |
+
help="pretrained checkpoint directory",
|
89 |
+
)
|
90 |
+
parser.add_argument("--squared_dice", action="store_true", help="use squared Dice")
|
91 |
+
parser.add_argument("--seed", type=int, default=23,help="use random seed")
|
92 |
+
|
93 |
+
|
94 |
+
def main():
|
95 |
+
args = parser.parse_args()
|
96 |
+
args.amp = not args.noamp
|
97 |
+
args.logdir = "./runs/" + args.logdir
|
98 |
+
random.seed(args.seed)
|
99 |
+
np.random.seed(args.seed)
|
100 |
+
torch.manual_seed(args.seed)
|
101 |
+
if args.distributed:
|
102 |
+
torch.cuda.manual_seed_all(args.seed)
|
103 |
+
args.ngpus_per_node = torch.cuda.device_count()
|
104 |
+
print("Found total gpus", args.ngpus_per_node)
|
105 |
+
args.world_size = args.ngpus_per_node * args.world_size
|
106 |
+
mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args,))
|
107 |
+
else:
|
108 |
+
torch.cuda.manual_seed(args.seed)
|
109 |
+
main_worker(gpu=0, args=args)
|
110 |
+
|
111 |
+
|
112 |
+
def main_worker(gpu, args):
|
113 |
+
if args.distributed:
|
114 |
+
torch.multiprocessing.set_start_method("fork", force=True)
|
115 |
+
np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True)
|
116 |
+
args.gpu = gpu
|
117 |
+
if args.distributed:
|
118 |
+
args.rank = args.rank * args.ngpus_per_node + gpu
|
119 |
+
dist.init_process_group(
|
120 |
+
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
121 |
+
)
|
122 |
+
torch.cuda.set_device(args.gpu)
|
123 |
+
torch.backends.cudnn.benchmark = True
|
124 |
+
args.test_mode = False
|
125 |
+
loader = get_loader(args)
|
126 |
+
print(args.rank, " gpu", args.gpu)
|
127 |
+
if args.rank == 0:
|
128 |
+
print("Batch size is:", args.batch_size, "epochs", args.max_epochs)
|
129 |
+
pretrained_dir = args.pretrained_dir
|
130 |
+
model_name = args.pretrained_model_name
|
131 |
+
pretrained_pth = os.path.join(pretrained_dir, model_name)
|
132 |
+
|
133 |
+
model = TextSwinUNETR(
|
134 |
+
img_size=(args.roi_x, args.roi_y, args.roi_z),
|
135 |
+
in_channels=args.in_channels,
|
136 |
+
out_channels=args.out_channels,
|
137 |
+
feature_size=args.feature_size,
|
138 |
+
use_checkpoint=args.use_checkpoint,
|
139 |
+
text_dim=768,
|
140 |
+
)
|
141 |
+
|
142 |
+
if args.resume_ckpt:
|
143 |
+
model_dict = torch.load(pretrained_pth)["state_dict"]
|
144 |
+
for key in list(model_dict.keys()):
|
145 |
+
model_dict[key.replace("module.", "")] = model_dict.pop(key)
|
146 |
+
model.load_state_dict(model_dict,strict=True)
|
147 |
+
print("Using pretrained weights")
|
148 |
+
|
149 |
+
if args.use_ssl_pretrained:
|
150 |
+
try:
|
151 |
+
model_dict = torch.load("/media/iipl/disk1/swinunetr/model_swinvit.pt",weights_only=True)
|
152 |
+
state_dict = model_dict["state_dict"]
|
153 |
+
# fix potential differences in state dict keys from pre-training to
|
154 |
+
# fine-tuning
|
155 |
+
for key in list(state_dict.keys()):
|
156 |
+
state_dict[key.replace("module.", "swinViT.")] = state_dict.pop(key)
|
157 |
+
for key in list(state_dict.keys()):
|
158 |
+
if "fc" in key:
|
159 |
+
state_dict[key.replace("fc","linear")] = state_dict.pop(key)
|
160 |
+
if "patch_embed" in key:
|
161 |
+
state_dict[key.replace("patch_embed","")] = state_dict.pop(key)
|
162 |
+
model.load_state_dict(state_dict, strict=False)
|
163 |
+
except ValueError:
|
164 |
+
raise ValueError("Self-supervised pre-trained weights not available for" + str(args.model_name))
|
165 |
+
|
166 |
+
if args.squared_dice:
|
167 |
+
dice_loss = DiceLoss(
|
168 |
+
to_onehot_y=False, sigmoid=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True)
|
172 |
+
post_sigmoid = Activations(sigmoid=True)
|
173 |
+
post_pred = AsDiscrete(argmax=False, logit_thresh=0.5)
|
174 |
+
dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, get_not_nans=True)
|
175 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
176 |
+
print("Total parameters count", pytorch_total_params)
|
177 |
+
|
178 |
+
best_acc = 0
|
179 |
+
start_epoch = 0
|
180 |
+
|
181 |
+
if args.checkpoint is not None:
|
182 |
+
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
183 |
+
from collections import OrderedDict
|
184 |
+
|
185 |
+
new_state_dict = OrderedDict()
|
186 |
+
for k, v in checkpoint["state_dict"].items():
|
187 |
+
new_state_dict[k.replace("backbone.", "")] = v
|
188 |
+
model.load_state_dict(new_state_dict, strict=False)
|
189 |
+
if "epoch" in checkpoint:
|
190 |
+
start_epoch = checkpoint["epoch"]
|
191 |
+
if "best_acc" in checkpoint:
|
192 |
+
best_acc = checkpoint["best_acc"]
|
193 |
+
print("=> loaded checkpoint '{}' (epoch {}) (bestacc {})".format(args.checkpoint, start_epoch, best_acc))
|
194 |
+
|
195 |
+
model.cuda(args.gpu)
|
196 |
+
|
197 |
+
if args.distributed:
|
198 |
+
torch.cuda.set_device(args.gpu)
|
199 |
+
if args.norm_name == "batch":
|
200 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
201 |
+
model.cuda(args.gpu)
|
202 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], output_device=args.gpu, find_unused_parameters = False,)
|
203 |
+
if args.optim_name == "adam":
|
204 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight)
|
205 |
+
elif args.optim_name == "adamw":
|
206 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight)
|
207 |
+
elif args.optim_name == "sgd":
|
208 |
+
optimizer = torch.optim.SGD(
|
209 |
+
model.parameters(), lr=args.optim_lr, momentum=args.momentum, nesterov=True, weight_decay=args.reg_weight
|
210 |
+
)
|
211 |
+
else:
|
212 |
+
raise ValueError("Unsupported Optimization Procedure: " + str(args.optim_name))
|
213 |
+
|
214 |
+
if args.lrschedule == "warmup_cosine":
|
215 |
+
scheduler = LinearWarmupCosineAnnealingLR(
|
216 |
+
optimizer, warmup_epochs=args.warmup_epochs, max_epochs=args.max_epochs
|
217 |
+
)
|
218 |
+
elif args.lrschedule == "cosine_anneal":
|
219 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epochs)
|
220 |
+
if args.checkpoint is not None:
|
221 |
+
scheduler.step(epoch=start_epoch)
|
222 |
+
else:
|
223 |
+
scheduler = None
|
224 |
+
|
225 |
+
semantic_classes = ["Dice_Val_TC", "Dice_Val_WT", "Dice_Val_ET"]
|
226 |
+
|
227 |
+
accuracy = run_training(
|
228 |
+
model=model,
|
229 |
+
train_loader=loader[0],
|
230 |
+
val_loader=loader[1],
|
231 |
+
optimizer=optimizer,
|
232 |
+
loss_func=dice_loss,
|
233 |
+
acc_func=dice_acc,
|
234 |
+
args=args,
|
235 |
+
scheduler=scheduler,
|
236 |
+
start_epoch=start_epoch,
|
237 |
+
post_sigmoid=post_sigmoid,
|
238 |
+
post_pred=post_pred,
|
239 |
+
semantic_classes=semantic_classes,
|
240 |
+
)
|
241 |
+
return accuracy
|
242 |
+
|
243 |
+
|
244 |
+
if __name__ == "__main__":
|
245 |
+
main()
|
merge.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
# === Please set your own paths below! ===
|
5 |
+
img_root = "/path/to/MICCAI_BraTS2020_TrainingData"
|
6 |
+
txt_root = "/path/to/Download/TextBraTSData"
|
7 |
+
out_root = "/path/to/TextBraTS/TextBraTSData"
|
8 |
+
|
9 |
+
# Loop over all cases in the image folder
|
10 |
+
for case in os.listdir(img_root):
|
11 |
+
img_case_dir = os.path.join(img_root, case)
|
12 |
+
txt_case_dir = os.path.join(txt_root, case)
|
13 |
+
out_case_dir = os.path.join(out_root, case)
|
14 |
+
|
15 |
+
if not os.path.isdir(img_case_dir):
|
16 |
+
continue # Skip non-directory files
|
17 |
+
|
18 |
+
# Create output folder for each case
|
19 |
+
os.makedirs(out_case_dir, exist_ok=True)
|
20 |
+
|
21 |
+
# Copy all imaging files and segmentation labels
|
22 |
+
for file in os.listdir(img_case_dir):
|
23 |
+
shutil.copy2(os.path.join(img_case_dir, file), os.path.join(out_case_dir, file))
|
24 |
+
|
25 |
+
# Copy text reports and feature files if available
|
26 |
+
if os.path.exists(txt_case_dir):
|
27 |
+
for file in os.listdir(txt_case_dir):
|
28 |
+
shutil.copy2(os.path.join(txt_case_dir, file), os.path.join(out_case_dir, file))
|
29 |
+
else:
|
30 |
+
print(f"Warning: {txt_case_dir} does not exist, skipping.")
|
31 |
+
|
32 |
+
print("Merge done! All cases are in:", out_root)
|
optimizers/__init__.py
ADDED
File without changes
|
optimizers/lr_scheduler.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 - 2021 MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
import math
|
13 |
+
import warnings
|
14 |
+
from typing import List
|
15 |
+
|
16 |
+
from torch.optim import Adam, Optimizer
|
17 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
18 |
+
|
19 |
+
|
20 |
+
class LinearWarmupCosineAnnealingLR(_LRScheduler):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
optimizer: Optimizer,
|
24 |
+
warmup_epochs: int,
|
25 |
+
max_epochs: int,
|
26 |
+
warmup_start_lr: float = 0.0,
|
27 |
+
eta_min: float = 0.0,
|
28 |
+
last_epoch: int = -1,
|
29 |
+
) -> None:
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
optimizer (Optimizer): Wrapped optimizer.
|
33 |
+
warmup_epochs (int): Maximum number of iterations for linear warmup
|
34 |
+
max_epochs (int): Maximum number of iterations
|
35 |
+
warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
|
36 |
+
eta_min (float): Minimum learning rate. Default: 0.
|
37 |
+
last_epoch (int): The index of last epoch. Default: -1.
|
38 |
+
"""
|
39 |
+
self.warmup_epochs = warmup_epochs
|
40 |
+
self.max_epochs = max_epochs
|
41 |
+
self.warmup_start_lr = warmup_start_lr
|
42 |
+
self.eta_min = eta_min
|
43 |
+
|
44 |
+
super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
|
45 |
+
|
46 |
+
def get_lr(self) -> List[float]:
|
47 |
+
"""
|
48 |
+
Compute learning rate using chainable form of the scheduler
|
49 |
+
"""
|
50 |
+
if not self._get_lr_called_within_step:
|
51 |
+
warnings.warn(
|
52 |
+
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning
|
53 |
+
)
|
54 |
+
|
55 |
+
if self.last_epoch == 0:
|
56 |
+
return [self.warmup_start_lr] * len(self.base_lrs)
|
57 |
+
elif self.last_epoch < self.warmup_epochs:
|
58 |
+
return [
|
59 |
+
group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
|
60 |
+
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
61 |
+
]
|
62 |
+
elif self.last_epoch == self.warmup_epochs:
|
63 |
+
return self.base_lrs
|
64 |
+
elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
|
65 |
+
return [
|
66 |
+
group["lr"]
|
67 |
+
+ (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
|
68 |
+
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
69 |
+
]
|
70 |
+
|
71 |
+
return [
|
72 |
+
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
|
73 |
+
/ (
|
74 |
+
1
|
75 |
+
+ math.cos(
|
76 |
+
math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
|
77 |
+
)
|
78 |
+
)
|
79 |
+
* (group["lr"] - self.eta_min)
|
80 |
+
+ self.eta_min
|
81 |
+
for group in self.optimizer.param_groups
|
82 |
+
]
|
83 |
+
|
84 |
+
def _get_closed_form_lr(self) -> List[float]:
|
85 |
+
"""
|
86 |
+
Called when epoch is passed as a param to the `step` function of the scheduler.
|
87 |
+
"""
|
88 |
+
if self.last_epoch < self.warmup_epochs:
|
89 |
+
return [
|
90 |
+
self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
|
91 |
+
for base_lr in self.base_lrs
|
92 |
+
]
|
93 |
+
|
94 |
+
return [
|
95 |
+
self.eta_min
|
96 |
+
+ 0.5
|
97 |
+
* (base_lr - self.eta_min)
|
98 |
+
* (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
|
99 |
+
for base_lr in self.base_lrs
|
100 |
+
]
|
test.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 - 2022 MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
import argparse
|
13 |
+
from utils.data_utils import get_loader
|
14 |
+
from utils.textswin_unetr import TextSwinUNETR
|
15 |
+
import os
|
16 |
+
import time
|
17 |
+
import torch
|
18 |
+
import torch.nn.parallel
|
19 |
+
import torch.utils.data.distributed
|
20 |
+
from utils.utils import AverageMeter
|
21 |
+
from monai.utils.enums import MetricReduction
|
22 |
+
from monai.metrics import DiceMetric, HausdorffDistanceMetric
|
23 |
+
|
24 |
+
|
25 |
+
parser = argparse.ArgumentParser(description="TextBraTS segmentation pipeline")
|
26 |
+
parser.add_argument("--data_dir", default="./data/TextBraTSData", type=str, help="dataset directory")
|
27 |
+
parser.add_argument("--exp_name", default="TextBraTS", type=str, help="experiment name")
|
28 |
+
parser.add_argument("--json_list", default="Test.json", type=str, help="dataset json file")
|
29 |
+
parser.add_argument("--fold", default=0, type=int, help="data fold")
|
30 |
+
parser.add_argument("--pretrained_model_name", default="model.pt", type=str, help="pretrained model name")
|
31 |
+
parser.add_argument("--feature_size", default=48, type=int, help="feature size")
|
32 |
+
parser.add_argument("--infer_overlap", default=0.6, type=float, help="sliding window inference overlap")
|
33 |
+
parser.add_argument("--in_channels", default=4, type=int, help="number of input channels")
|
34 |
+
parser.add_argument("--out_channels", default=3, type=int, help="number of output channels")
|
35 |
+
parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged")
|
36 |
+
parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged")
|
37 |
+
parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged")
|
38 |
+
parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged")
|
39 |
+
parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction")
|
40 |
+
parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction")
|
41 |
+
parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction")
|
42 |
+
parser.add_argument("--roi_x", default=128, type=int, help="roi size in x direction")
|
43 |
+
parser.add_argument("--roi_y", default=128, type=int, help="roi size in y direction")
|
44 |
+
parser.add_argument("--roi_z", default=128, type=int, help="roi size in z direction")
|
45 |
+
parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate")
|
46 |
+
parser.add_argument("--distributed", action="store_true", help="start distributed training")
|
47 |
+
parser.add_argument("--workers", default=8, type=int, help="number of workers")
|
48 |
+
parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability")
|
49 |
+
parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability")
|
50 |
+
parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data")
|
51 |
+
parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory")
|
52 |
+
parser.add_argument(
|
53 |
+
"--pretrained_dir",
|
54 |
+
default="./runs/TextBraTS/",
|
55 |
+
type=str,
|
56 |
+
help="pretrained checkpoint directory",
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def main():
|
61 |
+
args = parser.parse_args()
|
62 |
+
args.test_mode = True
|
63 |
+
output_directory = "./outputs/" + args.exp_name
|
64 |
+
if not os.path.exists(output_directory):
|
65 |
+
os.makedirs(output_directory)
|
66 |
+
test_loader = get_loader(args)
|
67 |
+
pretrained_dir = args.pretrained_dir
|
68 |
+
model_name = args.pretrained_model_name
|
69 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
70 |
+
pretrained_pth = os.path.join(pretrained_dir, model_name)
|
71 |
+
model = TextSwinUNETR(
|
72 |
+
img_size=128,
|
73 |
+
in_channels=args.in_channels,
|
74 |
+
out_channels=args.out_channels,
|
75 |
+
feature_size=args.feature_size,
|
76 |
+
drop_rate=0.0,
|
77 |
+
attn_drop_rate=0.0,
|
78 |
+
dropout_path_rate=0.0,
|
79 |
+
use_checkpoint=args.use_checkpoint,
|
80 |
+
text_dim=768,
|
81 |
+
)
|
82 |
+
model_dict = torch.load(pretrained_pth)["state_dict"]
|
83 |
+
model.load_state_dict(model_dict, strict=False)
|
84 |
+
model.eval()
|
85 |
+
model.to(device)
|
86 |
+
|
87 |
+
def val_epoch(model, loader, acc_func, hd95_func):
|
88 |
+
model.eval()
|
89 |
+
start_time = time.time()
|
90 |
+
run_acc = AverageMeter()
|
91 |
+
run_hd95 = AverageMeter()
|
92 |
+
|
93 |
+
with torch.no_grad():
|
94 |
+
for idx, batch_data in enumerate(loader):
|
95 |
+
data, target, text = batch_data["image"], batch_data["label"], batch_data["text_feature"]
|
96 |
+
data, target, text = data.cuda(), target.cuda(), text.cuda()
|
97 |
+
logits = model(data,text)
|
98 |
+
prob = torch.sigmoid(logits)
|
99 |
+
prob = (prob > 0.5).int()
|
100 |
+
|
101 |
+
acc_func(y_pred=prob, y=target)
|
102 |
+
acc, not_nans = acc_func.aggregate()
|
103 |
+
acc = acc.cuda()
|
104 |
+
|
105 |
+
run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy())
|
106 |
+
|
107 |
+
# HD95 Metric
|
108 |
+
hd95_func(y_pred=prob, y=target)
|
109 |
+
hd95 = hd95_func.aggregate() # Assuming it returns a single value
|
110 |
+
run_hd95.update(hd95.cpu().numpy())
|
111 |
+
|
112 |
+
|
113 |
+
Dice_TC = run_acc.avg[0]
|
114 |
+
Dice_WT = run_acc.avg[1]
|
115 |
+
Dice_ET = run_acc.avg[2]
|
116 |
+
HD95_TC = run_hd95.avg[0]
|
117 |
+
HD95_WT = run_hd95.avg[1]
|
118 |
+
HD95_ET = run_hd95.avg[2]
|
119 |
+
print(
|
120 |
+
"Val {}/{}".format(idx, len(loader)),
|
121 |
+
", Dice_TC:", Dice_TC,
|
122 |
+
", Dice_WT:", Dice_WT,
|
123 |
+
", Dice_ET:", Dice_ET,
|
124 |
+
", Avg Dice:", (Dice_ET + Dice_TC + Dice_WT) / 3,
|
125 |
+
", HD95_TC:", HD95_TC,
|
126 |
+
", HD95_WT:", HD95_WT,
|
127 |
+
", HD95_ET:", HD95_ET,
|
128 |
+
", Avg HD95:", (HD95_ET + HD95_TC + HD95_WT) / 3,
|
129 |
+
", time {:.2f}s".format(time.time() - start_time),
|
130 |
+
)
|
131 |
+
start_time = time.time()
|
132 |
+
with open(output_directory+'/log.txt', "a") as log_file:
|
133 |
+
log_file.write(f"Experiment name:{args.pretrained_dir.split('/')[-2]}, "
|
134 |
+
f"Final Validation Results - Dice_TC: {Dice_TC}, Dice_WT: {Dice_WT}, Dice_ET: {Dice_ET}, "
|
135 |
+
f"Avg Dice: {(Dice_ET + Dice_TC + Dice_WT) / 3}, "
|
136 |
+
f"HD95_TC: {HD95_TC}, HD95_WT: {HD95_WT}, HD95_ET: {HD95_ET}, "
|
137 |
+
f"Avg HD95: {(HD95_ET + HD95_TC + HD95_WT) / 3}\n")
|
138 |
+
return run_acc.avg
|
139 |
+
|
140 |
+
dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, get_not_nans=True)
|
141 |
+
hd95_acc = HausdorffDistanceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, percentile=95.0)
|
142 |
+
val_epoch(model, test_loader, acc_func=dice_acc,hd95_func=hd95_acc)
|
143 |
+
|
144 |
+
if __name__ == "__main__":
|
145 |
+
main()
|
trainer.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 - 2022 MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
import os
|
13 |
+
import shutil
|
14 |
+
import time
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.nn.parallel
|
19 |
+
import torch.utils.data.distributed
|
20 |
+
from tensorboardX import SummaryWriter
|
21 |
+
from torch.amp import GradScaler, autocast
|
22 |
+
from utils.utils import AverageMeter, distributed_all_gather
|
23 |
+
|
24 |
+
from monai.data import decollate_batch
|
25 |
+
|
26 |
+
|
27 |
+
def train_epoch(model, loader, optimizer, scaler, epoch, loss_func, args):
|
28 |
+
model.train()
|
29 |
+
start_time = time.time()
|
30 |
+
run_loss = AverageMeter()
|
31 |
+
for idx, batch_data in enumerate(loader):
|
32 |
+
if isinstance(batch_data, list):
|
33 |
+
data, target, text = batch_data
|
34 |
+
else:
|
35 |
+
data, target, text = batch_data["image"], batch_data["label"], batch_data["text_feature"]
|
36 |
+
data, target, text = data.cuda(args.rank), target.cuda(args.rank), text.cuda(args.rank)
|
37 |
+
optimizer.zero_grad(set_to_none=True)
|
38 |
+
with autocast('cuda',enabled=args.amp):
|
39 |
+
logits = model(data,text)
|
40 |
+
loss = loss_func(logits, target)
|
41 |
+
if args.amp:
|
42 |
+
scaler.scale(loss).backward()
|
43 |
+
scaler.step(optimizer)
|
44 |
+
scaler.update()
|
45 |
+
else:
|
46 |
+
loss.backward()
|
47 |
+
optimizer.step()
|
48 |
+
if args.distributed:
|
49 |
+
loss_list = distributed_all_gather([loss], out_numpy=True, is_valid=idx < loader.sampler.valid_length)
|
50 |
+
run_loss.update(
|
51 |
+
np.mean(np.mean(np.stack(loss_list, axis=0), axis=0), axis=0), n=args.batch_size * args.world_size
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
run_loss.update(loss.item(), n=args.batch_size)
|
55 |
+
if args.rank == 0:
|
56 |
+
print(
|
57 |
+
"Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)),
|
58 |
+
"loss: {:.4f}".format(run_loss.avg),
|
59 |
+
"time {:.2f}s".format(time.time() - start_time),
|
60 |
+
)
|
61 |
+
start_time = time.time()
|
62 |
+
'''for param in model.parameters():
|
63 |
+
param.grad = None'''
|
64 |
+
optimizer.zero_grad(set_to_none=True)
|
65 |
+
return run_loss.avg
|
66 |
+
|
67 |
+
|
68 |
+
def val_epoch(model, loader, epoch, acc_func, args, post_sigmoid=None, post_pred=None):
|
69 |
+
model.eval()
|
70 |
+
start_time = time.time()
|
71 |
+
run_acc = AverageMeter()
|
72 |
+
|
73 |
+
with torch.no_grad():
|
74 |
+
for idx, batch_data in enumerate(loader):
|
75 |
+
data, target, text = batch_data["image"], batch_data["label"], batch_data["text_feature"]
|
76 |
+
data, target, text = data.cuda(args.rank), target.cuda(args.rank), text.cuda(args.rank)
|
77 |
+
with autocast('cuda',enabled=args.amp):
|
78 |
+
logits = model(data,text)
|
79 |
+
val_labels_list = decollate_batch(target)
|
80 |
+
val_outputs_list = decollate_batch(logits)
|
81 |
+
val_output_convert = [post_pred(post_sigmoid(val_pred_tensor)) for val_pred_tensor in val_outputs_list]
|
82 |
+
acc_func.reset()
|
83 |
+
acc_func(y_pred=val_output_convert, y=val_labels_list)
|
84 |
+
acc, not_nans = acc_func.aggregate()
|
85 |
+
acc = acc.cuda(args.rank)
|
86 |
+
if args.distributed:
|
87 |
+
acc_list, not_nans_list = distributed_all_gather(
|
88 |
+
[acc, not_nans], out_numpy=True, is_valid=idx < loader.sampler.valid_length
|
89 |
+
)
|
90 |
+
for al, nl in zip(acc_list, not_nans_list):
|
91 |
+
run_acc.update(al, n=nl)
|
92 |
+
else:
|
93 |
+
run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy())
|
94 |
+
|
95 |
+
if args.rank == 0:
|
96 |
+
Dice_TC = run_acc.avg[0]
|
97 |
+
Dice_WT = run_acc.avg[1]
|
98 |
+
Dice_ET = run_acc.avg[2]
|
99 |
+
print(
|
100 |
+
"Val {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)),
|
101 |
+
", Dice_TC:",
|
102 |
+
Dice_TC,
|
103 |
+
", Dice_WT:",
|
104 |
+
Dice_WT,
|
105 |
+
", Dice_ET:",
|
106 |
+
Dice_ET,
|
107 |
+
", time {:.2f}s".format(time.time() - start_time),
|
108 |
+
)
|
109 |
+
start_time = time.time()
|
110 |
+
|
111 |
+
return run_acc.avg
|
112 |
+
|
113 |
+
|
114 |
+
def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0, optimizer=None, scheduler=None):
|
115 |
+
state_dict = model.state_dict() if not args.distributed else model.module.state_dict()
|
116 |
+
save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict}
|
117 |
+
if optimizer is not None:
|
118 |
+
save_dict["optimizer"] = optimizer.state_dict()
|
119 |
+
if scheduler is not None:
|
120 |
+
save_dict["scheduler"] = scheduler.state_dict()
|
121 |
+
filename = os.path.join(args.logdir, filename)
|
122 |
+
torch.save(save_dict, filename)
|
123 |
+
print("Saving checkpoint", filename)
|
124 |
+
|
125 |
+
|
126 |
+
def run_training(
|
127 |
+
model,
|
128 |
+
train_loader,
|
129 |
+
val_loader,
|
130 |
+
optimizer,
|
131 |
+
loss_func,
|
132 |
+
acc_func,
|
133 |
+
args,
|
134 |
+
scheduler=None,
|
135 |
+
start_epoch=0,
|
136 |
+
post_sigmoid=None,
|
137 |
+
post_pred=None,
|
138 |
+
semantic_classes=None,
|
139 |
+
):
|
140 |
+
writer = None
|
141 |
+
if args.logdir is not None and args.rank == 0:
|
142 |
+
writer = SummaryWriter(log_dir=args.logdir)
|
143 |
+
if args.rank == 0:
|
144 |
+
print("Writing Tensorboard logs to ", args.logdir)
|
145 |
+
scaler = None
|
146 |
+
if args.amp:
|
147 |
+
scaler = GradScaler()
|
148 |
+
val_acc_max = 0.0
|
149 |
+
for epoch in range(start_epoch, args.max_epochs):
|
150 |
+
if args.distributed:
|
151 |
+
train_loader.sampler.set_epoch(epoch)
|
152 |
+
torch.distributed.barrier()
|
153 |
+
print(args.rank, time.ctime(), "Epoch:", epoch)
|
154 |
+
epoch_time = time.time()
|
155 |
+
train_loss = train_epoch(
|
156 |
+
model, train_loader, optimizer, scaler=scaler, epoch=epoch, loss_func=loss_func, args=args
|
157 |
+
)
|
158 |
+
if args.rank == 0:
|
159 |
+
print(
|
160 |
+
"Final training {}/{}".format(epoch, args.max_epochs - 1),
|
161 |
+
"loss: {:.4f}".format(train_loss),
|
162 |
+
"time {:.2f}s".format(time.time() - epoch_time),
|
163 |
+
)
|
164 |
+
if args.rank == 0 and writer is not None:
|
165 |
+
writer.add_scalar("train_loss", train_loss, epoch)
|
166 |
+
b_new_best = False
|
167 |
+
if (epoch + 1) % args.val_every == 0:
|
168 |
+
if args.distributed:
|
169 |
+
torch.distributed.barrier()
|
170 |
+
epoch_time = time.time()
|
171 |
+
val_acc = val_epoch(
|
172 |
+
model,
|
173 |
+
val_loader,
|
174 |
+
epoch=epoch,
|
175 |
+
acc_func=acc_func,
|
176 |
+
args=args,
|
177 |
+
post_sigmoid=post_sigmoid,
|
178 |
+
post_pred=post_pred,
|
179 |
+
)
|
180 |
+
|
181 |
+
if args.rank == 0:
|
182 |
+
Dice_TC = val_acc[0]
|
183 |
+
Dice_WT = val_acc[1]
|
184 |
+
Dice_ET = val_acc[2]
|
185 |
+
print(
|
186 |
+
"Final validation stats {}/{}".format(epoch, args.max_epochs - 1),
|
187 |
+
", Dice_TC:",
|
188 |
+
Dice_TC,
|
189 |
+
", Dice_WT:",
|
190 |
+
Dice_WT,
|
191 |
+
", Dice_ET:",
|
192 |
+
Dice_ET,
|
193 |
+
", time {:.2f}s".format(time.time() - epoch_time),
|
194 |
+
)
|
195 |
+
|
196 |
+
if writer is not None:
|
197 |
+
writer.add_scalar("Mean_Val_Dice", np.mean(val_acc), epoch)
|
198 |
+
if semantic_classes is not None:
|
199 |
+
for val_channel_ind in range(len(semantic_classes)):
|
200 |
+
if val_channel_ind < val_acc.size:
|
201 |
+
writer.add_scalar(semantic_classes[val_channel_ind], val_acc[val_channel_ind], epoch)
|
202 |
+
val_avg_acc = np.mean(val_acc)
|
203 |
+
if val_avg_acc > val_acc_max:
|
204 |
+
print("new best ({:.6f} --> {:.6f}). ".format(val_acc_max, val_avg_acc))
|
205 |
+
val_acc_max = val_avg_acc
|
206 |
+
b_new_best = True
|
207 |
+
if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
|
208 |
+
save_checkpoint(
|
209 |
+
model, epoch, args, best_acc=val_acc_max, optimizer=optimizer, scheduler=scheduler
|
210 |
+
)
|
211 |
+
if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
|
212 |
+
print("Saving")
|
213 |
+
save_checkpoint(model, epoch, args, best_acc=val_acc_max, filename="model_final.pt")
|
214 |
+
if b_new_best:
|
215 |
+
print("Copying to model.pt new best model!!!!")
|
216 |
+
shutil.copyfile(os.path.join(args.logdir, "model_final.pt"), os.path.join(args.logdir, "model.pt"))
|
217 |
+
|
218 |
+
if scheduler is not None:
|
219 |
+
scheduler.step()
|
220 |
+
|
221 |
+
print("Training Finished !, Best Accuracy: ", val_acc_max)
|
222 |
+
|
223 |
+
return val_acc_max
|
utils/__init__.py
ADDED
File without changes
|
utils/data_utils.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 - 2022 MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
import json
|
13 |
+
import math
|
14 |
+
import os
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from monai import data, transforms
|
20 |
+
from monai.data import NibabelReader
|
21 |
+
from monai.transforms import MapTransform
|
22 |
+
|
23 |
+
#Load biobert features
|
24 |
+
class LoadNumpyd(MapTransform):
|
25 |
+
def __init__(self, keys):
|
26 |
+
super().__init__(keys)
|
27 |
+
|
28 |
+
def __call__(self, data):
|
29 |
+
d = dict(data)
|
30 |
+
for key in self.keys:
|
31 |
+
d[key] = np.load(d[key])
|
32 |
+
d[key] = np.squeeze(d[key],axis=0)
|
33 |
+
return d
|
34 |
+
|
35 |
+
class Sampler(torch.utils.data.Sampler):
|
36 |
+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True):
|
37 |
+
if num_replicas is None:
|
38 |
+
if not torch.distributed.is_available():
|
39 |
+
raise RuntimeError("Requires distributed package to be available")
|
40 |
+
num_replicas = torch.distributed.get_world_size()
|
41 |
+
if rank is None:
|
42 |
+
if not torch.distributed.is_available():
|
43 |
+
raise RuntimeError("Requires distributed package to be available")
|
44 |
+
rank = torch.distributed.get_rank()
|
45 |
+
self.shuffle = shuffle
|
46 |
+
self.make_even = make_even
|
47 |
+
self.dataset = dataset
|
48 |
+
self.num_replicas = num_replicas
|
49 |
+
self.rank = rank
|
50 |
+
self.epoch = 0
|
51 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
52 |
+
self.total_size = self.num_samples * self.num_replicas
|
53 |
+
indices = list(range(len(self.dataset)))
|
54 |
+
self.valid_length = len(indices[self.rank : self.total_size : self.num_replicas])
|
55 |
+
|
56 |
+
def __iter__(self):
|
57 |
+
if self.shuffle:
|
58 |
+
g = torch.Generator()
|
59 |
+
g.manual_seed(self.epoch)
|
60 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
61 |
+
else:
|
62 |
+
indices = list(range(len(self.dataset)))
|
63 |
+
if self.make_even:
|
64 |
+
if len(indices) < self.total_size:
|
65 |
+
if self.total_size - len(indices) < len(indices):
|
66 |
+
indices += indices[: (self.total_size - len(indices))]
|
67 |
+
else:
|
68 |
+
extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices))
|
69 |
+
indices += [indices[ids] for ids in extra_ids]
|
70 |
+
assert len(indices) == self.total_size
|
71 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
72 |
+
self.num_samples = len(indices)
|
73 |
+
return iter(indices)
|
74 |
+
|
75 |
+
def __len__(self):
|
76 |
+
return self.num_samples
|
77 |
+
|
78 |
+
def set_epoch(self, epoch):
|
79 |
+
self.epoch = epoch
|
80 |
+
|
81 |
+
|
82 |
+
def datafold_read(datalist, basedir, fold=0, key="training"):
|
83 |
+
with open(datalist) as f:
|
84 |
+
json_data = json.load(f)
|
85 |
+
|
86 |
+
json_data = json_data[key]
|
87 |
+
|
88 |
+
for d in json_data:
|
89 |
+
for k, v in d.items():
|
90 |
+
if isinstance(d[k], list):
|
91 |
+
d[k] = [os.path.join(basedir, iv) for iv in d[k]]
|
92 |
+
elif isinstance(d[k], str):
|
93 |
+
d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]
|
94 |
+
tr = []
|
95 |
+
val = []
|
96 |
+
for d in json_data:
|
97 |
+
if "fold" in d and d["fold"] == fold:
|
98 |
+
val.append(d)
|
99 |
+
else:
|
100 |
+
tr.append(d)
|
101 |
+
return tr, val
|
102 |
+
|
103 |
+
|
104 |
+
def get_loader(args):
|
105 |
+
data_dir = args.data_dir
|
106 |
+
datalist_json = args.json_list
|
107 |
+
train_files, validation_files = datafold_read(datalist=datalist_json, basedir=data_dir, fold=args.fold)
|
108 |
+
train_transform = transforms.Compose(
|
109 |
+
[
|
110 |
+
transforms.LoadImaged(keys=["image", "label"],reader=NibabelReader()),
|
111 |
+
LoadNumpyd(keys=["text_feature"]),
|
112 |
+
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
|
113 |
+
transforms.Resized(keys=["image","label"],spatial_size=[args.roi_x,args.roi_y,args.roi_z]),
|
114 |
+
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
115 |
+
transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
|
116 |
+
transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
|
117 |
+
transforms.ToTensord(keys=["image", "label", "text_feature"]),
|
118 |
+
]
|
119 |
+
)
|
120 |
+
val_transform = transforms.Compose(
|
121 |
+
[
|
122 |
+
transforms.LoadImaged(keys=["image", "label"],reader=NibabelReader()),
|
123 |
+
LoadNumpyd(keys=["text_feature"]),
|
124 |
+
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
|
125 |
+
transforms.Resized(keys=["image", "label"], spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
|
126 |
+
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
127 |
+
transforms.ToTensord(keys=["image", "label", "text_feature"]),
|
128 |
+
]
|
129 |
+
)
|
130 |
+
|
131 |
+
test_transform = transforms.Compose(
|
132 |
+
[
|
133 |
+
transforms.LoadImaged(keys=["image", "label"],reader=NibabelReader()),
|
134 |
+
LoadNumpyd(keys=["text_feature"]),
|
135 |
+
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
|
136 |
+
transforms.Resized(keys=["image", "label"], spatial_size=[args.roi_x, args.roi_y, args.roi_z]),
|
137 |
+
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
138 |
+
transforms.ToTensord(keys=["image", "label", "text_feature"]),
|
139 |
+
]
|
140 |
+
)
|
141 |
+
|
142 |
+
if args.test_mode:
|
143 |
+
val_ds = data.Dataset(data=validation_files, transform=test_transform)
|
144 |
+
val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None
|
145 |
+
test_loader = data.DataLoader(
|
146 |
+
val_ds, batch_size=1, shuffle=False, num_workers=args.workers, sampler=val_sampler, pin_memory=True
|
147 |
+
)
|
148 |
+
|
149 |
+
loader = test_loader
|
150 |
+
else:
|
151 |
+
train_ds = data.Dataset(data=train_files, transform=train_transform)
|
152 |
+
|
153 |
+
train_sampler = Sampler(train_ds) if args.distributed else None
|
154 |
+
train_loader = data.DataLoader(
|
155 |
+
train_ds,
|
156 |
+
batch_size=args.batch_size,
|
157 |
+
shuffle=(train_sampler is None),
|
158 |
+
num_workers=args.workers,
|
159 |
+
sampler=train_sampler,
|
160 |
+
pin_memory=True,
|
161 |
+
)
|
162 |
+
val_ds = data.Dataset(data=validation_files, transform=val_transform)
|
163 |
+
val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None
|
164 |
+
val_loader = data.DataLoader(
|
165 |
+
val_ds, batch_size=1, shuffle=False, num_workers=args.workers, sampler=val_sampler, pin_memory=True
|
166 |
+
)
|
167 |
+
loader = [train_loader, val_loader]
|
168 |
+
|
169 |
+
return loader
|
utils/textswin_unetr.py
ADDED
@@ -0,0 +1,1081 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 - 2022 -> (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
from typing import Sequence, Tuple, Type, Union
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torch.utils.checkpoint as checkpoint
|
19 |
+
from torch.nn import LayerNorm
|
20 |
+
|
21 |
+
from monai.networks.blocks import MLPBlock as Mlp
|
22 |
+
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
|
23 |
+
from monai.networks.layers import DropPath, trunc_normal_
|
24 |
+
from monai.utils import ensure_tuple_rep, optional_import
|
25 |
+
import math
|
26 |
+
|
27 |
+
rearrange, _ = optional_import("einops", name="rearrange")
|
28 |
+
|
29 |
+
|
30 |
+
class TextSwinUNETR(nn.Module):
|
31 |
+
"""
|
32 |
+
Swin UNETR based on: "Hatamizadeh et al.,
|
33 |
+
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
|
34 |
+
<https://arxiv.org/abs/2201.01266>"
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
img_size: Union[Sequence[int], int],
|
40 |
+
in_channels: int,
|
41 |
+
out_channels: int,
|
42 |
+
text_dim: int,
|
43 |
+
depths: Sequence[int] = (2, 2, 2, 2),
|
44 |
+
num_heads: Sequence[int] = (3, 6, 12, 24),
|
45 |
+
feature_size: int = 24,
|
46 |
+
norm_name: Union[Tuple, str] = "instance",
|
47 |
+
drop_rate: float = 0.0,
|
48 |
+
attn_drop_rate: float = 0.0,
|
49 |
+
dropout_path_rate: float = 0.0,
|
50 |
+
normalize: bool = True,
|
51 |
+
use_checkpoint: bool = False,
|
52 |
+
spatial_dims: int = 3,
|
53 |
+
) -> None:
|
54 |
+
"""
|
55 |
+
Args:
|
56 |
+
img_size: dimension of input image.
|
57 |
+
in_channels: dimension of input channels.
|
58 |
+
out_channels: dimension of output channels.
|
59 |
+
feature_size: dimension of network feature size.
|
60 |
+
depths: number of layers in each stage.
|
61 |
+
num_heads: number of attention heads.
|
62 |
+
norm_name: feature normalization type and arguments.
|
63 |
+
drop_rate: dropout rate.
|
64 |
+
attn_drop_rate: attention dropout rate.
|
65 |
+
dropout_path_rate: drop path rate.
|
66 |
+
normalize: normalize output intermediate features in each stage.
|
67 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
68 |
+
spatial_dims: number of spatial dims.
|
69 |
+
|
70 |
+
Examples::
|
71 |
+
|
72 |
+
# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
|
73 |
+
#>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
|
74 |
+
|
75 |
+
# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
|
76 |
+
#>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
|
77 |
+
|
78 |
+
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
|
79 |
+
#>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
|
80 |
+
|
81 |
+
"""
|
82 |
+
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
img_size = ensure_tuple_rep(img_size, spatial_dims)
|
86 |
+
patch_size = ensure_tuple_rep(2, spatial_dims)
|
87 |
+
window_size = ensure_tuple_rep(7, spatial_dims)
|
88 |
+
|
89 |
+
if not (spatial_dims == 2 or spatial_dims == 3):
|
90 |
+
raise ValueError("spatial dimension should be 2 or 3.")
|
91 |
+
|
92 |
+
for m, p in zip(img_size, patch_size):
|
93 |
+
for i in range(5):
|
94 |
+
if m % np.power(p, i + 1) != 0:
|
95 |
+
raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.")
|
96 |
+
|
97 |
+
if not (0 <= drop_rate <= 1):
|
98 |
+
raise ValueError("dropout rate should be between 0 and 1.")
|
99 |
+
|
100 |
+
if not (0 <= attn_drop_rate <= 1):
|
101 |
+
raise ValueError("attention dropout rate should be between 0 and 1.")
|
102 |
+
|
103 |
+
if not (0 <= dropout_path_rate <= 1):
|
104 |
+
raise ValueError("drop path rate should be between 0 and 1.")
|
105 |
+
|
106 |
+
if feature_size % 12 != 0:
|
107 |
+
raise ValueError("feature_size should be divisible by 12.")
|
108 |
+
|
109 |
+
self.normalize = normalize
|
110 |
+
|
111 |
+
self.swinViT = SwinTransformer(
|
112 |
+
in_chans=in_channels,
|
113 |
+
embed_dim=feature_size,
|
114 |
+
window_size=window_size,
|
115 |
+
patch_size=patch_size,
|
116 |
+
depths=depths,
|
117 |
+
num_heads=num_heads,
|
118 |
+
mlp_ratio=4.0,
|
119 |
+
qkv_bias=True,
|
120 |
+
drop_rate=drop_rate,
|
121 |
+
attn_drop_rate=attn_drop_rate,
|
122 |
+
drop_path_rate=dropout_path_rate,
|
123 |
+
norm_layer=nn.LayerNorm,
|
124 |
+
use_checkpoint=use_checkpoint,
|
125 |
+
spatial_dims=spatial_dims,
|
126 |
+
text_dim=text_dim,
|
127 |
+
)
|
128 |
+
|
129 |
+
self.encoder1 = UnetrBasicBlock(
|
130 |
+
spatial_dims=spatial_dims,
|
131 |
+
in_channels=in_channels,
|
132 |
+
out_channels=feature_size,
|
133 |
+
kernel_size=3,
|
134 |
+
stride=1,
|
135 |
+
norm_name=norm_name,
|
136 |
+
res_block=True,
|
137 |
+
)
|
138 |
+
|
139 |
+
self.encoder2 = UnetrBasicBlock(
|
140 |
+
spatial_dims=spatial_dims,
|
141 |
+
in_channels=feature_size,
|
142 |
+
out_channels=feature_size,
|
143 |
+
kernel_size=3,
|
144 |
+
stride=1,
|
145 |
+
norm_name=norm_name,
|
146 |
+
res_block=True,
|
147 |
+
)
|
148 |
+
|
149 |
+
self.encoder3 = UnetrBasicBlock(
|
150 |
+
spatial_dims=spatial_dims,
|
151 |
+
in_channels=2 * feature_size,
|
152 |
+
out_channels=2 * feature_size,
|
153 |
+
kernel_size=3,
|
154 |
+
stride=1,
|
155 |
+
norm_name=norm_name,
|
156 |
+
res_block=True,
|
157 |
+
)
|
158 |
+
|
159 |
+
self.encoder4 = UnetrBasicBlock(
|
160 |
+
spatial_dims=spatial_dims,
|
161 |
+
in_channels=4 * feature_size,
|
162 |
+
out_channels=4 * feature_size,
|
163 |
+
kernel_size=3,
|
164 |
+
stride=1,
|
165 |
+
norm_name=norm_name,
|
166 |
+
res_block=True,
|
167 |
+
)
|
168 |
+
|
169 |
+
self.encoder10 = UnetrBasicBlock(
|
170 |
+
spatial_dims=spatial_dims,
|
171 |
+
in_channels=16 * feature_size,
|
172 |
+
out_channels=16 * feature_size,
|
173 |
+
kernel_size=3,
|
174 |
+
stride=1,
|
175 |
+
norm_name=norm_name,
|
176 |
+
res_block=True,
|
177 |
+
)
|
178 |
+
|
179 |
+
self.decoder5 = UnetrUpBlock(
|
180 |
+
spatial_dims=spatial_dims,
|
181 |
+
in_channels=16 * feature_size,
|
182 |
+
out_channels=8 * feature_size,
|
183 |
+
kernel_size=3,
|
184 |
+
upsample_kernel_size=2,
|
185 |
+
norm_name=norm_name,
|
186 |
+
res_block=True,
|
187 |
+
)
|
188 |
+
|
189 |
+
self.decoder4 = UnetrUpBlock(
|
190 |
+
spatial_dims=spatial_dims,
|
191 |
+
in_channels=feature_size * 8,
|
192 |
+
out_channels=feature_size * 4,
|
193 |
+
kernel_size=3,
|
194 |
+
upsample_kernel_size=2,
|
195 |
+
norm_name=norm_name,
|
196 |
+
res_block=True,
|
197 |
+
)
|
198 |
+
|
199 |
+
self.decoder3 = UnetrUpBlock(
|
200 |
+
spatial_dims=spatial_dims,
|
201 |
+
in_channels=feature_size * 4,
|
202 |
+
out_channels=feature_size * 2,
|
203 |
+
kernel_size=3,
|
204 |
+
upsample_kernel_size=2,
|
205 |
+
norm_name=norm_name,
|
206 |
+
res_block=True,
|
207 |
+
)
|
208 |
+
self.decoder2 = UnetrUpBlock(
|
209 |
+
spatial_dims=spatial_dims,
|
210 |
+
in_channels=feature_size * 2,
|
211 |
+
out_channels=feature_size,
|
212 |
+
kernel_size=3,
|
213 |
+
upsample_kernel_size=2,
|
214 |
+
norm_name=norm_name,
|
215 |
+
res_block=True,
|
216 |
+
)
|
217 |
+
|
218 |
+
self.decoder1 = UnetrUpBlock(
|
219 |
+
spatial_dims=spatial_dims,
|
220 |
+
in_channels=feature_size,
|
221 |
+
out_channels=feature_size,
|
222 |
+
kernel_size=3,
|
223 |
+
upsample_kernel_size=2,
|
224 |
+
norm_name=norm_name,
|
225 |
+
res_block=True,
|
226 |
+
)
|
227 |
+
|
228 |
+
self.out = UnetOutBlock(
|
229 |
+
spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels
|
230 |
+
) # type: ignore
|
231 |
+
|
232 |
+
def load_from(self, weights):
|
233 |
+
|
234 |
+
with torch.no_grad():
|
235 |
+
self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"])
|
236 |
+
self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"])
|
237 |
+
for bname, block in self.swinViT.layers1[0].blocks.named_children():
|
238 |
+
block.load_from(weights, n_block=bname, layer="layers1")
|
239 |
+
self.swinViT.layers1[0].downsample.reduction.weight.copy_(
|
240 |
+
weights["state_dict"]["module.layers1.0.downsample.reduction.weight"]
|
241 |
+
)
|
242 |
+
self.swinViT.layers1[0].downsample.norm.weight.copy_(
|
243 |
+
weights["state_dict"]["module.layers1.0.downsample.norm.weight"]
|
244 |
+
)
|
245 |
+
self.swinViT.layers1[0].downsample.norm.bias.copy_(
|
246 |
+
weights["state_dict"]["module.layers1.0.downsample.norm.bias"]
|
247 |
+
)
|
248 |
+
for bname, block in self.swinViT.layers2[0].blocks.named_children():
|
249 |
+
block.load_from(weights, n_block=bname, layer="layers2")
|
250 |
+
self.swinViT.layers2[0].downsample.reduction.weight.copy_(
|
251 |
+
weights["state_dict"]["module.layers2.0.downsample.reduction.weight"]
|
252 |
+
)
|
253 |
+
self.swinViT.layers2[0].downsample.norm.weight.copy_(
|
254 |
+
weights["state_dict"]["module.layers2.0.downsample.norm.weight"]
|
255 |
+
)
|
256 |
+
self.swinViT.layers2[0].downsample.norm.bias.copy_(
|
257 |
+
weights["state_dict"]["module.layers2.0.downsample.norm.bias"]
|
258 |
+
)
|
259 |
+
for bname, block in self.swinViT.layers3[0].blocks.named_children():
|
260 |
+
block.load_from(weights, n_block=bname, layer="layers3")
|
261 |
+
self.swinViT.layers3[0].downsample.reduction.weight.copy_(
|
262 |
+
weights["state_dict"]["module.layers3.0.downsample.reduction.weight"]
|
263 |
+
)
|
264 |
+
self.swinViT.layers3[0].downsample.norm.weight.copy_(
|
265 |
+
weights["state_dict"]["module.layers3.0.downsample.norm.weight"]
|
266 |
+
)
|
267 |
+
self.swinViT.layers3[0].downsample.norm.bias.copy_(
|
268 |
+
weights["state_dict"]["module.layers3.0.downsample.norm.bias"]
|
269 |
+
)
|
270 |
+
for bname, block in self.swinViT.layers4[0].blocks.named_children():
|
271 |
+
block.load_from(weights, n_block=bname, layer="layers4")
|
272 |
+
self.swinViT.layers4[0].downsample.reduction.weight.copy_(
|
273 |
+
weights["state_dict"]["module.layers4.0.downsample.reduction.weight"]
|
274 |
+
)
|
275 |
+
self.swinViT.layers4[0].downsample.norm.weight.copy_(
|
276 |
+
weights["state_dict"]["module.layers4.0.downsample.norm.weight"]
|
277 |
+
)
|
278 |
+
self.swinViT.layers4[0].downsample.norm.bias.copy_(
|
279 |
+
weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
|
280 |
+
)
|
281 |
+
|
282 |
+
def forward(self, x_in, text_in):
|
283 |
+
hidden_states_out = self.swinViT(x_in, text_in, self.normalize)
|
284 |
+
enc0 = self.encoder1(x_in)
|
285 |
+
enc1 = self.encoder2(hidden_states_out[0])
|
286 |
+
enc2 = self.encoder3(hidden_states_out[1])
|
287 |
+
enc3 = self.encoder4(hidden_states_out[2])
|
288 |
+
dec4 = self.encoder10(hidden_states_out[4])
|
289 |
+
dec3 = self.decoder5(dec4, hidden_states_out[3])
|
290 |
+
dec2 = self.decoder4(dec3, enc3)
|
291 |
+
dec1 = self.decoder3(dec2, enc2)
|
292 |
+
dec0 = self.decoder2(dec1, enc1)
|
293 |
+
out = self.decoder1(dec0, enc0)
|
294 |
+
logits = self.out(out)
|
295 |
+
return logits
|
296 |
+
|
297 |
+
|
298 |
+
def window_partition(x, window_size):
|
299 |
+
"""window partition operation based on: "Liu et al.,
|
300 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
301 |
+
<https://arxiv.org/abs/2103.14030>"
|
302 |
+
https://github.com/microsoft/Swin-Transformer
|
303 |
+
|
304 |
+
Args:
|
305 |
+
x: input tensor.
|
306 |
+
window_size: local window size.
|
307 |
+
"""
|
308 |
+
x_shape = x.size()
|
309 |
+
if len(x_shape) == 5:
|
310 |
+
b, d, h, w, c = x_shape
|
311 |
+
x = x.view(
|
312 |
+
b,
|
313 |
+
d // window_size[0],
|
314 |
+
window_size[0],
|
315 |
+
h // window_size[1],
|
316 |
+
window_size[1],
|
317 |
+
w // window_size[2],
|
318 |
+
window_size[2],
|
319 |
+
c,
|
320 |
+
)
|
321 |
+
windows = (
|
322 |
+
x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
|
323 |
+
)
|
324 |
+
elif len(x_shape) == 4:
|
325 |
+
b, h, w, c = x.shape
|
326 |
+
x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
|
327 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
|
328 |
+
return windows
|
329 |
+
|
330 |
+
|
331 |
+
def window_reverse(windows, window_size, dims):
|
332 |
+
"""window reverse operation based on: "Liu et al.,
|
333 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
334 |
+
<https://arxiv.org/abs/2103.14030>"
|
335 |
+
https://github.com/microsoft/Swin-Transformer
|
336 |
+
|
337 |
+
Args:
|
338 |
+
windows: windows tensor.
|
339 |
+
window_size: local window size.
|
340 |
+
dims: dimension values.
|
341 |
+
"""
|
342 |
+
if len(dims) == 4:
|
343 |
+
b, d, h, w = dims
|
344 |
+
x = windows.view(
|
345 |
+
b,
|
346 |
+
d // window_size[0],
|
347 |
+
h // window_size[1],
|
348 |
+
w // window_size[2],
|
349 |
+
window_size[0],
|
350 |
+
window_size[1],
|
351 |
+
window_size[2],
|
352 |
+
-1,
|
353 |
+
)
|
354 |
+
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)
|
355 |
+
|
356 |
+
elif len(dims) == 3:
|
357 |
+
b, h, w = dims
|
358 |
+
x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1)
|
359 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
|
360 |
+
return x
|
361 |
+
|
362 |
+
|
363 |
+
def get_window_size(x_size, window_size, shift_size=None):
|
364 |
+
"""Computing window size based on: "Liu et al.,
|
365 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
366 |
+
<https://arxiv.org/abs/2103.14030>"
|
367 |
+
https://github.com/microsoft/Swin-Transformer
|
368 |
+
|
369 |
+
Args:
|
370 |
+
x_size: input size.
|
371 |
+
window_size: local window size.
|
372 |
+
shift_size: window shifting size.
|
373 |
+
"""
|
374 |
+
|
375 |
+
use_window_size = list(window_size)
|
376 |
+
if shift_size is not None:
|
377 |
+
use_shift_size = list(shift_size)
|
378 |
+
for i in range(len(x_size)):
|
379 |
+
if x_size[i] <= window_size[i]:
|
380 |
+
use_window_size[i] = x_size[i]
|
381 |
+
if shift_size is not None:
|
382 |
+
use_shift_size[i] = 0
|
383 |
+
|
384 |
+
if shift_size is None:
|
385 |
+
return tuple(use_window_size)
|
386 |
+
else:
|
387 |
+
return tuple(use_window_size), tuple(use_shift_size)
|
388 |
+
|
389 |
+
|
390 |
+
class WindowAttention(nn.Module):
|
391 |
+
"""
|
392 |
+
Window based multi-head self attention module with relative position bias based on: "Liu et al.,
|
393 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
394 |
+
<https://arxiv.org/abs/2103.14030>"
|
395 |
+
https://github.com/microsoft/Swin-Transformer
|
396 |
+
"""
|
397 |
+
|
398 |
+
def __init__(
|
399 |
+
self,
|
400 |
+
dim: int,
|
401 |
+
num_heads: int,
|
402 |
+
window_size: Sequence[int],
|
403 |
+
qkv_bias: bool = False,
|
404 |
+
attn_drop: float = 0.0,
|
405 |
+
proj_drop: float = 0.0,
|
406 |
+
) -> None:
|
407 |
+
"""
|
408 |
+
Args:
|
409 |
+
dim: number of feature channels.
|
410 |
+
num_heads: number of attention heads.
|
411 |
+
window_size: local window size.
|
412 |
+
qkv_bias: add a learnable bias to query, key, value.
|
413 |
+
attn_drop: attention dropout rate.
|
414 |
+
proj_drop: dropout rate of output.
|
415 |
+
"""
|
416 |
+
|
417 |
+
super().__init__()
|
418 |
+
self.dim = dim
|
419 |
+
self.window_size = window_size
|
420 |
+
self.num_heads = num_heads
|
421 |
+
head_dim = dim // num_heads
|
422 |
+
self.scale = head_dim**-0.5
|
423 |
+
mesh_args = torch.meshgrid.__kwdefaults__
|
424 |
+
|
425 |
+
if len(self.window_size) == 3:
|
426 |
+
self.relative_position_bias_table = nn.Parameter(
|
427 |
+
torch.zeros(
|
428 |
+
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
|
429 |
+
num_heads,
|
430 |
+
)
|
431 |
+
)
|
432 |
+
coords_d = torch.arange(self.window_size[0])
|
433 |
+
coords_h = torch.arange(self.window_size[1])
|
434 |
+
coords_w = torch.arange(self.window_size[2])
|
435 |
+
if mesh_args is not None:
|
436 |
+
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"))
|
437 |
+
else:
|
438 |
+
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
|
439 |
+
coords_flatten = torch.flatten(coords, 1)
|
440 |
+
|
441 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
442 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
443 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
444 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
445 |
+
relative_coords[:, :, 2] += self.window_size[2] - 1
|
446 |
+
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
|
447 |
+
relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
|
448 |
+
elif len(self.window_size) == 2:
|
449 |
+
self.relative_position_bias_table = nn.Parameter(
|
450 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
451 |
+
)
|
452 |
+
coords_h = torch.arange(self.window_size[0])
|
453 |
+
coords_w = torch.arange(self.window_size[1])
|
454 |
+
if mesh_args is not None:
|
455 |
+
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
|
456 |
+
else:
|
457 |
+
coords = torch.stack(torch.meshgrid(coords_h, coords_w))
|
458 |
+
coords_flatten = torch.flatten(coords, 1)
|
459 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
460 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
461 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
462 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
463 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
464 |
+
|
465 |
+
relative_position_index = relative_coords.sum(-1)
|
466 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
467 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
468 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
469 |
+
self.proj = nn.Linear(dim, dim)
|
470 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
471 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
472 |
+
self.softmax = nn.Softmax(dim=-1)
|
473 |
+
|
474 |
+
def forward(self, x, mask):
|
475 |
+
b, n, c = x.shape
|
476 |
+
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
|
477 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
478 |
+
q = q * self.scale
|
479 |
+
attn = q @ k.transpose(-2, -1)
|
480 |
+
relative_position_bias = self.relative_position_bias_table[
|
481 |
+
self.relative_position_index[:n, :n].reshape(-1)
|
482 |
+
].reshape(n, n, -1)
|
483 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
484 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
485 |
+
if mask is not None:
|
486 |
+
nw = mask.shape[0]
|
487 |
+
attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
|
488 |
+
attn = attn.view(-1, self.num_heads, n, n)
|
489 |
+
attn = self.softmax(attn)
|
490 |
+
else:
|
491 |
+
attn = self.softmax(attn)
|
492 |
+
|
493 |
+
attn = self.attn_drop(attn)
|
494 |
+
x = (attn @ v).transpose(1, 2).reshape(b, n, c)
|
495 |
+
x = self.proj(x)
|
496 |
+
x = self.proj_drop(x)
|
497 |
+
return x
|
498 |
+
|
499 |
+
|
500 |
+
class SwinTransformerBlock(nn.Module):
|
501 |
+
"""
|
502 |
+
Swin Transformer block based on: "Liu et al.,
|
503 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
504 |
+
<https://arxiv.org/abs/2103.14030>"
|
505 |
+
https://github.com/microsoft/Swin-Transformer
|
506 |
+
"""
|
507 |
+
|
508 |
+
def __init__(
|
509 |
+
self,
|
510 |
+
dim: int,
|
511 |
+
num_heads: int,
|
512 |
+
window_size: Sequence[int],
|
513 |
+
shift_size: Sequence[int],
|
514 |
+
mlp_ratio: float = 4.0,
|
515 |
+
qkv_bias: bool = True,
|
516 |
+
drop: float = 0.0,
|
517 |
+
attn_drop: float = 0.0,
|
518 |
+
drop_path: float = 0.0,
|
519 |
+
act_layer: str = "GELU",
|
520 |
+
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
521 |
+
use_checkpoint: bool = False,
|
522 |
+
) -> None:
|
523 |
+
"""
|
524 |
+
Args:
|
525 |
+
dim: number of feature channels.
|
526 |
+
num_heads: number of attention heads.
|
527 |
+
window_size: local window size.
|
528 |
+
shift_size: window shift size.
|
529 |
+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
530 |
+
qkv_bias: add a learnable bias to query, key, value.
|
531 |
+
drop: dropout rate.
|
532 |
+
attn_drop: attention dropout rate.
|
533 |
+
drop_path: stochastic depth rate.
|
534 |
+
act_layer: activation layer.
|
535 |
+
norm_layer: normalization layer.
|
536 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
537 |
+
"""
|
538 |
+
|
539 |
+
super().__init__()
|
540 |
+
self.dim = dim
|
541 |
+
self.num_heads = num_heads
|
542 |
+
self.window_size = window_size
|
543 |
+
self.shift_size = shift_size
|
544 |
+
self.mlp_ratio = mlp_ratio
|
545 |
+
self.use_checkpoint = use_checkpoint
|
546 |
+
self.norm1 = norm_layer(dim)
|
547 |
+
self.attn = WindowAttention(
|
548 |
+
dim,
|
549 |
+
window_size=self.window_size,
|
550 |
+
num_heads=num_heads,
|
551 |
+
qkv_bias=qkv_bias,
|
552 |
+
attn_drop=attn_drop,
|
553 |
+
proj_drop=drop,
|
554 |
+
)
|
555 |
+
|
556 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
557 |
+
self.norm2 = norm_layer(dim)
|
558 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
559 |
+
self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin")
|
560 |
+
|
561 |
+
def forward_part1(self, x, mask_matrix):
|
562 |
+
x_shape = x.size()
|
563 |
+
x = self.norm1(x)
|
564 |
+
if len(x_shape) == 5:
|
565 |
+
b, d, h, w, c = x.shape
|
566 |
+
window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
|
567 |
+
pad_l = pad_t = pad_d0 = 0
|
568 |
+
pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
|
569 |
+
pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
|
570 |
+
pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
|
571 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
|
572 |
+
_, dp, hp, wp, _ = x.shape
|
573 |
+
dims = [b, dp, hp, wp]
|
574 |
+
|
575 |
+
elif len(x_shape) == 4:
|
576 |
+
b, h, w, c = x.shape
|
577 |
+
window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
|
578 |
+
pad_l = pad_t = 0
|
579 |
+
pad_r = (window_size[0] - h % window_size[0]) % window_size[0]
|
580 |
+
pad_b = (window_size[1] - w % window_size[1]) % window_size[1]
|
581 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
582 |
+
_, hp, wp, _ = x.shape
|
583 |
+
dims = [b, hp, wp]
|
584 |
+
|
585 |
+
if any(i > 0 for i in shift_size):
|
586 |
+
if len(x_shape) == 5:
|
587 |
+
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
|
588 |
+
elif len(x_shape) == 4:
|
589 |
+
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
|
590 |
+
attn_mask = mask_matrix
|
591 |
+
else:
|
592 |
+
shifted_x = x
|
593 |
+
attn_mask = None
|
594 |
+
x_windows = window_partition(shifted_x, window_size)
|
595 |
+
attn_windows = self.attn(x_windows, mask=attn_mask)
|
596 |
+
attn_windows = attn_windows.view(-1, *(window_size + (c,)))
|
597 |
+
shifted_x = window_reverse(attn_windows, window_size, dims)
|
598 |
+
if any(i > 0 for i in shift_size):
|
599 |
+
if len(x_shape) == 5:
|
600 |
+
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
|
601 |
+
elif len(x_shape) == 4:
|
602 |
+
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
|
603 |
+
else:
|
604 |
+
x = shifted_x
|
605 |
+
|
606 |
+
if len(x_shape) == 5:
|
607 |
+
if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
|
608 |
+
x = x[:, :d, :h, :w, :].contiguous()
|
609 |
+
elif len(x_shape) == 4:
|
610 |
+
if pad_r > 0 or pad_b > 0:
|
611 |
+
x = x[:, :h, :w, :].contiguous()
|
612 |
+
|
613 |
+
return x
|
614 |
+
|
615 |
+
def forward_part2(self, x):
|
616 |
+
return self.drop_path(self.mlp(self.norm2(x)))
|
617 |
+
|
618 |
+
def load_from(self, weights, n_block, layer):
|
619 |
+
root = f"module.{layer}.0.blocks.{n_block}."
|
620 |
+
block_names = [
|
621 |
+
"norm1.weight",
|
622 |
+
"norm1.bias",
|
623 |
+
"attn.relative_position_bias_table",
|
624 |
+
"attn.relative_position_index",
|
625 |
+
"attn.qkv.weight",
|
626 |
+
"attn.qkv.bias",
|
627 |
+
"attn.proj.weight",
|
628 |
+
"attn.proj.bias",
|
629 |
+
"norm2.weight",
|
630 |
+
"norm2.bias",
|
631 |
+
"mlp.fc1.weight",
|
632 |
+
"mlp.fc1.bias",
|
633 |
+
"mlp.fc2.weight",
|
634 |
+
"mlp.fc2.bias",
|
635 |
+
]
|
636 |
+
with torch.no_grad():
|
637 |
+
self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
|
638 |
+
self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
|
639 |
+
self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
|
640 |
+
self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
|
641 |
+
self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
|
642 |
+
self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
|
643 |
+
self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
|
644 |
+
self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]])
|
645 |
+
self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]])
|
646 |
+
self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]])
|
647 |
+
self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]])
|
648 |
+
self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]])
|
649 |
+
self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]])
|
650 |
+
self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]])
|
651 |
+
|
652 |
+
def forward(self, x, mask_matrix):
|
653 |
+
shortcut = x
|
654 |
+
if self.use_checkpoint:
|
655 |
+
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
|
656 |
+
else:
|
657 |
+
x = self.forward_part1(x, mask_matrix)
|
658 |
+
x = shortcut + self.drop_path(x)
|
659 |
+
if self.use_checkpoint:
|
660 |
+
x = x + checkpoint.checkpoint(self.forward_part2, x)
|
661 |
+
else:
|
662 |
+
x = x + self.forward_part2(x)
|
663 |
+
return x
|
664 |
+
|
665 |
+
|
666 |
+
class PatchMerging(nn.Module):
|
667 |
+
"""
|
668 |
+
Patch merging layer based on: "Liu et al.,
|
669 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
670 |
+
<https://arxiv.org/abs/2103.14030>"
|
671 |
+
https://github.com/microsoft/Swin-Transformer
|
672 |
+
"""
|
673 |
+
|
674 |
+
def __init__(
|
675 |
+
self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3
|
676 |
+
) -> None: # type: ignore
|
677 |
+
"""
|
678 |
+
Args:
|
679 |
+
dim: number of feature channels.
|
680 |
+
norm_layer: normalization layer.
|
681 |
+
spatial_dims: number of spatial dims.
|
682 |
+
"""
|
683 |
+
|
684 |
+
super().__init__()
|
685 |
+
self.dim = dim
|
686 |
+
if spatial_dims == 3:
|
687 |
+
self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
|
688 |
+
self.norm = norm_layer(8 * dim)
|
689 |
+
elif spatial_dims == 2:
|
690 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
691 |
+
self.norm = norm_layer(4 * dim)
|
692 |
+
|
693 |
+
def forward(self, x):
|
694 |
+
|
695 |
+
x_shape = x.size()
|
696 |
+
if len(x_shape) == 5:
|
697 |
+
b, d, h, w, c = x_shape
|
698 |
+
pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
|
699 |
+
if pad_input:
|
700 |
+
x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2))
|
701 |
+
x0 = x[:, 0::2, 0::2, 0::2, :]
|
702 |
+
x1 = x[:, 1::2, 0::2, 0::2, :]
|
703 |
+
x2 = x[:, 0::2, 1::2, 0::2, :]
|
704 |
+
x3 = x[:, 0::2, 0::2, 1::2, :]
|
705 |
+
x4 = x[:, 1::2, 0::2, 1::2, :]
|
706 |
+
x5 = x[:, 0::2, 1::2, 0::2, :]
|
707 |
+
x6 = x[:, 0::2, 0::2, 1::2, :]
|
708 |
+
x7 = x[:, 1::2, 1::2, 1::2, :]
|
709 |
+
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
|
710 |
+
|
711 |
+
elif len(x_shape) == 4:
|
712 |
+
b, h, w, c = x_shape
|
713 |
+
pad_input = (h % 2 == 1) or (w % 2 == 1)
|
714 |
+
if pad_input:
|
715 |
+
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
|
716 |
+
x0 = x[:, 0::2, 0::2, :]
|
717 |
+
x1 = x[:, 1::2, 0::2, :]
|
718 |
+
x2 = x[:, 0::2, 1::2, :]
|
719 |
+
x3 = x[:, 1::2, 1::2, :]
|
720 |
+
x = torch.cat([x0, x1, x2, x3], -1)
|
721 |
+
|
722 |
+
x = self.norm(x)
|
723 |
+
x = self.reduction(x)
|
724 |
+
return x
|
725 |
+
|
726 |
+
|
727 |
+
def compute_mask(dims, window_size, shift_size, device):
|
728 |
+
"""Computing region masks based on: "Liu et al.,
|
729 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
730 |
+
<https://arxiv.org/abs/2103.14030>"
|
731 |
+
https://github.com/microsoft/Swin-Transformer
|
732 |
+
|
733 |
+
Args:
|
734 |
+
dims: dimension values.
|
735 |
+
window_size: local window size.
|
736 |
+
shift_size: shift size.
|
737 |
+
device: device.
|
738 |
+
"""
|
739 |
+
|
740 |
+
cnt = 0
|
741 |
+
|
742 |
+
if len(dims) == 3:
|
743 |
+
d, h, w = dims
|
744 |
+
img_mask = torch.zeros((1, d, h, w, 1), device=device)
|
745 |
+
for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
|
746 |
+
for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
|
747 |
+
for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
|
748 |
+
img_mask[:, d, h, w, :] = cnt
|
749 |
+
cnt += 1
|
750 |
+
|
751 |
+
elif len(dims) == 2:
|
752 |
+
h, w = dims
|
753 |
+
img_mask = torch.zeros((1, h, w, 1), device=device)
|
754 |
+
for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
|
755 |
+
for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
|
756 |
+
img_mask[:, h, w, :] = cnt
|
757 |
+
cnt += 1
|
758 |
+
|
759 |
+
mask_windows = window_partition(img_mask, window_size)
|
760 |
+
mask_windows = mask_windows.squeeze(-1)
|
761 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
762 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
763 |
+
|
764 |
+
return attn_mask
|
765 |
+
|
766 |
+
|
767 |
+
class BasicLayer(nn.Module):
|
768 |
+
"""
|
769 |
+
Basic Swin Transformer layer in one stage based on: "Liu et al.,
|
770 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
771 |
+
<https://arxiv.org/abs/2103.14030>"
|
772 |
+
https://github.com/microsoft/Swin-Transformer
|
773 |
+
"""
|
774 |
+
|
775 |
+
def __init__(
|
776 |
+
self,
|
777 |
+
dim: int,
|
778 |
+
depth: int,
|
779 |
+
num_heads: int,
|
780 |
+
window_size: Sequence[int],
|
781 |
+
drop_path: list,
|
782 |
+
mlp_ratio: float = 4.0,
|
783 |
+
qkv_bias: bool = False,
|
784 |
+
drop: float = 0.0,
|
785 |
+
attn_drop: float = 0.0,
|
786 |
+
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
787 |
+
downsample: isinstance = None, # type: ignore
|
788 |
+
use_checkpoint: bool = False,
|
789 |
+
) -> None:
|
790 |
+
"""
|
791 |
+
Args:
|
792 |
+
dim: number of feature channels.
|
793 |
+
depths: number of layers in each stage.
|
794 |
+
num_heads: number of attention heads.
|
795 |
+
window_size: local window size.
|
796 |
+
drop_path: stochastic depth rate.
|
797 |
+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
798 |
+
qkv_bias: add a learnable bias to query, key, value.
|
799 |
+
drop: dropout rate.
|
800 |
+
attn_drop: attention dropout rate.
|
801 |
+
norm_layer: normalization layer.
|
802 |
+
downsample: downsample layer at the end of the layer.
|
803 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
804 |
+
"""
|
805 |
+
|
806 |
+
super().__init__()
|
807 |
+
self.window_size = window_size
|
808 |
+
self.shift_size = tuple(i // 2 for i in window_size)
|
809 |
+
self.no_shift = tuple(0 for i in window_size)
|
810 |
+
self.depth = depth
|
811 |
+
self.use_checkpoint = use_checkpoint
|
812 |
+
self.blocks = nn.ModuleList(
|
813 |
+
[
|
814 |
+
SwinTransformerBlock(
|
815 |
+
dim=dim,
|
816 |
+
num_heads=num_heads,
|
817 |
+
window_size=self.window_size,
|
818 |
+
shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
|
819 |
+
mlp_ratio=mlp_ratio,
|
820 |
+
qkv_bias=qkv_bias,
|
821 |
+
drop=drop,
|
822 |
+
attn_drop=attn_drop,
|
823 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
824 |
+
norm_layer=norm_layer,
|
825 |
+
use_checkpoint=use_checkpoint,
|
826 |
+
)
|
827 |
+
for i in range(depth)
|
828 |
+
]
|
829 |
+
)
|
830 |
+
self.downsample = downsample
|
831 |
+
if self.downsample is not None:
|
832 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
|
833 |
+
|
834 |
+
def forward(self, x):
|
835 |
+
x_shape = x.size()
|
836 |
+
if len(x_shape) == 5:
|
837 |
+
b, c, d, h, w = x_shape
|
838 |
+
window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
|
839 |
+
x = rearrange(x, "b c d h w -> b d h w c")
|
840 |
+
dp = int(np.ceil(d / window_size[0])) * window_size[0]
|
841 |
+
hp = int(np.ceil(h / window_size[1])) * window_size[1]
|
842 |
+
wp = int(np.ceil(w / window_size[2])) * window_size[2]
|
843 |
+
attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
|
844 |
+
for blk in self.blocks:
|
845 |
+
x = blk(x, attn_mask)
|
846 |
+
x = x.view(b, d, h, w, -1)
|
847 |
+
if self.downsample is not None:
|
848 |
+
x = self.downsample(x)
|
849 |
+
x = rearrange(x, "b d h w c -> b c d h w")
|
850 |
+
|
851 |
+
elif len(x_shape) == 4:
|
852 |
+
b, c, h, w = x_shape
|
853 |
+
window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
|
854 |
+
x = rearrange(x, "b c h w -> b h w c")
|
855 |
+
hp = int(np.ceil(h / window_size[0])) * window_size[0]
|
856 |
+
wp = int(np.ceil(w / window_size[1])) * window_size[1]
|
857 |
+
attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
|
858 |
+
for blk in self.blocks:
|
859 |
+
x = blk(x, attn_mask)
|
860 |
+
x = x.view(b, h, w, -1)
|
861 |
+
if self.downsample is not None:
|
862 |
+
x = self.downsample(x)
|
863 |
+
x = rearrange(x, "b h w c -> b c h w")
|
864 |
+
return x
|
865 |
+
|
866 |
+
|
867 |
+
class SwinTransformer(nn.Module):
|
868 |
+
"""
|
869 |
+
Swin Transformer based on: "Liu et al.,
|
870 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
871 |
+
<https://arxiv.org/abs/2103.14030>"
|
872 |
+
https://github.com/microsoft/Swin-Transformer
|
873 |
+
"""
|
874 |
+
|
875 |
+
def __init__(
|
876 |
+
self,
|
877 |
+
in_chans: int,
|
878 |
+
embed_dim: int,
|
879 |
+
text_dim: int,
|
880 |
+
window_size: Sequence[int],
|
881 |
+
patch_size: Sequence[int],
|
882 |
+
depths: Sequence[int],
|
883 |
+
num_heads: Sequence[int],
|
884 |
+
mlp_ratio: float = 4.0,
|
885 |
+
qkv_bias: bool = True,
|
886 |
+
drop_rate: float = 0.0,
|
887 |
+
attn_drop_rate: float = 0.0,
|
888 |
+
drop_path_rate: float = 0.0,
|
889 |
+
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
890 |
+
patch_norm: bool = False,
|
891 |
+
use_checkpoint: bool = False,
|
892 |
+
spatial_dims: int = 3,
|
893 |
+
) -> None:
|
894 |
+
"""
|
895 |
+
Args:
|
896 |
+
in_chans: dimension of input channels.
|
897 |
+
embed_dim: number of linear projection output channels.
|
898 |
+
window_size: local window size.
|
899 |
+
patch_size: patch size.
|
900 |
+
depths: number of layers in each stage.
|
901 |
+
num_heads: number of attention heads.
|
902 |
+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
903 |
+
qkv_bias: add a learnable bias to query, key, value.
|
904 |
+
drop_rate: dropout rate.
|
905 |
+
attn_drop_rate: attention dropout rate.
|
906 |
+
drop_path_rate: stochastic depth rate.
|
907 |
+
norm_layer: normalization layer.
|
908 |
+
patch_norm: add normalization after patch embedding.
|
909 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
910 |
+
spatial_dims: spatial dimension.
|
911 |
+
"""
|
912 |
+
|
913 |
+
super().__init__()
|
914 |
+
self.num_layers = len(depths)
|
915 |
+
self.embed_dim = embed_dim
|
916 |
+
self.patch_norm = patch_norm
|
917 |
+
self.window_size = window_size
|
918 |
+
self.patch_size = patch_size
|
919 |
+
self.patch_embed = PatchEmbed(
|
920 |
+
patch_size=self.patch_size,
|
921 |
+
in_chans=in_chans,
|
922 |
+
embed_dim=embed_dim,
|
923 |
+
norm_layer=norm_layer if self.patch_norm else None, # type: ignore
|
924 |
+
spatial_dims=spatial_dims,
|
925 |
+
)
|
926 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
927 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
928 |
+
self.layers1 = nn.ModuleList()
|
929 |
+
self.layers2 = nn.ModuleList()
|
930 |
+
self.layers3 = nn.ModuleList()
|
931 |
+
self.layers4 = nn.ModuleList()
|
932 |
+
for i_layer in range(self.num_layers):
|
933 |
+
layer = BasicLayer(
|
934 |
+
dim=int(embed_dim * 2**i_layer),
|
935 |
+
depth=depths[i_layer],
|
936 |
+
num_heads=num_heads[i_layer],
|
937 |
+
window_size=self.window_size,
|
938 |
+
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
939 |
+
mlp_ratio=mlp_ratio,
|
940 |
+
qkv_bias=qkv_bias,
|
941 |
+
drop=drop_rate,
|
942 |
+
attn_drop=attn_drop_rate,
|
943 |
+
norm_layer=norm_layer,
|
944 |
+
downsample=PatchMerging,
|
945 |
+
use_checkpoint=use_checkpoint,
|
946 |
+
)
|
947 |
+
if i_layer == 0:
|
948 |
+
self.layers1.append(layer)
|
949 |
+
elif i_layer == 1:
|
950 |
+
self.layers2.append(layer)
|
951 |
+
elif i_layer == 2:
|
952 |
+
self.layers3.append(layer)
|
953 |
+
elif i_layer == 3:
|
954 |
+
self.layers4.append(layer)
|
955 |
+
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
956 |
+
self.mlp_text = nn.Sequential(
|
957 |
+
nn.Conv1d(text_dim, 1024, kernel_size=1),
|
958 |
+
nn.ReLU(),
|
959 |
+
nn.Conv1d(1024, 768, kernel_size=1),
|
960 |
+
)
|
961 |
+
|
962 |
+
self.fw_mlp = nn.Sequential(
|
963 |
+
nn.Conv3d(768, 768, kernel_size=1),
|
964 |
+
nn.ReLU())
|
965 |
+
|
966 |
+
self.mlpk = nn.Conv1d(768, 768, kernel_size=1)
|
967 |
+
self.mlpv = nn.Conv1d(768, 768, kernel_size=1)
|
968 |
+
|
969 |
+
self.ly_norm = nn.LayerNorm(normalized_shape=(4,4,4))
|
970 |
+
|
971 |
+
self.mlp_text_q = nn.Conv1d(768, 768, kernel_size=1)
|
972 |
+
self.mlp_image_k = nn.Conv1d(768, 768, kernel_size=1)
|
973 |
+
self.mlp_image_v = nn.Conv1d(768, 768, kernel_size=1)
|
974 |
+
self.mlp_image_q = nn.Conv1d(768, 768, kernel_size=1)
|
975 |
+
|
976 |
+
|
977 |
+
def proj_out(self, x, normalize=False):
|
978 |
+
if normalize:
|
979 |
+
x_shape = x.size()
|
980 |
+
if len(x_shape) == 5:
|
981 |
+
n, ch, d, h, w = x_shape
|
982 |
+
x = rearrange(x, "n c d h w -> n d h w c")
|
983 |
+
x = F.layer_norm(x, [ch])
|
984 |
+
x = rearrange(x, "n d h w c -> n c d h w")
|
985 |
+
elif len(x_shape) == 4:
|
986 |
+
n, ch, h, w = x_shape
|
987 |
+
x = rearrange(x, "n c h w -> n h w c")
|
988 |
+
x = F.layer_norm(x, [ch])
|
989 |
+
x = rearrange(x, "n h w c -> n c h w")
|
990 |
+
return x
|
991 |
+
|
992 |
+
def sequential_cross_attention(self, image_features, text_features):
|
993 |
+
|
994 |
+
"""
|
995 |
+
Cross attention between image and text features.
|
996 |
+
Args:
|
997 |
+
image_features: Tensor of shape (B, C, H, W, D)
|
998 |
+
text_features: Tensor of shape (B, T_dim, T_len)
|
999 |
+
Returns:
|
1000 |
+
Processed image features with the same shape as input (B, C, H, W, D)
|
1001 |
+
"""
|
1002 |
+
B, C, H, W, D = image_features.shape
|
1003 |
+
_, T_dim, T_len = text_features.shape
|
1004 |
+
|
1005 |
+
# Step 1: Text-to-Image Cross Attention (Text as Q, Image as K/V)
|
1006 |
+
# Project text features to Query
|
1007 |
+
text_features = self.mlp_text(text_features.permute(0,2,1).contiguous())
|
1008 |
+
text_q = self.mlp_text_q(text_features).permute(0,2,1).contiguous() # Shape: (B, T_len, d_k)
|
1009 |
+
|
1010 |
+
# Flatten image features and project to Key and Value
|
1011 |
+
image_features_flat = image_features.view(B, C, -1).contiguous() # Shape: (B, N_img, C)
|
1012 |
+
image_k = self.mlp_image_k(image_features_flat).permute(0,2,1).contiguous() # Shape: (B, N_img, d_k)
|
1013 |
+
image_v = self.mlp_image_v(image_features_flat).permute(0,2,1).contiguous() # Shape: (B, N_img, d_v)
|
1014 |
+
|
1015 |
+
# Compute attention scores and weights
|
1016 |
+
attn_scores_t2i = torch.matmul(text_q, image_k.transpose(-2, -1)) / math.sqrt(
|
1017 |
+
text_q.size(-1)) # (B, T_len, N_img)
|
1018 |
+
attn_weights_t2i = F.softmax(attn_scores_t2i, dim=-1) # (B, T_len, N_img)
|
1019 |
+
|
1020 |
+
# Get attended image features
|
1021 |
+
attended_image_features = torch.matmul(attn_weights_t2i, image_v).permute(0,2,1).contiguous() # (B, T_len, d_v)
|
1022 |
+
|
1023 |
+
# Step 2: Image-to-AttendedImage(Text) Cross Attention (Image as Q, AttendedImage as K/V)
|
1024 |
+
# Project image features to Query
|
1025 |
+
image_q = self.mlp_image_q(image_features_flat).permute(0,2,1).contiguous() # (B, N_img, d_k)
|
1026 |
+
|
1027 |
+
# Project attended text features to Key and Value
|
1028 |
+
attended_image_k = self.mlpk(attended_image_features).permute(0,2,1).contiguous() # (B, T_len, d_k)
|
1029 |
+
attended_image_v = self.mlpv(attended_image_features).permute(0,2,1).contiguous() # (B, T_len, d_v)
|
1030 |
+
|
1031 |
+
# Compute attention scores and weights
|
1032 |
+
attn_scores_i2t = torch.matmul(image_q, attended_image_k.transpose(-2, -1)) / math.sqrt(
|
1033 |
+
image_q.size(-1)) # (B, N_img, T_len)
|
1034 |
+
attn_weights_i2t = F.softmax(attn_scores_i2t, dim=-1) # (B, N_img, T_len)
|
1035 |
+
|
1036 |
+
# Get attended image features
|
1037 |
+
attn_output_image = torch.matmul(attn_weights_i2t, attended_image_v) # (B, N_img, d_v)
|
1038 |
+
|
1039 |
+
# Reshape back to original image feature shape
|
1040 |
+
attn_output_image = attn_output_image.permute(0, 2, 1).contiguous() # (B, d_v, N_img)
|
1041 |
+
attn_output_image = attn_output_image.view(B, C, H, W, D)
|
1042 |
+
|
1043 |
+
# Apply layer normalization and final MLP processing
|
1044 |
+
processed_image_features = self.ly_norm(attn_output_image)
|
1045 |
+
processed_image_features = self.fw_mlp(processed_image_features.float())
|
1046 |
+
processed_image_features = self.ly_norm(processed_image_features)
|
1047 |
+
|
1048 |
+
return processed_image_features
|
1049 |
+
|
1050 |
+
|
1051 |
+
def forward(self, x, text, normalize=True):
|
1052 |
+
x0 = self.patch_embed(x)
|
1053 |
+
x0 = self.pos_drop(x0)
|
1054 |
+
x0_out = self.proj_out(x0, normalize)
|
1055 |
+
x1 = self.layers1[0](x0.contiguous())
|
1056 |
+
x1_out = self.proj_out(x1, normalize)
|
1057 |
+
x2 = self.layers2[0](x1.contiguous())
|
1058 |
+
x2_out = self.proj_out(x2, normalize)
|
1059 |
+
x3 = self.layers3[0](x2.contiguous())
|
1060 |
+
x3_out = self.proj_out(x3, normalize)
|
1061 |
+
x4 = self.layers4[0](x3.contiguous())
|
1062 |
+
# Sequential cross-attention fusion
|
1063 |
+
x4 = self.sequential_cross_attention(x4,text)
|
1064 |
+
x4_out = self.proj_out(x4, normalize)
|
1065 |
+
return [x0_out, x1_out, x2_out, x3_out, x4_out]
|
1066 |
+
|
1067 |
+
|
1068 |
+
if __name__ == "__main__":
|
1069 |
+
model = TextSwinUNETR(
|
1070 |
+
img_size=(128,128,128),
|
1071 |
+
in_channels=4,
|
1072 |
+
out_channels=3,
|
1073 |
+
feature_size=48,
|
1074 |
+
text_dim=768,
|
1075 |
+
use_checkpoint=False,
|
1076 |
+
).cuda()
|
1077 |
+
|
1078 |
+
input = torch.randn(1,4,128,128,128).cuda()
|
1079 |
+
text = torch.randn(1,128,768).cuda()
|
1080 |
+
output = model(input,text)
|
1081 |
+
print(output[0].shape)
|
utils/utils.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 - 2022 MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
|
15 |
+
|
16 |
+
def dice(x, y):
|
17 |
+
intersect = np.sum(np.sum(np.sum(x * y)))
|
18 |
+
y_sum = np.sum(np.sum(np.sum(y)))
|
19 |
+
if y_sum == 0:
|
20 |
+
return 0.0
|
21 |
+
x_sum = np.sum(np.sum(np.sum(x)))
|
22 |
+
return 2 * intersect / (x_sum + y_sum)
|
23 |
+
|
24 |
+
|
25 |
+
class AverageMeter(object):
|
26 |
+
def __init__(self):
|
27 |
+
self.reset()
|
28 |
+
|
29 |
+
def reset(self):
|
30 |
+
self.val = 0
|
31 |
+
self.avg = 0
|
32 |
+
self.sum = 0
|
33 |
+
self.count = 0
|
34 |
+
|
35 |
+
def update(self, val, n=1):
|
36 |
+
self.val = val
|
37 |
+
self.sum += val * n
|
38 |
+
self.count += n
|
39 |
+
self.avg = np.where(self.count > 0, self.sum / self.count, self.sum)
|
40 |
+
|
41 |
+
|
42 |
+
def distributed_all_gather(
|
43 |
+
tensor_list, valid_batch_size=None, out_numpy=False, world_size=None, no_barrier=False, is_valid=None
|
44 |
+
):
|
45 |
+
if world_size is None:
|
46 |
+
world_size = torch.distributed.get_world_size()
|
47 |
+
if valid_batch_size is not None:
|
48 |
+
valid_batch_size = min(valid_batch_size, world_size)
|
49 |
+
elif is_valid is not None:
|
50 |
+
is_valid = torch.tensor(bool(is_valid), dtype=torch.bool, device=tensor_list[0].device)
|
51 |
+
if not no_barrier:
|
52 |
+
torch.distributed.barrier()
|
53 |
+
tensor_list_out = []
|
54 |
+
with torch.no_grad():
|
55 |
+
if is_valid is not None:
|
56 |
+
is_valid_list = [torch.zeros_like(is_valid) for _ in range(world_size)]
|
57 |
+
torch.distributed.all_gather(is_valid_list, is_valid)
|
58 |
+
is_valid = [x.item() for x in is_valid_list]
|
59 |
+
for tensor in tensor_list:
|
60 |
+
gather_list = [torch.zeros_like(tensor) for _ in range(world_size)]
|
61 |
+
torch.distributed.all_gather(gather_list, tensor)
|
62 |
+
if valid_batch_size is not None:
|
63 |
+
gather_list = gather_list[:valid_batch_size]
|
64 |
+
elif is_valid is not None:
|
65 |
+
gather_list = [g for g, v in zip(gather_list, is_valid_list) if v]
|
66 |
+
if out_numpy:
|
67 |
+
gather_list = [t.cpu().numpy() for t in gather_list]
|
68 |
+
tensor_list_out.append(gather_list)
|
69 |
+
return tensor_list_out
|