import cv2 import numpy as np import torch import torchvision.transforms.v2 as T from typing import List, Tuple, Union from src.models.regression_model import SpindaRegressionModel class SpindaInference: """Loads the trained model and predicts spot coordinates from an image crop.""" def __init__( self, model_path: str = "models/best_resnet34_model.pth", backbone: str = "resnet34", ): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = SpindaRegressionModel(pretrained=False, backbone=backbone) self.model.load_state_dict( torch.load(model_path, map_location=self.device) ) self.model.to(self.device).eval() self.transform = T.Compose([ T.Resize((128, 128)), T.ToDtype(torch.float32, scale=True), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def predict(self, image: Union[str, np.ndarray]) -> Tuple[List[int], str]: """Predict the 8 grid coordinates and return them with a fingerprint string. Args: image: file path (str) or a BGR numpy array (e.g. from the detector). Returns: grid_coords: list of 8 integers in [0, 15] fingerprint: "X1-Y1-X2-Y2-X3-Y3-X4-Y4" """ if isinstance(image, str): img_bgr = cv2.imread(image) if img_bgr is None: raise FileNotFoundError(f"Image not found: {image}") img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) else: img_rgb = image[:, :, ::-1].copy() img_tensor = torch.from_numpy(np.array(img_rgb)).permute(2, 0, 1) img_tensor = self.transform(img_tensor).unsqueeze(0).to(self.device) with torch.no_grad(): logits = self.model(img_tensor) # (1, 8, 16) grid_coords = logits.argmax(dim=2).squeeze(0).cpu().tolist() # [8] fingerprint = "-".join(f"{c:02d}" for c in grid_coords) return grid_coords, fingerprint if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python src/models/inference.py ") else: inf = SpindaInference() coords, fingerprint = inf.predict(sys.argv[1]) print(f"Predicted Grid Coordinates: {coords}") print(f"Visual Fingerprint: {fingerprint}")