rezasalatin
commited on
Commit
·
0e4f45d
1
Parent(s):
b4cc03f
Add all files and directories
Browse files- Training_Station/Duck_Rectified/labels.txt +3 -0
- image_module/__init__.py +0 -0
- image_module/dataset_water.py +161 -0
- image_module/transforms.py +151 -0
- myutils/__init__.py +2 -0
- myutils/data.py +149 -0
- myutils/system.py +103 -0
- records/link_efficientb4_model.pth +3 -0
- video_module/__init__.py +0 -0
- video_module/dataset/Water_DS.py +95 -0
- video_module/dataset/__init__.py +1 -0
- video_module/dataset/transforms.py +468 -0
- video_module/model/AFB_URR.py +319 -0
- video_module/model/FeatureBank.py +149 -0
- video_module/model/__init__.py +2 -0
Training_Station/Duck_Rectified/labels.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__ignore__
|
2 |
+
_background_
|
3 |
+
water
|
image_module/__init__.py
ADDED
File without changes
|
image_module/dataset_water.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from glob import glob
|
4 |
+
import torchvision.transforms.functional as TF
|
5 |
+
from PIL import Image
|
6 |
+
from torch.utils import data
|
7 |
+
|
8 |
+
from . import transforms as my_tf
|
9 |
+
from myutils import load_image_in_PIL as load_img
|
10 |
+
|
11 |
+
|
12 |
+
def load_image_in_PIL(path, mode='RGB'):
|
13 |
+
img = Image.open(path)
|
14 |
+
img.load() # Very important for loading large image
|
15 |
+
return img.convert(mode)
|
16 |
+
|
17 |
+
|
18 |
+
class WaterDataset(data.Dataset):
|
19 |
+
|
20 |
+
def __init__(self, mode, dataset_path, input_size=None, test_case=None, eval_size=None):
|
21 |
+
|
22 |
+
super(WaterDataset, self).__init__()
|
23 |
+
|
24 |
+
self.mode = mode
|
25 |
+
self.input_size = input_size
|
26 |
+
self.test_case = test_case
|
27 |
+
self.img_list = []
|
28 |
+
self.label_list = []
|
29 |
+
self.verbose_flag = False
|
30 |
+
self.online_augmentation_per_epoch = 640
|
31 |
+
self.eval_size = eval_size
|
32 |
+
|
33 |
+
if mode == 'train_offline':
|
34 |
+
with open(os.path.join(dataset_path, 'train_imgs.txt')) as f:
|
35 |
+
water_subdirs = f.readlines()
|
36 |
+
water_subdirs = [x.strip() for x in water_subdirs]
|
37 |
+
|
38 |
+
print('Initialize offline training dataset:')
|
39 |
+
|
40 |
+
for sub_folder in water_subdirs:
|
41 |
+
label_list = glob(os.path.join(dataset_path, 'Annotations/', sub_folder, '*.png'))
|
42 |
+
label_list.sort(key=lambda x: (len(x), x))
|
43 |
+
self.label_list += label_list
|
44 |
+
|
45 |
+
name_list = [os.path.basename(x)[:-4] for x in label_list]
|
46 |
+
|
47 |
+
img_list = glob(os.path.join(dataset_path, 'JPEGImages/', sub_folder, '*.jpg'))
|
48 |
+
img_list.sort(key=lambda x: (len(x), x))
|
49 |
+
img_list_valid = []
|
50 |
+
for img_path in img_list:
|
51 |
+
if os.path.basename(img_path)[:-4] in name_list:
|
52 |
+
img_list_valid.append(img_path)
|
53 |
+
|
54 |
+
self.img_list += img_list_valid
|
55 |
+
|
56 |
+
print('Add', sub_folder, len(img_list_valid), 'files.')
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
elif mode == 'eval':
|
61 |
+
if test_case is None:
|
62 |
+
raise ('test_case can not be None.')
|
63 |
+
|
64 |
+
img_path = os.path.join(dataset_path, 'JPEGImages/', test_case)
|
65 |
+
img_list = os.listdir(img_path)
|
66 |
+
img_list.sort(key=lambda x: (len(x), x))
|
67 |
+
self.img_list = [os.path.join(img_path, name) for name in img_list]
|
68 |
+
|
69 |
+
first_frame_label_path = os.path.join(dataset_path, 'Annotations/', test_case, img_list[0])
|
70 |
+
|
71 |
+
# Detect label image format: png or jpg
|
72 |
+
first_frame_label_path = first_frame_label_path[:-3]
|
73 |
+
if os.path.exists(first_frame_label_path + 'png'):
|
74 |
+
first_frame_label_path += 'png'
|
75 |
+
else:
|
76 |
+
first_frame_label_path += 'jpg'
|
77 |
+
|
78 |
+
if not os.path.exists(first_frame_label_path):
|
79 |
+
label_list = glob(os.path.join(dataset_path, 'Annotations/', test_case, '*.png'))
|
80 |
+
label_list.sort(key=lambda x: (x, len(x)))
|
81 |
+
first_frame_label_path = label_list[0]
|
82 |
+
|
83 |
+
self.first_frame = load_image_in_PIL(self.img_list[0], 'RGB')
|
84 |
+
self.img_list.pop(0)
|
85 |
+
|
86 |
+
self.first_frame_label = load_image_in_PIL(first_frame_label_path, 'P')
|
87 |
+
|
88 |
+
if self.eval_size:
|
89 |
+
self.origin_size = self.first_frame.size
|
90 |
+
self.first_frame = self.first_frame.resize(self.eval_size, Image.ANTIALIAS)
|
91 |
+
self.first_frame_label = self.first_frame_label.resize(self.eval_size, Image.ANTIALIAS)
|
92 |
+
|
93 |
+
else:
|
94 |
+
raise ('Mode %s does not support in [train_offline, train_online, eval].' % mode)
|
95 |
+
|
96 |
+
def __len__(self):
|
97 |
+
if self.mode == 'train_online':
|
98 |
+
return self.online_augmentation_per_epoch
|
99 |
+
else:
|
100 |
+
return len(self.img_list)
|
101 |
+
|
102 |
+
def get_first_frame(self):
|
103 |
+
img_tf = TF.to_tensor(self.first_frame)
|
104 |
+
img_tf = my_tf.imagenet_normalization(img_tf)
|
105 |
+
return img_tf
|
106 |
+
|
107 |
+
def get_first_frame_label(self):
|
108 |
+
return TF.to_tensor(self.first_frame_label)
|
109 |
+
|
110 |
+
def __getitem__(self, index):
|
111 |
+
raise NotImplementedError
|
112 |
+
|
113 |
+
|
114 |
+
class WaterDataset_RGB(WaterDataset):
|
115 |
+
def __init__(self, mode, dataset_path, input_size=None, test_case=None, eval_size=None):
|
116 |
+
super(WaterDataset_RGB, self).__init__(mode, dataset_path, input_size, test_case, eval_size)
|
117 |
+
|
118 |
+
def __getitem__(self, index):
|
119 |
+
if self.mode == 'train_offline' or self.mode == 'val_offline' or self.mode == 'test_offline':
|
120 |
+
img = load_img(self.img_list[index], 'RGB')
|
121 |
+
label = load_img(self.label_list[index], 'P')
|
122 |
+
return self.apply_transforms(img, label)
|
123 |
+
elif self.mode == 'train_online':
|
124 |
+
return self.apply_transforms(self.first_frame, self.first_frame_label)
|
125 |
+
elif self.mode == 'eval':
|
126 |
+
img = load_img(self.img_list[index], 'RGB')
|
127 |
+
if self.eval_size:
|
128 |
+
img = img.resize(self.eval_size, Image.ANTIALIAS)
|
129 |
+
return self.apply_transforms(img)
|
130 |
+
else:
|
131 |
+
raise Exception("Error: Invalid dataset mode!")
|
132 |
+
|
133 |
+
def resize_to_origin(self, img):
|
134 |
+
return img.resize(self.origin_size)
|
135 |
+
|
136 |
+
def apply_transforms(self, img, label=None):
|
137 |
+
if self.mode == 'train_offline' or self.mode == 'train_online':
|
138 |
+
img = my_tf.random_adjust_color(img, self.verbose_flag)
|
139 |
+
img, label = my_tf.random_affine_transformation(img, None, label, self.verbose_flag)
|
140 |
+
img, label = my_tf.random_resized_crop(img, None, label, self.input_size, self.verbose_flag)
|
141 |
+
elif self.mode == 'test_offline' or self.mode == 'val_offline':
|
142 |
+
img = TF.resize(img, self.input_size)
|
143 |
+
label = TF.resize(label, self.input_size)
|
144 |
+
elif self.mode == 'eval':
|
145 |
+
pass
|
146 |
+
|
147 |
+
img_orig = TF.to_tensor(img)
|
148 |
+
img_norm = my_tf.imagenet_normalization(img_orig)
|
149 |
+
|
150 |
+
if self.mode == 'train_offline' or self.mode == 'train_online':
|
151 |
+
# label = TF.to_tensor(label)
|
152 |
+
label = np.expand_dims(np.array(label, np.float32), axis=0)
|
153 |
+
return img_norm, label
|
154 |
+
elif self.mode == 'val_offline':
|
155 |
+
label = np.expand_dims(np.array(label, np.float32), axis=0)
|
156 |
+
return img_norm, label
|
157 |
+
elif self.mode == 'test_offline':
|
158 |
+
label = np.expand_dims(np.array(label, np.float32), axis=0)
|
159 |
+
return img_norm, label, img_orig
|
160 |
+
else:
|
161 |
+
return None
|
image_module/transforms.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image, ImageFilter
|
5 |
+
from scipy.ndimage import binary_erosion, binary_dilation
|
6 |
+
import torchvision.transforms.functional as TF
|
7 |
+
from torchvision.transforms import RandomResizedCrop
|
8 |
+
|
9 |
+
random_thres = 0.8
|
10 |
+
|
11 |
+
|
12 |
+
def random_adjust_color(img, verbose=False):
|
13 |
+
|
14 |
+
if random.random() < random_thres:
|
15 |
+
brightness_factor = random.uniform(0.1, 1.2)
|
16 |
+
img = TF.adjust_brightness(img, brightness_factor)
|
17 |
+
if verbose:
|
18 |
+
print('Brightness:', brightness_factor)
|
19 |
+
|
20 |
+
if random.random() < random_thres:
|
21 |
+
contrast_factor = random.uniform(0.2, 1.8)
|
22 |
+
img = TF.adjust_contrast(img, contrast_factor)
|
23 |
+
if verbose:
|
24 |
+
print('Contrast:', contrast_factor)
|
25 |
+
|
26 |
+
if random.random() < random_thres:
|
27 |
+
# hue_factor = random.uniform(-0.1, 0.1)
|
28 |
+
hue_factor = 0.1
|
29 |
+
img = TF.adjust_hue(img, hue_factor)
|
30 |
+
if verbose:
|
31 |
+
print('Hue:', hue_factor)
|
32 |
+
|
33 |
+
return img
|
34 |
+
|
35 |
+
def random_affine_transformation(img, mask, label, verbose=False):
|
36 |
+
|
37 |
+
if random.random() < random_thres:
|
38 |
+
degrees = random.uniform(-20, 20)
|
39 |
+
translate_h = random.uniform(-0.2, 0.2)
|
40 |
+
translate_v = random.uniform(-0.2, 0.2)
|
41 |
+
scale = random.uniform(0.7, 1.3)
|
42 |
+
shear = random.uniform(-20, 20)
|
43 |
+
resample = TF.InterpolationMode.BICUBIC
|
44 |
+
|
45 |
+
img = TF.affine(img, degrees, (translate_h, translate_v), scale, shear, resample)
|
46 |
+
if mask:
|
47 |
+
mask = TF.affine(mask, degrees, (translate_h, translate_v), scale, shear, resample)
|
48 |
+
label = TF.affine(label, degrees, (translate_h, translate_v), scale, shear, resample)
|
49 |
+
|
50 |
+
if verbose:
|
51 |
+
print('Affine degrees: %.1f, T_h: %.1f, T_v: %.1f, Scale: %.1f, Shear: %.1f' % \
|
52 |
+
(degrees, translate_h, translate_v, scale, shear))
|
53 |
+
|
54 |
+
if random.random() < 0.5:
|
55 |
+
|
56 |
+
img = TF.hflip(img)
|
57 |
+
if mask:
|
58 |
+
mask = TF.hflip(mask)
|
59 |
+
label = TF.hflip(label)
|
60 |
+
|
61 |
+
if verbose:
|
62 |
+
print('Horizontal flip')
|
63 |
+
|
64 |
+
if mask:
|
65 |
+
return img, mask, label
|
66 |
+
else:
|
67 |
+
return img, label
|
68 |
+
|
69 |
+
def random_mask_perturbation(mask, verbose=False):
|
70 |
+
|
71 |
+
degrees = random.uniform(-10, 10)
|
72 |
+
translate_h = random.uniform(-0.1, 0.1)
|
73 |
+
translate_v = random.uniform(-0.1, 0.1)
|
74 |
+
scale = random.uniform(0.8, 1.2)
|
75 |
+
shear = random.uniform(-10, 10)
|
76 |
+
resample = TF.InterpolationMode.BICUBIC
|
77 |
+
|
78 |
+
mask = TF.affine(mask, degrees, (translate_h, translate_v), scale, shear, resample)
|
79 |
+
|
80 |
+
if verbose:
|
81 |
+
print('Mask pertubation degrees: %.1f, T_h: %.1f, T_v: %.1f, Scale: %.1f, Shear: %.1f' % \
|
82 |
+
(degrees, translate_h, translate_v, scale, shear))
|
83 |
+
|
84 |
+
morphologic_times = int(random.random() * 10)
|
85 |
+
morphologic_thres = random.random()
|
86 |
+
filter_size = 7
|
87 |
+
for i in range(morphologic_times):
|
88 |
+
if random.random() < morphologic_thres:
|
89 |
+
mask = mask.filter(ImageFilter.MinFilter(filter_size))
|
90 |
+
if verbose:
|
91 |
+
print(i, 'erossion')
|
92 |
+
else:
|
93 |
+
mask = mask.filter(ImageFilter.MaxFilter(filter_size))
|
94 |
+
if verbose:
|
95 |
+
print(i, 'dilation')
|
96 |
+
|
97 |
+
mask = mask.convert('1')
|
98 |
+
|
99 |
+
return mask
|
100 |
+
|
101 |
+
def random_resized_crop(img, mask, label, size, verbose=False):
|
102 |
+
|
103 |
+
scale = (0.08, 1.0)
|
104 |
+
ratio = (0.75, 1.33333333)
|
105 |
+
|
106 |
+
sample_flag = False
|
107 |
+
|
108 |
+
for attempt in range(10):
|
109 |
+
area = img.size[0] * img.size[1]
|
110 |
+
target_area = random.uniform(*scale) * area
|
111 |
+
aspect_ratio = random.uniform(*ratio)
|
112 |
+
|
113 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
114 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
115 |
+
|
116 |
+
if random.random() < 0.5:
|
117 |
+
w, h = h, w
|
118 |
+
|
119 |
+
if w <= img.size[0] and h <= img.size[1]:
|
120 |
+
y = random.randint(0, img.size[1] - h)
|
121 |
+
x = random.randint(0, img.size[0] - w)
|
122 |
+
sample_flag = True
|
123 |
+
break
|
124 |
+
|
125 |
+
# Fallback
|
126 |
+
if not sample_flag:
|
127 |
+
w = min(img.size[0], img.size[1])
|
128 |
+
y = (img.size[1] - w) // 2
|
129 |
+
x = (img.size[0] - w) // 2
|
130 |
+
h = w
|
131 |
+
|
132 |
+
img = TF.resized_crop(img, y, x, h, w, size, TF.InterpolationMode.BICUBIC)
|
133 |
+
if mask:
|
134 |
+
mask = TF.resized_crop(mask, y, x, h, w, size, TF.InterpolationMode.BICUBIC)
|
135 |
+
label = TF.resized_crop(label, y, x, h, w, size, TF.InterpolationMode.BICUBIC)
|
136 |
+
|
137 |
+
if verbose:
|
138 |
+
print('x: %d, y: %d, w: %d, h: %d' % (x, y, w, h))
|
139 |
+
|
140 |
+
if mask:
|
141 |
+
return img, mask, label
|
142 |
+
else:
|
143 |
+
return img, label
|
144 |
+
|
145 |
+
def imagenet_normalization(img_tensor):
|
146 |
+
|
147 |
+
mean = [0.485, 0.456, 0.406]
|
148 |
+
std = [0.229, 0.224, 0.225]
|
149 |
+
img_tensor = TF.normalize(img_tensor, mean, std)
|
150 |
+
|
151 |
+
return img_tensor
|
myutils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .data import *
|
2 |
+
from .system import *
|
myutils/data.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from scipy.ndimage.morphology import binary_dilation
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
from numpy.linalg import norm
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch.nn import functional as NF
|
11 |
+
from torchvision.transforms import functional as TF
|
12 |
+
|
13 |
+
|
14 |
+
color_palette = [0, 0, 0, 0, 0, 128, 0, 128, 0, 128, 0, 0] + [100, 100, 100] * 252
|
15 |
+
|
16 |
+
|
17 |
+
def postprocessing_pred(pred: np.array) -> np.array:
|
18 |
+
|
19 |
+
label_cnt, labels = cv2.connectedComponentsWithAlgorithm(pred, 8, cv2.CV_32S, cv2.CCL_GRANA)
|
20 |
+
if label_cnt == 2:
|
21 |
+
if labels[0, 0] == pred[0, 0]:
|
22 |
+
pred = labels
|
23 |
+
else:
|
24 |
+
pred = 1 - labels
|
25 |
+
else:
|
26 |
+
max_cnt, max_label = 0, 0
|
27 |
+
for i in range(label_cnt):
|
28 |
+
mask = labels == i
|
29 |
+
if pred[mask][0] == 0:
|
30 |
+
continue
|
31 |
+
cnt = len(mask.nonzero()[0])
|
32 |
+
if cnt > max_cnt:
|
33 |
+
max_cnt = cnt
|
34 |
+
max_label = i
|
35 |
+
pred = labels == max_label
|
36 |
+
|
37 |
+
return pred.astype(np.uint8)
|
38 |
+
|
39 |
+
|
40 |
+
def calc_uncertainty(score):
|
41 |
+
|
42 |
+
# seg shape: bs, obj_n, h, w
|
43 |
+
score_top, _ = score.topk(k=2, dim=1)
|
44 |
+
uncertainty = score_top[:, 0] / (score_top[:, 1] + 1e-8) # bs, h, w
|
45 |
+
uncertainty = torch.exp(1 - uncertainty).unsqueeze(1) # bs, 1, h, w
|
46 |
+
return uncertainty
|
47 |
+
|
48 |
+
|
49 |
+
def save_seg_mask(pred, seg_path, palette=color_palette):
|
50 |
+
|
51 |
+
seg_img = Image.fromarray(pred)
|
52 |
+
seg_img.putpalette(palette)
|
53 |
+
seg_img.save(seg_path)
|
54 |
+
|
55 |
+
|
56 |
+
def add_overlay(img, mask, colors=color_palette, alpha=0.4, cscale=1):
|
57 |
+
|
58 |
+
ids = np.unique(mask)
|
59 |
+
img_overlay = img.copy()
|
60 |
+
ones_np = np.ones(img.shape) * (1 - alpha)
|
61 |
+
|
62 |
+
colors = np.reshape(colors, (-1, 3))
|
63 |
+
colors = np.atleast_2d(colors) * cscale
|
64 |
+
|
65 |
+
for i in ids[1:]:
|
66 |
+
|
67 |
+
canvas = img * alpha + ones_np * np.array(colors[i])[::-1]
|
68 |
+
|
69 |
+
binary_mask = mask == i
|
70 |
+
img_overlay[binary_mask] = canvas[binary_mask]
|
71 |
+
|
72 |
+
contour = binary_dilation(binary_mask) ^ binary_mask
|
73 |
+
img_overlay[contour, :] = 0
|
74 |
+
|
75 |
+
return img_overlay
|
76 |
+
|
77 |
+
|
78 |
+
def save_overlay(img, mask, overlay_path, colors=[255, 0, 0], alpha=0.4, cscale=1):
|
79 |
+
|
80 |
+
img = (img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
81 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
82 |
+
|
83 |
+
img_overlay = add_overlay(img, mask, colors, alpha, cscale)
|
84 |
+
cv2.imwrite(overlay_path, img_overlay)
|
85 |
+
|
86 |
+
|
87 |
+
def load_image_in_PIL(path, mode='RGB'):
|
88 |
+
img = Image.open(path)
|
89 |
+
img.load() # Very important for loading large image
|
90 |
+
return img.convert(mode)
|
91 |
+
|
92 |
+
|
93 |
+
def normalize(x):
|
94 |
+
return x / norm(x, ord=2, axis=1, keepdims=True)
|
95 |
+
|
96 |
+
|
97 |
+
def dist(p0, p1, axis):
|
98 |
+
return norm(p0 - p1, ord=2, axis=axis)
|
99 |
+
|
100 |
+
|
101 |
+
def resize_img(img, out_size):
|
102 |
+
h, w = img.shape[:2]
|
103 |
+
|
104 |
+
if h > w:
|
105 |
+
w_new = int(out_size * w / h)
|
106 |
+
h_new = out_size
|
107 |
+
else:
|
108 |
+
h_new = int(out_size * h / w)
|
109 |
+
w_new = out_size
|
110 |
+
|
111 |
+
img = cv2.resize(img, (w_new, h_new))
|
112 |
+
return img
|
113 |
+
|
114 |
+
|
115 |
+
def unify_features(features):
|
116 |
+
output_size = features['f0'].shape[-2:]
|
117 |
+
feature_tuple = tuple()
|
118 |
+
|
119 |
+
for key, f in features.items():
|
120 |
+
if key != 'f0':
|
121 |
+
f = NF.interpolate(
|
122 |
+
f,
|
123 |
+
size=output_size, mode='bilinear', align_corners=False
|
124 |
+
)
|
125 |
+
feature_tuple += (f,)
|
126 |
+
|
127 |
+
unified_feature = torch.cat(feature_tuple, dim=1)
|
128 |
+
|
129 |
+
return unified_feature
|
130 |
+
|
131 |
+
|
132 |
+
def pad_divide_by(in_list, d, in_size):
|
133 |
+
out_list = []
|
134 |
+
h, w = in_size
|
135 |
+
if h % d > 0:
|
136 |
+
new_h = h + d - h % d
|
137 |
+
else:
|
138 |
+
new_h = h
|
139 |
+
if w % d > 0:
|
140 |
+
new_w = w + d - w % d
|
141 |
+
else:
|
142 |
+
new_w = w
|
143 |
+
lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
|
144 |
+
lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
|
145 |
+
pad_array = (int(lw), int(uw), int(lh), int(uh))
|
146 |
+
for inp in in_list:
|
147 |
+
out_list.append(NF.pad(inp, pad_array))
|
148 |
+
|
149 |
+
return out_list, pad_array
|
myutils/system.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
class AvgMeter(object):
|
11 |
+
|
12 |
+
def __init__(self, window=-1):
|
13 |
+
self.window = window
|
14 |
+
self.reset()
|
15 |
+
|
16 |
+
def reset(self):
|
17 |
+
self.avg = 0
|
18 |
+
self.sum = 0
|
19 |
+
self.cnt = 0
|
20 |
+
self.max = -np.inf
|
21 |
+
|
22 |
+
if self.window > 0:
|
23 |
+
self.val_arr = np.zeros(self.window)
|
24 |
+
self.arr_idx = 0
|
25 |
+
|
26 |
+
def update(self, val, n=1):
|
27 |
+
|
28 |
+
self.cnt += n
|
29 |
+
self.max = max(self.max, val)
|
30 |
+
|
31 |
+
if self.window > 0:
|
32 |
+
self.val_arr[self.arr_idx] = val
|
33 |
+
self.arr_idx = (self.arr_idx + 1) % self.window
|
34 |
+
self.avg = self.val_arr.mean()
|
35 |
+
else:
|
36 |
+
self.sum += val * n
|
37 |
+
self.avg = self.sum / self.cnt
|
38 |
+
|
39 |
+
|
40 |
+
class FrameSecondMeter(object):
|
41 |
+
|
42 |
+
def __init__(self):
|
43 |
+
self.st = time.time()
|
44 |
+
self.fps = None
|
45 |
+
self.ed = None
|
46 |
+
self.frame_n = 0
|
47 |
+
|
48 |
+
def add_frame_n(self, frame_n):
|
49 |
+
self.frame_n += frame_n
|
50 |
+
|
51 |
+
def end(self):
|
52 |
+
self.ed = time.time()
|
53 |
+
self.fps = self.frame_n / (self.ed - self.st)
|
54 |
+
|
55 |
+
|
56 |
+
def gct(f='l'):
|
57 |
+
'''
|
58 |
+
get current time
|
59 |
+
:param f: 'l' for log, 'f' for file name
|
60 |
+
:return: formatted time
|
61 |
+
'''
|
62 |
+
if f == 'l':
|
63 |
+
return time.strftime('%m/%d %H:%M:%S', time.localtime(time.time()))
|
64 |
+
elif f == 'f':
|
65 |
+
return time.strftime('%m_%d_%H_%M', time.localtime(time.time()))
|
66 |
+
|
67 |
+
|
68 |
+
def save_scripts(path, scripts_to_save=None):
|
69 |
+
if not os.path.exists(os.path.join(path, 'scripts')):
|
70 |
+
os.makedirs(os.path.join(path, 'scripts'))
|
71 |
+
|
72 |
+
if scripts_to_save is not None:
|
73 |
+
for script in scripts_to_save:
|
74 |
+
dst_path = os.path.join(path, 'scripts', script)
|
75 |
+
try:
|
76 |
+
shutil.copy(script, dst_path)
|
77 |
+
except IOError:
|
78 |
+
os.makedirs(os.path.dirname(dst_path))
|
79 |
+
shutil.copy(script, dst_path)
|
80 |
+
|
81 |
+
|
82 |
+
def count_model_size(model):
|
83 |
+
return np.sum(np.prod(v.size()) for name, v in model.named_parameters()) / 1e6
|
84 |
+
|
85 |
+
|
86 |
+
def load_image_in_PIL(path, mode='RGB'):
|
87 |
+
img = Image.open(path)
|
88 |
+
img.load() # Very important for loading large image
|
89 |
+
return img.convert(mode)
|
90 |
+
|
91 |
+
|
92 |
+
def print_mem(info=None):
|
93 |
+
if info:
|
94 |
+
print(info, end=' ')
|
95 |
+
mem_allocated = round(torch.cuda.memory_allocated() / 1048576)
|
96 |
+
mem_cached = round(torch.cuda.memory_cached() / 1048576)
|
97 |
+
print(f'Mem allocated: {mem_allocated}MB, Mem cached: {mem_cached}MB')
|
98 |
+
|
99 |
+
|
100 |
+
def set_bn_eval(m):
|
101 |
+
classname = m.__class__.__name__
|
102 |
+
if classname.find('BatchNorm') != -1:
|
103 |
+
m.eval()
|
records/link_efficientb4_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab073b386a30375e0947246dbfcbab5b23197056d8918841f6b7e2764add7440
|
3 |
+
size 72294975
|
video_module/__init__.py
ADDED
File without changes
|
video_module/dataset/Water_DS.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from glob import glob
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.utils import data
|
8 |
+
import torchvision.transforms as TF
|
9 |
+
|
10 |
+
from video_module.dataset import transforms as mytrans
|
11 |
+
import myutils
|
12 |
+
|
13 |
+
|
14 |
+
class Water_Image_Train_DS(data.Dataset):
|
15 |
+
|
16 |
+
def __init__(self, root, output_size, clip_n, max_obj_n):
|
17 |
+
self.root = root
|
18 |
+
self.clip_n = clip_n
|
19 |
+
self.output_size = output_size
|
20 |
+
self.max_obj_n = max_obj_n
|
21 |
+
|
22 |
+
self.img_list = sorted(glob(os.path.join(root, 'JPEGImages', '*.jpg')) + glob(os.path.join(root, 'JPEGImages', '*.png')))
|
23 |
+
self.mask_list = sorted(glob(os.path.join(root, 'Annotations', '*.png')))
|
24 |
+
|
25 |
+
assert len(self.img_list) == len(self.mask_list), "The number of images and masks should be the same"
|
26 |
+
|
27 |
+
self.random_horizontal_flip = mytrans.RandomHorizontalFlip(0.3)
|
28 |
+
self.color_jitter = TF.ColorJitter(0.1, 0.1, 0.1, 0.03)
|
29 |
+
self.random_affine = mytrans.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10)
|
30 |
+
self.random_resize_crop = mytrans.RandomResizedCrop(output_size, (0.8, 1))
|
31 |
+
self.to_tensor = TF.ToTensor()
|
32 |
+
self.to_onehot = mytrans.ToOnehot(max_obj_n, shuffle=True)
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return len(self.img_list)
|
36 |
+
|
37 |
+
def __getitem__(self, idx):
|
38 |
+
|
39 |
+
img_pil = myutils.load_image_in_PIL(self.img_list[idx], 'RGB')
|
40 |
+
mask_pil = myutils.load_image_in_PIL(self.mask_list[idx], 'P')
|
41 |
+
|
42 |
+
frames = torch.zeros((self.clip_n, 3, self.output_size, self.output_size), dtype=torch.float)
|
43 |
+
masks = torch.zeros((self.clip_n, self.max_obj_n, self.output_size, self.output_size), dtype=torch.float)
|
44 |
+
|
45 |
+
for i in range(self.clip_n):
|
46 |
+
img, mask = img_pil, mask_pil
|
47 |
+
if i > 0:
|
48 |
+
img, mask = self.random_horizontal_flip(img, mask)
|
49 |
+
img = self.color_jitter(img)
|
50 |
+
img, mask = self.random_affine(img, mask)
|
51 |
+
|
52 |
+
img, mask = self.random_resize_crop(img, mask)
|
53 |
+
mask = np.array(mask, np.uint8)
|
54 |
+
|
55 |
+
if i == 0:
|
56 |
+
mask, obj_list = self.to_onehot(mask)
|
57 |
+
obj_n = len(obj_list) + 1
|
58 |
+
else:
|
59 |
+
mask, _ = self.to_onehot(mask, obj_list)
|
60 |
+
|
61 |
+
frames[i] = self.to_tensor(img)
|
62 |
+
masks[i] = mask
|
63 |
+
|
64 |
+
info = {
|
65 |
+
'name': self.img_list[idx]
|
66 |
+
}
|
67 |
+
return frames, masks[:, :obj_n], obj_n, info
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
class Video_DS(data.Dataset):
|
72 |
+
|
73 |
+
def __init__(self, img_list, first_frame, first_mask):
|
74 |
+
self.img_list = img_list[1:]
|
75 |
+
self.video_len = len(self.img_list)
|
76 |
+
|
77 |
+
first_mask = np.array(first_mask, np.uint8) > 0
|
78 |
+
self.obj_n = first_mask.max() + 1
|
79 |
+
|
80 |
+
self.to_tensor = TF.ToTensor()
|
81 |
+
self.to_onehot = mytrans.ToOnehot(self.obj_n, shuffle=False)
|
82 |
+
|
83 |
+
mask, _ = self.to_onehot(first_mask)
|
84 |
+
self.first_mask = mask[:self.obj_n]
|
85 |
+
self.first_frame = self.to_tensor(first_frame)
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return self.video_len
|
89 |
+
|
90 |
+
def __getitem__(self, idx):
|
91 |
+
img = myutils.load_image_in_PIL(self.img_list[idx], 'RGB')
|
92 |
+
frame = self.to_tensor(img)
|
93 |
+
img_name = os.path.basename(self.img_list[idx])[:-4]
|
94 |
+
|
95 |
+
return frame, img_name
|
video_module/dataset/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .Water_DS import *
|
video_module/dataset/transforms.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
import random
|
4 |
+
import numbers
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from collections.abc import Sequence
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms.functional as TF
|
11 |
+
|
12 |
+
_pil_interpolation_to_str = {
|
13 |
+
Image.NEAREST: 'PIL.Image.NEAREST',
|
14 |
+
Image.BILINEAR: 'PIL.Image.BILINEAR',
|
15 |
+
Image.BICUBIC: 'PIL.Image.BICUBIC',
|
16 |
+
Image.LANCZOS: 'PIL.Image.LANCZOS',
|
17 |
+
Image.HAMMING: 'PIL.Image.HAMMING',
|
18 |
+
Image.BOX: 'PIL.Image.BOX',
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def _get_image_size(img):
|
23 |
+
if TF._is_pil_image(img):
|
24 |
+
return img.size
|
25 |
+
elif isinstance(img, torch.Tensor) and img.dim() > 2:
|
26 |
+
return img.shape[-2:][::-1]
|
27 |
+
else:
|
28 |
+
raise TypeError("Unexpected type {}".format(type(img)))
|
29 |
+
|
30 |
+
|
31 |
+
class RandomHorizontalFlip(object):
|
32 |
+
"""Horizontal flip the given PIL Image randomly with a given probability.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
p (float): probability of the image being flipped. Default value is 0.5
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, p=0.5):
|
39 |
+
self.p = p
|
40 |
+
|
41 |
+
def __call__(self, img, mask):
|
42 |
+
"""
|
43 |
+
Args:
|
44 |
+
img (PIL Image): Image to be flipped.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
PIL Image: Randomly flipped image.
|
48 |
+
"""
|
49 |
+
if random.random() < self.p:
|
50 |
+
img = TF.hflip(img)
|
51 |
+
mask = TF.hflip(mask)
|
52 |
+
return img, mask
|
53 |
+
|
54 |
+
def __repr__(self):
|
55 |
+
return self.__class__.__name__ + '(p={})'.format(self.p)
|
56 |
+
|
57 |
+
|
58 |
+
class RandomAffine(object):
|
59 |
+
"""Random affine transformation of the image keeping center invariant
|
60 |
+
|
61 |
+
Args:
|
62 |
+
degrees (sequence or float or int): Range of degrees to select from.
|
63 |
+
If degrees is a number instead of sequence like (min, max), the range of degrees
|
64 |
+
will be (-degrees, +degrees). Set to 0 to deactivate rotations.
|
65 |
+
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
|
66 |
+
and vertical translations. For example translate=(a, b), then horizontal shift
|
67 |
+
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
|
68 |
+
randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
|
69 |
+
scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
|
70 |
+
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
|
71 |
+
shear (sequence or float or int, optional): Range of degrees to select from.
|
72 |
+
If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
|
73 |
+
will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
|
74 |
+
range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
|
75 |
+
a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
|
76 |
+
Will not apply shear by default
|
77 |
+
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
|
78 |
+
An optional resampling filter. See `filters`_ for more information.
|
79 |
+
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
|
80 |
+
fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area
|
81 |
+
outside the transform in the output image.(Pillow>=5.0.0)
|
82 |
+
|
83 |
+
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
|
84 |
+
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
|
88 |
+
if isinstance(degrees, numbers.Number):
|
89 |
+
if degrees < 0:
|
90 |
+
raise ValueError("If degrees is a single number, it must be positive.")
|
91 |
+
self.degrees = (-degrees, degrees)
|
92 |
+
else:
|
93 |
+
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
|
94 |
+
"degrees should be a list or tuple and it must be of length 2."
|
95 |
+
self.degrees = degrees
|
96 |
+
|
97 |
+
if translate is not None:
|
98 |
+
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
|
99 |
+
"translate should be a list or tuple and it must be of length 2."
|
100 |
+
for t in translate:
|
101 |
+
if not (0.0 <= t <= 1.0):
|
102 |
+
raise ValueError("translation values should be between 0 and 1")
|
103 |
+
self.translate = translate
|
104 |
+
|
105 |
+
if scale is not None:
|
106 |
+
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
|
107 |
+
"scale should be a list or tuple and it must be of length 2."
|
108 |
+
for s in scale:
|
109 |
+
if s <= 0:
|
110 |
+
raise ValueError("scale values should be positive")
|
111 |
+
self.scale = scale
|
112 |
+
|
113 |
+
if shear is not None:
|
114 |
+
if isinstance(shear, numbers.Number):
|
115 |
+
if shear < 0:
|
116 |
+
raise ValueError("If shear is a single number, it must be positive.")
|
117 |
+
self.shear = (-shear, shear)
|
118 |
+
else:
|
119 |
+
assert isinstance(shear, (tuple, list)) and \
|
120 |
+
(len(shear) == 2 or len(shear) == 4), \
|
121 |
+
"shear should be a list or tuple and it must be of length 2 or 4."
|
122 |
+
# X-Axis shear with [min, max]
|
123 |
+
if len(shear) == 2:
|
124 |
+
self.shear = [shear[0], shear[1], 0., 0.]
|
125 |
+
elif len(shear) == 4:
|
126 |
+
self.shear = [s for s in shear]
|
127 |
+
else:
|
128 |
+
self.shear = shear
|
129 |
+
|
130 |
+
self.resample = resample
|
131 |
+
self.fillcolor = fillcolor
|
132 |
+
|
133 |
+
@staticmethod
|
134 |
+
def get_params(degrees, translate, scale_ranges, shears, img_size):
|
135 |
+
"""Get parameters for affine transformation
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
sequence: params to be passed to the affine transformation
|
139 |
+
"""
|
140 |
+
angle = random.uniform(degrees[0], degrees[1])
|
141 |
+
if translate is not None:
|
142 |
+
max_dx = translate[0] * img_size[0]
|
143 |
+
max_dy = translate[1] * img_size[1]
|
144 |
+
translations = (np.round(random.uniform(-max_dx, max_dx)),
|
145 |
+
np.round(random.uniform(-max_dy, max_dy)))
|
146 |
+
else:
|
147 |
+
translations = (0, 0)
|
148 |
+
|
149 |
+
if scale_ranges is not None:
|
150 |
+
scale = random.uniform(scale_ranges[0], scale_ranges[1])
|
151 |
+
else:
|
152 |
+
scale = 1.0
|
153 |
+
|
154 |
+
if shears is not None:
|
155 |
+
if len(shears) == 2:
|
156 |
+
shear = [random.uniform(shears[0], shears[1]), 0.]
|
157 |
+
elif len(shears) == 4:
|
158 |
+
shear = [random.uniform(shears[0], shears[1]),
|
159 |
+
random.uniform(shears[2], shears[3])]
|
160 |
+
else:
|
161 |
+
shear = 0.0
|
162 |
+
|
163 |
+
return angle, translations, scale, shear
|
164 |
+
|
165 |
+
def __call__(self, img, mask):
|
166 |
+
"""
|
167 |
+
img (PIL Image): Image to be transformed.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
PIL Image: Affine transformed image.
|
171 |
+
"""
|
172 |
+
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
|
173 |
+
img = TF.affine(img, *ret, interpolation=TF.InterpolationMode.BICUBIC, fill=self.fillcolor)
|
174 |
+
mask = TF.affine(mask, *ret, interpolation=TF.InterpolationMode.NEAREST, fill=self.fillcolor)
|
175 |
+
return img, mask
|
176 |
+
|
177 |
+
def __repr__(self):
|
178 |
+
s = '{name}(degrees={degrees}'
|
179 |
+
if self.translate is not None:
|
180 |
+
s += ', translate={translate}'
|
181 |
+
if self.scale is not None:
|
182 |
+
s += ', scale={scale}'
|
183 |
+
if self.shear is not None:
|
184 |
+
s += ', shear={shear}'
|
185 |
+
if self.resample > 0:
|
186 |
+
s += ', resample={resample}'
|
187 |
+
if self.fillcolor != 0:
|
188 |
+
s += ', fillcolor={fillcolor}'
|
189 |
+
s += ')'
|
190 |
+
d = dict(self.__dict__)
|
191 |
+
d['resample'] = _pil_interpolation_to_str[d['resample']]
|
192 |
+
return s.format(name=self.__class__.__name__, **d)
|
193 |
+
|
194 |
+
|
195 |
+
class RandomCrop(object):
|
196 |
+
"""Crop the given PIL Image at a random location.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
200 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
201 |
+
made.
|
202 |
+
padding (int or sequence, optional): Optional padding on each border
|
203 |
+
of the image. Default is None, i.e no padding. If a sequence of length
|
204 |
+
4 is provided, it is used to pad left, top, right, bottom borders
|
205 |
+
respectively. If a sequence of length 2 is provided, it is used to
|
206 |
+
pad left/right, top/bottom borders, respectively.
|
207 |
+
pad_if_needed (boolean): It will pad the image if smaller than the
|
208 |
+
desired size to avoid raising an exception. Since cropping is done
|
209 |
+
after padding, the padding seems to be done at a random offset.
|
210 |
+
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
|
211 |
+
length 3, it is used to fill R, G, B channels respectively.
|
212 |
+
This value is only used when the padding_mode is constant
|
213 |
+
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
|
214 |
+
|
215 |
+
- constant: pads with a constant value, this value is specified with fill
|
216 |
+
|
217 |
+
- edge: pads with the last value on the edge of the image
|
218 |
+
|
219 |
+
- reflect: pads with reflection of image (without repeating the last value on the edge)
|
220 |
+
|
221 |
+
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
222 |
+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
223 |
+
|
224 |
+
- symmetric: pads with reflection of image (repeating the last value on the edge)
|
225 |
+
|
226 |
+
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
227 |
+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
228 |
+
|
229 |
+
"""
|
230 |
+
|
231 |
+
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
|
232 |
+
if isinstance(size, numbers.Number):
|
233 |
+
self.size = (int(size), int(size))
|
234 |
+
else:
|
235 |
+
self.size = size
|
236 |
+
self.padding = padding
|
237 |
+
self.pad_if_needed = pad_if_needed
|
238 |
+
self.fill = fill
|
239 |
+
self.padding_mode = padding_mode
|
240 |
+
|
241 |
+
@staticmethod
|
242 |
+
def get_params(img, output_size):
|
243 |
+
"""Get parameters for ``crop`` for a random crop.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
img (PIL Image): Image to be cropped.
|
247 |
+
output_size (tuple): Expected output size of the crop.
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
|
251 |
+
"""
|
252 |
+
w, h = _get_image_size(img)
|
253 |
+
th, tw = output_size
|
254 |
+
if w == tw and h == th:
|
255 |
+
return 0, 0, h, w
|
256 |
+
|
257 |
+
i = random.randint(0, h - th)
|
258 |
+
j = random.randint(0, w - tw)
|
259 |
+
return i, j, th, tw
|
260 |
+
|
261 |
+
def __call__(self, img, mask):
|
262 |
+
"""
|
263 |
+
Args:
|
264 |
+
img (PIL Image): Image to be cropped.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
PIL Image: Cropped image.
|
268 |
+
"""
|
269 |
+
# if self.padding is not None:
|
270 |
+
# img = TF.pad(img, self.padding, self.fill, self.padding_mode)
|
271 |
+
#
|
272 |
+
# # pad the width if needed
|
273 |
+
# if self.pad_if_needed and img.size[0] < self.size[1]:
|
274 |
+
# img = TF.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
|
275 |
+
# # pad the height if needed
|
276 |
+
# if self.pad_if_needed and img.size[1] < self.size[0]:
|
277 |
+
# img = TF.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
|
278 |
+
|
279 |
+
i, j, h, w = self.get_params(img, self.size)
|
280 |
+
img = TF.crop(img, i, j, h, w)
|
281 |
+
mask = TF.crop(mask, i, j, h, w)
|
282 |
+
|
283 |
+
return img, mask
|
284 |
+
|
285 |
+
def __repr__(self):
|
286 |
+
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
|
287 |
+
|
288 |
+
|
289 |
+
class RandomResizedCrop(object):
|
290 |
+
"""Crop the given PIL Image to random size and aspect ratio.
|
291 |
+
|
292 |
+
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
|
293 |
+
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
|
294 |
+
is finally resized to given size.
|
295 |
+
This is popularly used to train the Inception networks.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
size: expected output size of each edge
|
299 |
+
scale: range of size of the origin size cropped
|
300 |
+
ratio: range of aspect ratio of the origin aspect ratio cropped
|
301 |
+
interpolation: Default: PIL.Image.BILINEAR
|
302 |
+
"""
|
303 |
+
|
304 |
+
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
|
305 |
+
if isinstance(size, (tuple, list)):
|
306 |
+
self.size = size
|
307 |
+
else:
|
308 |
+
self.size = (size, size)
|
309 |
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
310 |
+
warnings.warn("range should be of kind (min, max)")
|
311 |
+
|
312 |
+
self.interpolation = interpolation
|
313 |
+
self.scale = scale
|
314 |
+
self.ratio = ratio
|
315 |
+
|
316 |
+
@staticmethod
|
317 |
+
def get_params(img, scale, ratio):
|
318 |
+
"""Get parameters for ``crop`` for a random sized crop.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
img (PIL Image): Image to be cropped.
|
322 |
+
scale (tuple): range of size of the origin size cropped
|
323 |
+
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
|
324 |
+
|
325 |
+
Returns:
|
326 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
327 |
+
sized crop.
|
328 |
+
"""
|
329 |
+
width, height = _get_image_size(img)
|
330 |
+
area = height * width
|
331 |
+
|
332 |
+
for _ in range(10):
|
333 |
+
target_area = random.uniform(*scale) * area
|
334 |
+
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
335 |
+
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
336 |
+
|
337 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
338 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
339 |
+
|
340 |
+
if 0 < w <= width and 0 < h <= height:
|
341 |
+
i = random.randint(0, height - h)
|
342 |
+
j = random.randint(0, width - w)
|
343 |
+
return i, j, h, w
|
344 |
+
|
345 |
+
# Fallback to central crop
|
346 |
+
in_ratio = float(width) / float(height)
|
347 |
+
if (in_ratio < min(ratio)):
|
348 |
+
w = width
|
349 |
+
h = int(round(w / min(ratio)))
|
350 |
+
elif (in_ratio > max(ratio)):
|
351 |
+
h = height
|
352 |
+
w = int(round(h * max(ratio)))
|
353 |
+
else: # whole image
|
354 |
+
w = width
|
355 |
+
h = height
|
356 |
+
i = (height - h) // 2
|
357 |
+
j = (width - w) // 2
|
358 |
+
return i, j, h, w
|
359 |
+
|
360 |
+
def __call__(self, img, mask):
|
361 |
+
"""
|
362 |
+
Args:
|
363 |
+
img (PIL Image): Image to be cropped and resized.
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
PIL Image: Randomly cropped and resized image.
|
367 |
+
"""
|
368 |
+
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
369 |
+
# print(i, j, h, w)
|
370 |
+
img = TF.resized_crop(img, i, j, h, w, self.size, TF.InterpolationMode.BICUBIC)
|
371 |
+
mask = TF.resized_crop(mask, i, j, h, w, self.size, TF.InterpolationMode.NEAREST)
|
372 |
+
return img, mask
|
373 |
+
|
374 |
+
def __repr__(self):
|
375 |
+
interpolate_str = _pil_interpolation_to_str[self.interpolation]
|
376 |
+
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
|
377 |
+
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
|
378 |
+
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
|
379 |
+
format_string += ', interpolation={0})'.format(interpolate_str)
|
380 |
+
return format_string
|
381 |
+
|
382 |
+
|
383 |
+
class ToOnehot(object):
|
384 |
+
"""To oneshot tensor
|
385 |
+
|
386 |
+
Args:
|
387 |
+
max_obj_n (float): Maximum number of the objects
|
388 |
+
"""
|
389 |
+
|
390 |
+
def __init__(self, max_obj_n, shuffle):
|
391 |
+
self.max_obj_n = max_obj_n
|
392 |
+
self.shuffle = shuffle
|
393 |
+
|
394 |
+
def __call__(self, mask, obj_list=None):
|
395 |
+
"""
|
396 |
+
Args:
|
397 |
+
mask (Mask in Numpy): Mask to be converted.
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
Tensor: Converted mask in onehot format.
|
401 |
+
"""
|
402 |
+
|
403 |
+
new_mask = np.zeros((self.max_obj_n, *mask.shape), np.uint8)
|
404 |
+
|
405 |
+
if not obj_list:
|
406 |
+
obj_list = list()
|
407 |
+
obj_max = mask.max() + 1
|
408 |
+
for i in range(1, obj_max):
|
409 |
+
tmp = (mask == i).astype(np.uint8)
|
410 |
+
if tmp.max() > 0:
|
411 |
+
obj_list.append(i)
|
412 |
+
|
413 |
+
if self.shuffle:
|
414 |
+
random.shuffle(obj_list)
|
415 |
+
obj_list = obj_list[:self.max_obj_n - 1]
|
416 |
+
|
417 |
+
for i in range(len(obj_list)):
|
418 |
+
new_mask[i + 1] = (mask == obj_list[i]).astype(np.uint8)
|
419 |
+
new_mask[0] = 1 - np.sum(new_mask, axis=0)
|
420 |
+
|
421 |
+
return torch.from_numpy(new_mask), obj_list
|
422 |
+
|
423 |
+
def __repr__(self):
|
424 |
+
return self.__class__.__name__ + '(max_obj_n={})'.format(self.max_obj_n)
|
425 |
+
|
426 |
+
|
427 |
+
class Resize(torch.nn.Module):
|
428 |
+
"""Resize the input image to the given size.
|
429 |
+
The image can be a PIL Image or a torch Tensor, in which case it is expected
|
430 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
|
431 |
+
|
432 |
+
Args:
|
433 |
+
size (sequence or int): Desired output size. If size is a sequence like
|
434 |
+
(h, w), output size will be matched to this. If size is an int,
|
435 |
+
smaller edge of the image will be matched to this number.
|
436 |
+
i.e, if height > width, then image will be rescaled to
|
437 |
+
(size * height / width, size).
|
438 |
+
In torchscript mode padding as single int is not supported, use a tuple or
|
439 |
+
list of length 1: ``[size, ]``.
|
440 |
+
interpolation (int, optional): Desired interpolation enum defined by `filters`_.
|
441 |
+
Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
|
442 |
+
and ``PIL.Image.BICUBIC`` are supported.
|
443 |
+
"""
|
444 |
+
|
445 |
+
def __init__(self, size, interpolation=Image.BILINEAR):
|
446 |
+
super().__init__()
|
447 |
+
if not isinstance(size, (int, Sequence)):
|
448 |
+
raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
|
449 |
+
if isinstance(size, Sequence) and len(size) not in (1, 2):
|
450 |
+
raise ValueError("If size is a sequence, it should have 1 or 2 values")
|
451 |
+
self.size = size
|
452 |
+
self.interpolation = interpolation
|
453 |
+
|
454 |
+
def forward(self, img, mask):
|
455 |
+
"""
|
456 |
+
Args:
|
457 |
+
img (PIL Image or Tensor): Image to be scaled.
|
458 |
+
|
459 |
+
Returns:
|
460 |
+
PIL Image or Tensor: Rescaled image.
|
461 |
+
"""
|
462 |
+
img = TF.resize(img, self.size, self.interpolation)
|
463 |
+
mask = TF.resize(mask, self.size, Image.NEAREST)
|
464 |
+
return img, mask
|
465 |
+
|
466 |
+
def __repr__(self):
|
467 |
+
interpolate_str = _pil_interpolation_to_str[self.interpolation]
|
468 |
+
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
|
video_module/model/AFB_URR.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torchvision.models import resnet50, ResNet50_Weights
|
6 |
+
|
7 |
+
import myutils
|
8 |
+
|
9 |
+
|
10 |
+
class ResBlock(nn.Module):
|
11 |
+
"""A simple residual block component."""
|
12 |
+
def __init__(self, indim, outdim=None, stride=1):
|
13 |
+
super(ResBlock, self).__init__()
|
14 |
+
outdim = outdim or indim
|
15 |
+
self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride)
|
16 |
+
self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1)
|
17 |
+
self.downsample = nn.Conv2d(indim, outdim, kernel_size=1, stride=stride) if indim != outdim or stride != 1 else None
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
identity = x
|
21 |
+
out = F.relu(self.conv1(x))
|
22 |
+
out = self.conv2(out)
|
23 |
+
if self.downsample:
|
24 |
+
identity = self.downsample(identity)
|
25 |
+
out += identity
|
26 |
+
return F.relu(out)
|
27 |
+
|
28 |
+
|
29 |
+
class EncoderM(nn.Module):
|
30 |
+
def __init__(self, load_imagenet_params):
|
31 |
+
super(EncoderM, self).__init__()
|
32 |
+
self.conv1_m = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
33 |
+
self.conv1_o = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
34 |
+
|
35 |
+
weights = ResNet50_Weights.IMAGENET1K_V1 if load_imagenet_params else None
|
36 |
+
resnet = resnet50(weights=weights)
|
37 |
+
self.conv1 = resnet.conv1
|
38 |
+
self.bn1 = resnet.bn1
|
39 |
+
self.relu = resnet.relu # 1/2, 64
|
40 |
+
self.maxpool = resnet.maxpool
|
41 |
+
|
42 |
+
self.res2 = resnet.layer1 # 1/4, 256
|
43 |
+
self.res3 = resnet.layer2 # 1/8, 512
|
44 |
+
self.res4 = resnet.layer3 # 1/16, 1024
|
45 |
+
|
46 |
+
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
47 |
+
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
48 |
+
|
49 |
+
def forward(self, in_f, in_m, in_o):
|
50 |
+
f = (in_f - self.mean) / self.std
|
51 |
+
|
52 |
+
x = self.conv1(f) + self.conv1_m(in_m) + self.conv1_o(in_o)
|
53 |
+
x = self.bn1(x)
|
54 |
+
r1 = self.relu(x) # 1/2, 64
|
55 |
+
x = self.maxpool(r1) # 1/4, 64
|
56 |
+
r2 = self.res2(x) # 1/4, 256
|
57 |
+
r3 = self.res3(r2) # 1/8, 512
|
58 |
+
r4 = self.res4(r3) # 1/16, 1024
|
59 |
+
|
60 |
+
return r4, r1
|
61 |
+
|
62 |
+
|
63 |
+
class EncoderQ(nn.Module):
|
64 |
+
def __init__(self, load_imagenet_params):
|
65 |
+
super(EncoderQ, self).__init__()
|
66 |
+
weights = ResNet50_Weights.IMAGENET1K_V1 if load_imagenet_params else None
|
67 |
+
resnet = resnet50(weights=weights)
|
68 |
+
self.conv1 = resnet.conv1
|
69 |
+
self.bn1 = resnet.bn1
|
70 |
+
self.relu = resnet.relu # 1/2, 64
|
71 |
+
self.maxpool = resnet.maxpool
|
72 |
+
|
73 |
+
self.res2 = resnet.layer1 # 1/4, 256
|
74 |
+
self.res3 = resnet.layer2 # 1/8, 512
|
75 |
+
self.res4 = resnet.layer3 # 1/8, 1024
|
76 |
+
|
77 |
+
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
78 |
+
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
79 |
+
|
80 |
+
def forward(self, in_f):
|
81 |
+
f = (in_f - self.mean) / self.std
|
82 |
+
|
83 |
+
x = self.conv1(f)
|
84 |
+
x = self.bn1(x)
|
85 |
+
r1 = self.relu(x) # 1/2, 64
|
86 |
+
x = self.maxpool(r1) # 1/4, 64
|
87 |
+
r2 = self.res2(x) # 1/4, 256
|
88 |
+
r3 = self.res3(r2) # 1/8, 512
|
89 |
+
r4 = self.res4(r3) # 1/8, 1024
|
90 |
+
|
91 |
+
return r4, r3, r2, r1
|
92 |
+
|
93 |
+
|
94 |
+
class KeyValue(nn.Module):
|
95 |
+
|
96 |
+
def __init__(self, indim, keydim, valdim):
|
97 |
+
super(KeyValue, self).__init__()
|
98 |
+
self.keydim = keydim
|
99 |
+
self.valdim = valdim
|
100 |
+
self.Key = nn.Conv2d(indim, keydim, kernel_size=(3, 3), padding=(1, 1), stride=1)
|
101 |
+
self.Value = nn.Conv2d(indim, valdim, kernel_size=(3, 3), padding=(1, 1), stride=1)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
key = self.Key(x)
|
105 |
+
key = key.view(*key.shape[:2], -1) # obj_n, key_dim, pixel_n
|
106 |
+
|
107 |
+
val = self.Value(x)
|
108 |
+
val = val.view(*val.shape[:2], -1) # obj_n, key_dim, pixel_n
|
109 |
+
return key, val
|
110 |
+
|
111 |
+
|
112 |
+
class Refine(nn.Module):
|
113 |
+
def __init__(self, inplanes, planes):
|
114 |
+
super(Refine, self).__init__()
|
115 |
+
self.convFS = nn.Conv2d(inplanes, planes, kernel_size=(3, 3), padding=(1, 1), stride=1)
|
116 |
+
self.ResFS = ResBlock(planes, planes)
|
117 |
+
self.ResMM = ResBlock(planes, planes)
|
118 |
+
self.scale_factor = 2
|
119 |
+
|
120 |
+
def forward(self, f, pm):
|
121 |
+
s = self.ResFS(self.convFS(f))
|
122 |
+
m = s + F.interpolate(pm, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
123 |
+
m = self.ResMM(m)
|
124 |
+
|
125 |
+
return m
|
126 |
+
|
127 |
+
|
128 |
+
class Matcher(nn.Module):
|
129 |
+
def __init__(self, thres_valid=1e-3, update_bank=False):
|
130 |
+
super(Matcher, self).__init__()
|
131 |
+
self.thres_valid = thres_valid
|
132 |
+
self.update_bank = update_bank
|
133 |
+
|
134 |
+
def forward(self, feature_bank, q_in, q_out):
|
135 |
+
|
136 |
+
mem_out_list = []
|
137 |
+
|
138 |
+
for i in range(0, feature_bank.obj_n):
|
139 |
+
d_key, bank_n = feature_bank.keys[i].size()
|
140 |
+
|
141 |
+
try:
|
142 |
+
p = torch.matmul(feature_bank.keys[i].transpose(0, 1), q_in) / math.sqrt(d_key) # THW, HW
|
143 |
+
p = F.softmax(p, dim=1) # bs, bank_n, HW
|
144 |
+
mem = torch.matmul(feature_bank.values[i], p) # bs, D_o, HW
|
145 |
+
except RuntimeError as e:
|
146 |
+
device = feature_bank.keys[i].device
|
147 |
+
key_cpu = feature_bank.keys[i].cpu()
|
148 |
+
value_cpu = feature_bank.values[i].cpu()
|
149 |
+
q_in_cpu = q_in.cpu()
|
150 |
+
|
151 |
+
p = torch.matmul(key_cpu.transpose(0, 1), q_in_cpu) / math.sqrt(d_key) # THW, HW
|
152 |
+
p = F.softmax(p, dim=1) # bs, bank_n, HW
|
153 |
+
mem = torch.matmul(value_cpu, p).to(device) # bs, D_o, HW
|
154 |
+
p = p.to(device)
|
155 |
+
print('\tLine 158. GPU out of memory, use CPU', f'p size: {p.shape}')
|
156 |
+
|
157 |
+
mem_out_list.append(torch.cat([mem, q_out], dim=1))
|
158 |
+
|
159 |
+
if self.update_bank:
|
160 |
+
try:
|
161 |
+
ones = torch.ones_like(p)
|
162 |
+
zeros = torch.zeros_like(p)
|
163 |
+
bank_cnt = torch.where(p > self.thres_valid, ones, zeros).sum(dim=2)[0]
|
164 |
+
except RuntimeError as e:
|
165 |
+
device = p.device
|
166 |
+
p = p.cpu()
|
167 |
+
ones = torch.ones_like(p)
|
168 |
+
zeros = torch.zeros_like(p)
|
169 |
+
bank_cnt = torch.where(p > self.thres_valid, ones, zeros).sum(dim=2)[0].to(device)
|
170 |
+
print('\tLine 170. GPU out of memory, use CPU', f'p size: {p.shape}')
|
171 |
+
|
172 |
+
feature_bank.info[i][:, 1] += torch.log(bank_cnt + 1)
|
173 |
+
|
174 |
+
mem_out_tensor = torch.stack(mem_out_list, dim=0).transpose(0, 1) # bs, obj_n, dim, pixel_n
|
175 |
+
|
176 |
+
return mem_out_tensor
|
177 |
+
|
178 |
+
|
179 |
+
class Decoder(nn.Module):
|
180 |
+
def __init__(self, device): # mdim_global = 256
|
181 |
+
super(Decoder, self).__init__()
|
182 |
+
|
183 |
+
self.device = device
|
184 |
+
mdim_global = 256
|
185 |
+
mdim_local = 32
|
186 |
+
local_size = 7
|
187 |
+
|
188 |
+
# Patch-wise
|
189 |
+
self.convFM = nn.Conv2d(1024, mdim_global, kernel_size=3, padding=1, stride=1)
|
190 |
+
self.ResMM = ResBlock(mdim_global, mdim_global)
|
191 |
+
self.RF3 = Refine(512, mdim_global) # 1/8 -> 1/8
|
192 |
+
self.RF2 = Refine(256, mdim_global) # 1/8 -> 1/4
|
193 |
+
self.pred2 = nn.Conv2d(mdim_global, 2, kernel_size=3, padding=1, stride=1)
|
194 |
+
|
195 |
+
# Local
|
196 |
+
self.local_avg = nn.AvgPool2d(local_size, stride=1, padding=local_size // 2)
|
197 |
+
self.local_max = nn.MaxPool2d(local_size, stride=1, padding=local_size // 2)
|
198 |
+
self.local_convFM = nn.Conv2d(128, mdim_local, kernel_size=3, padding=1, stride=1)
|
199 |
+
self.local_ResMM = ResBlock(mdim_local, mdim_local)
|
200 |
+
self.local_pred2 = nn.Conv2d(mdim_local, 2, kernel_size=3, padding=1, stride=1)
|
201 |
+
|
202 |
+
for m in self.modules():
|
203 |
+
if isinstance(m, nn.Conv2d):
|
204 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
205 |
+
|
206 |
+
def forward(self, patch_match, r3, r2, r1=None, feature_shape=None):
|
207 |
+
p = self.ResMM(self.convFM(patch_match))
|
208 |
+
p = self.RF3(r3, p) # out: 1/8, 256
|
209 |
+
p = self.RF2(r2, p) # out: 1/4, 256
|
210 |
+
p = self.pred2(F.relu(p))
|
211 |
+
|
212 |
+
p = F.interpolate(p, scale_factor=2, mode='bilinear', align_corners=False)
|
213 |
+
|
214 |
+
bs, obj_n, h, w = feature_shape
|
215 |
+
rough_seg = F.softmax(p, dim=1)[:, 1]
|
216 |
+
rough_seg = rough_seg.view(bs, obj_n, h, w)
|
217 |
+
rough_seg = F.softmax(rough_seg, dim=1) # object-level normalization
|
218 |
+
|
219 |
+
# Local refinement
|
220 |
+
uncertainty = myutils.calc_uncertainty(rough_seg)
|
221 |
+
uncertainty = uncertainty.expand(-1, obj_n, -1, -1).reshape(bs * obj_n, 1, h, w)
|
222 |
+
|
223 |
+
rough_seg = rough_seg.view(bs * obj_n, 1, h, w) # bs*obj_n, 1, h, w
|
224 |
+
r1_weighted = r1 * rough_seg
|
225 |
+
r1_local = self.local_avg(r1_weighted) # bs*obj_n, 64, h, w
|
226 |
+
r1_local = r1_local / (self.local_avg(rough_seg) + 1e-8) # neighborhood reference
|
227 |
+
r1_conf = self.local_max(rough_seg) # bs*obj_n, 1, h, w
|
228 |
+
|
229 |
+
local_match = torch.cat([r1, r1_local], dim=1)
|
230 |
+
q = self.local_ResMM(self.local_convFM(local_match))
|
231 |
+
q = r1_conf * self.local_pred2(F.relu(q))
|
232 |
+
|
233 |
+
p = p + uncertainty * q
|
234 |
+
p = F.interpolate(p, scale_factor=2, mode='bilinear', align_corners=False)
|
235 |
+
p = F.softmax(p, dim=1)[:, 1] # no, h, w
|
236 |
+
|
237 |
+
return p
|
238 |
+
|
239 |
+
|
240 |
+
class AFB_URR(nn.Module):
|
241 |
+
def __init__(self, device, update_bank, load_imagenet_params=False):
|
242 |
+
super(AFB_URR, self).__init__()
|
243 |
+
|
244 |
+
self.device = device
|
245 |
+
self.encoder_m = EncoderM(load_imagenet_params)
|
246 |
+
self.encoder_q = EncoderQ(load_imagenet_params)
|
247 |
+
|
248 |
+
self.keyval_r4 = KeyValue(1024, keydim=128, valdim=512)
|
249 |
+
|
250 |
+
self.global_matcher = Matcher(update_bank=update_bank)
|
251 |
+
self.decoder = Decoder(device)
|
252 |
+
|
253 |
+
def memorize(self, frame, mask):
|
254 |
+
|
255 |
+
_, K, H, W = mask.shape
|
256 |
+
|
257 |
+
(frame, mask), pad = myutils.pad_divide_by([frame, mask], 16, (frame.size()[2], frame.size()[3]))
|
258 |
+
|
259 |
+
frame = frame.expand(K, -1, -1, -1) # obj_n, 3, h, w
|
260 |
+
mask = mask[0].unsqueeze(1).float()
|
261 |
+
mask_ones = torch.ones_like(mask)
|
262 |
+
mask_inv = (mask_ones - mask).clamp(0, 1)
|
263 |
+
|
264 |
+
r4, r1 = self.encoder_m(frame, mask, mask_inv)
|
265 |
+
|
266 |
+
k4, v4 = self.keyval_r4(r4) # num_objects, 128 and 512, H/16, W/16
|
267 |
+
k4_list = [k4[i] for i in range(K)]
|
268 |
+
v4_list = [v4[i] for i in range(K)]
|
269 |
+
|
270 |
+
return k4_list, v4_list
|
271 |
+
|
272 |
+
def segment(self, frame, fb_global):
|
273 |
+
|
274 |
+
obj_n = fb_global.obj_n
|
275 |
+
|
276 |
+
if not self.training:
|
277 |
+
[frame], pad = myutils.pad_divide_by([frame], 16, (frame.size()[2], frame.size()[3]))
|
278 |
+
|
279 |
+
r4, r3, r2, r1 = self.encoder_q(frame)
|
280 |
+
bs, _, global_match_h, global_match_w = r4.shape
|
281 |
+
_, _, local_match_h, local_match_w = r1.shape
|
282 |
+
|
283 |
+
k4, v4 = self.keyval_r4(r4) # 1, dim, H/16, W/16
|
284 |
+
res_global = self.global_matcher(fb_global, k4, v4)
|
285 |
+
res_global = res_global.reshape(bs * obj_n, v4.shape[1] * 2, global_match_h, global_match_w)
|
286 |
+
|
287 |
+
r3_size = r3.shape
|
288 |
+
r2_size = r2.shape
|
289 |
+
r3 = r3.unsqueeze(1).expand(-1, obj_n, -1, -1, -1).reshape(bs * obj_n, *r3_size[1:])
|
290 |
+
r2 = r2.unsqueeze(1).expand(-1, obj_n, -1, -1, -1).reshape(bs * obj_n, *r2_size[1:])
|
291 |
+
|
292 |
+
r1_size = r1.shape
|
293 |
+
r1 = r1.unsqueeze(1).expand(-1, obj_n, -1, -1, -1).reshape(bs * obj_n, *r1_size[1:])
|
294 |
+
feature_size = (bs, obj_n, r1_size[2], r1_size[3])
|
295 |
+
score = self.decoder(res_global, r3, r2, r1, feature_size)
|
296 |
+
|
297 |
+
# score = score.view(obj_n, bs, *frame.shape[-2:]).permute(1, 0, 2, 3)
|
298 |
+
score = score.view(bs, obj_n, *frame.shape[-2:])
|
299 |
+
|
300 |
+
if self.training:
|
301 |
+
uncertainty = myutils.calc_uncertainty(F.softmax(score, dim=1))
|
302 |
+
uncertainty = uncertainty.view(bs, -1).norm(p=2, dim=1) / math.sqrt(frame.shape[-2] * frame.shape[-1]) # [B,1,H,W]
|
303 |
+
uncertainty = uncertainty.mean()
|
304 |
+
else:
|
305 |
+
uncertainty = None
|
306 |
+
|
307 |
+
score = torch.clamp(score, 1e-7, 1 - 1e-7)
|
308 |
+
score = torch.log((score / (1 - score)))
|
309 |
+
|
310 |
+
if not self.training:
|
311 |
+
if pad[2] + pad[3] > 0:
|
312 |
+
score = score[:, :, pad[2]:-pad[3], :]
|
313 |
+
if pad[0] + pad[1] > 0:
|
314 |
+
score = score[:, :, :, pad[0]:-pad[1]]
|
315 |
+
|
316 |
+
return score, uncertainty
|
317 |
+
|
318 |
+
def forward(self, x):
|
319 |
+
pass
|
video_module/model/FeatureBank.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as NF
|
4 |
+
|
5 |
+
from torch_scatter import scatter_mean
|
6 |
+
|
7 |
+
|
8 |
+
class FeatureBank:
|
9 |
+
|
10 |
+
def __init__(self, obj_n, memory_budget, device, update_rate=0.1, thres_close=0.95):
|
11 |
+
self.obj_n = obj_n
|
12 |
+
self.update_rate = update_rate
|
13 |
+
self.thres_close = thres_close
|
14 |
+
self.device = device
|
15 |
+
|
16 |
+
self.info = [None for _ in range(obj_n)]
|
17 |
+
self.peak_n = np.zeros(obj_n)
|
18 |
+
self.replace_n = np.zeros(obj_n)
|
19 |
+
|
20 |
+
self.class_budget = memory_budget // obj_n
|
21 |
+
if obj_n == 2:
|
22 |
+
self.class_budget = 0.8 * self.class_budget
|
23 |
+
|
24 |
+
self.keys = None
|
25 |
+
self.values = None
|
26 |
+
|
27 |
+
def init_bank(self, keys, values, frame_idx=0):
|
28 |
+
|
29 |
+
self.keys = keys
|
30 |
+
self.values = values
|
31 |
+
|
32 |
+
for class_idx in range(self.obj_n):
|
33 |
+
_, bank_n = keys[class_idx].shape
|
34 |
+
self.info[class_idx] = torch.zeros((bank_n, 2), device=self.device)
|
35 |
+
self.info[class_idx][:, 0] = frame_idx
|
36 |
+
self.peak_n[class_idx] = max(self.peak_n[class_idx], self.info[class_idx].shape[0])
|
37 |
+
|
38 |
+
def append(self, keys, values, frame_idx=0):
|
39 |
+
|
40 |
+
if self.keys:
|
41 |
+
for class_idx in range(self.obj_n):
|
42 |
+
self.keys[class_idx] = torch.cat([self.keys[class_idx], keys[class_idx]], dim=1)
|
43 |
+
self.values[class_idx] = torch.cat([self.values[class_idx], values[class_idx]], dim=1)
|
44 |
+
|
45 |
+
_, bank_n = keys[class_idx].shape
|
46 |
+
new_info = torch.ones((bank_n, 2), device=self.device) * 20 # zeros
|
47 |
+
new_info[:, 0] = frame_idx
|
48 |
+
self.info[class_idx] = torch.cat([self.info[class_idx], new_info], dim=0)
|
49 |
+
self.peak_n[class_idx] = max(self.peak_n[class_idx], self.info[class_idx].shape[0])
|
50 |
+
else:
|
51 |
+
self.init_bank(keys, values, frame_idx)
|
52 |
+
|
53 |
+
def update(self, prev_key, prev_value, frame_idx, update_rate=-1):
|
54 |
+
|
55 |
+
if update_rate == -1:
|
56 |
+
update_rate = self.update_rate
|
57 |
+
|
58 |
+
for class_idx in range(self.obj_n):
|
59 |
+
|
60 |
+
d_key, bank_n = self.keys[class_idx].shape
|
61 |
+
d_val, _ = self.values[class_idx].shape
|
62 |
+
|
63 |
+
normed_keys = NF.normalize(self.keys[class_idx], dim=0)
|
64 |
+
normed_prev_key = NF.normalize(prev_key[class_idx], dim=0)
|
65 |
+
mag_keys = self.keys[class_idx].norm(p=2, dim=0)
|
66 |
+
corr = torch.mm(normed_keys.transpose(0, 1), normed_prev_key) # bank_n, prev_n
|
67 |
+
related_bank_idx = corr.argmax(dim=0, keepdim=True) # 1, HW
|
68 |
+
related_bank_corr = torch.gather(corr, 0, related_bank_idx) # 1, HW
|
69 |
+
|
70 |
+
# greater than threshold, merge them
|
71 |
+
selected_idx = (related_bank_corr[0] > self.thres_close).nonzero(as_tuple=False)
|
72 |
+
class_related_bank_idx = related_bank_idx[0, selected_idx[:, 0]] # selected_HW
|
73 |
+
unique_related_bank_idx, cnt = class_related_bank_idx.unique(dim=0, return_counts=True) # selected_HW
|
74 |
+
|
75 |
+
# Update key
|
76 |
+
key_bank_update = torch.zeros((d_key, bank_n), dtype=torch.float, device=self.device) # d_key, THW
|
77 |
+
key_bank_idx = class_related_bank_idx.unsqueeze(0).expand(d_key, -1) # d_key, HW
|
78 |
+
scatter_mean(normed_prev_key[:, selected_idx[:, 0]], key_bank_idx, dim=1, out=key_bank_update)
|
79 |
+
# d_key, selected_HW
|
80 |
+
|
81 |
+
self.keys[class_idx][:, unique_related_bank_idx] = \
|
82 |
+
mag_keys[unique_related_bank_idx] * \
|
83 |
+
((1 - update_rate) * normed_keys[:, unique_related_bank_idx] + \
|
84 |
+
update_rate * key_bank_update[:, unique_related_bank_idx])
|
85 |
+
|
86 |
+
# Update value
|
87 |
+
normed_values = NF.normalize(self.values[class_idx], dim=0)
|
88 |
+
normed_prev_value = NF.normalize(prev_value[class_idx], dim=0)
|
89 |
+
mag_values = self.values[class_idx].norm(p=2, dim=0)
|
90 |
+
val_bank_update = torch.zeros((d_val, bank_n), dtype=torch.float, device=self.device)
|
91 |
+
val_bank_idx = class_related_bank_idx.unsqueeze(0).expand(d_val, -1)
|
92 |
+
scatter_mean(normed_prev_value[:, selected_idx[:, 0]], val_bank_idx, dim=1, out=val_bank_update)
|
93 |
+
|
94 |
+
self.values[class_idx][:, unique_related_bank_idx] = \
|
95 |
+
mag_values[unique_related_bank_idx] * \
|
96 |
+
((1 - update_rate) * normed_values[:, unique_related_bank_idx] + \
|
97 |
+
update_rate * val_bank_update[:, unique_related_bank_idx])
|
98 |
+
|
99 |
+
# less than the threshold, concat them
|
100 |
+
selected_idx = (related_bank_corr[0] <= self.thres_close).nonzero(as_tuple=False)
|
101 |
+
|
102 |
+
if self.class_budget < bank_n + selected_idx.shape[0]:
|
103 |
+
self.remove(class_idx, selected_idx.shape[0], frame_idx)
|
104 |
+
|
105 |
+
self.keys[class_idx] = torch.cat([self.keys[class_idx], prev_key[class_idx][:, selected_idx[:, 0]]], dim=1)
|
106 |
+
self.values[class_idx] = \
|
107 |
+
torch.cat([self.values[class_idx], prev_value[class_idx][:, selected_idx[:, 0]]], dim=1)
|
108 |
+
|
109 |
+
new_info = torch.zeros((selected_idx.shape[0], 2), device=self.device)
|
110 |
+
new_info[:, 0] = frame_idx
|
111 |
+
self.info[class_idx] = torch.cat([self.info[class_idx], new_info], dim=0)
|
112 |
+
|
113 |
+
self.peak_n[class_idx] = max(self.peak_n[class_idx], self.info[class_idx].shape[0])
|
114 |
+
|
115 |
+
self.info[class_idx][:, 1] = torch.clamp(self.info[class_idx][:, 1], 0, 1e5) # Prevent inf
|
116 |
+
|
117 |
+
def remove(self, class_idx, request_n, frame_idx):
|
118 |
+
|
119 |
+
old_size = self.keys[class_idx].shape[1]
|
120 |
+
|
121 |
+
LFU = frame_idx - self.info[class_idx][:, 0] # time length
|
122 |
+
LFU = self.info[class_idx][:, 1] / LFU
|
123 |
+
thres_dynamic = int(LFU.min()) + 1
|
124 |
+
iter_cnt = 0
|
125 |
+
|
126 |
+
while True:
|
127 |
+
selected_idx = LFU > thres_dynamic
|
128 |
+
self.keys[class_idx] = self.keys[class_idx][:, selected_idx]
|
129 |
+
self.values[class_idx] = self.values[class_idx][:, selected_idx]
|
130 |
+
self.info[class_idx] = self.info[class_idx][selected_idx]
|
131 |
+
LFU = LFU[selected_idx]
|
132 |
+
iter_cnt += 1
|
133 |
+
|
134 |
+
balance = (self.class_budget - self.keys[class_idx].shape[1]) - request_n
|
135 |
+
if balance < 0:
|
136 |
+
thres_dynamic = int(LFU.min()) + 1
|
137 |
+
else:
|
138 |
+
break
|
139 |
+
|
140 |
+
new_size = self.keys[class_idx].shape[1]
|
141 |
+
self.replace_n[class_idx] += old_size - new_size
|
142 |
+
|
143 |
+
return balance
|
144 |
+
|
145 |
+
def print_peak_mem(self):
|
146 |
+
|
147 |
+
ur = self.peak_n / self.class_budget
|
148 |
+
rr = self.replace_n / self.class_budget
|
149 |
+
print(f'Obj num: {self.obj_n}.', f'Budget / obj: {self.class_budget}.', f'UR: {ur}.', f'Replace: {rr}.')
|
video_module/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .AFB_URR import *
|
2 |
+
from .FeatureBank import FeatureBank
|