New backbone

This commit is contained in:
alexiondev
2026-05-08 16:42:35 -04:00
parent 49a2502c76
commit 799aa9fa3d
9 changed files with 69 additions and 22 deletions

View File

@@ -0,0 +1 @@
{"sessionId":"ce21bc15-a2ef-47a3-ac8b-8dad11b440dd","pid":9390,"procStart":"45841","acquiredAt":1778247797938}

View File

@@ -0,0 +1,24 @@
{
"permissions": {
"allow": [
"Bash(python3 -c ' *)",
"Bash(python3 -)",
"Bash(python3 -m src.utils.detector test_image.jpg)",
"Bash(python3 -c \"import sys; print\\(sys.executable\\)\")",
"Bash(pip show *)",
"Bash(.venv/bin/python -m src.utils.detector test_image.jpg)",
"Bash(.venv/bin/python -)",
"Bash(cat)",
"Bash(.venv/bin/python /tmp/eval_model.py)",
"Bash(.venv/bin/python /tmp/eval_final.py)",
"Bash(.venv/bin/python identify.py test_image.jpg)",
"Bash(.venv/bin/python -m src.data.generate_aug_test_set)",
"Bash(.venv/bin/python -c ' *)",
"Bash(.venv/bin/python -m src.models.train --epochs 5)",
"Bash(.venv/bin/python -m src.models.evaluate)",
"Bash(.venv/bin/python -m src.models.evaluate --backbone resnet34 --model_path models/best_resnet34_model.pth)",
"Bash(.venv/bin/python identify.py test_image2.jpg --backbone resnet34 --model_path models/best_resnet34_model.pth)",
"Bash(.venv/bin/python src/models/regression_model.py)"
]
}
}

View File

@@ -13,6 +13,9 @@ All commands must be run from the project root using the local venv:
# Train the model # Train the model
.venv/bin/python -m src.models.train --epochs 50 --batch_size 64 --lr 1e-4 .venv/bin/python -m src.models.train --epochs 50 --batch_size 64 --lr 1e-4
# Evaluate a trained model on val and aug_test sets
.venv/bin/python -m src.models.evaluate [--backbone resnet18|resnet34] [--model_path <path>]
# Run inference only (no registry lookup) # Run inference only (no registry lookup)
.venv/bin/python src/models/inference.py <image_path> .venv/bin/python src/models/inference.py <image_path>

View File

@@ -8,7 +8,11 @@ from src.registry.database import SpindaRegistry
from src.data.high_fidelity_generator import generate_high_fidelity_spinda from src.data.high_fidelity_generator import generate_high_fidelity_spinda
from src.utils.detector import SpindaDetector # Import the detector from src.utils.detector import SpindaDetector # Import the detector
def identify_spinda(image_path: str): def identify_spinda(
image_path: str,
model_path: str = "models/best_spinda_model.pth",
backbone: str = "resnet18",
):
if not os.path.exists(image_path): if not os.path.exists(image_path):
print(f"Error: File {image_path} not found.") print(f"Error: File {image_path} not found.")
return return
@@ -33,7 +37,7 @@ def identify_spinda(image_path: str):
# 2. Inference (Model Prediction) using the cropped image # 2. Inference (Model Prediction) using the cropped image
try: try:
inf = SpindaInference(model_path="models/best_spinda_model.pth") inf = SpindaInference(model_path=model_path, backbone=backbone)
coords, fingerprint = inf.predict(temp_cropped_path) coords, fingerprint = inf.predict(temp_cropped_path)
except Exception as e: except Exception as e:
print(f"Error during inference: {e}") print(f"Error during inference: {e}")
@@ -71,7 +75,10 @@ def identify_spinda(image_path: str):
print("\nNote: Accuracy depends on model training progress.") print("\nNote: Accuracy depends on model training progress.")
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) < 2: import argparse
print("Usage: python identify.py <image_path>") parser = argparse.ArgumentParser()
else: parser.add_argument("image_path")
identify_spinda(sys.argv[1]) parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
args = parser.parse_args()
identify_spinda(args.image_path, model_path=args.model_path, backbone=args.backbone)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.1 KiB

After

Width:  |  Height:  |  Size: 2.1 KiB

View File

@@ -30,7 +30,7 @@ def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: Spinda
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"]) parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
args = parser.parse_args() args = parser.parse_args()
if not os.path.exists(args.model_path): if not os.path.exists(args.model_path):

View File

