zhiqing0205 commited on
Commit
74acc06
·
1 Parent(s): 4a80644

Add basic Python scripts and documentation

Browse files
LogSAD技术详解.md ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LogSAD:基于视觉和语言基础模型的无训练异常检测方法详解
2
+
3
+ ## 项目概述
4
+
5
+ LogSAD(Towards Training-free Anomaly Detection with Vision and Language Foundation Models)是一个发表在CVPR 2025的无需训练的异常检测方法。该方法通过结合多个预训练的视觉和语言基础模型,实现了对MVTec LOCO数据集的逻辑异常和结构异常检测。
6
+
7
+ ## 整体架构与流程
8
+
9
+ ### 核心理念
10
+ LogSAD的核心思想是利用预训练模型的强大表示能力,通过多模态特征融合和逻辑推理来检测异常,无需对特定数据集进行训练。
11
+
12
+ ### 系统架构
13
+ ```
14
+ 输入图像 (448x448)
15
+
16
+ ┌─────────────────────────────────────────────────┐
17
+ │ 多模态特征提取层 │
18
+ │ ├─ CLIP ViT-L-14 (图像+文本特征) │
19
+ │ ├─ DINOv2 ViT-L-14 (图像特征) │
20
+ │ └─ SAM ViT-H (实例分割) │
21
+ └─────────────────────────────────────────────────┘
22
+
23
+ ┌─────────────────────────────────────────────────┐
24
+ │ 特征处理与融合层 │
25
+ │ ├─ K-means聚类分割 │
26
+ │ ├─ 文本引导的语义分割 │
27
+ │ └─ 多尺度特征融合 │
28
+ └─────────────────────────────────────────────────┘
29
+
30
+ ┌─────────────────────────────────────────────────┐
31
+ │ 异常检测层 │
32
+ │ ├─ 结构异常检测 (PatchCore) │
33
+ │ ├─ 逻辑异常检测 (直方图匹配) │
34
+ │ └─ 实例匹配检测 (Hungarian算法) │
35
+ └─────────────────────────────────────────────────┘
36
+
37
+ 最终异常分数
38
+ ```
39
+
40
+ ## 预训练模型详解
41
+
42
+ ### 1. CLIP ViT-L-14 模型
43
+ **作用**:视觉-语言理解的核心
44
+ - **模型**:`hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K`
45
+ - **输入尺寸**:448×448
46
+ - **特征提取层**:[6, 12, 18, 24]
47
+ - **特征维度**:1024维
48
+ - **输出特征尺寸**:32×32 → 64×64(插值)
49
+
50
+ **具体实现**:
51
+ ```python
52
+ # model_ensemble.py:96-97
53
+ self.model_clip, _, _ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
54
+ self.feature_list = [6, 12, 18, 24]
55
+ ```
56
+
57
+ **协作机制**:
58
+ - 提供图像的语义特征表示
59
+ - 通过文本提示编码不同物体的语义信息
60
+ - 用于语义分割和异常分类
61
+
62
+ ### 2. DINOv2 ViT-L-14 模型
63
+ **作用**:提供更丰富的视觉特征
64
+ - **模型**:`dinov2_vitl14`
65
+ - **特征提取层**:[6, 12, 18, 24]
66
+ - **特征维度**:1024维
67
+ - **输出特征尺寸**:32×32 → 64×64(插值)
68
+
69
+ **具体实现**:
70
+ ```python
71
+ # model_ensemble.py:181-186
72
+ from dinov2.dinov2.hub.backbones import dinov2_vitl14
73
+ self.model_dinov2 = dinov2_vitl14()
74
+ self.feature_list_dinov2 = [6, 12, 18, 24]
75
+ ```
76
+
77
+ **协作机制**:
78
+ - 为某些类别(splicing_connectors, breakfast_box, juice_bottle)提供更强的视觉特征
79
+ - 与CLIP特征互补,提高检测精度
80
+
81
+ ### 3. SAM (Segment Anything Model)
82
+ **作用**:实例分割
83
+ - **模型**:ViT-H版本
84
+ - **检查点**:`./checkpoint/sam_vit_h_4b8939.pth`
85
+ - **功能**:自动生成物体mask
86
+
87
+ **具体实现**:
88
+ ```python
89
+ # model_ensemble.py:102-103
90
+ self.model_sam = sam_model_registry["vit_h"](checkpoint = "./checkpoint/sam_vit_h_4b8939.pth")
91
+ self.mask_generator = SamAutomaticMaskGenerator(model = self.model_sam)
92
+ ```
93
+
94
+ **协作机制**:
95
+ - 提供精确的物体边界
96
+ - 用于实例级别的异常检测
97
+ - 与语义分割结果融合
98
+
99
+ ## 数据处理与尺寸变换详解
100
+
101
+ ### 图像预处理流程
102
+
103
+ 1. **输入尺寸标准化**:
104
+ ```python
105
+ # evaluation.py:184
106
+ datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category)
107
+ ```
108
+
109
+ 2. **归一化处理**:
110
+ ```python
111
+ # model_ensemble.py:88-92
112
+ self.transform = v2.Compose([
113
+ v2.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
114
+ std=(0.26862954, 0.26130258, 0.27577711)),
115
+ ])
116
+ ```
117
+
118
+ 3. **特征图尺寸变换**:
119
+ ```python
120
+ # model_ensemble.py:155-156
121
+ self.feat_size = 64 # 目标特征图大小
122
+ self.ori_feat_size = 32 # 原始特征图大小
123
+ ```
124
+
125
+ ### 详细的Resize流程
126
+
127
+ **CLIP特征处理**:
128
+ ```python
129
+ # model_ensemble.py:245-255
130
+ # 1. ��32x32插值到64x64
131
+ patch_tokens_clip = patch_tokens_clip.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
132
+ patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size),
133
+ mode=self.inter_mode, align_corners=self.align_corners)
134
+ patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
135
+ ```
136
+
137
+ **DINOv2特征处理**:
138
+ ```python
139
+ # model_ensemble.py:253-263
140
+ # 相同的插值流程
141
+ patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size),
142
+ mode=self.inter_mode, align_corners=self.align_corners)
143
+ ```
144
+
145
+ **插值参数**:
146
+ - **插值模式**:双线性插值(`bilinear`)
147
+ - **对齐角点**:`align_corners=True`
148
+ - **抗锯齿**:`antialias=True`
149
+
150
+ ## SAM多Mask处理机制
151
+
152
+ ### SAM生成多个Mask的处理
153
+
154
+ **Mask生成**:
155
+ ```python
156
+ # model_ensemble.py:394
157
+ masks = self.mask_generator.generate(raw_image)
158
+ sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
159
+ ```
160
+
161
+ **Mask融合策略**:
162
+ ```python
163
+ # model_ensemble.py:347-367
164
+ def merge_segmentations(a, b, background_class):
165
+ """将SAM mask与语义分割结果融合"""
166
+ # 通过投票机制确定每个SAM区域的语义标签
167
+ for label_a in unique_labels_a:
168
+ mask_a = (a == label_a)
169
+ labels_b = b[mask_a]
170
+ if labels_b.size > 0:
171
+ count_b = np.bincount(labels_b, minlength=unique_labels_b.max() + 1)
172
+ label_map[label_a] = np.argmax(count_b) # 多数投票
173
+ ```
174
+
175
+ **多Mask协作流程**:
176
+ 1. SAM生成所有可能的实例mask
177
+ 2. K-means聚类生成语义分割mask
178
+ 3. 文本引导生成patch级别的语义mask
179
+ 4. 通过投票机制融合不同来源的mask
180
+ 5. 过滤小区域噪声(阈值:32像素)
181
+
182
+ ## Ground Truth多Mask处理机制
183
+
184
+ ### MVTec LOCO数据集的Mask组织结构
185
+
186
+ **文件结构**:
187
+ ```
188
+ dataset/
189
+ ├── test/category/image_filename.png # 测试图像
190
+ ├── ground_truth/category/image_filename/ # 对应的GT mask目录
191
+ │ ├── 000.png # 第一个异常区域mask
192
+ │ ├── 001.png # 第二个异常区域mask
193
+ │ ├── 002.png # 第三个异常区域mask
194
+ │ └── ... # 更多异常区域mask
195
+ ```
196
+
197
+ **数据加载时的多Mask聚合**:
198
+ ```python
199
+ # anomalib/data/image/mvtec_loco.py:142-148
200
+ mask_samples = (
201
+ mask_samples.groupby(["path", "split", "label", "image_folder"])["image_path"]
202
+ .agg(list) # 将同一图像的多个mask路径聚合成列表
203
+ .reset_index()
204
+ .rename(columns={"image_path": "mask_path"})
205
+ )
206
+ ```
207
+
208
+ ### 多Mask融合策略
209
+
210
+ **步骤1:Mask路径处理**:
211
+ ```python
212
+ # anomalib/data/image/mvtec_loco.py:279-280
213
+ if isinstance(mask_path, str):
214
+ mask_path = [mask_path] # 确保mask_path是列表格式
215
+ ```
216
+
217
+ **步骤2:语义Mask堆叠**:
218
+ ```python
219
+ # anomalib/data/image/mvtec_loco.py:281-285
220
+ semantic_mask = (
221
+ Mask(torch.zeros(image.shape[-2:])).to(torch.uint8) # 正常图像:零mask
222
+ if label_index == LabelName.NORMAL
223
+ else Mask(torch.stack([self._read_mask(path) for path in mask_path])) # 异常图像:堆叠所有mask
224
+ )
225
+ ```
226
+
227
+ **步骤3:二值Mask生成**:
228
+ ```python
229
+ # anomalib/data/image/mvtec_loco.py:287
230
+ binary_mask = Mask(semantic_mask.view(-1, *semantic_mask.shape[-2:]).int().any(dim=0).to(torch.uint8))
231
+ ```
232
+
233
+ ### 关键融合机制解析
234
+
235
+ **维度变换**:
236
+ - 输入:多个mask,每个形状为 (H, W)
237
+ - 堆叠后:(N, H, W),其中N为mask数量
238
+ - `view(-1, H, W)`:重塑为 (N, H, W)
239
+ - `any(dim=0)`:沿第一维度求或运算,得到 (H, W)
240
+
241
+ **融合逻辑**:
242
+ ```python
243
+ # 伪代码示例
244
+ mask1 = [[0, 1, 0], mask2 = [[0, 0, 1],
245
+ [1, 0, 1], [0, 1, 0],
246
+ [0, 1, 0]] [1, 0, 0]]
247
+
248
+ # 堆叠:shape (2, 3, 3)
249
+ stacked = torch.stack([mask1, mask2])
250
+
251
+ # any操作:逐像素求或
252
+ result = [[0, 1, 1], # max(0,0), max(1,0), max(0,1)
253
+ [1, 1, 1], # max(1,0), max(0,1), max(1,0)
254
+ [1, 1, 0]] # max(0,1), max(1,0), max(0,0)
255
+ ```
256
+
257
+ ### 数据加载完整流程
258
+
259
+ **MVTec LOCO数据项结构**:
260
+ ```python
261
+ # 正常样本
262
+ item = {
263
+ "image_path": "/path/to/normal_image.png",
264
+ "label": 0,
265
+ "image": torch.Tensor(...),
266
+ "mask": torch.zeros(H, W), # 零mask
267
+ "mask_path": [], # 空列表
268
+ "semantic_mask": torch.zeros(H, W) # 零mask
269
+ }
270
+
271
+ # 异常样本
272
+ item = {
273
+ "image_path": "/path/to/abnormal_image.png",
274
+ "label": 1,
275
+ "image": torch.Tensor(...),
276
+ "mask": torch.Tensor(...), # 融合后的二值mask
277
+ "mask_path": [ # 多个mask路径列表
278
+ "/path/to/ground_truth/image/000.png",
279
+ "/path/to/ground_truth/image/001.png",
280
+ "/path/to/ground_truth/image/002.png"
281
+ ],
282
+ "semantic_mask": torch.Tensor(...) # 原始多mask堆叠,shape (N, H, W)
283
+ }
284
+ ```
285
+
286
+ ### 评估时的Mask使用
287
+
288
+ **重要特性**:LogSAD在推理过程中**不使用**ground truth mask,完全基于输入图像进行异常检测。Ground truth mask仅用于:
289
+
290
+ 1. **性能评估**:计算AUROC、F1等指标
291
+ 2. **可视化对比**:与预测结果对比
292
+ 3. **指标计算**:像素级和语义级异常检测性能
293
+
294
+ **验证机制**:
295
+ ```python
296
+ # anomalib/data/image/mvtec_loco.py:158-174
297
+ # 验证mask文件与图像文件的对应关系
298
+ image_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["image_path"].apply(lambda x: Path(x).stem)
299
+ mask_parent_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["mask_path"].apply(
300
+ lambda x: {Path(mask_path).parent.stem for mask_path in x},
301
+ )
302
+ # 确保 image: '005.png' 对应 mask: '005/000.png', '005/001.png' 等
303
+ ```
304
+
305
+ ### 多Mask场景的实际应用
306
+
307
+ **典型场景**:
308
+ 1. **Splicing Connectors**:连接器、电缆、夹具可能分别标注
309
+ 2. **Juice Bottle**:液体、标签、瓶身缺陷可能分别标注
310
+ 3. **Breakfast Box**:不同食物的缺失可能分别标注
311
+ 4. **Screw Bag**:不同螺丝、螺母、垫圈的异常分别标注
312
+
313
+ **处理优势**:
314
+ - 保留了详细的异常区域信息
315
+ - 支持多类型异常的联合评估
316
+ - 便于细粒度的性能分析
317
+ - 兼容传统二值异常检测评估
318
+
319
+ ## 关键特判逻辑详解
320
+
321
+ 代码中存在**5个主要特判分支**,分别对应不同的数据集类别:
322
+
323
+ ### 1. Pushpins类别特判
324
+
325
+ **位置**:`model_ensemble.py:432-479`
326
+
327
+ **逻辑**:
328
+ ```python
329
+ if self.class_name == 'pushpins':
330
+ # 1. 物体计数检测
331
+ pushpins_count = num_labels - 1
332
+ if self.few_shot_inited and pushpins_count != self.pushpins_count:
333
+ self.anomaly_flag = True
334
+
335
+ # 2. Patch直方图匹配
336
+ clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])
337
+ patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
338
+ score = 1 - patch_hist_similarity.max()
339
+ ```
340
+
341
+ **检测异常类型**:
342
+ - 推钉数量异常(标准数量:15个)
343
+ - 颜色分布异常
344
+
345
+ ### 2. Splicing Connectors类别特判
346
+
347
+ **位置**:`model_ensemble.py:481-615`
348
+
349
+ **复杂逻辑**:
350
+ ```python
351
+ elif self.class_name == 'splicing_connectors':
352
+ # 1. 连接组件检测
353
+ if count != 1:
354
+ self.anomaly_flag = True
355
+
356
+ # 2. 电缆颜色与夹具数量匹配检测
357
+ foreground_pixel_count = np.sum(erode_binary) / self.splicing_connectors_count[idx_color]
358
+ ratio = foreground_pixel_count / self.foreground_pixel_hist_splicing_connectors
359
+ if ratio > 1.2 or ratio < 0.8:
360
+ self.anomaly_flag = True
361
+
362
+ # 3. 左右对称性检测
363
+ ratio = np.sum(left_count) / (np.sum(right_count) + 1e-5)
364
+ if ratio > 1.2 or ratio < 0.8:
365
+ self.anomaly_flag = True
366
+
367
+ # 4. 距离检测
368
+ distance = np.sqrt((x1/w - x2/w)**2 + (y1/h - y2/h)**2)
369
+ ratio = distance / self.splicing_connectors_distance
370
+ if ratio < 0.6 or ratio > 1.4:
371
+ self.anomaly_flag = True
372
+ ```
373
+
374
+ **检测异常类型**:
375
+ - 电缆断裂或缺失
376
+ - 颜色与夹具数量不匹配(黄色2夹、蓝色3夹、红色5夹)
377
+ - 左右夹具不对称
378
+ - 电缆长度异常
379
+
380
+ ### 3. Screw Bag类别特判
381
+
382
+ **位置**:`model_ensemble.py:617-670`
383
+
384
+ **逻辑**:
385
+ ```python
386
+ elif self.class_name == 'screw_bag':
387
+ # 前景像素统计异常检测
388
+ foreground_pixel_count = np.sum(np.bincount(kmeans_mask.reshape(-1))[:len(self.foreground_label_idx[self.class_name])])
389
+ ratio = foreground_pixel_count / self.foreground_pixel_hist_screw_bag
390
+ if ratio < 0.94 or ratio > 1.06:
391
+ self.anomaly_flag = True
392
+ ```
393
+
394
+ **检测异常类型**:
395
+ - 螺丝、螺母、垫圈数量异常
396
+ - 前景像素比例异常(阈值:±6%)
397
+
398
+ ### 4. Juice Bottle类别特判
399
+
400
+ **位置**:`model_ensemble.py:715-771`
401
+
402
+ **逻辑**:
403
+ ```python
404
+ elif self.class_name == 'juice_bottle':
405
+ # 液体与水果匹配检测
406
+ liquid_idx = (liquid_feature @ query_liquid.T).argmax(-1).squeeze(0).item()
407
+ fruit_idx = (fruit_feature @ query_fruit.T).argmax(-1).squeeze(0).item()
408
+ if liquid_idx != fruit_idx:
409
+ self.anomaly_flag = True
410
+ ```
411
+
412
+ **检测异常类型**:
413
+ - 液体颜色与标签水果不匹配
414
+ - 标签错位
415
+
416
+ ### 5. Breakfast Box类别特判
417
+
418
+ **位置**:`model_ensemble.py:672-713`
419
+
420
+ **逻辑**:
421
+ ```python
422
+ elif self.class_name == 'breakfast_box':
423
+ # 主要依靠patch直方图匹配
424
+ sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
425
+ patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
426
+ score = 1 - patch_hist_similarity.max()
427
+ ```
428
+
429
+ **检测异常类型**:
430
+ - 食物分布异常
431
+ - 缺失或多余物品
432
+
433
+ ## Few-shot与Full-data模式区别
434
+
435
+ ### 数据处理差异
436
+
437
+ **Few-shot模式**(`model_ensemble_few_shot.py`):
438
+ ```python
439
+ # 直接使用所有few-shot样本
440
+ FEW_SHOT_SAMPLES = [0, 1, 2, 3] # 固定4个样本
441
+ self.k_shot = few_shot_samples.size(0)
442
+ ```
443
+
444
+ **Full-data模式**(`model_ensemble.py`):
445
+ ```python
446
+ # 使用完整训练集构建coreset
447
+ FEW_SHOT_SAMPLES = range(len(datamodule.train_data)) # 所有训练样本
448
+ self.k_shot = 4 if self.total_size > 4 else self.total_size
449
+ ```
450
+
451
+ ### Coreset子采样机制
452
+
453
+ **Few-shot模式**:无coreset,直接使用原始特征
454
+ ```python
455
+ # model_ensemble_few_shot.py:852
456
+ self.mem_patch_feature_clip_coreset = patch_tokens_clip
457
+ self.mem_patch_feature_dinov2_coreset = patch_tokens_dinov2
458
+ ```
459
+
460
+ **Full-data模式**:使用K-Center Greedy算法进行coreset子采样
461
+ ```python
462
+ # model_ensemble.py:892-896
463
+ clip_sampler = KCenterGreedy(embedding=mem_patch_feature_clip_coreset, sampling_ratio=0.25)
464
+ mem_patch_feature_clip_coreset = clip_sampler.sample_coreset()
465
+
466
+ dinov2_sampler = KCenterGreedy(embedding=mem_patch_feature_dinov2_coreset, sampling_ratio=0.25)
467
+ mem_patch_feature_dinov2_coreset = dinov2_sampler.sample_coreset()
468
+ ```
469
+
470
+ ### 统计信息差异
471
+
472
+ **Few-shot模式**:
473
+ ```python
474
+ # model_ensemble_few_shot.py:185
475
+ self.stats = pickle.load(open("memory_bank/statistic_scores_model_ensemble_few_shot_val.pkl", "rb"))
476
+ ```
477
+
478
+ **Full-data模式**:
479
+ ```python
480
+ # model_ensemble.py:188
481
+ self.stats = pickle.load(open("memory_bank/statistic_scores_model_ensemble_val.pkl", "rb"))
482
+ ```
483
+
484
+ ### 计算流程差异
485
+
486
+ **Few-shot模式流程**:
487
+ 1. 直接计算4个样本的特征
488
+ 2. 无需coreset计算
489
+ 3. 直接进行异常检测
490
+
491
+ **Full-data模式流程**:
492
+ 1. 计算所有训练样本特征(`compute_coreset.py`)
493
+ 2. 使用K-Center Greedy算法选择代表性特征
494
+ 3. 保存coreset到`memory_bank/`目录
495
+ 4. 加载预计算的coreset进行异常检测
496
+
497
+ ## 实现细节与优化
498
+
499
+ ### 内存优化策略
500
+
501
+ **批处理机制**:
502
+ ```python
503
+ # model_ensemble.py:926-928
504
+ for i in range(self.total_size//self.k_shot):
505
+ self.process(class_name, few_shot_samples[self.k_shot*i : min(self.k_shot*(i+1), self.total_size)],
506
+ few_shot_paths[self.k_shot*i : min(self.k_shot*(i+1), self.total_size)])
507
+ ```
508
+
509
+ **特征缓存**:
510
+ - 预计算的coreset特征保存在`memory_bank/`目录
511
+ - 统计信息预计算并缓存
512
+
513
+ ### 多模态特征融合
514
+
515
+ **特征层选择策略**:
516
+ - **聚类特征**:使用CLIP的第0、1层(`cluster_feature_id = [0, 1]`)
517
+ - **检测特征**:使用第6、12、18、24层的完整特征
518
+
519
+ **不同类别的模型选择**:
520
+ ```python
521
+ # model_ensemble.py:290-310
522
+ if self.class_name in ['pushpins', 'screw_bag']:
523
+ # 使用CLIP特征进行PatchCore检测
524
+ len_feature_list = len(self.feature_list)
525
+ for patch_feature, mem_patch_feature in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1),
526
+ mem_patch_feature_clip_coreset.chunk(len_feature_list, dim=-1)):
527
+
528
+ if self.class_name in ['splicing_connectors', 'breakfast_box', 'juice_bottle']:
529
+ # 使用DINOv2特征进行PatchCore检测
530
+ len_feature_list = len(self.feature_list_dinov2)
531
+ for patch_feature, mem_patch_feature in zip(patch_tokens_dinov2.chunk(len_feature_list, dim=-1),
532
+ mem_patch_feature_dinov2_coreset.chunk(len_feature_list, dim=-1)):
533
+ ```
534
+
535
+ ## 文本提示工程
536
+
537
+ ### 语义查询词典
538
+
539
+ **物体级别查询**:
540
+ ```python
541
+ # model_ensemble.py:123-136
542
+ self.query_words_dict = {
543
+ "breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
544
+ "juice_bottle": ['bottle', ['black background', 'background']],
545
+ "pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
546
+ "screw_bag": [['screw'], 'plastic bag', 'background'],
547
+ "splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
548
+ }
549
+ ```
550
+
551
+ **Patch级别查询**:
552
+ ```python
553
+ # model_ensemble.py:138-145
554
+ self.patch_query_words_dict = {
555
+ "juice_bottle": [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
556
+ "screw_bag": [['hex screw', 'hexagon bolt'], ['hex nut', 'hexagon nut'], ['ring washer', 'ring gasket'], ['plastic bag', 'background']],
557
+ # ...
558
+ }
559
+ ```
560
+
561
+ ### 文本编码策略
562
+
563
+ **多模板编码**:
564
+ ```python
565
+ # prompt_ensemble.py:98-120
566
+ def encode_obj_text(model, query_words, tokenizer, device):
567
+ for qw in query_words:
568
+ if type(qw) == list:
569
+ for qw2 in qw:
570
+ token_input.extend([temp(qw2) for temp in openai_imagenet_template])
571
+ else:
572
+ token_input = [temp(qw) for temp in openai_imagenet_template]
573
+ ```
574
+
575
+ 使用82个不同的ImageNet模板进行文本增强,提高文本特征的鲁棒性。
576
+
577
+ ## 性能评估
578
+
579
+ ### 评估指标
580
+
581
+ **图像级别指标**:
582
+ - F1-Max(Image)
583
+ - AUROC(Image)
584
+
585
+ **异常类型指标**:
586
+ - F1-Max(Logical):逻辑异常
587
+ - AUROC(Logical):逻辑异常
588
+ - F1-Max(Structural):结构异常
589
+ - AUROC(Structural):结构异常
590
+
591
+ ### 评估流程
592
+
593
+ **数据分离**:
594
+ ```python
595
+ # evaluation.py:222-227
596
+ if 'logical' not in image_path[0]:
597
+ image_metric_structure.update(output["pred_score"].cpu(), data["label"])
598
+ if 'structural' not in image_path[0]:
599
+ image_metric_logical.update(output["pred_score"].cpu(), data["label"])
600
+ ```
601
+
602
+ **分数融合**:
603
+ ```python
604
+ # model_ensemble.py:227-231
605
+ standard_structural_score = (structural_score - self.stats[self.class_name]["structural_scores"]["mean"]) / self.stats[self.class_name]["structural_scores"]["unbiased_std"]
606
+ standard_instance_hungarian_match_score = (instance_hungarian_match_score - self.stats[self.class_name]["instance_hungarian_match_scores"]["mean"]) / self.stats[self.class_name]["instance_hungarian_match_scores"]["unbiased_std"]
607
+
608
+ pred_score = max(standard_instance_hungarian_match_score, standard_structural_score)
609
+ pred_score = sigmoid(pred_score)
610
+ ```
611
+
612
+ ## 总结
613
+
614
+ LogSAD通过巧妙结合多个预训练模型的优势,实现了无需训练的异常检测:
615
+
616
+ 1. **多模态协作**:CLIP提供语义理解、DINOv2提供视觉特征、SAM提供精确分割
617
+ 2. **逻辑推理**:通过领域知识编码的特判逻辑检测复杂的逻辑异常
618
+ 3. **特征融合**:多尺度特征提取和融合提高检测精度
619
+ 4. **高效优化**:Coreset子采样和特征缓存机制保证实用性
620
+
621
+ 该方法在MVTec LOCO数据集上取得了优异的性能,展示了预训练模型在异常检测任务中的巨大潜力。
README.md ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Towards Training-free Anomaly Detection with Vision and Language Foundation Models (CVPR 2025)
2
+
3
+ <div>
4
+ <a href="https://arxiv.org/abs/2503.18325"><img src="https://img.shields.io/static/v1?label=Arxiv&message=LogSAD&color=red&logo=arxiv"></a> &ensp;
5
+ </div>
6
+
7
+ ## System Requirements
8
+
9
+ **Hardware Requirements:**
10
+ - **GPU Memory:** 32GB VRAM (for running complete experiments)
11
+ - **Storage:** 70GB free disk space (for models, datasets, and results)
12
+ - **CUDA:** Compatible GPU with CUDA 12.1 support
13
+
14
+ **Software Requirements:**
15
+ - Python 3.10
16
+ - Conda (recommended for environment management)
17
+ - CUDA 12.1 runtime
18
+
19
+ > **Note:** The memory and storage requirements are for running the full experimental pipeline on all categories with visualization enabled. Smaller experiments on individual categories may require less resources.
20
+
21
+ ## Installation
22
+
23
+ ### Automated Setup (Recommended)
24
+
25
+ Run the setup script to automatically configure the complete environment:
26
+
27
+ ```bash
28
+ bash scripts/setup_environment.sh
29
+ ```
30
+
31
+ This script will:
32
+ - Create a conda environment named `logsad` with Python 3.10
33
+ - Install PyTorch with CUDA 12.1 support
34
+ - Install all required dependencies from `requirements.txt`
35
+ - Configure numpy compatibility
36
+
37
+ ### Manual Setup
38
+
39
+ If you prefer manual setup, download the checkpoint for [ViT-H SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) and put it in the checkpoint folder.
40
+
41
+ After installation, activate the environment:
42
+ ```bash
43
+ conda activate logsad
44
+ ```
45
+
46
+
47
+ ## Instructions for MVTEC LOCO dataset
48
+
49
+ ### Quick Start (Recommended)
50
+
51
+ Run evaluation for all categories using the provided shell scripts:
52
+
53
+ **Few-shot Protocol:**
54
+ ```bash
55
+ bash scripts/run_few_shot.sh
56
+ ```
57
+
58
+ **Full-data Protocol:**
59
+ ```bash
60
+ bash scripts/run_full_data.sh
61
+ ```
62
+
63
+ ### Manual Execution
64
+
65
+ #### Few-shot Protocol
66
+ Run the script for few-shot protocal:
67
+
68
+ ```
69
+ python evaluation.py --module_path model_ensemble_few_shot --category CATEGORY --dataset_path DATASET_PATH
70
+ ```
71
+
72
+ #### Full-data Protocol
73
+ Run the script to compute coreset for full-data scenarios:
74
+
75
+ ```
76
+ python compute_coreset.py --module_path model_ensemble --category CATEGORY --dataset_path DATASET_PATH
77
+ ```
78
+
79
+ Run the script for full-data protocol:
80
+
81
+ ```
82
+ python evaluation.py --module_path model_ensemble --category CATEGORY --dataset_path DATASET_PATH
83
+ ```
84
+
85
+ **Available categories:** breakfast_box, juice_bottle, pushpins, screw_bag, splicing_connectors
86
+
87
+
88
+ ## Acknowledgement
89
+ We are grateful for the following awesome projects when implementing LogSAD:
90
+ * [SAM](https://github.com/facebookresearch/segment-anything), [OpenCLIP](https://github.com/mlfoundations/open_clip), [DINOv2](https://github.com/facebookresearch/dinov2) and [NACLIP](https://github.com/sinahmr/NACLIP).
91
+
92
+
93
+ ## Citation
94
+ If you find our paper is helpful in your research or applications, generously cite with
95
+ ```
96
+ @inproceedings{zhang2025logsad,
97
+ title={Towards Training-free Anomaly Detection with Vision and Language Foundation Models},
98
+ author={Jinjin Zhang, Guodong Wang, Yizhou Jin, Di Huang},
99
+ year={2025},
100
+ booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
101
+ }
102
+ ```
compute_coreset.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sample evaluation script for track 2."""
2
+
3
+ import os
4
+
5
+ # Set cache directories to use checkpoint folder for model downloads
6
+ os.environ['TORCH_HOME'] = './checkpoint'
7
+ os.environ['HF_HOME'] = './checkpoint/huggingface'
8
+ os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers'
9
+ os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub'
10
+
11
+ # Create checkpoint subdirectories if they don't exist
12
+ os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True)
13
+ os.makedirs('./checkpoint/huggingface/hub', exist_ok=True)
14
+
15
+ import argparse
16
+ import importlib
17
+ import importlib.util
18
+
19
+ import torch
20
+ import logging
21
+ from torch import nn
22
+
23
+ # NOTE: The following MVTecLoco import is not available in anomalib v1.0.1.
24
+ # It will be available in v1.1.0 which will be released on April 29th, 2024.
25
+ # If you are using an earlier version of anomalib, you could install anomalib
26
+ # from the anomalib source code from the following branch:
27
+ # https://github.com/openvinotoolkit/anomalib/tree/feature/mvtec-loco
28
+ from anomalib.data import MVTecLoco
29
+ from anomalib.metrics.f1_max import F1Max
30
+ from anomalib.metrics.auroc import AUROC
31
+ from tabulate import tabulate
32
+ import numpy as np
33
+
34
+ # FEW_SHOT_SAMPLES = [0, 1, 2, 3]
35
+
36
+ def parse_args() -> argparse.Namespace:
37
+ """Parse command line arguments.
38
+
39
+ Returns:
40
+ argparse.Namespace: Parsed arguments.
41
+ """
42
+ parser = argparse.ArgumentParser()
43
+ parser.add_argument("--module_path", type=str, required=True)
44
+ parser.add_argument("--class_name", default='MyModel', type=str, required=False)
45
+ parser.add_argument("--weights_path", type=str, required=False)
46
+ parser.add_argument("--dataset_path", default='/home/bhu/Project/datasets/mvtec_loco_anomaly_detection/', type=str, required=False)
47
+ parser.add_argument("--category", type=str, required=True)
48
+ parser.add_argument("--viz", action='store_true', default=False)
49
+ return parser.parse_args()
50
+
51
+
52
+ def load_model(module_path: str, class_name: str, weights_path: str) -> nn.Module:
53
+ """Load model.
54
+
55
+ Args:
56
+ module_path (str): Path to the module containing the model class.
57
+ class_name (str): Name of the model class.
58
+ weights_path (str): Path to the model weights.
59
+
60
+ Returns:
61
+ nn.Module: Loaded model.
62
+ """
63
+ # get model class
64
+ model_class = getattr(importlib.import_module(module_path), class_name)
65
+ # instantiate model
66
+ model = model_class()
67
+ # load weights
68
+ if weights_path:
69
+ model.load_state_dict(torch.load(weights_path))
70
+ return model
71
+
72
+
73
+ def run(module_path: str, class_name: str, weights_path: str, dataset_path: str, category: str, viz: bool) -> None:
74
+ """Run the evaluation script.
75
+
76
+ Args:
77
+ module_path (str): Path to the module containing the model class.
78
+ class_name (str): Name of the model class.
79
+ weights_path (str): Path to the model weights.
80
+ dataset_path (str): Path to the dataset.
81
+ category (str): Category of the dataset.
82
+ """
83
+ # Disable verbose logging from all libraries
84
+ logging.getLogger().setLevel(logging.ERROR)
85
+ logging.getLogger('anomalib').setLevel(logging.ERROR)
86
+ logging.getLogger('lightning').setLevel(logging.ERROR)
87
+ logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
88
+
89
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
90
+
91
+ # Instantiate model class here
92
+ # Load the model here from checkpoint.
93
+ model = load_model(module_path, class_name, weights_path)
94
+ model.to(device)
95
+
96
+ # Create the dataset
97
+ datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category)
98
+ datamodule.setup()
99
+
100
+ model.set_viz(viz)
101
+ model.set_save_coreset_features(True)
102
+
103
+
104
+ FEW_SHOT_SAMPLES = range(len(datamodule.train_data)) # traverse all dataset to build coreset
105
+
106
+ # pass few-shot images and dataset category to model
107
+ setup_data = {
108
+ "few_shot_samples": torch.stack([datamodule.train_data[idx]["image"] for idx in FEW_SHOT_SAMPLES]).to(device),
109
+ "few_shot_samples_path": [datamodule.train_data[idx]["image_path"] for idx in FEW_SHOT_SAMPLES],
110
+ "dataset_category": category,
111
+ }
112
+ model.setup(setup_data)
113
+
114
+ print(f"✓ Coreset computation completed for {category}")
115
+ print(f" Memory bank features saved to memory_bank/ directory")
116
+
117
+
118
+
119
+ if __name__ == "__main__":
120
+ args = parse_args()
121
+ run(args.module_path, args.class_name, args.weights_path, args.dataset_path, args.category, args.viz)
environment.yml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: logsad
2
+ channels:
3
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
4
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/pro/
5
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
6
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
7
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
8
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
9
+ - defaults
10
+ dependencies:
11
+ - _libgcc_mutex=0.1=conda_forge
12
+ - _openmp_mutex=4.5=2_gnu
13
+ - bzip2=1.0.8=h4bc722e_7
14
+ - ca-certificates=2025.8.3=hbd8a1cb_0
15
+ - ld_impl_linux-64=2.44=h1423503_1
16
+ - libexpat=2.7.1=hecca717_0
17
+ - libffi=3.4.6=h2dba641_1
18
+ - libgcc=15.1.0=h767d61c_4
19
+ - libgcc-ng=15.1.0=h69a702a_4
20
+ - libgomp=15.1.0=h767d61c_4
21
+ - liblzma=5.8.1=hb9d3cd8_2
22
+ - libnsl=2.0.1=hb9d3cd8_1
23
+ - libsqlite=3.50.4=h0c1763c_0
24
+ - libuuid=2.38.1=h0b41bf4_0
25
+ - libxcrypt=4.4.36=hd590300_1
26
+ - libzlib=1.3.1=hb9d3cd8_2
27
+ - ncurses=6.5=h2d0b736_3
28
+ - openssl=3.5.2=h26f9b46_0
29
+ - pip=25.2=pyh8b19718_0
30
+ - python=3.10.18=hd6af730_0_cpython
31
+ - readline=8.2=h8c095d6_2
32
+ - setuptools=80.9.0=pyhff2d567_0
33
+ - tk=8.6.13=noxft_hd72426e_102
34
+ - wheel=0.45.1=pyhd8ed1ab_1
35
+ - pip:
36
+ - aiohappyeyeballs==2.6.1
37
+ - aiohttp==3.12.11
38
+ - aiosignal==1.3.2
39
+ - antlr4-python3-runtime==4.9.3
40
+ - async-timeout==5.0.1
41
+ - attrs==25.3.0
42
+ - certifi==2025.4.26
43
+ - charset-normalizer==3.4.2
44
+ - contourpy==1.3.2
45
+ - cycler==0.12.1
46
+ - einops==0.6.1
47
+ - faiss-cpu==1.8.0
48
+ - filelock==3.18.0
49
+ - fonttools==4.58.2
50
+ - freia==0.2
51
+ - frozenlist==1.6.2
52
+ - fsspec==2024.12.0
53
+ - ftfy==6.3.1
54
+ - hf-xet==1.1.3
55
+ - huggingface-hub==0.32.4
56
+ - idna==3.10
57
+ - imageio==2.37.0
58
+ - imgaug==0.4.0
59
+ - jinja2==3.1.6
60
+ - joblib==1.5.1
61
+ - jsonargparse==4.29.0
62
+ - kiwisolver==1.4.8
63
+ - kmeans-pytorch==0.3
64
+ - kornia==0.7.0
65
+ - lazy-loader==0.4
66
+ - lightning==2.2.5
67
+ - lightning-utilities==0.14.3
68
+ - markdown-it-py==3.0.0
69
+ - markupsafe==3.0.2
70
+ - matplotlib==3.10.3
71
+ - mdurl==0.1.2
72
+ - mpmath==1.3.0
73
+ - multidict==6.4.4
74
+ - networkx==3.4.2
75
+ - numpy==1.23.1
76
+ - omegaconf==2.3.0
77
+ - open-clip-torch==2.24.0
78
+ - opencv-python==4.8.1.78
79
+ - packaging==24.2
80
+ - pandas==2.0.3
81
+ - pillow==11.2.1
82
+ - propcache==0.3.1
83
+ - protobuf==6.31.1
84
+ - pygments==2.19.1
85
+ - pyparsing==3.2.3
86
+ - python-dateutil==2.9.0.post0
87
+ - pytorch-lightning==2.5.1.post0
88
+ - pytz==2025.2
89
+ - pyyaml==6.0.2
90
+ - regex==2024.11.6
91
+ - requests==2.32.3
92
+ - rich==13.7.1
93
+ - safetensors==0.5.3
94
+ - scikit-image==0.25.2
95
+ - scikit-learn==1.2.2
96
+ - scipy==1.15.3
97
+ - segment-anything==1.0
98
+ - sentencepiece==0.2.0
99
+ - shapely==2.1.1
100
+ - six==1.17.0
101
+ - sympy==1.14.0
102
+ - tabulate==0.9.0
103
+ - threadpoolctl==3.6.0
104
+ - tifffile==2025.5.10
105
+ - timm==1.0.15
106
+ - torch==2.1.2+cu121
107
+ - torchmetrics==1.7.2
108
+ - torchvision==0.16.2+cu121
109
+ - tqdm==4.67.1
110
+ - triton==2.1.0
111
+ - typing-extensions==4.14.0
112
+ - tzdata==2025.2
113
+ - urllib3==2.4.0
114
+ - wcwidth==0.2.13
115
+ - yarl==1.20.0
116
+ prefix: /opt/conda/envs/logsad
evaluation.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sample evaluation script for track 2."""
2
+
3
+ import os
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+
7
+ # Set cache directories to use checkpoint folder for model downloads
8
+ os.environ['TORCH_HOME'] = './checkpoint'
9
+ os.environ['HF_HOME'] = './checkpoint/huggingface'
10
+ os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers'
11
+ os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub'
12
+
13
+ # Create checkpoint subdirectories if they don't exist
14
+ os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True)
15
+ os.makedirs('./checkpoint/huggingface/hub', exist_ok=True)
16
+
17
+ import argparse
18
+ import importlib
19
+ import importlib.util
20
+
21
+ import torch
22
+ import logging
23
+ from torch import nn
24
+
25
+ # NOTE: The following MVTecLoco import is not available in anomalib v1.0.1.
26
+ # It will be available in v1.1.0 which will be released on April 29th, 2024.
27
+ # If you are using an earlier version of anomalib, you could install anomalib
28
+ # from the anomalib source code from the following branch:
29
+ # https://github.com/openvinotoolkit/anomalib/tree/feature/mvtec-loco
30
+ from anomalib.data import MVTecLoco
31
+ from anomalib.metrics.f1_max import F1Max
32
+ from anomalib.metrics.auroc import AUROC
33
+ from tabulate import tabulate
34
+ import numpy as np
35
+
36
+ FEW_SHOT_SAMPLES = [0, 1, 2, 3]
37
+
38
+ def write_results_to_markdown(category, results_data, module_path):
39
+ """Write evaluation results to markdown file.
40
+
41
+ Args:
42
+ category (str): Dataset category name
43
+ results_data (dict): Dictionary containing all metrics
44
+ module_path (str): Model module path (for protocol identification)
45
+ """
46
+ # Determine protocol type from module path
47
+ protocol = "Few-shot" if "few_shot" in module_path else "Full-data"
48
+
49
+ # Create results directory
50
+ results_dir = Path("results")
51
+ results_dir.mkdir(exist_ok=True)
52
+
53
+ # Combined results file with simple protocol name
54
+ protocol_suffix = "few_shot" if "few_shot" in module_path else "full_data"
55
+ combined_file = results_dir / f"{protocol_suffix}_results.md"
56
+
57
+ # Read existing results if file exists
58
+ existing_results = {}
59
+ if combined_file.exists():
60
+ with open(combined_file, 'r') as f:
61
+ content = f.read()
62
+ # Parse existing results (basic parsing)
63
+ lines = content.split('\n')
64
+ for line in lines:
65
+ if '|' in line and line.count('|') >= 8:
66
+ parts = [p.strip() for p in line.split('|')]
67
+ if len(parts) >= 8 and parts[1] != 'Category' and parts[1] != '-----':
68
+ existing_results[parts[1]] = {
69
+ 'k_shots': parts[2],
70
+ 'f1_image': parts[3],
71
+ 'auroc_image': parts[4],
72
+ 'f1_logical': parts[5],
73
+ 'auroc_logical': parts[6],
74
+ 'f1_structural': parts[7],
75
+ 'auroc_structural': parts[8]
76
+ }
77
+
78
+ # Add current results
79
+ existing_results[category] = {
80
+ 'k_shots': str(len(FEW_SHOT_SAMPLES)),
81
+ 'f1_image': f"{results_data['f1_image']:.2f}",
82
+ 'auroc_image': f"{results_data['auroc_image']:.2f}",
83
+ 'f1_logical': f"{results_data['f1_logical']:.2f}",
84
+ 'auroc_logical': f"{results_data['auroc_logical']:.2f}",
85
+ 'f1_structural': f"{results_data['f1_structural']:.2f}",
86
+ 'auroc_structural': f"{results_data['auroc_structural']:.2f}"
87
+ }
88
+
89
+ # Write combined results
90
+ with open(combined_file, 'w') as f:
91
+ f.write(f"# All Categories - {protocol} Protocol Results\n\n")
92
+ f.write(f"**Last Updated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
93
+ f.write(f"**Protocol:** {protocol}\n")
94
+ f.write(f"**Available Categories:** {', '.join(sorted(existing_results.keys()))}\n\n")
95
+
96
+ f.write("## Summary Table\n\n")
97
+ f.write("| Category | K-shots | F1-Max (Image) | AUROC (Image) | F1-Max (Logical) | AUROC (Logical) | F1-Max (Structural) | AUROC (Structural) |\n")
98
+ f.write("|----------|---------|----------------|---------------|------------------|-----------------|---------------------|-------------------|\n")
99
+
100
+ # Sort categories alphabetically
101
+ for cat in sorted(existing_results.keys()):
102
+ data = existing_results[cat]
103
+ f.write(f"| {cat} | {data['k_shots']} | {data['f1_image']} | {data['auroc_image']} | {data['f1_logical']} | {data['auroc_logical']} | {data['f1_structural']} | {data['auroc_structural']} |\n")
104
+
105
+ print(f"\n✓ Results saved to:")
106
+ print(f" - Combined: {combined_file}")
107
+
108
+ def parse_args() -> argparse.Namespace:
109
+ """Parse command line arguments.
110
+
111
+ Returns:
112
+ argparse.Namespace: Parsed arguments.
113
+ """
114
+ parser = argparse.ArgumentParser()
115
+ parser.add_argument("--module_path", type=str, required=True)
116
+ parser.add_argument("--class_name", default='MyModel', type=str, required=False)
117
+ parser.add_argument("--weights_path", type=str, required=False)
118
+ parser.add_argument("--dataset_path", default='/home/bhu/Project/datasets/mvtec_loco_anomaly_detection/', type=str, required=False)
119
+ parser.add_argument("--category", type=str, required=True)
120
+ parser.add_argument("--viz", action='store_true', default=False)
121
+ return parser.parse_args()
122
+
123
+
124
+ def load_model(module_path: str, class_name: str, weights_path: str) -> nn.Module:
125
+ """Load model.
126
+
127
+ Args:
128
+ module_path (str): Path to the module containing the model class.
129
+ class_name (str): Name of the model class.
130
+ weights_path (str): Path to the model weights.
131
+
132
+ Returns:
133
+ nn.Module: Loaded model.
134
+ """
135
+ # get model class
136
+ model_class = getattr(importlib.import_module(module_path), class_name)
137
+ # instantiate model
138
+ model = model_class()
139
+ # load weights
140
+ if weights_path:
141
+ model.load_state_dict(torch.load(weights_path))
142
+ return model
143
+
144
+
145
+ def run(module_path: str, class_name: str, weights_path: str, dataset_path: str, category: str, viz: bool) -> None:
146
+ """Run the evaluation script.
147
+
148
+ Args:
149
+ module_path (str): Path to the module containing the model class.
150
+ class_name (str): Name of the model class.
151
+ weights_path (str): Path to the model weights.
152
+ dataset_path (str): Path to the dataset.
153
+ category (str): Category of the dataset.
154
+ """
155
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
156
+
157
+ # Instantiate model class here
158
+ # Load the model here from checkpoint.
159
+ model = load_model(module_path, class_name, weights_path)
160
+ model.to(device)
161
+
162
+ #
163
+ # Create the dataset
164
+ datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category)
165
+ datamodule.setup()
166
+
167
+ model.set_viz(viz)
168
+
169
+ #
170
+ # Create the metrics
171
+ image_metric = F1Max()
172
+ pixel_metric = F1Max()
173
+
174
+ image_metric_logical = F1Max()
175
+ image_metric_structure = F1Max()
176
+
177
+ image_metric_auroc = AUROC()
178
+ pixel_metric_auroc = AUROC()
179
+
180
+ image_metric_auroc_logical = AUROC()
181
+ image_metric_auroc_structure = AUROC()
182
+
183
+
184
+ #
185
+ # pass few-shot images and dataset category to model
186
+ setup_data = {
187
+ "few_shot_samples": torch.stack([datamodule.train_data[idx]["image"] for idx in FEW_SHOT_SAMPLES]).to(device),
188
+ "few_shot_samples_path": [datamodule.train_data[idx]["image_path"] for idx in FEW_SHOT_SAMPLES],
189
+ "dataset_category": category,
190
+ }
191
+ model.setup(setup_data)
192
+
193
+ # Loop over the test set and compute the metrics
194
+ for data in datamodule.test_dataloader():
195
+ with torch.no_grad():
196
+ image_path = data['image_path']
197
+ output = model(data["image"].to(device), data['image_path'])
198
+
199
+ image_metric.update(output["pred_score"].cpu(), data["label"])
200
+ image_metric_auroc.update(output["pred_score"].cpu(), data["label"])
201
+
202
+ if 'logical' not in image_path[0]:
203
+ image_metric_structure.update(output["pred_score"].cpu(), data["label"])
204
+ image_metric_auroc_structure.update(output["pred_score"].cpu(), data["label"])
205
+ if 'structural' not in image_path[0]:
206
+ image_metric_logical.update(output["pred_score"].cpu(), data["label"])
207
+ image_metric_auroc_logical.update(output["pred_score"].cpu(), data["label"])
208
+
209
+
210
+
211
+ # Disable verbose logging from all libraries
212
+ logging.getLogger().setLevel(logging.ERROR)
213
+ logging.getLogger('anomalib').setLevel(logging.ERROR)
214
+ logging.getLogger('lightning').setLevel(logging.ERROR)
215
+ logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
216
+
217
+ # Set up our own logger for results only
218
+ logger = logging.getLogger('evaluation')
219
+ logger.handlers.clear()
220
+ logger.setLevel(logging.INFO)
221
+ formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S')
222
+ console_handler = logging.StreamHandler()
223
+ console_handler.setFormatter(formatter)
224
+ logger.addHandler(console_handler)
225
+
226
+ table_ls = [[category,
227
+ str(len(FEW_SHOT_SAMPLES)),
228
+ str(np.round(image_metric.compute().item() * 100, decimals=2)),
229
+ str(np.round(image_metric_auroc.compute().item() * 100, decimals=2)),
230
+ # str(np.round(pixel_metric.compute().item() * 100, decimals=2)),
231
+ # str(np.round(pixel_metric_auroc.compute().item() * 100, decimals=2)),
232
+ str(np.round(image_metric_logical.compute().item() * 100, decimals=2)),
233
+ str(np.round(image_metric_auroc_logical.compute().item() * 100, decimals=2)),
234
+ str(np.round(image_metric_structure.compute().item() * 100, decimals=2)),
235
+ str(np.round(image_metric_auroc_structure.compute().item() * 100, decimals=2)),
236
+ ]]
237
+
238
+ results = tabulate(table_ls, headers=['category', 'K-shots', 'F1-Max(image)', 'AUROC(image)', 'F1-Max (logical)', 'AUROC (logical)', 'F1-Max (structural)', 'AUROC (structural)'], tablefmt="pipe")
239
+
240
+ logger.info("\n%s", results)
241
+
242
+ # Save results to markdown
243
+ results_data = {
244
+ 'f1_image': np.round(image_metric.compute().item() * 100, decimals=2),
245
+ 'auroc_image': np.round(image_metric_auroc.compute().item() * 100, decimals=2),
246
+ 'f1_logical': np.round(image_metric_logical.compute().item() * 100, decimals=2),
247
+ 'auroc_logical': np.round(image_metric_auroc_logical.compute().item() * 100, decimals=2),
248
+ 'f1_structural': np.round(image_metric_structure.compute().item() * 100, decimals=2),
249
+ 'auroc_structural': np.round(image_metric_auroc_structure.compute().item() * 100, decimals=2)
250
+ }
251
+ write_results_to_markdown(category, results_data, module_path)
252
+
253
+
254
+
255
+ if __name__ == "__main__":
256
+ args = parse_args()
257
+ run(args.module_path, args.class_name, args.weights_path, args.dataset_path, args.category, args.viz)
imagenet_template.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai_imagenet_template = [
2
+ lambda c: f'a bad photo of a {c}.',
3
+ lambda c: f'a photo of many {c}.',
4
+ lambda c: f'a sculpture of a {c}.',
5
+ lambda c: f'a photo of the hard to see {c}.',
6
+ lambda c: f'a low resolution photo of the {c}.',
7
+ lambda c: f'a rendering of a {c}.',
8
+ lambda c: f'graffiti of a {c}.',
9
+ lambda c: f'a bad photo of the {c}.',
10
+ lambda c: f'a cropped photo of the {c}.',
11
+ lambda c: f'a tattoo of a {c}.',
12
+ lambda c: f'the embroidered {c}.',
13
+ lambda c: f'a photo of a hard to see {c}.',
14
+ lambda c: f'a bright photo of a {c}.',
15
+ lambda c: f'a photo of a clean {c}.',
16
+ lambda c: f'a photo of a dirty {c}.',
17
+ lambda c: f'a dark photo of the {c}.',
18
+ lambda c: f'a drawing of a {c}.',
19
+ lambda c: f'a photo of my {c}.',
20
+ lambda c: f'the plastic {c}.',
21
+ lambda c: f'a photo of the cool {c}.',
22
+ lambda c: f'a close-up photo of a {c}.',
23
+ lambda c: f'a black and white photo of the {c}.',
24
+ lambda c: f'a painting of the {c}.',
25
+ lambda c: f'a painting of a {c}.',
26
+ lambda c: f'a pixelated photo of the {c}.',
27
+ lambda c: f'a sculpture of the {c}.',
28
+ lambda c: f'a bright photo of the {c}.',
29
+ lambda c: f'a cropped photo of a {c}.',
30
+ lambda c: f'a plastic {c}.',
31
+ lambda c: f'a photo of the dirty {c}.',
32
+ lambda c: f'a jpeg corrupted photo of a {c}.',
33
+ lambda c: f'a blurry photo of the {c}.',
34
+ lambda c: f'a photo of the {c}.',
35
+ lambda c: f'a good photo of the {c}.',
36
+ lambda c: f'a rendering of the {c}.',
37
+ lambda c: f'a {c} in a video game.',
38
+ lambda c: f'a photo of one {c}.',
39
+ lambda c: f'a doodle of a {c}.',
40
+ lambda c: f'a close-up photo of the {c}.',
41
+ lambda c: f'a photo of a {c}.',
42
+ lambda c: f'the origami {c}.',
43
+ lambda c: f'the {c} in a video game.',
44
+ lambda c: f'a sketch of a {c}.',
45
+ lambda c: f'a doodle of the {c}.',
46
+ lambda c: f'a origami {c}.',
47
+ lambda c: f'a low resolution photo of a {c}.',
48
+ lambda c: f'the toy {c}.',
49
+ lambda c: f'a rendition of the {c}.',
50
+ lambda c: f'a photo of the clean {c}.',
51
+ lambda c: f'a photo of a large {c}.',
52
+ lambda c: f'a rendition of a {c}.',
53
+ lambda c: f'a photo of a nice {c}.',
54
+ lambda c: f'a photo of a weird {c}.',
55
+ lambda c: f'a blurry photo of a {c}.',
56
+ lambda c: f'a cartoon {c}.',
57
+ lambda c: f'art of a {c}.',
58
+ lambda c: f'a sketch of the {c}.',
59
+ lambda c: f'a embroidered {c}.',
60
+ lambda c: f'a pixelated photo of a {c}.',
61
+ lambda c: f'itap of the {c}.',
62
+ lambda c: f'a jpeg corrupted photo of the {c}.',
63
+ lambda c: f'a good photo of a {c}.',
64
+ lambda c: f'a plushie {c}.',
65
+ lambda c: f'a photo of the nice {c}.',
66
+ lambda c: f'a photo of the small {c}.',
67
+ lambda c: f'a photo of the weird {c}.',
68
+ lambda c: f'the cartoon {c}.',
69
+ lambda c: f'art of the {c}.',
70
+ lambda c: f'a drawing of the {c}.',
71
+ lambda c: f'a photo of the large {c}.',
72
+ lambda c: f'a black and white photo of a {c}.',
73
+ lambda c: f'the plushie {c}.',
74
+ lambda c: f'a dark photo of a {c}.',
75
+ lambda c: f'itap of a {c}.',
76
+ lambda c: f'graffiti of the {c}.',
77
+ lambda c: f'a toy {c}.',
78
+ lambda c: f'itap of my {c}.',
79
+ lambda c: f'a photo of a cool {c}.',
80
+ lambda c: f'a photo of a small {c}.',
81
+ lambda c: f'a tattoo of the {c}.',
82
+ ]
model_ensemble.py ADDED
@@ -0,0 +1,1034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Set cache directories to use checkpoint folder for model downloads
4
+ os.environ['TORCH_HOME'] = './checkpoint'
5
+ os.environ['HF_HOME'] = './checkpoint/huggingface'
6
+ os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers'
7
+ os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub'
8
+
9
+ # Create checkpoint subdirectories if they don't exist
10
+ os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True)
11
+ os.makedirs('./checkpoint/huggingface/hub', exist_ok=True)
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torchvision.transforms import v2
16
+ from torchvision.transforms.v2.functional import resize
17
+ import cv2
18
+ import json
19
+ import torch
20
+ import random
21
+ import logging
22
+ import argparse
23
+ import numpy as np
24
+ from PIL import Image
25
+ from skimage import measure
26
+ from tabulate import tabulate
27
+ from torchvision.ops.focal_loss import sigmoid_focal_loss
28
+ import torch.nn.functional as F
29
+ import torchvision.transforms as transforms
30
+ import torchvision.transforms.functional as TF
31
+ from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
32
+ from sklearn.mixture import GaussianMixture
33
+ import faiss
34
+ import open_clip_local as open_clip
35
+
36
+ from torch.utils.data.dataset import ConcatDataset
37
+ from scipy.optimize import linear_sum_assignment
38
+ from sklearn.random_projection import SparseRandomProjection
39
+ import cv2
40
+ from torchvision.transforms import InterpolationMode
41
+ from PIL import Image
42
+ import string
43
+
44
+ from prompt_ensemble import encode_text_with_prompt_ensemble, encode_normal_text, encode_abnormal_text, encode_general_text, encode_obj_text
45
+ from kmeans_pytorch import kmeans, kmeans_predict
46
+ from scipy.optimize import linear_sum_assignment
47
+ from scipy.stats import norm
48
+
49
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
50
+ from matplotlib import pyplot as plt
51
+
52
+ import pickle
53
+ from scipy.stats import norm
54
+
55
+ from open_clip_local.pos_embed import get_2d_sincos_pos_embed
56
+
57
+ from anomalib.models.components import KCenterGreedy
58
+
59
+ def to_np_img(m):
60
+ m = m.permute(1, 2, 0).cpu().numpy()
61
+ mean = np.array([[[0.48145466, 0.4578275, 0.40821073]]])
62
+ std = np.array([[[0.26862954, 0.26130258, 0.27577711]]])
63
+ m = m * std + mean
64
+ return np.clip((m * 255.), 0, 255).astype(np.uint8)
65
+
66
+
67
+ def setup_seed(seed):
68
+ torch.manual_seed(seed)
69
+ torch.cuda.manual_seed_all(seed)
70
+ np.random.seed(seed)
71
+ random.seed(seed)
72
+ torch.backends.cudnn.deterministic = True
73
+ torch.backends.cudnn.benchmark = False
74
+
75
+
76
+ class MyModel(nn.Module):
77
+ """Example model class for track 2.
78
+
79
+ This class applies few-shot anomaly detection using the WinClip model from Anomalib.
80
+ """
81
+
82
+ def __init__(self) -> None:
83
+ super().__init__()
84
+
85
+ setup_seed(42)
86
+ # NOTE: Create your transformation pipeline (if needed).
87
+ self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
88
+ self.transform = v2.Compose(
89
+ [
90
+ v2.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
91
+ ],
92
+ )
93
+
94
+ # NOTE: Create your model.
95
+
96
+ self.model_clip, _, _ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
97
+ self.tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
98
+ self.feature_list = [6, 12, 18, 24]
99
+ self.embed_dim = 768
100
+ self.vision_width = 1024
101
+
102
+ self.model_sam = sam_model_registry["vit_h"](checkpoint = "./checkpoint/sam_vit_h_4b8939.pth").to(self.device)
103
+ self.mask_generator = SamAutomaticMaskGenerator(model = self.model_sam)
104
+
105
+ self.memory_size = 2048
106
+ self.n_neighbors = 2
107
+
108
+ self.model_clip.eval()
109
+ self.test_args = None
110
+ self.align_corners = True # False
111
+ self.antialias = True # False
112
+ self.inter_mode = 'bilinear' # bilinear/bicubic
113
+
114
+ self.cluster_feature_id = [0, 1]
115
+
116
+ self.cluster_num_dict = {
117
+ "breakfast_box": 3, # unused
118
+ "juice_bottle": 8, # unused
119
+ "splicing_connectors": 10, # unused
120
+ "pushpins": 10,
121
+ "screw_bag": 10,
122
+ }
123
+ self.query_words_dict = {
124
+ "breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
125
+ "juice_bottle": ['bottle', ['black background', 'background']],
126
+ "pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
127
+ "screw_bag": [['screw'], 'plastic bag', 'background'],
128
+ "splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
129
+ }
130
+ self.foreground_label_idx = { # for query_words_dict
131
+ "breakfast_box": [0, 1, 2, 3, 4, 5],
132
+ "juice_bottle": [0],
133
+ "pushpins": [0],
134
+ "screw_bag": [0],
135
+ "splicing_connectors":[0, 1]
136
+ }
137
+
138
+ self.patch_query_words_dict = {
139
+ "breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
140
+ "juice_bottle": [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
141
+ "pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
142
+ "screw_bag": [['hex screw', 'hexagon bolt'], ['hex nut', 'hexagon nut'], ['ring washer', 'ring gasket'], ['plastic bag', 'background']], # 79.71
143
+ "splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
144
+ }
145
+
146
+
147
+ self.query_threshold_dict = {
148
+ "breakfast_box": [0., 0., 0., 0., 0., 0., 0.], # unused
149
+ "juice_bottle": [0., 0., 0.], # unused
150
+ "splicing_connectors": [0.15, 0.15, 0.15, 0., 0.], # unused
151
+ "pushpins": [0.2, 0., 0., 0.],
152
+ "screw_bag": [0., 0., 0.,],
153
+ }
154
+
155
+ self.feat_size = 64
156
+ self.ori_feat_size = 32
157
+
158
+ self.visualization = False #False # True #False
159
+
160
+ self.pushpins_count = 15
161
+
162
+ self.splicing_connectors_count = [2, 3, 5] # coresponding to yellow, blue, and red
163
+ self.splicing_connectors_distance = 0
164
+ self.splicing_connectors_cable_color_query_words_dict = [['yellow cable', 'yellow wire'], ['blue cable', 'blue wire'], ['red cable', 'red wire']]
165
+
166
+ self.juice_bottle_liquid_query_words_dict = [['red liquid', 'cherry juice'], ['yellow liquid', 'orange juice'], ['milky liquid']]
167
+ self.juice_bottle_fruit_query_words_dict = ['cherry', ['tangerine', 'orange'], 'banana']
168
+
169
+ # query words
170
+ self.foreground_pixel_hist = 0
171
+ self.foreground_pixel_hist_screw_bag = 366.0 # 4-shot statistics
172
+ self.foreground_pixel_hist_splicing_connectors = 4249.666666666667 # 4-shot statistics
173
+ # patch query words
174
+ self.patch_token_hist = []
175
+
176
+ self.few_shot_inited = False
177
+
178
+ self.save_coreset_features = False
179
+
180
+
181
+ from dinov2.dinov2.hub.backbones import dinov2_vitl14
182
+ self.model_dinov2 = dinov2_vitl14()
183
+ self.model_dinov2.to(self.device)
184
+ self.model_dinov2.eval()
185
+ self.feature_list_dinov2 = [6, 12, 18, 24]
186
+ self.vision_width_dinov2 = 1024
187
+
188
+ self.stats = pickle.load(open("memory_bank/statistic_scores_model_ensemble_val.pkl", "rb"))
189
+
190
+ self.mem_instance_masks = None
191
+
192
+ self.anomaly_flag = False
193
+ self.validation = False #True #False
194
+
195
+ def set_save_coreset_features(self, save_coreset_features):
196
+ self.save_coreset_features = save_coreset_features
197
+
198
+ def set_viz(self, viz):
199
+ self.visualization = viz
200
+
201
+ def set_val(self, val):
202
+ self.validation = val
203
+
204
+ def forward(self, batch: torch.Tensor, batch_path: list) -> dict[str, torch.Tensor]:
205
+ """Transform the input batch and pass it through the model.
206
+
207
+ This model returns a dictionary with the following keys
208
+ - ``anomaly_map`` - Anomaly map.
209
+ - ``pred_score`` - Predicted anomaly score.
210
+ """
211
+ self.anomaly_flag = False
212
+ batch = self.transform(batch).to(self.device)
213
+ results = self.forward_one_sample(batch, self.mem_patch_feature_clip_coreset, self.mem_patch_feature_dinov2_coreset, batch_path[0])
214
+
215
+ hist_score = results['hist_score']
216
+ structural_score = results['structural_score']
217
+ instance_hungarian_match_score = results['instance_hungarian_match_score']
218
+
219
+
220
+ if self.validation:
221
+ return {"hist_score": torch.tensor(hist_score), "structural_score": torch.tensor(structural_score), "instance_hungarian_match_score": torch.tensor(instance_hungarian_match_score)}
222
+
223
+ def sigmoid(z):
224
+ return 1/(1 + np.exp(-z))
225
+
226
+ # standardization
227
+ standard_structural_score = (structural_score - self.stats[self.class_name]["structural_scores"]["mean"]) / self.stats[self.class_name]["structural_scores"]["unbiased_std"]
228
+ standard_instance_hungarian_match_score = (instance_hungarian_match_score - self.stats[self.class_name]["instance_hungarian_match_scores"]["mean"]) / self.stats[self.class_name]["instance_hungarian_match_scores"]["unbiased_std"]
229
+
230
+ pred_score = max(standard_instance_hungarian_match_score, standard_structural_score)
231
+ pred_score = sigmoid(pred_score)
232
+
233
+
234
+ if self.anomaly_flag:
235
+ pred_score = 1.
236
+ self.anomaly_flag = False
237
+
238
+
239
+ return {"pred_score": torch.tensor(pred_score), "hist_score": torch.tensor(hist_score), "structural_score": torch.tensor(structural_score), "instance_hungarian_match_score": torch.tensor(instance_hungarian_match_score)}
240
+
241
+
242
+ def forward_one_sample(self, batch: torch.Tensor, mem_patch_feature_clip_coreset: torch.Tensor, mem_patch_feature_dinov2_coreset: torch.Tensor, path: str):
243
+
244
+ with torch.no_grad():
245
+ image_features, patch_tokens, proj_patch_tokens = self.model_clip.encode_image(batch, self.feature_list)
246
+ # image_features /= image_features.norm(dim=-1, keepdim=True)
247
+ patch_tokens = [p[:, 1:, :] for p in patch_tokens]
248
+ patch_tokens = [p.reshape(p.shape[0]*p.shape[1], p.shape[2]) for p in patch_tokens]
249
+
250
+ patch_tokens_clip = torch.cat(patch_tokens, dim=-1) # (1, 1024, 1024x4)
251
+ # patch_tokens_clip = torch.cat(patch_tokens[2:], dim=-1) # (1, 1024, 1024x2)
252
+ patch_tokens_clip = patch_tokens_clip.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
253
+ patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
254
+ patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
255
+ patch_tokens_clip = F.normalize(patch_tokens_clip, p=2, dim=-1) # (1x64x64, 1024x4)
256
+
257
+ with torch.no_grad():
258
+ patch_tokens_dinov2 = self.model_dinov2.forward_features(batch, out_layer_list=self.feature_list)
259
+ patch_tokens_dinov2 = torch.cat(patch_tokens_dinov2, dim=-1) # (1, 1024, 1024x4)
260
+ patch_tokens_dinov2 = patch_tokens_dinov2.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
261
+ patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
262
+ patch_tokens_dinov2 = patch_tokens_dinov2.permute(0, 2, 3, 1).view(-1, self.vision_width_dinov2 * len(self.feature_list_dinov2))
263
+ patch_tokens_dinov2 = F.normalize(patch_tokens_dinov2, p=2, dim=-1) # (1x64x64, 1024x4)
264
+
265
+ '''adding for kmeans seg '''
266
+ if self.feat_size != self.ori_feat_size:
267
+ proj_patch_tokens = proj_patch_tokens.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
268
+ proj_patch_tokens = F.interpolate(proj_patch_tokens, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
269
+ proj_patch_tokens = proj_patch_tokens.permute(0, 2, 3, 1).view(self.feat_size * self.feat_size, self.embed_dim)
270
+ proj_patch_tokens = F.normalize(proj_patch_tokens, p=2, dim=-1)
271
+
272
+ mid_features = None
273
+ for layer in self.cluster_feature_id:
274
+ temp_feat = patch_tokens[layer]
275
+ mid_features = temp_feat if mid_features is None else torch.cat((mid_features, temp_feat), -1)
276
+
277
+ if self.feat_size != self.ori_feat_size:
278
+ mid_features = mid_features.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
279
+ mid_features = F.interpolate(mid_features, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
280
+ mid_features = mid_features.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.cluster_feature_id))
281
+ mid_features = F.normalize(mid_features, p=2, dim=-1)
282
+
283
+ results = self.histogram(batch, mid_features, proj_patch_tokens, self.class_name, os.path.dirname(path).split('/')[-1] + "_" + os.path.basename(path).split('.')[0])
284
+
285
+ hist_score = results['score']
286
+
287
+ '''calculate patchcore'''
288
+ anomaly_maps_patchcore = []
289
+
290
+ if self.class_name in ['pushpins', 'screw_bag']: # clip feature for patchcore
291
+ len_feature_list = len(self.feature_list)
292
+ for patch_feature, mem_patch_feature in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1), mem_patch_feature_clip_coreset.chunk(len_feature_list, dim=-1)):
293
+ patch_feature = F.normalize(patch_feature, dim=-1)
294
+ mem_patch_feature = F.normalize(mem_patch_feature, dim=-1)
295
+ normal_map_patchcore = (patch_feature @ mem_patch_feature.T)
296
+ normal_map_patchcore = (normal_map_patchcore.max(1)[0]).cpu().numpy() # 1: normal 0: abnormal
297
+ anomaly_map_patchcore = 1 - normal_map_patchcore
298
+
299
+ anomaly_maps_patchcore.append(anomaly_map_patchcore)
300
+
301
+ if self.class_name in ['splicing_connectors', 'breakfast_box', 'juice_bottle']: # dinov2 feature for patchcore
302
+ len_feature_list = len(self.feature_list_dinov2)
303
+ for patch_feature, mem_patch_feature in zip(patch_tokens_dinov2.chunk(len_feature_list, dim=-1), mem_patch_feature_dinov2_coreset.chunk(len_feature_list, dim=-1)):
304
+ patch_feature = F.normalize(patch_feature, dim=-1)
305
+ mem_patch_feature = F.normalize(mem_patch_feature, dim=-1)
306
+ normal_map_patchcore = (patch_feature @ mem_patch_feature.T)
307
+ normal_map_patchcore = (normal_map_patchcore.max(1)[0]).cpu().numpy() # 1: normal 0: abnormal
308
+ anomaly_map_patchcore = 1 - normal_map_patchcore
309
+
310
+ anomaly_maps_patchcore.append(anomaly_map_patchcore)
311
+
312
+ structural_score = np.stack(anomaly_maps_patchcore).mean(0).max()
313
+ # anomaly_map_structural = np.stack(anomaly_maps_patchcore).mean(0).reshape(self.feat_size, self.feat_size)
314
+
315
+ instance_masks = results["instance_masks"]
316
+ anomaly_instances_hungarian = []
317
+ instance_hungarian_match_score = 1.
318
+ if self.mem_instance_masks is not None and len(instance_masks) != 0:
319
+ for patch_feature, mem_instance_features_single_stage in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1), self.mem_instance_features_multi_stage.chunk(len_feature_list, dim=1)):
320
+ instance_features = [patch_feature[mask, :].mean(0, keepdim=True) for mask in instance_masks]
321
+ instance_features = torch.cat(instance_features, dim=0)
322
+ instance_features = F.normalize(instance_features, dim=-1)
323
+
324
+ normal_instance_hungarian = (instance_features @ mem_instance_features_single_stage.T)
325
+ cost_matrix = (1 - normal_instance_hungarian).cpu().numpy()
326
+
327
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
328
+ cost = cost_matrix[row_ind, col_ind].sum()
329
+ cost = cost / min(cost_matrix.shape)
330
+ anomaly_instances_hungarian.append(cost)
331
+
332
+ instance_hungarian_match_score = np.mean(anomaly_instances_hungarian)
333
+
334
+ results = {'hist_score': hist_score, 'structural_score': structural_score, 'instance_hungarian_match_score': instance_hungarian_match_score}
335
+
336
+ return results
337
+
338
+
339
+ def histogram(self, image, cluster_feature, proj_patch_token, class_name, path):
340
+ def plot_results_only(sorted_anns):
341
+ cur = 1
342
+ img_color = np.zeros((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1]))
343
+ for ann in sorted_anns:
344
+ m = ann['segmentation']
345
+ img_color[m] = cur
346
+ cur += 1
347
+ return img_color
348
+
349
+ def merge_segmentations(a, b, background_class):
350
+ unique_labels_a = np.unique(a)
351
+ unique_labels_b = np.unique(b)
352
+
353
+ max_label_a = unique_labels_a.max()
354
+ label_map = np.zeros(max_label_a + 1, dtype=int)
355
+
356
+ for label_a in unique_labels_a:
357
+ mask_a = (a == label_a)
358
+
359
+ labels_b = b[mask_a]
360
+ if labels_b.size > 0:
361
+ count_b = np.bincount(labels_b, minlength=unique_labels_b.max() + 1)
362
+ label_map[label_a] = np.argmax(count_b)
363
+ else:
364
+ label_map[label_a] = background_class # default background
365
+
366
+ merged_a = label_map[a]
367
+ return merged_a
368
+
369
+ pseudo_labels = kmeans_predict(cluster_feature, self.cluster_centers, 'euclidean', device=self.device)
370
+ kmeans_mask = torch.ones_like(pseudo_labels) * (self.classes - 1) # default to background
371
+
372
+ for pl in pseudo_labels.unique():
373
+ mask = (pseudo_labels == pl).reshape(-1)
374
+ # filter small region
375
+ binary = mask.cpu().numpy().reshape(self.feat_size, self.feat_size).astype(np.uint8)
376
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
377
+ for i in range(1, num_labels):
378
+ temp_mask = labels == i
379
+ if np.sum(temp_mask) <= 8:
380
+ mask[temp_mask.reshape(-1)] = False
381
+
382
+ if mask.any():
383
+ region_feature = proj_patch_token[mask, :].mean(0, keepdim=True)
384
+ similarity = (region_feature @ self.query_obj.T)
385
+ prob, index = torch.max(similarity, dim=-1)
386
+ temp_label = index.squeeze(0).item()
387
+ temp_prob = prob.squeeze(0).item()
388
+ if temp_prob > self.query_threshold_dict[class_name][temp_label]: # threshold for each class
389
+ kmeans_mask[mask] = temp_label
390
+
391
+
392
+ raw_image = to_np_img(image[0])
393
+ height, width = raw_image.shape[:2]
394
+ masks = self.mask_generator.generate(raw_image)
395
+ # self.predictor.set_image(raw_image)
396
+
397
+ kmeans_label = pseudo_labels.view(self.feat_size, self.feat_size).cpu().numpy()
398
+ kmeans_mask = kmeans_mask.view(self.feat_size, self.feat_size).cpu().numpy()
399
+
400
+ patch_similarity = (proj_patch_token @ self.patch_query_obj.T)
401
+ patch_mask = patch_similarity.argmax(-1)
402
+ patch_mask = patch_mask.view(self.feat_size, self.feat_size).cpu().numpy()
403
+
404
+ sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
405
+ sam_mask = plot_results_only(sorted_masks).astype(np.int)
406
+
407
+ resized_mask = cv2.resize(kmeans_mask, (width, height), interpolation = cv2.INTER_NEAREST)
408
+ merge_sam = merge_segmentations(sam_mask, resized_mask, background_class=self.classes-1)
409
+
410
+ resized_patch_mask = cv2.resize(patch_mask, (width, height), interpolation = cv2.INTER_NEAREST)
411
+ patch_merge_sam = merge_segmentations(sam_mask, resized_patch_mask, background_class=self.patch_query_obj.shape[0]-1)
412
+
413
+ # filter small region for merge sam
414
+ binary = np.isin(merge_sam, self.foreground_label_idx[self.class_name]).astype(np.uint8) # foreground 1 background 0
415
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
416
+ for i in range(1, num_labels):
417
+ temp_mask = labels == i
418
+ if np.sum(temp_mask) <= 32: # 448x448
419
+ merge_sam[temp_mask] = self.classes - 1 # set to background
420
+
421
+ # filter small region for patch merge sam
422
+ binary = (patch_merge_sam != (self.patch_query_obj.shape[0]-1) ).astype(np.uint8) # foreground 1 background 0
423
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
424
+ for i in range(1, num_labels):
425
+ temp_mask = labels == i
426
+ if np.sum(temp_mask) <= 32: # 448x448
427
+ patch_merge_sam[temp_mask] = self.patch_query_obj.shape[0]-1 # set to background
428
+
429
+ score = 0. # default to normal
430
+ self.anomaly_flag = False
431
+ instance_masks = []
432
+ if self.class_name == 'pushpins':
433
+ # object count hist
434
+ kernel = np.ones((3, 3), dtype=np.uint8) # dilate for robustness
435
+ binary = np.isin(merge_sam, self.foreground_label_idx[self.class_name]).astype(np.uint8) # foreground 1 background 0
436
+ dilate_binary = cv2.dilate(binary, kernel)
437
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(dilate_binary, connectivity=8)
438
+ pushpins_count = num_labels - 1 # number of pushpins
439
+
440
+ for i in range(1, num_labels):
441
+ instance_mask = (labels == i).astype(np.uint8)
442
+ instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
443
+ if instance_mask.any():
444
+ instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
445
+
446
+ if self.few_shot_inited and pushpins_count != self.pushpins_count and self.anomaly_flag is False:
447
+ self.anomaly_flag = True
448
+ print('number of pushpins: {}, but canonical number of pushpins: {}'.format(pushpins_count, self.pushpins_count))
449
+
450
+ # patch hist
451
+ clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])
452
+ clip_patch_hist = clip_patch_hist / np.linalg.norm(clip_patch_hist)
453
+
454
+ if self.few_shot_inited:
455
+ patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
456
+ score = 1 - patch_hist_similarity.max()
457
+
458
+ binary_foreground = dilate_binary.astype(np.uint8)
459
+
460
+ if len(instance_masks) != 0:
461
+ instance_masks = np.stack(instance_masks) #[N, 64x64]
462
+
463
+ if self.visualization:
464
+ image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
465
+ title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
466
+ plt.figure(figsize=(20, 3))
467
+ for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
468
+ plt.subplot(1, len(image_list), ind)
469
+ plt.imshow(temp_image)
470
+ plt.title(temp_title)
471
+ plt.margins(0, 0)
472
+ plt.axis('off')
473
+ # Extract relative path from class_name onwards
474
+ if class_name in path:
475
+ relative_path = path.split(class_name, 1)[-1]
476
+ if relative_path.startswith('/'):
477
+ relative_path = relative_path[1:]
478
+ save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
479
+ else:
480
+ save_path = f'visualization/full_data/{class_name}/{path}.png'
481
+
482
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
483
+ plt.tight_layout()
484
+ plt.savefig(save_path, bbox_inches='tight', dpi=150)
485
+ plt.close()
486
+
487
+
488
+ # todo: same number in total but in different boxes or broken box
489
+ return {"score": score, "clip_patch_hist": clip_patch_hist, "instance_masks": instance_masks}
490
+
491
+ elif self.class_name == 'splicing_connectors':
492
+ # object count hist for default
493
+ sam_mask_max_area = sorted_masks[0]['segmentation'] # background
494
+ binary = (sam_mask_max_area == 0).astype(np.uint8) # sam_mask_max_area is background, background 0 foreground 1
495
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
496
+ count = 0
497
+ for i in range(1, num_labels):
498
+ temp_mask = labels == i
499
+ if np.sum(temp_mask) <= 64: # 448x448 64
500
+ binary[temp_mask] = 0 # set to background
501
+ else:
502
+ count += 1
503
+ if count != 1 and self.anomaly_flag is False: # cable cut or no cable or no connector
504
+ print('number of connected component in splicing_connectors: {}, but the default connected component is 1.'.format(count))
505
+ self.anomaly_flag = True
506
+
507
+ merge_sam[~(binary.astype(np.bool))] = self.query_obj.shape[0] - 1 # remove noise
508
+ patch_merge_sam[~(binary.astype(np.bool))] = self.patch_query_obj.shape[0] - 1 # remove patch noise
509
+
510
+ # erode the cable and divide into left and right parts
511
+ kernel = np.ones((23, 23), dtype=np.uint8)
512
+ erode_binary = cv2.erode(binary, kernel)
513
+ h, w = erode_binary.shape
514
+ distance = 0
515
+
516
+ left, right = erode_binary[:, :int(w/2)], erode_binary[:, int(w/2):]
517
+ left_count = np.bincount(left.reshape(-1), minlength=self.classes)[1] # foreground
518
+ right_count = np.bincount(right.reshape(-1), minlength=self.classes)[1] # foreground
519
+
520
+ # binary_cable = (merge_sam == 1).astype(np.uint8)
521
+ binary_cable = (patch_merge_sam == 1).astype(np.uint8)
522
+
523
+ kernel = np.ones((5, 5), dtype=np.uint8)
524
+ binary_cable = cv2.erode(binary_cable, kernel)
525
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_cable, connectivity=8)
526
+ for i in range(1, num_labels):
527
+ temp_mask = labels == i
528
+ if np.sum(temp_mask) <= 64: # 448x448
529
+ binary_cable[temp_mask] = 0 # set to background
530
+
531
+
532
+ binary_cable = cv2.resize(binary_cable, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
533
+
534
+ binary_clamps = (patch_merge_sam == 0).astype(np.uint8)
535
+
536
+ kernel = np.ones((5, 5), dtype=np.uint8)
537
+ binary_clamps = cv2.erode(binary_clamps, kernel)
538
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_clamps, connectivity=8)
539
+ for i in range(1, num_labels):
540
+ temp_mask = labels == i
541
+ if np.sum(temp_mask) <= 64: # 448x448
542
+ binary_clamps[temp_mask] = 0 # set to background
543
+ else:
544
+ instance_mask = temp_mask.astype(np.uint8)
545
+ instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
546
+ if instance_mask.any():
547
+ instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
548
+
549
+ binary_clamps = cv2.resize(binary_clamps, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
550
+
551
+ binary_connector = cv2.resize(binary, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
552
+
553
+ query_cable_color = encode_obj_text(self.model_clip, self.splicing_connectors_cable_color_query_words_dict, self.tokenizer, self.device)
554
+ cable_feature = proj_patch_token[binary_cable.astype(np.bool).reshape(-1), :].mean(0, keepdim=True)
555
+ idx_color = (cable_feature @ query_cable_color.T).argmax(-1).squeeze(0).item()
556
+ foreground_pixel_count = np.sum(erode_binary) / self.splicing_connectors_count[idx_color]
557
+
558
+
559
+ slice_cable = binary[:, int(w/2)-1: int(w/2)+1]
560
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(slice_cable, connectivity=8)
561
+ cable_count = num_labels - 1
562
+ if cable_count != 1 and self.anomaly_flag is False: # two cables
563
+ print('number of cable count in splicing_connectors: {}, but the default cable count is 1.'.format(cable_count))
564
+ self.anomaly_flag = True
565
+
566
+ # {2-clamp: yellow 3-clamp: blue 5-clamp: red} cable color and clamp number mismatch
567
+ if self.few_shot_inited and self.foreground_pixel_hist_splicing_connectors != 0 and self.anomaly_flag is False:
568
+ ratio = foreground_pixel_count / self.foreground_pixel_hist_splicing_connectors
569
+ if (ratio > 1.2 or ratio < 0.8) and self.anomaly_flag is False: # color and number mismatch
570
+ print('cable color and number of clamps mismatch, cable color idx: {} (0: yellow 2-clamp, 1: blue 3-clamp, 2: red 5-clamp), foreground_pixel_count :{}, canonical foreground_pixel_hist: {}.'.format(idx_color, foreground_pixel_count, self.foreground_pixel_hist_splicing_connectors))
571
+ self.anomaly_flag = True
572
+
573
+ # left right hist for symmetry
574
+ ratio = np.sum(left_count) / (np.sum(right_count) + 1e-5)
575
+ if self.few_shot_inited and (ratio > 1.2 or ratio < 0.8) and self.anomaly_flag is False: # left right asymmetry in clamp
576
+ print('left and right connectors are not symmetry.')
577
+ self.anomaly_flag = True
578
+
579
+ # left and right centroids distance
580
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(erode_binary, connectivity=8)
581
+ if num_labels - 1 == 2:
582
+ centroids = centroids[1:]
583
+ x1, y1 = centroids[0]
584
+ x2, y2 = centroids[1]
585
+ distance = np.sqrt((x1/w - x2/w)**2 + (y1/h - y2/h)**2)
586
+ if self.few_shot_inited and self.splicing_connectors_distance != 0 and self.anomaly_flag is False:
587
+ ratio = distance / self.splicing_connectors_distance
588
+ if ratio < 0.6 or ratio > 1.4: # too short or too long centroids distance (cable) # 0.6 1.4
589
+ print('cable is too short or too long.')
590
+ self.anomaly_flag = True
591
+
592
+ # patch hist
593
+ sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])#[:-1] # ignore background (grid) for statistic
594
+ sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
595
+
596
+ if self.few_shot_inited:
597
+ patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
598
+ score = 1 - patch_hist_similarity.max()
599
+
600
+ # todo mismatch cable link
601
+ binary_foreground = binary.astype(np.uint8) # only 1 instance, so additionally seperate cable and clamps
602
+ if binary_connector.any():
603
+ instance_masks.append(binary_connector.astype(np.bool).reshape(-1))
604
+ if binary_clamps.any():
605
+ instance_masks.append(binary_clamps.astype(np.bool).reshape(-1))
606
+ if binary_cable.any():
607
+ instance_masks.append(binary_cable.astype(np.bool).reshape(-1))
608
+
609
+ if len(instance_masks) != 0:
610
+ instance_masks = np.stack(instance_masks) #[N, 64x64]
611
+
612
+ if self.visualization:
613
+ image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, binary_connector, merge_sam, patch_merge_sam, erode_binary, binary_cable, binary_clamps]
614
+ title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'binary_connector', 'merge sam', 'patch merge sam', 'erode binary', 'binary_cable', 'binary_clamps']
615
+ plt.figure(figsize=(25, 3))
616
+ for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
617
+ plt.subplot(1, len(image_list), ind)
618
+ plt.imshow(temp_image)
619
+ plt.title(temp_title)
620
+ plt.margins(0, 0)
621
+ plt.axis('off')
622
+ # Extract relative path from class_name onwards
623
+ if class_name in path:
624
+ relative_path = path.split(class_name, 1)[-1]
625
+ if relative_path.startswith('/'):
626
+ relative_path = relative_path[1:]
627
+ save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
628
+ else:
629
+ save_path = f'visualization/full_data/{class_name}/{path}.png'
630
+
631
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
632
+ plt.tight_layout()
633
+ plt.savefig(save_path, bbox_inches='tight', dpi=150)
634
+ plt.close()
635
+
636
+ return {"score": score, "foreground_pixel_count": foreground_pixel_count, "distance": distance, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
637
+
638
+ elif self.class_name == 'screw_bag':
639
+ # pixel hist of kmeans mask
640
+ foreground_pixel_count = np.sum(np.bincount(kmeans_mask.reshape(-1))[:len(self.foreground_label_idx[self.class_name])]) # foreground pixel
641
+ if self.few_shot_inited and self.foreground_pixel_hist_screw_bag != 0 and self.anomaly_flag is False:
642
+ ratio = foreground_pixel_count / self.foreground_pixel_hist_screw_bag
643
+ # todo: optimize
644
+ if ratio < 0.94 or ratio > 1.06: # 82.95 | 81.3
645
+ print('foreground pixel histagram of screw bag: {}, the canonical foreground pixel histogram of screw bag in few shot: {}'.format(foreground_pixel_count, self.foreground_pixel_hist_screw_bag))
646
+ self.anomaly_flag = True
647
+
648
+ # patch hist
649
+ binary_screw = np.isin(kmeans_mask, self.foreground_label_idx[self.class_name])
650
+ patch_mask[~binary_screw] = self.patch_query_obj.shape[0] - 1 # remove patch noise
651
+ resized_binary_screw = cv2.resize(binary_screw.astype(np.uint8), (patch_merge_sam.shape[1], patch_merge_sam.shape[0]), interpolation = cv2.INTER_NEAREST)
652
+ patch_merge_sam[~(resized_binary_screw.astype(np.bool))] = self.patch_query_obj.shape[0] - 1 # remove patch noise
653
+
654
+ clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])[:-1]
655
+ clip_patch_hist = clip_patch_hist / np.linalg.norm(clip_patch_hist)
656
+
657
+ if self.few_shot_inited:
658
+ patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
659
+ score = 1 - patch_hist_similarity.max()
660
+
661
+ for i in range(self.patch_query_obj.shape[0]-1):
662
+ binary_foreground = (patch_merge_sam == i).astype(np.uint8)
663
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
664
+ for i in range(1, num_labels):
665
+ instance_mask = (labels == i).astype(np.uint8)
666
+ instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
667
+ if instance_mask.any():
668
+ instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
669
+
670
+ if len(instance_masks) != 0:
671
+ instance_masks = np.stack(instance_masks) #[N, 64x64]
672
+
673
+ if self.visualization:
674
+ image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
675
+ title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
676
+ plt.figure(figsize=(20, 3))
677
+ for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
678
+ plt.subplot(1, len(image_list), ind)
679
+ plt.imshow(temp_image)
680
+ plt.title(temp_title)
681
+ plt.margins(0, 0)
682
+ plt.axis('off')
683
+ # Extract relative path from class_name onwards
684
+ if class_name in path:
685
+ relative_path = path.split(class_name, 1)[-1]
686
+ if relative_path.startswith('/'):
687
+ relative_path = relative_path[1:]
688
+ save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
689
+ else:
690
+ save_path = f'visualization/full_data/{class_name}/{path}.png'
691
+
692
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
693
+ plt.tight_layout()
694
+ plt.savefig(save_path, bbox_inches='tight', dpi=150)
695
+ plt.close()
696
+
697
+ # plt.axis('off')
698
+ # plt.imshow(patch_merge_sam)
699
+
700
+ # plt.savefig('pic/vis/{}_seg_{}.png'.format(class_name, path), bbox_inches='tight', pad_inches = 0) # pad_inches = 0
701
+ # plt.close()
702
+
703
+
704
+ return {"score": score, "foreground_pixel_count": foreground_pixel_count, "clip_patch_hist": clip_patch_hist, "instance_masks": instance_masks}
705
+
706
+ elif self.class_name == 'breakfast_box':
707
+ # patch hist
708
+ sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
709
+ sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
710
+
711
+ if self.few_shot_inited:
712
+ patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
713
+ score = 1 - patch_hist_similarity.max()
714
+
715
+ # todo: exist of foreground
716
+
717
+ binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1)).astype(np.uint8)
718
+
719
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
720
+ for i in range(1, num_labels):
721
+ instance_mask = (labels == i).astype(np.uint8)
722
+ instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
723
+ if instance_mask.any():
724
+ instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
725
+
726
+
727
+ if len(instance_masks) != 0:
728
+ instance_masks = np.stack(instance_masks) #[N, 64x64]
729
+
730
+ if self.visualization:
731
+ image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
732
+ title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
733
+ plt.figure(figsize=(20, 3))
734
+ for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
735
+ plt.subplot(1, len(image_list), ind)
736
+ plt.imshow(temp_image)
737
+ plt.title(temp_title)
738
+ plt.margins(0, 0)
739
+ plt.axis('off')
740
+ # Extract relative path from class_name onwards
741
+ if class_name in path:
742
+ relative_path = path.split(class_name, 1)[-1]
743
+ if relative_path.startswith('/'):
744
+ relative_path = relative_path[1:]
745
+ save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
746
+ else:
747
+ save_path = f'visualization/full_data/{class_name}/{path}.png'
748
+
749
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
750
+ plt.tight_layout()
751
+ plt.savefig(save_path, bbox_inches='tight', dpi=150)
752
+ plt.close()
753
+
754
+ # plt.axis('off')
755
+ # plt.imshow(patch_merge_sam)
756
+
757
+ # plt.savefig('pic/vis/{}_seg_{}.png'.format(class_name, path), bbox_inches='tight', pad_inches = 0) # pad_inches = 0
758
+ # plt.close()
759
+
760
+ return {"score": score, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
761
+
762
+ elif self.class_name == 'juice_bottle':
763
+ # remove noise due to non sam mask
764
+ merge_sam[sam_mask == 0] = self.classes - 1
765
+ patch_merge_sam[sam_mask == 0] = self.patch_query_obj.shape[0] - 1 # 79.5
766
+
767
+ # [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
768
+ # fruit and liquid mismatch (todo if exist)
769
+ resized_patch_merge_sam = cv2.resize(patch_merge_sam, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
770
+ binary_liquid = (resized_patch_merge_sam == 1)
771
+ binary_fruit = (resized_patch_merge_sam == 2)
772
+
773
+ query_liquid = encode_obj_text(self.model_clip, self.juice_bottle_liquid_query_words_dict, self.tokenizer, self.device)
774
+ query_fruit = encode_obj_text(self.model_clip, self.juice_bottle_fruit_query_words_dict, self.tokenizer, self.device)
775
+
776
+ liquid_feature = proj_patch_token[binary_liquid.reshape(-1), :].mean(0, keepdim=True)
777
+ liquid_idx = (liquid_feature @ query_liquid.T).argmax(-1).squeeze(0).item()
778
+
779
+ fruit_feature = proj_patch_token[binary_fruit.reshape(-1), :].mean(0, keepdim=True)
780
+ fruit_idx = (fruit_feature @ query_fruit.T).argmax(-1).squeeze(0).item()
781
+
782
+ if (liquid_idx != fruit_idx) and self.anomaly_flag is False:
783
+ print('liquid: {}, but fruit: {}.'.format(self.juice_bottle_liquid_query_words_dict[liquid_idx], self.juice_bottle_fruit_query_words_dict[fruit_idx]))
784
+ self.anomaly_flag = True
785
+
786
+ # # todo centroid of fruit and tag_0 mismatch (if exist) , only one tag, center
787
+
788
+ # patch hist
789
+ sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
790
+ sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
791
+
792
+ if self.few_shot_inited:
793
+ patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
794
+ score = 1 - patch_hist_similarity.max()
795
+
796
+ binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1) ).astype(np.uint8)
797
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
798
+ for i in range(1, num_labels):
799
+ instance_mask = (labels == i).astype(np.uint8)
800
+ instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
801
+ if instance_mask.any():
802
+ instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
803
+
804
+ if len(instance_masks) != 0:
805
+ instance_masks = np.stack(instance_masks) #[N, 64x64]
806
+
807
+ if self.visualization:
808
+ image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
809
+ title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
810
+ plt.figure(figsize=(20, 3))
811
+ for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
812
+ plt.subplot(1, len(image_list), ind)
813
+ plt.imshow(temp_image)
814
+ plt.title(temp_title)
815
+ plt.margins(0, 0)
816
+ plt.axis('off')
817
+ # Extract relative path from class_name onwards
818
+ if class_name in path:
819
+ relative_path = path.split(class_name, 1)[-1]
820
+ if relative_path.startswith('/'):
821
+ relative_path = relative_path[1:]
822
+ save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
823
+ else:
824
+ save_path = f'visualization/full_data/{class_name}/{path}.png'
825
+
826
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
827
+ plt.tight_layout()
828
+ plt.savefig(save_path, bbox_inches='tight', dpi=150)
829
+ plt.close()
830
+
831
+ return {"score": score, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
832
+
833
+ return {"score": score, "instance_masks": instance_masks}
834
+
835
+
836
+ def process_k_shot(self, class_name, few_shot_samples, few_shot_paths):
837
+ few_shot_samples = F.interpolate(few_shot_samples, size=(448, 448), mode=self.inter_mode, align_corners=self.align_corners, antialias=self.antialias)
838
+
839
+ with torch.no_grad():
840
+ image_features, patch_tokens, proj_patch_tokens = self.model_clip.encode_image(few_shot_samples, self.feature_list)
841
+ patch_tokens = [p[:, 1:, :] for p in patch_tokens]
842
+ patch_tokens = [p.reshape(p.shape[0]*p.shape[1], p.shape[2]) for p in patch_tokens]
843
+
844
+ patch_tokens_clip = torch.cat(patch_tokens, dim=-1) # (bs, 1024, 1024x4)
845
+ # patch_tokens_clip = torch.cat(patch_tokens[2:], dim=-1) # (bs, 1024, 1024x2)
846
+ patch_tokens_clip = patch_tokens_clip.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
847
+ patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
848
+ patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
849
+ patch_tokens_clip = F.normalize(patch_tokens_clip, p=2, dim=-1) # (bsx64x64, 1024x4)
850
+
851
+ with torch.no_grad():
852
+ patch_tokens_dinov2 = self.model_dinov2.forward_features(few_shot_samples, out_layer_list=self.feature_list_dinov2) # 4 x [bs, 32x32, 1024]
853
+ patch_tokens_dinov2 = torch.cat(patch_tokens_dinov2, dim=-1) # (bs, 1024, 1024x4)
854
+ patch_tokens_dinov2 = patch_tokens_dinov2.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
855
+ patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
856
+ patch_tokens_dinov2 = patch_tokens_dinov2.permute(0, 2, 3, 1).view(-1, self.vision_width_dinov2 * len(self.feature_list_dinov2))
857
+ patch_tokens_dinov2 = F.normalize(patch_tokens_dinov2, p=2, dim=-1) # (bsx64x64, 1024x4)
858
+
859
+
860
+ cluster_features = None
861
+ for layer in self.cluster_feature_id:
862
+ temp_feat = patch_tokens[layer]
863
+ cluster_features = temp_feat if cluster_features is None else torch.cat((cluster_features, temp_feat), 1)
864
+ if self.feat_size != self.ori_feat_size:
865
+ cluster_features = cluster_features.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
866
+ cluster_features = F.interpolate(cluster_features, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
867
+ cluster_features = cluster_features.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.cluster_feature_id))
868
+ cluster_features = F.normalize(cluster_features, p=2, dim=-1)
869
+
870
+ if self.feat_size != self.ori_feat_size:
871
+ proj_patch_tokens = proj_patch_tokens.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
872
+ proj_patch_tokens = F.interpolate(proj_patch_tokens, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
873
+ proj_patch_tokens = proj_patch_tokens.permute(0, 2, 3, 1).view(-1, self.embed_dim)
874
+ proj_patch_tokens = F.normalize(proj_patch_tokens, p=2, dim=-1)
875
+
876
+ if not self.cluster_init:
877
+ num_clusters = self.cluster_num_dict[class_name]
878
+ _, self.cluster_centers = kmeans(X=cluster_features, num_clusters=num_clusters, device=self.device)
879
+
880
+ self.query_obj = encode_obj_text(self.model_clip, self.query_words_dict[class_name], self.tokenizer, self.device)
881
+ self.patch_query_obj = encode_obj_text(self.model_clip, self.patch_query_words_dict[class_name], self.tokenizer, self.device)
882
+ self.classes = self.query_obj.shape[0]
883
+
884
+ self.cluster_init = True
885
+
886
+ scores = []
887
+ foreground_pixel_hist = []
888
+ splicing_connectors_distance = []
889
+ patch_token_hist = []
890
+ mem_instance_masks = []
891
+
892
+ for image, cluster_feature, proj_patch_token, few_shot_path in zip(few_shot_samples.chunk(self.k_shot), cluster_features.chunk(self.k_shot), proj_patch_tokens.chunk(self.k_shot), few_shot_paths):
893
+ # path = os.path.dirname(few_shot_path).split('/')[-1] + "_" + os.path.basename(few_shot_path).split('.')[0]
894
+ self.anomaly_flag = False
895
+ results = self.histogram(image, cluster_feature, proj_patch_token, class_name, "few_shot_" + os.path.basename(few_shot_path).split('.')[0])
896
+ if self.class_name == 'pushpins':
897
+ patch_token_hist.append(results["clip_patch_hist"])
898
+ mem_instance_masks.append(results['instance_masks'])
899
+
900
+ elif self.class_name == 'splicing_connectors':
901
+ foreground_pixel_hist.append(results["foreground_pixel_count"])
902
+ splicing_connectors_distance.append(results["distance"])
903
+ patch_token_hist.append(results["sam_patch_hist"])
904
+ mem_instance_masks.append(results['instance_masks'])
905
+
906
+ elif self.class_name == 'screw_bag':
907
+ foreground_pixel_hist.append(results["foreground_pixel_count"])
908
+ patch_token_hist.append(results["clip_patch_hist"])
909
+ mem_instance_masks.append(results['instance_masks'])
910
+
911
+ elif self.class_name == 'breakfast_box':
912
+ patch_token_hist.append(results["sam_patch_hist"])
913
+ mem_instance_masks.append(results['instance_masks'])
914
+
915
+ elif self.class_name == 'juice_bottle':
916
+ patch_token_hist.append(results["sam_patch_hist"])
917
+ mem_instance_masks.append(results['instance_masks'])
918
+
919
+ scores.append(results["score"])
920
+
921
+ if len(foreground_pixel_hist) != 0:
922
+ self.foreground_pixel_hist = np.mean(foreground_pixel_hist)
923
+ if len(splicing_connectors_distance) != 0:
924
+ self.splicing_connectors_distance = np.mean(splicing_connectors_distance)
925
+ if len(patch_token_hist) != 0: # patch hist
926
+ self.patch_token_hist = np.stack(patch_token_hist)
927
+ if len(mem_instance_masks) != 0:
928
+ self.mem_instance_masks = mem_instance_masks
929
+
930
+ # for interests matching
931
+ len_feature_list = len(self.feature_list)
932
+ for idx, batch_mem_patch_feature in enumerate(patch_tokens_clip.chunk(len_feature_list, dim=-1)): # 4 stages batch_mem_patch_feature (bsx64x64, 1024)
933
+ mem_instance_features = []
934
+ for mem_patch_feature, mem_instance_masks in zip(batch_mem_patch_feature.chunk(self.k_shot), self.mem_instance_masks): # k shot mem_patch_feature (64x64, 1024)
935
+ mem_instance_features.extend([mem_patch_feature[mask, :].mean(0, keepdim=True) for mask in mem_instance_masks])
936
+ mem_instance_features = torch.cat(mem_instance_features, dim=0)
937
+ mem_instance_features = F.normalize(mem_instance_features, dim=-1) # 4 stages
938
+ # mem_instance_features_multi_stage.append(mem_instance_features)
939
+ self.mem_instance_features_multi_stage[idx].append(mem_instance_features)
940
+
941
+
942
+ mem_patch_feature_clip_coreset = patch_tokens_clip
943
+ mem_patch_feature_dinov2_coreset = patch_tokens_dinov2
944
+
945
+ return scores, mem_patch_feature_clip_coreset, mem_patch_feature_dinov2_coreset
946
+
947
+ def process(self, class_name: str, few_shot_samples: list[torch.Tensor], few_shot_paths: list[str]):
948
+ few_shot_samples = self.transform(few_shot_samples).to(self.device)
949
+
950
+ scores, mem_patch_feature_clip_coreset, mem_patch_feature_dinov2_coreset = self.process_k_shot(class_name, few_shot_samples, few_shot_paths)
951
+
952
+ clip_sampler = KCenterGreedy(embedding=mem_patch_feature_clip_coreset, sampling_ratio=0.25)
953
+ mem_patch_feature_clip_coreset = clip_sampler.sample_coreset()
954
+
955
+ dinov2_sampler = KCenterGreedy(embedding=mem_patch_feature_dinov2_coreset, sampling_ratio=0.25)
956
+ mem_patch_feature_dinov2_coreset = dinov2_sampler.sample_coreset()
957
+
958
+ self.mem_patch_feature_clip_coreset.append(mem_patch_feature_clip_coreset)
959
+ self.mem_patch_feature_dinov2_coreset.append(mem_patch_feature_dinov2_coreset)
960
+
961
+
962
+ def setup(self, data: dict) -> None:
963
+ """Setup the few-shot samples for the model.
964
+
965
+ The evaluation script will call this method to pass the k images for few shot learning and the object class
966
+ name. In the case of MVTec LOCO this will be the dataset category name (e.g. breakfast_box). Please contact
967
+ the organizing committee if if your model requires any additional dataset-related information at setup-time.
968
+ """
969
+ few_shot_samples = data.get("few_shot_samples")
970
+ class_name = data.get("dataset_category")
971
+ few_shot_paths = data.get("few_shot_samples_path")
972
+ self.class_name = class_name
973
+
974
+ print(few_shot_samples.shape)
975
+
976
+ self.total_size = few_shot_samples.size(0)
977
+
978
+ self.k_shot = 4 if self.total_size > 4 else self.total_size
979
+
980
+ self.cluster_init = False
981
+ self.mem_instance_features_multi_stage = [[],[],[],[]]
982
+
983
+ self.mem_patch_feature_clip_coreset = []
984
+ self.mem_patch_feature_dinov2_coreset = []
985
+
986
+ # Check if coreset files already exist
987
+ clip_file = 'memory_bank/mem_patch_feature_clip_{}.pt'.format(self.class_name)
988
+ dinov2_file = 'memory_bank/mem_patch_feature_dinov2_{}.pt'.format(self.class_name)
989
+ instance_file = 'memory_bank/mem_instance_features_multi_stage_{}.pt'.format(self.class_name)
990
+
991
+ files_exist = os.path.exists(clip_file) and os.path.exists(dinov2_file) and os.path.exists(instance_file)
992
+
993
+ if self.save_coreset_features and not files_exist:
994
+ print(f"Coreset files not found for {self.class_name}, computing and saving...")
995
+ for i in range(self.total_size//self.k_shot):
996
+ self.process(class_name, few_shot_samples[self.k_shot*i : min(self.k_shot*(i+1), self.total_size)], few_shot_paths[self.k_shot*i : min(self.k_shot*(i+1), self.total_size)])
997
+
998
+ # Coreset Subsampling
999
+ self.mem_patch_feature_clip_coreset = torch.cat(self.mem_patch_feature_clip_coreset, dim=0)
1000
+ torch.save(self.mem_patch_feature_clip_coreset, clip_file)
1001
+
1002
+ self.mem_patch_feature_dinov2_coreset = torch.cat(self.mem_patch_feature_dinov2_coreset, dim=0)
1003
+ torch.save(self.mem_patch_feature_dinov2_coreset, dinov2_file)
1004
+
1005
+ print(self.mem_patch_feature_dinov2_coreset.shape, self.mem_patch_feature_clip_coreset.shape)
1006
+
1007
+ self.mem_instance_features_multi_stage = [ torch.cat(mem_instance_features, dim=0) for mem_instance_features in self.mem_instance_features_multi_stage ]
1008
+ self.mem_instance_features_multi_stage = torch.cat(self.mem_instance_features_multi_stage, dim=1)
1009
+ torch.save(self.mem_instance_features_multi_stage, instance_file)
1010
+
1011
+ print(self.mem_instance_features_multi_stage.shape)
1012
+
1013
+ elif self.save_coreset_features and files_exist:
1014
+ print(f"Coreset files found for {self.class_name}, loading existing files...")
1015
+ self.process(class_name, few_shot_samples[0 : self.k_shot], few_shot_paths[0 : self.k_shot])
1016
+
1017
+ self.mem_patch_feature_clip_coreset = torch.load(clip_file)
1018
+ self.mem_patch_feature_dinov2_coreset = torch.load(dinov2_file)
1019
+ self.mem_instance_features_multi_stage = torch.load(instance_file)
1020
+
1021
+ print(self.mem_patch_feature_dinov2_coreset.shape, self.mem_patch_feature_clip_coreset.shape)
1022
+ print(self.mem_instance_features_multi_stage.shape)
1023
+
1024
+ else:
1025
+ self.process(class_name, few_shot_samples[0 : self.k_shot], few_shot_paths[0 : self.k_shot])
1026
+
1027
+ self.mem_patch_feature_clip_coreset = torch.load(clip_file)
1028
+ self.mem_patch_feature_dinov2_coreset = torch.load(dinov2_file)
1029
+ self.mem_instance_features_multi_stage = torch.load(instance_file)
1030
+
1031
+
1032
+ self.few_shot_inited = True
1033
+
1034
+
model_ensemble_few_shot.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Set cache directories to use checkpoint folder for model downloads
4
+ os.environ['TORCH_HOME'] = './checkpoint'
5
+ os.environ['HF_HOME'] = './checkpoint/huggingface'
6
+ os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers'
7
+ os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub'
8
+
9
+ # Create checkpoint subdirectories if they don't exist
10
+ os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True)
11
+ os.makedirs('./checkpoint/huggingface/hub', exist_ok=True)
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torchvision.transforms import v2
16
+ from torchvision.transforms.v2.functional import resize
17
+ import cv2
18
+ import json
19
+ import torch
20
+ import random
21
+ import logging
22
+ import argparse
23
+ import numpy as np
24
+ from PIL import Image
25
+ from skimage import measure
26
+ from tabulate import tabulate
27
+ from torchvision.ops.focal_loss import sigmoid_focal_loss
28
+ import torch.nn.functional as F
29
+ import torchvision.transforms as transforms
30
+ import torchvision.transforms.functional as TF
31
+ from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
32
+ from sklearn.mixture import GaussianMixture
33
+ import faiss
34
+ import open_clip_local as open_clip
35
+
36
+ from torch.utils.data.dataset import ConcatDataset
37
+ from scipy.optimize import linear_sum_assignment
38
+ from sklearn.random_projection import SparseRandomProjection
39
+ import cv2
40
+ from torchvision.transforms import InterpolationMode
41
+ from PIL import Image
42
+ import string
43
+
44
+ from prompt_ensemble import encode_text_with_prompt_ensemble, encode_normal_text, encode_abnormal_text, encode_general_text, encode_obj_text
45
+ from kmeans_pytorch import kmeans, kmeans_predict
46
+ from scipy.optimize import linear_sum_assignment
47
+ from scipy.stats import norm
48
+
49
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
50
+ from matplotlib import pyplot as plt
51
+
52
+ import matplotlib
53
+ matplotlib.use('Agg')
54
+
55
+ import pickle
56
+ from scipy.stats import norm
57
+
58
+ from open_clip_local.pos_embed import get_2d_sincos_pos_embed
59
+
60
+ def to_np_img(m):
61
+ m = m.permute(1, 2, 0).cpu().numpy()
62
+ mean = np.array([[[0.48145466, 0.4578275, 0.40821073]]])
63
+ std = np.array([[[0.26862954, 0.26130258, 0.27577711]]])
64
+ m = m * std + mean
65
+ return np.clip((m * 255.), 0, 255).astype(np.uint8)
66
+
67
+
68
+ def setup_seed(seed):
69
+ torch.manual_seed(seed)
70
+ torch.cuda.manual_seed_all(seed)
71
+ np.random.seed(seed)
72
+ random.seed(seed)
73
+ torch.backends.cudnn.deterministic = True
74
+ torch.backends.cudnn.benchmark = False
75
+
76
+
77
+ class MyModel(nn.Module):
78
+ """Example model class for track 2.
79
+
80
+ This class applies few-shot anomaly detection using the WinClip model from Anomalib.
81
+ """
82
+
83
+ def __init__(self) -> None:
84
+ super().__init__()
85
+
86
+ setup_seed(42)
87
+ # NOTE: Create your transformation pipeline (if needed).
88
+ self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
89
+ self.transform = v2.Compose(
90
+ [
91
+ v2.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
92
+ ],
93
+ )
94
+
95
+ # NOTE: Create your model.
96
+
97
+ self.model_clip, _, _ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
98
+ self.tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
99
+ self.feature_list = [6, 12, 18, 24]
100
+ self.embed_dim = 768
101
+ self.vision_width = 1024
102
+
103
+ self.model_sam = sam_model_registry["vit_h"](checkpoint = "./checkpoint/sam_vit_h_4b8939.pth").to(self.device)
104
+ self.mask_generator = SamAutomaticMaskGenerator(model = self.model_sam)
105
+
106
+ self.memory_size = 2048
107
+ self.n_neighbors = 2
108
+
109
+ self.model_clip.eval()
110
+ self.test_args = None
111
+ self.align_corners = True # False
112
+ self.antialias = True # False
113
+ self.inter_mode = 'bilinear' # bilinear/bicubic
114
+
115
+ self.cluster_feature_id = [0, 1]
116
+
117
+ self.cluster_num_dict = {
118
+ "breakfast_box": 3, # unused
119
+ "juice_bottle": 8, # unused
120
+ "splicing_connectors": 10, # unused
121
+ "pushpins": 10,
122
+ "screw_bag": 10,
123
+ }
124
+ self.query_words_dict = {
125
+ "breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
126
+ "juice_bottle": ['bottle', ['black background', 'background']],
127
+ "pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
128
+ "screw_bag": [['screw'], 'plastic bag', 'background'],
129
+ "splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
130
+ }
131
+ self.foreground_label_idx = { # for query_words_dict
132
+ "breakfast_box": [0, 1, 2, 3, 4, 5],
133
+ "juice_bottle": [0],
134
+ "pushpins": [0],
135
+ "screw_bag": [0],
136
+ "splicing_connectors":[0, 1]
137
+ }
138
+
139
+ self.patch_query_words_dict = {
140
+ "breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
141
+ "juice_bottle": [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
142
+ "pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
143
+ "screw_bag": [['hex screw', 'hexagon bolt'], ['hex nut', 'hexagon nut'], ['ring washer', 'ring gasket'], ['plastic bag', 'background']],
144
+ "splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
145
+ }
146
+
147
+
148
+ self.query_threshold_dict = {
149
+ "breakfast_box": [0., 0., 0., 0., 0., 0., 0.], # unused
150
+ "juice_bottle": [0., 0., 0.], # unused
151
+ "splicing_connectors": [0.15, 0.15, 0.15, 0., 0.], # unused
152
+ "pushpins": [0.2, 0., 0., 0.],
153
+ "screw_bag": [0., 0., 0.,],
154
+ }
155
+
156
+ self.feat_size = 64
157
+ self.ori_feat_size = 32
158
+
159
+ self.visualization = False
160
+
161
+ self.pushpins_count = 15
162
+
163
+ self.splicing_connectors_count = [2, 3, 5] # coresponding to yellow, blue, and red
164
+ self.splicing_connectors_distance = 0
165
+ self.splicing_connectors_cable_color_query_words_dict = [['yellow cable', 'yellow wire'], ['blue cable', 'blue wire'], ['red cable', 'red wire']]
166
+
167
+ self.juice_bottle_liquid_query_words_dict = [['red liquid', 'cherry juice'], ['yellow liquid', 'orange juice'], ['milky liquid']]
168
+ self.juice_bottle_fruit_query_words_dict = ['cherry', ['tangerine', 'orange'], 'banana']
169
+
170
+ # query words
171
+ self.foreground_pixel_hist = 0
172
+ # patch query words
173
+ self.patch_token_hist = []
174
+
175
+ self.few_shot_inited = False
176
+
177
+
178
+ from dinov2.dinov2.hub.backbones import dinov2_vitl14
179
+ self.model_dinov2 = dinov2_vitl14()
180
+ self.model_dinov2.to(self.device)
181
+ self.model_dinov2.eval()
182
+ self.feature_list_dinov2 = [6, 12, 18, 24]
183
+ self.vision_width_dinov2 = 1024
184
+
185
+ self.stats = pickle.load(open("memory_bank/statistic_scores_model_ensemble_few_shot_val.pkl", "rb"))
186
+
187
+ self.mem_instance_masks = None
188
+
189
+ self.anomaly_flag = False
190
+ self.validation = False #True #False
191
+
192
+ def set_viz(self, viz):
193
+ self.visualization = viz
194
+
195
+ def set_val(self, val):
196
+ self.validation = val
197
+
198
+ def forward(self, batch: torch.Tensor, batch_path: list) -> dict[str, torch.Tensor]:
199
+ """Transform the input batch and pass it through the model.
200
+
201
+ This model returns a dictionary with the following keys
202
+ - ``anomaly_map`` - Anomaly map.
203
+ - ``pred_score`` - Predicted anomaly score.
204
+ """
205
+ self.anomaly_flag = False
206
+ batch = self.transform(batch).to(self.device)
207
+ results = self.forward_one_sample(batch, self.mem_patch_feature_clip_coreset, self.mem_patch_feature_dinov2_coreset, batch_path[0])
208
+
209
+ hist_score = results['hist_score']
210
+ structural_score = results['structural_score']
211
+ instance_hungarian_match_score = results['instance_hungarian_match_score']
212
+
213
+ anomaly_map_structural = results['anomaly_map_structural']
214
+
215
+ if self.validation:
216
+ return {"hist_score": torch.tensor(hist_score), "structural_score": torch.tensor(structural_score), "instance_hungarian_match_score": torch.tensor(instance_hungarian_match_score)}
217
+
218
+ def sigmoid(z):
219
+ return 1/(1 + np.exp(-z))
220
+
221
+ # standardization
222
+ standard_structural_score = (structural_score - self.stats[self.class_name]["structural_scores"]["mean"]) / self.stats[self.class_name]["structural_scores"]["unbiased_std"]
223
+ standard_instance_hungarian_match_score = (instance_hungarian_match_score - self.stats[self.class_name]["instance_hungarian_match_scores"]["mean"]) / self.stats[self.class_name]["instance_hungarian_match_scores"]["unbiased_std"]
224
+
225
+ pred_score = max(standard_instance_hungarian_match_score, standard_structural_score)
226
+ pred_score = sigmoid(pred_score)
227
+
228
+ if self.anomaly_flag:
229
+ pred_score = 1.
230
+ self.anomaly_flag = False
231
+
232
+ return {"pred_score": torch.tensor(pred_score), "anomaly_map": torch.tensor(anomaly_map_structural), "hist_score": torch.tensor(hist_score), "structural_score": torch.tensor(structural_score), "instance_hungarian_match_score": torch.tensor(instance_hungarian_match_score)}
233
+
234
+
235
+ def forward_one_sample(self, batch: torch.Tensor, mem_patch_feature_clip_coreset: torch.Tensor, mem_patch_feature_dinov2_coreset: torch.Tensor, path: str):
236
+
237
+ with torch.no_grad():
238
+ image_features, patch_tokens, proj_patch_tokens = self.model_clip.encode_image(batch, self.feature_list)
239
+ # image_features /= image_features.norm(dim=-1, keepdim=True)
240
+ patch_tokens = [p[:, 1:, :] for p in patch_tokens]
241
+ patch_tokens = [p.reshape(p.shape[0]*p.shape[1], p.shape[2]) for p in patch_tokens]
242
+
243
+ patch_tokens_clip = torch.cat(patch_tokens, dim=-1) # (1, 1024, 1024x4)
244
+ # patch_tokens_clip = torch.cat(patch_tokens[2:], dim=-1) # (1, 1024, 1024x2)
245
+ patch_tokens_clip = patch_tokens_clip.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
246
+ patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
247
+ patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
248
+ patch_tokens_clip = F.normalize(patch_tokens_clip, p=2, dim=-1) # (1x64x64, 1024x4)
249
+
250
+ with torch.no_grad():
251
+ patch_tokens_dinov2 = self.model_dinov2.forward_features(batch, out_layer_list=self.feature_list)
252
+ patch_tokens_dinov2 = torch.cat(patch_tokens_dinov2, dim=-1) # (1, 1024, 1024x4)
253
+ patch_tokens_dinov2 = patch_tokens_dinov2.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
254
+ patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
255
+ patch_tokens_dinov2 = patch_tokens_dinov2.permute(0, 2, 3, 1).view(-1, self.vision_width_dinov2 * len(self.feature_list_dinov2))
256
+ patch_tokens_dinov2 = F.normalize(patch_tokens_dinov2, p=2, dim=-1) # (1x64x64, 1024x4)
257
+
258
+ '''adding for kmeans seg '''
259
+ if self.feat_size != self.ori_feat_size:
260
+ proj_patch_tokens = proj_patch_tokens.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
261
+ proj_patch_tokens = F.interpolate(proj_patch_tokens, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
262
+ proj_patch_tokens = proj_patch_tokens.permute(0, 2, 3, 1).view(self.feat_size * self.feat_size, self.embed_dim)
263
+ proj_patch_tokens = F.normalize(proj_patch_tokens, p=2, dim=-1)
264
+
265
+ mid_features = None
266
+ for layer in self.cluster_feature_id:
267
+ temp_feat = patch_tokens[layer]
268
+ mid_features = temp_feat if mid_features is None else torch.cat((mid_features, temp_feat), -1)
269
+
270
+ if self.feat_size != self.ori_feat_size:
271
+ mid_features = mid_features.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
272
+ mid_features = F.interpolate(mid_features, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
273
+ mid_features = mid_features.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.cluster_feature_id))
274
+ mid_features = F.normalize(mid_features, p=2, dim=-1)
275
+
276
+ results = self.histogram(batch, mid_features, proj_patch_tokens, self.class_name, os.path.dirname(path).split('/')[-1] + "_" + os.path.basename(path).split('.')[0])
277
+
278
+ hist_score = results['score']
279
+
280
+ '''calculate patchcore'''
281
+ anomaly_maps_patchcore = []
282
+
283
+ if self.class_name in ['pushpins', 'screw_bag']: # clip feature for patchcore
284
+ len_feature_list = len(self.feature_list)
285
+ for patch_feature, mem_patch_feature in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1), mem_patch_feature_clip_coreset.chunk(len_feature_list, dim=-1)):
286
+ patch_feature = F.normalize(patch_feature, dim=-1)
287
+ mem_patch_feature = F.normalize(mem_patch_feature, dim=-1)
288
+ normal_map_patchcore = (patch_feature @ mem_patch_feature.T)
289
+ normal_map_patchcore = (normal_map_patchcore.max(1)[0]).cpu().numpy() # 1: normal 0: abnormal
290
+ anomaly_map_patchcore = 1 - normal_map_patchcore
291
+
292
+ anomaly_maps_patchcore.append(anomaly_map_patchcore)
293
+
294
+ if self.class_name in ['splicing_connectors', 'breakfast_box', 'juice_bottle']: # dinov2 feature for patchcore
295
+ len_feature_list = len(self.feature_list_dinov2)
296
+ for patch_feature, mem_patch_feature in zip(patch_tokens_dinov2.chunk(len_feature_list, dim=-1), mem_patch_feature_dinov2_coreset.chunk(len_feature_list, dim=-1)):
297
+ patch_feature = F.normalize(patch_feature, dim=-1)
298
+ mem_patch_feature = F.normalize(mem_patch_feature, dim=-1)
299
+ normal_map_patchcore = (patch_feature @ mem_patch_feature.T)
300
+ normal_map_patchcore = (normal_map_patchcore.max(1)[0]).cpu().numpy() # 1: normal 0: abnormal
301
+ anomaly_map_patchcore = 1 - normal_map_patchcore
302
+
303
+ anomaly_maps_patchcore.append(anomaly_map_patchcore)
304
+
305
+ structural_score = np.stack(anomaly_maps_patchcore).mean(0).max()
306
+ anomaly_map_structural = np.stack(anomaly_maps_patchcore).mean(0).reshape(self.feat_size, self.feat_size)
307
+
308
+ instance_masks = results["instance_masks"]
309
+ anomaly_instances_hungarian = []
310
+ instance_hungarian_match_score = 1.
311
+ if self.mem_instance_masks is not None and len(instance_masks) != 0:
312
+ for patch_feature, batch_mem_patch_feature in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1), mem_patch_feature_clip_coreset.chunk(len_feature_list, dim=-1)):
313
+ instance_features = [patch_feature[mask, :].mean(0, keepdim=True) for mask in instance_masks]
314
+ instance_features = torch.cat(instance_features, dim=0)
315
+ instance_features = F.normalize(instance_features, dim=-1)
316
+ mem_instance_features = []
317
+ for mem_patch_feature, mem_instance_masks in zip(batch_mem_patch_feature.chunk(self.k_shot), self.mem_instance_masks):
318
+ mem_instance_features.extend([mem_patch_feature[mask, :].mean(0, keepdim=True) for mask in mem_instance_masks])
319
+ mem_instance_features = torch.cat(mem_instance_features, dim=0)
320
+ mem_instance_features = F.normalize(mem_instance_features, dim=-1)
321
+
322
+ normal_instance_hungarian = (instance_features @ mem_instance_features.T)
323
+ cost_matrix = (1 - normal_instance_hungarian).cpu().numpy()
324
+
325
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
326
+ cost = cost_matrix[row_ind, col_ind].sum()
327
+ cost = cost / min(cost_matrix.shape)
328
+ anomaly_instances_hungarian.append(cost)
329
+
330
+ instance_hungarian_match_score = np.mean(anomaly_instances_hungarian)
331
+
332
+ results = {'hist_score': hist_score, 'structural_score': structural_score, 'instance_hungarian_match_score': instance_hungarian_match_score, "anomaly_map_structural": anomaly_map_structural}
333
+
334
+ return results
335
+
336
+
337
+ def histogram(self, image, cluster_feature, proj_patch_token, class_name, path):
338
+ def plot_results_only(sorted_anns):
339
+ cur = 1
340
+ img_color = np.zeros((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1]))
341
+ for ann in sorted_anns:
342
+ m = ann['segmentation']
343
+ img_color[m] = cur
344
+ cur += 1
345
+ return img_color
346
+
347
+ def merge_segmentations(a, b, background_class):
348
+ unique_labels_a = np.unique(a)
349
+ unique_labels_b = np.unique(b)
350
+
351
+ max_label_a = unique_labels_a.max()
352
+ label_map = np.zeros(max_label_a + 1, dtype=int)
353
+
354
+ for label_a in unique_labels_a:
355
+ mask_a = (a == label_a)
356
+
357
+ labels_b = b[mask_a]
358
+ if labels_b.size > 0:
359
+ count_b = np.bincount(labels_b, minlength=unique_labels_b.max() + 1)
360
+ label_map[label_a] = np.argmax(count_b)
361
+ else:
362
+ label_map[label_a] = background_class # default background
363
+
364
+ merged_a = label_map[a]
365
+ return merged_a
366
+
367
+ pseudo_labels = kmeans_predict(cluster_feature, self.cluster_centers, 'euclidean', device=self.device)
368
+ kmeans_mask = torch.ones_like(pseudo_labels) * (self.classes - 1) # default to background
369
+
370
+ for pl in pseudo_labels.unique():
371
+ mask = (pseudo_labels == pl).reshape(-1)
372
+ # filter small region
373
+ binary = mask.cpu().numpy().reshape(self.feat_size, self.feat_size).astype(np.uint8)
374
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
375
+ for i in range(1, num_labels):
376
+ temp_mask = labels == i
377
+ if np.sum(temp_mask) <= 8:
378
+ mask[temp_mask.reshape(-1)] = False
379
+
380
+ if mask.any():
381
+ region_feature = proj_patch_token[mask, :].mean(0, keepdim=True)
382
+ similarity = (region_feature @ self.query_obj.T)
383
+ prob, index = torch.max(similarity, dim=-1)
384
+ temp_label = index.squeeze(0).item()
385
+ temp_prob = prob.squeeze(0).item()
386
+ if temp_prob > self.query_threshold_dict[class_name][temp_label]: # threshold for each class
387
+ kmeans_mask[mask] = temp_label
388
+
389
+
390
+ raw_image = to_np_img(image[0])
391
+ height, width = raw_image.shape[:2]
392
+ masks = self.mask_generator.generate(raw_image)
393
+ # self.predictor.set_image(raw_image)
394
+
395
+ kmeans_label = pseudo_labels.view(self.feat_size, self.feat_size).cpu().numpy()
396
+ kmeans_mask = kmeans_mask.view(self.feat_size, self.feat_size).cpu().numpy()
397
+
398
+ patch_similarity = (proj_patch_token @ self.patch_query_obj.T)
399
+ patch_mask = patch_similarity.argmax(-1)
400
+ patch_mask = patch_mask.view(self.feat_size, self.feat_size).cpu().numpy()
401
+
402
+ sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
403
+ sam_mask = plot_results_only(sorted_masks).astype(np.int)
404
+
405
+ resized_mask = cv2.resize(kmeans_mask, (width, height), interpolation = cv2.INTER_NEAREST)
406
+ merge_sam = merge_segmentations(sam_mask, resized_mask, background_class=self.classes-1)
407
+
408
+ resized_patch_mask = cv2.resize(patch_mask, (width, height), interpolation = cv2.INTER_NEAREST)
409
+ patch_merge_sam = merge_segmentations(sam_mask, resized_patch_mask, background_class=self.patch_query_obj.shape[0]-1)
410
+
411
+ # filter small region for merge sam
412
+ binary = np.isin(merge_sam, self.foreground_label_idx[self.class_name]).astype(np.uint8) # foreground 1 background 0
413
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
414
+ for i in range(1, num_labels):
415
+ temp_mask = labels == i
416
+ if np.sum(temp_mask) <= 32: # 448x448
417
+ merge_sam[temp_mask] = self.classes - 1 # set to background
418
+
419
+ # filter small region for patch merge sam
420
+ binary = (patch_merge_sam != (self.patch_query_obj.shape[0]-1) ).astype(np.uint8) # foreground 1 background 0
421
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
422
+ for i in range(1, num_labels):
423
+ temp_mask = labels == i
424
+ if np.sum(temp_mask) <= 32: # 448x448
425
+ patch_merge_sam[temp_mask] = self.patch_query_obj.shape[0]-1 # set to background
426
+
427
+ score = 0. # default to normal
428
+ self.anomaly_flag = False
429
+ instance_masks = []
430
+ if self.class_name == 'pushpins':
431
+ # object count hist
432
+ kernel = np.ones((3, 3), dtype=np.uint8) # dilate for robustness
433
+ binary = np.isin(merge_sam, self.foreground_label_idx[self.class_name]).astype(np.uint8) # foreground 1 background 0
434
+ dilate_binary = cv2.dilate(binary, kernel)
435
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(dilate_binary, connectivity=8)
436
+ pushpins_count = num_labels - 1 # number of pushpins
437
+
438
+ for i in range(1, num_labels):
439
+ instance_mask = (labels == i).astype(np.uint8)
440
+ instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
441
+ if instance_mask.any():
442
+ instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
443
+
444
+ if self.few_shot_inited and pushpins_count != self.pushpins_count and self.anomaly_flag is False:
445
+ self.anomaly_flag = True
446
+ print('number of pushpins: {}, but canonical number of pushpins: {}'.format(pushpins_count, self.pushpins_count))
447
+
448
+ # patch hist
449
+ clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])
450
+ clip_patch_hist = clip_patch_hist / np.linalg.norm(clip_patch_hist)
451
+
452
+ if self.few_shot_inited:
453
+ patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
454
+ score = 1 - patch_hist_similarity.max()
455
+
456
+ binary_foreground = dilate_binary.astype(np.uint8)
457
+
458
+ if len(instance_masks) != 0:
459
+ instance_masks = np.stack(instance_masks) #[N, 64x64]
460
+
461
+ if self.visualization:
462
+ image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
463
+ title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
464
+ plt.figure(figsize=(20, 3))
465
+ for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
466
+ plt.subplot(1, len(image_list), ind)
467
+ plt.imshow(temp_image)
468
+ plt.title(temp_title)
469
+ plt.margins(0, 0)
470
+ plt.axis('off')
471
+ # Extract relative path from class_name onwards
472
+ if class_name in path:
473
+ relative_path = path.split(class_name, 1)[-1]
474
+ if relative_path.startswith('/'):
475
+ relative_path = relative_path[1:]
476
+ save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
477
+ else:
478
+ save_path = f'visualization/few_shot/{class_name}/{path}.png'
479
+
480
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
481
+ plt.tight_layout()
482
+ plt.savefig(save_path, bbox_inches='tight', dpi=150)
483
+ plt.close()
484
+
485
+ # todo: same number in total but in different boxes or broken box
486
+ return {"score": score, "clip_patch_hist": clip_patch_hist, "instance_masks": instance_masks}
487
+
488
+ elif self.class_name == 'splicing_connectors':
489
+ # object count hist for default
490
+ sam_mask_max_area = sorted_masks[0]['segmentation'] # background
491
+ binary = (sam_mask_max_area == 0).astype(np.uint8) # sam_mask_max_area is background, background 0 foreground 1
492
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
493
+ count = 0
494
+ for i in range(1, num_labels):
495
+ temp_mask = labels == i
496
+ if np.sum(temp_mask) <= 64: # 448x448 64
497
+ binary[temp_mask] = 0 # set to background
498
+ else:
499
+ count += 1
500
+ if count != 1 and self.anomaly_flag is False: # cable cut or no cable or no connector
501
+ print('number of connected component in splicing_connectors: {}, but the default connected component is 1.'.format(count))
502
+ self.anomaly_flag = True
503
+
504
+ merge_sam[~(binary.astype(np.bool))] = self.query_obj.shape[0] - 1 # remove noise
505
+ patch_merge_sam[~(binary.astype(np.bool))] = self.patch_query_obj.shape[0] - 1 # remove patch noise
506
+
507
+ # erode the cable and divide into left and right parts
508
+ kernel = np.ones((23, 23), dtype=np.uint8)
509
+ erode_binary = cv2.erode(binary, kernel)
510
+ h, w = erode_binary.shape
511
+ distance = 0
512
+
513
+ left, right = erode_binary[:, :int(w/2)], erode_binary[:, int(w/2):]
514
+ left_count = np.bincount(left.reshape(-1), minlength=self.classes)[1] # foreground
515
+ right_count = np.bincount(right.reshape(-1), minlength=self.classes)[1] # foreground
516
+
517
+ binary_cable = (patch_merge_sam == 1).astype(np.uint8)
518
+
519
+ kernel = np.ones((5, 5), dtype=np.uint8)
520
+ binary_cable = cv2.erode(binary_cable, kernel)
521
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_cable, connectivity=8)
522
+ for i in range(1, num_labels):
523
+ temp_mask = labels == i
524
+ if np.sum(temp_mask) <= 64: # 448x448
525
+ binary_cable[temp_mask] = 0 # set to background
526
+
527
+
528
+ binary_cable = cv2.resize(binary_cable, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
529
+
530
+ binary_clamps = (patch_merge_sam == 0).astype(np.uint8)
531
+
532
+ kernel = np.ones((5, 5), dtype=np.uint8)
533
+ binary_clamps = cv2.erode(binary_clamps, kernel)
534
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_clamps, connectivity=8)
535
+ for i in range(1, num_labels):
536
+ temp_mask = labels == i
537
+ if np.sum(temp_mask) <= 64: # 448x448
538
+ binary_clamps[temp_mask] = 0 # set to background
539
+ else:
540
+ instance_mask = temp_mask.astype(np.uint8)
541
+ instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
542
+ if instance_mask.any():
543
+ instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
544
+
545
+ binary_clamps = cv2.resize(binary_clamps, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
546
+
547
+ binary_connector = cv2.resize(binary, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
548
+
549
+ query_cable_color = encode_obj_text(self.model_clip, self.splicing_connectors_cable_color_query_words_dict, self.tokenizer, self.device)
550
+ cable_feature = proj_patch_token[binary_cable.astype(np.bool).reshape(-1), :].mean(0, keepdim=True)
551
+ idx_color = (cable_feature @ query_cable_color.T).argmax(-1).squeeze(0).item()
552
+ foreground_pixel_count = np.sum(erode_binary) / self.splicing_connectors_count[idx_color]
553
+
554
+
555
+ slice_cable = binary[:, int(w/2)-1: int(w/2)+1]
556
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(slice_cable, connectivity=8)
557
+ cable_count = num_labels - 1
558
+ if cable_count != 1 and self.anomaly_flag is False: # two cables
559
+ print('number of cable count in splicing_connectors: {}, but the default cable count is 1.'.format(cable_count))
560
+ self.anomaly_flag = True
561
+
562
+ # {2-clamp: yellow 3-clamp: blue 5-clamp: red} cable color and clamp number mismatch
563
+ if self.few_shot_inited and self.foreground_pixel_hist != 0 and self.anomaly_flag is False:
564
+ ratio = foreground_pixel_count / self.foreground_pixel_hist
565
+ if (ratio > 1.2 or ratio < 0.8) and self.anomaly_flag is False: # color and number mismatch
566
+ print('cable color and number of clamps mismatch, cable color idx: {} (0: yellow 2-clamp, 1: blue 3-clamp, 2: red 5-clamp), foreground_pixel_count :{}, canonical foreground_pixel_hist: {}.'.format(idx_color, foreground_pixel_count, self.foreground_pixel_hist))
567
+ self.anomaly_flag = True
568
+
569
+ # left right hist for symmetry
570
+ ratio = np.sum(left_count) / (np.sum(right_count) + 1e-5)
571
+ if self.few_shot_inited and (ratio > 1.2 or ratio < 0.8) and self.anomaly_flag is False: # left right asymmetry in clamp
572
+ print('left and right connectors are not symmetry.')
573
+ self.anomaly_flag = True
574
+
575
+ # left and right centroids distance
576
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(erode_binary, connectivity=8)
577
+ if num_labels - 1 == 2:
578
+ centroids = centroids[1:]
579
+ x1, y1 = centroids[0]
580
+ x2, y2 = centroids[1]
581
+ distance = np.sqrt((x1/w - x2/w)**2 + (y1/h - y2/h)**2)
582
+ if self.few_shot_inited and self.splicing_connectors_distance != 0 and self.anomaly_flag is False:
583
+ ratio = distance / self.splicing_connectors_distance
584
+ if ratio < 0.6 or ratio > 1.4: # too short or too long centroids distance (cable) # 0.6 1.4
585
+ print('cable is too short or too long.')
586
+ self.anomaly_flag = True
587
+
588
+ # patch hist
589
+ sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])#[:-1] # ignore background (grid) for statistic
590
+ sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
591
+
592
+ if self.few_shot_inited:
593
+ patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
594
+ score = 1 - patch_hist_similarity.max()
595
+
596
+ # todo mismatch cable link
597
+ binary_foreground = binary.astype(np.uint8) # only 1 instance, so additionally seperate cable and clamps
598
+ if binary_connector.any():
599
+ instance_masks.append(binary_connector.astype(np.bool).reshape(-1))
600
+ if binary_clamps.any():
601
+ instance_masks.append(binary_clamps.astype(np.bool).reshape(-1))
602
+ if binary_cable.any():
603
+ instance_masks.append(binary_cable.astype(np.bool).reshape(-1))
604
+
605
+ if len(instance_masks) != 0:
606
+ instance_masks = np.stack(instance_masks) #[N, 64x64]
607
+
608
+ if self.visualization:
609
+ image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, binary_connector, merge_sam, patch_merge_sam, erode_binary, binary_cable, binary_clamps]
610
+ title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'binary_connector', 'merge sam', 'patch merge sam', 'erode binary', 'binary_cable', 'binary_clamps']
611
+ plt.figure(figsize=(25, 3))
612
+ for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
613
+ plt.subplot(1, len(image_list), ind)
614
+ plt.imshow(temp_image)
615
+ plt.title(temp_title)
616
+ plt.margins(0, 0)
617
+ plt.axis('off')
618
+ # Extract relative path from class_name onwards
619
+ if class_name in path:
620
+ relative_path = path.split(class_name, 1)[-1]
621
+ if relative_path.startswith('/'):
622
+ relative_path = relative_path[1:]
623
+ save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
624
+ else:
625
+ save_path = f'visualization/few_shot/{class_name}/{path}.png'
626
+
627
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
628
+ plt.tight_layout()
629
+ plt.savefig(save_path, bbox_inches='tight', dpi=150)
630
+ plt.close()
631
+
632
+ return {"score": score, "foreground_pixel_count": foreground_pixel_count, "distance": distance, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
633
+
634
+ elif self.class_name == 'screw_bag':
635
+ # pixel hist of kmeans mask
636
+ foreground_pixel_count = np.sum(np.bincount(kmeans_mask.reshape(-1))[:len(self.foreground_label_idx[self.class_name])]) # foreground pixel
637
+ if self.few_shot_inited and self.foreground_pixel_hist != 0 and self.anomaly_flag is False:
638
+ ratio = foreground_pixel_count / self.foreground_pixel_hist
639
+ # todo: optimize
640
+ if ratio < 0.94 or ratio > 1.06:
641
+ print('foreground pixel histagram of screw bag: {}, the canonical foreground pixel histogram of screw bag in few shot: {}'.format(foreground_pixel_count, self.foreground_pixel_hist))
642
+ self.anomaly_flag = True
643
+
644
+ # patch hist
645
+ binary_screw = np.isin(kmeans_mask, self.foreground_label_idx[self.class_name])
646
+ patch_mask[~binary_screw] = self.patch_query_obj.shape[0] - 1 # remove patch noise
647
+ resized_binary_screw = cv2.resize(binary_screw.astype(np.uint8), (patch_merge_sam.shape[1], patch_merge_sam.shape[0]), interpolation = cv2.INTER_NEAREST)
648
+ patch_merge_sam[~(resized_binary_screw.astype(np.bool))] = self.patch_query_obj.shape[0] - 1 # remove patch noise
649
+
650
+ clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])[:-1]
651
+ clip_patch_hist = clip_patch_hist / np.linalg.norm(clip_patch_hist)
652
+
653
+ if self.few_shot_inited:
654
+ patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
655
+ score = 1 - patch_hist_similarity.max()
656
+
657
+ # # todo: count of screw, nut and washer, screw of different length
658
+ binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1)).astype(np.uint8)
659
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
660
+ for i in range(1, num_labels):
661
+ instance_mask = (labels == i).astype(np.uint8)
662
+ instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
663
+ if instance_mask.any():
664
+ instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
665
+
666
+ if len(instance_masks) != 0:
667
+ instance_masks = np.stack(instance_masks) #[N, 64x64]
668
+
669
+ if self.visualization:
670
+ image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
671
+ title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
672
+ plt.figure(figsize=(20, 3))
673
+ for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
674
+ plt.subplot(1, len(image_list), ind)
675
+ plt.imshow(temp_image)
676
+ plt.title(temp_title)
677
+ plt.margins(0, 0)
678
+ plt.axis('off')
679
+ # Extract relative path from class_name onwards
680
+ if class_name in path:
681
+ relative_path = path.split(class_name, 1)[-1]
682
+ if relative_path.startswith('/'):
683
+ relative_path = relative_path[1:]
684
+ save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
685
+ else:
686
+ save_path = f'visualization/few_shot/{class_name}/{path}.png'
687
+
688
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
689
+ plt.tight_layout()
690
+ plt.savefig(save_path, bbox_inches='tight', dpi=150)
691
+ plt.close()
692
+
693
+ return {"score": score, "foreground_pixel_count": foreground_pixel_count, "clip_patch_hist": clip_patch_hist, "instance_masks": instance_masks}
694
+
695
+ elif self.class_name == 'breakfast_box':
696
+ # patch hist
697
+ sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
698
+ sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
699
+
700
+ if self.few_shot_inited:
701
+ patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
702
+ score = 1 - patch_hist_similarity.max()
703
+
704
+ # todo: exist of foreground
705
+
706
+ binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1)).astype(np.uint8)
707
+
708
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
709
+ for i in range(1, num_labels):
710
+ instance_mask = (labels == i).astype(np.uint8)
711
+ instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
712
+ if instance_mask.any():
713
+ instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
714
+
715
+ if len(instance_masks) != 0:
716
+ instance_masks = np.stack(instance_masks) #[N, 64x64]
717
+
718
+ if self.visualization:
719
+ image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
720
+ title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
721
+ plt.figure(figsize=(20, 3))
722
+ for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
723
+ plt.subplot(1, len(image_list), ind)
724
+ plt.imshow(temp_image)
725
+ plt.title(temp_title)
726
+ plt.margins(0, 0)
727
+ plt.axis('off')
728
+ # Extract relative path from class_name onwards
729
+ if class_name in path:
730
+ relative_path = path.split(class_name, 1)[-1]
731
+ if relative_path.startswith('/'):
732
+ relative_path = relative_path[1:]
733
+ save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
734
+ else:
735
+ save_path = f'visualization/few_shot/{class_name}/{path}.png'
736
+
737
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
738
+ plt.tight_layout()
739
+ plt.savefig(save_path, bbox_inches='tight', dpi=150)
740
+ plt.close()
741
+
742
+ return {"score": score, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
743
+
744
+ elif self.class_name == 'juice_bottle':
745
+ # remove noise due to non sam mask
746
+ merge_sam[sam_mask == 0] = self.classes - 1
747
+ patch_merge_sam[sam_mask == 0] = self.patch_query_obj.shape[0] - 1 # 79.5
748
+
749
+ # [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
750
+ # fruit and liquid mismatch (todo if exist)
751
+ resized_patch_merge_sam = cv2.resize(patch_merge_sam, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
752
+ binary_liquid = (resized_patch_merge_sam == 1)
753
+ binary_fruit = (resized_patch_merge_sam == 2)
754
+
755
+ query_liquid = encode_obj_text(self.model_clip, self.juice_bottle_liquid_query_words_dict, self.tokenizer, self.device)
756
+ query_fruit = encode_obj_text(self.model_clip, self.juice_bottle_fruit_query_words_dict, self.tokenizer, self.device)
757
+
758
+ liquid_feature = proj_patch_token[binary_liquid.reshape(-1), :].mean(0, keepdim=True)
759
+ liquid_idx = (liquid_feature @ query_liquid.T).argmax(-1).squeeze(0).item()
760
+
761
+ fruit_feature = proj_patch_token[binary_fruit.reshape(-1), :].mean(0, keepdim=True)
762
+ fruit_idx = (fruit_feature @ query_fruit.T).argmax(-1).squeeze(0).item()
763
+
764
+ if (liquid_idx != fruit_idx) and self.anomaly_flag is False:
765
+ print('liquid: {}, but fruit: {}.'.format(self.juice_bottle_liquid_query_words_dict[liquid_idx], self.juice_bottle_fruit_query_words_dict[fruit_idx]))
766
+ self.anomaly_flag = True
767
+
768
+ # # todo centroid of fruit and tag_0 mismatch (if exist) , only one tag, center
769
+
770
+ # patch hist
771
+ sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
772
+ sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
773
+
774
+ if self.few_shot_inited:
775
+ patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
776
+ score = 1 - patch_hist_similarity.max()
777
+
778
+ binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1) ).astype(np.uint8)
779
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
780
+ for i in range(1, num_labels):
781
+ instance_mask = (labels == i).astype(np.uint8)
782
+ instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
783
+ if instance_mask.any():
784
+ instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
785
+
786
+ if len(instance_masks) != 0:
787
+ instance_masks = np.stack(instance_masks) #[N, 64x64]
788
+
789
+ if self.visualization:
790
+ image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
791
+ title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
792
+ plt.figure(figsize=(20, 3))
793
+ for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
794
+ plt.subplot(1, len(image_list), ind)
795
+ plt.imshow(temp_image)
796
+ plt.title(temp_title)
797
+ plt.margins(0, 0)
798
+ plt.axis('off')
799
+ # Extract relative path from class_name onwards
800
+ if class_name in path:
801
+ relative_path = path.split(class_name, 1)[-1]
802
+ if relative_path.startswith('/'):
803
+ relative_path = relative_path[1:]
804
+ save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
805
+ else:
806
+ save_path = f'visualization/few_shot/{class_name}/{path}.png'
807
+
808
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
809
+ plt.tight_layout()
810
+ plt.savefig(save_path, bbox_inches='tight', dpi=150)
811
+ plt.close()
812
+
813
+ return {"score": score, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
814
+
815
+ return {"score": score, "instance_masks": instance_masks}
816
+
817
+
818
+ def process_k_shot(self, class_name, few_shot_samples, few_shot_paths):
819
+ few_shot_samples = F.interpolate(few_shot_samples, size=(448, 448), mode=self.inter_mode, align_corners=self.align_corners, antialias=self.antialias)
820
+
821
+ with torch.no_grad():
822
+ image_features, patch_tokens, proj_patch_tokens = self.model_clip.encode_image(few_shot_samples, self.feature_list)
823
+ patch_tokens = [p[:, 1:, :] for p in patch_tokens]
824
+ patch_tokens = [p.reshape(p.shape[0]*p.shape[1], p.shape[2]) for p in patch_tokens]
825
+
826
+ patch_tokens_clip = torch.cat(patch_tokens, dim=-1) # (bs, 1024, 1024x4)
827
+ # patch_tokens_clip = torch.cat(patch_tokens[2:], dim=-1) # (bs, 1024, 1024x2)
828
+ patch_tokens_clip = patch_tokens_clip.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
829
+ patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
830
+ patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
831
+ patch_tokens_clip = F.normalize(patch_tokens_clip, p=2, dim=-1) # (bsx64x64, 1024x4)
832
+
833
+ with torch.no_grad():
834
+ patch_tokens_dinov2 = self.model_dinov2.forward_features(few_shot_samples, out_layer_list=self.feature_list_dinov2) # 4 x [bs, 32x32, 1024]
835
+ patch_tokens_dinov2 = torch.cat(patch_tokens_dinov2, dim=-1) # (bs, 1024, 1024x4)
836
+ patch_tokens_dinov2 = patch_tokens_dinov2.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
837
+ patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
838
+ patch_tokens_dinov2 = patch_tokens_dinov2.permute(0, 2, 3, 1).view(-1, self.vision_width_dinov2 * len(self.feature_list_dinov2))
839
+ patch_tokens_dinov2 = F.normalize(patch_tokens_dinov2, p=2, dim=-1) # (bsx64x64, 1024x4)
840
+
841
+ cluster_features = None
842
+ for layer in self.cluster_feature_id:
843
+ temp_feat = patch_tokens[layer]
844
+ cluster_features = temp_feat if cluster_features is None else torch.cat((cluster_features, temp_feat), 1)
845
+ if self.feat_size != self.ori_feat_size:
846
+ cluster_features = cluster_features.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
847
+ cluster_features = F.interpolate(cluster_features, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
848
+ cluster_features = cluster_features.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.cluster_feature_id))
849
+ cluster_features = F.normalize(cluster_features, p=2, dim=-1)
850
+
851
+ if self.feat_size != self.ori_feat_size:
852
+ proj_patch_tokens = proj_patch_tokens.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
853
+ proj_patch_tokens = F.interpolate(proj_patch_tokens, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
854
+ proj_patch_tokens = proj_patch_tokens.permute(0, 2, 3, 1).view(-1, self.embed_dim)
855
+ proj_patch_tokens = F.normalize(proj_patch_tokens, p=2, dim=-1)
856
+
857
+ num_clusters = self.cluster_num_dict[class_name]
858
+ _, self.cluster_centers = kmeans(X=cluster_features, num_clusters=num_clusters, device=self.device)
859
+
860
+ self.query_obj = encode_obj_text(self.model_clip, self.query_words_dict[class_name], self.tokenizer, self.device)
861
+ self.patch_query_obj = encode_obj_text(self.model_clip, self.patch_query_words_dict[class_name], self.tokenizer, self.device)
862
+ self.classes = self.query_obj.shape[0]
863
+
864
+ scores = []
865
+ foreground_pixel_hist = []
866
+ splicing_connectors_distance = []
867
+ patch_token_hist = []
868
+ mem_instance_masks = []
869
+
870
+ for image, cluster_feature, proj_patch_token, few_shot_path in zip(few_shot_samples.chunk(self.k_shot), cluster_features.chunk(self.k_shot), proj_patch_tokens.chunk(self.k_shot), few_shot_paths):
871
+ # path = os.path.dirname(few_shot_path).split('/')[-1] + "_" + os.path.basename(few_shot_path).split('.')[0]
872
+ self.anomaly_flag = False
873
+ results = self.histogram(image, cluster_feature, proj_patch_token, class_name, "few_shot_" + os.path.basename(few_shot_path).split('.')[0])
874
+ if self.class_name == 'pushpins':
875
+ patch_token_hist.append(results["clip_patch_hist"])
876
+ mem_instance_masks.append(results['instance_masks'])
877
+
878
+ elif self.class_name == 'splicing_connectors':
879
+ foreground_pixel_hist.append(results["foreground_pixel_count"])
880
+ splicing_connectors_distance.append(results["distance"])
881
+ patch_token_hist.append(results["sam_patch_hist"])
882
+ mem_instance_masks.append(results['instance_masks'])
883
+
884
+ elif self.class_name == 'screw_bag':
885
+ foreground_pixel_hist.append(results["foreground_pixel_count"])
886
+ patch_token_hist.append(results["clip_patch_hist"])
887
+ mem_instance_masks.append(results['instance_masks'])
888
+
889
+ elif self.class_name == 'breakfast_box':
890
+ patch_token_hist.append(results["sam_patch_hist"])
891
+ mem_instance_masks.append(results['instance_masks'])
892
+
893
+ elif self.class_name == 'juice_bottle':
894
+ patch_token_hist.append(results["sam_patch_hist"])
895
+ mem_instance_masks.append(results['instance_masks'])
896
+
897
+ scores.append(results["score"])
898
+
899
+ if len(foreground_pixel_hist) != 0:
900
+ self.foreground_pixel_hist = np.mean(foreground_pixel_hist)
901
+ if len(splicing_connectors_distance) != 0:
902
+ self.splicing_connectors_distance = np.mean(splicing_connectors_distance)
903
+ if len(patch_token_hist) != 0: # patch hist
904
+ self.patch_token_hist = np.stack(patch_token_hist)
905
+ if len(mem_instance_masks) != 0:
906
+ self.mem_instance_masks = mem_instance_masks
907
+
908
+ mem_patch_feature_clip_coreset = patch_tokens_clip
909
+ mem_patch_feature_dinov2_coreset = patch_tokens_dinov2
910
+
911
+ return scores, mem_patch_feature_clip_coreset, mem_patch_feature_dinov2_coreset
912
+
913
+
914
+
915
+ def process(self, class_name: str, few_shot_samples: list[torch.Tensor], few_shot_paths: list[str]):
916
+ few_shot_samples = self.transform(few_shot_samples).to(self.device)
917
+ scores, self.mem_patch_feature_clip_coreset, self.mem_patch_feature_dinov2_coreset = self.process_k_shot(class_name, few_shot_samples, few_shot_paths)
918
+
919
+ def setup(self, data: dict) -> None:
920
+ """Setup the few-shot samples for the model.
921
+
922
+ The evaluation script will call this method to pass the k images for few shot learning and the object class
923
+ name. In the case of MVTec LOCO this will be the dataset category name (e.g. breakfast_box). Please contact
924
+ the organizing committee if if your model requires any additional dataset-related information at setup-time.
925
+ """
926
+ few_shot_samples = data.get("few_shot_samples")
927
+ class_name = data.get("dataset_category")
928
+ few_shot_paths = data.get("few_shot_samples_path")
929
+ self.class_name = class_name
930
+
931
+ self.k_shot = few_shot_samples.size(0)
932
+ self.process(class_name, few_shot_samples, few_shot_paths)
933
+ self.few_shot_inited = True
934
+
935
+
prompt_ensemble.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, List
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from imagenet_template import openai_imagenet_template
7
+
8
+
9
+ def encode_text_with_prompt_ensemble(model, objs, tokenizer, device):
10
+ prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
11
+ prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
12
+ prompt_state = [prompt_normal, prompt_abnormal]
13
+ prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
14
+ text_prompts = {}
15
+ for obj in objs:
16
+ text_features = []
17
+ for i in range(len(prompt_state)):
18
+ prompted_state = [state.format(obj) for state in prompt_state[i]]
19
+ prompted_sentence = []
20
+ for s in prompted_state:
21
+ for template in prompt_templates:
22
+ prompted_sentence.append(template.format(s))
23
+ prompted_sentence = tokenizer(prompted_sentence).to(device)
24
+ class_embeddings = model.encode_text(prompted_sentence)
25
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
26
+ class_embedding = class_embeddings.mean(dim=0)
27
+ class_embedding /= class_embedding.norm()
28
+ text_features.append(class_embedding)
29
+
30
+ text_features = torch.stack(text_features, dim=1).to(device)
31
+ text_prompts[obj] = text_features
32
+
33
+ return text_prompts
34
+
35
+
36
+ def encode_general_text(model, obj_list, tokenizer, device):
37
+ text_dir = '/data/yizhou/VAND2.0/wgd/general_texts/train2014'
38
+ text_name_list = sorted(os.listdir(text_dir))
39
+ bs = 100
40
+ sentences = []
41
+ embeddings = []
42
+ all_sentences = []
43
+ for text_name in tqdm(text_name_list):
44
+ with open(os.path.join(text_dir, text_name), 'r') as f:
45
+ for line in f.readlines():
46
+ sentences.append(line.strip())
47
+ if len(sentences) > bs:
48
+ prompted_sentences = tokenizer(sentences).to(device)
49
+ class_embeddings = model.encode_text(prompted_sentences)
50
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
51
+ embeddings.append(class_embeddings)
52
+ all_sentences.extend(sentences)
53
+ sentences = []
54
+ # if len(all_sentences) > 10000:
55
+ # break
56
+ embeddings = torch.cat(embeddings, 0)
57
+ print(embeddings.size(0))
58
+ embeddings_dict = {}
59
+ for obj in obj_list:
60
+ embeddings_dict[obj] = embeddings
61
+ return embeddings_dict, all_sentences
62
+
63
+
64
+ def encode_abnormal_text(model, obj_list, tokenizer, device):
65
+ embeddings = {}
66
+ sentences = {}
67
+ for obj in obj_list:
68
+ sentence_abnormal = []
69
+ with open(os.path.join('text_prompt', 'v1', obj + '_abnormal.txt'), 'r') as f:
70
+ for line in f.readlines():
71
+ sentence_abnormal.append(line.strip().lower())
72
+
73
+ prompted_sentences = tokenizer(sentence_abnormal).to(device)
74
+ class_embeddings = model.encode_text(prompted_sentences)
75
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
76
+ embeddings[obj] = class_embeddings
77
+ sentences[obj] = sentence_abnormal
78
+ return embeddings, sentences
79
+
80
+
81
+ def encode_normal_text(model, obj_list, tokenizer, device):
82
+ embeddings = {}
83
+ sentences = {}
84
+ for obj in obj_list:
85
+ sentence_abnormal = []
86
+ with open(os.path.join('text_prompt', 'v1', obj + '_normal.txt'), 'r') as f:
87
+ for line in f.readlines():
88
+ sentence_abnormal.append(line.strip().lower())
89
+
90
+ prompted_sentences = tokenizer(sentence_abnormal).to(device)
91
+ class_embeddings = model.encode_text(prompted_sentences)
92
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
93
+ embeddings[obj] = class_embeddings
94
+ sentences[obj] = sentence_abnormal
95
+ return embeddings, sentences
96
+
97
+
98
+ def encode_obj_text(model, query_words, tokenizer, device):
99
+ # query_words = ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box']
100
+ # query_words = ['liquid', 'glass', "top", 'black background']
101
+ # query_words = ["connector", "grid"]
102
+ # query_words = [['screw'], 'plastic bag', 'background']
103
+ # query_words = [['pushpin', 'pin'], ['plastic box'], 'box', 'black background']
104
+ query_features = []
105
+ with torch.no_grad():
106
+ for qw in query_words:
107
+ token_input = []
108
+ if type(qw) == list:
109
+ for qw2 in qw:
110
+ token_input.extend([temp(qw2) for temp in openai_imagenet_template])
111
+ else:
112
+ token_input = [temp(qw) for temp in openai_imagenet_template]
113
+ query = tokenizer(token_input).to(device)
114
+ feature = model.encode_text(query)
115
+ feature /= feature.norm(dim=-1, keepdim=True)
116
+ feature = feature.mean(dim=0)
117
+ feature /= feature.norm()
118
+ query_features.append(feature.unsqueeze(0))
119
+ query_features = torch.cat(query_features, dim=0)
120
+ return query_features
121
+
requirements.txt ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.6.1
2
+ aiohttp==3.12.11
3
+ aiosignal==1.3.2
4
+ antlr4-python3-runtime==4.9.3
5
+ async-timeout==5.0.1
6
+ attrs==25.3.0
7
+ certifi==2025.4.26
8
+ charset-normalizer==3.4.2
9
+ contourpy==1.3.2
10
+ cycler==0.12.1
11
+ einops==0.6.1
12
+ faiss-cpu==1.8.0
13
+ filelock==3.18.0
14
+ fonttools==4.58.2
15
+ FrEIA==0.2
16
+ frozenlist==1.6.2
17
+ fsspec==2024.12.0
18
+ ftfy==6.3.1
19
+ hf-xet==1.1.3
20
+ huggingface-hub==0.32.4
21
+ idna==3.10
22
+ imageio==2.37.0
23
+ imgaug==0.4.0
24
+ Jinja2==3.1.6
25
+ joblib==1.5.1
26
+ jsonargparse==4.29.0
27
+ kiwisolver==1.4.8
28
+ kmeans-pytorch==0.3
29
+ kornia==0.7.0
30
+ lazy_loader==0.4
31
+ lightning==2.2.5
32
+ lightning-utilities==0.14.3
33
+ markdown-it-py==3.0.0
34
+ MarkupSafe==3.0.2
35
+ matplotlib==3.10.3
36
+ mdurl==0.1.2
37
+ mpmath==1.3.0
38
+ multidict==6.4.4
39
+ networkx==3.4.2
40
+ omegaconf==2.3.0
41
+ open-clip-torch==2.24.0
42
+ opencv-python==4.8.1.78
43
+ packaging==24.2
44
+ pandas==2.0.3
45
+ pillow==11.2.1
46
+ propcache==0.3.1
47
+ protobuf==6.31.1
48
+ Pygments==2.19.1
49
+ pyparsing==3.2.3
50
+ python-dateutil==2.9.0.post0
51
+ pytorch-lightning==2.5.1.post0
52
+ pytz==2025.2
53
+ PyYAML==6.0.2
54
+ regex==2024.11.6
55
+ requests==2.32.3
56
+ rich==13.7.1
57
+ safetensors==0.5.3
58
+ scikit-image==0.25.2
59
+ scikit-learn==1.2.2
60
+ scipy==1.15.3
61
+ segment-anything==1.0
62
+ sentencepiece==0.2.0
63
+ shapely==2.1.1
64
+ six==1.17.0
65
+ sympy==1.14.0
66
+ tabulate==0.9.0
67
+ threadpoolctl==3.6.0
68
+ tifffile==2025.5.10
69
+ timm==1.0.15
70
+ torchmetrics==1.7.2
71
+ tqdm==4.67.1
72
+ triton==2.1.0
73
+ typing_extensions==4.14.0
74
+ tzdata==2025.2
75
+ urllib3==2.4.0
76
+ wcwidth==0.2.13
77
+ yarl==1.20.0