Eval script
This commit is contained in:
51
src/models/evaluate.py
Normal file
51
src/models/evaluate.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from src.models.regression_model import SpindaRegressionModel
|
||||
from src.models.train import SpindaEvalDataset
|
||||
|
||||
|
||||
def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: SpindaEvalDataset) -> None:
|
||||
loader = DataLoader(ds, batch_size=128, shuffle=False, num_workers=0)
|
||||
exact, coord_correct = 0, [0] * 8
|
||||
with torch.no_grad():
|
||||
for imgs, targets in loader:
|
||||
imgs, targets = imgs.to(device), targets.to(device)
|
||||
preds = model(imgs).argmax(dim=2)
|
||||
exact += torch.all(preds == targets, dim=1).sum().item()
|
||||
for c in range(8):
|
||||
coord_correct[c] += (preds[:, c] == targets[:, c]).sum().item()
|
||||
n = len(ds)
|
||||
print(f"\n── {name} ({n} samples) ──")
|
||||
print(f" Exact match: {exact/n:.2%} ({exact}/{n})")
|
||||
names = ["TL_x", "TL_y", "TR_x", "TR_y", "BL_x", "BL_y", "BR_x", "BR_y"]
|
||||
for i, nm in enumerate(names):
|
||||
flag = " ◄" if i in (4, 5) else ""
|
||||
print(f" {nm}: {coord_correct[i]/n:.2%}{flag}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
|
||||
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"])
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.model_path):
|
||||
print(f"Error: model not found at {args.model_path}")
|
||||
raise SystemExit(1)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Device: {device}")
|
||||
print(f"Backbone: {args.backbone}")
|
||||
print(f"Model: {args.model_path}")
|
||||
|
||||
model = SpindaRegressionModel(pretrained=False, backbone=args.backbone).to(device)
|
||||
model.load_state_dict(torch.load(args.model_path, map_location=device))
|
||||
model.eval()
|
||||
|
||||
evaluate(model, device, "Clean Val", SpindaEvalDataset("data/val"))
|
||||
if os.path.exists("data/aug_test"):
|
||||
evaluate(model, device, "Aug Test", SpindaEvalDataset("data/aug_test"))
|
||||
Reference in New Issue
Block a user