Safetensors
qwen3
File size: 3,516 Bytes
c2e0fef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import math
from vllm import LLM, SamplingParams
from utils import prompt_template, truncate


class ERank_vLLM:
    
    def __init__(self, model_name_or_path: str):
        """
        Initializes the ERank_vLLM reranker.

        Args:
            model_name_or_path (str): The name or path of the model to be loaded.
                                      This can be a Hugging Face model ID or a local path.
        """
        num_gpu = torch.cuda.device_count()
        self.ranker = LLM(
            model=model_name_or_path, 
            tensor_parallel_size=num_gpu, 
            gpu_memory_utilization=0.95,
            enable_prefix_caching=True
        )
        self.tokenizer = self.ranker.get_tokenizer()
        self.sampling_params = SamplingParams(
            temperature=0,
            max_tokens=4096,
            logprobs=20
        )
  
    def rerank(self, query: str, docs: list, instruction: str, truncate_length: int=None) -> list:
        """
        Reranks a list of documents based on a query and a specific instruction.

        Args:
            query (str): The search query provided by the user.
            docs (list): A list of dictionaries, where each dictionary represents a document
                         and must contain a "content" key.
            instruction (str): The instruction for the model, guiding it on how to evaluate the documents.
            truncate_length (int, optional): The maximum length to truncate the query and document content to. Defaults to None.

        Returns:
            list: A new list of document dictionaries, sorted by their "rank_score" in descending order.
        """

        # prepare messages
        messages = [
            [{
                "role": "user", 
                "content": prompt_template.format(
                    query=truncate(self.tokenizer, query, length=truncate_length) if truncate_length else query,
                    doc=truncate(self.tokenizer, doc["content"], length=truncate_length) if truncate_length else doc["content"],
                    instruction=instruction
                )
            }] for doc in docs
        ]

        # LLM generate
        outputs = self.ranker.chat(messages, self.sampling_params)
  
        # extract and organize results
        results = []
        for doc, output in zip(docs, outputs):
            
            # extract the answer and its probability
            cur = ""
            answer = ""
            is_ans = False
            prob = 1.0
            for each in output.outputs[0].logprobs[-10:]:
                _, detail = next(iter(each.items()))
                token = detail.decoded_token
                logprob = detail.logprob
                if is_ans and token.isdigit():
                    answer += token
                    prob *= math.exp(logprob)
                else:
                    cur += token
                    if cur.endswith("<answer>"):
                        is_ans = True
            
            # in case the answer is not a digit or exceeds 10
            try:
                answer = int(answer)
                assert answer <= 10
            except:
                answer = -1
            
            # append to the final results
            results.append({
                **doc,
                "rank_score": answer * prob
            })
        
        # sort the reranking results for the query
        results.sort(key=lambda x:x["rank_score"], reverse=True)
        return results