Update Reinforcement_simple_model.py
Browse files- Reinforcement_simple_model.py +192 -0
Reinforcement_simple_model.py
CHANGED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import pickle
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
from collections import defaultdict
|
6 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
7 |
+
|
8 |
+
# ------------------------------
|
9 |
+
# RL Chatbot Agent (Q-Learning)
|
10 |
+
# ------------------------------
|
11 |
+
class RLChatbot:
|
12 |
+
def __init__(self, actions, alpha=0.1, gamma=0.9, epsilon=0.2):
|
13 |
+
self.actions = actions # possible responses
|
14 |
+
self.alpha = alpha # learning rate
|
15 |
+
self.gamma = gamma # discount factor
|
16 |
+
self.epsilon = epsilon # exploration rate
|
17 |
+
self.q_table = defaultdict(lambda: np.zeros(len(actions)))
|
18 |
+
self.vectorizer = CountVectorizer()
|
19 |
+
|
20 |
+
def featurize(self, text):
|
21 |
+
"""Convert input text to a hashed state ID (string key)."""
|
22 |
+
return str(hash(text.lower()) % (10**8))
|
23 |
+
|
24 |
+
def choose_action(self, state):
|
25 |
+
"""Epsilon-greedy action selection."""
|
26 |
+
if random.random() < self.epsilon:
|
27 |
+
return random.randint(0, len(self.actions) - 1)
|
28 |
+
return int(np.argmax(self.q_table[state]))
|
29 |
+
|
30 |
+
def update(self, state, action, reward, next_state):
|
31 |
+
"""Q-learning update."""
|
32 |
+
old_q = self.q_table[state][action]
|
33 |
+
next_max = np.max(self.q_table[next_state])
|
34 |
+
self.q_table[state][action] += self.alpha * (reward + self.gamma * next_max - old_q)
|
35 |
+
|
36 |
+
def save(self, path="rl_chatbot.pkl"):
|
37 |
+
with open(path, "wb") as f:
|
38 |
+
pickle.dump((dict(self.q_table), self.actions), f)
|
39 |
+
|
40 |
+
def load(self, path="rl_chatbot.pkl"):
|
41 |
+
if os.path.exists(path):
|
42 |
+
with open(path, "rb") as f:
|
43 |
+
data = pickle.load(f)
|
44 |
+
self.q_table = defaultdict(lambda: np.zeros(len(self.actions)), data[0])
|
45 |
+
self.actions = data[1]
|
46 |
+
|
47 |
+
# ------------------------------
|
48 |
+
# Simulated training environment
|
49 |
+
# ------------------------------
|
50 |
+
def simulated_reward(user_input, bot_response):
|
51 |
+
"""Fake reward function for simulation:
|
52 |
+
Higher reward if bot_response 'matches' intent."""
|
53 |
+
if "hello" in user_input.lower() and "hello" in bot_response.lower():
|
54 |
+
return 5
|
55 |
+
if "bye" in user_input.lower() and "bye" in bot_response.lower():
|
56 |
+
return 5
|
57 |
+
if "help" in user_input.lower() and "help" in bot_response.lower():
|
58 |
+
return 5
|
59 |
+
return -1 # default negative reward
|
60 |
+
|
61 |
+
# ------------------------------
|
62 |
+
# Main program
|
63 |
+
# ------------------------------
|
64 |
+
if __name__ == "__main__":
|
65 |
+
# actions = [
|
66 |
+
# "Hello! How can I help you?",
|
67 |
+
# "Goodbye! Have a nice day.",
|
68 |
+
# "I can help with your problems. What do you need?",
|
69 |
+
# "I'm not sure I understand.",
|
70 |
+
# "Please tell me more."
|
71 |
+
# ]
|
72 |
+
|
73 |
+
actions = [
|
74 |
+
# General greetings & casual
|
75 |
+
"Hello! How can I help you today?",
|
76 |
+
"Hi there! What’s on your mind?",
|
77 |
+
"Goodbye! Have a great day.",
|
78 |
+
"See you later! Keep coding.",
|
79 |
+
"I’m here to help with your questions.",
|
80 |
+
|
81 |
+
# AI/ML related
|
82 |
+
"Are you working on machine learning today?",
|
83 |
+
"Which model architecture are you using?",
|
84 |
+
"Do you want to discuss prompt engineering or fine-tuning?",
|
85 |
+
"I can explain how transformers work in detail.",
|
86 |
+
"Would you like me to write example PyTorch code for you?",
|
87 |
+
"I can help debug your reinforcement learning agent.",
|
88 |
+
"What dataset are you using for your project?",
|
89 |
+
"Let’s talk about optimizing training performance.",
|
90 |
+
"Are you running your model on CPU or GPU?",
|
91 |
+
"I can guide you on hyperparameter tuning.",
|
92 |
+
|
93 |
+
# Developer workflow
|
94 |
+
"Would you like me to generate example code?",
|
95 |
+
"I can help write a FastAPI endpoint for your AI model.",
|
96 |
+
"Do you need help with Hugging Face Transformers?",
|
97 |
+
"We can integrate this with a Flask web app.",
|
98 |
+
"Do you want me to explain FAISS indexing?",
|
99 |
+
"I can walk you through a RAG (Retrieval-Augmented Generation) pipeline.",
|
100 |
+
"Let’s debug your Python code step-by-step.",
|
101 |
+
"Would you like me to explain gradient descent?",
|
102 |
+
|
103 |
+
# More conversational fallback
|
104 |
+
"I’m not sure I understand, could you rephrase?",
|
105 |
+
"Can you provide more details?",
|
106 |
+
"Let’s break down the problem together.",
|
107 |
+
"Interesting question! Let’s explore it.",
|
108 |
+
"I can provide documentation links if you need.",
|
109 |
+
"That’s a complex topic, but I can simplify it for you."
|
110 |
+
]
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
agent = RLChatbot(actions)
|
115 |
+
|
116 |
+
# ------------------------------
|
117 |
+
# Simulated training phase
|
118 |
+
# ------------------------------
|
119 |
+
# training_data = [
|
120 |
+
# "hello", "hi there", "bye", "goodbye", "i need help", "can you help me",
|
121 |
+
# "what's up", "please help", "bye bye", "see you"
|
122 |
+
# ]
|
123 |
+
training_data = [
|
124 |
+
# Greetings / casual
|
125 |
+
"hello", "hi there", "hey", "good morning", "good evening",
|
126 |
+
"what's up", "how are you", "yo", "long time no see", "how's it going",
|
127 |
+
|
128 |
+
# General help
|
129 |
+
"i need help", "can you help me", "please help", "i have a question",
|
130 |
+
"i'm stuck", "can you guide me", "how do i fix this", "explain this to me",
|
131 |
+
"can you give me an example", "show me sample code",
|
132 |
+
|
133 |
+
# AI / ML specific
|
134 |
+
"how to train a model", "what is reinforcement learning",
|
135 |
+
"how does fine tuning work", "what is transfer learning",
|
136 |
+
"explain gradient descent", "how to improve accuracy",
|
137 |
+
"what is overfitting", "what is prompt engineering",
|
138 |
+
"how to load a huggingface model", "how to use pytorch",
|
139 |
+
"how to deploy a model", "difference between supervised and unsupervised learning",
|
140 |
+
|
141 |
+
# Coding / debugging
|
142 |
+
"why is my code not working", "how to debug python code",
|
143 |
+
"what does this error mean", "how to fix module not found error",
|
144 |
+
"how to install requirements", "what is virtual environment",
|
145 |
+
"how to use git", "how to clone a repository",
|
146 |
+
"what is docker", "how to run flask app",
|
147 |
+
|
148 |
+
# Farewells
|
149 |
+
"bye", "goodbye", "see you", "bye bye", "take care", "catch you later"
|
150 |
+
]
|
151 |
+
|
152 |
+
|
153 |
+
for episode in range(200):
|
154 |
+
user_msg = random.choice(training_data)
|
155 |
+
state = agent.featurize(user_msg)
|
156 |
+
action = agent.choose_action(state)
|
157 |
+
bot_reply = actions[action]
|
158 |
+
reward = simulated_reward(user_msg, bot_reply)
|
159 |
+
next_state = agent.featurize("end") # stateless
|
160 |
+
agent.update(state, action, reward, next_state)
|
161 |
+
|
162 |
+
print("✅ Training completed (simulated)")
|
163 |
+
|
164 |
+
# Save trained model
|
165 |
+
agent.save()
|
166 |
+
|
167 |
+
# ------------------------------
|
168 |
+
# Interactive chat
|
169 |
+
# ------------------------------
|
170 |
+
print("\n🤖 RL Chatbot is ready! Type 'quit' to exit.")
|
171 |
+
agent.load()
|
172 |
+
|
173 |
+
while True:
|
174 |
+
user_input = input("You: ")
|
175 |
+
if user_input.lower() in ["quit", "exit"]:
|
176 |
+
break
|
177 |
+
|
178 |
+
state = agent.featurize(user_input)
|
179 |
+
action = agent.choose_action(state)
|
180 |
+
bot_reply = actions[action]
|
181 |
+
print(f"Bot: {bot_reply}")
|
182 |
+
|
183 |
+
# Get human feedback (reward)
|
184 |
+
try:
|
185 |
+
reward = int(input("Rate this reply (-5 to 5): "))
|
186 |
+
except ValueError:
|
187 |
+
reward = 0 # default if invalid
|
188 |
+
next_state = agent.featurize("end")
|
189 |
+
agent.update(state, action, reward, next_state)
|
190 |
+
agent.save()
|
191 |
+
|
192 |
+
print("💾 Chatbot model updated and saved.")
|