Safetensors
qwen3
ERank-14B / examples /ERank_vLLM.py
Ucreate's picture
Upload assets and examples
c2e0fef verified
import torch
import math
from vllm import LLM, SamplingParams
from utils import prompt_template, truncate
class ERank_vLLM:
def __init__(self, model_name_or_path: str):
"""
Initializes the ERank_vLLM reranker.
Args:
model_name_or_path (str): The name or path of the model to be loaded.
This can be a Hugging Face model ID or a local path.
"""
num_gpu = torch.cuda.device_count()
self.ranker = LLM(
model=model_name_or_path,
tensor_parallel_size=num_gpu,
gpu_memory_utilization=0.95,
enable_prefix_caching=True
)
self.tokenizer = self.ranker.get_tokenizer()
self.sampling_params = SamplingParams(
temperature=0,
max_tokens=4096,
logprobs=20
)
def rerank(self, query: str, docs: list, instruction: str, truncate_length: int=None) -> list:
"""
Reranks a list of documents based on a query and a specific instruction.
Args:
query (str): The search query provided by the user.
docs (list): A list of dictionaries, where each dictionary represents a document
and must contain a "content" key.
instruction (str): The instruction for the model, guiding it on how to evaluate the documents.
truncate_length (int, optional): The maximum length to truncate the query and document content to. Defaults to None.
Returns:
list: A new list of document dictionaries, sorted by their "rank_score" in descending order.
"""
# prepare messages
messages = [
[{
"role": "user",
"content": prompt_template.format(
query=truncate(self.tokenizer, query, length=truncate_length) if truncate_length else query,
doc=truncate(self.tokenizer, doc["content"], length=truncate_length) if truncate_length else doc["content"],
instruction=instruction
)
}] for doc in docs
]
# LLM generate
outputs = self.ranker.chat(messages, self.sampling_params)
# extract and organize results
results = []
for doc, output in zip(docs, outputs):
# extract the answer and its probability
cur = ""
answer = ""
is_ans = False
prob = 1.0
for each in output.outputs[0].logprobs[-10:]:
_, detail = next(iter(each.items()))
token = detail.decoded_token
logprob = detail.logprob
if is_ans and token.isdigit():
answer += token
prob *= math.exp(logprob)
else:
cur += token
if cur.endswith("<answer>"):
is_ans = True
# in case the answer is not a digit or exceeds 10
try:
answer = int(answer)
assert answer <= 10
except:
answer = -1
# append to the final results
results.append({
**doc,
"rank_score": answer * prob
})
# sort the reranking results for the query
results.sort(key=lambda x:x["rank_score"], reverse=True)
return results