New backbone

This commit is contained in:
alexiondev
2026-05-08 16:42:35 -04:00
parent 49a2502c76
commit 799aa9fa3d
9 changed files with 69 additions and 22 deletions

View File

@@ -8,7 +8,11 @@ from src.registry.database import SpindaRegistry
from src.data.high_fidelity_generator import generate_high_fidelity_spinda
from src.utils.detector import SpindaDetector # Import the detector
def identify_spinda(image_path: str):
def identify_spinda(
image_path: str,
model_path: str = "models/best_spinda_model.pth",
backbone: str = "resnet18",
):
if not os.path.exists(image_path):
print(f"Error: File {image_path} not found.")
return
@@ -33,7 +37,7 @@ def identify_spinda(image_path: str):
# 2. Inference (Model Prediction) using the cropped image
try:
inf = SpindaInference(model_path="models/best_spinda_model.pth")
inf = SpindaInference(model_path=model_path, backbone=backbone)
coords, fingerprint = inf.predict(temp_cropped_path)
except Exception as e:
print(f"Error during inference: {e}")
@@ -71,7 +75,10 @@ def identify_spinda(image_path: str):
print("\nNote: Accuracy depends on model training progress.")
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python identify.py <image_path>")
else:
identify_spinda(sys.argv[1])
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("image_path")
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
args = parser.parse_args()
identify_spinda(args.image_path, model_path=args.model_path, backbone=args.backbone)