Refactor/cleanup
This commit is contained in:
@@ -18,7 +18,8 @@
|
|||||||
"Bash(.venv/bin/python -m src.models.evaluate)",
|
"Bash(.venv/bin/python -m src.models.evaluate)",
|
||||||
"Bash(.venv/bin/python -m src.models.evaluate --backbone resnet34 --model_path models/best_resnet34_model.pth)",
|
"Bash(.venv/bin/python -m src.models.evaluate --backbone resnet34 --model_path models/best_resnet34_model.pth)",
|
||||||
"Bash(.venv/bin/python identify.py test_image2.jpg --backbone resnet34 --model_path models/best_resnet34_model.pth)",
|
"Bash(.venv/bin/python identify.py test_image2.jpg --backbone resnet34 --model_path models/best_resnet34_model.pth)",
|
||||||
"Bash(.venv/bin/python src/models/regression_model.py)"
|
"Bash(.venv/bin/python src/models/regression_model.py)",
|
||||||
|
"Bash(xargs cat -n)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,2 +1,3 @@
|
|||||||
build/**
|
build/**
|
||||||
**/__pycache__/**
|
**/__pycache__/**
|
||||||
|
models/**
|
||||||
|
|||||||
24
CLAUDE.md
24
CLAUDE.md
@@ -11,10 +11,10 @@ All commands must be run from the project root using the local venv:
|
|||||||
.venv/bin/python identify.py <image_path>
|
.venv/bin/python identify.py <image_path>
|
||||||
|
|
||||||
# Train the model
|
# Train the model
|
||||||
.venv/bin/python -m src.models.train --epochs 50 --batch_size 64 --lr 1e-4
|
.venv/bin/python -m src.models.train --epochs 50
|
||||||
|
|
||||||
# Evaluate a trained model on val and aug_test sets
|
# Evaluate a trained model on val and aug_test sets
|
||||||
.venv/bin/python -m src.models.evaluate [--backbone resnet18|resnet34] [--model_path <path>]
|
.venv/bin/python -m src.models.evaluate [--backbone resnet34] [--model_path <path>]
|
||||||
|
|
||||||
# Run inference only (no registry lookup)
|
# Run inference only (no registry lookup)
|
||||||
.venv/bin/python src/models/inference.py <image_path>
|
.venv/bin/python src/models/inference.py <image_path>
|
||||||
@@ -26,7 +26,7 @@ All commands must be run from the project root using the local venv:
|
|||||||
.venv/bin/python -m src.data.generate_aug_test_set
|
.venv/bin/python -m src.data.generate_aug_test_set
|
||||||
|
|
||||||
# Generate a single sample image for visual inspection
|
# Generate a single sample image for visual inspection
|
||||||
.venv/bin/python src/data/high_fidelity_generator.py
|
.venv/bin/python src/data/renderer.py
|
||||||
|
|
||||||
# Lint
|
# Lint
|
||||||
.venv/bin/ruff check src/
|
.venv/bin/ruff check src/
|
||||||
@@ -50,7 +50,7 @@ Image → Detector → Inference → Resolver → Registry
|
|||||||
|
|
||||||
### 1. Detection (`src/utils/detector.py`)
|
### 1. Detection (`src/utils/detector.py`)
|
||||||
|
|
||||||
`SpindaDetector.detect_and_crop()` returns a **128×128 BGR** image, or `None`.
|
`SpindaDetector.detect_and_crop()` returns a **128×128 BGR** numpy array, or `None`.
|
||||||
|
|
||||||
Two-tier strategy, tried in order:
|
Two-tier strategy, tried in order:
|
||||||
- **Tier 1 (screenshots/sprites):** HSV-filter red pixels → find individual spot blobs → cluster to 4 spots → derive crop from cluster centroid + `_SPOT_CROP_RATIO=5.5` (= 128 / 24.5 px span) with a `_SPOT_CENTER_OFFSET=0.056` downward shift so the spot centroid lands at 44.4 % from the top of the crop (matching the training canvas).
|
- **Tier 1 (screenshots/sprites):** HSV-filter red pixels → find individual spot blobs → cluster to 4 spots → derive crop from cluster centroid + `_SPOT_CROP_RATIO=5.5` (= 128 / 24.5 px span) with a `_SPOT_CENTER_OFFSET=0.056` downward shift so the spot centroid lands at 44.4 % from the top of the crop (matching the training canvas).
|
||||||
@@ -58,15 +58,18 @@ Two-tier strategy, tried in order:
|
|||||||
|
|
||||||
### 2. Model (`src/models/regression_model.py`)
|
### 2. Model (`src/models/regression_model.py`)
|
||||||
|
|
||||||
ResNet-18 backbone with the final FC replaced by `Linear(512, 8·16)`. Forward pass returns **(B, 8, 16)** — treating each of the 8 coordinates as a 16-class classification problem. Trained with `CrossEntropyLoss` on `view(-1, 16)` vs `view(-1)` targets; predictions use `argmax(dim=2)`.
|
Configurable backbone (default: ResNet-34) with the final FC replaced by `Linear(feat_dim, 8·16)`. Forward pass returns **(B, 8, 16)** — treating each of the 8 coordinates as a 16-class classification problem. Trained with `CrossEntropyLoss` on `view(-1, 16)` vs `view(-1)` targets; predictions use `argmax(dim=2)`.
|
||||||
|
|
||||||
|
Supported backbones: `resnet18` (512-d), `resnet34` (512-d), `convnext_tiny` (768-d).
|
||||||
|
|
||||||
### 3. Training (`src/models/train.py`)
|
### 3. Training (`src/models/train.py`)
|
||||||
|
|
||||||
- `SpindaDataset` (200 k virtual samples/epoch): generates a fresh random 32-bit PID per `__getitem__`, renders the sprite with a random background colour, then applies the full augmentation pipeline.
|
- `SpindaDataset` (200 k virtual samples/epoch): generates a fresh random 32-bit PID per `__getitem__`, renders the sprite with a random background colour, then applies the full augmentation pipeline.
|
||||||
- `SpindaEvalDataset`: loads pre-generated images from disk (post-augmentation, pre-normalisation) and applies only the normalise step. Used for both `data/val/` (clean, seed=42) and `data/aug_test/` (augmented, seed=99).
|
- `SpindaEvalDataset` (in `src/data/dataset.py`): loads pre-generated images from disk (post-augmentation, pre-normalisation) and applies only the normalise step. Used for both `data/val/` (clean, seed=42) and `data/aug_test/` (augmented, seed=99).
|
||||||
- `_worker_init_fn` re-seeds Python `random` and NumPy per worker so forked workers generate distinct PIDs.
|
- `_worker_init_fn` re-seeds Python `random` and NumPy per worker so forked workers generate distinct PIDs.
|
||||||
|
- Weighted loss: BL_x ×1.5, BL_y ×2.5 — applied during training only; val loss is unweighted.
|
||||||
- Early stopping: patience = 10 epochs on clean-val exact-match rate.
|
- Early stopping: patience = 10 epochs on clean-val exact-match rate.
|
||||||
- Best model checkpoint: `models/best_spinda_model.pth`.
|
- Checkpoints saved to `models/best_{backbone}_model.pth`.
|
||||||
|
|
||||||
### 4. PID Encoding (domain invariant — must not be changed)
|
### 4. PID Encoding (domain invariant — must not be changed)
|
||||||
|
|
||||||
@@ -107,7 +110,8 @@ assets/ # Sprite assets used by the renderer
|
|||||||
Spot_{TL,TR,BL,BR}.png
|
Spot_{TL,TR,BL,BR}.png
|
||||||
|
|
||||||
models/
|
models/
|
||||||
best_spinda_model.pth
|
best_resnet34_model.pth # current best (default)
|
||||||
|
best_convnext_tiny_model.pth # convnext experiment
|
||||||
```
|
```
|
||||||
|
|
||||||
`metadata.json` format: `[{"img_path": "...", "pid_hex": "...", "target": [int×8]}, ...]`
|
`metadata.json` format: `[{"img_path": "...", "pid_hex": "...", "target": [int×8]}, ...]`
|
||||||
@@ -115,8 +119,8 @@ models/
|
|||||||
## Key invariants
|
## Key invariants
|
||||||
|
|
||||||
- **Visual collisions:** ~1.3 % of fingerprints are shared by multiple PIDs (many-to-one mapping). `SpindaRegistry` stores `(fingerprint, pid_hex)` pairs with a unique constraint so `lookup_by_fingerprint` can return *all* matching PIDs — this is intentional, not a bug.
|
- **Visual collisions:** ~1.3 % of fingerprints are shared by multiple PIDs (many-to-one mapping). `SpindaRegistry` stores `(fingerprint, pid_hex)` pairs with a unique constraint so `lookup_by_fingerprint` can return *all* matching PIDs — this is intentional, not a bug.
|
||||||
|
|
||||||
- The **validation set** uses white backgrounds (no augmentation baked in) to give a stable epoch-comparable baseline. Do not add augmentation to `generate_val_set.py`.
|
- The **validation set** uses white backgrounds (no augmentation baked in) to give a stable epoch-comparable baseline. Do not add augmentation to `generate_val_set.py`.
|
||||||
- The **augmented test set** is pre-generated and fixed. Regenerating it changes the baseline; do so intentionally.
|
- The **augmented test set** is pre-generated and fixed. Regenerating it changes the baseline; do so intentionally.
|
||||||
- The crop output size is always **128×128** regardless of tier. The model transform chain also resizes to 128×128, so the inference path is robust to re-size.
|
- The crop output size is always **128×128** regardless of tier. The model transform chain also resizes to 128×128, so the inference path is robust to re-size.
|
||||||
- `generate_high_fidelity_spinda()` always takes `bg_color` as a `(R, G, B)` tuple in PIL order (not BGR).
|
- `generate_high_fidelity_spinda()` in `src/data/renderer.py` always takes `bg_color` as a `(R, G, B)` tuple in PIL order (not BGR).
|
||||||
|
- `SpindaInference.predict()` accepts either a file path or a BGR numpy array directly (e.g. from the detector).
|
||||||
|
|||||||
50
identify.py
50
identify.py
@@ -1,70 +1,59 @@
|
|||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
|
||||||
|
from src.data.renderer import generate_high_fidelity_spinda
|
||||||
from src.models.inference import SpindaInference
|
from src.models.inference import SpindaInference
|
||||||
from src.utils.resolver import SpindaResolver
|
|
||||||
from src.registry.database import SpindaRegistry
|
from src.registry.database import SpindaRegistry
|
||||||
from src.data.high_fidelity_generator import generate_high_fidelity_spinda
|
from src.utils.detector import SpindaDetector
|
||||||
from src.utils.detector import SpindaDetector # Import the detector
|
from src.utils.resolver import SpindaResolver
|
||||||
|
|
||||||
|
|
||||||
def identify_spinda(
|
def identify_spinda(
|
||||||
image_path: str,
|
image_path: str,
|
||||||
model_path: str = "models/best_spinda_model.pth",
|
model_path: str = "models/best_resnet34_model.pth",
|
||||||
backbone: str = "resnet18",
|
backbone: str = "resnet34",
|
||||||
):
|
) -> None:
|
||||||
if not os.path.exists(image_path):
|
if not os.path.exists(image_path):
|
||||||
print(f"Error: File {image_path} not found.")
|
print(f"Error: File {image_path} not found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"--- Identifying Spinda in {image_path} ---")
|
print(f"--- Identifying Spinda in {image_path} ---")
|
||||||
|
|
||||||
# 1. Detect and Crop Spinda
|
# 1. Detect and crop
|
||||||
detector = SpindaDetector()
|
detector = SpindaDetector()
|
||||||
cropped_img = detector.detect_and_crop(image_path)
|
cropped_img = detector.detect_and_crop(image_path)
|
||||||
|
|
||||||
if cropped_img is None:
|
if cropped_img is None:
|
||||||
print("Error: Could not detect Spinda in the image.")
|
print("Error: Could not detect Spinda in the image.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Save cropped image for debug/visual check
|
|
||||||
cv2.imwrite("detected_spinda_crop.png", cropped_img)
|
cv2.imwrite("detected_spinda_crop.png", cropped_img)
|
||||||
print("Detected Spinda saved to detected_spinda_crop.png")
|
print("Detected Spinda saved to detected_spinda_crop.png")
|
||||||
|
|
||||||
# We need to save the cropped image to a temporary file for the inference model to read
|
# 2. Inference — pass the BGR array directly, no temp file needed
|
||||||
temp_cropped_path = "temp_cropped_spinda.png"
|
|
||||||
cv2.imwrite(temp_cropped_path, cropped_img)
|
|
||||||
|
|
||||||
# 2. Inference (Model Prediction) using the cropped image
|
|
||||||
try:
|
|
||||||
inf = SpindaInference(model_path=model_path, backbone=backbone)
|
inf = SpindaInference(model_path=model_path, backbone=backbone)
|
||||||
coords, fingerprint = inf.predict(temp_cropped_path)
|
coords, fingerprint = inf.predict(cropped_img)
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during inference: {e}")
|
|
||||||
os.remove(temp_cropped_path) # Clean up temp file
|
|
||||||
return
|
|
||||||
finally:
|
|
||||||
os.remove(temp_cropped_path) # Clean up temp file
|
|
||||||
|
|
||||||
print(f"Visual Fingerprint: {fingerprint}")
|
print(f"Visual Fingerprint: {fingerprint}")
|
||||||
print(f"Predicted Grid Coordinates: {coords}")
|
print(f"Predicted Grid Coordinates: {coords}")
|
||||||
|
|
||||||
# 3. Resolution (Mathematical PIDs)
|
# 3. Resolve to PIDs
|
||||||
resolved = SpindaResolver.resolve_fingerprint(fingerprint)
|
resolved = SpindaResolver.resolve_fingerprint(fingerprint)
|
||||||
print("\nPossible PIDs:")
|
print("\nPossible PIDs:")
|
||||||
print(f" Standard (Gen 3-8, HOME): 0x{resolved['standard']}")
|
print(f" Standard (Gen 3-8, HOME): 0x{resolved['standard']}")
|
||||||
print(f" BDSP (Big-Endian Flip): 0x{resolved['bdsp']}")
|
print(f" BDSP (Big-Endian Flip): 0x{resolved['bdsp']}")
|
||||||
|
|
||||||
# 4. Visual Verification
|
# 4. Visual verification
|
||||||
print("\nGenerating visual verification image...")
|
print("\nGenerating visual verification image...")
|
||||||
verify_img = generate_high_fidelity_spinda(int(resolved['standard'], 16))
|
verify_img = generate_high_fidelity_spinda(int(resolved['standard'], 16))
|
||||||
cv2.imwrite("prediction_verify.png", verify_img)
|
cv2.imwrite("prediction_verify.png", verify_img)
|
||||||
print("Verification image saved to: prediction_verify.png")
|
print("Verification image saved to: prediction_verify.png")
|
||||||
|
|
||||||
# 5. Registry Lookup
|
# 5. Registry lookup
|
||||||
reg = SpindaRegistry()
|
reg = SpindaRegistry()
|
||||||
matches = reg.lookup_by_fingerprint(fingerprint)
|
matches = reg.lookup_by_fingerprint(fingerprint)
|
||||||
|
|
||||||
if matches:
|
if matches:
|
||||||
print("\nMatches found in Global Registry:")
|
print("\nMatches found in Global Registry:")
|
||||||
for pid in matches:
|
for pid in matches:
|
||||||
@@ -72,13 +61,12 @@ def identify_spinda(
|
|||||||
else:
|
else:
|
||||||
print("\nNo matching entries in Global Registry.")
|
print("\nNo matching entries in Global Registry.")
|
||||||
|
|
||||||
print("\nNote: Accuracy depends on model training progress.")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("image_path")
|
parser.add_argument("image_path")
|
||||||
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
|
parser.add_argument("--backbone", type=str, default="resnet34",
|
||||||
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
|
choices=["resnet18", "resnet34", "convnext_tiny"])
|
||||||
|
parser.add_argument("--model_path", type=str, default="models/best_resnet34_model.pth")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
identify_spinda(args.image_path, model_path=args.model_path, backbone=args.backbone)
|
identify_spinda(args.image_path, model_path=args.model_path, backbone=args.backbone)
|
||||||
|
|||||||
Binary file not shown.
@@ -9,6 +9,7 @@ dependencies = [
|
|||||||
"opencv-python>=4.7.0",
|
"opencv-python>=4.7.0",
|
||||||
"Pillow>=9.5.0",
|
"Pillow>=9.5.0",
|
||||||
"numpy>=1.24.0",
|
"numpy>=1.24.0",
|
||||||
|
"tqdm>=4.65.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -16,7 +17,6 @@ dev = [
|
|||||||
"pytest>=7.3.0",
|
"pytest>=7.3.0",
|
||||||
"ruff>=0.0.270",
|
"ruff>=0.0.270",
|
||||||
"mypy>=1.3.0",
|
"mypy>=1.3.0",
|
||||||
"tqdm>=4.65.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
|||||||
@@ -1,69 +1,52 @@
|
|||||||
import torch
|
import json
|
||||||
from torch.utils.data import Dataset
|
import os
|
||||||
import torchvision.transforms.v2 as T
|
|
||||||
import numpy as np
|
|
||||||
from typing import Tuple, Optional
|
|
||||||
import random
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms.v2 as T
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from src.data.renderer import generate_high_fidelity_spinda
|
||||||
|
from src.utils.resolver import SpindaResolver
|
||||||
|
|
||||||
from src.data.high_fidelity_generator import generate_high_fidelity_spinda
|
|
||||||
|
|
||||||
class SpindaDataset(Dataset):
|
class SpindaDataset(Dataset):
|
||||||
"""PyTorch Dataset for generating synthetic Spinda samples with augmentations."""
|
"""PyTorch Dataset for generating synthetic Spinda samples with augmentations."""
|
||||||
|
|
||||||
def __init__(self, size: int = 10000, transform: Optional[T.Transform] = None):
|
def __init__(self, size: int = 10000, transform: Optional[T.Transform] = None):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
size: Virtual size of the dataset (since it's synthetic).
|
|
||||||
transform: Optional torchvision transforms to apply.
|
|
||||||
"""
|
|
||||||
self.size = size
|
self.size = size
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Generate a random 32-bit PID
|
|
||||||
pid = random.getrandbits(32)
|
pid = random.getrandbits(32)
|
||||||
pid_hex = f"{pid:08x}"
|
pid_hex = f"{pid:08x}"
|
||||||
|
|
||||||
# 1. Generate High-Fidelity Image on a random background colour
|
|
||||||
r = random.randint(0, 255)
|
r = random.randint(0, 255)
|
||||||
g = random.randint(0, 255)
|
g = random.randint(0, 255)
|
||||||
b = random.randint(0, 255)
|
b = random.randint(0, 255)
|
||||||
img_bgr = generate_high_fidelity_spinda(pid, bg_color=(r, g, b))
|
img_bgr = generate_high_fidelity_spinda(pid, bg_color=(r, g, b))
|
||||||
|
|
||||||
# Convert BGR to RGB for PyTorch/Torchvision
|
|
||||||
img_rgb = img_bgr[:, :, ::-1].copy()
|
img_rgb = img_bgr[:, :, ::-1].copy()
|
||||||
|
|
||||||
# 2. Get Ground Truth Coordinates (Target)
|
target_tensor = torch.tensor(SpindaResolver.pid_to_coords(pid_hex), dtype=torch.long)
|
||||||
# Raw nibble values (0-15) for each spot, in TL/TR/BL/BR order.
|
|
||||||
raw_coords = []
|
|
||||||
# TL (Spot 1): Nibble 0, 1 (PID[-1], PID[-2])
|
|
||||||
raw_coords.extend([int(pid_hex[-1], 16), int(pid_hex[-2], 16)])
|
|
||||||
# TR (Spot 2): Nibble 2, 3 (PID[-3], PID[-4])
|
|
||||||
raw_coords.extend([int(pid_hex[-3], 16), int(pid_hex[-4], 16)])
|
|
||||||
# BL (Spot 3): Nibble 4, 5 (PID[3], PID[2])
|
|
||||||
raw_coords.extend([int(pid_hex[3], 16), int(pid_hex[2], 16)])
|
|
||||||
# BR (Spot 4): Nibble 6, 7 (PID[1], PID[0])
|
|
||||||
raw_coords.extend([int(pid_hex[1], 16), int(pid_hex[0], 16)])
|
|
||||||
|
|
||||||
# Integer labels in [0, 15] — used with CrossEntropyLoss.
|
|
||||||
target_tensor = torch.tensor(raw_coords, dtype=torch.long)
|
|
||||||
|
|
||||||
# 3. Apply Transforms
|
|
||||||
if self.transform:
|
if self.transform:
|
||||||
# Convert to PIL or Tensor first if needed by transform
|
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1)
|
||||||
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1) # C, H, W
|
|
||||||
img_tensor = self.transform(img_tensor)
|
img_tensor = self.transform(img_tensor)
|
||||||
else:
|
else:
|
||||||
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0
|
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0
|
||||||
|
|
||||||
return img_tensor, target_tensor
|
return img_tensor, target_tensor
|
||||||
|
|
||||||
|
|
||||||
def add_gaussian_noise(x: torch.Tensor) -> torch.Tensor:
|
def add_gaussian_noise(x: torch.Tensor) -> torch.Tensor:
|
||||||
return (x + torch.randn_like(x) * 0.05).clamp(0, 1)
|
return (x + torch.randn_like(x) * 0.05).clamp(0, 1)
|
||||||
|
|
||||||
|
|
||||||
def add_scan_lines(x: torch.Tensor) -> torch.Tensor:
|
def add_scan_lines(x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Simulate LCD scan lines seen in handheld-camera photos of 3DS screens."""
|
"""Simulate LCD scan lines seen in handheld-camera photos of 3DS screens."""
|
||||||
if torch.rand(1).item() < 0.5:
|
if torch.rand(1).item() < 0.5:
|
||||||
@@ -72,26 +55,52 @@ def add_scan_lines(x: torch.Tensor) -> torch.Tensor:
|
|||||||
x[:, ::2, :] *= (1.0 - strength)
|
x[:, ::2, :] *= (1.0 - strength)
|
||||||
return x.clamp(0, 1)
|
return x.clamp(0, 1)
|
||||||
|
|
||||||
|
|
||||||
def get_default_augmentations() -> T.Compose:
|
def get_default_augmentations() -> T.Compose:
|
||||||
"""Domain randomisation pipeline calibrated for real handheld-photo conditions."""
|
"""Domain randomisation pipeline calibrated for real handheld-photo conditions."""
|
||||||
return T.Compose([
|
return T.Compose([
|
||||||
T.ToDtype(torch.float32, scale=True),
|
T.ToDtype(torch.float32, scale=True),
|
||||||
# Spatial — wider range to cover camera angle and zoom variation
|
|
||||||
T.RandomAffine(degrees=25, translate=(0.05, 0.05), shear=8),
|
T.RandomAffine(degrees=25, translate=(0.05, 0.05), shear=8),
|
||||||
T.RandomResizedCrop(size=(128, 128), scale=(0.75, 1.0), ratio=(0.85, 1.15), antialias=True),
|
T.RandomResizedCrop(size=(128, 128), scale=(0.75, 1.0), ratio=(0.85, 1.15), antialias=True),
|
||||||
# Colour / sensor — stronger to cover screen glare and ambient lighting
|
|
||||||
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.05),
|
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.05),
|
||||||
T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)),
|
T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)),
|
||||||
# Sensor noise and LCD scan lines
|
|
||||||
T.Lambda(add_gaussian_noise),
|
T.Lambda(add_gaussian_noise),
|
||||||
T.Lambda(add_scan_lines),
|
T.Lambda(add_scan_lines),
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class SpindaEvalDataset(Dataset):
|
||||||
|
"""Loads a fixed evaluation set (clean val or augmented test).
|
||||||
|
|
||||||
|
Images are stored post-augmentation, pre-normalisation; this class
|
||||||
|
applies only the normalisation step so the on-disk images are stable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_dir: str):
|
||||||
|
with open(os.path.join(data_dir, "metadata.json")) as f:
|
||||||
|
self.metadata = json.load(f)
|
||||||
|
self.transform = T.Compose([
|
||||||
|
T.ToDtype(torch.float32, scale=True),
|
||||||
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.metadata)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
item = self.metadata[idx]
|
||||||
|
img_bgr = cv2.imread(item["img_path"])
|
||||||
|
if img_bgr is None:
|
||||||
|
raise FileNotFoundError(f"Image not found: {item['img_path']}")
|
||||||
|
img_rgb = img_bgr[:, :, ::-1].copy()
|
||||||
|
img_tensor = self.transform(torch.from_numpy(img_rgb).permute(2, 0, 1))
|
||||||
|
target = torch.tensor(item["target"], dtype=torch.long)
|
||||||
|
return img_tensor, target
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Test the dataset
|
|
||||||
ds = SpindaDataset(size=5, transform=get_default_augmentations())
|
ds = SpindaDataset(size=5, transform=get_default_augmentations())
|
||||||
img, target = ds[0]
|
img, target = ds[0]
|
||||||
print(f"Image shape: {img.shape}")
|
print(f"Image shape: {img.shape}")
|
||||||
print(f"Target (normalized 0-1): {target}")
|
print(f"Target: {target.tolist()}")
|
||||||
print(f"Target (grid units): {target * 15.0}")
|
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ import torchvision.transforms.v2 as T
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from src.data.dataset import add_gaussian_noise, add_scan_lines
|
from src.data.dataset import add_gaussian_noise, add_scan_lines
|
||||||
from src.data.high_fidelity_generator import generate_high_fidelity_spinda
|
from src.data.renderer import generate_high_fidelity_spinda
|
||||||
|
from src.utils.resolver import SpindaResolver
|
||||||
|
|
||||||
|
|
||||||
# Augmentation without the final Normalize step — we bake everything else in.
|
# Augmentation without the final Normalize step — we bake everything else in.
|
||||||
@@ -53,13 +54,11 @@ def generate_aug_test_set(size: int = 500, output_dir: str = "data/aug_test") ->
|
|||||||
img_path = os.path.join(output_dir, f"sample_{i:04d}.png")
|
img_path = os.path.join(output_dir, f"sample_{i:04d}.png")
|
||||||
cv2.imwrite(img_path, aug_bgr)
|
cv2.imwrite(img_path, aug_bgr)
|
||||||
|
|
||||||
raw_coords = [
|
metadata.append({
|
||||||
int(pid_hex[-1], 16), int(pid_hex[-2], 16),
|
"img_path": img_path,
|
||||||
int(pid_hex[-3], 16), int(pid_hex[-4], 16),
|
"pid_hex": pid_hex,
|
||||||
int(pid_hex[3], 16), int(pid_hex[2], 16),
|
"target": SpindaResolver.pid_to_coords(pid_hex),
|
||||||
int(pid_hex[1], 16), int(pid_hex[0], 16),
|
})
|
||||||
]
|
|
||||||
metadata.append({"img_path": img_path, "pid_hex": pid_hex, "target": raw_coords})
|
|
||||||
|
|
||||||
with open(os.path.join(output_dir, "metadata.json"), "w") as f:
|
with open(os.path.join(output_dir, "metadata.json"), "w") as f:
|
||||||
json.dump(metadata, f, indent=4)
|
json.dump(metadata, f, indent=4)
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
from tqdm import tqdm
|
|
||||||
from src.data.high_fidelity_generator import generate_high_fidelity_spinda
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
def generate_fixed_val_set(size: int = 1000, output_dir: str = "data/val"):
|
import cv2
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from src.data.renderer import generate_high_fidelity_spinda
|
||||||
|
from src.utils.resolver import SpindaResolver
|
||||||
|
|
||||||
|
|
||||||
|
def generate_fixed_val_set(size: int = 1000, output_dir: str = "data/val") -> None:
|
||||||
"""Generates a fixed set of Spinda images and their targets for validation."""
|
"""Generates a fixed set of Spinda images and their targets for validation."""
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
metadata = []
|
metadata = []
|
||||||
|
|
||||||
print(f"Generating {size} validation samples...")
|
print(f"Generating {size} validation samples...")
|
||||||
@@ -17,23 +19,14 @@ def generate_fixed_val_set(size: int = 1000, output_dir: str = "data/val"):
|
|||||||
pid = random.getrandbits(32)
|
pid = random.getrandbits(32)
|
||||||
pid_hex = f"{pid:08x}"
|
pid_hex = f"{pid:08x}"
|
||||||
|
|
||||||
# Generate image
|
|
||||||
img_bgr = generate_high_fidelity_spinda(pid)
|
img_bgr = generate_high_fidelity_spinda(pid)
|
||||||
img_path = os.path.join(output_dir, f"sample_{i:04d}.png")
|
img_path = os.path.join(output_dir, f"sample_{i:04d}.png")
|
||||||
cv2.imwrite(img_path, img_bgr)
|
cv2.imwrite(img_path, img_bgr)
|
||||||
|
|
||||||
# Extract raw coordinates (0-15)
|
|
||||||
raw_coords = [
|
|
||||||
int(pid_hex[-1], 16), int(pid_hex[-2], 16),
|
|
||||||
int(pid_hex[-3], 16), int(pid_hex[-4], 16),
|
|
||||||
int(pid_hex[3], 16), int(pid_hex[2], 16),
|
|
||||||
int(pid_hex[1], 16), int(pid_hex[0], 16)
|
|
||||||
]
|
|
||||||
|
|
||||||
metadata.append({
|
metadata.append({
|
||||||
"img_path": img_path,
|
"img_path": img_path,
|
||||||
"pid_hex": pid_hex,
|
"pid_hex": pid_hex,
|
||||||
"target": raw_coords
|
"target": SpindaResolver.pid_to_coords(pid_hex),
|
||||||
})
|
})
|
||||||
|
|
||||||
with open(os.path.join(output_dir, "metadata.json"), "w") as f:
|
with open(os.path.join(output_dir, "metadata.json"), "w") as f:
|
||||||
@@ -41,7 +34,7 @@ def generate_fixed_val_set(size: int = 1000, output_dir: str = "data/val"):
|
|||||||
|
|
||||||
print(f"Validation set generated in {output_dir}")
|
print(f"Validation set generated in {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Seed for reproducibility of the validation set itself
|
|
||||||
random.seed(42)
|
random.seed(42)
|
||||||
generate_fixed_val_set()
|
generate_fixed_val_set()
|
||||||
|
|||||||
@@ -1,81 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
from typing import Tuple, List
|
|
||||||
|
|
||||||
# Image and Grid constants
|
|
||||||
IMG_SIZE = 128
|
|
||||||
GRID_SIZE = 16
|
|
||||||
SPOT_RADIUS = 6 # Approximate radius in pixels
|
|
||||||
|
|
||||||
# Colors (BGR)
|
|
||||||
FACE_COLOR = (180, 220, 240) # Pale cream/tan
|
|
||||||
SPOT_COLOR = (50, 50, 200) # Reddish-orange
|
|
||||||
EYE_COLOR = (0, 0, 0)
|
|
||||||
|
|
||||||
def extract_coords(pid: int, mode: str = "standard") -> List[Tuple[int, int]]:
|
|
||||||
"""Extracts four (x, y) coordinates from a 32-bit PID/EC.
|
|
||||||
|
|
||||||
Standard (Little-Endian): Byte 0 (LL), 1 (LR), 2 (UL), 3 (UR)
|
|
||||||
BDSP (Big-Endian): Byte 3 (LL), 2 (LR), 1 (UL), 0 (UR)
|
|
||||||
"""
|
|
||||||
bytes_list = [
|
|
||||||
(pid >> 0) & 0xFF, # Byte 0
|
|
||||||
(pid >> 8) & 0xFF, # Byte 1
|
|
||||||
(pid >> 16) & 0xFF, # Byte 2
|
|
||||||
(pid >> 24) & 0xFF, # Byte 3
|
|
||||||
]
|
|
||||||
|
|
||||||
if mode == "bdsp":
|
|
||||||
# BDSP reads the bytes in reverse order
|
|
||||||
ordered_bytes = [bytes_list[3], bytes_list[2], bytes_list[1], bytes_list[0]]
|
|
||||||
else:
|
|
||||||
ordered_bytes = bytes_list
|
|
||||||
|
|
||||||
coords = []
|
|
||||||
for byte in ordered_bytes:
|
|
||||||
x = byte & 0x0F
|
|
||||||
y = (byte >> 4) & 0x0F
|
|
||||||
coords.append((x, y))
|
|
||||||
return coords
|
|
||||||
|
|
||||||
def generate_spinda_face(pid: int, mode: str = "standard") -> np.ndarray:
|
|
||||||
"""Generates a procedural Spinda face with spots based on the PID."""
|
|
||||||
# Create blank canvas (white)
|
|
||||||
img = np.ones((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8) * 255
|
|
||||||
|
|
||||||
# Draw Ears
|
|
||||||
cv2.circle(img, (IMG_SIZE // 2 - 40, IMG_SIZE // 2 - 50), 20, FACE_COLOR, -1)
|
|
||||||
cv2.circle(img, (IMG_SIZE // 2 - 40, IMG_SIZE // 2 - 50), 20, (0, 0, 0), 2)
|
|
||||||
cv2.circle(img, (IMG_SIZE // 2 + 40, IMG_SIZE // 2 - 50), 20, FACE_COLOR, -1)
|
|
||||||
cv2.circle(img, (IMG_SIZE // 2 + 40, IMG_SIZE // 2 - 50), 20, (0, 0, 0), 2)
|
|
||||||
|
|
||||||
# Draw main face (oval)
|
|
||||||
center = (IMG_SIZE // 2, IMG_SIZE // 2)
|
|
||||||
axes = (50, 60)
|
|
||||||
cv2.ellipse(img, center, axes, 0, 0, 360, FACE_COLOR, -1)
|
|
||||||
cv2.ellipse(img, center, axes, 0, 0, 360, (0, 0, 0), 2) # Outline
|
|
||||||
|
|
||||||
# Fixed Eyes
|
|
||||||
cv2.circle(img, (center[0] - 15, center[1] - 10), 4, EYE_COLOR, -1)
|
|
||||||
cv2.circle(img, (center[0] + 15, center[1] - 10), 4, EYE_COLOR, -1)
|
|
||||||
|
|
||||||
# Define Spot Zones (Relative to the 16x16 grid)
|
|
||||||
# Heuristic mapping for a more "Pokemon-like" layout
|
|
||||||
# Spot 1 (LL), Spot 2 (LR), Spot 3 (UL), Spot 4 (UR)
|
|
||||||
# These offsets are designed to place spots in their respective quadrants
|
|
||||||
quadrant_offsets = [
|
|
||||||
(center[0] - 45, center[1] + 5), # LL (Face)
|
|
||||||
(center[0] + 5, center[1] + 5), # LR (Face)
|
|
||||||
(center[0] - 45, center[1] - 55), # UL (Ear/Upper Face)
|
|
||||||
(center[0] + 5, center[1] - 55), # UR (Ear/Upper Face)
|
|
||||||
]
|
|
||||||
|
|
||||||
coords = extract_coords(pid, mode=mode)
|
|
||||||
for i, (x, y) in enumerate(coords):
|
|
||||||
offset_x, offset_y = quadrant_offsets[i]
|
|
||||||
px = int(offset_x + x * 2.5) # Scale grid to quadrant
|
|
||||||
py = int(offset_y + y * 2.5)
|
|
||||||
|
|
||||||
cv2.circle(img, (px, py), SPOT_RADIUS, SPOT_COLOR, -1)
|
|
||||||
|
|
||||||
return img
|
|
||||||
@@ -1,15 +1,14 @@
|
|||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
from PIL import Image
|
|
||||||
import os
|
import os
|
||||||
from typing import List, Tuple
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
IMG_SIZE = 128
|
IMG_SIZE = 128
|
||||||
BASE_PATH = "assets"
|
BASE_PATH = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "assets"))
|
||||||
|
|
||||||
# Offsets from ProfessorRex's Spinda_generator.py (Observer Perspective)
|
# Offsets from ProfessorRex's Spinda_generator.py (Observer Perspective)
|
||||||
# This matches the script's PID_to_Coordinates logic:
|
|
||||||
# TL (Spot 1): Nibble 0, 1 (PID[-1], PID[-2])
|
# TL (Spot 1): Nibble 0, 1 (PID[-1], PID[-2])
|
||||||
# TR (Spot 2): Nibble 2, 3 (PID[-3], PID[-4]) + (24, 1)
|
# TR (Spot 2): Nibble 2, 3 (PID[-3], PID[-4]) + (24, 1)
|
||||||
# BL (Spot 3): Nibble 4, 5 (PID[3], PID[2]) + (6, 18)
|
# BL (Spot 3): Nibble 4, 5 (PID[3], PID[2]) + (6, 18)
|
||||||
@@ -21,10 +20,9 @@ SPOT_BASE_OFFSETS = [
|
|||||||
(18, 19), # Spot 4 (BR)
|
(18, 19), # Spot 4 (BR)
|
||||||
]
|
]
|
||||||
|
|
||||||
def extract_coords_rex(pid_hex: str) -> List[Tuple[int, int]]:
|
|
||||||
"""Extracts coordinates following ProfessorRex's PID_to_Coordinates logic.
|
def extract_coords_rex(pid_hex: str) -> list[tuple[int, int]]:
|
||||||
PID is expected to be an 8-char hex string.
|
"""Extracts coordinates following ProfessorRex's PID_to_Coordinates logic."""
|
||||||
"""
|
|
||||||
pid = pid_hex.lower().zfill(8)
|
pid = pid_hex.lower().zfill(8)
|
||||||
TL = (int(pid[-1], 16), int(pid[-2], 16))
|
TL = (int(pid[-1], 16), int(pid[-2], 16))
|
||||||
TR = (int(pid[-3], 16) + 24, int(pid[-4], 16) + 1)
|
TR = (int(pid[-3], 16) + 24, int(pid[-4], 16) + 1)
|
||||||
@@ -32,6 +30,7 @@ def extract_coords_rex(pid_hex: str) -> List[Tuple[int, int]]:
|
|||||||
BR = (int(pid[1], 16) + 18, int(pid[0], 16) + 19)
|
BR = (int(pid[1], 16) + 18, int(pid[0], 16) + 19)
|
||||||
return [TL, TR, BL, BR]
|
return [TL, TR, BL, BR]
|
||||||
|
|
||||||
|
|
||||||
def generate_high_fidelity_spinda(
|
def generate_high_fidelity_spinda(
|
||||||
pid: int,
|
pid: int,
|
||||||
bg_color: tuple[int, int, int] = (255, 255, 255),
|
bg_color: tuple[int, int, int] = (255, 255, 255),
|
||||||
@@ -45,51 +44,38 @@ def generate_high_fidelity_spinda(
|
|||||||
head_data = np.array(head_img)
|
head_data = np.array(head_img)
|
||||||
|
|
||||||
spot_names = ["Spot_TL.png", "Spot_TR.png", "Spot_BL.png", "Spot_BR.png"]
|
spot_names = ["Spot_TL.png", "Spot_TR.png", "Spot_BL.png", "Spot_BR.png"]
|
||||||
spots = [Image.open(os.path.join(BASE_PATH, name)).convert("RGBA") for name in spot_names]
|
spots = [np.array(Image.open(os.path.join(BASE_PATH, name)).convert("RGBA")) for name in spot_names]
|
||||||
|
|
||||||
# 2. Create Pattern Layer (Integer grid like in Rex's script)
|
# 2. Build pattern grid — mark pixels covered by any spot
|
||||||
W, H = base_img.size
|
W, H = base_img.size
|
||||||
pattern_grid = np.zeros((H, W), dtype=np.uint8)
|
pattern_grid = np.zeros((H, W), dtype=bool)
|
||||||
|
|
||||||
coords = extract_coords_rex(pid_hex)
|
for spot_arr, (px_start, py_start) in zip(spots, extract_coords_rex(pid_hex)):
|
||||||
|
|
||||||
for i, (px_start, py_start) in enumerate(coords):
|
|
||||||
spot_arr = np.array(spots[i])
|
|
||||||
sh, sw = spot_arr.shape[:2]
|
sh, sw = spot_arr.shape[:2]
|
||||||
|
active = spot_arr[:, :, 0] > 200 # white pixels in the spot mask
|
||||||
|
ys, xs = np.where(active)
|
||||||
|
tx = px_start + xs
|
||||||
|
ty = py_start + ys
|
||||||
|
valid = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H)
|
||||||
|
pattern_grid[ty[valid], tx[valid]] = True
|
||||||
|
|
||||||
for sy in range(sh):
|
# 3. Colourize — copy head colours onto active spot pixels
|
||||||
for sx in range(sw):
|
|
||||||
# If spot pixel is white (active) in the mask
|
|
||||||
# Check for white pixels (R,G,B > 200)
|
|
||||||
if spot_arr[sy, sx, 0] > 200:
|
|
||||||
tx, ty = px_start + sx, py_start + sy
|
|
||||||
if 0 <= tx < W and 0 <= ty < H:
|
|
||||||
# Mark as active
|
|
||||||
pattern_grid[ty, tx] = 1
|
|
||||||
|
|
||||||
# 3. Colourize
|
|
||||||
# Create an empty RGBA layer for the spots
|
|
||||||
spot_layer = np.zeros((H, W, 4), dtype=np.uint8)
|
spot_layer = np.zeros((H, W, 4), dtype=np.uint8)
|
||||||
for y in range(H):
|
spot_layer[pattern_grid] = head_data[pattern_grid]
|
||||||
for x in range(W):
|
|
||||||
if pattern_grid[y, x] > 0:
|
|
||||||
# Take color from Spinda_Head.png
|
|
||||||
spot_layer[y, x] = head_data[y, x]
|
|
||||||
|
|
||||||
# 4. Composite
|
# 4. Composite
|
||||||
spot_layer_img = Image.fromarray(spot_layer, "RGBA")
|
combined = Image.alpha_composite(base_img, Image.fromarray(spot_layer, "RGBA"))
|
||||||
combined = Image.alpha_composite(base_img, spot_layer_img)
|
|
||||||
|
|
||||||
# 5. Final Canvas (128x128)
|
# 5. Final Canvas (128×128)
|
||||||
final_img = Image.new("RGBA", (IMG_SIZE, IMG_SIZE), (*bg_color, 255))
|
final_img = Image.new("RGBA", (IMG_SIZE, IMG_SIZE), (*bg_color, 255))
|
||||||
offset = ((IMG_SIZE - W) // 2, (IMG_SIZE - H) // 2)
|
offset = ((IMG_SIZE - W) // 2, (IMG_SIZE - H) // 2)
|
||||||
final_img.paste(combined, offset, combined)
|
final_img.paste(combined, offset, combined)
|
||||||
|
|
||||||
return cv2.cvtColor(np.array(final_img), cv2.COLOR_RGBA2BGR)
|
return cv2.cvtColor(np.array(final_img), cv2.COLOR_RGBA2BGR)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Test with 0x12345678
|
|
||||||
test_pid = 0x12345678
|
test_pid = 0x12345678
|
||||||
img = generate_high_fidelity_spinda(test_pid)
|
img = generate_high_fidelity_spinda(test_pid)
|
||||||
cv2.imwrite("sample_high_fidelity_v3.png", img)
|
cv2.imwrite("sample_spinda.png", img)
|
||||||
print(f"Corrected High-fidelity sample saved to sample_high_fidelity_v3.png for PID: {hex(test_pid)}")
|
print(f"Sample saved to sample_spinda.png for PID: {hex(test_pid)}")
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
import cv2
|
|
||||||
from src.data.generator import generate_spinda_face
|
|
||||||
|
|
||||||
def save_sample_image():
|
|
||||||
# Example PID (from documentation/common examples if possible)
|
|
||||||
# Let's use a PID that should have distinct spots
|
|
||||||
# 0x12345678 ->
|
|
||||||
# Byte 0: 0x78 (X=8, Y=7)
|
|
||||||
# Byte 1: 0x56 (X=6, Y=5)
|
|
||||||
# Byte 2: 0x34 (X=4, Y=3)
|
|
||||||
# Byte 3: 0x12 (X=2, Y=1)
|
|
||||||
test_pid = 0x12345678
|
|
||||||
|
|
||||||
img = generate_spinda_face(test_pid)
|
|
||||||
|
|
||||||
# Save the image
|
|
||||||
cv2.imwrite("sample_spinda.png", img)
|
|
||||||
print(f"Sample Spinda image saved to sample_spinda.png for PID: {hex(test_pid)}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
save_sample_image()
|
|
||||||
@@ -5,7 +5,7 @@ import torch
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from src.models.regression_model import SpindaRegressionModel
|
from src.models.regression_model import SpindaRegressionModel
|
||||||
from src.models.train import SpindaEvalDataset
|
from src.data.dataset import SpindaEvalDataset
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: SpindaEvalDataset) -> None:
|
def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: SpindaEvalDataset) -> None:
|
||||||
@@ -29,8 +29,8 @@ def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: Spinda
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
|
parser.add_argument("--model_path", type=str, default="models/best_resnet34_model.pth")
|
||||||
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
|
parser.add_argument("--backbone", type=str, default="resnet34", choices=["resnet18", "resnet34", "convnext_tiny"])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not os.path.exists(args.model_path):
|
if not os.path.exists(args.model_path):
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.v2 as T
|
import torchvision.transforms.v2 as T
|
||||||
from PIL import Image
|
from typing import List, Tuple, Union
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
from src.models.regression_model import SpindaRegressionModel
|
from src.models.regression_model import SpindaRegressionModel
|
||||||
|
|
||||||
@@ -13,8 +12,8 @@ class SpindaInference:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path: str = "models/best_spinda_model.pth",
|
model_path: str = "models/best_resnet34_model.pth",
|
||||||
backbone: str = "resnet18",
|
backbone: str = "resnet34",
|
||||||
):
|
):
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.model = SpindaRegressionModel(pretrained=False, backbone=backbone)
|
self.model = SpindaRegressionModel(pretrained=False, backbone=backbone)
|
||||||
@@ -29,15 +28,23 @@ class SpindaInference:
|
|||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
])
|
])
|
||||||
|
|
||||||
def predict(self, image_path: str) -> Tuple[List[int], str]:
|
def predict(self, image: Union[str, np.ndarray]) -> Tuple[List[int], str]:
|
||||||
"""Predict the 8 grid coordinates and return them with a fingerprint string.
|
"""Predict the 8 grid coordinates and return them with a fingerprint string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: file path (str) or a BGR numpy array (e.g. from the detector).
|
||||||
Returns:
|
Returns:
|
||||||
grid_coords: list of 8 integers in [0, 15]
|
grid_coords: list of 8 integers in [0, 15]
|
||||||
fingerprint: "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
|
fingerprint: "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
|
||||||
"""
|
"""
|
||||||
img_bgr = cv2.imread(image_path)
|
if isinstance(image, str):
|
||||||
|
img_bgr = cv2.imread(image)
|
||||||
|
if img_bgr is None:
|
||||||
|
raise FileNotFoundError(f"Image not found: {image}")
|
||||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||||
|
else:
|
||||||
|
img_rgb = image[:, :, ::-1].copy()
|
||||||
|
|
||||||
img_tensor = torch.from_numpy(np.array(img_rgb)).permute(2, 0, 1)
|
img_tensor = torch.from_numpy(np.array(img_rgb)).permute(2, 0, 1)
|
||||||
img_tensor = self.transform(img_tensor).unsqueeze(0).to(self.device)
|
img_tensor = self.transform(img_tensor).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class SpindaRegressionModel(nn.Module):
|
|||||||
Prediction: output.argmax(dim=2) → (B, 8) integer coordinates.
|
Prediction: output.argmax(dim=2) → (B, 8) integer coordinates.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pretrained: bool = True, backbone: str = "resnet18"):
|
def __init__(self, pretrained: bool = True, backbone: str = "resnet34"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if backbone not in _BACKBONES:
|
if backbone not in _BACKBONES:
|
||||||
raise ValueError(f"backbone must be one of {list(_BACKBONES)}; got {backbone!r}")
|
raise ValueError(f"backbone must be one of {list(_BACKBONES)}; got {backbone!r}")
|
||||||
|
|||||||
@@ -1,49 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torchvision.transforms.v2 as T
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from src.data.dataset import SpindaDataset, get_default_augmentations
|
from src.data.dataset import SpindaDataset, SpindaEvalDataset, get_default_augmentations
|
||||||
from src.models.regression_model import SpindaRegressionModel
|
from src.models.regression_model import SpindaRegressionModel
|
||||||
|
|
||||||
|
|
||||||
class SpindaEvalDataset(Dataset):
|
|
||||||
"""Loads a fixed evaluation set (clean val or augmented test).
|
|
||||||
|
|
||||||
Images are stored post-augmentation, pre-normalisation; this class
|
|
||||||
applies only the normalisation step so the on-disk images are stable.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, data_dir: str):
|
|
||||||
with open(os.path.join(data_dir, "metadata.json")) as f:
|
|
||||||
self.metadata = json.load(f)
|
|
||||||
self.transform = T.Compose([
|
|
||||||
T.ToDtype(torch.float32, scale=True),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.metadata)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
item = self.metadata[idx]
|
|
||||||
img_bgr = cv2.imread(item["img_path"])
|
|
||||||
img_rgb = img_bgr[:, :, ::-1].copy()
|
|
||||||
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1)
|
|
||||||
img_tensor = self.transform(img_tensor)
|
|
||||||
target = torch.tensor(item["target"], dtype=torch.long)
|
|
||||||
return img_tensor, target
|
|
||||||
|
|
||||||
|
|
||||||
# BL_x (index 4) and BL_y (index 5) are the weakest coordinates; upweight them
|
# BL_x (index 4) and BL_y (index 5) are the weakest coordinates; upweight them
|
||||||
# so the optimiser focuses more of its gradient on the hardest spot.
|
# so the optimiser focuses more of its gradient on the hardest spot.
|
||||||
_COORD_WEIGHTS = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.5, 2.5, 1.0, 1.0])
|
_COORD_WEIGHTS = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.5, 2.5, 1.0, 1.0])
|
||||||
@@ -73,7 +41,7 @@ def train_model(
|
|||||||
model_path: str = "models/best_spinda_model.pth",
|
model_path: str = "models/best_spinda_model.pth",
|
||||||
num_workers: int = 4,
|
num_workers: int = 4,
|
||||||
epoch_size: int = 200000,
|
epoch_size: int = 200000,
|
||||||
backbone: str = "resnet18",
|
backbone: str = "resnet34",
|
||||||
save_path: str = "",
|
save_path: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
@@ -176,8 +144,7 @@ def train_model(
|
|||||||
if exact_rate > best_exact_rate:
|
if exact_rate > best_exact_rate:
|
||||||
best_exact_rate = exact_rate
|
best_exact_rate = exact_rate
|
||||||
epochs_without_improvement = 0
|
epochs_without_improvement = 0
|
||||||
os.makedirs("models", exist_ok=True)
|
os.makedirs(os.path.dirname(checkpoint_path) or ".", exist_ok=True)
|
||||||
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
|
||||||
torch.save(model.state_dict(), checkpoint_path)
|
torch.save(model.state_dict(), checkpoint_path)
|
||||||
print(f" → Saved best model to {checkpoint_path} (clean val exact match: {best_exact_rate:.2%})")
|
print(f" → Saved best model to {checkpoint_path} (clean val exact match: {best_exact_rate:.2%})")
|
||||||
else:
|
else:
|
||||||
@@ -198,7 +165,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
|
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
|
||||||
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)")
|
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)")
|
||||||
parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch")
|
parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch")
|
||||||
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"])
|
parser.add_argument("--backbone", type=str, default="resnet34", choices=["resnet18", "resnet34", "convnext_tiny"])
|
||||||
parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path")
|
parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,18 @@
|
|||||||
import sqlite3
|
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional, Tuple
|
import sqlite3
|
||||||
|
|
||||||
|
|
||||||
class SpindaRegistry:
|
class SpindaRegistry:
|
||||||
"""Handles the storage and lookup of Spinda PIDs based on their visual fingerprints."""
|
"""Handles the storage and lookup of Spinda PIDs based on their visual fingerprints."""
|
||||||
|
|
||||||
def __init__(self, db_path: str = "data/spinda_registry.db"):
|
def __init__(self, db_path: str = "data/spinda_registry.db"):
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
os.makedirs(os.path.dirname(db_path) or ".", exist_ok=True)
|
||||||
self._init_db()
|
self._init_db()
|
||||||
|
|
||||||
def _init_db(self):
|
def _init_db(self) -> None:
|
||||||
"""Initializes the database with a core table for PIDs and Fingerprints."""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
# Primary table: Mapping the 8-integer fingerprint to PIDs
|
|
||||||
# Fingerprint format: "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
|
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
CREATE TABLE IF NOT EXISTS registry (
|
CREATE TABLE IF NOT EXISTS registry (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
@@ -24,38 +21,34 @@ class SpindaRegistry:
|
|||||||
UNIQUE(fingerprint, pid_hex)
|
UNIQUE(fingerprint, pid_hex)
|
||||||
)
|
)
|
||||||
''')
|
''')
|
||||||
# Index for fast O(1)-style lookups by fingerprint
|
|
||||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_fingerprint ON registry(fingerprint)')
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_fingerprint ON registry(fingerprint)')
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def add_entry(self, fingerprint: str, pid_hex: str):
|
def add_entry(self, fingerprint: str, pid_hex: str) -> None:
|
||||||
"""Adds a new Spinda entry to the registry."""
|
"""Adds a new Spinda entry to the registry (idempotent)."""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
cursor = conn.cursor()
|
conn.execute(
|
||||||
cursor.execute(
|
|
||||||
"INSERT INTO registry (fingerprint, pid_hex) VALUES (?, ?)",
|
"INSERT INTO registry (fingerprint, pid_hex) VALUES (?, ?)",
|
||||||
(fingerprint, pid_hex)
|
(fingerprint, pid_hex),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
except sqlite3.IntegrityError:
|
except sqlite3.IntegrityError:
|
||||||
# Entry already exists
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def lookup_by_fingerprint(self, fingerprint: str) -> List[str]:
|
def lookup_by_fingerprint(self, fingerprint: str) -> list[str]:
|
||||||
"""Returns all PIDs associated with a specific visual fingerprint."""
|
"""Returns all PIDs associated with a specific visual fingerprint."""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.execute(
|
||||||
cursor.execute("SELECT pid_hex FROM registry WHERE fingerprint = ?", (fingerprint,))
|
"SELECT pid_hex FROM registry WHERE fingerprint = ?", (fingerprint,)
|
||||||
results = cursor.fetchall()
|
)
|
||||||
return [row[0] for row in results]
|
return [row[0] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Quick test
|
|
||||||
reg = SpindaRegistry("data/test_registry.db")
|
reg = SpindaRegistry("data/test_registry.db")
|
||||||
test_fp = "00-01-02-03-04-05-06-07"
|
test_fp = "00-01-02-03-04-05-06-07"
|
||||||
test_pid = "ABCDE123"
|
test_pid = "ABCDE123"
|
||||||
|
|
||||||
reg.add_entry(test_fp, test_pid)
|
reg.add_entry(test_fp, test_pid)
|
||||||
matches = reg.lookup_by_fingerprint(test_fp)
|
matches = reg.lookup_by_fingerprint(test_fp)
|
||||||
print(f"Looked up {test_fp}, found PIDs: {matches}")
|
print(f"Looked up {test_fp}, found PIDs: {matches}")
|
||||||
|
|||||||
@@ -1,10 +1,22 @@
|
|||||||
from typing import List, Tuple
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
class SpindaResolver:
|
class SpindaResolver:
|
||||||
"""Mathematically resolves a visual fingerprint back to its possible PIDs."""
|
"""Mathematically resolves a visual fingerprint back to its possible PIDs."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def coordinates_to_pid(coords: List[int], mode: str = "standard") -> str:
|
def pid_to_coords(pid_hex: str) -> list[int]:
|
||||||
|
"""Converts an 8-char hex PID string to the 8 nibble target coordinates."""
|
||||||
|
p = pid_hex.lower().zfill(8)
|
||||||
|
return [
|
||||||
|
int(p[-1], 16), int(p[-2], 16),
|
||||||
|
int(p[-3], 16), int(p[-4], 16),
|
||||||
|
int(p[3], 16), int(p[2], 16),
|
||||||
|
int(p[1], 16), int(p[0], 16),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def coordinates_to_pid(coords: list[int], mode: Literal["standard", "bdsp"] = "standard") -> str:
|
||||||
"""
|
"""
|
||||||
Converts 8 grid coordinates (0-15) back to a 32-bit hex PID.
|
Converts 8 grid coordinates (0-15) back to a 32-bit hex PID.
|
||||||
|
|
||||||
@@ -14,20 +26,19 @@ class SpindaResolver:
|
|||||||
Byte 2 (BL): x=coords[4], y=coords[5]
|
Byte 2 (BL): x=coords[4], y=coords[5]
|
||||||
Byte 3 (BR): x=coords[6], y=coords[7]
|
Byte 3 (BR): x=coords[6], y=coords[7]
|
||||||
"""
|
"""
|
||||||
|
if mode not in ("standard", "bdsp"):
|
||||||
|
raise ValueError(f"mode must be 'standard' or 'bdsp'; got {mode!r}")
|
||||||
|
|
||||||
# Each byte = (Y << 4) | X
|
# Each byte = (Y << 4) | X
|
||||||
bytes_list = []
|
bytes_list = []
|
||||||
for i in range(0, 8, 2):
|
for i in range(0, 8, 2):
|
||||||
x = coords[i]
|
x = coords[i]
|
||||||
y = coords[i+1]
|
y = coords[i + 1]
|
||||||
byte = (y << 4) | x
|
bytes_list.append((y << 4) | x)
|
||||||
bytes_list.append(byte)
|
|
||||||
|
|
||||||
if mode == "bdsp":
|
if mode == "bdsp":
|
||||||
# BDSP reads the bytes in reverse order (Big-Endian style)
|
|
||||||
# So we reverse them back
|
|
||||||
bytes_list = bytes_list[::-1]
|
bytes_list = bytes_list[::-1]
|
||||||
|
|
||||||
# Combine bytes into 32-bit integer
|
|
||||||
pid = 0
|
pid = 0
|
||||||
for i, byte in enumerate(bytes_list):
|
for i, byte in enumerate(bytes_list):
|
||||||
pid |= (byte << (i * 8))
|
pid |= (byte << (i * 8))
|
||||||
@@ -35,18 +46,18 @@ class SpindaResolver:
|
|||||||
return f"{pid:08x}"
|
return f"{pid:08x}"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resolve_fingerprint(fingerprint: str) -> dict:
|
def resolve_fingerprint(fingerprint: str) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Takes a fingerprint string "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
|
Takes a fingerprint string "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
|
||||||
and returns both possible PIDs (Standard and BDSP).
|
and returns both possible PIDs (Standard and BDSP).
|
||||||
"""
|
"""
|
||||||
coords = [int(c) for c in fingerprint.split("-")]
|
coords = [int(c) for c in fingerprint.split("-")]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"standard": SpindaResolver.coordinates_to_pid(coords, mode="standard"),
|
"standard": SpindaResolver.coordinates_to_pid(coords, mode="standard"),
|
||||||
"bdsp": SpindaResolver.coordinates_to_pid(coords, mode="bdsp")
|
"bdsp": SpindaResolver.coordinates_to_pid(coords, mode="bdsp"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Test with a known fingerprint
|
# Test with a known fingerprint
|
||||||
# PID 0x12345678 ->
|
# PID 0x12345678 ->
|
||||||
|
|||||||
Reference in New Issue
Block a user