File size: 3,902 Bytes
09d523d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from torch import nn
import torch.nn.functional as F
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from dataclasses import dataclass

@dataclass
class SelfCorrectiveLlamaOutput(CausalLMOutputWithPast):
    hallucination_logits: torch.FloatTensor = None

class SelfCorrectiveLlama(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        
        self.num_new_tokens = 3

        intermediate_size = config.intermediate_size
        self.hallucination_gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
        self.hallucination_up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
        self.hallucination_down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
        self.hallucination_detector = nn.Linear(config.hidden_size, self.num_new_tokens + 1)
    
    def forward(
        self, 
        input_ids, 
        attention_mask=None, 
        labels=None, 
        hallucination_labels=None, 
        **kwargs
    ):
        # 1. Get the last hidden state from the base transformer model.
        transformer_outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        last_hidden = transformer_outputs.last_hidden_state

        # 2. Calculate token logits and hallucination logits from the last hidden state.
        # Token logits
        logits = self.lm_head(last_hidden)

        # SwiGLU-based hallucination detector
        gate_output = self.hallucination_gate_proj(last_hidden)
        up_output = self.hallucination_up_proj(last_hidden)
        gated_hidden = F.silu(gate_output) * up_output
        detector_hidden = self.hallucination_down_proj(gated_hidden)

        # Hallucination logits
        all_hallucination_logits = self.hallucination_detector(detector_hidden)

        # 3. Modify the token logits conditionally.
        deletion_logits = all_hallucination_logits[..., 1:] # skip the first token (no hallucination)
        additional_logits = torch.zeros_like(logits)

        # Conditionally add the deletion logits if we are in training and labels are provided.
        if hallucination_labels is not None and labels is not None:
            # Condition 1: The hallucination label is 0 (no hallucination).
            mask_no_hallucination = (hallucination_labels == 0)

            # Condition 2: The next token is one of the deletion tokens.
            # Check if the token ID is within the range of the last `num_new_tokens` in the vocab.
            vocab_size = logits.shape[-1]
            mask_is_deletion_token = (labels >= (vocab_size - self.num_new_tokens)) & (labels < vocab_size)

            # Combine the masks. The addition happens if either condition is true.
            # We need to align the shapes for broadcasting.
            combined_mask = (mask_no_hallucination | mask_is_deletion_token).unsqueeze(-1)

            # Use the mask to conditionally apply the deletion logits.
            additional_logits[:, :, -self.num_new_tokens:] = torch.where(
                combined_mask,
                deletion_logits,
                torch.zeros_like(deletion_logits)
            )
        else:
            # Inference case: always add the deletion logits to the token logits.
            additional_logits[:, :, -self.num_new_tokens:] = deletion_logits

        logits = logits + additional_logits

        # 4. Return the custom output object.
        return SelfCorrectiveLlamaOutput(
            loss=None, # Loss calculation is handled by the Trainer
            logits=logits,
            hallucination_logits=all_hallucination_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=None,
            attentions=transformer_outputs.attentions
        )