import os import json import random import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision.transforms.v2 as T from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from src.data.dataset import SpindaDataset, get_default_augmentations from src.models.regression_model import SpindaRegressionModel class SpindaEvalDataset(Dataset): """Loads a fixed evaluation set (clean val or augmented test). Images are stored post-augmentation, pre-normalisation; this class applies only the normalisation step so the on-disk images are stable. """ def __init__(self, data_dir: str): with open(os.path.join(data_dir, "metadata.json")) as f: self.metadata = json.load(f) self.transform = T.Compose([ T.ToDtype(torch.float32, scale=True), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __len__(self) -> int: return len(self.metadata) def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: item = self.metadata[idx] img_bgr = cv2.imread(item["img_path"]) img_rgb = img_bgr[:, :, ::-1].copy() img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1) img_tensor = self.transform(img_tensor) target = torch.tensor(item["target"], dtype=torch.long) return img_tensor, target # BL_x (index 4) and BL_y (index 5) are the weakest coordinates; upweight them # so the optimiser focuses more of its gradient on the hardest spot. _COORD_WEIGHTS = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.5, 2.5, 1.0, 1.0]) def _weighted_loss( logits: torch.Tensor, targets: torch.Tensor, weights: torch.Tensor ) -> torch.Tensor: """CrossEntropy with per-coordinate weights. logits: (B,8,16), targets: (B,8).""" B = logits.size(0) per = F.cross_entropy(logits.view(-1, 16), targets.view(-1), reduction="none") return (per.view(B, 8) * weights).mean() def _worker_init_fn(worker_id: int) -> None: """Give each DataLoader worker a unique random seed so they generate different PIDs.""" seed = torch.initial_seed() % (2 ** 32) random.seed(seed + worker_id) np.random.seed(seed + worker_id) def train_model( epochs: int = 50, batch_size: int = 64, lr: float = 1e-4, resume: bool = False, model_path: str = "models/best_spinda_model.pth", num_workers: int = 4, epoch_size: int = 200000, backbone: str = "resnet18", save_path: str = "", ) -> None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") train_ds = SpindaDataset(size=epoch_size, transform=get_default_augmentations()) val_ds = SpindaEvalDataset("data/val") aug_test_ds = SpindaEvalDataset("data/aug_test") if os.path.exists("data/aug_test") else None train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_worker_init_fn if num_workers > 0 else None, persistent_workers=num_workers > 0, pin_memory=True, ) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0) aug_loader = DataLoader(aug_test_ds, batch_size=batch_size, shuffle=False, num_workers=0) if aug_test_ds else None checkpoint_path = save_path or f"models/best_{backbone}_model.pth" if resume: model = SpindaRegressionModel(pretrained=False, backbone=backbone).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) print(f"Resumed from {model_path}") else: model = SpindaRegressionModel(pretrained=True, backbone=backbone).to(device) coord_weights = _COORD_WEIGHTS.to(device) optimizer = optim.Adam(model.parameters(), lr=lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=3 ) best_exact_rate = 0.0 epochs_without_improvement = 0 early_stop_patience = 10 for epoch in range(epochs): # ── Training ────────────────────────────────────────────────── model.train() train_loss = 0.0 pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} [Train]") for images, targets in pbar: images = images.to(device) targets = targets.to(device) # (B, 8) long, values 0-15 optimizer.zero_grad() logits = model(images) # (B, 8, 16) loss = _weighted_loss(logits, targets, coord_weights) loss.backward() optimizer.step() train_loss += loss.item() * images.size(0) pbar.set_postfix({"loss": f"{loss.item():.4f}"}) train_loss /= len(train_loader.dataset) # ── Validation ──────────────────────────────────────────────── model.eval() val_loss = 0.0 exact_matches = 0 with torch.no_grad(): for images, targets in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{epochs} [Val]"): images = images.to(device) targets = targets.to(device) logits = model(images) # (B, 8, 16) loss = F.cross_entropy(logits.view(-1, 16), targets.view(-1)) val_loss += loss.item() * images.size(0) preds = logits.argmax(dim=2) # (B, 8) exact_matches += torch.all(preds == targets, dim=1).sum().item() val_loss /= len(val_loader.dataset) exact_rate = exact_matches / len(val_loader.dataset) # ── Augmented test set ──────────────────────────────────────── aug_exact_rate = 0.0 if aug_loader is not None: aug_exact = 0 with torch.no_grad(): for images, targets in aug_loader: images, targets = images.to(device), targets.to(device) logits = model(images) preds = logits.argmax(dim=2) aug_exact += torch.all(preds == targets, dim=1).sum().item() aug_exact_rate = aug_exact / len(aug_test_ds) aug_str = f" Aug Test: {aug_exact_rate:.2%}" if aug_loader else "" print( f"Epoch {epoch + 1}: " f"Train Loss: {train_loss:.4f} " f"Val Loss: {val_loss:.4f} " f"Clean Val: {exact_rate:.2%}" f"{aug_str}" ) scheduler.step(val_loss) # Save on exact-match improvement (the metric that actually matters) if exact_rate > best_exact_rate: best_exact_rate = exact_rate epochs_without_improvement = 0 os.makedirs("models", exist_ok=True) os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) torch.save(model.state_dict(), checkpoint_path) print(f" → Saved best model to {checkpoint_path} (clean val exact match: {best_exact_rate:.2%})") else: epochs_without_improvement += 1 if epochs_without_improvement >= early_stop_patience: print(f" → No improvement for {early_stop_patience} epochs. Stopping early.") break if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--epochs", type=int, default=50) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--resume", action="store_true", help="Fine-tune from --model_path checkpoint") parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)") parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch") parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"]) parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path") args = parser.parse_args() train_model( epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, resume=args.resume, model_path=args.model_path, num_workers=args.num_workers, epoch_size=args.epoch_size, backbone=args.backbone, save_path=args.save_path, )