abhilash88 commited on
Commit
57c0617
·
verified ·
1 Parent(s): 1648c07

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +62 -109
model.py CHANGED
@@ -1,8 +1,8 @@
1
  """
2
- Age-Gender Prediction Model with Working One-Liner Pipeline
3
- Based on Hugging Face official documentation and best practices
4
 
5
- Usage:
6
  from transformers import pipeline
7
  classifier = pipeline("image-classification", model="abhilash88/age-gender-prediction", trust_remote_code=True)
8
  result = classifier("image.jpg")
@@ -14,131 +14,91 @@ import torch.nn as nn
14
  from transformers import (
15
  ViTModel,
16
  ViTImageProcessor,
17
- PreTrainedModel,
18
- PretrainedConfig,
19
- ImageClassificationPipeline,
20
- Pipeline
21
  )
22
  from PIL import Image
23
  import numpy as np
24
 
25
 
26
- class AgeGenderConfig(PretrainedConfig):
27
- """Configuration class following HuggingFace standards"""
28
- model_type = "age-gender-vit"
29
 
30
- def __init__(self, **kwargs):
31
- super().__init__(**kwargs)
32
- self.vit_model_name = kwargs.get("vit_model_name", "google/vit-base-patch16-224")
33
- self.hidden_size = kwargs.get("hidden_size", 768)
34
- self.intermediate_size = kwargs.get("intermediate_size", 256)
35
- self.final_size = kwargs.get("final_size", 64)
36
- self.dropout_rate = kwargs.get("dropout_rate", 0.1)
37
-
38
-
39
- class AgeGenderViTModel(PreTrainedModel):
40
- """Age-Gender ViT Model following HuggingFace standards"""
41
- config_class = AgeGenderConfig
42
-
43
- def __init__(self, config=None):
44
- if config is None:
45
- config = AgeGenderConfig()
46
  super().__init__(config)
47
 
48
- self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
49
 
50
  # Age head: 768 → 256 → 64 → 1
51
  self.age_head = nn.Sequential(
52
- nn.Linear(768, 256), nn.ReLU(), nn.Dropout(0.1),
53
- nn.Linear(256, 64), nn.ReLU(), nn.Dropout(0.1),
 
 
 
 
54
  nn.Linear(64, 1)
55
  )
56
 
57
  # Gender head: 768 → 256 → 64 → 1
58
  self.gender_head = nn.Sequential(
59
- nn.Linear(768, 256), nn.ReLU(), nn.Dropout(0.1),
60
- nn.Linear(256, 64), nn.ReLU(), nn.Dropout(0.1),
61
- nn.Linear(64, 1), nn.Sigmoid()
 
 
 
 
 
62
  )
63
 
64
- # Required for some pipeline compatibility
65
- self.classifier = nn.Linear(2, 2)
 
 
 
66
 
67
- def forward(self, pixel_values, **kwargs):
68
  """Forward pass returning logits for pipeline"""
69
- vit_outputs = self.vit(pixel_values=pixel_values)
70
- pooled_output = vit_outputs.pooler_output
71
 
 
 
 
 
 
 
 
 
72
  age_output = self.age_head(pooled_output)
73
  gender_output = self.gender_head(pooled_output)
74
 
75
- # Create concatenated logits for pipeline processing
76
  logits = torch.cat([age_output, gender_output], dim=1)
 
77
  return {"logits": logits}
78
 
79
 
80
  class AgeGenderImageClassificationPipeline(ImageClassificationPipeline):
81
- """
82
- Custom pipeline following HuggingFace documentation standards
83
- Reference: https://huggingface.co/docs/transformers/add_new_pipeline
84
- """
85
-
86
- def _sanitize_parameters(self, **kwargs):
87
- """Sanitize parameters following HF guidelines"""
88
- preprocess_kwargs = {}
89
- postprocess_kwargs = {}
90
-
91
- # Handle any custom parameters here if needed
92
- if "top_k" in kwargs:
93
- postprocess_kwargs["top_k"] = kwargs["top_k"]
94
-
95
- return preprocess_kwargs, {}, postprocess_kwargs
96
-
97
- def preprocess(self, inputs, **kwargs):
98
- """Preprocess inputs following HF guidelines"""
99
- # Handle different input types
100
- if isinstance(inputs, str):
101
- if inputs.startswith(('http://', 'https://')):
102
- import requests
103
- from io import BytesIO
104
- response = requests.get(inputs)
105
- inputs = Image.open(BytesIO(response.content)).convert('RGB')
106
- else:
107
- inputs = Image.open(inputs).convert('RGB')
108
- elif isinstance(inputs, np.ndarray):
109
- inputs = Image.fromarray(inputs).convert('RGB')
110
- elif not isinstance(inputs, Image.Image):
111
- inputs = inputs.convert('RGB')
112
-
113
- # Use the model's image processor
114
- return super().preprocess(inputs, **kwargs)
115
-
116
- def _forward(self, model_inputs, **kwargs):
117
- """Forward pass following HF guidelines"""
118
- return self.model(**model_inputs)
119
 
120
  def postprocess(self, model_outputs, top_k=1, **kwargs):
121
- """
122
- Postprocess model outputs to age/gender format
123
- This is where LABEL_0/LABEL_1 gets converted to real predictions
124
- """
125
- # Extract logits from model output
126
- logits = model_outputs["logits"]
127
 
128
- # Get age and gender from concatenated logits
129
- age_raw = logits[0, 0].item() # First element is age
130
- gender_raw = logits[0, 1].item() # Second element is gender
 
131
 
132
- # Apply the scaling we discovered through testing
133
- # age_raw ~0.7 maps to realistic ages using this formula:
134
  age = int(max(18, min(70, ((age_raw - 1.5) / 1.0) * 50 + 20)))
135
 
136
- # Process gender (already sigmoid'd in the model)
137
  gender_prob = gender_raw
138
  gender = "Female" if gender_prob > 0.5 else "Male"
139
  confidence = gender_prob if gender_prob > 0.5 else 1 - gender_prob
140
 
141
- # Return in the standard pipeline format with age/gender keys
142
  return [{
143
  "label": f"{age} years, {gender}",
144
  "score": confidence,
@@ -150,13 +110,13 @@ class AgeGenderImageClassificationPipeline(ImageClassificationPipeline):
150
  }]
151
 
152
 
153
- # Manual functions for advanced users
154
- def predict_age_gender(image_path):
155
- """Manual prediction function for advanced usage"""
156
  import torch.nn as nn
157
  from transformers import ViTImageProcessor, ViTModel
158
 
159
- class SimpleAgeGenderModel(nn.Module):
160
  def __init__(self):
161
  super().__init__()
162
  self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
@@ -181,7 +141,7 @@ def predict_age_gender(image_path):
181
  return age_output, gender_output
182
 
183
  # Load model
184
- model = SimpleAgeGenderModel()
185
  model_url = "https://huggingface.co/abhilash88/age-gender-prediction/resolve/main/pytorch_model.bin"
186
  weights = torch.hub.load_state_dict_from_url(model_url, map_location='cpu')
187
  filtered_weights = {k: v for k, v in weights.items() if not k.startswith('classifier.')}
@@ -204,7 +164,7 @@ def predict_age_gender(image_path):
204
  with torch.no_grad():
205
  age_raw, gender_raw = model(inputs["pixel_values"])
206
 
207
- # Apply same scaling as pipeline
208
  age_val = age_raw.item()
209
  age = int(max(18, min(70, ((age_val - 1.5) / 1.0) * 50 + 20)))
210
 
@@ -220,29 +180,22 @@ def predict_age_gender(image_path):
220
  }
221
 
222
 
223
- # Test function
224
  if __name__ == "__main__":
225
- print("🧪 Testing Age-Gender Prediction Model...")
226
 
227
  try:
228
- # Test the one-liner pipeline
229
  from transformers import pipeline
230
 
 
231
  classifier = pipeline("image-classification", model="abhilash88/age-gender-prediction", trust_remote_code=True)
232
-
233
- # Test with a sample URL
234
  test_url = "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?w=300"
235
  result = classifier(test_url)
236
 
237
- print(f"✅ Pipeline result: {result[0]}")
238
- print(f"✅ Age: {result[0]['age']}, Gender: {result[0]['gender']}")
239
-
240
- # Test manual approach
241
- manual_result = predict_age_gender(test_url)
242
- print(f"✅ Manual result: {manual_result['summary']}")
243
 
244
- print("🎉 TRUE ONE-LINER WORKING!")
 
 
245
 
246
  except Exception as e:
247
- print(f"❌ Error: {e}")
248
- print("Upload the corrected files to enable the one-liner")
 
1
  """
