Refactor/cleanup
This commit is contained in:
@@ -2,8 +2,7 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.v2 as T
|
||||
from PIL import Image
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from src.models.regression_model import SpindaRegressionModel
|
||||
|
||||
@@ -13,8 +12,8 @@ class SpindaInference:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str = "models/best_spinda_model.pth",
|
||||
backbone: str = "resnet18",
|
||||
model_path: str = "models/best_resnet34_model.pth",
|
||||
backbone: str = "resnet34",
|
||||
):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.model = SpindaRegressionModel(pretrained=False, backbone=backbone)
|
||||
@@ -29,15 +28,23 @@ class SpindaInference:
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def predict(self, image_path: str) -> Tuple[List[int], str]:
|
||||
def predict(self, image: Union[str, np.ndarray]) -> Tuple[List[int], str]:
|
||||
"""Predict the 8 grid coordinates and return them with a fingerprint string.
|
||||
|
||||
Args:
|
||||
image: file path (str) or a BGR numpy array (e.g. from the detector).
|
||||
Returns:
|
||||
grid_coords: list of 8 integers in [0, 15]
|
||||
fingerprint: "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
|
||||
"""
|
||||
img_bgr = cv2.imread(image_path)
|
||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
if isinstance(image, str):
|
||||
img_bgr = cv2.imread(image)
|
||||
if img_bgr is None:
|
||||
raise FileNotFoundError(f"Image not found: {image}")
|
||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
img_rgb = image[:, :, ::-1].copy()
|
||||
|
||||
img_tensor = torch.from_numpy(np.array(img_rgb)).permute(2, 0, 1)
|
||||
img_tensor = self.transform(img_tensor).unsqueeze(0).to(self.device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user