Safetensors
qwen3
File size: 5,779 Bytes
c2e0fef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import prompt_template, truncate, hybrid_scores

class ERank_Transformer:
    
    def __init__(self, model_name_or_path: str):
        """
        Initializes the ERank_Transformer reranker.

        Args:
            model_name_or_path (str): The name or path of the model to be loaded.
                                      This can be a Hugging Face model ID or a local path.
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.reranker = AutoModelForCausalLM.from_pretrained(model_name_or_path).eval()
        self.reranker.to("cuda")
  
    def rerank(self, query: str, docs: list, instruction: str, truncate_length: int=None) -> list:
        """
        Reranks a list of documents based on a query and a specific instruction.

        Args:
            query (str): The search query provided by the user.
            docs (list): A list of dictionaries, where each dictionary represents a document
                         and must contain a "content" key.
            instruction (str): The instruction for the model, guiding it on how to evaluate the documents.
            truncate_length (int, optional): The maximum length to truncate the query and document content to. Defaults to None.

        Returns:
            list: A new list of document dictionaries, sorted by their "rank_score" in descending order.
        """
        
        # prepare messages
        messages = [
            [{
                "role": "user", 
                "content": prompt_template.format(
                    query=truncate(self.tokenizer, query, length=truncate_length) if truncate_length else query,
                    doc=truncate(self.tokenizer, doc["content"], length=truncate_length) if truncate_length else doc["content"],
                    instruction=instruction
                )
            }] for doc in docs
        ]

        # encode tokens
        texts = [
            self.tokenizer.apply_chat_template(
                each,
                tokenize=False,
                add_generation_prompt=True,
            ) for each in messages
        ]
        inputs = self.tokenizer(texts, padding=True, return_tensors="pt").to(self.reranker.device)

        # LLM completion
        outputs = self.reranker.generate(
            **inputs,
            max_new_tokens=8192,
            output_scores=True,
            return_dict_in_generate=True
        )
  
        # extract and organize results
        results = []
        scores = outputs.scores
        generated_ids = outputs.sequences
        answer_token_ids = self.tokenizer.encode("<answer>", add_special_tokens=False)
        for idx in range(len(texts)):
            
            # find <answer> in the generated sequence
            output_ids = generated_ids[idx].tolist()
            start_index = -1
            for i in range(len(output_ids)-len(answer_token_ids)-1, -1, -1):
                if output_ids[i:i + len(answer_token_ids)] == answer_token_ids:
                    start_index = i + len(answer_token_ids)
                    break
            
            # start from the index after <answer>
            answer = ""
            prob = 1.0
            if start_index != -1:
                for t in range(start_index - inputs.input_ids.size(1), len(scores)):
                    generated_token_id = generated_ids[idx][inputs.input_ids.size(1) + t]
                    token = self.tokenizer.decode(generated_token_id)
                    if token.isdigit():
                        logits = scores[t][idx]
                        probs = F.softmax(logits, dim=-1)
                        prob *= probs[generated_token_id].item()
                        answer += token
                    else:
                        break

            # in case the answer is not a digit or exceeds 10
            try:
                answer = int(answer)
                assert answer <= 10
            except:
                answer = -1
                
            # append to the final results
            results.append({
                **docs[idx],
                "rank_score": answer * prob
            })
            
        # sort the reranking results for the query
        results.sort(key=lambda x:x["rank_score"], reverse=True)
        return results
    
    
if __name__ == "__main__":
    
    # select a model
    model_name_or_path = "Ucreate/ERank-4B"
    # model_name_or_path = "Ucreate/ERank-14B"
    # model_name_or_path = "Ucreate/ERank-32B"
    reranker = ERank_Transformer(model_name_or_path)
    
    # input data
    instruction = "Retrieve relevant documents for the query."
    query = "I am happy"
    docs = [
        {"content": "excited", "first_stage_score": 46.7},
        {"content": "sad", "first_stage_score": 1.5},
        {"content": "peaceful", "first_stage_score": 2.3},
    ]

    # rerank
    results = reranker.rerank(query, docs, instruction, truncate_length=2048)
    print(results)
    # [
    #	 {'content': 'excited', 'first_stage_score': 46.7, 'rank_score': 4.84}, 
    #	 {'content': 'peaceful', 'first_stage_score': 2.3, 'rank_score': 2.98}
    #	 {'content': 'sad', 'first_stage_score': 1.5, 'rank_score': 0.0}, 
    # ]
    
    # Optional: hybrid with first-stage scores
    alpha = 0.2
    hybrid_results = hybrid_scores(results, alpha)
    print(hybrid_results)
    # [
    #	 {'content': 'excited', 'first_stage_score': 46.7, 'rank_score': 4.84, 'hybrid_score': 1.18}, 
    #	 {'content': 'peaceful', 'first_stage_score': 2.3, 'rank_score': 2.98, 'hybrid_score':0.01}, 
    #	 {'content': 'sad', 'first_stage_score': 1.5, 'rank_score': 0.0, 'hybrid_score': -1.19}
    # ]