import os
import random
import argparse
import pandas as pd
import numpy as np

from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms

from sklearn.metrics import accuracy_score, precision_score, recall_score

# Config
SEED = 42
BATCH_SIZE = 8
EPOCHS = 5
LEARNING_RATE = 1e-4
VAL_RATIO = 0.2  # 20% dei dati in validazione

MODEL_PTH = "melanoma_best_model.pt"
ONNX_PTH  = "melanoma_best_model.onnx"

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

class MelanomaDataset(Dataset):
    """
    Legge un CSV con: path,is_Melanoma
    path -> es. "all_images/xxx.jpg"
    is_Melanoma -> "True"/"False" o 1/0
    """
    def __init__(self, csv_file, root_dir, transform=None):
        self.df = pd.read_csv(csv_file)
        self.root_dir = root_dir  # /ephemeral/unified_dataset
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        rel_path = row["path"]  # es. "all_images/dataset_1742xxx_xxx.jpg"
        label_str = str(row["is_Melanoma"]).strip()

        if label_str.lower() in ["true", "1"]:
            label = 1
        else:
            label = 0

        full_img_path = os.path.join(self.root_dir, rel_path)
        image = Image.open(full_img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.float().unsqueeze(1).to(device)

        optimizer.zero_grad()
        outputs = model(images)  # [B,1]
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds, all_trues = [], []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.float().unsqueeze(1).to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).long().cpu().numpy().flatten()
            all_preds.extend(preds)
            all_trues.extend(labels.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(all_trues, all_preds)
    prec = precision_score(all_trues, all_preds, zero_division=0)
    rec = recall_score(all_trues, all_preds, zero_division=0)

    return avg_loss, acc, prec, rec

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--unified_dir", type=str, required=True,
                        help="Cartella che contiene all_images/ e unified_labels.csv")
    args = parser.parse_args()

    unified_dir = args.unified_dir
    csv_path = os.path.join(unified_dir, "unified_labels.csv")
    images_root = unified_dir  # conterrà "all_images/"

    # Controlli veloci
    if not os.path.exists(csv_path):
        raise FileNotFoundError("unified_labels.csv non trovato in: " + csv_path)

    # Trasformazioni per le immagini
    transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        # transforms.Normalize(...) se vuoi normalizzare
    ])

    # Creazione Dataset
    dataset = MelanomaDataset(csv_file=csv_path, root_dir=images_root, transform=transform)
    print(f"[INFO] Record totali: {len(dataset)}")

    # Split train/val
    val_size = int(len(dataset) * VAL_RATIO)
    train_size = len(dataset) - val_size

    train_ds, val_ds = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(SEED))
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

    print(f"[INFO] Train set: {train_size}, Val set: {val_size}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Modello
    model = models.resnet18(pretrained=True)
    num_feats = model.fc.in_features
    model.fc = nn.Linear(num_feats, 1)  # 1 logit
    model = model.to(device)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

    # Training
    best_val_loss = float("inf")
    for epoch in range(1, EPOCHS+1):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, val_prec, val_rec = evaluate(model, val_loader, criterion, device)
        scheduler.step()

        val_acc_percent = val_acc * 100
        val_prec_percent = val_prec * 100
        val_rec_percent = val_rec * 100

        print(f"Epoch {epoch}/{EPOCHS} | "
            f"Train Loss: {train_loss:.4f} | "
            f"Val Loss: {val_loss:.4f} | "
            f"Val Acc: {val_acc_percent:.2f}% | "
            f"Val Prec: {val_prec_percent:.2f}% | "
            f"Val Rec: {val_rec_percent:.2f}%")


        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), MODEL_PTH)
            print("[INFO] Miglior modello salvato.")

    # Carico i pesi migliori
    model.load_state_dict(torch.load(MODEL_PTH))

    # Esportazione ONNX
    dummy_input = torch.randn(1,3,224,224).to(device)
    torch.onnx.export(
        model,
        dummy_input,
        ONNX_PTH,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input":{0:"batch_size"}, "output":{0:"batch_size"}},
        opset_version=11
    )
    print(f"[INFO] Modello ONNX salvato in {ONNX_PTH}")

if __name__ == "__main__":
    main()
