from torch.nn import functional as F from transformers import AutoModelForCausalLM, AutoTokenizer from utils import prompt_template, truncate, hybrid_scores class ERank_Transformer: def __init__(self, model_name_or_path: str): """ Initializes the ERank_Transformer reranker. Args: model_name_or_path (str): The name or path of the model to be loaded. This can be a Hugging Face model ID or a local path. """ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.reranker = AutoModelForCausalLM.from_pretrained(model_name_or_path).eval() self.reranker.to("cuda") def rerank(self, query: str, docs: list, instruction: str, truncate_length: int=None) -> list: """ Reranks a list of documents based on a query and a specific instruction. Args: query (str): The search query provided by the user. docs (list): A list of dictionaries, where each dictionary represents a document and must contain a "content" key. instruction (str): The instruction for the model, guiding it on how to evaluate the documents. truncate_length (int, optional): The maximum length to truncate the query and document content to. Defaults to None. Returns: list: A new list of document dictionaries, sorted by their "rank_score" in descending order. """ # prepare messages messages = [ [{ "role": "user", "content": prompt_template.format( query=truncate(self.tokenizer, query, length=truncate_length) if truncate_length else query, doc=truncate(self.tokenizer, doc["content"], length=truncate_length) if truncate_length else doc["content"], instruction=instruction ) }] for doc in docs ] # encode tokens texts = [ self.tokenizer.apply_chat_template( each, tokenize=False, add_generation_prompt=True, ) for each in messages ] inputs = self.tokenizer(texts, padding=True, return_tensors="pt").to(self.reranker.device) # LLM completion outputs = self.reranker.generate( **inputs, max_new_tokens=8192, output_scores=True, return_dict_in_generate=True ) # extract and organize results results = [] scores = outputs.scores generated_ids = outputs.sequences answer_token_ids = self.tokenizer.encode("", add_special_tokens=False) for idx in range(len(texts)): # find in the generated sequence output_ids = generated_ids[idx].tolist() start_index = -1 for i in range(len(output_ids)-len(answer_token_ids)-1, -1, -1): if output_ids[i:i + len(answer_token_ids)] == answer_token_ids: start_index = i + len(answer_token_ids) break # start from the index after answer = "" prob = 1.0 if start_index != -1: for t in range(start_index - inputs.input_ids.size(1), len(scores)): generated_token_id = generated_ids[idx][inputs.input_ids.size(1) + t] token = self.tokenizer.decode(generated_token_id) if token.isdigit(): logits = scores[t][idx] probs = F.softmax(logits, dim=-1) prob *= probs[generated_token_id].item() answer += token else: break # in case the answer is not a digit or exceeds 10 try: answer = int(answer) assert answer <= 10 except: answer = -1 # append to the final results results.append({ **docs[idx], "rank_score": answer * prob }) # sort the reranking results for the query results.sort(key=lambda x:x["rank_score"], reverse=True) return results if __name__ == "__main__": # select a model model_name_or_path = "Ucreate/ERank-4B" # model_name_or_path = "Ucreate/ERank-14B" # model_name_or_path = "Ucreate/ERank-32B" reranker = ERank_Transformer(model_name_or_path) # input data instruction = "Retrieve relevant documents for the query." query = "I am happy" docs = [ {"content": "excited", "first_stage_score": 46.7}, {"content": "sad", "first_stage_score": 1.5}, {"content": "peaceful", "first_stage_score": 2.3}, ] # rerank results = reranker.rerank(query, docs, instruction, truncate_length=2048) print(results) # [ # {'content': 'excited', 'first_stage_score': 46.7, 'rank_score': 4.84}, # {'content': 'peaceful', 'first_stage_score': 2.3, 'rank_score': 2.98} # {'content': 'sad', 'first_stage_score': 1.5, 'rank_score': 0.0}, # ] # Optional: hybrid with first-stage scores alpha = 0.2 hybrid_results = hybrid_scores(results, alpha) print(hybrid_results) # [ # {'content': 'excited', 'first_stage_score': 46.7, 'rank_score': 4.84, 'hybrid_score': 1.18}, # {'content': 'peaceful', 'first_stage_score': 2.3, 'rank_score': 2.98, 'hybrid_score':0.01}, # {'content': 'sad', 'first_stage_score': 1.5, 'rank_score': 0.0, 'hybrid_score': -1.19} # ]