abhilash88's picture
Update model.py
8a420df verified
import torch
import torch.nn as nn
from transformers import ViTModel, ViTPreTrainedModel
class AgeGenderViTModel(ViTPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.vit = ViTModel(config, add_pooling_layer=False)
self.age_head = nn.Sequential(
nn.Linear(config.hidden_size, 256), nn.ReLU(), nn.Dropout(0.1),
nn.Linear(256, 64), nn.ReLU(), nn.Dropout(0.1),
nn.Linear(64, 1)
)
self.gender_head = nn.Sequential(
nn.Linear(config.hidden_size, 256), nn.ReLU(), nn.Dropout(0.1),
nn.Linear(256, 64), nn.ReLU(), nn.Dropout(0.1),
nn.Linear(64, 1), nn.Sigmoid()
)
self.classifier = nn.Linear(config.hidden_size, 2)
self.post_init()
def forward(self, pixel_values=None, **kwargs):
outputs = self.vit(pixel_values=pixel_values, **kwargs)
sequence_output = outputs[0]
pooled_output = sequence_output[:, 0]
age_output = self.age_head(pooled_output)
gender_output = self.gender_head(pooled_output)
logits = torch.cat([age_output, gender_output], dim=1)
return {"logits": logits}
# Add this to the END of your model.py file
def predict_age_gender(image_path):
"""
Simple one-liner function for age-gender prediction
Args:
image_path: Path to image file or URL
Returns:
Dictionary with age, gender, confidence
"""
from transformers import pipeline
classifier = pipeline("image-classification", model="abhilash88/age-gender-prediction", trust_remote_code=True)
raw = classifier(image_path)
result = raw[0] # Get first result
return {
'age': result['age'],
'gender': result['gender'],
'confidence': result['gender_confidence'],
'summary': f"{result['age']} years, {result['gender']} ({result['gender_confidence']:.1%} confidence)"
}
def simple_predict(image_path):
"""
Even simpler - just returns a string
"""
result = predict_age_gender(image_path)
return result['summary']