diff --git a/src/models/evaluate.py b/src/models/evaluate.py new file mode 100644 index 0000000..469bbc7 --- /dev/null +++ b/src/models/evaluate.py @@ -0,0 +1,51 @@ +import argparse +import os + +import torch +from torch.utils.data import DataLoader + +from src.models.regression_model import SpindaRegressionModel +from src.models.train import SpindaEvalDataset + + +def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: SpindaEvalDataset) -> None: + loader = DataLoader(ds, batch_size=128, shuffle=False, num_workers=0) + exact, coord_correct = 0, [0] * 8 + with torch.no_grad(): + for imgs, targets in loader: + imgs, targets = imgs.to(device), targets.to(device) + preds = model(imgs).argmax(dim=2) + exact += torch.all(preds == targets, dim=1).sum().item() + for c in range(8): + coord_correct[c] += (preds[:, c] == targets[:, c]).sum().item() + n = len(ds) + print(f"\n── {name} ({n} samples) ──") + print(f" Exact match: {exact/n:.2%} ({exact}/{n})") + names = ["TL_x", "TL_y", "TR_x", "TR_y", "BL_x", "BL_y", "BR_x", "BR_y"] + for i, nm in enumerate(names): + flag = " ◄" if i in (4, 5) else "" + print(f" {nm}: {coord_correct[i]/n:.2%}{flag}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") + parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"]) + args = parser.parse_args() + + if not os.path.exists(args.model_path): + print(f"Error: model not found at {args.model_path}") + raise SystemExit(1) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + print(f"Backbone: {args.backbone}") + print(f"Model: {args.model_path}") + + model = SpindaRegressionModel(pretrained=False, backbone=args.backbone).to(device) + model.load_state_dict(torch.load(args.model_path, map_location=device)) + model.eval() + + evaluate(model, device, "Clean Val", SpindaEvalDataset("data/val")) + if os.path.exists("data/aug_test"): + evaluate(model, device, "Aug Test", SpindaEvalDataset("data/aug_test"))