From 799aa9fa3d079c7fe4ca27fc663eeea2078c339f Mon Sep 17 00:00:00 2001 From: alexiondev <1363939+alexiondev@users.noreply.github.com> Date: Fri, 8 May 2026 16:42:35 -0400 Subject: [PATCH] New backbone --- .claude/scheduled_tasks.lock | 1 + .claude/settings.local.json | 24 ++++++++++++++++++++++++ CLAUDE.md | 3 +++ identify.py | 19 +++++++++++++------ prediction_verify.png | Bin 2122 -> 2126 bytes src/models/evaluate.py | 2 +- src/models/inference.py | 8 ++++++-- src/models/regression_model.py | 32 ++++++++++++++++++++------------ src/models/train.py | 2 +- 9 files changed, 69 insertions(+), 22 deletions(-) create mode 100644 .claude/scheduled_tasks.lock create mode 100644 .claude/settings.local.json diff --git a/.claude/scheduled_tasks.lock b/.claude/scheduled_tasks.lock new file mode 100644 index 0000000..d8cd204 --- /dev/null +++ b/.claude/scheduled_tasks.lock @@ -0,0 +1 @@ +{"sessionId":"ce21bc15-a2ef-47a3-ac8b-8dad11b440dd","pid":9390,"procStart":"45841","acquiredAt":1778247797938} \ No newline at end of file diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..8ddeffa --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,24 @@ +{ + "permissions": { + "allow": [ + "Bash(python3 -c ' *)", + "Bash(python3 -)", + "Bash(python3 -m src.utils.detector test_image.jpg)", + "Bash(python3 -c \"import sys; print\\(sys.executable\\)\")", + "Bash(pip show *)", + "Bash(.venv/bin/python -m src.utils.detector test_image.jpg)", + "Bash(.venv/bin/python -)", + "Bash(cat)", + "Bash(.venv/bin/python /tmp/eval_model.py)", + "Bash(.venv/bin/python /tmp/eval_final.py)", + "Bash(.venv/bin/python identify.py test_image.jpg)", + "Bash(.venv/bin/python -m src.data.generate_aug_test_set)", + "Bash(.venv/bin/python -c ' *)", + "Bash(.venv/bin/python -m src.models.train --epochs 5)", + "Bash(.venv/bin/python -m src.models.evaluate)", + "Bash(.venv/bin/python -m src.models.evaluate --backbone resnet34 --model_path models/best_resnet34_model.pth)", + "Bash(.venv/bin/python identify.py test_image2.jpg --backbone resnet34 --model_path models/best_resnet34_model.pth)", + "Bash(.venv/bin/python src/models/regression_model.py)" + ] + } +} diff --git a/CLAUDE.md b/CLAUDE.md index 46a5d9a..724f881 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -13,6 +13,9 @@ All commands must be run from the project root using the local venv: # Train the model .venv/bin/python -m src.models.train --epochs 50 --batch_size 64 --lr 1e-4 +# Evaluate a trained model on val and aug_test sets +.venv/bin/python -m src.models.evaluate [--backbone resnet18|resnet34] [--model_path ] + # Run inference only (no registry lookup) .venv/bin/python src/models/inference.py diff --git a/identify.py b/identify.py index 8a64cee..6343279 100644 --- a/identify.py +++ b/identify.py @@ -8,7 +8,11 @@ from src.registry.database import SpindaRegistry from src.data.high_fidelity_generator import generate_high_fidelity_spinda from src.utils.detector import SpindaDetector # Import the detector -def identify_spinda(image_path: str): +def identify_spinda( + image_path: str, + model_path: str = "models/best_spinda_model.pth", + backbone: str = "resnet18", +): if not os.path.exists(image_path): print(f"Error: File {image_path} not found.") return @@ -33,7 +37,7 @@ def identify_spinda(image_path: str): # 2. Inference (Model Prediction) using the cropped image try: - inf = SpindaInference(model_path="models/best_spinda_model.pth") + inf = SpindaInference(model_path=model_path, backbone=backbone) coords, fingerprint = inf.predict(temp_cropped_path) except Exception as e: print(f"Error during inference: {e}") @@ -71,7 +75,10 @@ def identify_spinda(image_path: str): print("\nNote: Accuracy depends on model training progress.") if __name__ == "__main__": - if len(sys.argv) < 2: - print("Usage: python identify.py ") - else: - identify_spinda(sys.argv[1]) + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("image_path") + parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"]) + parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") + args = parser.parse_args() + identify_spinda(args.image_path, model_path=args.model_path, backbone=args.backbone) diff --git a/prediction_verify.png b/prediction_verify.png index 9d72984bc05a22f5d1daba7a741fbc6ca768ef4a..ebef9af14ab3064f5f164e6d2e4495b3f8a0c46b 100644 GIT binary patch literal 2126 zcmds3`#V&57(e4SsNo@{NQ6zTB#-+wX==<=uB9>}QjC~2W;I02;}VKkMM5rlLTTF2 zxNEdXiD}AhnL)$4CXLMJRH4hAgnkbz|4WuAgSq4sqoC%)^lJi799Ufa+l+N3M*EHHTLR*$k4aRhUYI1px$)1K6I zJeNagZ|9DjIoNE9RD;u){J>=1-qY=u=dw0u{;pSq+Ryz;r{P03UpYSaXdkszf~kVo zd(x52JZ;43HMiawqdD|k$N>>Z4)O#Ub?a7f_B571r7ALDa@*HC5tK(P8ETFz;ijuS z%|QH9wrI-8Xgx*@00T#2$7o#;f4+M>nz*T<{cY^R84nSbv$lcoOv@Qpe5US6ZDobN zS9d{x&p)6KKk%L@*sq?)+!@Qc9#l{I7`E52>d3c6x&6%G1Xd1gvQ zO~yuE?;K1BG2e-EK0#t4-UQ7J49j(=j~;i1(om2>A3^tXQ0=+Q3Ly&S5g(m`_P*@n zBm0XoHAjj15hx@>-*3#j&Zvz-o@dvHUkD()6mm4=co2<`cPu`XsoZLPsqsO*DEf%4 zx>b33d0ZvR+CD(Q(y7dfy<9vkBe+){d%4_HA4C#LPNY=vIF5vyx*&TX9^qXb%9BIw zPT0Fps|}@>i$mv#&`tk*7PGC*)S@jUTsjXlM8%@GzPK?!+@UwnhX?7&rN;j`H8mB8 zcf3-W`>G(Rr6?rhd&`>8w(4GMGp#`tEPo?JTU&M`d>~^eyTo5IkF^;t$5LHGQ?4X} za>*3$Ym`;o4$f#Kf-_FN4b=rU`q^(mqcp5+>+e38t|Plv7YG)D7bC(yg^vA=^alHMv2dzF7M4V7&!!t!_D z|3b9EG)`ZrE!I{Y$E#m^Jrl?PmV>RDR%l88s37|p2x`*d-pQ@-!lh>kb~(zY#@jP$ zoqUUcV_b~(zUXk){^qk*iRLYlNZ@5e!`~&ki1Ei;Gd*!aJ(@5;kF^;b>?T4V0EGDo z9oDR%6X8j_0S^tt`q1>Gr|*^-$K}MoHP24Ko&uKX5YE7*XdCmTKawWRGS8Nf$U$6Y ze=@BLz+96S@v1<3XJI#zQ_%2g0aI(NrP%FpUCP=X1q$gPx*_ul66i|B6TV)wV=FY- zIQGqbaP2TT9M3KEiJ@^(LBi?ZxsF6P%t_3a&Zk`v3k*m8jzXa^RDm%9h)L zU)uB)rKj_{o!OSZ0+$K;t8VLtUb`~CGPtU+!2?zW@X);;RAguzx8fclE*m^-RtTUo z9=xxMUz+EZPj6I;1M2p1&eqSrRe@+dgJ@bAf?AP$wht~XN@wDpd53_pI@2R|O z>Cws8#iCn z0VTM{KG}J7QHCF^rYZM`{%HA2GmrZ++0H>(MK!V_B-E3HfX7>dBKAqSa_}e(bj%e< mgxFu11uy=L|IjCo$$$*)y(|TazAfNC0m2@&LqD^*5cLnAVe|q3 literal 2122 zcmds3`#+O=9Dg=Cxx_h4rJ@TG6+30Q7TH!JX-pKAZZ6HCi?n6RUZRmhD$#604$7hB zWG=DjwWRXe(g}OnSXjBOxunb`?R=lf>2Emahwt-j&+GgCyx-6J^ZtHbPcp&FO-F0F z76d^$?z`|l;7)^=1`=F74^FTkNJGyZ@9IlI$e(TsWA!gvII867oT0b-1e(yXQ8>*) zv(97+!PfHTd(z}qbuTe_Nj%ytW1-fET3wTuz3kl=EDTo}m~v9AlV-o~PWhCi(Vn zk)@I+*?1=8bC+|-1&!i9OT>v6JjAG@}V*L+yLMW`XCY-c3=cqbNu^c0bCxnNPiTTf-eM{j?h`DhwTMa9d|@3_7&?QKhO_EN~l{I z{M5D-MPyZ9-{?^I?Ov8z`iDNQn^A&U??Y8F*FA2i@GS(Hay+oP14--Z`PNp(q#(n0 zkKeMSJ<(w~AqkATeJq*ybm$0l=9e2uyuRl{F;ex0)R%8%si@fDHlYUx`)Cd~xi_Zk z7S;*!;Jh;6_>^s4(EXiH9q^r=DqLFD290*M4_uDaR^i$8$UNT-7?q}`tbo95w&#qk zMa5;y!hqnRAJDsRWDgAX&K`2Rx%5t42=|PSgd4;_TC}Yv_8-h-Dpm%OB2GTpYz6+e zP6IfxWN~(WCj6R!_HKD4=KE%-a5?Q2a4nKt9TzP#U&$68vaC?c^ZnD)I~+aKV+Ag! z0}-DisvnVQ_Q2E=aMcBbl6^(UFqXXRrP9@%b zmx%{=*dyNS?6z8?*l2rN&;L>`ukZ~+P2pm{1VYfZ7 z`ra|h96K_7ChySUBubU4I>h^ z=a{{aonKU-Ndk&2MkHtq@5eT9_`f{Yq?0oVY4g^G0W}~#t!M$++qNcofJ>d~$QiTq4!1>b4ArJ&n38?#HbLFo zCjW$Vd}(`TAEK21wP1@KT?A%wLvc!$H3n@|W6E;adJdZA&3-tK;^c7EVi$w6brQ^h zm?*DkoyiV`~&Qf_q diff --git a/src/models/evaluate.py b/src/models/evaluate.py index 469bbc7..073f430 100644 --- a/src/models/evaluate.py +++ b/src/models/evaluate.py @@ -30,7 +30,7 @@ def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: Spinda if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") - parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"]) + parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"]) args = parser.parse_args() if not os.path.exists(args.model_path): diff --git a/src/models/inference.py b/src/models/inference.py index f63c53c..b808109 100644 --- a/src/models/inference.py +++ b/src/models/inference.py @@ -11,9 +11,13 @@ from src.models.regression_model import SpindaRegressionModel class SpindaInference: """Loads the trained model and predicts spot coordinates from an image crop.""" - def __init__(self, model_path: str = "models/best_spinda_model.pth"): + def __init__( + self, + model_path: str = "models/best_spinda_model.pth", + backbone: str = "resnet18", + ): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model = SpindaRegressionModel(pretrained=False) + self.model = SpindaRegressionModel(pretrained=False, backbone=backbone) self.model.load_state_dict( torch.load(model_path, map_location=self.device) ) diff --git a/src/models/regression_model.py b/src/models/regression_model.py index c52197e..473959f 100644 --- a/src/models/regression_model.py +++ b/src/models/regression_model.py @@ -1,16 +1,18 @@ import torch import torch.nn as nn from torchvision import models -from torchvision.models import ResNet18_Weights, ResNet34_Weights +from torchvision.models import ResNet18_Weights, ResNet34_Weights, ConvNeXt_Tiny_Weights +# (factory, default_weights, feature_dim) _BACKBONES = { - "resnet18": (models.resnet18, ResNet18_Weights.DEFAULT), - "resnet34": (models.resnet34, ResNet34_Weights.DEFAULT), + "resnet18": (models.resnet18, ResNet18_Weights.DEFAULT, 512), + "resnet34": (models.resnet34, ResNet34_Weights.DEFAULT, 512), + "convnext_tiny": (models.convnext_tiny, ConvNeXt_Tiny_Weights.DEFAULT, 768), } class SpindaRegressionModel(nn.Module): - """ResNet backbone with 8 independent 16-class coordinate heads. + """CNN backbone with 8 independent 16-class coordinate heads. Each of the 8 output coordinates (4 spots × x, y) is treated as a 16-class classification problem over the [0, 15] nibble grid. @@ -25,23 +27,29 @@ class SpindaRegressionModel(nn.Module): super().__init__() if backbone not in _BACKBONES: raise ValueError(f"backbone must be one of {list(_BACKBONES)}; got {backbone!r}") - factory, default_weights = _BACKBONES[backbone] + factory, default_weights, feat_dim = _BACKBONES[backbone] weights = default_weights if pretrained else None net = factory(weights=weights) - # Strip the final FC; keep the feature extractor + average pool. - self.features = nn.Sequential(*list(net.children())[:-1]) - # 8 coordinates × 16 classes each (512-dim output for both resnet18/34) - self.classifier = nn.Linear(512, 8 * 16) + + if backbone in ("resnet18", "resnet34"): + # Strip the final FC; flatten the (B, 512, 1, 1) avgpool output. + self.features = nn.Sequential(*list(net.children())[:-1], nn.Flatten()) + elif backbone == "convnext_tiny": + # Keep features + avgpool + LayerNorm (classifier[0]); drop the final Linear. + self.features = nn.Sequential( + net.features, net.avgpool, net.classifier[0], nn.Flatten() + ) + + self.classifier = nn.Linear(feat_dim, 8 * 16) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.features(x) # (B, 512, 1, 1) - x = x.flatten(1) # (B, 512) + x = self.features(x) # (B, feat_dim) x = self.classifier(x) # (B, 128) return x.view(-1, 8, 16) # (B, 8, 16) if __name__ == "__main__": - for name in ("resnet18", "resnet34"): + for name in ("resnet18", "resnet34", "convnext_tiny"): model = SpindaRegressionModel(pretrained=False, backbone=name) out = model(torch.randn(2, 3, 128, 128)) print(f"{name}: output {out.shape}, predictions {out.argmax(dim=2)}") diff --git a/src/models/train.py b/src/models/train.py index 09b3e4c..0049a59 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -198,7 +198,7 @@ if __name__ == "__main__": parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)") parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch") - parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"]) + parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"]) parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path") args = parser.parse_args()