
Refactor inference script to streamline usage; remove unnecessary classes and integrate threading for response generation
0f6159c
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer | |
import torch | |
from peft import PeftModel | |
import threading | |
# For 4-bit quantized inference (recommended) | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
# First load the base model with quantization | |
base_model = AutoModelForCausalLM.from_pretrained( | |
"HuggingFaceTB/SmolLM2-1.7B-Instruct", | |
quantization_config=bnb_config, | |
device_map="auto" | |
) | |
# Then load the adapter weights (LoRA) | |
model = PeftModel.from_pretrained(base_model, "zahemen9900/finsight-ai") | |
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct") | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
system_prompt = "You are Finsight, a finance bot trained to assist users with financial insights" | |
prompt = "What's your name, and what're you good at?" | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": prompt} | |
] | |
formatted_prompt = tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
# Tokenize the formatted prompt | |
inputs = tokenizer(formatted_prompt, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} # Move all tensors to device | |
# Create a streamer | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
# Adjust generation parameters for more controlled responses | |
generation_config = { | |
"max_new_tokens": 256, | |
"temperature": 0.6, | |
"top_p": 0.95, | |
"do_sample": True, | |
"pad_token_id": tokenizer.eos_token_id, | |
"eos_token_id": tokenizer.eos_token_id, | |
"repetition_penalty": 1.2, | |
"no_repeat_ngram_size": 4, | |
"num_beams": 1, | |
"early_stopping": False, | |
"length_penalty": 1.0, | |
} | |
# Combine inputs and generation config for the generate function | |
generation_kwargs = {**generation_config, "input_ids": inputs["input_ids"], "streamer": streamer} | |
# Start generation in a separate thread | |
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Iterate over the generated text | |
print("Response: ", end="") | |
for text in streamer: | |
print(text, end="", flush=True) |