2
+ Age-Gender Prediction Model - Simplified Working Version
3
+ Uses standard ViT model_type to avoid CONFIG_MAPPING issues
4
 
5
+ EXACT Usage:
6
  from transformers import pipeline
7
  classifier = pipeline("image-classification", model="abhilash88/age-gender-prediction", trust_remote_code=True)
8
  result = classifier("image.jpg")
 
14
  from transformers import (
15
  ViTModel,
16
  ViTImageProcessor,
17
+ ViTPreTrainedModel,
18
+ ViTConfig,
19
+ ImageClassificationPipeline
 
20
  )
21
  from PIL import Image
22
  import numpy as np
23
 
24
 
25
+ class AgeGenderViTModel(ViTPreTrainedModel):
26
+ """Age-Gender ViT Model using standard ViT architecture"""
 
27
 
28
+ def __init__(self, config):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  super().__init__(config)
30
 
31
+ self.vit = ViTModel(config, add_pooling_layer=False)
32
 
33
  # Age head: 768 → 256 → 64 → 1
34
  self.age_head = nn.Sequential(
35
+ nn.Linear(config.hidden_size, 256),
36
+ nn.ReLU(),
37
+ nn.Dropout(0.1),
38
+ nn.Linear(256, 64),
39
+ nn.ReLU(),
40
+ nn.Dropout(0.1),
41
  nn.Linear(64, 1)
42
  )
