New backbone

This commit is contained in:
alexiondev
2026-05-08 16:42:35 -04:00
parent 49a2502c76
commit 799aa9fa3d
9 changed files with 69 additions and 22 deletions

View File

@@ -11,9 +11,13 @@ from src.models.regression_model import SpindaRegressionModel
class SpindaInference:
"""Loads the trained model and predicts spot coordinates from an image crop."""
def __init__(self, model_path: str = "models/best_spinda_model.pth"):
def __init__(
self,
model_path: str = "models/best_spinda_model.pth",
backbone: str = "resnet18",
):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = SpindaRegressionModel(pretrained=False)
self.model = SpindaRegressionModel(pretrained=False, backbone=backbone)
self.model.load_state_dict(
torch.load(model_path, map_location=self.device)
)