211 lines
8.6 KiB
Python
211 lines
8.6 KiB
Python
import os
|
|
import json
|
|
import random
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
import torchvision.transforms.v2 as T
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from tqdm import tqdm
|
|
|
|
from src.data.dataset import SpindaDataset, get_default_augmentations
|
|
from src.models.regression_model import SpindaRegressionModel
|
|
|
|
|
|
class SpindaEvalDataset(Dataset):
|
|
"""Loads a fixed evaluation set (clean val or augmented test).
|
|
|
|
Images are stored post-augmentation, pre-normalisation; this class
|
|
applies only the normalisation step so the on-disk images are stable.
|
|
"""
|
|
|
|
def __init__(self, data_dir: str):
|
|
with open(os.path.join(data_dir, "metadata.json")) as f:
|
|
self.metadata = json.load(f)
|
|
self.transform = T.Compose([
|
|
T.ToDtype(torch.float32, scale=True),
|
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.metadata)
|
|
|
|
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
item = self.metadata[idx]
|
|
img_bgr = cv2.imread(item["img_path"])
|
|
img_rgb = img_bgr[:, :, ::-1].copy()
|
|
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1)
|
|
img_tensor = self.transform(img_tensor)
|
|
target = torch.tensor(item["target"], dtype=torch.long)
|
|
return img_tensor, target
|
|
|
|
|
|
# BL_x (index 4) and BL_y (index 5) are the weakest coordinates; upweight them
|
|
# so the optimiser focuses more of its gradient on the hardest spot.
|
|
_COORD_WEIGHTS = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.5, 2.5, 1.0, 1.0])
|
|
|
|
|
|
def _weighted_loss(
|
|
logits: torch.Tensor, targets: torch.Tensor, weights: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""CrossEntropy with per-coordinate weights. logits: (B,8,16), targets: (B,8)."""
|
|
B = logits.size(0)
|
|
per = F.cross_entropy(logits.view(-1, 16), targets.view(-1), reduction="none")
|
|
return (per.view(B, 8) * weights).mean()
|
|
|
|
|
|
def _worker_init_fn(worker_id: int) -> None:
|
|
"""Give each DataLoader worker a unique random seed so they generate different PIDs."""
|
|
seed = torch.initial_seed() % (2 ** 32)
|
|
random.seed(seed + worker_id)
|
|
np.random.seed(seed + worker_id)
|
|
|
|
|
|
def train_model(
|
|
epochs: int = 50,
|
|
batch_size: int = 64,
|
|
lr: float = 1e-4,
|
|
resume: bool = False,
|
|
model_path: str = "models/best_spinda_model.pth",
|
|
num_workers: int = 4,
|
|
epoch_size: int = 200000,
|
|
backbone: str = "resnet18",
|
|
save_path: str = "",
|
|
) -> None:
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
train_ds = SpindaDataset(size=epoch_size, transform=get_default_augmentations())
|
|
val_ds = SpindaEvalDataset("data/val")
|
|
aug_test_ds = SpindaEvalDataset("data/aug_test") if os.path.exists("data/aug_test") else None
|
|
|
|
train_loader = DataLoader(
|
|
train_ds, batch_size=batch_size, shuffle=True,
|
|
num_workers=num_workers, worker_init_fn=_worker_init_fn if num_workers > 0 else None,
|
|
persistent_workers=num_workers > 0, pin_memory=True,
|
|
)
|
|
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
|
aug_loader = DataLoader(aug_test_ds, batch_size=batch_size, shuffle=False, num_workers=0) if aug_test_ds else None
|
|
|
|
checkpoint_path = save_path or f"models/best_{backbone}_model.pth"
|
|
|
|
if resume:
|
|
model = SpindaRegressionModel(pretrained=False, backbone=backbone).to(device)
|
|
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
print(f"Resumed from {model_path}")
|
|
else:
|
|
model = SpindaRegressionModel(pretrained=True, backbone=backbone).to(device)
|
|
coord_weights = _COORD_WEIGHTS.to(device)
|
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
|
optimizer, mode="min", factor=0.5, patience=3
|
|
)
|
|
|
|
best_exact_rate = 0.0
|
|
epochs_without_improvement = 0
|
|
early_stop_patience = 10
|
|
|
|
for epoch in range(epochs):
|
|
# ── Training ──────────────────────────────────────────────────
|
|
model.train()
|
|
train_loss = 0.0
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} [Train]")
|
|
for images, targets in pbar:
|
|
images = images.to(device)
|
|
targets = targets.to(device) # (B, 8) long, values 0-15
|
|
|
|
optimizer.zero_grad()
|
|
logits = model(images) # (B, 8, 16)
|
|
loss = _weighted_loss(logits, targets, coord_weights)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
train_loss += loss.item() * images.size(0)
|
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
|
|
|
train_loss /= len(train_loader.dataset)
|
|
|
|
# ── Validation ────────────────────────────────────────────────
|
|
model.eval()
|
|
val_loss = 0.0
|
|
exact_matches = 0
|
|
|
|
with torch.no_grad():
|
|
for images, targets in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{epochs} [Val]"):
|
|
images = images.to(device)
|
|
targets = targets.to(device)
|
|
|
|
logits = model(images) # (B, 8, 16)
|
|
loss = F.cross_entropy(logits.view(-1, 16), targets.view(-1))
|
|
val_loss += loss.item() * images.size(0)
|
|
|
|
preds = logits.argmax(dim=2) # (B, 8)
|
|
exact_matches += torch.all(preds == targets, dim=1).sum().item()
|
|
|
|
val_loss /= len(val_loader.dataset)
|
|
exact_rate = exact_matches / len(val_loader.dataset)
|
|
|
|
# ── Augmented test set ────────────────────────────────────────
|
|
aug_exact_rate = 0.0
|
|
if aug_loader is not None:
|
|
aug_exact = 0
|
|
with torch.no_grad():
|
|
for images, targets in aug_loader:
|
|
images, targets = images.to(device), targets.to(device)
|
|
logits = model(images)
|
|
preds = logits.argmax(dim=2)
|
|
aug_exact += torch.all(preds == targets, dim=1).sum().item()
|
|
aug_exact_rate = aug_exact / len(aug_test_ds)
|
|
|
|
aug_str = f" Aug Test: {aug_exact_rate:.2%}" if aug_loader else ""
|
|
print(
|
|
f"Epoch {epoch + 1}: "
|
|
f"Train Loss: {train_loss:.4f} "
|
|
f"Val Loss: {val_loss:.4f} "
|
|
f"Clean Val: {exact_rate:.2%}"
|
|
f"{aug_str}"
|
|
)
|
|
|
|
scheduler.step(val_loss)
|
|
|
|
# Save on exact-match improvement (the metric that actually matters)
|
|
if exact_rate > best_exact_rate:
|
|
best_exact_rate = exact_rate
|
|
epochs_without_improvement = 0
|
|
os.makedirs("models", exist_ok=True)
|
|
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
|
torch.save(model.state_dict(), checkpoint_path)
|
|
print(f" → Saved best model to {checkpoint_path} (clean val exact match: {best_exact_rate:.2%})")
|
|
else:
|
|
epochs_without_improvement += 1
|
|
if epochs_without_improvement >= early_stop_patience:
|
|
print(f" → No improvement for {early_stop_patience} epochs. Stopping early.")
|
|
break
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--epochs", type=int, default=50)
|
|
parser.add_argument("--batch_size", type=int, default=64)
|
|
parser.add_argument("--lr", type=float, default=1e-4)
|
|
parser.add_argument("--resume", action="store_true", help="Fine-tune from --model_path checkpoint")
|
|
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
|
|
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)")
|
|
parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch")
|
|
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
|
|
parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path")
|
|
args = parser.parse_args()
|
|
|
|
train_model(
|
|
epochs=args.epochs, batch_size=args.batch_size, lr=args.lr,
|
|
resume=args.resume, model_path=args.model_path,
|
|
num_workers=args.num_workers, epoch_size=args.epoch_size,
|
|
backbone=args.backbone, save_path=args.save_path,
|
|
)
|