70 lines
2.4 KiB
Python
70 lines
2.4 KiB
Python
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torchvision.transforms.v2 as T
|
|
from typing import List, Tuple, Union
|
|
|
|
from src.models.regression_model import SpindaRegressionModel
|
|
|
|
|
|
class SpindaInference:
|
|
"""Loads the trained model and predicts spot coordinates from an image crop."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: str = "models/best_resnet34_model.pth",
|
|
backbone: str = "resnet34",
|
|
):
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.model = SpindaRegressionModel(pretrained=False, backbone=backbone)
|
|
self.model.load_state_dict(
|
|
torch.load(model_path, map_location=self.device)
|
|
)
|
|
self.model.to(self.device).eval()
|
|
|
|
self.transform = T.Compose([
|
|
T.Resize((128, 128)),
|
|
T.ToDtype(torch.float32, scale=True),
|
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
def predict(self, image: Union[str, np.ndarray]) -> Tuple[List[int], str]:
|
|
"""Predict the 8 grid coordinates and return them with a fingerprint string.
|
|
|
|
Args:
|
|
image: file path (str) or a BGR numpy array (e.g. from the detector).
|
|
Returns:
|
|
grid_coords: list of 8 integers in [0, 15]
|
|
fingerprint: "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
|
|
"""
|
|
if isinstance(image, str):
|
|
img_bgr = cv2.imread(image)
|
|
if img_bgr is None:
|
|
raise FileNotFoundError(f"Image not found: {image}")
|
|
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
|
else:
|
|
img_rgb = image[:, :, ::-1].copy()
|
|
|
|
img_tensor = torch.from_numpy(np.array(img_rgb)).permute(2, 0, 1)
|
|
img_tensor = self.transform(img_tensor).unsqueeze(0).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
logits = self.model(img_tensor) # (1, 8, 16)
|
|
|
|
grid_coords = logits.argmax(dim=2).squeeze(0).cpu().tolist() # [8]
|
|
fingerprint = "-".join(f"{c:02d}" for c in grid_coords)
|
|
|
|
return grid_coords, fingerprint
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
if len(sys.argv) < 2:
|
|
print("Usage: python src/models/inference.py <image_path>")
|
|
else:
|
|
inf = SpindaInference()
|
|
coords, fingerprint = inf.predict(sys.argv[1])
|
|
print(f"Predicted Grid Coordinates: {coords}")
|
|
print(f"Visual Fingerprint: {fingerprint}")
|