43
 
44
  # Gender head: 768 → 256 → 64 → 1
45
  self.gender_head = nn.Sequential(
46
+ nn.Linear(config.hidden_size, 256),
47
+ nn.ReLU(),
48
+ nn.Dropout(0.1),
49
+ nn.Linear(256, 64),
50
+ nn.ReLU(),
51
+ nn.Dropout(0.1),
52
+ nn.Linear(64, 1),
53
+ nn.Sigmoid()
54
  )
55
 
56
+ # Standard classifier for compatibility
57
+ self.classifier = nn.Linear(config.hidden_size, 2)
58
+
59
+ # Initialize weights
60
+ self.post_init()
61
 
62
+ def forward(self, pixel_values=None, **kwargs):
63
  """Forward pass returning logits for pipeline"""
 
 
64
 
65
+ # Get ViT outputs
66
+ outputs = self.vit(pixel_values=pixel_values, **kwargs)
67
+
68
+ # Use the last hidden state and pool it
69
+ sequence_output = outputs[0]
70
+ pooled_output = sequence_output[:, 0] # Use [CLS] token
71
+
72
+ # Get age and gender predictions
73
  age_output = self.age_head(pooled_output)
74
  gender_output = self.gender_head(pooled_output)
75
 
76
+ # Create logits for pipeline
77
  logits = torch.cat([age_output, gender_output], dim=1)
78
+
79
  return {"logits": logits}
80
 
81
 
82
  class AgeGenderImageClassificationPipeline(ImageClassificationPipeline):
83
+ """Custom pipeline that converts model outputs to age/gender"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def postprocess(self, model_outputs, top_k=1, **kwargs):
86
+ """Convert logits to age/gender predictions"""
 
 
 
 
 
87
 
88
+ # Extract logits
89
+ logits = model_outputs["logits"]
90
+ age_raw = logits[0, 0].item()
91
+ gender_raw = logits[0, 1].item()
92
 
93
+ # Apply scaling discovered through testing
 
94
  age = int(max(18, min(70, ((age_raw - 1.5) / 1.0) * 50 + 20)))
95
 
96
+ # Process gender
97
  gender_prob = gender_raw
98
  gender = "Female" if gender_prob > 0.5 else "Male"
99
  confidence = gender_prob if gender_prob > 0.5 else 1 - gender_prob
100
 
101
+ # Return standard pipeline format
102
  return [{
103
  "label": f"{age} years, {gender}",
104
  "score": confidence,
 
110
  }]
111
 
112
 
113
+ # Helper function for manual usage
114
+ def predict_age_gender_manual(image_path):
115
+ """Manual prediction without pipeline"""
116
  import torch.nn as nn
117
  from transformers import ViTImageProcessor, ViTModel
118
 
119
+ class SimpleModel(nn.Module):
120
  def __init__(self):
121
  super().__init__()
122
  self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
 
141
  return age_output, gender_output
142
 
143
  # Load model
144
+ model = SimpleModel()
145
  model_url = "https://huggingface.co/abhilash88/age-gender-prediction/resolve/main/pytorch_model.bin"
146
  weights = torch.hub.load_state_dict_from_url(model_url, map_location='cpu')
147
  filtered_weights = {k: v for k, v in weights.items() if not k.startswith('classifier.')}
 
164
  with torch.no_grad():
165
  age_raw, gender_raw = model(inputs["pixel_values"])
166
 
167
+ # Apply scaling
168
  age_val = age_raw.item()
169
  age = int(max(18, min(70, ((age_val - 1.5) / 1.0) * 50 + 20)))
170
 
 
180
  }
181
 
182
 
 
183
  if __name__ == "__main__":
184
+ print("🧪 Testing simplified Age-Gender model...")
185
 
186
  try:
 
187
  from transformers import pipeline
188
 
189
+ # Test pipeline
190
  classifier = pipeline("image-classification", model="abhilash88/age-gender-prediction", trust_remote_code=True)
 
 
191
  test_url = "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?w=300"
192
  result = classifier(test_url)
193
 
194
+ print(f"✅ Pipeline: Age {result[0]['age']}, Gender {result[0]['gender']}")
 
 
 
 
 
195
 
196
+ # Test manual
197
+ manual_result = predict_age_gender_manual(test_url)
198
+ print(f"✅ Manual: {manual_result['summary']}")
199
 
200
  except Exception as e:
201
+ print(f"❌ Error: {e}")