|
from torch.nn import functional as F |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from utils import prompt_template, truncate, hybrid_scores |
|
|
|
class ERank_Transformer: |
|
|
|
def __init__(self, model_name_or_path: str): |
|
""" |
|
Initializes the ERank_Transformer reranker. |
|
|
|
Args: |
|
model_name_or_path (str): The name or path of the model to be loaded. |
|
This can be a Hugging Face model ID or a local path. |
|
""" |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
self.reranker = AutoModelForCausalLM.from_pretrained(model_name_or_path).eval() |
|
self.reranker.to("cuda") |
|
|
|
def rerank(self, query: str, docs: list, instruction: str, truncate_length: int=None) -> list: |
|
""" |
|
Reranks a list of documents based on a query and a specific instruction. |
|
|
|
Args: |
|
query (str): The search query provided by the user. |
|
docs (list): A list of dictionaries, where each dictionary represents a document |
|
and must contain a "content" key. |
|
instruction (str): The instruction for the model, guiding it on how to evaluate the documents. |
|
truncate_length (int, optional): The maximum length to truncate the query and document content to. Defaults to None. |
|
|
|
Returns: |
|
list: A new list of document dictionaries, sorted by their "rank_score" in descending order. |
|
""" |
|
|
|
|
|
messages = [ |
|
[{ |
|
"role": "user", |
|
"content": prompt_template.format( |
|
query=truncate(self.tokenizer, query, length=truncate_length) if truncate_length else query, |
|
doc=truncate(self.tokenizer, doc["content"], length=truncate_length) if truncate_length else doc["content"], |
|
instruction=instruction |
|
) |
|
}] for doc in docs |
|
] |
|
|
|
|
|
texts = [ |
|
self.tokenizer.apply_chat_template( |
|
each, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
) for each in messages |
|
] |
|
inputs = self.tokenizer(texts, padding=True, return_tensors="pt").to(self.reranker.device) |
|
|
|
|
|
outputs = self.reranker.generate( |
|
**inputs, |
|
max_new_tokens=8192, |
|
output_scores=True, |
|
return_dict_in_generate=True |
|
) |
|
|
|
|
|
results = [] |
|
scores = outputs.scores |
|
generated_ids = outputs.sequences |
|
answer_token_ids = self.tokenizer.encode("<answer>", add_special_tokens=False) |
|
for idx in range(len(texts)): |
|
|
|
|
|
output_ids = generated_ids[idx].tolist() |
|
start_index = -1 |
|
for i in range(len(output_ids)-len(answer_token_ids)-1, -1, -1): |
|
if output_ids[i:i + len(answer_token_ids)] == answer_token_ids: |
|
start_index = i + len(answer_token_ids) |
|
break |
|
|
|
|
|
answer = "" |
|
prob = 1.0 |
|
if start_index != -1: |
|
for t in range(start_index - inputs.input_ids.size(1), len(scores)): |
|
generated_token_id = generated_ids[idx][inputs.input_ids.size(1) + t] |
|
token = self.tokenizer.decode(generated_token_id) |
|
if token.isdigit(): |
|
logits = scores[t][idx] |
|
probs = F.softmax(logits, dim=-1) |
|
prob *= probs[generated_token_id].item() |
|
answer += token |
|
else: |
|
break |
|
|
|
|
|
try: |
|
answer = int(answer) |
|
assert answer <= 10 |
|
except: |
|
answer = -1 |
|
|
|
|
|
results.append({ |
|
**docs[idx], |
|
"rank_score": answer * prob |
|
}) |
|
|
|
|
|
results.sort(key=lambda x:x["rank_score"], reverse=True) |
|
return results |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
model_name_or_path = "Ucreate/ERank-4B" |
|
|
|
|
|
reranker = ERank_Transformer(model_name_or_path) |
|
|
|
|
|
instruction = "Retrieve relevant documents for the query." |
|
query = "I am happy" |
|
docs = [ |
|
{"content": "excited", "first_stage_score": 46.7}, |
|
{"content": "sad", "first_stage_score": 1.5}, |
|
{"content": "peaceful", "first_stage_score": 2.3}, |
|
] |
|
|
|
|
|
results = reranker.rerank(query, docs, instruction, truncate_length=2048) |
|
print(results) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
alpha = 0.2 |
|
hybrid_results = hybrid_scores(results, alpha) |
|
print(hybrid_results) |
|
|
|
|
|
|
|
|
|
|