Eval script

This commit is contained in:
alexiondev
2026-05-08 09:37:01 -04:00
parent 037f7131c2
commit 49a2502c76

51
src/models/evaluate.py Normal file
View File

@@ -0,0 +1,51 @@
import argparse
import os
import torch
from torch.utils.data import DataLoader
from src.models.regression_model import SpindaRegressionModel
from src.models.train import SpindaEvalDataset
def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: SpindaEvalDataset) -> None:
loader = DataLoader(ds, batch_size=128, shuffle=False, num_workers=0)
exact, coord_correct = 0, [0] * 8
with torch.no_grad():
for imgs, targets in loader:
imgs, targets = imgs.to(device), targets.to(device)
preds = model(imgs).argmax(dim=2)
exact += torch.all(preds == targets, dim=1).sum().item()
for c in range(8):
coord_correct[c] += (preds[:, c] == targets[:, c]).sum().item()
n = len(ds)
print(f"\n── {name} ({n} samples) ──")
print(f" Exact match: {exact/n:.2%} ({exact}/{n})")
names = ["TL_x", "TL_y", "TR_x", "TR_y", "BL_x", "BL_y", "BR_x", "BR_y"]
for i, nm in enumerate(names):
flag = "" if i in (4, 5) else ""
print(f" {nm}: {coord_correct[i]/n:.2%}{flag}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"])
args = parser.parse_args()
if not os.path.exists(args.model_path):
print(f"Error: model not found at {args.model_path}")
raise SystemExit(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
print(f"Backbone: {args.backbone}")
print(f"Model: {args.model_path}")
model = SpindaRegressionModel(pretrained=False, backbone=args.backbone).to(device)
model.load_state_dict(torch.load(args.model_path, map_location=device))
model.eval()
evaluate(model, device, "Clean Val", SpindaEvalDataset("data/val"))
if os.path.exists("data/aug_test"):
evaluate(model, device, "Aug Test", SpindaEvalDataset("data/aug_test"))