|
import torch |
|
import math |
|
from vllm import LLM, SamplingParams |
|
from utils import prompt_template, truncate |
|
|
|
|
|
class ERank_vLLM: |
|
|
|
def __init__(self, model_name_or_path: str): |
|
""" |
|
Initializes the ERank_vLLM reranker. |
|
|
|
Args: |
|
model_name_or_path (str): The name or path of the model to be loaded. |
|
This can be a Hugging Face model ID or a local path. |
|
""" |
|
num_gpu = torch.cuda.device_count() |
|
self.ranker = LLM( |
|
model=model_name_or_path, |
|
tensor_parallel_size=num_gpu, |
|
gpu_memory_utilization=0.95, |
|
enable_prefix_caching=True |
|
) |
|
self.tokenizer = self.ranker.get_tokenizer() |
|
self.sampling_params = SamplingParams( |
|
temperature=0, |
|
max_tokens=4096, |
|
logprobs=20 |
|
) |
|
|
|
def rerank(self, query: str, docs: list, instruction: str, truncate_length: int=None) -> list: |
|
""" |
|
Reranks a list of documents based on a query and a specific instruction. |
|
|
|
Args: |
|
query (str): The search query provided by the user. |
|
docs (list): A list of dictionaries, where each dictionary represents a document |
|
and must contain a "content" key. |
|
instruction (str): The instruction for the model, guiding it on how to evaluate the documents. |
|
truncate_length (int, optional): The maximum length to truncate the query and document content to. Defaults to None. |
|
|
|
Returns: |
|
list: A new list of document dictionaries, sorted by their "rank_score" in descending order. |
|
""" |
|
|
|
|
|
messages = [ |
|
[{ |
|
"role": "user", |
|
"content": prompt_template.format( |
|
query=truncate(self.tokenizer, query, length=truncate_length) if truncate_length else query, |
|
doc=truncate(self.tokenizer, doc["content"], length=truncate_length) if truncate_length else doc["content"], |
|
instruction=instruction |
|
) |
|
}] for doc in docs |
|
] |
|
|
|
|
|
outputs = self.ranker.chat(messages, self.sampling_params) |
|
|
|
|
|
results = [] |
|
for doc, output in zip(docs, outputs): |
|
|
|
|
|
cur = "" |
|
answer = "" |
|
is_ans = False |
|
prob = 1.0 |
|
for each in output.outputs[0].logprobs[-10:]: |
|
_, detail = next(iter(each.items())) |
|
token = detail.decoded_token |
|
logprob = detail.logprob |
|
if is_ans and token.isdigit(): |
|
answer += token |
|
prob *= math.exp(logprob) |
|
else: |
|
cur += token |
|
if cur.endswith("<answer>"): |
|
is_ans = True |
|
|
|
|
|
try: |
|
answer = int(answer) |
|
assert answer <= 10 |
|
except: |
|
answer = -1 |
|
|
|
|
|
results.append({ |
|
**doc, |
|
"rank_score": answer * prob |
|
}) |
|
|
|
|
|
results.sort(key=lambda x:x["rank_score"], reverse=True) |
|
return results |