File size: 1,543 Bytes
cdf4ebf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from transformers import AutoTokenizer, Pipeline
import torch

class PairTextClassificationPipeline(Pipeline):
    def __init__(self, model, tokenizer=None, **kwargs):
        # Initialize tokenizer first
        if tokenizer is None:
            tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
        # Make sure we store the tokenizer before calling super().__init__
        self.tokenizer = tokenizer
        super().__init__(model=model, tokenizer=tokenizer, **kwargs)
        self.prompt = "<pad> Determine if the hypothesis is true given the premise?\n\nPremise: {text1}\n\nHypothesis: {text2}"
        
    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        return preprocess_kwargs, {}, {}

    def preprocess(self, inputs):
        # Expect inputs to be list of (Premise, Hypothesis) tuples
        pair_dict = {'text1': inputs[0], 'text2': inputs[1]}
        formatted_prompt = self.prompt.format(**pair_dict)
        model_inputs = self.tokenizer(
            formatted_prompt,
            return_tensors='pt', 
            padding=True
        )
        return model_inputs

    def _forward(self, model_inputs):
        model_outputs = self.model(**model_inputs)
        return model_outputs

    def postprocess(self, model_outputs):
        logits = model_outputs.logits
        logits = logits[:, 0, :] # tok_cls
        transformed_probs = torch.softmax(logits, dim=-1)
        raw_scores = transformed_probs[:, 1] # probability of class 1
        return raw_scores.item()