ankitkushwaha90 commited on
Commit
0ee3719
·
verified ·
1 Parent(s): 994bc32

Create train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +47 -0
train_model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install required packages first:
2
+ # pip install torch transformers safetensors
3
+
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+
7
+ # -----------------------------
8
+ # 1️⃣ Load the trained model
9
+ # -----------------------------
10
+ model_path = "./mini_gpt_safetensor" # folder where model was saved
11
+
12
+ print("📥 Loading model and tokenizer...")
13
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
14
+ tokenizer.pad_token = tokenizer.eos_token # GPT models don't have pad_token
15
+
16
+ model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") # load model
17
+
18
+ # -----------------------------
19
+ # 2️⃣ Generate text
20
+ # -----------------------------
21
+ def generate_text(prompt, max_length=50):
22
+ # Tokenize prompt
23
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
24
+ input_ids = input_ids.to(model.device)
25
+
26
+ # Generate text
27
+ output_ids = model.generate(
28
+ input_ids,
29
+ max_length=max_length,
30
+ do_sample=True, # for randomness
31
+ top_k=50, # sample from top 50 tokens
32
+ top_p=0.95, # nucleus sampling
33
+ temperature=0.7,
34
+ num_return_sequences=1
35
+ )
36
+
37
+ # Decode output
38
+ output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
39
+ return output_text
40
+
41
+ # -----------------------------
42
+ # 3️⃣ Test generation
43
+ # -----------------------------
44
+ prompt = "Hello, I am training a mini GPT model"
45
+ generated_text = generate_text(prompt, max_length=50)
46
+ print("\n📝 Generated text:")
47
+ print(generated_text)