rezasalatin commited on
Commit
0e4f45d
·
1 Parent(s): b4cc03f

Add all files and directories

Browse files
Training_Station/Duck_Rectified/labels.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __ignore__
2
+ _background_
3
+ water
image_module/__init__.py ADDED
File without changes
image_module/dataset_water.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from glob import glob
4
+ import torchvision.transforms.functional as TF
5
+ from PIL import Image
6
+ from torch.utils import data
7
+
8
+ from . import transforms as my_tf
9
+ from myutils import load_image_in_PIL as load_img
10
+
11
+
12
+ def load_image_in_PIL(path, mode='RGB'):
13
+ img = Image.open(path)
14
+ img.load() # Very important for loading large image
15
+ return img.convert(mode)
16
+
17
+
18
+ class WaterDataset(data.Dataset):
19
+
20
+ def __init__(self, mode, dataset_path, input_size=None, test_case=None, eval_size=None):
21
+
22
+ super(WaterDataset, self).__init__()
23
+
24
+ self.mode = mode
25
+ self.input_size = input_size
26
+ self.test_case = test_case
27
+ self.img_list = []
28
+ self.label_list = []
29
+ self.verbose_flag = False
30
+ self.online_augmentation_per_epoch = 640
31
+ self.eval_size = eval_size
32
+
33
+ if mode == 'train_offline':
34
+ with open(os.path.join(dataset_path, 'train_imgs.txt')) as f:
35
+ water_subdirs = f.readlines()
36
+ water_subdirs = [x.strip() for x in water_subdirs]
37
+
38
+ print('Initialize offline training dataset:')
39
+
40
+ for sub_folder in water_subdirs:
41
+ label_list = glob(os.path.join(dataset_path, 'Annotations/', sub_folder, '*.png'))
42
+ label_list.sort(key=lambda x: (len(x), x))
43
+ self.label_list += label_list
44
+
45
+ name_list = [os.path.basename(x)[:-4] for x in label_list]
46
+
47
+ img_list = glob(os.path.join(dataset_path, 'JPEGImages/', sub_folder, '*.jpg'))
48
+ img_list.sort(key=lambda x: (len(x), x))
49
+ img_list_valid = []
50
+ for img_path in img_list:
51
+ if os.path.basename(img_path)[:-4] in name_list:
52
+ img_list_valid.append(img_path)
53
+
54
+ self.img_list += img_list_valid
55
+
56
+ print('Add', sub_folder, len(img_list_valid), 'files.')
57
+
58
+
59
+
60
+ elif mode == 'eval':
61
+ if test_case is None:
62
+ raise ('test_case can not be None.')
63
+
64
+ img_path = os.path.join(dataset_path, 'JPEGImages/', test_case)
65
+ img_list = os.listdir(img_path)
66
+ img_list.sort(key=lambda x: (len(x), x))
67
+ self.img_list = [os.path.join(img_path, name) for name in img_list]
68
+
69
+ first_frame_label_path = os.path.join(dataset_path, 'Annotations/', test_case, img_list[0])
70
+
71
+ # Detect label image format: png or jpg
72
+ first_frame_label_path = first_frame_label_path[:-3]
73
+ if os.path.exists(first_frame_label_path + 'png'):
74
+ first_frame_label_path += 'png'
75
+ else:
76
+ first_frame_label_path += 'jpg'
77
+
78
+ if not os.path.exists(first_frame_label_path):
79
+ label_list = glob(os.path.join(dataset_path, 'Annotations/', test_case, '*.png'))
80
+ label_list.sort(key=lambda x: (x, len(x)))
81
+ first_frame_label_path = label_list[0]
82
+
83
+ self.first_frame = load_image_in_PIL(self.img_list[0], 'RGB')
84
+ self.img_list.pop(0)
85
+
86
+ self.first_frame_label = load_image_in_PIL(first_frame_label_path, 'P')
87
+
88
+ if self.eval_size:
89
+ self.origin_size = self.first_frame.size
90
+ self.first_frame = self.first_frame.resize(self.eval_size, Image.ANTIALIAS)
91
+ self.first_frame_label = self.first_frame_label.resize(self.eval_size, Image.ANTIALIAS)
92
+
93
+ else:
94
+ raise ('Mode %s does not support in [train_offline, train_online, eval].' % mode)
95
+
96
+ def __len__(self):
97
+ if self.mode == 'train_online':
98
+ return self.online_augmentation_per_epoch
99
+ else:
100
+ return len(self.img_list)
101
+
102
+ def get_first_frame(self):
103
+ img_tf = TF.to_tensor(self.first_frame)
104
+ img_tf = my_tf.imagenet_normalization(img_tf)
105
+ return img_tf
106
+
107
+ def get_first_frame_label(self):
108
+ return TF.to_tensor(self.first_frame_label)
109
+
110
+ def __getitem__(self, index):
111
+ raise NotImplementedError
112
+
113
+
114
+ class WaterDataset_RGB(WaterDataset):
115
+ def __init__(self, mode, dataset_path, input_size=None, test_case=None, eval_size=None):
116
+ super(WaterDataset_RGB, self).__init__(mode, dataset_path, input_size, test_case, eval_size)
117
+
118
+ def __getitem__(self, index):
119
+ if self.mode == 'train_offline' or self.mode == 'val_offline' or self.mode == 'test_offline':
120
+ img = load_img(self.img_list[index], 'RGB')
121
+ label = load_img(self.label_list[index], 'P')
122
+ return self.apply_transforms(img, label)
123
+ elif self.mode == 'train_online':
124
+ return self.apply_transforms(self.first_frame, self.first_frame_label)
125
+ elif self.mode == 'eval':
126
+ img = load_img(self.img_list[index], 'RGB')
127
+ if self.eval_size:
128
+ img = img.resize(self.eval_size, Image.ANTIALIAS)
129
+ return self.apply_transforms(img)
130
+ else:
131
+ raise Exception("Error: Invalid dataset mode!")
132
+
133
+ def resize_to_origin(self, img):
134
+ return img.resize(self.origin_size)
135
+
136
+ def apply_transforms(self, img, label=None):
137
+ if self.mode == 'train_offline' or self.mode == 'train_online':
138
+ img = my_tf.random_adjust_color(img, self.verbose_flag)
139
+ img, label = my_tf.random_affine_transformation(img, None, label, self.verbose_flag)
140
+ img, label = my_tf.random_resized_crop(img, None, label, self.input_size, self.verbose_flag)
141
+ elif self.mode == 'test_offline' or self.mode == 'val_offline':
142
+ img = TF.resize(img, self.input_size)
143
+ label = TF.resize(label, self.input_size)
144
+ elif self.mode == 'eval':
145
+ pass
146
+
147
+ img_orig = TF.to_tensor(img)
148
+ img_norm = my_tf.imagenet_normalization(img_orig)
149
+
150
+ if self.mode == 'train_offline' or self.mode == 'train_online':
151
+ # label = TF.to_tensor(label)
152
+ label = np.expand_dims(np.array(label, np.float32), axis=0)
153
+ return img_norm, label
154
+ elif self.mode == 'val_offline':
155
+ label = np.expand_dims(np.array(label, np.float32), axis=0)
156
+ return img_norm, label
157
+ elif self.mode == 'test_offline':
158
+ label = np.expand_dims(np.array(label, np.float32), axis=0)
159
+ return img_norm, label, img_orig
160
+ else:
161
+ return None
image_module/transforms.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import math
3
+ import numpy as np
4
+ from PIL import Image, ImageFilter
5
+ from scipy.ndimage import binary_erosion, binary_dilation
6
+ import torchvision.transforms.functional as TF
7
+ from torchvision.transforms import RandomResizedCrop
8
+
9
+ random_thres = 0.8
10
+
11
+
12
+ def random_adjust_color(img, verbose=False):
13
+
14
+ if random.random() < random_thres:
15
+ brightness_factor = random.uniform(0.1, 1.2)
16
+ img = TF.adjust_brightness(img, brightness_factor)
17
+ if verbose:
18
+ print('Brightness:', brightness_factor)
19
+
20
+ if random.random() < random_thres:
21
+ contrast_factor = random.uniform(0.2, 1.8)
22
+ img = TF.adjust_contrast(img, contrast_factor)
23
+ if verbose:
24
+ print('Contrast:', contrast_factor)
25
+
26
+ if random.random() < random_thres:
27
+ # hue_factor = random.uniform(-0.1, 0.1)
28
+ hue_factor = 0.1
29
+ img = TF.adjust_hue(img, hue_factor)
30
+ if verbose:
31
+ print('Hue:', hue_factor)
32
+
33
+ return img
34
+
35
+ def random_affine_transformation(img, mask, label, verbose=False):
36
+
37
+ if random.random() < random_thres:
38
+ degrees = random.uniform(-20, 20)
39
+ translate_h = random.uniform(-0.2, 0.2)
40
+ translate_v = random.uniform(-0.2, 0.2)
41
+ scale = random.uniform(0.7, 1.3)
42
+ shear = random.uniform(-20, 20)
43
+ resample = TF.InterpolationMode.BICUBIC
44
+
45
+ img = TF.affine(img, degrees, (translate_h, translate_v), scale, shear, resample)
46
+ if mask:
47
+ mask = TF.affine(mask, degrees, (translate_h, translate_v), scale, shear, resample)
48
+ label = TF.affine(label, degrees, (translate_h, translate_v), scale, shear, resample)
49
+
50
+ if verbose:
51
+ print('Affine degrees: %.1f, T_h: %.1f, T_v: %.1f, Scale: %.1f, Shear: %.1f' % \
52
+ (degrees, translate_h, translate_v, scale, shear))
53
+
54
+ if random.random() < 0.5:
55
+
56
+ img = TF.hflip(img)
57
+ if mask:
58
+ mask = TF.hflip(mask)
59
+ label = TF.hflip(label)
60
+
61
+ if verbose:
62
+ print('Horizontal flip')
63
+
64
+ if mask:
65
+ return img, mask, label
66
+ else:
67
+ return img, label
68
+
69
+ def random_mask_perturbation(mask, verbose=False):
70
+
71
+ degrees = random.uniform(-10, 10)
72
+ translate_h = random.uniform(-0.1, 0.1)
73
+ translate_v = random.uniform(-0.1, 0.1)
74
+ scale = random.uniform(0.8, 1.2)
75
+ shear = random.uniform(-10, 10)
76
+ resample = TF.InterpolationMode.BICUBIC
77
+
78
+ mask = TF.affine(mask, degrees, (translate_h, translate_v), scale, shear, resample)
79
+
80
+ if verbose:
81
+ print('Mask pertubation degrees: %.1f, T_h: %.1f, T_v: %.1f, Scale: %.1f, Shear: %.1f' % \
82
+ (degrees, translate_h, translate_v, scale, shear))
83
+
84
+ morphologic_times = int(random.random() * 10)
85
+ morphologic_thres = random.random()
86
+ filter_size = 7
87
+ for i in range(morphologic_times):
88
+ if random.random() < morphologic_thres:
89
+ mask = mask.filter(ImageFilter.MinFilter(filter_size))
90
+ if verbose:
91
+ print(i, 'erossion')
92
+ else:
93
+ mask = mask.filter(ImageFilter.MaxFilter(filter_size))
94
+ if verbose:
95
+ print(i, 'dilation')
96
+
97
+ mask = mask.convert('1')
98
+
99
+ return mask
100
+
101
+ def random_resized_crop(img, mask, label, size, verbose=False):
102
+
103
+ scale = (0.08, 1.0)
104
+ ratio = (0.75, 1.33333333)
105
+
106
+ sample_flag = False
107
+
108
+ for attempt in range(10):
109
+ area = img.size[0] * img.size[1]
110
+ target_area = random.uniform(*scale) * area
111
+ aspect_ratio = random.uniform(*ratio)
112
+
113
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
114
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
115
+
116
+ if random.random() < 0.5:
117
+ w, h = h, w
118
+
119
+ if w <= img.size[0] and h <= img.size[1]:
120
+ y = random.randint(0, img.size[1] - h)
121
+ x = random.randint(0, img.size[0] - w)
122
+ sample_flag = True
123
+ break
124
+
125
+ # Fallback
126
+ if not sample_flag:
127
+ w = min(img.size[0], img.size[1])
128
+ y = (img.size[1] - w) // 2
129
+ x = (img.size[0] - w) // 2
130
+ h = w
131
+
132
+ img = TF.resized_crop(img, y, x, h, w, size, TF.InterpolationMode.BICUBIC)
133
+ if mask:
134
+ mask = TF.resized_crop(mask, y, x, h, w, size, TF.InterpolationMode.BICUBIC)
135
+ label = TF.resized_crop(label, y, x, h, w, size, TF.InterpolationMode.BICUBIC)
136
+
137
+ if verbose:
138
+ print('x: %d, y: %d, w: %d, h: %d' % (x, y, w, h))
139
+
140
+ if mask:
141
+ return img, mask, label
142
+ else:
143
+ return img, label
144
+
145
+ def imagenet_normalization(img_tensor):
146
+
147
+ mean = [0.485, 0.456, 0.406]
148
+ std = [0.229, 0.224, 0.225]
149
+ img_tensor = TF.normalize(img_tensor, mean, std)
150
+
151
+ return img_tensor
myutils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .data import *
2
+ from .system import *
myutils/data.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from scipy.ndimage.morphology import binary_dilation
5
+ import cv2
6
+
7
+ from numpy.linalg import norm
8
+
9
+ import torch
10
+ from torch.nn import functional as NF
11
+ from torchvision.transforms import functional as TF
12
+
13
+
14
+ color_palette = [0, 0, 0, 0, 0, 128, 0, 128, 0, 128, 0, 0] + [100, 100, 100] * 252
15
+
16
+
17
+ def postprocessing_pred(pred: np.array) -> np.array:
18
+
19
+ label_cnt, labels = cv2.connectedComponentsWithAlgorithm(pred, 8, cv2.CV_32S, cv2.CCL_GRANA)
20
+ if label_cnt == 2:
21
+ if labels[0, 0] == pred[0, 0]:
22
+ pred = labels
23
+ else:
24
+ pred = 1 - labels
25
+ else:
26
+ max_cnt, max_label = 0, 0
27
+ for i in range(label_cnt):
28
+ mask = labels == i
29
+ if pred[mask][0] == 0:
30
+ continue
31
+ cnt = len(mask.nonzero()[0])
32
+ if cnt > max_cnt:
33
+ max_cnt = cnt
34
+ max_label = i
35
+ pred = labels == max_label
36
+
37
+ return pred.astype(np.uint8)
38
+
39
+
40
+ def calc_uncertainty(score):
41
+
42
+ # seg shape: bs, obj_n, h, w
43
+ score_top, _ = score.topk(k=2, dim=1)
44
+ uncertainty = score_top[:, 0] / (score_top[:, 1] + 1e-8) # bs, h, w
45
+ uncertainty = torch.exp(1 - uncertainty).unsqueeze(1) # bs, 1, h, w
46
+ return uncertainty
47
+
48
+
49
+ def save_seg_mask(pred, seg_path, palette=color_palette):
50
+
51
+ seg_img = Image.fromarray(pred)
52
+ seg_img.putpalette(palette)
53
+ seg_img.save(seg_path)
54
+
55
+
56
+ def add_overlay(img, mask, colors=color_palette, alpha=0.4, cscale=1):
57
+
58
+ ids = np.unique(mask)
59
+ img_overlay = img.copy()
60
+ ones_np = np.ones(img.shape) * (1 - alpha)
61
+
62
+ colors = np.reshape(colors, (-1, 3))
63
+ colors = np.atleast_2d(colors) * cscale
64
+
65
+ for i in ids[1:]:
66
+
67
+ canvas = img * alpha + ones_np * np.array(colors[i])[::-1]
68
+
69
+ binary_mask = mask == i
70
+ img_overlay[binary_mask] = canvas[binary_mask]
71
+
72
+ contour = binary_dilation(binary_mask) ^ binary_mask
73
+ img_overlay[contour, :] = 0
74
+
75
+ return img_overlay
76
+
77
+
78
+ def save_overlay(img, mask, overlay_path, colors=[255, 0, 0], alpha=0.4, cscale=1):
79
+
80
+ img = (img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
81
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
82
+
83
+ img_overlay = add_overlay(img, mask, colors, alpha, cscale)
84
+ cv2.imwrite(overlay_path, img_overlay)
85
+
86
+
87
+ def load_image_in_PIL(path, mode='RGB'):
88
+ img = Image.open(path)
89
+ img.load() # Very important for loading large image
90
+ return img.convert(mode)
91
+
92
+
93
+ def normalize(x):
94
+ return x / norm(x, ord=2, axis=1, keepdims=True)
95
+
96
+
97
+ def dist(p0, p1, axis):
98
+ return norm(p0 - p1, ord=2, axis=axis)
99
+
100
+
101
+ def resize_img(img, out_size):
102
+ h, w = img.shape[:2]
103
+
104
+ if h > w:
105
+ w_new = int(out_size * w / h)
106
+ h_new = out_size
107
+ else:
108
+ h_new = int(out_size * h / w)
109
+ w_new = out_size
110
+
111
+ img = cv2.resize(img, (w_new, h_new))
112
+ return img
113
+
114
+
115
+ def unify_features(features):
116
+ output_size = features['f0'].shape[-2:]
117
+ feature_tuple = tuple()
118
+
119
+ for key, f in features.items():
120
+ if key != 'f0':
121
+ f = NF.interpolate(
122
+ f,
123
+ size=output_size, mode='bilinear', align_corners=False
124
+ )
125
+ feature_tuple += (f,)
126
+
127
+ unified_feature = torch.cat(feature_tuple, dim=1)
128
+
129
+ return unified_feature
130
+
131
+
132
+ def pad_divide_by(in_list, d, in_size):
133
+ out_list = []
134
+ h, w = in_size
135
+ if h % d > 0:
136
+ new_h = h + d - h % d
137
+ else:
138
+ new_h = h
139
+ if w % d > 0:
140
+ new_w = w + d - w % d
141
+ else:
142
+ new_w = w
143
+ lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
144
+ lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
145
+ pad_array = (int(lw), int(uw), int(lh), int(uh))
146
+ for inp in in_list:
147
+ out_list.append(NF.pad(inp, pad_array))
148
+
149
+ return out_list, pad_array
myutils/system.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import shutil
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ import torch
8
+
9
+
10
+ class AvgMeter(object):
11
+
12
+ def __init__(self, window=-1):
13
+ self.window = window
14
+ self.reset()
15
+
16
+ def reset(self):
17
+ self.avg = 0
18
+ self.sum = 0
19
+ self.cnt = 0
20
+ self.max = -np.inf
21
+
22
+ if self.window > 0:
23
+ self.val_arr = np.zeros(self.window)
24
+ self.arr_idx = 0
25
+
26
+ def update(self, val, n=1):
27
+
28
+ self.cnt += n
29
+ self.max = max(self.max, val)
30
+
31
+ if self.window > 0:
32
+ self.val_arr[self.arr_idx] = val
33
+ self.arr_idx = (self.arr_idx + 1) % self.window
34
+ self.avg = self.val_arr.mean()
35
+ else:
36
+ self.sum += val * n
37
+ self.avg = self.sum / self.cnt
38
+
39
+
40
+ class FrameSecondMeter(object):
41
+
42
+ def __init__(self):
43
+ self.st = time.time()
44
+ self.fps = None
45
+ self.ed = None
46
+ self.frame_n = 0
47
+
48
+ def add_frame_n(self, frame_n):
49
+ self.frame_n += frame_n
50
+
51
+ def end(self):
52
+ self.ed = time.time()
53
+ self.fps = self.frame_n / (self.ed - self.st)
54
+
55
+
56
+ def gct(f='l'):
57
+ '''
58
+ get current time
59
+ :param f: 'l' for log, 'f' for file name
60
+ :return: formatted time
61
+ '''
62
+ if f == 'l':
63
+ return time.strftime('%m/%d %H:%M:%S', time.localtime(time.time()))
64
+ elif f == 'f':
65
+ return time.strftime('%m_%d_%H_%M', time.localtime(time.time()))
66
+
67
+
68
+ def save_scripts(path, scripts_to_save=None):
69
+ if not os.path.exists(os.path.join(path, 'scripts')):
70
+ os.makedirs(os.path.join(path, 'scripts'))
71
+
72
+ if scripts_to_save is not None:
73
+ for script in scripts_to_save:
74
+ dst_path = os.path.join(path, 'scripts', script)
75
+ try:
76
+ shutil.copy(script, dst_path)
77
+ except IOError:
78
+ os.makedirs(os.path.dirname(dst_path))
79
+ shutil.copy(script, dst_path)
80
+
81
+
82
+ def count_model_size(model):
83
+ return np.sum(np.prod(v.size()) for name, v in model.named_parameters()) / 1e6
84
+
85
+
86
+ def load_image_in_PIL(path, mode='RGB'):
87
+ img = Image.open(path)
88
+ img.load() # Very important for loading large image
89
+ return img.convert(mode)
90
+
91
+
92
+ def print_mem(info=None):
93
+ if info:
94
+ print(info, end=' ')
95
+ mem_allocated = round(torch.cuda.memory_allocated() / 1048576)
96
+ mem_cached = round(torch.cuda.memory_cached() / 1048576)
97
+ print(f'Mem allocated: {mem_allocated}MB, Mem cached: {mem_cached}MB')
98
+
99
+
100
+ def set_bn_eval(m):
101
+ classname = m.__class__.__name__
102
+ if classname.find('BatchNorm') != -1:
103
+ m.eval()
records/link_efficientb4_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab073b386a30375e0947246dbfcbab5b23197056d8918841f6b7e2764add7440
3
+ size 72294975
video_module/__init__.py ADDED
File without changes
video_module/dataset/Water_DS.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+
4
+ from glob import glob
5
+
6
+ import torch
7
+ from torch.utils import data
8
+ import torchvision.transforms as TF
9
+
10
+ from video_module.dataset import transforms as mytrans
11
+ import myutils
12
+
13
+
14
+ class Water_Image_Train_DS(data.Dataset):
15
+
16
+ def __init__(self, root, output_size, clip_n, max_obj_n):
17
+ self.root = root
18
+ self.clip_n = clip_n
19
+ self.output_size = output_size
20
+ self.max_obj_n = max_obj_n
21
+
22
+ self.img_list = sorted(glob(os.path.join(root, 'JPEGImages', '*.jpg')) + glob(os.path.join(root, 'JPEGImages', '*.png')))
23
+ self.mask_list = sorted(glob(os.path.join(root, 'Annotations', '*.png')))
24
+
25
+ assert len(self.img_list) == len(self.mask_list), "The number of images and masks should be the same"
26
+
27
+ self.random_horizontal_flip = mytrans.RandomHorizontalFlip(0.3)
28
+ self.color_jitter = TF.ColorJitter(0.1, 0.1, 0.1, 0.03)
29
+ self.random_affine = mytrans.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10)
30
+ self.random_resize_crop = mytrans.RandomResizedCrop(output_size, (0.8, 1))
31
+ self.to_tensor = TF.ToTensor()
32
+ self.to_onehot = mytrans.ToOnehot(max_obj_n, shuffle=True)
33
+
34
+ def __len__(self):
35
+ return len(self.img_list)
36
+
37
+ def __getitem__(self, idx):
38
+
39
+ img_pil = myutils.load_image_in_PIL(self.img_list[idx], 'RGB')
40
+ mask_pil = myutils.load_image_in_PIL(self.mask_list[idx], 'P')
41
+
42
+ frames = torch.zeros((self.clip_n, 3, self.output_size, self.output_size), dtype=torch.float)
43
+ masks = torch.zeros((self.clip_n, self.max_obj_n, self.output_size, self.output_size), dtype=torch.float)
44
+
45
+ for i in range(self.clip_n):
46
+ img, mask = img_pil, mask_pil
47
+ if i > 0:
48
+ img, mask = self.random_horizontal_flip(img, mask)
49
+ img = self.color_jitter(img)
50
+ img, mask = self.random_affine(img, mask)
51
+
52
+ img, mask = self.random_resize_crop(img, mask)
53
+ mask = np.array(mask, np.uint8)
54
+
55
+ if i == 0:
56
+ mask, obj_list = self.to_onehot(mask)
57
+ obj_n = len(obj_list) + 1
58
+ else:
59
+ mask, _ = self.to_onehot(mask, obj_list)
60
+
61
+ frames[i] = self.to_tensor(img)
62
+ masks[i] = mask
63
+
64
+ info = {
65
+ 'name': self.img_list[idx]
66
+ }
67
+ return frames, masks[:, :obj_n], obj_n, info
68
+
69
+
70
+
71
+ class Video_DS(data.Dataset):
72
+
73
+ def __init__(self, img_list, first_frame, first_mask):
74
+ self.img_list = img_list[1:]
75
+ self.video_len = len(self.img_list)
76
+
77
+ first_mask = np.array(first_mask, np.uint8) > 0
78
+ self.obj_n = first_mask.max() + 1
79
+
80
+ self.to_tensor = TF.ToTensor()
81
+ self.to_onehot = mytrans.ToOnehot(self.obj_n, shuffle=False)
82
+
83
+ mask, _ = self.to_onehot(first_mask)
84
+ self.first_mask = mask[:self.obj_n]
85
+ self.first_frame = self.to_tensor(first_frame)
86
+
87
+ def __len__(self):
88
+ return self.video_len
89
+
90
+ def __getitem__(self, idx):
91
+ img = myutils.load_image_in_PIL(self.img_list[idx], 'RGB')
92
+ frame = self.to_tensor(img)
93
+ img_name = os.path.basename(self.img_list[idx])[:-4]
94
+
95
+ return frame, img_name
video_module/dataset/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .Water_DS import *
video_module/dataset/transforms.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ import random
4
+ import numbers
5
+ import numpy as np
6
+ from PIL import Image
7
+ from collections.abc import Sequence
8
+
9
+ import torch
10
+ import torchvision.transforms.functional as TF
11
+
12
+ _pil_interpolation_to_str = {
13
+ Image.NEAREST: 'PIL.Image.NEAREST',
14
+ Image.BILINEAR: 'PIL.Image.BILINEAR',
15
+ Image.BICUBIC: 'PIL.Image.BICUBIC',
16
+ Image.LANCZOS: 'PIL.Image.LANCZOS',
17
+ Image.HAMMING: 'PIL.Image.HAMMING',
18
+ Image.BOX: 'PIL.Image.BOX',
19
+ }
20
+
21
+
22
+ def _get_image_size(img):
23
+ if TF._is_pil_image(img):
24
+ return img.size
25
+ elif isinstance(img, torch.Tensor) and img.dim() > 2:
26
+ return img.shape[-2:][::-1]
27
+ else:
28
+ raise TypeError("Unexpected type {}".format(type(img)))
29
+
30
+
31
+ class RandomHorizontalFlip(object):
32
+ """Horizontal flip the given PIL Image randomly with a given probability.
33
+
34
+ Args:
35
+ p (float): probability of the image being flipped. Default value is 0.5
36
+ """
37
+
38
+ def __init__(self, p=0.5):
39
+ self.p = p
40
+
41
+ def __call__(self, img, mask):
42
+ """
43
+ Args:
44
+ img (PIL Image): Image to be flipped.
45
+
46
+ Returns:
47
+ PIL Image: Randomly flipped image.
48
+ """
49
+ if random.random() < self.p:
50
+ img = TF.hflip(img)
51
+ mask = TF.hflip(mask)
52
+ return img, mask
53
+
54
+ def __repr__(self):
55
+ return self.__class__.__name__ + '(p={})'.format(self.p)
56
+
57
+
58
+ class RandomAffine(object):
59
+ """Random affine transformation of the image keeping center invariant
60
+
61
+ Args:
62
+ degrees (sequence or float or int): Range of degrees to select from.
63
+ If degrees is a number instead of sequence like (min, max), the range of degrees
64
+ will be (-degrees, +degrees). Set to 0 to deactivate rotations.
65
+ translate (tuple, optional): tuple of maximum absolute fraction for horizontal
66
+ and vertical translations. For example translate=(a, b), then horizontal shift
67
+ is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
68
+ randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
69
+ scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
70
+ randomly sampled from the range a <= scale <= b. Will keep original scale by default.
71
+ shear (sequence or float or int, optional): Range of degrees to select from.
72
+ If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
73
+ will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
74
+ range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
75
+ a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
76
+ Will not apply shear by default
77
+ resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
78
+ An optional resampling filter. See `filters`_ for more information.
79
+ If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
80
+ fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area
81
+ outside the transform in the output image.(Pillow>=5.0.0)
82
+
83
+ .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
84
+
85
+ """
86
+
87
+ def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
88
+ if isinstance(degrees, numbers.Number):
89
+ if degrees < 0:
90
+ raise ValueError("If degrees is a single number, it must be positive.")
91
+ self.degrees = (-degrees, degrees)
92
+ else:
93
+ assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
94
+ "degrees should be a list or tuple and it must be of length 2."
95
+ self.degrees = degrees
96
+
97
+ if translate is not None:
98
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
99
+ "translate should be a list or tuple and it must be of length 2."
100
+ for t in translate:
101
+ if not (0.0 <= t <= 1.0):
102
+ raise ValueError("translation values should be between 0 and 1")
103
+ self.translate = translate
104
+
105
+ if scale is not None:
106
+ assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
107
+ "scale should be a list or tuple and it must be of length 2."
108
+ for s in scale:
109
+ if s <= 0:
110
+ raise ValueError("scale values should be positive")
111
+ self.scale = scale
112
+
113
+ if shear is not None:
114
+ if isinstance(shear, numbers.Number):
115
+ if shear < 0:
116
+ raise ValueError("If shear is a single number, it must be positive.")
117
+ self.shear = (-shear, shear)
118
+ else:
119
+ assert isinstance(shear, (tuple, list)) and \
120
+ (len(shear) == 2 or len(shear) == 4), \
121
+ "shear should be a list or tuple and it must be of length 2 or 4."
122
+ # X-Axis shear with [min, max]
123
+ if len(shear) == 2:
124
+ self.shear = [shear[0], shear[1], 0., 0.]
125
+ elif len(shear) == 4:
126
+ self.shear = [s for s in shear]
127
+ else:
128
+ self.shear = shear
129
+
130
+ self.resample = resample
131
+ self.fillcolor = fillcolor
132
+
133
+ @staticmethod
134
+ def get_params(degrees, translate, scale_ranges, shears, img_size):
135
+ """Get parameters for affine transformation
136
+
137
+ Returns:
138
+ sequence: params to be passed to the affine transformation
139
+ """
140
+ angle = random.uniform(degrees[0], degrees[1])
141
+ if translate is not None:
142
+ max_dx = translate[0] * img_size[0]
143
+ max_dy = translate[1] * img_size[1]
144
+ translations = (np.round(random.uniform(-max_dx, max_dx)),
145
+ np.round(random.uniform(-max_dy, max_dy)))
146
+ else:
147
+ translations = (0, 0)
148
+
149
+ if scale_ranges is not None:
150
+ scale = random.uniform(scale_ranges[0], scale_ranges[1])
151
+ else:
152
+ scale = 1.0
153
+
154
+ if shears is not None:
155
+ if len(shears) == 2:
156
+ shear = [random.uniform(shears[0], shears[1]), 0.]
157
+ elif len(shears) == 4:
158
+ shear = [random.uniform(shears[0], shears[1]),
159
+ random.uniform(shears[2], shears[3])]
160
+ else:
161
+ shear = 0.0
162
+
163
+ return angle, translations, scale, shear
164
+
165
+ def __call__(self, img, mask):
166
+ """
167
+ img (PIL Image): Image to be transformed.
168
+
169
+ Returns:
170
+ PIL Image: Affine transformed image.
171
+ """
172
+ ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
173
+ img = TF.affine(img, *ret, interpolation=TF.InterpolationMode.BICUBIC, fill=self.fillcolor)
174
+ mask = TF.affine(mask, *ret, interpolation=TF.InterpolationMode.NEAREST, fill=self.fillcolor)
175
+ return img, mask
176
+
177
+ def __repr__(self):
178
+ s = '{name}(degrees={degrees}'
179
+ if self.translate is not None:
180
+ s += ', translate={translate}'
181
+ if self.scale is not None:
182
+ s += ', scale={scale}'
183
+ if self.shear is not None:
184
+ s += ', shear={shear}'
185
+ if self.resample > 0:
186
+ s += ', resample={resample}'
187
+ if self.fillcolor != 0:
188
+ s += ', fillcolor={fillcolor}'
189
+ s += ')'
190
+ d = dict(self.__dict__)
191
+ d['resample'] = _pil_interpolation_to_str[d['resample']]
192
+ return s.format(name=self.__class__.__name__, **d)
193
+
194
+
195
+ class RandomCrop(object):
196
+ """Crop the given PIL Image at a random location.
197
+
198
+ Args:
199
+ size (sequence or int): Desired output size of the crop. If size is an
200
+ int instead of sequence like (h, w), a square crop (size, size) is
201
+ made.
202
+ padding (int or sequence, optional): Optional padding on each border
203
+ of the image. Default is None, i.e no padding. If a sequence of length
204
+ 4 is provided, it is used to pad left, top, right, bottom borders
205
+ respectively. If a sequence of length 2 is provided, it is used to
206
+ pad left/right, top/bottom borders, respectively.
207
+ pad_if_needed (boolean): It will pad the image if smaller than the
208
+ desired size to avoid raising an exception. Since cropping is done
209
+ after padding, the padding seems to be done at a random offset.
210
+ fill: Pixel fill value for constant fill. Default is 0. If a tuple of
211
+ length 3, it is used to fill R, G, B channels respectively.
212
+ This value is only used when the padding_mode is constant
213
+ padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
214
+
215
+ - constant: pads with a constant value, this value is specified with fill
216
+
217
+ - edge: pads with the last value on the edge of the image
218
+
219
+ - reflect: pads with reflection of image (without repeating the last value on the edge)
220
+
221
+ padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
222
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
223
+
224
+ - symmetric: pads with reflection of image (repeating the last value on the edge)
225
+
226
+ padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
227
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
228
+
229
+ """
230
+
231
+ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
232
+ if isinstance(size, numbers.Number):
233
+ self.size = (int(size), int(size))
234
+ else:
235
+ self.size = size
236
+ self.padding = padding
237
+ self.pad_if_needed = pad_if_needed
238
+ self.fill = fill
239
+ self.padding_mode = padding_mode
240
+
241
+ @staticmethod
242
+ def get_params(img, output_size):
243
+ """Get parameters for ``crop`` for a random crop.
244
+
245
+ Args:
246
+ img (PIL Image): Image to be cropped.
247
+ output_size (tuple): Expected output size of the crop.
248
+
249
+ Returns:
250
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
251
+ """
252
+ w, h = _get_image_size(img)
253
+ th, tw = output_size
254
+ if w == tw and h == th:
255
+ return 0, 0, h, w
256
+
257
+ i = random.randint(0, h - th)
258
+ j = random.randint(0, w - tw)
259
+ return i, j, th, tw
260
+
261
+ def __call__(self, img, mask):
262
+ """
263
+ Args:
264
+ img (PIL Image): Image to be cropped.
265
+
266
+ Returns:
267
+ PIL Image: Cropped image.
268
+ """
269
+ # if self.padding is not None:
270
+ # img = TF.pad(img, self.padding, self.fill, self.padding_mode)
271
+ #
272
+ # # pad the width if needed
273
+ # if self.pad_if_needed and img.size[0] < self.size[1]:
274
+ # img = TF.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
275
+ # # pad the height if needed
276
+ # if self.pad_if_needed and img.size[1] < self.size[0]:
277
+ # img = TF.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
278
+
279
+ i, j, h, w = self.get_params(img, self.size)
280
+ img = TF.crop(img, i, j, h, w)
281
+ mask = TF.crop(mask, i, j, h, w)
282
+
283
+ return img, mask
284
+
285
+ def __repr__(self):
286
+ return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
287
+
288
+
289
+ class RandomResizedCrop(object):
290
+ """Crop the given PIL Image to random size and aspect ratio.
291
+
292
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
293
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
294
+ is finally resized to given size.
295
+ This is popularly used to train the Inception networks.
296
+
297
+ Args:
298
+ size: expected output size of each edge
299
+ scale: range of size of the origin size cropped
300
+ ratio: range of aspect ratio of the origin aspect ratio cropped
301
+ interpolation: Default: PIL.Image.BILINEAR
302
+ """
303
+
304
+ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
305
+ if isinstance(size, (tuple, list)):
306
+ self.size = size
307
+ else:
308
+ self.size = (size, size)
309
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
310
+ warnings.warn("range should be of kind (min, max)")
311
+
312
+ self.interpolation = interpolation
313
+ self.scale = scale
314
+ self.ratio = ratio
315
+
316
+ @staticmethod
317
+ def get_params(img, scale, ratio):
318
+ """Get parameters for ``crop`` for a random sized crop.
319
+
320
+ Args:
321
+ img (PIL Image): Image to be cropped.
322
+ scale (tuple): range of size of the origin size cropped
323
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
324
+
325
+ Returns:
326
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
327
+ sized crop.
328
+ """
329
+ width, height = _get_image_size(img)
330
+ area = height * width
331
+
332
+ for _ in range(10):
333
+ target_area = random.uniform(*scale) * area
334
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
335
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
336
+
337
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
338
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
339
+
340
+ if 0 < w <= width and 0 < h <= height:
341
+ i = random.randint(0, height - h)
342
+ j = random.randint(0, width - w)
343
+ return i, j, h, w
344
+
345
+ # Fallback to central crop
346
+ in_ratio = float(width) / float(height)
347
+ if (in_ratio < min(ratio)):
348
+ w = width
349
+ h = int(round(w / min(ratio)))
350
+ elif (in_ratio > max(ratio)):
351
+ h = height
352
+ w = int(round(h * max(ratio)))
353
+ else: # whole image
354
+ w = width
355
+ h = height
356
+ i = (height - h) // 2
357
+ j = (width - w) // 2
358
+ return i, j, h, w
359
+
360
+ def __call__(self, img, mask):
361
+ """
362
+ Args:
363
+ img (PIL Image): Image to be cropped and resized.
364
+
365
+ Returns:
366
+ PIL Image: Randomly cropped and resized image.
367
+ """
368
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
369
+ # print(i, j, h, w)
370
+ img = TF.resized_crop(img, i, j, h, w, self.size, TF.InterpolationMode.BICUBIC)
371
+ mask = TF.resized_crop(mask, i, j, h, w, self.size, TF.InterpolationMode.NEAREST)
372
+ return img, mask
373
+
374
+ def __repr__(self):
375
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
376
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
377
+ format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
378
+ format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
379
+ format_string += ', interpolation={0})'.format(interpolate_str)
380
+ return format_string
381
+
382
+
383
+ class ToOnehot(object):
384
+ """To oneshot tensor
385
+
386
+ Args:
387
+ max_obj_n (float): Maximum number of the objects
388
+ """
389
+
390
+ def __init__(self, max_obj_n, shuffle):
391
+ self.max_obj_n = max_obj_n
392
+ self.shuffle = shuffle
393
+
394
+ def __call__(self, mask, obj_list=None):
395
+ """
396
+ Args:
397
+ mask (Mask in Numpy): Mask to be converted.
398
+
399
+ Returns:
400
+ Tensor: Converted mask in onehot format.
401
+ """
402
+
403
+ new_mask = np.zeros((self.max_obj_n, *mask.shape), np.uint8)
404
+
405
+ if not obj_list:
406
+ obj_list = list()
407
+ obj_max = mask.max() + 1
408
+ for i in range(1, obj_max):
409
+ tmp = (mask == i).astype(np.uint8)
410
+ if tmp.max() > 0:
411
+ obj_list.append(i)
412
+
413
+ if self.shuffle:
414
+ random.shuffle(obj_list)
415
+ obj_list = obj_list[:self.max_obj_n - 1]
416
+
417
+ for i in range(len(obj_list)):
418
+ new_mask[i + 1] = (mask == obj_list[i]).astype(np.uint8)
419
+ new_mask[0] = 1 - np.sum(new_mask, axis=0)
420
+
421
+ return torch.from_numpy(new_mask), obj_list
422
+
423
+ def __repr__(self):
424
+ return self.__class__.__name__ + '(max_obj_n={})'.format(self.max_obj_n)
425
+
426
+
427
+ class Resize(torch.nn.Module):
428
+ """Resize the input image to the given size.
429
+ The image can be a PIL Image or a torch Tensor, in which case it is expected
430
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
431
+
432
+ Args:
433
+ size (sequence or int): Desired output size. If size is a sequence like
434
+ (h, w), output size will be matched to this. If size is an int,
435
+ smaller edge of the image will be matched to this number.
436
+ i.e, if height > width, then image will be rescaled to
437
+ (size * height / width, size).
438
+ In torchscript mode padding as single int is not supported, use a tuple or
439
+ list of length 1: ``[size, ]``.
440
+ interpolation (int, optional): Desired interpolation enum defined by `filters`_.
441
+ Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
442
+ and ``PIL.Image.BICUBIC`` are supported.
443
+ """
444
+
445
+ def __init__(self, size, interpolation=Image.BILINEAR):
446
+ super().__init__()
447
+ if not isinstance(size, (int, Sequence)):
448
+ raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
449
+ if isinstance(size, Sequence) and len(size) not in (1, 2):
450
+ raise ValueError("If size is a sequence, it should have 1 or 2 values")
451
+ self.size = size
452
+ self.interpolation = interpolation
453
+
454
+ def forward(self, img, mask):
455
+ """
456
+ Args:
457
+ img (PIL Image or Tensor): Image to be scaled.
458
+
459
+ Returns:
460
+ PIL Image or Tensor: Rescaled image.
461
+ """
462
+ img = TF.resize(img, self.size, self.interpolation)
463
+ mask = TF.resize(mask, self.size, Image.NEAREST)
464
+ return img, mask
465
+
466
+ def __repr__(self):
467
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
468
+ return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
video_module/model/AFB_URR.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from torchvision.models import resnet50, ResNet50_Weights
6
+
7
+ import myutils
8
+
9
+
10
+ class ResBlock(nn.Module):
11
+ """A simple residual block component."""
12
+ def __init__(self, indim, outdim=None, stride=1):
13
+ super(ResBlock, self).__init__()
14
+ outdim = outdim or indim
15
+ self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride)
16
+ self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1)
17
+ self.downsample = nn.Conv2d(indim, outdim, kernel_size=1, stride=stride) if indim != outdim or stride != 1 else None
18
+
19
+ def forward(self, x):
20
+ identity = x
21
+ out = F.relu(self.conv1(x))
22
+ out = self.conv2(out)
23
+ if self.downsample:
24
+ identity = self.downsample(identity)
25
+ out += identity
26
+ return F.relu(out)
27
+
28
+
29
+ class EncoderM(nn.Module):
30
+ def __init__(self, load_imagenet_params):
31
+ super(EncoderM, self).__init__()
32
+ self.conv1_m = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
33
+ self.conv1_o = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
34
+
35
+ weights = ResNet50_Weights.IMAGENET1K_V1 if load_imagenet_params else None
36
+ resnet = resnet50(weights=weights)
37
+ self.conv1 = resnet.conv1
38
+ self.bn1 = resnet.bn1
39
+ self.relu = resnet.relu # 1/2, 64
40
+ self.maxpool = resnet.maxpool
41
+
42
+ self.res2 = resnet.layer1 # 1/4, 256
43
+ self.res3 = resnet.layer2 # 1/8, 512
44
+ self.res4 = resnet.layer3 # 1/16, 1024
45
+
46
+ self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
47
+ self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
48
+
49
+ def forward(self, in_f, in_m, in_o):
50
+ f = (in_f - self.mean) / self.std
51
+
52
+ x = self.conv1(f) + self.conv1_m(in_m) + self.conv1_o(in_o)
53
+ x = self.bn1(x)
54
+ r1 = self.relu(x) # 1/2, 64
55
+ x = self.maxpool(r1) # 1/4, 64
56
+ r2 = self.res2(x) # 1/4, 256
57
+ r3 = self.res3(r2) # 1/8, 512
58
+ r4 = self.res4(r3) # 1/16, 1024
59
+
60
+ return r4, r1
61
+
62
+
63
+ class EncoderQ(nn.Module):
64
+ def __init__(self, load_imagenet_params):
65
+ super(EncoderQ, self).__init__()
66
+ weights = ResNet50_Weights.IMAGENET1K_V1 if load_imagenet_params else None
67
+ resnet = resnet50(weights=weights)
68
+ self.conv1 = resnet.conv1
69
+ self.bn1 = resnet.bn1
70
+ self.relu = resnet.relu # 1/2, 64
71
+ self.maxpool = resnet.maxpool
72
+
73
+ self.res2 = resnet.layer1 # 1/4, 256
74
+ self.res3 = resnet.layer2 # 1/8, 512
75
+ self.res4 = resnet.layer3 # 1/8, 1024
76
+
77
+ self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
78
+ self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
79
+
80
+ def forward(self, in_f):
81
+ f = (in_f - self.mean) / self.std
82
+
83
+ x = self.conv1(f)
84
+ x = self.bn1(x)
85
+ r1 = self.relu(x) # 1/2, 64
86
+ x = self.maxpool(r1) # 1/4, 64
87
+ r2 = self.res2(x) # 1/4, 256
88
+ r3 = self.res3(r2) # 1/8, 512
89
+ r4 = self.res4(r3) # 1/8, 1024
90
+
91
+ return r4, r3, r2, r1
92
+
93
+
94
+ class KeyValue(nn.Module):
95
+
96
+ def __init__(self, indim, keydim, valdim):
97
+ super(KeyValue, self).__init__()
98
+ self.keydim = keydim
99
+ self.valdim = valdim
100
+ self.Key = nn.Conv2d(indim, keydim, kernel_size=(3, 3), padding=(1, 1), stride=1)
101
+ self.Value = nn.Conv2d(indim, valdim, kernel_size=(3, 3), padding=(1, 1), stride=1)
102
+
103
+ def forward(self, x):
104
+ key = self.Key(x)
105
+ key = key.view(*key.shape[:2], -1) # obj_n, key_dim, pixel_n
106
+
107
+ val = self.Value(x)
108
+ val = val.view(*val.shape[:2], -1) # obj_n, key_dim, pixel_n
109
+ return key, val
110
+
111
+
112
+ class Refine(nn.Module):
113
+ def __init__(self, inplanes, planes):
114
+ super(Refine, self).__init__()
115
+ self.convFS = nn.Conv2d(inplanes, planes, kernel_size=(3, 3), padding=(1, 1), stride=1)
116
+ self.ResFS = ResBlock(planes, planes)
117
+ self.ResMM = ResBlock(planes, planes)
118
+ self.scale_factor = 2
119
+
120
+ def forward(self, f, pm):
121
+ s = self.ResFS(self.convFS(f))
122
+ m = s + F.interpolate(pm, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
123
+ m = self.ResMM(m)
124
+
125
+ return m
126
+
127
+
128
+ class Matcher(nn.Module):
129
+ def __init__(self, thres_valid=1e-3, update_bank=False):
130
+ super(Matcher, self).__init__()
131
+ self.thres_valid = thres_valid
132
+ self.update_bank = update_bank
133
+
134
+ def forward(self, feature_bank, q_in, q_out):
135
+
136
+ mem_out_list = []
137
+
138
+ for i in range(0, feature_bank.obj_n):
139
+ d_key, bank_n = feature_bank.keys[i].size()
140
+
141
+ try:
142
+ p = torch.matmul(feature_bank.keys[i].transpose(0, 1), q_in) / math.sqrt(d_key) # THW, HW
143
+ p = F.softmax(p, dim=1) # bs, bank_n, HW
144
+ mem = torch.matmul(feature_bank.values[i], p) # bs, D_o, HW
145
+ except RuntimeError as e:
146
+ device = feature_bank.keys[i].device
147
+ key_cpu = feature_bank.keys[i].cpu()
148
+ value_cpu = feature_bank.values[i].cpu()
149
+ q_in_cpu = q_in.cpu()
150
+
151
+ p = torch.matmul(key_cpu.transpose(0, 1), q_in_cpu) / math.sqrt(d_key) # THW, HW
152
+ p = F.softmax(p, dim=1) # bs, bank_n, HW
153
+ mem = torch.matmul(value_cpu, p).to(device) # bs, D_o, HW
154
+ p = p.to(device)
155
+ print('\tLine 158. GPU out of memory, use CPU', f'p size: {p.shape}')
156
+
157
+ mem_out_list.append(torch.cat([mem, q_out], dim=1))
158
+
159
+ if self.update_bank:
160
+ try:
161
+ ones = torch.ones_like(p)
162
+ zeros = torch.zeros_like(p)
163
+ bank_cnt = torch.where(p > self.thres_valid, ones, zeros).sum(dim=2)[0]
164
+ except RuntimeError as e:
165
+ device = p.device
166
+ p = p.cpu()
167
+ ones = torch.ones_like(p)
168
+ zeros = torch.zeros_like(p)
169
+ bank_cnt = torch.where(p > self.thres_valid, ones, zeros).sum(dim=2)[0].to(device)
170
+ print('\tLine 170. GPU out of memory, use CPU', f'p size: {p.shape}')
171
+
172
+ feature_bank.info[i][:, 1] += torch.log(bank_cnt + 1)
173
+
174
+ mem_out_tensor = torch.stack(mem_out_list, dim=0).transpose(0, 1) # bs, obj_n, dim, pixel_n
175
+
176
+ return mem_out_tensor
177
+
178
+
179
+ class Decoder(nn.Module):
180
+ def __init__(self, device): # mdim_global = 256
181
+ super(Decoder, self).__init__()
182
+
183
+ self.device = device
184
+ mdim_global = 256
185
+ mdim_local = 32
186
+ local_size = 7
187
+
188
+ # Patch-wise
189
+ self.convFM = nn.Conv2d(1024, mdim_global, kernel_size=3, padding=1, stride=1)
190
+ self.ResMM = ResBlock(mdim_global, mdim_global)
191
+ self.RF3 = Refine(512, mdim_global) # 1/8 -> 1/8
192
+ self.RF2 = Refine(256, mdim_global) # 1/8 -> 1/4
193
+ self.pred2 = nn.Conv2d(mdim_global, 2, kernel_size=3, padding=1, stride=1)
194
+
195
+ # Local
196
+ self.local_avg = nn.AvgPool2d(local_size, stride=1, padding=local_size // 2)
197
+ self.local_max = nn.MaxPool2d(local_size, stride=1, padding=local_size // 2)
198
+ self.local_convFM = nn.Conv2d(128, mdim_local, kernel_size=3, padding=1, stride=1)
199
+ self.local_ResMM = ResBlock(mdim_local, mdim_local)
200
+ self.local_pred2 = nn.Conv2d(mdim_local, 2, kernel_size=3, padding=1, stride=1)
201
+
202
+ for m in self.modules():
203
+ if isinstance(m, nn.Conv2d):
204
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
205
+
206
+ def forward(self, patch_match, r3, r2, r1=None, feature_shape=None):
207
+ p = self.ResMM(self.convFM(patch_match))
208
+ p = self.RF3(r3, p) # out: 1/8, 256
209
+ p = self.RF2(r2, p) # out: 1/4, 256
210
+ p = self.pred2(F.relu(p))
211
+
212
+ p = F.interpolate(p, scale_factor=2, mode='bilinear', align_corners=False)
213
+
214
+ bs, obj_n, h, w = feature_shape
215
+ rough_seg = F.softmax(p, dim=1)[:, 1]
216
+ rough_seg = rough_seg.view(bs, obj_n, h, w)
217
+ rough_seg = F.softmax(rough_seg, dim=1) # object-level normalization
218
+
219
+ # Local refinement
220
+ uncertainty = myutils.calc_uncertainty(rough_seg)
221
+ uncertainty = uncertainty.expand(-1, obj_n, -1, -1).reshape(bs * obj_n, 1, h, w)
222
+
223
+ rough_seg = rough_seg.view(bs * obj_n, 1, h, w) # bs*obj_n, 1, h, w
224
+ r1_weighted = r1 * rough_seg
225
+ r1_local = self.local_avg(r1_weighted) # bs*obj_n, 64, h, w
226
+ r1_local = r1_local / (self.local_avg(rough_seg) + 1e-8) # neighborhood reference
227
+ r1_conf = self.local_max(rough_seg) # bs*obj_n, 1, h, w
228
+
229
+ local_match = torch.cat([r1, r1_local], dim=1)
230
+ q = self.local_ResMM(self.local_convFM(local_match))
231
+ q = r1_conf * self.local_pred2(F.relu(q))
232
+
233
+ p = p + uncertainty * q
234
+ p = F.interpolate(p, scale_factor=2, mode='bilinear', align_corners=False)
235
+ p = F.softmax(p, dim=1)[:, 1] # no, h, w
236
+
237
+ return p
238
+
239
+
240
+ class AFB_URR(nn.Module):
241
+ def __init__(self, device, update_bank, load_imagenet_params=False):
242
+ super(AFB_URR, self).__init__()
243
+
244
+ self.device = device
245
+ self.encoder_m = EncoderM(load_imagenet_params)
246
+ self.encoder_q = EncoderQ(load_imagenet_params)
247
+
248
+ self.keyval_r4 = KeyValue(1024, keydim=128, valdim=512)
249
+
250
+ self.global_matcher = Matcher(update_bank=update_bank)
251
+ self.decoder = Decoder(device)
252
+
253
+ def memorize(self, frame, mask):
254
+
255
+ _, K, H, W = mask.shape
256
+
257
+ (frame, mask), pad = myutils.pad_divide_by([frame, mask], 16, (frame.size()[2], frame.size()[3]))
258
+
259
+ frame = frame.expand(K, -1, -1, -1) # obj_n, 3, h, w
260
+ mask = mask[0].unsqueeze(1).float()
261
+ mask_ones = torch.ones_like(mask)
262
+ mask_inv = (mask_ones - mask).clamp(0, 1)
263
+
264
+ r4, r1 = self.encoder_m(frame, mask, mask_inv)
265
+
266
+ k4, v4 = self.keyval_r4(r4) # num_objects, 128 and 512, H/16, W/16
267
+ k4_list = [k4[i] for i in range(K)]
268
+ v4_list = [v4[i] for i in range(K)]
269
+
270
+ return k4_list, v4_list
271
+
272
+ def segment(self, frame, fb_global):
273
+
274
+ obj_n = fb_global.obj_n
275
+
276
+ if not self.training:
277
+ [frame], pad = myutils.pad_divide_by([frame], 16, (frame.size()[2], frame.size()[3]))
278
+
279
+ r4, r3, r2, r1 = self.encoder_q(frame)
280
+ bs, _, global_match_h, global_match_w = r4.shape
281
+ _, _, local_match_h, local_match_w = r1.shape
282
+
283
+ k4, v4 = self.keyval_r4(r4) # 1, dim, H/16, W/16
284
+ res_global = self.global_matcher(fb_global, k4, v4)
285
+ res_global = res_global.reshape(bs * obj_n, v4.shape[1] * 2, global_match_h, global_match_w)
286
+
287
+ r3_size = r3.shape
288
+ r2_size = r2.shape
289
+ r3 = r3.unsqueeze(1).expand(-1, obj_n, -1, -1, -1).reshape(bs * obj_n, *r3_size[1:])
290
+ r2 = r2.unsqueeze(1).expand(-1, obj_n, -1, -1, -1).reshape(bs * obj_n, *r2_size[1:])
291
+
292
+ r1_size = r1.shape
293
+ r1 = r1.unsqueeze(1).expand(-1, obj_n, -1, -1, -1).reshape(bs * obj_n, *r1_size[1:])
294
+ feature_size = (bs, obj_n, r1_size[2], r1_size[3])
295
+ score = self.decoder(res_global, r3, r2, r1, feature_size)
296
+
297
+ # score = score.view(obj_n, bs, *frame.shape[-2:]).permute(1, 0, 2, 3)
298
+ score = score.view(bs, obj_n, *frame.shape[-2:])
299
+
300
+ if self.training:
301
+ uncertainty = myutils.calc_uncertainty(F.softmax(score, dim=1))
302
+ uncertainty = uncertainty.view(bs, -1).norm(p=2, dim=1) / math.sqrt(frame.shape[-2] * frame.shape[-1]) # [B,1,H,W]
303
+ uncertainty = uncertainty.mean()
304
+ else:
305
+ uncertainty = None
306
+
307
+ score = torch.clamp(score, 1e-7, 1 - 1e-7)
308
+ score = torch.log((score / (1 - score)))
309
+
310
+ if not self.training:
311
+ if pad[2] + pad[3] > 0:
312
+ score = score[:, :, pad[2]:-pad[3], :]
313
+ if pad[0] + pad[1] > 0:
314
+ score = score[:, :, :, pad[0]:-pad[1]]
315
+
316
+ return score, uncertainty
317
+
318
+ def forward(self, x):
319
+ pass
video_module/model/FeatureBank.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as NF
4
+
5
+ from torch_scatter import scatter_mean
6
+
7
+
8
+ class FeatureBank:
9
+
10
+ def __init__(self, obj_n, memory_budget, device, update_rate=0.1, thres_close=0.95):
11
+ self.obj_n = obj_n
12
+ self.update_rate = update_rate
13
+ self.thres_close = thres_close
14
+ self.device = device
15
+
16
+ self.info = [None for _ in range(obj_n)]
17
+ self.peak_n = np.zeros(obj_n)
18
+ self.replace_n = np.zeros(obj_n)
19
+
20
+ self.class_budget = memory_budget // obj_n
21
+ if obj_n == 2:
22
+ self.class_budget = 0.8 * self.class_budget
23
+
24
+ self.keys = None
25
+ self.values = None
26
+
27
+ def init_bank(self, keys, values, frame_idx=0):
28
+
29
+ self.keys = keys
30
+ self.values = values
31
+
32
+ for class_idx in range(self.obj_n):
33
+ _, bank_n = keys[class_idx].shape
34
+ self.info[class_idx] = torch.zeros((bank_n, 2), device=self.device)
35
+ self.info[class_idx][:, 0] = frame_idx
36
+ self.peak_n[class_idx] = max(self.peak_n[class_idx], self.info[class_idx].shape[0])
37
+
38
+ def append(self, keys, values, frame_idx=0):
39
+
40
+ if self.keys:
41
+ for class_idx in range(self.obj_n):
42
+ self.keys[class_idx] = torch.cat([self.keys[class_idx], keys[class_idx]], dim=1)
43
+ self.values[class_idx] = torch.cat([self.values[class_idx], values[class_idx]], dim=1)
44
+
45
+ _, bank_n = keys[class_idx].shape
46
+ new_info = torch.ones((bank_n, 2), device=self.device) * 20 # zeros
47
+ new_info[:, 0] = frame_idx
48
+ self.info[class_idx] = torch.cat([self.info[class_idx], new_info], dim=0)
49
+ self.peak_n[class_idx] = max(self.peak_n[class_idx], self.info[class_idx].shape[0])
50
+ else:
51
+ self.init_bank(keys, values, frame_idx)
52
+
53
+ def update(self, prev_key, prev_value, frame_idx, update_rate=-1):
54
+
55
+ if update_rate == -1:
56
+ update_rate = self.update_rate
57
+
58
+ for class_idx in range(self.obj_n):
59
+
60
+ d_key, bank_n = self.keys[class_idx].shape
61
+ d_val, _ = self.values[class_idx].shape
62
+
63
+ normed_keys = NF.normalize(self.keys[class_idx], dim=0)
64
+ normed_prev_key = NF.normalize(prev_key[class_idx], dim=0)
65
+ mag_keys = self.keys[class_idx].norm(p=2, dim=0)
66
+ corr = torch.mm(normed_keys.transpose(0, 1), normed_prev_key) # bank_n, prev_n
67
+ related_bank_idx = corr.argmax(dim=0, keepdim=True) # 1, HW
68
+ related_bank_corr = torch.gather(corr, 0, related_bank_idx) # 1, HW
69
+
70
+ # greater than threshold, merge them
71
+ selected_idx = (related_bank_corr[0] > self.thres_close).nonzero(as_tuple=False)
72
+ class_related_bank_idx = related_bank_idx[0, selected_idx[:, 0]] # selected_HW
73
+ unique_related_bank_idx, cnt = class_related_bank_idx.unique(dim=0, return_counts=True) # selected_HW
74
+
75
+ # Update key
76
+ key_bank_update = torch.zeros((d_key, bank_n), dtype=torch.float, device=self.device) # d_key, THW
77
+ key_bank_idx = class_related_bank_idx.unsqueeze(0).expand(d_key, -1) # d_key, HW
78
+ scatter_mean(normed_prev_key[:, selected_idx[:, 0]], key_bank_idx, dim=1, out=key_bank_update)
79
+ # d_key, selected_HW
80
+
81
+ self.keys[class_idx][:, unique_related_bank_idx] = \
82
+ mag_keys[unique_related_bank_idx] * \
83
+ ((1 - update_rate) * normed_keys[:, unique_related_bank_idx] + \
84
+ update_rate * key_bank_update[:, unique_related_bank_idx])
85
+
86
+ # Update value
87
+ normed_values = NF.normalize(self.values[class_idx], dim=0)
88
+ normed_prev_value = NF.normalize(prev_value[class_idx], dim=0)
89
+ mag_values = self.values[class_idx].norm(p=2, dim=0)
90
+ val_bank_update = torch.zeros((d_val, bank_n), dtype=torch.float, device=self.device)
91
+ val_bank_idx = class_related_bank_idx.unsqueeze(0).expand(d_val, -1)
92
+ scatter_mean(normed_prev_value[:, selected_idx[:, 0]], val_bank_idx, dim=1, out=val_bank_update)
93
+
94
+ self.values[class_idx][:, unique_related_bank_idx] = \
95
+ mag_values[unique_related_bank_idx] * \
96
+ ((1 - update_rate) * normed_values[:, unique_related_bank_idx] + \
97
+ update_rate * val_bank_update[:, unique_related_bank_idx])
98
+
99
+ # less than the threshold, concat them
100
+ selected_idx = (related_bank_corr[0] <= self.thres_close).nonzero(as_tuple=False)
101
+
102
+ if self.class_budget < bank_n + selected_idx.shape[0]:
103
+ self.remove(class_idx, selected_idx.shape[0], frame_idx)
104
+
105
+ self.keys[class_idx] = torch.cat([self.keys[class_idx], prev_key[class_idx][:, selected_idx[:, 0]]], dim=1)
106
+ self.values[class_idx] = \
107
+ torch.cat([self.values[class_idx], prev_value[class_idx][:, selected_idx[:, 0]]], dim=1)
108
+
109
+ new_info = torch.zeros((selected_idx.shape[0], 2), device=self.device)
110
+ new_info[:, 0] = frame_idx
111
+ self.info[class_idx] = torch.cat([self.info[class_idx], new_info], dim=0)
112
+
113
+ self.peak_n[class_idx] = max(self.peak_n[class_idx], self.info[class_idx].shape[0])
114
+
115
+ self.info[class_idx][:, 1] = torch.clamp(self.info[class_idx][:, 1], 0, 1e5) # Prevent inf
116
+
117
+ def remove(self, class_idx, request_n, frame_idx):
118
+
119
+ old_size = self.keys[class_idx].shape[1]
120
+
121
+ LFU = frame_idx - self.info[class_idx][:, 0] # time length
122
+ LFU = self.info[class_idx][:, 1] / LFU
123
+ thres_dynamic = int(LFU.min()) + 1
124
+ iter_cnt = 0
125
+
126
+ while True:
127
+ selected_idx = LFU > thres_dynamic
128
+ self.keys[class_idx] = self.keys[class_idx][:, selected_idx]
129
+ self.values[class_idx] = self.values[class_idx][:, selected_idx]
130
+ self.info[class_idx] = self.info[class_idx][selected_idx]
131
+ LFU = LFU[selected_idx]
132
+ iter_cnt += 1
133
+
134
+ balance = (self.class_budget - self.keys[class_idx].shape[1]) - request_n
135
+ if balance < 0:
136
+ thres_dynamic = int(LFU.min()) + 1
137
+ else:
138
+ break
139
+
140
+ new_size = self.keys[class_idx].shape[1]
141
+ self.replace_n[class_idx] += old_size - new_size
142
+
143
+ return balance
144
+
145
+ def print_peak_mem(self):
146
+
147
+ ur = self.peak_n / self.class_budget
148
+ rr = self.replace_n / self.class_budget
149
+ print(f'Obj num: {self.obj_n}.', f'Budget / obj: {self.class_budget}.', f'UR: {ur}.', f'Replace: {rr}.')
video_module/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .AFB_URR import *
2
+ from .FeatureBank import FeatureBank