New backbone
This commit is contained in:
@@ -30,7 +30,7 @@ def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: Spinda
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
|
||||
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"])
|
||||
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.model_path):
|
||||
|
||||
@@ -11,9 +11,13 @@ from src.models.regression_model import SpindaRegressionModel
|
||||
class SpindaInference:
|
||||
"""Loads the trained model and predicts spot coordinates from an image crop."""
|
||||
|
||||
def __init__(self, model_path: str = "models/best_spinda_model.pth"):
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str = "models/best_spinda_model.pth",
|
||||
backbone: str = "resnet18",
|
||||
):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.model = SpindaRegressionModel(pretrained=False)
|
||||
self.model = SpindaRegressionModel(pretrained=False, backbone=backbone)
|
||||
self.model.load_state_dict(
|
||||
torch.load(model_path, map_location=self.device)
|
||||
)
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
from torchvision.models import ResNet18_Weights, ResNet34_Weights
|
||||
from torchvision.models import ResNet18_Weights, ResNet34_Weights, ConvNeXt_Tiny_Weights
|
||||
|
||||
# (factory, default_weights, feature_dim)
|
||||
_BACKBONES = {
|
||||
"resnet18": (models.resnet18, ResNet18_Weights.DEFAULT),
|
||||
"resnet34": (models.resnet34, ResNet34_Weights.DEFAULT),
|
||||
"resnet18": (models.resnet18, ResNet18_Weights.DEFAULT, 512),
|
||||
"resnet34": (models.resnet34, ResNet34_Weights.DEFAULT, 512),
|
||||
"convnext_tiny": (models.convnext_tiny, ConvNeXt_Tiny_Weights.DEFAULT, 768),
|
||||
}
|
||||
|
||||
|
||||
class SpindaRegressionModel(nn.Module):
|
||||
"""ResNet backbone with 8 independent 16-class coordinate heads.
|
||||
"""CNN backbone with 8 independent 16-class coordinate heads.
|
||||
|
||||
Each of the 8 output coordinates (4 spots × x, y) is treated as a
|
||||
16-class classification problem over the [0, 15] nibble grid.
|
||||
@@ -25,23 +27,29 @@ class SpindaRegressionModel(nn.Module):
|
||||
super().__init__()
|
||||
if backbone not in _BACKBONES:
|
||||
raise ValueError(f"backbone must be one of {list(_BACKBONES)}; got {backbone!r}")
|
||||
factory, default_weights = _BACKBONES[backbone]
|
||||
factory, default_weights, feat_dim = _BACKBONES[backbone]
|
||||
weights = default_weights if pretrained else None
|
||||
net = factory(weights=weights)
|
||||
# Strip the final FC; keep the feature extractor + average pool.
|
||||
self.features = nn.Sequential(*list(net.children())[:-1])
|
||||
# 8 coordinates × 16 classes each (512-dim output for both resnet18/34)
|
||||
self.classifier = nn.Linear(512, 8 * 16)
|
||||
|
||||
if backbone in ("resnet18", "resnet34"):
|
||||
# Strip the final FC; flatten the (B, 512, 1, 1) avgpool output.
|
||||
self.features = nn.Sequential(*list(net.children())[:-1], nn.Flatten())
|
||||
elif backbone == "convnext_tiny":
|
||||
# Keep features + avgpool + LayerNorm (classifier[0]); drop the final Linear.
|
||||
self.features = nn.Sequential(
|
||||
net.features, net.avgpool, net.classifier[0], nn.Flatten()
|
||||
)
|
||||
|
||||
self.classifier = nn.Linear(feat_dim, 8 * 16)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.features(x) # (B, 512, 1, 1)
|
||||
x = x.flatten(1) # (B, 512)
|
||||
x = self.features(x) # (B, feat_dim)
|
||||
x = self.classifier(x) # (B, 128)
|
||||
return x.view(-1, 8, 16) # (B, 8, 16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for name in ("resnet18", "resnet34"):
|
||||
for name in ("resnet18", "resnet34", "convnext_tiny"):
|
||||
model = SpindaRegressionModel(pretrained=False, backbone=name)
|
||||
out = model(torch.randn(2, 3, 128, 128))
|
||||
print(f"{name}: output {out.shape}, predictions {out.argmax(dim=2)}")
|
||||
|
||||
@@ -198,7 +198,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
|
||||
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)")
|
||||
parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch")
|
||||
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"])
|
||||
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
|
||||
parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user