File size: 2,451 Bytes
f118c1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
Token position definitions (copied from MCQA task)
"""

import re
from CausalAbstraction.neural.LM_units import TokenPosition, get_last_token_index


def get_token_positions(pipeline, causal_model):
    """
    Get token positions for the simple MCQA task.
    
    Args:
        pipeline: The language model pipeline with tokenizer
        causal_model: The causal model for the task
        
    Returns:
        list[TokenPosition]: List of TokenPosition objects for intervention experiments
    """
    def get_correct_symbol_index(input, pipeline, causal_model):
        """
        Find the index of the correct answer symbol in the prompt.
        
        Args:
            input (Dict): The input dictionary to a causal model
            pipeline: The tokenizer pipeline
            causal_model: The causal model
            
        Returns:
            list[int]: List containing the index of the correct answer symbol token
        """
        # Run the model to get the answer position
        output = causal_model.run_forward(input)
        pointer = output["answer_pointer"]
        correct_symbol = output[f"symbol{pointer}"]
        prompt = input["raw_input"]
        
        # Find all single uppercase letters in the prompt
        matches = list(re.finditer(r"\b[A-Z]\b", prompt))
        
        # Find the match corresponding to our correct symbol
        symbol_match = None
        for match in matches:
            if prompt[match.start():match.end()] == correct_symbol:
                symbol_match = match
                break
                
        if not symbol_match:
            raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}")
        
        # Get the substring up to the symbol match end
        substring = prompt[:symbol_match.end()]
        tokenized_substring = list(pipeline.load(substring)["input_ids"][0])
        
        # The symbol token will be at the end of the substring
        return [len(tokenized_substring) - 1]

    # Create TokenPosition objects
    token_positions = [
        TokenPosition(lambda x: get_correct_symbol_index(x, pipeline, causal_model), pipeline, id="correct_symbol"),
        TokenPosition(lambda x: [get_correct_symbol_index(x, pipeline, causal_model)[0]+1], pipeline, id="correct_symbol_period"),
        TokenPosition(lambda x: get_last_token_index(x, pipeline), pipeline, id="last_token")
    ]
    return token_positions