|
import torch |
|
import torch.nn as nn |
|
from transformers import ViTModel, ViTPreTrainedModel |
|
|
|
class AgeGenderViTModel(ViTPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.vit = ViTModel(config, add_pooling_layer=False) |
|
|
|
self.age_head = nn.Sequential( |
|
nn.Linear(config.hidden_size, 256), nn.ReLU(), nn.Dropout(0.1), |
|
nn.Linear(256, 64), nn.ReLU(), nn.Dropout(0.1), |
|
nn.Linear(64, 1) |
|
) |
|
|
|
self.gender_head = nn.Sequential( |
|
nn.Linear(config.hidden_size, 256), nn.ReLU(), nn.Dropout(0.1), |
|
nn.Linear(256, 64), nn.ReLU(), nn.Dropout(0.1), |
|
nn.Linear(64, 1), nn.Sigmoid() |
|
) |
|
|
|
self.classifier = nn.Linear(config.hidden_size, 2) |
|
self.post_init() |
|
|
|
def forward(self, pixel_values=None, **kwargs): |
|
outputs = self.vit(pixel_values=pixel_values, **kwargs) |
|
sequence_output = outputs[0] |
|
pooled_output = sequence_output[:, 0] |
|
|
|
age_output = self.age_head(pooled_output) |
|
gender_output = self.gender_head(pooled_output) |
|
|
|
logits = torch.cat([age_output, gender_output], dim=1) |
|
return {"logits": logits} |
|
|
|
|
|
|
|
def predict_age_gender(image_path): |
|
""" |
|
Simple one-liner function for age-gender prediction |
|
|
|
Args: |
|
image_path: Path to image file or URL |
|
|
|
Returns: |
|
Dictionary with age, gender, confidence |
|
""" |
|
from transformers import pipeline |
|
|
|
classifier = pipeline("image-classification", model="abhilash88/age-gender-prediction", trust_remote_code=True) |
|
raw = classifier(image_path) |
|
result = raw[0] |
|
|
|
return { |
|
'age': result['age'], |
|
'gender': result['gender'], |
|
'confidence': result['gender_confidence'], |
|
'summary': f"{result['age']} years, {result['gender']} ({result['gender_confidence']:.1%} confidence)" |
|
} |
|
|
|
|
|
def simple_predict(image_path): |
|
""" |
|
Even simpler - just returns a string |
|
""" |
|
result = predict_age_gender(image_path) |
|
return result['summary'] |