File size: 6,000 Bytes
74acc06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
from typing import Union, List
import torch
import numpy as np
from tqdm import tqdm
from imagenet_template import openai_imagenet_template


def encode_text_with_prompt_ensemble(model, objs, tokenizer, device):
    prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
    prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
    prompt_state = [prompt_normal, prompt_abnormal]
    prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
    text_prompts = {}
    for obj in objs:
        text_features = []
        for i in range(len(prompt_state)):
            prompted_state = [state.format(obj) for state in prompt_state[i]]
            prompted_sentence = []
            for s in prompted_state:
                for template in prompt_templates:
                    prompted_sentence.append(template.format(s))
            prompted_sentence = tokenizer(prompted_sentence).to(device)
            class_embeddings = model.encode_text(prompted_sentence)
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            text_features.append(class_embedding)

        text_features = torch.stack(text_features, dim=1).to(device)
        text_prompts[obj] = text_features

    return text_prompts


def encode_general_text(model, obj_list, tokenizer, device):
    text_dir = '/data/yizhou/VAND2.0/wgd/general_texts/train2014'
    text_name_list = sorted(os.listdir(text_dir))
    bs = 100
    sentences = []
    embeddings = []
    all_sentences = []
    for text_name in tqdm(text_name_list):
        with open(os.path.join(text_dir, text_name), 'r') as f:
            for line in f.readlines():
                sentences.append(line.strip())
        if len(sentences) > bs:
            prompted_sentences = tokenizer(sentences).to(device)
            class_embeddings = model.encode_text(prompted_sentences)
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            embeddings.append(class_embeddings)
            all_sentences.extend(sentences)
            sentences = []
        # if len(all_sentences) > 10000:
        #     break
    embeddings = torch.cat(embeddings, 0)
    print(embeddings.size(0))
    embeddings_dict = {}
    for obj in obj_list:
        embeddings_dict[obj] = embeddings
    return embeddings_dict, all_sentences


def encode_abnormal_text(model, obj_list, tokenizer, device):
    embeddings = {}
    sentences = {}
    for obj in obj_list:
        sentence_abnormal = []
        with open(os.path.join('text_prompt', 'v1', obj + '_abnormal.txt'), 'r') as f:
            for line in f.readlines():
                sentence_abnormal.append(line.strip().lower())

        prompted_sentences = tokenizer(sentence_abnormal).to(device)
        class_embeddings = model.encode_text(prompted_sentences)
        class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
        embeddings[obj] = class_embeddings
        sentences[obj] = sentence_abnormal
    return embeddings, sentences


def encode_normal_text(model, obj_list, tokenizer, device):
    embeddings = {}
    sentences = {}
    for obj in obj_list:
        sentence_abnormal = []
        with open(os.path.join('text_prompt', 'v1', obj + '_normal.txt'), 'r') as f:
            for line in f.readlines():
                sentence_abnormal.append(line.strip().lower())

        prompted_sentences = tokenizer(sentence_abnormal).to(device)
        class_embeddings = model.encode_text(prompted_sentences)
        class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
        embeddings[obj] = class_embeddings
        sentences[obj] = sentence_abnormal
    return embeddings, sentences


def encode_obj_text(model, query_words, tokenizer, device):
    # query_words = ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box']
    # query_words = ['liquid', 'glass', "top", 'black background']
    # query_words = ["connector", "grid"]
    # query_words = [['screw'], 'plastic bag', 'background']
    # query_words = [['pushpin', 'pin'], ['plastic box'], 'box', 'black background']
    query_features = []
    with torch.no_grad():
        for qw in query_words:
            token_input = []
            if type(qw) == list:
                for qw2 in qw:
                    token_input.extend([temp(qw2) for temp in openai_imagenet_template])
            else:
                token_input = [temp(qw) for temp in openai_imagenet_template]
            query = tokenizer(token_input).to(device)
            feature = model.encode_text(query)
            feature /= feature.norm(dim=-1, keepdim=True)
            feature = feature.mean(dim=0)
            feature /= feature.norm()
            query_features.append(feature.unsqueeze(0))
    query_features = torch.cat(query_features, dim=0)
    return query_features