diff --git a/.claude/scheduled_tasks.lock b/.claude/scheduled_tasks.lock new file mode 100644 index 0000000..d8cd204 --- /dev/null +++ b/.claude/scheduled_tasks.lock @@ -0,0 +1 @@ +{"sessionId":"ce21bc15-a2ef-47a3-ac8b-8dad11b440dd","pid":9390,"procStart":"45841","acquiredAt":1778247797938} \ No newline at end of file diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..8ddeffa --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,24 @@ +{ + "permissions": { + "allow": [ + "Bash(python3 -c ' *)", + "Bash(python3 -)", + "Bash(python3 -m src.utils.detector test_image.jpg)", + "Bash(python3 -c \"import sys; print\\(sys.executable\\)\")", + "Bash(pip show *)", + "Bash(.venv/bin/python -m src.utils.detector test_image.jpg)", + "Bash(.venv/bin/python -)", + "Bash(cat)", + "Bash(.venv/bin/python /tmp/eval_model.py)", + "Bash(.venv/bin/python /tmp/eval_final.py)", + "Bash(.venv/bin/python identify.py test_image.jpg)", + "Bash(.venv/bin/python -m src.data.generate_aug_test_set)", + "Bash(.venv/bin/python -c ' *)", + "Bash(.venv/bin/python -m src.models.train --epochs 5)", + "Bash(.venv/bin/python -m src.models.evaluate)", + "Bash(.venv/bin/python -m src.models.evaluate --backbone resnet34 --model_path models/best_resnet34_model.pth)", + "Bash(.venv/bin/python identify.py test_image2.jpg --backbone resnet34 --model_path models/best_resnet34_model.pth)", + "Bash(.venv/bin/python src/models/regression_model.py)" + ] + } +} diff --git a/CLAUDE.md b/CLAUDE.md index 46a5d9a..724f881 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -13,6 +13,9 @@ All commands must be run from the project root using the local venv: # Train the model .venv/bin/python -m src.models.train --epochs 50 --batch_size 64 --lr 1e-4 +# Evaluate a trained model on val and aug_test sets +.venv/bin/python -m src.models.evaluate [--backbone resnet18|resnet34] [--model_path ] + # Run inference only (no registry lookup) .venv/bin/python src/models/inference.py diff --git a/identify.py b/identify.py index 8a64cee..6343279 100644 --- a/identify.py +++ b/identify.py @@ -8,7 +8,11 @@ from src.registry.database import SpindaRegistry from src.data.high_fidelity_generator import generate_high_fidelity_spinda from src.utils.detector import SpindaDetector # Import the detector -def identify_spinda(image_path: str): +def identify_spinda( + image_path: str, + model_path: str = "models/best_spinda_model.pth", + backbone: str = "resnet18", +): if not os.path.exists(image_path): print(f"Error: File {image_path} not found.") return @@ -33,7 +37,7 @@ def identify_spinda(image_path: str): # 2. Inference (Model Prediction) using the cropped image try: - inf = SpindaInference(model_path="models/best_spinda_model.pth") + inf = SpindaInference(model_path=model_path, backbone=backbone) coords, fingerprint = inf.predict(temp_cropped_path) except Exception as e: print(f"Error during inference: {e}") @@ -71,7 +75,10 @@ def identify_spinda(image_path: str): print("\nNote: Accuracy depends on model training progress.") if __name__ == "__main__": - if len(sys.argv) < 2: - print("Usage: python identify.py ") - else: - identify_spinda(sys.argv[1]) + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("image_path") + parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"]) + parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") + args = parser.parse_args() + identify_spinda(args.image_path, model_path=args.model_path, backbone=args.backbone) diff --git a/prediction_verify.png b/prediction_verify.png index 9d72984..ebef9af 100644 Binary files a/prediction_verify.png and b/prediction_verify.png differ diff --git a/src/models/evaluate.py b/src/models/evaluate.py index 469bbc7..073f430 100644 --- a/src/models/evaluate.py +++ b/src/models/evaluate.py @@ -30,7 +30,7 @@ def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: Spinda if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") - parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"]) + parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"]) args = parser.parse_args() if not os.path.exists(args.model_path): diff --git a/src/models/inference.py b/src/models/inference.py index f63c53c..b808109 100644 --- a/src/models/inference.py +++ b/src/models/inference.py @@ -11,9 +11,13 @@ from src.models.regression_model import SpindaRegressionModel class SpindaInference: """Loads the trained model and predicts spot coordinates from an image crop.""" - def __init__(self, model_path: str = "models/best_spinda_model.pth"): + def __init__( + self, + model_path: str = "models/best_spinda_model.pth", + backbone: str = "resnet18", + ): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model = SpindaRegressionModel(pretrained=False) + self.model = SpindaRegressionModel(pretrained=False, backbone=backbone) self.model.load_state_dict( torch.load(model_path, map_location=self.device) ) diff --git a/src/models/regression_model.py b/src/models/regression_model.py index c52197e..473959f 100644 --- a/src/models/regression_model.py +++ b/src/models/regression_model.py @@ -1,16 +1,18 @@ import torch import torch.nn as nn from torchvision import models -from torchvision.models import ResNet18_Weights, ResNet34_Weights +from torchvision.models import ResNet18_Weights, ResNet34_Weights, ConvNeXt_Tiny_Weights +# (factory, default_weights, feature_dim) _BACKBONES = { - "resnet18": (models.resnet18, ResNet18_Weights.DEFAULT), - "resnet34": (models.resnet34, ResNet34_Weights.DEFAULT), + "resnet18": (models.resnet18, ResNet18_Weights.DEFAULT, 512), + "resnet34": (models.resnet34, ResNet34_Weights.DEFAULT, 512), + "convnext_tiny": (models.convnext_tiny, ConvNeXt_Tiny_Weights.DEFAULT, 768), } class SpindaRegressionModel(nn.Module): - """ResNet backbone with 8 independent 16-class coordinate heads. + """CNN backbone with 8 independent 16-class coordinate heads. Each of the 8 output coordinates (4 spots × x, y) is treated as a 16-class classification problem over the [0, 15] nibble grid. @@ -25,23 +27,29 @@ class SpindaRegressionModel(nn.Module): super().__init__() if backbone not in _BACKBONES: raise ValueError(f"backbone must be one of {list(_BACKBONES)}; got {backbone!r}") - factory, default_weights = _BACKBONES[backbone] + factory, default_weights, feat_dim = _BACKBONES[backbone] weights = default_weights if pretrained else None net = factory(weights=weights) - # Strip the final FC; keep the feature extractor + average pool. - self.features = nn.Sequential(*list(net.children())[:-1]) - # 8 coordinates × 16 classes each (512-dim output for both resnet18/34) - self.classifier = nn.Linear(512, 8 * 16) + + if backbone in ("resnet18", "resnet34"): + # Strip the final FC; flatten the (B, 512, 1, 1) avgpool output. + self.features = nn.Sequential(*list(net.children())[:-1], nn.Flatten()) + elif backbone == "convnext_tiny": + # Keep features + avgpool + LayerNorm (classifier[0]); drop the final Linear. + self.features = nn.Sequential( + net.features, net.avgpool, net.classifier[0], nn.Flatten() + ) + + self.classifier = nn.Linear(feat_dim, 8 * 16) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.features(x) # (B, 512, 1, 1) - x = x.flatten(1) # (B, 512) + x = self.features(x) # (B, feat_dim) x = self.classifier(x) # (B, 128) return x.view(-1, 8, 16) # (B, 8, 16) if __name__ == "__main__": - for name in ("resnet18", "resnet34"): + for name in ("resnet18", "resnet34", "convnext_tiny"): model = SpindaRegressionModel(pretrained=False, backbone=name) out = model(torch.randn(2, 3, 128, 128)) print(f"{name}: output {out.shape}, predictions {out.argmax(dim=2)}") diff --git a/src/models/train.py b/src/models/train.py index 09b3e4c..0049a59 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -198,7 +198,7 @@ if __name__ == "__main__": parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)") parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch") - parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"]) + parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"]) parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path") args = parser.parse_args()