New backbone
This commit is contained in:
@@ -11,9 +11,13 @@ from src.models.regression_model import SpindaRegressionModel
|
||||
class SpindaInference:
|
||||
"""Loads the trained model and predicts spot coordinates from an image crop."""
|
||||
|
||||
def __init__(self, model_path: str = "models/best_spinda_model.pth"):
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str = "models/best_spinda_model.pth",
|
||||
backbone: str = "resnet18",
|
||||
):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.model = SpindaRegressionModel(pretrained=False)
|
||||
self.model = SpindaRegressionModel(pretrained=False, backbone=backbone)
|
||||
self.model.load_state_dict(
|
||||
torch.load(model_path, map_location=self.device)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user