Source code

This commit is contained in:
alexiondev
2026-05-08 09:25:35 -04:00
parent 4867c1ac52
commit 037f7131c2
18 changed files with 1178 additions and 0 deletions

210
src/models/train.py Normal file
View File

@@ -0,0 +1,210 @@
import os
import json
import random
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms.v2 as T
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from src.data.dataset import SpindaDataset, get_default_augmentations
from src.models.regression_model import SpindaRegressionModel
class SpindaEvalDataset(Dataset):
"""Loads a fixed evaluation set (clean val or augmented test).
Images are stored post-augmentation, pre-normalisation; this class
applies only the normalisation step so the on-disk images are stable.
"""
def __init__(self, data_dir: str):
with open(os.path.join(data_dir, "metadata.json")) as f:
self.metadata = json.load(f)
self.transform = T.Compose([
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __len__(self) -> int:
return len(self.metadata)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
item = self.metadata[idx]
img_bgr = cv2.imread(item["img_path"])
img_rgb = img_bgr[:, :, ::-1].copy()
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1)
img_tensor = self.transform(img_tensor)
target = torch.tensor(item["target"], dtype=torch.long)
return img_tensor, target
# BL_x (index 4) and BL_y (index 5) are the weakest coordinates; upweight them
# so the optimiser focuses more of its gradient on the hardest spot.
_COORD_WEIGHTS = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.5, 2.5, 1.0, 1.0])
def _weighted_loss(
logits: torch.Tensor, targets: torch.Tensor, weights: torch.Tensor
) -> torch.Tensor:
"""CrossEntropy with per-coordinate weights. logits: (B,8,16), targets: (B,8)."""
B = logits.size(0)
per = F.cross_entropy(logits.view(-1, 16), targets.view(-1), reduction="none")
return (per.view(B, 8) * weights).mean()
def _worker_init_fn(worker_id: int) -> None:
"""Give each DataLoader worker a unique random seed so they generate different PIDs."""
seed = torch.initial_seed() % (2 ** 32)
random.seed(seed + worker_id)
np.random.seed(seed + worker_id)
def train_model(
epochs: int = 50,
batch_size: int = 64,
lr: float = 1e-4,
resume: bool = False,
model_path: str = "models/best_spinda_model.pth",
num_workers: int = 4,
epoch_size: int = 200000,
backbone: str = "resnet18",
save_path: str = "",
) -> None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
train_ds = SpindaDataset(size=epoch_size, transform=get_default_augmentations())
val_ds = SpindaEvalDataset("data/val")
aug_test_ds = SpindaEvalDataset("data/aug_test") if os.path.exists("data/aug_test") else None
train_loader = DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=num_workers, worker_init_fn=_worker_init_fn if num_workers > 0 else None,
persistent_workers=num_workers > 0, pin_memory=True,
)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
aug_loader = DataLoader(aug_test_ds, batch_size=batch_size, shuffle=False, num_workers=0) if aug_test_ds else None
checkpoint_path = save_path or f"models/best_{backbone}_model.pth"
if resume:
model = SpindaRegressionModel(pretrained=False, backbone=backbone).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
print(f"Resumed from {model_path}")
else:
model = SpindaRegressionModel(pretrained=True, backbone=backbone).to(device)
coord_weights = _COORD_WEIGHTS.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=3
)
best_exact_rate = 0.0
epochs_without_improvement = 0
early_stop_patience = 10
for epoch in range(epochs):
# ── Training ──────────────────────────────────────────────────
model.train()
train_loss = 0.0
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} [Train]")
for images, targets in pbar:
images = images.to(device)
targets = targets.to(device) # (B, 8) long, values 0-15
optimizer.zero_grad()
logits = model(images) # (B, 8, 16)
loss = _weighted_loss(logits, targets, coord_weights)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
train_loss /= len(train_loader.dataset)
# ── Validation ────────────────────────────────────────────────
model.eval()
val_loss = 0.0
exact_matches = 0
with torch.no_grad():
for images, targets in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{epochs} [Val]"):
images = images.to(device)
targets = targets.to(device)
logits = model(images) # (B, 8, 16)
loss = F.cross_entropy(logits.view(-1, 16), targets.view(-1))
val_loss += loss.item() * images.size(0)
preds = logits.argmax(dim=2) # (B, 8)
exact_matches += torch.all(preds == targets, dim=1).sum().item()
val_loss /= len(val_loader.dataset)
exact_rate = exact_matches / len(val_loader.dataset)
# ── Augmented test set ────────────────────────────────────────
aug_exact_rate = 0.0
if aug_loader is not None:
aug_exact = 0
with torch.no_grad():
for images, targets in aug_loader:
images, targets = images.to(device), targets.to(device)
logits = model(images)
preds = logits.argmax(dim=2)
aug_exact += torch.all(preds == targets, dim=1).sum().item()
aug_exact_rate = aug_exact / len(aug_test_ds)
aug_str = f" Aug Test: {aug_exact_rate:.2%}" if aug_loader else ""
print(
f"Epoch {epoch + 1}: "
f"Train Loss: {train_loss:.4f} "
f"Val Loss: {val_loss:.4f} "
f"Clean Val: {exact_rate:.2%}"
f"{aug_str}"
)
scheduler.step(val_loss)
# Save on exact-match improvement (the metric that actually matters)
if exact_rate > best_exact_rate:
best_exact_rate = exact_rate
epochs_without_improvement = 0
os.makedirs("models", exist_ok=True)
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
torch.save(model.state_dict(), checkpoint_path)
print(f" → Saved best model to {checkpoint_path} (clean val exact match: {best_exact_rate:.2%})")
else:
epochs_without_improvement += 1
if epochs_without_improvement >= early_stop_patience:
print(f" → No improvement for {early_stop_patience} epochs. Stopping early.")
break
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--resume", action="store_true", help="Fine-tune from --model_path checkpoint")
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)")
parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch")
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"])
parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path")
args = parser.parse_args()
train_model(
epochs=args.epochs, batch_size=args.batch_size, lr=args.lr,
resume=args.resume, model_path=args.model_path,
num_workers=args.num_workers, epoch_size=args.epoch_size,
backbone=args.backbone, save_path=args.save_path,
)