mlp-fashion-mnist
A multi-layer perceptron (MLP) trained on the Fashion-MNIST dataset.
It is a PyTorch adaptation of the TensorFlow model in Chapter 10 of Aurelien Geron's book 'Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow'.
Code: https://github.com/sambitmukherjee/handson-ml3-pytorch/blob/main/chapter10/mlp_fashion_mnist.ipynb
Experiment tracking: https://wandb.ai/sadhaklal/mlp-fashion-mnist
Usage
!pip install -q datasets
from datasets import load_dataset
fashion_mnist = load_dataset("zalando-datasets/fashion_mnist")
features = fashion_mnist['train'].features
id2label = {id: label for id, label in enumerate(features['label'].names)}
import torch
import torchvision.transforms.v2 as v2
tfms = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True)
])
device = torch.device("cpu")
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
class MLP(nn.Module, PyTorchModelHubMixin):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 300)
self.fc2 = nn.Linear(300, 100)
self.fc3 = nn.Linear(100, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
act = torch.relu(self.fc1(x))
act = torch.relu(self.fc2(act))
return self.fc3(act)
model = MLP.from_pretrained("sadhaklal/mlp-fashion-mnist")
model.to(device)
example = fashion_mnist['test'][0]
import matplotlib.pyplot as plt
plt.imshow(example['image'], cmap='gray')
print(f"Ground truth: {id2label[example['label']]}")
img = tfms(example['image'])
x_batch = img.unsqueeze(0)
model.eval()
x_batch = x_batch.to(device)
with torch.no_grad():
logits = model(x_batch)
proba = torch.softmax(logits, dim=-1)
confidence, pred = proba.max(dim=-1)
print(f"Predicted class: {id2label[pred[0].item()]}")
print(f"Predicted confidence: {round(confidence[0].item(), 4)}")
Metric
Accuracy on the test set: 0.8829
This model has been pushed to the Hub using the PyTorchModelHubMixin integration.
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API:
The HF Inference API does not support image-classification models for pytorch library.