fmthoker commited on
Commit
4940c8b
·
verified ·
1 Parent(s): 401fa20

Upload 26 files

Browse files
INSTALL.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SMILE Installation
2
+
3
+ This project relies on several open-source libraries. We recommend using **`conda`** to manage your Python environment and installing dependencies via the provided `environment.yml` file.
4
+
5
+ ## Installation Steps
6
+ 1. **Clone the repository**
7
+ ```bash
8
+ git clone https://github.com/fmthoker/SMILE.git
9
+ cd SMILE
10
+ ```
11
+ 2. **Create a conda environment**
12
+ ```bash
13
+ conda env create -f environment.yml
14
+ ```
15
+ 3. **Activate the environment**
16
+ ```bash
17
+ conda activate smile
18
+ ```
19
+ 4. **Download CLIP weights (Optional, only required for pretraining)**
20
+ ```bash
21
+ mkdir clip_weights
22
+ ```
23
+ For pretraining, please download the [CLIP weights](https://huggingface.co/fmthoker/SMILE/resolve/main/clip_weights/ViT-B-16.pt) and place them in the `clip_weights` folder created above.
24
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Fida Mohammad Thoker
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,113 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Official PyTorch Implementation of SMILE (CVPR 2025).
2
+
3
+ ![SMILE Framework](figs/smile.jpg)
4
+
5
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)<br>
6
+ [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/fmthoker/SMILE/tree/main/SMILE_MODELS)
7
+
8
+
9
+ > [**SMILE: Infusing Spatial and Motion Semantics in Masked Video Learning**](https://arxiv.org/abs/2504.00527)<br>
10
+ > [Fida Mohammad Thoker](https://fmthoker.github.io/), [Letian Jiang](https://tonnew5418.github.io/), [Chen Zhao](https://zhao-chen.com/), [Bernard Ghanem](https://cemse.kaust.edu.sa/profiles/bernard-ghanem)<br>King Abdullah University of Science and Technology (KAUST)
11
+
12
+ ## 📰 News
13
+ **[2025.6.2]** Code and pre-trained models are available now! <br>
14
+ **[2025.5.28]** Code and pre-trained models will be released here. Welcome to **watch** this repository for the latest updates.
15
+
16
+ ## ✨ Highlights
17
+
18
+ ### 🔥 State-of-the-art on SSv2 and K400
19
+
20
+ Our method achieves state-of-the-art performance on **SSv2** and **K400** benchmarks with a ViT-B backbone, surpassing prior self-supervised video models by up to **2.5%**, thanks to efficient *CLIP-based semantic supervision*.
21
+
22
+ ### ⚡️ Leading Results Across Generalization Challenges
23
+
24
+ We evaluate our method on the [**SEVERE benchmark**](https://bpiyush.github.io/SEVERE-website/), covering domain shift, low-shot learning, fine-grained actions, and task adaptability. Our model consistently outperforms prior methods and achieves a **3.0% average gain** over strong baselines, demonstrating superior generalization in diverse video understanding tasks.
25
+
26
+ ### 😮 Superior Motion Representation Without Video-Text Alignment
27
+
28
+ Compared to CLIP-based methods such as [**ViCLIP**](https://github.com/OpenGVLab/InternVideo/tree/main/Data/InternVid) and [**UMT**](https://github.com/OpenGVLab/unmasked_teacher), our model achieves higher accuracy on motion-sensitive datasets, particularly under *linear probing*. This indicates stronger video representations learned with less data and without relying on video-text alignment.
29
+
30
+ ## 🚀 Main Results and Models
31
+
32
+ ### ✨ Something-Something V2
33
+
34
+ | Method | Pretrain Dataset | Pretrain Epochs | Backbone | Top-1 | Finetune |
35
+ | :------: | :--------------: | :-------------: | :------: | :---: | :------: |
36
+ | SMILE | K400 | 800 | ViT-S | 69.1 | TODO |
37
+ | SMILE | K400 | 600 | ViT-B | 72.1 | [log](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/ssv2/VIT_B_600_EPOCHS/log.txt) / [checkpoint](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/ssv2/VIT_B_600_EPOCHS/ssv2_finetuned_after_k400_pretraining_first_stage_300_epochs_2nd_stage_300_epochs.pth) |
38
+ | SMILE | K400 | 1200 | ViT-B | 72.4 | [log](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/ssv2/VIT_B_1200_EPOCHS/log.txt) / [checkpoint](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/ssv2/VIT_B_1200_EPOCHS/ssv2_finetuned_after_k400_pretraining_first_stage_800_epochs_2nd_stage_400_epochs.pth)
39
+ | SMILE | SSv2 | 800 | ViT-B | 72.5 | TODO |
40
+
41
+ ### ✨ Kinetics-400
42
+
43
+ | Method | Pretrain Dataset | Pretrain Epochs | Backbone | Top-1 | Pretrain | Finetune |
44
+ | :------: | :--------------: | :-------------: | :------: | :---: | :------: | :------: |
45
+ | SMILE | K400 | 800 | ViT-S | 79.5 | TODO | TODO |
46
+ | SMILE | K400 | 600 | ViT-B | 83.1 | [checkpoint](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/pretrain/k400_pretraining_first_stage_300_epochs_2nd_stage_300_epochs.pth) | [log](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/k400/VIT_B_600_EPOCHS/log.txt) / [checkpoint](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/k400/VIT_B_600_EPOCHS/k400_finetuned_after_k400_pretraining_first_stage_300_epochs_2nd_stage_300_epochs.pth) |
47
+ | SMILE | K400 | 1200 | ViT-B | 83.4 | [checkpoint](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/pretrain/k400_pretraining_first_stage_800_epochs_2nd_stage_400_epochs.pth) | [log](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/k400/VIT_B_1200_EPOCHS/log.txt) / [checkpoint](https://huggingface.co/fmthoker/SMILE/resolve/main/SMILE_MODELS/finetune/k400/VIT_B_1200_EPOCHS/k400_finetuned_after_k400_pretraining_first_stage_800_epochs_2nd_stage_400_epochs.pth) |
48
+
49
+ ## 🔨 Installation
50
+
51
+ Please follow the instructions in [INSTALL.md](INSTALL.md).
52
+
53
+ ## ➡️ Data Preparation
54
+
55
+ We follow [VideoMAE Data preparation](https://github.com/MCG-NJU/VideoMAE/blob/main/DATASET.md) to prepare our datasets (K400 and SSv2). Here we provide our annotation files for those two datasets: [annotation_files](annotation_files). For pretraining, we use training sets (train.csv).
56
+
57
+ We provide the list of segmented object images used for pretraining in [object_instances.txt](annotation_files/object_instances.txt). The images will be released later.
58
+
59
+
60
+ ## 🔄 Pre-training
61
+
62
+ Following the [VideoMAE pre-training guide](https://github.com/MCG-NJU/VideoMAE/blob/main/PRETRAIN.md), we provide scripts for pre-training on the Kinetics-400 (K400) dataset using the ViT-Base model: [scripts/pretrain/](./scripts/pretrain/)
63
+
64
+ As described in the paper, we adopt a two-stage training strategy. Please refer to the script names to identify which stage to run.
65
+
66
+ If you wish to perform your own pre-training, make sure to update the following parameters in the scripts:
67
+
68
+ - `DATA_PATH`: Path to your dataset
69
+ - `OUTPUT_DIR`: Directory to save output results
70
+ - `OBJECTS_PATH`: Path to the overlaying objects image dataset (image data to be released)
71
+ - `FIRST_STAGE_CKPT`: Path to the ckpt from first stage pretraining ( for second stage training)
72
+
73
+ > **Note:** Our pre-training experiments were conducted using 8 V100(32 GB) GPUs.
74
+ ---
75
+
76
+ ## ⤴️ Fine-tuning with Pre-trained Models
77
+
78
+ Following the [VideoMAE finetuning guide](https://github.com/MCG-NJU/VideoMAE/blob/main/FINETUNE.md), we provide scripts for fine-tuning on the Something-Something v2 (SSv2) and Kinetics-400 (K400) datasets using the ViT-Base model: [scripts/finetune/](./scripts/finetune)
79
+
80
+
81
+ To perform your own fine-tuning, please update the following parameters in the script:
82
+
83
+ - `DATA_PATH`: Path to your dataset
84
+ - `MODEL_PATH`: Path to the pre-trained model
85
+ - `OUTPUT_DIR`: Directory to save output results
86
+
87
+ > **Note:** Our finetuning experiments were conducted using 4 V100(32 GB) GPUs.
88
+
89
+ ## ☎️ Contact
90
+
91
+ Fida Mohammad Thoker: [email protected]
92
+
93
+ ## 👍 Acknowledgements
94
+
95
+ We sincerely thank [Michael Dorkenwald](https://mdorkenwald.com/) for providing the object image dataset that supports this work.<br>
96
+ This project is built upon [VideoMAE](https://github.com/MCG-NJU/VideoMAE) and [tubelet-contrast](https://github.com/fmthoker/tubelet-contrast). Thanks to the contributors of these great codebases.
97
+
98
+ ## 🔒 License
99
+
100
+ This project is released under the MIT license. For more details, please refer to the [LICENSE](https://github.com/fmthoker/SMILE/blob/main/LICENSE) file.
101
+
102
+ ## ✏️ Citation
103
+
104
+ If you think this project is helpful, please feel free to leave a star⭐️ and cite our paper:
105
+
106
+ ```
107
+ @inproceedings{thoker2025smile,
108
+ author = {Thoker, Fida Mohammad and Jiang, Letian and Zhao, Chen and Ghanem, Bernard},
109
+ title = {SMILE: Infusing Spatial and Motion Semantics in Masked Video Learning},
110
+ journal = {CVPR},
111
+ year = {2025},
112
+ }
113
+ ```
datasets.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torchvision import transforms
3
+ from transforms import *
4
+ from masking_generator import TubeMaskingGenerator, TubeletMaskingGenerator
5
+ from kinetics import VideoClsDataset, VideoMAE
6
+ from ssv2 import SSVideoClsDataset
7
+ import synthetic_tubelets as synthetic_tubelets
8
+ import ast
9
+ import random
10
+
11
+ class DataAugmentationForVideoMAE(object):
12
+ def __init__(self, args):
13
+ self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
14
+ self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
15
+ normalize = GroupNormalize(self.input_mean, self.input_std)
16
+ self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66])
17
+ self.add_tubelets = args.add_tubelets
18
+ self.mask_type = args.mask_type
19
+
20
+ # original transform without adding tubelets
21
+ self.transform_original = transforms.Compose([
22
+ self.train_augmentation,
23
+ Stack(roll=False),
24
+ ToTorchFormatTensor(div=True),
25
+ normalize,
26
+ ])
27
+
28
+ # tubelet transform
29
+ if args.add_tubelets:
30
+ scales = ast.literal_eval(args.scales)
31
+
32
+ self.tubelets = synthetic_tubelets.PatchMask(
33
+ use_objects=args.use_objects,
34
+ objects_path=args.objects_path,
35
+ region_sampler=dict(
36
+ scales=scales,
37
+ ratios=[0.5, 0.67, 0.75, 1.0, 1.33, 1.50, 2.0],
38
+ scale_jitter=0.18,
39
+ num_rois=2,
40
+ ),
41
+ key_frame_probs=[0.5, 0.3, 0.2],
42
+ loc_velocity=12,
43
+ rot_velocity=6,
44
+ size_velocity=0.025,
45
+ label_prob=1.0,
46
+ motion_type=args.motion_type,
47
+ patch_transformation='rotation',)
48
+
49
+
50
+ self.transform1 = transforms.Compose([
51
+ self.train_augmentation,
52
+ self.tubelets,
53
+ ])
54
+ self.transform2 = transforms.Compose([Stack(roll=False),
55
+ ToTorchFormatTensor(div=True),
56
+ normalize,
57
+ ])
58
+ else:
59
+ self.transform = self.transform_original
60
+
61
+ self.original_masked_position_generator = TubeMaskingGenerator(
62
+ args.window_size, args.mask_ratio
63
+ )
64
+
65
+ if args.mask_type == 'tube':
66
+ self.masked_position_generator = self.original_masked_position_generator
67
+ elif args.mask_type == 'tubelet':
68
+ self.masked_position_generator = TubeletMaskingGenerator(
69
+ args.window_size, args.mask_ratio, args.visible_frames, args.sub_mask_type
70
+ )
71
+ else:
72
+ raise NotImplemented
73
+
74
+
75
+ def __call__(self, images):
76
+ process_data, _, traj_rois = self.ComposedTransform(images)
77
+
78
+ if self.mask_type == 'tubelet' and traj_rois is not None:
79
+ return process_data, self.masked_position_generator(traj_rois)
80
+ else:
81
+ return process_data, self.masked_position_generator()
82
+
83
+ def ComposedTransform(self, images):
84
+ traj_rois = None
85
+
86
+ if self.add_tubelets:
87
+ data = self.transform1(images)
88
+ process_data, traj_rois = data[:-1], data[-1]
89
+ process_data, _ = self.transform2(process_data)
90
+ else:
91
+ process_data, _ = self.transform(images)
92
+
93
+ return process_data, _, traj_rois
94
+
95
+ def __repr__(self):
96
+ repr = "(DataAugmentationForVideoMAE,\n"
97
+ try:
98
+ self.transform
99
+ except:
100
+ repr += " transform = %s,\n" % (str(self.transform1) + str(self.transform2))
101
+ else:
102
+ repr += " transform = %s,\n" % str(self.transform)
103
+
104
+ repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
105
+ repr += ")"
106
+ return repr
107
+
108
+
109
+ def build_pretraining_dataset(args):
110
+ transform = DataAugmentationForVideoMAE(args)
111
+ dataset = VideoMAE(
112
+ root=None,
113
+ setting=args.data_path,
114
+ video_ext='mp4',
115
+ is_color=True,
116
+ modality='rgb',
117
+ new_length=args.num_frames,
118
+ new_step=args.sampling_rate,
119
+ transform=transform,
120
+ temporal_jitter=False,
121
+ video_loader=True,
122
+ use_decord=True,
123
+ lazy_init=False)
124
+ print("Data Aug = %s" % str(transform))
125
+ return dataset
126
+
127
+
128
+ def build_dataset(is_train, test_mode, args):
129
+ if args.data_set == 'Kinetics-400' or args.data_set == "Mini-Kinetics":
130
+ mode = None
131
+ anno_path = None
132
+ if is_train is True:
133
+ mode = 'train'
134
+ if 'Mini' in args.data_set:
135
+ anno_path = os.path.join(args.data_path, 'train_mini_kinetics.csv')
136
+ else:
137
+ anno_path = os.path.join(args.data_path, 'train.csv')
138
+ elif test_mode is True:
139
+ mode = 'test'
140
+ if 'Mini' in args.data_set:
141
+ anno_path = os.path.join(args.data_path, 'test_mini_kinetics.csv')
142
+ else:
143
+ anno_path = os.path.join(args.data_path, 'test.csv')
144
+ else:
145
+ mode = 'validation'
146
+ if 'Mini' in args.data_set:
147
+ anno_path = os.path.join(args.data_path, 'val_mini_kinetics.csv')
148
+ else:
149
+ anno_path = os.path.join(args.data_path, 'val.csv')
150
+
151
+ dataset = VideoClsDataset(
152
+ anno_path=anno_path,
153
+ data_path='/',
154
+ mode=mode,
155
+ clip_len=args.num_frames,
156
+ frame_sample_rate=args.sampling_rate,
157
+ num_segment=1,
158
+ test_num_segment=args.test_num_segment,
159
+ test_num_crop=args.test_num_crop,
160
+ num_crop=1 if not test_mode else 3,
161
+ keep_aspect_ratio=True,
162
+ crop_size=args.input_size,
163
+ short_side_size=args.short_side_size,
164
+ new_height=256,
165
+ new_width=320,
166
+ args=args)
167
+ if 'Mini' in args.data_set:
168
+ nb_classes = 200
169
+ else:
170
+ nb_classes = 400
171
+
172
+ elif args.data_set == 'SSV2' or args.data_set == 'SSV2-Mini':
173
+ mode = None
174
+ anno_path = None
175
+ if is_train is True:
176
+ mode = 'train'
177
+ if 'Mini' in args.data_set:
178
+ anno_path = os.path.join(args.data_path, 'train_mini.csv')
179
+ else:
180
+ anno_path = os.path.join(args.data_path, 'train.csv')
181
+ elif test_mode is True:
182
+ mode = 'test'
183
+ anno_path = os.path.join(args.data_path, 'test.csv')
184
+ else:
185
+ mode = 'validation'
186
+ anno_path = os.path.join(args.data_path, 'val.csv')
187
+
188
+ dataset = SSVideoClsDataset(
189
+ anno_path=anno_path,
190
+ data_path='/',
191
+ mode=mode,
192
+ clip_len=1,
193
+ num_segment=args.num_frames,
194
+ test_num_segment=args.test_num_segment,
195
+ test_num_crop=args.test_num_crop,
196
+ num_crop=1 if not test_mode else 3,
197
+ keep_aspect_ratio=True,
198
+ crop_size=args.input_size,
199
+ short_side_size=args.short_side_size,
200
+ new_height=256,
201
+ new_width=320,
202
+ args=args)
203
+ nb_classes = 174
204
+
205
+ elif args.data_set == 'UCF101':
206
+ mode = None
207
+ anno_path = None
208
+ if is_train is True:
209
+ mode = 'train'
210
+ anno_path = os.path.join(args.data_path, 'train.csv')
211
+ elif test_mode is True:
212
+ mode = 'test'
213
+ anno_path = os.path.join(args.data_path, 'test.csv')
214
+ else:
215
+ mode = 'validation'
216
+ anno_path = os.path.join(args.data_path, 'val.csv')
217
+
218
+ dataset = VideoClsDataset(
219
+ anno_path=anno_path,
220
+ data_path='/',
221
+ mode=mode,
222
+ clip_len=args.num_frames,
223
+ frame_sample_rate=args.sampling_rate,
224
+ num_segment=1,
225
+ test_num_segment=args.test_num_segment,
226
+ test_num_crop=args.test_num_crop,
227
+ num_crop=1 if not test_mode else 3,
228
+ keep_aspect_ratio=True,
229
+ crop_size=args.input_size,
230
+ short_side_size=args.short_side_size,
231
+ new_height=256,
232
+ new_width=320,
233
+ args=args)
234
+ nb_classes = 101
235
+
236
+ elif args.data_set == 'HMDB51':
237
+ mode = None
238
+ anno_path = None
239
+ if is_train is True:
240
+ mode = 'train'
241
+ anno_path = os.path.join(args.data_path, 'train.csv')
242
+ elif test_mode is True:
243
+ mode = 'test'
244
+ anno_path = os.path.join(args.data_path, 'test.csv')
245
+ else:
246
+ mode = 'validation'
247
+ anno_path = os.path.join(args.data_path, 'val.csv')
248
+
249
+ dataset = VideoClsDataset(
250
+ anno_path=anno_path,
251
+ data_path='/',
252
+ mode=mode,
253
+ clip_len=args.num_frames,
254
+ frame_sample_rate=args.sampling_rate,
255
+ num_segment=1,
256
+ test_num_segment=args.test_num_segment,
257
+ test_num_crop=args.test_num_crop,
258
+ num_crop=1 if not test_mode else 3,
259
+ keep_aspect_ratio=True,
260
+ crop_size=args.input_size,
261
+ short_side_size=args.short_side_size,
262
+ new_height=256,
263
+ new_width=320,
264
+ args=args)
265
+ nb_classes = 51
266
+ else:
267
+ raise NotImplementedError()
268
+ assert nb_classes == args.nb_classes
269
+ print("Number of the class = %d" % args.nb_classes)
270
+
271
+ return dataset, nb_classes
dynamic_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ import numpy as np
5
+ from typing import List
6
+
7
+
8
+ def sample_key_frames(num_frames: int,
9
+ key_frame_probs: List[float]) -> np.ndarray:
10
+ """ Sample the indices of key frames.
11
+
12
+ Args:
13
+ num_frames (int): number of frames in whole video
14
+ key_frame_probs (List[float]): the sampling probability of how many
15
+ key frames will be sampled. The sum of this array should be 1.0.
16
+
17
+ Returns:
18
+ frame_inds (np.ndarray): key frame index, in range
19
+ of [0, num_frames - 1]. Note that the first frame and the
20
+ last frame will always be key frames.
21
+
22
+ Examples:
23
+ >>> sample_key_frames(16, [1.0, ])
24
+ np.ndarray([0, 15])
25
+ >>> sample_key_frames(16, [0.5, 0.5])
26
+ np.ndarray([0, 15])
27
+ np.ndarray([0, 7, 15])
28
+ np.ndarray([0, 8, 15])
29
+ np.ndarray([0, 15])
30
+ """
31
+ # how many key frames
32
+ num_key_frames = np.random.choice(len(key_frame_probs), p=key_frame_probs)
33
+ # if there is no inner key frame, we will directly
34
+ # sample the first frame and the last frame.
35
+ if num_key_frames == 0:
36
+ return np.array([0, num_frames - 1], dtype=np.int32)
37
+ avg_duration = num_frames / (num_key_frames + 1)
38
+ ticks = np.array([int(avg_duration * i)
39
+ for i in range(1, num_key_frames + 1)], dtype=np.int32)
40
+
41
+ # add random jitter
42
+ jitter_range = int(avg_duration / 3)
43
+ if jitter_range > 0:
44
+ jitter = np.random.randint(-jitter_range,
45
+ jitter_range, size=len(ticks))
46
+ else:
47
+ jitter = np.zeros((len(ticks),), np.int32)
48
+
49
+ ticks = ticks + jitter
50
+ # add the first frame and last frame
51
+ ticks = np.concatenate((ticks, np.array([0, num_frames - 1])), axis=0)
52
+ # remove duplication and sort array
53
+ ticks = np.sort(np.unique(ticks))
54
+ return ticks
55
+
56
+
57
+ def extend_key_frame_to_all(array: np.ndarray,
58
+ key_frame_inds: np.ndarray,
59
+ interpolate: str = 'uniform') -> np.ndarray:
60
+ """ Interpolate the values between key frames.
61
+
62
+ This function is used in some data augmentations for video clips. For
63
+ example, we first decide the color distortion values in some key frames,
64
+ then we can interpolate the values in the rest of frames. This strategy
65
+ will make the data augmentations more smooth over the entire video clip.
66
+
67
+ Args:
68
+ array (np.ndarray): The values in the key frames, in shape of [K, *]
69
+ key_frame_inds (np.ndarray): the frame index list of key frames, in
70
+ shape of [K, ]
71
+ interpolate (str): interpolation type. 'uniform' means the linear
72
+ interpolation; 'accelerate' means the constant acceleration.
73
+ 'decelerate' means the reverse order of 'accelerate'.
74
+
75
+ Returns:
76
+ out_array (np.ndarray): the interpolated values, in shape of [N, *].
77
+ N denotes the value of key_frame_inds[-1].
78
+
79
+ Examples:
80
+ >>> values = np.array([0.0, 5.0])
81
+ >>> inds = np.array([0, 10])
82
+ >>> extend_key_frame_to_all(values, inds)
83
+ array([0. , 0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. ])
84
+ >>> extend_key_frame_to_all(values, inds, 'accelerate')
85
+ array([0. , 0.05, 0.2 , 0.45, 0.8 , 1.25, 1.8 , 2.45, 3.2 , 4.05, 5.])
86
+ """
87
+
88
+ def _uniform_interpolate(start_state, end_state, index_delta):
89
+ delta_state = (end_state - start_state) * (1.0 / index_delta)
90
+ return np.concatenate([start_state + _ * delta_state
91
+ for _ in range(index_delta+1)], axis=0)
92
+
93
+ def _accelerate_interpolate(start_state, end_state, index_delta):
94
+ a = 2 * (end_state - start_state) / (index_delta ** 2)
95
+ return np.concatenate([start_state + 0.5 * a * (_**2)
96
+ for _ in range(index_delta+1)], axis=0)
97
+
98
+ def _decelerate_interpolate(start_state, end_state, index_delta):
99
+ a = 2 * (start_state - end_state) / (index_delta ** 2)
100
+ return np.concatenate([end_state + 0.5 * a * ((index_delta-_)**2)
101
+ for _ in range(index_delta+1)], axis=0)
102
+
103
+ assert key_frame_inds[0] == 0 and key_frame_inds[-1] > 0
104
+ num_key_frames = len(key_frame_inds)
105
+ assert num_key_frames == len(array)
106
+ num_frames = key_frame_inds[-1] + 1
107
+
108
+ out_array = np.zeros((num_frames, ) + array.shape[1:], dtype=array.dtype)
109
+ for i in range(num_key_frames - 1):
110
+ # fill the values between i -> i+1
111
+ st_idx, end_idx = key_frame_inds[i:i+2]
112
+ if interpolate == 'uniform':
113
+ inter_func = _uniform_interpolate
114
+ elif interpolate == 'accelerate':
115
+ inter_func = _accelerate_interpolate
116
+ elif interpolate == 'decelerate':
117
+ inter_func = _decelerate_interpolate
118
+ elif interpolate == 'random':
119
+ inter_index = np.random.choice(3, p=[0.7, 0.15, 0.15])
120
+ if inter_index == 0:
121
+ inter_func = _uniform_interpolate
122
+ elif inter_index == 1:
123
+ inter_func = _accelerate_interpolate
124
+ else:
125
+ inter_func = _decelerate_interpolate
126
+ else:
127
+ raise NotImplementedError
128
+ i_out = inter_func(array[i:i+1],
129
+ array[i+1:i+2],
130
+ end_idx - st_idx)
131
+ out_array[st_idx:end_idx+1] = i_out
132
+
133
+ return out_array
engine_for_finetuning.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import math
4
+ import sys
5
+ from typing import Iterable, Optional
6
+ import torch
7
+ from mixup import Mixup
8
+ from timm.utils import accuracy, ModelEma
9
+ import utils_mae as utils
10
+ from scipy.special import softmax
11
+ import gc
12
+ import pickle
13
+
14
+ def train_class_batch(model, samples, target, criterion):
15
+ outputs = model(samples)
16
+ loss = criterion(outputs, target)
17
+ return loss, outputs
18
+
19
+
20
+ def get_loss_scale_for_deepspeed(model):
21
+ optimizer = model.optimizer
22
+ return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
23
+
24
+
25
+ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
26
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
27
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
28
+ model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
29
+ start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
30
+ num_training_steps_per_epoch=None, update_freq=None):
31
+ model.train(True)
32
+ metric_logger = utils.MetricLogger(delimiter=" ")
33
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
34
+ metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
35
+ header = 'Epoch: [{}]'.format(epoch)
36
+ print_freq = 10
37
+
38
+ if loss_scaler is None:
39
+ model.zero_grad()
40
+ model.micro_steps = 0
41
+ else:
42
+ optimizer.zero_grad()
43
+
44
+ for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
45
+ step = data_iter_step // update_freq
46
+ if step >= num_training_steps_per_epoch:
47
+ continue
48
+ it = start_steps + step # global training iteration
49
+ # Update LR & WD for the first acc
50
+ if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
51
+ for i, param_group in enumerate(optimizer.param_groups):
52
+ if lr_schedule_values is not None:
53
+ param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
54
+ if wd_schedule_values is not None and param_group["weight_decay"] > 0:
55
+ param_group["weight_decay"] = wd_schedule_values[it]
56
+
57
+ samples = samples.to(device, non_blocking=True)
58
+ targets = targets.to(device, non_blocking=True)
59
+
60
+ if mixup_fn is not None:
61
+ samples, targets = mixup_fn(samples, targets)
62
+
63
+ if loss_scaler is None:
64
+ samples = samples.half()
65
+ loss, output = train_class_batch(
66
+ model, samples, targets, criterion)
67
+ else:
68
+ with torch.cuda.amp.autocast():
69
+ loss, output = train_class_batch(
70
+ model, samples, targets, criterion)
71
+
72
+ loss_value = loss.item()
73
+
74
+ if not math.isfinite(loss_value):
75
+ print("Loss is {}, stopping training".format(loss_value))
76
+ sys.exit(1)
77
+
78
+ if loss_scaler is None:
79
+ loss /= update_freq
80
+ model.backward(loss)
81
+ model.step()
82
+
83
+ if (data_iter_step + 1) % update_freq == 0:
84
+ # model.zero_grad()
85
+ # Deepspeed will call step() & model.zero_grad() automatic
86
+ if model_ema is not None:
87
+ model_ema.update(model)
88
+ grad_norm = None
89
+ loss_scale_value = get_loss_scale_for_deepspeed(model)
90
+ else:
91
+ # this attribute is added by timm on one optimizer (adahessian)
92
+ is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
93
+ loss /= update_freq
94
+ grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
95
+ parameters=model.parameters(), create_graph=is_second_order,
96
+ update_grad=(data_iter_step + 1) % update_freq == 0)
97
+ if (data_iter_step + 1) % update_freq == 0:
98
+ optimizer.zero_grad()
99
+ if model_ema is not None:
100
+ model_ema.update(model)
101
+ loss_scale_value = loss_scaler.state_dict()["scale"]
102
+
103
+ torch.cuda.synchronize()
104
+
105
+ if mixup_fn is None:
106
+ class_acc = (output.max(-1)[-1] == targets).float().mean()
107
+ else:
108
+ class_acc = None
109
+ metric_logger.update(loss=loss_value)
110
+ metric_logger.update(class_acc=class_acc)
111
+ metric_logger.update(loss_scale=loss_scale_value)
112
+ min_lr = 10.
113
+ max_lr = 0.
114
+ for group in optimizer.param_groups:
115
+ min_lr = min(min_lr, group["lr"])
116
+ max_lr = max(max_lr, group["lr"])
117
+
118
+ metric_logger.update(lr=max_lr)
119
+ metric_logger.update(min_lr=min_lr)
120
+ weight_decay_value = None
121
+ for group in optimizer.param_groups:
122
+ if group["weight_decay"] > 0:
123
+ weight_decay_value = group["weight_decay"]
124
+ metric_logger.update(weight_decay=weight_decay_value)
125
+ metric_logger.update(grad_norm=grad_norm)
126
+
127
+ if log_writer is not None:
128
+ log_writer.update(loss=loss_value, head="loss")
129
+ log_writer.update(class_acc=class_acc, head="loss")
130
+ log_writer.update(loss_scale=loss_scale_value, head="opt")
131
+ log_writer.update(lr=max_lr, head="opt")
132
+ log_writer.update(min_lr=min_lr, head="opt")
133
+ log_writer.update(weight_decay=weight_decay_value, head="opt")
134
+ log_writer.update(grad_norm=grad_norm, head="opt")
135
+
136
+ log_writer.set_step()
137
+
138
+ # gather the stats from all processes
139
+ metric_logger.synchronize_between_processes()
140
+ print("Averaged stats:", metric_logger)
141
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
142
+
143
+
144
+ @torch.no_grad()
145
+ def validation_one_epoch(data_loader, model, device):
146
+ criterion = torch.nn.CrossEntropyLoss()
147
+
148
+ metric_logger = utils.MetricLogger(delimiter=" ")
149
+ header = 'Val:'
150
+
151
+ # switch to evaluation mode
152
+ model.eval()
153
+
154
+ for batch in metric_logger.log_every(data_loader, 10, header):
155
+ videos = batch[0]
156
+ target = batch[1]
157
+ videos = videos.to(device, non_blocking=True)
158
+ target = target.to(device, non_blocking=True)
159
+
160
+ # compute output
161
+ with torch.cuda.amp.autocast():
162
+ output = model(videos)
163
+ loss = criterion(output, target)
164
+
165
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
166
+
167
+ batch_size = videos.shape[0]
168
+ metric_logger.update(loss=loss.item())
169
+ metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
170
+ metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
171
+ # gather the stats from all processes
172
+ metric_logger.synchronize_between_processes()
173
+ print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
174
+ .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
175
+
176
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
177
+
178
+
179
+
180
+ @torch.no_grad()
181
+ def final_test(data_loader, model, device, file):
182
+ criterion = torch.nn.CrossEntropyLoss()
183
+
184
+ metric_logger = utils.MetricLogger(delimiter=" ")
185
+ header = 'Test:'
186
+
187
+ # switch to evaluation mode
188
+ model.eval()
189
+ final_result = []
190
+
191
+ for batch in metric_logger.log_every(data_loader, 10, header):
192
+ videos = batch[0]
193
+ target = batch[1]
194
+ ids = batch[2]
195
+ chunk_nb = batch[3]
196
+ split_nb = batch[4]
197
+ videos = videos.to(device, non_blocking=True)
198
+ target = target.to(device, non_blocking=True)
199
+
200
+ # compute output
201
+ with torch.cuda.amp.autocast():
202
+ output = model(videos)
203
+ loss = criterion(output, target)
204
+
205
+ for i in range(output.size(0)):
206
+ string = "{} {} {} {} {}\n".format(ids[i], \
207
+ str(output.data[i].cpu().numpy().tolist()), \
208
+ str(int(target[i].cpu().numpy())), \
209
+ str(int(chunk_nb[i].cpu().numpy())), \
210
+ str(int(split_nb[i].cpu().numpy())))
211
+ final_result.append(string)
212
+
213
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
214
+
215
+ batch_size = videos.shape[0]
216
+ metric_logger.update(loss=loss.item())
217
+ metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
218
+ metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
219
+
220
+ if not os.path.exists(file):
221
+ os.mknod(file)
222
+ with open(file, 'w') as f:
223
+ f.write("{}, {}\n".format(acc1, acc5))
224
+ for line in final_result:
225
+ f.write(line)
226
+ # gather the stats from all processes
227
+ metric_logger.synchronize_between_processes()
228
+ print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
229
+ .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
230
+
231
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
232
+
233
+
234
+ def merge(eval_path, num_tasks):
235
+ dict_feats = {}
236
+ dict_label = {}
237
+ dict_pos = {}
238
+ print("Reading individual output files")
239
+
240
+ for x in range(num_tasks):
241
+ file = os.path.join(eval_path, str(x) + '.txt')
242
+ lines = open(file, 'r').readlines()[1:]
243
+ for line in lines:
244
+ line = line.strip()
245
+ name = line.split('[')[0]
246
+ label = line.split(']')[1].split(' ')[1]
247
+ chunk_nb = line.split(']')[1].split(' ')[2]
248
+ split_nb = line.split(']')[1].split(' ')[3]
249
+ data = np.fromstring(line.split('[')[1].split(']')[0], dtype=float, sep=',')
250
+ data = softmax(data)
251
+ if not name in dict_feats:
252
+ dict_feats[name] = []
253
+ dict_label[name] = 0
254
+ dict_pos[name] = []
255
+ if chunk_nb + split_nb in dict_pos[name]:
256
+ continue
257
+ dict_feats[name].append(data)
258
+ dict_pos[name].append(chunk_nb + split_nb)
259
+ dict_label[name] = label
260
+ print("Computing final results")
261
+
262
+ input_lst = []
263
+ print(len(dict_feats))
264
+ for i, item in enumerate(dict_feats):
265
+ input_lst.append([i, item, dict_feats[item], dict_label[item]])
266
+ from multiprocessing import Pool
267
+ p = Pool(64)
268
+ ans = p.map(compute_video, input_lst)
269
+ top1 = [x[1] for x in ans]
270
+ top5 = [x[2] for x in ans]
271
+ pred = [x[0] for x in ans]
272
+ label = [x[3] for x in ans]
273
+ final_top1 ,final_top5 = np.mean(top1), np.mean(top5)
274
+ return final_top1*100 ,final_top5*100
275
+
276
+ def compute_video(lst):
277
+ i, video_id, data, label = lst
278
+ feat = [x for x in data]
279
+ feat = np.mean(feat, axis=0)
280
+ pred = np.argmax(feat)
281
+ top1 = (int(pred) == int(label)) * 1.0
282
+ top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0
283
+ return [pred, top1, top5, int(label)]
284
+
285
+ def merge_mean_per_class(eval_path, num_tasks,nb_classes):
286
+ dict_feats = {}
287
+ dict_label = {}
288
+ dict_pos = {}
289
+ #print("Reading individual output files")
290
+
291
+ for x in range(num_tasks):
292
+ file = os.path.join(eval_path, str(x) + '.txt')
293
+ lines = open(file, 'r').readlines()[1:]
294
+ for line in lines:
295
+ line = line.strip()
296
+ name = line.split('[')[0]
297
+ label = line.split(']')[1].split(' ')[1]
298
+ chunk_nb = line.split(']')[1].split(' ')[2]
299
+ split_nb = line.split(']')[1].split(' ')[3]
300
+ data = np.fromstring(line.split('[')[1].split(']')[0], dtype=float, sep=',')
301
+ data = softmax(data)
302
+ if not name in dict_feats:
303
+ dict_feats[name] = []
304
+ dict_label[name] = 0
305
+ dict_pos[name] = []
306
+ if chunk_nb + split_nb in dict_pos[name]:
307
+ continue
308
+ dict_feats[name].append(data)
309
+ dict_pos[name].append(chunk_nb + split_nb)
310
+ dict_label[name] = label
311
+ print("Computing mean per class results")
312
+
313
+ input_lst = []
314
+ all_pred = []
315
+ all_label = []
316
+
317
+ classes = torch.arange(nb_classes)
318
+ classwise_top1 = [0 for c in classes]
319
+ classwise_top5 = [0 for c in classes]
320
+ actual_nb_classes = nb_classes
321
+ cnt = 0
322
+
323
+ for c in classes:
324
+ input_lst = []
325
+ for i, item in enumerate(dict_feats):
326
+ if int(dict_label[item]) == c:
327
+ input_lst.append([i, item, dict_feats[item], dict_label[item]])
328
+ cnt += len(input_lst)
329
+
330
+ # p = Pool(4)
331
+ # ans = p.map(compute_video, input_lst)
332
+ if len(input_lst) == 0:
333
+ actual_nb_classes -= 1
334
+ print(f"Class {c} is not present in test set, skip")
335
+ continue
336
+
337
+ ans = []
338
+ for i in input_lst:
339
+ ans.append(compute_video(i))
340
+ top1 = [x[1] for x in ans]
341
+ top5 = [x[2] for x in ans]
342
+ pred = [x[0] for x in ans]
343
+ label = [x[3] for x in ans]
344
+
345
+ # for i in pred:
346
+ # all_pred.append(i)
347
+ # for j in label:
348
+ # all_label.append(j)
349
+ final_top1 ,final_top5 = np.mean(top1), np.mean(top5)
350
+
351
+ classwise_top1[c] = final_top1*100
352
+ classwise_top5[c] = final_top5*100
353
+
354
+ del input_lst
355
+ del ans
356
+ del top1
357
+ del top5
358
+ del pred
359
+ del label
360
+ gc.collect()
361
+
362
+ assert cnt == len(dict_feats)
363
+ # pred_cnt = 0
364
+ # for idx, p in enumerate(all_pred):
365
+ # if int(p) == int(all_label[idx]):
366
+ # pred_cnt += 1
367
+ # print(pred_cnt/len(all_pred))
368
+ classwise_top1_path = os.path.join(eval_path, "classwise_top1.pkl")
369
+ with open(classwise_top1_path, 'wb') as file:
370
+ pickle.dump(classwise_top1, file)
371
+
372
+ classwise_top1 = np.sum(classwise_top1) / actual_nb_classes
373
+ classwise_top5 = np.sum(classwise_top5) / actual_nb_classes
374
+
375
+ return classwise_top1,classwise_top5
engine_for_pretraining.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+ from typing import Iterable
4
+ import torch
5
+ import torch.nn as nn
6
+ import utils_mae as utils
7
+ from einops import rearrange
8
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9
+
10
+ def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer,
11
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, patch_size: int = 16,
12
+ normlize_target: bool = True, log_writer=None, lr_scheduler=None, start_steps=None,
13
+ lr_schedule_values=None, wd_schedule_values=None,teacher_model=None,target_type='pixel', multiple_sampling=False):
14
+
15
+ model.train()
16
+ metric_logger = utils.MetricLogger(delimiter=" ")
17
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
18
+ metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
19
+ header = 'Epoch: [{}]'.format(epoch)
20
+ print_freq = 10
21
+
22
+ loss_func = nn.MSELoss()
23
+
24
+ for step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
25
+ # assign learning rate & weight decay for each step
26
+ it = start_steps + step # global training iteration
27
+ if lr_schedule_values is not None or wd_schedule_values is not None:
28
+ for i, param_group in enumerate(optimizer.param_groups):
29
+ if lr_schedule_values is not None:
30
+ param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
31
+ if wd_schedule_values is not None and param_group["weight_decay"] > 0:
32
+ param_group["weight_decay"] = wd_schedule_values[it]
33
+
34
+ videos, bool_masked_pos = batch
35
+ videos = videos.to(device, non_blocking=True)
36
+ bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
37
+ #print("input_1",videos.size(),bool_masked_pos.size())
38
+ bs, _, nf, h, w = videos.shape
39
+
40
+ idx = torch.randperm(bool_masked_pos.size(0))
41
+ shuffled_bool_masked_pos = bool_masked_pos[idx,:]
42
+
43
+ if 'pixel' in target_type:
44
+
45
+ with torch.no_grad():
46
+ # calculate the predict label
47
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
48
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
49
+ unnorm_videos = videos * std + mean # in [0, 1]
50
+
51
+ if normlize_target:
52
+ videos_squeeze = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', p0=2, p1=patch_size, p2=patch_size)
53
+ videos_norm = (videos_squeeze - videos_squeeze.mean(dim=-2, keepdim=True)
54
+ ) / (videos_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
55
+ # we find that the mean is about 0.48 and standard deviation is about 0.08.
56
+ videos_patch = rearrange(videos_norm, 'b n p c -> b n (p c)')
57
+ else:
58
+ videos_patch = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2 c)', p0=2, p1=patch_size, p2=patch_size)
59
+
60
+ B, _, C = videos_patch.shape
61
+ if not multiple_sampling:
62
+ labels = videos_patch[bool_masked_pos].reshape(B, -1, C)
63
+ else:
64
+ labels_1 = videos_patch[bool_masked_pos].reshape(B, -1, C)
65
+ labels_2 = videos_patch[shuffled_bool_masked_pos].reshape(B, -1, C)
66
+
67
+ elif 'dino' in target_type or 'clip' in target_type:
68
+
69
+ with torch.no_grad():
70
+ permuted_video = videos.permute(0, 2, 1, 3, 4)
71
+ bs, nf, _, h, w = permuted_video.shape
72
+ permuted_video = permuted_video[:, ::2].flatten(0, 1)
73
+ permuted_video = permuted_video.to(device, non_blocking=True)
74
+ features = teacher_model(permuted_video)
75
+ _, np, dim = features.shape
76
+ features = features.reshape(bs, nf//2, np, dim)
77
+ features.requires_grad = False
78
+
79
+ features = features.to(device, non_blocking=True)
80
+ with torch.no_grad():
81
+ features_squeeze = rearrange(features, 'b n o c -> b (n o) c')
82
+ if normlize_target:
83
+ labels = (features_squeeze - features_squeeze.mean(dim=-2, keepdim=True)
84
+ ) / (features_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
85
+ else:
86
+ labels = features_squeeze
87
+ B, _, C = labels.shape
88
+ if not multiple_sampling:
89
+ labels = labels[bool_masked_pos].reshape(B, -1, C)
90
+ else:
91
+ labels_1 = labels[bool_masked_pos].reshape(B, -1, C)
92
+ labels_2 = labels[shuffled_bool_masked_pos].reshape(B, -1, C)
93
+
94
+
95
+ with torch.cuda.amp.autocast():
96
+ if not multiple_sampling:
97
+ outputs = model(videos, bool_masked_pos)
98
+ else:
99
+ outputs_1 = model(videos, bool_masked_pos)
100
+ outputs_2 = model(videos,shuffled_bool_masked_pos)
101
+
102
+ labels = torch.cat((labels_1,labels_2),dim=0)
103
+ outputs = torch.cat((outputs_1,outputs_2),dim=0)
104
+
105
+ loss = loss_func(input=outputs, target=labels)
106
+
107
+ loss_value = loss.item()
108
+ if not math.isfinite(loss_value):
109
+ print("Loss is {}, stopping training".format(loss_value))
110
+ sys.exit(1)
111
+
112
+ optimizer.zero_grad()
113
+ # this attribute is added by timm on one optimizer (adahessian)
114
+ is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
115
+ grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
116
+ parameters=model.parameters(), create_graph=is_second_order)
117
+ loss_scale_value = loss_scaler.state_dict()["scale"]
118
+
119
+ torch.cuda.synchronize()
120
+
121
+ metric_logger.update(loss=loss_value)
122
+ metric_logger.update(loss_scale=loss_scale_value)
123
+ min_lr = 10.
124
+ max_lr = 0.
125
+ for group in optimizer.param_groups:
126
+ min_lr = min(min_lr, group["lr"])
127
+ max_lr = max(max_lr, group["lr"])
128
+
129
+ metric_logger.update(lr=max_lr)
130
+ metric_logger.update(min_lr=min_lr)
131
+ weight_decay_value = None
132
+ for group in optimizer.param_groups:
133
+ if group["weight_decay"] > 0:
134
+ weight_decay_value = group["weight_decay"]
135
+ metric_logger.update(weight_decay=weight_decay_value)
136
+ metric_logger.update(grad_norm=grad_norm)
137
+
138
+ if log_writer is not None:
139
+ log_writer.update(loss=loss_value, head="loss")
140
+ log_writer.update(loss_scale=loss_scale_value, head="opt")
141
+ log_writer.update(lr=max_lr, head="opt")
142
+ log_writer.update(min_lr=min_lr, head="opt")
143
+ log_writer.update(weight_decay=weight_decay_value, head="opt")
144
+ log_writer.update(grad_norm=grad_norm, head="opt")
145
+ log_writer.set_step()
146
+
147
+ if lr_scheduler is not None:
148
+ lr_scheduler.step_update(start_steps + step)
149
+ # gather the stats from all processes
150
+ metric_logger.synchronize_between_processes()
151
+ print("Averaged stats:", metric_logger)
152
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
environment.yml ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: smile
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - anaconda
6
+ - conda-forge
7
+ - defaults
8
+ dependencies:
9
+ - _libgcc_mutex=0.1=conda_forge
10
+ - _openmp_mutex=4.5=2_gnu
11
+ - alsa-lib=1.2.8=h166bdaf_0
12
+ - aom=3.5.0=h27087fc_0
13
+ - appdirs=1.4.4=pyh9f0ad1d_0
14
+ - attr=2.5.1=h166bdaf_1
15
+ - blas=1.0=mkl
16
+ - bottleneck=1.3.5=py310ha9d4c09_0
17
+ - brotli-python=1.1.0=py310hc6cd4ac_1
18
+ - bzip2=1.0.8=hd590300_5
19
+ - c-ares=1.23.0=hd590300_0
20
+ - ca-certificates=2023.11.17=hbcca054_0
21
+ - cairo=1.16.0=ha61ee94_1014
22
+ - certifi=2023.11.17=py310h06a4308_0
23
+ - charset-normalizer=3.3.2=pyhd8ed1ab_0
24
+ - click=8.1.7=unix_pyh707e725_0
25
+ - colorama=0.4.6=pyhd8ed1ab_0
26
+ - cuda-cudart=11.8.89=0
27
+ - cuda-cupti=11.8.87=0
28
+ - cuda-libraries=11.8.0=0
29
+ - cuda-nvrtc=11.8.89=0
30
+ - cuda-nvtx=11.8.86=0
31
+ - cuda-runtime=11.8.0=0
32
+ - dbus=1.13.6=h5008d03_3
33
+ - docker-pycreds=0.4.0=py_0
34
+ - einops=0.7.0=pyhd8ed1ab_1
35
+ - expat=2.5.0=hcb278e6_1
36
+ - ffmpeg=5.1.2=gpl_h8dda1f0_106
37
+ - fftw=3.3.10=nompi_hc118613_108
38
+ - filelock=3.13.1=pyhd8ed1ab_0
39
+ - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
40
+ - font-ttf-inconsolata=3.000=h77eed37_0
41
+ - font-ttf-source-code-pro=2.038=h77eed37_0
42
+ - font-ttf-ubuntu=0.83=h77eed37_1
43
+ - fontconfig=2.14.2=h14ed4e7_0
44
+ - fonts-conda-ecosystem=1=0
45
+ - fonts-conda-forge=1=0
46
+ - freeglut=3.2.2=h9c3ff4c_1
47
+ - freetype=2.12.1=h267a509_2
48
+ - fsspec=2023.12.0=pyhca7485f_0
49
+ - gettext=0.21.1=h27087fc_0
50
+ - gitdb=4.0.11=pyhd8ed1ab_0
51
+ - gitpython=3.1.40=pyhd8ed1ab_0
52
+ - glib=2.78.1=hfc55251_1
53
+ - glib-tools=2.78.1=hfc55251_1
54
+ - gmp=6.3.0=h59595ed_0
55
+ - gmpy2=2.1.2=py310h3ec546c_1
56
+ - gnutls=3.7.9=hb077bed_0
57
+ - graphite2=1.3.13=h58526e2_1001
58
+ - gst-plugins-base=1.22.0=h4243ec0_2
59
+ - gstreamer=1.22.0=h25f0c4b_2
60
+ - gstreamer-orc=0.4.34=hd590300_0
61
+ - harfbuzz=6.0.0=h8e241bc_0
62
+ - hdf5=1.14.0=nompi_hb72d44e_103
63
+ - huggingface_hub=0.19.4=pyhd8ed1ab_0
64
+ - icu=70.1=h27087fc_0
65
+ - idna=3.6=pyhd8ed1ab_0
66
+ - intel-openmp=2023.1.0=hdb19cb5_46306
67
+ - jack=1.9.22=h11f4161_0
68
+ - jasper=2.0.33=h0ff4b12_1
69
+ - jinja2=3.1.2=pyhd8ed1ab_1
70
+ - jpeg=9e=h166bdaf_2
71
+ - keyutils=1.6.1=h166bdaf_0
72
+ - krb5=1.20.1=h81ceb04_0
73
+ - lame=3.100=h166bdaf_1003
74
+ - lcms2=2.15=hfd0df8a_0
75
+ - ld_impl_linux-64=2.40=h41732ed_0
76
+ - lerc=4.0.0=h27087fc_0
77
+ - libabseil=20230802.1=cxx17_h59595ed_0
78
+ - libaec=1.1.2=h59595ed_1
79
+ - libblas=3.9.0=1_h86c2bf4_netlib
80
+ - libcap=2.67=he9d0100_0
81
+ - libcblas=3.9.0=5_h92ddd45_netlib
82
+ - libclang=15.0.7=default_hb11cfb5_4
83
+ - libclang13=15.0.7=default_ha2b6cf4_4
84
+ - libcublas=11.11.3.6=0
85
+ - libcufft=10.9.0.58=0
86
+ - libcufile=1.8.1.2=0
87
+ - libcups=2.3.3=h36d4200_3
88
+ - libcurand=10.3.4.101=0
89
+ - libcurl=8.1.2=h409715c_0
90
+ - libcusolver=11.4.1.48=0
91
+ - libcusparse=11.7.5.86=0
92
+ - libdb=6.2.32=h9c3ff4c_0
93
+ - libdeflate=1.17=h0b41bf4_0
94
+ - libdrm=2.4.114=h166bdaf_0
95
+ - libedit=3.1.20191231=he28a2e2_2
96
+ - libev=4.33=h516909a_1
97
+ - libevent=2.1.10=h28343ad_4
98
+ - libexpat=2.5.0=hcb278e6_1
99
+ - libffi=3.4.2=h7f98852_5
100
+ - libflac=1.4.3=h59595ed_0
101
+ - libgcc-ng=13.2.0=h807b86a_3
102
+ - libgcrypt=1.10.3=hd590300_0
103
+ - libgfortran-ng=13.2.0=h69a702a_3
104
+ - libgfortran5=13.2.0=ha4646dd_3
105
+ - libglib=2.78.1=h783c2da_1
106
+ - libglu=9.0.0=he1b5a44_1001
107
+ - libgomp=13.2.0=h807b86a_3
108
+ - libgpg-error=1.47=h71f35ed_0
109
+ - libhwloc=2.9.1=hd6dc26d_0
110
+ - libiconv=1.17=h166bdaf_0
111
+ - libidn2=2.3.4=h166bdaf_0
112
+ - libjpeg-turbo=2.0.0=h9bf148f_0
113
+ - liblapack=3.9.0=5_h92ddd45_netlib
114
+ - liblapacke=3.9.0=5_h92ddd45_netlib
115
+ - libllvm15=15.0.7=hadd5161_1
116
+ - libnghttp2=1.58.0=h47da74e_0
117
+ - libnpp=11.8.0.86=0
118
+ - libnsl=2.0.1=hd590300_0
119
+ - libnvjpeg=11.9.0.86=0
120
+ - libogg=1.3.4=h7f98852_1
121
+ - libopencv=4.7.0=py310hb48cf42_1
122
+ - libopus=1.3.1=h7f98852_1
123
+ - libpciaccess=0.17=h166bdaf_0
124
+ - libpng=1.6.39=h753d276_0
125
+ - libpq=15.3=hbcd7760_1
126
+ - libprotobuf=3.21.12=hfc55251_2
127
+ - libsndfile=1.2.2=hc60ed4a_1
128
+ - libsqlite=3.44.2=h2797004_0
129
+ - libssh2=1.11.0=h0841786_0
130
+ - libstdcxx-ng=13.2.0=h7e041cc_3
131
+ - libsystemd0=253=h8c4010b_1
132
+ - libtasn1=4.19.0=h166bdaf_0
133
+ - libtiff=4.5.0=h6adf6a1_2
134
+ - libtool=2.4.7=h27087fc_0
135
+ - libudev1=253=h0b41bf4_1
136
+ - libunistring=0.9.10=h7f98852_0
137
+ - libuuid=2.38.1=h0b41bf4_0
138
+ - libva=2.18.0=h0b41bf4_0
139
+ - libvorbis=1.3.7=h9c3ff4c_0
140
+ - libvpx=1.11.0=h9c3ff4c_3
141
+ - libwebp-base=1.3.2=hd590300_0
142
+ - libxcb=1.13=h7f98852_1004
143
+ - libxkbcommon=1.5.0=h79f4944_1
144
+ - libxml2=2.10.3=hca2bb57_4
145
+ - libzlib=1.2.13=hd590300_5
146
+ - llvm-openmp=15.0.7=h0cdce71_0
147
+ - lz4-c=1.9.4=hcb278e6_0
148
+ - markupsafe=2.1.3=py310h2372a71_1
149
+ - mkl=2023.1.0=h213fc3f_46344
150
+ - mkl-service=2.4.0=py310h5eee18b_1
151
+ - mpc=1.3.1=hfe3b2da_0
152
+ - mpfr=4.2.1=h9458935_0
153
+ - mpg123=1.32.3=h59595ed_0
154
+ - mpmath=1.3.0=pyhd8ed1ab_0
155
+ - mysql-common=8.0.33=hf1915f5_6
156
+ - mysql-libs=8.0.33=hca2cd23_6
157
+ - ncurses=6.4=h59595ed_2
158
+ - nettle=3.9.1=h7ab15ed_0
159
+ - networkx=3.2.1=pyhd8ed1ab_0
160
+ - nspr=4.35=h27087fc_0
161
+ - nss=3.95=h1d7d5a4_0
162
+ - numexpr=2.8.7=py310h85018f9_0
163
+ - numpy=1.26.2=py310hb13e2d6_0
164
+ - opencv=4.7.0=py310hff52083_1
165
+ - openh264=2.3.1=hcb278e6_2
166
+ - openjpeg=2.5.0=hfec8fc6_2
167
+ - openssl=3.1.4=hd590300_0
168
+ - p11-kit=0.24.1=hc5aa10d_0
169
+ - packaging=23.2=pyhd8ed1ab_0
170
+ - pandas=2.1.1=py310h1128e8f_0
171
+ - pathtools=0.1.2=py_1
172
+ - pcre2=10.42=hcad00b1_0
173
+ - pillow=9.4.0=py310h023d228_1
174
+ - pip=23.3.1=pyhd8ed1ab_0
175
+ - pixman=0.42.2=h59595ed_0
176
+ - protobuf=4.21.12=py310heca2aa9_0
177
+ - pthread-stubs=0.4=h36c2ea0_1001
178
+ - pulseaudio=16.1=hcb278e6_3
179
+ - pulseaudio-client=16.1=h5195f5e_3
180
+ - pulseaudio-daemon=16.1=ha8d29e2_3
181
+ - py-opencv=4.7.0=py310hfdc917e_1
182
+ - pysocks=1.7.1=pyha2e5f31_6
183
+ - python=3.10.13=hd12c33a_0_cpython
184
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
185
+ - python-tzdata=2023.3=pyhd3eb1b0_0
186
+ - python_abi=3.10=4_cp310
187
+ - pytorch=2.1.1=py3.10_cuda11.8_cudnn8.7.0_0
188
+ - pytorch-cuda=11.8=h7e8668a_5
189
+ - pytorch-mutex=1.0=cuda
190
+ - pytz=2023.3.post1=py310h06a4308_0
191
+ - pyyaml=6.0.1=py310h2372a71_1
192
+ - qt-main=5.15.8=h5d23da1_6
193
+ - readline=8.2=h8228510_1
194
+ - requests=2.31.0=pyhd8ed1ab_0
195
+ - safetensors=0.3.3=py310hcb5633a_1
196
+ - scipy=1.11.3=py310h5f9d8c6_0
197
+ - sentry-sdk=1.38.0=pyhd8ed1ab_0
198
+ - setproctitle=1.3.3=py310h2372a71_0
199
+ - setuptools=68.2.2=pyhd8ed1ab_0
200
+ - six=1.16.0=pyh6c4a22f_0
201
+ - smmap=5.0.0=pyhd8ed1ab_0
202
+ - svt-av1=1.4.1=hcb278e6_0
203
+ - sympy=1.12=pypyh9d50eac_103
204
+ - tbb=2021.9.0=hf52228f_0
205
+ - tensorboardx=2.6.2.2=pyhd8ed1ab_0
206
+ - timm=0.9.12=pyhd8ed1ab_0
207
+ - tk=8.6.13=noxft_h4845f30_101
208
+ - torchaudio=2.1.1=py310_cu118
209
+ - torchtriton=2.1.0=py310
210
+ - torchvision=0.16.1=py310_cu118
211
+ - tqdm=4.66.1=pyhd8ed1ab_0
212
+ - typing-extensions=4.8.0=hd8ed1ab_0
213
+ - typing_extensions=4.8.0=pyha770c72_0
214
+ - tzdata=2023c=h71feb2d_0
215
+ - urllib3=2.1.0=pyhd8ed1ab_0
216
+ - wandb=0.15.12=pyhd8ed1ab_0
217
+ - wheel=0.42.0=pyhd8ed1ab_0
218
+ - x264=1!164.3095=h166bdaf_2
219
+ - x265=3.5=h924138e_3
220
+ - xcb-util=0.4.0=h516909a_0
221
+ - xcb-util-image=0.4.0=h166bdaf_0
222
+ - xcb-util-keysyms=0.4.0=h516909a_0
223
+ - xcb-util-renderutil=0.3.9=h166bdaf_0
224
+ - xcb-util-wm=0.4.1=h516909a_0
225
+ - xkeyboard-config=2.38=h0b41bf4_0
226
+ - xorg-fixesproto=5.0=h7f98852_1002
227
+ - xorg-inputproto=2.3.2=h7f98852_1002
228
+ - xorg-kbproto=1.0.7=h7f98852_1002
229
+ - xorg-libice=1.1.1=hd590300_0
230
+ - xorg-libsm=1.2.4=h7391055_0
231
+ - xorg-libx11=1.8.4=h0b41bf4_0
232
+ - xorg-libxau=1.0.11=hd590300_0
233
+ - xorg-libxdmcp=1.1.3=h7f98852_0
234
+ - xorg-libxext=1.3.4=h0b41bf4_2
235
+ - xorg-libxfixes=5.0.3=h7f98852_1004
236
+ - xorg-libxi=1.7.10=h7f98852_0
237
+ - xorg-libxrender=0.9.10=h7f98852_1003
238
+ - xorg-renderproto=0.11.1=h7f98852_1002
239
+ - xorg-xextproto=7.3.0=h0b41bf4_1003
240
+ - xorg-xproto=7.0.31=h7f98852_1007
241
+ - xz=5.2.6=h166bdaf_0
242
+ - yaml=0.2.5=h7f98852_2
243
+ - zlib=1.2.13=hd590300_5
244
+ - zstd=1.5.5=hfc55251_0
245
+ - pip:
246
+ - annotated-types==0.6.0
247
+ - decord==0.6.0
248
+ - hjson==3.1.0
249
+ - ninja==1.11.1.1
250
+ - psutil==5.9.6
251
+ - py-cpuinfo==9.0.0
252
+ - pydantic==2.5.2
253
+ - pydantic-core==2.14.5
254
+ - pynvml==11.5.0
255
+ - imutils==0.5.4
256
+ - transformers==4.31.0
257
+ - ftfy
258
+ - easydict
259
+ - matplotlib==3.10.0
functional.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ import cv2
3
+ import numpy as np
4
+ import PIL
5
+ import torch
6
+
7
+
8
+ def _is_tensor_clip(clip):
9
+ return torch.is_tensor(clip) and clip.ndimension() == 4
10
+
11
+
12
+ def crop_clip(clip, min_h, min_w, h, w):
13
+ if isinstance(clip[0], np.ndarray):
14
+ cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
15
+
16
+ elif isinstance(clip[0], PIL.Image.Image):
17
+ cropped = [
18
+ img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
19
+ ]
20
+ else:
21
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
22
+ 'but got list of {0}'.format(type(clip[0])))
23
+ return cropped
24
+
25
+
26
+ def resize_clip(clip, size, interpolation='bilinear'):
27
+ if isinstance(clip[0], np.ndarray):
28
+ if isinstance(size, numbers.Number):
29
+ im_h, im_w, im_c = clip[0].shape
30
+ # Min spatial dim already matches minimal size
31
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
32
+ and im_h == size):
33
+ return clip
34
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
35
+ size = (new_w, new_h)
36
+ else:
37
+ size = size[0], size[1]
38
+ if interpolation == 'bilinear':
39
+ np_inter = cv2.INTER_LINEAR
40
+ else:
41
+ np_inter = cv2.INTER_NEAREST
42
+ scaled = [
43
+ cv2.resize(img, size, interpolation=np_inter) for img in clip
44
+ ]
45
+ elif isinstance(clip[0], PIL.Image.Image):
46
+ if isinstance(size, numbers.Number):
47
+ im_w, im_h = clip[0].size
48
+ # Min spatial dim already matches minimal size
49
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
50
+ and im_h == size):
51
+ return clip
52
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
53
+ size = (new_w, new_h)
54
+ else:
55
+ size = size[1], size[0]
56
+ if interpolation == 'bilinear':
57
+ pil_inter = PIL.Image.BILINEAR
58
+ else:
59
+ pil_inter = PIL.Image.NEAREST
60
+ scaled = [img.resize(size, pil_inter) for img in clip]
61
+ else:
62
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
63
+ 'but got list of {0}'.format(type(clip[0])))
64
+ return scaled
65
+
66
+
67
+ def get_resize_sizes(im_h, im_w, size):
68
+ if im_w < im_h:
69
+ ow = size
70
+ oh = int(size * im_h / im_w)
71
+ else:
72
+ oh = size
73
+ ow = int(size * im_w / im_h)
74
+ return oh, ow
75
+
76
+
77
+ def normalize(clip, mean, std, inplace=False):
78
+ if not _is_tensor_clip(clip):
79
+ raise TypeError('tensor is not a torch clip.')
80
+
81
+ if not inplace:
82
+ clip = clip.clone()
83
+
84
+ dtype = clip.dtype
85
+ mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
86
+ std = torch.as_tensor(std, dtype=dtype, device=clip.device)
87
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
88
+
89
+ return clip
kinetics.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from numpy.lib.function_base import disp
4
+ import torch
5
+ import decord
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from random_erasing import RandomErasing
9
+ import warnings
10
+ from decord import VideoReader, cpu
11
+ from torch.utils.data import Dataset
12
+ import video_transforms as video_transforms
13
+ import volume_transforms as volume_transforms
14
+
15
+ class VideoClsDataset(Dataset):
16
+ """Load your own video classification dataset."""
17
+
18
+ def __init__(self, anno_path, data_path, mode='train', clip_len=8,
19
+ frame_sample_rate=2, crop_size=224, short_side_size=256,
20
+ new_height=256, new_width=340, keep_aspect_ratio=True,
21
+ num_segment=1, num_crop=1, test_num_segment=10, test_num_crop=3,args=None):
22
+ self.anno_path = anno_path
23
+ self.data_path = data_path
24
+ self.mode = mode
25
+ self.clip_len = clip_len
26
+ self.frame_sample_rate = frame_sample_rate
27
+ self.crop_size = crop_size
28
+ self.short_side_size = short_side_size
29
+ self.new_height = new_height
30
+ self.new_width = new_width
31
+ self.keep_aspect_ratio = keep_aspect_ratio
32
+ self.num_segment = num_segment
33
+ self.test_num_segment = test_num_segment
34
+ self.num_crop = num_crop
35
+ self.test_num_crop = test_num_crop
36
+ self.args = args
37
+ self.aug = False
38
+ self.rand_erase = False
39
+ if self.mode in ['train']:
40
+ self.aug = True
41
+ if self.args.reprob > 0:
42
+ self.rand_erase = True
43
+ if VideoReader is None:
44
+ raise ImportError("Unable to import `decord` which is required to read videos.")
45
+
46
+ import pandas as pd
47
+ cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ')
48
+ self.dataset_samples = list(cleaned.values[:, 0])
49
+ self.label_array = list(cleaned.values[:, 1])
50
+
51
+ if (mode == 'train'):
52
+ pass
53
+
54
+ elif (mode == 'validation'):
55
+ self.data_transform = video_transforms.Compose([
56
+ video_transforms.Resize(self.short_side_size, interpolation='bilinear'),
57
+ video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)),
58
+ volume_transforms.ClipToTensor(),
59
+ video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
60
+ std=[0.229, 0.224, 0.225])
61
+ ])
62
+ elif mode == 'test':
63
+ self.data_resize = video_transforms.Compose([
64
+ video_transforms.Resize(size=(short_side_size), interpolation='bilinear')
65
+ ])
66
+ self.data_transform = video_transforms.Compose([
67
+ volume_transforms.ClipToTensor(),
68
+ video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
69
+ std=[0.229, 0.224, 0.225])
70
+ ])
71
+ self.test_seg = []
72
+ self.test_dataset = []
73
+ self.test_label_array = []
74
+ for ck in range(self.test_num_segment):
75
+ for cp in range(self.test_num_crop):
76
+ for idx in range(len(self.label_array)):
77
+ sample_label = self.label_array[idx]
78
+ self.test_label_array.append(sample_label)
79
+ self.test_dataset.append(self.dataset_samples[idx])
80
+ self.test_seg.append((ck, cp))
81
+
82
+ def __getitem__(self, index):
83
+ if self.mode == 'train':
84
+ args = self.args
85
+ scale_t = 1
86
+
87
+ sample = self.dataset_samples[index]
88
+ buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C
89
+ if len(buffer) == 0:
90
+ while len(buffer) == 0:
91
+ warnings.warn("video {} not correctly loaded during training".format(sample))
92
+ index = np.random.randint(self.__len__())
93
+ sample = self.dataset_samples[index]
94
+ buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t)
95
+
96
+ if args.num_sample > 1:
97
+ frame_list = []
98
+ label_list = []
99
+ index_list = []
100
+ for _ in range(args.num_sample):
101
+ new_frames = self._aug_frame(buffer, args)
102
+ label = self.label_array[index]
103
+ frame_list.append(new_frames)
104
+ label_list.append(label)
105
+ index_list.append(index)
106
+ return frame_list, label_list, index_list, {}
107
+ else:
108
+ buffer = self._aug_frame(buffer, args)
109
+
110
+ return buffer, self.label_array[index], index, {}
111
+
112
+ elif self.mode == 'validation':
113
+ sample = self.dataset_samples[index]
114
+ buffer = self.loadvideo_decord(sample)
115
+ if len(buffer) == 0:
116
+ while len(buffer) == 0:
117
+ warnings.warn("video {} not correctly loaded during validation".format(sample))
118
+ index = np.random.randint(self.__len__())
119
+ sample = self.dataset_samples[index]
120
+ buffer = self.loadvideo_decord(sample)
121
+ buffer = self.data_transform(buffer)
122
+ return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
123
+
124
+ elif self.mode == 'test':
125
+ sample = self.test_dataset[index]
126
+ chunk_nb, split_nb = self.test_seg[index]
127
+ buffer = self.loadvideo_decord(sample)
128
+
129
+ while len(buffer) == 0:
130
+ warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
131
+ str(self.test_dataset[index]), chunk_nb, split_nb))
132
+ index = np.random.randint(self.__len__())
133
+ sample = self.test_dataset[index]
134
+ chunk_nb, split_nb = self.test_seg[index]
135
+ buffer = self.loadvideo_decord(sample)
136
+
137
+ buffer = self.data_resize(buffer)
138
+ if isinstance(buffer, list):
139
+ buffer = np.stack(buffer, 0)
140
+
141
+ spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
142
+ / (self.test_num_crop - 1)
143
+ temporal_step = max(1.0 * (buffer.shape[0] - self.clip_len) \
144
+ / (self.test_num_segment - 1), 0)
145
+ temporal_start = int(chunk_nb * temporal_step)
146
+ spatial_start = int(split_nb * spatial_step)
147
+ if buffer.shape[1] >= buffer.shape[2]:
148
+ buffer = buffer[temporal_start:temporal_start + self.clip_len, \
149
+ spatial_start:spatial_start + self.short_side_size, :, :]
150
+ else:
151
+ buffer = buffer[temporal_start:temporal_start + self.clip_len, \
152
+ :, spatial_start:spatial_start + self.short_side_size, :]
153
+
154
+ buffer = self.data_transform(buffer)
155
+ return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
156
+ chunk_nb, split_nb
157
+ else:
158
+ raise NameError('mode {} unkown'.format(self.mode))
159
+
160
+ def _aug_frame(
161
+ self,
162
+ buffer,
163
+ args,
164
+ ):
165
+
166
+ aug_transform = video_transforms.create_random_augment(
167
+ input_size=(self.crop_size, self.crop_size),
168
+ auto_augment=args.aa,
169
+ interpolation=args.train_interpolation,
170
+ )
171
+
172
+ buffer = [
173
+ transforms.ToPILImage()(frame) for frame in buffer
174
+ ]
175
+
176
+ buffer = aug_transform(buffer)
177
+
178
+ buffer = [transforms.ToTensor()(img) for img in buffer]
179
+ buffer = torch.stack(buffer) # T C H W
180
+ buffer = buffer.permute(0, 2, 3, 1) # T H W C
181
+
182
+ # T H W C
183
+ buffer = tensor_normalize(
184
+ buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
185
+ )
186
+ # T H W C -> C T H W.
187
+ buffer = buffer.permute(3, 0, 1, 2)
188
+ # Perform data augmentation.
189
+ scl, asp = (
190
+ [0.25, 1.0],
191
+ [0.75, 1.3333],
192
+ )
193
+
194
+ buffer = spatial_sampling(
195
+ buffer,
196
+ spatial_idx=-1,
197
+ min_scale=256,
198
+ max_scale=320,
199
+ crop_size=self.crop_size,
200
+ random_horizontal_flip=False if args.data_set == 'SSV2' else True ,
201
+ inverse_uniform_sampling=False,
202
+ aspect_ratio=asp,
203
+ scale=scl,
204
+ motion_shift=False
205
+ )
206
+
207
+ if self.rand_erase:
208
+ erase_transform = RandomErasing(
209
+ args.reprob,
210
+ mode=args.remode,
211
+ max_count=args.recount,
212
+ num_splits=args.recount,
213
+ device="cpu",
214
+ )
215
+ buffer = buffer.permute(1, 0, 2, 3)
216
+ buffer = erase_transform(buffer)
217
+ buffer = buffer.permute(1, 0, 2, 3)
218
+
219
+ return buffer
220
+
221
+
222
+ def loadvideo_decord(self, sample, sample_rate_scale=1):
223
+ """Load video content using Decord"""
224
+ fname = sample
225
+
226
+ if not (os.path.exists(fname)):
227
+ return []
228
+
229
+ # avoid hanging issue
230
+ if os.path.getsize(fname) < 1 * 1024:
231
+ print('SKIP: ', fname, " - ", os.path.getsize(fname))
232
+ return []
233
+ try:
234
+ if self.keep_aspect_ratio:
235
+ vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
236
+ else:
237
+ vr = VideoReader(fname, width=self.new_width, height=self.new_height,
238
+ num_threads=1, ctx=cpu(0))
239
+ except:
240
+ print("video cannot be loaded by decord: ", fname)
241
+ return []
242
+
243
+ if self.mode == 'test':
244
+ all_index = [x for x in range(0, len(vr), self.frame_sample_rate)]
245
+ while len(all_index) < self.clip_len:
246
+ all_index.append(all_index[-1])
247
+ vr.seek(0)
248
+ buffer = vr.get_batch(all_index).asnumpy()
249
+ return buffer
250
+
251
+ # handle temporal segments
252
+ converted_len = int(self.clip_len * self.frame_sample_rate)
253
+ seg_len = len(vr) // self.num_segment
254
+
255
+ all_index = []
256
+ for i in range(self.num_segment):
257
+ if seg_len <= converted_len:
258
+ index = np.linspace(0, seg_len, num=seg_len // self.frame_sample_rate)
259
+ index = np.concatenate((index, np.ones(self.clip_len - seg_len // self.frame_sample_rate) * seg_len))
260
+ index = np.clip(index, 0, seg_len - 1).astype(np.int64)
261
+ else:
262
+ end_idx = np.random.randint(converted_len, seg_len)
263
+ str_idx = end_idx - converted_len
264
+ index = np.linspace(str_idx, end_idx, num=self.clip_len)
265
+ index = np.clip(index, str_idx, end_idx - 1).astype(np.int64)
266
+ index = index + i*seg_len
267
+ all_index.extend(list(index))
268
+
269
+ all_index = all_index[::int(sample_rate_scale)]
270
+ vr.seek(0)
271
+ buffer = vr.get_batch(all_index).asnumpy()
272
+ return buffer
273
+
274
+ def __len__(self):
275
+ if self.mode != 'test':
276
+ return len(self.dataset_samples)
277
+ else:
278
+ return len(self.test_dataset)
279
+
280
+
281
+ def spatial_sampling(
282
+ frames,
283
+ spatial_idx=-1,
284
+ min_scale=256,
285
+ max_scale=320,
286
+ crop_size=224,
287
+ random_horizontal_flip=True,
288
+ inverse_uniform_sampling=False,
289
+ aspect_ratio=None,
290
+ scale=None,
291
+ motion_shift=False,
292
+ ):
293
+ """
294
+ Perform spatial sampling on the given video frames. If spatial_idx is
295
+ -1, perform random scale, random crop, and random flip on the given
296
+ frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
297
+ with the given spatial_idx.
298
+ Args:
299
+ frames (tensor): frames of images sampled from the video. The
300
+ dimension is `num frames` x `height` x `width` x `channel`.
301
+ spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
302
+ or 2, perform left, center, right crop if width is larger than
303
+ height, and perform top, center, buttom crop if height is larger
304
+ than width.
305
+ min_scale (int): the minimal size of scaling.
306
+ max_scale (int): the maximal size of scaling.
307
+ crop_size (int): the size of height and width used to crop the
308
+ frames.
309
+ inverse_uniform_sampling (bool): if True, sample uniformly in
310
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
311
+ scale. If False, take a uniform sample from [min_scale,
312
+ max_scale].
313
+ aspect_ratio (list): Aspect ratio range for resizing.
314
+ scale (list): Scale range for resizing.
315
+ motion_shift (bool): Whether to apply motion shift for resizing.
316
+ Returns:
317
+ frames (tensor): spatially sampled frames.
318
+ """
319
+ assert spatial_idx in [-1, 0, 1, 2]
320
+ if spatial_idx == -1:
321
+ if aspect_ratio is None and scale is None:
322
+ frames, _ = video_transforms.random_short_side_scale_jitter(
323
+ images=frames,
324
+ min_size=min_scale,
325
+ max_size=max_scale,
326
+ inverse_uniform_sampling=inverse_uniform_sampling,
327
+ )
328
+ frames, _ = video_transforms.random_crop(frames, crop_size)
329
+ else:
330
+ transform_func = (
331
+ video_transforms.random_resized_crop_with_shift
332
+ if motion_shift
333
+ else video_transforms.random_resized_crop
334
+ )
335
+ frames = transform_func(
336
+ images=frames,
337
+ target_height=crop_size,
338
+ target_width=crop_size,
339
+ scale=scale,
340
+ ratio=aspect_ratio,
341
+ )
342
+ if random_horizontal_flip:
343
+ frames, _ = video_transforms.horizontal_flip(0.5, frames)
344
+ else:
345
+ # The testing is deterministic and no jitter should be performed.
346
+ # min_scale, max_scale, and crop_size are expect to be the same.
347
+ assert len({min_scale, max_scale, crop_size}) == 1
348
+ frames, _ = video_transforms.random_short_side_scale_jitter(
349
+ frames, min_scale, max_scale
350
+ )
351
+ frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx)
352
+ return frames
353
+
354
+
355
+ def tensor_normalize(tensor, mean, std):
356
+ """
357
+ Normalize a given tensor by subtracting the mean and dividing the std.
358
+ Args:
359
+ tensor (tensor): tensor to normalize.
360
+ mean (tensor or list): mean value to subtract.
361
+ std (tensor or list): std to divide.
362
+ """
363
+ if tensor.dtype == torch.uint8:
364
+ tensor = tensor.float()
365
+ tensor = tensor / 255.0
366
+ if type(mean) == list:
367
+ mean = torch.tensor(mean)
368
+ if type(std) == list:
369
+ std = torch.tensor(std)
370
+ tensor = tensor - mean
371
+ tensor = tensor / std
372
+ return tensor
373
+
374
+
375
+ class VideoMAE(torch.utils.data.Dataset):
376
+ """Load your own video classification dataset.
377
+ Parameters
378
+ ----------
379
+ root : str, required.
380
+ Path to the root folder storing the dataset.
381
+ setting : str, required.
382
+ A text file describing the dataset, each line per video sample.
383
+ There are three items in each line: (1) video path; (2) video length and (3) video label.
384
+ train : bool, default True.
385
+ Whether to load the training or validation set.
386
+ test_mode : bool, default False.
387
+ Whether to perform evaluation on the test set.
388
+ Usually there is three-crop or ten-crop evaluation strategy involved.
389
+ name_pattern : str, default None.
390
+ The naming pattern of the decoded video frames.
391
+ For example, img_00012.jpg.
392
+ video_ext : str, default 'mp4'.
393
+ If video_loader is set to True, please specify the video format accordinly.
394
+ is_color : bool, default True.
395
+ Whether the loaded image is color or grayscale.
396
+ modality : str, default 'rgb'.
397
+ Input modalities, we support only rgb video frames for now.
398
+ Will add support for rgb difference image and optical flow image later.
399
+ num_segments : int, default 1.
400
+ Number of segments to evenly divide the video into clips.
401
+ A useful technique to obtain global video-level information.
402
+ Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
403
+ num_crop : int, default 1.
404
+ Number of crops for each image. default is 1.
405
+ Common choices are three crops and ten crops during evaluation.
406
+ new_length : int, default 1.
407
+ The length of input video clip. Default is a single image, but it can be multiple video frames.
408
+ For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
409
+ new_step : int, default 1.
410
+ Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
411
+ new_step=2 means we will extract a video clip of every other frame.
412
+ temporal_jitter : bool, default False.
413
+ Whether to temporally jitter if new_step > 1.
414
+ video_loader : bool, default False.
415
+ Whether to use video loader to load data.
416
+ use_decord : bool, default True.
417
+ Whether to use Decord video loader to load data. Otherwise use mmcv video loader.
418
+ transform : function, default None.
419
+ A function that takes data and label and transforms them.
420
+ data_aug : str, default 'v1'.
421
+ Different types of data augmentation auto. Supports v1, v2, v3 and v4.
422
+ lazy_init : bool, default False.
423
+ If set to True, build a dataset instance without loading any dataset.
424
+ """
425
+ def __init__(self,
426
+ root,
427
+ setting,
428
+ train=True,
429
+ test_mode=False,
430
+ name_pattern='img_%05d.jpg',
431
+ video_ext='mp4',
432
+ is_color=True,
433
+ modality='rgb',
434
+ num_segments=1,
435
+ num_crop=1,
436
+ new_length=1,
437
+ new_step=1,
438
+ transform=None,
439
+ temporal_jitter=False,
440
+ video_loader=False,
441
+ use_decord=False,
442
+ lazy_init=False):
443
+
444
+ super(VideoMAE, self).__init__()
445
+ self.root = root
446
+ self.setting = setting
447
+ self.train = train
448
+ self.test_mode = test_mode
449
+ self.is_color = is_color
450
+ self.modality = modality
451
+ self.num_segments = num_segments
452
+ self.num_crop = num_crop
453
+ self.new_length = new_length
454
+ self.new_step = new_step
455
+ self.skip_length = self.new_length * self.new_step
456
+ self.temporal_jitter = temporal_jitter
457
+ self.name_pattern = name_pattern
458
+ self.video_loader = video_loader
459
+ self.video_ext = video_ext
460
+ self.use_decord = use_decord
461
+ self.transform = transform
462
+ self.lazy_init = lazy_init
463
+
464
+
465
+ if not self.lazy_init:
466
+ self.clips = self._make_dataset(root, setting)
467
+ if len(self.clips) == 0:
468
+ raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
469
+ "Check your data directory (opt.data-dir)."))
470
+
471
+ def __getitem__(self, index):
472
+ try:
473
+ directory, target = self.clips[index]
474
+ if self.video_loader:
475
+ if '.' in directory.split('/')[-1]:
476
+ # data in the "setting" file already have extension, e.g., demo.mp4
477
+ video_name = directory
478
+ else:
479
+ # data in the "setting" file do not have extension, e.g., demo
480
+ # So we need to provide extension (i.e., .mp4) to complete the file name.
481
+ video_name = '{}.{}'.format(directory, self.video_ext)
482
+
483
+ decord_vr = decord.VideoReader(video_name, num_threads=1)
484
+ duration = len(decord_vr)
485
+
486
+ segment_indices, skip_offsets = self._sample_train_indices(duration)
487
+
488
+ images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets)
489
+
490
+ process_data, mask = self.transform((images, None)) # T*C,H,W
491
+ process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0,1) # T*C,H,W -> T,C,H,W -> C,T,H,W
492
+ return (process_data, mask)
493
+ except Exception as error:
494
+ print(error , " in failed to load : ",video_name)
495
+ return self[(index+1) % len(self)]
496
+
497
+
498
+ def __len__(self):
499
+ return len(self.clips)
500
+
501
+ def _make_dataset(self, directory, setting):
502
+ if not os.path.exists(setting):
503
+ raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
504
+ clips = []
505
+ with open(setting) as split_f:
506
+ data = split_f.readlines()
507
+ for line in data:
508
+ line_info = line.split(' ')
509
+ # line format: video_path, video_duration, video_label
510
+ if len(line_info) < 2:
511
+ raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
512
+ clip_path = os.path.join(line_info[0])
513
+ target = int(line_info[1])
514
+ item = (clip_path, target)
515
+ clips.append(item)
516
+ return clips
517
+
518
+ def _sample_train_indices(self, num_frames):
519
+ average_duration = (num_frames - self.skip_length + 1) // self.num_segments
520
+ if average_duration > 0:
521
+ offsets = np.multiply(list(range(self.num_segments)),
522
+ average_duration)
523
+ offsets = offsets + np.random.randint(average_duration,
524
+ size=self.num_segments)
525
+ elif num_frames > max(self.num_segments, self.skip_length):
526
+ offsets = np.sort(np.random.randint(
527
+ num_frames - self.skip_length + 1,
528
+ size=self.num_segments))
529
+ else:
530
+ offsets = np.zeros((self.num_segments,))
531
+
532
+ if self.temporal_jitter:
533
+ skip_offsets = np.random.randint(
534
+ self.new_step, size=self.skip_length // self.new_step)
535
+ else:
536
+ skip_offsets = np.zeros(
537
+ self.skip_length // self.new_step, dtype=int)
538
+ return offsets + 1, skip_offsets
539
+
540
+
541
+ def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets):
542
+ sampled_list = []
543
+ frame_id_list = []
544
+ for seg_ind in indices:
545
+ offset = int(seg_ind)
546
+ for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
547
+ if offset + skip_offsets[i] <= duration:
548
+ frame_id = offset + skip_offsets[i] - 1
549
+ else:
550
+ frame_id = offset - 1
551
+ frame_id_list.append(frame_id)
552
+ if offset + self.new_step < duration:
553
+ offset += self.new_step
554
+ try:
555
+ video_data = video_reader.get_batch(frame_id_list).asnumpy()
556
+ sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]
557
+ except:
558
+ raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, duration))
559
+ return sampled_list
masking_generator.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import ast
4
+
5
+ class TubeMaskingGenerator:
6
+ def __init__(self, input_size, mask_ratio):
7
+ self.frames, self.height, self.width = input_size
8
+ self.num_patches_per_frame = self.height * self.width
9
+ self.total_patches = self.frames * self.num_patches_per_frame
10
+ self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
11
+ self.total_masks = self.frames * self.num_masks_per_frame
12
+
13
+ def __repr__(self):
14
+ repr_str = "Maks: total patches {}, mask patches {}".format(
15
+ self.total_patches, self.total_masks
16
+ )
17
+ return repr_str
18
+
19
+ def __call__(self):
20
+ mask_per_frame = np.hstack([
21
+ np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
22
+ np.ones(self.num_masks_per_frame),
23
+ ])
24
+ np.random.shuffle(mask_per_frame)
25
+ mask = np.tile(mask_per_frame, (self.frames,1)).flatten()
26
+ return mask
27
+
28
+
29
+ class TubeletMaskingGenerator:
30
+ def __init__(self, input_size, mask_ratio, visible_frames, mask_type="tube", traj_unmask_ratio=0.1):
31
+ self.tube_masking_generator = TubeMaskingGenerator(input_size, mask_ratio)
32
+ self.frames, self.height, self.width = input_size
33
+ self.num_patches_per_frame = self.height * self.width
34
+ self.total_patches = self.frames * self.num_patches_per_frame
35
+ self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
36
+ self.total_masks = self.frames * self.num_masks_per_frame
37
+ self.patch_size = 16
38
+ self.traj_unmask_ratio = traj_unmask_ratio
39
+ if visible_frames is not None:
40
+ visible_list = ast.literal_eval(visible_frames)
41
+ self.visible_frames = [int(element) for element in visible_list]
42
+ else:
43
+ self.visible_frames = None
44
+
45
+ self.mask_type = mask_type
46
+
47
+ def _balance_num_masks(self, combined_mask,
48
+ unmasked_object_patches_index,
49
+ unmasked_non_object_patches_index,
50
+ masked_object_patches_index,
51
+ tube_masked_index=None,
52
+ tube_unmasked_index=None):
53
+ current_masks = np.sum(combined_mask)
54
+ num_diff = np.abs(self.total_masks - current_masks)
55
+
56
+ if tube_masked_index is None or tube_unmasked_index is None:
57
+ # tubelet masking without tube mask
58
+ # if too many masked patches, we unmask some patches
59
+ if current_masks > self.total_masks:
60
+ picked_index = masked_object_patches_index[np.random.choice(masked_object_patches_index.size, size=int(num_diff), replace=False)]
61
+ combined_mask[picked_index] = 0.
62
+ # if too few masked patches, we first try to mask non-object patches, if not enough, then we mask protected patches
63
+ elif current_masks < self.total_masks:
64
+ if num_diff <= len(unmasked_non_object_patches_index):
65
+ picked_index = unmasked_non_object_patches_index[np.random.choice(unmasked_non_object_patches_index.size, size=int(num_diff), replace=False)]
66
+ combined_mask[picked_index] = 1.
67
+ else:
68
+ combined_mask[unmasked_non_object_patches_index] = 1.
69
+ picked_index = unmasked_object_patches_index[np.random.choice(unmasked_object_patches_index.size, size=int(num_diff - len(unmasked_non_object_patches_index)), replace=False)]
70
+ combined_mask[picked_index] = 1.
71
+ else:
72
+ # if too many masked patches, we first try to unmask tube masked patches, if not enough, then we unmask object patches
73
+ tube_masked_non_object_index = np.array(list(set(tube_masked_index) - set(masked_object_patches_index) - set(unmasked_object_patches_index)))
74
+ if current_masks > self.total_masks:
75
+ if num_diff <= len(tube_masked_non_object_index):
76
+ picked_index = tube_masked_non_object_index[np.random.choice(tube_masked_non_object_index.size, size=int(num_diff), replace=False)]
77
+ combined_mask[picked_index] = 0.
78
+ else:
79
+ combined_mask[tube_masked_non_object_index] = 0.
80
+ picked_index = masked_object_patches_index[np.random.choice(masked_object_patches_index.size, size=int(num_diff - len(tube_masked_non_object_index)), replace=False)]
81
+ combined_mask[picked_index] = 0.
82
+ # if too few masked patches, we first try to mask non-object patches, if not enough, then we mask protected patches
83
+ elif current_masks < self.total_masks:
84
+ tube_unmasked_non_object_index = np.array(list(set(tube_unmasked_index) - set(masked_object_patches_index) - set(unmasked_object_patches_index)))
85
+ if num_diff <= len(tube_unmasked_non_object_index):
86
+ picked_index = tube_unmasked_non_object_index[np.random.choice(tube_unmasked_non_object_index.size, size=int(num_diff), replace=False)]
87
+ combined_mask[picked_index] = 1.
88
+ else:
89
+ combined_mask[tube_unmasked_non_object_index] = 1.
90
+ picked_index = unmasked_object_patches_index[np.random.choice(unmasked_object_patches_index.size, size=int(num_diff - len(tube_unmasked_non_object_index)), replace=False)]
91
+ combined_mask[picked_index] = 1.
92
+
93
+ balanced_mask = combined_mask
94
+ return balanced_mask
95
+
96
+ def __repr__(self):
97
+ repr_str = "Maks: total patches {}, mask patches {}".format(
98
+ self.total_patches, self.total_masks
99
+ )
100
+ return repr_str
101
+
102
+ # 1 in mask array means masked, 0 means unmasked
103
+ def __call__(self, traj_rois):
104
+ # generate original VideoMAE tube mask and intialize the tube mask index
105
+ tube_mask = self.tube_masking_generator()
106
+ tube_masked_index = None
107
+ tube_unmasked_index = None
108
+
109
+ # initialize mask
110
+ num_tubelet, num_frame, box = traj_rois.shape
111
+ assert num_frame % 2 == 0 and self.frames == (num_frame // 2)
112
+ combined_mask = np.zeros((num_frame // 2, self.height, self.width))
113
+ # assume patch size is (2, 16, 16) so mask shape should be (8, 14, 14)
114
+ # we combine the traj_rois of two consecutive frames to one large traj_rois
115
+
116
+ # pick one tubelet that is not masked
117
+ if self.visible_frames is None:
118
+ picked_frame = np.random.randint(0, (num_frame // 2))
119
+ picked_list = [picked_frame]
120
+ else:
121
+ picked_list = self.visible_frames
122
+
123
+ # combined mask 1 means object patches that should be masked, 2 means object patches that should not be masked, 0 means non-object patches
124
+ for roi_idx, roi in enumerate(traj_rois):
125
+ for i in range(num_frame // 2):
126
+ min_x = min( (roi[2 * i][0], roi[2 * i + 1][0]) )
127
+ max_x = max( (roi[2 * i][2], roi[2 * i + 1][2]) )
128
+ min_y = min( (roi[2 * i][1], roi[2 * i + 1][1]) )
129
+ max_y = max( (roi[2 * i][3], roi[2 * i + 1][3]) )
130
+
131
+ patch_index_x_min = max( int(np.floor(min_x / self.patch_size)), 0)
132
+ patch_index_x_max = min( int(np.ceil(max_x / self.patch_size)) + 1, 14)
133
+ patch_index_y_min = max( int(np.floor(min_y / self.patch_size)), 0)
134
+ patch_index_y_max = min( int(np.ceil(max_y / self.patch_size)) + 1, 14)
135
+
136
+ if i in picked_list:
137
+ combined_mask[i][patch_index_y_min:patch_index_y_max, patch_index_x_min:patch_index_x_max] = 2.
138
+ else:
139
+ combined_mask[i][patch_index_y_min:patch_index_y_max, patch_index_x_min:patch_index_x_max] = 1.
140
+
141
+ combined_mask = combined_mask.flatten()
142
+ masked_object_patches_index = np.where(combined_mask == 1.)[0]
143
+ unmasked_non_object_patches_index = np.where(combined_mask == 0.)[0]
144
+ unmasked_object_patches_index = np.where(combined_mask == 2.)[0]
145
+ combined_mask[unmasked_object_patches_index] = 0.
146
+
147
+ tube_masked_index = np.where(tube_mask == 1.)[0]
148
+ tube_unmasked_index = np.where(tube_mask == 0.)[0]
149
+
150
+ # combine tubelet mask and tube mask
151
+ combined_mask = np.bitwise_or(combined_mask.astype(bool), tube_mask.astype(bool)).astype(np.float32)
152
+
153
+ if self.mask_type == "tube+picked_frame_visible":
154
+ # unmasked the protected patches
155
+ combined_mask[unmasked_object_patches_index] = 0.
156
+
157
+ elif self.mask_type == "tube+traj_mask":
158
+ # get index of unmasked traj patches
159
+ traj_unmask_ratio = self.traj_unmask_ratio
160
+ traj_patches_index = np.array(list(set(masked_object_patches_index) | set(unmasked_object_patches_index)))
161
+ unmasked_traj_patches_index = traj_patches_index[np.random.choice(traj_patches_index.size, size=int(traj_unmask_ratio * len(traj_patches_index)), replace=False)]
162
+
163
+ # mask the whole traj
164
+ combined_mask[traj_patches_index] = 1.
165
+ # unmask those selected patches
166
+ combined_mask[unmasked_traj_patches_index] = 0.
167
+
168
+ # update indexes
169
+ unmasked_object_patches_index = unmasked_traj_patches_index
170
+ masked_object_patches_index = np.array(list(set(traj_patches_index) - set(unmasked_traj_patches_index)))
171
+
172
+
173
+ # balance masked patch number
174
+ mask = self._balance_num_masks(combined_mask,
175
+ unmasked_object_patches_index,
176
+ unmasked_non_object_patches_index,
177
+ masked_object_patches_index,
178
+ tube_masked_index,
179
+ tube_unmasked_index)
180
+
181
+
182
+ assert np.sum(mask) == self.total_masks
183
+ return mask
184
+
185
+
mixup.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Mixup and Cutmix
2
+
3
+ Papers:
4
+ mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5
+
6
+ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7
+
8
+ Code Reference:
9
+ CutMix: https://github.com/clovaai/CutMix-PyTorch
10
+
11
+ Hacked together by / Copyright 2019, Ross Wightman
12
+ """
13
+ import numpy as np
14
+ import torch
15
+
16
+
17
+ def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
18
+ x = x.long().view(-1, 1)
19
+ return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
20
+
21
+
22
+ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
23
+ off_value = smoothing / num_classes
24
+ on_value = 1. - smoothing + off_value
25
+ y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
26
+ y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
27
+ return y1 * lam + y2 * (1. - lam)
28
+
29
+
30
+ def rand_bbox(img_shape, lam, margin=0., count=None):
31
+ """ Standard CutMix bounding-box
32
+ Generates a random square bbox based on lambda value. This impl includes
33
+ support for enforcing a border margin as percent of bbox dimensions.
34
+
35
+ Args:
36
+ img_shape (tuple): Image shape as tuple
37
+ lam (float): Cutmix lambda value
38
+ margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39
+ count (int): Number of bbox to generate
40
+ """
41
+ ratio = np.sqrt(1 - lam)
42
+ img_h, img_w = img_shape[-2:]
43
+ cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
44
+ margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
45
+ cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
46
+ cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
47
+ yl = np.clip(cy - cut_h // 2, 0, img_h)
48
+ yh = np.clip(cy + cut_h // 2, 0, img_h)
49
+ xl = np.clip(cx - cut_w // 2, 0, img_w)
50
+ xh = np.clip(cx + cut_w // 2, 0, img_w)
51
+ return yl, yh, xl, xh
52
+
53
+
54
+ def rand_bbox_minmax(img_shape, minmax, count=None):
55
+ """ Min-Max CutMix bounding-box
56
+ Inspired by Darknet cutmix impl, generates a random rectangular bbox
57
+ based on min/max percent values applied to each dimension of the input image.
58
+
59
+ Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
60
+
61
+ Args:
62
+ img_shape (tuple): Image shape as tuple
63
+ minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64
+ count (int): Number of bbox to generate
65
+ """
66
+ assert len(minmax) == 2
67
+ img_h, img_w = img_shape[-2:]
68
+ cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
69
+ cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
70
+ yl = np.random.randint(0, img_h - cut_h, size=count)
71
+ xl = np.random.randint(0, img_w - cut_w, size=count)
72
+ yu = yl + cut_h
73
+ xu = xl + cut_w
74
+ return yl, yu, xl, xu
75
+
76
+
77
+ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
78
+ """ Generate bbox and apply lambda correction.
79
+ """
80
+ if ratio_minmax is not None:
81
+ yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
82
+ else:
83
+ yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
84
+ if correct_lam or ratio_minmax is not None:
85
+ bbox_area = (yu - yl) * (xu - xl)
86
+ lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
87
+ return (yl, yu, xl, xu), lam
88
+
89
+
90
+ class Mixup:
91
+ """ Mixup/Cutmix that applies different params to each element or whole batch
92
+
93
+ Args:
94
+ mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95
+ cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96
+ cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97
+ prob (float): probability of applying mixup or cutmix per batch or element
98
+ switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99
+ mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
100
+ correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101
+ label_smoothing (float): apply label smoothing to the mixed target tensor
102
+ num_classes (int): number of classes for target
103
+ """
104
+ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
105
+ mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
106
+ self.mixup_alpha = mixup_alpha
107
+ self.cutmix_alpha = cutmix_alpha
108
+ self.cutmix_minmax = cutmix_minmax
109
+ if self.cutmix_minmax is not None:
110
+ assert len(self.cutmix_minmax) == 2
111
+ # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
112
+ self.cutmix_alpha = 1.0
113
+ self.mix_prob = prob
114
+ self.switch_prob = switch_prob
115
+ self.label_smoothing = label_smoothing
116
+ self.num_classes = num_classes
117
+ self.mode = mode
118
+ self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119
+ self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
120
+
121
+ def _params_per_elem(self, batch_size):
122
+ lam = np.ones(batch_size, dtype=np.float32)
123
+ use_cutmix = np.zeros(batch_size, dtype=np.bool)
124
+ if self.mixup_enabled:
125
+ if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
126
+ use_cutmix = np.random.rand(batch_size) < self.switch_prob
127
+ lam_mix = np.where(
128
+ use_cutmix,
129
+ np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
130
+ np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
131
+ elif self.mixup_alpha > 0.:
132
+ lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
133
+ elif self.cutmix_alpha > 0.:
134
+ use_cutmix = np.ones(batch_size, dtype=np.bool)
135
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
136
+ else:
137
+ assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
138
+ lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
139
+ return lam, use_cutmix
140
+
141
+ def _params_per_batch(self):
142
+ lam = 1.
143
+ use_cutmix = False
144
+ if self.mixup_enabled and np.random.rand() < self.mix_prob:
145
+ if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
146
+ use_cutmix = np.random.rand() < self.switch_prob
147
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
148
+ np.random.beta(self.mixup_alpha, self.mixup_alpha)
149
+ elif self.mixup_alpha > 0.:
150
+ lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
151
+ elif self.cutmix_alpha > 0.:
152
+ use_cutmix = True
153
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
154
+ else:
155
+ assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
156
+ lam = float(lam_mix)
157
+ return lam, use_cutmix
158
+
159
+ def _mix_elem(self, x):
160
+ batch_size = len(x)
161
+ lam_batch, use_cutmix = self._params_per_elem(batch_size)
162
+ x_orig = x.clone() # need to keep an unmodified original for mixing source
163
+ for i in range(batch_size):
164
+ j = batch_size - i - 1
165
+ lam = lam_batch[i]
166
+ if lam != 1.:
167
+ if use_cutmix[i]:
168
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
169
+ x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
170
+ x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh]
171
+ lam_batch[i] = lam
172
+ else:
173
+ x[i] = x[i] * lam + x_orig[j] * (1 - lam)
174
+ return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
175
+
176
+ def _mix_pair(self, x):
177
+ batch_size = len(x)
178
+ lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
179
+ x_orig = x.clone() # need to keep an unmodified original for mixing source
180
+ for i in range(batch_size // 2):
181
+ j = batch_size - i - 1
182
+ lam = lam_batch[i]
183
+ if lam != 1.:
184
+ if use_cutmix[i]:
185
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
186
+ x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
187
+ x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
188
+ x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
189
+ lam_batch[i] = lam
190
+ else:
191
+ x[i] = x[i] * lam + x_orig[j] * (1 - lam)
192
+ x[j] = x[j] * lam + x_orig[i] * (1 - lam)
193
+ lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
194
+ return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
195
+
196
+ def _mix_batch(self, x):
197
+ lam, use_cutmix = self._params_per_batch()
198
+ if lam == 1.:
199
+ return 1.
200
+ if use_cutmix:
201
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
202
+ x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
203
+ x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh]
204
+ else:
205
+ x_flipped = x.flip(0).mul_(1. - lam)
206
+ x.mul_(lam).add_(x_flipped)
207
+ return lam
208
+
209
+ def __call__(self, x, target):
210
+ assert len(x) % 2 == 0, 'Batch size should be even when using this'
211
+ if self.mode == 'elem':
212
+ lam = self._mix_elem(x)
213
+ elif self.mode == 'pair':
214
+ lam = self._mix_pair(x)
215
+ else:
216
+ lam = self._mix_batch(x)
217
+ target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
218
+ return x, target
219
+
220
+
221
+ class FastCollateMixup(Mixup):
222
+ """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
223
+
224
+ A Mixup impl that's performed while collating the batches.
225
+ """
226
+
227
+ def _mix_elem_collate(self, output, batch, half=False):
228
+ batch_size = len(batch)
229
+ num_elem = batch_size // 2 if half else batch_size
230
+ assert len(output) == num_elem
231
+ lam_batch, use_cutmix = self._params_per_elem(num_elem)
232
+ for i in range(num_elem):
233
+ j = batch_size - i - 1
234
+ lam = lam_batch[i]
235
+ mixed = batch[i][0]
236
+ if lam != 1.:
237
+ if use_cutmix[i]:
238
+ if not half:
239
+ mixed = mixed.copy()
240
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
241
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
242
+ mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
243
+ lam_batch[i] = lam
244
+ else:
245
+ mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
246
+ np.rint(mixed, out=mixed)
247
+ output[i] += torch.from_numpy(mixed.astype(np.uint8))
248
+ if half:
249
+ lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
250
+ return torch.tensor(lam_batch).unsqueeze(1)
251
+
252
+ def _mix_pair_collate(self, output, batch):
253
+ batch_size = len(batch)
254
+ lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
255
+ for i in range(batch_size // 2):
256
+ j = batch_size - i - 1
257
+ lam = lam_batch[i]
258
+ mixed_i = batch[i][0]
259
+ mixed_j = batch[j][0]
260
+ assert 0 <= lam <= 1.0
261
+ if lam < 1.:
262
+ if use_cutmix[i]:
263
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
264
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
265
+ patch_i = mixed_i[:, yl:yh, xl:xh].copy()
266
+ mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
267
+ mixed_j[:, yl:yh, xl:xh] = patch_i
268
+ lam_batch[i] = lam
269
+ else:
270
+ mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
271
+ mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
272
+ mixed_i = mixed_temp
273
+ np.rint(mixed_j, out=mixed_j)
274
+ np.rint(mixed_i, out=mixed_i)
275
+ output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
276
+ output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
277
+ lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
278
+ return torch.tensor(lam_batch).unsqueeze(1)
279
+
280
+ def _mix_batch_collate(self, output, batch):
281
+ batch_size = len(batch)
282
+ lam, use_cutmix = self._params_per_batch()
283
+ if use_cutmix:
284
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
285
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
286
+ for i in range(batch_size):
287
+ j = batch_size - i - 1
288
+ mixed = batch[i][0]
289
+ if lam != 1.:
290
+ if use_cutmix:
291
+ mixed = mixed.copy() # don't want to modify the original while iterating
292
+ mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh]
293
+ else:
294
+ mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
295
+ np.rint(mixed, out=mixed)
296
+ output[i] += torch.from_numpy(mixed.astype(np.uint8))
297
+ return lam
298
+
299
+ def __call__(self, batch, _=None):
300
+ batch_size = len(batch)
301
+ assert batch_size % 2 == 0, 'Batch size should be even when using this'
302
+ half = 'half' in self.mode
303
+ if half:
304
+ batch_size //= 2
305
+ output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
306
+ if self.mode == 'elem' or self.mode == 'half':
307
+ lam = self._mix_elem_collate(output, batch, half=half)
308
+ elif self.mode == 'pair':
309
+ lam = self._mix_pair_collate(output, batch)
310
+ else:
311
+ lam = self._mix_batch_collate(output, batch)
312
+ target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
313
+ target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
314
+ target = target[:batch_size]
315
+ return output, target
316
+
modeling_finetune.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
7
+ from timm.models.registry import register_model
8
+ import torch.utils.checkpoint as checkpoint
9
+
10
+
11
+ def _cfg(url='', **kwargs):
12
+ return {
13
+ 'url': url,
14
+ 'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None,
15
+ 'crop_pct': .9, 'interpolation': 'bicubic',
16
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
17
+ **kwargs
18
+ }
19
+
20
+
21
+ class DropPath(nn.Module):
22
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
23
+ """
24
+ def __init__(self, drop_prob=None):
25
+ super(DropPath, self).__init__()
26
+ self.drop_prob = drop_prob
27
+
28
+ def forward(self, x):
29
+ return drop_path(x, self.drop_prob, self.training)
30
+
31
+ def extra_repr(self) -> str:
32
+ return 'p={}'.format(self.drop_prob)
33
+
34
+
35
+ class Mlp(nn.Module):
36
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
37
+ super().__init__()
38
+ out_features = out_features or in_features
39
+ hidden_features = hidden_features or in_features
40
+ self.fc1 = nn.Linear(in_features, hidden_features)
41
+ self.act = act_layer()
42
+ self.fc2 = nn.Linear(hidden_features, out_features)
43
+ self.drop = nn.Dropout(drop)
44
+
45
+ def forward(self, x):
46
+ x = self.fc1(x)
47
+ x = self.act(x)
48
+ # x = self.drop(x)
49
+ # commit this for the orignal BERT implement
50
+ x = self.fc2(x)
51
+ x = self.drop(x)
52
+ return x
53
+
54
+
55
+ class Attention(nn.Module):
56
+ def __init__(
57
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
58
+ proj_drop=0., attn_head_dim=None):
59
+ super().__init__()
60
+ self.num_heads = num_heads
61
+ head_dim = dim // num_heads
62
+ if attn_head_dim is not None:
63
+ head_dim = attn_head_dim
64
+ all_head_dim = head_dim * self.num_heads
65
+ self.scale = qk_scale or head_dim ** -0.5
66
+
67
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
68
+ if qkv_bias:
69
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
70
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
71
+ else:
72
+ self.q_bias = None
73
+ self.v_bias = None
74
+
75
+ self.attn_drop = nn.Dropout(attn_drop)
76
+ self.proj = nn.Linear(all_head_dim, dim)
77
+ self.proj_drop = nn.Dropout(proj_drop)
78
+
79
+ def forward(self, x):
80
+ B, N, C = x.shape
81
+ qkv_bias = None
82
+ if self.q_bias is not None:
83
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
84
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
85
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
86
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
87
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
88
+
89
+ q = q * self.scale
90
+ attn = (q @ k.transpose(-2, -1))
91
+
92
+
93
+ attn = attn.softmax(dim=-1)
94
+ attn = self.attn_drop(attn)
95
+
96
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
97
+ x = self.proj(x)
98
+ x = self.proj_drop(x)
99
+ return x
100
+
101
+
102
+ class Block(nn.Module):
103
+
104
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
105
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
106
+ attn_head_dim=None):
107
+ super().__init__()
108
+ self.norm1 = norm_layer(dim)
109
+ self.attn = Attention(
110
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
111
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
112
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
113
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
114
+ self.norm2 = norm_layer(dim)
115
+ mlp_hidden_dim = int(dim * mlp_ratio)
116
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
117
+
118
+ if init_values > 0:
119
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
120
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
121
+ else:
122
+ self.gamma_1, self.gamma_2 = None, None
123
+
124
+ def forward(self, x):
125
+ if self.gamma_1 is None:
126
+ x = x + self.drop_path(self.attn(self.norm1(x)))
127
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
128
+ else:
129
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
130
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
131
+ return x
132
+
133
+
134
+ class PatchEmbed(nn.Module):
135
+ """ Image to Patch Embedding
136
+ """
137
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
138
+ super().__init__()
139
+ img_size = to_2tuple(img_size)
140
+ patch_size = to_2tuple(patch_size)
141
+ self.tubelet_size = int(tubelet_size)
142
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
143
+ self.img_size = img_size
144
+ self.patch_size = patch_size
145
+ self.num_patches = num_patches
146
+ self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim,
147
+ kernel_size = (self.tubelet_size, patch_size[0],patch_size[1]),
148
+ stride=(self.tubelet_size, patch_size[0], patch_size[1]))
149
+
150
+ def forward(self, x, **kwargs):
151
+ B, C, T, H, W = x.shape
152
+ # FIXME look at relaxing size constraints
153
+ assert H == self.img_size[0] and W == self.img_size[1], \
154
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
155
+ x = self.proj(x).flatten(2).transpose(1, 2)
156
+ return x
157
+
158
+ # sin-cos position encoding
159
+ # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
160
+ def get_sinusoid_encoding_table(n_position, d_hid):
161
+ ''' Sinusoid position encoding table '''
162
+ # TODO: make it with torch instead of numpy
163
+ def get_position_angle_vec(position):
164
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
165
+
166
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
167
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
168
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
169
+
170
+ return torch.tensor(sinusoid_table,dtype=torch.float, requires_grad=False).unsqueeze(0)
171
+
172
+
173
+ class VisionTransformer(nn.Module):
174
+ """ Vision Transformer with support for patch or hybrid CNN input stage
175
+ """
176
+ def __init__(self,
177
+ img_size=224,
178
+ patch_size=16,
179
+ in_chans=3,
180
+ num_classes=1000,
181
+ embed_dim=768,
182
+ depth=12,
183
+ num_heads=12,
184
+ mlp_ratio=4.,
185
+ qkv_bias=False,
186
+ qk_scale=None,
187
+ fc_drop_rate=0.,
188
+ drop_rate=0.,
189
+ attn_drop_rate=0.,
190
+ drop_path_rate=0.,
191
+ norm_layer=nn.LayerNorm,
192
+ init_values=0.,
193
+ use_learnable_pos_emb=False,
194
+ init_scale=0.,
195
+ all_frames=16,
196
+ tubelet_size=2,
197
+ use_checkpoint=False,
198
+ use_mean_pooling=True,
199
+ pretrained_cfg=None,
200
+ pretrained_cfg_overlay = None
201
+ ):
202
+ super().__init__()
203
+ self.num_classes = num_classes
204
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
205
+ self.tubelet_size = tubelet_size
206
+ self.patch_embed = PatchEmbed(
207
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=all_frames, tubelet_size=self.tubelet_size)
208
+ num_patches = self.patch_embed.num_patches
209
+ self.use_checkpoint = use_checkpoint
210
+
211
+ if use_learnable_pos_emb:
212
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
213
+ else:
214
+ # sine-cosine positional embeddings is on the way
215
+ self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
216
+
217
+ self.pos_drop = nn.Dropout(p=drop_rate)
218
+
219
+
220
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
221
+ self.blocks = nn.ModuleList([
222
+ Block(
223
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
224
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
225
+ init_values=init_values)
226
+ for i in range(depth)])
227
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
228
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
229
+ self.fc_dropout = nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity()
230
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
231
+
232
+ if use_learnable_pos_emb:
233
+ trunc_normal_(self.pos_embed, std=.02)
234
+
235
+ trunc_normal_(self.head.weight, std=.02)
236
+ self.apply(self._init_weights)
237
+
238
+ self.head.weight.data.mul_(init_scale)
239
+ self.head.bias.data.mul_(init_scale)
240
+
241
+ def _init_weights(self, m):
242
+ if isinstance(m, nn.Linear):
243
+ trunc_normal_(m.weight, std=.02)
244
+ if isinstance(m, nn.Linear) and m.bias is not None:
245
+ nn.init.constant_(m.bias, 0)
246
+ elif isinstance(m, nn.LayerNorm):
247
+ nn.init.constant_(m.bias, 0)
248
+ nn.init.constant_(m.weight, 1.0)
249
+
250
+ def get_num_layers(self):
251
+ return len(self.blocks)
252
+
253
+ @torch.jit.ignore
254
+ def no_weight_decay(self):
255
+ return {'pos_embed', 'cls_token'}
256
+
257
+ def get_classifier(self):
258
+ return self.head
259
+
260
+ def reset_classifier(self, num_classes, global_pool=''):
261
+ self.num_classes = num_classes
262
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
263
+
264
+ def forward_features(self, x):
265
+ x = self.patch_embed(x)
266
+ B, _, _ = x.size()
267
+
268
+ if self.pos_embed is not None:
269
+ x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
270
+ x = self.pos_drop(x)
271
+
272
+ if self.use_checkpoint:
273
+ for blk in self.blocks:
274
+ x = checkpoint.checkpoint(blk, x)
275
+ else:
276
+ for blk in self.blocks:
277
+ x = blk(x)
278
+
279
+ x = self.norm(x)
280
+ if self.fc_norm is not None:
281
+ return self.fc_norm(x.mean(1))
282
+ else:
283
+ return x[:, 0]
284
+
285
+ def forward(self, x):
286
+ x = self.forward_features(x)
287
+ x = self.head(self.fc_dropout(x))
288
+ return x
289
+
290
+
291
+ @register_model
292
+ def vit_small_patch16_224(pretrained=False, **kwargs):
293
+ model = VisionTransformer(
294
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
295
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
296
+ model.default_cfg = _cfg()
297
+ return model
298
+
299
+
300
+ @register_model
301
+ def vit_base_patch16_224(pretrained=False, **kwargs):
302
+ model = VisionTransformer(
303
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
304
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
305
+ model.default_cfg = _cfg()
306
+ return model
307
+
308
+
309
+ @register_model
310
+ def vit_base_patch16_384(pretrained=False, **kwargs):
311
+ model = VisionTransformer(
312
+ img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
313
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
314
+ model.default_cfg = _cfg()
315
+ return model
316
+
317
+
318
+ @register_model
319
+ def vit_large_patch16_224(pretrained=False, **kwargs):
320
+ model = VisionTransformer(
321
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
322
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
323
+ model.default_cfg = _cfg()
324
+ return model
325
+
326
+
327
+ @register_model
328
+ def vit_large_patch16_384(pretrained=False, **kwargs):
329
+ model = VisionTransformer(
330
+ img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
331
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
332
+ model.default_cfg = _cfg()
333
+ return model
334
+
335
+
336
+ @register_model
337
+ def vit_large_patch16_512(pretrained=False, **kwargs):
338
+ model = VisionTransformer(
339
+ img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
340
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
341
+ model.default_cfg = _cfg()
342
+ return model
343
+
344
+
345
+ @register_model
346
+ def vit_huge_patch16_224(pretrained=False, **kwargs):
347
+ model = VisionTransformer(
348
+ patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
349
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
350
+ model.default_cfg = _cfg()
351
+ return model
modeling_pretrain.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint as checkpoint
6
+ from functools import partial
7
+
8
+ from modeling_finetune import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table
9
+ from timm.models.registry import register_model
10
+ from timm.models.layers import trunc_normal_ as __call_trunc_normal_
11
+
12
+
13
+
14
+ def trunc_normal_(tensor, mean=0., std=1.):
15
+ __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
16
+
17
+
18
+ __all__ = [
19
+ 'pretrain_videomae_small_patch16_224',
20
+ 'pretrain_videomae_base_patch16_224',
21
+ 'pretrain_videomae_large_patch16_224',
22
+ 'pretrain_videomae_huge_patch16_224',
23
+ ]
24
+
25
+
26
+ class PretrainVisionTransformerEncoder(nn.Module):
27
+ """ Vision Transformer with support for patch or hybrid CNN input stage
28
+ """
29
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
30
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
31
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2, use_checkpoint=False,
32
+ use_learnable_pos_emb=False):
33
+ super().__init__()
34
+ self.num_classes = num_classes
35
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
36
+ self.patch_embed = PatchEmbed(
37
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,tubelet_size=tubelet_size)
38
+ num_patches = self.patch_embed.num_patches
39
+ self.use_checkpoint = use_checkpoint
40
+
41
+
42
+ # TODO: Add the cls token
43
+ if use_learnable_pos_emb:
44
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
45
+ else:
46
+ # sine-cosine positional embeddings
47
+ self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
48
+
49
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
50
+ self.blocks = nn.ModuleList([
51
+ Block(
52
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
53
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
54
+ init_values=init_values)
55
+ for i in range(depth)])
56
+ self.norm = norm_layer(embed_dim)
57
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
58
+
59
+ if use_learnable_pos_emb:
60
+ trunc_normal_(self.pos_embed, std=.02)
61
+
62
+ self.apply(self._init_weights)
63
+
64
+
65
+ def _init_weights(self, m):
66
+ if isinstance(m, nn.Linear):
67
+ nn.init.xavier_uniform_(m.weight)
68
+ if isinstance(m, nn.Linear) and m.bias is not None:
69
+ nn.init.constant_(m.bias, 0)
70
+ elif isinstance(m, nn.LayerNorm):
71
+ nn.init.constant_(m.bias, 0)
72
+ nn.init.constant_(m.weight, 1.0)
73
+
74
+ def get_num_layers(self):
75
+ return len(self.blocks)
76
+
77
+ @torch.jit.ignore
78
+ def no_weight_decay(self):
79
+ return {'pos_embed', 'cls_token'}
80
+
81
+ def get_classifier(self):
82
+ return self.head
83
+
84
+ def reset_classifier(self, num_classes, global_pool=''):
85
+ self.num_classes = num_classes
86
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
87
+
88
+ def forward_features(self, x, mask):
89
+ _, _, T, _, _ = x.shape
90
+ x = self.patch_embed(x)
91
+
92
+ x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
93
+
94
+ B, _, C = x.shape
95
+ x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
96
+
97
+ if self.use_checkpoint:
98
+ for blk in self.blocks:
99
+ x_vis = checkpoint.checkpoint(blk, x_vis)
100
+ else:
101
+ for blk in self.blocks:
102
+ x_vis = blk(x_vis)
103
+
104
+ x_vis = self.norm(x_vis)
105
+ return x_vis
106
+
107
+ def forward(self, x, mask):
108
+ x = self.forward_features(x, mask)
109
+ x = self.head(x)
110
+ return x
111
+
112
+ class PretrainVisionTransformerDecoder(nn.Module):
113
+ """ Vision Transformer with support for patch or hybrid CNN input stage
114
+ """
115
+ def __init__(self, patch_size=16, num_classes=768, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
116
+ qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
117
+ norm_layer=nn.LayerNorm, init_values=None, num_patches=196, tubelet_size=2, use_checkpoint=False
118
+ ):
119
+ super().__init__()
120
+ self.num_classes = num_classes
121
+ #assert num_classes == 3 * tubelet_size * patch_size ** 2
122
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
123
+ self.patch_size = patch_size
124
+ self.use_checkpoint = use_checkpoint
125
+
126
+
127
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
128
+ self.blocks = nn.ModuleList([
129
+ Block(
130
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
131
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
132
+ init_values=init_values)
133
+ for i in range(depth)])
134
+ self.norm = norm_layer(embed_dim)
135
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
136
+
137
+ self.apply(self._init_weights)
138
+
139
+
140
+ def _init_weights(self, m):
141
+ if isinstance(m, nn.Linear):
142
+ nn.init.xavier_uniform_(m.weight)
143
+ if isinstance(m, nn.Linear) and m.bias is not None:
144
+ nn.init.constant_(m.bias, 0)
145
+ elif isinstance(m, nn.LayerNorm):
146
+ nn.init.constant_(m.bias, 0)
147
+ nn.init.constant_(m.weight, 1.0)
148
+
149
+ def get_num_layers(self):
150
+ return len(self.blocks)
151
+
152
+ @torch.jit.ignore
153
+ def no_weight_decay(self):
154
+ return {'pos_embed', 'cls_token'}
155
+
156
+ def get_classifier(self):
157
+ return self.head
158
+
159
+ def reset_classifier(self, num_classes, global_pool=''):
160
+ self.num_classes = num_classes
161
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
162
+
163
+ def forward(self, x, return_token_num):
164
+ if self.use_checkpoint:
165
+ for blk in self.blocks:
166
+ x = checkpoint.checkpoint(blk, x)
167
+ else:
168
+ for blk in self.blocks:
169
+ x = blk(x)
170
+
171
+ if return_token_num > 0:
172
+ x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
173
+ else:
174
+ x = self.head(self.norm(x))
175
+
176
+ return x
177
+
178
+ class FeatureExtractor(torch.nn.Module):
179
+ def __init__(self, vit_model, input_size, patch_size):
180
+ super(FeatureExtractor, self).__init__()
181
+ self.vit_model = vit_model
182
+ self.input_size = input_size
183
+ self.patch_size = patch_size
184
+ self.spatial_resolution = input_size // patch_size
185
+ assert self.spatial_resolution * patch_size == input_size
186
+
187
+ def forward(self, x):
188
+ if self.patch_size == 14:
189
+ features = self.vit_model.forward_features(x)[:, 5:]
190
+ bs, np, dim = features.shape
191
+ features = features.reshape(bs, self.spatial_resolution, self.spatial_resolution, dim).permute(0, 3, 1, 2)
192
+ features = F.interpolate(features, size=(14, 14), mode='bilinear')
193
+ features = features.flatten(2, -1).permute(0, 2, 1)
194
+ else:
195
+ features = self.vit_model.forward_features(x)[:, 1:]
196
+ return features
197
+
198
+ class PretrainVisionTransformer(nn.Module):
199
+ """ Vision Transformer with support for patch or hybrid CNN input stage
200
+ """
201
+ def __init__(self,
202
+ img_size=224,
203
+ patch_size=16,
204
+ encoder_in_chans=3,
205
+ encoder_num_classes=0,
206
+ encoder_embed_dim=768,
207
+ encoder_depth=12,
208
+ encoder_num_heads=12,
209
+ decoder_num_classes=1536, # decoder_num_classes=768,
210
+ decoder_embed_dim=512,
211
+ decoder_depth=8,
212
+ decoder_num_heads=8,
213
+ mlp_ratio=4.,
214
+ qkv_bias=False,
215
+ qk_scale=None,
216
+ drop_rate=0.,
217
+ attn_drop_rate=0.,
218
+ drop_path_rate=0.,
219
+ norm_layer=nn.LayerNorm,
220
+ init_values=0.,
221
+ use_learnable_pos_emb=False,
222
+ use_checkpoint=False,
223
+ tubelet_size=2,
224
+ num_classes=0, # avoid the error from create_fn in timm
225
+ in_chans=0, # avoid the error from create_fn in timm
226
+ pretrained_cfg=None, # avoid the error from create_fn in timm
227
+ pretrained_cfg_overlay=None, # avoid the error from create_fn in timm
228
+ ):
229
+ super().__init__()
230
+ self.encoder = PretrainVisionTransformerEncoder(
231
+ img_size=img_size,
232
+ patch_size=patch_size,
233
+ in_chans=encoder_in_chans,
234
+ num_classes=encoder_num_classes,
235
+ embed_dim=encoder_embed_dim,
236
+ depth=encoder_depth,
237
+ num_heads=encoder_num_heads,
238
+ mlp_ratio=mlp_ratio,
239
+ qkv_bias=qkv_bias,
240
+ qk_scale=qk_scale,
241
+ drop_rate=drop_rate,
242
+ attn_drop_rate=attn_drop_rate,
243
+ drop_path_rate=drop_path_rate,
244
+ norm_layer=norm_layer,
245
+ init_values=init_values,
246
+ tubelet_size=tubelet_size,
247
+ use_checkpoint=use_checkpoint,
248
+ use_learnable_pos_emb=use_learnable_pos_emb)
249
+
250
+ self.decoder = PretrainVisionTransformerDecoder(
251
+ patch_size=patch_size,
252
+ num_patches=self.encoder.patch_embed.num_patches,
253
+ num_classes=decoder_num_classes,
254
+ embed_dim=decoder_embed_dim,
255
+ depth=decoder_depth,
256
+ num_heads=decoder_num_heads,
257
+ mlp_ratio=mlp_ratio,
258
+ qkv_bias=qkv_bias,
259
+ qk_scale=qk_scale,
260
+ drop_rate=drop_rate,
261
+ attn_drop_rate=attn_drop_rate,
262
+ drop_path_rate=drop_path_rate,
263
+ norm_layer=norm_layer,
264
+ init_values=init_values,
265
+ tubelet_size=tubelet_size,
266
+ use_checkpoint=use_checkpoint)
267
+
268
+ self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False)
269
+
270
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
271
+
272
+ self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim)
273
+
274
+ trunc_normal_(self.mask_token, std=.02)
275
+
276
+
277
+ def _init_weights(self, m):
278
+ if isinstance(m, nn.Linear):
279
+ nn.init.xavier_uniform_(m.weight)
280
+ if isinstance(m, nn.Linear) and m.bias is not None:
281
+ nn.init.constant_(m.bias, 0)
282
+ elif isinstance(m, nn.LayerNorm):
283
+ nn.init.constant_(m.bias, 0)
284
+ nn.init.constant_(m.weight, 1.0)
285
+
286
+ def get_num_layers(self):
287
+ return len(self.blocks)
288
+
289
+ @torch.jit.ignore
290
+ def no_weight_decay(self):
291
+ return {'pos_embed', 'cls_token', 'mask_token'}
292
+
293
+ def forward(self, x, mask):
294
+ _, _, T, _, _ = x.shape
295
+ x_encoder = self.encoder(x, mask) # [B, N_vis, C_e]
296
+ x_vis = self.encoder_to_decoder(x_encoder) # [B, N_vis, C_d]
297
+ B, N, C = x_vis.shape
298
+ # we don't unshuffle the correct visible token order,
299
+ # but shuffle the pos embedding accorddingly.
300
+ expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
301
+ pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
302
+ pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
303
+ x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d]
304
+ x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]
305
+ return x
306
+
307
+ @register_model
308
+ def pretrain_videomae_small_patch16_224(pretrained=False, **kwargs):
309
+ model = PretrainVisionTransformer(
310
+ img_size=224,
311
+ patch_size=16,
312
+ encoder_embed_dim=384,
313
+ encoder_depth=12,
314
+ encoder_num_heads=6,
315
+ encoder_num_classes=0,
316
+ decoder_embed_dim=192,
317
+ decoder_num_heads=3,
318
+ mlp_ratio=4,
319
+ qkv_bias=True,
320
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
321
+ **kwargs)
322
+ model.default_cfg = _cfg()
323
+ if pretrained:
324
+ checkpoint = torch.load(
325
+ kwargs["init_ckpt"], map_location="cpu"
326
+ )
327
+ model.load_state_dict(checkpoint["model"])
328
+ return model
329
+
330
+ @register_model
331
+ def pretrain_videomae_base_patch16_224(pretrained=False, **kwargs):
332
+ model = PretrainVisionTransformer(
333
+ img_size=224,
334
+ patch_size=16,
335
+ encoder_embed_dim=768,
336
+ encoder_depth=12,
337
+ encoder_num_heads=12,
338
+ encoder_num_classes=0,
339
+ decoder_embed_dim=384,
340
+ decoder_num_heads=6,
341
+ mlp_ratio=4,
342
+ qkv_bias=True,
343
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
344
+ **kwargs)
345
+ model.default_cfg = _cfg()
346
+ if pretrained:
347
+ checkpoint = torch.load(
348
+ kwargs["init_ckpt"], map_location="cpu"
349
+ )
350
+ model.load_state_dict(checkpoint["model"])
351
+ return model
352
+
353
+ @register_model
354
+ def pretrain_videomae_large_patch16_224(pretrained=False, **kwargs):
355
+ model = PretrainVisionTransformer(
356
+ img_size=224,
357
+ patch_size=16,
358
+ encoder_embed_dim=1024,
359
+ encoder_depth=24,
360
+ encoder_num_heads=16,
361
+ encoder_num_classes=0,
362
+ decoder_embed_dim=512,
363
+ decoder_num_heads=8,
364
+ mlp_ratio=4,
365
+ qkv_bias=True,
366
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
367
+ **kwargs)
368
+ model.default_cfg = _cfg()
369
+ if pretrained:
370
+ checkpoint = torch.load(
371
+ kwargs["init_ckpt"], map_location="cpu"
372
+ )
373
+ model.load_state_dict(checkpoint["model"])
374
+ return model
375
+
376
+ @register_model
377
+ def pretrain_videomae_huge_patch16_224(pretrained=False, **kwargs):
378
+ model = PretrainVisionTransformer(
379
+ img_size=224,
380
+ patch_size=16,
381
+ encoder_embed_dim=1280,
382
+ encoder_depth=32,
383
+ encoder_num_heads=16,
384
+ encoder_num_classes=0,
385
+ decoder_num_classes=1536,
386
+ decoder_embed_dim=640,
387
+ decoder_num_heads=8,
388
+ mlp_ratio=4,
389
+ qkv_bias=True,
390
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
391
+ **kwargs)
392
+ model.default_cfg = _cfg()
393
+ if pretrained:
394
+ checkpoint = torch.load(
395
+ kwargs["init_ckpt"], map_location="cpu"
396
+ )
397
+ model.load_state_dict(checkpoint["model"])
398
+ return model
optim_factory.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import optim as optim
3
+
4
+ from timm.optim.adafactor import Adafactor
5
+ from timm.optim.adahessian import Adahessian
6
+ from timm.optim.adamp import AdamP
7
+ from timm.optim.lookahead import Lookahead
8
+ from timm.optim.nadam import Nadam
9
+ #from timm.optim.novograd import NovoGrad
10
+ from timm.optim.nvnovograd import NvNovoGrad
11
+ from timm.optim.radam import RAdam
12
+ from timm.optim.rmsprop_tf import RMSpropTF
13
+ from timm.optim.sgdp import SGDP
14
+
15
+ import json
16
+
17
+ try:
18
+ from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
19
+ has_apex = True
20
+ except ImportError:
21
+ has_apex = False
22
+
23
+
24
+ def get_num_layer_for_vit(var_name, num_max_layer):
25
+ if var_name in ("cls_token", "mask_token", "pos_embed"):
26
+ return 0
27
+ elif var_name.startswith("patch_embed"):
28
+ return 0
29
+ elif var_name.startswith("rel_pos_bias"):
30
+ return num_max_layer - 1
31
+ elif var_name.startswith("blocks"):
32
+ layer_id = int(var_name.split('.')[1])
33
+ return layer_id + 1
34
+ else:
35
+ return num_max_layer - 1
36
+
37
+
38
+ class LayerDecayValueAssigner(object):
39
+ def __init__(self, values):
40
+ self.values = values
41
+
42
+ def get_scale(self, layer_id):
43
+ return self.values[layer_id]
44
+
45
+ def get_layer_id(self, var_name):
46
+ return get_num_layer_for_vit(var_name, len(self.values))
47
+
48
+
49
+ def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
50
+ parameter_group_names = {}
51
+ parameter_group_vars = {}
52
+
53
+ for name, param in model.named_parameters():
54
+ if not param.requires_grad:
55
+ continue # frozen weights
56
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
57
+ group_name = "no_decay"
58
+ this_weight_decay = 0.
59
+ else:
60
+ group_name = "decay"
61
+ this_weight_decay = weight_decay
62
+ if get_num_layer is not None:
63
+ layer_id = get_num_layer(name)
64
+ group_name = "layer_%d_%s" % (layer_id, group_name)
65
+ else:
66
+ layer_id = None
67
+
68
+ if group_name not in parameter_group_names:
69
+ if get_layer_scale is not None:
70
+ scale = get_layer_scale(layer_id)
71
+ else:
72
+ scale = 1.
73
+
74
+ parameter_group_names[group_name] = {
75
+ "weight_decay": this_weight_decay,
76
+ "params": [],
77
+ "lr_scale": scale
78
+ }
79
+ parameter_group_vars[group_name] = {
80
+ "weight_decay": this_weight_decay,
81
+ "params": [],
82
+ "lr_scale": scale
83
+ }
84
+
85
+ parameter_group_vars[group_name]["params"].append(param)
86
+ parameter_group_names[group_name]["params"].append(name)
87
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
88
+ return list(parameter_group_vars.values())
89
+
90
+
91
+ def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
92
+ opt_lower = args.opt.lower()
93
+ weight_decay = args.weight_decay
94
+ if weight_decay and filter_bias_and_bn:
95
+ skip = {}
96
+ if skip_list is not None:
97
+ skip = skip_list
98
+ elif hasattr(model, 'no_weight_decay'):
99
+ skip = model.no_weight_decay()
100
+ parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
101
+ weight_decay = 0.
102
+ else:
103
+ parameters = model.parameters()
104
+
105
+ if 'fused' in opt_lower:
106
+ assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
107
+
108
+ opt_args = dict(lr=args.lr, weight_decay=weight_decay)
109
+ if hasattr(args, 'opt_eps') and args.opt_eps is not None:
110
+ opt_args['eps'] = args.opt_eps
111
+ if hasattr(args, 'opt_betas') and args.opt_betas is not None:
112
+ opt_args['betas'] = args.opt_betas
113
+
114
+ print("optimizer settings:", opt_args)
115
+
116
+ opt_split = opt_lower.split('_')
117
+ opt_lower = opt_split[-1]
118
+ if opt_lower == 'sgd' or opt_lower == 'nesterov':
119
+ opt_args.pop('eps', None)
120
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
121
+ elif opt_lower == 'momentum':
122
+ opt_args.pop('eps', None)
123
+ optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
124
+ elif opt_lower == 'adam':
125
+ optimizer = optim.Adam(parameters, **opt_args)
126
+ elif opt_lower == 'adamw':
127
+ optimizer = optim.AdamW(parameters, **opt_args)
128
+ elif opt_lower == 'nadam':
129
+ optimizer = Nadam(parameters, **opt_args)
130
+ elif opt_lower == 'radam':
131
+ optimizer = RAdam(parameters, **opt_args)
132
+ elif opt_lower == 'adamp':
133
+ optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
134
+ elif opt_lower == 'sgdp':
135
+ optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
136
+ elif opt_lower == 'adadelta':
137
+ optimizer = optim.Adadelta(parameters, **opt_args)
138
+ elif opt_lower == 'adafactor':
139
+ if not args.lr:
140
+ opt_args['lr'] = None
141
+ optimizer = Adafactor(parameters, **opt_args)
142
+ elif opt_lower == 'adahessian':
143
+ optimizer = Adahessian(parameters, **opt_args)
144
+ elif opt_lower == 'rmsprop':
145
+ optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
146
+ elif opt_lower == 'rmsproptf':
147
+ optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
148
+ elif opt_lower == 'novograd':
149
+ optimizer = NovoGrad(parameters, **opt_args)
150
+ elif opt_lower == 'nvnovograd':
151
+ optimizer = NvNovoGrad(parameters, **opt_args)
152
+ elif opt_lower == 'fusedsgd':
153
+ opt_args.pop('eps', None)
154
+ optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
155
+ elif opt_lower == 'fusedmomentum':
156
+ opt_args.pop('eps', None)
157
+ optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
158
+ elif opt_lower == 'fusedadam':
159
+ optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
160
+ elif opt_lower == 'fusedadamw':
161
+ optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
162
+ elif opt_lower == 'fusedlamb':
163
+ optimizer = FusedLAMB(parameters, **opt_args)
164
+ elif opt_lower == 'fusednovograd':
165
+ opt_args.setdefault('betas', (0.95, 0.98))
166
+ optimizer = FusedNovoGrad(parameters, **opt_args)
167
+ else:
168
+ assert False and "Invalid optimizer"
169
+ raise ValueError
170
+
171
+ if len(opt_split) > 1:
172
+ if opt_split[0] == 'lookahead':
173
+ optimizer = Lookahead(optimizer)
174
+
175
+ return optimizer
rand_augment.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This implementation is based on
3
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
4
+ pulished under an Apache License 2.0.
5
+
6
+ COMMENT FROM ORIGINAL:
7
+ AutoAugment, RandAugment, and AugMix for PyTorch
8
+ This code implements the searched ImageNet policies with various tweaks and
9
+ improvements and does not include any of the search code. AA and RA
10
+ Implementation adapted from:
11
+ https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
12
+ AugMix adapted from:
13
+ https://github.com/google-research/augmix
14
+ Papers:
15
+ AutoAugment: Learning Augmentation Policies from Data
16
+ https://arxiv.org/abs/1805.09501
17
+ Learning Data Augmentation Strategies for Object Detection
18
+ https://arxiv.org/abs/1906.11172
19
+ RandAugment: Practical automated data augmentation...
20
+ https://arxiv.org/abs/1909.13719
21
+ AugMix: A Simple Data Processing Method to Improve Robustness and
22
+ Uncertainty https://arxiv.org/abs/1912.02781
23
+
24
+ Hacked together by / Copyright 2020 Ross Wightman
25
+ """
26
+
27
+ import math
28
+ import numpy as np
29
+ import random
30
+ import re
31
+ import PIL
32
+ from PIL import Image, ImageEnhance, ImageOps
33
+
34
+ _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
35
+
36
+ _FILL = (128, 128, 128)
37
+
38
+ # This signifies the max integer that the controller RNN could predict for the
39
+ # augmentation scheme.
40
+ _MAX_LEVEL = 10.0
41
+
42
+ _HPARAMS_DEFAULT = {
43
+ "translate_const": 250,
44
+ "img_mean": _FILL,
45
+ }
46
+
47
+ _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
48
+
49
+
50
+ def _interpolation(kwargs):
51
+ interpolation = kwargs.pop("resample", Image.BILINEAR)
52
+ if isinstance(interpolation, (list, tuple)):
53
+ return random.choice(interpolation)
54
+ else:
55
+ return interpolation
56
+
57
+
58
+ def _check_args_tf(kwargs):
59
+ if "fillcolor" in kwargs and _PIL_VER < (5, 0):
60
+ kwargs.pop("fillcolor")
61
+ kwargs["resample"] = _interpolation(kwargs)
62
+
63
+
64
+ def shear_x(img, factor, **kwargs):
65
+ _check_args_tf(kwargs)
66
+ return img.transform(
67
+ img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs
68
+ )
69
+
70
+
71
+ def shear_y(img, factor, **kwargs):
72
+ _check_args_tf(kwargs)
73
+ return img.transform(
74
+ img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs
75
+ )
76
+
77
+
78
+ def translate_x_rel(img, pct, **kwargs):
79
+ pixels = pct * img.size[0]
80
+ _check_args_tf(kwargs)
81
+ return img.transform(
82
+ img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
83
+ )
84
+
85
+
86
+ def translate_y_rel(img, pct, **kwargs):
87
+ pixels = pct * img.size[1]
88
+ _check_args_tf(kwargs)
89
+ return img.transform(
90
+ img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
91
+ )
92
+
93
+
94
+ def translate_x_abs(img, pixels, **kwargs):
95
+ _check_args_tf(kwargs)
96
+ return img.transform(
97
+ img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
98
+ )
99
+
100
+
101
+ def translate_y_abs(img, pixels, **kwargs):
102
+ _check_args_tf(kwargs)
103
+ return img.transform(
104
+ img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
105
+ )
106
+
107
+
108
+ def rotate(img, degrees, **kwargs):
109
+ _check_args_tf(kwargs)
110
+ if _PIL_VER >= (5, 2):
111
+ return img.rotate(degrees, **kwargs)
112
+ elif _PIL_VER >= (5, 0):
113
+ w, h = img.size
114
+ post_trans = (0, 0)
115
+ rotn_center = (w / 2.0, h / 2.0)
116
+ angle = -math.radians(degrees)
117
+ matrix = [
118
+ round(math.cos(angle), 15),
119
+ round(math.sin(angle), 15),
120
+ 0.0,
121
+ round(-math.sin(angle), 15),
122
+ round(math.cos(angle), 15),
123
+ 0.0,
124
+ ]
125
+
126
+ def transform(x, y, matrix):
127
+ (a, b, c, d, e, f) = matrix
128
+ return a * x + b * y + c, d * x + e * y + f
129
+
130
+ matrix[2], matrix[5] = transform(
131
+ -rotn_center[0] - post_trans[0],
132
+ -rotn_center[1] - post_trans[1],
133
+ matrix,
134
+ )
135
+ matrix[2] += rotn_center[0]
136
+ matrix[5] += rotn_center[1]
137
+ return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
138
+ else:
139
+ return img.rotate(degrees, resample=kwargs["resample"])
140
+
141
+
142
+ def auto_contrast(img, **__):
143
+ return ImageOps.autocontrast(img)
144
+
145
+
146
+ def invert(img, **__):
147
+ return ImageOps.invert(img)
148
+
149
+
150
+ def equalize(img, **__):
151
+ return ImageOps.equalize(img)
152
+
153
+
154
+ def solarize(img, thresh, **__):
155
+ return ImageOps.solarize(img, thresh)
156
+
157
+
158
+ def solarize_add(img, add, thresh=128, **__):
159
+ lut = []
160
+ for i in range(256):
161
+ if i < thresh:
162
+ lut.append(min(255, i + add))
163
+ else:
164
+ lut.append(i)
165
+ if img.mode in ("L", "RGB"):
166
+ if img.mode == "RGB" and len(lut) == 256:
167
+ lut = lut + lut + lut
168
+ return img.point(lut)
169
+ else:
170
+ return img
171
+
172
+
173
+ def posterize(img, bits_to_keep, **__):
174
+ if bits_to_keep >= 8:
175
+ return img
176
+ return ImageOps.posterize(img, bits_to_keep)
177
+
178
+
179
+ def contrast(img, factor, **__):
180
+ return ImageEnhance.Contrast(img).enhance(factor)
181
+
182
+
183
+ def color(img, factor, **__):
184
+ return ImageEnhance.Color(img).enhance(factor)
185
+
186
+
187
+ def brightness(img, factor, **__):
188
+ return ImageEnhance.Brightness(img).enhance(factor)
189
+
190
+
191
+ def sharpness(img, factor, **__):
192
+ return ImageEnhance.Sharpness(img).enhance(factor)
193
+
194
+
195
+ def _randomly_negate(v):
196
+ """With 50% prob, negate the value"""
197
+ return -v if random.random() > 0.5 else v
198
+
199
+
200
+ def _rotate_level_to_arg(level, _hparams):
201
+ # range [-30, 30]
202
+ level = (level / _MAX_LEVEL) * 30.0
203
+ level = _randomly_negate(level)
204
+ return (level,)
205
+
206
+
207
+ def _enhance_level_to_arg(level, _hparams):
208
+ # range [0.1, 1.9]
209
+ return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
210
+
211
+
212
+ def _enhance_increasing_level_to_arg(level, _hparams):
213
+ # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
214
+ # range [0.1, 1.9]
215
+ level = (level / _MAX_LEVEL) * 0.9
216
+ level = 1.0 + _randomly_negate(level)
217
+ return (level,)
218
+
219
+
220
+ def _shear_level_to_arg(level, _hparams):
221
+ # range [-0.3, 0.3]
222
+ level = (level / _MAX_LEVEL) * 0.3
223
+ level = _randomly_negate(level)
224
+ return (level,)
225
+
226
+
227
+ def _translate_abs_level_to_arg(level, hparams):
228
+ translate_const = hparams["translate_const"]
229
+ level = (level / _MAX_LEVEL) * float(translate_const)
230
+ level = _randomly_negate(level)
231
+ return (level,)
232
+
233
+
234
+ def _translate_rel_level_to_arg(level, hparams):
235
+ # default range [-0.45, 0.45]
236
+ translate_pct = hparams.get("translate_pct", 0.45)
237
+ level = (level / _MAX_LEVEL) * translate_pct
238
+ level = _randomly_negate(level)
239
+ return (level,)
240
+
241
+
242
+ def _posterize_level_to_arg(level, _hparams):
243
+ # As per Tensorflow TPU EfficientNet impl
244
+ # range [0, 4], 'keep 0 up to 4 MSB of original image'
245
+ # intensity/severity of augmentation decreases with level
246
+ return (int((level / _MAX_LEVEL) * 4),)
247
+
248
+
249
+ def _posterize_increasing_level_to_arg(level, hparams):
250
+ # As per Tensorflow models research and UDA impl
251
+ # range [4, 0], 'keep 4 down to 0 MSB of original image',
252
+ # intensity/severity of augmentation increases with level
253
+ return (4 - _posterize_level_to_arg(level, hparams)[0],)
254
+
255
+
256
+ def _posterize_original_level_to_arg(level, _hparams):
257
+ # As per original AutoAugment paper description
258
+ # range [4, 8], 'keep 4 up to 8 MSB of image'
259
+ # intensity/severity of augmentation decreases with level
260
+ return (int((level / _MAX_LEVEL) * 4) + 4,)
261
+
262
+
263
+ def _solarize_level_to_arg(level, _hparams):
264
+ # range [0, 256]
265
+ # intensity/severity of augmentation decreases with level
266
+ return (int((level / _MAX_LEVEL) * 256),)
267
+
268
+
269
+ def _solarize_increasing_level_to_arg(level, _hparams):
270
+ # range [0, 256]
271
+ # intensity/severity of augmentation increases with level
272
+ return (256 - _solarize_level_to_arg(level, _hparams)[0],)
273
+
274
+
275
+ def _solarize_add_level_to_arg(level, _hparams):
276
+ # range [0, 110]
277
+ return (int((level / _MAX_LEVEL) * 110),)
278
+
279
+
280
+ LEVEL_TO_ARG = {
281
+ "AutoContrast": None,
282
+ "Equalize": None,
283
+ "Invert": None,
284
+ "Rotate": _rotate_level_to_arg,
285
+ # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
286
+ "Posterize": _posterize_level_to_arg,
287
+ "PosterizeIncreasing": _posterize_increasing_level_to_arg,
288
+ "PosterizeOriginal": _posterize_original_level_to_arg,
289
+ "Solarize": _solarize_level_to_arg,
290
+ "SolarizeIncreasing": _solarize_increasing_level_to_arg,
291
+ "SolarizeAdd": _solarize_add_level_to_arg,
292
+ "Color": _enhance_level_to_arg,
293
+ "ColorIncreasing": _enhance_increasing_level_to_arg,
294
+ "Contrast": _enhance_level_to_arg,
295
+ "ContrastIncreasing": _enhance_increasing_level_to_arg,
296
+ "Brightness": _enhance_level_to_arg,
297
+ "BrightnessIncreasing": _enhance_increasing_level_to_arg,
298
+ "Sharpness": _enhance_level_to_arg,
299
+ "SharpnessIncreasing": _enhance_increasing_level_to_arg,
300
+ "ShearX": _shear_level_to_arg,
301
+ "ShearY": _shear_level_to_arg,
302
+ "TranslateX": _translate_abs_level_to_arg,
303
+ "TranslateY": _translate_abs_level_to_arg,
304
+ "TranslateXRel": _translate_rel_level_to_arg,
305
+ "TranslateYRel": _translate_rel_level_to_arg,
306
+ }
307
+
308
+
309
+ NAME_TO_OP = {
310
+ "AutoContrast": auto_contrast,
311
+ "Equalize": equalize,
312
+ "Invert": invert,
313
+ "Rotate": rotate,
314
+ "Posterize": posterize,
315
+ "PosterizeIncreasing": posterize,
316
+ "PosterizeOriginal": posterize,
317
+ "Solarize": solarize,
318
+ "SolarizeIncreasing": solarize,
319
+ "SolarizeAdd": solarize_add,
320
+ "Color": color,
321
+ "ColorIncreasing": color,
322
+ "Contrast": contrast,
323
+ "ContrastIncreasing": contrast,
324
+ "Brightness": brightness,
325
+ "BrightnessIncreasing": brightness,
326
+ "Sharpness": sharpness,
327
+ "SharpnessIncreasing": sharpness,
328
+ "ShearX": shear_x,
329
+ "ShearY": shear_y,
330
+ "TranslateX": translate_x_abs,
331
+ "TranslateY": translate_y_abs,
332
+ "TranslateXRel": translate_x_rel,
333
+ "TranslateYRel": translate_y_rel,
334
+ }
335
+
336
+
337
+ class AugmentOp:
338
+ """
339
+ Apply for video.
340
+ """
341
+
342
+ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
343
+ hparams = hparams or _HPARAMS_DEFAULT
344
+ self.aug_fn = NAME_TO_OP[name]
345
+ self.level_fn = LEVEL_TO_ARG[name]
346
+ self.prob = prob
347
+ self.magnitude = magnitude
348
+ self.hparams = hparams.copy()
349
+ self.kwargs = {
350
+ "fillcolor": hparams["img_mean"]
351
+ if "img_mean" in hparams
352
+ else _FILL,
353
+ "resample": hparams["interpolation"]
354
+ if "interpolation" in hparams
355
+ else _RANDOM_INTERPOLATION,
356
+ }
357
+
358
+ # If magnitude_std is > 0, we introduce some randomness
359
+ # in the usually fixed policy and sample magnitude from a normal distribution
360
+ # with mean `magnitude` and std-dev of `magnitude_std`.
361
+ # NOTE This is my own hack, being tested, not in papers or reference impls.
362
+ self.magnitude_std = self.hparams.get("magnitude_std", 0)
363
+
364
+ def __call__(self, img_list):
365
+ if self.prob < 1.0 and random.random() > self.prob:
366
+ return img_list
367
+ magnitude = self.magnitude
368
+ if self.magnitude_std and self.magnitude_std > 0:
369
+ magnitude = random.gauss(magnitude, self.magnitude_std)
370
+ magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
371
+ level_args = (
372
+ self.level_fn(magnitude, self.hparams)
373
+ if self.level_fn is not None
374
+ else ()
375
+ )
376
+
377
+ if isinstance(img_list, list):
378
+ return [
379
+ self.aug_fn(img, *level_args, **self.kwargs) for img in img_list
380
+ ]
381
+ else:
382
+ return self.aug_fn(img_list, *level_args, **self.kwargs)
383
+
384
+
385
+ _RAND_TRANSFORMS = [
386
+ "AutoContrast",
387
+ "Equalize",
388
+ "Invert",
389
+ "Rotate",
390
+ "Posterize",
391
+ "Solarize",
392
+ "SolarizeAdd",
393
+ "Color",
394
+ "Contrast",
395
+ "Brightness",
396
+ "Sharpness",
397
+ "ShearX",
398
+ "ShearY",
399
+ "TranslateXRel",
400
+ "TranslateYRel",
401
+ ]
402
+
403
+
404
+ _RAND_INCREASING_TRANSFORMS = [
405
+ "AutoContrast",
406
+ "Equalize",
407
+ "Invert",
408
+ "Rotate",
409
+ "PosterizeIncreasing",
410
+ "SolarizeIncreasing",
411
+ "SolarizeAdd",
412
+ "ColorIncreasing",
413
+ "ContrastIncreasing",
414
+ "BrightnessIncreasing",
415
+ "SharpnessIncreasing",
416
+ "ShearX",
417
+ "ShearY",
418
+ "TranslateXRel",
419
+ "TranslateYRel",
420
+ ]
421
+
422
+
423
+ # These experimental weights are based loosely on the relative improvements mentioned in paper.
424
+ # They may not result in increased performance, but could likely be tuned to so.
425
+ _RAND_CHOICE_WEIGHTS_0 = {
426
+ "Rotate": 0.3,
427
+ "ShearX": 0.2,
428
+ "ShearY": 0.2,
429
+ "TranslateXRel": 0.1,
430
+ "TranslateYRel": 0.1,
431
+ "Color": 0.025,
432
+ "Sharpness": 0.025,
433
+ "AutoContrast": 0.025,
434
+ "Solarize": 0.005,
435
+ "SolarizeAdd": 0.005,
436
+ "Contrast": 0.005,
437
+ "Brightness": 0.005,
438
+ "Equalize": 0.005,
439
+ "Posterize": 0,
440
+ "Invert": 0,
441
+ }
442
+
443
+
444
+ def _select_rand_weights(weight_idx=0, transforms=None):
445
+ transforms = transforms or _RAND_TRANSFORMS
446
+ assert weight_idx == 0 # only one set of weights currently
447
+ rand_weights = _RAND_CHOICE_WEIGHTS_0
448
+ probs = [rand_weights[k] for k in transforms]
449
+ probs /= np.sum(probs)
450
+ return probs
451
+
452
+
453
+ def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
454
+ hparams = hparams or _HPARAMS_DEFAULT
455
+ transforms = transforms or _RAND_TRANSFORMS
456
+ return [
457
+ AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)
458
+ for name in transforms
459
+ ]
460
+
461
+
462
+ class RandAugment:
463
+ def __init__(self, ops, num_layers=2, choice_weights=None):
464
+ self.ops = ops
465
+ self.num_layers = num_layers
466
+ self.choice_weights = choice_weights
467
+
468
+ def __call__(self, img):
469
+ # no replacement when using weighted choice
470
+ ops = np.random.choice(
471
+ self.ops,
472
+ self.num_layers,
473
+ replace=self.choice_weights is None,
474
+ p=self.choice_weights,
475
+ )
476
+ for op in ops:
477
+ img = op(img)
478
+ return img
479
+
480
+
481
+ def rand_augment_transform(config_str, hparams):
482
+ """
483
+ RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
484
+
485
+ Create a RandAugment transform
486
+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
487
+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
488
+ sections, not order sepecific determine
489
+ 'm' - integer magnitude of rand augment
490
+ 'n' - integer num layers (number of transform ops selected per image)
491
+ 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
492
+ 'mstd' - float std deviation of magnitude noise applied
493
+ 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
494
+ Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
495
+ 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
496
+ :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
497
+ :return: A PyTorch compatible Transform
498
+ """
499
+ magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
500
+ num_layers = 2 # default to 2 ops per image
501
+ weight_idx = None # default to no probability weights for op choice
502
+ transforms = _RAND_TRANSFORMS
503
+ config = config_str.split("-")
504
+ assert config[0] == "rand"
505
+ config = config[1:]
506
+ for c in config:
507
+ cs = re.split(r"(\d.*)", c)
508
+ if len(cs) < 2:
509
+ continue
510
+ key, val = cs[:2]
511
+ if key == "mstd":
512
+ # noise param injected via hparams for now
513
+ hparams.setdefault("magnitude_std", float(val))
514
+ elif key == "inc":
515
+ if bool(val):
516
+ transforms = _RAND_INCREASING_TRANSFORMS
517
+ elif key == "m":
518
+ magnitude = int(val)
519
+ elif key == "n":
520
+ num_layers = int(val)
521
+ elif key == "w":
522
+ weight_idx = int(val)
523
+ else:
524
+ assert NotImplementedError
525
+ ra_ops = rand_augment_ops(
526
+ magnitude=magnitude, hparams=hparams, transforms=transforms
527
+ )
528
+ choice_weights = (
529
+ None if weight_idx is None else _select_rand_weights(weight_idx)
530
+ )
531
+ return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
random_erasing.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This implementation is based on
3
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
4
+ pulished under an Apache License 2.0.
5
+ """
6
+ import math
7
+ import random
8
+ import torch
9
+
10
+
11
+ def _get_pixels(
12
+ per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
13
+ ):
14
+ # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
15
+ # paths, flip the order so normal is run on CPU if this becomes a problem
16
+ # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
17
+ if per_pixel:
18
+ return torch.empty(patch_size, dtype=dtype, device=device).normal_()
19
+ elif rand_color:
20
+ return torch.empty(
21
+ (patch_size[0], 1, 1), dtype=dtype, device=device
22
+ ).normal_()
23
+ else:
24
+ return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
25
+
26
+
27
+ class RandomErasing:
28
+ """Randomly selects a rectangle region in an image and erases its pixels.
29
+ 'Random Erasing Data Augmentation' by Zhong et al.
30
+ See https://arxiv.org/pdf/1708.04896.pdf
31
+ This variant of RandomErasing is intended to be applied to either a batch
32
+ or single image tensor after it has been normalized by dataset mean and std.
33
+ Args:
34
+ probability: Probability that the Random Erasing operation will be performed.
35
+ min_area: Minimum percentage of erased area wrt input image area.
36
+ max_area: Maximum percentage of erased area wrt input image area.
37
+ min_aspect: Minimum aspect ratio of erased area.
38
+ mode: pixel color mode, one of 'const', 'rand', or 'pixel'
39
+ 'const' - erase block is constant color of 0 for all channels
40
+ 'rand' - erase block is same per-channel random (normal) color
41
+ 'pixel' - erase block is per-pixel random (normal) color
42
+ max_count: maximum number of erasing blocks per image, area per box is scaled by count.
43
+ per-image count is randomly chosen between 1 and this value.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ probability=0.5,
49
+ min_area=0.02,
50
+ max_area=1 / 3,
51
+ min_aspect=0.3,
52
+ max_aspect=None,
53
+ mode="const",
54
+ min_count=1,
55
+ max_count=None,
56
+ num_splits=0,
57
+ device="cuda",
58
+ cube=True,
59
+ ):
60
+ self.probability = probability
61
+ self.min_area = min_area
62
+ self.max_area = max_area
63
+ max_aspect = max_aspect or 1 / min_aspect
64
+ self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
65
+ self.min_count = min_count
66
+ self.max_count = max_count or min_count
67
+ self.num_splits = num_splits
68
+ mode = mode.lower()
69
+ self.rand_color = False
70
+ self.per_pixel = False
71
+ self.cube = cube
72
+ if mode == "rand":
73
+ self.rand_color = True # per block random normal
74
+ elif mode == "pixel":
75
+ self.per_pixel = True # per pixel random normal
76
+ else:
77
+ assert not mode or mode == "const"
78
+ self.device = device
79
+
80
+ def _erase(self, img, chan, img_h, img_w, dtype):
81
+ if random.random() > self.probability:
82
+ return
83
+ area = img_h * img_w
84
+ count = (
85
+ self.min_count
86
+ if self.min_count == self.max_count
87
+ else random.randint(self.min_count, self.max_count)
88
+ )
89
+ for _ in range(count):
90
+ for _ in range(10):
91
+ target_area = (
92
+ random.uniform(self.min_area, self.max_area) * area / count
93
+ )
94
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
95
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
96
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
97
+ if w < img_w and h < img_h:
98
+ top = random.randint(0, img_h - h)
99
+ left = random.randint(0, img_w - w)
100
+ img[:, top : top + h, left : left + w] = _get_pixels(
101
+ self.per_pixel,
102
+ self.rand_color,
103
+ (chan, h, w),
104
+ dtype=dtype,
105
+ device=self.device,
106
+ )
107
+ break
108
+
109
+ def _erase_cube(
110
+ self,
111
+ img,
112
+ batch_start,
113
+ batch_size,
114
+ chan,
115
+ img_h,
116
+ img_w,
117
+ dtype,
118
+ ):
119
+ if random.random() > self.probability:
120
+ return
121
+ area = img_h * img_w
122
+ count = (
123
+ self.min_count
124
+ if self.min_count == self.max_count
125
+ else random.randint(self.min_count, self.max_count)
126
+ )
127
+ for _ in range(count):
128
+ for _ in range(100):
129
+ target_area = (
130
+ random.uniform(self.min_area, self.max_area) * area / count
131
+ )
132
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
133
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
134
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
135
+ if w < img_w and h < img_h:
136
+ top = random.randint(0, img_h - h)
137
+ left = random.randint(0, img_w - w)
138
+ for i in range(batch_start, batch_size):
139
+ img_instance = img[i]
140
+ img_instance[
141
+ :, top : top + h, left : left + w
142
+ ] = _get_pixels(
143
+ self.per_pixel,
144
+ self.rand_color,
145
+ (chan, h, w),
146
+ dtype=dtype,
147
+ device=self.device,
148
+ )
149
+ break
150
+
151
+ def __call__(self, input):
152
+ if len(input.size()) == 3:
153
+ self._erase(input, *input.size(), input.dtype)
154
+ else:
155
+ batch_size, chan, img_h, img_w = input.size()
156
+ # skip first slice of batch if num_splits is set (for clean portion of samples)
157
+ batch_start = (
158
+ batch_size // self.num_splits if self.num_splits > 1 else 0
159
+ )
160
+ if self.cube:
161
+ self._erase_cube(
162
+ input,
163
+ batch_start,
164
+ batch_size,
165
+ chan,
166
+ img_h,
167
+ img_w,
168
+ input.dtype,
169
+ )
170
+ else:
171
+ for i in range(batch_start, batch_size):
172
+ self._erase(input[i], chan, img_h, img_w, input.dtype)
173
+ return input
run_class_finetuning.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import numpy as np
4
+ import time
5
+ import torch
6
+ import torch.backends.cudnn as cudnn
7
+ import json
8
+ import os
9
+ from functools import partial
10
+ from pathlib import Path
11
+ from collections import OrderedDict
12
+
13
+ from mixup import Mixup
14
+ from timm.models import create_model
15
+ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
16
+ from timm.utils import ModelEma
17
+ from optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner
18
+
19
+ from datasets import build_dataset
20
+ from engine_for_finetuning import train_one_epoch, validation_one_epoch, final_test, merge, merge_mean_per_class
21
+ from utils_mae import NativeScalerWithGradNormCount as NativeScaler
22
+ from utils_mae import multiple_samples_collate
23
+ import utils_mae as utils
24
+ import modeling_finetune
25
+
26
+
27
+ def get_args():
28
+ parser = argparse.ArgumentParser('VideoMAE fine-tuning and evaluation script for video classification', add_help=False)
29
+ parser.add_argument('--batch_size', default=64, type=int)
30
+ parser.add_argument('--epochs', default=30, type=int)
31
+ parser.add_argument('--update_freq', default=1, type=int)
32
+ parser.add_argument('--save_ckpt_freq', default=100, type=int)
33
+ parser.add_argument('--val_freq', default=1, type=int)
34
+
35
+ # Model parameters
36
+ parser.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL',
37
+ help='Name of model to train')
38
+ parser.add_argument('--tubelet_size', type=int, default= 2)
39
+ parser.add_argument('--input_size', default=224, type=int,
40
+ help='videos input size')
41
+
42
+ parser.add_argument('--fc_drop_rate', type=float, default=0.0, metavar='PCT',
43
+ help='Dropout rate (default: 0.)')
44
+ parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
45
+ help='Dropout rate (default: 0.)')
46
+ parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
47
+ help='Attention dropout rate (default: 0.)')
48
+ parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
49
+ help='Drop path rate (default: 0.1)')
50
+
51
+ parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False)
52
+ parser.add_argument('--model_ema', action='store_true', default=False)
53
+ parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
54
+ parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
55
+
56
+ # Optimizer parameters
57
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
58
+ help='Optimizer (default: "adamw"')
59
+ parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
60
+ help='Optimizer Epsilon (default: 1e-8)')
61
+ parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
62
+ help='Optimizer Betas (default: None, use opt default)')
63
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
64
+ help='Clip gradient norm (default: None, no clipping)')
65
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
66
+ help='SGD momentum (default: 0.9)')
67
+ parser.add_argument('--weight_decay', type=float, default=0.05,
68
+ help='weight decay (default: 0.05)')
69
+ parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
70
+ weight decay. We use a cosine schedule for WD and using a larger decay by
71
+ the end of training improves performance for ViTs.""")
72
+
73
+ parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
74
+ help='learning rate (default: 1e-3)')
75
+ parser.add_argument('--layer_decay', type=float, default=0.75)
76
+
77
+ parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
78
+ help='warmup learning rate (default: 1e-6)')
79
+ parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
80
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
81
+
82
+ parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
83
+ help='epochs to warmup LR, if scheduler supports')
84
+ parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
85
+ help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
86
+
87
+ # Augmentation parameters
88
+ parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
89
+ help='Color jitter factor (default: 0.4)')
90
+ parser.add_argument('--num_sample', type=int, default=2,
91
+ help='Repeated_aug (default: 2)')
92
+ parser.add_argument('--aa', type=str, default='rand-m7-n4-mstd0.5-inc1', metavar='NAME',
93
+ help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m7-n4-mstd0.5-inc1)'),
94
+ parser.add_argument('--smoothing', type=float, default=0.1,
95
+ help='Label smoothing (default: 0.1)')
96
+ parser.add_argument('--train_interpolation', type=str, default='bicubic',
97
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
98
+
99
+ # Evaluation parameters
100
+ parser.add_argument('--crop_pct', type=float, default=None)
101
+ parser.add_argument('--short_side_size', type=int, default=224)
102
+ parser.add_argument('--test_num_segment', type=int, default=5)
103
+ parser.add_argument('--test_num_crop', type=int, default=3)
104
+
105
+ # Random Erase params
106
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
107
+ help='Random erase prob (default: 0.25)')
108
+ parser.add_argument('--remode', type=str, default='pixel',
109
+ help='Random erase mode (default: "pixel")')
110
+ parser.add_argument('--recount', type=int, default=1,
111
+ help='Random erase count (default: 1)')
112
+ parser.add_argument('--resplit', action='store_true', default=False,
113
+ help='Do not random erase first (clean) augmentation split')
114
+
115
+ # Mixup params
116
+ parser.add_argument('--mixup', type=float, default=0.8,
117
+ help='mixup alpha, mixup enabled if > 0.')
118
+ parser.add_argument('--cutmix', type=float, default=1.0,
119
+ help='cutmix alpha, cutmix enabled if > 0.')
120
+ parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
121
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
122
+ parser.add_argument('--mixup_prob', type=float, default=1.0,
123
+ help='Probability of performing mixup or cutmix when either/both is enabled')
124
+ parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
125
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
126
+ parser.add_argument('--mixup_mode', type=str, default='batch',
127
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
128
+
129
+ # Finetuning params
130
+ parser.add_argument('--finetune', default='', help='finetune from checkpoint')
131
+ parser.add_argument('--model_key', default='model|module', type=str)
132
+ parser.add_argument('--model_prefix', default='', type=str)
133
+ parser.add_argument('--init_scale', default=0.001, type=float)
134
+ parser.add_argument('--use_checkpoint', action='store_true')
135
+ parser.set_defaults(use_checkpoint=False)
136
+ parser.add_argument('--use_mean_pooling', action='store_true')
137
+ parser.set_defaults(use_mean_pooling=True)
138
+ parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')
139
+
140
+ # Dataset parameters
141
+ parser.add_argument('--data_path', default='/path/to/list_kinetics-400', type=str,
142
+ help='dataset path')
143
+ parser.add_argument('--eval_data_path', default=None, type=str,
144
+ help='dataset path for evaluation')
145
+ parser.add_argument('--nb_classes', default=400, type=int,
146
+ help='number of the classification types')
147
+ parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
148
+ parser.add_argument('--num_segments', type=int, default= 1)
149
+ parser.add_argument('--num_frames', type=int, default= 16)
150
+ parser.add_argument('--sampling_rate', type=int, default= 4)
151
+ parser.add_argument('--data_set', default='Kinetics-400', choices=['Kinetics-400', 'SSV2', 'UCF101', 'HMDB51','image_folder','SSV2-Mini', 'Mini-Kinetics'],
152
+ type=str, help='dataset')
153
+ parser.add_argument('--output_dir', default='',
154
+ help='path where to save, empty for no saving')
155
+ parser.add_argument('--log_dir', default=None,
156
+ help='path where to tensorboard log')
157
+ parser.add_argument('--device', default='cuda',
158
+ help='device to use for training / testing')
159
+ parser.add_argument('--seed', default=0, type=int)
160
+ parser.add_argument('--resume', default='',
161
+ help='resume from checkpoint')
162
+ parser.add_argument('--auto_resume', action='store_true')
163
+ parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
164
+ parser.set_defaults(auto_resume=True)
165
+
166
+ parser.add_argument('--save_ckpt', action='store_true')
167
+ parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
168
+ parser.set_defaults(save_ckpt=True)
169
+
170
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
171
+ help='start epoch')
172
+ parser.add_argument('--eval', action='store_true',
173
+ help='Perform evaluation only')
174
+ parser.add_argument('--dist_eval', action='store_true', default=False,
175
+ help='Enabling distributed evaluation')
176
+ parser.add_argument('--num_workers', default=10, type=int)
177
+ parser.add_argument('--pin_mem', action='store_true',
178
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
179
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
180
+ parser.set_defaults(pin_mem=True)
181
+
182
+ # distributed training parameters
183
+ parser.add_argument('--world_size', default=1, type=int,
184
+ help='number of distributed processes')
185
+ parser.add_argument('--local_rank', default=-1, type=int)
186
+ parser.add_argument('--dist_on_itp', action='store_true')
187
+ parser.add_argument('--dist_url', default='env://',
188
+ help='url used to set up distributed training')
189
+
190
+ parser.add_argument('--enable_deepspeed', action='store_true', default=False)
191
+
192
+ # debug mode
193
+ parser.add_argument('--not_dist', action='store_true', default=False)
194
+ parser.add_argument('--num_outputs', default=8, type=int)
195
+
196
+ known_args, _ = parser.parse_known_args()
197
+
198
+ if known_args.enable_deepspeed:
199
+ try:
200
+ import deepspeed
201
+ from deepspeed import DeepSpeedConfig
202
+ parser = deepspeed.add_config_arguments(parser)
203
+ ds_init = deepspeed.initialize
204
+ except:
205
+ print("Please 'pip install deepspeed'")
206
+ exit(0)
207
+ else:
208
+ ds_init = None
209
+
210
+ return parser.parse_args(), ds_init
211
+
212
+
213
+ def main(args, ds_init):
214
+ if args.not_dist:
215
+ args.distributed = False
216
+ else:
217
+ utils.init_distributed_mode(args)
218
+
219
+ if ds_init is not None:
220
+ utils.create_ds_config(args)
221
+
222
+ print(args)
223
+
224
+ device = torch.device(args.device)
225
+
226
+ # fix the seed for reproducibility
227
+ seed = args.seed + utils.get_rank()
228
+ torch.manual_seed(seed)
229
+ np.random.seed(seed)
230
+ # random.seed(seed)
231
+
232
+ cudnn.benchmark = True
233
+
234
+ dataset_train, args.nb_classes = build_dataset(is_train=True, test_mode=False, args=args)
235
+ if args.disable_eval_during_finetuning:
236
+ dataset_val = None
237
+ else:
238
+ dataset_val, _ = build_dataset(is_train=False, test_mode=False, args=args)
239
+ dataset_test, _ = build_dataset(is_train=False, test_mode=True, args=args)
240
+
241
+
242
+ num_tasks = utils.get_world_size()
243
+ global_rank = utils.get_rank()
244
+ sampler_train = torch.utils.data.DistributedSampler(
245
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
246
+ )
247
+ print("Sampler_train = %s" % str(sampler_train))
248
+ if args.dist_eval:
249
+ if len(dataset_val) % num_tasks != 0:
250
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
251
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
252
+ 'equal num of samples per-process.')
253
+ sampler_val = torch.utils.data.DistributedSampler(
254
+ dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
255
+ sampler_test = torch.utils.data.DistributedSampler(
256
+ dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=False)
257
+ else:
258
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
259
+
260
+ if global_rank == 0 and args.log_dir is not None:
261
+ os.makedirs(args.log_dir, exist_ok=True)
262
+ log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
263
+ else:
264
+ log_writer = None
265
+
266
+ if args.num_sample > 1:
267
+ collate_func = partial(multiple_samples_collate, fold=False)
268
+ else:
269
+ collate_func = None
270
+
271
+ data_loader_train = torch.utils.data.DataLoader(
272
+ dataset_train, sampler=sampler_train,
273
+ batch_size=args.batch_size,
274
+ num_workers=args.num_workers,
275
+ pin_memory=args.pin_mem,
276
+ drop_last=True,
277
+ collate_fn=collate_func,
278
+ )
279
+
280
+ if dataset_val is not None:
281
+ data_loader_val = torch.utils.data.DataLoader(
282
+ dataset_val, sampler=sampler_val,
283
+ batch_size=int(1.5 * args.batch_size),
284
+ num_workers=args.num_workers,
285
+ pin_memory=args.pin_mem,
286
+ drop_last=False
287
+ )
288
+ else:
289
+ data_loader_val = None
290
+
291
+ if dataset_test is not None:
292
+ data_loader_test = torch.utils.data.DataLoader(
293
+ dataset_test, sampler=sampler_test,
294
+ batch_size=args.batch_size,
295
+ num_workers=args.num_workers,
296
+ pin_memory=args.pin_mem,
297
+ drop_last=False
298
+ )
299
+ else:
300
+ data_loader_test = None
301
+
302
+ mixup_fn = None
303
+ mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
304
+ if mixup_active:
305
+ print("Mixup is activated!")
306
+ mixup_fn = Mixup(
307
+ mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
308
+ prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
309
+ label_smoothing=args.smoothing, num_classes=args.nb_classes)
310
+
311
+ model = create_model(
312
+ args.model,
313
+ pretrained=False,
314
+ num_classes=args.nb_classes,
315
+ all_frames=args.num_frames * args.num_segments,
316
+ tubelet_size=args.tubelet_size,
317
+ fc_drop_rate=args.fc_drop_rate,
318
+ drop_rate=args.drop,
319
+ drop_path_rate=args.drop_path,
320
+ attn_drop_rate=args.attn_drop_rate,
321
+ drop_block_rate=None,
322
+ use_checkpoint=args.use_checkpoint,
323
+ use_mean_pooling=args.use_mean_pooling,
324
+ init_scale=args.init_scale,
325
+ )
326
+
327
+ patch_size = model.patch_embed.patch_size
328
+ print("Patch size = %s" % str(patch_size))
329
+ args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1])
330
+ args.patch_size = patch_size
331
+
332
+ if args.finetune:
333
+ if args.finetune.startswith('https'):
334
+ checkpoint = torch.hub.load_state_dict_from_url(
335
+ args.finetune, map_location='cpu', check_hash=True)
336
+ else:
337
+ checkpoint = torch.load(args.finetune, map_location='cpu')
338
+
339
+ print("Load ckpt from %s" % args.finetune)
340
+ checkpoint_model = None
341
+ for model_key in args.model_key.split('|'):
342
+ if model_key in checkpoint:
343
+ checkpoint_model = checkpoint[model_key]
344
+ print("Load state_dict by model_key = %s" % model_key)
345
+ break
346
+ if checkpoint_model is None:
347
+ checkpoint_model = checkpoint
348
+ state_dict = model.state_dict()
349
+ for k in ['head.weight', 'head.bias']:
350
+ if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
351
+ print(f"Removing key {k} from pretrained checkpoint")
352
+ del checkpoint_model[k]
353
+
354
+ all_keys = list(checkpoint_model.keys())
355
+ new_dict = OrderedDict()
356
+ for key in all_keys:
357
+ if key.startswith('backbone.'):
358
+ new_dict[key[9:]] = checkpoint_model[key]
359
+ elif key.startswith('encoder.'):
360
+ new_dict[key[8:]] = checkpoint_model[key]
361
+ else:
362
+ new_dict[key] = checkpoint_model[key]
363
+ checkpoint_model = new_dict
364
+
365
+ # interpolate position embedding
366
+ if 'pos_embed' in checkpoint_model:
367
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
368
+ embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
369
+ num_patches = model.patch_embed.num_patches #
370
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
371
+
372
+ # height (== width) for the checkpoint position embedding
373
+ orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(args.num_frames // model.patch_embed.tubelet_size)) ** 0.5)
374
+ # height (== width) for the new position embedding
375
+ new_size = int((num_patches // (args.num_frames // model.patch_embed.tubelet_size) )** 0.5)
376
+ # class_token and dist_token are kept unchanged
377
+ if orig_size != new_size:
378
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
379
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
380
+ # only the position tokens are interpolated
381
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
382
+ # B, L, C -> BT, H, W, C -> BT, C, H, W
383
+ pos_tokens = pos_tokens.reshape(-1, args.num_frames // model.patch_embed.tubelet_size, orig_size, orig_size, embedding_size)
384
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
385
+ pos_tokens = torch.nn.functional.interpolate(
386
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
387
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
388
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, args.num_frames // model.patch_embed.tubelet_size, new_size, new_size, embedding_size)
389
+ pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
390
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
391
+ checkpoint_model['pos_embed'] = new_pos_embed
392
+
393
+ utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
394
+
395
+ model.to(device)
396
+
397
+ model_ema = None
398
+ if args.model_ema:
399
+ model_ema = ModelEma(
400
+ model,
401
+ decay=args.model_ema_decay,
402
+ device='cpu' if args.model_ema_force_cpu else '',
403
+ resume='')
404
+ print("Using EMA with decay = %.8f" % args.model_ema_decay)
405
+
406
+ model_without_ddp = model
407
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
408
+
409
+ print("Model = %s" % str(model_without_ddp))
410
+ print('number of params:', n_parameters)
411
+
412
+ total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
413
+ num_training_steps_per_epoch = len(dataset_train) // total_batch_size
414
+ args.lr = args.lr * total_batch_size / 256
415
+ args.min_lr = args.min_lr * total_batch_size / 256
416
+ args.warmup_lr = args.warmup_lr * total_batch_size / 256
417
+ print("LR = %.8f" % args.lr)
418
+ print("Batch size = %d" % total_batch_size)
419
+ print("Update frequent = %d" % args.update_freq)
420
+ print("Number of training examples = %d" % len(dataset_train))
421
+ print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
422
+
423
+ num_layers = model_without_ddp.get_num_layers()
424
+ if args.layer_decay < 1.0:
425
+ assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
426
+ else:
427
+ assigner = None
428
+
429
+ if assigner is not None:
430
+ print("Assigned values = %s" % str(assigner.values))
431
+
432
+ skip_weight_decay_list = model.no_weight_decay()
433
+ print("Skip weight decay list: ", skip_weight_decay_list)
434
+
435
+ if args.enable_deepspeed:
436
+ loss_scaler = None
437
+ optimizer_params = get_parameter_groups(
438
+ model, args.weight_decay, skip_weight_decay_list,
439
+ assigner.get_layer_id if assigner is not None else None,
440
+ assigner.get_scale if assigner is not None else None)
441
+ model, optimizer, _, _ = ds_init(
442
+ args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed,
443
+ )
444
+
445
+ print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps())
446
+ assert model.gradient_accumulation_steps() == args.update_freq
447
+ else:
448
+ if args.distributed:
449
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
450
+ model_without_ddp = model.module
451
+
452
+ optimizer = create_optimizer(
453
+ args, model_without_ddp, skip_list=skip_weight_decay_list,
454
+ get_num_layer=assigner.get_layer_id if assigner is not None else None,
455
+ get_layer_scale=assigner.get_scale if assigner is not None else None)
456
+ loss_scaler = NativeScaler()
457
+
458
+ print("Use step level LR scheduler!")
459
+ lr_schedule_values = utils.cosine_scheduler(
460
+ args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
461
+ warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
462
+ )
463
+ if args.weight_decay_end is None:
464
+ args.weight_decay_end = args.weight_decay
465
+ wd_schedule_values = utils.cosine_scheduler(
466
+ args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
467
+ print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
468
+
469
+ if mixup_fn is not None:
470
+ # smoothing is handled with mixup label transform
471
+ criterion = SoftTargetCrossEntropy()
472
+ elif args.smoothing > 0.:
473
+ criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
474
+ else:
475
+ criterion = torch.nn.CrossEntropyLoss()
476
+
477
+ print("criterion = %s" % str(criterion))
478
+
479
+ utils.auto_load_model(
480
+ args=args, model=model, model_without_ddp=model_without_ddp,
481
+ optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
482
+
483
+ if args.eval:
484
+ if not args.not_dist:
485
+ preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt')
486
+ test_stats = final_test(data_loader_test, model, device, preds_file)
487
+ torch.distributed.barrier()
488
+ else:
489
+ num_tasks = args.num_outputs
490
+
491
+ if global_rank == 0:
492
+ print("Start merging results...")
493
+ final_top1 ,final_top5 = merge(args.output_dir, num_tasks)
494
+ print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%")
495
+ log_stats = {'Final top-1': final_top1,
496
+ 'Final Top-5': final_top5}
497
+
498
+ final_top1_per_class ,final_top5_per_class = merge_mean_per_class(args.output_dir, num_tasks,args.nb_classes)
499
+ print(f"Accuracy of the network on the {len(dataset_test)} test videos: Mean-Top-1: {final_top1_per_class:.2f}%, Mean-Top-5: {final_top5_per_class:.2f}%")
500
+ log_stats["Class-Mean-Top-1"] = final_top1_per_class
501
+ log_stats["Class-Mean-Top-5"] = final_top5_per_class
502
+
503
+ if args.output_dir and utils.is_main_process():
504
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
505
+ f.write(json.dumps(log_stats) + "\n")
506
+ exit(0)
507
+
508
+
509
+ print(f"Start training for {args.epochs} epochs")
510
+ start_time = time.time()
511
+ max_accuracy = 0.0
512
+ for epoch in range(args.start_epoch, args.epochs):
513
+ if args.distributed:
514
+ data_loader_train.sampler.set_epoch(epoch)
515
+ if log_writer is not None:
516
+ log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
517
+ train_stats = train_one_epoch(
518
+ model, criterion, data_loader_train, optimizer,
519
+ device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
520
+ log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch,
521
+ lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
522
+ num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
523
+ )
524
+ if args.output_dir and args.save_ckpt:
525
+ if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
526
+ utils.save_model(
527
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
528
+ loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
529
+ if data_loader_val is not None and (epoch + 1) % args.val_freq == 0:
530
+ test_stats = validation_one_epoch(data_loader_val, model, device)
531
+ print(f"Accuracy of the network on the {len(dataset_val)} val videos: {test_stats['acc1']:.1f}%")
532
+ if max_accuracy < test_stats["acc1"]:
533
+ max_accuracy = test_stats["acc1"]
534
+ if args.output_dir and args.save_ckpt:
535
+ utils.save_model(
536
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
537
+ loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
538
+
539
+ print(f'Max accuracy: {max_accuracy:.2f}%')
540
+ if log_writer is not None:
541
+ log_writer.update(val_acc1=test_stats['acc1'], head="perf", step=epoch)
542
+ log_writer.update(val_acc5=test_stats['acc5'], head="perf", step=epoch)
543
+ log_writer.update(val_loss=test_stats['loss'], head="perf", step=epoch)
544
+
545
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
546
+ **{f'val_{k}': v for k, v in test_stats.items()},
547
+ 'epoch': epoch,
548
+ 'n_parameters': n_parameters}
549
+ else:
550
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
551
+ 'epoch': epoch,
552
+ 'n_parameters': n_parameters}
553
+ if args.output_dir and utils.is_main_process():
554
+ if log_writer is not None:
555
+ log_writer.flush()
556
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
557
+ f.write(json.dumps(log_stats) + "\n")
558
+
559
+ preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt')
560
+ test_stats = final_test(data_loader_test, model, device, preds_file)
561
+ torch.distributed.barrier()
562
+ if global_rank == 0:
563
+ print("Start merging results...")
564
+ final_top1 ,final_top5 = merge(args.output_dir, num_tasks)
565
+ print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%")
566
+ log_stats = {'Final top-1': final_top1,
567
+ 'Final Top-5': final_top5}
568
+ if args.output_dir and utils.is_main_process():
569
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
570
+ f.write(json.dumps(log_stats) + "\n")
571
+
572
+
573
+ total_time = time.time() - start_time
574
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
575
+ print('Training time {}'.format(total_time_str))
576
+
577
+
578
+ if __name__ == '__main__':
579
+ opts, ds_init = get_args()
580
+ if opts.output_dir:
581
+ Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
582
+ main(opts, ds_init)
run_mae_pretraining.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import numpy as np
4
+ import time
5
+ import torch
6
+ import torch.backends.cudnn as cudnn
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+ from timm.models import create_model
11
+ from optim_factory import create_optimizer
12
+ from datasets import build_pretraining_dataset
13
+ from engine_for_pretraining import train_one_epoch
14
+ from utils_mae import NativeScalerWithGradNormCount as NativeScaler
15
+ import utils_mae as utils
16
+ import modeling_pretrain
17
+ from timm.models.vision_transformer import vit_small_patch16_224, vit_base_patch16_224, vit_large_patch16_224
18
+ from modeling_pretrain import FeatureExtractor
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser('VideoMAE pre-training script', add_help=False)
23
+ parser.add_argument('--batch_size', default=64, type=int)
24
+ parser.add_argument('--epochs', default=800, type=int)
25
+ parser.add_argument('--save_ckpt_freq', default=50, type=int)
26
+
27
+ # Model parameters
28
+ parser.add_argument('--model', default='pretrain_videomae_base_patch16_224', type=str, metavar='MODEL',
29
+ help='Name of model to train')
30
+
31
+ parser.add_argument('--decoder_depth', default=4, type=int,
32
+ help='depth of decoder')
33
+
34
+ parser.add_argument('--mask_type', default='tube', choices=['random', 'tube', 'tubelet'],
35
+ type=str, help='masked strategy of video tokens/patches')
36
+
37
+ parser.add_argument('--sub_mask_type', default='tube+picked_frame_visible', choices=['tube', 'tube+picked_frame_visible', 'tube+traj_mask'],
38
+ type=str, help='sub masked strategy of tubelet masking')
39
+
40
+ parser.add_argument('--mask_ratio', default=0.75, type=float,
41
+ help='ratio of the visual tokens/patches need be masked')
42
+
43
+ parser.add_argument('--input_size', default=224, type=int,
44
+ help='videos input size for backbone')
45
+
46
+ parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT',
47
+ help='Drop path rate (default: 0.1)')
48
+
49
+ parser.add_argument('--normlize_target', default=True, type=bool,
50
+ help='normalized the target patch pixels')
51
+
52
+ # Optimizer parameters
53
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
54
+ help='Optimizer (default: "adamw"')
55
+ parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
56
+ help='Optimizer Epsilon (default: 1e-8)')
57
+ parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
58
+ help='Optimizer Betas (default: None, use opt default)')
59
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
60
+ help='Clip gradient norm (default: None, no clipping)')
61
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
62
+ help='SGD momentum (default: 0.9)')
63
+ parser.add_argument('--weight_decay', type=float, default=0.05,
64
+ help='weight decay (default: 0.05)')
65
+ parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
66
+ weight decay. We use a cosine schedule for WD.
67
+ (Set the same value with args.weight_decay to keep weight decay no change)""")
68
+
69
+ parser.add_argument('--lr', type=float, default=1.5e-4, metavar='LR',
70
+ help='learning rate (default: 1.5e-4)')
71
+ parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
72
+ help='warmup learning rate (default: 1e-6)')
73
+ parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
74
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
75
+
76
+ parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
77
+ help='epochs to warmup LR, if scheduler supports')
78
+ parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
79
+ help='epochs to warmup LR, if scheduler supports')
80
+ parser.add_argument('--use_checkpoint', action='store_true')
81
+ parser.set_defaults(use_checkpoint=False)
82
+
83
+ # Augmentation parameters
84
+ parser.add_argument('--color_jitter', type=float, default=0.0, metavar='PCT',
85
+ help='Color jitter factor (default: 0.4)')
86
+ parser.add_argument('--train_interpolation', type=str, default='bicubic',
87
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
88
+
89
+ # Dataset parameters
90
+ parser.add_argument('--data_path', default='/path/to/list_kinetics-400', type=str,
91
+ help='dataset path')
92
+ parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
93
+ parser.add_argument('--num_frames', type=int, default= 16)
94
+ parser.add_argument('--sampling_rate', type=int, default= 4)
95
+ parser.add_argument('--output_dir', default='',
96
+ help='path where to save, empty for no saving')
97
+ parser.add_argument('--log_dir', default=None,
98
+ help='path where to tensorboard log')
99
+ parser.add_argument('--device', default='cuda',
100
+ help='device to use for training / testing')
101
+ parser.add_argument('--seed', default=0, type=int)
102
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
103
+ parser.add_argument('--auto_resume', action='store_true')
104
+ parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
105
+ parser.set_defaults(auto_resume=True)
106
+
107
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
108
+ help='start epoch')
109
+ parser.add_argument('--num_workers', default=10, type=int)
110
+ parser.add_argument('--pin_mem', action='store_true',
111
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
112
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
113
+ help='')
114
+ parser.set_defaults(pin_mem=True)
115
+
116
+ # distributed training parameters
117
+ parser.add_argument('--world_size', default=1, type=int,
118
+ help='number of distributed processes')
119
+ parser.add_argument('--local-rank', default=-1, type=int)
120
+ parser.add_argument('--dist_on_itp', action='store_true')
121
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
122
+
123
+ # Tubelet params
124
+ parser.add_argument('--add_tubelets', action='store_true')
125
+ parser.set_defaults(add_tubelets=False)
126
+ parser.add_argument('--use_objects', action='store_true')
127
+ parser.set_defaults(use_objects=False)
128
+ parser.add_argument('--objects_path', type=str, default=None)
129
+ parser.add_argument('--motion_type', type=str, default='gaussian')
130
+ parser.add_argument('--scales', type=str, default='[32, 48, 56, 64, 96, 128]')
131
+ parser.add_argument('--visible_frames', type=str, default=None) # not used
132
+ parser.add_argument('--traj_unmask_ratio', type=float, default=0.1)
133
+
134
+ #dino params
135
+ parser.add_argument('--target_type', default='pixel', choices=['pixel', 'dino_v1', 'clip'], type=str, help='define target type for loss')
136
+ parser.add_argument('--distillation_teacher', default="clip_b", type=str, choices=['dino_s', 'dino_b', 'clip_b'], help='distillation teacher model')
137
+
138
+ # multiple sampling
139
+ parser.add_argument('--multiple_sampling', action='store_true')
140
+ # for 2nd stage training
141
+ parser.add_argument('--first_stage_path', type=str, default=None)
142
+
143
+ return parser.parse_args()
144
+
145
+
146
+
147
+ def get_teacher_student_models(args):
148
+ print(f"Creating model: {args.model}")
149
+ if args.target_type=='pixel':
150
+ dec_dim = 1536
151
+ elif 'dino' in args.target_type or 'clip' in args.target_type:
152
+ if args.distillation_teacher == 'dino_s':
153
+ dec_dim = 384
154
+ elif args.distillation_teacher == 'dino_b' or args.distillation_teacher == 'clip_b':
155
+
156
+ dec_dim = 768
157
+
158
+ student_model = create_model(
159
+ args.model,
160
+ pretrained=False,
161
+ drop_path_rate=args.drop_path,
162
+ drop_block_rate=None,
163
+ decoder_depth=args.decoder_depth,
164
+ use_checkpoint=args.use_checkpoint,
165
+ decoder_num_classes=dec_dim,
166
+ )
167
+
168
+ if args.target_type == 'dino_v1':
169
+
170
+ # load dino
171
+ if args.distillation_teacher == 'dino_s':
172
+ pretraining = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
173
+ teacher_model = vit_small_patch16_224(pretrained=False)
174
+ elif args.distillation_teacher == 'dino_b':
175
+ pretraining = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
176
+ teacher_model = vit_base_patch16_224(pretrained=False)
177
+
178
+ msg =teacher_model.load_state_dict(pretraining.state_dict(), strict=False)
179
+ teacher_model = FeatureExtractor(teacher_model, args.input_size, 16)
180
+ print(msg)
181
+ teacher_model.eval()
182
+
183
+ elif args.target_type == 'clip':
184
+
185
+ # load clip
186
+ from utils_viclip.config import Config
187
+ from utils_viclip.config_utils import setup_viclip
188
+ from tasks.shared_utils import setup_model
189
+ from models_viclip.viclip import ViCLIP
190
+
191
+ config = setup_viclip('configs/config.py')
192
+ model_cls = eval(config.model.get('model_cls', 'ViCLIP'))
193
+ teacher_model = setup_model(
194
+ config,
195
+ model_cls=model_cls,
196
+ has_decoder=False,
197
+ pretrain=False,
198
+ find_unused_parameters=False,
199
+ )
200
+ teacher_model.eval()
201
+ else:
202
+ teacher_model = None
203
+
204
+
205
+ return student_model, teacher_model
206
+
207
+ def load_first_stage(model,args):
208
+ if args.first_stage_path is not None:
209
+ checkpoint = torch.load(args.first_stage_path, map_location='cpu')
210
+ print("loading first stage from ",args.first_stage_path)
211
+ checkpoint_model = checkpoint['model']
212
+ utils.load_state_dict(model, checkpoint_model)
213
+
214
+
215
+ def main(args):
216
+ utils.init_distributed_mode(args)
217
+
218
+ print(args)
219
+
220
+ device = torch.device(args.device)
221
+
222
+ # fix the seed for reproducibility
223
+ seed = args.seed + utils.get_rank()
224
+ torch.manual_seed(seed)
225
+ np.random.seed(seed)
226
+
227
+ cudnn.benchmark = True
228
+
229
+ student_model, teacher_model = get_teacher_student_models(args)
230
+
231
+ patch_size = student_model.encoder.patch_embed.patch_size
232
+ print("Patch size = %s" % str(patch_size))
233
+ args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1]) # [8, 14, 14]
234
+ print(f"Window Size = {args.window_size}")
235
+ args.patch_size = patch_size
236
+
237
+
238
+ # Start from pretrained first stage model
239
+ if args.first_stage_path is not None:
240
+ load_first_stage(student_model,args)
241
+
242
+ # get dataset
243
+ dataset_train = build_pretraining_dataset(args)
244
+
245
+
246
+ num_tasks = utils.get_world_size()
247
+ global_rank = utils.get_rank()
248
+ sampler_rank = global_rank
249
+
250
+ total_batch_size = args.batch_size * num_tasks
251
+ num_training_steps_per_epoch = len(dataset_train) // total_batch_size
252
+
253
+ sampler_train = torch.utils.data.DistributedSampler(
254
+ dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True
255
+ )
256
+ print("Sampler_train = %s" % str(sampler_train))
257
+
258
+
259
+ if global_rank == 0 and args.log_dir is not None:
260
+ os.makedirs(args.log_dir, exist_ok=True)
261
+ log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
262
+ else:
263
+ log_writer = None
264
+
265
+ data_loader_train = torch.utils.data.DataLoader(
266
+ dataset_train, sampler=sampler_train,
267
+ batch_size=args.batch_size if not args.multiple_sampling else int(args.batch_size/2),
268
+ num_workers=args.num_workers,
269
+ pin_memory=args.pin_mem,
270
+ drop_last=True,
271
+ worker_init_fn=utils.seed_worker
272
+ )
273
+
274
+ student_model.to(device)
275
+ if teacher_model is not None:
276
+ teacher_model.to(device)
277
+ model_without_ddp = student_model
278
+ n_parameters = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
279
+
280
+ print("Model = %s" % str(model_without_ddp))
281
+ print('number of params: {} M'.format(n_parameters / 1e6))
282
+
283
+ args.lr = args.lr * total_batch_size / 256
284
+ args.min_lr = args.min_lr * total_batch_size / 256
285
+ args.warmup_lr = args.warmup_lr * total_batch_size / 256
286
+ print("LR = %.8f" % args.lr)
287
+ print("Batch size = %d" % total_batch_size)
288
+ print("Number of training steps = %d" % num_training_steps_per_epoch)
289
+ print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch))
290
+
291
+ if args.distributed:
292
+ student_model = torch.nn.parallel.DistributedDataParallel(student_model, device_ids=[args.gpu], find_unused_parameters=False)
293
+ model_without_ddp = student_model.module
294
+
295
+ optimizer = create_optimizer(
296
+ args, model_without_ddp)
297
+ loss_scaler = NativeScaler()
298
+
299
+ print("Use step level LR & WD scheduler!")
300
+ lr_schedule_values = utils.cosine_scheduler(
301
+ args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
302
+ warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
303
+ )
304
+ if args.weight_decay_end is None:
305
+ args.weight_decay_end = args.weight_decay
306
+ wd_schedule_values = utils.cosine_scheduler(
307
+ args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
308
+ print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
309
+
310
+ utils.auto_load_model(
311
+ args=args, model=student_model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
312
+ torch.cuda.empty_cache()
313
+ print(f"Start training for {args.epochs} epochs")
314
+ start_time = time.time()
315
+ for epoch in range(args.start_epoch, args.epochs):
316
+ if args.distributed:
317
+ data_loader_train.sampler.set_epoch(epoch)
318
+ if log_writer is not None:
319
+ log_writer.set_step(epoch * num_training_steps_per_epoch)
320
+ train_stats = train_one_epoch(
321
+ student_model, data_loader_train,
322
+ optimizer, device, epoch, loss_scaler,
323
+ args.clip_grad, log_writer=log_writer,
324
+ start_steps=epoch * num_training_steps_per_epoch,
325
+ lr_schedule_values=lr_schedule_values,
326
+ wd_schedule_values=wd_schedule_values,
327
+ patch_size=patch_size[0],
328
+ normlize_target=args.normlize_target,
329
+ teacher_model = teacher_model,
330
+ target_type=args.target_type,
331
+ multiple_sampling=args.multiple_sampling,
332
+ )
333
+ if args.output_dir:
334
+ if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
335
+ utils.save_model(
336
+ args=args, model=student_model, model_without_ddp=model_without_ddp, optimizer=optimizer,
337
+ loss_scaler=loss_scaler, epoch=epoch)
338
+
339
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
340
+ 'epoch': epoch, 'n_parameters': n_parameters}
341
+
342
+ if args.output_dir and utils.is_main_process():
343
+ if log_writer is not None:
344
+ log_writer.flush()
345
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
346
+ f.write(json.dumps(log_stats) + "\n")
347
+ #if (epoch + 1) % 2 == 0:
348
+ #exit(0)
349
+
350
+ total_time = time.time() - start_time
351
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
352
+ print('Training time {}'.format(total_time_str))
353
+
354
+
355
+ if __name__ == '__main__':
356
+ opts = get_args()
357
+ if opts.output_dir:
358
+ Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
359
+ main(opts)
run_videomae_vis.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+ import torch.backends.cudnn as cudnn
6
+ from PIL import Image
7
+ from pathlib import Path
8
+ from timm.models import create_model
9
+ import utils
10
+ import modeling_pretrain
11
+ from datasets import DataAugmentationForVideoMAE
12
+ from torchvision.transforms import ToPILImage
13
+ from einops import rearrange
14
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
15
+ from decord import VideoReader, cpu
16
+ from torchvision import transforms
17
+ from transforms import *
18
+ # from datasets import DataAugmentationForVideoMAE
19
+ from masking_generator import TubeMaskingGenerator
20
+ class DataAugmentationForVideoMAE(object):
21
+ def __init__(self, args):
22
+ self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
23
+ self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
24
+ normalize = GroupNormalize(self.input_mean, self.input_std)
25
+ self.train_augmentation = GroupCenterCrop(args.input_size)
26
+ self.transform = transforms.Compose([
27
+ self.train_augmentation,
28
+ Stack(roll=False),
29
+ ToTorchFormatTensor(div=True),
30
+ normalize,
31
+ ])
32
+ if args.mask_type == 'tube':
33
+ self.masked_position_generator = TubeMaskingGenerator(
34
+ args.window_size, args.mask_ratio
35
+ )
36
+
37
+ def __call__(self, images):
38
+ process_data , _ = self.transform(images)
39
+ return process_data, self.masked_position_generator()
40
+
41
+ def __repr__(self):
42
+ repr = "(DataAugmentationForVideoMAE,\n"
43
+ repr += " transform = %s,\n" % str(self.transform)
44
+ repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
45
+ repr += ")"
46
+ return repr
47
+
48
+ def get_args():
49
+ parser = argparse.ArgumentParser('VideoMAE visualization reconstruction script', add_help=False)
50
+ parser.add_argument('img_path', type=str, help='input video path')
51
+ parser.add_argument('save_path', type=str, help='save video path')
52
+ parser.add_argument('model_path', type=str, help='checkpoint path of model')
53
+ parser.add_argument('--mask_type', default='tube', choices=['random', 'tube', 'tubelet'],
54
+ type=str, help='masked strategy of video tokens/patches')
55
+ parser.add_argument('--num_frames', type=int, default= 16)
56
+ parser.add_argument('--sampling_rate', type=int, default= 4)
57
+ parser.add_argument('--decoder_depth', default=4, type=int,
58
+ help='depth of decoder')
59
+ parser.add_argument('--input_size', default=224, type=int,
60
+ help='videos input size for backbone')
61
+ parser.add_argument('--device', default='cuda:0',
62
+ help='device to use for training / testing')
63
+ parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
64
+ parser.add_argument('--mask_ratio', default=0.75, type=float,
65
+ help='ratio of the visual tokens/patches need be masked')
66
+ # Model parameters
67
+ parser.add_argument('--model', default='pretrain_videomae_small_patch16_224', type=str, metavar='MODEL',
68
+ help='Name of model to vis')
69
+ parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT',
70
+ help='Drop path rate (default: 0.1)')
71
+
72
+ # Tubelet params
73
+ parser.add_argument('--add_tubelets', action='store_true')
74
+ parser.set_defaults(add_tubelets=True)
75
+ parser.add_argument('--use_objects', action='store_true')
76
+ parser.set_defaults(use_objects=True)
77
+ parser.add_argument('--motion_type', type=str, default='gaussian')
78
+ parser.add_argument('--scales', type=str, default='[32, 48, 56, 64, 96, 128]')
79
+ parser.add_argument('--loc_velocity', type=int, default=12)
80
+ parser.add_argument('--mixed_tubelet', action='store_true')
81
+ parser.set_defaults(mixed_tubelet=False)
82
+ parser.add_argument('--visible_frames', type=str, default=None)
83
+
84
+
85
+ return parser.parse_args()
86
+
87
+
88
+ def get_model(args):
89
+ print(f"Creating model: {args.model}")
90
+ model = create_model(
91
+ args.model,
92
+ pretrained=False,
93
+ drop_path_rate=args.drop_path,
94
+ drop_block_rate=None,
95
+ decoder_depth=args.decoder_depth
96
+ )
97
+
98
+ return model
99
+
100
+
101
+ def main(args):
102
+ print(args)
103
+
104
+ device = torch.device(args.device)
105
+ cudnn.benchmark = True
106
+
107
+ model = get_model(args)
108
+ patch_size = model.encoder.patch_embed.patch_size
109
+ print("Patch size = %s" % str(patch_size))
110
+ args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1])
111
+ args.patch_size = patch_size
112
+
113
+ model.to(device)
114
+ checkpoint = torch.load(args.model_path, map_location='cpu')
115
+ model.load_state_dict(checkpoint['model'])
116
+ model.eval()
117
+
118
+ if args.save_path:
119
+ Path(args.save_path).mkdir(parents=True, exist_ok=True)
120
+
121
+ with open(args.img_path, 'rb') as f:
122
+ vr = VideoReader(f, ctx=cpu(0))
123
+ duration = len(vr)
124
+ new_length = 1
125
+ new_step = 1
126
+ skip_length = new_length * new_step
127
+ # frame_id_list = [1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61]
128
+
129
+
130
+ tmp = np.arange(0,32, 2) + 60
131
+ frame_id_list = tmp.tolist()
132
+ # average_duration = (duration - skip_length + 1) // args.num_frames
133
+ # if average_duration > 0:
134
+ # frame_id_list = np.multiply(list(range(args.num_frames)),
135
+ # average_duration)
136
+ # frame_id_list = frame_id_list + np.random.randint(average_duration,
137
+ # size=args.num_frames)
138
+
139
+ video_data = vr.get_batch(frame_id_list).asnumpy()
140
+ print(video_data.shape)
141
+ img = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]
142
+
143
+ transforms = DataAugmentationForVideoMAE(args)
144
+ img, bool_masked_pos = transforms((img, None)) # T*C,H,W
145
+ # print(img.shape)
146
+ img = img.view((args.num_frames , 3) + img.size()[-2:]).transpose(0,1) # T*C,H,W -> T,C,H,W -> C,T,H,W
147
+ # img = img.view(( -1 , args.num_frames) + img.size()[-2:])
148
+ bool_masked_pos = torch.from_numpy(bool_masked_pos)
149
+
150
+ with torch.no_grad():
151
+ # img = img[None, :]
152
+ # bool_masked_pos = bool_masked_pos[None, :]
153
+ img = img.unsqueeze(0)
154
+ print(img.shape)
155
+ bool_masked_pos = bool_masked_pos.unsqueeze(0)
156
+
157
+ img = img.to(device, non_blocking=True)
158
+ bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
159
+ outputs = model(img, bool_masked_pos)
160
+
161
+ #save original video
162
+ mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
163
+ std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
164
+ ori_img = img * std + mean # in [0, 1]
165
+ imgs = [ToPILImage()(ori_img[0,:,vid,:,:].cpu()) for vid, _ in enumerate(frame_id_list) ]
166
+ for id, im in enumerate(imgs):
167
+ im.save(f"{args.save_path}/ori_img{id}.jpg")
168
+
169
+ img_squeeze = rearrange(ori_img, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', p0=2, p1=patch_size[0], p2=patch_size[0])
170
+ img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / (img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
171
+ img_patch = rearrange(img_norm, 'b n p c -> b n (p c)')
172
+ img_patch[bool_masked_pos] = outputs
173
+
174
+ #make mask
175
+ mask = torch.ones_like(img_patch)
176
+ mask[bool_masked_pos] = 0
177
+ mask = rearrange(mask, 'b n (p c) -> b n p c', c=3)
178
+ mask = rearrange(mask, 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2) ', p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14)
179
+
180
+ #save reconstruction video
181
+ rec_img = rearrange(img_patch, 'b n (p c) -> b n p c', c=3)
182
+ # Notice: To visualize the reconstruction video, we add the predict and the original mean and var of each patch.
183
+ rec_img = rec_img * (img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) + img_squeeze.mean(dim=-2, keepdim=True)
184
+ rec_img = rearrange(rec_img, 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)', p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14)
185
+ imgs = [ ToPILImage()(rec_img[0, :, vid, :, :].cpu().clamp(0,0.996)) for vid, _ in enumerate(frame_id_list) ]
186
+
187
+ for id, im in enumerate(imgs):
188
+ im.save(f"{args.save_path}/rec_img{id}.jpg")
189
+
190
+ #save masked video
191
+ img_mask = rec_img * mask
192
+ imgs = [ToPILImage()(img_mask[0, :, vid, :, :].cpu()) for vid, _ in enumerate(frame_id_list)]
193
+ for id, im in enumerate(imgs):
194
+ im.save(f"{args.save_path}/mask_img{id}.jpg")
195
+
196
+ if __name__ == '__main__':
197
+ opts = get_args()
198
+ main(opts)
ssv2.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torchvision import transforms
5
+ from random_erasing import RandomErasing
6
+ import warnings
7
+ from decord import VideoReader, cpu
8
+ from torch.utils.data import Dataset
9
+ import video_transforms as video_transforms
10
+ import volume_transforms as volume_transforms
11
+
12
+
13
+ class SSVideoClsDataset(Dataset):
14
+ """Load your own video classification dataset."""
15
+
16
+ def __init__(self, anno_path, data_path, mode='train', clip_len=8,
17
+ crop_size=224, short_side_size=256, new_height=256,
18
+ new_width=340, keep_aspect_ratio=True, num_segment=1,
19
+ num_crop=1, test_num_segment=10, test_num_crop=3, args=None):
20
+ self.anno_path = anno_path
21
+ self.data_path = data_path
22
+ self.mode = mode
23
+ self.clip_len = clip_len
24
+ self.crop_size = crop_size
25
+ self.short_side_size = short_side_size
26
+ self.new_height = new_height
27
+ self.new_width = new_width
28
+ self.keep_aspect_ratio = keep_aspect_ratio
29
+ self.num_segment = num_segment
30
+ self.test_num_segment = test_num_segment
31
+ self.num_crop = num_crop
32
+ self.test_num_crop = test_num_crop
33
+ self.args = args
34
+ self.aug = False
35
+ self.rand_erase = False
36
+ if self.mode in ['train']:
37
+ self.aug = True
38
+ if self.args.reprob > 0:
39
+ self.rand_erase = True
40
+ if VideoReader is None:
41
+ raise ImportError("Unable to import `decord` which is required to read videos.")
42
+
43
+ import pandas as pd
44
+ cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ')
45
+ self.dataset_samples = list(cleaned.values[:, 0])
46
+ self.label_array = list(cleaned.values[:, 1])
47
+
48
+ if (mode == 'train'):
49
+ pass
50
+
51
+ elif (mode == 'validation'):
52
+ self.data_transform = video_transforms.Compose([
53
+ video_transforms.Resize(self.short_side_size, interpolation='bilinear'),
54
+ video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)),
55
+ volume_transforms.ClipToTensor(),
56
+ video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
57
+ std=[0.229, 0.224, 0.225])
58
+ ])
59
+ elif mode == 'test':
60
+ self.data_resize = video_transforms.Compose([
61
+ video_transforms.Resize(size=(short_side_size), interpolation='bilinear')
62
+ ])
63
+ self.data_transform = video_transforms.Compose([
64
+ volume_transforms.ClipToTensor(),
65
+ video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
66
+ std=[0.229, 0.224, 0.225])
67
+ ])
68
+ self.test_seg = []
69
+ self.test_dataset = []
70
+ self.test_label_array = []
71
+ for ck in range(self.test_num_segment):
72
+ for cp in range(self.test_num_crop):
73
+ for idx in range(len(self.label_array)):
74
+ sample_label = self.label_array[idx]
75
+ self.test_label_array.append(sample_label)
76
+ self.test_dataset.append(self.dataset_samples[idx])
77
+ self.test_seg.append((ck, cp))
78
+
79
+ def __getitem__(self, index):
80
+ if self.mode == 'train':
81
+ args = self.args
82
+ scale_t = 1
83
+
84
+ sample = self.dataset_samples[index]
85
+ buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C
86
+ if len(buffer) == 0:
87
+ while len(buffer) == 0:
88
+ warnings.warn("video {} not correctly loaded during training".format(sample))
89
+ index = np.random.randint(self.__len__())
90
+ sample = self.dataset_samples[index]
91
+ buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t)
92
+
93
+ if args.num_sample > 1:
94
+ frame_list = []
95
+ label_list = []
96
+ index_list = []
97
+ for _ in range(args.num_sample):
98
+ new_frames = self._aug_frame(buffer, args)
99
+ label = self.label_array[index]
100
+ frame_list.append(new_frames)
101
+ label_list.append(label)
102
+ index_list.append(index)
103
+ return frame_list, label_list, index_list, {}
104
+ else:
105
+ buffer = self._aug_frame(buffer, args)
106
+
107
+ return buffer, self.label_array[index], index, {}
108
+
109
+ elif self.mode == 'validation':
110
+ sample = self.dataset_samples[index]
111
+ buffer = self.loadvideo_decord(sample)
112
+ if len(buffer) == 0:
113
+ while len(buffer) == 0:
114
+ warnings.warn("video {} not correctly loaded during validation".format(sample))
115
+ index = np.random.randint(self.__len__())
116
+ sample = self.dataset_samples[index]
117
+ buffer = self.loadvideo_decord(sample)
118
+ buffer = self.data_transform(buffer)
119
+ return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
120
+
121
+ elif self.mode == 'test':
122
+ sample = self.test_dataset[index]
123
+ chunk_nb, split_nb = self.test_seg[index]
124
+ buffer = self.loadvideo_decord(sample)
125
+
126
+ while len(buffer) == 0:
127
+ warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
128
+ str(self.test_dataset[index]), chunk_nb, split_nb))
129
+ index = np.random.randint(self.__len__())
130
+ sample = self.test_dataset[index]
131
+ chunk_nb, split_nb = self.test_seg[index]
132
+ buffer = self.loadvideo_decord(sample)
133
+
134
+ buffer = self.data_resize(buffer)
135
+ if isinstance(buffer, list):
136
+ buffer = np.stack(buffer, 0)
137
+
138
+ spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
139
+ / (self.test_num_crop - 1)
140
+ temporal_start = chunk_nb # 0/1
141
+ spatial_start = int(split_nb * spatial_step)
142
+ if buffer.shape[1] >= buffer.shape[2]:
143
+ buffer = buffer[temporal_start::2, \
144
+ spatial_start:spatial_start + self.short_side_size, :, :]
145
+ else:
146
+ buffer = buffer[temporal_start::2, \
147
+ :, spatial_start:spatial_start + self.short_side_size, :]
148
+
149
+ buffer = self.data_transform(buffer)
150
+ return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
151
+ chunk_nb, split_nb
152
+ else:
153
+ raise NameError('mode {} unkown'.format(self.mode))
154
+
155
+ def _aug_frame(
156
+ self,
157
+ buffer,
158
+ args,
159
+ ):
160
+
161
+ aug_transform = video_transforms.create_random_augment(
162
+ input_size=(self.crop_size, self.crop_size),
163
+ auto_augment=args.aa,
164
+ interpolation=args.train_interpolation,
165
+ )
166
+
167
+ buffer = [
168
+ transforms.ToPILImage()(frame) for frame in buffer
169
+ ]
170
+
171
+ buffer = aug_transform(buffer)
172
+
173
+ buffer = [transforms.ToTensor()(img) for img in buffer]
174
+ buffer = torch.stack(buffer) # T C H W
175
+ buffer = buffer.permute(0, 2, 3, 1) # T H W C
176
+
177
+ # T H W C
178
+ buffer = tensor_normalize(
179
+ buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
180
+ )
181
+ # T H W C -> C T H W.
182
+ buffer = buffer.permute(3, 0, 1, 2)
183
+ # Perform data augmentation.
184
+ scl, asp = (
185
+ [0.25, 1.0],
186
+ [0.75, 1.3333],
187
+ )
188
+
189
+ buffer = spatial_sampling(
190
+ buffer,
191
+ spatial_idx=-1,
192
+ min_scale=256,
193
+ max_scale=320,
194
+ crop_size=self.crop_size,
195
+ random_horizontal_flip=False if args.data_set == 'SSV2' else True,
196
+ inverse_uniform_sampling=False,
197
+ aspect_ratio=asp,
198
+ scale=scl,
199
+ motion_shift=False
200
+ )
201
+
202
+ if self.rand_erase:
203
+ erase_transform = RandomErasing(
204
+ args.reprob,
205
+ mode=args.remode,
206
+ max_count=args.recount,
207
+ num_splits=args.recount,
208
+ device="cpu",
209
+ )
210
+ buffer = buffer.permute(1, 0, 2, 3)
211
+ buffer = erase_transform(buffer)
212
+ buffer = buffer.permute(1, 0, 2, 3)
213
+
214
+ return buffer
215
+
216
+
217
+ def loadvideo_decord(self, sample, sample_rate_scale=1):
218
+ """Load video content using Decord"""
219
+ fname = sample
220
+
221
+ if not (os.path.exists(fname)):
222
+ return []
223
+
224
+ # avoid hanging issue
225
+ if os.path.getsize(fname) < 1 * 1024:
226
+ print('SKIP: ', fname, " - ", os.path.getsize(fname))
227
+ return []
228
+ try:
229
+ if self.keep_aspect_ratio:
230
+ vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
231
+ else:
232
+ vr = VideoReader(fname, width=self.new_width, height=self.new_height,
233
+ num_threads=1, ctx=cpu(0))
234
+ except:
235
+ print("video cannot be loaded by decord: ", fname)
236
+ return []
237
+
238
+ if self.mode == 'test':
239
+ all_index = []
240
+ tick = len(vr) / float(self.num_segment)
241
+ all_index = list(np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segment)] +
242
+ [int(tick * x) for x in range(self.num_segment)]))
243
+ while len(all_index) < (self.num_segment * self.test_num_segment):
244
+ all_index.append(all_index[-1])
245
+ all_index = list(np.sort(np.array(all_index)))
246
+ vr.seek(0)
247
+ buffer = vr.get_batch(all_index).asnumpy()
248
+ return buffer
249
+
250
+ # handle temporal segments
251
+ average_duration = len(vr) // self.num_segment
252
+ all_index = []
253
+ if average_duration > 0:
254
+ all_index += list(np.multiply(list(range(self.num_segment)), average_duration) + np.random.randint(average_duration,
255
+ size=self.num_segment))
256
+ elif len(vr) > self.num_segment:
257
+ all_index += list(np.sort(np.random.randint(len(vr), size=self.num_segment)))
258
+ else:
259
+ all_index += list(np.zeros((self.num_segment,)))
260
+ all_index = list(np.array(all_index))
261
+ vr.seek(0)
262
+ buffer = vr.get_batch(all_index).asnumpy()
263
+ return buffer
264
+
265
+ def __len__(self):
266
+ if self.mode != 'test':
267
+ return len(self.dataset_samples)
268
+ else:
269
+ return len(self.test_dataset)
270
+
271
+
272
+ def spatial_sampling(
273
+ frames,
274
+ spatial_idx=-1,
275
+ min_scale=256,
276
+ max_scale=320,
277
+ crop_size=224,
278
+ random_horizontal_flip=True,
279
+ inverse_uniform_sampling=False,
280
+ aspect_ratio=None,
281
+ scale=None,
282
+ motion_shift=False,
283
+ ):
284
+ """
285
+ Perform spatial sampling on the given video frames. If spatial_idx is
286
+ -1, perform random scale, random crop, and random flip on the given
287
+ frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
288
+ with the given spatial_idx.
289
+ Args:
290
+ frames (tensor): frames of images sampled from the video. The
291
+ dimension is `num frames` x `height` x `width` x `channel`.
292
+ spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
293
+ or 2, perform left, center, right crop if width is larger than
294
+ height, and perform top, center, buttom crop if height is larger
295
+ than width.
296
+ min_scale (int): the minimal size of scaling.
297
+ max_scale (int): the maximal size of scaling.
298
+ crop_size (int): the size of height and width used to crop the
299
+ frames.
300
+ inverse_uniform_sampling (bool): if True, sample uniformly in
301
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
302
+ scale. If False, take a uniform sample from [min_scale,
303
+ max_scale].
304
+ aspect_ratio (list): Aspect ratio range for resizing.
305
+ scale (list): Scale range for resizing.
306
+ motion_shift (bool): Whether to apply motion shift for resizing.
307
+ Returns:
308
+ frames (tensor): spatially sampled frames.
309
+ """
310
+ assert spatial_idx in [-1, 0, 1, 2]
311
+ if spatial_idx == -1:
312
+ if aspect_ratio is None and scale is None:
313
+ frames, _ = video_transforms.random_short_side_scale_jitter(
314
+ images=frames,
315
+ min_size=min_scale,
316
+ max_size=max_scale,
317
+ inverse_uniform_sampling=inverse_uniform_sampling,
318
+ )
319
+ frames, _ = video_transforms.random_crop(frames, crop_size)
320
+ else:
321
+ transform_func = (
322
+ video_transforms.random_resized_crop_with_shift
323
+ if motion_shift
324
+ else video_transforms.random_resized_crop
325
+ )
326
+ frames = transform_func(
327
+ images=frames,
328
+ target_height=crop_size,
329
+ target_width=crop_size,
330
+ scale=scale,
331
+ ratio=aspect_ratio,
332
+ )
333
+ if random_horizontal_flip:
334
+ frames, _ = video_transforms.horizontal_flip(0.5, frames)
335
+ else:
336
+ # The testing is deterministic and no jitter should be performed.
337
+ # min_scale, max_scale, and crop_size are expect to be the same.
338
+ assert len({min_scale, max_scale, crop_size}) == 1
339
+ frames, _ = video_transforms.random_short_side_scale_jitter(
340
+ frames, min_scale, max_scale
341
+ )
342
+ frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx)
343
+ return frames
344
+
345
+
346
+ def tensor_normalize(tensor, mean, std):
347
+ """
348
+ Normalize a given tensor by subtracting the mean and dividing the std.
349
+ Args:
350
+ tensor (tensor): tensor to normalize.
351
+ mean (tensor or list): mean value to subtract.
352
+ std (tensor or list): std to divide.
353
+ """
354
+ if tensor.dtype == torch.uint8:
355
+ tensor = tensor.float()
356
+ tensor = tensor / 255.0
357
+ if type(mean) == list:
358
+ mean = torch.tensor(mean)
359
+ if type(std) == list:
360
+ std = torch.tensor(std)
361
+ tensor = tensor - mean
362
+ tensor = tensor / std
363
+ return tensor
synthetic_tubelets.py ADDED
@@ -0,0 +1,785 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+ import random
4
+ import numpy as np
5
+ import random
6
+ import cv2
7
+ from typing import List
8
+ from PIL import Image
9
+
10
+ from dynamic_utils import (extend_key_frame_to_all,
11
+ sample_key_frames)
12
+ import imutils
13
+ import math
14
+ from scipy.ndimage import gaussian_filter1d
15
+ from glob import glob
16
+
17
+
18
+ class RandomRegionSampler(object):
19
+
20
+ def __init__(self,
21
+ num_rois: int,
22
+ scales: tuple,
23
+ ratios: tuple,
24
+ scale_jitter: float):
25
+ """ Randomly sample several RoIs
26
+
27
+ Args:
28
+ num_rois (int): number of sampled RoIs per image
29
+ scales (tuple): scales of candidate bounding boxes
30
+ ratios (tuple): aspect ratios of candidate bounding boxes
31
+ scale_jitter (float): scale jitter factor, positive number
32
+ """
33
+
34
+ self.num_rois = num_rois
35
+ self.scale_jitter = scale_jitter
36
+
37
+ scales = np.array(scales, np.float32)
38
+ ratios = np.array(ratios, np.float32)
39
+ widths = scales.reshape(1, -1) * np.sqrt(ratios).reshape(-1, 1)
40
+ heights = scales.reshape(1, -1) / np.sqrt(ratios).reshape(-1, 1)
41
+ self.anchors = np.concatenate((widths.reshape(-1, 1),
42
+ heights.reshape(-1, 1)), axis=-1)
43
+
44
+ def sample(self, data: List[np.ndarray]) -> np.ndarray:
45
+ """ Sample boxes.
46
+
47
+ Args:
48
+ data (list): image list, each element is a numpy.ndarray
49
+ in shape of [H, W, 3]
50
+
51
+ Returns:
52
+ boxes (np.ndarray): the sampled bounding boxes. in shape of
53
+ [self.num_rois, 4], represented in (x1, y1, x2, y2).
54
+
55
+ """
56
+ h, w = data[0].shape[0:2]
57
+
58
+ # random sample box shapes
59
+ anchor_inds = np.random.randint(0, len(self.anchors),
60
+ size=(self.num_rois, ))
61
+ box_shapes = self.anchors[anchor_inds].copy()
62
+ if self.scale_jitter is not None:
63
+ scale_factors = np.random.uniform(-self.scale_jitter,
64
+ self.scale_jitter,
65
+ size=(self.num_rois, 2))
66
+ box_shapes = box_shapes * np.exp(scale_factors)
67
+ box_shapes[:, 0] = np.clip(box_shapes[:, 0], 1, w - 1)
68
+ box_shapes[:, 1] = np.clip(box_shapes[:, 1], 1, h - 1)
69
+
70
+ #print("box shapes",box_shapes,box_shapes.shape)
71
+ # random sample box x1, y1
72
+ x1 = np.random.uniform(0, w - box_shapes[:, 0])
73
+ y1 = np.random.uniform(0, h - box_shapes[:, 1])
74
+ #print("x1, y1",x1,y1)
75
+ boxes = np.concatenate((x1.reshape(-1, 1),
76
+ y1.reshape(-1, 1),
77
+ (x1 + box_shapes[:, 0]).reshape(-1, 1),
78
+ (y1 + box_shapes[:, 1]).reshape(-1, 1)),
79
+ axis=1)
80
+ #print("sampled initial boxes",boxes)
81
+
82
+ return boxes
83
+
84
+ def sample_box_shapes(self, data: List[np.ndarray]) -> np.ndarray:
85
+ """ Sample boxes.
86
+
87
+ Args:
88
+ data (list): image list, each element is a numpy.ndarray
89
+ in shape of [H, W, 3]
90
+
91
+ Returns:
92
+ boxes (np.ndarray): the sampled bounding boxes. in shape of
93
+ [self.num_rois, 4], represented in (x1, y1, x2, y2).
94
+
95
+ """
96
+ h, w = data[0].shape[0:2]
97
+
98
+ # random sample box shapes
99
+ anchor_inds = np.random.randint(0, len(self.anchors),
100
+ size=(self.num_rois, ))
101
+ box_shapes = self.anchors[anchor_inds].copy()
102
+ if self.scale_jitter is not None:
103
+ scale_factors = np.random.uniform(-self.scale_jitter,
104
+ self.scale_jitter,
105
+ size=(self.num_rois, 2))
106
+ box_shapes = box_shapes * np.exp(scale_factors)
107
+ box_shapes[:, 0] = np.clip(box_shapes[:, 0], 1, w - 1)
108
+ box_shapes[:, 1] = np.clip(box_shapes[:, 1], 1, h - 1)
109
+
110
+ #print(" gaussian box shapes",box_shapes)
111
+
112
+ return box_shapes
113
+
114
+
115
+ class PatchMask(object):
116
+
117
+ def __init__(self,
118
+ use_objects: bool,
119
+ objects_path: str,
120
+ region_sampler: dict,
121
+ key_frame_probs: list,
122
+ loc_velocity: float,
123
+ rot_velocity: float,
124
+ size_velocity: float,
125
+ label_prob: float,
126
+ patch_transformation: str,
127
+ motion_type: str):
128
+
129
+ """ Core transformation in Catch-the-Patch.
130
+
131
+ Args:
132
+ region_sampler (dict): region sampler setting, it will be used to
133
+ construct a RandomRegionSampler object.
134
+ key_frame_probs (list): probabilities of sampling how many key
135
+ frames. The sum of this list should be 1.
136
+ loc_velocity (float): the maximum patch movement speed. (pix per
137
+ frame).
138
+ size_velocity (float): the maximum size change ratios between two
139
+ neighbouring frames.
140
+ label_prob (float): how many percentages of frames will be
141
+ modified. Note that even the frame is not modified, we still
142
+ force the model to infer the patch positions. (see MRM module
143
+ in the paper).
144
+ """
145
+ self.region_sampler = RandomRegionSampler(**region_sampler)
146
+ self.key_frame_probs = key_frame_probs
147
+ self.loc_velocity = loc_velocity
148
+ self.rot_velocity = rot_velocity
149
+ self.size_velocity = size_velocity
150
+ self.label_prob = label_prob
151
+ if motion_type is not None:
152
+ self.motion_type = motion_type
153
+ self.patch_transformation = patch_transformation
154
+ self.use_objects = use_objects
155
+
156
+ if self.use_objects:
157
+ #self.object_list = glob("/ibex/user/jianl0b/Dataset/Fida_file_1/video_images/micheal_objects/cleaned/images/*/*")
158
+ self.object_list = glob(objects_path+"/*/*")
159
+
160
+ #self.object_list = glob("/ibex/project/c2134/Fida/micheal_objects_big/cleaned_big/images/*/*")
161
+ print(self.object_list[0:10],len(self.object_list))
162
+
163
+ def paste_objects(self, data, traj_rois, boxes):
164
+
165
+ objects_list = []
166
+ label_list = []
167
+
168
+ for i in range(len(boxes)):
169
+ objects, crop_index = self.pick_objects(data, traj_rois[i])
170
+ labels = np.random.uniform(0, 1, size=(len(data), ))
171
+ labels[crop_index] = 0.0
172
+ labels[0] = 0.0
173
+ labels = labels <= self.label_prob
174
+ objects_list.append(objects)
175
+ label_list.append(labels)
176
+
177
+ return objects_list, None, label_list
178
+
179
+ def paste_patches(self, data, traj_rois, boxes):
180
+
181
+ patches_list = []
182
+ alphas_list = []
183
+ label_list = []
184
+
185
+ for i in range(len(boxes)):
186
+ patches, crop_index = self.pick_patches(data, traj_rois[i])
187
+ alphas = self.pick_alphas(data, traj_rois[i], crop_index)
188
+ labels = np.random.uniform(0, 1, size=(len(data), ))
189
+ labels[crop_index] = 0.0
190
+ labels[0] = 0.0
191
+ labels = labels <= self.label_prob
192
+ patches_list.append(patches)
193
+ alphas_list.append(alphas)
194
+ label_list.append(labels)
195
+
196
+ return patches_list, alphas_list, label_list
197
+
198
+
199
+
200
+
201
+
202
+ def pick_patches(self,
203
+ data: List[np.ndarray],
204
+ traj_rois: np.ndarray) -> tuple:
205
+ """ Pick image patches from the raw video frame.
206
+
207
+ We just randomly select a frame index, and crop the frame according to
208
+ the trajectory rois. This cropped patch will be resized into the
209
+ suitable size specified by the traj_rois.
210
+
211
+ Args:
212
+ data (List[np.ndarray]): list of images, each element is in shape
213
+ of [H, W, 3]
214
+ traj_rois (np.ndarray): the generated trajectories, in shape of
215
+ [N_frames, 4]. (x1, y1, x2, y2)
216
+
217
+ Returns:
218
+ patches (List[np.ndarray]): the cropped patches
219
+ select_idx (int): the frame index which the source patch
220
+ cropped from.
221
+ """
222
+ traj_sizes = traj_rois[..., 2:4] - traj_rois[..., 0:2]
223
+ num = len(traj_sizes)
224
+ select_idx = random.randint(0, num - 1)
225
+ x1, y1, x2, y2 = traj_rois[select_idx]
226
+ traj_rois_H = y2 - y1
227
+ traj_rois_W = x2 - x1
228
+
229
+ img = data[select_idx]
230
+ img_H, img_W, _ = img.shape
231
+
232
+ if img_W - traj_rois_W - 1 >= 0 and img_H - traj_rois_H - 1 >= 0:
233
+ new_x1 = random.randint(0, img_W - traj_rois_W - 1)
234
+ new_y1 = random.randint(0, img_H - traj_rois_H - 1)
235
+ new_x2 = new_x1 + traj_rois_W
236
+ new_y2 = new_y1 + traj_rois_H
237
+ img = img[new_y1:new_y2, new_x1:new_x2, :]
238
+ else:
239
+ img = img
240
+ patches = [cv2.resize(img, (traj_sizes[i, 0], traj_sizes[i, 1]))
241
+ for i in range(traj_rois.shape[0])]
242
+ return patches, select_idx
243
+
244
+ def pick_objects(self,
245
+ data: List[np.ndarray],
246
+ traj_rois: np.ndarray) -> tuple:
247
+ """ Pick image patches from the raw video frame.
248
+
249
+ We just randomly select a frame index, and crop the frame according to
250
+ the trajectory rois. This cropped patch will be resized into the
251
+ suitable size specified by the traj_rois.
252
+
253
+ Args:
254
+ data (List[np.ndarray]): list of images, each element is in shape
255
+ of [H, W, 3]
256
+ traj_rois (np.ndarray): the generated trajectories, in shape of
257
+ [N_frames, 4]. (x1, y1, x2, y2)
258
+
259
+ Returns:
260
+ patches (List[np.ndarray]): the cropped patches
261
+ select_idx (int): the frame index which the source patch
262
+ cropped from.
263
+ """
264
+ traj_sizes = traj_rois[..., 2:4] - traj_rois[..., 0:2]
265
+ num = len(traj_sizes)
266
+ select_idx = random.randint(0, num - 1)
267
+ #print(len(data),traj_rois.shape)
268
+ x1, y1, x2, y2 = traj_rois[select_idx]
269
+ #print(x1, y1, x2, y2)
270
+
271
+ object_ind = random.randint(0, len(self.object_list)- 1)
272
+ object_img = Image.open(self.object_list[object_ind])
273
+ object_img = object_img.resize((x2-x1,y2-y1))
274
+
275
+ objects = [object_img.resize((traj_sizes[i, 0], traj_sizes[i, 1]))
276
+ for i in range(traj_rois.shape[0])]
277
+
278
+ return objects, select_idx
279
+
280
+
281
+
282
+ def pick_alphas(self,
283
+ data,
284
+ traj_rois: np.ndarray,
285
+ crop_index: int):
286
+ """ Generate the alpha masks for merging the patches into the raw
287
+ frames:
288
+ out_frame = raw_frame * (1 - alpha) + patch * alpha.
289
+ Despite the transparency, the alpha values are also used to mask the
290
+ patches into some predefined shapes, like ellipse or rhombus.
291
+ There are many strange constants in this function. But we do not
292
+ conduct any ablation analysis on these constants. They should have
293
+ little impact to the final performances.
294
+
295
+ Args:
296
+ data (List[np.ndarray]): list of images, each element is in shape
297
+ of [H, W, 3]
298
+ traj_rois (np.ndarray): the generated trajectories, in shape of
299
+ [N_frames, 4]. (x1, y1, x2, y2)
300
+ crop_index (int): the frame index which the source patch
301
+ cropped from.
302
+
303
+ Returns:
304
+ alphas (List[np.ndarray]): the generated alpha values
305
+
306
+ """
307
+ traj_sizes = traj_rois[..., 2:4] - traj_rois[..., 0:2]
308
+ num_frames = traj_sizes.shape[0]
309
+
310
+ base_w, base_h = traj_sizes[crop_index]
311
+
312
+ base_x_grids, base_y_grids = np.meshgrid(
313
+ np.arange(base_w).astype(np.float32),
314
+ np.arange(base_h).astype(np.float32)
315
+
316
+ )
317
+ ctr_w = (base_w - 1) // 2
318
+ ctr_h = (base_h - 1) // 2
319
+
320
+ dist_to_ctr_x = np.abs(base_x_grids - ctr_w) / base_w
321
+ dist_to_ctr_y = np.abs(base_y_grids - ctr_h) / base_h
322
+
323
+ mask_type = int(np.random.choice(3, p=[0.5, 0.35, 0.15]))
324
+ if mask_type == 0:
325
+ dist_to_ctr = np.maximum(dist_to_ctr_x, dist_to_ctr_y)
326
+ base_alpha = np.ones((base_h, base_w), np.float32)
327
+ elif mask_type == 1:
328
+ dist_to_ctr = np.sqrt(dist_to_ctr_x ** 2 + dist_to_ctr_y ** 2)
329
+ base_alpha = np.where(dist_to_ctr < 0.5,
330
+ np.ones((base_h, base_w), np.float32),
331
+ np.zeros((base_h, base_w), np.float32))
332
+ elif mask_type == 2:
333
+ dist_to_ctr = (dist_to_ctr_x + dist_to_ctr_y)
334
+ base_alpha = np.where(dist_to_ctr < 0.5,
335
+ np.ones((base_h, base_w), np.float32),
336
+ np.zeros((base_h, base_w), np.float32))
337
+ else:
338
+ raise NotImplementedError
339
+
340
+ use_smooth_edge = random.uniform(0, 1) < 0.5
341
+ if use_smooth_edge:
342
+ turning_point = random.uniform(0.30, 0.45)
343
+ k = -1 / (0.5 - turning_point)
344
+ alpha_mul = k * dist_to_ctr - 0.5 * k
345
+ alpha_mul = np.clip(alpha_mul, 0, 1)
346
+ base_alpha = base_alpha * alpha_mul
347
+
348
+ # sample key frames
349
+ key_inds = sample_key_frames(num_frames, self.key_frame_probs)
350
+ frame_alphas = np.random.uniform(0.8, 1.0, size=(len(key_inds), 1))
351
+ frame_alphas = extend_key_frame_to_all(frame_alphas, key_inds)
352
+
353
+ alphas = []
354
+ for frame_idx in range(num_frames):
355
+ w, h = traj_sizes[frame_idx]
356
+ i_alpha = cv2.resize(base_alpha, (w, h))
357
+ i_alpha = i_alpha * frame_alphas[frame_idx]
358
+ alphas.append(i_alpha)
359
+ return alphas
360
+
361
+ def get_rotation_angles(self,
362
+ num_frames,
363
+ transform_param: dict):
364
+ key_frame_probs = transform_param['key_frame_probs']
365
+ loc_key_inds = sample_key_frames(num_frames, key_frame_probs)
366
+
367
+ rot_velocity = transform_param['rot_velocity']
368
+ rot_angles = np.zeros((transform_param['traj_rois'].shape[0],1))
369
+
370
+ #print("rotation angles original",rot_angles.shape,loc_key_inds)
371
+ rot_angles_list= [np.expand_dims(rot_angles, axis=0)]
372
+ for i in range(len(loc_key_inds) - 1):
373
+ if rot_velocity > 0:
374
+ index_diff = loc_key_inds[i + 1] - loc_key_inds[i]
375
+ shifts = np.random.uniform(low=-rot_velocity* index_diff,
376
+ high=rot_velocity* index_diff,
377
+ size=rot_angles.shape)
378
+ rot_angles = rot_angles + shifts
379
+ rot_angles_list.append(np.expand_dims(rot_angles, axis=0))
380
+ rot_angles = np.concatenate(rot_angles_list, axis=0)
381
+ rot_angles = extend_key_frame_to_all(rot_angles, loc_key_inds, 'random')
382
+ rot_angles = rot_angles.transpose((1, 0, 2))
383
+
384
+
385
+ return rot_angles
386
+
387
+ def get_shear_factors(self,
388
+ num_frames,
389
+ transform_param: dict):
390
+ key_frame_probs = transform_param['key_frame_probs']
391
+ loc_key_inds = sample_key_frames(num_frames, key_frame_probs)
392
+
393
+ #print("Loc key inds shear",loc_key_inds)
394
+
395
+ rot_velocity = transform_param['rot_velocity']
396
+ rot_angles = np.zeros((transform_param['traj_rois'].shape[0],1))
397
+
398
+ #print("rotation angles original",rot_angles.shape,loc_key_inds)
399
+ rot_angles_list= [np.expand_dims(rot_angles, axis=0)]
400
+ for i in range(len(loc_key_inds) - 1):
401
+ if rot_velocity > 0:
402
+ index_diff = loc_key_inds[i + 1] - loc_key_inds[i]
403
+ shifts = np.random.uniform(low=-rot_velocity* index_diff,
404
+ high=rot_velocity* index_diff,
405
+ size=rot_angles.shape)
406
+ #scales = np.exp(shifts)
407
+ #print("shifts shear", shifts)
408
+ #rot_angles = scales
409
+ rot_angles = rot_angles + shifts
410
+ rot_angles_list.append(np.expand_dims(rot_angles, axis=0))
411
+ rot_angles = np.concatenate(rot_angles_list, axis=0)
412
+ rot_angles = extend_key_frame_to_all(rot_angles, loc_key_inds, 'random')
413
+ rot_angles = rot_angles.transpose((1, 0, 2))
414
+
415
+ return rot_angles
416
+
417
+
418
+ def _apply_image(self,
419
+ data: List[np.ndarray],
420
+ transform_param: dict):
421
+
422
+ data_1 = data
423
+
424
+ # we sort the size and firstly paste the large patch
425
+ # this trick is because, if we paste the small patch first, it may
426
+ # be totally covered by a large one.
427
+ sizes = transform_param['traj_rois'][..., 2:4] - \
428
+ transform_param['traj_rois'][..., 0:2]
429
+ avg_sizes = np.prod(np.mean(sizes, axis=1), axis=1)
430
+ arg_rank = np.argsort(avg_sizes)[::-1]
431
+
432
+ width, height,_ = data_1[0].shape
433
+ #print(width,height)
434
+
435
+
436
+ if self.use_objects:
437
+
438
+ if transform_param['patch_transformation'] == 'rotation':
439
+ rot_angles = self.get_rotation_angles(len(data_1),transform_param)
440
+ transformed_data_1 = []
441
+ for frame_idx in range(len(data_1)):
442
+ i_rois = transform_param['traj_rois'][:, frame_idx, :]
443
+ img = data_1[frame_idx].copy()
444
+ for patch_idx in arg_rank:
445
+ if not transform_param['traj_labels'][patch_idx][frame_idx]:
446
+ continue
447
+ i_object = transform_param['patches'][patch_idx][frame_idx] # here patches are objects
448
+ i_object = np.array(i_object)
449
+ angle = int(rot_angles[patch_idx][frame_idx])
450
+ rotated_i_object = imutils.rotate_bound(i_object, angle)
451
+
452
+ rotated_i_alpha = rotated_i_object[..., -1]
453
+ rotated_i_alpha = rotated_i_alpha / 255.0
454
+ rotated_i_object = rotated_i_object[..., :3]
455
+
456
+ h_prime, w_prime, channels = rotated_i_object.shape
457
+ x1, y1, x2, y2 = i_rois[patch_idx]
458
+ h, w = y2 - y1, x2 - x1
459
+ if ((h_prime - h) % 2) == 0:
460
+ delta_h1 = delta_h2 = math.ceil((h_prime - h) / 2)
461
+ else:
462
+ delta_h1 = math.ceil((h_prime - h) / 2)
463
+ delta_h2 = math.floor((h_prime - h) / 2)
464
+ if ((w_prime - w) % 2) == 0:
465
+ delta_w1 = delta_w2 = math.ceil((w_prime - w) / 2)
466
+ else:
467
+ delta_w1 = math.ceil((w_prime - w) / 2)
468
+ delta_w2 = math.floor((w_prime - w) / 2)
469
+
470
+ x1_new, y1_new, x2_new, y2_new = x1 - delta_w1, y1 - delta_h1, x2 + delta_w2, y2 + delta_h2
471
+ if all(i >= 0 for i in [x1_new, y1_new, x2_new, y2_new]) and all(
472
+ i < width for i in [x1_new, y1_new, x2_new, y2_new]):
473
+ # in bound
474
+ i_patch = rotated_i_object
475
+ i_alpha = rotated_i_alpha[..., np.newaxis]
476
+ img[y1_new:y2_new, x1_new:x2_new, :] = img[y1_new:y2_new, x1_new:x2_new, :] * (1 - i_alpha) + i_patch * i_alpha
477
+ else:
478
+ # out of bound
479
+ img_H, img_W, C = img.shape
480
+ patch_H, patch_W, _ = rotated_i_object.shape
481
+ extended_img = np.zeros((img_H + 2 * patch_H, img_W + 2 * patch_W, C), dtype=img.dtype)
482
+ extended_img[patch_H:(img_H + patch_H), patch_W:(img_W + patch_W), :] = img
483
+
484
+ x1_new += patch_W
485
+ x2_new += patch_W
486
+ y1_new += patch_H
487
+ y2_new += patch_H
488
+ i_alpha = rotated_i_alpha[..., np.newaxis]
489
+ extended_img[y1_new:y2_new, x1_new:x2_new, :] = extended_img[y1_new:y2_new, x1_new:x2_new, :] * (1 - i_alpha) + rotated_i_object * i_alpha
490
+ img = extended_img[patch_H:(img_H + patch_H), patch_W:(img_W + patch_W), :]
491
+
492
+ img = np.array(img)
493
+ transformed_data_1.append(img)
494
+
495
+ return transformed_data_1
496
+
497
+
498
+ @staticmethod
499
+ def rectangle_movement(boxes: np.ndarray,
500
+ img_wh: tuple,
501
+ loc_velocity: float,
502
+ size_velocity: float,
503
+ num_frames: int,
504
+ key_frame_probs: List[float]) -> np.ndarray:
505
+ """ Simulate the object movement.
506
+
507
+ Args:
508
+ boxes (np.ndarray): in shpae of [N_boxes, 4]
509
+ img_wh (tuple): image width and image height
510
+ loc_velocity (float): max speed of the center point movement
511
+ size_velocity (float): max speed of size changes
512
+ num_frames (int): number of frames
513
+ key_frame_probs (float): probability distribution of how many key
514
+ frames will be sampled.
515
+
516
+ Returns
517
+ all_boxes (np.ndarray): the generated box trajectory, in shpae
518
+ of [N_traj, N_frame, 4].
519
+
520
+ """
521
+ # Step 1, sample key frames for location changes
522
+ loc_key_inds = sample_key_frames(num_frames, key_frame_probs)
523
+ # Step 2, decide box locations in key frames
524
+ ctr_pts = (boxes[:, 0:2] + boxes[:, 2:4]) * 0.5
525
+ #print("center points original",ctr_pts)
526
+ box_sizes = (boxes[:, 2:4] - boxes[:, 0:2])
527
+ #print("box sizes = ",box_sizes,box_sizes.shape)
528
+
529
+ min_ctr_pts = box_sizes * 0.5
530
+ max_ctr_pts = np.array(img_wh[0:2]).reshape(1, 2) - box_sizes * 0.5
531
+
532
+ #print("initial center points ",ctr_pts,loc_key_inds)
533
+ ctr_pts_list = [np.expand_dims(ctr_pts, axis=0)]
534
+ #print("ctr pts list",ctr_pts_list)
535
+ for i in range(len(loc_key_inds) - 1):
536
+ if loc_velocity > 0:
537
+ index_diff = loc_key_inds[i + 1] - loc_key_inds[i]
538
+ shifts = np.random.uniform(low=-loc_velocity * index_diff,
539
+ high=loc_velocity * index_diff,
540
+ size=ctr_pts.shape)
541
+ #print("shifts",shifts)
542
+ ctr_pts = ctr_pts + shifts
543
+ ctr_pts = np.clip(ctr_pts, min_ctr_pts, max_ctr_pts)
544
+ ctr_pts_list.append(np.expand_dims(ctr_pts, axis=0))
545
+ ctr_pts = np.concatenate(ctr_pts_list, axis=0)
546
+
547
+ ctr_pts = extend_key_frame_to_all(ctr_pts, loc_key_inds, 'random')
548
+ #print("all center points ",ctr_pts,ctr_pts.shape)
549
+
550
+ # Step 3, sample key frames for shape changes
551
+ size_key_inds = sample_key_frames(num_frames, key_frame_probs)
552
+
553
+ # Step 4, setup shape in different key frames
554
+ box_sizes_list = [np.expand_dims(box_sizes, axis=0)]
555
+ for i in range(len(size_key_inds) - 1):
556
+ if size_velocity > 0:
557
+ index_diff = size_key_inds[i + 1] - size_key_inds[i]
558
+ scales = np.random.uniform(low=-size_velocity * index_diff,
559
+ high=size_velocity * index_diff,
560
+ size=box_sizes.shape)
561
+ scales = np.exp(scales)
562
+ box_sizes = box_sizes * scales
563
+ box_sizes_list.append(np.expand_dims(box_sizes, axis=0))
564
+ box_sizes = np.concatenate(box_sizes_list, axis=0)
565
+ # print("box sizes before interpolation",box_sizes,size_key_inds)
566
+ box_sizes = extend_key_frame_to_all(box_sizes, size_key_inds, 'random')
567
+ #print("box sizes after interpolation",box_sizes)
568
+
569
+ # Step 5, construct boxes in key frames
570
+ all_boxes = np.concatenate((ctr_pts - box_sizes * 0.5,
571
+ ctr_pts + box_sizes * 0.5), axis=2)
572
+ # all_boxes[..., 0::2] = np.clip(all_boxes[..., 0::2], 0, img_wh[0])
573
+ # all_boxes[..., 1::2] = np.clip(all_boxes[..., 1::2], 0, img_wh[1])
574
+ all_boxes = all_boxes.transpose((1, 0, 2))
575
+ return all_boxes
576
+
577
+ @staticmethod
578
+ def gaussian_movement(box_shapes: np.ndarray,
579
+ img_wh: tuple,
580
+ num_trajs: int,
581
+ size_velocity: float,
582
+ num_frames: int,
583
+ key_frame_probs: List[float]) -> np.ndarray:
584
+ """ Simulate the object movement.
585
+
586
+ Args:
587
+
588
+ Returns
589
+ all_boxes (np.ndarray): the generated box trajectory, in shpae
590
+ of [N_traj, N_frame, 4].
591
+
592
+ """
593
+
594
+ def create_traj(box_shapes):
595
+ w = img_wh[0]
596
+ h = img_wh[1]
597
+ #print("gaussian",w,h)
598
+
599
+ n_points = 48 # how many points to create trajectory
600
+ sigma = 8 # bigger sigma -> smoother trajectory
601
+
602
+ # simulate trajectory points
603
+ #x = np.random.uniform(0,112,n_points)
604
+ #y = np.random.uniform(0,112,n_points)
605
+
606
+ # for 112 x 112
607
+ x = np.random.uniform(1+box_shapes[0]/2,w-1-box_shapes[0]/2,n_points)
608
+ y = np.random.uniform(1+box_shapes[1]/2,h-1-box_shapes[1]/2,n_points)
609
+
610
+ # for 224x 224
611
+ # x = np.random.uniform(0,112,n_points)
612
+ # y = np.random.uniform(0,112,n_points)
613
+
614
+ # smooth trajectory
615
+ xk = gaussian_filter1d(x, sigma=sigma, mode='reflect')
616
+ yk = gaussian_filter1d(y, sigma=sigma, mode='reflect')
617
+
618
+ # normalize and random scale
619
+ xkk = (xk -xk.min())
620
+ xkk /= xkk.max()
621
+ ykk = (yk -yk.min())
622
+ ykk /= ykk.max()
623
+
624
+ #scaling_factor = np.random.randint(20,90)
625
+ scaling_factor = np.random.randint(40,180)
626
+ xkk*=scaling_factor # randomize
627
+ ykk*=scaling_factor # randomize
628
+
629
+
630
+ # random translate and clip
631
+ translation_factor_x = np.random.randint(0,w-scaling_factor)
632
+ translation_factor_y = np.random.randint(0,h-scaling_factor)
633
+ tr_x = xkk + translation_factor_x
634
+ tr_y = ykk + translation_factor_y
635
+
636
+ tr_x = np.clip(tr_x,0,w-1)
637
+ tr_y = np.clip(tr_y,0,h-1)
638
+
639
+ # sample 16 points from trajectory with linear spacing
640
+ idxs = np.round(np.linspace(0, tr_x.shape[0]-1, num=16)).astype(int)
641
+ x_f = tr_x[idxs].astype(int)
642
+ y_f = tr_y[idxs].astype(int)
643
+ #print(x_f.shape,y_f.shape)
644
+ traj = np.column_stack((x_f,y_f))
645
+ traj = np.expand_dims(traj, axis=1)
646
+ return traj
647
+
648
+ # Step 1 create a non-linear trajectory
649
+ #print(" number of rois",num_trajs,box_shapes.shape)
650
+ ctr_pts_list = []
651
+ for i in range(num_trajs):
652
+ ctr_pts_list.append(create_traj(box_shapes[i]))
653
+ ctr_pts = np.concatenate(ctr_pts_list, axis=1)
654
+ #print("all center points guassian ",ctr_pts,ctr_pts.shape)
655
+
656
+ # Step 2 create box shapes for the starting location
657
+
658
+ boxes_list = []
659
+ for i in range(num_trajs):
660
+ x1, y1 = ctr_pts[0][i][0], ctr_pts[0][i][1]
661
+ box = np.concatenate((
662
+ (x1 - box_shapes[i, 0]/2).reshape(-1, 1),
663
+ (y1 - box_shapes[i, 1]/2).reshape(-1, 1),
664
+ (x1 + box_shapes[i, 0]/2).reshape(-1, 1),
665
+ (y1 + box_shapes[i, 1]/2).reshape(-1, 1)),
666
+ axis=1)
667
+ boxes_list.append(box)
668
+
669
+ boxes= np.concatenate(boxes_list, axis=0)
670
+ box_sizes = (boxes[:, 2:4] - boxes[:, 0:2])
671
+ #print("bboxes guassian ",boxes,boxes.shape)
672
+ #print("guassian box sizes = ",box_sizes,box_sizes.shape)
673
+
674
+ # Step 3, sample key frames for shape changes
675
+ size_key_inds = sample_key_frames(num_frames, key_frame_probs)
676
+ # Step 4, setup shape in different key frames
677
+ box_sizes_list = [np.expand_dims(box_sizes, axis=0)]
678
+ for i in range(len(size_key_inds) - 1):
679
+ if size_velocity > 0:
680
+ index_diff = size_key_inds[i + 1] - size_key_inds[i]
681
+ scales = np.random.uniform(low=-size_velocity * index_diff,
682
+ high=size_velocity * index_diff,
683
+ size=box_sizes.shape)
684
+ scales = np.exp(scales)
685
+ box_sizes = box_sizes * scales
686
+ box_sizes_list.append(np.expand_dims(box_sizes, axis=0))
687
+ box_sizes = np.concatenate(box_sizes_list, axis=0)
688
+ # print("box sizes before interpolation",box_sizes)
689
+ box_sizes = extend_key_frame_to_all(box_sizes, size_key_inds, 'random')
690
+ #print("box sizes after interpolation",box_sizes)
691
+
692
+ # Step 5, construct boxes in key frames
693
+ all_boxes = np.concatenate((ctr_pts - box_sizes * 0.5,
694
+ ctr_pts + box_sizes * 0.5), axis=2)
695
+ # all_boxes[..., 0::2] = np.clip(all_boxes[..., 0::2], 0, img_wh[0])
696
+ # all_boxes[..., 1::2] = np.clip(all_boxes[..., 1::2], 0, img_wh[1])
697
+ all_boxes = all_boxes.transpose((1, 0, 2))
698
+ return all_boxes,boxes
699
+
700
+ def __call__(self,img_tuple):
701
+ #def get_transform_param(self, data: List[np.ndarray], *args, **kwargs):
702
+ """ Generate the transformation parameters.
703
+
704
+ Args:
705
+ data (List[np.ndarray]): list of image array, each element is in
706
+ a shape of [H, W, 3]
707
+
708
+ Returns:
709
+ params (dict): a dict that contains necessary transformation
710
+ params, which include:
711
+ 'patches': list of image patches (np.ndarray)
712
+ 'alphas': list of alpha mask, same size and shape as patches.
713
+ 'traj_rois': the trajectory position, in shape of
714
+ [N_traj, N_frame, 4]
715
+ 'traj_labels': whether the patches have been pasted on some
716
+ specific frames, in shape of [N_traj, N_frame]
717
+ """
718
+
719
+ #print("with tubelets")
720
+
721
+ img_group, label = img_tuple
722
+
723
+ #print("before length data",len(img_group),img_group[0].size)
724
+
725
+ new_data = [np.array(img) for img in img_group]
726
+
727
+ #print("after length data",len(new_data),new_data[0].shape)
728
+
729
+ data_1 = new_data # Step 1, generate the trajectories.
730
+
731
+ h, w = data_1[0].shape[0:2]
732
+
733
+ #print("motion type and size_velocity", self.motion_type,self.size_velocity)
734
+ #print(" patch transformation and rotation velocity =",self.patch_transformation,self.rot_velocity)
735
+ if self.motion_type == 'linear' :
736
+
737
+ boxes = self.region_sampler.sample(data_1)
738
+
739
+ traj_rois = self.rectangle_movement(boxes, (w, h),
740
+ self.loc_velocity,
741
+ self.size_velocity,
742
+ len(data_1),
743
+ self.key_frame_probs)
744
+ # gaussian
745
+ elif self.motion_type == 'gaussian' :
746
+
747
+ box_shapes = self.region_sampler.sample_box_shapes(data_1)
748
+
749
+ traj_rois,boxes = self.gaussian_movement(box_shapes, (w, h),
750
+ self.region_sampler.num_rois,
751
+ self.size_velocity,
752
+ len(data_1),
753
+ self.key_frame_probs)
754
+
755
+ #print("gaussian rois",traj_rois.shape)
756
+ traj_rois = np.round(traj_rois).astype(int)
757
+ # traj_rois[..., 0::2] = np.clip(traj_rois[..., 0::2], 0, w)
758
+ # traj_rois[..., 1::2] = np.clip(traj_rois[..., 1::2], 0, h)
759
+
760
+ # Step 2, crop the patches and prepare the alpha masks.
761
+ if not self.use_objects:
762
+
763
+ #print(" pasting patches")
764
+ patches_list, alphas_list, label_list = self.paste_patches(data_1,traj_rois,boxes)
765
+ else:
766
+ #print(" pasting objects")
767
+ patches_list, alphas_list, label_list = self.paste_objects(data_1,traj_rois,boxes)
768
+
769
+
770
+
771
+ transforms_dict = dict(
772
+ traj_rois=traj_rois,
773
+ patches=patches_list,
774
+ alphas=alphas_list,
775
+ traj_labels=label_list,
776
+ rot_velocity = self.rot_velocity,
777
+ patch_transformation = self.patch_transformation,
778
+ key_frame_probs = self.key_frame_probs
779
+ )
780
+
781
+ output_data = self._apply_image( new_data,transforms_dict)
782
+
783
+ ret_data = [Image.fromarray(img) for img in output_data]
784
+
785
+ return ret_data, label, traj_rois
transforms.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms.functional as F
3
+ import warnings
4
+ import random
5
+ import numpy as np
6
+ import torchvision
7
+ from PIL import Image, ImageOps
8
+ import numbers
9
+
10
+
11
+ class GroupRandomCrop(object):
12
+ def __init__(self, size):
13
+ if isinstance(size, numbers.Number):
14
+ self.size = (int(size), int(size))
15
+ else:
16
+ self.size = size
17
+
18
+ def __call__(self, img_tuple):
19
+ img_group, label = img_tuple
20
+
21
+ w, h = img_group[0].size
22
+ th, tw = self.size
23
+
24
+ out_images = list()
25
+
26
+ x1 = random.randint(0, w - tw)
27
+ y1 = random.randint(0, h - th)
28
+
29
+ for img in img_group:
30
+ assert(img.size[0] == w and img.size[1] == h)
31
+ if w == tw and h == th:
32
+ out_images.append(img)
33
+ else:
34
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
35
+
36
+ return (out_images, label)
37
+
38
+
39
+ class GroupCenterCrop(object):
40
+ def __init__(self, size):
41
+ self.worker = torchvision.transforms.CenterCrop(size)
42
+
43
+ def __call__(self, img_tuple):
44
+ img_group, label = img_tuple
45
+ return ([self.worker(img) for img in img_group], label)
46
+
47
+
48
+ class GroupNormalize(object):
49
+ def __init__(self, mean, std):
50
+ self.mean = mean
51
+ self.std = std
52
+
53
+ def __call__(self, tensor_tuple):
54
+ tensor, label = tensor_tuple
55
+ rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
56
+ rep_std = self.std * (tensor.size()[0]//len(self.std))
57
+
58
+ # TODO: make efficient
59
+ for t, m, s in zip(tensor, rep_mean, rep_std):
60
+ t.sub_(m).div_(s)
61
+
62
+ return (tensor,label)
63
+
64
+
65
+ class GroupGrayScale(object):
66
+ def __init__(self, size):
67
+ self.worker = torchvision.transforms.Grayscale(size)
68
+
69
+ def __call__(self, img_tuple):
70
+ img_group, label = img_tuple
71
+ return ([self.worker(img) for img in img_group], label)
72
+
73
+
74
+ class GroupScale(object):
75
+ """ Rescales the input PIL.Image to the given 'size'.
76
+ 'size' will be the size of the smaller edge.
77
+ For example, if height > width, then image will be
78
+ rescaled to (size * height / width, size)
79
+ size: size of the smaller edge
80
+ interpolation: Default: PIL.Image.BILINEAR
81
+ """
82
+
83
+ def __init__(self, size, interpolation=Image.BILINEAR):
84
+ self.worker = torchvision.transforms.Resize(size, interpolation)
85
+
86
+ def __call__(self, img_tuple):
87
+ img_group, label = img_tuple
88
+ return ([self.worker(img) for img in img_group], label)
89
+
90
+
91
+ class GroupMultiScaleCrop(object):
92
+
93
+ def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
94
+ self.scales = scales if scales is not None else [1, .875, .75, .66]
95
+ self.max_distort = max_distort
96
+ self.fix_crop = fix_crop
97
+ self.more_fix_crop = more_fix_crop
98
+ self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
99
+ self.interpolation = Image.BILINEAR
100
+
101
+ def __call__(self, img_tuple):
102
+ img_group, label = img_tuple
103
+
104
+ im_size = img_group[0].size
105
+
106
+ crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
107
+ crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
108
+ ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group]
109
+ return (ret_img_group, label)
110
+
111
+ def _sample_crop_size(self, im_size):
112
+ image_w, image_h = im_size[0], im_size[1]
113
+
114
+ # find a crop size
115
+ base_size = min(image_w, image_h)
116
+ crop_sizes = [int(base_size * x) for x in self.scales]
117
+ crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
118
+ crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
119
+
120
+ pairs = []
121
+ for i, h in enumerate(crop_h):
122
+ for j, w in enumerate(crop_w):
123
+ if abs(i - j) <= self.max_distort:
124
+ pairs.append((w, h))
125
+
126
+ crop_pair = random.choice(pairs)
127
+ if not self.fix_crop:
128
+ w_offset = random.randint(0, image_w - crop_pair[0])
129
+ h_offset = random.randint(0, image_h - crop_pair[1])
130
+ else:
131
+ w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
132
+
133
+ return crop_pair[0], crop_pair[1], w_offset, h_offset
134
+
135
+ def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
136
+ offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
137
+ return random.choice(offsets)
138
+
139
+ @staticmethod
140
+ def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
141
+ w_step = (image_w - crop_w) // 4
142
+ h_step = (image_h - crop_h) // 4
143
+
144
+ ret = list()
145
+ ret.append((0, 0)) # upper left
146
+ ret.append((4 * w_step, 0)) # upper right
147
+ ret.append((0, 4 * h_step)) # lower left
148
+ ret.append((4 * w_step, 4 * h_step)) # lower right
149
+ ret.append((2 * w_step, 2 * h_step)) # center
150
+
151
+ if more_fix_crop:
152
+ ret.append((0, 2 * h_step)) # center left
153
+ ret.append((4 * w_step, 2 * h_step)) # center right
154
+ ret.append((2 * w_step, 4 * h_step)) # lower center
155
+ ret.append((2 * w_step, 0 * h_step)) # upper center
156
+
157
+ ret.append((1 * w_step, 1 * h_step)) # upper left quarter
158
+ ret.append((3 * w_step, 1 * h_step)) # upper right quarter
159
+ ret.append((1 * w_step, 3 * h_step)) # lower left quarter
160
+ ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
161
+ return ret
162
+
163
+
164
+ class Stack(object):
165
+
166
+ def __init__(self, roll=False):
167
+ self.roll = roll
168
+
169
+ def __call__(self, img_tuple):
170
+ img_group, label = img_tuple
171
+
172
+ if img_group[0].mode == 'L':
173
+ return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label)
174
+ elif img_group[0].mode == 'RGB':
175
+ if self.roll:
176
+ return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label)
177
+ else:
178
+ return (np.concatenate(img_group, axis=2), label)
179
+
180
+
181
+ class ToTorchFormatTensor(object):
182
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
183
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
184
+ def __init__(self, div=True):
185
+ self.div = div
186
+
187
+ def __call__(self, pic_tuple):
188
+ pic, label = pic_tuple
189
+
190
+ if isinstance(pic, np.ndarray):
191
+ # handle numpy array
192
+ img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
193
+ else:
194
+ # handle PIL Image
195
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
196
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
197
+ # put it from HWC to CHW format
198
+ # yikes, this transpose takes 80% of the loading time/CPU
199
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
200
+ return (img.float().div(255.) if self.div else img.float(), label)
201
+
202
+
203
+ class IdentityTransform(object):
204
+
205
+ def __call__(self, data):
206
+ return data
utils_mae.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import math
4
+ import time
5
+ import json
6
+ from collections import defaultdict, deque
7
+ import datetime
8
+ import numpy as np
9
+ from timm.utils import get_state_dict
10
+ from torch.utils.data._utils.collate import default_collate
11
+ from pathlib import Path
12
+ import subprocess
13
+ import torch
14
+ import torch.distributed as dist
15
+ #from torch._six import inf
16
+ from torch import inf
17
+ import random
18
+
19
+ from tensorboardX import SummaryWriter
20
+
21
+
22
+ class SmoothedValue(object):
23
+ """Track a series of values and provide access to smoothed values over a
24
+ window or the global series average.
25
+ """
26
+
27
+ def __init__(self, window_size=20, fmt=None):
28
+ if fmt is None:
29
+ fmt = "{median:.4f} ({global_avg:.4f})"
30
+ self.deque = deque(maxlen=window_size)
31
+ self.total = 0.0
32
+ self.count = 0
33
+ self.fmt = fmt
34
+
35
+ def update(self, value, n=1):
36
+ self.deque.append(value)
37
+ self.count += n
38
+ self.total += value * n
39
+
40
+ def synchronize_between_processes(self):
41
+ """
42
+ Warning: does not synchronize the deque!
43
+ """
44
+ if not is_dist_avail_and_initialized():
45
+ return
46
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
47
+ dist.barrier()
48
+ dist.all_reduce(t)
49
+ t = t.tolist()
50
+ self.count = int(t[0])
51
+ self.total = t[1]
52
+
53
+ @property
54
+ def median(self):
55
+ d = torch.tensor(list(self.deque))
56
+ return d.median().item()
57
+
58
+ @property
59
+ def avg(self):
60
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
61
+ return d.mean().item()
62
+
63
+ @property
64
+ def global_avg(self):
65
+ return self.total / self.count
66
+
67
+ @property
68
+ def max(self):
69
+ return max(self.deque)
70
+
71
+ @property
72
+ def value(self):
73
+ return self.deque[-1]
74
+
75
+ def __str__(self):
76
+ return self.fmt.format(
77
+ median=self.median,
78
+ avg=self.avg,
79
+ global_avg=self.global_avg,
80
+ max=self.max,
81
+ value=self.value)
82
+
83
+
84
+ class MetricLogger(object):
85
+ def __init__(self, delimiter="\t"):
86
+ self.meters = defaultdict(SmoothedValue)
87
+ self.delimiter = delimiter
88
+
89
+ def update(self, **kwargs):
90
+ for k, v in kwargs.items():
91
+ if v is None:
92
+ continue
93
+ if isinstance(v, torch.Tensor):
94
+ v = v.item()
95
+ assert isinstance(v, (float, int))
96
+ self.meters[k].update(v)
97
+
98
+ def __getattr__(self, attr):
99
+ if attr in self.meters:
100
+ return self.meters[attr]
101
+ if attr in self.__dict__:
102
+ return self.__dict__[attr]
103
+ raise AttributeError("'{}' object has no attribute '{}'".format(
104
+ type(self).__name__, attr))
105
+
106
+ def __str__(self):
107
+ loss_str = []
108
+ for name, meter in self.meters.items():
109
+ loss_str.append(
110
+ "{}: {}".format(name, str(meter))
111
+ )
112
+ return self.delimiter.join(loss_str)
113
+
114
+ def synchronize_between_processes(self):
115
+ for meter in self.meters.values():
116
+ meter.synchronize_between_processes()
117
+
118
+ def add_meter(self, name, meter):
119
+ self.meters[name] = meter
120
+
121
+ def log_every(self, iterable, print_freq, header=None):
122
+ i = 0
123
+ if not header:
124
+ header = ''
125
+ start_time = time.time()
126
+ end = time.time()
127
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
128
+ data_time = SmoothedValue(fmt='{avg:.4f}')
129
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
130
+ log_msg = [
131
+ header,
132
+ '[{0' + space_fmt + '}/{1}]',
133
+ 'eta: {eta}',
134
+ '{meters}',
135
+ 'time: {time}',
136
+ 'data: {data}'
137
+ ]
138
+ if torch.cuda.is_available():
139
+ log_msg.append('max mem: {memory:.0f}')
140
+ log_msg = self.delimiter.join(log_msg)
141
+ MB = 1024.0 * 1024.0
142
+ for obj in iterable:
143
+ data_time.update(time.time() - end)
144
+ yield obj
145
+ iter_time.update(time.time() - end)
146
+ if i % print_freq == 0 or i == len(iterable) - 1:
147
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
148
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
149
+ if torch.cuda.is_available():
150
+ print(log_msg.format(
151
+ i, len(iterable), eta=eta_string,
152
+ meters=str(self),
153
+ time=str(iter_time), data=str(data_time),
154
+ memory=torch.cuda.max_memory_allocated() / MB))
155
+ else:
156
+ print(log_msg.format(
157
+ i, len(iterable), eta=eta_string,
158
+ meters=str(self),
159
+ time=str(iter_time), data=str(data_time)))
160
+ i += 1
161
+ end = time.time()
162
+ total_time = time.time() - start_time
163
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
164
+ print('{} Total time: {} ({:.4f} s / it)'.format(
165
+ header, total_time_str, total_time / len(iterable)))
166
+
167
+
168
+ class TensorboardLogger(object):
169
+ def __init__(self, log_dir):
170
+ self.writer = SummaryWriter(logdir=log_dir)
171
+ self.step = 0
172
+
173
+ def set_step(self, step=None):
174
+ if step is not None:
175
+ self.step = step
176
+ else:
177
+ self.step += 1
178
+
179
+ def update(self, head='scalar', step=None, **kwargs):
180
+ for k, v in kwargs.items():
181
+ if v is None:
182
+ continue
183
+ if isinstance(v, torch.Tensor):
184
+ v = v.item()
185
+ assert isinstance(v, (float, int))
186
+ self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
187
+
188
+ def flush(self):
189
+ self.writer.flush()
190
+
191
+ def seed_worker(worker_id):
192
+ worker_seed = torch.initial_seed() % 2**32
193
+ np.random.seed(worker_seed)
194
+ random.seed(worker_seed)
195
+
196
+ def _load_checkpoint_for_ema(model_ema, checkpoint):
197
+ """
198
+ Workaround for ModelEma._load_checkpoint to accept an already-loaded object
199
+ """
200
+ mem_file = io.BytesIO()
201
+ torch.save(checkpoint, mem_file)
202
+ mem_file.seek(0)
203
+ model_ema._load_checkpoint(mem_file)
204
+
205
+
206
+ def setup_for_distributed(is_master):
207
+ """
208
+ This function disables printing when not in master process
209
+ """
210
+ import builtins as __builtin__
211
+ builtin_print = __builtin__.print
212
+
213
+ def print(*args, **kwargs):
214
+ force = kwargs.pop('force', False)
215
+ if is_master or force:
216
+ builtin_print(*args, **kwargs)
217
+
218
+ __builtin__.print = print
219
+
220
+
221
+ def is_dist_avail_and_initialized():
222
+ if not dist.is_available():
223
+ return False
224
+ if not dist.is_initialized():
225
+ return False
226
+ return True
227
+
228
+
229
+ def get_world_size():
230
+ if not is_dist_avail_and_initialized():
231
+ return 1
232
+ return dist.get_world_size()
233
+
234
+
235
+ def get_rank():
236
+ if not is_dist_avail_and_initialized():
237
+ return 0
238
+ return dist.get_rank()
239
+
240
+
241
+ def is_main_process():
242
+ return get_rank() == 0
243
+
244
+
245
+ def save_on_master(*args, **kwargs):
246
+ if is_main_process():
247
+ torch.save(*args, **kwargs)
248
+
249
+
250
+ def init_distributed_mode(args):
251
+ if args.dist_on_itp:
252
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
253
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
254
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
255
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
256
+ os.environ['LOCAL_RANK'] = str(args.gpu)
257
+ os.environ['RANK'] = str(args.rank)
258
+ os.environ['WORLD_SIZE'] = str(args.world_size)
259
+ elif 'SLURM_PROCID' in os.environ:
260
+ args.rank = int(os.environ['SLURM_PROCID'])
261
+ args.gpu = int(os.environ['SLURM_LOCALID'])
262
+ args.world_size = int(os.environ['SLURM_NTASKS'])
263
+ os.environ['RANK'] = str(args.rank)
264
+ os.environ['LOCAL_RANK'] = str(args.gpu)
265
+ os.environ['WORLD_SIZE'] = str(args.world_size)
266
+
267
+ node_list = os.environ['SLURM_NODELIST']
268
+ addr = subprocess.getoutput(
269
+ f'scontrol show hostname {node_list} | head -n1')
270
+ if 'MASTER_ADDR' not in os.environ:
271
+ os.environ['MASTER_ADDR'] = addr
272
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
273
+ args.rank = int(os.environ["RANK"])
274
+ args.world_size = int(os.environ['WORLD_SIZE'])
275
+ args.gpu = int(os.environ['LOCAL_RANK'])
276
+ else:
277
+ print('Not using distributed mode')
278
+ args.distributed = False
279
+ return
280
+
281
+ args.distributed = True
282
+
283
+ torch.cuda.set_device(args.gpu)
284
+ args.dist_backend = 'nccl'
285
+ print('| distributed init (rank {}): {}, gpu {}'.format(
286
+ args.rank, args.dist_url, args.gpu), flush=True)
287
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
288
+ world_size=args.world_size, rank=args.rank)
289
+ torch.distributed.barrier()
290
+ # assert torch.distributed.is_initialized()
291
+ setup_for_distributed(args.rank == 0)
292
+
293
+
294
+ def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
295
+ missing_keys = []
296
+ unexpected_keys = []
297
+ error_msgs = []
298
+ metadata = getattr(state_dict, '_metadata', None)
299
+ state_dict = state_dict.copy()
300
+ if metadata is not None:
301
+ state_dict._metadata = metadata
302
+
303
+ def load(module, prefix=''):
304
+ local_metadata = {} if metadata is None else metadata.get(
305
+ prefix[:-1], {})
306
+ module._load_from_state_dict(
307
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
308
+ for name, child in module._modules.items():
309
+ if child is not None:
310
+ load(child, prefix + name + '.')
311
+
312
+ load(model, prefix=prefix)
313
+
314
+ warn_missing_keys = []
315
+ ignore_missing_keys = []
316
+ for key in missing_keys:
317
+ keep_flag = True
318
+ for ignore_key in ignore_missing.split('|'):
319
+ if ignore_key in key:
320
+ keep_flag = False
321
+ break
322
+ if keep_flag:
323
+ warn_missing_keys.append(key)
324
+ else:
325
+ ignore_missing_keys.append(key)
326
+
327
+ missing_keys = warn_missing_keys
328
+
329
+ if len(missing_keys) > 0:
330
+ print("Weights of {} not initialized from pretrained model: {}".format(
331
+ model.__class__.__name__, missing_keys))
332
+ if len(unexpected_keys) > 0:
333
+ print("Weights from pretrained model not used in {}: {}".format(
334
+ model.__class__.__name__, unexpected_keys))
335
+ if len(ignore_missing_keys) > 0:
336
+ print("Ignored weights of {} not initialized from pretrained model: {}".format(
337
+ model.__class__.__name__, ignore_missing_keys))
338
+ if len(error_msgs) > 0:
339
+ print('\n'.join(error_msgs))
340
+
341
+
342
+ class NativeScalerWithGradNormCount:
343
+ state_dict_key = "amp_scaler"
344
+
345
+ def __init__(self):
346
+ self._scaler = torch.cuda.amp.GradScaler()
347
+
348
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
349
+ self._scaler.scale(loss).backward(create_graph=create_graph)
350
+ if update_grad:
351
+ if clip_grad is not None:
352
+ assert parameters is not None
353
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
354
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
355
+ else:
356
+ self._scaler.unscale_(optimizer)
357
+ norm = get_grad_norm_(parameters)
358
+ self._scaler.step(optimizer)
359
+ self._scaler.update()
360
+ else:
361
+ norm = None
362
+ return norm
363
+
364
+ def state_dict(self):
365
+ return self._scaler.state_dict()
366
+
367
+ def load_state_dict(self, state_dict):
368
+ self._scaler.load_state_dict(state_dict)
369
+
370
+
371
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
372
+ if isinstance(parameters, torch.Tensor):
373
+ parameters = [parameters]
374
+ parameters = [p for p in parameters if p.grad is not None]
375
+ norm_type = float(norm_type)
376
+ if len(parameters) == 0:
377
+ return torch.tensor(0.)
378
+ device = parameters[0].grad.device
379
+ if norm_type == inf:
380
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
381
+ else:
382
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
383
+ return total_norm
384
+
385
+
386
+ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
387
+ start_warmup_value=0, warmup_steps=-1):
388
+ warmup_schedule = np.array([])
389
+ warmup_iters = warmup_epochs * niter_per_ep
390
+ if warmup_steps > 0:
391
+ warmup_iters = warmup_steps
392
+ print("Set warmup steps = %d" % warmup_iters)
393
+ if warmup_epochs > 0:
394
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
395
+
396
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
397
+ schedule = np.array(
398
+ [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
399
+
400
+ schedule = np.concatenate((warmup_schedule, schedule))
401
+
402
+ assert len(schedule) == epochs * niter_per_ep
403
+ return schedule
404
+
405
+
406
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
407
+ output_dir = Path(args.output_dir)
408
+ epoch_name = str(epoch)
409
+ if loss_scaler is not None:
410
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
411
+ for checkpoint_path in checkpoint_paths:
412
+ to_save = {
413
+ 'model': model_without_ddp.state_dict(),
414
+ 'optimizer': optimizer.state_dict(),
415
+ 'epoch': epoch,
416
+ 'scaler': loss_scaler.state_dict(),
417
+ 'args': args,
418
+ }
419
+
420
+ if model_ema is not None:
421
+ to_save['model_ema'] = get_state_dict(model_ema)
422
+
423
+ save_on_master(to_save, checkpoint_path)
424
+ else:
425
+ client_state = {'epoch': epoch}
426
+ if model_ema is not None:
427
+ client_state['model_ema'] = get_state_dict(model_ema)
428
+ model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
429
+
430
+
431
+ def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
432
+ output_dir = Path(args.output_dir)
433
+ if loss_scaler is not None:
434
+ # torch.amp
435
+ if args.auto_resume and len(args.resume) == 0:
436
+ import glob
437
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
438
+ latest_ckpt = -1
439
+ for ckpt in all_checkpoints:
440
+ t = ckpt.split('-')[-1].split('.')[0]
441
+ if t.isdigit():
442
+ latest_ckpt = max(int(t), latest_ckpt)
443
+ if latest_ckpt >= 0:
444
+ args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
445
+ print("Auto resume checkpoint: %s" % args.resume)
446
+
447
+ if args.resume:
448
+ if args.resume.startswith('https'):
449
+ checkpoint = torch.hub.load_state_dict_from_url(
450
+ args.resume, map_location='cpu', check_hash=True)
451
+ else:
452
+ checkpoint = torch.load(args.resume, map_location='cpu')
453
+ model_without_ddp.load_state_dict(checkpoint['model'])
454
+ print("Resume checkpoint %s" % args.resume)
455
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint:
456
+ optimizer.load_state_dict(checkpoint['optimizer'])
457
+ args.start_epoch = checkpoint['epoch'] + 1
458
+ if hasattr(args, 'model_ema') and args.model_ema:
459
+ _load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
460
+ if 'scaler' in checkpoint:
461
+ loss_scaler.load_state_dict(checkpoint['scaler'])
462
+ print("With optim & sched!")
463
+ else:
464
+ # deepspeed, only support '--auto_resume'.
465
+ if args.auto_resume:
466
+ import glob
467
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
468
+ latest_ckpt = -1
469
+ for ckpt in all_checkpoints:
470
+ t = ckpt.split('-')[-1].split('.')[0]
471
+ if t.isdigit():
472
+ latest_ckpt = max(int(t), latest_ckpt)
473
+ if latest_ckpt >= 0:
474
+ args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt)
475
+ print("Auto resume checkpoint: %d" % latest_ckpt)
476
+ _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt)
477
+ args.start_epoch = client_states['epoch'] + 1
478
+ if model_ema is not None:
479
+ if args.model_ema:
480
+ _load_checkpoint_for_ema(model_ema, client_states['model_ema'])
481
+
482
+
483
+ def create_ds_config(args):
484
+ args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
485
+ with open(args.deepspeed_config, mode="w") as writer:
486
+ ds_config = {
487
+ "train_batch_size": args.batch_size * args.update_freq * get_world_size(),
488
+ "train_micro_batch_size_per_gpu": args.batch_size,
489
+ "steps_per_print": 1000,
490
+ "optimizer": {
491
+ "type": "Adam",
492
+ "adam_w_mode": True,
493
+ "params": {
494
+ "lr": args.lr,
495
+ "weight_decay": args.weight_decay,
496
+ "bias_correction": True,
497
+ "betas": [
498
+ 0.9,
499
+ 0.999
500
+ ],
501
+ "eps": 1e-8
502
+ }
503
+ },
504
+ "fp16": {
505
+ "enabled": True,
506
+ "loss_scale": 0,
507
+ "initial_scale_power": 7,
508
+ "loss_scale_window": 128
509
+ }
510
+ }
511
+
512
+ writer.write(json.dumps(ds_config, indent=2))
513
+
514
+ def multiple_samples_collate(batch, fold=False):
515
+ """
516
+ Collate function for repeated augmentation. Each instance in the batch has
517
+ more than one sample.
518
+ Args:
519
+ batch (tuple or list): data batch to collate.
520
+ Returns:
521
+ (tuple): collated data batch.
522
+ """
523
+ inputs, labels, video_idx, extra_data = zip(*batch)
524
+ inputs = [item for sublist in inputs for item in sublist]
525
+ labels = [item for sublist in labels for item in sublist]
526
+ video_idx = [item for sublist in video_idx for item in sublist]
527
+ inputs, labels, video_idx, extra_data = (
528
+ default_collate(inputs),
529
+ default_collate(labels),
530
+ default_collate(video_idx),
531
+ default_collate(extra_data),
532
+ )
533
+ if fold:
534
+ return [inputs], labels, video_idx, extra_data
535
+ else:
536
+ return inputs, labels, video_idx, extra_data
video_transforms.py ADDED
@@ -0,0 +1,1281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import math
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ import torchvision.transforms.functional as F
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+
10
+ from rand_augment import rand_augment_transform
11
+ from random_erasing import RandomErasing
12
+
13
+
14
+ import numbers
15
+ import PIL
16
+ import torchvision
17
+
18
+ import functional as FF
19
+
20
+ _pil_interpolation_to_str = {
21
+ Image.NEAREST: "PIL.Image.NEAREST",
22
+ Image.BILINEAR: "PIL.Image.BILINEAR",
23
+ Image.BICUBIC: "PIL.Image.BICUBIC",
24
+ Image.LANCZOS: "PIL.Image.LANCZOS",
25
+ Image.HAMMING: "PIL.Image.HAMMING",
26
+ Image.BOX: "PIL.Image.BOX",
27
+ }
28
+
29
+
30
+ _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
31
+
32
+
33
+ def _pil_interp(method):
34
+ if method == "bicubic":
35
+ return Image.BICUBIC
36
+ elif method == "lanczos":
37
+ return Image.LANCZOS
38
+ elif method == "hamming":
39
+ return Image.HAMMING
40
+ else:
41
+ return Image.BILINEAR
42
+
43
+
44
+ def random_short_side_scale_jitter(
45
+ images, min_size, max_size, boxes=None, inverse_uniform_sampling=False
46
+ ):
47
+ """
48
+ Perform a spatial short scale jittering on the given images and
49
+ corresponding boxes.
50
+ Args:
51
+ images (tensor): images to perform scale jitter. Dimension is
52
+ `num frames` x `channel` x `height` x `width`.
53
+ min_size (int): the minimal size to scale the frames.
54
+ max_size (int): the maximal size to scale the frames.
55
+ boxes (ndarray): optional. Corresponding boxes to images.
56
+ Dimension is `num boxes` x 4.
57
+ inverse_uniform_sampling (bool): if True, sample uniformly in
58
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
59
+ scale. If False, take a uniform sample from [min_scale, max_scale].
60
+ Returns:
61
+ (tensor): the scaled images with dimension of
62
+ `num frames` x `channel` x `new height` x `new width`.
63
+ (ndarray or None): the scaled boxes with dimension of
64
+ `num boxes` x 4.
65
+ """
66
+ if inverse_uniform_sampling:
67
+ size = int(
68
+ round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))
69
+ )
70
+ else:
71
+ size = int(round(np.random.uniform(min_size, max_size)))
72
+
73
+ height = images.shape[2]
74
+ width = images.shape[3]
75
+ if (width <= height and width == size) or (
76
+ height <= width and height == size
77
+ ):
78
+ return images, boxes
79
+ new_width = size
80
+ new_height = size
81
+ if width < height:
82
+ new_height = int(math.floor((float(height) / width) * size))
83
+ if boxes is not None:
84
+ boxes = boxes * float(new_height) / height
85
+ else:
86
+ new_width = int(math.floor((float(width) / height) * size))
87
+ if boxes is not None:
88
+ boxes = boxes * float(new_width) / width
89
+
90
+ return (
91
+ torch.nn.functional.interpolate(
92
+ images,
93
+ size=(new_height, new_width),
94
+ mode="bilinear",
95
+ align_corners=False,
96
+ ),
97
+ boxes,
98
+ )
99
+
100
+
101
+ def crop_boxes(boxes, x_offset, y_offset):
102
+ """
103
+ Peform crop on the bounding boxes given the offsets.
104
+ Args:
105
+ boxes (ndarray or None): bounding boxes to peform crop. The dimension
106
+ is `num boxes` x 4.
107
+ x_offset (int): cropping offset in the x axis.
108
+ y_offset (int): cropping offset in the y axis.
109
+ Returns:
110
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
111
+ `num boxes` x 4.
112
+ """
113
+ cropped_boxes = boxes.copy()
114
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
115
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
116
+
117
+ return cropped_boxes
118
+
119
+
120
+ def random_crop(images, size, boxes=None):
121
+ """
122
+ Perform random spatial crop on the given images and corresponding boxes.
123
+ Args:
124
+ images (tensor): images to perform random crop. The dimension is
125
+ `num frames` x `channel` x `height` x `width`.
126
+ size (int): the size of height and width to crop on the image.
127
+ boxes (ndarray or None): optional. Corresponding boxes to images.
128
+ Dimension is `num boxes` x 4.
129
+ Returns:
130
+ cropped (tensor): cropped images with dimension of
131
+ `num frames` x `channel` x `size` x `size`.
132
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
133
+ `num boxes` x 4.
134
+ """
135
+ if images.shape[2] == size and images.shape[3] == size:
136
+ return images
137
+ height = images.shape[2]
138
+ width = images.shape[3]
139
+ y_offset = 0
140
+ if height > size:
141
+ y_offset = int(np.random.randint(0, height - size))
142
+ x_offset = 0
143
+ if width > size:
144
+ x_offset = int(np.random.randint(0, width - size))
145
+ cropped = images[
146
+ :, :, y_offset : y_offset + size, x_offset : x_offset + size
147
+ ]
148
+
149
+ cropped_boxes = (
150
+ crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
151
+ )
152
+
153
+ return cropped, cropped_boxes
154
+
155
+
156
+ def horizontal_flip(prob, images, boxes=None):
157
+ """
158
+ Perform horizontal flip on the given images and corresponding boxes.
159
+ Args:
160
+ prob (float): probility to flip the images.
161
+ images (tensor): images to perform horizontal flip, the dimension is
162
+ `num frames` x `channel` x `height` x `width`.
163
+ boxes (ndarray or None): optional. Corresponding boxes to images.
164
+ Dimension is `num boxes` x 4.
165
+ Returns:
166
+ images (tensor): images with dimension of
167
+ `num frames` x `channel` x `height` x `width`.
168
+ flipped_boxes (ndarray or None): the flipped boxes with dimension of
169
+ `num boxes` x 4.
170
+ """
171
+ if boxes is None:
172
+ flipped_boxes = None
173
+ else:
174
+ flipped_boxes = boxes.copy()
175
+
176
+ if np.random.uniform() < prob:
177
+ images = images.flip((-1))
178
+
179
+ if len(images.shape) == 3:
180
+ width = images.shape[2]
181
+ elif len(images.shape) == 4:
182
+ width = images.shape[3]
183
+ else:
184
+ raise NotImplementedError("Dimension does not supported")
185
+ if boxes is not None:
186
+ flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1
187
+
188
+ return images, flipped_boxes
189
+
190
+
191
+ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
192
+ """
193
+ Perform uniform spatial sampling on the images and corresponding boxes.
194
+ Args:
195
+ images (tensor): images to perform uniform crop. The dimension is
196
+ `num frames` x `channel` x `height` x `width`.
197
+ size (int): size of height and weight to crop the images.
198
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
199
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
200
+ crop if height is larger than width.
201
+ boxes (ndarray or None): optional. Corresponding boxes to images.
202
+ Dimension is `num boxes` x 4.
203
+ scale_size (int): optinal. If not None, resize the images to scale_size before
204
+ performing any crop.
205
+ Returns:
206
+ cropped (tensor): images with dimension of
207
+ `num frames` x `channel` x `size` x `size`.
208
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
209
+ `num boxes` x 4.
210
+ """
211
+ assert spatial_idx in [0, 1, 2]
212
+ ndim = len(images.shape)
213
+ if ndim == 3:
214
+ images = images.unsqueeze(0)
215
+ height = images.shape[2]
216
+ width = images.shape[3]
217
+
218
+ if scale_size is not None:
219
+ if width <= height:
220
+ width, height = scale_size, int(height / width * scale_size)
221
+ else:
222
+ width, height = int(width / height * scale_size), scale_size
223
+ images = torch.nn.functional.interpolate(
224
+ images,
225
+ size=(height, width),
226
+ mode="bilinear",
227
+ align_corners=False,
228
+ )
229
+
230
+ y_offset = int(math.ceil((height - size) / 2))
231
+ x_offset = int(math.ceil((width - size) / 2))
232
+
233
+ if height > width:
234
+ if spatial_idx == 0:
235
+ y_offset = 0
236
+ elif spatial_idx == 2:
237
+ y_offset = height - size
238
+ else:
239
+ if spatial_idx == 0:
240
+ x_offset = 0
241
+ elif spatial_idx == 2:
242
+ x_offset = width - size
243
+ cropped = images[
244
+ :, :, y_offset : y_offset + size, x_offset : x_offset + size
245
+ ]
246
+ cropped_boxes = (
247
+ crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
248
+ )
249
+ if ndim == 3:
250
+ cropped = cropped.squeeze(0)
251
+ return cropped, cropped_boxes
252
+
253
+
254
+ def clip_boxes_to_image(boxes, height, width):
255
+ """
256
+ Clip an array of boxes to an image with the given height and width.
257
+ Args:
258
+ boxes (ndarray): bounding boxes to perform clipping.
259
+ Dimension is `num boxes` x 4.
260
+ height (int): given image height.
261
+ width (int): given image width.
262
+ Returns:
263
+ clipped_boxes (ndarray): the clipped boxes with dimension of
264
+ `num boxes` x 4.
265
+ """
266
+ clipped_boxes = boxes.copy()
267
+ clipped_boxes[:, [0, 2]] = np.minimum(
268
+ width - 1.0, np.maximum(0.0, boxes[:, [0, 2]])
269
+ )
270
+ clipped_boxes[:, [1, 3]] = np.minimum(
271
+ height - 1.0, np.maximum(0.0, boxes[:, [1, 3]])
272
+ )
273
+ return clipped_boxes
274
+
275
+
276
+ def blend(images1, images2, alpha):
277
+ """
278
+ Blend two images with a given weight alpha.
279
+ Args:
280
+ images1 (tensor): the first images to be blended, the dimension is
281
+ `num frames` x `channel` x `height` x `width`.
282
+ images2 (tensor): the second images to be blended, the dimension is
283
+ `num frames` x `channel` x `height` x `width`.
284
+ alpha (float): the blending weight.
285
+ Returns:
286
+ (tensor): blended images, the dimension is
287
+ `num frames` x `channel` x `height` x `width`.
288
+ """
289
+ return images1 * alpha + images2 * (1 - alpha)
290
+
291
+
292
+ def grayscale(images):
293
+ """
294
+ Get the grayscale for the input images. The channels of images should be
295
+ in order BGR.
296
+ Args:
297
+ images (tensor): the input images for getting grayscale. Dimension is
298
+ `num frames` x `channel` x `height` x `width`.
299
+ Returns:
300
+ img_gray (tensor): blended images, the dimension is
301
+ `num frames` x `channel` x `height` x `width`.
302
+ """
303
+ # R -> 0.299, G -> 0.587, B -> 0.114.
304
+ img_gray = torch.tensor(images)
305
+ gray_channel = (
306
+ 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0]
307
+ )
308
+ img_gray[:, 0] = gray_channel
309
+ img_gray[:, 1] = gray_channel
310
+ img_gray[:, 2] = gray_channel
311
+ return img_gray
312
+
313
+
314
+ def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0):
315
+ """
316
+ Perfrom a color jittering on the input images. The channels of images
317
+ should be in order BGR.
318
+ Args:
319
+ images (tensor): images to perform color jitter. Dimension is
320
+ `num frames` x `channel` x `height` x `width`.
321
+ img_brightness (float): jitter ratio for brightness.
322
+ img_contrast (float): jitter ratio for contrast.
323
+ img_saturation (float): jitter ratio for saturation.
324
+ Returns:
325
+ images (tensor): the jittered images, the dimension is
326
+ `num frames` x `channel` x `height` x `width`.
327
+ """
328
+
329
+ jitter = []
330
+ if img_brightness != 0:
331
+ jitter.append("brightness")
332
+ if img_contrast != 0:
333
+ jitter.append("contrast")
334
+ if img_saturation != 0:
335
+ jitter.append("saturation")
336
+
337
+ if len(jitter) > 0:
338
+ order = np.random.permutation(np.arange(len(jitter)))
339
+ for idx in range(0, len(jitter)):
340
+ if jitter[order[idx]] == "brightness":
341
+ images = brightness_jitter(img_brightness, images)
342
+ elif jitter[order[idx]] == "contrast":
343
+ images = contrast_jitter(img_contrast, images)
344
+ elif jitter[order[idx]] == "saturation":
345
+ images = saturation_jitter(img_saturation, images)
346
+ return images
347
+
348
+
349
+ def brightness_jitter(var, images):
350
+ """
351
+ Perfrom brightness jittering on the input images. The channels of images
352
+ should be in order BGR.
353
+ Args:
354
+ var (float): jitter ratio for brightness.
355
+ images (tensor): images to perform color jitter. Dimension is
356
+ `num frames` x `channel` x `height` x `width`.
357
+ Returns:
358
+ images (tensor): the jittered images, the dimension is
359
+ `num frames` x `channel` x `height` x `width`.
360
+ """
361
+ alpha = 1.0 + np.random.uniform(-var, var)
362
+
363
+ img_bright = torch.zeros(images.shape)
364
+ images = blend(images, img_bright, alpha)
365
+ return images
366
+
367
+
368
+ def contrast_jitter(var, images):
369
+ """
370
+ Perfrom contrast jittering on the input images. The channels of images
371
+ should be in order BGR.
372
+ Args:
373
+ var (float): jitter ratio for contrast.
374
+ images (tensor): images to perform color jitter. Dimension is
375
+ `num frames` x `channel` x `height` x `width`.
376
+ Returns:
377
+ images (tensor): the jittered images, the dimension is
378
+ `num frames` x `channel` x `height` x `width`.
379
+ """
380
+ alpha = 1.0 + np.random.uniform(-var, var)
381
+
382
+ img_gray = grayscale(images)
383
+ img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True)
384
+ images = blend(images, img_gray, alpha)
385
+ return images
386
+
387
+
388
+ def saturation_jitter(var, images):
389
+ """
390
+ Perfrom saturation jittering on the input images. The channels of images
391
+ should be in order BGR.
392
+ Args:
393
+ var (float): jitter ratio for saturation.
394
+ images (tensor): images to perform color jitter. Dimension is
395
+ `num frames` x `channel` x `height` x `width`.
396
+ Returns:
397
+ images (tensor): the jittered images, the dimension is
398
+ `num frames` x `channel` x `height` x `width`.
399
+ """
400
+ alpha = 1.0 + np.random.uniform(-var, var)
401
+ img_gray = grayscale(images)
402
+ images = blend(images, img_gray, alpha)
403
+
404
+ return images
405
+
406
+
407
+ def lighting_jitter(images, alphastd, eigval, eigvec):
408
+ """
409
+ Perform AlexNet-style PCA jitter on the given images.
410
+ Args:
411
+ images (tensor): images to perform lighting jitter. Dimension is
412
+ `num frames` x `channel` x `height` x `width`.
413
+ alphastd (float): jitter ratio for PCA jitter.
414
+ eigval (list): eigenvalues for PCA jitter.
415
+ eigvec (list[list]): eigenvectors for PCA jitter.
416
+ Returns:
417
+ out_images (tensor): the jittered images, the dimension is
418
+ `num frames` x `channel` x `height` x `width`.
419
+ """
420
+ if alphastd == 0:
421
+ return images
422
+ # generate alpha1, alpha2, alpha3.
423
+ alpha = np.random.normal(0, alphastd, size=(1, 3))
424
+ eig_vec = np.array(eigvec)
425
+ eig_val = np.reshape(eigval, (1, 3))
426
+ rgb = np.sum(
427
+ eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
428
+ axis=1,
429
+ )
430
+ out_images = torch.zeros_like(images)
431
+ if len(images.shape) == 3:
432
+ # C H W
433
+ channel_dim = 0
434
+ elif len(images.shape) == 4:
435
+ # T C H W
436
+ channel_dim = 1
437
+ else:
438
+ raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
439
+
440
+ for idx in range(images.shape[channel_dim]):
441
+ # C H W
442
+ if len(images.shape) == 3:
443
+ out_images[idx] = images[idx] + rgb[2 - idx]
444
+ # T C H W
445
+ elif len(images.shape) == 4:
446
+ out_images[:, idx] = images[:, idx] + rgb[2 - idx]
447
+ else:
448
+ raise NotImplementedError(
449
+ f"Unsupported dimension {len(images.shape)}"
450
+ )
451
+
452
+ return out_images
453
+
454
+
455
+ def color_normalization(images, mean, stddev):
456
+ """
457
+ Perform color nomration on the given images.
458
+ Args:
459
+ images (tensor): images to perform color normalization. Dimension is
460
+ `num frames` x `channel` x `height` x `width`.
461
+ mean (list): mean values for normalization.
462
+ stddev (list): standard deviations for normalization.
463
+
464
+ Returns:
465
+ out_images (tensor): the noramlized images, the dimension is
466
+ `num frames` x `channel` x `height` x `width`.
467
+ """
468
+ if len(images.shape) == 3:
469
+ assert (
470
+ len(mean) == images.shape[0]
471
+ ), "channel mean not computed properly"
472
+ assert (
473
+ len(stddev) == images.shape[0]
474
+ ), "channel stddev not computed properly"
475
+ elif len(images.shape) == 4:
476
+ assert (
477
+ len(mean) == images.shape[1]
478
+ ), "channel mean not computed properly"
479
+ assert (
480
+ len(stddev) == images.shape[1]
481
+ ), "channel stddev not computed properly"
482
+ else:
483
+ raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
484
+
485
+ out_images = torch.zeros_like(images)
486
+ for idx in range(len(mean)):
487
+ # C H W
488
+ if len(images.shape) == 3:
489
+ out_images[idx] = (images[idx] - mean[idx]) / stddev[idx]
490
+ elif len(images.shape) == 4:
491
+ out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx]
492
+ else:
493
+ raise NotImplementedError(
494
+ f"Unsupported dimension {len(images.shape)}"
495
+ )
496
+ return out_images
497
+
498
+
499
+ def _get_param_spatial_crop(
500
+ scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False
501
+ ):
502
+ """
503
+ Given scale, ratio, height and width, return sampled coordinates of the videos.
504
+ """
505
+ for _ in range(num_repeat):
506
+ area = height * width
507
+ target_area = random.uniform(*scale) * area
508
+ if log_scale:
509
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
510
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
511
+ else:
512
+ aspect_ratio = random.uniform(*ratio)
513
+
514
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
515
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
516
+
517
+ if np.random.uniform() < 0.5 and switch_hw:
518
+ w, h = h, w
519
+
520
+ if 0 < w <= width and 0 < h <= height:
521
+ i = random.randint(0, height - h)
522
+ j = random.randint(0, width - w)
523
+ return i, j, h, w
524
+
525
+ # Fallback to central crop
526
+ in_ratio = float(width) / float(height)
527
+ if in_ratio < min(ratio):
528
+ w = width
529
+ h = int(round(w / min(ratio)))
530
+ elif in_ratio > max(ratio):
531
+ h = height
532
+ w = int(round(h * max(ratio)))
533
+ else: # whole image
534
+ w = width
535
+ h = height
536
+ i = (height - h) // 2
537
+ j = (width - w) // 2
538
+ return i, j, h, w
539
+
540
+
541
+ def random_resized_crop(
542
+ images,
543
+ target_height,
544
+ target_width,
545
+ scale=(0.8, 1.0),
546
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
547
+ ):
548
+ """
549
+ Crop the given images to random size and aspect ratio. A crop of random
550
+ size (default: of 0.08 to 1.0) of the original size and a random aspect
551
+ ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This
552
+ crop is finally resized to given size. This is popularly used to train the
553
+ Inception networks.
554
+
555
+ Args:
556
+ images: Images to perform resizing and cropping.
557
+ target_height: Desired height after cropping.
558
+ target_width: Desired width after cropping.
559
+ scale: Scale range of Inception-style area based random resizing.
560
+ ratio: Aspect ratio range of Inception-style area based random resizing.
561
+ """
562
+
563
+ height = images.shape[2]
564
+ width = images.shape[3]
565
+
566
+ i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
567
+ cropped = images[:, :, i : i + h, j : j + w]
568
+ return torch.nn.functional.interpolate(
569
+ cropped,
570
+ size=(target_height, target_width),
571
+ mode="bilinear",
572
+ align_corners=False,
573
+ )
574
+
575
+
576
+ def random_resized_crop_with_shift(
577
+ images,
578
+ target_height,
579
+ target_width,
580
+ scale=(0.8, 1.0),
581
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
582
+ ):
583
+ """
584
+ This is similar to random_resized_crop. However, it samples two different
585
+ boxes (for cropping) for the first and last frame. It then linearly
586
+ interpolates the two boxes for other frames.
587
+
588
+ Args:
589
+ images: Images to perform resizing and cropping.
590
+ target_height: Desired height after cropping.
591
+ target_width: Desired width after cropping.
592
+ scale: Scale range of Inception-style area based random resizing.
593
+ ratio: Aspect ratio range of Inception-style area based random resizing.
594
+ """
595
+ t = images.shape[1]
596
+ height = images.shape[2]
597
+ width = images.shape[3]
598
+
599
+ i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
600
+ i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width)
601
+ i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()]
602
+ j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()]
603
+ h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()]
604
+ w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()]
605
+ out = torch.zeros((3, t, target_height, target_width))
606
+ for ind in range(t):
607
+ out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate(
608
+ images[
609
+ :,
610
+ ind : ind + 1,
611
+ i_s[ind] : i_s[ind] + h_s[ind],
612
+ j_s[ind] : j_s[ind] + w_s[ind],
613
+ ],
614
+ size=(target_height, target_width),
615
+ mode="bilinear",
616
+ align_corners=False,
617
+ )
618
+ return out
619
+
620
+
621
+ def create_random_augment(
622
+ input_size,
623
+ auto_augment=None,
624
+ interpolation="bilinear",
625
+ ):
626
+ """
627
+ Get video randaug transform.
628
+
629
+ Args:
630
+ input_size: The size of the input video in tuple.
631
+ auto_augment: Parameters for randaug. An example:
632
+ "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number
633
+ of operations to apply).
634
+ interpolation: Interpolation method.
635
+ """
636
+ if isinstance(input_size, tuple):
637
+ img_size = input_size[-2:]
638
+ else:
639
+ img_size = input_size
640
+
641
+ if auto_augment:
642
+ assert isinstance(auto_augment, str)
643
+ if isinstance(img_size, tuple):
644
+ img_size_min = min(img_size)
645
+ else:
646
+ img_size_min = img_size
647
+ aa_params = {"translate_const": int(img_size_min * 0.45)}
648
+ if interpolation and interpolation != "random":
649
+ aa_params["interpolation"] = _pil_interp(interpolation)
650
+ if auto_augment.startswith("rand"):
651
+ return transforms.Compose(
652
+ [rand_augment_transform(auto_augment, aa_params)]
653
+ )
654
+ raise NotImplementedError
655
+
656
+
657
+ def random_sized_crop_img(
658
+ im,
659
+ size,
660
+ jitter_scale=(0.08, 1.0),
661
+ jitter_aspect=(3.0 / 4.0, 4.0 / 3.0),
662
+ max_iter=10,
663
+ ):
664
+ """
665
+ Performs Inception-style cropping (used for training).
666
+ """
667
+ assert (
668
+ len(im.shape) == 3
669
+ ), "Currently only support image for random_sized_crop"
670
+ h, w = im.shape[1:3]
671
+ i, j, h, w = _get_param_spatial_crop(
672
+ scale=jitter_scale,
673
+ ratio=jitter_aspect,
674
+ height=h,
675
+ width=w,
676
+ num_repeat=max_iter,
677
+ log_scale=False,
678
+ switch_hw=True,
679
+ )
680
+ cropped = im[:, i : i + h, j : j + w]
681
+ return torch.nn.functional.interpolate(
682
+ cropped.unsqueeze(0),
683
+ size=(size, size),
684
+ mode="bilinear",
685
+ align_corners=False,
686
+ ).squeeze(0)
687
+
688
+
689
+ # The following code are modified based on timm lib, we will replace the following
690
+ # contents with dependency from PyTorchVideo.
691
+ # https://github.com/facebookresearch/pytorchvideo
692
+ class RandomResizedCropAndInterpolation:
693
+ """Crop the given PIL Image to random size and aspect ratio with random interpolation.
694
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
695
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
696
+ is finally resized to given size.
697
+ This is popularly used to train the Inception networks.
698
+ Args:
699
+ size: expected output size of each edge
700
+ scale: range of size of the origin size cropped
701
+ ratio: range of aspect ratio of the origin aspect ratio cropped
702
+ interpolation: Default: PIL.Image.BILINEAR
703
+ """
704
+
705
+ def __init__(
706
+ self,
707
+ size,
708
+ scale=(0.08, 1.0),
709
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
710
+ interpolation="bilinear",
711
+ ):
712
+ if isinstance(size, tuple):
713
+ self.size = size
714
+ else:
715
+ self.size = (size, size)
716
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
717
+ print("range should be of kind (min, max)")
718
+
719
+ if interpolation == "random":
720
+ self.interpolation = _RANDOM_INTERPOLATION
721
+ else:
722
+ self.interpolation = _pil_interp(interpolation)
723
+ self.scale = scale
724
+ self.ratio = ratio
725
+
726
+ @staticmethod
727
+ def get_params(img, scale, ratio):
728
+ """Get parameters for ``crop`` for a random sized crop.
729
+ Args:
730
+ img (PIL Image): Image to be cropped.
731
+ scale (tuple): range of size of the origin size cropped
732
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
733
+ Returns:
734
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
735
+ sized crop.
736
+ """
737
+ area = img.size[0] * img.size[1]
738
+
739
+ for _ in range(10):
740
+ target_area = random.uniform(*scale) * area
741
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
742
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
743
+
744
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
745
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
746
+
747
+ if w <= img.size[0] and h <= img.size[1]:
748
+ i = random.randint(0, img.size[1] - h)
749
+ j = random.randint(0, img.size[0] - w)
750
+ return i, j, h, w
751
+
752
+ # Fallback to central crop
753
+ in_ratio = img.size[0] / img.size[1]
754
+ if in_ratio < min(ratio):
755
+ w = img.size[0]
756
+ h = int(round(w / min(ratio)))
757
+ elif in_ratio > max(ratio):
758
+ h = img.size[1]
759
+ w = int(round(h * max(ratio)))
760
+ else: # whole image
761
+ w = img.size[0]
762
+ h = img.size[1]
763
+ i = (img.size[1] - h) // 2
764
+ j = (img.size[0] - w) // 2
765
+ return i, j, h, w
766
+
767
+ def __call__(self, img):
768
+ """
769
+ Args:
770
+ img (PIL Image): Image to be cropped and resized.
771
+ Returns:
772
+ PIL Image: Randomly cropped and resized image.
773
+ """
774
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
775
+ if isinstance(self.interpolation, (tuple, list)):
776
+ interpolation = random.choice(self.interpolation)
777
+ else:
778
+ interpolation = self.interpolation
779
+ return F.resized_crop(img, i, j, h, w, self.size, interpolation)
780
+
781
+ def __repr__(self):
782
+ if isinstance(self.interpolation, (tuple, list)):
783
+ interpolate_str = " ".join(
784
+ [_pil_interpolation_to_str[x] for x in self.interpolation]
785
+ )
786
+ else:
787
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
788
+ format_string = self.__class__.__name__ + "(size={0}".format(self.size)
789
+ format_string += ", scale={0}".format(
790
+ tuple(round(s, 4) for s in self.scale)
791
+ )
792
+ format_string += ", ratio={0}".format(
793
+ tuple(round(r, 4) for r in self.ratio)
794
+ )
795
+ format_string += ", interpolation={0})".format(interpolate_str)
796
+ return format_string
797
+
798
+
799
+ def transforms_imagenet_train(
800
+ img_size=224,
801
+ scale=None,
802
+ ratio=None,
803
+ hflip=0.5,
804
+ vflip=0.0,
805
+ color_jitter=0.4,
806
+ auto_augment=None,
807
+ interpolation="random",
808
+ use_prefetcher=False,
809
+ mean=(0.485, 0.456, 0.406),
810
+ std=(0.229, 0.224, 0.225),
811
+ re_prob=0.0,
812
+ re_mode="const",
813
+ re_count=1,
814
+ re_num_splits=0,
815
+ separate=False,
816
+ ):
817
+ """
818
+ If separate==True, the transforms are returned as a tuple of 3 separate transforms
819
+ for use in a mixing dataset that passes
820
+ * all data through the first (primary) transform, called the 'clean' data
821
+ * a portion of the data through the secondary transform
822
+ * normalizes and converts the branches above with the third, final transform
823
+ """
824
+ if isinstance(img_size, tuple):
825
+ img_size = img_size[-2:]
826
+ else:
827
+ img_size = img_size
828
+
829
+ scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
830
+ ratio = tuple(
831
+ ratio or (3.0 / 4.0, 4.0 / 3.0)
832
+ ) # default imagenet ratio range
833
+ primary_tfl = [
834
+ RandomResizedCropAndInterpolation(
835
+ img_size, scale=scale, ratio=ratio, interpolation=interpolation
836
+ )
837
+ ]
838
+ if hflip > 0.0:
839
+ primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
840
+ if vflip > 0.0:
841
+ primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
842
+
843
+ secondary_tfl = []
844
+ if auto_augment:
845
+ assert isinstance(auto_augment, str)
846
+ if isinstance(img_size, tuple):
847
+ img_size_min = min(img_size)
848
+ else:
849
+ img_size_min = img_size
850
+ aa_params = dict(
851
+ translate_const=int(img_size_min * 0.45),
852
+ img_mean=tuple([min(255, round(255 * x)) for x in mean]),
853
+ )
854
+ if interpolation and interpolation != "random":
855
+ aa_params["interpolation"] = _pil_interp(interpolation)
856
+ if auto_augment.startswith("rand"):
857
+ secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
858
+ elif auto_augment.startswith("augmix"):
859
+ raise NotImplementedError("Augmix not implemented")
860
+ else:
861
+ raise NotImplementedError("Auto aug not implemented")
862
+ elif color_jitter is not None:
863
+ # color jitter is enabled when not using AA
864
+ if isinstance(color_jitter, (list, tuple)):
865
+ # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
866
+ # or 4 if also augmenting hue
867
+ assert len(color_jitter) in (3, 4)
868
+ else:
869
+ # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
870
+ color_jitter = (float(color_jitter),) * 3
871
+ secondary_tfl += [transforms.ColorJitter(*color_jitter)]
872
+
873
+ final_tfl = []
874
+ final_tfl += [
875
+ transforms.ToTensor(),
876
+ transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
877
+ ]
878
+ if re_prob > 0.0:
879
+ final_tfl.append(
880
+ RandomErasing(
881
+ re_prob,
882
+ mode=re_mode,
883
+ max_count=re_count,
884
+ num_splits=re_num_splits,
885
+ device="cpu",
886
+ cube=False,
887
+ )
888
+ )
889
+
890
+ if separate:
891
+ return (
892
+ transforms.Compose(primary_tfl),
893
+ transforms.Compose(secondary_tfl),
894
+ transforms.Compose(final_tfl),
895
+ )
896
+ else:
897
+ return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
898
+
899
+ ############################################################################################################
900
+ ############################################################################################################
901
+
902
+ class Compose(object):
903
+ """Composes several transforms
904
+ Args:
905
+ transforms (list of ``Transform`` objects): list of transforms
906
+ to compose
907
+ """
908
+
909
+ def __init__(self, transforms):
910
+ self.transforms = transforms
911
+
912
+ def __call__(self, clip):
913
+ for t in self.transforms:
914
+ clip = t(clip)
915
+ return clip
916
+
917
+
918
+ class RandomHorizontalFlip(object):
919
+ """Horizontally flip the list of given images randomly
920
+ with a probability 0.5
921
+ """
922
+
923
+ def __call__(self, clip):
924
+ """
925
+ Args:
926
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
927
+ in format (h, w, c) in numpy.ndarray
928
+ Returns:
929
+ PIL.Image or numpy.ndarray: Randomly flipped clip
930
+ """
931
+ if random.random() < 0.5:
932
+ if isinstance(clip[0], np.ndarray):
933
+ return [np.fliplr(img) for img in clip]
934
+ elif isinstance(clip[0], PIL.Image.Image):
935
+ return [
936
+ img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip
937
+ ]
938
+ else:
939
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
940
+ ' but got list of {0}'.format(type(clip[0])))
941
+ return clip
942
+
943
+
944
+ class RandomResize(object):
945
+ """Resizes a list of (H x W x C) numpy.ndarray to the final size
946
+ The larger the original image is, the more times it takes to
947
+ interpolate
948
+ Args:
949
+ interpolation (str): Can be one of 'nearest', 'bilinear'
950
+ defaults to nearest
951
+ size (tuple): (widht, height)
952
+ """
953
+
954
+ def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
955
+ self.ratio = ratio
956
+ self.interpolation = interpolation
957
+
958
+ def __call__(self, clip):
959
+ scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
960
+
961
+ if isinstance(clip[0], np.ndarray):
962
+ im_h, im_w, im_c = clip[0].shape
963
+ elif isinstance(clip[0], PIL.Image.Image):
964
+ im_w, im_h = clip[0].size
965
+
966
+ new_w = int(im_w * scaling_factor)
967
+ new_h = int(im_h * scaling_factor)
968
+ new_size = (new_w, new_h)
969
+ resized = FF.resize_clip(
970
+ clip, new_size, interpolation=self.interpolation)
971
+ return resized
972
+
973
+
974
+ class Resize(object):
975
+ """Resizes a list of (H x W x C) numpy.ndarray to the final size
976
+ The larger the original image is, the more times it takes to
977
+ interpolate
978
+ Args:
979
+ interpolation (str): Can be one of 'nearest', 'bilinear'
980
+ defaults to nearest
981
+ size (tuple): (widht, height)
982
+ """
983
+
984
+ def __init__(self, size, interpolation='nearest'):
985
+ self.size = size
986
+ self.interpolation = interpolation
987
+
988
+ def __call__(self, clip):
989
+ resized = FF.resize_clip(
990
+ clip, self.size, interpolation=self.interpolation)
991
+ return resized
992
+
993
+
994
+ class RandomCrop(object):
995
+ """Extract random crop at the same location for a list of images
996
+ Args:
997
+ size (sequence or int): Desired output size for the
998
+ crop in format (h, w)
999
+ """
1000
+
1001
+ def __init__(self, size):
1002
+ if isinstance(size, numbers.Number):
1003
+ size = (size, size)
1004
+
1005
+ self.size = size
1006
+
1007
+ def __call__(self, clip):
1008
+ """
1009
+ Args:
1010
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
1011
+ in format (h, w, c) in numpy.ndarray
1012
+ Returns:
1013
+ PIL.Image or numpy.ndarray: Cropped list of images
1014
+ """
1015
+ h, w = self.size
1016
+ if isinstance(clip[0], np.ndarray):
1017
+ im_h, im_w, im_c = clip[0].shape
1018
+ elif isinstance(clip[0], PIL.Image.Image):
1019
+ im_w, im_h = clip[0].size
1020
+ else:
1021
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
1022
+ 'but got list of {0}'.format(type(clip[0])))
1023
+ if w > im_w or h > im_h:
1024
+ error_msg = (
1025
+ 'Initial image size should be larger then '
1026
+ 'cropped size but got cropped sizes : ({w}, {h}) while '
1027
+ 'initial image is ({im_w}, {im_h})'.format(
1028
+ im_w=im_w, im_h=im_h, w=w, h=h))
1029
+ raise ValueError(error_msg)
1030
+
1031
+ x1 = random.randint(0, im_w - w)
1032
+ y1 = random.randint(0, im_h - h)
1033
+ cropped = FF.crop_clip(clip, y1, x1, h, w)
1034
+
1035
+ return cropped
1036
+
1037
+
1038
+ class ThreeCrop(object):
1039
+ """Extract random crop at the same location for a list of images
1040
+ Args:
1041
+ size (sequence or int): Desired output size for the
1042
+ crop in format (h, w)
1043
+ """
1044
+
1045
+ def __init__(self, size):
1046
+ if isinstance(size, numbers.Number):
1047
+ size = (size, size)
1048
+
1049
+ self.size = size
1050
+
1051
+ def __call__(self, clip):
1052
+ """
1053
+ Args:
1054
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
1055
+ in format (h, w, c) in numpy.ndarray
1056
+ Returns:
1057
+ PIL.Image or numpy.ndarray: Cropped list of images
1058
+ """
1059
+ h, w = self.size
1060
+ if isinstance(clip[0], np.ndarray):
1061
+ im_h, im_w, im_c = clip[0].shape
1062
+ elif isinstance(clip[0], PIL.Image.Image):
1063
+ im_w, im_h = clip[0].size
1064
+ else:
1065
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
1066
+ 'but got list of {0}'.format(type(clip[0])))
1067
+ if w != im_w and h != im_h:
1068
+ clip = FF.resize_clip(clip, self.size, interpolation="bilinear")
1069
+ im_h, im_w, im_c = clip[0].shape
1070
+
1071
+ step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0)
1072
+ cropped = []
1073
+ for i in range(3):
1074
+ if (im_h > self.size[0]):
1075
+ x1 = 0
1076
+ y1 = i * step
1077
+ cropped.extend(FF.crop_clip(clip, y1, x1, h, w))
1078
+ else:
1079
+ x1 = i * step
1080
+ y1 = 0
1081
+ cropped.extend(FF.crop_clip(clip, y1, x1, h, w))
1082
+ return cropped
1083
+
1084
+
1085
+ class RandomRotation(object):
1086
+ """Rotate entire clip randomly by a random angle within
1087
+ given bounds
1088
+ Args:
1089
+ degrees (sequence or int): Range of degrees to select from
1090
+ If degrees is a number instead of sequence like (min, max),
1091
+ the range of degrees, will be (-degrees, +degrees).
1092
+ """
1093
+
1094
+ def __init__(self, degrees):
1095
+ if isinstance(degrees, numbers.Number):
1096
+ if degrees < 0:
1097
+ raise ValueError('If degrees is a single number,'
1098
+ 'must be positive')
1099
+ degrees = (-degrees, degrees)
1100
+ else:
1101
+ if len(degrees) != 2:
1102
+ raise ValueError('If degrees is a sequence,'
1103
+ 'it must be of len 2.')
1104
+
1105
+ self.degrees = degrees
1106
+
1107
+ def __call__(self, clip):
1108
+ """
1109
+ Args:
1110
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
1111
+ in format (h, w, c) in numpy.ndarray
1112
+ Returns:
1113
+ PIL.Image or numpy.ndarray: Cropped list of images
1114
+ """
1115
+ import skimage
1116
+ angle = random.uniform(self.degrees[0], self.degrees[1])
1117
+ if isinstance(clip[0], np.ndarray):
1118
+ rotated = [skimage.transform.rotate(img, angle) for img in clip]
1119
+ elif isinstance(clip[0], PIL.Image.Image):
1120
+ rotated = [img.rotate(angle) for img in clip]
1121
+ else:
1122
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
1123
+ 'but got list of {0}'.format(type(clip[0])))
1124
+
1125
+ return rotated
1126
+
1127
+
1128
+ class CenterCrop(object):
1129
+ """Extract center crop at the same location for a list of images
1130
+ Args:
1131
+ size (sequence or int): Desired output size for the
1132
+ crop in format (h, w)
1133
+ """
1134
+
1135
+ def __init__(self, size):
1136
+ if isinstance(size, numbers.Number):
1137
+ size = (size, size)
1138
+
1139
+ self.size = size
1140
+
1141
+ def __call__(self, clip):
1142
+ """
1143
+ Args:
1144
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
1145
+ in format (h, w, c) in numpy.ndarray
1146
+ Returns:
1147
+ PIL.Image or numpy.ndarray: Cropped list of images
1148
+ """
1149
+ h, w = self.size
1150
+ if isinstance(clip[0], np.ndarray):
1151
+ im_h, im_w, im_c = clip[0].shape
1152
+ elif isinstance(clip[0], PIL.Image.Image):
1153
+ im_w, im_h = clip[0].size
1154
+ else:
1155
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
1156
+ 'but got list of {0}'.format(type(clip[0])))
1157
+ if w > im_w or h > im_h:
1158
+ error_msg = (
1159
+ 'Initial image size should be larger then '
1160
+ 'cropped size but got cropped sizes : ({w}, {h}) while '
1161
+ 'initial image is ({im_w}, {im_h})'.format(
1162
+ im_w=im_w, im_h=im_h, w=w, h=h))
1163
+ raise ValueError(error_msg)
1164
+
1165
+ x1 = int(round((im_w - w) / 2.))
1166
+ y1 = int(round((im_h - h) / 2.))
1167
+ cropped = FF.crop_clip(clip, y1, x1, h, w)
1168
+
1169
+ return cropped
1170
+
1171
+
1172
+ class ColorJitter(object):
1173
+ """Randomly change the brightness, contrast and saturation and hue of the clip
1174
+ Args:
1175
+ brightness (float): How much to jitter brightness. brightness_factor
1176
+ is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
1177
+ contrast (float): How much to jitter contrast. contrast_factor
1178
+ is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
1179
+ saturation (float): How much to jitter saturation. saturation_factor
1180
+ is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
1181
+ hue(float): How much to jitter hue. hue_factor is chosen uniformly from
1182
+ [-hue, hue]. Should be >=0 and <= 0.5.
1183
+ """
1184
+
1185
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
1186
+ self.brightness = brightness
1187
+ self.contrast = contrast
1188
+ self.saturation = saturation
1189
+ self.hue = hue
1190
+
1191
+ def get_params(self, brightness, contrast, saturation, hue):
1192
+ if brightness > 0:
1193
+ brightness_factor = random.uniform(
1194
+ max(0, 1 - brightness), 1 + brightness)
1195
+ else:
1196
+ brightness_factor = None
1197
+
1198
+ if contrast > 0:
1199
+ contrast_factor = random.uniform(
1200
+ max(0, 1 - contrast), 1 + contrast)
1201
+ else:
1202
+ contrast_factor = None
1203
+
1204
+ if saturation > 0:
1205
+ saturation_factor = random.uniform(
1206
+ max(0, 1 - saturation), 1 + saturation)
1207
+ else:
1208
+ saturation_factor = None
1209
+
1210
+ if hue > 0:
1211
+ hue_factor = random.uniform(-hue, hue)
1212
+ else:
1213
+ hue_factor = None
1214
+ return brightness_factor, contrast_factor, saturation_factor, hue_factor
1215
+
1216
+ def __call__(self, clip):
1217
+ """
1218
+ Args:
1219
+ clip (list): list of PIL.Image
1220
+ Returns:
1221
+ list PIL.Image : list of transformed PIL.Image
1222
+ """
1223
+ if isinstance(clip[0], np.ndarray):
1224
+ raise TypeError(
1225
+ 'Color jitter not yet implemented for numpy arrays')
1226
+ elif isinstance(clip[0], PIL.Image.Image):
1227
+ brightness, contrast, saturation, hue = self.get_params(
1228
+ self.brightness, self.contrast, self.saturation, self.hue)
1229
+
1230
+ # Create img transform function sequence
1231
+ img_transforms = []
1232
+ if brightness is not None:
1233
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
1234
+ if saturation is not None:
1235
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
1236
+ if hue is not None:
1237
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
1238
+ if contrast is not None:
1239
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
1240
+ random.shuffle(img_transforms)
1241
+
1242
+ # Apply to all images
1243
+ jittered_clip = []
1244
+ for img in clip:
1245
+ for func in img_transforms:
1246
+ jittered_img = func(img)
1247
+ jittered_clip.append(jittered_img)
1248
+
1249
+ else:
1250
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
1251
+ 'but got list of {0}'.format(type(clip[0])))
1252
+ return jittered_clip
1253
+
1254
+
1255
+ class Normalize(object):
1256
+ """Normalize a clip with mean and standard deviation.
1257
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
1258
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
1259
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
1260
+ .. note::
1261
+ This transform acts out of place, i.e., it does not mutates the input tensor.
1262
+ Args:
1263
+ mean (sequence): Sequence of means for each channel.
1264
+ std (sequence): Sequence of standard deviations for each channel.
1265
+ """
1266
+
1267
+ def __init__(self, mean, std):
1268
+ self.mean = mean
1269
+ self.std = std
1270
+
1271
+ def __call__(self, clip):
1272
+ """
1273
+ Args:
1274
+ clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized.
1275
+ Returns:
1276
+ Tensor: Normalized Tensor clip.
1277
+ """
1278
+ return FF.normalize(clip, self.mean, self.std)
1279
+
1280
+ def __repr__(self):
1281
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
volume_transforms.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch
4
+
5
+
6
+ def convert_img(img):
7
+ """Converts (H, W, C) numpy.ndarray to (C, W, H) format
8
+ """
9
+ if len(img.shape) == 3:
10
+ img = img.transpose(2, 0, 1)
11
+ if len(img.shape) == 2:
12
+ img = np.expand_dims(img, 0)
13
+ return img
14
+
15
+
16
+ class ClipToTensor(object):
17
+ """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
18
+ to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
19
+ """
20
+
21
+ def __init__(self, channel_nb=3, div_255=True, numpy=False):
22
+ self.channel_nb = channel_nb
23
+ self.div_255 = div_255
24
+ self.numpy = numpy
25
+
26
+ def __call__(self, clip):
27
+ """
28
+ Args: clip (list of numpy.ndarray): clip (list of images)
29
+ to be converted to tensor.
30
+ """
31
+ # Retrieve shape
32
+ if isinstance(clip[0], np.ndarray):
33
+ h, w, ch = clip[0].shape
34
+ assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
35
+ ch)
36
+ elif isinstance(clip[0], Image.Image):
37
+ w, h = clip[0].size
38
+ else:
39
+ raise TypeError('Expected numpy.ndarray or PIL.Image\
40
+ but got list of {0}'.format(type(clip[0])))
41
+
42
+ np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
43
+
44
+ # Convert
45
+ for img_idx, img in enumerate(clip):
46
+ if isinstance(img, np.ndarray):
47
+ pass
48
+ elif isinstance(img, Image.Image):
49
+ img = np.array(img, copy=False)
50
+ else:
51
+ raise TypeError('Expected numpy.ndarray or PIL.Image\
52
+ but got list of {0}'.format(type(clip[0])))
53
+ img = convert_img(img)
54
+ np_clip[:, img_idx, :, :] = img
55
+ if self.numpy:
56
+ if self.div_255:
57
+ np_clip = np_clip / 255.0
58
+ return np_clip
59
+
60
+ else:
61
+ tensor_clip = torch.from_numpy(np_clip)
62
+
63
+ if not isinstance(tensor_clip, torch.FloatTensor):
64
+ tensor_clip = tensor_clip.float()
65
+ if self.div_255:
66
+ tensor_clip = torch.div(tensor_clip, 255)
67
+ return tensor_clip
68
+
69
+
70
+ # Note this norms data to -1/1
71
+ class ClipToTensor_K(object):
72
+ """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
73
+ to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
74
+ """
75
+
76
+ def __init__(self, channel_nb=3, div_255=True, numpy=False):
77
+ self.channel_nb = channel_nb
78
+ self.div_255 = div_255
79
+ self.numpy = numpy
80
+
81
+ def __call__(self, clip):
82
+ """
83
+ Args: clip (list of numpy.ndarray): clip (list of images)
84
+ to be converted to tensor.
85
+ """
86
+ # Retrieve shape
87
+ if isinstance(clip[0], np.ndarray):
88
+ h, w, ch = clip[0].shape
89
+ assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
90
+ ch)
91
+ elif isinstance(clip[0], Image.Image):
92
+ w, h = clip[0].size
93
+ else:
94
+ raise TypeError('Expected numpy.ndarray or PIL.Image\
95
+ but got list of {0}'.format(type(clip[0])))
96
+
97
+ np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
98
+
99
+ # Convert
100
+ for img_idx, img in enumerate(clip):
101
+ if isinstance(img, np.ndarray):
102
+ pass
103
+ elif isinstance(img, Image.Image):
104
+ img = np.array(img, copy=False)
105
+ else:
106
+ raise TypeError('Expected numpy.ndarray or PIL.Image\
107
+ but got list of {0}'.format(type(clip[0])))
108
+ img = convert_img(img)
109
+ np_clip[:, img_idx, :, :] = img
110
+ if self.numpy:
111
+ if self.div_255:
112
+ np_clip = (np_clip - 127.5) / 127.5
113
+ return np_clip
114
+
115
+ else:
116
+ tensor_clip = torch.from_numpy(np_clip)
117
+
118
+ if not isinstance(tensor_clip, torch.FloatTensor):
119
+ tensor_clip = tensor_clip.float()
120
+ if self.div_255:
121
+ tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5)
122
+ return tensor_clip
123
+
124
+
125
+ class ToTensor(object):
126
+ """Converts numpy array to tensor
127
+ """
128
+
129
+ def __call__(self, array):
130
+ tensor = torch.from_numpy(array)
131
+ return tensor