Upload 26 files
Browse files- INSTALL.md +24 -0
- LICENSE +21 -0
- README.md +113 -3
- datasets.py +271 -0
- dynamic_utils.py +133 -0
- engine_for_finetuning.py +375 -0
- engine_for_pretraining.py +152 -0
- environment.yml +259 -0
- functional.py +89 -0
- kinetics.py +559 -0
- masking_generator.py +185 -0
- mixup.py +316 -0
- modeling_finetune.py +351 -0
- modeling_pretrain.py +398 -0
- optim_factory.py +175 -0
- rand_augment.py +531 -0
- random_erasing.py +173 -0
- run_class_finetuning.py +582 -0
- run_mae_pretraining.py +359 -0
- run_videomae_vis.py +198 -0
- ssv2.py +363 -0
- synthetic_tubelets.py +785 -0
- transforms.py +206 -0
- utils_mae.py +536 -0
- video_transforms.py +1281 -0
- volume_transforms.py +131 -0
INSTALL.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SMILE Installation
|
2 |
+
|
3 |
+
This project relies on several open-source libraries. We recommend using **`conda`** to manage your Python environment and installing dependencies via the provided `environment.yml` file.
|
4 |
+
|
5 |
+
## Installation Steps
|
6 |
+
1. **Clone the repository**
|
7 |
+
```bash
|
8 |
+
git clone https://github.com/fmthoker/SMILE.git
|
9 |
+
cd SMILE
|
10 |
+
```
|
11 |
+
2. **Create a conda environment**
|
12 |
+
```bash
|
13 |
+
conda env create -f environment.yml
|
14 |
+
```
|
15 |
+
3. **Activate the environment**
|
16 |
+
```bash
|
17 |
+
conda activate smile
|
18 |
+
```
|
19 |
+
4. **Download CLIP weights (Optional, only required for pretraining)**
|
20 |
+
```bash
|
21 |
+
mkdir clip_weights
|
22 |
+
```
|
23 |
+
For pretraining, please download the [CLIP weights](https://huggingface.co/fmthoker/SMILE/resolve/main/clip_weights/ViT-B-16.pt) and place them in the `clip_weights` folder created above.
|
24 |
+
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Fida Mohammad Thoker
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,3 +1,113 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Official PyTorch Implementation of SMILE (CVPR 2025).
|
2 |
+
|
3 |
+

|
4 |
+
|
5 |
+
[](https://opensource.org/licenses/MIT)<br>
|
6 |
+
[](https://huggingface.co/fmthoker/SMILE/tree/main/SMILE_MODELS)
|
7 |
+
|
8 |
+
|
9 |
+
> [**SMILE: Infusing Spatial and Motion Semantics in Masked Video Learning**](https://arxiv.org/abs/2504.00527)<br>
|
10 |
+
> [Fida Mohammad Thoker](https://fmthoker.github.io/), [Letian Jiang](https://tonnew5418.github.io/), [Chen Zhao](https://zhao-chen.com/), [Bernard Ghanem](https://cemse.kaust.edu.sa/profiles/bernard-ghanem)<br>King Abdullah University of Science and Technology (KAUST)
|
11 |
+
|
12 |
+
## 📰 News
|
13 |
+
**[2025.6.2]** Code and pre-trained models are available now! <br>
|
14 |
+
**[2025.5.28]** Code and pre-trained models will be released here. Welcome to **watch** this repository for the latest updates.
|
15 |
+
|
16 |
+
## ✨ Highlights
|
17 |
+
|
18 |
+
### 🔥 State-of-the-art on SSv2 and K400
|
19 |
+
|
20 |
+
Our method achieves state-of-the-art performance on **SSv2** and **K400** benchmarks with a ViT-B backbone, surpassing prior self-supervised video models by up to **2.5%**, thanks to efficient *CLIP-based semantic supervision*.
|
21 |
+
|
22 |
+
### ⚡️ Leading Results Across Generalization Challenges
|
23 |
+
|
24 |
+
We evaluate our method on the [**SEVERE benchmark**](https://bpiyush.github.io/SEVERE-website/), covering domain shift, low-shot learning, fine-grained actions, and task adaptability. Our model consistently outperforms prior methods and achieves a **3.0% average gain** over strong baselines, demonstrating superior generalization in diverse video understanding tasks.
|
25 |
+
|
26 |
+
### 😮 Superior Motion Representation Without Video-Text Alignment
|
27 |
+
|
28 |
+
Compared to CLIP-based methods such as [**ViCLIP**](https://github.com/OpenGVLab/InternVideo/tree/main/Data/InternVid) and [**UMT**](https://github.com/OpenGVLab/unmasked_teacher), our model achieves higher accuracy on motion-sensitive datasets, particularly under *linear probing*. This indicates stronger video representations learned with less data and without relying on video-text alignment.
|
29 |
+
|
30 |
+
## 🚀 Main Results and Models
|
31 |
+
|
32 |
+
### ✨ Something-Something V2
|
33 |
+
|
34 |
+
| Method | Pretrain Dataset | Pretrain Epochs | Backbone | Top-1 | Finetune |
|
35 |
+
| :------: | :--------------: | :-------------: | :------: | :---: | :------: |
|
36 |
+
| SMILE | K400 | 800 | ViT-S | 69.1 | TODO |
|
37 |
+
| SMILE | K400 | 600 | ViT-B | 72.1 | [log](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/ssv2/VIT_B_600_EPOCHS/log.txt) / [checkpoint](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/ssv2/VIT_B_600_EPOCHS/ssv2_finetuned_after_k400_pretraining_first_stage_300_epochs_2nd_stage_300_epochs.pth) |
|
38 |
+
| SMILE | K400 | 1200 | ViT-B | 72.4 | [log](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/ssv2/VIT_B_1200_EPOCHS/log.txt) / [checkpoint](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/ssv2/VIT_B_1200_EPOCHS/ssv2_finetuned_after_k400_pretraining_first_stage_800_epochs_2nd_stage_400_epochs.pth)
|
39 |
+
| SMILE | SSv2 | 800 | ViT-B | 72.5 | TODO |
|
40 |
+
|
41 |
+
### ✨ Kinetics-400
|
42 |
+
|
43 |
+
| Method | Pretrain Dataset | Pretrain Epochs | Backbone | Top-1 | Pretrain | Finetune |
|
44 |
+
| :------: | :--------------: | :-------------: | :------: | :---: | :------: | :------: |
|
45 |
+
| SMILE | K400 | 800 | ViT-S | 79.5 | TODO | TODO |
|
46 |
+
| SMILE | K400 | 600 | ViT-B | 83.1 | [checkpoint](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/pretrain/k400_pretraining_first_stage_300_epochs_2nd_stage_300_epochs.pth) | [log](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/k400/VIT_B_600_EPOCHS/log.txt) / [checkpoint](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/k400/VIT_B_600_EPOCHS/k400_finetuned_after_k400_pretraining_first_stage_300_epochs_2nd_stage_300_epochs.pth) |
|
47 |
+
| SMILE | K400 | 1200 | ViT-B | 83.4 | [checkpoint](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/pretrain/k400_pretraining_first_stage_800_epochs_2nd_stage_400_epochs.pth) | [log](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/k400/VIT_B_1200_EPOCHS/log.txt) / [checkpoint](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/k400/VIT_B_1200_EPOCHS/k400_finetuned_after_k400_pretraining_first_stage_800_epochs_2nd_stage_400_epochs.pth) |
|
48 |
+
|
49 |
+
## 🔨 Installation
|
50 |
+
|
51 |
+
Please follow the instructions in [INSTALL.md](INSTALL.md).
|
52 |
+
|
53 |
+
## ➡️ Data Preparation
|
54 |
+
|
55 |
+
We follow [VideoMAE Data preparation](https://github.com/MCG-NJU/VideoMAE/blob/main/DATASET.md) to prepare our datasets (K400 and SSv2). Here we provide our annotation files for those two datasets: [annotation_files](annotation_files). For pretraining, we use training sets (train.csv).
|
56 |
+
|
57 |
+
We provide the list of segmented object images used for pretraining in [object_instances.txt](annotation_files/object_instances.txt). The images will be released later.
|
58 |
+
|
59 |
+
|
60 |
+
## 🔄 Pre-training
|
61 |
+
|
62 |
+
Following the [VideoMAE pre-training guide](https://github.com/MCG-NJU/VideoMAE/blob/main/PRETRAIN.md), we provide scripts for pre-training on the Kinetics-400 (K400) dataset using the ViT-Base model: [scripts/pretrain/](./scripts/pretrain/)
|
63 |
+
|
64 |
+
As described in the paper, we adopt a two-stage training strategy. Please refer to the script names to identify which stage to run.
|
65 |
+
|
66 |
+
If you wish to perform your own pre-training, make sure to update the following parameters in the scripts:
|
67 |
+
|
68 |
+
- `DATA_PATH`: Path to your dataset
|
69 |
+
- `OUTPUT_DIR`: Directory to save output results
|
70 |
+
- `OBJECTS_PATH`: Path to the overlaying objects image dataset (image data to be released)
|
71 |
+
- `FIRST_STAGE_CKPT`: Path to the ckpt from first stage pretraining ( for second stage training)
|
72 |
+
|
73 |
+
> **Note:** Our pre-training experiments were conducted using 8 V100(32 GB) GPUs.
|
74 |
+
---
|
75 |
+
|
76 |
+
## ⤴️ Fine-tuning with Pre-trained Models
|
77 |
+
|
78 |
+
Following the [VideoMAE finetuning guide](https://github.com/MCG-NJU/VideoMAE/blob/main/FINETUNE.md), we provide scripts for fine-tuning on the Something-Something v2 (SSv2) and Kinetics-400 (K400) datasets using the ViT-Base model: [scripts/finetune/](./scripts/finetune)
|
79 |
+
|
80 |
+
|
81 |
+
To perform your own fine-tuning, please update the following parameters in the script:
|
82 |
+
|
83 |
+
- `DATA_PATH`: Path to your dataset
|
84 |
+
- `MODEL_PATH`: Path to the pre-trained model
|
85 |
+
- `OUTPUT_DIR`: Directory to save output results
|
86 |
+
|
87 |
+
> **Note:** Our finetuning experiments were conducted using 4 V100(32 GB) GPUs.
|
88 |
+
|
89 |
+
## ☎️ Contact
|
90 |
+
|
91 |
+
Fida Mohammad Thoker: [email protected]
|
92 |
+
|
93 |
+
## 👍 Acknowledgements
|
94 |
+
|
95 |
+
We sincerely thank [Michael Dorkenwald](https://mdorkenwald.com/) for providing the object image dataset that supports this work.<br>
|
96 |
+
This project is built upon [VideoMAE](https://github.com/MCG-NJU/VideoMAE) and [tubelet-contrast](https://github.com/fmthoker/tubelet-contrast). Thanks to the contributors of these great codebases.
|
97 |
+
|
98 |
+
## 🔒 License
|
99 |
+
|
100 |
+
This project is released under the MIT license. For more details, please refer to the [LICENSE](https://github.com/fmthoker/SMILE/blob/main/LICENSE) file.
|
101 |
+
|
102 |
+
## ✏️ Citation
|
103 |
+
|
104 |
+
If you think this project is helpful, please feel free to leave a star⭐️ and cite our paper:
|
105 |
+
|
106 |
+
```
|
107 |
+
@inproceedings{thoker2025smile,
|
108 |
+
author = {Thoker, Fida Mohammad and Jiang, Letian and Zhao, Chen and Ghanem, Bernard},
|
109 |
+
title = {SMILE: Infusing Spatial and Motion Semantics in Masked Video Learning},
|
110 |
+
journal = {CVPR},
|
111 |
+
year = {2025},
|
112 |
+
}
|
113 |
+
```
|
datasets.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torchvision import transforms
|
3 |
+
from transforms import *
|
4 |
+
from masking_generator import TubeMaskingGenerator, TubeletMaskingGenerator
|
5 |
+
from kinetics import VideoClsDataset, VideoMAE
|
6 |
+
from ssv2 import SSVideoClsDataset
|
7 |
+
import synthetic_tubelets as synthetic_tubelets
|
8 |
+
import ast
|
9 |
+
import random
|
10 |
+
|
11 |
+
class DataAugmentationForVideoMAE(object):
|
12 |
+
def __init__(self, args):
|
13 |
+
self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
|
14 |
+
self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
|
15 |
+
normalize = GroupNormalize(self.input_mean, self.input_std)
|
16 |
+
self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66])
|
17 |
+
self.add_tubelets = args.add_tubelets
|
18 |
+
self.mask_type = args.mask_type
|
19 |
+
|
20 |
+
# original transform without adding tubelets
|
21 |
+
self.transform_original = transforms.Compose([
|
22 |
+
self.train_augmentation,
|
23 |
+
Stack(roll=False),
|
24 |
+
ToTorchFormatTensor(div=True),
|
25 |
+
normalize,
|
26 |
+
])
|
27 |
+
|
28 |
+
# tubelet transform
|
29 |
+
if args.add_tubelets:
|
30 |
+
scales = ast.literal_eval(args.scales)
|
31 |
+
|
32 |
+
self.tubelets = synthetic_tubelets.PatchMask(
|
33 |
+
use_objects=args.use_objects,
|
34 |
+
objects_path=args.objects_path,
|
35 |
+
region_sampler=dict(
|
36 |
+
scales=scales,
|
37 |
+
ratios=[0.5, 0.67, 0.75, 1.0, 1.33, 1.50, 2.0],
|
38 |
+
scale_jitter=0.18,
|
39 |
+
num_rois=2,
|
40 |
+
),
|
41 |
+
key_frame_probs=[0.5, 0.3, 0.2],
|
42 |
+
loc_velocity=12,
|
43 |
+
rot_velocity=6,
|
44 |
+
size_velocity=0.025,
|
45 |
+
label_prob=1.0,
|
46 |
+
motion_type=args.motion_type,
|
47 |
+
patch_transformation='rotation',)
|
48 |
+
|
49 |
+
|
50 |
+
self.transform1 = transforms.Compose([
|
51 |
+
self.train_augmentation,
|
52 |
+
self.tubelets,
|
53 |
+
])
|
54 |
+
self.transform2 = transforms.Compose([Stack(roll=False),
|
55 |
+
ToTorchFormatTensor(div=True),
|
56 |
+
normalize,
|
57 |
+
])
|
58 |
+
else:
|
59 |
+
self.transform = self.transform_original
|
60 |
+
|
61 |
+
self.original_masked_position_generator = TubeMaskingGenerator(
|
62 |
+
args.window_size, args.mask_ratio
|
63 |
+
)
|
64 |
+
|
65 |
+
if args.mask_type == 'tube':
|
66 |
+
self.masked_position_generator = self.original_masked_position_generator
|
67 |
+
elif args.mask_type == 'tubelet':
|
68 |
+
self.masked_position_generator = TubeletMaskingGenerator(
|
69 |
+
args.window_size, args.mask_ratio, args.visible_frames, args.sub_mask_type
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
raise NotImplemented
|
73 |
+
|
74 |
+
|
75 |
+
def __call__(self, images):
|
76 |
+
process_data, _, traj_rois = self.ComposedTransform(images)
|
77 |
+
|
78 |
+
if self.mask_type == 'tubelet' and traj_rois is not None:
|
79 |
+
return process_data, self.masked_position_generator(traj_rois)
|
80 |
+
else:
|
81 |
+
return process_data, self.masked_position_generator()
|
82 |
+
|
83 |
+
def ComposedTransform(self, images):
|
84 |
+
traj_rois = None
|
85 |
+
|
86 |
+
if self.add_tubelets:
|
87 |
+
data = self.transform1(images)
|
88 |
+
process_data, traj_rois = data[:-1], data[-1]
|
89 |
+
process_data, _ = self.transform2(process_data)
|
90 |
+
else:
|
91 |
+
process_data, _ = self.transform(images)
|
92 |
+
|
93 |
+
return process_data, _, traj_rois
|
94 |
+
|
95 |
+
def __repr__(self):
|
96 |
+
repr = "(DataAugmentationForVideoMAE,\n"
|
97 |
+
try:
|
98 |
+
self.transform
|
99 |
+
except:
|
100 |
+
repr += " transform = %s,\n" % (str(self.transform1) + str(self.transform2))
|
101 |
+
else:
|
102 |
+
repr += " transform = %s,\n" % str(self.transform)
|
103 |
+
|
104 |
+
repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
|
105 |
+
repr += ")"
|
106 |
+
return repr
|
107 |
+
|
108 |
+
|
109 |
+
def build_pretraining_dataset(args):
|
110 |
+
transform = DataAugmentationForVideoMAE(args)
|
111 |
+
dataset = VideoMAE(
|
112 |
+
root=None,
|
113 |
+
setting=args.data_path,
|
114 |
+
video_ext='mp4',
|
115 |
+
is_color=True,
|
116 |
+
modality='rgb',
|
117 |
+
new_length=args.num_frames,
|
118 |
+
new_step=args.sampling_rate,
|
119 |
+
transform=transform,
|
120 |
+
temporal_jitter=False,
|
121 |
+
video_loader=True,
|
122 |
+
use_decord=True,
|
123 |
+
lazy_init=False)
|
124 |
+
print("Data Aug = %s" % str(transform))
|
125 |
+
return dataset
|
126 |
+
|
127 |
+
|
128 |
+
def build_dataset(is_train, test_mode, args):
|
129 |
+
if args.data_set == 'Kinetics-400' or args.data_set == "Mini-Kinetics":
|
130 |
+
mode = None
|
131 |
+
anno_path = None
|
132 |
+
if is_train is True:
|
133 |
+
mode = 'train'
|
134 |
+
if 'Mini' in args.data_set:
|
135 |
+
anno_path = os.path.join(args.data_path, 'train_mini_kinetics.csv')
|
136 |
+
else:
|
137 |
+
anno_path = os.path.join(args.data_path, 'train.csv')
|
138 |
+
elif test_mode is True:
|
139 |
+
mode = 'test'
|
140 |
+
if 'Mini' in args.data_set:
|
141 |
+
anno_path = os.path.join(args.data_path, 'test_mini_kinetics.csv')
|
142 |
+
else:
|
143 |
+
anno_path = os.path.join(args.data_path, 'test.csv')
|
144 |
+
else:
|
145 |
+
mode = 'validation'
|
146 |
+
if 'Mini' in args.data_set:
|
147 |
+
anno_path = os.path.join(args.data_path, 'val_mini_kinetics.csv')
|
148 |
+
else:
|
149 |
+
anno_path = os.path.join(args.data_path, 'val.csv')
|
150 |
+
|
151 |
+
dataset = VideoClsDataset(
|
152 |
+
anno_path=anno_path,
|
153 |
+
data_path='/',
|
154 |
+
mode=mode,
|
155 |
+
clip_len=args.num_frames,
|
156 |
+
frame_sample_rate=args.sampling_rate,
|
157 |
+
num_segment=1,
|
158 |
+
test_num_segment=args.test_num_segment,
|
159 |
+
test_num_crop=args.test_num_crop,
|
160 |
+
num_crop=1 if not test_mode else 3,
|
161 |
+
keep_aspect_ratio=True,
|
162 |
+
crop_size=args.input_size,
|
163 |
+
short_side_size=args.short_side_size,
|
164 |
+
new_height=256,
|
165 |
+
new_width=320,
|
166 |
+
args=args)
|
167 |
+
if 'Mini' in args.data_set:
|
168 |
+
nb_classes = 200
|
169 |
+
else:
|
170 |
+
nb_classes = 400
|
171 |
+
|
172 |
+
elif args.data_set == 'SSV2' or args.data_set == 'SSV2-Mini':
|
173 |
+
mode = None
|
174 |
+
anno_path = None
|
175 |
+
if is_train is True:
|
176 |
+
mode = 'train'
|
177 |
+
if 'Mini' in args.data_set:
|
178 |
+
anno_path = os.path.join(args.data_path, 'train_mini.csv')
|
179 |
+
else:
|
180 |
+
anno_path = os.path.join(args.data_path, 'train.csv')
|
181 |
+
elif test_mode is True:
|
182 |
+
mode = 'test'
|
183 |
+
anno_path = os.path.join(args.data_path, 'test.csv')
|
184 |
+
else:
|
185 |
+
mode = 'validation'
|
186 |
+
anno_path = os.path.join(args.data_path, 'val.csv')
|
187 |
+
|
188 |
+
dataset = SSVideoClsDataset(
|
189 |
+
anno_path=anno_path,
|
190 |
+
data_path='/',
|
191 |
+
mode=mode,
|
192 |
+
clip_len=1,
|
193 |
+
num_segment=args.num_frames,
|
194 |
+
test_num_segment=args.test_num_segment,
|
195 |
+
test_num_crop=args.test_num_crop,
|
196 |
+
num_crop=1 if not test_mode else 3,
|
197 |
+
keep_aspect_ratio=True,
|
198 |
+
crop_size=args.input_size,
|
199 |
+
short_side_size=args.short_side_size,
|
200 |
+
new_height=256,
|
201 |
+
new_width=320,
|
202 |
+
args=args)
|
203 |
+
nb_classes = 174
|
204 |
+
|
205 |
+
elif args.data_set == 'UCF101':
|
206 |
+
mode = None
|
207 |
+
anno_path = None
|
208 |
+
if is_train is True:
|
209 |
+
mode = 'train'
|
210 |
+
anno_path = os.path.join(args.data_path, 'train.csv')
|
211 |
+
elif test_mode is True:
|
212 |
+
mode = 'test'
|
213 |
+
anno_path = os.path.join(args.data_path, 'test.csv')
|
214 |
+
else:
|
215 |
+
mode = 'validation'
|
216 |
+
anno_path = os.path.join(args.data_path, 'val.csv')
|
217 |
+
|
218 |
+
dataset = VideoClsDataset(
|
219 |
+
anno_path=anno_path,
|
220 |
+
data_path='/',
|
221 |
+
mode=mode,
|
222 |
+
clip_len=args.num_frames,
|
223 |
+
frame_sample_rate=args.sampling_rate,
|
224 |
+
num_segment=1,
|
225 |
+
test_num_segment=args.test_num_segment,
|
226 |
+
test_num_crop=args.test_num_crop,
|
227 |
+
num_crop=1 if not test_mode else 3,
|
228 |
+
keep_aspect_ratio=True,
|
229 |
+
crop_size=args.input_size,
|
230 |
+
short_side_size=args.short_side_size,
|
231 |
+
new_height=256,
|
232 |
+
new_width=320,
|
233 |
+
args=args)
|
234 |
+
nb_classes = 101
|
235 |
+
|
236 |
+
elif args.data_set == 'HMDB51':
|
237 |
+
mode = None
|
238 |
+
anno_path = None
|
239 |
+
if is_train is True:
|
240 |
+
mode = 'train'
|
241 |
+
anno_path = os.path.join(args.data_path, 'train.csv')
|
242 |
+
elif test_mode is True:
|
243 |
+
mode = 'test'
|
244 |
+
anno_path = os.path.join(args.data_path, 'test.csv')
|
245 |
+
else:
|
246 |
+
mode = 'validation'
|
247 |
+
anno_path = os.path.join(args.data_path, 'val.csv')
|
248 |
+
|
249 |
+
dataset = VideoClsDataset(
|
250 |
+
anno_path=anno_path,
|
251 |
+
data_path='/',
|
252 |
+
mode=mode,
|
253 |
+
clip_len=args.num_frames,
|
254 |
+
frame_sample_rate=args.sampling_rate,
|
255 |
+
num_segment=1,
|
256 |
+
test_num_segment=args.test_num_segment,
|
257 |
+
test_num_crop=args.test_num_crop,
|
258 |
+
num_crop=1 if not test_mode else 3,
|
259 |
+
keep_aspect_ratio=True,
|
260 |
+
crop_size=args.input_size,
|
261 |
+
short_side_size=args.short_side_size,
|
262 |
+
new_height=256,
|
263 |
+
new_width=320,
|
264 |
+
args=args)
|
265 |
+
nb_classes = 51
|
266 |
+
else:
|
267 |
+
raise NotImplementedError()
|
268 |
+
assert nb_classes == args.nb_classes
|
269 |
+
print("Number of the class = %d" % args.nb_classes)
|
270 |
+
|
271 |
+
return dataset, nb_classes
|
dynamic_utils.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
|
8 |
+
def sample_key_frames(num_frames: int,
|
9 |
+
key_frame_probs: List[float]) -> np.ndarray:
|
10 |
+
""" Sample the indices of key frames.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
num_frames (int): number of frames in whole video
|
14 |
+
key_frame_probs (List[float]): the sampling probability of how many
|
15 |
+
key frames will be sampled. The sum of this array should be 1.0.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
frame_inds (np.ndarray): key frame index, in range
|
19 |
+
of [0, num_frames - 1]. Note that the first frame and the
|
20 |
+
last frame will always be key frames.
|
21 |
+
|
22 |
+
Examples:
|
23 |
+
>>> sample_key_frames(16, [1.0, ])
|
24 |
+
np.ndarray([0, 15])
|
25 |
+
>>> sample_key_frames(16, [0.5, 0.5])
|
26 |
+
np.ndarray([0, 15])
|
27 |
+
np.ndarray([0, 7, 15])
|
28 |
+
np.ndarray([0, 8, 15])
|
29 |
+
np.ndarray([0, 15])
|
30 |
+
"""
|
31 |
+
# how many key frames
|
32 |
+
num_key_frames = np.random.choice(len(key_frame_probs), p=key_frame_probs)
|
33 |
+
# if there is no inner key frame, we will directly
|
34 |
+
# sample the first frame and the last frame.
|
35 |
+
if num_key_frames == 0:
|
36 |
+
return np.array([0, num_frames - 1], dtype=np.int32)
|
37 |
+
avg_duration = num_frames / (num_key_frames + 1)
|
38 |
+
ticks = np.array([int(avg_duration * i)
|
39 |
+
for i in range(1, num_key_frames + 1)], dtype=np.int32)
|
40 |
+
|
41 |
+
# add random jitter
|
42 |
+
jitter_range = int(avg_duration / 3)
|
43 |
+
if jitter_range > 0:
|
44 |
+
jitter = np.random.randint(-jitter_range,
|
45 |
+
jitter_range, size=len(ticks))
|
46 |
+
else:
|
47 |
+
jitter = np.zeros((len(ticks),), np.int32)
|
48 |
+
|
49 |
+
ticks = ticks + jitter
|
50 |
+
# add the first frame and last frame
|
51 |
+
ticks = np.concatenate((ticks, np.array([0, num_frames - 1])), axis=0)
|
52 |
+
# remove duplication and sort array
|
53 |
+
ticks = np.sort(np.unique(ticks))
|
54 |
+
return ticks
|
55 |
+
|
56 |
+
|
57 |
+
def extend_key_frame_to_all(array: np.ndarray,
|
58 |
+
key_frame_inds: np.ndarray,
|
59 |
+
interpolate: str = 'uniform') -> np.ndarray:
|
60 |
+
""" Interpolate the values between key frames.
|
61 |
+
|
62 |
+
This function is used in some data augmentations for video clips. For
|
63 |
+
example, we first decide the color distortion values in some key frames,
|
64 |
+
then we can interpolate the values in the rest of frames. This strategy
|
65 |
+
will make the data augmentations more smooth over the entire video clip.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
array (np.ndarray): The values in the key frames, in shape of [K, *]
|
69 |
+
key_frame_inds (np.ndarray): the frame index list of key frames, in
|
70 |
+
shape of [K, ]
|
71 |
+
interpolate (str): interpolation type. 'uniform' means the linear
|
72 |
+
interpolation; 'accelerate' means the constant acceleration.
|
73 |
+
'decelerate' means the reverse order of 'accelerate'.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
out_array (np.ndarray): the interpolated values, in shape of [N, *].
|
77 |
+
N denotes the value of key_frame_inds[-1].
|
78 |
+
|
79 |
+
Examples:
|
80 |
+
>>> values = np.array([0.0, 5.0])
|
81 |
+
>>> inds = np.array([0, 10])
|
82 |
+
>>> extend_key_frame_to_all(values, inds)
|
83 |
+
array([0. , 0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. ])
|
84 |
+
>>> extend_key_frame_to_all(values, inds, 'accelerate')
|
85 |
+
array([0. , 0.05, 0.2 , 0.45, 0.8 , 1.25, 1.8 , 2.45, 3.2 , 4.05, 5.])
|
86 |
+
"""
|
87 |
+
|
88 |
+
def _uniform_interpolate(start_state, end_state, index_delta):
|
89 |
+
delta_state = (end_state - start_state) * (1.0 / index_delta)
|
90 |
+
return np.concatenate([start_state + _ * delta_state
|
91 |
+
for _ in range(index_delta+1)], axis=0)
|
92 |
+
|
93 |
+
def _accelerate_interpolate(start_state, end_state, index_delta):
|
94 |
+
a = 2 * (end_state - start_state) / (index_delta ** 2)
|
95 |
+
return np.concatenate([start_state + 0.5 * a * (_**2)
|
96 |
+
for _ in range(index_delta+1)], axis=0)
|
97 |
+
|
98 |
+
def _decelerate_interpolate(start_state, end_state, index_delta):
|
99 |
+
a = 2 * (start_state - end_state) / (index_delta ** 2)
|
100 |
+
return np.concatenate([end_state + 0.5 * a * ((index_delta-_)**2)
|
101 |
+
for _ in range(index_delta+1)], axis=0)
|
102 |
+
|
103 |
+
assert key_frame_inds[0] == 0 and key_frame_inds[-1] > 0
|
104 |
+
num_key_frames = len(key_frame_inds)
|
105 |
+
assert num_key_frames == len(array)
|
106 |
+
num_frames = key_frame_inds[-1] + 1
|
107 |
+
|
108 |
+
out_array = np.zeros((num_frames, ) + array.shape[1:], dtype=array.dtype)
|
109 |
+
for i in range(num_key_frames - 1):
|
110 |
+
# fill the values between i -> i+1
|
111 |
+
st_idx, end_idx = key_frame_inds[i:i+2]
|
112 |
+
if interpolate == 'uniform':
|
113 |
+
inter_func = _uniform_interpolate
|
114 |
+
elif interpolate == 'accelerate':
|
115 |
+
inter_func = _accelerate_interpolate
|
116 |
+
elif interpolate == 'decelerate':
|
117 |
+
inter_func = _decelerate_interpolate
|
118 |
+
elif interpolate == 'random':
|
119 |
+
inter_index = np.random.choice(3, p=[0.7, 0.15, 0.15])
|
120 |
+
if inter_index == 0:
|
121 |
+
inter_func = _uniform_interpolate
|
122 |
+
elif inter_index == 1:
|
123 |
+
inter_func = _accelerate_interpolate
|
124 |
+
else:
|
125 |
+
inter_func = _decelerate_interpolate
|
126 |
+
else:
|
127 |
+
raise NotImplementedError
|
128 |
+
i_out = inter_func(array[i:i+1],
|
129 |
+
array[i+1:i+2],
|
130 |
+
end_idx - st_idx)
|
131 |
+
out_array[st_idx:end_idx+1] = i_out
|
132 |
+
|
133 |
+
return out_array
|
engine_for_finetuning.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
import sys
|
5 |
+
from typing import Iterable, Optional
|
6 |
+
import torch
|
7 |
+
from mixup import Mixup
|
8 |
+
from timm.utils import accuracy, ModelEma
|
9 |
+
import utils_mae as utils
|
10 |
+
from scipy.special import softmax
|
11 |
+
import gc
|
12 |
+
import pickle
|
13 |
+
|
14 |
+
def train_class_batch(model, samples, target, criterion):
|
15 |
+
outputs = model(samples)
|
16 |
+
loss = criterion(outputs, target)
|
17 |
+
return loss, outputs
|
18 |
+
|
19 |
+
|
20 |
+
def get_loss_scale_for_deepspeed(model):
|
21 |
+
optimizer = model.optimizer
|
22 |
+
return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
|
23 |
+
|
24 |
+
|
25 |
+
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
26 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
27 |
+
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
28 |
+
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
|
29 |
+
start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
|
30 |
+
num_training_steps_per_epoch=None, update_freq=None):
|
31 |
+
model.train(True)
|
32 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
33 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
34 |
+
metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
35 |
+
header = 'Epoch: [{}]'.format(epoch)
|
36 |
+
print_freq = 10
|
37 |
+
|
38 |
+
if loss_scaler is None:
|
39 |
+
model.zero_grad()
|
40 |
+
model.micro_steps = 0
|
41 |
+
else:
|
42 |
+
optimizer.zero_grad()
|
43 |
+
|
44 |
+
for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
45 |
+
step = data_iter_step // update_freq
|
46 |
+
if step >= num_training_steps_per_epoch:
|
47 |
+
continue
|
48 |
+
it = start_steps + step # global training iteration
|
49 |
+
# Update LR & WD for the first acc
|
50 |
+
if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
|
51 |
+
for i, param_group in enumerate(optimizer.param_groups):
|
52 |
+
if lr_schedule_values is not None:
|
53 |
+
param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
|
54 |
+
if wd_schedule_values is not None and param_group["weight_decay"] > 0:
|
55 |
+
param_group["weight_decay"] = wd_schedule_values[it]
|
56 |
+
|
57 |
+
samples = samples.to(device, non_blocking=True)
|
58 |
+
targets = targets.to(device, non_blocking=True)
|
59 |
+
|
60 |
+
if mixup_fn is not None:
|
61 |
+
samples, targets = mixup_fn(samples, targets)
|
62 |
+
|
63 |
+
if loss_scaler is None:
|
64 |
+
samples = samples.half()
|
65 |
+
loss, output = train_class_batch(
|
66 |
+
model, samples, targets, criterion)
|
67 |
+
else:
|
68 |
+
with torch.cuda.amp.autocast():
|
69 |
+
loss, output = train_class_batch(
|
70 |
+
model, samples, targets, criterion)
|
71 |
+
|
72 |
+
loss_value = loss.item()
|
73 |
+
|
74 |
+
if not math.isfinite(loss_value):
|
75 |
+
print("Loss is {}, stopping training".format(loss_value))
|
76 |
+
sys.exit(1)
|
77 |
+
|
78 |
+
if loss_scaler is None:
|
79 |
+
loss /= update_freq
|
80 |
+
model.backward(loss)
|
81 |
+
model.step()
|
82 |
+
|
83 |
+
if (data_iter_step + 1) % update_freq == 0:
|
84 |
+
# model.zero_grad()
|
85 |
+
# Deepspeed will call step() & model.zero_grad() automatic
|
86 |
+
if model_ema is not None:
|
87 |
+
model_ema.update(model)
|
88 |
+
grad_norm = None
|
89 |
+
loss_scale_value = get_loss_scale_for_deepspeed(model)
|
90 |
+
else:
|
91 |
+
# this attribute is added by timm on one optimizer (adahessian)
|
92 |
+
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
|
93 |
+
loss /= update_freq
|
94 |
+
grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
|
95 |
+
parameters=model.parameters(), create_graph=is_second_order,
|
96 |
+
update_grad=(data_iter_step + 1) % update_freq == 0)
|
97 |
+
if (data_iter_step + 1) % update_freq == 0:
|
98 |
+
optimizer.zero_grad()
|
99 |
+
if model_ema is not None:
|
100 |
+
model_ema.update(model)
|
101 |
+
loss_scale_value = loss_scaler.state_dict()["scale"]
|
102 |
+
|
103 |
+
torch.cuda.synchronize()
|
104 |
+
|
105 |
+
if mixup_fn is None:
|
106 |
+
class_acc = (output.max(-1)[-1] == targets).float().mean()
|
107 |
+
else:
|
108 |
+
class_acc = None
|
109 |
+
metric_logger.update(loss=loss_value)
|
110 |
+
metric_logger.update(class_acc=class_acc)
|
111 |
+
metric_logger.update(loss_scale=loss_scale_value)
|
112 |
+
min_lr = 10.
|
113 |
+
max_lr = 0.
|
114 |
+
for group in optimizer.param_groups:
|
115 |
+
min_lr = min(min_lr, group["lr"])
|
116 |
+
max_lr = max(max_lr, group["lr"])
|
117 |
+
|
118 |
+
metric_logger.update(lr=max_lr)
|
119 |
+
metric_logger.update(min_lr=min_lr)
|
120 |
+
weight_decay_value = None
|
121 |
+
for group in optimizer.param_groups:
|
122 |
+
if group["weight_decay"] > 0:
|
123 |
+
weight_decay_value = group["weight_decay"]
|
124 |
+
metric_logger.update(weight_decay=weight_decay_value)
|
125 |
+
metric_logger.update(grad_norm=grad_norm)
|
126 |
+
|
127 |
+
if log_writer is not None:
|
128 |
+
log_writer.update(loss=loss_value, head="loss")
|
129 |
+
log_writer.update(class_acc=class_acc, head="loss")
|
130 |
+
log_writer.update(loss_scale=loss_scale_value, head="opt")
|
131 |
+
log_writer.update(lr=max_lr, head="opt")
|
132 |
+
log_writer.update(min_lr=min_lr, head="opt")
|
133 |
+
log_writer.update(weight_decay=weight_decay_value, head="opt")
|
134 |
+
log_writer.update(grad_norm=grad_norm, head="opt")
|
135 |
+
|
136 |
+
log_writer.set_step()
|
137 |
+
|
138 |
+
# gather the stats from all processes
|
139 |
+
metric_logger.synchronize_between_processes()
|
140 |
+
print("Averaged stats:", metric_logger)
|
141 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
142 |
+
|
143 |
+
|
144 |
+
@torch.no_grad()
|
145 |
+
def validation_one_epoch(data_loader, model, device):
|
146 |
+
criterion = torch.nn.CrossEntropyLoss()
|
147 |
+
|
148 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
149 |
+
header = 'Val:'
|
150 |
+
|
151 |
+
# switch to evaluation mode
|
152 |
+
model.eval()
|
153 |
+
|
154 |
+
for batch in metric_logger.log_every(data_loader, 10, header):
|
155 |
+
videos = batch[0]
|
156 |
+
target = batch[1]
|
157 |
+
videos = videos.to(device, non_blocking=True)
|
158 |
+
target = target.to(device, non_blocking=True)
|
159 |
+
|
160 |
+
# compute output
|
161 |
+
with torch.cuda.amp.autocast():
|
162 |
+
output = model(videos)
|
163 |
+
loss = criterion(output, target)
|
164 |
+
|
165 |
+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
166 |
+
|
167 |
+
batch_size = videos.shape[0]
|
168 |
+
metric_logger.update(loss=loss.item())
|
169 |
+
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
170 |
+
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
|
171 |
+
# gather the stats from all processes
|
172 |
+
metric_logger.synchronize_between_processes()
|
173 |
+
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
|
174 |
+
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
|
175 |
+
|
176 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
@torch.no_grad()
|
181 |
+
def final_test(data_loader, model, device, file):
|
182 |
+
criterion = torch.nn.CrossEntropyLoss()
|
183 |
+
|
184 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
185 |
+
header = 'Test:'
|
186 |
+
|
187 |
+
# switch to evaluation mode
|
188 |
+
model.eval()
|
189 |
+
final_result = []
|
190 |
+
|
191 |
+
for batch in metric_logger.log_every(data_loader, 10, header):
|
192 |
+
videos = batch[0]
|
193 |
+
target = batch[1]
|
194 |
+
ids = batch[2]
|
195 |
+
chunk_nb = batch[3]
|
196 |
+
split_nb = batch[4]
|
197 |
+
videos = videos.to(device, non_blocking=True)
|
198 |
+
target = target.to(device, non_blocking=True)
|
199 |
+
|
200 |
+
# compute output
|
201 |
+
with torch.cuda.amp.autocast():
|
202 |
+
output = model(videos)
|
203 |
+
loss = criterion(output, target)
|
204 |
+
|
205 |
+
for i in range(output.size(0)):
|
206 |
+
string = "{} {} {} {} {}\n".format(ids[i], \
|
207 |
+
str(output.data[i].cpu().numpy().tolist()), \
|
208 |
+
str(int(target[i].cpu().numpy())), \
|
209 |
+
str(int(chunk_nb[i].cpu().numpy())), \
|
210 |
+
str(int(split_nb[i].cpu().numpy())))
|
211 |
+
final_result.append(string)
|
212 |
+
|
213 |
+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
214 |
+
|
215 |
+
batch_size = videos.shape[0]
|
216 |
+
metric_logger.update(loss=loss.item())
|
217 |
+
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
218 |
+
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
|
219 |
+
|
220 |
+
if not os.path.exists(file):
|
221 |
+
os.mknod(file)
|
222 |
+
with open(file, 'w') as f:
|
223 |
+
f.write("{}, {}\n".format(acc1, acc5))
|
224 |
+
for line in final_result:
|
225 |
+
f.write(line)
|
226 |
+
# gather the stats from all processes
|
227 |
+
metric_logger.synchronize_between_processes()
|
228 |
+
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
|
229 |
+
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
|
230 |
+
|
231 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
232 |
+
|
233 |
+
|
234 |
+
def merge(eval_path, num_tasks):
|
235 |
+
dict_feats = {}
|
236 |
+
dict_label = {}
|
237 |
+
dict_pos = {}
|
238 |
+
print("Reading individual output files")
|
239 |
+
|
240 |
+
for x in range(num_tasks):
|
241 |
+
file = os.path.join(eval_path, str(x) + '.txt')
|
242 |
+
lines = open(file, 'r').readlines()[1:]
|
243 |
+
for line in lines:
|
244 |
+
line = line.strip()
|
245 |
+
name = line.split('[')[0]
|
246 |
+
label = line.split(']')[1].split(' ')[1]
|
247 |
+
chunk_nb = line.split(']')[1].split(' ')[2]
|
248 |
+
split_nb = line.split(']')[1].split(' ')[3]
|
249 |
+
data = np.fromstring(line.split('[')[1].split(']')[0], dtype=float, sep=',')
|
250 |
+
data = softmax(data)
|
251 |
+
if not name in dict_feats:
|
252 |
+
dict_feats[name] = []
|
253 |
+
dict_label[name] = 0
|
254 |
+
dict_pos[name] = []
|
255 |
+
if chunk_nb + split_nb in dict_pos[name]:
|
256 |
+
continue
|
257 |
+
dict_feats[name].append(data)
|
258 |
+
dict_pos[name].append(chunk_nb + split_nb)
|
259 |
+
dict_label[name] = label
|
260 |
+
print("Computing final results")
|
261 |
+
|
262 |
+
input_lst = []
|
263 |
+
print(len(dict_feats))
|
264 |
+
for i, item in enumerate(dict_feats):
|
265 |
+
input_lst.append([i, item, dict_feats[item], dict_label[item]])
|
266 |
+
from multiprocessing import Pool
|
267 |
+
p = Pool(64)
|
268 |
+
ans = p.map(compute_video, input_lst)
|
269 |
+
top1 = [x[1] for x in ans]
|
270 |
+
top5 = [x[2] for x in ans]
|
271 |
+
pred = [x[0] for x in ans]
|
272 |
+
label = [x[3] for x in ans]
|
273 |
+
final_top1 ,final_top5 = np.mean(top1), np.mean(top5)
|
274 |
+
return final_top1*100 ,final_top5*100
|
275 |
+
|
276 |
+
def compute_video(lst):
|
277 |
+
i, video_id, data, label = lst
|
278 |
+
feat = [x for x in data]
|
279 |
+
feat = np.mean(feat, axis=0)
|
280 |
+
pred = np.argmax(feat)
|
281 |
+
top1 = (int(pred) == int(label)) * 1.0
|
282 |
+
top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0
|
283 |
+
return [pred, top1, top5, int(label)]
|
284 |
+
|
285 |
+
def merge_mean_per_class(eval_path, num_tasks,nb_classes):
|
286 |
+
dict_feats = {}
|
287 |
+
dict_label = {}
|
288 |
+
dict_pos = {}
|
289 |
+
#print("Reading individual output files")
|
290 |
+
|
291 |
+
for x in range(num_tasks):
|
292 |
+
file = os.path.join(eval_path, str(x) + '.txt')
|
293 |
+
lines = open(file, 'r').readlines()[1:]
|
294 |
+
for line in lines:
|
295 |
+
line = line.strip()
|
296 |
+
name = line.split('[')[0]
|
297 |
+
label = line.split(']')[1].split(' ')[1]
|
298 |
+
chunk_nb = line.split(']')[1].split(' ')[2]
|
299 |
+
split_nb = line.split(']')[1].split(' ')[3]
|
300 |
+
data = np.fromstring(line.split('[')[1].split(']')[0], dtype=float, sep=',')
|
301 |
+
data = softmax(data)
|
302 |
+
if not name in dict_feats:
|
303 |
+
dict_feats[name] = []
|
304 |
+
dict_label[name] = 0
|
305 |
+
dict_pos[name] = []
|
306 |
+
if chunk_nb + split_nb in dict_pos[name]:
|
307 |
+
continue
|
308 |
+
dict_feats[name].append(data)
|
309 |
+
dict_pos[name].append(chunk_nb + split_nb)
|
310 |
+
dict_label[name] = label
|
311 |
+
print("Computing mean per class results")
|
312 |
+
|
313 |
+
input_lst = []
|
314 |
+
all_pred = []
|
315 |
+
all_label = []
|
316 |
+
|
317 |
+
classes = torch.arange(nb_classes)
|
318 |
+
classwise_top1 = [0 for c in classes]
|
319 |
+
classwise_top5 = [0 for c in classes]
|
320 |
+
actual_nb_classes = nb_classes
|
321 |
+
cnt = 0
|
322 |
+
|
323 |
+
for c in classes:
|
324 |
+
input_lst = []
|
325 |
+
for i, item in enumerate(dict_feats):
|
326 |
+
if int(dict_label[item]) == c:
|
327 |
+
input_lst.append([i, item, dict_feats[item], dict_label[item]])
|
328 |
+
cnt += len(input_lst)
|
329 |
+
|
330 |
+
# p = Pool(4)
|
331 |
+
# ans = p.map(compute_video, input_lst)
|
332 |
+
if len(input_lst) == 0:
|
333 |
+
actual_nb_classes -= 1
|
334 |
+
print(f"Class {c} is not present in test set, skip")
|
335 |
+
continue
|
336 |
+
|
337 |
+
ans = []
|
338 |
+
for i in input_lst:
|
339 |
+
ans.append(compute_video(i))
|
340 |
+
top1 = [x[1] for x in ans]
|
341 |
+
top5 = [x[2] for x in ans]
|
342 |
+
pred = [x[0] for x in ans]
|
343 |
+
label = [x[3] for x in ans]
|
344 |
+
|
345 |
+
# for i in pred:
|
346 |
+
# all_pred.append(i)
|
347 |
+
# for j in label:
|
348 |
+
# all_label.append(j)
|
349 |
+
final_top1 ,final_top5 = np.mean(top1), np.mean(top5)
|
350 |
+
|
351 |
+
classwise_top1[c] = final_top1*100
|
352 |
+
classwise_top5[c] = final_top5*100
|
353 |
+
|
354 |
+
del input_lst
|
355 |
+
del ans
|
356 |
+
del top1
|
357 |
+
del top5
|
358 |
+
del pred
|
359 |
+
del label
|
360 |
+
gc.collect()
|
361 |
+
|
362 |
+
assert cnt == len(dict_feats)
|
363 |
+
# pred_cnt = 0
|
364 |
+
# for idx, p in enumerate(all_pred):
|
365 |
+
# if int(p) == int(all_label[idx]):
|
366 |
+
# pred_cnt += 1
|
367 |
+
# print(pred_cnt/len(all_pred))
|
368 |
+
classwise_top1_path = os.path.join(eval_path, "classwise_top1.pkl")
|
369 |
+
with open(classwise_top1_path, 'wb') as file:
|
370 |
+
pickle.dump(classwise_top1, file)
|
371 |
+
|
372 |
+
classwise_top1 = np.sum(classwise_top1) / actual_nb_classes
|
373 |
+
classwise_top5 = np.sum(classwise_top5) / actual_nb_classes
|
374 |
+
|
375 |
+
return classwise_top1,classwise_top5
|
engine_for_pretraining.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import sys
|
3 |
+
from typing import Iterable
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import utils_mae as utils
|
7 |
+
from einops import rearrange
|
8 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
9 |
+
|
10 |
+
def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
11 |
+
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, patch_size: int = 16,
|
12 |
+
normlize_target: bool = True, log_writer=None, lr_scheduler=None, start_steps=None,
|
13 |
+
lr_schedule_values=None, wd_schedule_values=None,teacher_model=None,target_type='pixel', multiple_sampling=False):
|
14 |
+
|
15 |
+
model.train()
|
16 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
17 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
18 |
+
metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
19 |
+
header = 'Epoch: [{}]'.format(epoch)
|
20 |
+
print_freq = 10
|
21 |
+
|
22 |
+
loss_func = nn.MSELoss()
|
23 |
+
|
24 |
+
for step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
25 |
+
# assign learning rate & weight decay for each step
|
26 |
+
it = start_steps + step # global training iteration
|
27 |
+
if lr_schedule_values is not None or wd_schedule_values is not None:
|
28 |
+
for i, param_group in enumerate(optimizer.param_groups):
|
29 |
+
if lr_schedule_values is not None:
|
30 |
+
param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
|
31 |
+
if wd_schedule_values is not None and param_group["weight_decay"] > 0:
|
32 |
+
param_group["weight_decay"] = wd_schedule_values[it]
|
33 |
+
|
34 |
+
videos, bool_masked_pos = batch
|
35 |
+
videos = videos.to(device, non_blocking=True)
|
36 |
+
bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
|
37 |
+
#print("input_1",videos.size(),bool_masked_pos.size())
|
38 |
+
bs, _, nf, h, w = videos.shape
|
39 |
+
|
40 |
+
idx = torch.randperm(bool_masked_pos.size(0))
|
41 |
+
shuffled_bool_masked_pos = bool_masked_pos[idx,:]
|
42 |
+
|
43 |
+
if 'pixel' in target_type:
|
44 |
+
|
45 |
+
with torch.no_grad():
|
46 |
+
# calculate the predict label
|
47 |
+
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
|
48 |
+
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
|
49 |
+
unnorm_videos = videos * std + mean # in [0, 1]
|
50 |
+
|
51 |
+
if normlize_target:
|
52 |
+
videos_squeeze = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', p0=2, p1=patch_size, p2=patch_size)
|
53 |
+
videos_norm = (videos_squeeze - videos_squeeze.mean(dim=-2, keepdim=True)
|
54 |
+
) / (videos_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
|
55 |
+
# we find that the mean is about 0.48 and standard deviation is about 0.08.
|
56 |
+
videos_patch = rearrange(videos_norm, 'b n p c -> b n (p c)')
|
57 |
+
else:
|
58 |
+
videos_patch = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2 c)', p0=2, p1=patch_size, p2=patch_size)
|
59 |
+
|
60 |
+
B, _, C = videos_patch.shape
|
61 |
+
if not multiple_sampling:
|
62 |
+
labels = videos_patch[bool_masked_pos].reshape(B, -1, C)
|
63 |
+
else:
|
64 |
+
labels_1 = videos_patch[bool_masked_pos].reshape(B, -1, C)
|
65 |
+
labels_2 = videos_patch[shuffled_bool_masked_pos].reshape(B, -1, C)
|
66 |
+
|
67 |
+
elif 'dino' in target_type or 'clip' in target_type:
|
68 |
+
|
69 |
+
with torch.no_grad():
|
70 |
+
permuted_video = videos.permute(0, 2, 1, 3, 4)
|
71 |
+
bs, nf, _, h, w = permuted_video.shape
|
72 |
+
permuted_video = permuted_video[:, ::2].flatten(0, 1)
|
73 |
+
permuted_video = permuted_video.to(device, non_blocking=True)
|
74 |
+
features = teacher_model(permuted_video)
|
75 |
+
_, np, dim = features.shape
|
76 |
+
features = features.reshape(bs, nf//2, np, dim)
|
77 |
+
features.requires_grad = False
|
78 |
+
|
79 |
+
features = features.to(device, non_blocking=True)
|
80 |
+
with torch.no_grad():
|
81 |
+
features_squeeze = rearrange(features, 'b n o c -> b (n o) c')
|
82 |
+
if normlize_target:
|
83 |
+
labels = (features_squeeze - features_squeeze.mean(dim=-2, keepdim=True)
|
84 |
+
) / (features_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
|
85 |
+
else:
|
86 |
+
labels = features_squeeze
|
87 |
+
B, _, C = labels.shape
|
88 |
+
if not multiple_sampling:
|
89 |
+
labels = labels[bool_masked_pos].reshape(B, -1, C)
|
90 |
+
else:
|
91 |
+
labels_1 = labels[bool_masked_pos].reshape(B, -1, C)
|
92 |
+
labels_2 = labels[shuffled_bool_masked_pos].reshape(B, -1, C)
|
93 |
+
|
94 |
+
|
95 |
+
with torch.cuda.amp.autocast():
|
96 |
+
if not multiple_sampling:
|
97 |
+
outputs = model(videos, bool_masked_pos)
|
98 |
+
else:
|
99 |
+
outputs_1 = model(videos, bool_masked_pos)
|
100 |
+
outputs_2 = model(videos,shuffled_bool_masked_pos)
|
101 |
+
|
102 |
+
labels = torch.cat((labels_1,labels_2),dim=0)
|
103 |
+
outputs = torch.cat((outputs_1,outputs_2),dim=0)
|
104 |
+
|
105 |
+
loss = loss_func(input=outputs, target=labels)
|
106 |
+
|
107 |
+
loss_value = loss.item()
|
108 |
+
if not math.isfinite(loss_value):
|
109 |
+
print("Loss is {}, stopping training".format(loss_value))
|
110 |
+
sys.exit(1)
|
111 |
+
|
112 |
+
optimizer.zero_grad()
|
113 |
+
# this attribute is added by timm on one optimizer (adahessian)
|
114 |
+
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
|
115 |
+
grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
|
116 |
+
parameters=model.parameters(), create_graph=is_second_order)
|
117 |
+
loss_scale_value = loss_scaler.state_dict()["scale"]
|
118 |
+
|
119 |
+
torch.cuda.synchronize()
|
120 |
+
|
121 |
+
metric_logger.update(loss=loss_value)
|
122 |
+
metric_logger.update(loss_scale=loss_scale_value)
|
123 |
+
min_lr = 10.
|
124 |
+
max_lr = 0.
|
125 |
+
for group in optimizer.param_groups:
|
126 |
+
min_lr = min(min_lr, group["lr"])
|
127 |
+
max_lr = max(max_lr, group["lr"])
|
128 |
+
|
129 |
+
metric_logger.update(lr=max_lr)
|
130 |
+
metric_logger.update(min_lr=min_lr)
|
131 |
+
weight_decay_value = None
|
132 |
+
for group in optimizer.param_groups:
|
133 |
+
if group["weight_decay"] > 0:
|
134 |
+
weight_decay_value = group["weight_decay"]
|
135 |
+
metric_logger.update(weight_decay=weight_decay_value)
|
136 |
+
metric_logger.update(grad_norm=grad_norm)
|
137 |
+
|
138 |
+
if log_writer is not None:
|
139 |
+
log_writer.update(loss=loss_value, head="loss")
|
140 |
+
log_writer.update(loss_scale=loss_scale_value, head="opt")
|
141 |
+
log_writer.update(lr=max_lr, head="opt")
|
142 |
+
log_writer.update(min_lr=min_lr, head="opt")
|
143 |
+
log_writer.update(weight_decay=weight_decay_value, head="opt")
|
144 |
+
log_writer.update(grad_norm=grad_norm, head="opt")
|
145 |
+
log_writer.set_step()
|
146 |
+
|
147 |
+
if lr_scheduler is not None:
|
148 |
+
lr_scheduler.step_update(start_steps + step)
|
149 |
+
# gather the stats from all processes
|
150 |
+
metric_logger.synchronize_between_processes()
|
151 |
+
print("Averaged stats:", metric_logger)
|
152 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
environment.yml
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: smile
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- anaconda
|
6 |
+
- conda-forge
|
7 |
+
- defaults
|
8 |
+
dependencies:
|
9 |
+
- _libgcc_mutex=0.1=conda_forge
|
10 |
+
- _openmp_mutex=4.5=2_gnu
|
11 |
+
- alsa-lib=1.2.8=h166bdaf_0
|
12 |
+
- aom=3.5.0=h27087fc_0
|
13 |
+
- appdirs=1.4.4=pyh9f0ad1d_0
|
14 |
+
- attr=2.5.1=h166bdaf_1
|
15 |
+
- blas=1.0=mkl
|
16 |
+
- bottleneck=1.3.5=py310ha9d4c09_0
|
17 |
+
- brotli-python=1.1.0=py310hc6cd4ac_1
|
18 |
+
- bzip2=1.0.8=hd590300_5
|
19 |
+
- c-ares=1.23.0=hd590300_0
|
20 |
+
- ca-certificates=2023.11.17=hbcca054_0
|
21 |
+
- cairo=1.16.0=ha61ee94_1014
|
22 |
+
- certifi=2023.11.17=py310h06a4308_0
|
23 |
+
- charset-normalizer=3.3.2=pyhd8ed1ab_0
|
24 |
+
- click=8.1.7=unix_pyh707e725_0
|
25 |
+
- colorama=0.4.6=pyhd8ed1ab_0
|
26 |
+
- cuda-cudart=11.8.89=0
|
27 |
+
- cuda-cupti=11.8.87=0
|
28 |
+
- cuda-libraries=11.8.0=0
|
29 |
+
- cuda-nvrtc=11.8.89=0
|
30 |
+
- cuda-nvtx=11.8.86=0
|
31 |
+
- cuda-runtime=11.8.0=0
|
32 |
+
- dbus=1.13.6=h5008d03_3
|
33 |
+
- docker-pycreds=0.4.0=py_0
|
34 |
+
- einops=0.7.0=pyhd8ed1ab_1
|
35 |
+
- expat=2.5.0=hcb278e6_1
|
36 |
+
- ffmpeg=5.1.2=gpl_h8dda1f0_106
|
37 |
+
- fftw=3.3.10=nompi_hc118613_108
|
38 |
+
- filelock=3.13.1=pyhd8ed1ab_0
|
39 |
+
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
40 |
+
- font-ttf-inconsolata=3.000=h77eed37_0
|
41 |
+
- font-ttf-source-code-pro=2.038=h77eed37_0
|
42 |
+
- font-ttf-ubuntu=0.83=h77eed37_1
|
43 |
+
- fontconfig=2.14.2=h14ed4e7_0
|
44 |
+
- fonts-conda-ecosystem=1=0
|
45 |
+
- fonts-conda-forge=1=0
|
46 |
+
- freeglut=3.2.2=h9c3ff4c_1
|
47 |
+
- freetype=2.12.1=h267a509_2
|
48 |
+
- fsspec=2023.12.0=pyhca7485f_0
|
49 |
+
- gettext=0.21.1=h27087fc_0
|
50 |
+
- gitdb=4.0.11=pyhd8ed1ab_0
|
51 |
+
- gitpython=3.1.40=pyhd8ed1ab_0
|
52 |
+
- glib=2.78.1=hfc55251_1
|
53 |
+
- glib-tools=2.78.1=hfc55251_1
|
54 |
+
- gmp=6.3.0=h59595ed_0
|
55 |
+
- gmpy2=2.1.2=py310h3ec546c_1
|
56 |
+
- gnutls=3.7.9=hb077bed_0
|
57 |
+
- graphite2=1.3.13=h58526e2_1001
|
58 |
+
- gst-plugins-base=1.22.0=h4243ec0_2
|
59 |
+
- gstreamer=1.22.0=h25f0c4b_2
|
60 |
+
- gstreamer-orc=0.4.34=hd590300_0
|
61 |
+
- harfbuzz=6.0.0=h8e241bc_0
|
62 |
+
- hdf5=1.14.0=nompi_hb72d44e_103
|
63 |
+
- huggingface_hub=0.19.4=pyhd8ed1ab_0
|
64 |
+
- icu=70.1=h27087fc_0
|
65 |
+
- idna=3.6=pyhd8ed1ab_0
|
66 |
+
- intel-openmp=2023.1.0=hdb19cb5_46306
|
67 |
+
- jack=1.9.22=h11f4161_0
|
68 |
+
- jasper=2.0.33=h0ff4b12_1
|
69 |
+
- jinja2=3.1.2=pyhd8ed1ab_1
|
70 |
+
- jpeg=9e=h166bdaf_2
|
71 |
+
- keyutils=1.6.1=h166bdaf_0
|
72 |
+
- krb5=1.20.1=h81ceb04_0
|
73 |
+
- lame=3.100=h166bdaf_1003
|
74 |
+
- lcms2=2.15=hfd0df8a_0
|
75 |
+
- ld_impl_linux-64=2.40=h41732ed_0
|
76 |
+
- lerc=4.0.0=h27087fc_0
|
77 |
+
- libabseil=20230802.1=cxx17_h59595ed_0
|
78 |
+
- libaec=1.1.2=h59595ed_1
|
79 |
+
- libblas=3.9.0=1_h86c2bf4_netlib
|
80 |
+
- libcap=2.67=he9d0100_0
|
81 |
+
- libcblas=3.9.0=5_h92ddd45_netlib
|
82 |
+
- libclang=15.0.7=default_hb11cfb5_4
|
83 |
+
- libclang13=15.0.7=default_ha2b6cf4_4
|
84 |
+
- libcublas=11.11.3.6=0
|
85 |
+
- libcufft=10.9.0.58=0
|
86 |
+
- libcufile=1.8.1.2=0
|
87 |
+
- libcups=2.3.3=h36d4200_3
|
88 |
+
- libcurand=10.3.4.101=0
|
89 |
+
- libcurl=8.1.2=h409715c_0
|
90 |
+
- libcusolver=11.4.1.48=0
|
91 |
+
- libcusparse=11.7.5.86=0
|
92 |
+
- libdb=6.2.32=h9c3ff4c_0
|
93 |
+
- libdeflate=1.17=h0b41bf4_0
|
94 |
+
- libdrm=2.4.114=h166bdaf_0
|
95 |
+
- libedit=3.1.20191231=he28a2e2_2
|
96 |
+
- libev=4.33=h516909a_1
|
97 |
+
- libevent=2.1.10=h28343ad_4
|
98 |
+
- libexpat=2.5.0=hcb278e6_1
|
99 |
+
- libffi=3.4.2=h7f98852_5
|
100 |
+
- libflac=1.4.3=h59595ed_0
|
101 |
+
- libgcc-ng=13.2.0=h807b86a_3
|
102 |
+
- libgcrypt=1.10.3=hd590300_0
|
103 |
+
- libgfortran-ng=13.2.0=h69a702a_3
|
104 |
+
- libgfortran5=13.2.0=ha4646dd_3
|
105 |
+
- libglib=2.78.1=h783c2da_1
|
106 |
+
- libglu=9.0.0=he1b5a44_1001
|
107 |
+
- libgomp=13.2.0=h807b86a_3
|
108 |
+
- libgpg-error=1.47=h71f35ed_0
|
109 |
+
- libhwloc=2.9.1=hd6dc26d_0
|
110 |
+
- libiconv=1.17=h166bdaf_0
|
111 |
+
- libidn2=2.3.4=h166bdaf_0
|
112 |
+
- libjpeg-turbo=2.0.0=h9bf148f_0
|
113 |
+
- liblapack=3.9.0=5_h92ddd45_netlib
|
114 |
+
- liblapacke=3.9.0=5_h92ddd45_netlib
|
115 |
+
- libllvm15=15.0.7=hadd5161_1
|
116 |
+
- libnghttp2=1.58.0=h47da74e_0
|
117 |
+
- libnpp=11.8.0.86=0
|
118 |
+
- libnsl=2.0.1=hd590300_0
|
119 |
+
- libnvjpeg=11.9.0.86=0
|
120 |
+
- libogg=1.3.4=h7f98852_1
|
121 |
+
- libopencv=4.7.0=py310hb48cf42_1
|
122 |
+
- libopus=1.3.1=h7f98852_1
|
123 |
+
- libpciaccess=0.17=h166bdaf_0
|
124 |
+
- libpng=1.6.39=h753d276_0
|
125 |
+
- libpq=15.3=hbcd7760_1
|
126 |
+
- libprotobuf=3.21.12=hfc55251_2
|
127 |
+
- libsndfile=1.2.2=hc60ed4a_1
|
128 |
+
- libsqlite=3.44.2=h2797004_0
|
129 |
+
- libssh2=1.11.0=h0841786_0
|
130 |
+
- libstdcxx-ng=13.2.0=h7e041cc_3
|
131 |
+
- libsystemd0=253=h8c4010b_1
|
132 |
+
- libtasn1=4.19.0=h166bdaf_0
|
133 |
+
- libtiff=4.5.0=h6adf6a1_2
|
134 |
+
- libtool=2.4.7=h27087fc_0
|
135 |
+
- libudev1=253=h0b41bf4_1
|
136 |
+
- libunistring=0.9.10=h7f98852_0
|
137 |
+
- libuuid=2.38.1=h0b41bf4_0
|
138 |
+
- libva=2.18.0=h0b41bf4_0
|
139 |
+
- libvorbis=1.3.7=h9c3ff4c_0
|
140 |
+
- libvpx=1.11.0=h9c3ff4c_3
|
141 |
+
- libwebp-base=1.3.2=hd590300_0
|
142 |
+
- libxcb=1.13=h7f98852_1004
|
143 |
+
- libxkbcommon=1.5.0=h79f4944_1
|
144 |
+
- libxml2=2.10.3=hca2bb57_4
|
145 |
+
- libzlib=1.2.13=hd590300_5
|
146 |
+
- llvm-openmp=15.0.7=h0cdce71_0
|
147 |
+
- lz4-c=1.9.4=hcb278e6_0
|
148 |
+
- markupsafe=2.1.3=py310h2372a71_1
|
149 |
+
- mkl=2023.1.0=h213fc3f_46344
|
150 |
+
- mkl-service=2.4.0=py310h5eee18b_1
|
151 |
+
- mpc=1.3.1=hfe3b2da_0
|
152 |
+
- mpfr=4.2.1=h9458935_0
|
153 |
+
- mpg123=1.32.3=h59595ed_0
|
154 |
+
- mpmath=1.3.0=pyhd8ed1ab_0
|
155 |
+
- mysql-common=8.0.33=hf1915f5_6
|
156 |
+
- mysql-libs=8.0.33=hca2cd23_6
|
157 |
+
- ncurses=6.4=h59595ed_2
|
158 |
+
- nettle=3.9.1=h7ab15ed_0
|
159 |
+
- networkx=3.2.1=pyhd8ed1ab_0
|
160 |
+
- nspr=4.35=h27087fc_0
|
161 |
+
- nss=3.95=h1d7d5a4_0
|
162 |
+
- numexpr=2.8.7=py310h85018f9_0
|
163 |
+
- numpy=1.26.2=py310hb13e2d6_0
|
164 |
+
- opencv=4.7.0=py310hff52083_1
|
165 |
+
- openh264=2.3.1=hcb278e6_2
|
166 |
+
- openjpeg=2.5.0=hfec8fc6_2
|
167 |
+
- openssl=3.1.4=hd590300_0
|
168 |
+
- p11-kit=0.24.1=hc5aa10d_0
|
169 |
+
- packaging=23.2=pyhd8ed1ab_0
|
170 |
+
- pandas=2.1.1=py310h1128e8f_0
|
171 |
+
- pathtools=0.1.2=py_1
|
172 |
+
- pcre2=10.42=hcad00b1_0
|
173 |
+
- pillow=9.4.0=py310h023d228_1
|
174 |
+
- pip=23.3.1=pyhd8ed1ab_0
|
175 |
+
- pixman=0.42.2=h59595ed_0
|
176 |
+
- protobuf=4.21.12=py310heca2aa9_0
|
177 |
+
- pthread-stubs=0.4=h36c2ea0_1001
|
178 |
+
- pulseaudio=16.1=hcb278e6_3
|
179 |
+
- pulseaudio-client=16.1=h5195f5e_3
|
180 |
+
- pulseaudio-daemon=16.1=ha8d29e2_3
|
181 |
+
- py-opencv=4.7.0=py310hfdc917e_1
|
182 |
+
- pysocks=1.7.1=pyha2e5f31_6
|
183 |
+
- python=3.10.13=hd12c33a_0_cpython
|
184 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
185 |
+
- python-tzdata=2023.3=pyhd3eb1b0_0
|
186 |
+
- python_abi=3.10=4_cp310
|
187 |
+
- pytorch=2.1.1=py3.10_cuda11.8_cudnn8.7.0_0
|
188 |
+
- pytorch-cuda=11.8=h7e8668a_5
|
189 |
+
- pytorch-mutex=1.0=cuda
|
190 |
+
- pytz=2023.3.post1=py310h06a4308_0
|
191 |
+
- pyyaml=6.0.1=py310h2372a71_1
|
192 |
+
- qt-main=5.15.8=h5d23da1_6
|
193 |
+
- readline=8.2=h8228510_1
|
194 |
+
- requests=2.31.0=pyhd8ed1ab_0
|
195 |
+
- safetensors=0.3.3=py310hcb5633a_1
|
196 |
+
- scipy=1.11.3=py310h5f9d8c6_0
|
197 |
+
- sentry-sdk=1.38.0=pyhd8ed1ab_0
|
198 |
+
- setproctitle=1.3.3=py310h2372a71_0
|
199 |
+
- setuptools=68.2.2=pyhd8ed1ab_0
|
200 |
+
- six=1.16.0=pyh6c4a22f_0
|
201 |
+
- smmap=5.0.0=pyhd8ed1ab_0
|
202 |
+
- svt-av1=1.4.1=hcb278e6_0
|
203 |
+
- sympy=1.12=pypyh9d50eac_103
|
204 |
+
- tbb=2021.9.0=hf52228f_0
|
205 |
+
- tensorboardx=2.6.2.2=pyhd8ed1ab_0
|
206 |
+
- timm=0.9.12=pyhd8ed1ab_0
|
207 |
+
- tk=8.6.13=noxft_h4845f30_101
|
208 |
+
- torchaudio=2.1.1=py310_cu118
|
209 |
+
- torchtriton=2.1.0=py310
|
210 |
+
- torchvision=0.16.1=py310_cu118
|
211 |
+
- tqdm=4.66.1=pyhd8ed1ab_0
|
212 |
+
- typing-extensions=4.8.0=hd8ed1ab_0
|
213 |
+
- typing_extensions=4.8.0=pyha770c72_0
|
214 |
+
- tzdata=2023c=h71feb2d_0
|
215 |
+
- urllib3=2.1.0=pyhd8ed1ab_0
|
216 |
+
- wandb=0.15.12=pyhd8ed1ab_0
|
217 |
+
- wheel=0.42.0=pyhd8ed1ab_0
|
218 |
+
- x264=1!164.3095=h166bdaf_2
|
219 |
+
- x265=3.5=h924138e_3
|
220 |
+
- xcb-util=0.4.0=h516909a_0
|
221 |
+
- xcb-util-image=0.4.0=h166bdaf_0
|
222 |
+
- xcb-util-keysyms=0.4.0=h516909a_0
|
223 |
+
- xcb-util-renderutil=0.3.9=h166bdaf_0
|
224 |
+
- xcb-util-wm=0.4.1=h516909a_0
|
225 |
+
- xkeyboard-config=2.38=h0b41bf4_0
|
226 |
+
- xorg-fixesproto=5.0=h7f98852_1002
|
227 |
+
- xorg-inputproto=2.3.2=h7f98852_1002
|
228 |
+
- xorg-kbproto=1.0.7=h7f98852_1002
|
229 |
+
- xorg-libice=1.1.1=hd590300_0
|
230 |
+
- xorg-libsm=1.2.4=h7391055_0
|
231 |
+
- xorg-libx11=1.8.4=h0b41bf4_0
|
232 |
+
- xorg-libxau=1.0.11=hd590300_0
|
233 |
+
- xorg-libxdmcp=1.1.3=h7f98852_0
|
234 |
+
- xorg-libxext=1.3.4=h0b41bf4_2
|
235 |
+
- xorg-libxfixes=5.0.3=h7f98852_1004
|
236 |
+
- xorg-libxi=1.7.10=h7f98852_0
|
237 |
+
- xorg-libxrender=0.9.10=h7f98852_1003
|
238 |
+
- xorg-renderproto=0.11.1=h7f98852_1002
|
239 |
+
- xorg-xextproto=7.3.0=h0b41bf4_1003
|
240 |
+
- xorg-xproto=7.0.31=h7f98852_1007
|
241 |
+
- xz=5.2.6=h166bdaf_0
|
242 |
+
- yaml=0.2.5=h7f98852_2
|
243 |
+
- zlib=1.2.13=hd590300_5
|
244 |
+
- zstd=1.5.5=hfc55251_0
|
245 |
+
- pip:
|
246 |
+
- annotated-types==0.6.0
|
247 |
+
- decord==0.6.0
|
248 |
+
- hjson==3.1.0
|
249 |
+
- ninja==1.11.1.1
|
250 |
+
- psutil==5.9.6
|
251 |
+
- py-cpuinfo==9.0.0
|
252 |
+
- pydantic==2.5.2
|
253 |
+
- pydantic-core==2.14.5
|
254 |
+
- pynvml==11.5.0
|
255 |
+
- imutils==0.5.4
|
256 |
+
- transformers==4.31.0
|
257 |
+
- ftfy
|
258 |
+
- easydict
|
259 |
+
- matplotlib==3.10.0
|
functional.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numbers
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import PIL
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def _is_tensor_clip(clip):
|
9 |
+
return torch.is_tensor(clip) and clip.ndimension() == 4
|
10 |
+
|
11 |
+
|
12 |
+
def crop_clip(clip, min_h, min_w, h, w):
|
13 |
+
if isinstance(clip[0], np.ndarray):
|
14 |
+
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
|
15 |
+
|
16 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
17 |
+
cropped = [
|
18 |
+
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
|
19 |
+
]
|
20 |
+
else:
|
21 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
22 |
+
'but got list of {0}'.format(type(clip[0])))
|
23 |
+
return cropped
|
24 |
+
|
25 |
+
|
26 |
+
def resize_clip(clip, size, interpolation='bilinear'):
|
27 |
+
if isinstance(clip[0], np.ndarray):
|
28 |
+
if isinstance(size, numbers.Number):
|
29 |
+
im_h, im_w, im_c = clip[0].shape
|
30 |
+
# Min spatial dim already matches minimal size
|
31 |
+
if (im_w <= im_h and im_w == size) or (im_h <= im_w
|
32 |
+
and im_h == size):
|
33 |
+
return clip
|
34 |
+
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
35 |
+
size = (new_w, new_h)
|
36 |
+
else:
|
37 |
+
size = size[0], size[1]
|
38 |
+
if interpolation == 'bilinear':
|
39 |
+
np_inter = cv2.INTER_LINEAR
|
40 |
+
else:
|
41 |
+
np_inter = cv2.INTER_NEAREST
|
42 |
+
scaled = [
|
43 |
+
cv2.resize(img, size, interpolation=np_inter) for img in clip
|
44 |
+
]
|
45 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
46 |
+
if isinstance(size, numbers.Number):
|
47 |
+
im_w, im_h = clip[0].size
|
48 |
+
# Min spatial dim already matches minimal size
|
49 |
+
if (im_w <= im_h and im_w == size) or (im_h <= im_w
|
50 |
+
and im_h == size):
|
51 |
+
return clip
|
52 |
+
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
53 |
+
size = (new_w, new_h)
|
54 |
+
else:
|
55 |
+
size = size[1], size[0]
|
56 |
+
if interpolation == 'bilinear':
|
57 |
+
pil_inter = PIL.Image.BILINEAR
|
58 |
+
else:
|
59 |
+
pil_inter = PIL.Image.NEAREST
|
60 |
+
scaled = [img.resize(size, pil_inter) for img in clip]
|
61 |
+
else:
|
62 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
63 |
+
'but got list of {0}'.format(type(clip[0])))
|
64 |
+
return scaled
|
65 |
+
|
66 |
+
|
67 |
+
def get_resize_sizes(im_h, im_w, size):
|
68 |
+
if im_w < im_h:
|
69 |
+
ow = size
|
70 |
+
oh = int(size * im_h / im_w)
|
71 |
+
else:
|
72 |
+
oh = size
|
73 |
+
ow = int(size * im_w / im_h)
|
74 |
+
return oh, ow
|
75 |
+
|
76 |
+
|
77 |
+
def normalize(clip, mean, std, inplace=False):
|
78 |
+
if not _is_tensor_clip(clip):
|
79 |
+
raise TypeError('tensor is not a torch clip.')
|
80 |
+
|
81 |
+
if not inplace:
|
82 |
+
clip = clip.clone()
|
83 |
+
|
84 |
+
dtype = clip.dtype
|
85 |
+
mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
|
86 |
+
std = torch.as_tensor(std, dtype=dtype, device=clip.device)
|
87 |
+
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
88 |
+
|
89 |
+
return clip
|
kinetics.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from numpy.lib.function_base import disp
|
4 |
+
import torch
|
5 |
+
import decord
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import transforms
|
8 |
+
from random_erasing import RandomErasing
|
9 |
+
import warnings
|
10 |
+
from decord import VideoReader, cpu
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
import video_transforms as video_transforms
|
13 |
+
import volume_transforms as volume_transforms
|
14 |
+
|
15 |
+
class VideoClsDataset(Dataset):
|
16 |
+
"""Load your own video classification dataset."""
|
17 |
+
|
18 |
+
def __init__(self, anno_path, data_path, mode='train', clip_len=8,
|
19 |
+
frame_sample_rate=2, crop_size=224, short_side_size=256,
|
20 |
+
new_height=256, new_width=340, keep_aspect_ratio=True,
|
21 |
+
num_segment=1, num_crop=1, test_num_segment=10, test_num_crop=3,args=None):
|
22 |
+
self.anno_path = anno_path
|
23 |
+
self.data_path = data_path
|
24 |
+
self.mode = mode
|
25 |
+
self.clip_len = clip_len
|
26 |
+
self.frame_sample_rate = frame_sample_rate
|
27 |
+
self.crop_size = crop_size
|
28 |
+
self.short_side_size = short_side_size
|
29 |
+
self.new_height = new_height
|
30 |
+
self.new_width = new_width
|
31 |
+
self.keep_aspect_ratio = keep_aspect_ratio
|
32 |
+
self.num_segment = num_segment
|
33 |
+
self.test_num_segment = test_num_segment
|
34 |
+
self.num_crop = num_crop
|
35 |
+
self.test_num_crop = test_num_crop
|
36 |
+
self.args = args
|
37 |
+
self.aug = False
|
38 |
+
self.rand_erase = False
|
39 |
+
if self.mode in ['train']:
|
40 |
+
self.aug = True
|
41 |
+
if self.args.reprob > 0:
|
42 |
+
self.rand_erase = True
|
43 |
+
if VideoReader is None:
|
44 |
+
raise ImportError("Unable to import `decord` which is required to read videos.")
|
45 |
+
|
46 |
+
import pandas as pd
|
47 |
+
cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ')
|
48 |
+
self.dataset_samples = list(cleaned.values[:, 0])
|
49 |
+
self.label_array = list(cleaned.values[:, 1])
|
50 |
+
|
51 |
+
if (mode == 'train'):
|
52 |
+
pass
|
53 |
+
|
54 |
+
elif (mode == 'validation'):
|
55 |
+
self.data_transform = video_transforms.Compose([
|
56 |
+
video_transforms.Resize(self.short_side_size, interpolation='bilinear'),
|
57 |
+
video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)),
|
58 |
+
volume_transforms.ClipToTensor(),
|
59 |
+
video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
60 |
+
std=[0.229, 0.224, 0.225])
|
61 |
+
])
|
62 |
+
elif mode == 'test':
|
63 |
+
self.data_resize = video_transforms.Compose([
|
64 |
+
video_transforms.Resize(size=(short_side_size), interpolation='bilinear')
|
65 |
+
])
|
66 |
+
self.data_transform = video_transforms.Compose([
|
67 |
+
volume_transforms.ClipToTensor(),
|
68 |
+
video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
69 |
+
std=[0.229, 0.224, 0.225])
|
70 |
+
])
|
71 |
+
self.test_seg = []
|
72 |
+
self.test_dataset = []
|
73 |
+
self.test_label_array = []
|
74 |
+
for ck in range(self.test_num_segment):
|
75 |
+
for cp in range(self.test_num_crop):
|
76 |
+
for idx in range(len(self.label_array)):
|
77 |
+
sample_label = self.label_array[idx]
|
78 |
+
self.test_label_array.append(sample_label)
|
79 |
+
self.test_dataset.append(self.dataset_samples[idx])
|
80 |
+
self.test_seg.append((ck, cp))
|
81 |
+
|
82 |
+
def __getitem__(self, index):
|
83 |
+
if self.mode == 'train':
|
84 |
+
args = self.args
|
85 |
+
scale_t = 1
|
86 |
+
|
87 |
+
sample = self.dataset_samples[index]
|
88 |
+
buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C
|
89 |
+
if len(buffer) == 0:
|
90 |
+
while len(buffer) == 0:
|
91 |
+
warnings.warn("video {} not correctly loaded during training".format(sample))
|
92 |
+
index = np.random.randint(self.__len__())
|
93 |
+
sample = self.dataset_samples[index]
|
94 |
+
buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t)
|
95 |
+
|
96 |
+
if args.num_sample > 1:
|
97 |
+
frame_list = []
|
98 |
+
label_list = []
|
99 |
+
index_list = []
|
100 |
+
for _ in range(args.num_sample):
|
101 |
+
new_frames = self._aug_frame(buffer, args)
|
102 |
+
label = self.label_array[index]
|
103 |
+
frame_list.append(new_frames)
|
104 |
+
label_list.append(label)
|
105 |
+
index_list.append(index)
|
106 |
+
return frame_list, label_list, index_list, {}
|
107 |
+
else:
|
108 |
+
buffer = self._aug_frame(buffer, args)
|
109 |
+
|
110 |
+
return buffer, self.label_array[index], index, {}
|
111 |
+
|
112 |
+
elif self.mode == 'validation':
|
113 |
+
sample = self.dataset_samples[index]
|
114 |
+
buffer = self.loadvideo_decord(sample)
|
115 |
+
if len(buffer) == 0:
|
116 |
+
while len(buffer) == 0:
|
117 |
+
warnings.warn("video {} not correctly loaded during validation".format(sample))
|
118 |
+
index = np.random.randint(self.__len__())
|
119 |
+
sample = self.dataset_samples[index]
|
120 |
+
buffer = self.loadvideo_decord(sample)
|
121 |
+
buffer = self.data_transform(buffer)
|
122 |
+
return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
|
123 |
+
|
124 |
+
elif self.mode == 'test':
|
125 |
+
sample = self.test_dataset[index]
|
126 |
+
chunk_nb, split_nb = self.test_seg[index]
|
127 |
+
buffer = self.loadvideo_decord(sample)
|
128 |
+
|
129 |
+
while len(buffer) == 0:
|
130 |
+
warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
|
131 |
+
str(self.test_dataset[index]), chunk_nb, split_nb))
|
132 |
+
index = np.random.randint(self.__len__())
|
133 |
+
sample = self.test_dataset[index]
|
134 |
+
chunk_nb, split_nb = self.test_seg[index]
|
135 |
+
buffer = self.loadvideo_decord(sample)
|
136 |
+
|
137 |
+
buffer = self.data_resize(buffer)
|
138 |
+
if isinstance(buffer, list):
|
139 |
+
buffer = np.stack(buffer, 0)
|
140 |
+
|
141 |
+
spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
|
142 |
+
/ (self.test_num_crop - 1)
|
143 |
+
temporal_step = max(1.0 * (buffer.shape[0] - self.clip_len) \
|
144 |
+
/ (self.test_num_segment - 1), 0)
|
145 |
+
temporal_start = int(chunk_nb * temporal_step)
|
146 |
+
spatial_start = int(split_nb * spatial_step)
|
147 |
+
if buffer.shape[1] >= buffer.shape[2]:
|
148 |
+
buffer = buffer[temporal_start:temporal_start + self.clip_len, \
|
149 |
+
spatial_start:spatial_start + self.short_side_size, :, :]
|
150 |
+
else:
|
151 |
+
buffer = buffer[temporal_start:temporal_start + self.clip_len, \
|
152 |
+
:, spatial_start:spatial_start + self.short_side_size, :]
|
153 |
+
|
154 |
+
buffer = self.data_transform(buffer)
|
155 |
+
return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
|
156 |
+
chunk_nb, split_nb
|
157 |
+
else:
|
158 |
+
raise NameError('mode {} unkown'.format(self.mode))
|
159 |
+
|
160 |
+
def _aug_frame(
|
161 |
+
self,
|
162 |
+
buffer,
|
163 |
+
args,
|
164 |
+
):
|
165 |
+
|
166 |
+
aug_transform = video_transforms.create_random_augment(
|
167 |
+
input_size=(self.crop_size, self.crop_size),
|
168 |
+
auto_augment=args.aa,
|
169 |
+
interpolation=args.train_interpolation,
|
170 |
+
)
|
171 |
+
|
172 |
+
buffer = [
|
173 |
+
transforms.ToPILImage()(frame) for frame in buffer
|
174 |
+
]
|
175 |
+
|
176 |
+
buffer = aug_transform(buffer)
|
177 |
+
|
178 |
+
buffer = [transforms.ToTensor()(img) for img in buffer]
|
179 |
+
buffer = torch.stack(buffer) # T C H W
|
180 |
+
buffer = buffer.permute(0, 2, 3, 1) # T H W C
|
181 |
+
|
182 |
+
# T H W C
|
183 |
+
buffer = tensor_normalize(
|
184 |
+
buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
185 |
+
)
|
186 |
+
# T H W C -> C T H W.
|
187 |
+
buffer = buffer.permute(3, 0, 1, 2)
|
188 |
+
# Perform data augmentation.
|
189 |
+
scl, asp = (
|
190 |
+
[0.25, 1.0],
|
191 |
+
[0.75, 1.3333],
|
192 |
+
)
|
193 |
+
|
194 |
+
buffer = spatial_sampling(
|
195 |
+
buffer,
|
196 |
+
spatial_idx=-1,
|
197 |
+
min_scale=256,
|
198 |
+
max_scale=320,
|
199 |
+
crop_size=self.crop_size,
|
200 |
+
random_horizontal_flip=False if args.data_set == 'SSV2' else True ,
|
201 |
+
inverse_uniform_sampling=False,
|
202 |
+
aspect_ratio=asp,
|
203 |
+
scale=scl,
|
204 |
+
motion_shift=False
|
205 |
+
)
|
206 |
+
|
207 |
+
if self.rand_erase:
|
208 |
+
erase_transform = RandomErasing(
|
209 |
+
args.reprob,
|
210 |
+
mode=args.remode,
|
211 |
+
max_count=args.recount,
|
212 |
+
num_splits=args.recount,
|
213 |
+
device="cpu",
|
214 |
+
)
|
215 |
+
buffer = buffer.permute(1, 0, 2, 3)
|
216 |
+
buffer = erase_transform(buffer)
|
217 |
+
buffer = buffer.permute(1, 0, 2, 3)
|
218 |
+
|
219 |
+
return buffer
|
220 |
+
|
221 |
+
|
222 |
+
def loadvideo_decord(self, sample, sample_rate_scale=1):
|
223 |
+
"""Load video content using Decord"""
|
224 |
+
fname = sample
|
225 |
+
|
226 |
+
if not (os.path.exists(fname)):
|
227 |
+
return []
|
228 |
+
|
229 |
+
# avoid hanging issue
|
230 |
+
if os.path.getsize(fname) < 1 * 1024:
|
231 |
+
print('SKIP: ', fname, " - ", os.path.getsize(fname))
|
232 |
+
return []
|
233 |
+
try:
|
234 |
+
if self.keep_aspect_ratio:
|
235 |
+
vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
|
236 |
+
else:
|
237 |
+
vr = VideoReader(fname, width=self.new_width, height=self.new_height,
|
238 |
+
num_threads=1, ctx=cpu(0))
|
239 |
+
except:
|
240 |
+
print("video cannot be loaded by decord: ", fname)
|
241 |
+
return []
|
242 |
+
|
243 |
+
if self.mode == 'test':
|
244 |
+
all_index = [x for x in range(0, len(vr), self.frame_sample_rate)]
|
245 |
+
while len(all_index) < self.clip_len:
|
246 |
+
all_index.append(all_index[-1])
|
247 |
+
vr.seek(0)
|
248 |
+
buffer = vr.get_batch(all_index).asnumpy()
|
249 |
+
return buffer
|
250 |
+
|
251 |
+
# handle temporal segments
|
252 |
+
converted_len = int(self.clip_len * self.frame_sample_rate)
|
253 |
+
seg_len = len(vr) // self.num_segment
|
254 |
+
|
255 |
+
all_index = []
|
256 |
+
for i in range(self.num_segment):
|
257 |
+
if seg_len <= converted_len:
|
258 |
+
index = np.linspace(0, seg_len, num=seg_len // self.frame_sample_rate)
|
259 |
+
index = np.concatenate((index, np.ones(self.clip_len - seg_len // self.frame_sample_rate) * seg_len))
|
260 |
+
index = np.clip(index, 0, seg_len - 1).astype(np.int64)
|
261 |
+
else:
|
262 |
+
end_idx = np.random.randint(converted_len, seg_len)
|
263 |
+
str_idx = end_idx - converted_len
|
264 |
+
index = np.linspace(str_idx, end_idx, num=self.clip_len)
|
265 |
+
index = np.clip(index, str_idx, end_idx - 1).astype(np.int64)
|
266 |
+
index = index + i*seg_len
|
267 |
+
all_index.extend(list(index))
|
268 |
+
|
269 |
+
all_index = all_index[::int(sample_rate_scale)]
|
270 |
+
vr.seek(0)
|
271 |
+
buffer = vr.get_batch(all_index).asnumpy()
|
272 |
+
return buffer
|
273 |
+
|
274 |
+
def __len__(self):
|
275 |
+
if self.mode != 'test':
|
276 |
+
return len(self.dataset_samples)
|
277 |
+
else:
|
278 |
+
return len(self.test_dataset)
|
279 |
+
|
280 |
+
|
281 |
+
def spatial_sampling(
|
282 |
+
frames,
|
283 |
+
spatial_idx=-1,
|
284 |
+
min_scale=256,
|
285 |
+
max_scale=320,
|
286 |
+
crop_size=224,
|
287 |
+
random_horizontal_flip=True,
|
288 |
+
inverse_uniform_sampling=False,
|
289 |
+
aspect_ratio=None,
|
290 |
+
scale=None,
|
291 |
+
motion_shift=False,
|
292 |
+
):
|
293 |
+
"""
|
294 |
+
Perform spatial sampling on the given video frames. If spatial_idx is
|
295 |
+
-1, perform random scale, random crop, and random flip on the given
|
296 |
+
frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
|
297 |
+
with the given spatial_idx.
|
298 |
+
Args:
|
299 |
+
frames (tensor): frames of images sampled from the video. The
|
300 |
+
dimension is `num frames` x `height` x `width` x `channel`.
|
301 |
+
spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
|
302 |
+
or 2, perform left, center, right crop if width is larger than
|
303 |
+
height, and perform top, center, buttom crop if height is larger
|
304 |
+
than width.
|
305 |
+
min_scale (int): the minimal size of scaling.
|
306 |
+
max_scale (int): the maximal size of scaling.
|
307 |
+
crop_size (int): the size of height and width used to crop the
|
308 |
+
frames.
|
309 |
+
inverse_uniform_sampling (bool): if True, sample uniformly in
|
310 |
+
[1 / max_scale, 1 / min_scale] and take a reciprocal to get the
|
311 |
+
scale. If False, take a uniform sample from [min_scale,
|
312 |
+
max_scale].
|
313 |
+
aspect_ratio (list): Aspect ratio range for resizing.
|
314 |
+
scale (list): Scale range for resizing.
|
315 |
+
motion_shift (bool): Whether to apply motion shift for resizing.
|
316 |
+
Returns:
|
317 |
+
frames (tensor): spatially sampled frames.
|
318 |
+
"""
|
319 |
+
assert spatial_idx in [-1, 0, 1, 2]
|
320 |
+
if spatial_idx == -1:
|
321 |
+
if aspect_ratio is None and scale is None:
|
322 |
+
frames, _ = video_transforms.random_short_side_scale_jitter(
|
323 |
+
images=frames,
|
324 |
+
min_size=min_scale,
|
325 |
+
max_size=max_scale,
|
326 |
+
inverse_uniform_sampling=inverse_uniform_sampling,
|
327 |
+
)
|
328 |
+
frames, _ = video_transforms.random_crop(frames, crop_size)
|
329 |
+
else:
|
330 |
+
transform_func = (
|
331 |
+
video_transforms.random_resized_crop_with_shift
|
332 |
+
if motion_shift
|
333 |
+
else video_transforms.random_resized_crop
|
334 |
+
)
|
335 |
+
frames = transform_func(
|
336 |
+
images=frames,
|
337 |
+
target_height=crop_size,
|
338 |
+
target_width=crop_size,
|
339 |
+
scale=scale,
|
340 |
+
ratio=aspect_ratio,
|
341 |
+
)
|
342 |
+
if random_horizontal_flip:
|
343 |
+
frames, _ = video_transforms.horizontal_flip(0.5, frames)
|
344 |
+
else:
|
345 |
+
# The testing is deterministic and no jitter should be performed.
|
346 |
+
# min_scale, max_scale, and crop_size are expect to be the same.
|
347 |
+
assert len({min_scale, max_scale, crop_size}) == 1
|
348 |
+
frames, _ = video_transforms.random_short_side_scale_jitter(
|
349 |
+
frames, min_scale, max_scale
|
350 |
+
)
|
351 |
+
frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx)
|
352 |
+
return frames
|
353 |
+
|
354 |
+
|
355 |
+
def tensor_normalize(tensor, mean, std):
|
356 |
+
"""
|
357 |
+
Normalize a given tensor by subtracting the mean and dividing the std.
|
358 |
+
Args:
|
359 |
+
tensor (tensor): tensor to normalize.
|
360 |
+
mean (tensor or list): mean value to subtract.
|
361 |
+
std (tensor or list): std to divide.
|
362 |
+
"""
|
363 |
+
if tensor.dtype == torch.uint8:
|
364 |
+
tensor = tensor.float()
|
365 |
+
tensor = tensor / 255.0
|
366 |
+
if type(mean) == list:
|
367 |
+
mean = torch.tensor(mean)
|
368 |
+
if type(std) == list:
|
369 |
+
std = torch.tensor(std)
|
370 |
+
tensor = tensor - mean
|
371 |
+
tensor = tensor / std
|
372 |
+
return tensor
|
373 |
+
|
374 |
+
|
375 |
+
class VideoMAE(torch.utils.data.Dataset):
|
376 |
+
"""Load your own video classification dataset.
|
377 |
+
Parameters
|
378 |
+
----------
|
379 |
+
root : str, required.
|
380 |
+
Path to the root folder storing the dataset.
|
381 |
+
setting : str, required.
|
382 |
+
A text file describing the dataset, each line per video sample.
|
383 |
+
There are three items in each line: (1) video path; (2) video length and (3) video label.
|
384 |
+
train : bool, default True.
|
385 |
+
Whether to load the training or validation set.
|
386 |
+
test_mode : bool, default False.
|
387 |
+
Whether to perform evaluation on the test set.
|
388 |
+
Usually there is three-crop or ten-crop evaluation strategy involved.
|
389 |
+
name_pattern : str, default None.
|
390 |
+
The naming pattern of the decoded video frames.
|
391 |
+
For example, img_00012.jpg.
|
392 |
+
video_ext : str, default 'mp4'.
|
393 |
+
If video_loader is set to True, please specify the video format accordinly.
|
394 |
+
is_color : bool, default True.
|
395 |
+
Whether the loaded image is color or grayscale.
|
396 |
+
modality : str, default 'rgb'.
|
397 |
+
Input modalities, we support only rgb video frames for now.
|
398 |
+
Will add support for rgb difference image and optical flow image later.
|
399 |
+
num_segments : int, default 1.
|
400 |
+
Number of segments to evenly divide the video into clips.
|
401 |
+
A useful technique to obtain global video-level information.
|
402 |
+
Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
|
403 |
+
num_crop : int, default 1.
|
404 |
+
Number of crops for each image. default is 1.
|
405 |
+
Common choices are three crops and ten crops during evaluation.
|
406 |
+
new_length : int, default 1.
|
407 |
+
The length of input video clip. Default is a single image, but it can be multiple video frames.
|
408 |
+
For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
|
409 |
+
new_step : int, default 1.
|
410 |
+
Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
|
411 |
+
new_step=2 means we will extract a video clip of every other frame.
|
412 |
+
temporal_jitter : bool, default False.
|
413 |
+
Whether to temporally jitter if new_step > 1.
|
414 |
+
video_loader : bool, default False.
|
415 |
+
Whether to use video loader to load data.
|
416 |
+
use_decord : bool, default True.
|
417 |
+
Whether to use Decord video loader to load data. Otherwise use mmcv video loader.
|
418 |
+
transform : function, default None.
|
419 |
+
A function that takes data and label and transforms them.
|
420 |
+
data_aug : str, default 'v1'.
|
421 |
+
Different types of data augmentation auto. Supports v1, v2, v3 and v4.
|
422 |
+
lazy_init : bool, default False.
|
423 |
+
If set to True, build a dataset instance without loading any dataset.
|
424 |
+
"""
|
425 |
+
def __init__(self,
|
426 |
+
root,
|
427 |
+
setting,
|
428 |
+
train=True,
|
429 |
+
test_mode=False,
|
430 |
+
name_pattern='img_%05d.jpg',
|
431 |
+
video_ext='mp4',
|
432 |
+
is_color=True,
|
433 |
+
modality='rgb',
|
434 |
+
num_segments=1,
|
435 |
+
num_crop=1,
|
436 |
+
new_length=1,
|
437 |
+
new_step=1,
|
438 |
+
transform=None,
|
439 |
+
temporal_jitter=False,
|
440 |
+
video_loader=False,
|
441 |
+
use_decord=False,
|
442 |
+
lazy_init=False):
|
443 |
+
|
444 |
+
super(VideoMAE, self).__init__()
|
445 |
+
self.root = root
|
446 |
+
self.setting = setting
|
447 |
+
self.train = train
|
448 |
+
self.test_mode = test_mode
|
449 |
+
self.is_color = is_color
|
450 |
+
self.modality = modality
|
451 |
+
self.num_segments = num_segments
|
452 |
+
self.num_crop = num_crop
|
453 |
+
self.new_length = new_length
|
454 |
+
self.new_step = new_step
|
455 |
+
self.skip_length = self.new_length * self.new_step
|
456 |
+
self.temporal_jitter = temporal_jitter
|
457 |
+
self.name_pattern = name_pattern
|
458 |
+
self.video_loader = video_loader
|
459 |
+
self.video_ext = video_ext
|
460 |
+
self.use_decord = use_decord
|
461 |
+
self.transform = transform
|
462 |
+
self.lazy_init = lazy_init
|
463 |
+
|
464 |
+
|
465 |
+
if not self.lazy_init:
|
466 |
+
self.clips = self._make_dataset(root, setting)
|
467 |
+
if len(self.clips) == 0:
|
468 |
+
raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
|
469 |
+
"Check your data directory (opt.data-dir)."))
|
470 |
+
|
471 |
+
def __getitem__(self, index):
|
472 |
+
try:
|
473 |
+
directory, target = self.clips[index]
|
474 |
+
if self.video_loader:
|
475 |
+
if '.' in directory.split('/')[-1]:
|
476 |
+
# data in the "setting" file already have extension, e.g., demo.mp4
|
477 |
+
video_name = directory
|
478 |
+
else:
|
479 |
+
# data in the "setting" file do not have extension, e.g., demo
|
480 |
+
# So we need to provide extension (i.e., .mp4) to complete the file name.
|
481 |
+
video_name = '{}.{}'.format(directory, self.video_ext)
|
482 |
+
|
483 |
+
decord_vr = decord.VideoReader(video_name, num_threads=1)
|
484 |
+
duration = len(decord_vr)
|
485 |
+
|
486 |
+
segment_indices, skip_offsets = self._sample_train_indices(duration)
|
487 |
+
|
488 |
+
images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets)
|
489 |
+
|
490 |
+
process_data, mask = self.transform((images, None)) # T*C,H,W
|
491 |
+
process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0,1) # T*C,H,W -> T,C,H,W -> C,T,H,W
|
492 |
+
return (process_data, mask)
|
493 |
+
except Exception as error:
|
494 |
+
print(error , " in failed to load : ",video_name)
|
495 |
+
return self[(index+1) % len(self)]
|
496 |
+
|
497 |
+
|
498 |
+
def __len__(self):
|
499 |
+
return len(self.clips)
|
500 |
+
|
501 |
+
def _make_dataset(self, directory, setting):
|
502 |
+
if not os.path.exists(setting):
|
503 |
+
raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
|
504 |
+
clips = []
|
505 |
+
with open(setting) as split_f:
|
506 |
+
data = split_f.readlines()
|
507 |
+
for line in data:
|
508 |
+
line_info = line.split(' ')
|
509 |
+
# line format: video_path, video_duration, video_label
|
510 |
+
if len(line_info) < 2:
|
511 |
+
raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
|
512 |
+
clip_path = os.path.join(line_info[0])
|
513 |
+
target = int(line_info[1])
|
514 |
+
item = (clip_path, target)
|
515 |
+
clips.append(item)
|
516 |
+
return clips
|
517 |
+
|
518 |
+
def _sample_train_indices(self, num_frames):
|
519 |
+
average_duration = (num_frames - self.skip_length + 1) // self.num_segments
|
520 |
+
if average_duration > 0:
|
521 |
+
offsets = np.multiply(list(range(self.num_segments)),
|
522 |
+
average_duration)
|
523 |
+
offsets = offsets + np.random.randint(average_duration,
|
524 |
+
size=self.num_segments)
|
525 |
+
elif num_frames > max(self.num_segments, self.skip_length):
|
526 |
+
offsets = np.sort(np.random.randint(
|
527 |
+
num_frames - self.skip_length + 1,
|
528 |
+
size=self.num_segments))
|
529 |
+
else:
|
530 |
+
offsets = np.zeros((self.num_segments,))
|
531 |
+
|
532 |
+
if self.temporal_jitter:
|
533 |
+
skip_offsets = np.random.randint(
|
534 |
+
self.new_step, size=self.skip_length // self.new_step)
|
535 |
+
else:
|
536 |
+
skip_offsets = np.zeros(
|
537 |
+
self.skip_length // self.new_step, dtype=int)
|
538 |
+
return offsets + 1, skip_offsets
|
539 |
+
|
540 |
+
|
541 |
+
def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets):
|
542 |
+
sampled_list = []
|
543 |
+
frame_id_list = []
|
544 |
+
for seg_ind in indices:
|
545 |
+
offset = int(seg_ind)
|
546 |
+
for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
|
547 |
+
if offset + skip_offsets[i] <= duration:
|
548 |
+
frame_id = offset + skip_offsets[i] - 1
|
549 |
+
else:
|
550 |
+
frame_id = offset - 1
|
551 |
+
frame_id_list.append(frame_id)
|
552 |
+
if offset + self.new_step < duration:
|
553 |
+
offset += self.new_step
|
554 |
+
try:
|
555 |
+
video_data = video_reader.get_batch(frame_id_list).asnumpy()
|
556 |
+
sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]
|
557 |
+
except:
|
558 |
+
raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, duration))
|
559 |
+
return sampled_list
|
masking_generator.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import ast
|
4 |
+
|
5 |
+
class TubeMaskingGenerator:
|
6 |
+
def __init__(self, input_size, mask_ratio):
|
7 |
+
self.frames, self.height, self.width = input_size
|
8 |
+
self.num_patches_per_frame = self.height * self.width
|
9 |
+
self.total_patches = self.frames * self.num_patches_per_frame
|
10 |
+
self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
|
11 |
+
self.total_masks = self.frames * self.num_masks_per_frame
|
12 |
+
|
13 |
+
def __repr__(self):
|
14 |
+
repr_str = "Maks: total patches {}, mask patches {}".format(
|
15 |
+
self.total_patches, self.total_masks
|
16 |
+
)
|
17 |
+
return repr_str
|
18 |
+
|
19 |
+
def __call__(self):
|
20 |
+
mask_per_frame = np.hstack([
|
21 |
+
np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
|
22 |
+
np.ones(self.num_masks_per_frame),
|
23 |
+
])
|
24 |
+
np.random.shuffle(mask_per_frame)
|
25 |
+
mask = np.tile(mask_per_frame, (self.frames,1)).flatten()
|
26 |
+
return mask
|
27 |
+
|
28 |
+
|
29 |
+
class TubeletMaskingGenerator:
|
30 |
+
def __init__(self, input_size, mask_ratio, visible_frames, mask_type="tube", traj_unmask_ratio=0.1):
|
31 |
+
self.tube_masking_generator = TubeMaskingGenerator(input_size, mask_ratio)
|
32 |
+
self.frames, self.height, self.width = input_size
|
33 |
+
self.num_patches_per_frame = self.height * self.width
|
34 |
+
self.total_patches = self.frames * self.num_patches_per_frame
|
35 |
+
self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
|
36 |
+
self.total_masks = self.frames * self.num_masks_per_frame
|
37 |
+
self.patch_size = 16
|
38 |
+
self.traj_unmask_ratio = traj_unmask_ratio
|
39 |
+
if visible_frames is not None:
|
40 |
+
visible_list = ast.literal_eval(visible_frames)
|
41 |
+
self.visible_frames = [int(element) for element in visible_list]
|
42 |
+
else:
|
43 |
+
self.visible_frames = None
|
44 |
+
|
45 |
+
self.mask_type = mask_type
|
46 |
+
|
47 |
+
def _balance_num_masks(self, combined_mask,
|
48 |
+
unmasked_object_patches_index,
|
49 |
+
unmasked_non_object_patches_index,
|
50 |
+
masked_object_patches_index,
|
51 |
+
tube_masked_index=None,
|
52 |
+
tube_unmasked_index=None):
|
53 |
+
current_masks = np.sum(combined_mask)
|
54 |
+
num_diff = np.abs(self.total_masks - current_masks)
|
55 |
+
|
56 |
+
if tube_masked_index is None or tube_unmasked_index is None:
|
57 |
+
# tubelet masking without tube mask
|
58 |
+
# if too many masked patches, we unmask some patches
|
59 |
+
if current_masks > self.total_masks:
|
60 |
+
picked_index = masked_object_patches_index[np.random.choice(masked_object_patches_index.size, size=int(num_diff), replace=False)]
|
61 |
+
combined_mask[picked_index] = 0.
|
62 |
+
# if too few masked patches, we first try to mask non-object patches, if not enough, then we mask protected patches
|
63 |
+
elif current_masks < self.total_masks:
|
64 |
+
if num_diff <= len(unmasked_non_object_patches_index):
|
65 |
+
picked_index = unmasked_non_object_patches_index[np.random.choice(unmasked_non_object_patches_index.size, size=int(num_diff), replace=False)]
|
66 |
+
combined_mask[picked_index] = 1.
|
67 |
+
else:
|
68 |
+
combined_mask[unmasked_non_object_patches_index] = 1.
|
69 |
+
picked_index = unmasked_object_patches_index[np.random.choice(unmasked_object_patches_index.size, size=int(num_diff - len(unmasked_non_object_patches_index)), replace=False)]
|
70 |
+
combined_mask[picked_index] = 1.
|
71 |
+
else:
|
72 |
+
# if too many masked patches, we first try to unmask tube masked patches, if not enough, then we unmask object patches
|
73 |
+
tube_masked_non_object_index = np.array(list(set(tube_masked_index) - set(masked_object_patches_index) - set(unmasked_object_patches_index)))
|
74 |
+
if current_masks > self.total_masks:
|
75 |
+
if num_diff <= len(tube_masked_non_object_index):
|
76 |
+
picked_index = tube_masked_non_object_index[np.random.choice(tube_masked_non_object_index.size, size=int(num_diff), replace=False)]
|
77 |
+
combined_mask[picked_index] = 0.
|
78 |
+
else:
|
79 |
+
combined_mask[tube_masked_non_object_index] = 0.
|
80 |
+
picked_index = masked_object_patches_index[np.random.choice(masked_object_patches_index.size, size=int(num_diff - len(tube_masked_non_object_index)), replace=False)]
|
81 |
+
combined_mask[picked_index] = 0.
|
82 |
+
# if too few masked patches, we first try to mask non-object patches, if not enough, then we mask protected patches
|
83 |
+
elif current_masks < self.total_masks:
|
84 |
+
tube_unmasked_non_object_index = np.array(list(set(tube_unmasked_index) - set(masked_object_patches_index) - set(unmasked_object_patches_index)))
|
85 |
+
if num_diff <= len(tube_unmasked_non_object_index):
|
86 |
+
picked_index = tube_unmasked_non_object_index[np.random.choice(tube_unmasked_non_object_index.size, size=int(num_diff), replace=False)]
|
87 |
+
combined_mask[picked_index] = 1.
|
88 |
+
else:
|
89 |
+
combined_mask[tube_unmasked_non_object_index] = 1.
|
90 |
+
picked_index = unmasked_object_patches_index[np.random.choice(unmasked_object_patches_index.size, size=int(num_diff - len(tube_unmasked_non_object_index)), replace=False)]
|
91 |
+
combined_mask[picked_index] = 1.
|
92 |
+
|
93 |
+
balanced_mask = combined_mask
|
94 |
+
return balanced_mask
|
95 |
+
|
96 |
+
def __repr__(self):
|
97 |
+
repr_str = "Maks: total patches {}, mask patches {}".format(
|
98 |
+
self.total_patches, self.total_masks
|
99 |
+
)
|
100 |
+
return repr_str
|
101 |
+
|
102 |
+
# 1 in mask array means masked, 0 means unmasked
|
103 |
+
def __call__(self, traj_rois):
|
104 |
+
# generate original VideoMAE tube mask and intialize the tube mask index
|
105 |
+
tube_mask = self.tube_masking_generator()
|
106 |
+
tube_masked_index = None
|
107 |
+
tube_unmasked_index = None
|
108 |
+
|
109 |
+
# initialize mask
|
110 |
+
num_tubelet, num_frame, box = traj_rois.shape
|
111 |
+
assert num_frame % 2 == 0 and self.frames == (num_frame // 2)
|
112 |
+
combined_mask = np.zeros((num_frame // 2, self.height, self.width))
|
113 |
+
# assume patch size is (2, 16, 16) so mask shape should be (8, 14, 14)
|
114 |
+
# we combine the traj_rois of two consecutive frames to one large traj_rois
|
115 |
+
|
116 |
+
# pick one tubelet that is not masked
|
117 |
+
if self.visible_frames is None:
|
118 |
+
picked_frame = np.random.randint(0, (num_frame // 2))
|
119 |
+
picked_list = [picked_frame]
|
120 |
+
else:
|
121 |
+
picked_list = self.visible_frames
|
122 |
+
|
123 |
+
# combined mask 1 means object patches that should be masked, 2 means object patches that should not be masked, 0 means non-object patches
|
124 |
+
for roi_idx, roi in enumerate(traj_rois):
|
125 |
+
for i in range(num_frame // 2):
|
126 |
+
min_x = min( (roi[2 * i][0], roi[2 * i + 1][0]) )
|
127 |
+
max_x = max( (roi[2 * i][2], roi[2 * i + 1][2]) )
|
128 |
+
min_y = min( (roi[2 * i][1], roi[2 * i + 1][1]) )
|
129 |
+
max_y = max( (roi[2 * i][3], roi[2 * i + 1][3]) )
|
130 |
+
|
131 |
+
patch_index_x_min = max( int(np.floor(min_x / self.patch_size)), 0)
|
132 |
+
patch_index_x_max = min( int(np.ceil(max_x / self.patch_size)) + 1, 14)
|
133 |
+
patch_index_y_min = max( int(np.floor(min_y / self.patch_size)), 0)
|
134 |
+
patch_index_y_max = min( int(np.ceil(max_y / self.patch_size)) + 1, 14)
|
135 |
+
|
136 |
+
if i in picked_list:
|
137 |
+
combined_mask[i][patch_index_y_min:patch_index_y_max, patch_index_x_min:patch_index_x_max] = 2.
|
138 |
+
else:
|
139 |
+
combined_mask[i][patch_index_y_min:patch_index_y_max, patch_index_x_min:patch_index_x_max] = 1.
|
140 |
+
|
141 |
+
combined_mask = combined_mask.flatten()
|
142 |
+
masked_object_patches_index = np.where(combined_mask == 1.)[0]
|
143 |
+
unmasked_non_object_patches_index = np.where(combined_mask == 0.)[0]
|
144 |
+
unmasked_object_patches_index = np.where(combined_mask == 2.)[0]
|
145 |
+
combined_mask[unmasked_object_patches_index] = 0.
|
146 |
+
|
147 |
+
tube_masked_index = np.where(tube_mask == 1.)[0]
|
148 |
+
tube_unmasked_index = np.where(tube_mask == 0.)[0]
|
149 |
+
|
150 |
+
# combine tubelet mask and tube mask
|
151 |
+
combined_mask = np.bitwise_or(combined_mask.astype(bool), tube_mask.astype(bool)).astype(np.float32)
|
152 |
+
|
153 |
+
if self.mask_type == "tube+picked_frame_visible":
|
154 |
+
# unmasked the protected patches
|
155 |
+
combined_mask[unmasked_object_patches_index] = 0.
|
156 |
+
|
157 |
+
elif self.mask_type == "tube+traj_mask":
|
158 |
+
# get index of unmasked traj patches
|
159 |
+
traj_unmask_ratio = self.traj_unmask_ratio
|
160 |
+
traj_patches_index = np.array(list(set(masked_object_patches_index) | set(unmasked_object_patches_index)))
|
161 |
+
unmasked_traj_patches_index = traj_patches_index[np.random.choice(traj_patches_index.size, size=int(traj_unmask_ratio * len(traj_patches_index)), replace=False)]
|
162 |
+
|
163 |
+
# mask the whole traj
|
164 |
+
combined_mask[traj_patches_index] = 1.
|
165 |
+
# unmask those selected patches
|
166 |
+
combined_mask[unmasked_traj_patches_index] = 0.
|
167 |
+
|
168 |
+
# update indexes
|
169 |
+
unmasked_object_patches_index = unmasked_traj_patches_index
|
170 |
+
masked_object_patches_index = np.array(list(set(traj_patches_index) - set(unmasked_traj_patches_index)))
|
171 |
+
|
172 |
+
|
173 |
+
# balance masked patch number
|
174 |
+
mask = self._balance_num_masks(combined_mask,
|
175 |
+
unmasked_object_patches_index,
|
176 |
+
unmasked_non_object_patches_index,
|
177 |
+
masked_object_patches_index,
|
178 |
+
tube_masked_index,
|
179 |
+
tube_unmasked_index)
|
180 |
+
|
181 |
+
|
182 |
+
assert np.sum(mask) == self.total_masks
|
183 |
+
return mask
|
184 |
+
|
185 |
+
|
mixup.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Mixup and Cutmix
|
2 |
+
|
3 |
+
Papers:
|
4 |
+
mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
|
5 |
+
|
6 |
+
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
|
7 |
+
|
8 |
+
Code Reference:
|
9 |
+
CutMix: https://github.com/clovaai/CutMix-PyTorch
|
10 |
+
|
11 |
+
Hacked together by / Copyright 2019, Ross Wightman
|
12 |
+
"""
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
|
16 |
+
|
17 |
+
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
|
18 |
+
x = x.long().view(-1, 1)
|
19 |
+
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
|
20 |
+
|
21 |
+
|
22 |
+
def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
|
23 |
+
off_value = smoothing / num_classes
|
24 |
+
on_value = 1. - smoothing + off_value
|
25 |
+
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
|
26 |
+
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
|
27 |
+
return y1 * lam + y2 * (1. - lam)
|
28 |
+
|
29 |
+
|
30 |
+
def rand_bbox(img_shape, lam, margin=0., count=None):
|
31 |
+
""" Standard CutMix bounding-box
|
32 |
+
Generates a random square bbox based on lambda value. This impl includes
|
33 |
+
support for enforcing a border margin as percent of bbox dimensions.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
img_shape (tuple): Image shape as tuple
|
37 |
+
lam (float): Cutmix lambda value
|
38 |
+
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
|
39 |
+
count (int): Number of bbox to generate
|
40 |
+
"""
|
41 |
+
ratio = np.sqrt(1 - lam)
|
42 |
+
img_h, img_w = img_shape[-2:]
|
43 |
+
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
|
44 |
+
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
|
45 |
+
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
|
46 |
+
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
|
47 |
+
yl = np.clip(cy - cut_h // 2, 0, img_h)
|
48 |
+
yh = np.clip(cy + cut_h // 2, 0, img_h)
|
49 |
+
xl = np.clip(cx - cut_w // 2, 0, img_w)
|
50 |
+
xh = np.clip(cx + cut_w // 2, 0, img_w)
|
51 |
+
return yl, yh, xl, xh
|
52 |
+
|
53 |
+
|
54 |
+
def rand_bbox_minmax(img_shape, minmax, count=None):
|
55 |
+
""" Min-Max CutMix bounding-box
|
56 |
+
Inspired by Darknet cutmix impl, generates a random rectangular bbox
|
57 |
+
based on min/max percent values applied to each dimension of the input image.
|
58 |
+
|
59 |
+
Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
img_shape (tuple): Image shape as tuple
|
63 |
+
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
|
64 |
+
count (int): Number of bbox to generate
|
65 |
+
"""
|
66 |
+
assert len(minmax) == 2
|
67 |
+
img_h, img_w = img_shape[-2:]
|
68 |
+
cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
|
69 |
+
cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
|
70 |
+
yl = np.random.randint(0, img_h - cut_h, size=count)
|
71 |
+
xl = np.random.randint(0, img_w - cut_w, size=count)
|
72 |
+
yu = yl + cut_h
|
73 |
+
xu = xl + cut_w
|
74 |
+
return yl, yu, xl, xu
|
75 |
+
|
76 |
+
|
77 |
+
def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
|
78 |
+
""" Generate bbox and apply lambda correction.
|
79 |
+
"""
|
80 |
+
if ratio_minmax is not None:
|
81 |
+
yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
|
82 |
+
else:
|
83 |
+
yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
|
84 |
+
if correct_lam or ratio_minmax is not None:
|
85 |
+
bbox_area = (yu - yl) * (xu - xl)
|
86 |
+
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
|
87 |
+
return (yl, yu, xl, xu), lam
|
88 |
+
|
89 |
+
|
90 |
+
class Mixup:
|
91 |
+
""" Mixup/Cutmix that applies different params to each element or whole batch
|
92 |
+
|
93 |
+
Args:
|
94 |
+
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
|
95 |
+
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
|
96 |
+
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
|
97 |
+
prob (float): probability of applying mixup or cutmix per batch or element
|
98 |
+
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
|
99 |
+
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
|
100 |
+
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
|
101 |
+
label_smoothing (float): apply label smoothing to the mixed target tensor
|
102 |
+
num_classes (int): number of classes for target
|
103 |
+
"""
|
104 |
+
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
|
105 |
+
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
|
106 |
+
self.mixup_alpha = mixup_alpha
|
107 |
+
self.cutmix_alpha = cutmix_alpha
|
108 |
+
self.cutmix_minmax = cutmix_minmax
|
109 |
+
if self.cutmix_minmax is not None:
|
110 |
+
assert len(self.cutmix_minmax) == 2
|
111 |
+
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
|
112 |
+
self.cutmix_alpha = 1.0
|
113 |
+
self.mix_prob = prob
|
114 |
+
self.switch_prob = switch_prob
|
115 |
+
self.label_smoothing = label_smoothing
|
116 |
+
self.num_classes = num_classes
|
117 |
+
self.mode = mode
|
118 |
+
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
|
119 |
+
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
|
120 |
+
|
121 |
+
def _params_per_elem(self, batch_size):
|
122 |
+
lam = np.ones(batch_size, dtype=np.float32)
|
123 |
+
use_cutmix = np.zeros(batch_size, dtype=np.bool)
|
124 |
+
if self.mixup_enabled:
|
125 |
+
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
|
126 |
+
use_cutmix = np.random.rand(batch_size) < self.switch_prob
|
127 |
+
lam_mix = np.where(
|
128 |
+
use_cutmix,
|
129 |
+
np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
|
130 |
+
np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
|
131 |
+
elif self.mixup_alpha > 0.:
|
132 |
+
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
|
133 |
+
elif self.cutmix_alpha > 0.:
|
134 |
+
use_cutmix = np.ones(batch_size, dtype=np.bool)
|
135 |
+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
|
136 |
+
else:
|
137 |
+
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
|
138 |
+
lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
|
139 |
+
return lam, use_cutmix
|
140 |
+
|
141 |
+
def _params_per_batch(self):
|
142 |
+
lam = 1.
|
143 |
+
use_cutmix = False
|
144 |
+
if self.mixup_enabled and np.random.rand() < self.mix_prob:
|
145 |
+
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
|
146 |
+
use_cutmix = np.random.rand() < self.switch_prob
|
147 |
+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
|
148 |
+
np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
149 |
+
elif self.mixup_alpha > 0.:
|
150 |
+
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
151 |
+
elif self.cutmix_alpha > 0.:
|
152 |
+
use_cutmix = True
|
153 |
+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
|
154 |
+
else:
|
155 |
+
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
|
156 |
+
lam = float(lam_mix)
|
157 |
+
return lam, use_cutmix
|
158 |
+
|
159 |
+
def _mix_elem(self, x):
|
160 |
+
batch_size = len(x)
|
161 |
+
lam_batch, use_cutmix = self._params_per_elem(batch_size)
|
162 |
+
x_orig = x.clone() # need to keep an unmodified original for mixing source
|
163 |
+
for i in range(batch_size):
|
164 |
+
j = batch_size - i - 1
|
165 |
+
lam = lam_batch[i]
|
166 |
+
if lam != 1.:
|
167 |
+
if use_cutmix[i]:
|
168 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
169 |
+
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
170 |
+
x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh]
|
171 |
+
lam_batch[i] = lam
|
172 |
+
else:
|
173 |
+
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
|
174 |
+
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
|
175 |
+
|
176 |
+
def _mix_pair(self, x):
|
177 |
+
batch_size = len(x)
|
178 |
+
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
|
179 |
+
x_orig = x.clone() # need to keep an unmodified original for mixing source
|
180 |
+
for i in range(batch_size // 2):
|
181 |
+
j = batch_size - i - 1
|
182 |
+
lam = lam_batch[i]
|
183 |
+
if lam != 1.:
|
184 |
+
if use_cutmix[i]:
|
185 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
186 |
+
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
187 |
+
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
|
188 |
+
x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
|
189 |
+
lam_batch[i] = lam
|
190 |
+
else:
|
191 |
+
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
|
192 |
+
x[j] = x[j] * lam + x_orig[i] * (1 - lam)
|
193 |
+
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
|
194 |
+
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
|
195 |
+
|
196 |
+
def _mix_batch(self, x):
|
197 |
+
lam, use_cutmix = self._params_per_batch()
|
198 |
+
if lam == 1.:
|
199 |
+
return 1.
|
200 |
+
if use_cutmix:
|
201 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
202 |
+
x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
203 |
+
x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh]
|
204 |
+
else:
|
205 |
+
x_flipped = x.flip(0).mul_(1. - lam)
|
206 |
+
x.mul_(lam).add_(x_flipped)
|
207 |
+
return lam
|
208 |
+
|
209 |
+
def __call__(self, x, target):
|
210 |
+
assert len(x) % 2 == 0, 'Batch size should be even when using this'
|
211 |
+
if self.mode == 'elem':
|
212 |
+
lam = self._mix_elem(x)
|
213 |
+
elif self.mode == 'pair':
|
214 |
+
lam = self._mix_pair(x)
|
215 |
+
else:
|
216 |
+
lam = self._mix_batch(x)
|
217 |
+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
|
218 |
+
return x, target
|
219 |
+
|
220 |
+
|
221 |
+
class FastCollateMixup(Mixup):
|
222 |
+
""" Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
|
223 |
+
|
224 |
+
A Mixup impl that's performed while collating the batches.
|
225 |
+
"""
|
226 |
+
|
227 |
+
def _mix_elem_collate(self, output, batch, half=False):
|
228 |
+
batch_size = len(batch)
|
229 |
+
num_elem = batch_size // 2 if half else batch_size
|
230 |
+
assert len(output) == num_elem
|
231 |
+
lam_batch, use_cutmix = self._params_per_elem(num_elem)
|
232 |
+
for i in range(num_elem):
|
233 |
+
j = batch_size - i - 1
|
234 |
+
lam = lam_batch[i]
|
235 |
+
mixed = batch[i][0]
|
236 |
+
if lam != 1.:
|
237 |
+
if use_cutmix[i]:
|
238 |
+
if not half:
|
239 |
+
mixed = mixed.copy()
|
240 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
241 |
+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
242 |
+
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
|
243 |
+
lam_batch[i] = lam
|
244 |
+
else:
|
245 |
+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
|
246 |
+
np.rint(mixed, out=mixed)
|
247 |
+
output[i] += torch.from_numpy(mixed.astype(np.uint8))
|
248 |
+
if half:
|
249 |
+
lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
|
250 |
+
return torch.tensor(lam_batch).unsqueeze(1)
|
251 |
+
|
252 |
+
def _mix_pair_collate(self, output, batch):
|
253 |
+
batch_size = len(batch)
|
254 |
+
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
|
255 |
+
for i in range(batch_size // 2):
|
256 |
+
j = batch_size - i - 1
|
257 |
+
lam = lam_batch[i]
|
258 |
+
mixed_i = batch[i][0]
|
259 |
+
mixed_j = batch[j][0]
|
260 |
+
assert 0 <= lam <= 1.0
|
261 |
+
if lam < 1.:
|
262 |
+
if use_cutmix[i]:
|
263 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
264 |
+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
265 |
+
patch_i = mixed_i[:, yl:yh, xl:xh].copy()
|
266 |
+
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
|
267 |
+
mixed_j[:, yl:yh, xl:xh] = patch_i
|
268 |
+
lam_batch[i] = lam
|
269 |
+
else:
|
270 |
+
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
|
271 |
+
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
|
272 |
+
mixed_i = mixed_temp
|
273 |
+
np.rint(mixed_j, out=mixed_j)
|
274 |
+
np.rint(mixed_i, out=mixed_i)
|
275 |
+
output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
|
276 |
+
output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
|
277 |
+
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
|
278 |
+
return torch.tensor(lam_batch).unsqueeze(1)
|
279 |
+
|
280 |
+
def _mix_batch_collate(self, output, batch):
|
281 |
+
batch_size = len(batch)
|
282 |
+
lam, use_cutmix = self._params_per_batch()
|
283 |
+
if use_cutmix:
|
284 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
285 |
+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
286 |
+
for i in range(batch_size):
|
287 |
+
j = batch_size - i - 1
|
288 |
+
mixed = batch[i][0]
|
289 |
+
if lam != 1.:
|
290 |
+
if use_cutmix:
|
291 |
+
mixed = mixed.copy() # don't want to modify the original while iterating
|
292 |
+
mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh]
|
293 |
+
else:
|
294 |
+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
|
295 |
+
np.rint(mixed, out=mixed)
|
296 |
+
output[i] += torch.from_numpy(mixed.astype(np.uint8))
|
297 |
+
return lam
|
298 |
+
|
299 |
+
def __call__(self, batch, _=None):
|
300 |
+
batch_size = len(batch)
|
301 |
+
assert batch_size % 2 == 0, 'Batch size should be even when using this'
|
302 |
+
half = 'half' in self.mode
|
303 |
+
if half:
|
304 |
+
batch_size //= 2
|
305 |
+
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
306 |
+
if self.mode == 'elem' or self.mode == 'half':
|
307 |
+
lam = self._mix_elem_collate(output, batch, half=half)
|
308 |
+
elif self.mode == 'pair':
|
309 |
+
lam = self._mix_pair_collate(output, batch)
|
310 |
+
else:
|
311 |
+
lam = self._mix_batch_collate(output, batch)
|
312 |
+
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
313 |
+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
|
314 |
+
target = target[:batch_size]
|
315 |
+
return output, target
|
316 |
+
|
modeling_finetune.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
7 |
+
from timm.models.registry import register_model
|
8 |
+
import torch.utils.checkpoint as checkpoint
|
9 |
+
|
10 |
+
|
11 |
+
def _cfg(url='', **kwargs):
|
12 |
+
return {
|
13 |
+
'url': url,
|
14 |
+
'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None,
|
15 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
16 |
+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
17 |
+
**kwargs
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
class DropPath(nn.Module):
|
22 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
23 |
+
"""
|
24 |
+
def __init__(self, drop_prob=None):
|
25 |
+
super(DropPath, self).__init__()
|
26 |
+
self.drop_prob = drop_prob
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return drop_path(x, self.drop_prob, self.training)
|
30 |
+
|
31 |
+
def extra_repr(self) -> str:
|
32 |
+
return 'p={}'.format(self.drop_prob)
|
33 |
+
|
34 |
+
|
35 |
+
class Mlp(nn.Module):
|
36 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
37 |
+
super().__init__()
|
38 |
+
out_features = out_features or in_features
|
39 |
+
hidden_features = hidden_features or in_features
|
40 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
41 |
+
self.act = act_layer()
|
42 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
43 |
+
self.drop = nn.Dropout(drop)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
x = self.fc1(x)
|
47 |
+
x = self.act(x)
|
48 |
+
# x = self.drop(x)
|
49 |
+
# commit this for the orignal BERT implement
|
50 |
+
x = self.fc2(x)
|
51 |
+
x = self.drop(x)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class Attention(nn.Module):
|
56 |
+
def __init__(
|
57 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
58 |
+
proj_drop=0., attn_head_dim=None):
|
59 |
+
super().__init__()
|
60 |
+
self.num_heads = num_heads
|
61 |
+
head_dim = dim // num_heads
|
62 |
+
if attn_head_dim is not None:
|
63 |
+
head_dim = attn_head_dim
|
64 |
+
all_head_dim = head_dim * self.num_heads
|
65 |
+
self.scale = qk_scale or head_dim ** -0.5
|
66 |
+
|
67 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
68 |
+
if qkv_bias:
|
69 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
70 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
71 |
+
else:
|
72 |
+
self.q_bias = None
|
73 |
+
self.v_bias = None
|
74 |
+
|
75 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
76 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
77 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
B, N, C = x.shape
|
81 |
+
qkv_bias = None
|
82 |
+
if self.q_bias is not None:
|
83 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
84 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
85 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
86 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
87 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
88 |
+
|
89 |
+
q = q * self.scale
|
90 |
+
attn = (q @ k.transpose(-2, -1))
|
91 |
+
|
92 |
+
|
93 |
+
attn = attn.softmax(dim=-1)
|
94 |
+
attn = self.attn_drop(attn)
|
95 |
+
|
96 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
97 |
+
x = self.proj(x)
|
98 |
+
x = self.proj_drop(x)
|
99 |
+
return x
|
100 |
+
|
101 |
+
|
102 |
+
class Block(nn.Module):
|
103 |
+
|
104 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
105 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
106 |
+
attn_head_dim=None):
|
107 |
+
super().__init__()
|
108 |
+
self.norm1 = norm_layer(dim)
|
109 |
+
self.attn = Attention(
|
110 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
111 |
+
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
|
112 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
113 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
114 |
+
self.norm2 = norm_layer(dim)
|
115 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
116 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
117 |
+
|
118 |
+
if init_values > 0:
|
119 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
120 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
121 |
+
else:
|
122 |
+
self.gamma_1, self.gamma_2 = None, None
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
if self.gamma_1 is None:
|
126 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
127 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
128 |
+
else:
|
129 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
130 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
131 |
+
return x
|
132 |
+
|
133 |
+
|
134 |
+
class PatchEmbed(nn.Module):
|
135 |
+
""" Image to Patch Embedding
|
136 |
+
"""
|
137 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
|
138 |
+
super().__init__()
|
139 |
+
img_size = to_2tuple(img_size)
|
140 |
+
patch_size = to_2tuple(patch_size)
|
141 |
+
self.tubelet_size = int(tubelet_size)
|
142 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
|
143 |
+
self.img_size = img_size
|
144 |
+
self.patch_size = patch_size
|
145 |
+
self.num_patches = num_patches
|
146 |
+
self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim,
|
147 |
+
kernel_size = (self.tubelet_size, patch_size[0],patch_size[1]),
|
148 |
+
stride=(self.tubelet_size, patch_size[0], patch_size[1]))
|
149 |
+
|
150 |
+
def forward(self, x, **kwargs):
|
151 |
+
B, C, T, H, W = x.shape
|
152 |
+
# FIXME look at relaxing size constraints
|
153 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
154 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
155 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
156 |
+
return x
|
157 |
+
|
158 |
+
# sin-cos position encoding
|
159 |
+
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
|
160 |
+
def get_sinusoid_encoding_table(n_position, d_hid):
|
161 |
+
''' Sinusoid position encoding table '''
|
162 |
+
# TODO: make it with torch instead of numpy
|
163 |
+
def get_position_angle_vec(position):
|
164 |
+
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
165 |
+
|
166 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
167 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
168 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
169 |
+
|
170 |
+
return torch.tensor(sinusoid_table,dtype=torch.float, requires_grad=False).unsqueeze(0)
|
171 |
+
|
172 |
+
|
173 |
+
class VisionTransformer(nn.Module):
|
174 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
175 |
+
"""
|
176 |
+
def __init__(self,
|
177 |
+
img_size=224,
|
178 |
+
patch_size=16,
|
179 |
+
in_chans=3,
|
180 |
+
num_classes=1000,
|
181 |
+
embed_dim=768,
|
182 |
+
depth=12,
|
183 |
+
num_heads=12,
|
184 |
+
mlp_ratio=4.,
|
185 |
+
qkv_bias=False,
|
186 |
+
qk_scale=None,
|
187 |
+
fc_drop_rate=0.,
|
188 |
+
drop_rate=0.,
|
189 |
+
attn_drop_rate=0.,
|
190 |
+
drop_path_rate=0.,
|
191 |
+
norm_layer=nn.LayerNorm,
|
192 |
+
init_values=0.,
|
193 |
+
use_learnable_pos_emb=False,
|
194 |
+
init_scale=0.,
|
195 |
+
all_frames=16,
|
196 |
+
tubelet_size=2,
|
197 |
+
use_checkpoint=False,
|
198 |
+
use_mean_pooling=True,
|
199 |
+
pretrained_cfg=None,
|
200 |
+
pretrained_cfg_overlay = None
|
201 |
+
):
|
202 |
+
super().__init__()
|
203 |
+
self.num_classes = num_classes
|
204 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
205 |
+
self.tubelet_size = tubelet_size
|
206 |
+
self.patch_embed = PatchEmbed(
|
207 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=all_frames, tubelet_size=self.tubelet_size)
|
208 |
+
num_patches = self.patch_embed.num_patches
|
209 |
+
self.use_checkpoint = use_checkpoint
|
210 |
+
|
211 |
+
if use_learnable_pos_emb:
|
212 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
213 |
+
else:
|
214 |
+
# sine-cosine positional embeddings is on the way
|
215 |
+
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
|
216 |
+
|
217 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
218 |
+
|
219 |
+
|
220 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
221 |
+
self.blocks = nn.ModuleList([
|
222 |
+
Block(
|
223 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
224 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
225 |
+
init_values=init_values)
|
226 |
+
for i in range(depth)])
|
227 |
+
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
228 |
+
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
229 |
+
self.fc_dropout = nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity()
|
230 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
231 |
+
|
232 |
+
if use_learnable_pos_emb:
|
233 |
+
trunc_normal_(self.pos_embed, std=.02)
|
234 |
+
|
235 |
+
trunc_normal_(self.head.weight, std=.02)
|
236 |
+
self.apply(self._init_weights)
|
237 |
+
|
238 |
+
self.head.weight.data.mul_(init_scale)
|
239 |
+
self.head.bias.data.mul_(init_scale)
|
240 |
+
|
241 |
+
def _init_weights(self, m):
|
242 |
+
if isinstance(m, nn.Linear):
|
243 |
+
trunc_normal_(m.weight, std=.02)
|
244 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
245 |
+
nn.init.constant_(m.bias, 0)
|
246 |
+
elif isinstance(m, nn.LayerNorm):
|
247 |
+
nn.init.constant_(m.bias, 0)
|
248 |
+
nn.init.constant_(m.weight, 1.0)
|
249 |
+
|
250 |
+
def get_num_layers(self):
|
251 |
+
return len(self.blocks)
|
252 |
+
|
253 |
+
@torch.jit.ignore
|
254 |
+
def no_weight_decay(self):
|
255 |
+
return {'pos_embed', 'cls_token'}
|
256 |
+
|
257 |
+
def get_classifier(self):
|
258 |
+
return self.head
|
259 |
+
|
260 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
261 |
+
self.num_classes = num_classes
|
262 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
263 |
+
|
264 |
+
def forward_features(self, x):
|
265 |
+
x = self.patch_embed(x)
|
266 |
+
B, _, _ = x.size()
|
267 |
+
|
268 |
+
if self.pos_embed is not None:
|
269 |
+
x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
|
270 |
+
x = self.pos_drop(x)
|
271 |
+
|
272 |
+
if self.use_checkpoint:
|
273 |
+
for blk in self.blocks:
|
274 |
+
x = checkpoint.checkpoint(blk, x)
|
275 |
+
else:
|
276 |
+
for blk in self.blocks:
|
277 |
+
x = blk(x)
|
278 |
+
|
279 |
+
x = self.norm(x)
|
280 |
+
if self.fc_norm is not None:
|
281 |
+
return self.fc_norm(x.mean(1))
|
282 |
+
else:
|
283 |
+
return x[:, 0]
|
284 |
+
|
285 |
+
def forward(self, x):
|
286 |
+
x = self.forward_features(x)
|
287 |
+
x = self.head(self.fc_dropout(x))
|
288 |
+
return x
|
289 |
+
|
290 |
+
|
291 |
+
@register_model
|
292 |
+
def vit_small_patch16_224(pretrained=False, **kwargs):
|
293 |
+
model = VisionTransformer(
|
294 |
+
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
295 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
296 |
+
model.default_cfg = _cfg()
|
297 |
+
return model
|
298 |
+
|
299 |
+
|
300 |
+
@register_model
|
301 |
+
def vit_base_patch16_224(pretrained=False, **kwargs):
|
302 |
+
model = VisionTransformer(
|
303 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
304 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
305 |
+
model.default_cfg = _cfg()
|
306 |
+
return model
|
307 |
+
|
308 |
+
|
309 |
+
@register_model
|
310 |
+
def vit_base_patch16_384(pretrained=False, **kwargs):
|
311 |
+
model = VisionTransformer(
|
312 |
+
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
313 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
314 |
+
model.default_cfg = _cfg()
|
315 |
+
return model
|
316 |
+
|
317 |
+
|
318 |
+
@register_model
|
319 |
+
def vit_large_patch16_224(pretrained=False, **kwargs):
|
320 |
+
model = VisionTransformer(
|
321 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
322 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
323 |
+
model.default_cfg = _cfg()
|
324 |
+
return model
|
325 |
+
|
326 |
+
|
327 |
+
@register_model
|
328 |
+
def vit_large_patch16_384(pretrained=False, **kwargs):
|
329 |
+
model = VisionTransformer(
|
330 |
+
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
331 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
332 |
+
model.default_cfg = _cfg()
|
333 |
+
return model
|
334 |
+
|
335 |
+
|
336 |
+
@register_model
|
337 |
+
def vit_large_patch16_512(pretrained=False, **kwargs):
|
338 |
+
model = VisionTransformer(
|
339 |
+
img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
340 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
341 |
+
model.default_cfg = _cfg()
|
342 |
+
return model
|
343 |
+
|
344 |
+
|
345 |
+
@register_model
|
346 |
+
def vit_huge_patch16_224(pretrained=False, **kwargs):
|
347 |
+
model = VisionTransformer(
|
348 |
+
patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
349 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
350 |
+
model.default_cfg = _cfg()
|
351 |
+
return model
|
modeling_pretrain.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.utils.checkpoint as checkpoint
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
from modeling_finetune import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table
|
9 |
+
from timm.models.registry import register_model
|
10 |
+
from timm.models.layers import trunc_normal_ as __call_trunc_normal_
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
def trunc_normal_(tensor, mean=0., std=1.):
|
15 |
+
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
|
16 |
+
|
17 |
+
|
18 |
+
__all__ = [
|
19 |
+
'pretrain_videomae_small_patch16_224',
|
20 |
+
'pretrain_videomae_base_patch16_224',
|
21 |
+
'pretrain_videomae_large_patch16_224',
|
22 |
+
'pretrain_videomae_huge_patch16_224',
|
23 |
+
]
|
24 |
+
|
25 |
+
|
26 |
+
class PretrainVisionTransformerEncoder(nn.Module):
|
27 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
28 |
+
"""
|
29 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
|
30 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
31 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2, use_checkpoint=False,
|
32 |
+
use_learnable_pos_emb=False):
|
33 |
+
super().__init__()
|
34 |
+
self.num_classes = num_classes
|
35 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
36 |
+
self.patch_embed = PatchEmbed(
|
37 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,tubelet_size=tubelet_size)
|
38 |
+
num_patches = self.patch_embed.num_patches
|
39 |
+
self.use_checkpoint = use_checkpoint
|
40 |
+
|
41 |
+
|
42 |
+
# TODO: Add the cls token
|
43 |
+
if use_learnable_pos_emb:
|
44 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
45 |
+
else:
|
46 |
+
# sine-cosine positional embeddings
|
47 |
+
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
|
48 |
+
|
49 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
50 |
+
self.blocks = nn.ModuleList([
|
51 |
+
Block(
|
52 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
53 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
54 |
+
init_values=init_values)
|
55 |
+
for i in range(depth)])
|
56 |
+
self.norm = norm_layer(embed_dim)
|
57 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
58 |
+
|
59 |
+
if use_learnable_pos_emb:
|
60 |
+
trunc_normal_(self.pos_embed, std=.02)
|
61 |
+
|
62 |
+
self.apply(self._init_weights)
|
63 |
+
|
64 |
+
|
65 |
+
def _init_weights(self, m):
|
66 |
+
if isinstance(m, nn.Linear):
|
67 |
+
nn.init.xavier_uniform_(m.weight)
|
68 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
69 |
+
nn.init.constant_(m.bias, 0)
|
70 |
+
elif isinstance(m, nn.LayerNorm):
|
71 |
+
nn.init.constant_(m.bias, 0)
|
72 |
+
nn.init.constant_(m.weight, 1.0)
|
73 |
+
|
74 |
+
def get_num_layers(self):
|
75 |
+
return len(self.blocks)
|
76 |
+
|
77 |
+
@torch.jit.ignore
|
78 |
+
def no_weight_decay(self):
|
79 |
+
return {'pos_embed', 'cls_token'}
|
80 |
+
|
81 |
+
def get_classifier(self):
|
82 |
+
return self.head
|
83 |
+
|
84 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
85 |
+
self.num_classes = num_classes
|
86 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
87 |
+
|
88 |
+
def forward_features(self, x, mask):
|
89 |
+
_, _, T, _, _ = x.shape
|
90 |
+
x = self.patch_embed(x)
|
91 |
+
|
92 |
+
x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
|
93 |
+
|
94 |
+
B, _, C = x.shape
|
95 |
+
x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
|
96 |
+
|
97 |
+
if self.use_checkpoint:
|
98 |
+
for blk in self.blocks:
|
99 |
+
x_vis = checkpoint.checkpoint(blk, x_vis)
|
100 |
+
else:
|
101 |
+
for blk in self.blocks:
|
102 |
+
x_vis = blk(x_vis)
|
103 |
+
|
104 |
+
x_vis = self.norm(x_vis)
|
105 |
+
return x_vis
|
106 |
+
|
107 |
+
def forward(self, x, mask):
|
108 |
+
x = self.forward_features(x, mask)
|
109 |
+
x = self.head(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
class PretrainVisionTransformerDecoder(nn.Module):
|
113 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
114 |
+
"""
|
115 |
+
def __init__(self, patch_size=16, num_classes=768, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
|
116 |
+
qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
117 |
+
norm_layer=nn.LayerNorm, init_values=None, num_patches=196, tubelet_size=2, use_checkpoint=False
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
self.num_classes = num_classes
|
121 |
+
#assert num_classes == 3 * tubelet_size * patch_size ** 2
|
122 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
123 |
+
self.patch_size = patch_size
|
124 |
+
self.use_checkpoint = use_checkpoint
|
125 |
+
|
126 |
+
|
127 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
128 |
+
self.blocks = nn.ModuleList([
|
129 |
+
Block(
|
130 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
131 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
132 |
+
init_values=init_values)
|
133 |
+
for i in range(depth)])
|
134 |
+
self.norm = norm_layer(embed_dim)
|
135 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
136 |
+
|
137 |
+
self.apply(self._init_weights)
|
138 |
+
|
139 |
+
|
140 |
+
def _init_weights(self, m):
|
141 |
+
if isinstance(m, nn.Linear):
|
142 |
+
nn.init.xavier_uniform_(m.weight)
|
143 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
144 |
+
nn.init.constant_(m.bias, 0)
|
145 |
+
elif isinstance(m, nn.LayerNorm):
|
146 |
+
nn.init.constant_(m.bias, 0)
|
147 |
+
nn.init.constant_(m.weight, 1.0)
|
148 |
+
|
149 |
+
def get_num_layers(self):
|
150 |
+
return len(self.blocks)
|
151 |
+
|
152 |
+
@torch.jit.ignore
|
153 |
+
def no_weight_decay(self):
|
154 |
+
return {'pos_embed', 'cls_token'}
|
155 |
+
|
156 |
+
def get_classifier(self):
|
157 |
+
return self.head
|
158 |
+
|
159 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
160 |
+
self.num_classes = num_classes
|
161 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
162 |
+
|
163 |
+
def forward(self, x, return_token_num):
|
164 |
+
if self.use_checkpoint:
|
165 |
+
for blk in self.blocks:
|
166 |
+
x = checkpoint.checkpoint(blk, x)
|
167 |
+
else:
|
168 |
+
for blk in self.blocks:
|
169 |
+
x = blk(x)
|
170 |
+
|
171 |
+
if return_token_num > 0:
|
172 |
+
x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
|
173 |
+
else:
|
174 |
+
x = self.head(self.norm(x))
|
175 |
+
|
176 |
+
return x
|
177 |
+
|
178 |
+
class FeatureExtractor(torch.nn.Module):
|
179 |
+
def __init__(self, vit_model, input_size, patch_size):
|
180 |
+
super(FeatureExtractor, self).__init__()
|
181 |
+
self.vit_model = vit_model
|
182 |
+
self.input_size = input_size
|
183 |
+
self.patch_size = patch_size
|
184 |
+
self.spatial_resolution = input_size // patch_size
|
185 |
+
assert self.spatial_resolution * patch_size == input_size
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
if self.patch_size == 14:
|
189 |
+
features = self.vit_model.forward_features(x)[:, 5:]
|
190 |
+
bs, np, dim = features.shape
|
191 |
+
features = features.reshape(bs, self.spatial_resolution, self.spatial_resolution, dim).permute(0, 3, 1, 2)
|
192 |
+
features = F.interpolate(features, size=(14, 14), mode='bilinear')
|
193 |
+
features = features.flatten(2, -1).permute(0, 2, 1)
|
194 |
+
else:
|
195 |
+
features = self.vit_model.forward_features(x)[:, 1:]
|
196 |
+
return features
|
197 |
+
|
198 |
+
class PretrainVisionTransformer(nn.Module):
|
199 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
200 |
+
"""
|
201 |
+
def __init__(self,
|
202 |
+
img_size=224,
|
203 |
+
patch_size=16,
|
204 |
+
encoder_in_chans=3,
|
205 |
+
encoder_num_classes=0,
|
206 |
+
encoder_embed_dim=768,
|
207 |
+
encoder_depth=12,
|
208 |
+
encoder_num_heads=12,
|
209 |
+
decoder_num_classes=1536, # decoder_num_classes=768,
|
210 |
+
decoder_embed_dim=512,
|
211 |
+
decoder_depth=8,
|
212 |
+
decoder_num_heads=8,
|
213 |
+
mlp_ratio=4.,
|
214 |
+
qkv_bias=False,
|
215 |
+
qk_scale=None,
|
216 |
+
drop_rate=0.,
|
217 |
+
attn_drop_rate=0.,
|
218 |
+
drop_path_rate=0.,
|
219 |
+
norm_layer=nn.LayerNorm,
|
220 |
+
init_values=0.,
|
221 |
+
use_learnable_pos_emb=False,
|
222 |
+
use_checkpoint=False,
|
223 |
+
tubelet_size=2,
|
224 |
+
num_classes=0, # avoid the error from create_fn in timm
|
225 |
+
in_chans=0, # avoid the error from create_fn in timm
|
226 |
+
pretrained_cfg=None, # avoid the error from create_fn in timm
|
227 |
+
pretrained_cfg_overlay=None, # avoid the error from create_fn in timm
|
228 |
+
):
|
229 |
+
super().__init__()
|
230 |
+
self.encoder = PretrainVisionTransformerEncoder(
|
231 |
+
img_size=img_size,
|
232 |
+
patch_size=patch_size,
|
233 |
+
in_chans=encoder_in_chans,
|
234 |
+
num_classes=encoder_num_classes,
|
235 |
+
embed_dim=encoder_embed_dim,
|
236 |
+
depth=encoder_depth,
|
237 |
+
num_heads=encoder_num_heads,
|
238 |
+
mlp_ratio=mlp_ratio,
|
239 |
+
qkv_bias=qkv_bias,
|
240 |
+
qk_scale=qk_scale,
|
241 |
+
drop_rate=drop_rate,
|
242 |
+
attn_drop_rate=attn_drop_rate,
|
243 |
+
drop_path_rate=drop_path_rate,
|
244 |
+
norm_layer=norm_layer,
|
245 |
+
init_values=init_values,
|
246 |
+
tubelet_size=tubelet_size,
|
247 |
+
use_checkpoint=use_checkpoint,
|
248 |
+
use_learnable_pos_emb=use_learnable_pos_emb)
|
249 |
+
|
250 |
+
self.decoder = PretrainVisionTransformerDecoder(
|
251 |
+
patch_size=patch_size,
|
252 |
+
num_patches=self.encoder.patch_embed.num_patches,
|
253 |
+
num_classes=decoder_num_classes,
|
254 |
+
embed_dim=decoder_embed_dim,
|
255 |
+
depth=decoder_depth,
|
256 |
+
num_heads=decoder_num_heads,
|
257 |
+
mlp_ratio=mlp_ratio,
|
258 |
+
qkv_bias=qkv_bias,
|
259 |
+
qk_scale=qk_scale,
|
260 |
+
drop_rate=drop_rate,
|
261 |
+
attn_drop_rate=attn_drop_rate,
|
262 |
+
drop_path_rate=drop_path_rate,
|
263 |
+
norm_layer=norm_layer,
|
264 |
+
init_values=init_values,
|
265 |
+
tubelet_size=tubelet_size,
|
266 |
+
use_checkpoint=use_checkpoint)
|
267 |
+
|
268 |
+
self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False)
|
269 |
+
|
270 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
271 |
+
|
272 |
+
self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim)
|
273 |
+
|
274 |
+
trunc_normal_(self.mask_token, std=.02)
|
275 |
+
|
276 |
+
|
277 |
+
def _init_weights(self, m):
|
278 |
+
if isinstance(m, nn.Linear):
|
279 |
+
nn.init.xavier_uniform_(m.weight)
|
280 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
281 |
+
nn.init.constant_(m.bias, 0)
|
282 |
+
elif isinstance(m, nn.LayerNorm):
|
283 |
+
nn.init.constant_(m.bias, 0)
|
284 |
+
nn.init.constant_(m.weight, 1.0)
|
285 |
+
|
286 |
+
def get_num_layers(self):
|
287 |
+
return len(self.blocks)
|
288 |
+
|
289 |
+
@torch.jit.ignore
|
290 |
+
def no_weight_decay(self):
|
291 |
+
return {'pos_embed', 'cls_token', 'mask_token'}
|
292 |
+
|
293 |
+
def forward(self, x, mask):
|
294 |
+
_, _, T, _, _ = x.shape
|
295 |
+
x_encoder = self.encoder(x, mask) # [B, N_vis, C_e]
|
296 |
+
x_vis = self.encoder_to_decoder(x_encoder) # [B, N_vis, C_d]
|
297 |
+
B, N, C = x_vis.shape
|
298 |
+
# we don't unshuffle the correct visible token order,
|
299 |
+
# but shuffle the pos embedding accorddingly.
|
300 |
+
expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
|
301 |
+
pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
|
302 |
+
pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
|
303 |
+
x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d]
|
304 |
+
x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]
|
305 |
+
return x
|
306 |
+
|
307 |
+
@register_model
|
308 |
+
def pretrain_videomae_small_patch16_224(pretrained=False, **kwargs):
|
309 |
+
model = PretrainVisionTransformer(
|
310 |
+
img_size=224,
|
311 |
+
patch_size=16,
|
312 |
+
encoder_embed_dim=384,
|
313 |
+
encoder_depth=12,
|
314 |
+
encoder_num_heads=6,
|
315 |
+
encoder_num_classes=0,
|
316 |
+
decoder_embed_dim=192,
|
317 |
+
decoder_num_heads=3,
|
318 |
+
mlp_ratio=4,
|
319 |
+
qkv_bias=True,
|
320 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
321 |
+
**kwargs)
|
322 |
+
model.default_cfg = _cfg()
|
323 |
+
if pretrained:
|
324 |
+
checkpoint = torch.load(
|
325 |
+
kwargs["init_ckpt"], map_location="cpu"
|
326 |
+
)
|
327 |
+
model.load_state_dict(checkpoint["model"])
|
328 |
+
return model
|
329 |
+
|
330 |
+
@register_model
|
331 |
+
def pretrain_videomae_base_patch16_224(pretrained=False, **kwargs):
|
332 |
+
model = PretrainVisionTransformer(
|
333 |
+
img_size=224,
|
334 |
+
patch_size=16,
|
335 |
+
encoder_embed_dim=768,
|
336 |
+
encoder_depth=12,
|
337 |
+
encoder_num_heads=12,
|
338 |
+
encoder_num_classes=0,
|
339 |
+
decoder_embed_dim=384,
|
340 |
+
decoder_num_heads=6,
|
341 |
+
mlp_ratio=4,
|
342 |
+
qkv_bias=True,
|
343 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
344 |
+
**kwargs)
|
345 |
+
model.default_cfg = _cfg()
|
346 |
+
if pretrained:
|
347 |
+
checkpoint = torch.load(
|
348 |
+
kwargs["init_ckpt"], map_location="cpu"
|
349 |
+
)
|
350 |
+
model.load_state_dict(checkpoint["model"])
|
351 |
+
return model
|
352 |
+
|
353 |
+
@register_model
|
354 |
+
def pretrain_videomae_large_patch16_224(pretrained=False, **kwargs):
|
355 |
+
model = PretrainVisionTransformer(
|
356 |
+
img_size=224,
|
357 |
+
patch_size=16,
|
358 |
+
encoder_embed_dim=1024,
|
359 |
+
encoder_depth=24,
|
360 |
+
encoder_num_heads=16,
|
361 |
+
encoder_num_classes=0,
|
362 |
+
decoder_embed_dim=512,
|
363 |
+
decoder_num_heads=8,
|
364 |
+
mlp_ratio=4,
|
365 |
+
qkv_bias=True,
|
366 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
367 |
+
**kwargs)
|
368 |
+
model.default_cfg = _cfg()
|
369 |
+
if pretrained:
|
370 |
+
checkpoint = torch.load(
|
371 |
+
kwargs["init_ckpt"], map_location="cpu"
|
372 |
+
)
|
373 |
+
model.load_state_dict(checkpoint["model"])
|
374 |
+
return model
|
375 |
+
|
376 |
+
@register_model
|
377 |
+
def pretrain_videomae_huge_patch16_224(pretrained=False, **kwargs):
|
378 |
+
model = PretrainVisionTransformer(
|
379 |
+
img_size=224,
|
380 |
+
patch_size=16,
|
381 |
+
encoder_embed_dim=1280,
|
382 |
+
encoder_depth=32,
|
383 |
+
encoder_num_heads=16,
|
384 |
+
encoder_num_classes=0,
|
385 |
+
decoder_num_classes=1536,
|
386 |
+
decoder_embed_dim=640,
|
387 |
+
decoder_num_heads=8,
|
388 |
+
mlp_ratio=4,
|
389 |
+
qkv_bias=True,
|
390 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
391 |
+
**kwargs)
|
392 |
+
model.default_cfg = _cfg()
|
393 |
+
if pretrained:
|
394 |
+
checkpoint = torch.load(
|
395 |
+
kwargs["init_ckpt"], map_location="cpu"
|
396 |
+
)
|
397 |
+
model.load_state_dict(checkpoint["model"])
|
398 |
+
return model
|
optim_factory.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import optim as optim
|
3 |
+
|
4 |
+
from timm.optim.adafactor import Adafactor
|
5 |
+
from timm.optim.adahessian import Adahessian
|
6 |
+
from timm.optim.adamp import AdamP
|
7 |
+
from timm.optim.lookahead import Lookahead
|
8 |
+
from timm.optim.nadam import Nadam
|
9 |
+
#from timm.optim.novograd import NovoGrad
|
10 |
+
from timm.optim.nvnovograd import NvNovoGrad
|
11 |
+
from timm.optim.radam import RAdam
|
12 |
+
from timm.optim.rmsprop_tf import RMSpropTF
|
13 |
+
from timm.optim.sgdp import SGDP
|
14 |
+
|
15 |
+
import json
|
16 |
+
|
17 |
+
try:
|
18 |
+
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
19 |
+
has_apex = True
|
20 |
+
except ImportError:
|
21 |
+
has_apex = False
|
22 |
+
|
23 |
+
|
24 |
+
def get_num_layer_for_vit(var_name, num_max_layer):
|
25 |
+
if var_name in ("cls_token", "mask_token", "pos_embed"):
|
26 |
+
return 0
|
27 |
+
elif var_name.startswith("patch_embed"):
|
28 |
+
return 0
|
29 |
+
elif var_name.startswith("rel_pos_bias"):
|
30 |
+
return num_max_layer - 1
|
31 |
+
elif var_name.startswith("blocks"):
|
32 |
+
layer_id = int(var_name.split('.')[1])
|
33 |
+
return layer_id + 1
|
34 |
+
else:
|
35 |
+
return num_max_layer - 1
|
36 |
+
|
37 |
+
|
38 |
+
class LayerDecayValueAssigner(object):
|
39 |
+
def __init__(self, values):
|
40 |
+
self.values = values
|
41 |
+
|
42 |
+
def get_scale(self, layer_id):
|
43 |
+
return self.values[layer_id]
|
44 |
+
|
45 |
+
def get_layer_id(self, var_name):
|
46 |
+
return get_num_layer_for_vit(var_name, len(self.values))
|
47 |
+
|
48 |
+
|
49 |
+
def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
|
50 |
+
parameter_group_names = {}
|
51 |
+
parameter_group_vars = {}
|
52 |
+
|
53 |
+
for name, param in model.named_parameters():
|
54 |
+
if not param.requires_grad:
|
55 |
+
continue # frozen weights
|
56 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
|
57 |
+
group_name = "no_decay"
|
58 |
+
this_weight_decay = 0.
|
59 |
+
else:
|
60 |
+
group_name = "decay"
|
61 |
+
this_weight_decay = weight_decay
|
62 |
+
if get_num_layer is not None:
|
63 |
+
layer_id = get_num_layer(name)
|
64 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
65 |
+
else:
|
66 |
+
layer_id = None
|
67 |
+
|
68 |
+
if group_name not in parameter_group_names:
|
69 |
+
if get_layer_scale is not None:
|
70 |
+
scale = get_layer_scale(layer_id)
|
71 |
+
else:
|
72 |
+
scale = 1.
|
73 |
+
|
74 |
+
parameter_group_names[group_name] = {
|
75 |
+
"weight_decay": this_weight_decay,
|
76 |
+
"params": [],
|
77 |
+
"lr_scale": scale
|
78 |
+
}
|
79 |
+
parameter_group_vars[group_name] = {
|
80 |
+
"weight_decay": this_weight_decay,
|
81 |
+
"params": [],
|
82 |
+
"lr_scale": scale
|
83 |
+
}
|
84 |
+
|
85 |
+
parameter_group_vars[group_name]["params"].append(param)
|
86 |
+
parameter_group_names[group_name]["params"].append(name)
|
87 |
+
print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
|
88 |
+
return list(parameter_group_vars.values())
|
89 |
+
|
90 |
+
|
91 |
+
def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
|
92 |
+
opt_lower = args.opt.lower()
|
93 |
+
weight_decay = args.weight_decay
|
94 |
+
if weight_decay and filter_bias_and_bn:
|
95 |
+
skip = {}
|
96 |
+
if skip_list is not None:
|
97 |
+
skip = skip_list
|
98 |
+
elif hasattr(model, 'no_weight_decay'):
|
99 |
+
skip = model.no_weight_decay()
|
100 |
+
parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
|
101 |
+
weight_decay = 0.
|
102 |
+
else:
|
103 |
+
parameters = model.parameters()
|
104 |
+
|
105 |
+
if 'fused' in opt_lower:
|
106 |
+
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
107 |
+
|
108 |
+
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
|
109 |
+
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
|
110 |
+
opt_args['eps'] = args.opt_eps
|
111 |
+
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
|
112 |
+
opt_args['betas'] = args.opt_betas
|
113 |
+
|
114 |
+
print("optimizer settings:", opt_args)
|
115 |
+
|
116 |
+
opt_split = opt_lower.split('_')
|
117 |
+
opt_lower = opt_split[-1]
|
118 |
+
if opt_lower == 'sgd' or opt_lower == 'nesterov':
|
119 |
+
opt_args.pop('eps', None)
|
120 |
+
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
121 |
+
elif opt_lower == 'momentum':
|
122 |
+
opt_args.pop('eps', None)
|
123 |
+
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
124 |
+
elif opt_lower == 'adam':
|
125 |
+
optimizer = optim.Adam(parameters, **opt_args)
|
126 |
+
elif opt_lower == 'adamw':
|
127 |
+
optimizer = optim.AdamW(parameters, **opt_args)
|
128 |
+
elif opt_lower == 'nadam':
|
129 |
+
optimizer = Nadam(parameters, **opt_args)
|
130 |
+
elif opt_lower == 'radam':
|
131 |
+
optimizer = RAdam(parameters, **opt_args)
|
132 |
+
elif opt_lower == 'adamp':
|
133 |
+
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
|
134 |
+
elif opt_lower == 'sgdp':
|
135 |
+
optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
136 |
+
elif opt_lower == 'adadelta':
|
137 |
+
optimizer = optim.Adadelta(parameters, **opt_args)
|
138 |
+
elif opt_lower == 'adafactor':
|
139 |
+
if not args.lr:
|
140 |
+
opt_args['lr'] = None
|
141 |
+
optimizer = Adafactor(parameters, **opt_args)
|
142 |
+
elif opt_lower == 'adahessian':
|
143 |
+
optimizer = Adahessian(parameters, **opt_args)
|
144 |
+
elif opt_lower == 'rmsprop':
|
145 |
+
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
|
146 |
+
elif opt_lower == 'rmsproptf':
|
147 |
+
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
|
148 |
+
elif opt_lower == 'novograd':
|
149 |
+
optimizer = NovoGrad(parameters, **opt_args)
|
150 |
+
elif opt_lower == 'nvnovograd':
|
151 |
+
optimizer = NvNovoGrad(parameters, **opt_args)
|
152 |
+
elif opt_lower == 'fusedsgd':
|
153 |
+
opt_args.pop('eps', None)
|
154 |
+
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
155 |
+
elif opt_lower == 'fusedmomentum':
|
156 |
+
opt_args.pop('eps', None)
|
157 |
+
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
158 |
+
elif opt_lower == 'fusedadam':
|
159 |
+
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
|
160 |
+
elif opt_lower == 'fusedadamw':
|
161 |
+
optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
|
162 |
+
elif opt_lower == 'fusedlamb':
|
163 |
+
optimizer = FusedLAMB(parameters, **opt_args)
|
164 |
+
elif opt_lower == 'fusednovograd':
|
165 |
+
opt_args.setdefault('betas', (0.95, 0.98))
|
166 |
+
optimizer = FusedNovoGrad(parameters, **opt_args)
|
167 |
+
else:
|
168 |
+
assert False and "Invalid optimizer"
|
169 |
+
raise ValueError
|
170 |
+
|
171 |
+
if len(opt_split) > 1:
|
172 |
+
if opt_split[0] == 'lookahead':
|
173 |
+
optimizer = Lookahead(optimizer)
|
174 |
+
|
175 |
+
return optimizer
|
rand_augment.py
ADDED
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This implementation is based on
|
3 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
|
4 |
+
pulished under an Apache License 2.0.
|
5 |
+
|
6 |
+
COMMENT FROM ORIGINAL:
|
7 |
+
AutoAugment, RandAugment, and AugMix for PyTorch
|
8 |
+
This code implements the searched ImageNet policies with various tweaks and
|
9 |
+
improvements and does not include any of the search code. AA and RA
|
10 |
+
Implementation adapted from:
|
11 |
+
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
|
12 |
+
AugMix adapted from:
|
13 |
+
https://github.com/google-research/augmix
|
14 |
+
Papers:
|
15 |
+
AutoAugment: Learning Augmentation Policies from Data
|
16 |
+
https://arxiv.org/abs/1805.09501
|
17 |
+
Learning Data Augmentation Strategies for Object Detection
|
18 |
+
https://arxiv.org/abs/1906.11172
|
19 |
+
RandAugment: Practical automated data augmentation...
|
20 |
+
https://arxiv.org/abs/1909.13719
|
21 |
+
AugMix: A Simple Data Processing Method to Improve Robustness and
|
22 |
+
Uncertainty https://arxiv.org/abs/1912.02781
|
23 |
+
|
24 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
25 |
+
"""
|
26 |
+
|
27 |
+
import math
|
28 |
+
import numpy as np
|
29 |
+
import random
|
30 |
+
import re
|
31 |
+
import PIL
|
32 |
+
from PIL import Image, ImageEnhance, ImageOps
|
33 |
+
|
34 |
+
_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
|
35 |
+
|
36 |
+
_FILL = (128, 128, 128)
|
37 |
+
|
38 |
+
# This signifies the max integer that the controller RNN could predict for the
|
39 |
+
# augmentation scheme.
|
40 |
+
_MAX_LEVEL = 10.0
|
41 |
+
|
42 |
+
_HPARAMS_DEFAULT = {
|
43 |
+
"translate_const": 250,
|
44 |
+
"img_mean": _FILL,
|
45 |
+
}
|
46 |
+
|
47 |
+
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
48 |
+
|
49 |
+
|
50 |
+
def _interpolation(kwargs):
|
51 |
+
interpolation = kwargs.pop("resample", Image.BILINEAR)
|
52 |
+
if isinstance(interpolation, (list, tuple)):
|
53 |
+
return random.choice(interpolation)
|
54 |
+
else:
|
55 |
+
return interpolation
|
56 |
+
|
57 |
+
|
58 |
+
def _check_args_tf(kwargs):
|
59 |
+
if "fillcolor" in kwargs and _PIL_VER < (5, 0):
|
60 |
+
kwargs.pop("fillcolor")
|
61 |
+
kwargs["resample"] = _interpolation(kwargs)
|
62 |
+
|
63 |
+
|
64 |
+
def shear_x(img, factor, **kwargs):
|
65 |
+
_check_args_tf(kwargs)
|
66 |
+
return img.transform(
|
67 |
+
img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def shear_y(img, factor, **kwargs):
|
72 |
+
_check_args_tf(kwargs)
|
73 |
+
return img.transform(
|
74 |
+
img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def translate_x_rel(img, pct, **kwargs):
|
79 |
+
pixels = pct * img.size[0]
|
80 |
+
_check_args_tf(kwargs)
|
81 |
+
return img.transform(
|
82 |
+
img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
def translate_y_rel(img, pct, **kwargs):
|
87 |
+
pixels = pct * img.size[1]
|
88 |
+
_check_args_tf(kwargs)
|
89 |
+
return img.transform(
|
90 |
+
img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
def translate_x_abs(img, pixels, **kwargs):
|
95 |
+
_check_args_tf(kwargs)
|
96 |
+
return img.transform(
|
97 |
+
img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
|
98 |
+
)
|
99 |
+
|
100 |
+
|
101 |
+
def translate_y_abs(img, pixels, **kwargs):
|
102 |
+
_check_args_tf(kwargs)
|
103 |
+
return img.transform(
|
104 |
+
img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
def rotate(img, degrees, **kwargs):
|
109 |
+
_check_args_tf(kwargs)
|
110 |
+
if _PIL_VER >= (5, 2):
|
111 |
+
return img.rotate(degrees, **kwargs)
|
112 |
+
elif _PIL_VER >= (5, 0):
|
113 |
+
w, h = img.size
|
114 |
+
post_trans = (0, 0)
|
115 |
+
rotn_center = (w / 2.0, h / 2.0)
|
116 |
+
angle = -math.radians(degrees)
|
117 |
+
matrix = [
|
118 |
+
round(math.cos(angle), 15),
|
119 |
+
round(math.sin(angle), 15),
|
120 |
+
0.0,
|
121 |
+
round(-math.sin(angle), 15),
|
122 |
+
round(math.cos(angle), 15),
|
123 |
+
0.0,
|
124 |
+
]
|
125 |
+
|
126 |
+
def transform(x, y, matrix):
|
127 |
+
(a, b, c, d, e, f) = matrix
|
128 |
+
return a * x + b * y + c, d * x + e * y + f
|
129 |
+
|
130 |
+
matrix[2], matrix[5] = transform(
|
131 |
+
-rotn_center[0] - post_trans[0],
|
132 |
+
-rotn_center[1] - post_trans[1],
|
133 |
+
matrix,
|
134 |
+
)
|
135 |
+
matrix[2] += rotn_center[0]
|
136 |
+
matrix[5] += rotn_center[1]
|
137 |
+
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
138 |
+
else:
|
139 |
+
return img.rotate(degrees, resample=kwargs["resample"])
|
140 |
+
|
141 |
+
|
142 |
+
def auto_contrast(img, **__):
|
143 |
+
return ImageOps.autocontrast(img)
|
144 |
+
|
145 |
+
|
146 |
+
def invert(img, **__):
|
147 |
+
return ImageOps.invert(img)
|
148 |
+
|
149 |
+
|
150 |
+
def equalize(img, **__):
|
151 |
+
return ImageOps.equalize(img)
|
152 |
+
|
153 |
+
|
154 |
+
def solarize(img, thresh, **__):
|
155 |
+
return ImageOps.solarize(img, thresh)
|
156 |
+
|
157 |
+
|
158 |
+
def solarize_add(img, add, thresh=128, **__):
|
159 |
+
lut = []
|
160 |
+
for i in range(256):
|
161 |
+
if i < thresh:
|
162 |
+
lut.append(min(255, i + add))
|
163 |
+
else:
|
164 |
+
lut.append(i)
|
165 |
+
if img.mode in ("L", "RGB"):
|
166 |
+
if img.mode == "RGB" and len(lut) == 256:
|
167 |
+
lut = lut + lut + lut
|
168 |
+
return img.point(lut)
|
169 |
+
else:
|
170 |
+
return img
|
171 |
+
|
172 |
+
|
173 |
+
def posterize(img, bits_to_keep, **__):
|
174 |
+
if bits_to_keep >= 8:
|
175 |
+
return img
|
176 |
+
return ImageOps.posterize(img, bits_to_keep)
|
177 |
+
|
178 |
+
|
179 |
+
def contrast(img, factor, **__):
|
180 |
+
return ImageEnhance.Contrast(img).enhance(factor)
|
181 |
+
|
182 |
+
|
183 |
+
def color(img, factor, **__):
|
184 |
+
return ImageEnhance.Color(img).enhance(factor)
|
185 |
+
|
186 |
+
|
187 |
+
def brightness(img, factor, **__):
|
188 |
+
return ImageEnhance.Brightness(img).enhance(factor)
|
189 |
+
|
190 |
+
|
191 |
+
def sharpness(img, factor, **__):
|
192 |
+
return ImageEnhance.Sharpness(img).enhance(factor)
|
193 |
+
|
194 |
+
|
195 |
+
def _randomly_negate(v):
|
196 |
+
"""With 50% prob, negate the value"""
|
197 |
+
return -v if random.random() > 0.5 else v
|
198 |
+
|
199 |
+
|
200 |
+
def _rotate_level_to_arg(level, _hparams):
|
201 |
+
# range [-30, 30]
|
202 |
+
level = (level / _MAX_LEVEL) * 30.0
|
203 |
+
level = _randomly_negate(level)
|
204 |
+
return (level,)
|
205 |
+
|
206 |
+
|
207 |
+
def _enhance_level_to_arg(level, _hparams):
|
208 |
+
# range [0.1, 1.9]
|
209 |
+
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
|
210 |
+
|
211 |
+
|
212 |
+
def _enhance_increasing_level_to_arg(level, _hparams):
|
213 |
+
# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
|
214 |
+
# range [0.1, 1.9]
|
215 |
+
level = (level / _MAX_LEVEL) * 0.9
|
216 |
+
level = 1.0 + _randomly_negate(level)
|
217 |
+
return (level,)
|
218 |
+
|
219 |
+
|
220 |
+
def _shear_level_to_arg(level, _hparams):
|
221 |
+
# range [-0.3, 0.3]
|
222 |
+
level = (level / _MAX_LEVEL) * 0.3
|
223 |
+
level = _randomly_negate(level)
|
224 |
+
return (level,)
|
225 |
+
|
226 |
+
|
227 |
+
def _translate_abs_level_to_arg(level, hparams):
|
228 |
+
translate_const = hparams["translate_const"]
|
229 |
+
level = (level / _MAX_LEVEL) * float(translate_const)
|
230 |
+
level = _randomly_negate(level)
|
231 |
+
return (level,)
|
232 |
+
|
233 |
+
|
234 |
+
def _translate_rel_level_to_arg(level, hparams):
|
235 |
+
# default range [-0.45, 0.45]
|
236 |
+
translate_pct = hparams.get("translate_pct", 0.45)
|
237 |
+
level = (level / _MAX_LEVEL) * translate_pct
|
238 |
+
level = _randomly_negate(level)
|
239 |
+
return (level,)
|
240 |
+
|
241 |
+
|
242 |
+
def _posterize_level_to_arg(level, _hparams):
|
243 |
+
# As per Tensorflow TPU EfficientNet impl
|
244 |
+
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
245 |
+
# intensity/severity of augmentation decreases with level
|
246 |
+
return (int((level / _MAX_LEVEL) * 4),)
|
247 |
+
|
248 |
+
|
249 |
+
def _posterize_increasing_level_to_arg(level, hparams):
|
250 |
+
# As per Tensorflow models research and UDA impl
|
251 |
+
# range [4, 0], 'keep 4 down to 0 MSB of original image',
|
252 |
+
# intensity/severity of augmentation increases with level
|
253 |
+
return (4 - _posterize_level_to_arg(level, hparams)[0],)
|
254 |
+
|
255 |
+
|
256 |
+
def _posterize_original_level_to_arg(level, _hparams):
|
257 |
+
# As per original AutoAugment paper description
|
258 |
+
# range [4, 8], 'keep 4 up to 8 MSB of image'
|
259 |
+
# intensity/severity of augmentation decreases with level
|
260 |
+
return (int((level / _MAX_LEVEL) * 4) + 4,)
|
261 |
+
|
262 |
+
|
263 |
+
def _solarize_level_to_arg(level, _hparams):
|
264 |
+
# range [0, 256]
|
265 |
+
# intensity/severity of augmentation decreases with level
|
266 |
+
return (int((level / _MAX_LEVEL) * 256),)
|
267 |
+
|
268 |
+
|
269 |
+
def _solarize_increasing_level_to_arg(level, _hparams):
|
270 |
+
# range [0, 256]
|
271 |
+
# intensity/severity of augmentation increases with level
|
272 |
+
return (256 - _solarize_level_to_arg(level, _hparams)[0],)
|
273 |
+
|
274 |
+
|
275 |
+
def _solarize_add_level_to_arg(level, _hparams):
|
276 |
+
# range [0, 110]
|
277 |
+
return (int((level / _MAX_LEVEL) * 110),)
|
278 |
+
|
279 |
+
|
280 |
+
LEVEL_TO_ARG = {
|
281 |
+
"AutoContrast": None,
|
282 |
+
"Equalize": None,
|
283 |
+
"Invert": None,
|
284 |
+
"Rotate": _rotate_level_to_arg,
|
285 |
+
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
|
286 |
+
"Posterize": _posterize_level_to_arg,
|
287 |
+
"PosterizeIncreasing": _posterize_increasing_level_to_arg,
|
288 |
+
"PosterizeOriginal": _posterize_original_level_to_arg,
|
289 |
+
"Solarize": _solarize_level_to_arg,
|
290 |
+
"SolarizeIncreasing": _solarize_increasing_level_to_arg,
|
291 |
+
"SolarizeAdd": _solarize_add_level_to_arg,
|
292 |
+
"Color": _enhance_level_to_arg,
|
293 |
+
"ColorIncreasing": _enhance_increasing_level_to_arg,
|
294 |
+
"Contrast": _enhance_level_to_arg,
|
295 |
+
"ContrastIncreasing": _enhance_increasing_level_to_arg,
|
296 |
+
"Brightness": _enhance_level_to_arg,
|
297 |
+
"BrightnessIncreasing": _enhance_increasing_level_to_arg,
|
298 |
+
"Sharpness": _enhance_level_to_arg,
|
299 |
+
"SharpnessIncreasing": _enhance_increasing_level_to_arg,
|
300 |
+
"ShearX": _shear_level_to_arg,
|
301 |
+
"ShearY": _shear_level_to_arg,
|
302 |
+
"TranslateX": _translate_abs_level_to_arg,
|
303 |
+
"TranslateY": _translate_abs_level_to_arg,
|
304 |
+
"TranslateXRel": _translate_rel_level_to_arg,
|
305 |
+
"TranslateYRel": _translate_rel_level_to_arg,
|
306 |
+
}
|
307 |
+
|
308 |
+
|
309 |
+
NAME_TO_OP = {
|
310 |
+
"AutoContrast": auto_contrast,
|
311 |
+
"Equalize": equalize,
|
312 |
+
"Invert": invert,
|
313 |
+
"Rotate": rotate,
|
314 |
+
"Posterize": posterize,
|
315 |
+
"PosterizeIncreasing": posterize,
|
316 |
+
"PosterizeOriginal": posterize,
|
317 |
+
"Solarize": solarize,
|
318 |
+
"SolarizeIncreasing": solarize,
|
319 |
+
"SolarizeAdd": solarize_add,
|
320 |
+
"Color": color,
|
321 |
+
"ColorIncreasing": color,
|
322 |
+
"Contrast": contrast,
|
323 |
+
"ContrastIncreasing": contrast,
|
324 |
+
"Brightness": brightness,
|
325 |
+
"BrightnessIncreasing": brightness,
|
326 |
+
"Sharpness": sharpness,
|
327 |
+
"SharpnessIncreasing": sharpness,
|
328 |
+
"ShearX": shear_x,
|
329 |
+
"ShearY": shear_y,
|
330 |
+
"TranslateX": translate_x_abs,
|
331 |
+
"TranslateY": translate_y_abs,
|
332 |
+
"TranslateXRel": translate_x_rel,
|
333 |
+
"TranslateYRel": translate_y_rel,
|
334 |
+
}
|
335 |
+
|
336 |
+
|
337 |
+
class AugmentOp:
|
338 |
+
"""
|
339 |
+
Apply for video.
|
340 |
+
"""
|
341 |
+
|
342 |
+
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
343 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
344 |
+
self.aug_fn = NAME_TO_OP[name]
|
345 |
+
self.level_fn = LEVEL_TO_ARG[name]
|
346 |
+
self.prob = prob
|
347 |
+
self.magnitude = magnitude
|
348 |
+
self.hparams = hparams.copy()
|
349 |
+
self.kwargs = {
|
350 |
+
"fillcolor": hparams["img_mean"]
|
351 |
+
if "img_mean" in hparams
|
352 |
+
else _FILL,
|
353 |
+
"resample": hparams["interpolation"]
|
354 |
+
if "interpolation" in hparams
|
355 |
+
else _RANDOM_INTERPOLATION,
|
356 |
+
}
|
357 |
+
|
358 |
+
# If magnitude_std is > 0, we introduce some randomness
|
359 |
+
# in the usually fixed policy and sample magnitude from a normal distribution
|
360 |
+
# with mean `magnitude` and std-dev of `magnitude_std`.
|
361 |
+
# NOTE This is my own hack, being tested, not in papers or reference impls.
|
362 |
+
self.magnitude_std = self.hparams.get("magnitude_std", 0)
|
363 |
+
|
364 |
+
def __call__(self, img_list):
|
365 |
+
if self.prob < 1.0 and random.random() > self.prob:
|
366 |
+
return img_list
|
367 |
+
magnitude = self.magnitude
|
368 |
+
if self.magnitude_std and self.magnitude_std > 0:
|
369 |
+
magnitude = random.gauss(magnitude, self.magnitude_std)
|
370 |
+
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
371 |
+
level_args = (
|
372 |
+
self.level_fn(magnitude, self.hparams)
|
373 |
+
if self.level_fn is not None
|
374 |
+
else ()
|
375 |
+
)
|
376 |
+
|
377 |
+
if isinstance(img_list, list):
|
378 |
+
return [
|
379 |
+
self.aug_fn(img, *level_args, **self.kwargs) for img in img_list
|
380 |
+
]
|
381 |
+
else:
|
382 |
+
return self.aug_fn(img_list, *level_args, **self.kwargs)
|
383 |
+
|
384 |
+
|
385 |
+
_RAND_TRANSFORMS = [
|
386 |
+
"AutoContrast",
|
387 |
+
"Equalize",
|
388 |
+
"Invert",
|
389 |
+
"Rotate",
|
390 |
+
"Posterize",
|
391 |
+
"Solarize",
|
392 |
+
"SolarizeAdd",
|
393 |
+
"Color",
|
394 |
+
"Contrast",
|
395 |
+
"Brightness",
|
396 |
+
"Sharpness",
|
397 |
+
"ShearX",
|
398 |
+
"ShearY",
|
399 |
+
"TranslateXRel",
|
400 |
+
"TranslateYRel",
|
401 |
+
]
|
402 |
+
|
403 |
+
|
404 |
+
_RAND_INCREASING_TRANSFORMS = [
|
405 |
+
"AutoContrast",
|
406 |
+
"Equalize",
|
407 |
+
"Invert",
|
408 |
+
"Rotate",
|
409 |
+
"PosterizeIncreasing",
|
410 |
+
"SolarizeIncreasing",
|
411 |
+
"SolarizeAdd",
|
412 |
+
"ColorIncreasing",
|
413 |
+
"ContrastIncreasing",
|
414 |
+
"BrightnessIncreasing",
|
415 |
+
"SharpnessIncreasing",
|
416 |
+
"ShearX",
|
417 |
+
"ShearY",
|
418 |
+
"TranslateXRel",
|
419 |
+
"TranslateYRel",
|
420 |
+
]
|
421 |
+
|
422 |
+
|
423 |
+
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
424 |
+
# They may not result in increased performance, but could likely be tuned to so.
|
425 |
+
_RAND_CHOICE_WEIGHTS_0 = {
|
426 |
+
"Rotate": 0.3,
|
427 |
+
"ShearX": 0.2,
|
428 |
+
"ShearY": 0.2,
|
429 |
+
"TranslateXRel": 0.1,
|
430 |
+
"TranslateYRel": 0.1,
|
431 |
+
"Color": 0.025,
|
432 |
+
"Sharpness": 0.025,
|
433 |
+
"AutoContrast": 0.025,
|
434 |
+
"Solarize": 0.005,
|
435 |
+
"SolarizeAdd": 0.005,
|
436 |
+
"Contrast": 0.005,
|
437 |
+
"Brightness": 0.005,
|
438 |
+
"Equalize": 0.005,
|
439 |
+
"Posterize": 0,
|
440 |
+
"Invert": 0,
|
441 |
+
}
|
442 |
+
|
443 |
+
|
444 |
+
def _select_rand_weights(weight_idx=0, transforms=None):
|
445 |
+
transforms = transforms or _RAND_TRANSFORMS
|
446 |
+
assert weight_idx == 0 # only one set of weights currently
|
447 |
+
rand_weights = _RAND_CHOICE_WEIGHTS_0
|
448 |
+
probs = [rand_weights[k] for k in transforms]
|
449 |
+
probs /= np.sum(probs)
|
450 |
+
return probs
|
451 |
+
|
452 |
+
|
453 |
+
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
454 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
455 |
+
transforms = transforms or _RAND_TRANSFORMS
|
456 |
+
return [
|
457 |
+
AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)
|
458 |
+
for name in transforms
|
459 |
+
]
|
460 |
+
|
461 |
+
|
462 |
+
class RandAugment:
|
463 |
+
def __init__(self, ops, num_layers=2, choice_weights=None):
|
464 |
+
self.ops = ops
|
465 |
+
self.num_layers = num_layers
|
466 |
+
self.choice_weights = choice_weights
|
467 |
+
|
468 |
+
def __call__(self, img):
|
469 |
+
# no replacement when using weighted choice
|
470 |
+
ops = np.random.choice(
|
471 |
+
self.ops,
|
472 |
+
self.num_layers,
|
473 |
+
replace=self.choice_weights is None,
|
474 |
+
p=self.choice_weights,
|
475 |
+
)
|
476 |
+
for op in ops:
|
477 |
+
img = op(img)
|
478 |
+
return img
|
479 |
+
|
480 |
+
|
481 |
+
def rand_augment_transform(config_str, hparams):
|
482 |
+
"""
|
483 |
+
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
|
484 |
+
|
485 |
+
Create a RandAugment transform
|
486 |
+
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
487 |
+
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
488 |
+
sections, not order sepecific determine
|
489 |
+
'm' - integer magnitude of rand augment
|
490 |
+
'n' - integer num layers (number of transform ops selected per image)
|
491 |
+
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
492 |
+
'mstd' - float std deviation of magnitude noise applied
|
493 |
+
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
494 |
+
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
495 |
+
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
496 |
+
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
|
497 |
+
:return: A PyTorch compatible Transform
|
498 |
+
"""
|
499 |
+
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
|
500 |
+
num_layers = 2 # default to 2 ops per image
|
501 |
+
weight_idx = None # default to no probability weights for op choice
|
502 |
+
transforms = _RAND_TRANSFORMS
|
503 |
+
config = config_str.split("-")
|
504 |
+
assert config[0] == "rand"
|
505 |
+
config = config[1:]
|
506 |
+
for c in config:
|
507 |
+
cs = re.split(r"(\d.*)", c)
|
508 |
+
if len(cs) < 2:
|
509 |
+
continue
|
510 |
+
key, val = cs[:2]
|
511 |
+
if key == "mstd":
|
512 |
+
# noise param injected via hparams for now
|
513 |
+
hparams.setdefault("magnitude_std", float(val))
|
514 |
+
elif key == "inc":
|
515 |
+
if bool(val):
|
516 |
+
transforms = _RAND_INCREASING_TRANSFORMS
|
517 |
+
elif key == "m":
|
518 |
+
magnitude = int(val)
|
519 |
+
elif key == "n":
|
520 |
+
num_layers = int(val)
|
521 |
+
elif key == "w":
|
522 |
+
weight_idx = int(val)
|
523 |
+
else:
|
524 |
+
assert NotImplementedError
|
525 |
+
ra_ops = rand_augment_ops(
|
526 |
+
magnitude=magnitude, hparams=hparams, transforms=transforms
|
527 |
+
)
|
528 |
+
choice_weights = (
|
529 |
+
None if weight_idx is None else _select_rand_weights(weight_idx)
|
530 |
+
)
|
531 |
+
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
random_erasing.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This implementation is based on
|
3 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
|
4 |
+
pulished under an Apache License 2.0.
|
5 |
+
"""
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def _get_pixels(
|
12 |
+
per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
|
13 |
+
):
|
14 |
+
# NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
|
15 |
+
# paths, flip the order so normal is run on CPU if this becomes a problem
|
16 |
+
# Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
|
17 |
+
if per_pixel:
|
18 |
+
return torch.empty(patch_size, dtype=dtype, device=device).normal_()
|
19 |
+
elif rand_color:
|
20 |
+
return torch.empty(
|
21 |
+
(patch_size[0], 1, 1), dtype=dtype, device=device
|
22 |
+
).normal_()
|
23 |
+
else:
|
24 |
+
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
|
25 |
+
|
26 |
+
|
27 |
+
class RandomErasing:
|
28 |
+
"""Randomly selects a rectangle region in an image and erases its pixels.
|
29 |
+
'Random Erasing Data Augmentation' by Zhong et al.
|
30 |
+
See https://arxiv.org/pdf/1708.04896.pdf
|
31 |
+
This variant of RandomErasing is intended to be applied to either a batch
|
32 |
+
or single image tensor after it has been normalized by dataset mean and std.
|
33 |
+
Args:
|
34 |
+
probability: Probability that the Random Erasing operation will be performed.
|
35 |
+
min_area: Minimum percentage of erased area wrt input image area.
|
36 |
+
max_area: Maximum percentage of erased area wrt input image area.
|
37 |
+
min_aspect: Minimum aspect ratio of erased area.
|
38 |
+
mode: pixel color mode, one of 'const', 'rand', or 'pixel'
|
39 |
+
'const' - erase block is constant color of 0 for all channels
|
40 |
+
'rand' - erase block is same per-channel random (normal) color
|
41 |
+
'pixel' - erase block is per-pixel random (normal) color
|
42 |
+
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
|
43 |
+
per-image count is randomly chosen between 1 and this value.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
probability=0.5,
|
49 |
+
min_area=0.02,
|
50 |
+
max_area=1 / 3,
|
51 |
+
min_aspect=0.3,
|
52 |
+
max_aspect=None,
|
53 |
+
mode="const",
|
54 |
+
min_count=1,
|
55 |
+
max_count=None,
|
56 |
+
num_splits=0,
|
57 |
+
device="cuda",
|
58 |
+
cube=True,
|
59 |
+
):
|
60 |
+
self.probability = probability
|
61 |
+
self.min_area = min_area
|
62 |
+
self.max_area = max_area
|
63 |
+
max_aspect = max_aspect or 1 / min_aspect
|
64 |
+
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
65 |
+
self.min_count = min_count
|
66 |
+
self.max_count = max_count or min_count
|
67 |
+
self.num_splits = num_splits
|
68 |
+
mode = mode.lower()
|
69 |
+
self.rand_color = False
|
70 |
+
self.per_pixel = False
|
71 |
+
self.cube = cube
|
72 |
+
if mode == "rand":
|
73 |
+
self.rand_color = True # per block random normal
|
74 |
+
elif mode == "pixel":
|
75 |
+
self.per_pixel = True # per pixel random normal
|
76 |
+
else:
|
77 |
+
assert not mode or mode == "const"
|
78 |
+
self.device = device
|
79 |
+
|
80 |
+
def _erase(self, img, chan, img_h, img_w, dtype):
|
81 |
+
if random.random() > self.probability:
|
82 |
+
return
|
83 |
+
area = img_h * img_w
|
84 |
+
count = (
|
85 |
+
self.min_count
|
86 |
+
if self.min_count == self.max_count
|
87 |
+
else random.randint(self.min_count, self.max_count)
|
88 |
+
)
|
89 |
+
for _ in range(count):
|
90 |
+
for _ in range(10):
|
91 |
+
target_area = (
|
92 |
+
random.uniform(self.min_area, self.max_area) * area / count
|
93 |
+
)
|
94 |
+
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
95 |
+
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
96 |
+
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
97 |
+
if w < img_w and h < img_h:
|
98 |
+
top = random.randint(0, img_h - h)
|
99 |
+
left = random.randint(0, img_w - w)
|
100 |
+
img[:, top : top + h, left : left + w] = _get_pixels(
|
101 |
+
self.per_pixel,
|
102 |
+
self.rand_color,
|
103 |
+
(chan, h, w),
|
104 |
+
dtype=dtype,
|
105 |
+
device=self.device,
|
106 |
+
)
|
107 |
+
break
|
108 |
+
|
109 |
+
def _erase_cube(
|
110 |
+
self,
|
111 |
+
img,
|
112 |
+
batch_start,
|
113 |
+
batch_size,
|
114 |
+
chan,
|
115 |
+
img_h,
|
116 |
+
img_w,
|
117 |
+
dtype,
|
118 |
+
):
|
119 |
+
if random.random() > self.probability:
|
120 |
+
return
|
121 |
+
area = img_h * img_w
|
122 |
+
count = (
|
123 |
+
self.min_count
|
124 |
+
if self.min_count == self.max_count
|
125 |
+
else random.randint(self.min_count, self.max_count)
|
126 |
+
)
|
127 |
+
for _ in range(count):
|
128 |
+
for _ in range(100):
|
129 |
+
target_area = (
|
130 |
+
random.uniform(self.min_area, self.max_area) * area / count
|
131 |
+
)
|
132 |
+
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
133 |
+
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
134 |
+
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
135 |
+
if w < img_w and h < img_h:
|
136 |
+
top = random.randint(0, img_h - h)
|
137 |
+
left = random.randint(0, img_w - w)
|
138 |
+
for i in range(batch_start, batch_size):
|
139 |
+
img_instance = img[i]
|
140 |
+
img_instance[
|
141 |
+
:, top : top + h, left : left + w
|
142 |
+
] = _get_pixels(
|
143 |
+
self.per_pixel,
|
144 |
+
self.rand_color,
|
145 |
+
(chan, h, w),
|
146 |
+
dtype=dtype,
|
147 |
+
device=self.device,
|
148 |
+
)
|
149 |
+
break
|
150 |
+
|
151 |
+
def __call__(self, input):
|
152 |
+
if len(input.size()) == 3:
|
153 |
+
self._erase(input, *input.size(), input.dtype)
|
154 |
+
else:
|
155 |
+
batch_size, chan, img_h, img_w = input.size()
|
156 |
+
# skip first slice of batch if num_splits is set (for clean portion of samples)
|
157 |
+
batch_start = (
|
158 |
+
batch_size // self.num_splits if self.num_splits > 1 else 0
|
159 |
+
)
|
160 |
+
if self.cube:
|
161 |
+
self._erase_cube(
|
162 |
+
input,
|
163 |
+
batch_start,
|
164 |
+
batch_size,
|
165 |
+
chan,
|
166 |
+
img_h,
|
167 |
+
img_w,
|
168 |
+
input.dtype,
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
for i in range(batch_start, batch_size):
|
172 |
+
self._erase(input[i], chan, img_h, img_w, input.dtype)
|
173 |
+
return input
|
run_class_finetuning.py
ADDED
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import numpy as np
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import torch.backends.cudnn as cudnn
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from functools import partial
|
10 |
+
from pathlib import Path
|
11 |
+
from collections import OrderedDict
|
12 |
+
|
13 |
+
from mixup import Mixup
|
14 |
+
from timm.models import create_model
|
15 |
+
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
16 |
+
from timm.utils import ModelEma
|
17 |
+
from optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner
|
18 |
+
|
19 |
+
from datasets import build_dataset
|
20 |
+
from engine_for_finetuning import train_one_epoch, validation_one_epoch, final_test, merge, merge_mean_per_class
|
21 |
+
from utils_mae import NativeScalerWithGradNormCount as NativeScaler
|
22 |
+
from utils_mae import multiple_samples_collate
|
23 |
+
import utils_mae as utils
|
24 |
+
import modeling_finetune
|
25 |
+
|
26 |
+
|
27 |
+
def get_args():
|
28 |
+
parser = argparse.ArgumentParser('VideoMAE fine-tuning and evaluation script for video classification', add_help=False)
|
29 |
+
parser.add_argument('--batch_size', default=64, type=int)
|
30 |
+
parser.add_argument('--epochs', default=30, type=int)
|
31 |
+
parser.add_argument('--update_freq', default=1, type=int)
|
32 |
+
parser.add_argument('--save_ckpt_freq', default=100, type=int)
|
33 |
+
parser.add_argument('--val_freq', default=1, type=int)
|
34 |
+
|
35 |
+
# Model parameters
|
36 |
+
parser.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL',
|
37 |
+
help='Name of model to train')
|
38 |
+
parser.add_argument('--tubelet_size', type=int, default= 2)
|
39 |
+
parser.add_argument('--input_size', default=224, type=int,
|
40 |
+
help='videos input size')
|
41 |
+
|
42 |
+
parser.add_argument('--fc_drop_rate', type=float, default=0.0, metavar='PCT',
|
43 |
+
help='Dropout rate (default: 0.)')
|
44 |
+
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
45 |
+
help='Dropout rate (default: 0.)')
|
46 |
+
parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
|
47 |
+
help='Attention dropout rate (default: 0.)')
|
48 |
+
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
|
49 |
+
help='Drop path rate (default: 0.1)')
|
50 |
+
|
51 |
+
parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False)
|
52 |
+
parser.add_argument('--model_ema', action='store_true', default=False)
|
53 |
+
parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
|
54 |
+
parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
|
55 |
+
|
56 |
+
# Optimizer parameters
|
57 |
+
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
58 |
+
help='Optimizer (default: "adamw"')
|
59 |
+
parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
|
60 |
+
help='Optimizer Epsilon (default: 1e-8)')
|
61 |
+
parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
|
62 |
+
help='Optimizer Betas (default: None, use opt default)')
|
63 |
+
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
|
64 |
+
help='Clip gradient norm (default: None, no clipping)')
|
65 |
+
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
66 |
+
help='SGD momentum (default: 0.9)')
|
67 |
+
parser.add_argument('--weight_decay', type=float, default=0.05,
|
68 |
+
help='weight decay (default: 0.05)')
|
69 |
+
parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
|
70 |
+
weight decay. We use a cosine schedule for WD and using a larger decay by
|
71 |
+
the end of training improves performance for ViTs.""")
|
72 |
+
|
73 |
+
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
|
74 |
+
help='learning rate (default: 1e-3)')
|
75 |
+
parser.add_argument('--layer_decay', type=float, default=0.75)
|
76 |
+
|
77 |
+
parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
|
78 |
+
help='warmup learning rate (default: 1e-6)')
|
79 |
+
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
|
80 |
+
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
81 |
+
|
82 |
+
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
|
83 |
+
help='epochs to warmup LR, if scheduler supports')
|
84 |
+
parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
|
85 |
+
help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
|
86 |
+
|
87 |
+
# Augmentation parameters
|
88 |
+
parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
|
89 |
+
help='Color jitter factor (default: 0.4)')
|
90 |
+
parser.add_argument('--num_sample', type=int, default=2,
|
91 |
+
help='Repeated_aug (default: 2)')
|
92 |
+
parser.add_argument('--aa', type=str, default='rand-m7-n4-mstd0.5-inc1', metavar='NAME',
|
93 |
+
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m7-n4-mstd0.5-inc1)'),
|
94 |
+
parser.add_argument('--smoothing', type=float, default=0.1,
|
95 |
+
help='Label smoothing (default: 0.1)')
|
96 |
+
parser.add_argument('--train_interpolation', type=str, default='bicubic',
|
97 |
+
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
|
98 |
+
|
99 |
+
# Evaluation parameters
|
100 |
+
parser.add_argument('--crop_pct', type=float, default=None)
|
101 |
+
parser.add_argument('--short_side_size', type=int, default=224)
|
102 |
+
parser.add_argument('--test_num_segment', type=int, default=5)
|
103 |
+
parser.add_argument('--test_num_crop', type=int, default=3)
|
104 |
+
|
105 |
+
# Random Erase params
|
106 |
+
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
107 |
+
help='Random erase prob (default: 0.25)')
|
108 |
+
parser.add_argument('--remode', type=str, default='pixel',
|
109 |
+
help='Random erase mode (default: "pixel")')
|
110 |
+
parser.add_argument('--recount', type=int, default=1,
|
111 |
+
help='Random erase count (default: 1)')
|
112 |
+
parser.add_argument('--resplit', action='store_true', default=False,
|
113 |
+
help='Do not random erase first (clean) augmentation split')
|
114 |
+
|
115 |
+
# Mixup params
|
116 |
+
parser.add_argument('--mixup', type=float, default=0.8,
|
117 |
+
help='mixup alpha, mixup enabled if > 0.')
|
118 |
+
parser.add_argument('--cutmix', type=float, default=1.0,
|
119 |
+
help='cutmix alpha, cutmix enabled if > 0.')
|
120 |
+
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
|
121 |
+
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
122 |
+
parser.add_argument('--mixup_prob', type=float, default=1.0,
|
123 |
+
help='Probability of performing mixup or cutmix when either/both is enabled')
|
124 |
+
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
|
125 |
+
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
126 |
+
parser.add_argument('--mixup_mode', type=str, default='batch',
|
127 |
+
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
128 |
+
|
129 |
+
# Finetuning params
|
130 |
+
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
|
131 |
+
parser.add_argument('--model_key', default='model|module', type=str)
|
132 |
+
parser.add_argument('--model_prefix', default='', type=str)
|
133 |
+
parser.add_argument('--init_scale', default=0.001, type=float)
|
134 |
+
parser.add_argument('--use_checkpoint', action='store_true')
|
135 |
+
parser.set_defaults(use_checkpoint=False)
|
136 |
+
parser.add_argument('--use_mean_pooling', action='store_true')
|
137 |
+
parser.set_defaults(use_mean_pooling=True)
|
138 |
+
parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')
|
139 |
+
|
140 |
+
# Dataset parameters
|
141 |
+
parser.add_argument('--data_path', default='/path/to/list_kinetics-400', type=str,
|
142 |
+
help='dataset path')
|
143 |
+
parser.add_argument('--eval_data_path', default=None, type=str,
|
144 |
+
help='dataset path for evaluation')
|
145 |
+
parser.add_argument('--nb_classes', default=400, type=int,
|
146 |
+
help='number of the classification types')
|
147 |
+
parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
|
148 |
+
parser.add_argument('--num_segments', type=int, default= 1)
|
149 |
+
parser.add_argument('--num_frames', type=int, default= 16)
|
150 |
+
parser.add_argument('--sampling_rate', type=int, default= 4)
|
151 |
+
parser.add_argument('--data_set', default='Kinetics-400', choices=['Kinetics-400', 'SSV2', 'UCF101', 'HMDB51','image_folder','SSV2-Mini', 'Mini-Kinetics'],
|
152 |
+
type=str, help='dataset')
|
153 |
+
parser.add_argument('--output_dir', default='',
|
154 |
+
help='path where to save, empty for no saving')
|
155 |
+
parser.add_argument('--log_dir', default=None,
|
156 |
+
help='path where to tensorboard log')
|
157 |
+
parser.add_argument('--device', default='cuda',
|
158 |
+
help='device to use for training / testing')
|
159 |
+
parser.add_argument('--seed', default=0, type=int)
|
160 |
+
parser.add_argument('--resume', default='',
|
161 |
+
help='resume from checkpoint')
|
162 |
+
parser.add_argument('--auto_resume', action='store_true')
|
163 |
+
parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
|
164 |
+
parser.set_defaults(auto_resume=True)
|
165 |
+
|
166 |
+
parser.add_argument('--save_ckpt', action='store_true')
|
167 |
+
parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
|
168 |
+
parser.set_defaults(save_ckpt=True)
|
169 |
+
|
170 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
171 |
+
help='start epoch')
|
172 |
+
parser.add_argument('--eval', action='store_true',
|
173 |
+
help='Perform evaluation only')
|
174 |
+
parser.add_argument('--dist_eval', action='store_true', default=False,
|
175 |
+
help='Enabling distributed evaluation')
|
176 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
177 |
+
parser.add_argument('--pin_mem', action='store_true',
|
178 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
179 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
180 |
+
parser.set_defaults(pin_mem=True)
|
181 |
+
|
182 |
+
# distributed training parameters
|
183 |
+
parser.add_argument('--world_size', default=1, type=int,
|
184 |
+
help='number of distributed processes')
|
185 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
186 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
187 |
+
parser.add_argument('--dist_url', default='env://',
|
188 |
+
help='url used to set up distributed training')
|
189 |
+
|
190 |
+
parser.add_argument('--enable_deepspeed', action='store_true', default=False)
|
191 |
+
|
192 |
+
# debug mode
|
193 |
+
parser.add_argument('--not_dist', action='store_true', default=False)
|
194 |
+
parser.add_argument('--num_outputs', default=8, type=int)
|
195 |
+
|
196 |
+
known_args, _ = parser.parse_known_args()
|
197 |
+
|
198 |
+
if known_args.enable_deepspeed:
|
199 |
+
try:
|
200 |
+
import deepspeed
|
201 |
+
from deepspeed import DeepSpeedConfig
|
202 |
+
parser = deepspeed.add_config_arguments(parser)
|
203 |
+
ds_init = deepspeed.initialize
|
204 |
+
except:
|
205 |
+
print("Please 'pip install deepspeed'")
|
206 |
+
exit(0)
|
207 |
+
else:
|
208 |
+
ds_init = None
|
209 |
+
|
210 |
+
return parser.parse_args(), ds_init
|
211 |
+
|
212 |
+
|
213 |
+
def main(args, ds_init):
|
214 |
+
if args.not_dist:
|
215 |
+
args.distributed = False
|
216 |
+
else:
|
217 |
+
utils.init_distributed_mode(args)
|
218 |
+
|
219 |
+
if ds_init is not None:
|
220 |
+
utils.create_ds_config(args)
|
221 |
+
|
222 |
+
print(args)
|
223 |
+
|
224 |
+
device = torch.device(args.device)
|
225 |
+
|
226 |
+
# fix the seed for reproducibility
|
227 |
+
seed = args.seed + utils.get_rank()
|
228 |
+
torch.manual_seed(seed)
|
229 |
+
np.random.seed(seed)
|
230 |
+
# random.seed(seed)
|
231 |
+
|
232 |
+
cudnn.benchmark = True
|
233 |
+
|
234 |
+
dataset_train, args.nb_classes = build_dataset(is_train=True, test_mode=False, args=args)
|
235 |
+
if args.disable_eval_during_finetuning:
|
236 |
+
dataset_val = None
|
237 |
+
else:
|
238 |
+
dataset_val, _ = build_dataset(is_train=False, test_mode=False, args=args)
|
239 |
+
dataset_test, _ = build_dataset(is_train=False, test_mode=True, args=args)
|
240 |
+
|
241 |
+
|
242 |
+
num_tasks = utils.get_world_size()
|
243 |
+
global_rank = utils.get_rank()
|
244 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
245 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
246 |
+
)
|
247 |
+
print("Sampler_train = %s" % str(sampler_train))
|
248 |
+
if args.dist_eval:
|
249 |
+
if len(dataset_val) % num_tasks != 0:
|
250 |
+
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
251 |
+
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
252 |
+
'equal num of samples per-process.')
|
253 |
+
sampler_val = torch.utils.data.DistributedSampler(
|
254 |
+
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
|
255 |
+
sampler_test = torch.utils.data.DistributedSampler(
|
256 |
+
dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=False)
|
257 |
+
else:
|
258 |
+
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
259 |
+
|
260 |
+
if global_rank == 0 and args.log_dir is not None:
|
261 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
262 |
+
log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
|
263 |
+
else:
|
264 |
+
log_writer = None
|
265 |
+
|
266 |
+
if args.num_sample > 1:
|
267 |
+
collate_func = partial(multiple_samples_collate, fold=False)
|
268 |
+
else:
|
269 |
+
collate_func = None
|
270 |
+
|
271 |
+
data_loader_train = torch.utils.data.DataLoader(
|
272 |
+
dataset_train, sampler=sampler_train,
|
273 |
+
batch_size=args.batch_size,
|
274 |
+
num_workers=args.num_workers,
|
275 |
+
pin_memory=args.pin_mem,
|
276 |
+
drop_last=True,
|
277 |
+
collate_fn=collate_func,
|
278 |
+
)
|
279 |
+
|
280 |
+
if dataset_val is not None:
|
281 |
+
data_loader_val = torch.utils.data.DataLoader(
|
282 |
+
dataset_val, sampler=sampler_val,
|
283 |
+
batch_size=int(1.5 * args.batch_size),
|
284 |
+
num_workers=args.num_workers,
|
285 |
+
pin_memory=args.pin_mem,
|
286 |
+
drop_last=False
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
data_loader_val = None
|
290 |
+
|
291 |
+
if dataset_test is not None:
|
292 |
+
data_loader_test = torch.utils.data.DataLoader(
|
293 |
+
dataset_test, sampler=sampler_test,
|
294 |
+
batch_size=args.batch_size,
|
295 |
+
num_workers=args.num_workers,
|
296 |
+
pin_memory=args.pin_mem,
|
297 |
+
drop_last=False
|
298 |
+
)
|
299 |
+
else:
|
300 |
+
data_loader_test = None
|
301 |
+
|
302 |
+
mixup_fn = None
|
303 |
+
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
304 |
+
if mixup_active:
|
305 |
+
print("Mixup is activated!")
|
306 |
+
mixup_fn = Mixup(
|
307 |
+
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
308 |
+
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
309 |
+
label_smoothing=args.smoothing, num_classes=args.nb_classes)
|
310 |
+
|
311 |
+
model = create_model(
|
312 |
+
args.model,
|
313 |
+
pretrained=False,
|
314 |
+
num_classes=args.nb_classes,
|
315 |
+
all_frames=args.num_frames * args.num_segments,
|
316 |
+
tubelet_size=args.tubelet_size,
|
317 |
+
fc_drop_rate=args.fc_drop_rate,
|
318 |
+
drop_rate=args.drop,
|
319 |
+
drop_path_rate=args.drop_path,
|
320 |
+
attn_drop_rate=args.attn_drop_rate,
|
321 |
+
drop_block_rate=None,
|
322 |
+
use_checkpoint=args.use_checkpoint,
|
323 |
+
use_mean_pooling=args.use_mean_pooling,
|
324 |
+
init_scale=args.init_scale,
|
325 |
+
)
|
326 |
+
|
327 |
+
patch_size = model.patch_embed.patch_size
|
328 |
+
print("Patch size = %s" % str(patch_size))
|
329 |
+
args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1])
|
330 |
+
args.patch_size = patch_size
|
331 |
+
|
332 |
+
if args.finetune:
|
333 |
+
if args.finetune.startswith('https'):
|
334 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
335 |
+
args.finetune, map_location='cpu', check_hash=True)
|
336 |
+
else:
|
337 |
+
checkpoint = torch.load(args.finetune, map_location='cpu')
|
338 |
+
|
339 |
+
print("Load ckpt from %s" % args.finetune)
|
340 |
+
checkpoint_model = None
|
341 |
+
for model_key in args.model_key.split('|'):
|
342 |
+
if model_key in checkpoint:
|
343 |
+
checkpoint_model = checkpoint[model_key]
|
344 |
+
print("Load state_dict by model_key = %s" % model_key)
|
345 |
+
break
|
346 |
+
if checkpoint_model is None:
|
347 |
+
checkpoint_model = checkpoint
|
348 |
+
state_dict = model.state_dict()
|
349 |
+
for k in ['head.weight', 'head.bias']:
|
350 |
+
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
351 |
+
print(f"Removing key {k} from pretrained checkpoint")
|
352 |
+
del checkpoint_model[k]
|
353 |
+
|
354 |
+
all_keys = list(checkpoint_model.keys())
|
355 |
+
new_dict = OrderedDict()
|
356 |
+
for key in all_keys:
|
357 |
+
if key.startswith('backbone.'):
|
358 |
+
new_dict[key[9:]] = checkpoint_model[key]
|
359 |
+
elif key.startswith('encoder.'):
|
360 |
+
new_dict[key[8:]] = checkpoint_model[key]
|
361 |
+
else:
|
362 |
+
new_dict[key] = checkpoint_model[key]
|
363 |
+
checkpoint_model = new_dict
|
364 |
+
|
365 |
+
# interpolate position embedding
|
366 |
+
if 'pos_embed' in checkpoint_model:
|
367 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
368 |
+
embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
|
369 |
+
num_patches = model.patch_embed.num_patches #
|
370 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
|
371 |
+
|
372 |
+
# height (== width) for the checkpoint position embedding
|
373 |
+
orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(args.num_frames // model.patch_embed.tubelet_size)) ** 0.5)
|
374 |
+
# height (== width) for the new position embedding
|
375 |
+
new_size = int((num_patches // (args.num_frames // model.patch_embed.tubelet_size) )** 0.5)
|
376 |
+
# class_token and dist_token are kept unchanged
|
377 |
+
if orig_size != new_size:
|
378 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
379 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
380 |
+
# only the position tokens are interpolated
|
381 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
382 |
+
# B, L, C -> BT, H, W, C -> BT, C, H, W
|
383 |
+
pos_tokens = pos_tokens.reshape(-1, args.num_frames // model.patch_embed.tubelet_size, orig_size, orig_size, embedding_size)
|
384 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
385 |
+
pos_tokens = torch.nn.functional.interpolate(
|
386 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
387 |
+
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
|
388 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, args.num_frames // model.patch_embed.tubelet_size, new_size, new_size, embedding_size)
|
389 |
+
pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
|
390 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
391 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
392 |
+
|
393 |
+
utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
|
394 |
+
|
395 |
+
model.to(device)
|
396 |
+
|
397 |
+
model_ema = None
|
398 |
+
if args.model_ema:
|
399 |
+
model_ema = ModelEma(
|
400 |
+
model,
|
401 |
+
decay=args.model_ema_decay,
|
402 |
+
device='cpu' if args.model_ema_force_cpu else '',
|
403 |
+
resume='')
|
404 |
+
print("Using EMA with decay = %.8f" % args.model_ema_decay)
|
405 |
+
|
406 |
+
model_without_ddp = model
|
407 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
408 |
+
|
409 |
+
print("Model = %s" % str(model_without_ddp))
|
410 |
+
print('number of params:', n_parameters)
|
411 |
+
|
412 |
+
total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
|
413 |
+
num_training_steps_per_epoch = len(dataset_train) // total_batch_size
|
414 |
+
args.lr = args.lr * total_batch_size / 256
|
415 |
+
args.min_lr = args.min_lr * total_batch_size / 256
|
416 |
+
args.warmup_lr = args.warmup_lr * total_batch_size / 256
|
417 |
+
print("LR = %.8f" % args.lr)
|
418 |
+
print("Batch size = %d" % total_batch_size)
|
419 |
+
print("Update frequent = %d" % args.update_freq)
|
420 |
+
print("Number of training examples = %d" % len(dataset_train))
|
421 |
+
print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
|
422 |
+
|
423 |
+
num_layers = model_without_ddp.get_num_layers()
|
424 |
+
if args.layer_decay < 1.0:
|
425 |
+
assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
|
426 |
+
else:
|
427 |
+
assigner = None
|
428 |
+
|
429 |
+
if assigner is not None:
|
430 |
+
print("Assigned values = %s" % str(assigner.values))
|
431 |
+
|
432 |
+
skip_weight_decay_list = model.no_weight_decay()
|
433 |
+
print("Skip weight decay list: ", skip_weight_decay_list)
|
434 |
+
|
435 |
+
if args.enable_deepspeed:
|
436 |
+
loss_scaler = None
|
437 |
+
optimizer_params = get_parameter_groups(
|
438 |
+
model, args.weight_decay, skip_weight_decay_list,
|
439 |
+
assigner.get_layer_id if assigner is not None else None,
|
440 |
+
assigner.get_scale if assigner is not None else None)
|
441 |
+
model, optimizer, _, _ = ds_init(
|
442 |
+
args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed,
|
443 |
+
)
|
444 |
+
|
445 |
+
print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps())
|
446 |
+
assert model.gradient_accumulation_steps() == args.update_freq
|
447 |
+
else:
|
448 |
+
if args.distributed:
|
449 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
450 |
+
model_without_ddp = model.module
|
451 |
+
|
452 |
+
optimizer = create_optimizer(
|
453 |
+
args, model_without_ddp, skip_list=skip_weight_decay_list,
|
454 |
+
get_num_layer=assigner.get_layer_id if assigner is not None else None,
|
455 |
+
get_layer_scale=assigner.get_scale if assigner is not None else None)
|
456 |
+
loss_scaler = NativeScaler()
|
457 |
+
|
458 |
+
print("Use step level LR scheduler!")
|
459 |
+
lr_schedule_values = utils.cosine_scheduler(
|
460 |
+
args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
|
461 |
+
warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
|
462 |
+
)
|
463 |
+
if args.weight_decay_end is None:
|
464 |
+
args.weight_decay_end = args.weight_decay
|
465 |
+
wd_schedule_values = utils.cosine_scheduler(
|
466 |
+
args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
|
467 |
+
print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
|
468 |
+
|
469 |
+
if mixup_fn is not None:
|
470 |
+
# smoothing is handled with mixup label transform
|
471 |
+
criterion = SoftTargetCrossEntropy()
|
472 |
+
elif args.smoothing > 0.:
|
473 |
+
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
474 |
+
else:
|
475 |
+
criterion = torch.nn.CrossEntropyLoss()
|
476 |
+
|
477 |
+
print("criterion = %s" % str(criterion))
|
478 |
+
|
479 |
+
utils.auto_load_model(
|
480 |
+
args=args, model=model, model_without_ddp=model_without_ddp,
|
481 |
+
optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
|
482 |
+
|
483 |
+
if args.eval:
|
484 |
+
if not args.not_dist:
|
485 |
+
preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt')
|
486 |
+
test_stats = final_test(data_loader_test, model, device, preds_file)
|
487 |
+
torch.distributed.barrier()
|
488 |
+
else:
|
489 |
+
num_tasks = args.num_outputs
|
490 |
+
|
491 |
+
if global_rank == 0:
|
492 |
+
print("Start merging results...")
|
493 |
+
final_top1 ,final_top5 = merge(args.output_dir, num_tasks)
|
494 |
+
print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%")
|
495 |
+
log_stats = {'Final top-1': final_top1,
|
496 |
+
'Final Top-5': final_top5}
|
497 |
+
|
498 |
+
final_top1_per_class ,final_top5_per_class = merge_mean_per_class(args.output_dir, num_tasks,args.nb_classes)
|
499 |
+
print(f"Accuracy of the network on the {len(dataset_test)} test videos: Mean-Top-1: {final_top1_per_class:.2f}%, Mean-Top-5: {final_top5_per_class:.2f}%")
|
500 |
+
log_stats["Class-Mean-Top-1"] = final_top1_per_class
|
501 |
+
log_stats["Class-Mean-Top-5"] = final_top5_per_class
|
502 |
+
|
503 |
+
if args.output_dir and utils.is_main_process():
|
504 |
+
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
|
505 |
+
f.write(json.dumps(log_stats) + "\n")
|
506 |
+
exit(0)
|
507 |
+
|
508 |
+
|
509 |
+
print(f"Start training for {args.epochs} epochs")
|
510 |
+
start_time = time.time()
|
511 |
+
max_accuracy = 0.0
|
512 |
+
for epoch in range(args.start_epoch, args.epochs):
|
513 |
+
if args.distributed:
|
514 |
+
data_loader_train.sampler.set_epoch(epoch)
|
515 |
+
if log_writer is not None:
|
516 |
+
log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
|
517 |
+
train_stats = train_one_epoch(
|
518 |
+
model, criterion, data_loader_train, optimizer,
|
519 |
+
device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
|
520 |
+
log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch,
|
521 |
+
lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
|
522 |
+
num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
|
523 |
+
)
|
524 |
+
if args.output_dir and args.save_ckpt:
|
525 |
+
if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
|
526 |
+
utils.save_model(
|
527 |
+
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
528 |
+
loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
|
529 |
+
if data_loader_val is not None and (epoch + 1) % args.val_freq == 0:
|
530 |
+
test_stats = validation_one_epoch(data_loader_val, model, device)
|
531 |
+
print(f"Accuracy of the network on the {len(dataset_val)} val videos: {test_stats['acc1']:.1f}%")
|
532 |
+
if max_accuracy < test_stats["acc1"]:
|
533 |
+
max_accuracy = test_stats["acc1"]
|
534 |
+
if args.output_dir and args.save_ckpt:
|
535 |
+
utils.save_model(
|
536 |
+
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
537 |
+
loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
|
538 |
+
|
539 |
+
print(f'Max accuracy: {max_accuracy:.2f}%')
|
540 |
+
if log_writer is not None:
|
541 |
+
log_writer.update(val_acc1=test_stats['acc1'], head="perf", step=epoch)
|
542 |
+
log_writer.update(val_acc5=test_stats['acc5'], head="perf", step=epoch)
|
543 |
+
log_writer.update(val_loss=test_stats['loss'], head="perf", step=epoch)
|
544 |
+
|
545 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
546 |
+
**{f'val_{k}': v for k, v in test_stats.items()},
|
547 |
+
'epoch': epoch,
|
548 |
+
'n_parameters': n_parameters}
|
549 |
+
else:
|
550 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
551 |
+
'epoch': epoch,
|
552 |
+
'n_parameters': n_parameters}
|
553 |
+
if args.output_dir and utils.is_main_process():
|
554 |
+
if log_writer is not None:
|
555 |
+
log_writer.flush()
|
556 |
+
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
|
557 |
+
f.write(json.dumps(log_stats) + "\n")
|
558 |
+
|
559 |
+
preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt')
|
560 |
+
test_stats = final_test(data_loader_test, model, device, preds_file)
|
561 |
+
torch.distributed.barrier()
|
562 |
+
if global_rank == 0:
|
563 |
+
print("Start merging results...")
|
564 |
+
final_top1 ,final_top5 = merge(args.output_dir, num_tasks)
|
565 |
+
print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%")
|
566 |
+
log_stats = {'Final top-1': final_top1,
|
567 |
+
'Final Top-5': final_top5}
|
568 |
+
if args.output_dir and utils.is_main_process():
|
569 |
+
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
|
570 |
+
f.write(json.dumps(log_stats) + "\n")
|
571 |
+
|
572 |
+
|
573 |
+
total_time = time.time() - start_time
|
574 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
575 |
+
print('Training time {}'.format(total_time_str))
|
576 |
+
|
577 |
+
|
578 |
+
if __name__ == '__main__':
|
579 |
+
opts, ds_init = get_args()
|
580 |
+
if opts.output_dir:
|
581 |
+
Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
|
582 |
+
main(opts, ds_init)
|
run_mae_pretraining.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import numpy as np
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import torch.backends.cudnn as cudnn
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from pathlib import Path
|
10 |
+
from timm.models import create_model
|
11 |
+
from optim_factory import create_optimizer
|
12 |
+
from datasets import build_pretraining_dataset
|
13 |
+
from engine_for_pretraining import train_one_epoch
|
14 |
+
from utils_mae import NativeScalerWithGradNormCount as NativeScaler
|
15 |
+
import utils_mae as utils
|
16 |
+
import modeling_pretrain
|
17 |
+
from timm.models.vision_transformer import vit_small_patch16_224, vit_base_patch16_224, vit_large_patch16_224
|
18 |
+
from modeling_pretrain import FeatureExtractor
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
parser = argparse.ArgumentParser('VideoMAE pre-training script', add_help=False)
|
23 |
+
parser.add_argument('--batch_size', default=64, type=int)
|
24 |
+
parser.add_argument('--epochs', default=800, type=int)
|
25 |
+
parser.add_argument('--save_ckpt_freq', default=50, type=int)
|
26 |
+
|
27 |
+
# Model parameters
|
28 |
+
parser.add_argument('--model', default='pretrain_videomae_base_patch16_224', type=str, metavar='MODEL',
|
29 |
+
help='Name of model to train')
|
30 |
+
|
31 |
+
parser.add_argument('--decoder_depth', default=4, type=int,
|
32 |
+
help='depth of decoder')
|
33 |
+
|
34 |
+
parser.add_argument('--mask_type', default='tube', choices=['random', 'tube', 'tubelet'],
|
35 |
+
type=str, help='masked strategy of video tokens/patches')
|
36 |
+
|
37 |
+
parser.add_argument('--sub_mask_type', default='tube+picked_frame_visible', choices=['tube', 'tube+picked_frame_visible', 'tube+traj_mask'],
|
38 |
+
type=str, help='sub masked strategy of tubelet masking')
|
39 |
+
|
40 |
+
parser.add_argument('--mask_ratio', default=0.75, type=float,
|
41 |
+
help='ratio of the visual tokens/patches need be masked')
|
42 |
+
|
43 |
+
parser.add_argument('--input_size', default=224, type=int,
|
44 |
+
help='videos input size for backbone')
|
45 |
+
|
46 |
+
parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT',
|
47 |
+
help='Drop path rate (default: 0.1)')
|
48 |
+
|
49 |
+
parser.add_argument('--normlize_target', default=True, type=bool,
|
50 |
+
help='normalized the target patch pixels')
|
51 |
+
|
52 |
+
# Optimizer parameters
|
53 |
+
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
54 |
+
help='Optimizer (default: "adamw"')
|
55 |
+
parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
|
56 |
+
help='Optimizer Epsilon (default: 1e-8)')
|
57 |
+
parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
|
58 |
+
help='Optimizer Betas (default: None, use opt default)')
|
59 |
+
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
|
60 |
+
help='Clip gradient norm (default: None, no clipping)')
|
61 |
+
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
62 |
+
help='SGD momentum (default: 0.9)')
|
63 |
+
parser.add_argument('--weight_decay', type=float, default=0.05,
|
64 |
+
help='weight decay (default: 0.05)')
|
65 |
+
parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
|
66 |
+
weight decay. We use a cosine schedule for WD.
|
67 |
+
(Set the same value with args.weight_decay to keep weight decay no change)""")
|
68 |
+
|
69 |
+
parser.add_argument('--lr', type=float, default=1.5e-4, metavar='LR',
|
70 |
+
help='learning rate (default: 1.5e-4)')
|
71 |
+
parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
|
72 |
+
help='warmup learning rate (default: 1e-6)')
|
73 |
+
parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
|
74 |
+
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
75 |
+
|
76 |
+
parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
|
77 |
+
help='epochs to warmup LR, if scheduler supports')
|
78 |
+
parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
|
79 |
+
help='epochs to warmup LR, if scheduler supports')
|
80 |
+
parser.add_argument('--use_checkpoint', action='store_true')
|
81 |
+
parser.set_defaults(use_checkpoint=False)
|
82 |
+
|
83 |
+
# Augmentation parameters
|
84 |
+
parser.add_argument('--color_jitter', type=float, default=0.0, metavar='PCT',
|
85 |
+
help='Color jitter factor (default: 0.4)')
|
86 |
+
parser.add_argument('--train_interpolation', type=str, default='bicubic',
|
87 |
+
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
|
88 |
+
|
89 |
+
# Dataset parameters
|
90 |
+
parser.add_argument('--data_path', default='/path/to/list_kinetics-400', type=str,
|
91 |
+
help='dataset path')
|
92 |
+
parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
|
93 |
+
parser.add_argument('--num_frames', type=int, default= 16)
|
94 |
+
parser.add_argument('--sampling_rate', type=int, default= 4)
|
95 |
+
parser.add_argument('--output_dir', default='',
|
96 |
+
help='path where to save, empty for no saving')
|
97 |
+
parser.add_argument('--log_dir', default=None,
|
98 |
+
help='path where to tensorboard log')
|
99 |
+
parser.add_argument('--device', default='cuda',
|
100 |
+
help='device to use for training / testing')
|
101 |
+
parser.add_argument('--seed', default=0, type=int)
|
102 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
103 |
+
parser.add_argument('--auto_resume', action='store_true')
|
104 |
+
parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
|
105 |
+
parser.set_defaults(auto_resume=True)
|
106 |
+
|
107 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
108 |
+
help='start epoch')
|
109 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
110 |
+
parser.add_argument('--pin_mem', action='store_true',
|
111 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
112 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
|
113 |
+
help='')
|
114 |
+
parser.set_defaults(pin_mem=True)
|
115 |
+
|
116 |
+
# distributed training parameters
|
117 |
+
parser.add_argument('--world_size', default=1, type=int,
|
118 |
+
help='number of distributed processes')
|
119 |
+
parser.add_argument('--local-rank', default=-1, type=int)
|
120 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
121 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
122 |
+
|
123 |
+
# Tubelet params
|
124 |
+
parser.add_argument('--add_tubelets', action='store_true')
|
125 |
+
parser.set_defaults(add_tubelets=False)
|
126 |
+
parser.add_argument('--use_objects', action='store_true')
|
127 |
+
parser.set_defaults(use_objects=False)
|
128 |
+
parser.add_argument('--objects_path', type=str, default=None)
|
129 |
+
parser.add_argument('--motion_type', type=str, default='gaussian')
|
130 |
+
parser.add_argument('--scales', type=str, default='[32, 48, 56, 64, 96, 128]')
|
131 |
+
parser.add_argument('--visible_frames', type=str, default=None) # not used
|
132 |
+
parser.add_argument('--traj_unmask_ratio', type=float, default=0.1)
|
133 |
+
|
134 |
+
#dino params
|
135 |
+
parser.add_argument('--target_type', default='pixel', choices=['pixel', 'dino_v1', 'clip'], type=str, help='define target type for loss')
|
136 |
+
parser.add_argument('--distillation_teacher', default="clip_b", type=str, choices=['dino_s', 'dino_b', 'clip_b'], help='distillation teacher model')
|
137 |
+
|
138 |
+
# multiple sampling
|
139 |
+
parser.add_argument('--multiple_sampling', action='store_true')
|
140 |
+
# for 2nd stage training
|
141 |
+
parser.add_argument('--first_stage_path', type=str, default=None)
|
142 |
+
|
143 |
+
return parser.parse_args()
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
def get_teacher_student_models(args):
|
148 |
+
print(f"Creating model: {args.model}")
|
149 |
+
if args.target_type=='pixel':
|
150 |
+
dec_dim = 1536
|
151 |
+
elif 'dino' in args.target_type or 'clip' in args.target_type:
|
152 |
+
if args.distillation_teacher == 'dino_s':
|
153 |
+
dec_dim = 384
|
154 |
+
elif args.distillation_teacher == 'dino_b' or args.distillation_teacher == 'clip_b':
|
155 |
+
|
156 |
+
dec_dim = 768
|
157 |
+
|
158 |
+
student_model = create_model(
|
159 |
+
args.model,
|
160 |
+
pretrained=False,
|
161 |
+
drop_path_rate=args.drop_path,
|
162 |
+
drop_block_rate=None,
|
163 |
+
decoder_depth=args.decoder_depth,
|
164 |
+
use_checkpoint=args.use_checkpoint,
|
165 |
+
decoder_num_classes=dec_dim,
|
166 |
+
)
|
167 |
+
|
168 |
+
if args.target_type == 'dino_v1':
|
169 |
+
|
170 |
+
# load dino
|
171 |
+
if args.distillation_teacher == 'dino_s':
|
172 |
+
pretraining = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
|
173 |
+
teacher_model = vit_small_patch16_224(pretrained=False)
|
174 |
+
elif args.distillation_teacher == 'dino_b':
|
175 |
+
pretraining = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
|
176 |
+
teacher_model = vit_base_patch16_224(pretrained=False)
|
177 |
+
|
178 |
+
msg =teacher_model.load_state_dict(pretraining.state_dict(), strict=False)
|
179 |
+
teacher_model = FeatureExtractor(teacher_model, args.input_size, 16)
|
180 |
+
print(msg)
|
181 |
+
teacher_model.eval()
|
182 |
+
|
183 |
+
elif args.target_type == 'clip':
|
184 |
+
|
185 |
+
# load clip
|
186 |
+
from utils_viclip.config import Config
|
187 |
+
from utils_viclip.config_utils import setup_viclip
|
188 |
+
from tasks.shared_utils import setup_model
|
189 |
+
from models_viclip.viclip import ViCLIP
|
190 |
+
|
191 |
+
config = setup_viclip('configs/config.py')
|
192 |
+
model_cls = eval(config.model.get('model_cls', 'ViCLIP'))
|
193 |
+
teacher_model = setup_model(
|
194 |
+
config,
|
195 |
+
model_cls=model_cls,
|
196 |
+
has_decoder=False,
|
197 |
+
pretrain=False,
|
198 |
+
find_unused_parameters=False,
|
199 |
+
)
|
200 |
+
teacher_model.eval()
|
201 |
+
else:
|
202 |
+
teacher_model = None
|
203 |
+
|
204 |
+
|
205 |
+
return student_model, teacher_model
|
206 |
+
|
207 |
+
def load_first_stage(model,args):
|
208 |
+
if args.first_stage_path is not None:
|
209 |
+
checkpoint = torch.load(args.first_stage_path, map_location='cpu')
|
210 |
+
print("loading first stage from ",args.first_stage_path)
|
211 |
+
checkpoint_model = checkpoint['model']
|
212 |
+
utils.load_state_dict(model, checkpoint_model)
|
213 |
+
|
214 |
+
|
215 |
+
def main(args):
|
216 |
+
utils.init_distributed_mode(args)
|
217 |
+
|
218 |
+
print(args)
|
219 |
+
|
220 |
+
device = torch.device(args.device)
|
221 |
+
|
222 |
+
# fix the seed for reproducibility
|
223 |
+
seed = args.seed + utils.get_rank()
|
224 |
+
torch.manual_seed(seed)
|
225 |
+
np.random.seed(seed)
|
226 |
+
|
227 |
+
cudnn.benchmark = True
|
228 |
+
|
229 |
+
student_model, teacher_model = get_teacher_student_models(args)
|
230 |
+
|
231 |
+
patch_size = student_model.encoder.patch_embed.patch_size
|
232 |
+
print("Patch size = %s" % str(patch_size))
|
233 |
+
args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1]) # [8, 14, 14]
|
234 |
+
print(f"Window Size = {args.window_size}")
|
235 |
+
args.patch_size = patch_size
|
236 |
+
|
237 |
+
|
238 |
+
# Start from pretrained first stage model
|
239 |
+
if args.first_stage_path is not None:
|
240 |
+
load_first_stage(student_model,args)
|
241 |
+
|
242 |
+
# get dataset
|
243 |
+
dataset_train = build_pretraining_dataset(args)
|
244 |
+
|
245 |
+
|
246 |
+
num_tasks = utils.get_world_size()
|
247 |
+
global_rank = utils.get_rank()
|
248 |
+
sampler_rank = global_rank
|
249 |
+
|
250 |
+
total_batch_size = args.batch_size * num_tasks
|
251 |
+
num_training_steps_per_epoch = len(dataset_train) // total_batch_size
|
252 |
+
|
253 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
254 |
+
dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True
|
255 |
+
)
|
256 |
+
print("Sampler_train = %s" % str(sampler_train))
|
257 |
+
|
258 |
+
|
259 |
+
if global_rank == 0 and args.log_dir is not None:
|
260 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
261 |
+
log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
|
262 |
+
else:
|
263 |
+
log_writer = None
|
264 |
+
|
265 |
+
data_loader_train = torch.utils.data.DataLoader(
|
266 |
+
dataset_train, sampler=sampler_train,
|
267 |
+
batch_size=args.batch_size if not args.multiple_sampling else int(args.batch_size/2),
|
268 |
+
num_workers=args.num_workers,
|
269 |
+
pin_memory=args.pin_mem,
|
270 |
+
drop_last=True,
|
271 |
+
worker_init_fn=utils.seed_worker
|
272 |
+
)
|
273 |
+
|
274 |
+
student_model.to(device)
|
275 |
+
if teacher_model is not None:
|
276 |
+
teacher_model.to(device)
|
277 |
+
model_without_ddp = student_model
|
278 |
+
n_parameters = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
|
279 |
+
|
280 |
+
print("Model = %s" % str(model_without_ddp))
|
281 |
+
print('number of params: {} M'.format(n_parameters / 1e6))
|
282 |
+
|
283 |
+
args.lr = args.lr * total_batch_size / 256
|
284 |
+
args.min_lr = args.min_lr * total_batch_size / 256
|
285 |
+
args.warmup_lr = args.warmup_lr * total_batch_size / 256
|
286 |
+
print("LR = %.8f" % args.lr)
|
287 |
+
print("Batch size = %d" % total_batch_size)
|
288 |
+
print("Number of training steps = %d" % num_training_steps_per_epoch)
|
289 |
+
print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch))
|
290 |
+
|
291 |
+
if args.distributed:
|
292 |
+
student_model = torch.nn.parallel.DistributedDataParallel(student_model, device_ids=[args.gpu], find_unused_parameters=False)
|
293 |
+
model_without_ddp = student_model.module
|
294 |
+
|
295 |
+
optimizer = create_optimizer(
|
296 |
+
args, model_without_ddp)
|
297 |
+
loss_scaler = NativeScaler()
|
298 |
+
|
299 |
+
print("Use step level LR & WD scheduler!")
|
300 |
+
lr_schedule_values = utils.cosine_scheduler(
|
301 |
+
args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
|
302 |
+
warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
|
303 |
+
)
|
304 |
+
if args.weight_decay_end is None:
|
305 |
+
args.weight_decay_end = args.weight_decay
|
306 |
+
wd_schedule_values = utils.cosine_scheduler(
|
307 |
+
args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
|
308 |
+
print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
|
309 |
+
|
310 |
+
utils.auto_load_model(
|
311 |
+
args=args, model=student_model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
|
312 |
+
torch.cuda.empty_cache()
|
313 |
+
print(f"Start training for {args.epochs} epochs")
|
314 |
+
start_time = time.time()
|
315 |
+
for epoch in range(args.start_epoch, args.epochs):
|
316 |
+
if args.distributed:
|
317 |
+
data_loader_train.sampler.set_epoch(epoch)
|
318 |
+
if log_writer is not None:
|
319 |
+
log_writer.set_step(epoch * num_training_steps_per_epoch)
|
320 |
+
train_stats = train_one_epoch(
|
321 |
+
student_model, data_loader_train,
|
322 |
+
optimizer, device, epoch, loss_scaler,
|
323 |
+
args.clip_grad, log_writer=log_writer,
|
324 |
+
start_steps=epoch * num_training_steps_per_epoch,
|
325 |
+
lr_schedule_values=lr_schedule_values,
|
326 |
+
wd_schedule_values=wd_schedule_values,
|
327 |
+
patch_size=patch_size[0],
|
328 |
+
normlize_target=args.normlize_target,
|
329 |
+
teacher_model = teacher_model,
|
330 |
+
target_type=args.target_type,
|
331 |
+
multiple_sampling=args.multiple_sampling,
|
332 |
+
)
|
333 |
+
if args.output_dir:
|
334 |
+
if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
|
335 |
+
utils.save_model(
|
336 |
+
args=args, model=student_model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
337 |
+
loss_scaler=loss_scaler, epoch=epoch)
|
338 |
+
|
339 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
340 |
+
'epoch': epoch, 'n_parameters': n_parameters}
|
341 |
+
|
342 |
+
if args.output_dir and utils.is_main_process():
|
343 |
+
if log_writer is not None:
|
344 |
+
log_writer.flush()
|
345 |
+
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
|
346 |
+
f.write(json.dumps(log_stats) + "\n")
|
347 |
+
#if (epoch + 1) % 2 == 0:
|
348 |
+
#exit(0)
|
349 |
+
|
350 |
+
total_time = time.time() - start_time
|
351 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
352 |
+
print('Training time {}'.format(total_time_str))
|
353 |
+
|
354 |
+
|
355 |
+
if __name__ == '__main__':
|
356 |
+
opts = get_args()
|
357 |
+
if opts.output_dir:
|
358 |
+
Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
|
359 |
+
main(opts)
|
run_videomae_vis.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.backends.cudnn as cudnn
|
6 |
+
from PIL import Image
|
7 |
+
from pathlib import Path
|
8 |
+
from timm.models import create_model
|
9 |
+
import utils
|
10 |
+
import modeling_pretrain
|
11 |
+
from datasets import DataAugmentationForVideoMAE
|
12 |
+
from torchvision.transforms import ToPILImage
|
13 |
+
from einops import rearrange
|
14 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
15 |
+
from decord import VideoReader, cpu
|
16 |
+
from torchvision import transforms
|
17 |
+
from transforms import *
|
18 |
+
# from datasets import DataAugmentationForVideoMAE
|
19 |
+
from masking_generator import TubeMaskingGenerator
|
20 |
+
class DataAugmentationForVideoMAE(object):
|
21 |
+
def __init__(self, args):
|
22 |
+
self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
|
23 |
+
self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
|
24 |
+
normalize = GroupNormalize(self.input_mean, self.input_std)
|
25 |
+
self.train_augmentation = GroupCenterCrop(args.input_size)
|
26 |
+
self.transform = transforms.Compose([
|
27 |
+
self.train_augmentation,
|
28 |
+
Stack(roll=False),
|
29 |
+
ToTorchFormatTensor(div=True),
|
30 |
+
normalize,
|
31 |
+
])
|
32 |
+
if args.mask_type == 'tube':
|
33 |
+
self.masked_position_generator = TubeMaskingGenerator(
|
34 |
+
args.window_size, args.mask_ratio
|
35 |
+
)
|
36 |
+
|
37 |
+
def __call__(self, images):
|
38 |
+
process_data , _ = self.transform(images)
|
39 |
+
return process_data, self.masked_position_generator()
|
40 |
+
|
41 |
+
def __repr__(self):
|
42 |
+
repr = "(DataAugmentationForVideoMAE,\n"
|
43 |
+
repr += " transform = %s,\n" % str(self.transform)
|
44 |
+
repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
|
45 |
+
repr += ")"
|
46 |
+
return repr
|
47 |
+
|
48 |
+
def get_args():
|
49 |
+
parser = argparse.ArgumentParser('VideoMAE visualization reconstruction script', add_help=False)
|
50 |
+
parser.add_argument('img_path', type=str, help='input video path')
|
51 |
+
parser.add_argument('save_path', type=str, help='save video path')
|
52 |
+
parser.add_argument('model_path', type=str, help='checkpoint path of model')
|
53 |
+
parser.add_argument('--mask_type', default='tube', choices=['random', 'tube', 'tubelet'],
|
54 |
+
type=str, help='masked strategy of video tokens/patches')
|
55 |
+
parser.add_argument('--num_frames', type=int, default= 16)
|
56 |
+
parser.add_argument('--sampling_rate', type=int, default= 4)
|
57 |
+
parser.add_argument('--decoder_depth', default=4, type=int,
|
58 |
+
help='depth of decoder')
|
59 |
+
parser.add_argument('--input_size', default=224, type=int,
|
60 |
+
help='videos input size for backbone')
|
61 |
+
parser.add_argument('--device', default='cuda:0',
|
62 |
+
help='device to use for training / testing')
|
63 |
+
parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
|
64 |
+
parser.add_argument('--mask_ratio', default=0.75, type=float,
|
65 |
+
help='ratio of the visual tokens/patches need be masked')
|
66 |
+
# Model parameters
|
67 |
+
parser.add_argument('--model', default='pretrain_videomae_small_patch16_224', type=str, metavar='MODEL',
|
68 |
+
help='Name of model to vis')
|
69 |
+
parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT',
|
70 |
+
help='Drop path rate (default: 0.1)')
|
71 |
+
|
72 |
+
# Tubelet params
|
73 |
+
parser.add_argument('--add_tubelets', action='store_true')
|
74 |
+
parser.set_defaults(add_tubelets=True)
|
75 |
+
parser.add_argument('--use_objects', action='store_true')
|
76 |
+
parser.set_defaults(use_objects=True)
|
77 |
+
parser.add_argument('--motion_type', type=str, default='gaussian')
|
78 |
+
parser.add_argument('--scales', type=str, default='[32, 48, 56, 64, 96, 128]')
|
79 |
+
parser.add_argument('--loc_velocity', type=int, default=12)
|
80 |
+
parser.add_argument('--mixed_tubelet', action='store_true')
|
81 |
+
parser.set_defaults(mixed_tubelet=False)
|
82 |
+
parser.add_argument('--visible_frames', type=str, default=None)
|
83 |
+
|
84 |
+
|
85 |
+
return parser.parse_args()
|
86 |
+
|
87 |
+
|
88 |
+
def get_model(args):
|
89 |
+
print(f"Creating model: {args.model}")
|
90 |
+
model = create_model(
|
91 |
+
args.model,
|
92 |
+
pretrained=False,
|
93 |
+
drop_path_rate=args.drop_path,
|
94 |
+
drop_block_rate=None,
|
95 |
+
decoder_depth=args.decoder_depth
|
96 |
+
)
|
97 |
+
|
98 |
+
return model
|
99 |
+
|
100 |
+
|
101 |
+
def main(args):
|
102 |
+
print(args)
|
103 |
+
|
104 |
+
device = torch.device(args.device)
|
105 |
+
cudnn.benchmark = True
|
106 |
+
|
107 |
+
model = get_model(args)
|
108 |
+
patch_size = model.encoder.patch_embed.patch_size
|
109 |
+
print("Patch size = %s" % str(patch_size))
|
110 |
+
args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1])
|
111 |
+
args.patch_size = patch_size
|
112 |
+
|
113 |
+
model.to(device)
|
114 |
+
checkpoint = torch.load(args.model_path, map_location='cpu')
|
115 |
+
model.load_state_dict(checkpoint['model'])
|
116 |
+
model.eval()
|
117 |
+
|
118 |
+
if args.save_path:
|
119 |
+
Path(args.save_path).mkdir(parents=True, exist_ok=True)
|
120 |
+
|
121 |
+
with open(args.img_path, 'rb') as f:
|
122 |
+
vr = VideoReader(f, ctx=cpu(0))
|
123 |
+
duration = len(vr)
|
124 |
+
new_length = 1
|
125 |
+
new_step = 1
|
126 |
+
skip_length = new_length * new_step
|
127 |
+
# frame_id_list = [1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61]
|
128 |
+
|
129 |
+
|
130 |
+
tmp = np.arange(0,32, 2) + 60
|
131 |
+
frame_id_list = tmp.tolist()
|
132 |
+
# average_duration = (duration - skip_length + 1) // args.num_frames
|
133 |
+
# if average_duration > 0:
|
134 |
+
# frame_id_list = np.multiply(list(range(args.num_frames)),
|
135 |
+
# average_duration)
|
136 |
+
# frame_id_list = frame_id_list + np.random.randint(average_duration,
|
137 |
+
# size=args.num_frames)
|
138 |
+
|
139 |
+
video_data = vr.get_batch(frame_id_list).asnumpy()
|
140 |
+
print(video_data.shape)
|
141 |
+
img = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]
|
142 |
+
|
143 |
+
transforms = DataAugmentationForVideoMAE(args)
|
144 |
+
img, bool_masked_pos = transforms((img, None)) # T*C,H,W
|
145 |
+
# print(img.shape)
|
146 |
+
img = img.view((args.num_frames , 3) + img.size()[-2:]).transpose(0,1) # T*C,H,W -> T,C,H,W -> C,T,H,W
|
147 |
+
# img = img.view(( -1 , args.num_frames) + img.size()[-2:])
|
148 |
+
bool_masked_pos = torch.from_numpy(bool_masked_pos)
|
149 |
+
|
150 |
+
with torch.no_grad():
|
151 |
+
# img = img[None, :]
|
152 |
+
# bool_masked_pos = bool_masked_pos[None, :]
|
153 |
+
img = img.unsqueeze(0)
|
154 |
+
print(img.shape)
|
155 |
+
bool_masked_pos = bool_masked_pos.unsqueeze(0)
|
156 |
+
|
157 |
+
img = img.to(device, non_blocking=True)
|
158 |
+
bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
|
159 |
+
outputs = model(img, bool_masked_pos)
|
160 |
+
|
161 |
+
#save original video
|
162 |
+
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
|
163 |
+
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
|
164 |
+
ori_img = img * std + mean # in [0, 1]
|
165 |
+
imgs = [ToPILImage()(ori_img[0,:,vid,:,:].cpu()) for vid, _ in enumerate(frame_id_list) ]
|
166 |
+
for id, im in enumerate(imgs):
|
167 |
+
im.save(f"{args.save_path}/ori_img{id}.jpg")
|
168 |
+
|
169 |
+
img_squeeze = rearrange(ori_img, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', p0=2, p1=patch_size[0], p2=patch_size[0])
|
170 |
+
img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / (img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
|
171 |
+
img_patch = rearrange(img_norm, 'b n p c -> b n (p c)')
|
172 |
+
img_patch[bool_masked_pos] = outputs
|
173 |
+
|
174 |
+
#make mask
|
175 |
+
mask = torch.ones_like(img_patch)
|
176 |
+
mask[bool_masked_pos] = 0
|
177 |
+
mask = rearrange(mask, 'b n (p c) -> b n p c', c=3)
|
178 |
+
mask = rearrange(mask, 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2) ', p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14)
|
179 |
+
|
180 |
+
#save reconstruction video
|
181 |
+
rec_img = rearrange(img_patch, 'b n (p c) -> b n p c', c=3)
|
182 |
+
# Notice: To visualize the reconstruction video, we add the predict and the original mean and var of each patch.
|
183 |
+
rec_img = rec_img * (img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) + img_squeeze.mean(dim=-2, keepdim=True)
|
184 |
+
rec_img = rearrange(rec_img, 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)', p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14)
|
185 |
+
imgs = [ ToPILImage()(rec_img[0, :, vid, :, :].cpu().clamp(0,0.996)) for vid, _ in enumerate(frame_id_list) ]
|
186 |
+
|
187 |
+
for id, im in enumerate(imgs):
|
188 |
+
im.save(f"{args.save_path}/rec_img{id}.jpg")
|
189 |
+
|
190 |
+
#save masked video
|
191 |
+
img_mask = rec_img * mask
|
192 |
+
imgs = [ToPILImage()(img_mask[0, :, vid, :, :].cpu()) for vid, _ in enumerate(frame_id_list)]
|
193 |
+
for id, im in enumerate(imgs):
|
194 |
+
im.save(f"{args.save_path}/mask_img{id}.jpg")
|
195 |
+
|
196 |
+
if __name__ == '__main__':
|
197 |
+
opts = get_args()
|
198 |
+
main(opts)
|
ssv2.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
from random_erasing import RandomErasing
|
6 |
+
import warnings
|
7 |
+
from decord import VideoReader, cpu
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
import video_transforms as video_transforms
|
10 |
+
import volume_transforms as volume_transforms
|
11 |
+
|
12 |
+
|
13 |
+
class SSVideoClsDataset(Dataset):
|
14 |
+
"""Load your own video classification dataset."""
|
15 |
+
|
16 |
+
def __init__(self, anno_path, data_path, mode='train', clip_len=8,
|
17 |
+
crop_size=224, short_side_size=256, new_height=256,
|
18 |
+
new_width=340, keep_aspect_ratio=True, num_segment=1,
|
19 |
+
num_crop=1, test_num_segment=10, test_num_crop=3, args=None):
|
20 |
+
self.anno_path = anno_path
|
21 |
+
self.data_path = data_path
|
22 |
+
self.mode = mode
|
23 |
+
self.clip_len = clip_len
|
24 |
+
self.crop_size = crop_size
|
25 |
+
self.short_side_size = short_side_size
|
26 |
+
self.new_height = new_height
|
27 |
+
self.new_width = new_width
|
28 |
+
self.keep_aspect_ratio = keep_aspect_ratio
|
29 |
+
self.num_segment = num_segment
|
30 |
+
self.test_num_segment = test_num_segment
|
31 |
+
self.num_crop = num_crop
|
32 |
+
self.test_num_crop = test_num_crop
|
33 |
+
self.args = args
|
34 |
+
self.aug = False
|
35 |
+
self.rand_erase = False
|
36 |
+
if self.mode in ['train']:
|
37 |
+
self.aug = True
|
38 |
+
if self.args.reprob > 0:
|
39 |
+
self.rand_erase = True
|
40 |
+
if VideoReader is None:
|
41 |
+
raise ImportError("Unable to import `decord` which is required to read videos.")
|
42 |
+
|
43 |
+
import pandas as pd
|
44 |
+
cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ')
|
45 |
+
self.dataset_samples = list(cleaned.values[:, 0])
|
46 |
+
self.label_array = list(cleaned.values[:, 1])
|
47 |
+
|
48 |
+
if (mode == 'train'):
|
49 |
+
pass
|
50 |
+
|
51 |
+
elif (mode == 'validation'):
|
52 |
+
self.data_transform = video_transforms.Compose([
|
53 |
+
video_transforms.Resize(self.short_side_size, interpolation='bilinear'),
|
54 |
+
video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)),
|
55 |
+
volume_transforms.ClipToTensor(),
|
56 |
+
video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
57 |
+
std=[0.229, 0.224, 0.225])
|
58 |
+
])
|
59 |
+
elif mode == 'test':
|
60 |
+
self.data_resize = video_transforms.Compose([
|
61 |
+
video_transforms.Resize(size=(short_side_size), interpolation='bilinear')
|
62 |
+
])
|
63 |
+
self.data_transform = video_transforms.Compose([
|
64 |
+
volume_transforms.ClipToTensor(),
|
65 |
+
video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
66 |
+
std=[0.229, 0.224, 0.225])
|
67 |
+
])
|
68 |
+
self.test_seg = []
|
69 |
+
self.test_dataset = []
|
70 |
+
self.test_label_array = []
|
71 |
+
for ck in range(self.test_num_segment):
|
72 |
+
for cp in range(self.test_num_crop):
|
73 |
+
for idx in range(len(self.label_array)):
|
74 |
+
sample_label = self.label_array[idx]
|
75 |
+
self.test_label_array.append(sample_label)
|
76 |
+
self.test_dataset.append(self.dataset_samples[idx])
|
77 |
+
self.test_seg.append((ck, cp))
|
78 |
+
|
79 |
+
def __getitem__(self, index):
|
80 |
+
if self.mode == 'train':
|
81 |
+
args = self.args
|
82 |
+
scale_t = 1
|
83 |
+
|
84 |
+
sample = self.dataset_samples[index]
|
85 |
+
buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C
|
86 |
+
if len(buffer) == 0:
|
87 |
+
while len(buffer) == 0:
|
88 |
+
warnings.warn("video {} not correctly loaded during training".format(sample))
|
89 |
+
index = np.random.randint(self.__len__())
|
90 |
+
sample = self.dataset_samples[index]
|
91 |
+
buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t)
|
92 |
+
|
93 |
+
if args.num_sample > 1:
|
94 |
+
frame_list = []
|
95 |
+
label_list = []
|
96 |
+
index_list = []
|
97 |
+
for _ in range(args.num_sample):
|
98 |
+
new_frames = self._aug_frame(buffer, args)
|
99 |
+
label = self.label_array[index]
|
100 |
+
frame_list.append(new_frames)
|
101 |
+
label_list.append(label)
|
102 |
+
index_list.append(index)
|
103 |
+
return frame_list, label_list, index_list, {}
|
104 |
+
else:
|
105 |
+
buffer = self._aug_frame(buffer, args)
|
106 |
+
|
107 |
+
return buffer, self.label_array[index], index, {}
|
108 |
+
|
109 |
+
elif self.mode == 'validation':
|
110 |
+
sample = self.dataset_samples[index]
|
111 |
+
buffer = self.loadvideo_decord(sample)
|
112 |
+
if len(buffer) == 0:
|
113 |
+
while len(buffer) == 0:
|
114 |
+
warnings.warn("video {} not correctly loaded during validation".format(sample))
|
115 |
+
index = np.random.randint(self.__len__())
|
116 |
+
sample = self.dataset_samples[index]
|
117 |
+
buffer = self.loadvideo_decord(sample)
|
118 |
+
buffer = self.data_transform(buffer)
|
119 |
+
return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
|
120 |
+
|
121 |
+
elif self.mode == 'test':
|
122 |
+
sample = self.test_dataset[index]
|
123 |
+
chunk_nb, split_nb = self.test_seg[index]
|
124 |
+
buffer = self.loadvideo_decord(sample)
|
125 |
+
|
126 |
+
while len(buffer) == 0:
|
127 |
+
warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
|
128 |
+
str(self.test_dataset[index]), chunk_nb, split_nb))
|
129 |
+
index = np.random.randint(self.__len__())
|
130 |
+
sample = self.test_dataset[index]
|
131 |
+
chunk_nb, split_nb = self.test_seg[index]
|
132 |
+
buffer = self.loadvideo_decord(sample)
|
133 |
+
|
134 |
+
buffer = self.data_resize(buffer)
|
135 |
+
if isinstance(buffer, list):
|
136 |
+
buffer = np.stack(buffer, 0)
|
137 |
+
|
138 |
+
spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
|
139 |
+
/ (self.test_num_crop - 1)
|
140 |
+
temporal_start = chunk_nb # 0/1
|
141 |
+
spatial_start = int(split_nb * spatial_step)
|
142 |
+
if buffer.shape[1] >= buffer.shape[2]:
|
143 |
+
buffer = buffer[temporal_start::2, \
|
144 |
+
spatial_start:spatial_start + self.short_side_size, :, :]
|
145 |
+
else:
|
146 |
+
buffer = buffer[temporal_start::2, \
|
147 |
+
:, spatial_start:spatial_start + self.short_side_size, :]
|
148 |
+
|
149 |
+
buffer = self.data_transform(buffer)
|
150 |
+
return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
|
151 |
+
chunk_nb, split_nb
|
152 |
+
else:
|
153 |
+
raise NameError('mode {} unkown'.format(self.mode))
|
154 |
+
|
155 |
+
def _aug_frame(
|
156 |
+
self,
|
157 |
+
buffer,
|
158 |
+
args,
|
159 |
+
):
|
160 |
+
|
161 |
+
aug_transform = video_transforms.create_random_augment(
|
162 |
+
input_size=(self.crop_size, self.crop_size),
|
163 |
+
auto_augment=args.aa,
|
164 |
+
interpolation=args.train_interpolation,
|
165 |
+
)
|
166 |
+
|
167 |
+
buffer = [
|
168 |
+
transforms.ToPILImage()(frame) for frame in buffer
|
169 |
+
]
|
170 |
+
|
171 |
+
buffer = aug_transform(buffer)
|
172 |
+
|
173 |
+
buffer = [transforms.ToTensor()(img) for img in buffer]
|
174 |
+
buffer = torch.stack(buffer) # T C H W
|
175 |
+
buffer = buffer.permute(0, 2, 3, 1) # T H W C
|
176 |
+
|
177 |
+
# T H W C
|
178 |
+
buffer = tensor_normalize(
|
179 |
+
buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
180 |
+
)
|
181 |
+
# T H W C -> C T H W.
|
182 |
+
buffer = buffer.permute(3, 0, 1, 2)
|
183 |
+
# Perform data augmentation.
|
184 |
+
scl, asp = (
|
185 |
+
[0.25, 1.0],
|
186 |
+
[0.75, 1.3333],
|
187 |
+
)
|
188 |
+
|
189 |
+
buffer = spatial_sampling(
|
190 |
+
buffer,
|
191 |
+
spatial_idx=-1,
|
192 |
+
min_scale=256,
|
193 |
+
max_scale=320,
|
194 |
+
crop_size=self.crop_size,
|
195 |
+
random_horizontal_flip=False if args.data_set == 'SSV2' else True,
|
196 |
+
inverse_uniform_sampling=False,
|
197 |
+
aspect_ratio=asp,
|
198 |
+
scale=scl,
|
199 |
+
motion_shift=False
|
200 |
+
)
|
201 |
+
|
202 |
+
if self.rand_erase:
|
203 |
+
erase_transform = RandomErasing(
|
204 |
+
args.reprob,
|
205 |
+
mode=args.remode,
|
206 |
+
max_count=args.recount,
|
207 |
+
num_splits=args.recount,
|
208 |
+
device="cpu",
|
209 |
+
)
|
210 |
+
buffer = buffer.permute(1, 0, 2, 3)
|
211 |
+
buffer = erase_transform(buffer)
|
212 |
+
buffer = buffer.permute(1, 0, 2, 3)
|
213 |
+
|
214 |
+
return buffer
|
215 |
+
|
216 |
+
|
217 |
+
def loadvideo_decord(self, sample, sample_rate_scale=1):
|
218 |
+
"""Load video content using Decord"""
|
219 |
+
fname = sample
|
220 |
+
|
221 |
+
if not (os.path.exists(fname)):
|
222 |
+
return []
|
223 |
+
|
224 |
+
# avoid hanging issue
|
225 |
+
if os.path.getsize(fname) < 1 * 1024:
|
226 |
+
print('SKIP: ', fname, " - ", os.path.getsize(fname))
|
227 |
+
return []
|
228 |
+
try:
|
229 |
+
if self.keep_aspect_ratio:
|
230 |
+
vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
|
231 |
+
else:
|
232 |
+
vr = VideoReader(fname, width=self.new_width, height=self.new_height,
|
233 |
+
num_threads=1, ctx=cpu(0))
|
234 |
+
except:
|
235 |
+
print("video cannot be loaded by decord: ", fname)
|
236 |
+
return []
|
237 |
+
|
238 |
+
if self.mode == 'test':
|
239 |
+
all_index = []
|
240 |
+
tick = len(vr) / float(self.num_segment)
|
241 |
+
all_index = list(np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segment)] +
|
242 |
+
[int(tick * x) for x in range(self.num_segment)]))
|
243 |
+
while len(all_index) < (self.num_segment * self.test_num_segment):
|
244 |
+
all_index.append(all_index[-1])
|
245 |
+
all_index = list(np.sort(np.array(all_index)))
|
246 |
+
vr.seek(0)
|
247 |
+
buffer = vr.get_batch(all_index).asnumpy()
|
248 |
+
return buffer
|
249 |
+
|
250 |
+
# handle temporal segments
|
251 |
+
average_duration = len(vr) // self.num_segment
|
252 |
+
all_index = []
|
253 |
+
if average_duration > 0:
|
254 |
+
all_index += list(np.multiply(list(range(self.num_segment)), average_duration) + np.random.randint(average_duration,
|
255 |
+
size=self.num_segment))
|
256 |
+
elif len(vr) > self.num_segment:
|
257 |
+
all_index += list(np.sort(np.random.randint(len(vr), size=self.num_segment)))
|
258 |
+
else:
|
259 |
+
all_index += list(np.zeros((self.num_segment,)))
|
260 |
+
all_index = list(np.array(all_index))
|
261 |
+
vr.seek(0)
|
262 |
+
buffer = vr.get_batch(all_index).asnumpy()
|
263 |
+
return buffer
|
264 |
+
|
265 |
+
def __len__(self):
|
266 |
+
if self.mode != 'test':
|
267 |
+
return len(self.dataset_samples)
|
268 |
+
else:
|
269 |
+
return len(self.test_dataset)
|
270 |
+
|
271 |
+
|
272 |
+
def spatial_sampling(
|
273 |
+
frames,
|
274 |
+
spatial_idx=-1,
|
275 |
+
min_scale=256,
|
276 |
+
max_scale=320,
|
277 |
+
crop_size=224,
|
278 |
+
random_horizontal_flip=True,
|
279 |
+
inverse_uniform_sampling=False,
|
280 |
+
aspect_ratio=None,
|
281 |
+
scale=None,
|
282 |
+
motion_shift=False,
|
283 |
+
):
|
284 |
+
"""
|
285 |
+
Perform spatial sampling on the given video frames. If spatial_idx is
|
286 |
+
-1, perform random scale, random crop, and random flip on the given
|
287 |
+
frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
|
288 |
+
with the given spatial_idx.
|
289 |
+
Args:
|
290 |
+
frames (tensor): frames of images sampled from the video. The
|
291 |
+
dimension is `num frames` x `height` x `width` x `channel`.
|
292 |
+
spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
|
293 |
+
or 2, perform left, center, right crop if width is larger than
|
294 |
+
height, and perform top, center, buttom crop if height is larger
|
295 |
+
than width.
|
296 |
+
min_scale (int): the minimal size of scaling.
|
297 |
+
max_scale (int): the maximal size of scaling.
|
298 |
+
crop_size (int): the size of height and width used to crop the
|
299 |
+
frames.
|
300 |
+
inverse_uniform_sampling (bool): if True, sample uniformly in
|
301 |
+
[1 / max_scale, 1 / min_scale] and take a reciprocal to get the
|
302 |
+
scale. If False, take a uniform sample from [min_scale,
|
303 |
+
max_scale].
|
304 |
+
aspect_ratio (list): Aspect ratio range for resizing.
|
305 |
+
scale (list): Scale range for resizing.
|
306 |
+
motion_shift (bool): Whether to apply motion shift for resizing.
|
307 |
+
Returns:
|
308 |
+
frames (tensor): spatially sampled frames.
|
309 |
+
"""
|
310 |
+
assert spatial_idx in [-1, 0, 1, 2]
|
311 |
+
if spatial_idx == -1:
|
312 |
+
if aspect_ratio is None and scale is None:
|
313 |
+
frames, _ = video_transforms.random_short_side_scale_jitter(
|
314 |
+
images=frames,
|
315 |
+
min_size=min_scale,
|
316 |
+
max_size=max_scale,
|
317 |
+
inverse_uniform_sampling=inverse_uniform_sampling,
|
318 |
+
)
|
319 |
+
frames, _ = video_transforms.random_crop(frames, crop_size)
|
320 |
+
else:
|
321 |
+
transform_func = (
|
322 |
+
video_transforms.random_resized_crop_with_shift
|
323 |
+
if motion_shift
|
324 |
+
else video_transforms.random_resized_crop
|
325 |
+
)
|
326 |
+
frames = transform_func(
|
327 |
+
images=frames,
|
328 |
+
target_height=crop_size,
|
329 |
+
target_width=crop_size,
|
330 |
+
scale=scale,
|
331 |
+
ratio=aspect_ratio,
|
332 |
+
)
|
333 |
+
if random_horizontal_flip:
|
334 |
+
frames, _ = video_transforms.horizontal_flip(0.5, frames)
|
335 |
+
else:
|
336 |
+
# The testing is deterministic and no jitter should be performed.
|
337 |
+
# min_scale, max_scale, and crop_size are expect to be the same.
|
338 |
+
assert len({min_scale, max_scale, crop_size}) == 1
|
339 |
+
frames, _ = video_transforms.random_short_side_scale_jitter(
|
340 |
+
frames, min_scale, max_scale
|
341 |
+
)
|
342 |
+
frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx)
|
343 |
+
return frames
|
344 |
+
|
345 |
+
|
346 |
+
def tensor_normalize(tensor, mean, std):
|
347 |
+
"""
|
348 |
+
Normalize a given tensor by subtracting the mean and dividing the std.
|
349 |
+
Args:
|
350 |
+
tensor (tensor): tensor to normalize.
|
351 |
+
mean (tensor or list): mean value to subtract.
|
352 |
+
std (tensor or list): std to divide.
|
353 |
+
"""
|
354 |
+
if tensor.dtype == torch.uint8:
|
355 |
+
tensor = tensor.float()
|
356 |
+
tensor = tensor / 255.0
|
357 |
+
if type(mean) == list:
|
358 |
+
mean = torch.tensor(mean)
|
359 |
+
if type(std) == list:
|
360 |
+
std = torch.tensor(std)
|
361 |
+
tensor = tensor - mean
|
362 |
+
tensor = tensor / std
|
363 |
+
return tensor
|
synthetic_tubelets.py
ADDED
@@ -0,0 +1,785 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import cv2
|
7 |
+
from typing import List
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from dynamic_utils import (extend_key_frame_to_all,
|
11 |
+
sample_key_frames)
|
12 |
+
import imutils
|
13 |
+
import math
|
14 |
+
from scipy.ndimage import gaussian_filter1d
|
15 |
+
from glob import glob
|
16 |
+
|
17 |
+
|
18 |
+
class RandomRegionSampler(object):
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
num_rois: int,
|
22 |
+
scales: tuple,
|
23 |
+
ratios: tuple,
|
24 |
+
scale_jitter: float):
|
25 |
+
""" Randomly sample several RoIs
|
26 |
+
|
27 |
+
Args:
|
28 |
+
num_rois (int): number of sampled RoIs per image
|
29 |
+
scales (tuple): scales of candidate bounding boxes
|
30 |
+
ratios (tuple): aspect ratios of candidate bounding boxes
|
31 |
+
scale_jitter (float): scale jitter factor, positive number
|
32 |
+
"""
|
33 |
+
|
34 |
+
self.num_rois = num_rois
|
35 |
+
self.scale_jitter = scale_jitter
|
36 |
+
|
37 |
+
scales = np.array(scales, np.float32)
|
38 |
+
ratios = np.array(ratios, np.float32)
|
39 |
+
widths = scales.reshape(1, -1) * np.sqrt(ratios).reshape(-1, 1)
|
40 |
+
heights = scales.reshape(1, -1) / np.sqrt(ratios).reshape(-1, 1)
|
41 |
+
self.anchors = np.concatenate((widths.reshape(-1, 1),
|
42 |
+
heights.reshape(-1, 1)), axis=-1)
|
43 |
+
|
44 |
+
def sample(self, data: List[np.ndarray]) -> np.ndarray:
|
45 |
+
""" Sample boxes.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
data (list): image list, each element is a numpy.ndarray
|
49 |
+
in shape of [H, W, 3]
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
boxes (np.ndarray): the sampled bounding boxes. in shape of
|
53 |
+
[self.num_rois, 4], represented in (x1, y1, x2, y2).
|
54 |
+
|
55 |
+
"""
|
56 |
+
h, w = data[0].shape[0:2]
|
57 |
+
|
58 |
+
# random sample box shapes
|
59 |
+
anchor_inds = np.random.randint(0, len(self.anchors),
|
60 |
+
size=(self.num_rois, ))
|
61 |
+
box_shapes = self.anchors[anchor_inds].copy()
|
62 |
+
if self.scale_jitter is not None:
|
63 |
+
scale_factors = np.random.uniform(-self.scale_jitter,
|
64 |
+
self.scale_jitter,
|
65 |
+
size=(self.num_rois, 2))
|
66 |
+
box_shapes = box_shapes * np.exp(scale_factors)
|
67 |
+
box_shapes[:, 0] = np.clip(box_shapes[:, 0], 1, w - 1)
|
68 |
+
box_shapes[:, 1] = np.clip(box_shapes[:, 1], 1, h - 1)
|
69 |
+
|
70 |
+
#print("box shapes",box_shapes,box_shapes.shape)
|
71 |
+
# random sample box x1, y1
|
72 |
+
x1 = np.random.uniform(0, w - box_shapes[:, 0])
|
73 |
+
y1 = np.random.uniform(0, h - box_shapes[:, 1])
|
74 |
+
#print("x1, y1",x1,y1)
|
75 |
+
boxes = np.concatenate((x1.reshape(-1, 1),
|
76 |
+
y1.reshape(-1, 1),
|
77 |
+
(x1 + box_shapes[:, 0]).reshape(-1, 1),
|
78 |
+
(y1 + box_shapes[:, 1]).reshape(-1, 1)),
|
79 |
+
axis=1)
|
80 |
+
#print("sampled initial boxes",boxes)
|
81 |
+
|
82 |
+
return boxes
|
83 |
+
|
84 |
+
def sample_box_shapes(self, data: List[np.ndarray]) -> np.ndarray:
|
85 |
+
""" Sample boxes.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
data (list): image list, each element is a numpy.ndarray
|
89 |
+
in shape of [H, W, 3]
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
boxes (np.ndarray): the sampled bounding boxes. in shape of
|
93 |
+
[self.num_rois, 4], represented in (x1, y1, x2, y2).
|
94 |
+
|
95 |
+
"""
|
96 |
+
h, w = data[0].shape[0:2]
|
97 |
+
|
98 |
+
# random sample box shapes
|
99 |
+
anchor_inds = np.random.randint(0, len(self.anchors),
|
100 |
+
size=(self.num_rois, ))
|
101 |
+
box_shapes = self.anchors[anchor_inds].copy()
|
102 |
+
if self.scale_jitter is not None:
|
103 |
+
scale_factors = np.random.uniform(-self.scale_jitter,
|
104 |
+
self.scale_jitter,
|
105 |
+
size=(self.num_rois, 2))
|
106 |
+
box_shapes = box_shapes * np.exp(scale_factors)
|
107 |
+
box_shapes[:, 0] = np.clip(box_shapes[:, 0], 1, w - 1)
|
108 |
+
box_shapes[:, 1] = np.clip(box_shapes[:, 1], 1, h - 1)
|
109 |
+
|
110 |
+
#print(" gaussian box shapes",box_shapes)
|
111 |
+
|
112 |
+
return box_shapes
|
113 |
+
|
114 |
+
|
115 |
+
class PatchMask(object):
|
116 |
+
|
117 |
+
def __init__(self,
|
118 |
+
use_objects: bool,
|
119 |
+
objects_path: str,
|
120 |
+
region_sampler: dict,
|
121 |
+
key_frame_probs: list,
|
122 |
+
loc_velocity: float,
|
123 |
+
rot_velocity: float,
|
124 |
+
size_velocity: float,
|
125 |
+
label_prob: float,
|
126 |
+
patch_transformation: str,
|
127 |
+
motion_type: str):
|
128 |
+
|
129 |
+
""" Core transformation in Catch-the-Patch.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
region_sampler (dict): region sampler setting, it will be used to
|
133 |
+
construct a RandomRegionSampler object.
|
134 |
+
key_frame_probs (list): probabilities of sampling how many key
|
135 |
+
frames. The sum of this list should be 1.
|
136 |
+
loc_velocity (float): the maximum patch movement speed. (pix per
|
137 |
+
frame).
|
138 |
+
size_velocity (float): the maximum size change ratios between two
|
139 |
+
neighbouring frames.
|
140 |
+
label_prob (float): how many percentages of frames will be
|
141 |
+
modified. Note that even the frame is not modified, we still
|
142 |
+
force the model to infer the patch positions. (see MRM module
|
143 |
+
in the paper).
|
144 |
+
"""
|
145 |
+
self.region_sampler = RandomRegionSampler(**region_sampler)
|
146 |
+
self.key_frame_probs = key_frame_probs
|
147 |
+
self.loc_velocity = loc_velocity
|
148 |
+
self.rot_velocity = rot_velocity
|
149 |
+
self.size_velocity = size_velocity
|
150 |
+
self.label_prob = label_prob
|
151 |
+
if motion_type is not None:
|
152 |
+
self.motion_type = motion_type
|
153 |
+
self.patch_transformation = patch_transformation
|
154 |
+
self.use_objects = use_objects
|
155 |
+
|
156 |
+
if self.use_objects:
|
157 |
+
#self.object_list = glob("/ibex/user/jianl0b/Dataset/Fida_file_1/video_images/micheal_objects/cleaned/images/*/*")
|
158 |
+
self.object_list = glob(objects_path+"/*/*")
|
159 |
+
|
160 |
+
#self.object_list = glob("/ibex/project/c2134/Fida/micheal_objects_big/cleaned_big/images/*/*")
|
161 |
+
print(self.object_list[0:10],len(self.object_list))
|
162 |
+
|
163 |
+
def paste_objects(self, data, traj_rois, boxes):
|
164 |
+
|
165 |
+
objects_list = []
|
166 |
+
label_list = []
|
167 |
+
|
168 |
+
for i in range(len(boxes)):
|
169 |
+
objects, crop_index = self.pick_objects(data, traj_rois[i])
|
170 |
+
labels = np.random.uniform(0, 1, size=(len(data), ))
|
171 |
+
labels[crop_index] = 0.0
|
172 |
+
labels[0] = 0.0
|
173 |
+
labels = labels <= self.label_prob
|
174 |
+
objects_list.append(objects)
|
175 |
+
label_list.append(labels)
|
176 |
+
|
177 |
+
return objects_list, None, label_list
|
178 |
+
|
179 |
+
def paste_patches(self, data, traj_rois, boxes):
|
180 |
+
|
181 |
+
patches_list = []
|
182 |
+
alphas_list = []
|
183 |
+
label_list = []
|
184 |
+
|
185 |
+
for i in range(len(boxes)):
|
186 |
+
patches, crop_index = self.pick_patches(data, traj_rois[i])
|
187 |
+
alphas = self.pick_alphas(data, traj_rois[i], crop_index)
|
188 |
+
labels = np.random.uniform(0, 1, size=(len(data), ))
|
189 |
+
labels[crop_index] = 0.0
|
190 |
+
labels[0] = 0.0
|
191 |
+
labels = labels <= self.label_prob
|
192 |
+
patches_list.append(patches)
|
193 |
+
alphas_list.append(alphas)
|
194 |
+
label_list.append(labels)
|
195 |
+
|
196 |
+
return patches_list, alphas_list, label_list
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
def pick_patches(self,
|
203 |
+
data: List[np.ndarray],
|
204 |
+
traj_rois: np.ndarray) -> tuple:
|
205 |
+
""" Pick image patches from the raw video frame.
|
206 |
+
|
207 |
+
We just randomly select a frame index, and crop the frame according to
|
208 |
+
the trajectory rois. This cropped patch will be resized into the
|
209 |
+
suitable size specified by the traj_rois.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
data (List[np.ndarray]): list of images, each element is in shape
|
213 |
+
of [H, W, 3]
|
214 |
+
traj_rois (np.ndarray): the generated trajectories, in shape of
|
215 |
+
[N_frames, 4]. (x1, y1, x2, y2)
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
patches (List[np.ndarray]): the cropped patches
|
219 |
+
select_idx (int): the frame index which the source patch
|
220 |
+
cropped from.
|
221 |
+
"""
|
222 |
+
traj_sizes = traj_rois[..., 2:4] - traj_rois[..., 0:2]
|
223 |
+
num = len(traj_sizes)
|
224 |
+
select_idx = random.randint(0, num - 1)
|
225 |
+
x1, y1, x2, y2 = traj_rois[select_idx]
|
226 |
+
traj_rois_H = y2 - y1
|
227 |
+
traj_rois_W = x2 - x1
|
228 |
+
|
229 |
+
img = data[select_idx]
|
230 |
+
img_H, img_W, _ = img.shape
|
231 |
+
|
232 |
+
if img_W - traj_rois_W - 1 >= 0 and img_H - traj_rois_H - 1 >= 0:
|
233 |
+
new_x1 = random.randint(0, img_W - traj_rois_W - 1)
|
234 |
+
new_y1 = random.randint(0, img_H - traj_rois_H - 1)
|
235 |
+
new_x2 = new_x1 + traj_rois_W
|
236 |
+
new_y2 = new_y1 + traj_rois_H
|
237 |
+
img = img[new_y1:new_y2, new_x1:new_x2, :]
|
238 |
+
else:
|
239 |
+
img = img
|
240 |
+
patches = [cv2.resize(img, (traj_sizes[i, 0], traj_sizes[i, 1]))
|
241 |
+
for i in range(traj_rois.shape[0])]
|
242 |
+
return patches, select_idx
|
243 |
+
|
244 |
+
def pick_objects(self,
|
245 |
+
data: List[np.ndarray],
|
246 |
+
traj_rois: np.ndarray) -> tuple:
|
247 |
+
""" Pick image patches from the raw video frame.
|
248 |
+
|
249 |
+
We just randomly select a frame index, and crop the frame according to
|
250 |
+
the trajectory rois. This cropped patch will be resized into the
|
251 |
+
suitable size specified by the traj_rois.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
data (List[np.ndarray]): list of images, each element is in shape
|
255 |
+
of [H, W, 3]
|
256 |
+
traj_rois (np.ndarray): the generated trajectories, in shape of
|
257 |
+
[N_frames, 4]. (x1, y1, x2, y2)
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
patches (List[np.ndarray]): the cropped patches
|
261 |
+
select_idx (int): the frame index which the source patch
|
262 |
+
cropped from.
|
263 |
+
"""
|
264 |
+
traj_sizes = traj_rois[..., 2:4] - traj_rois[..., 0:2]
|
265 |
+
num = len(traj_sizes)
|
266 |
+
select_idx = random.randint(0, num - 1)
|
267 |
+
#print(len(data),traj_rois.shape)
|
268 |
+
x1, y1, x2, y2 = traj_rois[select_idx]
|
269 |
+
#print(x1, y1, x2, y2)
|
270 |
+
|
271 |
+
object_ind = random.randint(0, len(self.object_list)- 1)
|
272 |
+
object_img = Image.open(self.object_list[object_ind])
|
273 |
+
object_img = object_img.resize((x2-x1,y2-y1))
|
274 |
+
|
275 |
+
objects = [object_img.resize((traj_sizes[i, 0], traj_sizes[i, 1]))
|
276 |
+
for i in range(traj_rois.shape[0])]
|
277 |
+
|
278 |
+
return objects, select_idx
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
def pick_alphas(self,
|
283 |
+
data,
|
284 |
+
traj_rois: np.ndarray,
|
285 |
+
crop_index: int):
|
286 |
+
""" Generate the alpha masks for merging the patches into the raw
|
287 |
+
frames:
|
288 |
+
out_frame = raw_frame * (1 - alpha) + patch * alpha.
|
289 |
+
Despite the transparency, the alpha values are also used to mask the
|
290 |
+
patches into some predefined shapes, like ellipse or rhombus.
|
291 |
+
There are many strange constants in this function. But we do not
|
292 |
+
conduct any ablation analysis on these constants. They should have
|
293 |
+
little impact to the final performances.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
data (List[np.ndarray]): list of images, each element is in shape
|
297 |
+
of [H, W, 3]
|
298 |
+
traj_rois (np.ndarray): the generated trajectories, in shape of
|
299 |
+
[N_frames, 4]. (x1, y1, x2, y2)
|
300 |
+
crop_index (int): the frame index which the source patch
|
301 |
+
cropped from.
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
alphas (List[np.ndarray]): the generated alpha values
|
305 |
+
|
306 |
+
"""
|
307 |
+
traj_sizes = traj_rois[..., 2:4] - traj_rois[..., 0:2]
|
308 |
+
num_frames = traj_sizes.shape[0]
|
309 |
+
|
310 |
+
base_w, base_h = traj_sizes[crop_index]
|
311 |
+
|
312 |
+
base_x_grids, base_y_grids = np.meshgrid(
|
313 |
+
np.arange(base_w).astype(np.float32),
|
314 |
+
np.arange(base_h).astype(np.float32)
|
315 |
+
|
316 |
+
)
|
317 |
+
ctr_w = (base_w - 1) // 2
|
318 |
+
ctr_h = (base_h - 1) // 2
|
319 |
+
|
320 |
+
dist_to_ctr_x = np.abs(base_x_grids - ctr_w) / base_w
|
321 |
+
dist_to_ctr_y = np.abs(base_y_grids - ctr_h) / base_h
|
322 |
+
|
323 |
+
mask_type = int(np.random.choice(3, p=[0.5, 0.35, 0.15]))
|
324 |
+
if mask_type == 0:
|
325 |
+
dist_to_ctr = np.maximum(dist_to_ctr_x, dist_to_ctr_y)
|
326 |
+
base_alpha = np.ones((base_h, base_w), np.float32)
|
327 |
+
elif mask_type == 1:
|
328 |
+
dist_to_ctr = np.sqrt(dist_to_ctr_x ** 2 + dist_to_ctr_y ** 2)
|
329 |
+
base_alpha = np.where(dist_to_ctr < 0.5,
|
330 |
+
np.ones((base_h, base_w), np.float32),
|
331 |
+
np.zeros((base_h, base_w), np.float32))
|
332 |
+
elif mask_type == 2:
|
333 |
+
dist_to_ctr = (dist_to_ctr_x + dist_to_ctr_y)
|
334 |
+
base_alpha = np.where(dist_to_ctr < 0.5,
|
335 |
+
np.ones((base_h, base_w), np.float32),
|
336 |
+
np.zeros((base_h, base_w), np.float32))
|
337 |
+
else:
|
338 |
+
raise NotImplementedError
|
339 |
+
|
340 |
+
use_smooth_edge = random.uniform(0, 1) < 0.5
|
341 |
+
if use_smooth_edge:
|
342 |
+
turning_point = random.uniform(0.30, 0.45)
|
343 |
+
k = -1 / (0.5 - turning_point)
|
344 |
+
alpha_mul = k * dist_to_ctr - 0.5 * k
|
345 |
+
alpha_mul = np.clip(alpha_mul, 0, 1)
|
346 |
+
base_alpha = base_alpha * alpha_mul
|
347 |
+
|
348 |
+
# sample key frames
|
349 |
+
key_inds = sample_key_frames(num_frames, self.key_frame_probs)
|
350 |
+
frame_alphas = np.random.uniform(0.8, 1.0, size=(len(key_inds), 1))
|
351 |
+
frame_alphas = extend_key_frame_to_all(frame_alphas, key_inds)
|
352 |
+
|
353 |
+
alphas = []
|
354 |
+
for frame_idx in range(num_frames):
|
355 |
+
w, h = traj_sizes[frame_idx]
|
356 |
+
i_alpha = cv2.resize(base_alpha, (w, h))
|
357 |
+
i_alpha = i_alpha * frame_alphas[frame_idx]
|
358 |
+
alphas.append(i_alpha)
|
359 |
+
return alphas
|
360 |
+
|
361 |
+
def get_rotation_angles(self,
|
362 |
+
num_frames,
|
363 |
+
transform_param: dict):
|
364 |
+
key_frame_probs = transform_param['key_frame_probs']
|
365 |
+
loc_key_inds = sample_key_frames(num_frames, key_frame_probs)
|
366 |
+
|
367 |
+
rot_velocity = transform_param['rot_velocity']
|
368 |
+
rot_angles = np.zeros((transform_param['traj_rois'].shape[0],1))
|
369 |
+
|
370 |
+
#print("rotation angles original",rot_angles.shape,loc_key_inds)
|
371 |
+
rot_angles_list= [np.expand_dims(rot_angles, axis=0)]
|
372 |
+
for i in range(len(loc_key_inds) - 1):
|
373 |
+
if rot_velocity > 0:
|
374 |
+
index_diff = loc_key_inds[i + 1] - loc_key_inds[i]
|
375 |
+
shifts = np.random.uniform(low=-rot_velocity* index_diff,
|
376 |
+
high=rot_velocity* index_diff,
|
377 |
+
size=rot_angles.shape)
|
378 |
+
rot_angles = rot_angles + shifts
|
379 |
+
rot_angles_list.append(np.expand_dims(rot_angles, axis=0))
|
380 |
+
rot_angles = np.concatenate(rot_angles_list, axis=0)
|
381 |
+
rot_angles = extend_key_frame_to_all(rot_angles, loc_key_inds, 'random')
|
382 |
+
rot_angles = rot_angles.transpose((1, 0, 2))
|
383 |
+
|
384 |
+
|
385 |
+
return rot_angles
|
386 |
+
|
387 |
+
def get_shear_factors(self,
|
388 |
+
num_frames,
|
389 |
+
transform_param: dict):
|
390 |
+
key_frame_probs = transform_param['key_frame_probs']
|
391 |
+
loc_key_inds = sample_key_frames(num_frames, key_frame_probs)
|
392 |
+
|
393 |
+
#print("Loc key inds shear",loc_key_inds)
|
394 |
+
|
395 |
+
rot_velocity = transform_param['rot_velocity']
|
396 |
+
rot_angles = np.zeros((transform_param['traj_rois'].shape[0],1))
|
397 |
+
|
398 |
+
#print("rotation angles original",rot_angles.shape,loc_key_inds)
|
399 |
+
rot_angles_list= [np.expand_dims(rot_angles, axis=0)]
|
400 |
+
for i in range(len(loc_key_inds) - 1):
|
401 |
+
if rot_velocity > 0:
|
402 |
+
index_diff = loc_key_inds[i + 1] - loc_key_inds[i]
|
403 |
+
shifts = np.random.uniform(low=-rot_velocity* index_diff,
|
404 |
+
high=rot_velocity* index_diff,
|
405 |
+
size=rot_angles.shape)
|
406 |
+
#scales = np.exp(shifts)
|
407 |
+
#print("shifts shear", shifts)
|
408 |
+
#rot_angles = scales
|
409 |
+
rot_angles = rot_angles + shifts
|
410 |
+
rot_angles_list.append(np.expand_dims(rot_angles, axis=0))
|
411 |
+
rot_angles = np.concatenate(rot_angles_list, axis=0)
|
412 |
+
rot_angles = extend_key_frame_to_all(rot_angles, loc_key_inds, 'random')
|
413 |
+
rot_angles = rot_angles.transpose((1, 0, 2))
|
414 |
+
|
415 |
+
return rot_angles
|
416 |
+
|
417 |
+
|
418 |
+
def _apply_image(self,
|
419 |
+
data: List[np.ndarray],
|
420 |
+
transform_param: dict):
|
421 |
+
|
422 |
+
data_1 = data
|
423 |
+
|
424 |
+
# we sort the size and firstly paste the large patch
|
425 |
+
# this trick is because, if we paste the small patch first, it may
|
426 |
+
# be totally covered by a large one.
|
427 |
+
sizes = transform_param['traj_rois'][..., 2:4] - \
|
428 |
+
transform_param['traj_rois'][..., 0:2]
|
429 |
+
avg_sizes = np.prod(np.mean(sizes, axis=1), axis=1)
|
430 |
+
arg_rank = np.argsort(avg_sizes)[::-1]
|
431 |
+
|
432 |
+
width, height,_ = data_1[0].shape
|
433 |
+
#print(width,height)
|
434 |
+
|
435 |
+
|
436 |
+
if self.use_objects:
|
437 |
+
|
438 |
+
if transform_param['patch_transformation'] == 'rotation':
|
439 |
+
rot_angles = self.get_rotation_angles(len(data_1),transform_param)
|
440 |
+
transformed_data_1 = []
|
441 |
+
for frame_idx in range(len(data_1)):
|
442 |
+
i_rois = transform_param['traj_rois'][:, frame_idx, :]
|
443 |
+
img = data_1[frame_idx].copy()
|
444 |
+
for patch_idx in arg_rank:
|
445 |
+
if not transform_param['traj_labels'][patch_idx][frame_idx]:
|
446 |
+
continue
|
447 |
+
i_object = transform_param['patches'][patch_idx][frame_idx] # here patches are objects
|
448 |
+
i_object = np.array(i_object)
|
449 |
+
angle = int(rot_angles[patch_idx][frame_idx])
|
450 |
+
rotated_i_object = imutils.rotate_bound(i_object, angle)
|
451 |
+
|
452 |
+
rotated_i_alpha = rotated_i_object[..., -1]
|
453 |
+
rotated_i_alpha = rotated_i_alpha / 255.0
|
454 |
+
rotated_i_object = rotated_i_object[..., :3]
|
455 |
+
|
456 |
+
h_prime, w_prime, channels = rotated_i_object.shape
|
457 |
+
x1, y1, x2, y2 = i_rois[patch_idx]
|
458 |
+
h, w = y2 - y1, x2 - x1
|
459 |
+
if ((h_prime - h) % 2) == 0:
|
460 |
+
delta_h1 = delta_h2 = math.ceil((h_prime - h) / 2)
|
461 |
+
else:
|
462 |
+
delta_h1 = math.ceil((h_prime - h) / 2)
|
463 |
+
delta_h2 = math.floor((h_prime - h) / 2)
|
464 |
+
if ((w_prime - w) % 2) == 0:
|
465 |
+
delta_w1 = delta_w2 = math.ceil((w_prime - w) / 2)
|
466 |
+
else:
|
467 |
+
delta_w1 = math.ceil((w_prime - w) / 2)
|
468 |
+
delta_w2 = math.floor((w_prime - w) / 2)
|
469 |
+
|
470 |
+
x1_new, y1_new, x2_new, y2_new = x1 - delta_w1, y1 - delta_h1, x2 + delta_w2, y2 + delta_h2
|
471 |
+
if all(i >= 0 for i in [x1_new, y1_new, x2_new, y2_new]) and all(
|
472 |
+
i < width for i in [x1_new, y1_new, x2_new, y2_new]):
|
473 |
+
# in bound
|
474 |
+
i_patch = rotated_i_object
|
475 |
+
i_alpha = rotated_i_alpha[..., np.newaxis]
|
476 |
+
img[y1_new:y2_new, x1_new:x2_new, :] = img[y1_new:y2_new, x1_new:x2_new, :] * (1 - i_alpha) + i_patch * i_alpha
|
477 |
+
else:
|
478 |
+
# out of bound
|
479 |
+
img_H, img_W, C = img.shape
|
480 |
+
patch_H, patch_W, _ = rotated_i_object.shape
|
481 |
+
extended_img = np.zeros((img_H + 2 * patch_H, img_W + 2 * patch_W, C), dtype=img.dtype)
|
482 |
+
extended_img[patch_H:(img_H + patch_H), patch_W:(img_W + patch_W), :] = img
|
483 |
+
|
484 |
+
x1_new += patch_W
|
485 |
+
x2_new += patch_W
|
486 |
+
y1_new += patch_H
|
487 |
+
y2_new += patch_H
|
488 |
+
i_alpha = rotated_i_alpha[..., np.newaxis]
|
489 |
+
extended_img[y1_new:y2_new, x1_new:x2_new, :] = extended_img[y1_new:y2_new, x1_new:x2_new, :] * (1 - i_alpha) + rotated_i_object * i_alpha
|
490 |
+
img = extended_img[patch_H:(img_H + patch_H), patch_W:(img_W + patch_W), :]
|
491 |
+
|
492 |
+
img = np.array(img)
|
493 |
+
transformed_data_1.append(img)
|
494 |
+
|
495 |
+
return transformed_data_1
|
496 |
+
|
497 |
+
|
498 |
+
@staticmethod
|
499 |
+
def rectangle_movement(boxes: np.ndarray,
|
500 |
+
img_wh: tuple,
|
501 |
+
loc_velocity: float,
|
502 |
+
size_velocity: float,
|
503 |
+
num_frames: int,
|
504 |
+
key_frame_probs: List[float]) -> np.ndarray:
|
505 |
+
""" Simulate the object movement.
|
506 |
+
|
507 |
+
Args:
|
508 |
+
boxes (np.ndarray): in shpae of [N_boxes, 4]
|
509 |
+
img_wh (tuple): image width and image height
|
510 |
+
loc_velocity (float): max speed of the center point movement
|
511 |
+
size_velocity (float): max speed of size changes
|
512 |
+
num_frames (int): number of frames
|
513 |
+
key_frame_probs (float): probability distribution of how many key
|
514 |
+
frames will be sampled.
|
515 |
+
|
516 |
+
Returns
|
517 |
+
all_boxes (np.ndarray): the generated box trajectory, in shpae
|
518 |
+
of [N_traj, N_frame, 4].
|
519 |
+
|
520 |
+
"""
|
521 |
+
# Step 1, sample key frames for location changes
|
522 |
+
loc_key_inds = sample_key_frames(num_frames, key_frame_probs)
|
523 |
+
# Step 2, decide box locations in key frames
|
524 |
+
ctr_pts = (boxes[:, 0:2] + boxes[:, 2:4]) * 0.5
|
525 |
+
#print("center points original",ctr_pts)
|
526 |
+
box_sizes = (boxes[:, 2:4] - boxes[:, 0:2])
|
527 |
+
#print("box sizes = ",box_sizes,box_sizes.shape)
|
528 |
+
|
529 |
+
min_ctr_pts = box_sizes * 0.5
|
530 |
+
max_ctr_pts = np.array(img_wh[0:2]).reshape(1, 2) - box_sizes * 0.5
|
531 |
+
|
532 |
+
#print("initial center points ",ctr_pts,loc_key_inds)
|
533 |
+
ctr_pts_list = [np.expand_dims(ctr_pts, axis=0)]
|
534 |
+
#print("ctr pts list",ctr_pts_list)
|
535 |
+
for i in range(len(loc_key_inds) - 1):
|
536 |
+
if loc_velocity > 0:
|
537 |
+
index_diff = loc_key_inds[i + 1] - loc_key_inds[i]
|
538 |
+
shifts = np.random.uniform(low=-loc_velocity * index_diff,
|
539 |
+
high=loc_velocity * index_diff,
|
540 |
+
size=ctr_pts.shape)
|
541 |
+
#print("shifts",shifts)
|
542 |
+
ctr_pts = ctr_pts + shifts
|
543 |
+
ctr_pts = np.clip(ctr_pts, min_ctr_pts, max_ctr_pts)
|
544 |
+
ctr_pts_list.append(np.expand_dims(ctr_pts, axis=0))
|
545 |
+
ctr_pts = np.concatenate(ctr_pts_list, axis=0)
|
546 |
+
|
547 |
+
ctr_pts = extend_key_frame_to_all(ctr_pts, loc_key_inds, 'random')
|
548 |
+
#print("all center points ",ctr_pts,ctr_pts.shape)
|
549 |
+
|
550 |
+
# Step 3, sample key frames for shape changes
|
551 |
+
size_key_inds = sample_key_frames(num_frames, key_frame_probs)
|
552 |
+
|
553 |
+
# Step 4, setup shape in different key frames
|
554 |
+
box_sizes_list = [np.expand_dims(box_sizes, axis=0)]
|
555 |
+
for i in range(len(size_key_inds) - 1):
|
556 |
+
if size_velocity > 0:
|
557 |
+
index_diff = size_key_inds[i + 1] - size_key_inds[i]
|
558 |
+
scales = np.random.uniform(low=-size_velocity * index_diff,
|
559 |
+
high=size_velocity * index_diff,
|
560 |
+
size=box_sizes.shape)
|
561 |
+
scales = np.exp(scales)
|
562 |
+
box_sizes = box_sizes * scales
|
563 |
+
box_sizes_list.append(np.expand_dims(box_sizes, axis=0))
|
564 |
+
box_sizes = np.concatenate(box_sizes_list, axis=0)
|
565 |
+
# print("box sizes before interpolation",box_sizes,size_key_inds)
|
566 |
+
box_sizes = extend_key_frame_to_all(box_sizes, size_key_inds, 'random')
|
567 |
+
#print("box sizes after interpolation",box_sizes)
|
568 |
+
|
569 |
+
# Step 5, construct boxes in key frames
|
570 |
+
all_boxes = np.concatenate((ctr_pts - box_sizes * 0.5,
|
571 |
+
ctr_pts + box_sizes * 0.5), axis=2)
|
572 |
+
# all_boxes[..., 0::2] = np.clip(all_boxes[..., 0::2], 0, img_wh[0])
|
573 |
+
# all_boxes[..., 1::2] = np.clip(all_boxes[..., 1::2], 0, img_wh[1])
|
574 |
+
all_boxes = all_boxes.transpose((1, 0, 2))
|
575 |
+
return all_boxes
|
576 |
+
|
577 |
+
@staticmethod
|
578 |
+
def gaussian_movement(box_shapes: np.ndarray,
|
579 |
+
img_wh: tuple,
|
580 |
+
num_trajs: int,
|
581 |
+
size_velocity: float,
|
582 |
+
num_frames: int,
|
583 |
+
key_frame_probs: List[float]) -> np.ndarray:
|
584 |
+
""" Simulate the object movement.
|
585 |
+
|
586 |
+
Args:
|
587 |
+
|
588 |
+
Returns
|
589 |
+
all_boxes (np.ndarray): the generated box trajectory, in shpae
|
590 |
+
of [N_traj, N_frame, 4].
|
591 |
+
|
592 |
+
"""
|
593 |
+
|
594 |
+
def create_traj(box_shapes):
|
595 |
+
w = img_wh[0]
|
596 |
+
h = img_wh[1]
|
597 |
+
#print("gaussian",w,h)
|
598 |
+
|
599 |
+
n_points = 48 # how many points to create trajectory
|
600 |
+
sigma = 8 # bigger sigma -> smoother trajectory
|
601 |
+
|
602 |
+
# simulate trajectory points
|
603 |
+
#x = np.random.uniform(0,112,n_points)
|
604 |
+
#y = np.random.uniform(0,112,n_points)
|
605 |
+
|
606 |
+
# for 112 x 112
|
607 |
+
x = np.random.uniform(1+box_shapes[0]/2,w-1-box_shapes[0]/2,n_points)
|
608 |
+
y = np.random.uniform(1+box_shapes[1]/2,h-1-box_shapes[1]/2,n_points)
|
609 |
+
|
610 |
+
# for 224x 224
|
611 |
+
# x = np.random.uniform(0,112,n_points)
|
612 |
+
# y = np.random.uniform(0,112,n_points)
|
613 |
+
|
614 |
+
# smooth trajectory
|
615 |
+
xk = gaussian_filter1d(x, sigma=sigma, mode='reflect')
|
616 |
+
yk = gaussian_filter1d(y, sigma=sigma, mode='reflect')
|
617 |
+
|
618 |
+
# normalize and random scale
|
619 |
+
xkk = (xk -xk.min())
|
620 |
+
xkk /= xkk.max()
|
621 |
+
ykk = (yk -yk.min())
|
622 |
+
ykk /= ykk.max()
|
623 |
+
|
624 |
+
#scaling_factor = np.random.randint(20,90)
|
625 |
+
scaling_factor = np.random.randint(40,180)
|
626 |
+
xkk*=scaling_factor # randomize
|
627 |
+
ykk*=scaling_factor # randomize
|
628 |
+
|
629 |
+
|
630 |
+
# random translate and clip
|
631 |
+
translation_factor_x = np.random.randint(0,w-scaling_factor)
|
632 |
+
translation_factor_y = np.random.randint(0,h-scaling_factor)
|
633 |
+
tr_x = xkk + translation_factor_x
|
634 |
+
tr_y = ykk + translation_factor_y
|
635 |
+
|
636 |
+
tr_x = np.clip(tr_x,0,w-1)
|
637 |
+
tr_y = np.clip(tr_y,0,h-1)
|
638 |
+
|
639 |
+
# sample 16 points from trajectory with linear spacing
|
640 |
+
idxs = np.round(np.linspace(0, tr_x.shape[0]-1, num=16)).astype(int)
|
641 |
+
x_f = tr_x[idxs].astype(int)
|
642 |
+
y_f = tr_y[idxs].astype(int)
|
643 |
+
#print(x_f.shape,y_f.shape)
|
644 |
+
traj = np.column_stack((x_f,y_f))
|
645 |
+
traj = np.expand_dims(traj, axis=1)
|
646 |
+
return traj
|
647 |
+
|
648 |
+
# Step 1 create a non-linear trajectory
|
649 |
+
#print(" number of rois",num_trajs,box_shapes.shape)
|
650 |
+
ctr_pts_list = []
|
651 |
+
for i in range(num_trajs):
|
652 |
+
ctr_pts_list.append(create_traj(box_shapes[i]))
|
653 |
+
ctr_pts = np.concatenate(ctr_pts_list, axis=1)
|
654 |
+
#print("all center points guassian ",ctr_pts,ctr_pts.shape)
|
655 |
+
|
656 |
+
# Step 2 create box shapes for the starting location
|
657 |
+
|
658 |
+
boxes_list = []
|
659 |
+
for i in range(num_trajs):
|
660 |
+
x1, y1 = ctr_pts[0][i][0], ctr_pts[0][i][1]
|
661 |
+
box = np.concatenate((
|
662 |
+
(x1 - box_shapes[i, 0]/2).reshape(-1, 1),
|
663 |
+
(y1 - box_shapes[i, 1]/2).reshape(-1, 1),
|
664 |
+
(x1 + box_shapes[i, 0]/2).reshape(-1, 1),
|
665 |
+
(y1 + box_shapes[i, 1]/2).reshape(-1, 1)),
|
666 |
+
axis=1)
|
667 |
+
boxes_list.append(box)
|
668 |
+
|
669 |
+
boxes= np.concatenate(boxes_list, axis=0)
|
670 |
+
box_sizes = (boxes[:, 2:4] - boxes[:, 0:2])
|
671 |
+
#print("bboxes guassian ",boxes,boxes.shape)
|
672 |
+
#print("guassian box sizes = ",box_sizes,box_sizes.shape)
|
673 |
+
|
674 |
+
# Step 3, sample key frames for shape changes
|
675 |
+
size_key_inds = sample_key_frames(num_frames, key_frame_probs)
|
676 |
+
# Step 4, setup shape in different key frames
|
677 |
+
box_sizes_list = [np.expand_dims(box_sizes, axis=0)]
|
678 |
+
for i in range(len(size_key_inds) - 1):
|
679 |
+
if size_velocity > 0:
|
680 |
+
index_diff = size_key_inds[i + 1] - size_key_inds[i]
|
681 |
+
scales = np.random.uniform(low=-size_velocity * index_diff,
|
682 |
+
high=size_velocity * index_diff,
|
683 |
+
size=box_sizes.shape)
|
684 |
+
scales = np.exp(scales)
|
685 |
+
box_sizes = box_sizes * scales
|
686 |
+
box_sizes_list.append(np.expand_dims(box_sizes, axis=0))
|
687 |
+
box_sizes = np.concatenate(box_sizes_list, axis=0)
|
688 |
+
# print("box sizes before interpolation",box_sizes)
|
689 |
+
box_sizes = extend_key_frame_to_all(box_sizes, size_key_inds, 'random')
|
690 |
+
#print("box sizes after interpolation",box_sizes)
|
691 |
+
|
692 |
+
# Step 5, construct boxes in key frames
|
693 |
+
all_boxes = np.concatenate((ctr_pts - box_sizes * 0.5,
|
694 |
+
ctr_pts + box_sizes * 0.5), axis=2)
|
695 |
+
# all_boxes[..., 0::2] = np.clip(all_boxes[..., 0::2], 0, img_wh[0])
|
696 |
+
# all_boxes[..., 1::2] = np.clip(all_boxes[..., 1::2], 0, img_wh[1])
|
697 |
+
all_boxes = all_boxes.transpose((1, 0, 2))
|
698 |
+
return all_boxes,boxes
|
699 |
+
|
700 |
+
def __call__(self,img_tuple):
|
701 |
+
#def get_transform_param(self, data: List[np.ndarray], *args, **kwargs):
|
702 |
+
""" Generate the transformation parameters.
|
703 |
+
|
704 |
+
Args:
|
705 |
+
data (List[np.ndarray]): list of image array, each element is in
|
706 |
+
a shape of [H, W, 3]
|
707 |
+
|
708 |
+
Returns:
|
709 |
+
params (dict): a dict that contains necessary transformation
|
710 |
+
params, which include:
|
711 |
+
'patches': list of image patches (np.ndarray)
|
712 |
+
'alphas': list of alpha mask, same size and shape as patches.
|
713 |
+
'traj_rois': the trajectory position, in shape of
|
714 |
+
[N_traj, N_frame, 4]
|
715 |
+
'traj_labels': whether the patches have been pasted on some
|
716 |
+
specific frames, in shape of [N_traj, N_frame]
|
717 |
+
"""
|
718 |
+
|
719 |
+
#print("with tubelets")
|
720 |
+
|
721 |
+
img_group, label = img_tuple
|
722 |
+
|
723 |
+
#print("before length data",len(img_group),img_group[0].size)
|
724 |
+
|
725 |
+
new_data = [np.array(img) for img in img_group]
|
726 |
+
|
727 |
+
#print("after length data",len(new_data),new_data[0].shape)
|
728 |
+
|
729 |
+
data_1 = new_data # Step 1, generate the trajectories.
|
730 |
+
|
731 |
+
h, w = data_1[0].shape[0:2]
|
732 |
+
|
733 |
+
#print("motion type and size_velocity", self.motion_type,self.size_velocity)
|
734 |
+
#print(" patch transformation and rotation velocity =",self.patch_transformation,self.rot_velocity)
|
735 |
+
if self.motion_type == 'linear' :
|
736 |
+
|
737 |
+
boxes = self.region_sampler.sample(data_1)
|
738 |
+
|
739 |
+
traj_rois = self.rectangle_movement(boxes, (w, h),
|
740 |
+
self.loc_velocity,
|
741 |
+
self.size_velocity,
|
742 |
+
len(data_1),
|
743 |
+
self.key_frame_probs)
|
744 |
+
# gaussian
|
745 |
+
elif self.motion_type == 'gaussian' :
|
746 |
+
|
747 |
+
box_shapes = self.region_sampler.sample_box_shapes(data_1)
|
748 |
+
|
749 |
+
traj_rois,boxes = self.gaussian_movement(box_shapes, (w, h),
|
750 |
+
self.region_sampler.num_rois,
|
751 |
+
self.size_velocity,
|
752 |
+
len(data_1),
|
753 |
+
self.key_frame_probs)
|
754 |
+
|
755 |
+
#print("gaussian rois",traj_rois.shape)
|
756 |
+
traj_rois = np.round(traj_rois).astype(int)
|
757 |
+
# traj_rois[..., 0::2] = np.clip(traj_rois[..., 0::2], 0, w)
|
758 |
+
# traj_rois[..., 1::2] = np.clip(traj_rois[..., 1::2], 0, h)
|
759 |
+
|
760 |
+
# Step 2, crop the patches and prepare the alpha masks.
|
761 |
+
if not self.use_objects:
|
762 |
+
|
763 |
+
#print(" pasting patches")
|
764 |
+
patches_list, alphas_list, label_list = self.paste_patches(data_1,traj_rois,boxes)
|
765 |
+
else:
|
766 |
+
#print(" pasting objects")
|
767 |
+
patches_list, alphas_list, label_list = self.paste_objects(data_1,traj_rois,boxes)
|
768 |
+
|
769 |
+
|
770 |
+
|
771 |
+
transforms_dict = dict(
|
772 |
+
traj_rois=traj_rois,
|
773 |
+
patches=patches_list,
|
774 |
+
alphas=alphas_list,
|
775 |
+
traj_labels=label_list,
|
776 |
+
rot_velocity = self.rot_velocity,
|
777 |
+
patch_transformation = self.patch_transformation,
|
778 |
+
key_frame_probs = self.key_frame_probs
|
779 |
+
)
|
780 |
+
|
781 |
+
output_data = self._apply_image( new_data,transforms_dict)
|
782 |
+
|
783 |
+
ret_data = [Image.fromarray(img) for img in output_data]
|
784 |
+
|
785 |
+
return ret_data, label, traj_rois
|
transforms.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms.functional as F
|
3 |
+
import warnings
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
import torchvision
|
7 |
+
from PIL import Image, ImageOps
|
8 |
+
import numbers
|
9 |
+
|
10 |
+
|
11 |
+
class GroupRandomCrop(object):
|
12 |
+
def __init__(self, size):
|
13 |
+
if isinstance(size, numbers.Number):
|
14 |
+
self.size = (int(size), int(size))
|
15 |
+
else:
|
16 |
+
self.size = size
|
17 |
+
|
18 |
+
def __call__(self, img_tuple):
|
19 |
+
img_group, label = img_tuple
|
20 |
+
|
21 |
+
w, h = img_group[0].size
|
22 |
+
th, tw = self.size
|
23 |
+
|
24 |
+
out_images = list()
|
25 |
+
|
26 |
+
x1 = random.randint(0, w - tw)
|
27 |
+
y1 = random.randint(0, h - th)
|
28 |
+
|
29 |
+
for img in img_group:
|
30 |
+
assert(img.size[0] == w and img.size[1] == h)
|
31 |
+
if w == tw and h == th:
|
32 |
+
out_images.append(img)
|
33 |
+
else:
|
34 |
+
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
35 |
+
|
36 |
+
return (out_images, label)
|
37 |
+
|
38 |
+
|
39 |
+
class GroupCenterCrop(object):
|
40 |
+
def __init__(self, size):
|
41 |
+
self.worker = torchvision.transforms.CenterCrop(size)
|
42 |
+
|
43 |
+
def __call__(self, img_tuple):
|
44 |
+
img_group, label = img_tuple
|
45 |
+
return ([self.worker(img) for img in img_group], label)
|
46 |
+
|
47 |
+
|
48 |
+
class GroupNormalize(object):
|
49 |
+
def __init__(self, mean, std):
|
50 |
+
self.mean = mean
|
51 |
+
self.std = std
|
52 |
+
|
53 |
+
def __call__(self, tensor_tuple):
|
54 |
+
tensor, label = tensor_tuple
|
55 |
+
rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
|
56 |
+
rep_std = self.std * (tensor.size()[0]//len(self.std))
|
57 |
+
|
58 |
+
# TODO: make efficient
|
59 |
+
for t, m, s in zip(tensor, rep_mean, rep_std):
|
60 |
+
t.sub_(m).div_(s)
|
61 |
+
|
62 |
+
return (tensor,label)
|
63 |
+
|
64 |
+
|
65 |
+
class GroupGrayScale(object):
|
66 |
+
def __init__(self, size):
|
67 |
+
self.worker = torchvision.transforms.Grayscale(size)
|
68 |
+
|
69 |
+
def __call__(self, img_tuple):
|
70 |
+
img_group, label = img_tuple
|
71 |
+
return ([self.worker(img) for img in img_group], label)
|
72 |
+
|
73 |
+
|
74 |
+
class GroupScale(object):
|
75 |
+
""" Rescales the input PIL.Image to the given 'size'.
|
76 |
+
'size' will be the size of the smaller edge.
|
77 |
+
For example, if height > width, then image will be
|
78 |
+
rescaled to (size * height / width, size)
|
79 |
+
size: size of the smaller edge
|
80 |
+
interpolation: Default: PIL.Image.BILINEAR
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self, size, interpolation=Image.BILINEAR):
|
84 |
+
self.worker = torchvision.transforms.Resize(size, interpolation)
|
85 |
+
|
86 |
+
def __call__(self, img_tuple):
|
87 |
+
img_group, label = img_tuple
|
88 |
+
return ([self.worker(img) for img in img_group], label)
|
89 |
+
|
90 |
+
|
91 |
+
class GroupMultiScaleCrop(object):
|
92 |
+
|
93 |
+
def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
|
94 |
+
self.scales = scales if scales is not None else [1, .875, .75, .66]
|
95 |
+
self.max_distort = max_distort
|
96 |
+
self.fix_crop = fix_crop
|
97 |
+
self.more_fix_crop = more_fix_crop
|
98 |
+
self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
|
99 |
+
self.interpolation = Image.BILINEAR
|
100 |
+
|
101 |
+
def __call__(self, img_tuple):
|
102 |
+
img_group, label = img_tuple
|
103 |
+
|
104 |
+
im_size = img_group[0].size
|
105 |
+
|
106 |
+
crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
|
107 |
+
crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
|
108 |
+
ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group]
|
109 |
+
return (ret_img_group, label)
|
110 |
+
|
111 |
+
def _sample_crop_size(self, im_size):
|
112 |
+
image_w, image_h = im_size[0], im_size[1]
|
113 |
+
|
114 |
+
# find a crop size
|
115 |
+
base_size = min(image_w, image_h)
|
116 |
+
crop_sizes = [int(base_size * x) for x in self.scales]
|
117 |
+
crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
|
118 |
+
crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
|
119 |
+
|
120 |
+
pairs = []
|
121 |
+
for i, h in enumerate(crop_h):
|
122 |
+
for j, w in enumerate(crop_w):
|
123 |
+
if abs(i - j) <= self.max_distort:
|
124 |
+
pairs.append((w, h))
|
125 |
+
|
126 |
+
crop_pair = random.choice(pairs)
|
127 |
+
if not self.fix_crop:
|
128 |
+
w_offset = random.randint(0, image_w - crop_pair[0])
|
129 |
+
h_offset = random.randint(0, image_h - crop_pair[1])
|
130 |
+
else:
|
131 |
+
w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
|
132 |
+
|
133 |
+
return crop_pair[0], crop_pair[1], w_offset, h_offset
|
134 |
+
|
135 |
+
def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
|
136 |
+
offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
|
137 |
+
return random.choice(offsets)
|
138 |
+
|
139 |
+
@staticmethod
|
140 |
+
def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
|
141 |
+
w_step = (image_w - crop_w) // 4
|
142 |
+
h_step = (image_h - crop_h) // 4
|
143 |
+
|
144 |
+
ret = list()
|
145 |
+
ret.append((0, 0)) # upper left
|
146 |
+
ret.append((4 * w_step, 0)) # upper right
|
147 |
+
ret.append((0, 4 * h_step)) # lower left
|
148 |
+
ret.append((4 * w_step, 4 * h_step)) # lower right
|
149 |
+
ret.append((2 * w_step, 2 * h_step)) # center
|
150 |
+
|
151 |
+
if more_fix_crop:
|
152 |
+
ret.append((0, 2 * h_step)) # center left
|
153 |
+
ret.append((4 * w_step, 2 * h_step)) # center right
|
154 |
+
ret.append((2 * w_step, 4 * h_step)) # lower center
|
155 |
+
ret.append((2 * w_step, 0 * h_step)) # upper center
|
156 |
+
|
157 |
+
ret.append((1 * w_step, 1 * h_step)) # upper left quarter
|
158 |
+
ret.append((3 * w_step, 1 * h_step)) # upper right quarter
|
159 |
+
ret.append((1 * w_step, 3 * h_step)) # lower left quarter
|
160 |
+
ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
|
161 |
+
return ret
|
162 |
+
|
163 |
+
|
164 |
+
class Stack(object):
|
165 |
+
|
166 |
+
def __init__(self, roll=False):
|
167 |
+
self.roll = roll
|
168 |
+
|
169 |
+
def __call__(self, img_tuple):
|
170 |
+
img_group, label = img_tuple
|
171 |
+
|
172 |
+
if img_group[0].mode == 'L':
|
173 |
+
return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label)
|
174 |
+
elif img_group[0].mode == 'RGB':
|
175 |
+
if self.roll:
|
176 |
+
return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label)
|
177 |
+
else:
|
178 |
+
return (np.concatenate(img_group, axis=2), label)
|
179 |
+
|
180 |
+
|
181 |
+
class ToTorchFormatTensor(object):
|
182 |
+
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
|
183 |
+
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
|
184 |
+
def __init__(self, div=True):
|
185 |
+
self.div = div
|
186 |
+
|
187 |
+
def __call__(self, pic_tuple):
|
188 |
+
pic, label = pic_tuple
|
189 |
+
|
190 |
+
if isinstance(pic, np.ndarray):
|
191 |
+
# handle numpy array
|
192 |
+
img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
|
193 |
+
else:
|
194 |
+
# handle PIL Image
|
195 |
+
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
196 |
+
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
|
197 |
+
# put it from HWC to CHW format
|
198 |
+
# yikes, this transpose takes 80% of the loading time/CPU
|
199 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
200 |
+
return (img.float().div(255.) if self.div else img.float(), label)
|
201 |
+
|
202 |
+
|
203 |
+
class IdentityTransform(object):
|
204 |
+
|
205 |
+
def __call__(self, data):
|
206 |
+
return data
|
utils_mae.py
ADDED
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
from collections import defaultdict, deque
|
7 |
+
import datetime
|
8 |
+
import numpy as np
|
9 |
+
from timm.utils import get_state_dict
|
10 |
+
from torch.utils.data._utils.collate import default_collate
|
11 |
+
from pathlib import Path
|
12 |
+
import subprocess
|
13 |
+
import torch
|
14 |
+
import torch.distributed as dist
|
15 |
+
#from torch._six import inf
|
16 |
+
from torch import inf
|
17 |
+
import random
|
18 |
+
|
19 |
+
from tensorboardX import SummaryWriter
|
20 |
+
|
21 |
+
|
22 |
+
class SmoothedValue(object):
|
23 |
+
"""Track a series of values and provide access to smoothed values over a
|
24 |
+
window or the global series average.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, window_size=20, fmt=None):
|
28 |
+
if fmt is None:
|
29 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
30 |
+
self.deque = deque(maxlen=window_size)
|
31 |
+
self.total = 0.0
|
32 |
+
self.count = 0
|
33 |
+
self.fmt = fmt
|
34 |
+
|
35 |
+
def update(self, value, n=1):
|
36 |
+
self.deque.append(value)
|
37 |
+
self.count += n
|
38 |
+
self.total += value * n
|
39 |
+
|
40 |
+
def synchronize_between_processes(self):
|
41 |
+
"""
|
42 |
+
Warning: does not synchronize the deque!
|
43 |
+
"""
|
44 |
+
if not is_dist_avail_and_initialized():
|
45 |
+
return
|
46 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
47 |
+
dist.barrier()
|
48 |
+
dist.all_reduce(t)
|
49 |
+
t = t.tolist()
|
50 |
+
self.count = int(t[0])
|
51 |
+
self.total = t[1]
|
52 |
+
|
53 |
+
@property
|
54 |
+
def median(self):
|
55 |
+
d = torch.tensor(list(self.deque))
|
56 |
+
return d.median().item()
|
57 |
+
|
58 |
+
@property
|
59 |
+
def avg(self):
|
60 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
61 |
+
return d.mean().item()
|
62 |
+
|
63 |
+
@property
|
64 |
+
def global_avg(self):
|
65 |
+
return self.total / self.count
|
66 |
+
|
67 |
+
@property
|
68 |
+
def max(self):
|
69 |
+
return max(self.deque)
|
70 |
+
|
71 |
+
@property
|
72 |
+
def value(self):
|
73 |
+
return self.deque[-1]
|
74 |
+
|
75 |
+
def __str__(self):
|
76 |
+
return self.fmt.format(
|
77 |
+
median=self.median,
|
78 |
+
avg=self.avg,
|
79 |
+
global_avg=self.global_avg,
|
80 |
+
max=self.max,
|
81 |
+
value=self.value)
|
82 |
+
|
83 |
+
|
84 |
+
class MetricLogger(object):
|
85 |
+
def __init__(self, delimiter="\t"):
|
86 |
+
self.meters = defaultdict(SmoothedValue)
|
87 |
+
self.delimiter = delimiter
|
88 |
+
|
89 |
+
def update(self, **kwargs):
|
90 |
+
for k, v in kwargs.items():
|
91 |
+
if v is None:
|
92 |
+
continue
|
93 |
+
if isinstance(v, torch.Tensor):
|
94 |
+
v = v.item()
|
95 |
+
assert isinstance(v, (float, int))
|
96 |
+
self.meters[k].update(v)
|
97 |
+
|
98 |
+
def __getattr__(self, attr):
|
99 |
+
if attr in self.meters:
|
100 |
+
return self.meters[attr]
|
101 |
+
if attr in self.__dict__:
|
102 |
+
return self.__dict__[attr]
|
103 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
104 |
+
type(self).__name__, attr))
|
105 |
+
|
106 |
+
def __str__(self):
|
107 |
+
loss_str = []
|
108 |
+
for name, meter in self.meters.items():
|
109 |
+
loss_str.append(
|
110 |
+
"{}: {}".format(name, str(meter))
|
111 |
+
)
|
112 |
+
return self.delimiter.join(loss_str)
|
113 |
+
|
114 |
+
def synchronize_between_processes(self):
|
115 |
+
for meter in self.meters.values():
|
116 |
+
meter.synchronize_between_processes()
|
117 |
+
|
118 |
+
def add_meter(self, name, meter):
|
119 |
+
self.meters[name] = meter
|
120 |
+
|
121 |
+
def log_every(self, iterable, print_freq, header=None):
|
122 |
+
i = 0
|
123 |
+
if not header:
|
124 |
+
header = ''
|
125 |
+
start_time = time.time()
|
126 |
+
end = time.time()
|
127 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
128 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
129 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
130 |
+
log_msg = [
|
131 |
+
header,
|
132 |
+
'[{0' + space_fmt + '}/{1}]',
|
133 |
+
'eta: {eta}',
|
134 |
+
'{meters}',
|
135 |
+
'time: {time}',
|
136 |
+
'data: {data}'
|
137 |
+
]
|
138 |
+
if torch.cuda.is_available():
|
139 |
+
log_msg.append('max mem: {memory:.0f}')
|
140 |
+
log_msg = self.delimiter.join(log_msg)
|
141 |
+
MB = 1024.0 * 1024.0
|
142 |
+
for obj in iterable:
|
143 |
+
data_time.update(time.time() - end)
|
144 |
+
yield obj
|
145 |
+
iter_time.update(time.time() - end)
|
146 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
147 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
148 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
149 |
+
if torch.cuda.is_available():
|
150 |
+
print(log_msg.format(
|
151 |
+
i, len(iterable), eta=eta_string,
|
152 |
+
meters=str(self),
|
153 |
+
time=str(iter_time), data=str(data_time),
|
154 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
155 |
+
else:
|
156 |
+
print(log_msg.format(
|
157 |
+
i, len(iterable), eta=eta_string,
|
158 |
+
meters=str(self),
|
159 |
+
time=str(iter_time), data=str(data_time)))
|
160 |
+
i += 1
|
161 |
+
end = time.time()
|
162 |
+
total_time = time.time() - start_time
|
163 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
164 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
165 |
+
header, total_time_str, total_time / len(iterable)))
|
166 |
+
|
167 |
+
|
168 |
+
class TensorboardLogger(object):
|
169 |
+
def __init__(self, log_dir):
|
170 |
+
self.writer = SummaryWriter(logdir=log_dir)
|
171 |
+
self.step = 0
|
172 |
+
|
173 |
+
def set_step(self, step=None):
|
174 |
+
if step is not None:
|
175 |
+
self.step = step
|
176 |
+
else:
|
177 |
+
self.step += 1
|
178 |
+
|
179 |
+
def update(self, head='scalar', step=None, **kwargs):
|
180 |
+
for k, v in kwargs.items():
|
181 |
+
if v is None:
|
182 |
+
continue
|
183 |
+
if isinstance(v, torch.Tensor):
|
184 |
+
v = v.item()
|
185 |
+
assert isinstance(v, (float, int))
|
186 |
+
self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
|
187 |
+
|
188 |
+
def flush(self):
|
189 |
+
self.writer.flush()
|
190 |
+
|
191 |
+
def seed_worker(worker_id):
|
192 |
+
worker_seed = torch.initial_seed() % 2**32
|
193 |
+
np.random.seed(worker_seed)
|
194 |
+
random.seed(worker_seed)
|
195 |
+
|
196 |
+
def _load_checkpoint_for_ema(model_ema, checkpoint):
|
197 |
+
"""
|
198 |
+
Workaround for ModelEma._load_checkpoint to accept an already-loaded object
|
199 |
+
"""
|
200 |
+
mem_file = io.BytesIO()
|
201 |
+
torch.save(checkpoint, mem_file)
|
202 |
+
mem_file.seek(0)
|
203 |
+
model_ema._load_checkpoint(mem_file)
|
204 |
+
|
205 |
+
|
206 |
+
def setup_for_distributed(is_master):
|
207 |
+
"""
|
208 |
+
This function disables printing when not in master process
|
209 |
+
"""
|
210 |
+
import builtins as __builtin__
|
211 |
+
builtin_print = __builtin__.print
|
212 |
+
|
213 |
+
def print(*args, **kwargs):
|
214 |
+
force = kwargs.pop('force', False)
|
215 |
+
if is_master or force:
|
216 |
+
builtin_print(*args, **kwargs)
|
217 |
+
|
218 |
+
__builtin__.print = print
|
219 |
+
|
220 |
+
|
221 |
+
def is_dist_avail_and_initialized():
|
222 |
+
if not dist.is_available():
|
223 |
+
return False
|
224 |
+
if not dist.is_initialized():
|
225 |
+
return False
|
226 |
+
return True
|
227 |
+
|
228 |
+
|
229 |
+
def get_world_size():
|
230 |
+
if not is_dist_avail_and_initialized():
|
231 |
+
return 1
|
232 |
+
return dist.get_world_size()
|
233 |
+
|
234 |
+
|
235 |
+
def get_rank():
|
236 |
+
if not is_dist_avail_and_initialized():
|
237 |
+
return 0
|
238 |
+
return dist.get_rank()
|
239 |
+
|
240 |
+
|
241 |
+
def is_main_process():
|
242 |
+
return get_rank() == 0
|
243 |
+
|
244 |
+
|
245 |
+
def save_on_master(*args, **kwargs):
|
246 |
+
if is_main_process():
|
247 |
+
torch.save(*args, **kwargs)
|
248 |
+
|
249 |
+
|
250 |
+
def init_distributed_mode(args):
|
251 |
+
if args.dist_on_itp:
|
252 |
+
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
253 |
+
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
254 |
+
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
255 |
+
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
256 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
257 |
+
os.environ['RANK'] = str(args.rank)
|
258 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
259 |
+
elif 'SLURM_PROCID' in os.environ:
|
260 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
261 |
+
args.gpu = int(os.environ['SLURM_LOCALID'])
|
262 |
+
args.world_size = int(os.environ['SLURM_NTASKS'])
|
263 |
+
os.environ['RANK'] = str(args.rank)
|
264 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
265 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
266 |
+
|
267 |
+
node_list = os.environ['SLURM_NODELIST']
|
268 |
+
addr = subprocess.getoutput(
|
269 |
+
f'scontrol show hostname {node_list} | head -n1')
|
270 |
+
if 'MASTER_ADDR' not in os.environ:
|
271 |
+
os.environ['MASTER_ADDR'] = addr
|
272 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
273 |
+
args.rank = int(os.environ["RANK"])
|
274 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
275 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
276 |
+
else:
|
277 |
+
print('Not using distributed mode')
|
278 |
+
args.distributed = False
|
279 |
+
return
|
280 |
+
|
281 |
+
args.distributed = True
|
282 |
+
|
283 |
+
torch.cuda.set_device(args.gpu)
|
284 |
+
args.dist_backend = 'nccl'
|
285 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
286 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
287 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
288 |
+
world_size=args.world_size, rank=args.rank)
|
289 |
+
torch.distributed.barrier()
|
290 |
+
# assert torch.distributed.is_initialized()
|
291 |
+
setup_for_distributed(args.rank == 0)
|
292 |
+
|
293 |
+
|
294 |
+
def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
|
295 |
+
missing_keys = []
|
296 |
+
unexpected_keys = []
|
297 |
+
error_msgs = []
|
298 |
+
metadata = getattr(state_dict, '_metadata', None)
|
299 |
+
state_dict = state_dict.copy()
|
300 |
+
if metadata is not None:
|
301 |
+
state_dict._metadata = metadata
|
302 |
+
|
303 |
+
def load(module, prefix=''):
|
304 |
+
local_metadata = {} if metadata is None else metadata.get(
|
305 |
+
prefix[:-1], {})
|
306 |
+
module._load_from_state_dict(
|
307 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
308 |
+
for name, child in module._modules.items():
|
309 |
+
if child is not None:
|
310 |
+
load(child, prefix + name + '.')
|
311 |
+
|
312 |
+
load(model, prefix=prefix)
|
313 |
+
|
314 |
+
warn_missing_keys = []
|
315 |
+
ignore_missing_keys = []
|
316 |
+
for key in missing_keys:
|
317 |
+
keep_flag = True
|
318 |
+
for ignore_key in ignore_missing.split('|'):
|
319 |
+
if ignore_key in key:
|
320 |
+
keep_flag = False
|
321 |
+
break
|
322 |
+
if keep_flag:
|
323 |
+
warn_missing_keys.append(key)
|
324 |
+
else:
|
325 |
+
ignore_missing_keys.append(key)
|
326 |
+
|
327 |
+
missing_keys = warn_missing_keys
|
328 |
+
|
329 |
+
if len(missing_keys) > 0:
|
330 |
+
print("Weights of {} not initialized from pretrained model: {}".format(
|
331 |
+
model.__class__.__name__, missing_keys))
|
332 |
+
if len(unexpected_keys) > 0:
|
333 |
+
print("Weights from pretrained model not used in {}: {}".format(
|
334 |
+
model.__class__.__name__, unexpected_keys))
|
335 |
+
if len(ignore_missing_keys) > 0:
|
336 |
+
print("Ignored weights of {} not initialized from pretrained model: {}".format(
|
337 |
+
model.__class__.__name__, ignore_missing_keys))
|
338 |
+
if len(error_msgs) > 0:
|
339 |
+
print('\n'.join(error_msgs))
|
340 |
+
|
341 |
+
|
342 |
+
class NativeScalerWithGradNormCount:
|
343 |
+
state_dict_key = "amp_scaler"
|
344 |
+
|
345 |
+
def __init__(self):
|
346 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
347 |
+
|
348 |
+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
|
349 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
350 |
+
if update_grad:
|
351 |
+
if clip_grad is not None:
|
352 |
+
assert parameters is not None
|
353 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
354 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
355 |
+
else:
|
356 |
+
self._scaler.unscale_(optimizer)
|
357 |
+
norm = get_grad_norm_(parameters)
|
358 |
+
self._scaler.step(optimizer)
|
359 |
+
self._scaler.update()
|
360 |
+
else:
|
361 |
+
norm = None
|
362 |
+
return norm
|
363 |
+
|
364 |
+
def state_dict(self):
|
365 |
+
return self._scaler.state_dict()
|
366 |
+
|
367 |
+
def load_state_dict(self, state_dict):
|
368 |
+
self._scaler.load_state_dict(state_dict)
|
369 |
+
|
370 |
+
|
371 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
372 |
+
if isinstance(parameters, torch.Tensor):
|
373 |
+
parameters = [parameters]
|
374 |
+
parameters = [p for p in parameters if p.grad is not None]
|
375 |
+
norm_type = float(norm_type)
|
376 |
+
if len(parameters) == 0:
|
377 |
+
return torch.tensor(0.)
|
378 |
+
device = parameters[0].grad.device
|
379 |
+
if norm_type == inf:
|
380 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
381 |
+
else:
|
382 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
383 |
+
return total_norm
|
384 |
+
|
385 |
+
|
386 |
+
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
|
387 |
+
start_warmup_value=0, warmup_steps=-1):
|
388 |
+
warmup_schedule = np.array([])
|
389 |
+
warmup_iters = warmup_epochs * niter_per_ep
|
390 |
+
if warmup_steps > 0:
|
391 |
+
warmup_iters = warmup_steps
|
392 |
+
print("Set warmup steps = %d" % warmup_iters)
|
393 |
+
if warmup_epochs > 0:
|
394 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
395 |
+
|
396 |
+
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
397 |
+
schedule = np.array(
|
398 |
+
[final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
|
399 |
+
|
400 |
+
schedule = np.concatenate((warmup_schedule, schedule))
|
401 |
+
|
402 |
+
assert len(schedule) == epochs * niter_per_ep
|
403 |
+
return schedule
|
404 |
+
|
405 |
+
|
406 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
|
407 |
+
output_dir = Path(args.output_dir)
|
408 |
+
epoch_name = str(epoch)
|
409 |
+
if loss_scaler is not None:
|
410 |
+
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
|
411 |
+
for checkpoint_path in checkpoint_paths:
|
412 |
+
to_save = {
|
413 |
+
'model': model_without_ddp.state_dict(),
|
414 |
+
'optimizer': optimizer.state_dict(),
|
415 |
+
'epoch': epoch,
|
416 |
+
'scaler': loss_scaler.state_dict(),
|
417 |
+
'args': args,
|
418 |
+
}
|
419 |
+
|
420 |
+
if model_ema is not None:
|
421 |
+
to_save['model_ema'] = get_state_dict(model_ema)
|
422 |
+
|
423 |
+
save_on_master(to_save, checkpoint_path)
|
424 |
+
else:
|
425 |
+
client_state = {'epoch': epoch}
|
426 |
+
if model_ema is not None:
|
427 |
+
client_state['model_ema'] = get_state_dict(model_ema)
|
428 |
+
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
|
429 |
+
|
430 |
+
|
431 |
+
def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
|
432 |
+
output_dir = Path(args.output_dir)
|
433 |
+
if loss_scaler is not None:
|
434 |
+
# torch.amp
|
435 |
+
if args.auto_resume and len(args.resume) == 0:
|
436 |
+
import glob
|
437 |
+
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
|
438 |
+
latest_ckpt = -1
|
439 |
+
for ckpt in all_checkpoints:
|
440 |
+
t = ckpt.split('-')[-1].split('.')[0]
|
441 |
+
if t.isdigit():
|
442 |
+
latest_ckpt = max(int(t), latest_ckpt)
|
443 |
+
if latest_ckpt >= 0:
|
444 |
+
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
|
445 |
+
print("Auto resume checkpoint: %s" % args.resume)
|
446 |
+
|
447 |
+
if args.resume:
|
448 |
+
if args.resume.startswith('https'):
|
449 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
450 |
+
args.resume, map_location='cpu', check_hash=True)
|
451 |
+
else:
|
452 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
453 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
454 |
+
print("Resume checkpoint %s" % args.resume)
|
455 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
|
456 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
457 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
458 |
+
if hasattr(args, 'model_ema') and args.model_ema:
|
459 |
+
_load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
|
460 |
+
if 'scaler' in checkpoint:
|
461 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
462 |
+
print("With optim & sched!")
|
463 |
+
else:
|
464 |
+
# deepspeed, only support '--auto_resume'.
|
465 |
+
if args.auto_resume:
|
466 |
+
import glob
|
467 |
+
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
|
468 |
+
latest_ckpt = -1
|
469 |
+
for ckpt in all_checkpoints:
|
470 |
+
t = ckpt.split('-')[-1].split('.')[0]
|
471 |
+
if t.isdigit():
|
472 |
+
latest_ckpt = max(int(t), latest_ckpt)
|
473 |
+
if latest_ckpt >= 0:
|
474 |
+
args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt)
|
475 |
+
print("Auto resume checkpoint: %d" % latest_ckpt)
|
476 |
+
_, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt)
|
477 |
+
args.start_epoch = client_states['epoch'] + 1
|
478 |
+
if model_ema is not None:
|
479 |
+
if args.model_ema:
|
480 |
+
_load_checkpoint_for_ema(model_ema, client_states['model_ema'])
|
481 |
+
|
482 |
+
|
483 |
+
def create_ds_config(args):
|
484 |
+
args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
|
485 |
+
with open(args.deepspeed_config, mode="w") as writer:
|
486 |
+
ds_config = {
|
487 |
+
"train_batch_size": args.batch_size * args.update_freq * get_world_size(),
|
488 |
+
"train_micro_batch_size_per_gpu": args.batch_size,
|
489 |
+
"steps_per_print": 1000,
|
490 |
+
"optimizer": {
|
491 |
+
"type": "Adam",
|
492 |
+
"adam_w_mode": True,
|
493 |
+
"params": {
|
494 |
+
"lr": args.lr,
|
495 |
+
"weight_decay": args.weight_decay,
|
496 |
+
"bias_correction": True,
|
497 |
+
"betas": [
|
498 |
+
0.9,
|
499 |
+
0.999
|
500 |
+
],
|
501 |
+
"eps": 1e-8
|
502 |
+
}
|
503 |
+
},
|
504 |
+
"fp16": {
|
505 |
+
"enabled": True,
|
506 |
+
"loss_scale": 0,
|
507 |
+
"initial_scale_power": 7,
|
508 |
+
"loss_scale_window": 128
|
509 |
+
}
|
510 |
+
}
|
511 |
+
|
512 |
+
writer.write(json.dumps(ds_config, indent=2))
|
513 |
+
|
514 |
+
def multiple_samples_collate(batch, fold=False):
|
515 |
+
"""
|
516 |
+
Collate function for repeated augmentation. Each instance in the batch has
|
517 |
+
more than one sample.
|
518 |
+
Args:
|
519 |
+
batch (tuple or list): data batch to collate.
|
520 |
+
Returns:
|
521 |
+
(tuple): collated data batch.
|
522 |
+
"""
|
523 |
+
inputs, labels, video_idx, extra_data = zip(*batch)
|
524 |
+
inputs = [item for sublist in inputs for item in sublist]
|
525 |
+
labels = [item for sublist in labels for item in sublist]
|
526 |
+
video_idx = [item for sublist in video_idx for item in sublist]
|
527 |
+
inputs, labels, video_idx, extra_data = (
|
528 |
+
default_collate(inputs),
|
529 |
+
default_collate(labels),
|
530 |
+
default_collate(video_idx),
|
531 |
+
default_collate(extra_data),
|
532 |
+
)
|
533 |
+
if fold:
|
534 |
+
return [inputs], labels, video_idx, extra_data
|
535 |
+
else:
|
536 |
+
return inputs, labels, video_idx, extra_data
|
video_transforms.py
ADDED
@@ -0,0 +1,1281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms.functional as F
|
7 |
+
from PIL import Image
|
8 |
+
from torchvision import transforms
|
9 |
+
|
10 |
+
from rand_augment import rand_augment_transform
|
11 |
+
from random_erasing import RandomErasing
|
12 |
+
|
13 |
+
|
14 |
+
import numbers
|
15 |
+
import PIL
|
16 |
+
import torchvision
|
17 |
+
|
18 |
+
import functional as FF
|
19 |
+
|
20 |
+
_pil_interpolation_to_str = {
|
21 |
+
Image.NEAREST: "PIL.Image.NEAREST",
|
22 |
+
Image.BILINEAR: "PIL.Image.BILINEAR",
|
23 |
+
Image.BICUBIC: "PIL.Image.BICUBIC",
|
24 |
+
Image.LANCZOS: "PIL.Image.LANCZOS",
|
25 |
+
Image.HAMMING: "PIL.Image.HAMMING",
|
26 |
+
Image.BOX: "PIL.Image.BOX",
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
31 |
+
|
32 |
+
|
33 |
+
def _pil_interp(method):
|
34 |
+
if method == "bicubic":
|
35 |
+
return Image.BICUBIC
|
36 |
+
elif method == "lanczos":
|
37 |
+
return Image.LANCZOS
|
38 |
+
elif method == "hamming":
|
39 |
+
return Image.HAMMING
|
40 |
+
else:
|
41 |
+
return Image.BILINEAR
|
42 |
+
|
43 |
+
|
44 |
+
def random_short_side_scale_jitter(
|
45 |
+
images, min_size, max_size, boxes=None, inverse_uniform_sampling=False
|
46 |
+
):
|
47 |
+
"""
|
48 |
+
Perform a spatial short scale jittering on the given images and
|
49 |
+
corresponding boxes.
|
50 |
+
Args:
|
51 |
+
images (tensor): images to perform scale jitter. Dimension is
|
52 |
+
`num frames` x `channel` x `height` x `width`.
|
53 |
+
min_size (int): the minimal size to scale the frames.
|
54 |
+
max_size (int): the maximal size to scale the frames.
|
55 |
+
boxes (ndarray): optional. Corresponding boxes to images.
|
56 |
+
Dimension is `num boxes` x 4.
|
57 |
+
inverse_uniform_sampling (bool): if True, sample uniformly in
|
58 |
+
[1 / max_scale, 1 / min_scale] and take a reciprocal to get the
|
59 |
+
scale. If False, take a uniform sample from [min_scale, max_scale].
|
60 |
+
Returns:
|
61 |
+
(tensor): the scaled images with dimension of
|
62 |
+
`num frames` x `channel` x `new height` x `new width`.
|
63 |
+
(ndarray or None): the scaled boxes with dimension of
|
64 |
+
`num boxes` x 4.
|
65 |
+
"""
|
66 |
+
if inverse_uniform_sampling:
|
67 |
+
size = int(
|
68 |
+
round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
size = int(round(np.random.uniform(min_size, max_size)))
|
72 |
+
|
73 |
+
height = images.shape[2]
|
74 |
+
width = images.shape[3]
|
75 |
+
if (width <= height and width == size) or (
|
76 |
+
height <= width and height == size
|
77 |
+
):
|
78 |
+
return images, boxes
|
79 |
+
new_width = size
|
80 |
+
new_height = size
|
81 |
+
if width < height:
|
82 |
+
new_height = int(math.floor((float(height) / width) * size))
|
83 |
+
if boxes is not None:
|
84 |
+
boxes = boxes * float(new_height) / height
|
85 |
+
else:
|
86 |
+
new_width = int(math.floor((float(width) / height) * size))
|
87 |
+
if boxes is not None:
|
88 |
+
boxes = boxes * float(new_width) / width
|
89 |
+
|
90 |
+
return (
|
91 |
+
torch.nn.functional.interpolate(
|
92 |
+
images,
|
93 |
+
size=(new_height, new_width),
|
94 |
+
mode="bilinear",
|
95 |
+
align_corners=False,
|
96 |
+
),
|
97 |
+
boxes,
|
98 |
+
)
|
99 |
+
|
100 |
+
|
101 |
+
def crop_boxes(boxes, x_offset, y_offset):
|
102 |
+
"""
|
103 |
+
Peform crop on the bounding boxes given the offsets.
|
104 |
+
Args:
|
105 |
+
boxes (ndarray or None): bounding boxes to peform crop. The dimension
|
106 |
+
is `num boxes` x 4.
|
107 |
+
x_offset (int): cropping offset in the x axis.
|
108 |
+
y_offset (int): cropping offset in the y axis.
|
109 |
+
Returns:
|
110 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
111 |
+
`num boxes` x 4.
|
112 |
+
"""
|
113 |
+
cropped_boxes = boxes.copy()
|
114 |
+
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
|
115 |
+
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
|
116 |
+
|
117 |
+
return cropped_boxes
|
118 |
+
|
119 |
+
|
120 |
+
def random_crop(images, size, boxes=None):
|
121 |
+
"""
|
122 |
+
Perform random spatial crop on the given images and corresponding boxes.
|
123 |
+
Args:
|
124 |
+
images (tensor): images to perform random crop. The dimension is
|
125 |
+
`num frames` x `channel` x `height` x `width`.
|
126 |
+
size (int): the size of height and width to crop on the image.
|
127 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
128 |
+
Dimension is `num boxes` x 4.
|
129 |
+
Returns:
|
130 |
+
cropped (tensor): cropped images with dimension of
|
131 |
+
`num frames` x `channel` x `size` x `size`.
|
132 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
133 |
+
`num boxes` x 4.
|
134 |
+
"""
|
135 |
+
if images.shape[2] == size and images.shape[3] == size:
|
136 |
+
return images
|
137 |
+
height = images.shape[2]
|
138 |
+
width = images.shape[3]
|
139 |
+
y_offset = 0
|
140 |
+
if height > size:
|
141 |
+
y_offset = int(np.random.randint(0, height - size))
|
142 |
+
x_offset = 0
|
143 |
+
if width > size:
|
144 |
+
x_offset = int(np.random.randint(0, width - size))
|
145 |
+
cropped = images[
|
146 |
+
:, :, y_offset : y_offset + size, x_offset : x_offset + size
|
147 |
+
]
|
148 |
+
|
149 |
+
cropped_boxes = (
|
150 |
+
crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
151 |
+
)
|
152 |
+
|
153 |
+
return cropped, cropped_boxes
|
154 |
+
|
155 |
+
|
156 |
+
def horizontal_flip(prob, images, boxes=None):
|
157 |
+
"""
|
158 |
+
Perform horizontal flip on the given images and corresponding boxes.
|
159 |
+
Args:
|
160 |
+
prob (float): probility to flip the images.
|
161 |
+
images (tensor): images to perform horizontal flip, the dimension is
|
162 |
+
`num frames` x `channel` x `height` x `width`.
|
163 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
164 |
+
Dimension is `num boxes` x 4.
|
165 |
+
Returns:
|
166 |
+
images (tensor): images with dimension of
|
167 |
+
`num frames` x `channel` x `height` x `width`.
|
168 |
+
flipped_boxes (ndarray or None): the flipped boxes with dimension of
|
169 |
+
`num boxes` x 4.
|
170 |
+
"""
|
171 |
+
if boxes is None:
|
172 |
+
flipped_boxes = None
|
173 |
+
else:
|
174 |
+
flipped_boxes = boxes.copy()
|
175 |
+
|
176 |
+
if np.random.uniform() < prob:
|
177 |
+
images = images.flip((-1))
|
178 |
+
|
179 |
+
if len(images.shape) == 3:
|
180 |
+
width = images.shape[2]
|
181 |
+
elif len(images.shape) == 4:
|
182 |
+
width = images.shape[3]
|
183 |
+
else:
|
184 |
+
raise NotImplementedError("Dimension does not supported")
|
185 |
+
if boxes is not None:
|
186 |
+
flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1
|
187 |
+
|
188 |
+
return images, flipped_boxes
|
189 |
+
|
190 |
+
|
191 |
+
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
|
192 |
+
"""
|
193 |
+
Perform uniform spatial sampling on the images and corresponding boxes.
|
194 |
+
Args:
|
195 |
+
images (tensor): images to perform uniform crop. The dimension is
|
196 |
+
`num frames` x `channel` x `height` x `width`.
|
197 |
+
size (int): size of height and weight to crop the images.
|
198 |
+
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
|
199 |
+
is larger than height. Or 0, 1, or 2 for top, center, and bottom
|
200 |
+
crop if height is larger than width.
|
201 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
202 |
+
Dimension is `num boxes` x 4.
|
203 |
+
scale_size (int): optinal. If not None, resize the images to scale_size before
|
204 |
+
performing any crop.
|
205 |
+
Returns:
|
206 |
+
cropped (tensor): images with dimension of
|
207 |
+
`num frames` x `channel` x `size` x `size`.
|
208 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
209 |
+
`num boxes` x 4.
|
210 |
+
"""
|
211 |
+
assert spatial_idx in [0, 1, 2]
|
212 |
+
ndim = len(images.shape)
|
213 |
+
if ndim == 3:
|
214 |
+
images = images.unsqueeze(0)
|
215 |
+
height = images.shape[2]
|
216 |
+
width = images.shape[3]
|
217 |
+
|
218 |
+
if scale_size is not None:
|
219 |
+
if width <= height:
|
220 |
+
width, height = scale_size, int(height / width * scale_size)
|
221 |
+
else:
|
222 |
+
width, height = int(width / height * scale_size), scale_size
|
223 |
+
images = torch.nn.functional.interpolate(
|
224 |
+
images,
|
225 |
+
size=(height, width),
|
226 |
+
mode="bilinear",
|
227 |
+
align_corners=False,
|
228 |
+
)
|
229 |
+
|
230 |
+
y_offset = int(math.ceil((height - size) / 2))
|
231 |
+
x_offset = int(math.ceil((width - size) / 2))
|
232 |
+
|
233 |
+
if height > width:
|
234 |
+
if spatial_idx == 0:
|
235 |
+
y_offset = 0
|
236 |
+
elif spatial_idx == 2:
|
237 |
+
y_offset = height - size
|
238 |
+
else:
|
239 |
+
if spatial_idx == 0:
|
240 |
+
x_offset = 0
|
241 |
+
elif spatial_idx == 2:
|
242 |
+
x_offset = width - size
|
243 |
+
cropped = images[
|
244 |
+
:, :, y_offset : y_offset + size, x_offset : x_offset + size
|
245 |
+
]
|
246 |
+
cropped_boxes = (
|
247 |
+
crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
248 |
+
)
|
249 |
+
if ndim == 3:
|
250 |
+
cropped = cropped.squeeze(0)
|
251 |
+
return cropped, cropped_boxes
|
252 |
+
|
253 |
+
|
254 |
+
def clip_boxes_to_image(boxes, height, width):
|
255 |
+
"""
|
256 |
+
Clip an array of boxes to an image with the given height and width.
|
257 |
+
Args:
|
258 |
+
boxes (ndarray): bounding boxes to perform clipping.
|
259 |
+
Dimension is `num boxes` x 4.
|
260 |
+
height (int): given image height.
|
261 |
+
width (int): given image width.
|
262 |
+
Returns:
|
263 |
+
clipped_boxes (ndarray): the clipped boxes with dimension of
|
264 |
+
`num boxes` x 4.
|
265 |
+
"""
|
266 |
+
clipped_boxes = boxes.copy()
|
267 |
+
clipped_boxes[:, [0, 2]] = np.minimum(
|
268 |
+
width - 1.0, np.maximum(0.0, boxes[:, [0, 2]])
|
269 |
+
)
|
270 |
+
clipped_boxes[:, [1, 3]] = np.minimum(
|
271 |
+
height - 1.0, np.maximum(0.0, boxes[:, [1, 3]])
|
272 |
+
)
|
273 |
+
return clipped_boxes
|
274 |
+
|
275 |
+
|
276 |
+
def blend(images1, images2, alpha):
|
277 |
+
"""
|
278 |
+
Blend two images with a given weight alpha.
|
279 |
+
Args:
|
280 |
+
images1 (tensor): the first images to be blended, the dimension is
|
281 |
+
`num frames` x `channel` x `height` x `width`.
|
282 |
+
images2 (tensor): the second images to be blended, the dimension is
|
283 |
+
`num frames` x `channel` x `height` x `width`.
|
284 |
+
alpha (float): the blending weight.
|
285 |
+
Returns:
|
286 |
+
(tensor): blended images, the dimension is
|
287 |
+
`num frames` x `channel` x `height` x `width`.
|
288 |
+
"""
|
289 |
+
return images1 * alpha + images2 * (1 - alpha)
|
290 |
+
|
291 |
+
|
292 |
+
def grayscale(images):
|
293 |
+
"""
|
294 |
+
Get the grayscale for the input images. The channels of images should be
|
295 |
+
in order BGR.
|
296 |
+
Args:
|
297 |
+
images (tensor): the input images for getting grayscale. Dimension is
|
298 |
+
`num frames` x `channel` x `height` x `width`.
|
299 |
+
Returns:
|
300 |
+
img_gray (tensor): blended images, the dimension is
|
301 |
+
`num frames` x `channel` x `height` x `width`.
|
302 |
+
"""
|
303 |
+
# R -> 0.299, G -> 0.587, B -> 0.114.
|
304 |
+
img_gray = torch.tensor(images)
|
305 |
+
gray_channel = (
|
306 |
+
0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0]
|
307 |
+
)
|
308 |
+
img_gray[:, 0] = gray_channel
|
309 |
+
img_gray[:, 1] = gray_channel
|
310 |
+
img_gray[:, 2] = gray_channel
|
311 |
+
return img_gray
|
312 |
+
|
313 |
+
|
314 |
+
def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0):
|
315 |
+
"""
|
316 |
+
Perfrom a color jittering on the input images. The channels of images
|
317 |
+
should be in order BGR.
|
318 |
+
Args:
|
319 |
+
images (tensor): images to perform color jitter. Dimension is
|
320 |
+
`num frames` x `channel` x `height` x `width`.
|
321 |
+
img_brightness (float): jitter ratio for brightness.
|
322 |
+
img_contrast (float): jitter ratio for contrast.
|
323 |
+
img_saturation (float): jitter ratio for saturation.
|
324 |
+
Returns:
|
325 |
+
images (tensor): the jittered images, the dimension is
|
326 |
+
`num frames` x `channel` x `height` x `width`.
|
327 |
+
"""
|
328 |
+
|
329 |
+
jitter = []
|
330 |
+
if img_brightness != 0:
|
331 |
+
jitter.append("brightness")
|
332 |
+
if img_contrast != 0:
|
333 |
+
jitter.append("contrast")
|
334 |
+
if img_saturation != 0:
|
335 |
+
jitter.append("saturation")
|
336 |
+
|
337 |
+
if len(jitter) > 0:
|
338 |
+
order = np.random.permutation(np.arange(len(jitter)))
|
339 |
+
for idx in range(0, len(jitter)):
|
340 |
+
if jitter[order[idx]] == "brightness":
|
341 |
+
images = brightness_jitter(img_brightness, images)
|
342 |
+
elif jitter[order[idx]] == "contrast":
|
343 |
+
images = contrast_jitter(img_contrast, images)
|
344 |
+
elif jitter[order[idx]] == "saturation":
|
345 |
+
images = saturation_jitter(img_saturation, images)
|
346 |
+
return images
|
347 |
+
|
348 |
+
|
349 |
+
def brightness_jitter(var, images):
|
350 |
+
"""
|
351 |
+
Perfrom brightness jittering on the input images. The channels of images
|
352 |
+
should be in order BGR.
|
353 |
+
Args:
|
354 |
+
var (float): jitter ratio for brightness.
|
355 |
+
images (tensor): images to perform color jitter. Dimension is
|
356 |
+
`num frames` x `channel` x `height` x `width`.
|
357 |
+
Returns:
|
358 |
+
images (tensor): the jittered images, the dimension is
|
359 |
+
`num frames` x `channel` x `height` x `width`.
|
360 |
+
"""
|
361 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
362 |
+
|
363 |
+
img_bright = torch.zeros(images.shape)
|
364 |
+
images = blend(images, img_bright, alpha)
|
365 |
+
return images
|
366 |
+
|
367 |
+
|
368 |
+
def contrast_jitter(var, images):
|
369 |
+
"""
|
370 |
+
Perfrom contrast jittering on the input images. The channels of images
|
371 |
+
should be in order BGR.
|
372 |
+
Args:
|
373 |
+
var (float): jitter ratio for contrast.
|
374 |
+
images (tensor): images to perform color jitter. Dimension is
|
375 |
+
`num frames` x `channel` x `height` x `width`.
|
376 |
+
Returns:
|
377 |
+
images (tensor): the jittered images, the dimension is
|
378 |
+
`num frames` x `channel` x `height` x `width`.
|
379 |
+
"""
|
380 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
381 |
+
|
382 |
+
img_gray = grayscale(images)
|
383 |
+
img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True)
|
384 |
+
images = blend(images, img_gray, alpha)
|
385 |
+
return images
|
386 |
+
|
387 |
+
|
388 |
+
def saturation_jitter(var, images):
|
389 |
+
"""
|
390 |
+
Perfrom saturation jittering on the input images. The channels of images
|
391 |
+
should be in order BGR.
|
392 |
+
Args:
|
393 |
+
var (float): jitter ratio for saturation.
|
394 |
+
images (tensor): images to perform color jitter. Dimension is
|
395 |
+
`num frames` x `channel` x `height` x `width`.
|
396 |
+
Returns:
|
397 |
+
images (tensor): the jittered images, the dimension is
|
398 |
+
`num frames` x `channel` x `height` x `width`.
|
399 |
+
"""
|
400 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
401 |
+
img_gray = grayscale(images)
|
402 |
+
images = blend(images, img_gray, alpha)
|
403 |
+
|
404 |
+
return images
|
405 |
+
|
406 |
+
|
407 |
+
def lighting_jitter(images, alphastd, eigval, eigvec):
|
408 |
+
"""
|
409 |
+
Perform AlexNet-style PCA jitter on the given images.
|
410 |
+
Args:
|
411 |
+
images (tensor): images to perform lighting jitter. Dimension is
|
412 |
+
`num frames` x `channel` x `height` x `width`.
|
413 |
+
alphastd (float): jitter ratio for PCA jitter.
|
414 |
+
eigval (list): eigenvalues for PCA jitter.
|
415 |
+
eigvec (list[list]): eigenvectors for PCA jitter.
|
416 |
+
Returns:
|
417 |
+
out_images (tensor): the jittered images, the dimension is
|
418 |
+
`num frames` x `channel` x `height` x `width`.
|
419 |
+
"""
|
420 |
+
if alphastd == 0:
|
421 |
+
return images
|
422 |
+
# generate alpha1, alpha2, alpha3.
|
423 |
+
alpha = np.random.normal(0, alphastd, size=(1, 3))
|
424 |
+
eig_vec = np.array(eigvec)
|
425 |
+
eig_val = np.reshape(eigval, (1, 3))
|
426 |
+
rgb = np.sum(
|
427 |
+
eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
|
428 |
+
axis=1,
|
429 |
+
)
|
430 |
+
out_images = torch.zeros_like(images)
|
431 |
+
if len(images.shape) == 3:
|
432 |
+
# C H W
|
433 |
+
channel_dim = 0
|
434 |
+
elif len(images.shape) == 4:
|
435 |
+
# T C H W
|
436 |
+
channel_dim = 1
|
437 |
+
else:
|
438 |
+
raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
|
439 |
+
|
440 |
+
for idx in range(images.shape[channel_dim]):
|
441 |
+
# C H W
|
442 |
+
if len(images.shape) == 3:
|
443 |
+
out_images[idx] = images[idx] + rgb[2 - idx]
|
444 |
+
# T C H W
|
445 |
+
elif len(images.shape) == 4:
|
446 |
+
out_images[:, idx] = images[:, idx] + rgb[2 - idx]
|
447 |
+
else:
|
448 |
+
raise NotImplementedError(
|
449 |
+
f"Unsupported dimension {len(images.shape)}"
|
450 |
+
)
|
451 |
+
|
452 |
+
return out_images
|
453 |
+
|
454 |
+
|
455 |
+
def color_normalization(images, mean, stddev):
|
456 |
+
"""
|
457 |
+
Perform color nomration on the given images.
|
458 |
+
Args:
|
459 |
+
images (tensor): images to perform color normalization. Dimension is
|
460 |
+
`num frames` x `channel` x `height` x `width`.
|
461 |
+
mean (list): mean values for normalization.
|
462 |
+
stddev (list): standard deviations for normalization.
|
463 |
+
|
464 |
+
Returns:
|
465 |
+
out_images (tensor): the noramlized images, the dimension is
|
466 |
+
`num frames` x `channel` x `height` x `width`.
|
467 |
+
"""
|
468 |
+
if len(images.shape) == 3:
|
469 |
+
assert (
|
470 |
+
len(mean) == images.shape[0]
|
471 |
+
), "channel mean not computed properly"
|
472 |
+
assert (
|
473 |
+
len(stddev) == images.shape[0]
|
474 |
+
), "channel stddev not computed properly"
|
475 |
+
elif len(images.shape) == 4:
|
476 |
+
assert (
|
477 |
+
len(mean) == images.shape[1]
|
478 |
+
), "channel mean not computed properly"
|
479 |
+
assert (
|
480 |
+
len(stddev) == images.shape[1]
|
481 |
+
), "channel stddev not computed properly"
|
482 |
+
else:
|
483 |
+
raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
|
484 |
+
|
485 |
+
out_images = torch.zeros_like(images)
|
486 |
+
for idx in range(len(mean)):
|
487 |
+
# C H W
|
488 |
+
if len(images.shape) == 3:
|
489 |
+
out_images[idx] = (images[idx] - mean[idx]) / stddev[idx]
|
490 |
+
elif len(images.shape) == 4:
|
491 |
+
out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx]
|
492 |
+
else:
|
493 |
+
raise NotImplementedError(
|
494 |
+
f"Unsupported dimension {len(images.shape)}"
|
495 |
+
)
|
496 |
+
return out_images
|
497 |
+
|
498 |
+
|
499 |
+
def _get_param_spatial_crop(
|
500 |
+
scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False
|
501 |
+
):
|
502 |
+
"""
|
503 |
+
Given scale, ratio, height and width, return sampled coordinates of the videos.
|
504 |
+
"""
|
505 |
+
for _ in range(num_repeat):
|
506 |
+
area = height * width
|
507 |
+
target_area = random.uniform(*scale) * area
|
508 |
+
if log_scale:
|
509 |
+
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
510 |
+
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
511 |
+
else:
|
512 |
+
aspect_ratio = random.uniform(*ratio)
|
513 |
+
|
514 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
515 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
516 |
+
|
517 |
+
if np.random.uniform() < 0.5 and switch_hw:
|
518 |
+
w, h = h, w
|
519 |
+
|
520 |
+
if 0 < w <= width and 0 < h <= height:
|
521 |
+
i = random.randint(0, height - h)
|
522 |
+
j = random.randint(0, width - w)
|
523 |
+
return i, j, h, w
|
524 |
+
|
525 |
+
# Fallback to central crop
|
526 |
+
in_ratio = float(width) / float(height)
|
527 |
+
if in_ratio < min(ratio):
|
528 |
+
w = width
|
529 |
+
h = int(round(w / min(ratio)))
|
530 |
+
elif in_ratio > max(ratio):
|
531 |
+
h = height
|
532 |
+
w = int(round(h * max(ratio)))
|
533 |
+
else: # whole image
|
534 |
+
w = width
|
535 |
+
h = height
|
536 |
+
i = (height - h) // 2
|
537 |
+
j = (width - w) // 2
|
538 |
+
return i, j, h, w
|
539 |
+
|
540 |
+
|
541 |
+
def random_resized_crop(
|
542 |
+
images,
|
543 |
+
target_height,
|
544 |
+
target_width,
|
545 |
+
scale=(0.8, 1.0),
|
546 |
+
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
547 |
+
):
|
548 |
+
"""
|
549 |
+
Crop the given images to random size and aspect ratio. A crop of random
|
550 |
+
size (default: of 0.08 to 1.0) of the original size and a random aspect
|
551 |
+
ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This
|
552 |
+
crop is finally resized to given size. This is popularly used to train the
|
553 |
+
Inception networks.
|
554 |
+
|
555 |
+
Args:
|
556 |
+
images: Images to perform resizing and cropping.
|
557 |
+
target_height: Desired height after cropping.
|
558 |
+
target_width: Desired width after cropping.
|
559 |
+
scale: Scale range of Inception-style area based random resizing.
|
560 |
+
ratio: Aspect ratio range of Inception-style area based random resizing.
|
561 |
+
"""
|
562 |
+
|
563 |
+
height = images.shape[2]
|
564 |
+
width = images.shape[3]
|
565 |
+
|
566 |
+
i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
|
567 |
+
cropped = images[:, :, i : i + h, j : j + w]
|
568 |
+
return torch.nn.functional.interpolate(
|
569 |
+
cropped,
|
570 |
+
size=(target_height, target_width),
|
571 |
+
mode="bilinear",
|
572 |
+
align_corners=False,
|
573 |
+
)
|
574 |
+
|
575 |
+
|
576 |
+
def random_resized_crop_with_shift(
|
577 |
+
images,
|
578 |
+
target_height,
|
579 |
+
target_width,
|
580 |
+
scale=(0.8, 1.0),
|
581 |
+
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
582 |
+
):
|
583 |
+
"""
|
584 |
+
This is similar to random_resized_crop. However, it samples two different
|
585 |
+
boxes (for cropping) for the first and last frame. It then linearly
|
586 |
+
interpolates the two boxes for other frames.
|
587 |
+
|
588 |
+
Args:
|
589 |
+
images: Images to perform resizing and cropping.
|
590 |
+
target_height: Desired height after cropping.
|
591 |
+
target_width: Desired width after cropping.
|
592 |
+
scale: Scale range of Inception-style area based random resizing.
|
593 |
+
ratio: Aspect ratio range of Inception-style area based random resizing.
|
594 |
+
"""
|
595 |
+
t = images.shape[1]
|
596 |
+
height = images.shape[2]
|
597 |
+
width = images.shape[3]
|
598 |
+
|
599 |
+
i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
|
600 |
+
i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width)
|
601 |
+
i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()]
|
602 |
+
j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()]
|
603 |
+
h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()]
|
604 |
+
w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()]
|
605 |
+
out = torch.zeros((3, t, target_height, target_width))
|
606 |
+
for ind in range(t):
|
607 |
+
out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate(
|
608 |
+
images[
|
609 |
+
:,
|
610 |
+
ind : ind + 1,
|
611 |
+
i_s[ind] : i_s[ind] + h_s[ind],
|
612 |
+
j_s[ind] : j_s[ind] + w_s[ind],
|
613 |
+
],
|
614 |
+
size=(target_height, target_width),
|
615 |
+
mode="bilinear",
|
616 |
+
align_corners=False,
|
617 |
+
)
|
618 |
+
return out
|
619 |
+
|
620 |
+
|
621 |
+
def create_random_augment(
|
622 |
+
input_size,
|
623 |
+
auto_augment=None,
|
624 |
+
interpolation="bilinear",
|
625 |
+
):
|
626 |
+
"""
|
627 |
+
Get video randaug transform.
|
628 |
+
|
629 |
+
Args:
|
630 |
+
input_size: The size of the input video in tuple.
|
631 |
+
auto_augment: Parameters for randaug. An example:
|
632 |
+
"rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number
|
633 |
+
of operations to apply).
|
634 |
+
interpolation: Interpolation method.
|
635 |
+
"""
|
636 |
+
if isinstance(input_size, tuple):
|
637 |
+
img_size = input_size[-2:]
|
638 |
+
else:
|
639 |
+
img_size = input_size
|
640 |
+
|
641 |
+
if auto_augment:
|
642 |
+
assert isinstance(auto_augment, str)
|
643 |
+
if isinstance(img_size, tuple):
|
644 |
+
img_size_min = min(img_size)
|
645 |
+
else:
|
646 |
+
img_size_min = img_size
|
647 |
+
aa_params = {"translate_const": int(img_size_min * 0.45)}
|
648 |
+
if interpolation and interpolation != "random":
|
649 |
+
aa_params["interpolation"] = _pil_interp(interpolation)
|
650 |
+
if auto_augment.startswith("rand"):
|
651 |
+
return transforms.Compose(
|
652 |
+
[rand_augment_transform(auto_augment, aa_params)]
|
653 |
+
)
|
654 |
+
raise NotImplementedError
|
655 |
+
|
656 |
+
|
657 |
+
def random_sized_crop_img(
|
658 |
+
im,
|
659 |
+
size,
|
660 |
+
jitter_scale=(0.08, 1.0),
|
661 |
+
jitter_aspect=(3.0 / 4.0, 4.0 / 3.0),
|
662 |
+
max_iter=10,
|
663 |
+
):
|
664 |
+
"""
|
665 |
+
Performs Inception-style cropping (used for training).
|
666 |
+
"""
|
667 |
+
assert (
|
668 |
+
len(im.shape) == 3
|
669 |
+
), "Currently only support image for random_sized_crop"
|
670 |
+
h, w = im.shape[1:3]
|
671 |
+
i, j, h, w = _get_param_spatial_crop(
|
672 |
+
scale=jitter_scale,
|
673 |
+
ratio=jitter_aspect,
|
674 |
+
height=h,
|
675 |
+
width=w,
|
676 |
+
num_repeat=max_iter,
|
677 |
+
log_scale=False,
|
678 |
+
switch_hw=True,
|
679 |
+
)
|
680 |
+
cropped = im[:, i : i + h, j : j + w]
|
681 |
+
return torch.nn.functional.interpolate(
|
682 |
+
cropped.unsqueeze(0),
|
683 |
+
size=(size, size),
|
684 |
+
mode="bilinear",
|
685 |
+
align_corners=False,
|
686 |
+
).squeeze(0)
|
687 |
+
|
688 |
+
|
689 |
+
# The following code are modified based on timm lib, we will replace the following
|
690 |
+
# contents with dependency from PyTorchVideo.
|
691 |
+
# https://github.com/facebookresearch/pytorchvideo
|
692 |
+
class RandomResizedCropAndInterpolation:
|
693 |
+
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
|
694 |
+
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
|
695 |
+
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
|
696 |
+
is finally resized to given size.
|
697 |
+
This is popularly used to train the Inception networks.
|
698 |
+
Args:
|
699 |
+
size: expected output size of each edge
|
700 |
+
scale: range of size of the origin size cropped
|
701 |
+
ratio: range of aspect ratio of the origin aspect ratio cropped
|
702 |
+
interpolation: Default: PIL.Image.BILINEAR
|
703 |
+
"""
|
704 |
+
|
705 |
+
def __init__(
|
706 |
+
self,
|
707 |
+
size,
|
708 |
+
scale=(0.08, 1.0),
|
709 |
+
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
710 |
+
interpolation="bilinear",
|
711 |
+
):
|
712 |
+
if isinstance(size, tuple):
|
713 |
+
self.size = size
|
714 |
+
else:
|
715 |
+
self.size = (size, size)
|
716 |
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
717 |
+
print("range should be of kind (min, max)")
|
718 |
+
|
719 |
+
if interpolation == "random":
|
720 |
+
self.interpolation = _RANDOM_INTERPOLATION
|
721 |
+
else:
|
722 |
+
self.interpolation = _pil_interp(interpolation)
|
723 |
+
self.scale = scale
|
724 |
+
self.ratio = ratio
|
725 |
+
|
726 |
+
@staticmethod
|
727 |
+
def get_params(img, scale, ratio):
|
728 |
+
"""Get parameters for ``crop`` for a random sized crop.
|
729 |
+
Args:
|
730 |
+
img (PIL Image): Image to be cropped.
|
731 |
+
scale (tuple): range of size of the origin size cropped
|
732 |
+
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
|
733 |
+
Returns:
|
734 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
735 |
+
sized crop.
|
736 |
+
"""
|
737 |
+
area = img.size[0] * img.size[1]
|
738 |
+
|
739 |
+
for _ in range(10):
|
740 |
+
target_area = random.uniform(*scale) * area
|
741 |
+
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
742 |
+
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
743 |
+
|
744 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
745 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
746 |
+
|
747 |
+
if w <= img.size[0] and h <= img.size[1]:
|
748 |
+
i = random.randint(0, img.size[1] - h)
|
749 |
+
j = random.randint(0, img.size[0] - w)
|
750 |
+
return i, j, h, w
|
751 |
+
|
752 |
+
# Fallback to central crop
|
753 |
+
in_ratio = img.size[0] / img.size[1]
|
754 |
+
if in_ratio < min(ratio):
|
755 |
+
w = img.size[0]
|
756 |
+
h = int(round(w / min(ratio)))
|
757 |
+
elif in_ratio > max(ratio):
|
758 |
+
h = img.size[1]
|
759 |
+
w = int(round(h * max(ratio)))
|
760 |
+
else: # whole image
|
761 |
+
w = img.size[0]
|
762 |
+
h = img.size[1]
|
763 |
+
i = (img.size[1] - h) // 2
|
764 |
+
j = (img.size[0] - w) // 2
|
765 |
+
return i, j, h, w
|
766 |
+
|
767 |
+
def __call__(self, img):
|
768 |
+
"""
|
769 |
+
Args:
|
770 |
+
img (PIL Image): Image to be cropped and resized.
|
771 |
+
Returns:
|
772 |
+
PIL Image: Randomly cropped and resized image.
|
773 |
+
"""
|
774 |
+
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
775 |
+
if isinstance(self.interpolation, (tuple, list)):
|
776 |
+
interpolation = random.choice(self.interpolation)
|
777 |
+
else:
|
778 |
+
interpolation = self.interpolation
|
779 |
+
return F.resized_crop(img, i, j, h, w, self.size, interpolation)
|
780 |
+
|
781 |
+
def __repr__(self):
|
782 |
+
if isinstance(self.interpolation, (tuple, list)):
|
783 |
+
interpolate_str = " ".join(
|
784 |
+
[_pil_interpolation_to_str[x] for x in self.interpolation]
|
785 |
+
)
|
786 |
+
else:
|
787 |
+
interpolate_str = _pil_interpolation_to_str[self.interpolation]
|
788 |
+
format_string = self.__class__.__name__ + "(size={0}".format(self.size)
|
789 |
+
format_string += ", scale={0}".format(
|
790 |
+
tuple(round(s, 4) for s in self.scale)
|
791 |
+
)
|
792 |
+
format_string += ", ratio={0}".format(
|
793 |
+
tuple(round(r, 4) for r in self.ratio)
|
794 |
+
)
|
795 |
+
format_string += ", interpolation={0})".format(interpolate_str)
|
796 |
+
return format_string
|
797 |
+
|
798 |
+
|
799 |
+
def transforms_imagenet_train(
|
800 |
+
img_size=224,
|
801 |
+
scale=None,
|
802 |
+
ratio=None,
|
803 |
+
hflip=0.5,
|
804 |
+
vflip=0.0,
|
805 |
+
color_jitter=0.4,
|
806 |
+
auto_augment=None,
|
807 |
+
interpolation="random",
|
808 |
+
use_prefetcher=False,
|
809 |
+
mean=(0.485, 0.456, 0.406),
|
810 |
+
std=(0.229, 0.224, 0.225),
|
811 |
+
re_prob=0.0,
|
812 |
+
re_mode="const",
|
813 |
+
re_count=1,
|
814 |
+
re_num_splits=0,
|
815 |
+
separate=False,
|
816 |
+
):
|
817 |
+
"""
|
818 |
+
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
819 |
+
for use in a mixing dataset that passes
|
820 |
+
* all data through the first (primary) transform, called the 'clean' data
|
821 |
+
* a portion of the data through the secondary transform
|
822 |
+
* normalizes and converts the branches above with the third, final transform
|
823 |
+
"""
|
824 |
+
if isinstance(img_size, tuple):
|
825 |
+
img_size = img_size[-2:]
|
826 |
+
else:
|
827 |
+
img_size = img_size
|
828 |
+
|
829 |
+
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
830 |
+
ratio = tuple(
|
831 |
+
ratio or (3.0 / 4.0, 4.0 / 3.0)
|
832 |
+
) # default imagenet ratio range
|
833 |
+
primary_tfl = [
|
834 |
+
RandomResizedCropAndInterpolation(
|
835 |
+
img_size, scale=scale, ratio=ratio, interpolation=interpolation
|
836 |
+
)
|
837 |
+
]
|
838 |
+
if hflip > 0.0:
|
839 |
+
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
|
840 |
+
if vflip > 0.0:
|
841 |
+
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
|
842 |
+
|
843 |
+
secondary_tfl = []
|
844 |
+
if auto_augment:
|
845 |
+
assert isinstance(auto_augment, str)
|
846 |
+
if isinstance(img_size, tuple):
|
847 |
+
img_size_min = min(img_size)
|
848 |
+
else:
|
849 |
+
img_size_min = img_size
|
850 |
+
aa_params = dict(
|
851 |
+
translate_const=int(img_size_min * 0.45),
|
852 |
+
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
853 |
+
)
|
854 |
+
if interpolation and interpolation != "random":
|
855 |
+
aa_params["interpolation"] = _pil_interp(interpolation)
|
856 |
+
if auto_augment.startswith("rand"):
|
857 |
+
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
|
858 |
+
elif auto_augment.startswith("augmix"):
|
859 |
+
raise NotImplementedError("Augmix not implemented")
|
860 |
+
else:
|
861 |
+
raise NotImplementedError("Auto aug not implemented")
|
862 |
+
elif color_jitter is not None:
|
863 |
+
# color jitter is enabled when not using AA
|
864 |
+
if isinstance(color_jitter, (list, tuple)):
|
865 |
+
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
|
866 |
+
# or 4 if also augmenting hue
|
867 |
+
assert len(color_jitter) in (3, 4)
|
868 |
+
else:
|
869 |
+
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
870 |
+
color_jitter = (float(color_jitter),) * 3
|
871 |
+
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
|
872 |
+
|
873 |
+
final_tfl = []
|
874 |
+
final_tfl += [
|
875 |
+
transforms.ToTensor(),
|
876 |
+
transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
|
877 |
+
]
|
878 |
+
if re_prob > 0.0:
|
879 |
+
final_tfl.append(
|
880 |
+
RandomErasing(
|
881 |
+
re_prob,
|
882 |
+
mode=re_mode,
|
883 |
+
max_count=re_count,
|
884 |
+
num_splits=re_num_splits,
|
885 |
+
device="cpu",
|
886 |
+
cube=False,
|
887 |
+
)
|
888 |
+
)
|
889 |
+
|
890 |
+
if separate:
|
891 |
+
return (
|
892 |
+
transforms.Compose(primary_tfl),
|
893 |
+
transforms.Compose(secondary_tfl),
|
894 |
+
transforms.Compose(final_tfl),
|
895 |
+
)
|
896 |
+
else:
|
897 |
+
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
|
898 |
+
|
899 |
+
############################################################################################################
|
900 |
+
############################################################################################################
|
901 |
+
|
902 |
+
class Compose(object):
|
903 |
+
"""Composes several transforms
|
904 |
+
Args:
|
905 |
+
transforms (list of ``Transform`` objects): list of transforms
|
906 |
+
to compose
|
907 |
+
"""
|
908 |
+
|
909 |
+
def __init__(self, transforms):
|
910 |
+
self.transforms = transforms
|
911 |
+
|
912 |
+
def __call__(self, clip):
|
913 |
+
for t in self.transforms:
|
914 |
+
clip = t(clip)
|
915 |
+
return clip
|
916 |
+
|
917 |
+
|
918 |
+
class RandomHorizontalFlip(object):
|
919 |
+
"""Horizontally flip the list of given images randomly
|
920 |
+
with a probability 0.5
|
921 |
+
"""
|
922 |
+
|
923 |
+
def __call__(self, clip):
|
924 |
+
"""
|
925 |
+
Args:
|
926 |
+
img (PIL.Image or numpy.ndarray): List of images to be cropped
|
927 |
+
in format (h, w, c) in numpy.ndarray
|
928 |
+
Returns:
|
929 |
+
PIL.Image or numpy.ndarray: Randomly flipped clip
|
930 |
+
"""
|
931 |
+
if random.random() < 0.5:
|
932 |
+
if isinstance(clip[0], np.ndarray):
|
933 |
+
return [np.fliplr(img) for img in clip]
|
934 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
935 |
+
return [
|
936 |
+
img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip
|
937 |
+
]
|
938 |
+
else:
|
939 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
940 |
+
' but got list of {0}'.format(type(clip[0])))
|
941 |
+
return clip
|
942 |
+
|
943 |
+
|
944 |
+
class RandomResize(object):
|
945 |
+
"""Resizes a list of (H x W x C) numpy.ndarray to the final size
|
946 |
+
The larger the original image is, the more times it takes to
|
947 |
+
interpolate
|
948 |
+
Args:
|
949 |
+
interpolation (str): Can be one of 'nearest', 'bilinear'
|
950 |
+
defaults to nearest
|
951 |
+
size (tuple): (widht, height)
|
952 |
+
"""
|
953 |
+
|
954 |
+
def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
|
955 |
+
self.ratio = ratio
|
956 |
+
self.interpolation = interpolation
|
957 |
+
|
958 |
+
def __call__(self, clip):
|
959 |
+
scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
|
960 |
+
|
961 |
+
if isinstance(clip[0], np.ndarray):
|
962 |
+
im_h, im_w, im_c = clip[0].shape
|
963 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
964 |
+
im_w, im_h = clip[0].size
|
965 |
+
|
966 |
+
new_w = int(im_w * scaling_factor)
|
967 |
+
new_h = int(im_h * scaling_factor)
|
968 |
+
new_size = (new_w, new_h)
|
969 |
+
resized = FF.resize_clip(
|
970 |
+
clip, new_size, interpolation=self.interpolation)
|
971 |
+
return resized
|
972 |
+
|
973 |
+
|
974 |
+
class Resize(object):
|
975 |
+
"""Resizes a list of (H x W x C) numpy.ndarray to the final size
|
976 |
+
The larger the original image is, the more times it takes to
|
977 |
+
interpolate
|
978 |
+
Args:
|
979 |
+
interpolation (str): Can be one of 'nearest', 'bilinear'
|
980 |
+
defaults to nearest
|
981 |
+
size (tuple): (widht, height)
|
982 |
+
"""
|
983 |
+
|
984 |
+
def __init__(self, size, interpolation='nearest'):
|
985 |
+
self.size = size
|
986 |
+
self.interpolation = interpolation
|
987 |
+
|
988 |
+
def __call__(self, clip):
|
989 |
+
resized = FF.resize_clip(
|
990 |
+
clip, self.size, interpolation=self.interpolation)
|
991 |
+
return resized
|
992 |
+
|
993 |
+
|
994 |
+
class RandomCrop(object):
|
995 |
+
"""Extract random crop at the same location for a list of images
|
996 |
+
Args:
|
997 |
+
size (sequence or int): Desired output size for the
|
998 |
+
crop in format (h, w)
|
999 |
+
"""
|
1000 |
+
|
1001 |
+
def __init__(self, size):
|
1002 |
+
if isinstance(size, numbers.Number):
|
1003 |
+
size = (size, size)
|
1004 |
+
|
1005 |
+
self.size = size
|
1006 |
+
|
1007 |
+
def __call__(self, clip):
|
1008 |
+
"""
|
1009 |
+
Args:
|
1010 |
+
img (PIL.Image or numpy.ndarray): List of images to be cropped
|
1011 |
+
in format (h, w, c) in numpy.ndarray
|
1012 |
+
Returns:
|
1013 |
+
PIL.Image or numpy.ndarray: Cropped list of images
|
1014 |
+
"""
|
1015 |
+
h, w = self.size
|
1016 |
+
if isinstance(clip[0], np.ndarray):
|
1017 |
+
im_h, im_w, im_c = clip[0].shape
|
1018 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
1019 |
+
im_w, im_h = clip[0].size
|
1020 |
+
else:
|
1021 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
1022 |
+
'but got list of {0}'.format(type(clip[0])))
|
1023 |
+
if w > im_w or h > im_h:
|
1024 |
+
error_msg = (
|
1025 |
+
'Initial image size should be larger then '
|
1026 |
+
'cropped size but got cropped sizes : ({w}, {h}) while '
|
1027 |
+
'initial image is ({im_w}, {im_h})'.format(
|
1028 |
+
im_w=im_w, im_h=im_h, w=w, h=h))
|
1029 |
+
raise ValueError(error_msg)
|
1030 |
+
|
1031 |
+
x1 = random.randint(0, im_w - w)
|
1032 |
+
y1 = random.randint(0, im_h - h)
|
1033 |
+
cropped = FF.crop_clip(clip, y1, x1, h, w)
|
1034 |
+
|
1035 |
+
return cropped
|
1036 |
+
|
1037 |
+
|
1038 |
+
class ThreeCrop(object):
|
1039 |
+
"""Extract random crop at the same location for a list of images
|
1040 |
+
Args:
|
1041 |
+
size (sequence or int): Desired output size for the
|
1042 |
+
crop in format (h, w)
|
1043 |
+
"""
|
1044 |
+
|
1045 |
+
def __init__(self, size):
|
1046 |
+
if isinstance(size, numbers.Number):
|
1047 |
+
size = (size, size)
|
1048 |
+
|
1049 |
+
self.size = size
|
1050 |
+
|
1051 |
+
def __call__(self, clip):
|
1052 |
+
"""
|
1053 |
+
Args:
|
1054 |
+
img (PIL.Image or numpy.ndarray): List of images to be cropped
|
1055 |
+
in format (h, w, c) in numpy.ndarray
|
1056 |
+
Returns:
|
1057 |
+
PIL.Image or numpy.ndarray: Cropped list of images
|
1058 |
+
"""
|
1059 |
+
h, w = self.size
|
1060 |
+
if isinstance(clip[0], np.ndarray):
|
1061 |
+
im_h, im_w, im_c = clip[0].shape
|
1062 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
1063 |
+
im_w, im_h = clip[0].size
|
1064 |
+
else:
|
1065 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
1066 |
+
'but got list of {0}'.format(type(clip[0])))
|
1067 |
+
if w != im_w and h != im_h:
|
1068 |
+
clip = FF.resize_clip(clip, self.size, interpolation="bilinear")
|
1069 |
+
im_h, im_w, im_c = clip[0].shape
|
1070 |
+
|
1071 |
+
step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0)
|
1072 |
+
cropped = []
|
1073 |
+
for i in range(3):
|
1074 |
+
if (im_h > self.size[0]):
|
1075 |
+
x1 = 0
|
1076 |
+
y1 = i * step
|
1077 |
+
cropped.extend(FF.crop_clip(clip, y1, x1, h, w))
|
1078 |
+
else:
|
1079 |
+
x1 = i * step
|
1080 |
+
y1 = 0
|
1081 |
+
cropped.extend(FF.crop_clip(clip, y1, x1, h, w))
|
1082 |
+
return cropped
|
1083 |
+
|
1084 |
+
|
1085 |
+
class RandomRotation(object):
|
1086 |
+
"""Rotate entire clip randomly by a random angle within
|
1087 |
+
given bounds
|
1088 |
+
Args:
|
1089 |
+
degrees (sequence or int): Range of degrees to select from
|
1090 |
+
If degrees is a number instead of sequence like (min, max),
|
1091 |
+
the range of degrees, will be (-degrees, +degrees).
|
1092 |
+
"""
|
1093 |
+
|
1094 |
+
def __init__(self, degrees):
|
1095 |
+
if isinstance(degrees, numbers.Number):
|
1096 |
+
if degrees < 0:
|
1097 |
+
raise ValueError('If degrees is a single number,'
|
1098 |
+
'must be positive')
|
1099 |
+
degrees = (-degrees, degrees)
|
1100 |
+
else:
|
1101 |
+
if len(degrees) != 2:
|
1102 |
+
raise ValueError('If degrees is a sequence,'
|
1103 |
+
'it must be of len 2.')
|
1104 |
+
|
1105 |
+
self.degrees = degrees
|
1106 |
+
|
1107 |
+
def __call__(self, clip):
|
1108 |
+
"""
|
1109 |
+
Args:
|
1110 |
+
img (PIL.Image or numpy.ndarray): List of images to be cropped
|
1111 |
+
in format (h, w, c) in numpy.ndarray
|
1112 |
+
Returns:
|
1113 |
+
PIL.Image or numpy.ndarray: Cropped list of images
|
1114 |
+
"""
|
1115 |
+
import skimage
|
1116 |
+
angle = random.uniform(self.degrees[0], self.degrees[1])
|
1117 |
+
if isinstance(clip[0], np.ndarray):
|
1118 |
+
rotated = [skimage.transform.rotate(img, angle) for img in clip]
|
1119 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
1120 |
+
rotated = [img.rotate(angle) for img in clip]
|
1121 |
+
else:
|
1122 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
1123 |
+
'but got list of {0}'.format(type(clip[0])))
|
1124 |
+
|
1125 |
+
return rotated
|
1126 |
+
|
1127 |
+
|
1128 |
+
class CenterCrop(object):
|
1129 |
+
"""Extract center crop at the same location for a list of images
|
1130 |
+
Args:
|
1131 |
+
size (sequence or int): Desired output size for the
|
1132 |
+
crop in format (h, w)
|
1133 |
+
"""
|
1134 |
+
|
1135 |
+
def __init__(self, size):
|
1136 |
+
if isinstance(size, numbers.Number):
|
1137 |
+
size = (size, size)
|
1138 |
+
|
1139 |
+
self.size = size
|
1140 |
+
|
1141 |
+
def __call__(self, clip):
|
1142 |
+
"""
|
1143 |
+
Args:
|
1144 |
+
img (PIL.Image or numpy.ndarray): List of images to be cropped
|
1145 |
+
in format (h, w, c) in numpy.ndarray
|
1146 |
+
Returns:
|
1147 |
+
PIL.Image or numpy.ndarray: Cropped list of images
|
1148 |
+
"""
|
1149 |
+
h, w = self.size
|
1150 |
+
if isinstance(clip[0], np.ndarray):
|
1151 |
+
im_h, im_w, im_c = clip[0].shape
|
1152 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
1153 |
+
im_w, im_h = clip[0].size
|
1154 |
+
else:
|
1155 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
1156 |
+
'but got list of {0}'.format(type(clip[0])))
|
1157 |
+
if w > im_w or h > im_h:
|
1158 |
+
error_msg = (
|
1159 |
+
'Initial image size should be larger then '
|
1160 |
+
'cropped size but got cropped sizes : ({w}, {h}) while '
|
1161 |
+
'initial image is ({im_w}, {im_h})'.format(
|
1162 |
+
im_w=im_w, im_h=im_h, w=w, h=h))
|
1163 |
+
raise ValueError(error_msg)
|
1164 |
+
|
1165 |
+
x1 = int(round((im_w - w) / 2.))
|
1166 |
+
y1 = int(round((im_h - h) / 2.))
|
1167 |
+
cropped = FF.crop_clip(clip, y1, x1, h, w)
|
1168 |
+
|
1169 |
+
return cropped
|
1170 |
+
|
1171 |
+
|
1172 |
+
class ColorJitter(object):
|
1173 |
+
"""Randomly change the brightness, contrast and saturation and hue of the clip
|
1174 |
+
Args:
|
1175 |
+
brightness (float): How much to jitter brightness. brightness_factor
|
1176 |
+
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
1177 |
+
contrast (float): How much to jitter contrast. contrast_factor
|
1178 |
+
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
1179 |
+
saturation (float): How much to jitter saturation. saturation_factor
|
1180 |
+
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
1181 |
+
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
1182 |
+
[-hue, hue]. Should be >=0 and <= 0.5.
|
1183 |
+
"""
|
1184 |
+
|
1185 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
1186 |
+
self.brightness = brightness
|
1187 |
+
self.contrast = contrast
|
1188 |
+
self.saturation = saturation
|
1189 |
+
self.hue = hue
|
1190 |
+
|
1191 |
+
def get_params(self, brightness, contrast, saturation, hue):
|
1192 |
+
if brightness > 0:
|
1193 |
+
brightness_factor = random.uniform(
|
1194 |
+
max(0, 1 - brightness), 1 + brightness)
|
1195 |
+
else:
|
1196 |
+
brightness_factor = None
|
1197 |
+
|
1198 |
+
if contrast > 0:
|
1199 |
+
contrast_factor = random.uniform(
|
1200 |
+
max(0, 1 - contrast), 1 + contrast)
|
1201 |
+
else:
|
1202 |
+
contrast_factor = None
|
1203 |
+
|
1204 |
+
if saturation > 0:
|
1205 |
+
saturation_factor = random.uniform(
|
1206 |
+
max(0, 1 - saturation), 1 + saturation)
|
1207 |
+
else:
|
1208 |
+
saturation_factor = None
|
1209 |
+
|
1210 |
+
if hue > 0:
|
1211 |
+
hue_factor = random.uniform(-hue, hue)
|
1212 |
+
else:
|
1213 |
+
hue_factor = None
|
1214 |
+
return brightness_factor, contrast_factor, saturation_factor, hue_factor
|
1215 |
+
|
1216 |
+
def __call__(self, clip):
|
1217 |
+
"""
|
1218 |
+
Args:
|
1219 |
+
clip (list): list of PIL.Image
|
1220 |
+
Returns:
|
1221 |
+
list PIL.Image : list of transformed PIL.Image
|
1222 |
+
"""
|
1223 |
+
if isinstance(clip[0], np.ndarray):
|
1224 |
+
raise TypeError(
|
1225 |
+
'Color jitter not yet implemented for numpy arrays')
|
1226 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
1227 |
+
brightness, contrast, saturation, hue = self.get_params(
|
1228 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
1229 |
+
|
1230 |
+
# Create img transform function sequence
|
1231 |
+
img_transforms = []
|
1232 |
+
if brightness is not None:
|
1233 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
|
1234 |
+
if saturation is not None:
|
1235 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
|
1236 |
+
if hue is not None:
|
1237 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
|
1238 |
+
if contrast is not None:
|
1239 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
|
1240 |
+
random.shuffle(img_transforms)
|
1241 |
+
|
1242 |
+
# Apply to all images
|
1243 |
+
jittered_clip = []
|
1244 |
+
for img in clip:
|
1245 |
+
for func in img_transforms:
|
1246 |
+
jittered_img = func(img)
|
1247 |
+
jittered_clip.append(jittered_img)
|
1248 |
+
|
1249 |
+
else:
|
1250 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
1251 |
+
'but got list of {0}'.format(type(clip[0])))
|
1252 |
+
return jittered_clip
|
1253 |
+
|
1254 |
+
|
1255 |
+
class Normalize(object):
|
1256 |
+
"""Normalize a clip with mean and standard deviation.
|
1257 |
+
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
|
1258 |
+
will normalize each channel of the input ``torch.*Tensor`` i.e.
|
1259 |
+
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
|
1260 |
+
.. note::
|
1261 |
+
This transform acts out of place, i.e., it does not mutates the input tensor.
|
1262 |
+
Args:
|
1263 |
+
mean (sequence): Sequence of means for each channel.
|
1264 |
+
std (sequence): Sequence of standard deviations for each channel.
|
1265 |
+
"""
|
1266 |
+
|
1267 |
+
def __init__(self, mean, std):
|
1268 |
+
self.mean = mean
|
1269 |
+
self.std = std
|
1270 |
+
|
1271 |
+
def __call__(self, clip):
|
1272 |
+
"""
|
1273 |
+
Args:
|
1274 |
+
clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized.
|
1275 |
+
Returns:
|
1276 |
+
Tensor: Normalized Tensor clip.
|
1277 |
+
"""
|
1278 |
+
return FF.normalize(clip, self.mean, self.std)
|
1279 |
+
|
1280 |
+
def __repr__(self):
|
1281 |
+
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
volume_transforms.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def convert_img(img):
|
7 |
+
"""Converts (H, W, C) numpy.ndarray to (C, W, H) format
|
8 |
+
"""
|
9 |
+
if len(img.shape) == 3:
|
10 |
+
img = img.transpose(2, 0, 1)
|
11 |
+
if len(img.shape) == 2:
|
12 |
+
img = np.expand_dims(img, 0)
|
13 |
+
return img
|
14 |
+
|
15 |
+
|
16 |
+
class ClipToTensor(object):
|
17 |
+
"""Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
|
18 |
+
to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, channel_nb=3, div_255=True, numpy=False):
|
22 |
+
self.channel_nb = channel_nb
|
23 |
+
self.div_255 = div_255
|
24 |
+
self.numpy = numpy
|
25 |
+
|
26 |
+
def __call__(self, clip):
|
27 |
+
"""
|
28 |
+
Args: clip (list of numpy.ndarray): clip (list of images)
|
29 |
+
to be converted to tensor.
|
30 |
+
"""
|
31 |
+
# Retrieve shape
|
32 |
+
if isinstance(clip[0], np.ndarray):
|
33 |
+
h, w, ch = clip[0].shape
|
34 |
+
assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
|
35 |
+
ch)
|
36 |
+
elif isinstance(clip[0], Image.Image):
|
37 |
+
w, h = clip[0].size
|
38 |
+
else:
|
39 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image\
|
40 |
+
but got list of {0}'.format(type(clip[0])))
|
41 |
+
|
42 |
+
np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
|
43 |
+
|
44 |
+
# Convert
|
45 |
+
for img_idx, img in enumerate(clip):
|
46 |
+
if isinstance(img, np.ndarray):
|
47 |
+
pass
|
48 |
+
elif isinstance(img, Image.Image):
|
49 |
+
img = np.array(img, copy=False)
|
50 |
+
else:
|
51 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image\
|
52 |
+
but got list of {0}'.format(type(clip[0])))
|
53 |
+
img = convert_img(img)
|
54 |
+
np_clip[:, img_idx, :, :] = img
|
55 |
+
if self.numpy:
|
56 |
+
if self.div_255:
|
57 |
+
np_clip = np_clip / 255.0
|
58 |
+
return np_clip
|
59 |
+
|
60 |
+
else:
|
61 |
+
tensor_clip = torch.from_numpy(np_clip)
|
62 |
+
|
63 |
+
if not isinstance(tensor_clip, torch.FloatTensor):
|
64 |
+
tensor_clip = tensor_clip.float()
|
65 |
+
if self.div_255:
|
66 |
+
tensor_clip = torch.div(tensor_clip, 255)
|
67 |
+
return tensor_clip
|
68 |
+
|
69 |
+
|
70 |
+
# Note this norms data to -1/1
|
71 |
+
class ClipToTensor_K(object):
|
72 |
+
"""Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
|
73 |
+
to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, channel_nb=3, div_255=True, numpy=False):
|
77 |
+
self.channel_nb = channel_nb
|
78 |
+
self.div_255 = div_255
|
79 |
+
self.numpy = numpy
|
80 |
+
|
81 |
+
def __call__(self, clip):
|
82 |
+
"""
|
83 |
+
Args: clip (list of numpy.ndarray): clip (list of images)
|
84 |
+
to be converted to tensor.
|
85 |
+
"""
|
86 |
+
# Retrieve shape
|
87 |
+
if isinstance(clip[0], np.ndarray):
|
88 |
+
h, w, ch = clip[0].shape
|
89 |
+
assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
|
90 |
+
ch)
|
91 |
+
elif isinstance(clip[0], Image.Image):
|
92 |
+
w, h = clip[0].size
|
93 |
+
else:
|
94 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image\
|
95 |
+
but got list of {0}'.format(type(clip[0])))
|
96 |
+
|
97 |
+
np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
|
98 |
+
|
99 |
+
# Convert
|
100 |
+
for img_idx, img in enumerate(clip):
|
101 |
+
if isinstance(img, np.ndarray):
|
102 |
+
pass
|
103 |
+
elif isinstance(img, Image.Image):
|
104 |
+
img = np.array(img, copy=False)
|
105 |
+
else:
|
106 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image\
|
107 |
+
but got list of {0}'.format(type(clip[0])))
|
108 |
+
img = convert_img(img)
|
109 |
+
np_clip[:, img_idx, :, :] = img
|
110 |
+
if self.numpy:
|
111 |
+
if self.div_255:
|
112 |
+
np_clip = (np_clip - 127.5) / 127.5
|
113 |
+
return np_clip
|
114 |
+
|
115 |
+
else:
|
116 |
+
tensor_clip = torch.from_numpy(np_clip)
|
117 |
+
|
118 |
+
if not isinstance(tensor_clip, torch.FloatTensor):
|
119 |
+
tensor_clip = tensor_clip.float()
|
120 |
+
if self.div_255:
|
121 |
+
tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5)
|
122 |
+
return tensor_clip
|
123 |
+
|
124 |
+
|
125 |
+
class ToTensor(object):
|
126 |
+
"""Converts numpy array to tensor
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __call__(self, array):
|
130 |
+
tensor = torch.from_numpy(array)
|
131 |
+
return tensor
|