zhiqing0205
commited on
Commit
·
74acc06
1
Parent(s):
4a80644
Add basic Python scripts and documentation
Browse files- LogSAD技术详解.md +621 -0
- README.md +102 -0
- compute_coreset.py +121 -0
- environment.yml +116 -0
- evaluation.py +257 -0
- imagenet_template.py +82 -0
- model_ensemble.py +1034 -0
- model_ensemble_few_shot.py +935 -0
- prompt_ensemble.py +121 -0
- requirements.txt +77 -0
LogSAD技术详解.md
ADDED
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LogSAD:基于视觉和语言基础模型的无训练异常检测方法详解
|
2 |
+
|
3 |
+
## 项目概述
|
4 |
+
|
5 |
+
LogSAD(Towards Training-free Anomaly Detection with Vision and Language Foundation Models)是一个发表在CVPR 2025的无需训练的异常检测方法。该方法通过结合多个预训练的视觉和语言基础模型,实现了对MVTec LOCO数据集的逻辑异常和结构异常检测。
|
6 |
+
|
7 |
+
## 整体架构与流程
|
8 |
+
|
9 |
+
### 核心理念
|
10 |
+
LogSAD的核心思想是利用预训练模型的强大表示能力,通过多模态特征融合和逻辑推理来检测异常,无需对特定数据集进行训练。
|
11 |
+
|
12 |
+
### 系统架构
|
13 |
+
```
|
14 |
+
输入图像 (448x448)
|
15 |
+
↓
|
16 |
+
┌─────────────────────────────────────────────────┐
|
17 |
+
│ 多模态特征提取层 │
|
18 |
+
│ ├─ CLIP ViT-L-14 (图像+文本特征) │
|
19 |
+
│ ├─ DINOv2 ViT-L-14 (图像特征) │
|
20 |
+
│ └─ SAM ViT-H (实例分割) │
|
21 |
+
└─────────────────────────────────────────────────┘
|
22 |
+
↓
|
23 |
+
┌─────────────────────────────────────────────────┐
|
24 |
+
│ 特征处理与融合层 │
|
25 |
+
│ ├─ K-means聚类分割 │
|
26 |
+
│ ├─ 文本引导的语义分割 │
|
27 |
+
│ └─ 多尺度特征融合 │
|
28 |
+
└─────────────────────────────────────────────────┘
|
29 |
+
↓
|
30 |
+
┌─────────────────────────────────────────────────┐
|
31 |
+
│ 异常检测层 │
|
32 |
+
│ ├─ 结构异常检测 (PatchCore) │
|
33 |
+
│ ├─ 逻辑异常检测 (直方图匹配) │
|
34 |
+
│ └─ 实例匹配检测 (Hungarian算法) │
|
35 |
+
└─────────────────────────────────────────────────┘
|
36 |
+
↓
|
37 |
+
最终异常分数
|
38 |
+
```
|
39 |
+
|
40 |
+
## 预训练模型详解
|
41 |
+
|
42 |
+
### 1. CLIP ViT-L-14 模型
|
43 |
+
**作用**:视觉-语言理解的核心
|
44 |
+
- **模型**:`hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K`
|
45 |
+
- **输入尺寸**:448×448
|
46 |
+
- **特征提取层**:[6, 12, 18, 24]
|
47 |
+
- **特征维度**:1024维
|
48 |
+
- **输出特征尺寸**:32×32 → 64×64(插值)
|
49 |
+
|
50 |
+
**具体实现**:
|
51 |
+
```python
|
52 |
+
# model_ensemble.py:96-97
|
53 |
+
self.model_clip, _, _ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
|
54 |
+
self.feature_list = [6, 12, 18, 24]
|
55 |
+
```
|
56 |
+
|
57 |
+
**协作机制**:
|
58 |
+
- 提供图像的语义特征表示
|
59 |
+
- 通过文本提示编码不同物体的语义信息
|
60 |
+
- 用于语义分割和异常分类
|
61 |
+
|
62 |
+
### 2. DINOv2 ViT-L-14 模型
|
63 |
+
**作用**:提供更丰富的视觉特征
|
64 |
+
- **模型**:`dinov2_vitl14`
|
65 |
+
- **特征提取层**:[6, 12, 18, 24]
|
66 |
+
- **特征维度**:1024维
|
67 |
+
- **输出特征尺寸**:32×32 → 64×64(插值)
|
68 |
+
|
69 |
+
**具体实现**:
|
70 |
+
```python
|
71 |
+
# model_ensemble.py:181-186
|
72 |
+
from dinov2.dinov2.hub.backbones import dinov2_vitl14
|
73 |
+
self.model_dinov2 = dinov2_vitl14()
|
74 |
+
self.feature_list_dinov2 = [6, 12, 18, 24]
|
75 |
+
```
|
76 |
+
|
77 |
+
**协作机制**:
|
78 |
+
- 为某些类别(splicing_connectors, breakfast_box, juice_bottle)提供更强的视觉特征
|
79 |
+
- 与CLIP特征互补,提高检测精度
|
80 |
+
|
81 |
+
### 3. SAM (Segment Anything Model)
|
82 |
+
**作用**:实例分割
|
83 |
+
- **模型**:ViT-H版本
|
84 |
+
- **检查点**:`./checkpoint/sam_vit_h_4b8939.pth`
|
85 |
+
- **功能**:自动生成物体mask
|
86 |
+
|
87 |
+
**具体实现**:
|
88 |
+
```python
|
89 |
+
# model_ensemble.py:102-103
|
90 |
+
self.model_sam = sam_model_registry["vit_h"](checkpoint = "./checkpoint/sam_vit_h_4b8939.pth")
|
91 |
+
self.mask_generator = SamAutomaticMaskGenerator(model = self.model_sam)
|
92 |
+
```
|
93 |
+
|
94 |
+
**协作机制**:
|
95 |
+
- 提供精确的物体边界
|
96 |
+
- 用于实例级别的异常检测
|
97 |
+
- 与语义分割结果融合
|
98 |
+
|
99 |
+
## 数据处理与尺寸变换详解
|
100 |
+
|
101 |
+
### 图像预处理流程
|
102 |
+
|
103 |
+
1. **输入尺寸标准化**:
|
104 |
+
```python
|
105 |
+
# evaluation.py:184
|
106 |
+
datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category)
|
107 |
+
```
|
108 |
+
|
109 |
+
2. **归一化处理**:
|
110 |
+
```python
|
111 |
+
# model_ensemble.py:88-92
|
112 |
+
self.transform = v2.Compose([
|
113 |
+
v2.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
|
114 |
+
std=(0.26862954, 0.26130258, 0.27577711)),
|
115 |
+
])
|
116 |
+
```
|
117 |
+
|
118 |
+
3. **特征图尺寸变换**:
|
119 |
+
```python
|
120 |
+
# model_ensemble.py:155-156
|
121 |
+
self.feat_size = 64 # 目标特征图大小
|
122 |
+
self.ori_feat_size = 32 # 原始特征图大小
|
123 |
+
```
|
124 |
+
|
125 |
+
### 详细的Resize流程
|
126 |
+
|
127 |
+
**CLIP特征处理**:
|
128 |
+
```python
|
129 |
+
# model_ensemble.py:245-255
|
130 |
+
# 1. ��32x32插值到64x64
|
131 |
+
patch_tokens_clip = patch_tokens_clip.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
132 |
+
patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size),
|
133 |
+
mode=self.inter_mode, align_corners=self.align_corners)
|
134 |
+
patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
|
135 |
+
```
|
136 |
+
|
137 |
+
**DINOv2特征处理**:
|
138 |
+
```python
|
139 |
+
# model_ensemble.py:253-263
|
140 |
+
# 相同的插值流程
|
141 |
+
patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size),
|
142 |
+
mode=self.inter_mode, align_corners=self.align_corners)
|
143 |
+
```
|
144 |
+
|
145 |
+
**插值参数**:
|
146 |
+
- **插值模式**:双线性插值(`bilinear`)
|
147 |
+
- **对齐角点**:`align_corners=True`
|
148 |
+
- **抗锯齿**:`antialias=True`
|
149 |
+
|
150 |
+
## SAM多Mask处理机制
|
151 |
+
|
152 |
+
### SAM生成多个Mask的处理
|
153 |
+
|
154 |
+
**Mask生成**:
|
155 |
+
```python
|
156 |
+
# model_ensemble.py:394
|
157 |
+
masks = self.mask_generator.generate(raw_image)
|
158 |
+
sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
|
159 |
+
```
|
160 |
+
|
161 |
+
**Mask融合策略**:
|
162 |
+
```python
|
163 |
+
# model_ensemble.py:347-367
|
164 |
+
def merge_segmentations(a, b, background_class):
|
165 |
+
"""将SAM mask与语义分割结果融合"""
|
166 |
+
# 通过投票机制确定每个SAM区域的语义标签
|
167 |
+
for label_a in unique_labels_a:
|
168 |
+
mask_a = (a == label_a)
|
169 |
+
labels_b = b[mask_a]
|
170 |
+
if labels_b.size > 0:
|
171 |
+
count_b = np.bincount(labels_b, minlength=unique_labels_b.max() + 1)
|
172 |
+
label_map[label_a] = np.argmax(count_b) # 多数投票
|
173 |
+
```
|
174 |
+
|
175 |
+
**多Mask协作流程**:
|
176 |
+
1. SAM生成所有可能的实例mask
|
177 |
+
2. K-means聚类生成语义分割mask
|
178 |
+
3. 文本引导生成patch级别的语义mask
|
179 |
+
4. 通过投票机制融合不同来源的mask
|
180 |
+
5. 过滤小区域噪声(阈值:32像素)
|
181 |
+
|
182 |
+
## Ground Truth多Mask处理机制
|
183 |
+
|
184 |
+
### MVTec LOCO数据集的Mask组织结构
|
185 |
+
|
186 |
+
**文件结构**:
|
187 |
+
```
|
188 |
+
dataset/
|
189 |
+
├── test/category/image_filename.png # 测试图像
|
190 |
+
├── ground_truth/category/image_filename/ # 对应的GT mask目录
|
191 |
+
│ ├── 000.png # 第一个异常区域mask
|
192 |
+
│ ├── 001.png # 第二个异常区域mask
|
193 |
+
│ ├── 002.png # 第三个异常区域mask
|
194 |
+
│ └── ... # 更多异常区域mask
|
195 |
+
```
|
196 |
+
|
197 |
+
**数据加载时的多Mask聚合**:
|
198 |
+
```python
|
199 |
+
# anomalib/data/image/mvtec_loco.py:142-148
|
200 |
+
mask_samples = (
|
201 |
+
mask_samples.groupby(["path", "split", "label", "image_folder"])["image_path"]
|
202 |
+
.agg(list) # 将同一图像的多个mask路径聚合成列表
|
203 |
+
.reset_index()
|
204 |
+
.rename(columns={"image_path": "mask_path"})
|
205 |
+
)
|
206 |
+
```
|
207 |
+
|
208 |
+
### 多Mask融合策略
|
209 |
+
|
210 |
+
**步骤1:Mask路径处理**:
|
211 |
+
```python
|
212 |
+
# anomalib/data/image/mvtec_loco.py:279-280
|
213 |
+
if isinstance(mask_path, str):
|
214 |
+
mask_path = [mask_path] # 确保mask_path是列表格式
|
215 |
+
```
|
216 |
+
|
217 |
+
**步骤2:语义Mask堆叠**:
|
218 |
+
```python
|
219 |
+
# anomalib/data/image/mvtec_loco.py:281-285
|
220 |
+
semantic_mask = (
|
221 |
+
Mask(torch.zeros(image.shape[-2:])).to(torch.uint8) # 正常图像:零mask
|
222 |
+
if label_index == LabelName.NORMAL
|
223 |
+
else Mask(torch.stack([self._read_mask(path) for path in mask_path])) # 异常图像:堆叠所有mask
|
224 |
+
)
|
225 |
+
```
|
226 |
+
|
227 |
+
**步骤3:二值Mask生成**:
|
228 |
+
```python
|
229 |
+
# anomalib/data/image/mvtec_loco.py:287
|
230 |
+
binary_mask = Mask(semantic_mask.view(-1, *semantic_mask.shape[-2:]).int().any(dim=0).to(torch.uint8))
|
231 |
+
```
|
232 |
+
|
233 |
+
### 关键融合机制解析
|
234 |
+
|
235 |
+
**维度变换**:
|
236 |
+
- 输入:多个mask,每个形状为 (H, W)
|
237 |
+
- 堆叠后:(N, H, W),其中N为mask数量
|
238 |
+
- `view(-1, H, W)`:重塑为 (N, H, W)
|
239 |
+
- `any(dim=0)`:沿第一维度求或运算,得到 (H, W)
|
240 |
+
|
241 |
+
**融合逻辑**:
|
242 |
+
```python
|
243 |
+
# 伪代码示例
|
244 |
+
mask1 = [[0, 1, 0], mask2 = [[0, 0, 1],
|
245 |
+
[1, 0, 1], [0, 1, 0],
|
246 |
+
[0, 1, 0]] [1, 0, 0]]
|
247 |
+
|
248 |
+
# 堆叠:shape (2, 3, 3)
|
249 |
+
stacked = torch.stack([mask1, mask2])
|
250 |
+
|
251 |
+
# any操作:逐像素求或
|
252 |
+
result = [[0, 1, 1], # max(0,0), max(1,0), max(0,1)
|
253 |
+
[1, 1, 1], # max(1,0), max(0,1), max(1,0)
|
254 |
+
[1, 1, 0]] # max(0,1), max(1,0), max(0,0)
|
255 |
+
```
|
256 |
+
|
257 |
+
### 数据加载完整流程
|
258 |
+
|
259 |
+
**MVTec LOCO数据项结构**:
|
260 |
+
```python
|
261 |
+
# 正常样本
|
262 |
+
item = {
|
263 |
+
"image_path": "/path/to/normal_image.png",
|
264 |
+
"label": 0,
|
265 |
+
"image": torch.Tensor(...),
|
266 |
+
"mask": torch.zeros(H, W), # 零mask
|
267 |
+
"mask_path": [], # 空列表
|
268 |
+
"semantic_mask": torch.zeros(H, W) # 零mask
|
269 |
+
}
|
270 |
+
|
271 |
+
# 异常样本
|
272 |
+
item = {
|
273 |
+
"image_path": "/path/to/abnormal_image.png",
|
274 |
+
"label": 1,
|
275 |
+
"image": torch.Tensor(...),
|
276 |
+
"mask": torch.Tensor(...), # 融合后的二值mask
|
277 |
+
"mask_path": [ # 多个mask路径列表
|
278 |
+
"/path/to/ground_truth/image/000.png",
|
279 |
+
"/path/to/ground_truth/image/001.png",
|
280 |
+
"/path/to/ground_truth/image/002.png"
|
281 |
+
],
|
282 |
+
"semantic_mask": torch.Tensor(...) # 原始多mask堆叠,shape (N, H, W)
|
283 |
+
}
|
284 |
+
```
|
285 |
+
|
286 |
+
### 评估时的Mask使用
|
287 |
+
|
288 |
+
**重要特性**:LogSAD在推理过程中**不使用**ground truth mask,完全基于输入图像进行异常检测。Ground truth mask仅用于:
|
289 |
+
|
290 |
+
1. **性能评估**:计算AUROC、F1等指标
|
291 |
+
2. **可视化对比**:与预测结果对比
|
292 |
+
3. **指标计算**:像素级和语义级异常检测性能
|
293 |
+
|
294 |
+
**验证机制**:
|
295 |
+
```python
|
296 |
+
# anomalib/data/image/mvtec_loco.py:158-174
|
297 |
+
# 验证mask文件与图像文件的对应关系
|
298 |
+
image_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["image_path"].apply(lambda x: Path(x).stem)
|
299 |
+
mask_parent_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["mask_path"].apply(
|
300 |
+
lambda x: {Path(mask_path).parent.stem for mask_path in x},
|
301 |
+
)
|
302 |
+
# 确保 image: '005.png' 对应 mask: '005/000.png', '005/001.png' 等
|
303 |
+
```
|
304 |
+
|
305 |
+
### 多Mask场景的实际应用
|
306 |
+
|
307 |
+
**典型场景**:
|
308 |
+
1. **Splicing Connectors**:连接器、电缆、夹具可能分别标注
|
309 |
+
2. **Juice Bottle**:液体、标签、瓶身缺陷可能分别标注
|
310 |
+
3. **Breakfast Box**:不同食物的缺失可能分别标注
|
311 |
+
4. **Screw Bag**:不同螺丝、螺母、垫圈的异常分别标注
|
312 |
+
|
313 |
+
**处理优势**:
|
314 |
+
- 保留了详细的异常区域信息
|
315 |
+
- 支持多类型异常的联合评估
|
316 |
+
- 便于细粒度的性能分析
|
317 |
+
- 兼容传统二值异常检测评估
|
318 |
+
|
319 |
+
## 关键特判逻辑详解
|
320 |
+
|
321 |
+
代码中存在**5个主要特判分支**,分别对应不同的数据集类别:
|
322 |
+
|
323 |
+
### 1. Pushpins类别特判
|
324 |
+
|
325 |
+
**位置**:`model_ensemble.py:432-479`
|
326 |
+
|
327 |
+
**逻辑**:
|
328 |
+
```python
|
329 |
+
if self.class_name == 'pushpins':
|
330 |
+
# 1. 物体计数检测
|
331 |
+
pushpins_count = num_labels - 1
|
332 |
+
if self.few_shot_inited and pushpins_count != self.pushpins_count:
|
333 |
+
self.anomaly_flag = True
|
334 |
+
|
335 |
+
# 2. Patch直方图匹配
|
336 |
+
clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
337 |
+
patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
|
338 |
+
score = 1 - patch_hist_similarity.max()
|
339 |
+
```
|
340 |
+
|
341 |
+
**检测异常类型**:
|
342 |
+
- 推钉数量异常(标准数量:15个)
|
343 |
+
- 颜色分布异常
|
344 |
+
|
345 |
+
### 2. Splicing Connectors类别特判
|
346 |
+
|
347 |
+
**位置**:`model_ensemble.py:481-615`
|
348 |
+
|
349 |
+
**复杂逻辑**:
|
350 |
+
```python
|
351 |
+
elif self.class_name == 'splicing_connectors':
|
352 |
+
# 1. 连接组件检测
|
353 |
+
if count != 1:
|
354 |
+
self.anomaly_flag = True
|
355 |
+
|
356 |
+
# 2. 电缆颜色与夹具数量匹配检测
|
357 |
+
foreground_pixel_count = np.sum(erode_binary) / self.splicing_connectors_count[idx_color]
|
358 |
+
ratio = foreground_pixel_count / self.foreground_pixel_hist_splicing_connectors
|
359 |
+
if ratio > 1.2 or ratio < 0.8:
|
360 |
+
self.anomaly_flag = True
|
361 |
+
|
362 |
+
# 3. 左右对称性检测
|
363 |
+
ratio = np.sum(left_count) / (np.sum(right_count) + 1e-5)
|
364 |
+
if ratio > 1.2 or ratio < 0.8:
|
365 |
+
self.anomaly_flag = True
|
366 |
+
|
367 |
+
# 4. 距离检测
|
368 |
+
distance = np.sqrt((x1/w - x2/w)**2 + (y1/h - y2/h)**2)
|
369 |
+
ratio = distance / self.splicing_connectors_distance
|
370 |
+
if ratio < 0.6 or ratio > 1.4:
|
371 |
+
self.anomaly_flag = True
|
372 |
+
```
|
373 |
+
|
374 |
+
**检测异常类型**:
|
375 |
+
- 电缆断裂或缺失
|
376 |
+
- 颜色与夹具数量不匹配(黄色2夹、蓝色3夹、红色5夹)
|
377 |
+
- 左右夹具不对称
|
378 |
+
- 电缆长度异常
|
379 |
+
|
380 |
+
### 3. Screw Bag类别特判
|
381 |
+
|
382 |
+
**位置**:`model_ensemble.py:617-670`
|
383 |
+
|
384 |
+
**逻辑**:
|
385 |
+
```python
|
386 |
+
elif self.class_name == 'screw_bag':
|
387 |
+
# 前景像素统计异常检测
|
388 |
+
foreground_pixel_count = np.sum(np.bincount(kmeans_mask.reshape(-1))[:len(self.foreground_label_idx[self.class_name])])
|
389 |
+
ratio = foreground_pixel_count / self.foreground_pixel_hist_screw_bag
|
390 |
+
if ratio < 0.94 or ratio > 1.06:
|
391 |
+
self.anomaly_flag = True
|
392 |
+
```
|
393 |
+
|
394 |
+
**检测异常类型**:
|
395 |
+
- 螺丝、螺母、垫圈数量异常
|
396 |
+
- 前景像素比例异常(阈值:±6%)
|
397 |
+
|
398 |
+
### 4. Juice Bottle类别特判
|
399 |
+
|
400 |
+
**位置**:`model_ensemble.py:715-771`
|
401 |
+
|
402 |
+
**逻辑**:
|
403 |
+
```python
|
404 |
+
elif self.class_name == 'juice_bottle':
|
405 |
+
# 液体与水果匹配检测
|
406 |
+
liquid_idx = (liquid_feature @ query_liquid.T).argmax(-1).squeeze(0).item()
|
407 |
+
fruit_idx = (fruit_feature @ query_fruit.T).argmax(-1).squeeze(0).item()
|
408 |
+
if liquid_idx != fruit_idx:
|
409 |
+
self.anomaly_flag = True
|
410 |
+
```
|
411 |
+
|
412 |
+
**检测异常类型**:
|
413 |
+
- 液体颜色与标签水果不匹配
|
414 |
+
- 标签错位
|
415 |
+
|
416 |
+
### 5. Breakfast Box类别特判
|
417 |
+
|
418 |
+
**位置**:`model_ensemble.py:672-713`
|
419 |
+
|
420 |
+
**逻辑**:
|
421 |
+
```python
|
422 |
+
elif self.class_name == 'breakfast_box':
|
423 |
+
# 主要依靠patch直方图匹配
|
424 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
425 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
426 |
+
score = 1 - patch_hist_similarity.max()
|
427 |
+
```
|
428 |
+
|
429 |
+
**检测异常类型**:
|
430 |
+
- 食物分布异常
|
431 |
+
- 缺失或多余物品
|
432 |
+
|
433 |
+
## Few-shot与Full-data模式区别
|
434 |
+
|
435 |
+
### 数据处理差异
|
436 |
+
|
437 |
+
**Few-shot模式**(`model_ensemble_few_shot.py`):
|
438 |
+
```python
|
439 |
+
# 直接使用所有few-shot样本
|
440 |
+
FEW_SHOT_SAMPLES = [0, 1, 2, 3] # 固定4个样本
|
441 |
+
self.k_shot = few_shot_samples.size(0)
|
442 |
+
```
|
443 |
+
|
444 |
+
**Full-data模式**(`model_ensemble.py`):
|
445 |
+
```python
|
446 |
+
# 使用完整训练集构建coreset
|
447 |
+
FEW_SHOT_SAMPLES = range(len(datamodule.train_data)) # 所有训练样本
|
448 |
+
self.k_shot = 4 if self.total_size > 4 else self.total_size
|
449 |
+
```
|
450 |
+
|
451 |
+
### Coreset子采样机制
|
452 |
+
|
453 |
+
**Few-shot模式**:无coreset,直接使用原始特征
|
454 |
+
```python
|
455 |
+
# model_ensemble_few_shot.py:852
|
456 |
+
self.mem_patch_feature_clip_coreset = patch_tokens_clip
|
457 |
+
self.mem_patch_feature_dinov2_coreset = patch_tokens_dinov2
|
458 |
+
```
|
459 |
+
|
460 |
+
**Full-data模式**:使用K-Center Greedy算法进行coreset子采样
|
461 |
+
```python
|
462 |
+
# model_ensemble.py:892-896
|
463 |
+
clip_sampler = KCenterGreedy(embedding=mem_patch_feature_clip_coreset, sampling_ratio=0.25)
|
464 |
+
mem_patch_feature_clip_coreset = clip_sampler.sample_coreset()
|
465 |
+
|
466 |
+
dinov2_sampler = KCenterGreedy(embedding=mem_patch_feature_dinov2_coreset, sampling_ratio=0.25)
|
467 |
+
mem_patch_feature_dinov2_coreset = dinov2_sampler.sample_coreset()
|
468 |
+
```
|
469 |
+
|
470 |
+
### 统计信息差异
|
471 |
+
|
472 |
+
**Few-shot模式**:
|
473 |
+
```python
|
474 |
+
# model_ensemble_few_shot.py:185
|
475 |
+
self.stats = pickle.load(open("memory_bank/statistic_scores_model_ensemble_few_shot_val.pkl", "rb"))
|
476 |
+
```
|
477 |
+
|
478 |
+
**Full-data模式**:
|
479 |
+
```python
|
480 |
+
# model_ensemble.py:188
|
481 |
+
self.stats = pickle.load(open("memory_bank/statistic_scores_model_ensemble_val.pkl", "rb"))
|
482 |
+
```
|
483 |
+
|
484 |
+
### 计算流程差异
|
485 |
+
|
486 |
+
**Few-shot模式流程**:
|
487 |
+
1. 直接计算4个样本的特征
|
488 |
+
2. 无需coreset计算
|
489 |
+
3. 直接进行异常检测
|
490 |
+
|
491 |
+
**Full-data模式流程**:
|
492 |
+
1. 计算所有训练样本特征(`compute_coreset.py`)
|
493 |
+
2. 使用K-Center Greedy算法选择代表性特征
|
494 |
+
3. 保存coreset到`memory_bank/`目录
|
495 |
+
4. 加载预计算的coreset进行异常检测
|
496 |
+
|
497 |
+
## 实现细节与优化
|
498 |
+
|
499 |
+
### 内存优化策略
|
500 |
+
|
501 |
+
**批处理机制**:
|
502 |
+
```python
|
503 |
+
# model_ensemble.py:926-928
|
504 |
+
for i in range(self.total_size//self.k_shot):
|
505 |
+
self.process(class_name, few_shot_samples[self.k_shot*i : min(self.k_shot*(i+1), self.total_size)],
|
506 |
+
few_shot_paths[self.k_shot*i : min(self.k_shot*(i+1), self.total_size)])
|
507 |
+
```
|
508 |
+
|
509 |
+
**特征缓存**:
|
510 |
+
- 预计算的coreset特征保存在`memory_bank/`目录
|
511 |
+
- 统计信息预计算并缓存
|
512 |
+
|
513 |
+
### 多模态特征融合
|
514 |
+
|
515 |
+
**特征层选择策略**:
|
516 |
+
- **聚类特征**:使用CLIP的第0、1层(`cluster_feature_id = [0, 1]`)
|
517 |
+
- **检测特征**:使用第6、12、18、24层的完整特征
|
518 |
+
|
519 |
+
**不同类别的模型选择**:
|
520 |
+
```python
|
521 |
+
# model_ensemble.py:290-310
|
522 |
+
if self.class_name in ['pushpins', 'screw_bag']:
|
523 |
+
# 使用CLIP特征进行PatchCore检测
|
524 |
+
len_feature_list = len(self.feature_list)
|
525 |
+
for patch_feature, mem_patch_feature in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1),
|
526 |
+
mem_patch_feature_clip_coreset.chunk(len_feature_list, dim=-1)):
|
527 |
+
|
528 |
+
if self.class_name in ['splicing_connectors', 'breakfast_box', 'juice_bottle']:
|
529 |
+
# 使用DINOv2特征进行PatchCore检测
|
530 |
+
len_feature_list = len(self.feature_list_dinov2)
|
531 |
+
for patch_feature, mem_patch_feature in zip(patch_tokens_dinov2.chunk(len_feature_list, dim=-1),
|
532 |
+
mem_patch_feature_dinov2_coreset.chunk(len_feature_list, dim=-1)):
|
533 |
+
```
|
534 |
+
|
535 |
+
## 文本提示工程
|
536 |
+
|
537 |
+
### 语义查询词典
|
538 |
+
|
539 |
+
**物体级别查询**:
|
540 |
+
```python
|
541 |
+
# model_ensemble.py:123-136
|
542 |
+
self.query_words_dict = {
|
543 |
+
"breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
|
544 |
+
"juice_bottle": ['bottle', ['black background', 'background']],
|
545 |
+
"pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
|
546 |
+
"screw_bag": [['screw'], 'plastic bag', 'background'],
|
547 |
+
"splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
|
548 |
+
}
|
549 |
+
```
|
550 |
+
|
551 |
+
**Patch级别查询**:
|
552 |
+
```python
|
553 |
+
# model_ensemble.py:138-145
|
554 |
+
self.patch_query_words_dict = {
|
555 |
+
"juice_bottle": [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
|
556 |
+
"screw_bag": [['hex screw', 'hexagon bolt'], ['hex nut', 'hexagon nut'], ['ring washer', 'ring gasket'], ['plastic bag', 'background']],
|
557 |
+
# ...
|
558 |
+
}
|
559 |
+
```
|
560 |
+
|
561 |
+
### 文本编码策略
|
562 |
+
|
563 |
+
**多模板编码**:
|
564 |
+
```python
|
565 |
+
# prompt_ensemble.py:98-120
|
566 |
+
def encode_obj_text(model, query_words, tokenizer, device):
|
567 |
+
for qw in query_words:
|
568 |
+
if type(qw) == list:
|
569 |
+
for qw2 in qw:
|
570 |
+
token_input.extend([temp(qw2) for temp in openai_imagenet_template])
|
571 |
+
else:
|
572 |
+
token_input = [temp(qw) for temp in openai_imagenet_template]
|
573 |
+
```
|
574 |
+
|
575 |
+
使用82个不同的ImageNet模板进行文本增强,提高文本特征的鲁棒性。
|
576 |
+
|
577 |
+
## 性能评估
|
578 |
+
|
579 |
+
### 评估指标
|
580 |
+
|
581 |
+
**图像级别指标**:
|
582 |
+
- F1-Max(Image)
|
583 |
+
- AUROC(Image)
|
584 |
+
|
585 |
+
**异常类型指标**:
|
586 |
+
- F1-Max(Logical):逻辑异常
|
587 |
+
- AUROC(Logical):逻辑异常
|
588 |
+
- F1-Max(Structural):结构异常
|
589 |
+
- AUROC(Structural):结构异常
|
590 |
+
|
591 |
+
### 评估流程
|
592 |
+
|
593 |
+
**数据分离**:
|
594 |
+
```python
|
595 |
+
# evaluation.py:222-227
|
596 |
+
if 'logical' not in image_path[0]:
|
597 |
+
image_metric_structure.update(output["pred_score"].cpu(), data["label"])
|
598 |
+
if 'structural' not in image_path[0]:
|
599 |
+
image_metric_logical.update(output["pred_score"].cpu(), data["label"])
|
600 |
+
```
|
601 |
+
|
602 |
+
**分数融合**:
|
603 |
+
```python
|
604 |
+
# model_ensemble.py:227-231
|
605 |
+
standard_structural_score = (structural_score - self.stats[self.class_name]["structural_scores"]["mean"]) / self.stats[self.class_name]["structural_scores"]["unbiased_std"]
|
606 |
+
standard_instance_hungarian_match_score = (instance_hungarian_match_score - self.stats[self.class_name]["instance_hungarian_match_scores"]["mean"]) / self.stats[self.class_name]["instance_hungarian_match_scores"]["unbiased_std"]
|
607 |
+
|
608 |
+
pred_score = max(standard_instance_hungarian_match_score, standard_structural_score)
|
609 |
+
pred_score = sigmoid(pred_score)
|
610 |
+
```
|
611 |
+
|
612 |
+
## 总结
|
613 |
+
|
614 |
+
LogSAD通过巧妙结合多个预训练模型的优势,实现了无需训练的异常检测:
|
615 |
+
|
616 |
+
1. **多模态协作**:CLIP提供语义理解、DINOv2提供视觉特征、SAM提供精确分割
|
617 |
+
2. **逻辑推理**:通过领域知识编码的特判逻辑检测复杂的逻辑异常
|
618 |
+
3. **特征融合**:多尺度特征提取和融合提高检测精度
|
619 |
+
4. **高效优化**:Coreset子采样和特征缓存机制保证实用性
|
620 |
+
|
621 |
+
该方法在MVTec LOCO数据集上取得了优异的性能,展示了预训练模型在异常检测任务中的巨大潜力。
|
README.md
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Towards Training-free Anomaly Detection with Vision and Language Foundation Models (CVPR 2025)
|
2 |
+
|
3 |
+
<div>
|
4 |
+
<a href="https://arxiv.org/abs/2503.18325"><img src="https://img.shields.io/static/v1?label=Arxiv&message=LogSAD&color=red&logo=arxiv"></a>  
|
5 |
+
</div>
|
6 |
+
|
7 |
+
## System Requirements
|
8 |
+
|
9 |
+
**Hardware Requirements:**
|
10 |
+
- **GPU Memory:** 32GB VRAM (for running complete experiments)
|
11 |
+
- **Storage:** 70GB free disk space (for models, datasets, and results)
|
12 |
+
- **CUDA:** Compatible GPU with CUDA 12.1 support
|
13 |
+
|
14 |
+
**Software Requirements:**
|
15 |
+
- Python 3.10
|
16 |
+
- Conda (recommended for environment management)
|
17 |
+
- CUDA 12.1 runtime
|
18 |
+
|
19 |
+
> **Note:** The memory and storage requirements are for running the full experimental pipeline on all categories with visualization enabled. Smaller experiments on individual categories may require less resources.
|
20 |
+
|
21 |
+
## Installation
|
22 |
+
|
23 |
+
### Automated Setup (Recommended)
|
24 |
+
|
25 |
+
Run the setup script to automatically configure the complete environment:
|
26 |
+
|
27 |
+
```bash
|
28 |
+
bash scripts/setup_environment.sh
|
29 |
+
```
|
30 |
+
|
31 |
+
This script will:
|
32 |
+
- Create a conda environment named `logsad` with Python 3.10
|
33 |
+
- Install PyTorch with CUDA 12.1 support
|
34 |
+
- Install all required dependencies from `requirements.txt`
|
35 |
+
- Configure numpy compatibility
|
36 |
+
|
37 |
+
### Manual Setup
|
38 |
+
|
39 |
+
If you prefer manual setup, download the checkpoint for [ViT-H SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) and put it in the checkpoint folder.
|
40 |
+
|
41 |
+
After installation, activate the environment:
|
42 |
+
```bash
|
43 |
+
conda activate logsad
|
44 |
+
```
|
45 |
+
|
46 |
+
|
47 |
+
## Instructions for MVTEC LOCO dataset
|
48 |
+
|
49 |
+
### Quick Start (Recommended)
|
50 |
+
|
51 |
+
Run evaluation for all categories using the provided shell scripts:
|
52 |
+
|
53 |
+
**Few-shot Protocol:**
|
54 |
+
```bash
|
55 |
+
bash scripts/run_few_shot.sh
|
56 |
+
```
|
57 |
+
|
58 |
+
**Full-data Protocol:**
|
59 |
+
```bash
|
60 |
+
bash scripts/run_full_data.sh
|
61 |
+
```
|
62 |
+
|
63 |
+
### Manual Execution
|
64 |
+
|
65 |
+
#### Few-shot Protocol
|
66 |
+
Run the script for few-shot protocal:
|
67 |
+
|
68 |
+
```
|
69 |
+
python evaluation.py --module_path model_ensemble_few_shot --category CATEGORY --dataset_path DATASET_PATH
|
70 |
+
```
|
71 |
+
|
72 |
+
#### Full-data Protocol
|
73 |
+
Run the script to compute coreset for full-data scenarios:
|
74 |
+
|
75 |
+
```
|
76 |
+
python compute_coreset.py --module_path model_ensemble --category CATEGORY --dataset_path DATASET_PATH
|
77 |
+
```
|
78 |
+
|
79 |
+
Run the script for full-data protocol:
|
80 |
+
|
81 |
+
```
|
82 |
+
python evaluation.py --module_path model_ensemble --category CATEGORY --dataset_path DATASET_PATH
|
83 |
+
```
|
84 |
+
|
85 |
+
**Available categories:** breakfast_box, juice_bottle, pushpins, screw_bag, splicing_connectors
|
86 |
+
|
87 |
+
|
88 |
+
## Acknowledgement
|
89 |
+
We are grateful for the following awesome projects when implementing LogSAD:
|
90 |
+
* [SAM](https://github.com/facebookresearch/segment-anything), [OpenCLIP](https://github.com/mlfoundations/open_clip), [DINOv2](https://github.com/facebookresearch/dinov2) and [NACLIP](https://github.com/sinahmr/NACLIP).
|
91 |
+
|
92 |
+
|
93 |
+
## Citation
|
94 |
+
If you find our paper is helpful in your research or applications, generously cite with
|
95 |
+
```
|
96 |
+
@inproceedings{zhang2025logsad,
|
97 |
+
title={Towards Training-free Anomaly Detection with Vision and Language Foundation Models},
|
98 |
+
author={Jinjin Zhang, Guodong Wang, Yizhou Jin, Di Huang},
|
99 |
+
year={2025},
|
100 |
+
booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
101 |
+
}
|
102 |
+
```
|
compute_coreset.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Sample evaluation script for track 2."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
# Set cache directories to use checkpoint folder for model downloads
|
6 |
+
os.environ['TORCH_HOME'] = './checkpoint'
|
7 |
+
os.environ['HF_HOME'] = './checkpoint/huggingface'
|
8 |
+
os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers'
|
9 |
+
os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub'
|
10 |
+
|
11 |
+
# Create checkpoint subdirectories if they don't exist
|
12 |
+
os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True)
|
13 |
+
os.makedirs('./checkpoint/huggingface/hub', exist_ok=True)
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import importlib
|
17 |
+
import importlib.util
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import logging
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
# NOTE: The following MVTecLoco import is not available in anomalib v1.0.1.
|
24 |
+
# It will be available in v1.1.0 which will be released on April 29th, 2024.
|
25 |
+
# If you are using an earlier version of anomalib, you could install anomalib
|
26 |
+
# from the anomalib source code from the following branch:
|
27 |
+
# https://github.com/openvinotoolkit/anomalib/tree/feature/mvtec-loco
|
28 |
+
from anomalib.data import MVTecLoco
|
29 |
+
from anomalib.metrics.f1_max import F1Max
|
30 |
+
from anomalib.metrics.auroc import AUROC
|
31 |
+
from tabulate import tabulate
|
32 |
+
import numpy as np
|
33 |
+
|
34 |
+
# FEW_SHOT_SAMPLES = [0, 1, 2, 3]
|
35 |
+
|
36 |
+
def parse_args() -> argparse.Namespace:
|
37 |
+
"""Parse command line arguments.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
argparse.Namespace: Parsed arguments.
|
41 |
+
"""
|
42 |
+
parser = argparse.ArgumentParser()
|
43 |
+
parser.add_argument("--module_path", type=str, required=True)
|
44 |
+
parser.add_argument("--class_name", default='MyModel', type=str, required=False)
|
45 |
+
parser.add_argument("--weights_path", type=str, required=False)
|
46 |
+
parser.add_argument("--dataset_path", default='/home/bhu/Project/datasets/mvtec_loco_anomaly_detection/', type=str, required=False)
|
47 |
+
parser.add_argument("--category", type=str, required=True)
|
48 |
+
parser.add_argument("--viz", action='store_true', default=False)
|
49 |
+
return parser.parse_args()
|
50 |
+
|
51 |
+
|
52 |
+
def load_model(module_path: str, class_name: str, weights_path: str) -> nn.Module:
|
53 |
+
"""Load model.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
module_path (str): Path to the module containing the model class.
|
57 |
+
class_name (str): Name of the model class.
|
58 |
+
weights_path (str): Path to the model weights.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
nn.Module: Loaded model.
|
62 |
+
"""
|
63 |
+
# get model class
|
64 |
+
model_class = getattr(importlib.import_module(module_path), class_name)
|
65 |
+
# instantiate model
|
66 |
+
model = model_class()
|
67 |
+
# load weights
|
68 |
+
if weights_path:
|
69 |
+
model.load_state_dict(torch.load(weights_path))
|
70 |
+
return model
|
71 |
+
|
72 |
+
|
73 |
+
def run(module_path: str, class_name: str, weights_path: str, dataset_path: str, category: str, viz: bool) -> None:
|
74 |
+
"""Run the evaluation script.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
module_path (str): Path to the module containing the model class.
|
78 |
+
class_name (str): Name of the model class.
|
79 |
+
weights_path (str): Path to the model weights.
|
80 |
+
dataset_path (str): Path to the dataset.
|
81 |
+
category (str): Category of the dataset.
|
82 |
+
"""
|
83 |
+
# Disable verbose logging from all libraries
|
84 |
+
logging.getLogger().setLevel(logging.ERROR)
|
85 |
+
logging.getLogger('anomalib').setLevel(logging.ERROR)
|
86 |
+
logging.getLogger('lightning').setLevel(logging.ERROR)
|
87 |
+
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
88 |
+
|
89 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
90 |
+
|
91 |
+
# Instantiate model class here
|
92 |
+
# Load the model here from checkpoint.
|
93 |
+
model = load_model(module_path, class_name, weights_path)
|
94 |
+
model.to(device)
|
95 |
+
|
96 |
+
# Create the dataset
|
97 |
+
datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category)
|
98 |
+
datamodule.setup()
|
99 |
+
|
100 |
+
model.set_viz(viz)
|
101 |
+
model.set_save_coreset_features(True)
|
102 |
+
|
103 |
+
|
104 |
+
FEW_SHOT_SAMPLES = range(len(datamodule.train_data)) # traverse all dataset to build coreset
|
105 |
+
|
106 |
+
# pass few-shot images and dataset category to model
|
107 |
+
setup_data = {
|
108 |
+
"few_shot_samples": torch.stack([datamodule.train_data[idx]["image"] for idx in FEW_SHOT_SAMPLES]).to(device),
|
109 |
+
"few_shot_samples_path": [datamodule.train_data[idx]["image_path"] for idx in FEW_SHOT_SAMPLES],
|
110 |
+
"dataset_category": category,
|
111 |
+
}
|
112 |
+
model.setup(setup_data)
|
113 |
+
|
114 |
+
print(f"✓ Coreset computation completed for {category}")
|
115 |
+
print(f" Memory bank features saved to memory_bank/ directory")
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
args = parse_args()
|
121 |
+
run(args.module_path, args.class_name, args.weights_path, args.dataset_path, args.category, args.viz)
|
environment.yml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: logsad
|
2 |
+
channels:
|
3 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
|
4 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/pro/
|
5 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
|
6 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
|
7 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
|
8 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
|
9 |
+
- defaults
|
10 |
+
dependencies:
|
11 |
+
- _libgcc_mutex=0.1=conda_forge
|
12 |
+
- _openmp_mutex=4.5=2_gnu
|
13 |
+
- bzip2=1.0.8=h4bc722e_7
|
14 |
+
- ca-certificates=2025.8.3=hbd8a1cb_0
|
15 |
+
- ld_impl_linux-64=2.44=h1423503_1
|
16 |
+
- libexpat=2.7.1=hecca717_0
|
17 |
+
- libffi=3.4.6=h2dba641_1
|
18 |
+
- libgcc=15.1.0=h767d61c_4
|
19 |
+
- libgcc-ng=15.1.0=h69a702a_4
|
20 |
+
- libgomp=15.1.0=h767d61c_4
|
21 |
+
- liblzma=5.8.1=hb9d3cd8_2
|
22 |
+
- libnsl=2.0.1=hb9d3cd8_1
|
23 |
+
- libsqlite=3.50.4=h0c1763c_0
|
24 |
+
- libuuid=2.38.1=h0b41bf4_0
|
25 |
+
- libxcrypt=4.4.36=hd590300_1
|
26 |
+
- libzlib=1.3.1=hb9d3cd8_2
|
27 |
+
- ncurses=6.5=h2d0b736_3
|
28 |
+
- openssl=3.5.2=h26f9b46_0
|
29 |
+
- pip=25.2=pyh8b19718_0
|
30 |
+
- python=3.10.18=hd6af730_0_cpython
|
31 |
+
- readline=8.2=h8c095d6_2
|
32 |
+
- setuptools=80.9.0=pyhff2d567_0
|
33 |
+
- tk=8.6.13=noxft_hd72426e_102
|
34 |
+
- wheel=0.45.1=pyhd8ed1ab_1
|
35 |
+
- pip:
|
36 |
+
- aiohappyeyeballs==2.6.1
|
37 |
+
- aiohttp==3.12.11
|
38 |
+
- aiosignal==1.3.2
|
39 |
+
- antlr4-python3-runtime==4.9.3
|
40 |
+
- async-timeout==5.0.1
|
41 |
+
- attrs==25.3.0
|
42 |
+
- certifi==2025.4.26
|
43 |
+
- charset-normalizer==3.4.2
|
44 |
+
- contourpy==1.3.2
|
45 |
+
- cycler==0.12.1
|
46 |
+
- einops==0.6.1
|
47 |
+
- faiss-cpu==1.8.0
|
48 |
+
- filelock==3.18.0
|
49 |
+
- fonttools==4.58.2
|
50 |
+
- freia==0.2
|
51 |
+
- frozenlist==1.6.2
|
52 |
+
- fsspec==2024.12.0
|
53 |
+
- ftfy==6.3.1
|
54 |
+
- hf-xet==1.1.3
|
55 |
+
- huggingface-hub==0.32.4
|
56 |
+
- idna==3.10
|
57 |
+
- imageio==2.37.0
|
58 |
+
- imgaug==0.4.0
|
59 |
+
- jinja2==3.1.6
|
60 |
+
- joblib==1.5.1
|
61 |
+
- jsonargparse==4.29.0
|
62 |
+
- kiwisolver==1.4.8
|
63 |
+
- kmeans-pytorch==0.3
|
64 |
+
- kornia==0.7.0
|
65 |
+
- lazy-loader==0.4
|
66 |
+
- lightning==2.2.5
|
67 |
+
- lightning-utilities==0.14.3
|
68 |
+
- markdown-it-py==3.0.0
|
69 |
+
- markupsafe==3.0.2
|
70 |
+
- matplotlib==3.10.3
|
71 |
+
- mdurl==0.1.2
|
72 |
+
- mpmath==1.3.0
|
73 |
+
- multidict==6.4.4
|
74 |
+
- networkx==3.4.2
|
75 |
+
- numpy==1.23.1
|
76 |
+
- omegaconf==2.3.0
|
77 |
+
- open-clip-torch==2.24.0
|
78 |
+
- opencv-python==4.8.1.78
|
79 |
+
- packaging==24.2
|
80 |
+
- pandas==2.0.3
|
81 |
+
- pillow==11.2.1
|
82 |
+
- propcache==0.3.1
|
83 |
+
- protobuf==6.31.1
|
84 |
+
- pygments==2.19.1
|
85 |
+
- pyparsing==3.2.3
|
86 |
+
- python-dateutil==2.9.0.post0
|
87 |
+
- pytorch-lightning==2.5.1.post0
|
88 |
+
- pytz==2025.2
|
89 |
+
- pyyaml==6.0.2
|
90 |
+
- regex==2024.11.6
|
91 |
+
- requests==2.32.3
|
92 |
+
- rich==13.7.1
|
93 |
+
- safetensors==0.5.3
|
94 |
+
- scikit-image==0.25.2
|
95 |
+
- scikit-learn==1.2.2
|
96 |
+
- scipy==1.15.3
|
97 |
+
- segment-anything==1.0
|
98 |
+
- sentencepiece==0.2.0
|
99 |
+
- shapely==2.1.1
|
100 |
+
- six==1.17.0
|
101 |
+
- sympy==1.14.0
|
102 |
+
- tabulate==0.9.0
|
103 |
+
- threadpoolctl==3.6.0
|
104 |
+
- tifffile==2025.5.10
|
105 |
+
- timm==1.0.15
|
106 |
+
- torch==2.1.2+cu121
|
107 |
+
- torchmetrics==1.7.2
|
108 |
+
- torchvision==0.16.2+cu121
|
109 |
+
- tqdm==4.67.1
|
110 |
+
- triton==2.1.0
|
111 |
+
- typing-extensions==4.14.0
|
112 |
+
- tzdata==2025.2
|
113 |
+
- urllib3==2.4.0
|
114 |
+
- wcwidth==0.2.13
|
115 |
+
- yarl==1.20.0
|
116 |
+
prefix: /opt/conda/envs/logsad
|
evaluation.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Sample evaluation script for track 2."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
from datetime import datetime
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
# Set cache directories to use checkpoint folder for model downloads
|
8 |
+
os.environ['TORCH_HOME'] = './checkpoint'
|
9 |
+
os.environ['HF_HOME'] = './checkpoint/huggingface'
|
10 |
+
os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers'
|
11 |
+
os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub'
|
12 |
+
|
13 |
+
# Create checkpoint subdirectories if they don't exist
|
14 |
+
os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True)
|
15 |
+
os.makedirs('./checkpoint/huggingface/hub', exist_ok=True)
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import importlib
|
19 |
+
import importlib.util
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import logging
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
# NOTE: The following MVTecLoco import is not available in anomalib v1.0.1.
|
26 |
+
# It will be available in v1.1.0 which will be released on April 29th, 2024.
|
27 |
+
# If you are using an earlier version of anomalib, you could install anomalib
|
28 |
+
# from the anomalib source code from the following branch:
|
29 |
+
# https://github.com/openvinotoolkit/anomalib/tree/feature/mvtec-loco
|
30 |
+
from anomalib.data import MVTecLoco
|
31 |
+
from anomalib.metrics.f1_max import F1Max
|
32 |
+
from anomalib.metrics.auroc import AUROC
|
33 |
+
from tabulate import tabulate
|
34 |
+
import numpy as np
|
35 |
+
|
36 |
+
FEW_SHOT_SAMPLES = [0, 1, 2, 3]
|
37 |
+
|
38 |
+
def write_results_to_markdown(category, results_data, module_path):
|
39 |
+
"""Write evaluation results to markdown file.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
category (str): Dataset category name
|
43 |
+
results_data (dict): Dictionary containing all metrics
|
44 |
+
module_path (str): Model module path (for protocol identification)
|
45 |
+
"""
|
46 |
+
# Determine protocol type from module path
|
47 |
+
protocol = "Few-shot" if "few_shot" in module_path else "Full-data"
|
48 |
+
|
49 |
+
# Create results directory
|
50 |
+
results_dir = Path("results")
|
51 |
+
results_dir.mkdir(exist_ok=True)
|
52 |
+
|
53 |
+
# Combined results file with simple protocol name
|
54 |
+
protocol_suffix = "few_shot" if "few_shot" in module_path else "full_data"
|
55 |
+
combined_file = results_dir / f"{protocol_suffix}_results.md"
|
56 |
+
|
57 |
+
# Read existing results if file exists
|
58 |
+
existing_results = {}
|
59 |
+
if combined_file.exists():
|
60 |
+
with open(combined_file, 'r') as f:
|
61 |
+
content = f.read()
|
62 |
+
# Parse existing results (basic parsing)
|
63 |
+
lines = content.split('\n')
|
64 |
+
for line in lines:
|
65 |
+
if '|' in line and line.count('|') >= 8:
|
66 |
+
parts = [p.strip() for p in line.split('|')]
|
67 |
+
if len(parts) >= 8 and parts[1] != 'Category' and parts[1] != '-----':
|
68 |
+
existing_results[parts[1]] = {
|
69 |
+
'k_shots': parts[2],
|
70 |
+
'f1_image': parts[3],
|
71 |
+
'auroc_image': parts[4],
|
72 |
+
'f1_logical': parts[5],
|
73 |
+
'auroc_logical': parts[6],
|
74 |
+
'f1_structural': parts[7],
|
75 |
+
'auroc_structural': parts[8]
|
76 |
+
}
|
77 |
+
|
78 |
+
# Add current results
|
79 |
+
existing_results[category] = {
|
80 |
+
'k_shots': str(len(FEW_SHOT_SAMPLES)),
|
81 |
+
'f1_image': f"{results_data['f1_image']:.2f}",
|
82 |
+
'auroc_image': f"{results_data['auroc_image']:.2f}",
|
83 |
+
'f1_logical': f"{results_data['f1_logical']:.2f}",
|
84 |
+
'auroc_logical': f"{results_data['auroc_logical']:.2f}",
|
85 |
+
'f1_structural': f"{results_data['f1_structural']:.2f}",
|
86 |
+
'auroc_structural': f"{results_data['auroc_structural']:.2f}"
|
87 |
+
}
|
88 |
+
|
89 |
+
# Write combined results
|
90 |
+
with open(combined_file, 'w') as f:
|
91 |
+
f.write(f"# All Categories - {protocol} Protocol Results\n\n")
|
92 |
+
f.write(f"**Last Updated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
93 |
+
f.write(f"**Protocol:** {protocol}\n")
|
94 |
+
f.write(f"**Available Categories:** {', '.join(sorted(existing_results.keys()))}\n\n")
|
95 |
+
|
96 |
+
f.write("## Summary Table\n\n")
|
97 |
+
f.write("| Category | K-shots | F1-Max (Image) | AUROC (Image) | F1-Max (Logical) | AUROC (Logical) | F1-Max (Structural) | AUROC (Structural) |\n")
|
98 |
+
f.write("|----------|---------|----------------|---------------|------------------|-----------------|---------------------|-------------------|\n")
|
99 |
+
|
100 |
+
# Sort categories alphabetically
|
101 |
+
for cat in sorted(existing_results.keys()):
|
102 |
+
data = existing_results[cat]
|
103 |
+
f.write(f"| {cat} | {data['k_shots']} | {data['f1_image']} | {data['auroc_image']} | {data['f1_logical']} | {data['auroc_logical']} | {data['f1_structural']} | {data['auroc_structural']} |\n")
|
104 |
+
|
105 |
+
print(f"\n✓ Results saved to:")
|
106 |
+
print(f" - Combined: {combined_file}")
|
107 |
+
|
108 |
+
def parse_args() -> argparse.Namespace:
|
109 |
+
"""Parse command line arguments.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
argparse.Namespace: Parsed arguments.
|
113 |
+
"""
|
114 |
+
parser = argparse.ArgumentParser()
|
115 |
+
parser.add_argument("--module_path", type=str, required=True)
|
116 |
+
parser.add_argument("--class_name", default='MyModel', type=str, required=False)
|
117 |
+
parser.add_argument("--weights_path", type=str, required=False)
|
118 |
+
parser.add_argument("--dataset_path", default='/home/bhu/Project/datasets/mvtec_loco_anomaly_detection/', type=str, required=False)
|
119 |
+
parser.add_argument("--category", type=str, required=True)
|
120 |
+
parser.add_argument("--viz", action='store_true', default=False)
|
121 |
+
return parser.parse_args()
|
122 |
+
|
123 |
+
|
124 |
+
def load_model(module_path: str, class_name: str, weights_path: str) -> nn.Module:
|
125 |
+
"""Load model.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
module_path (str): Path to the module containing the model class.
|
129 |
+
class_name (str): Name of the model class.
|
130 |
+
weights_path (str): Path to the model weights.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
nn.Module: Loaded model.
|
134 |
+
"""
|
135 |
+
# get model class
|
136 |
+
model_class = getattr(importlib.import_module(module_path), class_name)
|
137 |
+
# instantiate model
|
138 |
+
model = model_class()
|
139 |
+
# load weights
|
140 |
+
if weights_path:
|
141 |
+
model.load_state_dict(torch.load(weights_path))
|
142 |
+
return model
|
143 |
+
|
144 |
+
|
145 |
+
def run(module_path: str, class_name: str, weights_path: str, dataset_path: str, category: str, viz: bool) -> None:
|
146 |
+
"""Run the evaluation script.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
module_path (str): Path to the module containing the model class.
|
150 |
+
class_name (str): Name of the model class.
|
151 |
+
weights_path (str): Path to the model weights.
|
152 |
+
dataset_path (str): Path to the dataset.
|
153 |
+
category (str): Category of the dataset.
|
154 |
+
"""
|
155 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
156 |
+
|
157 |
+
# Instantiate model class here
|
158 |
+
# Load the model here from checkpoint.
|
159 |
+
model = load_model(module_path, class_name, weights_path)
|
160 |
+
model.to(device)
|
161 |
+
|
162 |
+
#
|
163 |
+
# Create the dataset
|
164 |
+
datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category)
|
165 |
+
datamodule.setup()
|
166 |
+
|
167 |
+
model.set_viz(viz)
|
168 |
+
|
169 |
+
#
|
170 |
+
# Create the metrics
|
171 |
+
image_metric = F1Max()
|
172 |
+
pixel_metric = F1Max()
|
173 |
+
|
174 |
+
image_metric_logical = F1Max()
|
175 |
+
image_metric_structure = F1Max()
|
176 |
+
|
177 |
+
image_metric_auroc = AUROC()
|
178 |
+
pixel_metric_auroc = AUROC()
|
179 |
+
|
180 |
+
image_metric_auroc_logical = AUROC()
|
181 |
+
image_metric_auroc_structure = AUROC()
|
182 |
+
|
183 |
+
|
184 |
+
#
|
185 |
+
# pass few-shot images and dataset category to model
|
186 |
+
setup_data = {
|
187 |
+
"few_shot_samples": torch.stack([datamodule.train_data[idx]["image"] for idx in FEW_SHOT_SAMPLES]).to(device),
|
188 |
+
"few_shot_samples_path": [datamodule.train_data[idx]["image_path"] for idx in FEW_SHOT_SAMPLES],
|
189 |
+
"dataset_category": category,
|
190 |
+
}
|
191 |
+
model.setup(setup_data)
|
192 |
+
|
193 |
+
# Loop over the test set and compute the metrics
|
194 |
+
for data in datamodule.test_dataloader():
|
195 |
+
with torch.no_grad():
|
196 |
+
image_path = data['image_path']
|
197 |
+
output = model(data["image"].to(device), data['image_path'])
|
198 |
+
|
199 |
+
image_metric.update(output["pred_score"].cpu(), data["label"])
|
200 |
+
image_metric_auroc.update(output["pred_score"].cpu(), data["label"])
|
201 |
+
|
202 |
+
if 'logical' not in image_path[0]:
|
203 |
+
image_metric_structure.update(output["pred_score"].cpu(), data["label"])
|
204 |
+
image_metric_auroc_structure.update(output["pred_score"].cpu(), data["label"])
|
205 |
+
if 'structural' not in image_path[0]:
|
206 |
+
image_metric_logical.update(output["pred_score"].cpu(), data["label"])
|
207 |
+
image_metric_auroc_logical.update(output["pred_score"].cpu(), data["label"])
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
# Disable verbose logging from all libraries
|
212 |
+
logging.getLogger().setLevel(logging.ERROR)
|
213 |
+
logging.getLogger('anomalib').setLevel(logging.ERROR)
|
214 |
+
logging.getLogger('lightning').setLevel(logging.ERROR)
|
215 |
+
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
216 |
+
|
217 |
+
# Set up our own logger for results only
|
218 |
+
logger = logging.getLogger('evaluation')
|
219 |
+
logger.handlers.clear()
|
220 |
+
logger.setLevel(logging.INFO)
|
221 |
+
formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S')
|
222 |
+
console_handler = logging.StreamHandler()
|
223 |
+
console_handler.setFormatter(formatter)
|
224 |
+
logger.addHandler(console_handler)
|
225 |
+
|
226 |
+
table_ls = [[category,
|
227 |
+
str(len(FEW_SHOT_SAMPLES)),
|
228 |
+
str(np.round(image_metric.compute().item() * 100, decimals=2)),
|
229 |
+
str(np.round(image_metric_auroc.compute().item() * 100, decimals=2)),
|
230 |
+
# str(np.round(pixel_metric.compute().item() * 100, decimals=2)),
|
231 |
+
# str(np.round(pixel_metric_auroc.compute().item() * 100, decimals=2)),
|
232 |
+
str(np.round(image_metric_logical.compute().item() * 100, decimals=2)),
|
233 |
+
str(np.round(image_metric_auroc_logical.compute().item() * 100, decimals=2)),
|
234 |
+
str(np.round(image_metric_structure.compute().item() * 100, decimals=2)),
|
235 |
+
str(np.round(image_metric_auroc_structure.compute().item() * 100, decimals=2)),
|
236 |
+
]]
|
237 |
+
|
238 |
+
results = tabulate(table_ls, headers=['category', 'K-shots', 'F1-Max(image)', 'AUROC(image)', 'F1-Max (logical)', 'AUROC (logical)', 'F1-Max (structural)', 'AUROC (structural)'], tablefmt="pipe")
|
239 |
+
|
240 |
+
logger.info("\n%s", results)
|
241 |
+
|
242 |
+
# Save results to markdown
|
243 |
+
results_data = {
|
244 |
+
'f1_image': np.round(image_metric.compute().item() * 100, decimals=2),
|
245 |
+
'auroc_image': np.round(image_metric_auroc.compute().item() * 100, decimals=2),
|
246 |
+
'f1_logical': np.round(image_metric_logical.compute().item() * 100, decimals=2),
|
247 |
+
'auroc_logical': np.round(image_metric_auroc_logical.compute().item() * 100, decimals=2),
|
248 |
+
'f1_structural': np.round(image_metric_structure.compute().item() * 100, decimals=2),
|
249 |
+
'auroc_structural': np.round(image_metric_auroc_structure.compute().item() * 100, decimals=2)
|
250 |
+
}
|
251 |
+
write_results_to_markdown(category, results_data, module_path)
|
252 |
+
|
253 |
+
|
254 |
+
|
255 |
+
if __name__ == "__main__":
|
256 |
+
args = parse_args()
|
257 |
+
run(args.module_path, args.class_name, args.weights_path, args.dataset_path, args.category, args.viz)
|
imagenet_template.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai_imagenet_template = [
|
2 |
+
lambda c: f'a bad photo of a {c}.',
|
3 |
+
lambda c: f'a photo of many {c}.',
|
4 |
+
lambda c: f'a sculpture of a {c}.',
|
5 |
+
lambda c: f'a photo of the hard to see {c}.',
|
6 |
+
lambda c: f'a low resolution photo of the {c}.',
|
7 |
+
lambda c: f'a rendering of a {c}.',
|
8 |
+
lambda c: f'graffiti of a {c}.',
|
9 |
+
lambda c: f'a bad photo of the {c}.',
|
10 |
+
lambda c: f'a cropped photo of the {c}.',
|
11 |
+
lambda c: f'a tattoo of a {c}.',
|
12 |
+
lambda c: f'the embroidered {c}.',
|
13 |
+
lambda c: f'a photo of a hard to see {c}.',
|
14 |
+
lambda c: f'a bright photo of a {c}.',
|
15 |
+
lambda c: f'a photo of a clean {c}.',
|
16 |
+
lambda c: f'a photo of a dirty {c}.',
|
17 |
+
lambda c: f'a dark photo of the {c}.',
|
18 |
+
lambda c: f'a drawing of a {c}.',
|
19 |
+
lambda c: f'a photo of my {c}.',
|
20 |
+
lambda c: f'the plastic {c}.',
|
21 |
+
lambda c: f'a photo of the cool {c}.',
|
22 |
+
lambda c: f'a close-up photo of a {c}.',
|
23 |
+
lambda c: f'a black and white photo of the {c}.',
|
24 |
+
lambda c: f'a painting of the {c}.',
|
25 |
+
lambda c: f'a painting of a {c}.',
|
26 |
+
lambda c: f'a pixelated photo of the {c}.',
|
27 |
+
lambda c: f'a sculpture of the {c}.',
|
28 |
+
lambda c: f'a bright photo of the {c}.',
|
29 |
+
lambda c: f'a cropped photo of a {c}.',
|
30 |
+
lambda c: f'a plastic {c}.',
|
31 |
+
lambda c: f'a photo of the dirty {c}.',
|
32 |
+
lambda c: f'a jpeg corrupted photo of a {c}.',
|
33 |
+
lambda c: f'a blurry photo of the {c}.',
|
34 |
+
lambda c: f'a photo of the {c}.',
|
35 |
+
lambda c: f'a good photo of the {c}.',
|
36 |
+
lambda c: f'a rendering of the {c}.',
|
37 |
+
lambda c: f'a {c} in a video game.',
|
38 |
+
lambda c: f'a photo of one {c}.',
|
39 |
+
lambda c: f'a doodle of a {c}.',
|
40 |
+
lambda c: f'a close-up photo of the {c}.',
|
41 |
+
lambda c: f'a photo of a {c}.',
|
42 |
+
lambda c: f'the origami {c}.',
|
43 |
+
lambda c: f'the {c} in a video game.',
|
44 |
+
lambda c: f'a sketch of a {c}.',
|
45 |
+
lambda c: f'a doodle of the {c}.',
|
46 |
+
lambda c: f'a origami {c}.',
|
47 |
+
lambda c: f'a low resolution photo of a {c}.',
|
48 |
+
lambda c: f'the toy {c}.',
|
49 |
+
lambda c: f'a rendition of the {c}.',
|
50 |
+
lambda c: f'a photo of the clean {c}.',
|
51 |
+
lambda c: f'a photo of a large {c}.',
|
52 |
+
lambda c: f'a rendition of a {c}.',
|
53 |
+
lambda c: f'a photo of a nice {c}.',
|
54 |
+
lambda c: f'a photo of a weird {c}.',
|
55 |
+
lambda c: f'a blurry photo of a {c}.',
|
56 |
+
lambda c: f'a cartoon {c}.',
|
57 |
+
lambda c: f'art of a {c}.',
|
58 |
+
lambda c: f'a sketch of the {c}.',
|
59 |
+
lambda c: f'a embroidered {c}.',
|
60 |
+
lambda c: f'a pixelated photo of a {c}.',
|
61 |
+
lambda c: f'itap of the {c}.',
|
62 |
+
lambda c: f'a jpeg corrupted photo of the {c}.',
|
63 |
+
lambda c: f'a good photo of a {c}.',
|
64 |
+
lambda c: f'a plushie {c}.',
|
65 |
+
lambda c: f'a photo of the nice {c}.',
|
66 |
+
lambda c: f'a photo of the small {c}.',
|
67 |
+
lambda c: f'a photo of the weird {c}.',
|
68 |
+
lambda c: f'the cartoon {c}.',
|
69 |
+
lambda c: f'art of the {c}.',
|
70 |
+
lambda c: f'a drawing of the {c}.',
|
71 |
+
lambda c: f'a photo of the large {c}.',
|
72 |
+
lambda c: f'a black and white photo of a {c}.',
|
73 |
+
lambda c: f'the plushie {c}.',
|
74 |
+
lambda c: f'a dark photo of a {c}.',
|
75 |
+
lambda c: f'itap of a {c}.',
|
76 |
+
lambda c: f'graffiti of the {c}.',
|
77 |
+
lambda c: f'a toy {c}.',
|
78 |
+
lambda c: f'itap of my {c}.',
|
79 |
+
lambda c: f'a photo of a cool {c}.',
|
80 |
+
lambda c: f'a photo of a small {c}.',
|
81 |
+
lambda c: f'a tattoo of the {c}.',
|
82 |
+
]
|
model_ensemble.py
ADDED
@@ -0,0 +1,1034 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# Set cache directories to use checkpoint folder for model downloads
|
4 |
+
os.environ['TORCH_HOME'] = './checkpoint'
|
5 |
+
os.environ['HF_HOME'] = './checkpoint/huggingface'
|
6 |
+
os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers'
|
7 |
+
os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub'
|
8 |
+
|
9 |
+
# Create checkpoint subdirectories if they don't exist
|
10 |
+
os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True)
|
11 |
+
os.makedirs('./checkpoint/huggingface/hub', exist_ok=True)
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from torchvision.transforms import v2
|
16 |
+
from torchvision.transforms.v2.functional import resize
|
17 |
+
import cv2
|
18 |
+
import json
|
19 |
+
import torch
|
20 |
+
import random
|
21 |
+
import logging
|
22 |
+
import argparse
|
23 |
+
import numpy as np
|
24 |
+
from PIL import Image
|
25 |
+
from skimage import measure
|
26 |
+
from tabulate import tabulate
|
27 |
+
from torchvision.ops.focal_loss import sigmoid_focal_loss
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torchvision.transforms as transforms
|
30 |
+
import torchvision.transforms.functional as TF
|
31 |
+
from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
|
32 |
+
from sklearn.mixture import GaussianMixture
|
33 |
+
import faiss
|
34 |
+
import open_clip_local as open_clip
|
35 |
+
|
36 |
+
from torch.utils.data.dataset import ConcatDataset
|
37 |
+
from scipy.optimize import linear_sum_assignment
|
38 |
+
from sklearn.random_projection import SparseRandomProjection
|
39 |
+
import cv2
|
40 |
+
from torchvision.transforms import InterpolationMode
|
41 |
+
from PIL import Image
|
42 |
+
import string
|
43 |
+
|
44 |
+
from prompt_ensemble import encode_text_with_prompt_ensemble, encode_normal_text, encode_abnormal_text, encode_general_text, encode_obj_text
|
45 |
+
from kmeans_pytorch import kmeans, kmeans_predict
|
46 |
+
from scipy.optimize import linear_sum_assignment
|
47 |
+
from scipy.stats import norm
|
48 |
+
|
49 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
50 |
+
from matplotlib import pyplot as plt
|
51 |
+
|
52 |
+
import pickle
|
53 |
+
from scipy.stats import norm
|
54 |
+
|
55 |
+
from open_clip_local.pos_embed import get_2d_sincos_pos_embed
|
56 |
+
|
57 |
+
from anomalib.models.components import KCenterGreedy
|
58 |
+
|
59 |
+
def to_np_img(m):
|
60 |
+
m = m.permute(1, 2, 0).cpu().numpy()
|
61 |
+
mean = np.array([[[0.48145466, 0.4578275, 0.40821073]]])
|
62 |
+
std = np.array([[[0.26862954, 0.26130258, 0.27577711]]])
|
63 |
+
m = m * std + mean
|
64 |
+
return np.clip((m * 255.), 0, 255).astype(np.uint8)
|
65 |
+
|
66 |
+
|
67 |
+
def setup_seed(seed):
|
68 |
+
torch.manual_seed(seed)
|
69 |
+
torch.cuda.manual_seed_all(seed)
|
70 |
+
np.random.seed(seed)
|
71 |
+
random.seed(seed)
|
72 |
+
torch.backends.cudnn.deterministic = True
|
73 |
+
torch.backends.cudnn.benchmark = False
|
74 |
+
|
75 |
+
|
76 |
+
class MyModel(nn.Module):
|
77 |
+
"""Example model class for track 2.
|
78 |
+
|
79 |
+
This class applies few-shot anomaly detection using the WinClip model from Anomalib.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self) -> None:
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
setup_seed(42)
|
86 |
+
# NOTE: Create your transformation pipeline (if needed).
|
87 |
+
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
88 |
+
self.transform = v2.Compose(
|
89 |
+
[
|
90 |
+
v2.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
|
91 |
+
],
|
92 |
+
)
|
93 |
+
|
94 |
+
# NOTE: Create your model.
|
95 |
+
|
96 |
+
self.model_clip, _, _ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
|
97 |
+
self.tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
|
98 |
+
self.feature_list = [6, 12, 18, 24]
|
99 |
+
self.embed_dim = 768
|
100 |
+
self.vision_width = 1024
|
101 |
+
|
102 |
+
self.model_sam = sam_model_registry["vit_h"](checkpoint = "./checkpoint/sam_vit_h_4b8939.pth").to(self.device)
|
103 |
+
self.mask_generator = SamAutomaticMaskGenerator(model = self.model_sam)
|
104 |
+
|
105 |
+
self.memory_size = 2048
|
106 |
+
self.n_neighbors = 2
|
107 |
+
|
108 |
+
self.model_clip.eval()
|
109 |
+
self.test_args = None
|
110 |
+
self.align_corners = True # False
|
111 |
+
self.antialias = True # False
|
112 |
+
self.inter_mode = 'bilinear' # bilinear/bicubic
|
113 |
+
|
114 |
+
self.cluster_feature_id = [0, 1]
|
115 |
+
|
116 |
+
self.cluster_num_dict = {
|
117 |
+
"breakfast_box": 3, # unused
|
118 |
+
"juice_bottle": 8, # unused
|
119 |
+
"splicing_connectors": 10, # unused
|
120 |
+
"pushpins": 10,
|
121 |
+
"screw_bag": 10,
|
122 |
+
}
|
123 |
+
self.query_words_dict = {
|
124 |
+
"breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
|
125 |
+
"juice_bottle": ['bottle', ['black background', 'background']],
|
126 |
+
"pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
|
127 |
+
"screw_bag": [['screw'], 'plastic bag', 'background'],
|
128 |
+
"splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
|
129 |
+
}
|
130 |
+
self.foreground_label_idx = { # for query_words_dict
|
131 |
+
"breakfast_box": [0, 1, 2, 3, 4, 5],
|
132 |
+
"juice_bottle": [0],
|
133 |
+
"pushpins": [0],
|
134 |
+
"screw_bag": [0],
|
135 |
+
"splicing_connectors":[0, 1]
|
136 |
+
}
|
137 |
+
|
138 |
+
self.patch_query_words_dict = {
|
139 |
+
"breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
|
140 |
+
"juice_bottle": [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
|
141 |
+
"pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
|
142 |
+
"screw_bag": [['hex screw', 'hexagon bolt'], ['hex nut', 'hexagon nut'], ['ring washer', 'ring gasket'], ['plastic bag', 'background']], # 79.71
|
143 |
+
"splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
|
144 |
+
}
|
145 |
+
|
146 |
+
|
147 |
+
self.query_threshold_dict = {
|
148 |
+
"breakfast_box": [0., 0., 0., 0., 0., 0., 0.], # unused
|
149 |
+
"juice_bottle": [0., 0., 0.], # unused
|
150 |
+
"splicing_connectors": [0.15, 0.15, 0.15, 0., 0.], # unused
|
151 |
+
"pushpins": [0.2, 0., 0., 0.],
|
152 |
+
"screw_bag": [0., 0., 0.,],
|
153 |
+
}
|
154 |
+
|
155 |
+
self.feat_size = 64
|
156 |
+
self.ori_feat_size = 32
|
157 |
+
|
158 |
+
self.visualization = False #False # True #False
|
159 |
+
|
160 |
+
self.pushpins_count = 15
|
161 |
+
|
162 |
+
self.splicing_connectors_count = [2, 3, 5] # coresponding to yellow, blue, and red
|
163 |
+
self.splicing_connectors_distance = 0
|
164 |
+
self.splicing_connectors_cable_color_query_words_dict = [['yellow cable', 'yellow wire'], ['blue cable', 'blue wire'], ['red cable', 'red wire']]
|
165 |
+
|
166 |
+
self.juice_bottle_liquid_query_words_dict = [['red liquid', 'cherry juice'], ['yellow liquid', 'orange juice'], ['milky liquid']]
|
167 |
+
self.juice_bottle_fruit_query_words_dict = ['cherry', ['tangerine', 'orange'], 'banana']
|
168 |
+
|
169 |
+
# query words
|
170 |
+
self.foreground_pixel_hist = 0
|
171 |
+
self.foreground_pixel_hist_screw_bag = 366.0 # 4-shot statistics
|
172 |
+
self.foreground_pixel_hist_splicing_connectors = 4249.666666666667 # 4-shot statistics
|
173 |
+
# patch query words
|
174 |
+
self.patch_token_hist = []
|
175 |
+
|
176 |
+
self.few_shot_inited = False
|
177 |
+
|
178 |
+
self.save_coreset_features = False
|
179 |
+
|
180 |
+
|
181 |
+
from dinov2.dinov2.hub.backbones import dinov2_vitl14
|
182 |
+
self.model_dinov2 = dinov2_vitl14()
|
183 |
+
self.model_dinov2.to(self.device)
|
184 |
+
self.model_dinov2.eval()
|
185 |
+
self.feature_list_dinov2 = [6, 12, 18, 24]
|
186 |
+
self.vision_width_dinov2 = 1024
|
187 |
+
|
188 |
+
self.stats = pickle.load(open("memory_bank/statistic_scores_model_ensemble_val.pkl", "rb"))
|
189 |
+
|
190 |
+
self.mem_instance_masks = None
|
191 |
+
|
192 |
+
self.anomaly_flag = False
|
193 |
+
self.validation = False #True #False
|
194 |
+
|
195 |
+
def set_save_coreset_features(self, save_coreset_features):
|
196 |
+
self.save_coreset_features = save_coreset_features
|
197 |
+
|
198 |
+
def set_viz(self, viz):
|
199 |
+
self.visualization = viz
|
200 |
+
|
201 |
+
def set_val(self, val):
|
202 |
+
self.validation = val
|
203 |
+
|
204 |
+
def forward(self, batch: torch.Tensor, batch_path: list) -> dict[str, torch.Tensor]:
|
205 |
+
"""Transform the input batch and pass it through the model.
|
206 |
+
|
207 |
+
This model returns a dictionary with the following keys
|
208 |
+
- ``anomaly_map`` - Anomaly map.
|
209 |
+
- ``pred_score`` - Predicted anomaly score.
|
210 |
+
"""
|
211 |
+
self.anomaly_flag = False
|
212 |
+
batch = self.transform(batch).to(self.device)
|
213 |
+
results = self.forward_one_sample(batch, self.mem_patch_feature_clip_coreset, self.mem_patch_feature_dinov2_coreset, batch_path[0])
|
214 |
+
|
215 |
+
hist_score = results['hist_score']
|
216 |
+
structural_score = results['structural_score']
|
217 |
+
instance_hungarian_match_score = results['instance_hungarian_match_score']
|
218 |
+
|
219 |
+
|
220 |
+
if self.validation:
|
221 |
+
return {"hist_score": torch.tensor(hist_score), "structural_score": torch.tensor(structural_score), "instance_hungarian_match_score": torch.tensor(instance_hungarian_match_score)}
|
222 |
+
|
223 |
+
def sigmoid(z):
|
224 |
+
return 1/(1 + np.exp(-z))
|
225 |
+
|
226 |
+
# standardization
|
227 |
+
standard_structural_score = (structural_score - self.stats[self.class_name]["structural_scores"]["mean"]) / self.stats[self.class_name]["structural_scores"]["unbiased_std"]
|
228 |
+
standard_instance_hungarian_match_score = (instance_hungarian_match_score - self.stats[self.class_name]["instance_hungarian_match_scores"]["mean"]) / self.stats[self.class_name]["instance_hungarian_match_scores"]["unbiased_std"]
|
229 |
+
|
230 |
+
pred_score = max(standard_instance_hungarian_match_score, standard_structural_score)
|
231 |
+
pred_score = sigmoid(pred_score)
|
232 |
+
|
233 |
+
|
234 |
+
if self.anomaly_flag:
|
235 |
+
pred_score = 1.
|
236 |
+
self.anomaly_flag = False
|
237 |
+
|
238 |
+
|
239 |
+
return {"pred_score": torch.tensor(pred_score), "hist_score": torch.tensor(hist_score), "structural_score": torch.tensor(structural_score), "instance_hungarian_match_score": torch.tensor(instance_hungarian_match_score)}
|
240 |
+
|
241 |
+
|
242 |
+
def forward_one_sample(self, batch: torch.Tensor, mem_patch_feature_clip_coreset: torch.Tensor, mem_patch_feature_dinov2_coreset: torch.Tensor, path: str):
|
243 |
+
|
244 |
+
with torch.no_grad():
|
245 |
+
image_features, patch_tokens, proj_patch_tokens = self.model_clip.encode_image(batch, self.feature_list)
|
246 |
+
# image_features /= image_features.norm(dim=-1, keepdim=True)
|
247 |
+
patch_tokens = [p[:, 1:, :] for p in patch_tokens]
|
248 |
+
patch_tokens = [p.reshape(p.shape[0]*p.shape[1], p.shape[2]) for p in patch_tokens]
|
249 |
+
|
250 |
+
patch_tokens_clip = torch.cat(patch_tokens, dim=-1) # (1, 1024, 1024x4)
|
251 |
+
# patch_tokens_clip = torch.cat(patch_tokens[2:], dim=-1) # (1, 1024, 1024x2)
|
252 |
+
patch_tokens_clip = patch_tokens_clip.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
253 |
+
patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
254 |
+
patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
|
255 |
+
patch_tokens_clip = F.normalize(patch_tokens_clip, p=2, dim=-1) # (1x64x64, 1024x4)
|
256 |
+
|
257 |
+
with torch.no_grad():
|
258 |
+
patch_tokens_dinov2 = self.model_dinov2.forward_features(batch, out_layer_list=self.feature_list)
|
259 |
+
patch_tokens_dinov2 = torch.cat(patch_tokens_dinov2, dim=-1) # (1, 1024, 1024x4)
|
260 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
261 |
+
patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
262 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.permute(0, 2, 3, 1).view(-1, self.vision_width_dinov2 * len(self.feature_list_dinov2))
|
263 |
+
patch_tokens_dinov2 = F.normalize(patch_tokens_dinov2, p=2, dim=-1) # (1x64x64, 1024x4)
|
264 |
+
|
265 |
+
'''adding for kmeans seg '''
|
266 |
+
if self.feat_size != self.ori_feat_size:
|
267 |
+
proj_patch_tokens = proj_patch_tokens.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
268 |
+
proj_patch_tokens = F.interpolate(proj_patch_tokens, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
269 |
+
proj_patch_tokens = proj_patch_tokens.permute(0, 2, 3, 1).view(self.feat_size * self.feat_size, self.embed_dim)
|
270 |
+
proj_patch_tokens = F.normalize(proj_patch_tokens, p=2, dim=-1)
|
271 |
+
|
272 |
+
mid_features = None
|
273 |
+
for layer in self.cluster_feature_id:
|
274 |
+
temp_feat = patch_tokens[layer]
|
275 |
+
mid_features = temp_feat if mid_features is None else torch.cat((mid_features, temp_feat), -1)
|
276 |
+
|
277 |
+
if self.feat_size != self.ori_feat_size:
|
278 |
+
mid_features = mid_features.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
279 |
+
mid_features = F.interpolate(mid_features, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
280 |
+
mid_features = mid_features.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.cluster_feature_id))
|
281 |
+
mid_features = F.normalize(mid_features, p=2, dim=-1)
|
282 |
+
|
283 |
+
results = self.histogram(batch, mid_features, proj_patch_tokens, self.class_name, os.path.dirname(path).split('/')[-1] + "_" + os.path.basename(path).split('.')[0])
|
284 |
+
|
285 |
+
hist_score = results['score']
|
286 |
+
|
287 |
+
'''calculate patchcore'''
|
288 |
+
anomaly_maps_patchcore = []
|
289 |
+
|
290 |
+
if self.class_name in ['pushpins', 'screw_bag']: # clip feature for patchcore
|
291 |
+
len_feature_list = len(self.feature_list)
|
292 |
+
for patch_feature, mem_patch_feature in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1), mem_patch_feature_clip_coreset.chunk(len_feature_list, dim=-1)):
|
293 |
+
patch_feature = F.normalize(patch_feature, dim=-1)
|
294 |
+
mem_patch_feature = F.normalize(mem_patch_feature, dim=-1)
|
295 |
+
normal_map_patchcore = (patch_feature @ mem_patch_feature.T)
|
296 |
+
normal_map_patchcore = (normal_map_patchcore.max(1)[0]).cpu().numpy() # 1: normal 0: abnormal
|
297 |
+
anomaly_map_patchcore = 1 - normal_map_patchcore
|
298 |
+
|
299 |
+
anomaly_maps_patchcore.append(anomaly_map_patchcore)
|
300 |
+
|
301 |
+
if self.class_name in ['splicing_connectors', 'breakfast_box', 'juice_bottle']: # dinov2 feature for patchcore
|
302 |
+
len_feature_list = len(self.feature_list_dinov2)
|
303 |
+
for patch_feature, mem_patch_feature in zip(patch_tokens_dinov2.chunk(len_feature_list, dim=-1), mem_patch_feature_dinov2_coreset.chunk(len_feature_list, dim=-1)):
|
304 |
+
patch_feature = F.normalize(patch_feature, dim=-1)
|
305 |
+
mem_patch_feature = F.normalize(mem_patch_feature, dim=-1)
|
306 |
+
normal_map_patchcore = (patch_feature @ mem_patch_feature.T)
|
307 |
+
normal_map_patchcore = (normal_map_patchcore.max(1)[0]).cpu().numpy() # 1: normal 0: abnormal
|
308 |
+
anomaly_map_patchcore = 1 - normal_map_patchcore
|
309 |
+
|
310 |
+
anomaly_maps_patchcore.append(anomaly_map_patchcore)
|
311 |
+
|
312 |
+
structural_score = np.stack(anomaly_maps_patchcore).mean(0).max()
|
313 |
+
# anomaly_map_structural = np.stack(anomaly_maps_patchcore).mean(0).reshape(self.feat_size, self.feat_size)
|
314 |
+
|
315 |
+
instance_masks = results["instance_masks"]
|
316 |
+
anomaly_instances_hungarian = []
|
317 |
+
instance_hungarian_match_score = 1.
|
318 |
+
if self.mem_instance_masks is not None and len(instance_masks) != 0:
|
319 |
+
for patch_feature, mem_instance_features_single_stage in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1), self.mem_instance_features_multi_stage.chunk(len_feature_list, dim=1)):
|
320 |
+
instance_features = [patch_feature[mask, :].mean(0, keepdim=True) for mask in instance_masks]
|
321 |
+
instance_features = torch.cat(instance_features, dim=0)
|
322 |
+
instance_features = F.normalize(instance_features, dim=-1)
|
323 |
+
|
324 |
+
normal_instance_hungarian = (instance_features @ mem_instance_features_single_stage.T)
|
325 |
+
cost_matrix = (1 - normal_instance_hungarian).cpu().numpy()
|
326 |
+
|
327 |
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
328 |
+
cost = cost_matrix[row_ind, col_ind].sum()
|
329 |
+
cost = cost / min(cost_matrix.shape)
|
330 |
+
anomaly_instances_hungarian.append(cost)
|
331 |
+
|
332 |
+
instance_hungarian_match_score = np.mean(anomaly_instances_hungarian)
|
333 |
+
|
334 |
+
results = {'hist_score': hist_score, 'structural_score': structural_score, 'instance_hungarian_match_score': instance_hungarian_match_score}
|
335 |
+
|
336 |
+
return results
|
337 |
+
|
338 |
+
|
339 |
+
def histogram(self, image, cluster_feature, proj_patch_token, class_name, path):
|
340 |
+
def plot_results_only(sorted_anns):
|
341 |
+
cur = 1
|
342 |
+
img_color = np.zeros((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1]))
|
343 |
+
for ann in sorted_anns:
|
344 |
+
m = ann['segmentation']
|
345 |
+
img_color[m] = cur
|
346 |
+
cur += 1
|
347 |
+
return img_color
|
348 |
+
|
349 |
+
def merge_segmentations(a, b, background_class):
|
350 |
+
unique_labels_a = np.unique(a)
|
351 |
+
unique_labels_b = np.unique(b)
|
352 |
+
|
353 |
+
max_label_a = unique_labels_a.max()
|
354 |
+
label_map = np.zeros(max_label_a + 1, dtype=int)
|
355 |
+
|
356 |
+
for label_a in unique_labels_a:
|
357 |
+
mask_a = (a == label_a)
|
358 |
+
|
359 |
+
labels_b = b[mask_a]
|
360 |
+
if labels_b.size > 0:
|
361 |
+
count_b = np.bincount(labels_b, minlength=unique_labels_b.max() + 1)
|
362 |
+
label_map[label_a] = np.argmax(count_b)
|
363 |
+
else:
|
364 |
+
label_map[label_a] = background_class # default background
|
365 |
+
|
366 |
+
merged_a = label_map[a]
|
367 |
+
return merged_a
|
368 |
+
|
369 |
+
pseudo_labels = kmeans_predict(cluster_feature, self.cluster_centers, 'euclidean', device=self.device)
|
370 |
+
kmeans_mask = torch.ones_like(pseudo_labels) * (self.classes - 1) # default to background
|
371 |
+
|
372 |
+
for pl in pseudo_labels.unique():
|
373 |
+
mask = (pseudo_labels == pl).reshape(-1)
|
374 |
+
# filter small region
|
375 |
+
binary = mask.cpu().numpy().reshape(self.feat_size, self.feat_size).astype(np.uint8)
|
376 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
377 |
+
for i in range(1, num_labels):
|
378 |
+
temp_mask = labels == i
|
379 |
+
if np.sum(temp_mask) <= 8:
|
380 |
+
mask[temp_mask.reshape(-1)] = False
|
381 |
+
|
382 |
+
if mask.any():
|
383 |
+
region_feature = proj_patch_token[mask, :].mean(0, keepdim=True)
|
384 |
+
similarity = (region_feature @ self.query_obj.T)
|
385 |
+
prob, index = torch.max(similarity, dim=-1)
|
386 |
+
temp_label = index.squeeze(0).item()
|
387 |
+
temp_prob = prob.squeeze(0).item()
|
388 |
+
if temp_prob > self.query_threshold_dict[class_name][temp_label]: # threshold for each class
|
389 |
+
kmeans_mask[mask] = temp_label
|
390 |
+
|
391 |
+
|
392 |
+
raw_image = to_np_img(image[0])
|
393 |
+
height, width = raw_image.shape[:2]
|
394 |
+
masks = self.mask_generator.generate(raw_image)
|
395 |
+
# self.predictor.set_image(raw_image)
|
396 |
+
|
397 |
+
kmeans_label = pseudo_labels.view(self.feat_size, self.feat_size).cpu().numpy()
|
398 |
+
kmeans_mask = kmeans_mask.view(self.feat_size, self.feat_size).cpu().numpy()
|
399 |
+
|
400 |
+
patch_similarity = (proj_patch_token @ self.patch_query_obj.T)
|
401 |
+
patch_mask = patch_similarity.argmax(-1)
|
402 |
+
patch_mask = patch_mask.view(self.feat_size, self.feat_size).cpu().numpy()
|
403 |
+
|
404 |
+
sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
|
405 |
+
sam_mask = plot_results_only(sorted_masks).astype(np.int)
|
406 |
+
|
407 |
+
resized_mask = cv2.resize(kmeans_mask, (width, height), interpolation = cv2.INTER_NEAREST)
|
408 |
+
merge_sam = merge_segmentations(sam_mask, resized_mask, background_class=self.classes-1)
|
409 |
+
|
410 |
+
resized_patch_mask = cv2.resize(patch_mask, (width, height), interpolation = cv2.INTER_NEAREST)
|
411 |
+
patch_merge_sam = merge_segmentations(sam_mask, resized_patch_mask, background_class=self.patch_query_obj.shape[0]-1)
|
412 |
+
|
413 |
+
# filter small region for merge sam
|
414 |
+
binary = np.isin(merge_sam, self.foreground_label_idx[self.class_name]).astype(np.uint8) # foreground 1 background 0
|
415 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
416 |
+
for i in range(1, num_labels):
|
417 |
+
temp_mask = labels == i
|
418 |
+
if np.sum(temp_mask) <= 32: # 448x448
|
419 |
+
merge_sam[temp_mask] = self.classes - 1 # set to background
|
420 |
+
|
421 |
+
# filter small region for patch merge sam
|
422 |
+
binary = (patch_merge_sam != (self.patch_query_obj.shape[0]-1) ).astype(np.uint8) # foreground 1 background 0
|
423 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
424 |
+
for i in range(1, num_labels):
|
425 |
+
temp_mask = labels == i
|
426 |
+
if np.sum(temp_mask) <= 32: # 448x448
|
427 |
+
patch_merge_sam[temp_mask] = self.patch_query_obj.shape[0]-1 # set to background
|
428 |
+
|
429 |
+
score = 0. # default to normal
|
430 |
+
self.anomaly_flag = False
|
431 |
+
instance_masks = []
|
432 |
+
if self.class_name == 'pushpins':
|
433 |
+
# object count hist
|
434 |
+
kernel = np.ones((3, 3), dtype=np.uint8) # dilate for robustness
|
435 |
+
binary = np.isin(merge_sam, self.foreground_label_idx[self.class_name]).astype(np.uint8) # foreground 1 background 0
|
436 |
+
dilate_binary = cv2.dilate(binary, kernel)
|
437 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(dilate_binary, connectivity=8)
|
438 |
+
pushpins_count = num_labels - 1 # number of pushpins
|
439 |
+
|
440 |
+
for i in range(1, num_labels):
|
441 |
+
instance_mask = (labels == i).astype(np.uint8)
|
442 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
443 |
+
if instance_mask.any():
|
444 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
445 |
+
|
446 |
+
if self.few_shot_inited and pushpins_count != self.pushpins_count and self.anomaly_flag is False:
|
447 |
+
self.anomaly_flag = True
|
448 |
+
print('number of pushpins: {}, but canonical number of pushpins: {}'.format(pushpins_count, self.pushpins_count))
|
449 |
+
|
450 |
+
# patch hist
|
451 |
+
clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
452 |
+
clip_patch_hist = clip_patch_hist / np.linalg.norm(clip_patch_hist)
|
453 |
+
|
454 |
+
if self.few_shot_inited:
|
455 |
+
patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
|
456 |
+
score = 1 - patch_hist_similarity.max()
|
457 |
+
|
458 |
+
binary_foreground = dilate_binary.astype(np.uint8)
|
459 |
+
|
460 |
+
if len(instance_masks) != 0:
|
461 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
462 |
+
|
463 |
+
if self.visualization:
|
464 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
465 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
466 |
+
plt.figure(figsize=(20, 3))
|
467 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
468 |
+
plt.subplot(1, len(image_list), ind)
|
469 |
+
plt.imshow(temp_image)
|
470 |
+
plt.title(temp_title)
|
471 |
+
plt.margins(0, 0)
|
472 |
+
plt.axis('off')
|
473 |
+
# Extract relative path from class_name onwards
|
474 |
+
if class_name in path:
|
475 |
+
relative_path = path.split(class_name, 1)[-1]
|
476 |
+
if relative_path.startswith('/'):
|
477 |
+
relative_path = relative_path[1:]
|
478 |
+
save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
|
479 |
+
else:
|
480 |
+
save_path = f'visualization/full_data/{class_name}/{path}.png'
|
481 |
+
|
482 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
483 |
+
plt.tight_layout()
|
484 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
485 |
+
plt.close()
|
486 |
+
|
487 |
+
|
488 |
+
# todo: same number in total but in different boxes or broken box
|
489 |
+
return {"score": score, "clip_patch_hist": clip_patch_hist, "instance_masks": instance_masks}
|
490 |
+
|
491 |
+
elif self.class_name == 'splicing_connectors':
|
492 |
+
# object count hist for default
|
493 |
+
sam_mask_max_area = sorted_masks[0]['segmentation'] # background
|
494 |
+
binary = (sam_mask_max_area == 0).astype(np.uint8) # sam_mask_max_area is background, background 0 foreground 1
|
495 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
496 |
+
count = 0
|
497 |
+
for i in range(1, num_labels):
|
498 |
+
temp_mask = labels == i
|
499 |
+
if np.sum(temp_mask) <= 64: # 448x448 64
|
500 |
+
binary[temp_mask] = 0 # set to background
|
501 |
+
else:
|
502 |
+
count += 1
|
503 |
+
if count != 1 and self.anomaly_flag is False: # cable cut or no cable or no connector
|
504 |
+
print('number of connected component in splicing_connectors: {}, but the default connected component is 1.'.format(count))
|
505 |
+
self.anomaly_flag = True
|
506 |
+
|
507 |
+
merge_sam[~(binary.astype(np.bool))] = self.query_obj.shape[0] - 1 # remove noise
|
508 |
+
patch_merge_sam[~(binary.astype(np.bool))] = self.patch_query_obj.shape[0] - 1 # remove patch noise
|
509 |
+
|
510 |
+
# erode the cable and divide into left and right parts
|
511 |
+
kernel = np.ones((23, 23), dtype=np.uint8)
|
512 |
+
erode_binary = cv2.erode(binary, kernel)
|
513 |
+
h, w = erode_binary.shape
|
514 |
+
distance = 0
|
515 |
+
|
516 |
+
left, right = erode_binary[:, :int(w/2)], erode_binary[:, int(w/2):]
|
517 |
+
left_count = np.bincount(left.reshape(-1), minlength=self.classes)[1] # foreground
|
518 |
+
right_count = np.bincount(right.reshape(-1), minlength=self.classes)[1] # foreground
|
519 |
+
|
520 |
+
# binary_cable = (merge_sam == 1).astype(np.uint8)
|
521 |
+
binary_cable = (patch_merge_sam == 1).astype(np.uint8)
|
522 |
+
|
523 |
+
kernel = np.ones((5, 5), dtype=np.uint8)
|
524 |
+
binary_cable = cv2.erode(binary_cable, kernel)
|
525 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_cable, connectivity=8)
|
526 |
+
for i in range(1, num_labels):
|
527 |
+
temp_mask = labels == i
|
528 |
+
if np.sum(temp_mask) <= 64: # 448x448
|
529 |
+
binary_cable[temp_mask] = 0 # set to background
|
530 |
+
|
531 |
+
|
532 |
+
binary_cable = cv2.resize(binary_cable, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
533 |
+
|
534 |
+
binary_clamps = (patch_merge_sam == 0).astype(np.uint8)
|
535 |
+
|
536 |
+
kernel = np.ones((5, 5), dtype=np.uint8)
|
537 |
+
binary_clamps = cv2.erode(binary_clamps, kernel)
|
538 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_clamps, connectivity=8)
|
539 |
+
for i in range(1, num_labels):
|
540 |
+
temp_mask = labels == i
|
541 |
+
if np.sum(temp_mask) <= 64: # 448x448
|
542 |
+
binary_clamps[temp_mask] = 0 # set to background
|
543 |
+
else:
|
544 |
+
instance_mask = temp_mask.astype(np.uint8)
|
545 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
546 |
+
if instance_mask.any():
|
547 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
548 |
+
|
549 |
+
binary_clamps = cv2.resize(binary_clamps, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
550 |
+
|
551 |
+
binary_connector = cv2.resize(binary, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
552 |
+
|
553 |
+
query_cable_color = encode_obj_text(self.model_clip, self.splicing_connectors_cable_color_query_words_dict, self.tokenizer, self.device)
|
554 |
+
cable_feature = proj_patch_token[binary_cable.astype(np.bool).reshape(-1), :].mean(0, keepdim=True)
|
555 |
+
idx_color = (cable_feature @ query_cable_color.T).argmax(-1).squeeze(0).item()
|
556 |
+
foreground_pixel_count = np.sum(erode_binary) / self.splicing_connectors_count[idx_color]
|
557 |
+
|
558 |
+
|
559 |
+
slice_cable = binary[:, int(w/2)-1: int(w/2)+1]
|
560 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(slice_cable, connectivity=8)
|
561 |
+
cable_count = num_labels - 1
|
562 |
+
if cable_count != 1 and self.anomaly_flag is False: # two cables
|
563 |
+
print('number of cable count in splicing_connectors: {}, but the default cable count is 1.'.format(cable_count))
|
564 |
+
self.anomaly_flag = True
|
565 |
+
|
566 |
+
# {2-clamp: yellow 3-clamp: blue 5-clamp: red} cable color and clamp number mismatch
|
567 |
+
if self.few_shot_inited and self.foreground_pixel_hist_splicing_connectors != 0 and self.anomaly_flag is False:
|
568 |
+
ratio = foreground_pixel_count / self.foreground_pixel_hist_splicing_connectors
|
569 |
+
if (ratio > 1.2 or ratio < 0.8) and self.anomaly_flag is False: # color and number mismatch
|
570 |
+
print('cable color and number of clamps mismatch, cable color idx: {} (0: yellow 2-clamp, 1: blue 3-clamp, 2: red 5-clamp), foreground_pixel_count :{}, canonical foreground_pixel_hist: {}.'.format(idx_color, foreground_pixel_count, self.foreground_pixel_hist_splicing_connectors))
|
571 |
+
self.anomaly_flag = True
|
572 |
+
|
573 |
+
# left right hist for symmetry
|
574 |
+
ratio = np.sum(left_count) / (np.sum(right_count) + 1e-5)
|
575 |
+
if self.few_shot_inited and (ratio > 1.2 or ratio < 0.8) and self.anomaly_flag is False: # left right asymmetry in clamp
|
576 |
+
print('left and right connectors are not symmetry.')
|
577 |
+
self.anomaly_flag = True
|
578 |
+
|
579 |
+
# left and right centroids distance
|
580 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(erode_binary, connectivity=8)
|
581 |
+
if num_labels - 1 == 2:
|
582 |
+
centroids = centroids[1:]
|
583 |
+
x1, y1 = centroids[0]
|
584 |
+
x2, y2 = centroids[1]
|
585 |
+
distance = np.sqrt((x1/w - x2/w)**2 + (y1/h - y2/h)**2)
|
586 |
+
if self.few_shot_inited and self.splicing_connectors_distance != 0 and self.anomaly_flag is False:
|
587 |
+
ratio = distance / self.splicing_connectors_distance
|
588 |
+
if ratio < 0.6 or ratio > 1.4: # too short or too long centroids distance (cable) # 0.6 1.4
|
589 |
+
print('cable is too short or too long.')
|
590 |
+
self.anomaly_flag = True
|
591 |
+
|
592 |
+
# patch hist
|
593 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])#[:-1] # ignore background (grid) for statistic
|
594 |
+
sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
|
595 |
+
|
596 |
+
if self.few_shot_inited:
|
597 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
598 |
+
score = 1 - patch_hist_similarity.max()
|
599 |
+
|
600 |
+
# todo mismatch cable link
|
601 |
+
binary_foreground = binary.astype(np.uint8) # only 1 instance, so additionally seperate cable and clamps
|
602 |
+
if binary_connector.any():
|
603 |
+
instance_masks.append(binary_connector.astype(np.bool).reshape(-1))
|
604 |
+
if binary_clamps.any():
|
605 |
+
instance_masks.append(binary_clamps.astype(np.bool).reshape(-1))
|
606 |
+
if binary_cable.any():
|
607 |
+
instance_masks.append(binary_cable.astype(np.bool).reshape(-1))
|
608 |
+
|
609 |
+
if len(instance_masks) != 0:
|
610 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
611 |
+
|
612 |
+
if self.visualization:
|
613 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, binary_connector, merge_sam, patch_merge_sam, erode_binary, binary_cable, binary_clamps]
|
614 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'binary_connector', 'merge sam', 'patch merge sam', 'erode binary', 'binary_cable', 'binary_clamps']
|
615 |
+
plt.figure(figsize=(25, 3))
|
616 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
617 |
+
plt.subplot(1, len(image_list), ind)
|
618 |
+
plt.imshow(temp_image)
|
619 |
+
plt.title(temp_title)
|
620 |
+
plt.margins(0, 0)
|
621 |
+
plt.axis('off')
|
622 |
+
# Extract relative path from class_name onwards
|
623 |
+
if class_name in path:
|
624 |
+
relative_path = path.split(class_name, 1)[-1]
|
625 |
+
if relative_path.startswith('/'):
|
626 |
+
relative_path = relative_path[1:]
|
627 |
+
save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
|
628 |
+
else:
|
629 |
+
save_path = f'visualization/full_data/{class_name}/{path}.png'
|
630 |
+
|
631 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
632 |
+
plt.tight_layout()
|
633 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
634 |
+
plt.close()
|
635 |
+
|
636 |
+
return {"score": score, "foreground_pixel_count": foreground_pixel_count, "distance": distance, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
|
637 |
+
|
638 |
+
elif self.class_name == 'screw_bag':
|
639 |
+
# pixel hist of kmeans mask
|
640 |
+
foreground_pixel_count = np.sum(np.bincount(kmeans_mask.reshape(-1))[:len(self.foreground_label_idx[self.class_name])]) # foreground pixel
|
641 |
+
if self.few_shot_inited and self.foreground_pixel_hist_screw_bag != 0 and self.anomaly_flag is False:
|
642 |
+
ratio = foreground_pixel_count / self.foreground_pixel_hist_screw_bag
|
643 |
+
# todo: optimize
|
644 |
+
if ratio < 0.94 or ratio > 1.06: # 82.95 | 81.3
|
645 |
+
print('foreground pixel histagram of screw bag: {}, the canonical foreground pixel histogram of screw bag in few shot: {}'.format(foreground_pixel_count, self.foreground_pixel_hist_screw_bag))
|
646 |
+
self.anomaly_flag = True
|
647 |
+
|
648 |
+
# patch hist
|
649 |
+
binary_screw = np.isin(kmeans_mask, self.foreground_label_idx[self.class_name])
|
650 |
+
patch_mask[~binary_screw] = self.patch_query_obj.shape[0] - 1 # remove patch noise
|
651 |
+
resized_binary_screw = cv2.resize(binary_screw.astype(np.uint8), (patch_merge_sam.shape[1], patch_merge_sam.shape[0]), interpolation = cv2.INTER_NEAREST)
|
652 |
+
patch_merge_sam[~(resized_binary_screw.astype(np.bool))] = self.patch_query_obj.shape[0] - 1 # remove patch noise
|
653 |
+
|
654 |
+
clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])[:-1]
|
655 |
+
clip_patch_hist = clip_patch_hist / np.linalg.norm(clip_patch_hist)
|
656 |
+
|
657 |
+
if self.few_shot_inited:
|
658 |
+
patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
|
659 |
+
score = 1 - patch_hist_similarity.max()
|
660 |
+
|
661 |
+
for i in range(self.patch_query_obj.shape[0]-1):
|
662 |
+
binary_foreground = (patch_merge_sam == i).astype(np.uint8)
|
663 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
|
664 |
+
for i in range(1, num_labels):
|
665 |
+
instance_mask = (labels == i).astype(np.uint8)
|
666 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
667 |
+
if instance_mask.any():
|
668 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
669 |
+
|
670 |
+
if len(instance_masks) != 0:
|
671 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
672 |
+
|
673 |
+
if self.visualization:
|
674 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
675 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
676 |
+
plt.figure(figsize=(20, 3))
|
677 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
678 |
+
plt.subplot(1, len(image_list), ind)
|
679 |
+
plt.imshow(temp_image)
|
680 |
+
plt.title(temp_title)
|
681 |
+
plt.margins(0, 0)
|
682 |
+
plt.axis('off')
|
683 |
+
# Extract relative path from class_name onwards
|
684 |
+
if class_name in path:
|
685 |
+
relative_path = path.split(class_name, 1)[-1]
|
686 |
+
if relative_path.startswith('/'):
|
687 |
+
relative_path = relative_path[1:]
|
688 |
+
save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
|
689 |
+
else:
|
690 |
+
save_path = f'visualization/full_data/{class_name}/{path}.png'
|
691 |
+
|
692 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
693 |
+
plt.tight_layout()
|
694 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
695 |
+
plt.close()
|
696 |
+
|
697 |
+
# plt.axis('off')
|
698 |
+
# plt.imshow(patch_merge_sam)
|
699 |
+
|
700 |
+
# plt.savefig('pic/vis/{}_seg_{}.png'.format(class_name, path), bbox_inches='tight', pad_inches = 0) # pad_inches = 0
|
701 |
+
# plt.close()
|
702 |
+
|
703 |
+
|
704 |
+
return {"score": score, "foreground_pixel_count": foreground_pixel_count, "clip_patch_hist": clip_patch_hist, "instance_masks": instance_masks}
|
705 |
+
|
706 |
+
elif self.class_name == 'breakfast_box':
|
707 |
+
# patch hist
|
708 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
709 |
+
sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
|
710 |
+
|
711 |
+
if self.few_shot_inited:
|
712 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
713 |
+
score = 1 - patch_hist_similarity.max()
|
714 |
+
|
715 |
+
# todo: exist of foreground
|
716 |
+
|
717 |
+
binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1)).astype(np.uint8)
|
718 |
+
|
719 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
|
720 |
+
for i in range(1, num_labels):
|
721 |
+
instance_mask = (labels == i).astype(np.uint8)
|
722 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
723 |
+
if instance_mask.any():
|
724 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
725 |
+
|
726 |
+
|
727 |
+
if len(instance_masks) != 0:
|
728 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
729 |
+
|
730 |
+
if self.visualization:
|
731 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
732 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
733 |
+
plt.figure(figsize=(20, 3))
|
734 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
735 |
+
plt.subplot(1, len(image_list), ind)
|
736 |
+
plt.imshow(temp_image)
|
737 |
+
plt.title(temp_title)
|
738 |
+
plt.margins(0, 0)
|
739 |
+
plt.axis('off')
|
740 |
+
# Extract relative path from class_name onwards
|
741 |
+
if class_name in path:
|
742 |
+
relative_path = path.split(class_name, 1)[-1]
|
743 |
+
if relative_path.startswith('/'):
|
744 |
+
relative_path = relative_path[1:]
|
745 |
+
save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
|
746 |
+
else:
|
747 |
+
save_path = f'visualization/full_data/{class_name}/{path}.png'
|
748 |
+
|
749 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
750 |
+
plt.tight_layout()
|
751 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
752 |
+
plt.close()
|
753 |
+
|
754 |
+
# plt.axis('off')
|
755 |
+
# plt.imshow(patch_merge_sam)
|
756 |
+
|
757 |
+
# plt.savefig('pic/vis/{}_seg_{}.png'.format(class_name, path), bbox_inches='tight', pad_inches = 0) # pad_inches = 0
|
758 |
+
# plt.close()
|
759 |
+
|
760 |
+
return {"score": score, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
|
761 |
+
|
762 |
+
elif self.class_name == 'juice_bottle':
|
763 |
+
# remove noise due to non sam mask
|
764 |
+
merge_sam[sam_mask == 0] = self.classes - 1
|
765 |
+
patch_merge_sam[sam_mask == 0] = self.patch_query_obj.shape[0] - 1 # 79.5
|
766 |
+
|
767 |
+
# [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
|
768 |
+
# fruit and liquid mismatch (todo if exist)
|
769 |
+
resized_patch_merge_sam = cv2.resize(patch_merge_sam, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
770 |
+
binary_liquid = (resized_patch_merge_sam == 1)
|
771 |
+
binary_fruit = (resized_patch_merge_sam == 2)
|
772 |
+
|
773 |
+
query_liquid = encode_obj_text(self.model_clip, self.juice_bottle_liquid_query_words_dict, self.tokenizer, self.device)
|
774 |
+
query_fruit = encode_obj_text(self.model_clip, self.juice_bottle_fruit_query_words_dict, self.tokenizer, self.device)
|
775 |
+
|
776 |
+
liquid_feature = proj_patch_token[binary_liquid.reshape(-1), :].mean(0, keepdim=True)
|
777 |
+
liquid_idx = (liquid_feature @ query_liquid.T).argmax(-1).squeeze(0).item()
|
778 |
+
|
779 |
+
fruit_feature = proj_patch_token[binary_fruit.reshape(-1), :].mean(0, keepdim=True)
|
780 |
+
fruit_idx = (fruit_feature @ query_fruit.T).argmax(-1).squeeze(0).item()
|
781 |
+
|
782 |
+
if (liquid_idx != fruit_idx) and self.anomaly_flag is False:
|
783 |
+
print('liquid: {}, but fruit: {}.'.format(self.juice_bottle_liquid_query_words_dict[liquid_idx], self.juice_bottle_fruit_query_words_dict[fruit_idx]))
|
784 |
+
self.anomaly_flag = True
|
785 |
+
|
786 |
+
# # todo centroid of fruit and tag_0 mismatch (if exist) , only one tag, center
|
787 |
+
|
788 |
+
# patch hist
|
789 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
790 |
+
sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
|
791 |
+
|
792 |
+
if self.few_shot_inited:
|
793 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
794 |
+
score = 1 - patch_hist_similarity.max()
|
795 |
+
|
796 |
+
binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1) ).astype(np.uint8)
|
797 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
|
798 |
+
for i in range(1, num_labels):
|
799 |
+
instance_mask = (labels == i).astype(np.uint8)
|
800 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
801 |
+
if instance_mask.any():
|
802 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
803 |
+
|
804 |
+
if len(instance_masks) != 0:
|
805 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
806 |
+
|
807 |
+
if self.visualization:
|
808 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
809 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
810 |
+
plt.figure(figsize=(20, 3))
|
811 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
812 |
+
plt.subplot(1, len(image_list), ind)
|
813 |
+
plt.imshow(temp_image)
|
814 |
+
plt.title(temp_title)
|
815 |
+
plt.margins(0, 0)
|
816 |
+
plt.axis('off')
|
817 |
+
# Extract relative path from class_name onwards
|
818 |
+
if class_name in path:
|
819 |
+
relative_path = path.split(class_name, 1)[-1]
|
820 |
+
if relative_path.startswith('/'):
|
821 |
+
relative_path = relative_path[1:]
|
822 |
+
save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
|
823 |
+
else:
|
824 |
+
save_path = f'visualization/full_data/{class_name}/{path}.png'
|
825 |
+
|
826 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
827 |
+
plt.tight_layout()
|
828 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
829 |
+
plt.close()
|
830 |
+
|
831 |
+
return {"score": score, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
|
832 |
+
|
833 |
+
return {"score": score, "instance_masks": instance_masks}
|
834 |
+
|
835 |
+
|
836 |
+
def process_k_shot(self, class_name, few_shot_samples, few_shot_paths):
|
837 |
+
few_shot_samples = F.interpolate(few_shot_samples, size=(448, 448), mode=self.inter_mode, align_corners=self.align_corners, antialias=self.antialias)
|
838 |
+
|
839 |
+
with torch.no_grad():
|
840 |
+
image_features, patch_tokens, proj_patch_tokens = self.model_clip.encode_image(few_shot_samples, self.feature_list)
|
841 |
+
patch_tokens = [p[:, 1:, :] for p in patch_tokens]
|
842 |
+
patch_tokens = [p.reshape(p.shape[0]*p.shape[1], p.shape[2]) for p in patch_tokens]
|
843 |
+
|
844 |
+
patch_tokens_clip = torch.cat(patch_tokens, dim=-1) # (bs, 1024, 1024x4)
|
845 |
+
# patch_tokens_clip = torch.cat(patch_tokens[2:], dim=-1) # (bs, 1024, 1024x2)
|
846 |
+
patch_tokens_clip = patch_tokens_clip.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
847 |
+
patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
848 |
+
patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
|
849 |
+
patch_tokens_clip = F.normalize(patch_tokens_clip, p=2, dim=-1) # (bsx64x64, 1024x4)
|
850 |
+
|
851 |
+
with torch.no_grad():
|
852 |
+
patch_tokens_dinov2 = self.model_dinov2.forward_features(few_shot_samples, out_layer_list=self.feature_list_dinov2) # 4 x [bs, 32x32, 1024]
|
853 |
+
patch_tokens_dinov2 = torch.cat(patch_tokens_dinov2, dim=-1) # (bs, 1024, 1024x4)
|
854 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
855 |
+
patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
856 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.permute(0, 2, 3, 1).view(-1, self.vision_width_dinov2 * len(self.feature_list_dinov2))
|
857 |
+
patch_tokens_dinov2 = F.normalize(patch_tokens_dinov2, p=2, dim=-1) # (bsx64x64, 1024x4)
|
858 |
+
|
859 |
+
|
860 |
+
cluster_features = None
|
861 |
+
for layer in self.cluster_feature_id:
|
862 |
+
temp_feat = patch_tokens[layer]
|
863 |
+
cluster_features = temp_feat if cluster_features is None else torch.cat((cluster_features, temp_feat), 1)
|
864 |
+
if self.feat_size != self.ori_feat_size:
|
865 |
+
cluster_features = cluster_features.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
866 |
+
cluster_features = F.interpolate(cluster_features, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
867 |
+
cluster_features = cluster_features.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.cluster_feature_id))
|
868 |
+
cluster_features = F.normalize(cluster_features, p=2, dim=-1)
|
869 |
+
|
870 |
+
if self.feat_size != self.ori_feat_size:
|
871 |
+
proj_patch_tokens = proj_patch_tokens.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
872 |
+
proj_patch_tokens = F.interpolate(proj_patch_tokens, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
873 |
+
proj_patch_tokens = proj_patch_tokens.permute(0, 2, 3, 1).view(-1, self.embed_dim)
|
874 |
+
proj_patch_tokens = F.normalize(proj_patch_tokens, p=2, dim=-1)
|
875 |
+
|
876 |
+
if not self.cluster_init:
|
877 |
+
num_clusters = self.cluster_num_dict[class_name]
|
878 |
+
_, self.cluster_centers = kmeans(X=cluster_features, num_clusters=num_clusters, device=self.device)
|
879 |
+
|
880 |
+
self.query_obj = encode_obj_text(self.model_clip, self.query_words_dict[class_name], self.tokenizer, self.device)
|
881 |
+
self.patch_query_obj = encode_obj_text(self.model_clip, self.patch_query_words_dict[class_name], self.tokenizer, self.device)
|
882 |
+
self.classes = self.query_obj.shape[0]
|
883 |
+
|
884 |
+
self.cluster_init = True
|
885 |
+
|
886 |
+
scores = []
|
887 |
+
foreground_pixel_hist = []
|
888 |
+
splicing_connectors_distance = []
|
889 |
+
patch_token_hist = []
|
890 |
+
mem_instance_masks = []
|
891 |
+
|
892 |
+
for image, cluster_feature, proj_patch_token, few_shot_path in zip(few_shot_samples.chunk(self.k_shot), cluster_features.chunk(self.k_shot), proj_patch_tokens.chunk(self.k_shot), few_shot_paths):
|
893 |
+
# path = os.path.dirname(few_shot_path).split('/')[-1] + "_" + os.path.basename(few_shot_path).split('.')[0]
|
894 |
+
self.anomaly_flag = False
|
895 |
+
results = self.histogram(image, cluster_feature, proj_patch_token, class_name, "few_shot_" + os.path.basename(few_shot_path).split('.')[0])
|
896 |
+
if self.class_name == 'pushpins':
|
897 |
+
patch_token_hist.append(results["clip_patch_hist"])
|
898 |
+
mem_instance_masks.append(results['instance_masks'])
|
899 |
+
|
900 |
+
elif self.class_name == 'splicing_connectors':
|
901 |
+
foreground_pixel_hist.append(results["foreground_pixel_count"])
|
902 |
+
splicing_connectors_distance.append(results["distance"])
|
903 |
+
patch_token_hist.append(results["sam_patch_hist"])
|
904 |
+
mem_instance_masks.append(results['instance_masks'])
|
905 |
+
|
906 |
+
elif self.class_name == 'screw_bag':
|
907 |
+
foreground_pixel_hist.append(results["foreground_pixel_count"])
|
908 |
+
patch_token_hist.append(results["clip_patch_hist"])
|
909 |
+
mem_instance_masks.append(results['instance_masks'])
|
910 |
+
|
911 |
+
elif self.class_name == 'breakfast_box':
|
912 |
+
patch_token_hist.append(results["sam_patch_hist"])
|
913 |
+
mem_instance_masks.append(results['instance_masks'])
|
914 |
+
|
915 |
+
elif self.class_name == 'juice_bottle':
|
916 |
+
patch_token_hist.append(results["sam_patch_hist"])
|
917 |
+
mem_instance_masks.append(results['instance_masks'])
|
918 |
+
|
919 |
+
scores.append(results["score"])
|
920 |
+
|
921 |
+
if len(foreground_pixel_hist) != 0:
|
922 |
+
self.foreground_pixel_hist = np.mean(foreground_pixel_hist)
|
923 |
+
if len(splicing_connectors_distance) != 0:
|
924 |
+
self.splicing_connectors_distance = np.mean(splicing_connectors_distance)
|
925 |
+
if len(patch_token_hist) != 0: # patch hist
|
926 |
+
self.patch_token_hist = np.stack(patch_token_hist)
|
927 |
+
if len(mem_instance_masks) != 0:
|
928 |
+
self.mem_instance_masks = mem_instance_masks
|
929 |
+
|
930 |
+
# for interests matching
|
931 |
+
len_feature_list = len(self.feature_list)
|
932 |
+
for idx, batch_mem_patch_feature in enumerate(patch_tokens_clip.chunk(len_feature_list, dim=-1)): # 4 stages batch_mem_patch_feature (bsx64x64, 1024)
|
933 |
+
mem_instance_features = []
|
934 |
+
for mem_patch_feature, mem_instance_masks in zip(batch_mem_patch_feature.chunk(self.k_shot), self.mem_instance_masks): # k shot mem_patch_feature (64x64, 1024)
|
935 |
+
mem_instance_features.extend([mem_patch_feature[mask, :].mean(0, keepdim=True) for mask in mem_instance_masks])
|
936 |
+
mem_instance_features = torch.cat(mem_instance_features, dim=0)
|
937 |
+
mem_instance_features = F.normalize(mem_instance_features, dim=-1) # 4 stages
|
938 |
+
# mem_instance_features_multi_stage.append(mem_instance_features)
|
939 |
+
self.mem_instance_features_multi_stage[idx].append(mem_instance_features)
|
940 |
+
|
941 |
+
|
942 |
+
mem_patch_feature_clip_coreset = patch_tokens_clip
|
943 |
+
mem_patch_feature_dinov2_coreset = patch_tokens_dinov2
|
944 |
+
|
945 |
+
return scores, mem_patch_feature_clip_coreset, mem_patch_feature_dinov2_coreset
|
946 |
+
|
947 |
+
def process(self, class_name: str, few_shot_samples: list[torch.Tensor], few_shot_paths: list[str]):
|
948 |
+
few_shot_samples = self.transform(few_shot_samples).to(self.device)
|
949 |
+
|
950 |
+
scores, mem_patch_feature_clip_coreset, mem_patch_feature_dinov2_coreset = self.process_k_shot(class_name, few_shot_samples, few_shot_paths)
|
951 |
+
|
952 |
+
clip_sampler = KCenterGreedy(embedding=mem_patch_feature_clip_coreset, sampling_ratio=0.25)
|
953 |
+
mem_patch_feature_clip_coreset = clip_sampler.sample_coreset()
|
954 |
+
|
955 |
+
dinov2_sampler = KCenterGreedy(embedding=mem_patch_feature_dinov2_coreset, sampling_ratio=0.25)
|
956 |
+
mem_patch_feature_dinov2_coreset = dinov2_sampler.sample_coreset()
|
957 |
+
|
958 |
+
self.mem_patch_feature_clip_coreset.append(mem_patch_feature_clip_coreset)
|
959 |
+
self.mem_patch_feature_dinov2_coreset.append(mem_patch_feature_dinov2_coreset)
|
960 |
+
|
961 |
+
|
962 |
+
def setup(self, data: dict) -> None:
|
963 |
+
"""Setup the few-shot samples for the model.
|
964 |
+
|
965 |
+
The evaluation script will call this method to pass the k images for few shot learning and the object class
|
966 |
+
name. In the case of MVTec LOCO this will be the dataset category name (e.g. breakfast_box). Please contact
|
967 |
+
the organizing committee if if your model requires any additional dataset-related information at setup-time.
|
968 |
+
"""
|
969 |
+
few_shot_samples = data.get("few_shot_samples")
|
970 |
+
class_name = data.get("dataset_category")
|
971 |
+
few_shot_paths = data.get("few_shot_samples_path")
|
972 |
+
self.class_name = class_name
|
973 |
+
|
974 |
+
print(few_shot_samples.shape)
|
975 |
+
|
976 |
+
self.total_size = few_shot_samples.size(0)
|
977 |
+
|
978 |
+
self.k_shot = 4 if self.total_size > 4 else self.total_size
|
979 |
+
|
980 |
+
self.cluster_init = False
|
981 |
+
self.mem_instance_features_multi_stage = [[],[],[],[]]
|
982 |
+
|
983 |
+
self.mem_patch_feature_clip_coreset = []
|
984 |
+
self.mem_patch_feature_dinov2_coreset = []
|
985 |
+
|
986 |
+
# Check if coreset files already exist
|
987 |
+
clip_file = 'memory_bank/mem_patch_feature_clip_{}.pt'.format(self.class_name)
|
988 |
+
dinov2_file = 'memory_bank/mem_patch_feature_dinov2_{}.pt'.format(self.class_name)
|
989 |
+
instance_file = 'memory_bank/mem_instance_features_multi_stage_{}.pt'.format(self.class_name)
|
990 |
+
|
991 |
+
files_exist = os.path.exists(clip_file) and os.path.exists(dinov2_file) and os.path.exists(instance_file)
|
992 |
+
|
993 |
+
if self.save_coreset_features and not files_exist:
|
994 |
+
print(f"Coreset files not found for {self.class_name}, computing and saving...")
|
995 |
+
for i in range(self.total_size//self.k_shot):
|
996 |
+
self.process(class_name, few_shot_samples[self.k_shot*i : min(self.k_shot*(i+1), self.total_size)], few_shot_paths[self.k_shot*i : min(self.k_shot*(i+1), self.total_size)])
|
997 |
+
|
998 |
+
# Coreset Subsampling
|
999 |
+
self.mem_patch_feature_clip_coreset = torch.cat(self.mem_patch_feature_clip_coreset, dim=0)
|
1000 |
+
torch.save(self.mem_patch_feature_clip_coreset, clip_file)
|
1001 |
+
|
1002 |
+
self.mem_patch_feature_dinov2_coreset = torch.cat(self.mem_patch_feature_dinov2_coreset, dim=0)
|
1003 |
+
torch.save(self.mem_patch_feature_dinov2_coreset, dinov2_file)
|
1004 |
+
|
1005 |
+
print(self.mem_patch_feature_dinov2_coreset.shape, self.mem_patch_feature_clip_coreset.shape)
|
1006 |
+
|
1007 |
+
self.mem_instance_features_multi_stage = [ torch.cat(mem_instance_features, dim=0) for mem_instance_features in self.mem_instance_features_multi_stage ]
|
1008 |
+
self.mem_instance_features_multi_stage = torch.cat(self.mem_instance_features_multi_stage, dim=1)
|
1009 |
+
torch.save(self.mem_instance_features_multi_stage, instance_file)
|
1010 |
+
|
1011 |
+
print(self.mem_instance_features_multi_stage.shape)
|
1012 |
+
|
1013 |
+
elif self.save_coreset_features and files_exist:
|
1014 |
+
print(f"Coreset files found for {self.class_name}, loading existing files...")
|
1015 |
+
self.process(class_name, few_shot_samples[0 : self.k_shot], few_shot_paths[0 : self.k_shot])
|
1016 |
+
|
1017 |
+
self.mem_patch_feature_clip_coreset = torch.load(clip_file)
|
1018 |
+
self.mem_patch_feature_dinov2_coreset = torch.load(dinov2_file)
|
1019 |
+
self.mem_instance_features_multi_stage = torch.load(instance_file)
|
1020 |
+
|
1021 |
+
print(self.mem_patch_feature_dinov2_coreset.shape, self.mem_patch_feature_clip_coreset.shape)
|
1022 |
+
print(self.mem_instance_features_multi_stage.shape)
|
1023 |
+
|
1024 |
+
else:
|
1025 |
+
self.process(class_name, few_shot_samples[0 : self.k_shot], few_shot_paths[0 : self.k_shot])
|
1026 |
+
|
1027 |
+
self.mem_patch_feature_clip_coreset = torch.load(clip_file)
|
1028 |
+
self.mem_patch_feature_dinov2_coreset = torch.load(dinov2_file)
|
1029 |
+
self.mem_instance_features_multi_stage = torch.load(instance_file)
|
1030 |
+
|
1031 |
+
|
1032 |
+
self.few_shot_inited = True
|
1033 |
+
|
1034 |
+
|
model_ensemble_few_shot.py
ADDED
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# Set cache directories to use checkpoint folder for model downloads
|
4 |
+
os.environ['TORCH_HOME'] = './checkpoint'
|
5 |
+
os.environ['HF_HOME'] = './checkpoint/huggingface'
|
6 |
+
os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers'
|
7 |
+
os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub'
|
8 |
+
|
9 |
+
# Create checkpoint subdirectories if they don't exist
|
10 |
+
os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True)
|
11 |
+
os.makedirs('./checkpoint/huggingface/hub', exist_ok=True)
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from torchvision.transforms import v2
|
16 |
+
from torchvision.transforms.v2.functional import resize
|
17 |
+
import cv2
|
18 |
+
import json
|
19 |
+
import torch
|
20 |
+
import random
|
21 |
+
import logging
|
22 |
+
import argparse
|
23 |
+
import numpy as np
|
24 |
+
from PIL import Image
|
25 |
+
from skimage import measure
|
26 |
+
from tabulate import tabulate
|
27 |
+
from torchvision.ops.focal_loss import sigmoid_focal_loss
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torchvision.transforms as transforms
|
30 |
+
import torchvision.transforms.functional as TF
|
31 |
+
from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
|
32 |
+
from sklearn.mixture import GaussianMixture
|
33 |
+
import faiss
|
34 |
+
import open_clip_local as open_clip
|
35 |
+
|
36 |
+
from torch.utils.data.dataset import ConcatDataset
|
37 |
+
from scipy.optimize import linear_sum_assignment
|
38 |
+
from sklearn.random_projection import SparseRandomProjection
|
39 |
+
import cv2
|
40 |
+
from torchvision.transforms import InterpolationMode
|
41 |
+
from PIL import Image
|
42 |
+
import string
|
43 |
+
|
44 |
+
from prompt_ensemble import encode_text_with_prompt_ensemble, encode_normal_text, encode_abnormal_text, encode_general_text, encode_obj_text
|
45 |
+
from kmeans_pytorch import kmeans, kmeans_predict
|
46 |
+
from scipy.optimize import linear_sum_assignment
|
47 |
+
from scipy.stats import norm
|
48 |
+
|
49 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
50 |
+
from matplotlib import pyplot as plt
|
51 |
+
|
52 |
+
import matplotlib
|
53 |
+
matplotlib.use('Agg')
|
54 |
+
|
55 |
+
import pickle
|
56 |
+
from scipy.stats import norm
|
57 |
+
|
58 |
+
from open_clip_local.pos_embed import get_2d_sincos_pos_embed
|
59 |
+
|
60 |
+
def to_np_img(m):
|
61 |
+
m = m.permute(1, 2, 0).cpu().numpy()
|
62 |
+
mean = np.array([[[0.48145466, 0.4578275, 0.40821073]]])
|
63 |
+
std = np.array([[[0.26862954, 0.26130258, 0.27577711]]])
|
64 |
+
m = m * std + mean
|
65 |
+
return np.clip((m * 255.), 0, 255).astype(np.uint8)
|
66 |
+
|
67 |
+
|
68 |
+
def setup_seed(seed):
|
69 |
+
torch.manual_seed(seed)
|
70 |
+
torch.cuda.manual_seed_all(seed)
|
71 |
+
np.random.seed(seed)
|
72 |
+
random.seed(seed)
|
73 |
+
torch.backends.cudnn.deterministic = True
|
74 |
+
torch.backends.cudnn.benchmark = False
|
75 |
+
|
76 |
+
|
77 |
+
class MyModel(nn.Module):
|
78 |
+
"""Example model class for track 2.
|
79 |
+
|
80 |
+
This class applies few-shot anomaly detection using the WinClip model from Anomalib.
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self) -> None:
|
84 |
+
super().__init__()
|
85 |
+
|
86 |
+
setup_seed(42)
|
87 |
+
# NOTE: Create your transformation pipeline (if needed).
|
88 |
+
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
89 |
+
self.transform = v2.Compose(
|
90 |
+
[
|
91 |
+
v2.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
|
92 |
+
],
|
93 |
+
)
|
94 |
+
|
95 |
+
# NOTE: Create your model.
|
96 |
+
|
97 |
+
self.model_clip, _, _ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
|
98 |
+
self.tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
|
99 |
+
self.feature_list = [6, 12, 18, 24]
|
100 |
+
self.embed_dim = 768
|
101 |
+
self.vision_width = 1024
|
102 |
+
|
103 |
+
self.model_sam = sam_model_registry["vit_h"](checkpoint = "./checkpoint/sam_vit_h_4b8939.pth").to(self.device)
|
104 |
+
self.mask_generator = SamAutomaticMaskGenerator(model = self.model_sam)
|
105 |
+
|
106 |
+
self.memory_size = 2048
|
107 |
+
self.n_neighbors = 2
|
108 |
+
|
109 |
+
self.model_clip.eval()
|
110 |
+
self.test_args = None
|
111 |
+
self.align_corners = True # False
|
112 |
+
self.antialias = True # False
|
113 |
+
self.inter_mode = 'bilinear' # bilinear/bicubic
|
114 |
+
|
115 |
+
self.cluster_feature_id = [0, 1]
|
116 |
+
|
117 |
+
self.cluster_num_dict = {
|
118 |
+
"breakfast_box": 3, # unused
|
119 |
+
"juice_bottle": 8, # unused
|
120 |
+
"splicing_connectors": 10, # unused
|
121 |
+
"pushpins": 10,
|
122 |
+
"screw_bag": 10,
|
123 |
+
}
|
124 |
+
self.query_words_dict = {
|
125 |
+
"breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
|
126 |
+
"juice_bottle": ['bottle', ['black background', 'background']],
|
127 |
+
"pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
|
128 |
+
"screw_bag": [['screw'], 'plastic bag', 'background'],
|
129 |
+
"splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
|
130 |
+
}
|
131 |
+
self.foreground_label_idx = { # for query_words_dict
|
132 |
+
"breakfast_box": [0, 1, 2, 3, 4, 5],
|
133 |
+
"juice_bottle": [0],
|
134 |
+
"pushpins": [0],
|
135 |
+
"screw_bag": [0],
|
136 |
+
"splicing_connectors":[0, 1]
|
137 |
+
}
|
138 |
+
|
139 |
+
self.patch_query_words_dict = {
|
140 |
+
"breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
|
141 |
+
"juice_bottle": [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
|
142 |
+
"pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
|
143 |
+
"screw_bag": [['hex screw', 'hexagon bolt'], ['hex nut', 'hexagon nut'], ['ring washer', 'ring gasket'], ['plastic bag', 'background']],
|
144 |
+
"splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
|
145 |
+
}
|
146 |
+
|
147 |
+
|
148 |
+
self.query_threshold_dict = {
|
149 |
+
"breakfast_box": [0., 0., 0., 0., 0., 0., 0.], # unused
|
150 |
+
"juice_bottle": [0., 0., 0.], # unused
|
151 |
+
"splicing_connectors": [0.15, 0.15, 0.15, 0., 0.], # unused
|
152 |
+
"pushpins": [0.2, 0., 0., 0.],
|
153 |
+
"screw_bag": [0., 0., 0.,],
|
154 |
+
}
|
155 |
+
|
156 |
+
self.feat_size = 64
|
157 |
+
self.ori_feat_size = 32
|
158 |
+
|
159 |
+
self.visualization = False
|
160 |
+
|
161 |
+
self.pushpins_count = 15
|
162 |
+
|
163 |
+
self.splicing_connectors_count = [2, 3, 5] # coresponding to yellow, blue, and red
|
164 |
+
self.splicing_connectors_distance = 0
|
165 |
+
self.splicing_connectors_cable_color_query_words_dict = [['yellow cable', 'yellow wire'], ['blue cable', 'blue wire'], ['red cable', 'red wire']]
|
166 |
+
|
167 |
+
self.juice_bottle_liquid_query_words_dict = [['red liquid', 'cherry juice'], ['yellow liquid', 'orange juice'], ['milky liquid']]
|
168 |
+
self.juice_bottle_fruit_query_words_dict = ['cherry', ['tangerine', 'orange'], 'banana']
|
169 |
+
|
170 |
+
# query words
|
171 |
+
self.foreground_pixel_hist = 0
|
172 |
+
# patch query words
|
173 |
+
self.patch_token_hist = []
|
174 |
+
|
175 |
+
self.few_shot_inited = False
|
176 |
+
|
177 |
+
|
178 |
+
from dinov2.dinov2.hub.backbones import dinov2_vitl14
|
179 |
+
self.model_dinov2 = dinov2_vitl14()
|
180 |
+
self.model_dinov2.to(self.device)
|
181 |
+
self.model_dinov2.eval()
|
182 |
+
self.feature_list_dinov2 = [6, 12, 18, 24]
|
183 |
+
self.vision_width_dinov2 = 1024
|
184 |
+
|
185 |
+
self.stats = pickle.load(open("memory_bank/statistic_scores_model_ensemble_few_shot_val.pkl", "rb"))
|
186 |
+
|
187 |
+
self.mem_instance_masks = None
|
188 |
+
|
189 |
+
self.anomaly_flag = False
|
190 |
+
self.validation = False #True #False
|
191 |
+
|
192 |
+
def set_viz(self, viz):
|
193 |
+
self.visualization = viz
|
194 |
+
|
195 |
+
def set_val(self, val):
|
196 |
+
self.validation = val
|
197 |
+
|
198 |
+
def forward(self, batch: torch.Tensor, batch_path: list) -> dict[str, torch.Tensor]:
|
199 |
+
"""Transform the input batch and pass it through the model.
|
200 |
+
|
201 |
+
This model returns a dictionary with the following keys
|
202 |
+
- ``anomaly_map`` - Anomaly map.
|
203 |
+
- ``pred_score`` - Predicted anomaly score.
|
204 |
+
"""
|
205 |
+
self.anomaly_flag = False
|
206 |
+
batch = self.transform(batch).to(self.device)
|
207 |
+
results = self.forward_one_sample(batch, self.mem_patch_feature_clip_coreset, self.mem_patch_feature_dinov2_coreset, batch_path[0])
|
208 |
+
|
209 |
+
hist_score = results['hist_score']
|
210 |
+
structural_score = results['structural_score']
|
211 |
+
instance_hungarian_match_score = results['instance_hungarian_match_score']
|
212 |
+
|
213 |
+
anomaly_map_structural = results['anomaly_map_structural']
|
214 |
+
|
215 |
+
if self.validation:
|
216 |
+
return {"hist_score": torch.tensor(hist_score), "structural_score": torch.tensor(structural_score), "instance_hungarian_match_score": torch.tensor(instance_hungarian_match_score)}
|
217 |
+
|
218 |
+
def sigmoid(z):
|
219 |
+
return 1/(1 + np.exp(-z))
|
220 |
+
|
221 |
+
# standardization
|
222 |
+
standard_structural_score = (structural_score - self.stats[self.class_name]["structural_scores"]["mean"]) / self.stats[self.class_name]["structural_scores"]["unbiased_std"]
|
223 |
+
standard_instance_hungarian_match_score = (instance_hungarian_match_score - self.stats[self.class_name]["instance_hungarian_match_scores"]["mean"]) / self.stats[self.class_name]["instance_hungarian_match_scores"]["unbiased_std"]
|
224 |
+
|
225 |
+
pred_score = max(standard_instance_hungarian_match_score, standard_structural_score)
|
226 |
+
pred_score = sigmoid(pred_score)
|
227 |
+
|
228 |
+
if self.anomaly_flag:
|
229 |
+
pred_score = 1.
|
230 |
+
self.anomaly_flag = False
|
231 |
+
|
232 |
+
return {"pred_score": torch.tensor(pred_score), "anomaly_map": torch.tensor(anomaly_map_structural), "hist_score": torch.tensor(hist_score), "structural_score": torch.tensor(structural_score), "instance_hungarian_match_score": torch.tensor(instance_hungarian_match_score)}
|
233 |
+
|
234 |
+
|
235 |
+
def forward_one_sample(self, batch: torch.Tensor, mem_patch_feature_clip_coreset: torch.Tensor, mem_patch_feature_dinov2_coreset: torch.Tensor, path: str):
|
236 |
+
|
237 |
+
with torch.no_grad():
|
238 |
+
image_features, patch_tokens, proj_patch_tokens = self.model_clip.encode_image(batch, self.feature_list)
|
239 |
+
# image_features /= image_features.norm(dim=-1, keepdim=True)
|
240 |
+
patch_tokens = [p[:, 1:, :] for p in patch_tokens]
|
241 |
+
patch_tokens = [p.reshape(p.shape[0]*p.shape[1], p.shape[2]) for p in patch_tokens]
|
242 |
+
|
243 |
+
patch_tokens_clip = torch.cat(patch_tokens, dim=-1) # (1, 1024, 1024x4)
|
244 |
+
# patch_tokens_clip = torch.cat(patch_tokens[2:], dim=-1) # (1, 1024, 1024x2)
|
245 |
+
patch_tokens_clip = patch_tokens_clip.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
246 |
+
patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
247 |
+
patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
|
248 |
+
patch_tokens_clip = F.normalize(patch_tokens_clip, p=2, dim=-1) # (1x64x64, 1024x4)
|
249 |
+
|
250 |
+
with torch.no_grad():
|
251 |
+
patch_tokens_dinov2 = self.model_dinov2.forward_features(batch, out_layer_list=self.feature_list)
|
252 |
+
patch_tokens_dinov2 = torch.cat(patch_tokens_dinov2, dim=-1) # (1, 1024, 1024x4)
|
253 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
254 |
+
patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
255 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.permute(0, 2, 3, 1).view(-1, self.vision_width_dinov2 * len(self.feature_list_dinov2))
|
256 |
+
patch_tokens_dinov2 = F.normalize(patch_tokens_dinov2, p=2, dim=-1) # (1x64x64, 1024x4)
|
257 |
+
|
258 |
+
'''adding for kmeans seg '''
|
259 |
+
if self.feat_size != self.ori_feat_size:
|
260 |
+
proj_patch_tokens = proj_patch_tokens.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
261 |
+
proj_patch_tokens = F.interpolate(proj_patch_tokens, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
262 |
+
proj_patch_tokens = proj_patch_tokens.permute(0, 2, 3, 1).view(self.feat_size * self.feat_size, self.embed_dim)
|
263 |
+
proj_patch_tokens = F.normalize(proj_patch_tokens, p=2, dim=-1)
|
264 |
+
|
265 |
+
mid_features = None
|
266 |
+
for layer in self.cluster_feature_id:
|
267 |
+
temp_feat = patch_tokens[layer]
|
268 |
+
mid_features = temp_feat if mid_features is None else torch.cat((mid_features, temp_feat), -1)
|
269 |
+
|
270 |
+
if self.feat_size != self.ori_feat_size:
|
271 |
+
mid_features = mid_features.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
272 |
+
mid_features = F.interpolate(mid_features, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
273 |
+
mid_features = mid_features.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.cluster_feature_id))
|
274 |
+
mid_features = F.normalize(mid_features, p=2, dim=-1)
|
275 |
+
|
276 |
+
results = self.histogram(batch, mid_features, proj_patch_tokens, self.class_name, os.path.dirname(path).split('/')[-1] + "_" + os.path.basename(path).split('.')[0])
|
277 |
+
|
278 |
+
hist_score = results['score']
|
279 |
+
|
280 |
+
'''calculate patchcore'''
|
281 |
+
anomaly_maps_patchcore = []
|
282 |
+
|
283 |
+
if self.class_name in ['pushpins', 'screw_bag']: # clip feature for patchcore
|
284 |
+
len_feature_list = len(self.feature_list)
|
285 |
+
for patch_feature, mem_patch_feature in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1), mem_patch_feature_clip_coreset.chunk(len_feature_list, dim=-1)):
|
286 |
+
patch_feature = F.normalize(patch_feature, dim=-1)
|
287 |
+
mem_patch_feature = F.normalize(mem_patch_feature, dim=-1)
|
288 |
+
normal_map_patchcore = (patch_feature @ mem_patch_feature.T)
|
289 |
+
normal_map_patchcore = (normal_map_patchcore.max(1)[0]).cpu().numpy() # 1: normal 0: abnormal
|
290 |
+
anomaly_map_patchcore = 1 - normal_map_patchcore
|
291 |
+
|
292 |
+
anomaly_maps_patchcore.append(anomaly_map_patchcore)
|
293 |
+
|
294 |
+
if self.class_name in ['splicing_connectors', 'breakfast_box', 'juice_bottle']: # dinov2 feature for patchcore
|
295 |
+
len_feature_list = len(self.feature_list_dinov2)
|
296 |
+
for patch_feature, mem_patch_feature in zip(patch_tokens_dinov2.chunk(len_feature_list, dim=-1), mem_patch_feature_dinov2_coreset.chunk(len_feature_list, dim=-1)):
|
297 |
+
patch_feature = F.normalize(patch_feature, dim=-1)
|
298 |
+
mem_patch_feature = F.normalize(mem_patch_feature, dim=-1)
|
299 |
+
normal_map_patchcore = (patch_feature @ mem_patch_feature.T)
|
300 |
+
normal_map_patchcore = (normal_map_patchcore.max(1)[0]).cpu().numpy() # 1: normal 0: abnormal
|
301 |
+
anomaly_map_patchcore = 1 - normal_map_patchcore
|
302 |
+
|
303 |
+
anomaly_maps_patchcore.append(anomaly_map_patchcore)
|
304 |
+
|
305 |
+
structural_score = np.stack(anomaly_maps_patchcore).mean(0).max()
|
306 |
+
anomaly_map_structural = np.stack(anomaly_maps_patchcore).mean(0).reshape(self.feat_size, self.feat_size)
|
307 |
+
|
308 |
+
instance_masks = results["instance_masks"]
|
309 |
+
anomaly_instances_hungarian = []
|
310 |
+
instance_hungarian_match_score = 1.
|
311 |
+
if self.mem_instance_masks is not None and len(instance_masks) != 0:
|
312 |
+
for patch_feature, batch_mem_patch_feature in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1), mem_patch_feature_clip_coreset.chunk(len_feature_list, dim=-1)):
|
313 |
+
instance_features = [patch_feature[mask, :].mean(0, keepdim=True) for mask in instance_masks]
|
314 |
+
instance_features = torch.cat(instance_features, dim=0)
|
315 |
+
instance_features = F.normalize(instance_features, dim=-1)
|
316 |
+
mem_instance_features = []
|
317 |
+
for mem_patch_feature, mem_instance_masks in zip(batch_mem_patch_feature.chunk(self.k_shot), self.mem_instance_masks):
|
318 |
+
mem_instance_features.extend([mem_patch_feature[mask, :].mean(0, keepdim=True) for mask in mem_instance_masks])
|
319 |
+
mem_instance_features = torch.cat(mem_instance_features, dim=0)
|
320 |
+
mem_instance_features = F.normalize(mem_instance_features, dim=-1)
|
321 |
+
|
322 |
+
normal_instance_hungarian = (instance_features @ mem_instance_features.T)
|
323 |
+
cost_matrix = (1 - normal_instance_hungarian).cpu().numpy()
|
324 |
+
|
325 |
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
326 |
+
cost = cost_matrix[row_ind, col_ind].sum()
|
327 |
+
cost = cost / min(cost_matrix.shape)
|
328 |
+
anomaly_instances_hungarian.append(cost)
|
329 |
+
|
330 |
+
instance_hungarian_match_score = np.mean(anomaly_instances_hungarian)
|
331 |
+
|
332 |
+
results = {'hist_score': hist_score, 'structural_score': structural_score, 'instance_hungarian_match_score': instance_hungarian_match_score, "anomaly_map_structural": anomaly_map_structural}
|
333 |
+
|
334 |
+
return results
|
335 |
+
|
336 |
+
|
337 |
+
def histogram(self, image, cluster_feature, proj_patch_token, class_name, path):
|
338 |
+
def plot_results_only(sorted_anns):
|
339 |
+
cur = 1
|
340 |
+
img_color = np.zeros((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1]))
|
341 |
+
for ann in sorted_anns:
|
342 |
+
m = ann['segmentation']
|
343 |
+
img_color[m] = cur
|
344 |
+
cur += 1
|
345 |
+
return img_color
|
346 |
+
|
347 |
+
def merge_segmentations(a, b, background_class):
|
348 |
+
unique_labels_a = np.unique(a)
|
349 |
+
unique_labels_b = np.unique(b)
|
350 |
+
|
351 |
+
max_label_a = unique_labels_a.max()
|
352 |
+
label_map = np.zeros(max_label_a + 1, dtype=int)
|
353 |
+
|
354 |
+
for label_a in unique_labels_a:
|
355 |
+
mask_a = (a == label_a)
|
356 |
+
|
357 |
+
labels_b = b[mask_a]
|
358 |
+
if labels_b.size > 0:
|
359 |
+
count_b = np.bincount(labels_b, minlength=unique_labels_b.max() + 1)
|
360 |
+
label_map[label_a] = np.argmax(count_b)
|
361 |
+
else:
|
362 |
+
label_map[label_a] = background_class # default background
|
363 |
+
|
364 |
+
merged_a = label_map[a]
|
365 |
+
return merged_a
|
366 |
+
|
367 |
+
pseudo_labels = kmeans_predict(cluster_feature, self.cluster_centers, 'euclidean', device=self.device)
|
368 |
+
kmeans_mask = torch.ones_like(pseudo_labels) * (self.classes - 1) # default to background
|
369 |
+
|
370 |
+
for pl in pseudo_labels.unique():
|
371 |
+
mask = (pseudo_labels == pl).reshape(-1)
|
372 |
+
# filter small region
|
373 |
+
binary = mask.cpu().numpy().reshape(self.feat_size, self.feat_size).astype(np.uint8)
|
374 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
375 |
+
for i in range(1, num_labels):
|
376 |
+
temp_mask = labels == i
|
377 |
+
if np.sum(temp_mask) <= 8:
|
378 |
+
mask[temp_mask.reshape(-1)] = False
|
379 |
+
|
380 |
+
if mask.any():
|
381 |
+
region_feature = proj_patch_token[mask, :].mean(0, keepdim=True)
|
382 |
+
similarity = (region_feature @ self.query_obj.T)
|
383 |
+
prob, index = torch.max(similarity, dim=-1)
|
384 |
+
temp_label = index.squeeze(0).item()
|
385 |
+
temp_prob = prob.squeeze(0).item()
|
386 |
+
if temp_prob > self.query_threshold_dict[class_name][temp_label]: # threshold for each class
|
387 |
+
kmeans_mask[mask] = temp_label
|
388 |
+
|
389 |
+
|
390 |
+
raw_image = to_np_img(image[0])
|
391 |
+
height, width = raw_image.shape[:2]
|
392 |
+
masks = self.mask_generator.generate(raw_image)
|
393 |
+
# self.predictor.set_image(raw_image)
|
394 |
+
|
395 |
+
kmeans_label = pseudo_labels.view(self.feat_size, self.feat_size).cpu().numpy()
|
396 |
+
kmeans_mask = kmeans_mask.view(self.feat_size, self.feat_size).cpu().numpy()
|
397 |
+
|
398 |
+
patch_similarity = (proj_patch_token @ self.patch_query_obj.T)
|
399 |
+
patch_mask = patch_similarity.argmax(-1)
|
400 |
+
patch_mask = patch_mask.view(self.feat_size, self.feat_size).cpu().numpy()
|
401 |
+
|
402 |
+
sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
|
403 |
+
sam_mask = plot_results_only(sorted_masks).astype(np.int)
|
404 |
+
|
405 |
+
resized_mask = cv2.resize(kmeans_mask, (width, height), interpolation = cv2.INTER_NEAREST)
|
406 |
+
merge_sam = merge_segmentations(sam_mask, resized_mask, background_class=self.classes-1)
|
407 |
+
|
408 |
+
resized_patch_mask = cv2.resize(patch_mask, (width, height), interpolation = cv2.INTER_NEAREST)
|
409 |
+
patch_merge_sam = merge_segmentations(sam_mask, resized_patch_mask, background_class=self.patch_query_obj.shape[0]-1)
|
410 |
+
|
411 |
+
# filter small region for merge sam
|
412 |
+
binary = np.isin(merge_sam, self.foreground_label_idx[self.class_name]).astype(np.uint8) # foreground 1 background 0
|
413 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
414 |
+
for i in range(1, num_labels):
|
415 |
+
temp_mask = labels == i
|
416 |
+
if np.sum(temp_mask) <= 32: # 448x448
|
417 |
+
merge_sam[temp_mask] = self.classes - 1 # set to background
|
418 |
+
|
419 |
+
# filter small region for patch merge sam
|
420 |
+
binary = (patch_merge_sam != (self.patch_query_obj.shape[0]-1) ).astype(np.uint8) # foreground 1 background 0
|
421 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
422 |
+
for i in range(1, num_labels):
|
423 |
+
temp_mask = labels == i
|
424 |
+
if np.sum(temp_mask) <= 32: # 448x448
|
425 |
+
patch_merge_sam[temp_mask] = self.patch_query_obj.shape[0]-1 # set to background
|
426 |
+
|
427 |
+
score = 0. # default to normal
|
428 |
+
self.anomaly_flag = False
|
429 |
+
instance_masks = []
|
430 |
+
if self.class_name == 'pushpins':
|
431 |
+
# object count hist
|
432 |
+
kernel = np.ones((3, 3), dtype=np.uint8) # dilate for robustness
|
433 |
+
binary = np.isin(merge_sam, self.foreground_label_idx[self.class_name]).astype(np.uint8) # foreground 1 background 0
|
434 |
+
dilate_binary = cv2.dilate(binary, kernel)
|
435 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(dilate_binary, connectivity=8)
|
436 |
+
pushpins_count = num_labels - 1 # number of pushpins
|
437 |
+
|
438 |
+
for i in range(1, num_labels):
|
439 |
+
instance_mask = (labels == i).astype(np.uint8)
|
440 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
441 |
+
if instance_mask.any():
|
442 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
443 |
+
|
444 |
+
if self.few_shot_inited and pushpins_count != self.pushpins_count and self.anomaly_flag is False:
|
445 |
+
self.anomaly_flag = True
|
446 |
+
print('number of pushpins: {}, but canonical number of pushpins: {}'.format(pushpins_count, self.pushpins_count))
|
447 |
+
|
448 |
+
# patch hist
|
449 |
+
clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
450 |
+
clip_patch_hist = clip_patch_hist / np.linalg.norm(clip_patch_hist)
|
451 |
+
|
452 |
+
if self.few_shot_inited:
|
453 |
+
patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
|
454 |
+
score = 1 - patch_hist_similarity.max()
|
455 |
+
|
456 |
+
binary_foreground = dilate_binary.astype(np.uint8)
|
457 |
+
|
458 |
+
if len(instance_masks) != 0:
|
459 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
460 |
+
|
461 |
+
if self.visualization:
|
462 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
463 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
464 |
+
plt.figure(figsize=(20, 3))
|
465 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
466 |
+
plt.subplot(1, len(image_list), ind)
|
467 |
+
plt.imshow(temp_image)
|
468 |
+
plt.title(temp_title)
|
469 |
+
plt.margins(0, 0)
|
470 |
+
plt.axis('off')
|
471 |
+
# Extract relative path from class_name onwards
|
472 |
+
if class_name in path:
|
473 |
+
relative_path = path.split(class_name, 1)[-1]
|
474 |
+
if relative_path.startswith('/'):
|
475 |
+
relative_path = relative_path[1:]
|
476 |
+
save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
|
477 |
+
else:
|
478 |
+
save_path = f'visualization/few_shot/{class_name}/{path}.png'
|
479 |
+
|
480 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
481 |
+
plt.tight_layout()
|
482 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
483 |
+
plt.close()
|
484 |
+
|
485 |
+
# todo: same number in total but in different boxes or broken box
|
486 |
+
return {"score": score, "clip_patch_hist": clip_patch_hist, "instance_masks": instance_masks}
|
487 |
+
|
488 |
+
elif self.class_name == 'splicing_connectors':
|
489 |
+
# object count hist for default
|
490 |
+
sam_mask_max_area = sorted_masks[0]['segmentation'] # background
|
491 |
+
binary = (sam_mask_max_area == 0).astype(np.uint8) # sam_mask_max_area is background, background 0 foreground 1
|
492 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
493 |
+
count = 0
|
494 |
+
for i in range(1, num_labels):
|
495 |
+
temp_mask = labels == i
|
496 |
+
if np.sum(temp_mask) <= 64: # 448x448 64
|
497 |
+
binary[temp_mask] = 0 # set to background
|
498 |
+
else:
|
499 |
+
count += 1
|
500 |
+
if count != 1 and self.anomaly_flag is False: # cable cut or no cable or no connector
|
501 |
+
print('number of connected component in splicing_connectors: {}, but the default connected component is 1.'.format(count))
|
502 |
+
self.anomaly_flag = True
|
503 |
+
|
504 |
+
merge_sam[~(binary.astype(np.bool))] = self.query_obj.shape[0] - 1 # remove noise
|
505 |
+
patch_merge_sam[~(binary.astype(np.bool))] = self.patch_query_obj.shape[0] - 1 # remove patch noise
|
506 |
+
|
507 |
+
# erode the cable and divide into left and right parts
|
508 |
+
kernel = np.ones((23, 23), dtype=np.uint8)
|
509 |
+
erode_binary = cv2.erode(binary, kernel)
|
510 |
+
h, w = erode_binary.shape
|
511 |
+
distance = 0
|
512 |
+
|
513 |
+
left, right = erode_binary[:, :int(w/2)], erode_binary[:, int(w/2):]
|
514 |
+
left_count = np.bincount(left.reshape(-1), minlength=self.classes)[1] # foreground
|
515 |
+
right_count = np.bincount(right.reshape(-1), minlength=self.classes)[1] # foreground
|
516 |
+
|
517 |
+
binary_cable = (patch_merge_sam == 1).astype(np.uint8)
|
518 |
+
|
519 |
+
kernel = np.ones((5, 5), dtype=np.uint8)
|
520 |
+
binary_cable = cv2.erode(binary_cable, kernel)
|
521 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_cable, connectivity=8)
|
522 |
+
for i in range(1, num_labels):
|
523 |
+
temp_mask = labels == i
|
524 |
+
if np.sum(temp_mask) <= 64: # 448x448
|
525 |
+
binary_cable[temp_mask] = 0 # set to background
|
526 |
+
|
527 |
+
|
528 |
+
binary_cable = cv2.resize(binary_cable, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
529 |
+
|
530 |
+
binary_clamps = (patch_merge_sam == 0).astype(np.uint8)
|
531 |
+
|
532 |
+
kernel = np.ones((5, 5), dtype=np.uint8)
|
533 |
+
binary_clamps = cv2.erode(binary_clamps, kernel)
|
534 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_clamps, connectivity=8)
|
535 |
+
for i in range(1, num_labels):
|
536 |
+
temp_mask = labels == i
|
537 |
+
if np.sum(temp_mask) <= 64: # 448x448
|
538 |
+
binary_clamps[temp_mask] = 0 # set to background
|
539 |
+
else:
|
540 |
+
instance_mask = temp_mask.astype(np.uint8)
|
541 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
542 |
+
if instance_mask.any():
|
543 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
544 |
+
|
545 |
+
binary_clamps = cv2.resize(binary_clamps, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
546 |
+
|
547 |
+
binary_connector = cv2.resize(binary, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
548 |
+
|
549 |
+
query_cable_color = encode_obj_text(self.model_clip, self.splicing_connectors_cable_color_query_words_dict, self.tokenizer, self.device)
|
550 |
+
cable_feature = proj_patch_token[binary_cable.astype(np.bool).reshape(-1), :].mean(0, keepdim=True)
|
551 |
+
idx_color = (cable_feature @ query_cable_color.T).argmax(-1).squeeze(0).item()
|
552 |
+
foreground_pixel_count = np.sum(erode_binary) / self.splicing_connectors_count[idx_color]
|
553 |
+
|
554 |
+
|
555 |
+
slice_cable = binary[:, int(w/2)-1: int(w/2)+1]
|
556 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(slice_cable, connectivity=8)
|
557 |
+
cable_count = num_labels - 1
|
558 |
+
if cable_count != 1 and self.anomaly_flag is False: # two cables
|
559 |
+
print('number of cable count in splicing_connectors: {}, but the default cable count is 1.'.format(cable_count))
|
560 |
+
self.anomaly_flag = True
|
561 |
+
|
562 |
+
# {2-clamp: yellow 3-clamp: blue 5-clamp: red} cable color and clamp number mismatch
|
563 |
+
if self.few_shot_inited and self.foreground_pixel_hist != 0 and self.anomaly_flag is False:
|
564 |
+
ratio = foreground_pixel_count / self.foreground_pixel_hist
|
565 |
+
if (ratio > 1.2 or ratio < 0.8) and self.anomaly_flag is False: # color and number mismatch
|
566 |
+
print('cable color and number of clamps mismatch, cable color idx: {} (0: yellow 2-clamp, 1: blue 3-clamp, 2: red 5-clamp), foreground_pixel_count :{}, canonical foreground_pixel_hist: {}.'.format(idx_color, foreground_pixel_count, self.foreground_pixel_hist))
|
567 |
+
self.anomaly_flag = True
|
568 |
+
|
569 |
+
# left right hist for symmetry
|
570 |
+
ratio = np.sum(left_count) / (np.sum(right_count) + 1e-5)
|
571 |
+
if self.few_shot_inited and (ratio > 1.2 or ratio < 0.8) and self.anomaly_flag is False: # left right asymmetry in clamp
|
572 |
+
print('left and right connectors are not symmetry.')
|
573 |
+
self.anomaly_flag = True
|
574 |
+
|
575 |
+
# left and right centroids distance
|
576 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(erode_binary, connectivity=8)
|
577 |
+
if num_labels - 1 == 2:
|
578 |
+
centroids = centroids[1:]
|
579 |
+
x1, y1 = centroids[0]
|
580 |
+
x2, y2 = centroids[1]
|
581 |
+
distance = np.sqrt((x1/w - x2/w)**2 + (y1/h - y2/h)**2)
|
582 |
+
if self.few_shot_inited and self.splicing_connectors_distance != 0 and self.anomaly_flag is False:
|
583 |
+
ratio = distance / self.splicing_connectors_distance
|
584 |
+
if ratio < 0.6 or ratio > 1.4: # too short or too long centroids distance (cable) # 0.6 1.4
|
585 |
+
print('cable is too short or too long.')
|
586 |
+
self.anomaly_flag = True
|
587 |
+
|
588 |
+
# patch hist
|
589 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])#[:-1] # ignore background (grid) for statistic
|
590 |
+
sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
|
591 |
+
|
592 |
+
if self.few_shot_inited:
|
593 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
594 |
+
score = 1 - patch_hist_similarity.max()
|
595 |
+
|
596 |
+
# todo mismatch cable link
|
597 |
+
binary_foreground = binary.astype(np.uint8) # only 1 instance, so additionally seperate cable and clamps
|
598 |
+
if binary_connector.any():
|
599 |
+
instance_masks.append(binary_connector.astype(np.bool).reshape(-1))
|
600 |
+
if binary_clamps.any():
|
601 |
+
instance_masks.append(binary_clamps.astype(np.bool).reshape(-1))
|
602 |
+
if binary_cable.any():
|
603 |
+
instance_masks.append(binary_cable.astype(np.bool).reshape(-1))
|
604 |
+
|
605 |
+
if len(instance_masks) != 0:
|
606 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
607 |
+
|
608 |
+
if self.visualization:
|
609 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, binary_connector, merge_sam, patch_merge_sam, erode_binary, binary_cable, binary_clamps]
|
610 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'binary_connector', 'merge sam', 'patch merge sam', 'erode binary', 'binary_cable', 'binary_clamps']
|
611 |
+
plt.figure(figsize=(25, 3))
|
612 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
613 |
+
plt.subplot(1, len(image_list), ind)
|
614 |
+
plt.imshow(temp_image)
|
615 |
+
plt.title(temp_title)
|
616 |
+
plt.margins(0, 0)
|
617 |
+
plt.axis('off')
|
618 |
+
# Extract relative path from class_name onwards
|
619 |
+
if class_name in path:
|
620 |
+
relative_path = path.split(class_name, 1)[-1]
|
621 |
+
if relative_path.startswith('/'):
|
622 |
+
relative_path = relative_path[1:]
|
623 |
+
save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
|
624 |
+
else:
|
625 |
+
save_path = f'visualization/few_shot/{class_name}/{path}.png'
|
626 |
+
|
627 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
628 |
+
plt.tight_layout()
|
629 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
630 |
+
plt.close()
|
631 |
+
|
632 |
+
return {"score": score, "foreground_pixel_count": foreground_pixel_count, "distance": distance, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
|
633 |
+
|
634 |
+
elif self.class_name == 'screw_bag':
|
635 |
+
# pixel hist of kmeans mask
|
636 |
+
foreground_pixel_count = np.sum(np.bincount(kmeans_mask.reshape(-1))[:len(self.foreground_label_idx[self.class_name])]) # foreground pixel
|
637 |
+
if self.few_shot_inited and self.foreground_pixel_hist != 0 and self.anomaly_flag is False:
|
638 |
+
ratio = foreground_pixel_count / self.foreground_pixel_hist
|
639 |
+
# todo: optimize
|
640 |
+
if ratio < 0.94 or ratio > 1.06:
|
641 |
+
print('foreground pixel histagram of screw bag: {}, the canonical foreground pixel histogram of screw bag in few shot: {}'.format(foreground_pixel_count, self.foreground_pixel_hist))
|
642 |
+
self.anomaly_flag = True
|
643 |
+
|
644 |
+
# patch hist
|
645 |
+
binary_screw = np.isin(kmeans_mask, self.foreground_label_idx[self.class_name])
|
646 |
+
patch_mask[~binary_screw] = self.patch_query_obj.shape[0] - 1 # remove patch noise
|
647 |
+
resized_binary_screw = cv2.resize(binary_screw.astype(np.uint8), (patch_merge_sam.shape[1], patch_merge_sam.shape[0]), interpolation = cv2.INTER_NEAREST)
|
648 |
+
patch_merge_sam[~(resized_binary_screw.astype(np.bool))] = self.patch_query_obj.shape[0] - 1 # remove patch noise
|
649 |
+
|
650 |
+
clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])[:-1]
|
651 |
+
clip_patch_hist = clip_patch_hist / np.linalg.norm(clip_patch_hist)
|
652 |
+
|
653 |
+
if self.few_shot_inited:
|
654 |
+
patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
|
655 |
+
score = 1 - patch_hist_similarity.max()
|
656 |
+
|
657 |
+
# # todo: count of screw, nut and washer, screw of different length
|
658 |
+
binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1)).astype(np.uint8)
|
659 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
|
660 |
+
for i in range(1, num_labels):
|
661 |
+
instance_mask = (labels == i).astype(np.uint8)
|
662 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
663 |
+
if instance_mask.any():
|
664 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
665 |
+
|
666 |
+
if len(instance_masks) != 0:
|
667 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
668 |
+
|
669 |
+
if self.visualization:
|
670 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
671 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
672 |
+
plt.figure(figsize=(20, 3))
|
673 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
674 |
+
plt.subplot(1, len(image_list), ind)
|
675 |
+
plt.imshow(temp_image)
|
676 |
+
plt.title(temp_title)
|
677 |
+
plt.margins(0, 0)
|
678 |
+
plt.axis('off')
|
679 |
+
# Extract relative path from class_name onwards
|
680 |
+
if class_name in path:
|
681 |
+
relative_path = path.split(class_name, 1)[-1]
|
682 |
+
if relative_path.startswith('/'):
|
683 |
+
relative_path = relative_path[1:]
|
684 |
+
save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
|
685 |
+
else:
|
686 |
+
save_path = f'visualization/few_shot/{class_name}/{path}.png'
|
687 |
+
|
688 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
689 |
+
plt.tight_layout()
|
690 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
691 |
+
plt.close()
|
692 |
+
|
693 |
+
return {"score": score, "foreground_pixel_count": foreground_pixel_count, "clip_patch_hist": clip_patch_hist, "instance_masks": instance_masks}
|
694 |
+
|
695 |
+
elif self.class_name == 'breakfast_box':
|
696 |
+
# patch hist
|
697 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
698 |
+
sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
|
699 |
+
|
700 |
+
if self.few_shot_inited:
|
701 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
702 |
+
score = 1 - patch_hist_similarity.max()
|
703 |
+
|
704 |
+
# todo: exist of foreground
|
705 |
+
|
706 |
+
binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1)).astype(np.uint8)
|
707 |
+
|
708 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
|
709 |
+
for i in range(1, num_labels):
|
710 |
+
instance_mask = (labels == i).astype(np.uint8)
|
711 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
712 |
+
if instance_mask.any():
|
713 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
714 |
+
|
715 |
+
if len(instance_masks) != 0:
|
716 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
717 |
+
|
718 |
+
if self.visualization:
|
719 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
720 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
721 |
+
plt.figure(figsize=(20, 3))
|
722 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
723 |
+
plt.subplot(1, len(image_list), ind)
|
724 |
+
plt.imshow(temp_image)
|
725 |
+
plt.title(temp_title)
|
726 |
+
plt.margins(0, 0)
|
727 |
+
plt.axis('off')
|
728 |
+
# Extract relative path from class_name onwards
|
729 |
+
if class_name in path:
|
730 |
+
relative_path = path.split(class_name, 1)[-1]
|
731 |
+
if relative_path.startswith('/'):
|
732 |
+
relative_path = relative_path[1:]
|
733 |
+
save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
|
734 |
+
else:
|
735 |
+
save_path = f'visualization/few_shot/{class_name}/{path}.png'
|
736 |
+
|
737 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
738 |
+
plt.tight_layout()
|
739 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
740 |
+
plt.close()
|
741 |
+
|
742 |
+
return {"score": score, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
|
743 |
+
|
744 |
+
elif self.class_name == 'juice_bottle':
|
745 |
+
# remove noise due to non sam mask
|
746 |
+
merge_sam[sam_mask == 0] = self.classes - 1
|
747 |
+
patch_merge_sam[sam_mask == 0] = self.patch_query_obj.shape[0] - 1 # 79.5
|
748 |
+
|
749 |
+
# [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
|
750 |
+
# fruit and liquid mismatch (todo if exist)
|
751 |
+
resized_patch_merge_sam = cv2.resize(patch_merge_sam, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
752 |
+
binary_liquid = (resized_patch_merge_sam == 1)
|
753 |
+
binary_fruit = (resized_patch_merge_sam == 2)
|
754 |
+
|
755 |
+
query_liquid = encode_obj_text(self.model_clip, self.juice_bottle_liquid_query_words_dict, self.tokenizer, self.device)
|
756 |
+
query_fruit = encode_obj_text(self.model_clip, self.juice_bottle_fruit_query_words_dict, self.tokenizer, self.device)
|
757 |
+
|
758 |
+
liquid_feature = proj_patch_token[binary_liquid.reshape(-1), :].mean(0, keepdim=True)
|
759 |
+
liquid_idx = (liquid_feature @ query_liquid.T).argmax(-1).squeeze(0).item()
|
760 |
+
|
761 |
+
fruit_feature = proj_patch_token[binary_fruit.reshape(-1), :].mean(0, keepdim=True)
|
762 |
+
fruit_idx = (fruit_feature @ query_fruit.T).argmax(-1).squeeze(0).item()
|
763 |
+
|
764 |
+
if (liquid_idx != fruit_idx) and self.anomaly_flag is False:
|
765 |
+
print('liquid: {}, but fruit: {}.'.format(self.juice_bottle_liquid_query_words_dict[liquid_idx], self.juice_bottle_fruit_query_words_dict[fruit_idx]))
|
766 |
+
self.anomaly_flag = True
|
767 |
+
|
768 |
+
# # todo centroid of fruit and tag_0 mismatch (if exist) , only one tag, center
|
769 |
+
|
770 |
+
# patch hist
|
771 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
772 |
+
sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
|
773 |
+
|
774 |
+
if self.few_shot_inited:
|
775 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
776 |
+
score = 1 - patch_hist_similarity.max()
|
777 |
+
|
778 |
+
binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1) ).astype(np.uint8)
|
779 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
|
780 |
+
for i in range(1, num_labels):
|
781 |
+
instance_mask = (labels == i).astype(np.uint8)
|
782 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
783 |
+
if instance_mask.any():
|
784 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
785 |
+
|
786 |
+
if len(instance_masks) != 0:
|
787 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
788 |
+
|
789 |
+
if self.visualization:
|
790 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
791 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
792 |
+
plt.figure(figsize=(20, 3))
|
793 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
794 |
+
plt.subplot(1, len(image_list), ind)
|
795 |
+
plt.imshow(temp_image)
|
796 |
+
plt.title(temp_title)
|
797 |
+
plt.margins(0, 0)
|
798 |
+
plt.axis('off')
|
799 |
+
# Extract relative path from class_name onwards
|
800 |
+
if class_name in path:
|
801 |
+
relative_path = path.split(class_name, 1)[-1]
|
802 |
+
if relative_path.startswith('/'):
|
803 |
+
relative_path = relative_path[1:]
|
804 |
+
save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
|
805 |
+
else:
|
806 |
+
save_path = f'visualization/few_shot/{class_name}/{path}.png'
|
807 |
+
|
808 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
809 |
+
plt.tight_layout()
|
810 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
811 |
+
plt.close()
|
812 |
+
|
813 |
+
return {"score": score, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
|
814 |
+
|
815 |
+
return {"score": score, "instance_masks": instance_masks}
|
816 |
+
|
817 |
+
|
818 |
+
def process_k_shot(self, class_name, few_shot_samples, few_shot_paths):
|
819 |
+
few_shot_samples = F.interpolate(few_shot_samples, size=(448, 448), mode=self.inter_mode, align_corners=self.align_corners, antialias=self.antialias)
|
820 |
+
|
821 |
+
with torch.no_grad():
|
822 |
+
image_features, patch_tokens, proj_patch_tokens = self.model_clip.encode_image(few_shot_samples, self.feature_list)
|
823 |
+
patch_tokens = [p[:, 1:, :] for p in patch_tokens]
|
824 |
+
patch_tokens = [p.reshape(p.shape[0]*p.shape[1], p.shape[2]) for p in patch_tokens]
|
825 |
+
|
826 |
+
patch_tokens_clip = torch.cat(patch_tokens, dim=-1) # (bs, 1024, 1024x4)
|
827 |
+
# patch_tokens_clip = torch.cat(patch_tokens[2:], dim=-1) # (bs, 1024, 1024x2)
|
828 |
+
patch_tokens_clip = patch_tokens_clip.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
829 |
+
patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
830 |
+
patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
|
831 |
+
patch_tokens_clip = F.normalize(patch_tokens_clip, p=2, dim=-1) # (bsx64x64, 1024x4)
|
832 |
+
|
833 |
+
with torch.no_grad():
|
834 |
+
patch_tokens_dinov2 = self.model_dinov2.forward_features(few_shot_samples, out_layer_list=self.feature_list_dinov2) # 4 x [bs, 32x32, 1024]
|
835 |
+
patch_tokens_dinov2 = torch.cat(patch_tokens_dinov2, dim=-1) # (bs, 1024, 1024x4)
|
836 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
837 |
+
patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
838 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.permute(0, 2, 3, 1).view(-1, self.vision_width_dinov2 * len(self.feature_list_dinov2))
|
839 |
+
patch_tokens_dinov2 = F.normalize(patch_tokens_dinov2, p=2, dim=-1) # (bsx64x64, 1024x4)
|
840 |
+
|
841 |
+
cluster_features = None
|
842 |
+
for layer in self.cluster_feature_id:
|
843 |
+
temp_feat = patch_tokens[layer]
|
844 |
+
cluster_features = temp_feat if cluster_features is None else torch.cat((cluster_features, temp_feat), 1)
|
845 |
+
if self.feat_size != self.ori_feat_size:
|
846 |
+
cluster_features = cluster_features.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
847 |
+
cluster_features = F.interpolate(cluster_features, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
848 |
+
cluster_features = cluster_features.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.cluster_feature_id))
|
849 |
+
cluster_features = F.normalize(cluster_features, p=2, dim=-1)
|
850 |
+
|
851 |
+
if self.feat_size != self.ori_feat_size:
|
852 |
+
proj_patch_tokens = proj_patch_tokens.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
853 |
+
proj_patch_tokens = F.interpolate(proj_patch_tokens, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
854 |
+
proj_patch_tokens = proj_patch_tokens.permute(0, 2, 3, 1).view(-1, self.embed_dim)
|
855 |
+
proj_patch_tokens = F.normalize(proj_patch_tokens, p=2, dim=-1)
|
856 |
+
|
857 |
+
num_clusters = self.cluster_num_dict[class_name]
|
858 |
+
_, self.cluster_centers = kmeans(X=cluster_features, num_clusters=num_clusters, device=self.device)
|
859 |
+
|
860 |
+
self.query_obj = encode_obj_text(self.model_clip, self.query_words_dict[class_name], self.tokenizer, self.device)
|
861 |
+
self.patch_query_obj = encode_obj_text(self.model_clip, self.patch_query_words_dict[class_name], self.tokenizer, self.device)
|
862 |
+
self.classes = self.query_obj.shape[0]
|
863 |
+
|
864 |
+
scores = []
|
865 |
+
foreground_pixel_hist = []
|
866 |
+
splicing_connectors_distance = []
|
867 |
+
patch_token_hist = []
|
868 |
+
mem_instance_masks = []
|
869 |
+
|
870 |
+
for image, cluster_feature, proj_patch_token, few_shot_path in zip(few_shot_samples.chunk(self.k_shot), cluster_features.chunk(self.k_shot), proj_patch_tokens.chunk(self.k_shot), few_shot_paths):
|
871 |
+
# path = os.path.dirname(few_shot_path).split('/')[-1] + "_" + os.path.basename(few_shot_path).split('.')[0]
|
872 |
+
self.anomaly_flag = False
|
873 |
+
results = self.histogram(image, cluster_feature, proj_patch_token, class_name, "few_shot_" + os.path.basename(few_shot_path).split('.')[0])
|
874 |
+
if self.class_name == 'pushpins':
|
875 |
+
patch_token_hist.append(results["clip_patch_hist"])
|
876 |
+
mem_instance_masks.append(results['instance_masks'])
|
877 |
+
|
878 |
+
elif self.class_name == 'splicing_connectors':
|
879 |
+
foreground_pixel_hist.append(results["foreground_pixel_count"])
|
880 |
+
splicing_connectors_distance.append(results["distance"])
|
881 |
+
patch_token_hist.append(results["sam_patch_hist"])
|
882 |
+
mem_instance_masks.append(results['instance_masks'])
|
883 |
+
|
884 |
+
elif self.class_name == 'screw_bag':
|
885 |
+
foreground_pixel_hist.append(results["foreground_pixel_count"])
|
886 |
+
patch_token_hist.append(results["clip_patch_hist"])
|
887 |
+
mem_instance_masks.append(results['instance_masks'])
|
888 |
+
|
889 |
+
elif self.class_name == 'breakfast_box':
|
890 |
+
patch_token_hist.append(results["sam_patch_hist"])
|
891 |
+
mem_instance_masks.append(results['instance_masks'])
|
892 |
+
|
893 |
+
elif self.class_name == 'juice_bottle':
|
894 |
+
patch_token_hist.append(results["sam_patch_hist"])
|
895 |
+
mem_instance_masks.append(results['instance_masks'])
|
896 |
+
|
897 |
+
scores.append(results["score"])
|
898 |
+
|
899 |
+
if len(foreground_pixel_hist) != 0:
|
900 |
+
self.foreground_pixel_hist = np.mean(foreground_pixel_hist)
|
901 |
+
if len(splicing_connectors_distance) != 0:
|
902 |
+
self.splicing_connectors_distance = np.mean(splicing_connectors_distance)
|
903 |
+
if len(patch_token_hist) != 0: # patch hist
|
904 |
+
self.patch_token_hist = np.stack(patch_token_hist)
|
905 |
+
if len(mem_instance_masks) != 0:
|
906 |
+
self.mem_instance_masks = mem_instance_masks
|
907 |
+
|
908 |
+
mem_patch_feature_clip_coreset = patch_tokens_clip
|
909 |
+
mem_patch_feature_dinov2_coreset = patch_tokens_dinov2
|
910 |
+
|
911 |
+
return scores, mem_patch_feature_clip_coreset, mem_patch_feature_dinov2_coreset
|
912 |
+
|
913 |
+
|
914 |
+
|
915 |
+
def process(self, class_name: str, few_shot_samples: list[torch.Tensor], few_shot_paths: list[str]):
|
916 |
+
few_shot_samples = self.transform(few_shot_samples).to(self.device)
|
917 |
+
scores, self.mem_patch_feature_clip_coreset, self.mem_patch_feature_dinov2_coreset = self.process_k_shot(class_name, few_shot_samples, few_shot_paths)
|
918 |
+
|
919 |
+
def setup(self, data: dict) -> None:
|
920 |
+
"""Setup the few-shot samples for the model.
|
921 |
+
|
922 |
+
The evaluation script will call this method to pass the k images for few shot learning and the object class
|
923 |
+
name. In the case of MVTec LOCO this will be the dataset category name (e.g. breakfast_box). Please contact
|
924 |
+
the organizing committee if if your model requires any additional dataset-related information at setup-time.
|
925 |
+
"""
|
926 |
+
few_shot_samples = data.get("few_shot_samples")
|
927 |
+
class_name = data.get("dataset_category")
|
928 |
+
few_shot_paths = data.get("few_shot_samples_path")
|
929 |
+
self.class_name = class_name
|
930 |
+
|
931 |
+
self.k_shot = few_shot_samples.size(0)
|
932 |
+
self.process(class_name, few_shot_samples, few_shot_paths)
|
933 |
+
self.few_shot_inited = True
|
934 |
+
|
935 |
+
|
prompt_ensemble.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Union, List
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from imagenet_template import openai_imagenet_template
|
7 |
+
|
8 |
+
|
9 |
+
def encode_text_with_prompt_ensemble(model, objs, tokenizer, device):
|
10 |
+
prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
|
11 |
+
prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
|
12 |
+
prompt_state = [prompt_normal, prompt_abnormal]
|
13 |
+
prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
|
14 |
+
text_prompts = {}
|
15 |
+
for obj in objs:
|
16 |
+
text_features = []
|
17 |
+
for i in range(len(prompt_state)):
|
18 |
+
prompted_state = [state.format(obj) for state in prompt_state[i]]
|
19 |
+
prompted_sentence = []
|
20 |
+
for s in prompted_state:
|
21 |
+
for template in prompt_templates:
|
22 |
+
prompted_sentence.append(template.format(s))
|
23 |
+
prompted_sentence = tokenizer(prompted_sentence).to(device)
|
24 |
+
class_embeddings = model.encode_text(prompted_sentence)
|
25 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
26 |
+
class_embedding = class_embeddings.mean(dim=0)
|
27 |
+
class_embedding /= class_embedding.norm()
|
28 |
+
text_features.append(class_embedding)
|
29 |
+
|
30 |
+
text_features = torch.stack(text_features, dim=1).to(device)
|
31 |
+
text_prompts[obj] = text_features
|
32 |
+
|
33 |
+
return text_prompts
|
34 |
+
|
35 |
+
|
36 |
+
def encode_general_text(model, obj_list, tokenizer, device):
|
37 |
+
text_dir = '/data/yizhou/VAND2.0/wgd/general_texts/train2014'
|
38 |
+
text_name_list = sorted(os.listdir(text_dir))
|
39 |
+
bs = 100
|
40 |
+
sentences = []
|
41 |
+
embeddings = []
|
42 |
+
all_sentences = []
|
43 |
+
for text_name in tqdm(text_name_list):
|
44 |
+
with open(os.path.join(text_dir, text_name), 'r') as f:
|
45 |
+
for line in f.readlines():
|
46 |
+
sentences.append(line.strip())
|
47 |
+
if len(sentences) > bs:
|
48 |
+
prompted_sentences = tokenizer(sentences).to(device)
|
49 |
+
class_embeddings = model.encode_text(prompted_sentences)
|
50 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
51 |
+
embeddings.append(class_embeddings)
|
52 |
+
all_sentences.extend(sentences)
|
53 |
+
sentences = []
|
54 |
+
# if len(all_sentences) > 10000:
|
55 |
+
# break
|
56 |
+
embeddings = torch.cat(embeddings, 0)
|
57 |
+
print(embeddings.size(0))
|
58 |
+
embeddings_dict = {}
|
59 |
+
for obj in obj_list:
|
60 |
+
embeddings_dict[obj] = embeddings
|
61 |
+
return embeddings_dict, all_sentences
|
62 |
+
|
63 |
+
|
64 |
+
def encode_abnormal_text(model, obj_list, tokenizer, device):
|
65 |
+
embeddings = {}
|
66 |
+
sentences = {}
|
67 |
+
for obj in obj_list:
|
68 |
+
sentence_abnormal = []
|
69 |
+
with open(os.path.join('text_prompt', 'v1', obj + '_abnormal.txt'), 'r') as f:
|
70 |
+
for line in f.readlines():
|
71 |
+
sentence_abnormal.append(line.strip().lower())
|
72 |
+
|
73 |
+
prompted_sentences = tokenizer(sentence_abnormal).to(device)
|
74 |
+
class_embeddings = model.encode_text(prompted_sentences)
|
75 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
76 |
+
embeddings[obj] = class_embeddings
|
77 |
+
sentences[obj] = sentence_abnormal
|
78 |
+
return embeddings, sentences
|
79 |
+
|
80 |
+
|
81 |
+
def encode_normal_text(model, obj_list, tokenizer, device):
|
82 |
+
embeddings = {}
|
83 |
+
sentences = {}
|
84 |
+
for obj in obj_list:
|
85 |
+
sentence_abnormal = []
|
86 |
+
with open(os.path.join('text_prompt', 'v1', obj + '_normal.txt'), 'r') as f:
|
87 |
+
for line in f.readlines():
|
88 |
+
sentence_abnormal.append(line.strip().lower())
|
89 |
+
|
90 |
+
prompted_sentences = tokenizer(sentence_abnormal).to(device)
|
91 |
+
class_embeddings = model.encode_text(prompted_sentences)
|
92 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
93 |
+
embeddings[obj] = class_embeddings
|
94 |
+
sentences[obj] = sentence_abnormal
|
95 |
+
return embeddings, sentences
|
96 |
+
|
97 |
+
|
98 |
+
def encode_obj_text(model, query_words, tokenizer, device):
|
99 |
+
# query_words = ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box']
|
100 |
+
# query_words = ['liquid', 'glass', "top", 'black background']
|
101 |
+
# query_words = ["connector", "grid"]
|
102 |
+
# query_words = [['screw'], 'plastic bag', 'background']
|
103 |
+
# query_words = [['pushpin', 'pin'], ['plastic box'], 'box', 'black background']
|
104 |
+
query_features = []
|
105 |
+
with torch.no_grad():
|
106 |
+
for qw in query_words:
|
107 |
+
token_input = []
|
108 |
+
if type(qw) == list:
|
109 |
+
for qw2 in qw:
|
110 |
+
token_input.extend([temp(qw2) for temp in openai_imagenet_template])
|
111 |
+
else:
|
112 |
+
token_input = [temp(qw) for temp in openai_imagenet_template]
|
113 |
+
query = tokenizer(token_input).to(device)
|
114 |
+
feature = model.encode_text(query)
|
115 |
+
feature /= feature.norm(dim=-1, keepdim=True)
|
116 |
+
feature = feature.mean(dim=0)
|
117 |
+
feature /= feature.norm()
|
118 |
+
query_features.append(feature.unsqueeze(0))
|
119 |
+
query_features = torch.cat(query_features, dim=0)
|
120 |
+
return query_features
|
121 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohappyeyeballs==2.6.1
|
2 |
+
aiohttp==3.12.11
|
3 |
+
aiosignal==1.3.2
|
4 |
+
antlr4-python3-runtime==4.9.3
|
5 |
+
async-timeout==5.0.1
|
6 |
+
attrs==25.3.0
|
7 |
+
certifi==2025.4.26
|
8 |
+
charset-normalizer==3.4.2
|
9 |
+
contourpy==1.3.2
|
10 |
+
cycler==0.12.1
|
11 |
+
einops==0.6.1
|
12 |
+
faiss-cpu==1.8.0
|
13 |
+
filelock==3.18.0
|
14 |
+
fonttools==4.58.2
|
15 |
+
FrEIA==0.2
|
16 |
+
frozenlist==1.6.2
|
17 |
+
fsspec==2024.12.0
|
18 |
+
ftfy==6.3.1
|
19 |
+
hf-xet==1.1.3
|
20 |
+
huggingface-hub==0.32.4
|
21 |
+
idna==3.10
|
22 |
+
imageio==2.37.0
|
23 |
+
imgaug==0.4.0
|
24 |
+
Jinja2==3.1.6
|
25 |
+
joblib==1.5.1
|
26 |
+
jsonargparse==4.29.0
|
27 |
+
kiwisolver==1.4.8
|
28 |
+
kmeans-pytorch==0.3
|
29 |
+
kornia==0.7.0
|
30 |
+
lazy_loader==0.4
|
31 |
+
lightning==2.2.5
|
32 |
+
lightning-utilities==0.14.3
|
33 |
+
markdown-it-py==3.0.0
|
34 |
+
MarkupSafe==3.0.2
|
35 |
+
matplotlib==3.10.3
|
36 |
+
mdurl==0.1.2
|
37 |
+
mpmath==1.3.0
|
38 |
+
multidict==6.4.4
|
39 |
+
networkx==3.4.2
|
40 |
+
omegaconf==2.3.0
|
41 |
+
open-clip-torch==2.24.0
|
42 |
+
opencv-python==4.8.1.78
|
43 |
+
packaging==24.2
|
44 |
+
pandas==2.0.3
|
45 |
+
pillow==11.2.1
|
46 |
+
propcache==0.3.1
|
47 |
+
protobuf==6.31.1
|
48 |
+
Pygments==2.19.1
|
49 |
+
pyparsing==3.2.3
|
50 |
+
python-dateutil==2.9.0.post0
|
51 |
+
pytorch-lightning==2.5.1.post0
|
52 |
+
pytz==2025.2
|
53 |
+
PyYAML==6.0.2
|
54 |
+
regex==2024.11.6
|
55 |
+
requests==2.32.3
|
56 |
+
rich==13.7.1
|
57 |
+
safetensors==0.5.3
|
58 |
+
scikit-image==0.25.2
|
59 |
+
scikit-learn==1.2.2
|
60 |
+
scipy==1.15.3
|
61 |
+
segment-anything==1.0
|
62 |
+
sentencepiece==0.2.0
|
63 |
+
shapely==2.1.1
|
64 |
+
six==1.17.0
|
65 |
+
sympy==1.14.0
|
66 |
+
tabulate==0.9.0
|
67 |
+
threadpoolctl==3.6.0
|
68 |
+
tifffile==2025.5.10
|
69 |
+
timm==1.0.15
|
70 |
+
torchmetrics==1.7.2
|
71 |
+
tqdm==4.67.1
|
72 |
+
triton==2.1.0
|
73 |
+
typing_extensions==4.14.0
|
74 |
+
tzdata==2025.2
|
75 |
+
urllib3==2.4.0
|
76 |
+
wcwidth==0.2.13
|
77 |
+
yarl==1.20.0
|