Source code
This commit is contained in:
58
src/models/inference.py
Normal file
58
src/models/inference.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.v2 as T
|
||||
from PIL import Image
|
||||
from typing import List, Tuple
|
||||
|
||||
from src.models.regression_model import SpindaRegressionModel
|
||||
|
||||
|
||||
class SpindaInference:
|
||||
"""Loads the trained model and predicts spot coordinates from an image crop."""
|
||||
|
||||
def __init__(self, model_path: str = "models/best_spinda_model.pth"):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.model = SpindaRegressionModel(pretrained=False)
|
||||
self.model.load_state_dict(
|
||||
torch.load(model_path, map_location=self.device)
|
||||
)
|
||||
self.model.to(self.device).eval()
|
||||
|
||||
self.transform = T.Compose([
|
||||
T.Resize((128, 128)),
|
||||
T.ToDtype(torch.float32, scale=True),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def predict(self, image_path: str) -> Tuple[List[int], str]:
|
||||
"""Predict the 8 grid coordinates and return them with a fingerprint string.
|
||||
|
||||
Returns:
|
||||
grid_coords: list of 8 integers in [0, 15]
|
||||
fingerprint: "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
|
||||
"""
|
||||
img_bgr = cv2.imread(image_path)
|
||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
img_tensor = torch.from_numpy(np.array(img_rgb)).permute(2, 0, 1)
|
||||
img_tensor = self.transform(img_tensor).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.model(img_tensor) # (1, 8, 16)
|
||||
|
||||
grid_coords = logits.argmax(dim=2).squeeze(0).cpu().tolist() # [8]
|
||||
fingerprint = "-".join(f"{c:02d}" for c in grid_coords)
|
||||
|
||||
return grid_coords, fingerprint
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python src/models/inference.py <image_path>")
|
||||
else:
|
||||
inf = SpindaInference()
|
||||
coords, fingerprint = inf.predict(sys.argv[1])
|
||||
print(f"Predicted Grid Coordinates: {coords}")
|
||||
print(f"Visual Fingerprint: {fingerprint}")
|
||||
47
src/models/regression_model.py
Normal file
47
src/models/regression_model.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
from torchvision.models import ResNet18_Weights, ResNet34_Weights
|
||||
|
||||
_BACKBONES = {
|
||||
"resnet18": (models.resnet18, ResNet18_Weights.DEFAULT),
|
||||
"resnet34": (models.resnet34, ResNet34_Weights.DEFAULT),
|
||||
}
|
||||
|
||||
|
||||
class SpindaRegressionModel(nn.Module):
|
||||
"""ResNet backbone with 8 independent 16-class coordinate heads.
|
||||
|
||||
Each of the 8 output coordinates (4 spots × x, y) is treated as a
|
||||
16-class classification problem over the [0, 15] nibble grid.
|
||||
This eliminates the float→integer rounding step and lets CrossEntropy
|
||||
directly optimise for exact coordinate prediction.
|
||||
|
||||
Output shape: (B, 8, 16) — unnormalised logits.
|
||||
Prediction: output.argmax(dim=2) → (B, 8) integer coordinates.
|
||||
"""
|
||||
|
||||
def __init__(self, pretrained: bool = True, backbone: str = "resnet18"):
|
||||
super().__init__()
|
||||
if backbone not in _BACKBONES:
|
||||
raise ValueError(f"backbone must be one of {list(_BACKBONES)}; got {backbone!r}")
|
||||
factory, default_weights = _BACKBONES[backbone]
|
||||
weights = default_weights if pretrained else None
|
||||
net = factory(weights=weights)
|
||||
# Strip the final FC; keep the feature extractor + average pool.
|
||||
self.features = nn.Sequential(*list(net.children())[:-1])
|
||||
# 8 coordinates × 16 classes each (512-dim output for both resnet18/34)
|
||||
self.classifier = nn.Linear(512, 8 * 16)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.features(x) # (B, 512, 1, 1)
|
||||
x = x.flatten(1) # (B, 512)
|
||||
x = self.classifier(x) # (B, 128)
|
||||
return x.view(-1, 8, 16) # (B, 8, 16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for name in ("resnet18", "resnet34"):
|
||||
model = SpindaRegressionModel(pretrained=False, backbone=name)
|
||||
out = model(torch.randn(2, 3, 128, 128))
|
||||
print(f"{name}: output {out.shape}, predictions {out.argmax(dim=2)}")
|
||||
210
src/models/train.py
Normal file
210
src/models/train.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms.v2 as T
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.data.dataset import SpindaDataset, get_default_augmentations
|
||||
from src.models.regression_model import SpindaRegressionModel
|
||||
|
||||
|
||||
class SpindaEvalDataset(Dataset):
|
||||
"""Loads a fixed evaluation set (clean val or augmented test).
|
||||
|
||||
Images are stored post-augmentation, pre-normalisation; this class
|
||||
applies only the normalisation step so the on-disk images are stable.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str):
|
||||
with open(os.path.join(data_dir, "metadata.json")) as f:
|
||||
self.metadata = json.load(f)
|
||||
self.transform = T.Compose([
|
||||
T.ToDtype(torch.float32, scale=True),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.metadata)
|
||||
|
||||
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
item = self.metadata[idx]
|
||||
img_bgr = cv2.imread(item["img_path"])
|
||||
img_rgb = img_bgr[:, :, ::-1].copy()
|
||||
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1)
|
||||
img_tensor = self.transform(img_tensor)
|
||||
target = torch.tensor(item["target"], dtype=torch.long)
|
||||
return img_tensor, target
|
||||
|
||||
|
||||
# BL_x (index 4) and BL_y (index 5) are the weakest coordinates; upweight them
|
||||
# so the optimiser focuses more of its gradient on the hardest spot.
|
||||
_COORD_WEIGHTS = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.5, 2.5, 1.0, 1.0])
|
||||
|
||||
|
||||
def _weighted_loss(
|
||||
logits: torch.Tensor, targets: torch.Tensor, weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""CrossEntropy with per-coordinate weights. logits: (B,8,16), targets: (B,8)."""
|
||||
B = logits.size(0)
|
||||
per = F.cross_entropy(logits.view(-1, 16), targets.view(-1), reduction="none")
|
||||
return (per.view(B, 8) * weights).mean()
|
||||
|
||||
|
||||
def _worker_init_fn(worker_id: int) -> None:
|
||||
"""Give each DataLoader worker a unique random seed so they generate different PIDs."""
|
||||
seed = torch.initial_seed() % (2 ** 32)
|
||||
random.seed(seed + worker_id)
|
||||
np.random.seed(seed + worker_id)
|
||||
|
||||
|
||||
def train_model(
|
||||
epochs: int = 50,
|
||||
batch_size: int = 64,
|
||||
lr: float = 1e-4,
|
||||
resume: bool = False,
|
||||
model_path: str = "models/best_spinda_model.pth",
|
||||
num_workers: int = 4,
|
||||
epoch_size: int = 200000,
|
||||
backbone: str = "resnet18",
|
||||
save_path: str = "",
|
||||
) -> None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
train_ds = SpindaDataset(size=epoch_size, transform=get_default_augmentations())
|
||||
val_ds = SpindaEvalDataset("data/val")
|
||||
aug_test_ds = SpindaEvalDataset("data/aug_test") if os.path.exists("data/aug_test") else None
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=batch_size, shuffle=True,
|
||||
num_workers=num_workers, worker_init_fn=_worker_init_fn if num_workers > 0 else None,
|
||||
persistent_workers=num_workers > 0, pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
||||
aug_loader = DataLoader(aug_test_ds, batch_size=batch_size, shuffle=False, num_workers=0) if aug_test_ds else None
|
||||
|
||||
checkpoint_path = save_path or f"models/best_{backbone}_model.pth"
|
||||
|
||||
if resume:
|
||||
model = SpindaRegressionModel(pretrained=False, backbone=backbone).to(device)
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
print(f"Resumed from {model_path}")
|
||||
else:
|
||||
model = SpindaRegressionModel(pretrained=True, backbone=backbone).to(device)
|
||||
coord_weights = _COORD_WEIGHTS.to(device)
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode="min", factor=0.5, patience=3
|
||||
)
|
||||
|
||||
best_exact_rate = 0.0
|
||||
epochs_without_improvement = 0
|
||||
early_stop_patience = 10
|
||||
|
||||
for epoch in range(epochs):
|
||||
# ── Training ──────────────────────────────────────────────────
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} [Train]")
|
||||
for images, targets in pbar:
|
||||
images = images.to(device)
|
||||
targets = targets.to(device) # (B, 8) long, values 0-15
|
||||
|
||||
optimizer.zero_grad()
|
||||
logits = model(images) # (B, 8, 16)
|
||||
loss = _weighted_loss(logits, targets, coord_weights)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item() * images.size(0)
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
|
||||
train_loss /= len(train_loader.dataset)
|
||||
|
||||
# ── Validation ────────────────────────────────────────────────
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
exact_matches = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for images, targets in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{epochs} [Val]"):
|
||||
images = images.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
logits = model(images) # (B, 8, 16)
|
||||
loss = F.cross_entropy(logits.view(-1, 16), targets.view(-1))
|
||||
val_loss += loss.item() * images.size(0)
|
||||
|
||||
preds = logits.argmax(dim=2) # (B, 8)
|
||||
exact_matches += torch.all(preds == targets, dim=1).sum().item()
|
||||
|
||||
val_loss /= len(val_loader.dataset)
|
||||
exact_rate = exact_matches / len(val_loader.dataset)
|
||||
|
||||
# ── Augmented test set ────────────────────────────────────────
|
||||
aug_exact_rate = 0.0
|
||||
if aug_loader is not None:
|
||||
aug_exact = 0
|
||||
with torch.no_grad():
|
||||
for images, targets in aug_loader:
|
||||
images, targets = images.to(device), targets.to(device)
|
||||
logits = model(images)
|
||||
preds = logits.argmax(dim=2)
|
||||
aug_exact += torch.all(preds == targets, dim=1).sum().item()
|
||||
aug_exact_rate = aug_exact / len(aug_test_ds)
|
||||
|
||||
aug_str = f" Aug Test: {aug_exact_rate:.2%}" if aug_loader else ""
|
||||
print(
|
||||
f"Epoch {epoch + 1}: "
|
||||
f"Train Loss: {train_loss:.4f} "
|
||||
f"Val Loss: {val_loss:.4f} "
|
||||
f"Clean Val: {exact_rate:.2%}"
|
||||
f"{aug_str}"
|
||||
)
|
||||
|
||||
scheduler.step(val_loss)
|
||||
|
||||
# Save on exact-match improvement (the metric that actually matters)
|
||||
if exact_rate > best_exact_rate:
|
||||
best_exact_rate = exact_rate
|
||||
epochs_without_improvement = 0
|
||||
os.makedirs("models", exist_ok=True)
|
||||
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
||||
torch.save(model.state_dict(), checkpoint_path)
|
||||
print(f" → Saved best model to {checkpoint_path} (clean val exact match: {best_exact_rate:.2%})")
|
||||
else:
|
||||
epochs_without_improvement += 1
|
||||
if epochs_without_improvement >= early_stop_patience:
|
||||
print(f" → No improvement for {early_stop_patience} epochs. Stopping early.")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--epochs", type=int, default=50)
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
parser.add_argument("--resume", action="store_true", help="Fine-tune from --model_path checkpoint")
|
||||
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
|
||||
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)")
|
||||
parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch")
|
||||
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"])
|
||||
parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path")
|
||||
args = parser.parse_args()
|
||||
|
||||
train_model(
|
||||
epochs=args.epochs, batch_size=args.batch_size, lr=args.lr,
|
||||
resume=args.resume, model_path=args.model_path,
|
||||
num_workers=args.num_workers, epoch_size=args.epoch_size,
|
||||
backbone=args.backbone, save_path=args.save_path,
|
||||
)
|
||||
Reference in New Issue
Block a user