File size: 828 Bytes
a642e69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # To prevent long warnings :)

tokenizer = AutoTokenizer.from_pretrained("gg-hf/gemma-2-2b-it")
model = AutoModelForCausalLM.from_pretrained("gg-hf/gemma-2-2b-it").to("cuda:1")

model.generation_config.cache_implementation = "static"

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
messages = [
    {"role": "user", "content": "Who are you? Please, answer in pirate-speak."},
]

inputs = tokenizer.apply_chat_template(
  messages,
  tokenize=True,
  add_generation_prompt=True,
  return_tensors="pt",
  return_dict=True,
  ).to("cuda:1")

outputs = model.generate(**inputs, max_new_tokens=256)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))