New backbone
This commit is contained in:
19
identify.py
19
identify.py
@@ -8,7 +8,11 @@ from src.registry.database import SpindaRegistry
|
||||
from src.data.high_fidelity_generator import generate_high_fidelity_spinda
|
||||
from src.utils.detector import SpindaDetector # Import the detector
|
||||
|
||||
def identify_spinda(image_path: str):
|
||||
def identify_spinda(
|
||||
image_path: str,
|
||||
model_path: str = "models/best_spinda_model.pth",
|
||||
backbone: str = "resnet18",
|
||||
):
|
||||
if not os.path.exists(image_path):
|
||||
print(f"Error: File {image_path} not found.")
|
||||
return
|
||||
@@ -33,7 +37,7 @@ def identify_spinda(image_path: str):
|
||||
|
||||
# 2. Inference (Model Prediction) using the cropped image
|
||||
try:
|
||||
inf = SpindaInference(model_path="models/best_spinda_model.pth")
|
||||
inf = SpindaInference(model_path=model_path, backbone=backbone)
|
||||
coords, fingerprint = inf.predict(temp_cropped_path)
|
||||
except Exception as e:
|
||||
print(f"Error during inference: {e}")
|
||||
@@ -71,7 +75,10 @@ def identify_spinda(image_path: str):
|
||||
print("\nNote: Accuracy depends on model training progress.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python identify.py <image_path>")
|
||||
else:
|
||||
identify_spinda(sys.argv[1])
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("image_path")
|
||||
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
|
||||
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
|
||||
args = parser.parse_args()
|
||||
identify_spinda(args.image_path, model_path=args.model_path, backbone=args.backbone)
|
||||
|
||||
Reference in New Issue
Block a user