@@ -11,9 +11,13 @@ from src.models.regression_model import SpindaRegressionModel
class SpindaInference: class SpindaInference:
"""Loads the trained model and predicts spot coordinates from an image crop.""" """Loads the trained model and predicts spot coordinates from an image crop."""
def __init__(self, model_path: str = "models/best_spinda_model.pth"): def __init__(
self,
model_path: str = "models/best_spinda_model.pth",
backbone: str = "resnet18",
):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = SpindaRegressionModel(pretrained=False) self.model = SpindaRegressionModel(pretrained=False, backbone=backbone)
self.model.load_state_dict( self.model.load_state_dict(
torch.load(model_path, map_location=self.device) torch.load(model_path, map_location=self.device)
) )

View File

@@ -1,16 +1,18 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchvision import models from torchvision import models
from torchvision.models import ResNet18_Weights, ResNet34_Weights from torchvision.models import ResNet18_Weights, ResNet34_Weights, ConvNeXt_Tiny_Weights
# (factory, default_weights, feature_dim)
_BACKBONES = { _BACKBONES = {
"resnet18": (models.resnet18, ResNet18_Weights.DEFAULT), "resnet18": (models.resnet18, ResNet18_Weights.DEFAULT, 512),
"resnet34": (models.resnet34, ResNet34_Weights.DEFAULT), "resnet34": (models.resnet34, ResNet34_Weights.DEFAULT, 512),
"convnext_tiny": (models.convnext_tiny, ConvNeXt_Tiny_Weights.DEFAULT, 768),
} }
class SpindaRegressionModel(nn.Module): class SpindaRegressionModel(nn.Module):
"""ResNet backbone with 8 independent 16-class coordinate heads. """CNN backbone with 8 independent 16-class coordinate heads.
Each of the 8 output coordinates (4 spots × x, y) is treated as a Each of the 8 output coordinates (4 spots × x, y) is treated as a
16-class classification problem over the [0, 15] nibble grid. 16-class classification problem over the [0, 15] nibble grid.
@@ -25,23 +27,29 @@ class SpindaRegressionModel(nn.Module):
super().__init__() super().__init__()
if backbone not in _BACKBONES: if backbone not in _BACKBONES:
raise ValueError(f"backbone must be one of {list(_BACKBONES)}; got {backbone!r}") raise ValueError(f"backbone must be one of {list(_BACKBONES)}; got {backbone!r}")
factory, default_weights = _BACKBONES[backbone] factory, default_weights, feat_dim = _BACKBONES[backbone]
weights = default_weights if pretrained else None weights = default_weights if pretrained else None
net = factory(weights=weights) net = factory(weights=weights)
# Strip the final FC; keep the feature extractor + average pool.
self.features = nn.Sequential(*list(net.children())[:-1]) if backbone in ("resnet18", "resnet34"):
# 8 coordinates × 16 classes each (512-dim output for both resnet18/34) # Strip the final FC; flatten the (B, 512, 1, 1) avgpool output.
self.classifier = nn.Linear(512, 8 * 16) self.features = nn.Sequential(*list(net.children())[:-1], nn.Flatten())
elif backbone == "convnext_tiny":
# Keep features + avgpool + LayerNorm (classifier[0]); drop the final Linear.
self.features = nn.Sequential(
net.features, net.avgpool, net.classifier[0], nn.Flatten()
)
self.classifier = nn.Linear(feat_dim, 8 * 16)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x) # (B, 512, 1, 1) x = self.features(x) # (B, feat_dim)
x = x.flatten(1) # (B, 512)
x = self.classifier(x) # (B, 128) x = self.classifier(x) # (B, 128)
return x.view(-1, 8, 16) # (B, 8, 16) return x.view(-1, 8, 16) # (B, 8, 16)
if __name__ == "__main__": if __name__ == "__main__":
for name in ("resnet18", "resnet34"): for name in ("resnet18", "resnet34", "convnext_tiny"):
model = SpindaRegressionModel(pretrained=False, backbone=name) model = SpindaRegressionModel(pretrained=False, backbone=name)
out = model(torch.randn(2, 3, 128, 128)) out = model(torch.randn(2, 3, 128, 128))
print(f"{name}: output {out.shape}, predictions {out.argmax(dim=2)}") print(f"{name}: output {out.shape}, predictions {out.argmax(dim=2)}")

View File

@@ -198,7 +198,7 @@ if __name__ == "__main__":
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)") parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)")
parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch") parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch")
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"]) parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path") parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path")
args = parser.parse_args() args = parser.parse_args()