Refactor/cleanup

This commit is contained in:
alexiondev
2026-05-08 17:18:58 -04:00
parent 799aa9fa3d
commit 1b904e04ea
18 changed files with 214 additions and 357 deletions

View File

@@ -1,70 +1,59 @@
import argparse
import os
import sys
import cv2
import torch
from src.data.renderer import generate_high_fidelity_spinda
from src.models.inference import SpindaInference
from src.utils.resolver import SpindaResolver
from src.registry.database import SpindaRegistry
from src.data.high_fidelity_generator import generate_high_fidelity_spinda
from src.utils.detector import SpindaDetector # Import the detector
from src.utils.detector import SpindaDetector
from src.utils.resolver import SpindaResolver
def identify_spinda(
image_path: str,
model_path: str = "models/best_spinda_model.pth",
backbone: str = "resnet18",
):
model_path: str = "models/best_resnet34_model.pth",
backbone: str = "resnet34",
) -> None:
if not os.path.exists(image_path):
print(f"Error: File {image_path} not found.")
return
print(f"--- Identifying Spinda in {image_path} ---")
# 1. Detect and Crop Spinda
# 1. Detect and crop
detector = SpindaDetector()
cropped_img = detector.detect_and_crop(image_path)
if cropped_img is None:
print("Error: Could not detect Spinda in the image.")
return
# Save cropped image for debug/visual check
cv2.imwrite("detected_spinda_crop.png", cropped_img)
print("Detected Spinda saved to detected_spinda_crop.png")
# We need to save the cropped image to a temporary file for the inference model to read
temp_cropped_path = "temp_cropped_spinda.png"
cv2.imwrite(temp_cropped_path, cropped_img)
# 2. Inference (Model Prediction) using the cropped image
try:
inf = SpindaInference(model_path=model_path, backbone=backbone)
coords, fingerprint = inf.predict(temp_cropped_path)
except Exception as e:
print(f"Error during inference: {e}")
os.remove(temp_cropped_path) # Clean up temp file
return
finally:
os.remove(temp_cropped_path) # Clean up temp file
# 2. Inference — pass the BGR array directly, no temp file needed
inf = SpindaInference(model_path=model_path, backbone=backbone)
coords, fingerprint = inf.predict(cropped_img)
print(f"Visual Fingerprint: {fingerprint}")
print(f"Predicted Grid Coordinates: {coords}")
# 3. Resolution (Mathematical PIDs)
# 3. Resolve to PIDs
resolved = SpindaResolver.resolve_fingerprint(fingerprint)
print("\nPossible PIDs:")
print(f" Standard (Gen 3-8, HOME): 0x{resolved['standard']}")
print(f" BDSP (Big-Endian Flip): 0x{resolved['bdsp']}")
# 4. Visual Verification
# 4. Visual verification
print("\nGenerating visual verification image...")
verify_img = generate_high_fidelity_spinda(int(resolved['standard'], 16))
cv2.imwrite("prediction_verify.png", verify_img)
print("Verification image saved to: prediction_verify.png")
# 5. Registry Lookup
# 5. Registry lookup
reg = SpindaRegistry()
matches = reg.lookup_by_fingerprint(fingerprint)
if matches:
print("\nMatches found in Global Registry:")
for pid in matches:
@@ -72,13 +61,12 @@ def identify_spinda(
else:
print("\nNo matching entries in Global Registry.")
print("\nNote: Accuracy depends on model training progress.")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("image_path")
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
parser.add_argument("--backbone", type=str, default="resnet34",
choices=["resnet18", "resnet34", "convnext_tiny"])
parser.add_argument("--model_path", type=str, default="models/best_resnet34_model.pth")
args = parser.parse_args()
identify_spinda(args.image_path, model_path=args.model_path, backbone=args.backbone)