ankitkushwaha90 commited on
Commit
65802e8
·
verified ·
1 Parent(s): 0ee3719

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +95 -30
train_model.py CHANGED
@@ -1,47 +1,112 @@
1
  # Install required packages first:
2
- # pip install torch transformers safetensors
3
 
4
  import torch
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
6
 
7
  # -----------------------------
8
- # 1️⃣ Load the trained model
9
  # -----------------------------
10
- model_path = "./mini_gpt_safetensor" # folder where model was saved
11
 
12
- print("📥 Loading model and tokenizer...")
13
- tokenizer = AutoTokenizer.from_pretrained(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  tokenizer.pad_token = tokenizer.eos_token # GPT models don't have pad_token
15
 
16
- model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") # load model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # -----------------------------
19
- # 2️⃣ Generate text
20
  # -----------------------------
21
- def generate_text(prompt, max_length=50):
22
- # Tokenize prompt
23
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
24
- input_ids = input_ids.to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # Generate text
27
- output_ids = model.generate(
28
- input_ids,
29
- max_length=max_length,
30
- do_sample=True, # for randomness
31
- top_k=50, # sample from top 50 tokens
32
- top_p=0.95, # nucleus sampling
33
- temperature=0.7,
34
- num_return_sequences=1
35
- )
36
 
37
- # Decode output
38
- output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
39
- return output_text
 
 
40
 
41
  # -----------------------------
42
- # 3️⃣ Test generation
43
  # -----------------------------
44
- prompt = "Hello, I am training a mini GPT model"
45
- generated_text = generate_text(prompt, max_length=50)
46
- print("\n📝 Generated text:")
47
- print(generated_text)
 
1
  # Install required packages first:
2
+ # pip install torch transformers datasets accelerate safetensors
3
 
4
  import torch
5
+ from datasets import Dataset
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ AutoModelForCausalLM,
9
+ Trainer,
10
+ TrainingArguments,
11
+ DataCollatorForLanguageModeling
12
+ )
13
 
14
  # -----------------------------
15
+ # 1️⃣ Create a small custom dataset
16
  # -----------------------------
17
+ print("📥 Creating small dataset for training...")
18
 
19
+ train_texts = [
20
+ "Hello, my name is Ankit.",
21
+ "I love programming in Python.",
22
+ "Transformers library makes NLP easy.",
23
+ "PyTorch is great for deep learning.",
24
+ "I am learning to fine-tune GPT models."
25
+ ]
26
+
27
+ test_texts = [
28
+ "Hello, I am training a small GPT.",
29
+ "Deep learning is fun!",
30
+ "Python is my favorite programming language."
31
+ ]
32
+
33
+ # Convert to Hugging Face Dataset
34
+ train_data = Dataset.from_dict({"text": train_texts})
35
+ test_data = Dataset.from_dict({"text": test_texts})
36
+
37
+ # -----------------------------
38
+ # 2️⃣ Load tokenizer
39
+ # -----------------------------
40
+ print("📝 Loading tokenizer...")
41
+ tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
42
  tokenizer.pad_token = tokenizer.eos_token # GPT models don't have pad_token
43
 
44
+ # Tokenize dataset
45
+ def tokenize(batch):
46
+ return tokenizer(batch['text'], truncation=True, padding='max_length', max_length=32)
47
+
48
+ train_data = train_data.map(tokenize, batched=True)
49
+ test_data = test_data.map(tokenize, batched=True)
50
+
51
+ train_data.set_format('torch', columns=['input_ids', 'attention_mask'])
52
+ test_data.set_format('torch', columns=['input_ids', 'attention_mask'])
53
+
54
+ # -----------------------------
55
+ # 3️⃣ Load model
56
+ # -----------------------------
57
+ print("🤖 Loading model...")
58
+ model = AutoModelForCausalLM.from_pretrained("distilgpt2")
59
+
60
+ # -----------------------------
61
+ # 4️⃣ Data collator
62
+ # -----------------------------
63
+ data_collator = DataCollatorForLanguageModeling(
64
+ tokenizer=tokenizer,
65
+ mlm=False
66
+ )
67
 
68
  # -----------------------------
69
+ # 5️⃣ Training arguments
70
  # -----------------------------
71
+ training_args = TrainingArguments(
72
+ output_dir="./mini_gpt_safetensor",
73
+ overwrite_output_dir=True,
74
+ per_device_train_batch_size=2,
75
+ per_device_eval_batch_size=2,
76
+ num_train_epochs=3,
77
+ save_strategy="epoch",
78
+ logging_steps=10,
79
+ learning_rate=5e-5,
80
+ weight_decay=0.01,
81
+ fp16=True if torch.cuda.is_available() else False,
82
+ save_total_limit=2,
83
+ push_to_hub=False,
84
+ report_to=None,
85
+ optim="adamw_torch",
86
+ save_safetensors=True # saves in safetensors format
87
+ )
88
 
89
+ # -----------------------------
90
+ # 6️⃣ Trainer
91
+ # -----------------------------
92
+ trainer = Trainer(
93
+ model=model,
94
+ args=training_args,
95
+ train_dataset=train_data,
96
+ eval_dataset=test_data,
97
+ data_collator=data_collator
98
+ )
99
 
100
+ # -----------------------------
101
+ # 7️⃣ Train model
102
+ # -----------------------------
103
+ print("🏋️ Training model...")
104
+ trainer.train()
105
 
106
  # -----------------------------
107
+ # 8️⃣ Save model in safetensor format
108
  # -----------------------------
109
+ print("💾 Saving model in safetensors format...")
110
+ trainer.save_model("./mini_gpt_safetensor")
111
+
112
+ print("✅ Training complete and model saved!")