Refactor/cleanup
This commit is contained in:
@@ -5,7 +5,7 @@ import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from src.models.regression_model import SpindaRegressionModel
|
||||
from src.models.train import SpindaEvalDataset
|
||||
from src.data.dataset import SpindaEvalDataset
|
||||
|
||||
|
||||
def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: SpindaEvalDataset) -> None:
|
||||
@@ -29,8 +29,8 @@ def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: Spinda
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
|
||||
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
|
||||
parser.add_argument("--model_path", type=str, default="models/best_resnet34_model.pth")
|
||||
parser.add_argument("--backbone", type=str, default="resnet34", choices=["resnet18", "resnet34", "convnext_tiny"])
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.model_path):
|
||||
|
||||
@@ -2,8 +2,7 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.v2 as T
|
||||
from PIL import Image
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from src.models.regression_model import SpindaRegressionModel
|
||||
|
||||
@@ -13,8 +12,8 @@ class SpindaInference:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str = "models/best_spinda_model.pth",
|
||||
backbone: str = "resnet18",
|
||||
model_path: str = "models/best_resnet34_model.pth",
|
||||
backbone: str = "resnet34",
|
||||
):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.model = SpindaRegressionModel(pretrained=False, backbone=backbone)
|
||||
@@ -29,15 +28,23 @@ class SpindaInference:
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def predict(self, image_path: str) -> Tuple[List[int], str]:
|
||||
def predict(self, image: Union[str, np.ndarray]) -> Tuple[List[int], str]:
|
||||
"""Predict the 8 grid coordinates and return them with a fingerprint string.
|
||||
|
||||
Args:
|
||||
image: file path (str) or a BGR numpy array (e.g. from the detector).
|
||||
Returns:
|
||||
grid_coords: list of 8 integers in [0, 15]
|
||||
fingerprint: "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
|
||||
"""
|
||||
img_bgr = cv2.imread(image_path)
|
||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
if isinstance(image, str):
|
||||
img_bgr = cv2.imread(image)
|
||||
if img_bgr is None:
|
||||
raise FileNotFoundError(f"Image not found: {image}")
|
||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
img_rgb = image[:, :, ::-1].copy()
|
||||
|
||||
img_tensor = torch.from_numpy(np.array(img_rgb)).permute(2, 0, 1)
|
||||
img_tensor = self.transform(img_tensor).unsqueeze(0).to(self.device)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class SpindaRegressionModel(nn.Module):
|
||||
Prediction: output.argmax(dim=2) → (B, 8) integer coordinates.
|
||||
"""
|
||||
|
||||
def __init__(self, pretrained: bool = True, backbone: str = "resnet18"):
|
||||
def __init__(self, pretrained: bool = True, backbone: str = "resnet34"):
|
||||
super().__init__()
|
||||
if backbone not in _BACKBONES:
|
||||
raise ValueError(f"backbone must be one of {list(_BACKBONES)}; got {backbone!r}")
|
||||
|
||||
@@ -1,49 +1,17 @@
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms.v2 as T
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.data.dataset import SpindaDataset, get_default_augmentations
|
||||
from src.data.dataset import SpindaDataset, SpindaEvalDataset, get_default_augmentations
|
||||
from src.models.regression_model import SpindaRegressionModel
|
||||
|
||||
|
||||
class SpindaEvalDataset(Dataset):
|
||||
"""Loads a fixed evaluation set (clean val or augmented test).
|
||||
|
||||
Images are stored post-augmentation, pre-normalisation; this class
|
||||
applies only the normalisation step so the on-disk images are stable.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str):
|
||||
with open(os.path.join(data_dir, "metadata.json")) as f:
|
||||
self.metadata = json.load(f)
|
||||
self.transform = T.Compose([
|
||||
T.ToDtype(torch.float32, scale=True),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.metadata)
|
||||
|
||||
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
item = self.metadata[idx]
|
||||
img_bgr = cv2.imread(item["img_path"])
|
||||
img_rgb = img_bgr[:, :, ::-1].copy()
|
||||
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1)
|
||||
img_tensor = self.transform(img_tensor)
|
||||
target = torch.tensor(item["target"], dtype=torch.long)
|
||||
return img_tensor, target
|
||||
|
||||
|
||||
# BL_x (index 4) and BL_y (index 5) are the weakest coordinates; upweight them
|
||||
# so the optimiser focuses more of its gradient on the hardest spot.
|
||||
_COORD_WEIGHTS = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.5, 2.5, 1.0, 1.0])
|
||||
@@ -73,7 +41,7 @@ def train_model(
|
||||
model_path: str = "models/best_spinda_model.pth",
|
||||
num_workers: int = 4,
|
||||
epoch_size: int = 200000,
|
||||
backbone: str = "resnet18",
|
||||
backbone: str = "resnet34",
|
||||
save_path: str = "",
|
||||
) -> None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@@ -176,8 +144,7 @@ def train_model(
|
||||
if exact_rate > best_exact_rate:
|
||||
best_exact_rate = exact_rate
|
||||
epochs_without_improvement = 0
|
||||
os.makedirs("models", exist_ok=True)
|
||||
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(checkpoint_path) or ".", exist_ok=True)
|
||||
torch.save(model.state_dict(), checkpoint_path)
|
||||
print(f" → Saved best model to {checkpoint_path} (clean val exact match: {best_exact_rate:.2%})")
|
||||
else:
|
||||
@@ -198,7 +165,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
|
||||
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)")
|
||||
parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch")
|
||||
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
|
||||
parser.add_argument("--backbone", type=str, default="resnet34", choices=["resnet18", "resnet34", "convnext_tiny"])
|
||||
parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user