abhilash88 commited on
Commit
1b298f7
·
verified ·
1 Parent(s): 34b38de

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +344 -0
model.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Complete Age and Gender Prediction Model with Pipeline Support
3
+ Author: Abhilash Sahoo
4
+ License: Apache 2.0
5
+
6
+ This file provides both manual usage and Hugging Face pipeline support.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from transformers import (
12
+ ViTModel,
13
+ ViTImageProcessor,
14
+ PreTrainedModel,
15
+ PretrainedConfig,
16
+ ImageClassificationPipeline
17
+ )
18
+ from PIL import Image
19
+ import numpy as np
20
+ from typing import Union, Dict, Any, List
21
+ import requests
22
+ from io import BytesIO
23
+
24
+
25
+ class AgeGenderConfig(PretrainedConfig):
26
+ """Configuration class for AgeGenderViTModel"""
27
+ model_type = "age-gender-vit"
28
+
29
+ def __init__(
30
+ self,
31
+ vit_model_name="google/vit-base-patch16-224",
32
+ hidden_size=768,
33
+ intermediate_size=256,
34
+ final_size=64,
35
+ dropout_rate=0.1,
36
+ num_age_classes=100,
37
+ **kwargs
38
+ ):
39
+ super().__init__(**kwargs)
40
+ self.vit_model_name = vit_model_name
41
+ self.hidden_size = hidden_size
42
+ self.intermediate_size = intermediate_size
43
+ self.final_size = final_size
44
+ self.dropout_rate = dropout_rate
45
+ self.num_age_classes = num_age_classes
46
+
47
+
48
+ class AgeGenderViTModel(PreTrainedModel):
49
+ """
50
+ Vision Transformer model for simultaneous age estimation and gender classification.
51
+
52
+ Architecture: ViT-Base + Dual-head (768 → 256 → 64 → 1)
53
+ - Age head: Regression output for age estimation (0-100 years)
54
+ - Gender head: Binary classification for gender prediction (Male/Female)
55
+ """
56
+ config_class = AgeGenderConfig
57
+
58
+ def __init__(self, config=None):
59
+ if config is None:
60
+ config = AgeGenderConfig()
61
+ super().__init__(config)
62
+
63
+ # Load pre-trained ViT model
64
+ self.vit = ViTModel.from_pretrained(config.vit_model_name)
65
+
66
+ # Age regression head: 768 → 256 → 64 → 1
67
+ self.age_head = nn.Sequential(
68
+ nn.Linear(config.hidden_size, config.intermediate_size), # 768 → 256
69
+ nn.ReLU(),
70
+ nn.Dropout(config.dropout_rate),
71
+ nn.Linear(config.intermediate_size, config.final_size), # 256 → 64
72
+ nn.ReLU(),
73
+ nn.Dropout(config.dropout_rate),
74
+ nn.Linear(config.final_size, 1) # 64 → 1
75
+ )
76
+
77
+ # Gender classification head: 768 → 256 → 64 → 1
78
+ self.gender_head = nn.Sequential(
79
+ nn.Linear(config.hidden_size, config.intermediate_size), # 768 → 256
80
+ nn.ReLU(),
81
+ nn.Dropout(config.dropout_rate),
82
+ nn.Linear(config.intermediate_size, config.final_size), # 256 → 64
83
+ nn.ReLU(),
84
+ nn.Dropout(config.dropout_rate),
85
+ nn.Linear(config.final_size, 1), # 64 → 1
86
+ nn.Sigmoid()
87
+ )
88
+
89
+ # For pipeline compatibility, add a dummy classifier
90
+ self.classifier = nn.Linear(2, 2) # Dummy layer for pipeline compatibility
91
+
92
+ def forward(self, pixel_values, **kwargs):
93
+ """Forward pass through the model"""
94
+ # Extract features using ViT
95
+ vit_outputs = self.vit(pixel_values=pixel_values)
96
+ pooled_output = vit_outputs.pooler_output
97
+
98
+ # Get predictions from both heads
99
+ age_output = self.age_head(pooled_output)
100
+ gender_output = self.gender_head(pooled_output)
101
+
102
+ # For pipeline compatibility, create fake logits
103
+ fake_logits = torch.cat([age_output, gender_output], dim=1)
104
+
105
+ # Return object with both pipeline format and raw outputs
106
+ return type('ModelOutput', (), {
107
+ 'logits': fake_logits,
108
+ 'age_logits': age_output,
109
+ 'gender_logits': gender_output
110
+ })()
111
+
112
+
113
+ class AgeGenderImageClassificationPipeline(ImageClassificationPipeline):
114
+ """Custom pipeline for age-gender classification"""
115
+
116
+ def postprocess(self, model_outputs, top_k=1, **kwargs):
117
+ """Custom postprocessing for age-gender predictions"""
118
+ if isinstance(model_outputs, list):
119
+ outputs = model_outputs[0] # Single image output
120
+ else:
121
+ outputs = model_outputs
122
+
123
+ # Extract age and gender logits
124
+ age_logits = outputs.age_logits
125
+ gender_logits = outputs.gender_logits
126
+
127
+ # Process predictions
128
+ age = int(torch.clamp(age_logits, 0, 100).item())
129
+ gender_prob = gender_logits.item()
130
+ gender = "Female" if gender_prob > 0.5 else "Male"
131
+ confidence = gender_prob if gender_prob > 0.5 else 1 - gender_prob
132
+
133
+ # Return in pipeline format
134
+ return [{
135
+ "label": f"{age} years, {gender}",
136
+ "score": confidence,
137
+ "age": age,
138
+ "gender": gender,
139
+ "gender_confidence": round(confidence, 3),
140
+ "gender_probability_female": round(gender_prob, 3),
141
+ "gender_probability_male": round(1 - gender_prob, 3)
142
+ }]
143
+
144
+
145
+ def create_model_and_processor():
146
+ """
147
+ Create model and processor for manual usage
148
+
149
+ Returns:
150
+ tuple: (model, processor)
151
+ """
152
+ # Create model
153
+ model = AgeGenderViTModel()
154
+
155
+ # Load pre-trained weights
156
+ model_url = "https://huggingface.co/abhilash88/age-gender-prediction/resolve/main/pytorch_model.bin"
157
+ weights = torch.hub.load_state_dict_from_url(model_url, map_location='cpu')
158
+ model.load_state_dict(weights)
159
+ model.eval()
160
+
161
+ # Create processor
162
+ processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
163
+
164
+ return model, processor
165
+
166
+
167
+ def predict_age_gender(
168
+ image: Union[str, Image.Image, np.ndarray],
169
+ model=None,
170
+ processor=None,
171
+ device='auto'
172
+ ) -> Dict[str, Any]:
173
+ """
174
+ Predict age and gender from an image (Manual approach)
175
+
176
+ Args:
177
+ image: Image path, URL, PIL Image, or numpy array
178
+ model: Pre-loaded model (optional)
179
+ processor: Pre-loaded processor (optional)
180
+ device: Device to use ('auto', 'cpu', 'cuda')
181
+
182
+ Returns:
183
+ Dictionary with predictions
184
+ """
185
+ # Auto-detect device
186
+ if device == 'auto':
187
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
188
+
189
+ # Load model and processor if not provided
190
+ if model is None or processor is None:
191
+ model, processor = create_model_and_processor()
192
+
193
+ model = model.to(device)
194
+
195
+ # Handle different input types
196
+ if isinstance(image, str):
197
+ if image.startswith(('http://', 'https://')):
198
+ # Download image from URL
199
+ response = requests.get(image)
200
+ image = Image.open(BytesIO(response.content)).convert('RGB')
201
+ else:
202
+ # Load local image
203
+ image = Image.open(image).convert('RGB')
204
+ elif isinstance(image, np.ndarray):
205
+ image = Image.fromarray(image).convert('RGB')
206
+ elif not isinstance(image, Image.Image):
207
+ raise ValueError(f"Unsupported image type: {type(image)}")
208
+
209
+ # Process image
210
+ inputs = processor(images=image, return_tensors="pt")
211
+ inputs = {k: v.to(device) for k, v in inputs.items()}
212
+
213
+ # Make prediction
214
+ with torch.no_grad():
215
+ outputs = model(**inputs)
216
+ age_pred = outputs.age_logits
217
+ gender_pred = outputs.gender_logits
218
+
219
+ # Post-process predictions
220
+ age = int(torch.clamp(age_pred, 0, 100).item())
221
+ gender_prob = gender_pred.item()
222
+ gender = "Female" if gender_prob > 0.5 else "Male"
223
+ confidence = gender_prob if gender_prob > 0.5 else 1 - gender_prob
224
+
225
+ return {
226
+ "age": age,
227
+ "gender": gender,
228
+ "gender_confidence": round(confidence, 3),
229
+ "gender_probability_female": round(gender_prob, 3),
230
+ "gender_probability_male": round(1 - gender_prob, 3),
231
+ "summary": f"{age} years, {gender} ({confidence:.1%} confidence)"
232
+ }
233
+
234
+
235
+ def predict_age_gender_pipeline(image_input: Union[str, Image.Image]) -> Dict[str, Any]:
236
+ """
237
+ Predict using Hugging Face pipeline (requires proper repo setup)
238
+
239
+ Args:
240
+ image_input: Image path, URL, or PIL Image
241
+
242
+ Returns:
243
+ Dictionary with predictions
244
+ """
245
+ from transformers import pipeline
246
+
247
+ try:
248
+ # Create pipeline
249
+ classifier = pipeline(
250
+ "image-classification",
251
+ model="abhilash88/age-gender-prediction",
252
+ trust_remote_code=True
253
+ )
254
+
255
+ # Make prediction
256
+ result = classifier(image_input)[0] # Get first result
257
+
258
+ return {
259
+ "age": result["age"],
260
+ "gender": result["gender"],
261
+ "confidence": result["gender_confidence"],
262
+ "summary": result["label"]
263
+ }
264
+
265
+ except Exception as e:
266
+ print(f"Pipeline failed: {e}")
267
+ print("Falling back to manual approach...")
268
+ return predict_age_gender(image_input)
269
+
270
+
271
+ # Simple usage functions
272
+ def simple_predict(image_path: str) -> str:
273
+ """
274
+ Simplest possible usage - just returns a string
275
+
276
+ Args:
277
+ image_path: Path to image or URL
278
+
279
+ Returns:
280
+ String with prediction
281
+ """
282
+ result = predict_age_gender(image_path)
283
+ return result["summary"]
284
+
285
+
286
+ def batch_predict(image_list: List[str]) -> List[Dict[str, Any]]:
287
+ """
288
+ Predict on multiple images
289
+
290
+ Args:
291
+ image_list: List of image paths or URLs
292
+
293
+ Returns:
294
+ List of prediction dictionaries
295
+ """
296
+ # Load model once for efficiency
297
+ model, processor = create_model_and_processor()
298
+
299
+ results = []
300
+ for image_path in image_list:
301
+ try:
302
+ result = predict_age_gender(image_path, model, processor)
303
+ result["image_path"] = image_path
304
+ results.append(result)
305
+ except Exception as e:
306
+ results.append({
307
+ "image_path": image_path,
308
+ "error": str(e),
309
+ "age": None,
310
+ "gender": None
311
+ })
312
+
313
+ return results
314
+
315
+
316
+ # Example usage and testing
317
+ if __name__ == "__main__":
318
+ print("🚀 Testing Age-Gender Prediction Model...")
319
+
320
+ try:
321
+ # Test simple prediction
322
+ print("📝 Testing simple prediction...")
323
+ test_url = "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?w=300&h=300&fit=crop&crop=face"
324
+
325
+ # Method 1: Simple string output
326
+ simple_result = simple_predict(test_url)
327
+ print(f"✅ Simple: {simple_result}")
328
+
329
+ # Method 2: Detailed output
330
+ detailed_result = predict_age_gender(test_url)
331
+ print(f"✅ Detailed: {detailed_result}")
332
+
333
+ # Method 3: Try pipeline (may fail if repo not updated)
334
+ try:
335
+ pipeline_result = predict_age_gender_pipeline(test_url)
336
+ print(f"✅ Pipeline: {pipeline_result}")
337
+ except:
338
+ print("❌ Pipeline not working yet (needs repo file updates)")
339
+
340
+ print("🎉 Model is working perfectly!")
341
+
342
+ except Exception as e:
343
+ print(f"❌ Error: {e}")
344
+ print("Note: This test requires internet connection for test image")