File size: 5,779 Bytes
c2e0fef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import prompt_template, truncate, hybrid_scores
class ERank_Transformer:
def __init__(self, model_name_or_path: str):
"""
Initializes the ERank_Transformer reranker.
Args:
model_name_or_path (str): The name or path of the model to be loaded.
This can be a Hugging Face model ID or a local path.
"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.reranker = AutoModelForCausalLM.from_pretrained(model_name_or_path).eval()
self.reranker.to("cuda")
def rerank(self, query: str, docs: list, instruction: str, truncate_length: int=None) -> list:
"""
Reranks a list of documents based on a query and a specific instruction.
Args:
query (str): The search query provided by the user.
docs (list): A list of dictionaries, where each dictionary represents a document
and must contain a "content" key.
instruction (str): The instruction for the model, guiding it on how to evaluate the documents.
truncate_length (int, optional): The maximum length to truncate the query and document content to. Defaults to None.
Returns:
list: A new list of document dictionaries, sorted by their "rank_score" in descending order.
"""
# prepare messages
messages = [
[{
"role": "user",
"content": prompt_template.format(
query=truncate(self.tokenizer, query, length=truncate_length) if truncate_length else query,
doc=truncate(self.tokenizer, doc["content"], length=truncate_length) if truncate_length else doc["content"],
instruction=instruction
)
}] for doc in docs
]
# encode tokens
texts = [
self.tokenizer.apply_chat_template(
each,
tokenize=False,
add_generation_prompt=True,
) for each in messages
]
inputs = self.tokenizer(texts, padding=True, return_tensors="pt").to(self.reranker.device)
# LLM completion
outputs = self.reranker.generate(
**inputs,
max_new_tokens=8192,
output_scores=True,
return_dict_in_generate=True
)
# extract and organize results
results = []
scores = outputs.scores
generated_ids = outputs.sequences
answer_token_ids = self.tokenizer.encode("<answer>", add_special_tokens=False)
for idx in range(len(texts)):
# find <answer> in the generated sequence
output_ids = generated_ids[idx].tolist()
start_index = -1
for i in range(len(output_ids)-len(answer_token_ids)-1, -1, -1):
if output_ids[i:i + len(answer_token_ids)] == answer_token_ids:
start_index = i + len(answer_token_ids)
break
# start from the index after <answer>
answer = ""
prob = 1.0
if start_index != -1:
for t in range(start_index - inputs.input_ids.size(1), len(scores)):
generated_token_id = generated_ids[idx][inputs.input_ids.size(1) + t]
token = self.tokenizer.decode(generated_token_id)
if token.isdigit():
logits = scores[t][idx]
probs = F.softmax(logits, dim=-1)
prob *= probs[generated_token_id].item()
answer += token
else:
break
# in case the answer is not a digit or exceeds 10
try:
answer = int(answer)
assert answer <= 10
except:
answer = -1
# append to the final results
results.append({
**docs[idx],
"rank_score": answer * prob
})
# sort the reranking results for the query
results.sort(key=lambda x:x["rank_score"], reverse=True)
return results
if __name__ == "__main__":
# select a model
model_name_or_path = "Ucreate/ERank-4B"
# model_name_or_path = "Ucreate/ERank-14B"
# model_name_or_path = "Ucreate/ERank-32B"
reranker = ERank_Transformer(model_name_or_path)
# input data
instruction = "Retrieve relevant documents for the query."
query = "I am happy"
docs = [
{"content": "excited", "first_stage_score": 46.7},
{"content": "sad", "first_stage_score": 1.5},
{"content": "peaceful", "first_stage_score": 2.3},
]
# rerank
results = reranker.rerank(query, docs, instruction, truncate_length=2048)
print(results)
# [
# {'content': 'excited', 'first_stage_score': 46.7, 'rank_score': 4.84},
# {'content': 'peaceful', 'first_stage_score': 2.3, 'rank_score': 2.98}
# {'content': 'sad', 'first_stage_score': 1.5, 'rank_score': 0.0},
# ]
# Optional: hybrid with first-stage scores
alpha = 0.2
hybrid_results = hybrid_scores(results, alpha)
print(hybrid_results)
# [
# {'content': 'excited', 'first_stage_score': 46.7, 'rank_score': 4.84, 'hybrid_score': 1.18},
# {'content': 'peaceful', 'first_stage_score': 2.3, 'rank_score': 2.98, 'hybrid_score':0.01},
# {'content': 'sad', 'first_stage_score': 1.5, 'rank_score': 0.0, 'hybrid_score': -1.19}
# ] |