Refactor/cleanup

This commit is contained in:
alexiondev
2026-05-08 17:18:58 -04:00
parent 799aa9fa3d
commit 1b904e04ea
18 changed files with 214 additions and 357 deletions

View File

@@ -18,7 +18,8 @@
"Bash(.venv/bin/python -m src.models.evaluate)", "Bash(.venv/bin/python -m src.models.evaluate)",
"Bash(.venv/bin/python -m src.models.evaluate --backbone resnet34 --model_path models/best_resnet34_model.pth)", "Bash(.venv/bin/python -m src.models.evaluate --backbone resnet34 --model_path models/best_resnet34_model.pth)",
"Bash(.venv/bin/python identify.py test_image2.jpg --backbone resnet34 --model_path models/best_resnet34_model.pth)", "Bash(.venv/bin/python identify.py test_image2.jpg --backbone resnet34 --model_path models/best_resnet34_model.pth)",
"Bash(.venv/bin/python src/models/regression_model.py)" "Bash(.venv/bin/python src/models/regression_model.py)",
"Bash(xargs cat -n)"
] ]
} }
} }

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
build/** build/**
**/__pycache__/** **/__pycache__/**
models/**

View File

@@ -11,10 +11,10 @@ All commands must be run from the project root using the local venv:
.venv/bin/python identify.py <image_path> .venv/bin/python identify.py <image_path>
# Train the model # Train the model
.venv/bin/python -m src.models.train --epochs 50 --batch_size 64 --lr 1e-4 .venv/bin/python -m src.models.train --epochs 50
# Evaluate a trained model on val and aug_test sets # Evaluate a trained model on val and aug_test sets
.venv/bin/python -m src.models.evaluate [--backbone resnet18|resnet34] [--model_path <path>] .venv/bin/python -m src.models.evaluate [--backbone resnet34] [--model_path <path>]
# Run inference only (no registry lookup) # Run inference only (no registry lookup)
.venv/bin/python src/models/inference.py <image_path> .venv/bin/python src/models/inference.py <image_path>
@@ -26,7 +26,7 @@ All commands must be run from the project root using the local venv:
.venv/bin/python -m src.data.generate_aug_test_set .venv/bin/python -m src.data.generate_aug_test_set
# Generate a single sample image for visual inspection # Generate a single sample image for visual inspection
.venv/bin/python src/data/high_fidelity_generator.py .venv/bin/python src/data/renderer.py
# Lint # Lint
.venv/bin/ruff check src/ .venv/bin/ruff check src/
@@ -50,7 +50,7 @@ Image → Detector → Inference → Resolver → Registry
### 1. Detection (`src/utils/detector.py`) ### 1. Detection (`src/utils/detector.py`)
`SpindaDetector.detect_and_crop()` returns a **128×128 BGR** image, or `None`. `SpindaDetector.detect_and_crop()` returns a **128×128 BGR** numpy array, or `None`.
Two-tier strategy, tried in order: Two-tier strategy, tried in order:
- **Tier 1 (screenshots/sprites):** HSV-filter red pixels → find individual spot blobs → cluster to 4 spots → derive crop from cluster centroid + `_SPOT_CROP_RATIO=5.5` (= 128 / 24.5 px span) with a `_SPOT_CENTER_OFFSET=0.056` downward shift so the spot centroid lands at 44.4 % from the top of the crop (matching the training canvas). - **Tier 1 (screenshots/sprites):** HSV-filter red pixels → find individual spot blobs → cluster to 4 spots → derive crop from cluster centroid + `_SPOT_CROP_RATIO=5.5` (= 128 / 24.5 px span) with a `_SPOT_CENTER_OFFSET=0.056` downward shift so the spot centroid lands at 44.4 % from the top of the crop (matching the training canvas).
@@ -58,15 +58,18 @@ Two-tier strategy, tried in order:
### 2. Model (`src/models/regression_model.py`) ### 2. Model (`src/models/regression_model.py`)
ResNet-18 backbone with the final FC replaced by `Linear(512, 8·16)`. Forward pass returns **(B, 8, 16)** — treating each of the 8 coordinates as a 16-class classification problem. Trained with `CrossEntropyLoss` on `view(-1, 16)` vs `view(-1)` targets; predictions use `argmax(dim=2)`. Configurable backbone (default: ResNet-34) with the final FC replaced by `Linear(feat_dim, 8·16)`. Forward pass returns **(B, 8, 16)** — treating each of the 8 coordinates as a 16-class classification problem. Trained with `CrossEntropyLoss` on `view(-1, 16)` vs `view(-1)` targets; predictions use `argmax(dim=2)`.
Supported backbones: `resnet18` (512-d), `resnet34` (512-d), `convnext_tiny` (768-d).
### 3. Training (`src/models/train.py`) ### 3. Training (`src/models/train.py`)
- `SpindaDataset` (200 k virtual samples/epoch): generates a fresh random 32-bit PID per `__getitem__`, renders the sprite with a random background colour, then applies the full augmentation pipeline. - `SpindaDataset` (200 k virtual samples/epoch): generates a fresh random 32-bit PID per `__getitem__`, renders the sprite with a random background colour, then applies the full augmentation pipeline.
- `SpindaEvalDataset`: loads pre-generated images from disk (post-augmentation, pre-normalisation) and applies only the normalise step. Used for both `data/val/` (clean, seed=42) and `data/aug_test/` (augmented, seed=99). - `SpindaEvalDataset` (in `src/data/dataset.py`): loads pre-generated images from disk (post-augmentation, pre-normalisation) and applies only the normalise step. Used for both `data/val/` (clean, seed=42) and `data/aug_test/` (augmented, seed=99).
- `_worker_init_fn` re-seeds Python `random` and NumPy per worker so forked workers generate distinct PIDs. - `_worker_init_fn` re-seeds Python `random` and NumPy per worker so forked workers generate distinct PIDs.
- Weighted loss: BL_x ×1.5, BL_y ×2.5 — applied during training only; val loss is unweighted.
- Early stopping: patience = 10 epochs on clean-val exact-match rate. - Early stopping: patience = 10 epochs on clean-val exact-match rate.
- Best model checkpoint: `models/best_spinda_model.pth`. - Checkpoints saved to `models/best_{backbone}_model.pth`.
### 4. PID Encoding (domain invariant — must not be changed) ### 4. PID Encoding (domain invariant — must not be changed)
@@ -107,7 +110,8 @@ assets/ # Sprite assets used by the renderer
Spot_{TL,TR,BL,BR}.png Spot_{TL,TR,BL,BR}.png
models/ models/
best_spinda_model.pth best_resnet34_model.pth # current best (default)
best_convnext_tiny_model.pth # convnext experiment
``` ```
`metadata.json` format: `[{"img_path": "...", "pid_hex": "...", "target": [int×8]}, ...]` `metadata.json` format: `[{"img_path": "...", "pid_hex": "...", "target": [int×8]}, ...]`
@@ -115,8 +119,8 @@ models/
## Key invariants ## Key invariants
- **Visual collisions:** ~1.3 % of fingerprints are shared by multiple PIDs (many-to-one mapping). `SpindaRegistry` stores `(fingerprint, pid_hex)` pairs with a unique constraint so `lookup_by_fingerprint` can return *all* matching PIDs — this is intentional, not a bug. - **Visual collisions:** ~1.3 % of fingerprints are shared by multiple PIDs (many-to-one mapping). `SpindaRegistry` stores `(fingerprint, pid_hex)` pairs with a unique constraint so `lookup_by_fingerprint` can return *all* matching PIDs — this is intentional, not a bug.
- The **validation set** uses white backgrounds (no augmentation baked in) to give a stable epoch-comparable baseline. Do not add augmentation to `generate_val_set.py`. - The **validation set** uses white backgrounds (no augmentation baked in) to give a stable epoch-comparable baseline. Do not add augmentation to `generate_val_set.py`.
- The **augmented test set** is pre-generated and fixed. Regenerating it changes the baseline; do so intentionally. - The **augmented test set** is pre-generated and fixed. Regenerating it changes the baseline; do so intentionally.
- The crop output size is always **128×128** regardless of tier. The model transform chain also resizes to 128×128, so the inference path is robust to re-size. - The crop output size is always **128×128** regardless of tier. The model transform chain also resizes to 128×128, so the inference path is robust to re-size.
- `generate_high_fidelity_spinda()` always takes `bg_color` as a `(R, G, B)` tuple in PIL order (not BGR). - `generate_high_fidelity_spinda()` in `src/data/renderer.py` always takes `bg_color` as a `(R, G, B)` tuple in PIL order (not BGR).
- `SpindaInference.predict()` accepts either a file path or a BGR numpy array directly (e.g. from the detector).

View File

@@ -1,70 +1,59 @@
import argparse
import os import os
import sys import sys
import cv2 import cv2
import torch
from src.data.renderer import generate_high_fidelity_spinda
from src.models.inference import SpindaInference from src.models.inference import SpindaInference
from src.utils.resolver import SpindaResolver
from src.registry.database import SpindaRegistry from src.registry.database import SpindaRegistry
from src.data.high_fidelity_generator import generate_high_fidelity_spinda from src.utils.detector import SpindaDetector
from src.utils.detector import SpindaDetector # Import the detector from src.utils.resolver import SpindaResolver
def identify_spinda( def identify_spinda(
image_path: str, image_path: str,
model_path: str = "models/best_spinda_model.pth", model_path: str = "models/best_resnet34_model.pth",
backbone: str = "resnet18", backbone: str = "resnet34",
): ) -> None:
if not os.path.exists(image_path): if not os.path.exists(image_path):
print(f"Error: File {image_path} not found.") print(f"Error: File {image_path} not found.")
return return
print(f"--- Identifying Spinda in {image_path} ---") print(f"--- Identifying Spinda in {image_path} ---")
# 1. Detect and Crop Spinda # 1. Detect and crop
detector = SpindaDetector() detector = SpindaDetector()
cropped_img = detector.detect_and_crop(image_path) cropped_img = detector.detect_and_crop(image_path)
if cropped_img is None: if cropped_img is None:
print("Error: Could not detect Spinda in the image.") print("Error: Could not detect Spinda in the image.")
return return
# Save cropped image for debug/visual check
cv2.imwrite("detected_spinda_crop.png", cropped_img) cv2.imwrite("detected_spinda_crop.png", cropped_img)
print("Detected Spinda saved to detected_spinda_crop.png") print("Detected Spinda saved to detected_spinda_crop.png")
# We need to save the cropped image to a temporary file for the inference model to read # 2. Inference — pass the BGR array directly, no temp file needed
temp_cropped_path = "temp_cropped_spinda.png"
cv2.imwrite(temp_cropped_path, cropped_img)
# 2. Inference (Model Prediction) using the cropped image
try:
inf = SpindaInference(model_path=model_path, backbone=backbone) inf = SpindaInference(model_path=model_path, backbone=backbone)
coords, fingerprint = inf.predict(temp_cropped_path) coords, fingerprint = inf.predict(cropped_img)
except Exception as e:
print(f"Error during inference: {e}")
os.remove(temp_cropped_path) # Clean up temp file
return
finally:
os.remove(temp_cropped_path) # Clean up temp file
print(f"Visual Fingerprint: {fingerprint}") print(f"Visual Fingerprint: {fingerprint}")
print(f"Predicted Grid Coordinates: {coords}") print(f"Predicted Grid Coordinates: {coords}")
# 3. Resolution (Mathematical PIDs) # 3. Resolve to PIDs
resolved = SpindaResolver.resolve_fingerprint(fingerprint) resolved = SpindaResolver.resolve_fingerprint(fingerprint)
print("\nPossible PIDs:") print("\nPossible PIDs:")
print(f" Standard (Gen 3-8, HOME): 0x{resolved['standard']}") print(f" Standard (Gen 3-8, HOME): 0x{resolved['standard']}")
print(f" BDSP (Big-Endian Flip): 0x{resolved['bdsp']}") print(f" BDSP (Big-Endian Flip): 0x{resolved['bdsp']}")
# 4. Visual Verification # 4. Visual verification
print("\nGenerating visual verification image...") print("\nGenerating visual verification image...")
verify_img = generate_high_fidelity_spinda(int(resolved['standard'], 16)) verify_img = generate_high_fidelity_spinda(int(resolved['standard'], 16))
cv2.imwrite("prediction_verify.png", verify_img) cv2.imwrite("prediction_verify.png", verify_img)
print("Verification image saved to: prediction_verify.png") print("Verification image saved to: prediction_verify.png")
# 5. Registry Lookup # 5. Registry lookup
reg = SpindaRegistry() reg = SpindaRegistry()
matches = reg.lookup_by_fingerprint(fingerprint) matches = reg.lookup_by_fingerprint(fingerprint)
if matches: if matches:
print("\nMatches found in Global Registry:") print("\nMatches found in Global Registry:")
for pid in matches: for pid in matches:
@@ -72,13 +61,12 @@ def identify_spinda(
else: else:
print("\nNo matching entries in Global Registry.") print("\nNo matching entries in Global Registry.")
print("\nNote: Accuracy depends on model training progress.")
if __name__ == "__main__": if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("image_path") parser.add_argument("image_path")
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"]) parser.add_argument("--backbone", type=str, default="resnet34",
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") choices=["resnet18", "resnet34", "convnext_tiny"])
parser.add_argument("--model_path", type=str, default="models/best_resnet34_model.pth")
args = parser.parse_args() args = parser.parse_args()
identify_spinda(args.image_path, model_path=args.model_path, backbone=args.backbone) identify_spinda(args.image_path, model_path=args.model_path, backbone=args.backbone)

Binary file not shown.

View File

@@ -9,6 +9,7 @@ dependencies = [
"opencv-python>=4.7.0", "opencv-python>=4.7.0",
"Pillow>=9.5.0", "Pillow>=9.5.0",
"numpy>=1.24.0", "numpy>=1.24.0",
"tqdm>=4.65.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]
@@ -16,7 +17,6 @@ dev = [
"pytest>=7.3.0", "pytest>=7.3.0",
"ruff>=0.0.270", "ruff>=0.0.270",
"mypy>=1.3.0", "mypy>=1.3.0",
"tqdm>=4.65.0",
] ]
[tool.ruff] [tool.ruff]

View File

@@ -1,69 +1,52 @@
import torch import json
from torch.utils.data import Dataset import os
import torchvision.transforms.v2 as T
import numpy as np
from typing import Tuple, Optional
import random import random
from typing import Optional
import cv2
import torch
import torchvision.transforms.v2 as T
from torch.utils.data import Dataset
from src.data.renderer import generate_high_fidelity_spinda
from src.utils.resolver import SpindaResolver
from src.data.high_fidelity_generator import generate_high_fidelity_spinda
class SpindaDataset(Dataset): class SpindaDataset(Dataset):
"""PyTorch Dataset for generating synthetic Spinda samples with augmentations.""" """PyTorch Dataset for generating synthetic Spinda samples with augmentations."""
def __init__(self, size: int = 10000, transform: Optional[T.Transform] = None): def __init__(self, size: int = 10000, transform: Optional[T.Transform] = None):
"""
Args:
size: Virtual size of the dataset (since it's synthetic).
transform: Optional torchvision transforms to apply.
"""
self.size = size self.size = size
self.transform = transform self.transform = transform
def __len__(self) -> int: def __len__(self) -> int:
return self.size return self.size
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
# Generate a random 32-bit PID
pid = random.getrandbits(32) pid = random.getrandbits(32)
pid_hex = f"{pid:08x}" pid_hex = f"{pid:08x}"
# 1. Generate High-Fidelity Image on a random background colour
r = random.randint(0, 255) r = random.randint(0, 255)
g = random.randint(0, 255) g = random.randint(0, 255)
b = random.randint(0, 255) b = random.randint(0, 255)
img_bgr = generate_high_fidelity_spinda(pid, bg_color=(r, g, b)) img_bgr = generate_high_fidelity_spinda(pid, bg_color=(r, g, b))
# Convert BGR to RGB for PyTorch/Torchvision
img_rgb = img_bgr[:, :, ::-1].copy() img_rgb = img_bgr[:, :, ::-1].copy()
# 2. Get Ground Truth Coordinates (Target) target_tensor = torch.tensor(SpindaResolver.pid_to_coords(pid_hex), dtype=torch.long)
# Raw nibble values (0-15) for each spot, in TL/TR/BL/BR order.
raw_coords = []
# TL (Spot 1): Nibble 0, 1 (PID[-1], PID[-2])
raw_coords.extend([int(pid_hex[-1], 16), int(pid_hex[-2], 16)])
# TR (Spot 2): Nibble 2, 3 (PID[-3], PID[-4])
raw_coords.extend([int(pid_hex[-3], 16), int(pid_hex[-4], 16)])
# BL (Spot 3): Nibble 4, 5 (PID[3], PID[2])
raw_coords.extend([int(pid_hex[3], 16), int(pid_hex[2], 16)])
# BR (Spot 4): Nibble 6, 7 (PID[1], PID[0])
raw_coords.extend([int(pid_hex[1], 16), int(pid_hex[0], 16)])
# Integer labels in [0, 15] — used with CrossEntropyLoss.
target_tensor = torch.tensor(raw_coords, dtype=torch.long)
# 3. Apply Transforms
if self.transform: if self.transform:
# Convert to PIL or Tensor first if needed by transform img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1)
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1) # C, H, W
img_tensor = self.transform(img_tensor) img_tensor = self.transform(img_tensor)
else: else:
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0 img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0
return img_tensor, target_tensor return img_tensor, target_tensor
def add_gaussian_noise(x: torch.Tensor) -> torch.Tensor: def add_gaussian_noise(x: torch.Tensor) -> torch.Tensor:
return (x + torch.randn_like(x) * 0.05).clamp(0, 1) return (x + torch.randn_like(x) * 0.05).clamp(0, 1)
def add_scan_lines(x: torch.Tensor) -> torch.Tensor: def add_scan_lines(x: torch.Tensor) -> torch.Tensor:
"""Simulate LCD scan lines seen in handheld-camera photos of 3DS screens.""" """Simulate LCD scan lines seen in handheld-camera photos of 3DS screens."""
if torch.rand(1).item() < 0.5: if torch.rand(1).item() < 0.5:
@@ -72,26 +55,52 @@ def add_scan_lines(x: torch.Tensor) -> torch.Tensor:
x[:, ::2, :] *= (1.0 - strength) x[:, ::2, :] *= (1.0 - strength)
return x.clamp(0, 1) return x.clamp(0, 1)
def get_default_augmentations() -> T.Compose: def get_default_augmentations() -> T.Compose:
"""Domain randomisation pipeline calibrated for real handheld-photo conditions.""" """Domain randomisation pipeline calibrated for real handheld-photo conditions."""
return T.Compose([ return T.Compose([
T.ToDtype(torch.float32, scale=True), T.ToDtype(torch.float32, scale=True),
# Spatial — wider range to cover camera angle and zoom variation
T.RandomAffine(degrees=25, translate=(0.05, 0.05), shear=8), T.RandomAffine(degrees=25, translate=(0.05, 0.05), shear=8),
T.RandomResizedCrop(size=(128, 128), scale=(0.75, 1.0), ratio=(0.85, 1.15), antialias=True), T.RandomResizedCrop(size=(128, 128), scale=(0.75, 1.0), ratio=(0.85, 1.15), antialias=True),
# Colour / sensor — stronger to cover screen glare and ambient lighting
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.05), T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.05),
T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)), T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)),
# Sensor noise and LCD scan lines
T.Lambda(add_gaussian_noise), T.Lambda(add_gaussian_noise),
T.Lambda(add_scan_lines), T.Lambda(add_scan_lines),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]) ])
class SpindaEvalDataset(Dataset):
"""Loads a fixed evaluation set (clean val or augmented test).
Images are stored post-augmentation, pre-normalisation; this class
applies only the normalisation step so the on-disk images are stable.
"""
def __init__(self, data_dir: str):
with open(os.path.join(data_dir, "metadata.json")) as f:
self.metadata = json.load(f)
self.transform = T.Compose([
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __len__(self) -> int:
return len(self.metadata)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
item = self.metadata[idx]
img_bgr = cv2.imread(item["img_path"])
if img_bgr is None:
raise FileNotFoundError(f"Image not found: {item['img_path']}")
img_rgb = img_bgr[:, :, ::-1].copy()
img_tensor = self.transform(torch.from_numpy(img_rgb).permute(2, 0, 1))
target = torch.tensor(item["target"], dtype=torch.long)
return img_tensor, target
if __name__ == "__main__": if __name__ == "__main__":
# Test the dataset
ds = SpindaDataset(size=5, transform=get_default_augmentations()) ds = SpindaDataset(size=5, transform=get_default_augmentations())
img, target = ds[0] img, target = ds[0]
print(f"Image shape: {img.shape}") print(f"Image shape: {img.shape}")
print(f"Target (normalized 0-1): {target}") print(f"Target: {target.tolist()}")
print(f"Target (grid units): {target * 15.0}")

View File

@@ -15,7 +15,8 @@ import torchvision.transforms.v2 as T
from tqdm import tqdm from tqdm import tqdm
from src.data.dataset import add_gaussian_noise, add_scan_lines from src.data.dataset import add_gaussian_noise, add_scan_lines
from src.data.high_fidelity_generator import generate_high_fidelity_spinda from src.data.renderer import generate_high_fidelity_spinda
from src.utils.resolver import SpindaResolver
# Augmentation without the final Normalize step — we bake everything else in. # Augmentation without the final Normalize step — we bake everything else in.
@@ -53,13 +54,11 @@ def generate_aug_test_set(size: int = 500, output_dir: str = "data/aug_test") ->
img_path = os.path.join(output_dir, f"sample_{i:04d}.png") img_path = os.path.join(output_dir, f"sample_{i:04d}.png")
cv2.imwrite(img_path, aug_bgr) cv2.imwrite(img_path, aug_bgr)
raw_coords = [ metadata.append({
int(pid_hex[-1], 16), int(pid_hex[-2], 16), "img_path": img_path,
int(pid_hex[-3], 16), int(pid_hex[-4], 16), "pid_hex": pid_hex,
int(pid_hex[3], 16), int(pid_hex[2], 16), "target": SpindaResolver.pid_to_coords(pid_hex),
int(pid_hex[1], 16), int(pid_hex[0], 16), })
]
metadata.append({"img_path": img_path, "pid_hex": pid_hex, "target": raw_coords})
with open(os.path.join(output_dir, "metadata.json"), "w") as f: with open(os.path.join(output_dir, "metadata.json"), "w") as f:
json.dump(metadata, f, indent=4) json.dump(metadata, f, indent=4)

View File

@@ -1,15 +1,17 @@
import os
import torch
import json import json
import os
import random import random
from tqdm import tqdm
from src.data.high_fidelity_generator import generate_high_fidelity_spinda
import cv2
def generate_fixed_val_set(size: int = 1000, output_dir: str = "data/val"): import cv2
from tqdm import tqdm
from src.data.renderer import generate_high_fidelity_spinda
from src.utils.resolver import SpindaResolver
def generate_fixed_val_set(size: int = 1000, output_dir: str = "data/val") -> None:
"""Generates a fixed set of Spinda images and their targets for validation.""" """Generates a fixed set of Spinda images and their targets for validation."""
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
metadata = [] metadata = []
print(f"Generating {size} validation samples...") print(f"Generating {size} validation samples...")
@@ -17,23 +19,14 @@ def generate_fixed_val_set(size: int = 1000, output_dir: str = "data/val"):
pid = random.getrandbits(32) pid = random.getrandbits(32)
pid_hex = f"{pid:08x}" pid_hex = f"{pid:08x}"
# Generate image
img_bgr = generate_high_fidelity_spinda(pid) img_bgr = generate_high_fidelity_spinda(pid)
img_path = os.path.join(output_dir, f"sample_{i:04d}.png") img_path = os.path.join(output_dir, f"sample_{i:04d}.png")
cv2.imwrite(img_path, img_bgr) cv2.imwrite(img_path, img_bgr)
# Extract raw coordinates (0-15)
raw_coords = [
int(pid_hex[-1], 16), int(pid_hex[-2], 16),
int(pid_hex[-3], 16), int(pid_hex[-4], 16),
int(pid_hex[3], 16), int(pid_hex[2], 16),
int(pid_hex[1], 16), int(pid_hex[0], 16)
]
metadata.append({ metadata.append({
"img_path": img_path, "img_path": img_path,
"pid_hex": pid_hex, "pid_hex": pid_hex,
"target": raw_coords "target": SpindaResolver.pid_to_coords(pid_hex),
}) })
with open(os.path.join(output_dir, "metadata.json"), "w") as f: with open(os.path.join(output_dir, "metadata.json"), "w") as f:
@@ -41,7 +34,7 @@ def generate_fixed_val_set(size: int = 1000, output_dir: str = "data/val"):
print(f"Validation set generated in {output_dir}") print(f"Validation set generated in {output_dir}")
if __name__ == "__main__": if __name__ == "__main__":
# Seed for reproducibility of the validation set itself
random.seed(42) random.seed(42)
generate_fixed_val_set() generate_fixed_val_set()

View File

@@ -1,81 +0,0 @@
import numpy as np
import cv2
from typing import Tuple, List
# Image and Grid constants
IMG_SIZE = 128
GRID_SIZE = 16
SPOT_RADIUS = 6 # Approximate radius in pixels
# Colors (BGR)
FACE_COLOR = (180, 220, 240) # Pale cream/tan
SPOT_COLOR = (50, 50, 200) # Reddish-orange
EYE_COLOR = (0, 0, 0)
def extract_coords(pid: int, mode: str = "standard") -> List[Tuple[int, int]]:
"""Extracts four (x, y) coordinates from a 32-bit PID/EC.
Standard (Little-Endian): Byte 0 (LL), 1 (LR), 2 (UL), 3 (UR)
BDSP (Big-Endian): Byte 3 (LL), 2 (LR), 1 (UL), 0 (UR)
"""
bytes_list = [
(pid >> 0) & 0xFF, # Byte 0
(pid >> 8) & 0xFF, # Byte 1
(pid >> 16) & 0xFF, # Byte 2
(pid >> 24) & 0xFF, # Byte 3
]
if mode == "bdsp":
# BDSP reads the bytes in reverse order
ordered_bytes = [bytes_list[3], bytes_list[2], bytes_list[1], bytes_list[0]]
else:
ordered_bytes = bytes_list
coords = []
for byte in ordered_bytes:
x = byte & 0x0F
y = (byte >> 4) & 0x0F
coords.append((x, y))
return coords
def generate_spinda_face(pid: int, mode: str = "standard") -> np.ndarray:
"""Generates a procedural Spinda face with spots based on the PID."""
# Create blank canvas (white)
img = np.ones((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8) * 255
# Draw Ears
cv2.circle(img, (IMG_SIZE // 2 - 40, IMG_SIZE // 2 - 50), 20, FACE_COLOR, -1)
cv2.circle(img, (IMG_SIZE // 2 - 40, IMG_SIZE // 2 - 50), 20, (0, 0, 0), 2)
cv2.circle(img, (IMG_SIZE // 2 + 40, IMG_SIZE // 2 - 50), 20, FACE_COLOR, -1)
cv2.circle(img, (IMG_SIZE // 2 + 40, IMG_SIZE // 2 - 50), 20, (0, 0, 0), 2)
# Draw main face (oval)
center = (IMG_SIZE // 2, IMG_SIZE // 2)
axes = (50, 60)
cv2.ellipse(img, center, axes, 0, 0, 360, FACE_COLOR, -1)
cv2.ellipse(img, center, axes, 0, 0, 360, (0, 0, 0), 2) # Outline
# Fixed Eyes
cv2.circle(img, (center[0] - 15, center[1] - 10), 4, EYE_COLOR, -1)
cv2.circle(img, (center[0] + 15, center[1] - 10), 4, EYE_COLOR, -1)
# Define Spot Zones (Relative to the 16x16 grid)
# Heuristic mapping for a more "Pokemon-like" layout
# Spot 1 (LL), Spot 2 (LR), Spot 3 (UL), Spot 4 (UR)
# These offsets are designed to place spots in their respective quadrants
quadrant_offsets = [
(center[0] - 45, center[1] + 5), # LL (Face)
(center[0] + 5, center[1] + 5), # LR (Face)
(center[0] - 45, center[1] - 55), # UL (Ear/Upper Face)
(center[0] + 5, center[1] - 55), # UR (Ear/Upper Face)
]
coords = extract_coords(pid, mode=mode)
for i, (x, y) in enumerate(coords):
offset_x, offset_y = quadrant_offsets[i]
px = int(offset_x + x * 2.5) # Scale grid to quadrant
py = int(offset_y + y * 2.5)
cv2.circle(img, (px, py), SPOT_RADIUS, SPOT_COLOR, -1)
return img

View File

@@ -1,15 +1,14 @@
import numpy as np
import cv2
from PIL import Image
import os import os
from typing import List, Tuple
import cv2
import numpy as np
from PIL import Image
# Constants # Constants
IMG_SIZE = 128 IMG_SIZE = 128
BASE_PATH = "assets" BASE_PATH = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "assets"))
# Offsets from ProfessorRex's Spinda_generator.py (Observer Perspective) # Offsets from ProfessorRex's Spinda_generator.py (Observer Perspective)
# This matches the script's PID_to_Coordinates logic:
# TL (Spot 1): Nibble 0, 1 (PID[-1], PID[-2]) # TL (Spot 1): Nibble 0, 1 (PID[-1], PID[-2])
# TR (Spot 2): Nibble 2, 3 (PID[-3], PID[-4]) + (24, 1) # TR (Spot 2): Nibble 2, 3 (PID[-3], PID[-4]) + (24, 1)
# BL (Spot 3): Nibble 4, 5 (PID[3], PID[2]) + (6, 18) # BL (Spot 3): Nibble 4, 5 (PID[3], PID[2]) + (6, 18)
@@ -21,10 +20,9 @@ SPOT_BASE_OFFSETS = [
(18, 19), # Spot 4 (BR) (18, 19), # Spot 4 (BR)
] ]
def extract_coords_rex(pid_hex: str) -> List[Tuple[int, int]]:
"""Extracts coordinates following ProfessorRex's PID_to_Coordinates logic. def extract_coords_rex(pid_hex: str) -> list[tuple[int, int]]:
PID is expected to be an 8-char hex string. """Extracts coordinates following ProfessorRex's PID_to_Coordinates logic."""
"""
pid = pid_hex.lower().zfill(8) pid = pid_hex.lower().zfill(8)
TL = (int(pid[-1], 16), int(pid[-2], 16)) TL = (int(pid[-1], 16), int(pid[-2], 16))
TR = (int(pid[-3], 16) + 24, int(pid[-4], 16) + 1) TR = (int(pid[-3], 16) + 24, int(pid[-4], 16) + 1)
@@ -32,6 +30,7 @@ def extract_coords_rex(pid_hex: str) -> List[Tuple[int, int]]:
BR = (int(pid[1], 16) + 18, int(pid[0], 16) + 19) BR = (int(pid[1], 16) + 18, int(pid[0], 16) + 19)
return [TL, TR, BL, BR] return [TL, TR, BL, BR]
def generate_high_fidelity_spinda( def generate_high_fidelity_spinda(
pid: int, pid: int,
bg_color: tuple[int, int, int] = (255, 255, 255), bg_color: tuple[int, int, int] = (255, 255, 255),
@@ -45,51 +44,38 @@ def generate_high_fidelity_spinda(
head_data = np.array(head_img) head_data = np.array(head_img)
spot_names = ["Spot_TL.png", "Spot_TR.png", "Spot_BL.png", "Spot_BR.png"] spot_names = ["Spot_TL.png", "Spot_TR.png", "Spot_BL.png", "Spot_BR.png"]
spots = [Image.open(os.path.join(BASE_PATH, name)).convert("RGBA") for name in spot_names] spots = [np.array(Image.open(os.path.join(BASE_PATH, name)).convert("RGBA")) for name in spot_names]
# 2. Create Pattern Layer (Integer grid like in Rex's script) # 2. Build pattern grid — mark pixels covered by any spot
W, H = base_img.size W, H = base_img.size
pattern_grid = np.zeros((H, W), dtype=np.uint8) pattern_grid = np.zeros((H, W), dtype=bool)
coords = extract_coords_rex(pid_hex) for spot_arr, (px_start, py_start) in zip(spots, extract_coords_rex(pid_hex)):
for i, (px_start, py_start) in enumerate(coords):
spot_arr = np.array(spots[i])
sh, sw = spot_arr.shape[:2] sh, sw = spot_arr.shape[:2]
active = spot_arr[:, :, 0] > 200 # white pixels in the spot mask
ys, xs = np.where(active)
tx = px_start + xs
ty = py_start + ys
valid = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H)
pattern_grid[ty[valid], tx[valid]] = True
for sy in range(sh): # 3. Colourize — copy head colours onto active spot pixels
for sx in range(sw):
# If spot pixel is white (active) in the mask
# Check for white pixels (R,G,B > 200)
if spot_arr[sy, sx, 0] > 200:
tx, ty = px_start + sx, py_start + sy
if 0 <= tx < W and 0 <= ty < H:
# Mark as active
pattern_grid[ty, tx] = 1
# 3. Colourize
# Create an empty RGBA layer for the spots
spot_layer = np.zeros((H, W, 4), dtype=np.uint8) spot_layer = np.zeros((H, W, 4), dtype=np.uint8)
for y in range(H): spot_layer[pattern_grid] = head_data[pattern_grid]
for x in range(W):
if pattern_grid[y, x] > 0:
# Take color from Spinda_Head.png
spot_layer[y, x] = head_data[y, x]
# 4. Composite # 4. Composite
spot_layer_img = Image.fromarray(spot_layer, "RGBA") combined = Image.alpha_composite(base_img, Image.fromarray(spot_layer, "RGBA"))
combined = Image.alpha_composite(base_img, spot_layer_img)
# 5. Final Canvas (128x128) # 5. Final Canvas (128×128)
final_img = Image.new("RGBA", (IMG_SIZE, IMG_SIZE), (*bg_color, 255)) final_img = Image.new("RGBA", (IMG_SIZE, IMG_SIZE), (*bg_color, 255))
offset = ((IMG_SIZE - W) // 2, (IMG_SIZE - H) // 2) offset = ((IMG_SIZE - W) // 2, (IMG_SIZE - H) // 2)
final_img.paste(combined, offset, combined) final_img.paste(combined, offset, combined)
return cv2.cvtColor(np.array(final_img), cv2.COLOR_RGBA2BGR) return cv2.cvtColor(np.array(final_img), cv2.COLOR_RGBA2BGR)
if __name__ == "__main__": if __name__ == "__main__":
# Test with 0x12345678
test_pid = 0x12345678 test_pid = 0x12345678
img = generate_high_fidelity_spinda(test_pid) img = generate_high_fidelity_spinda(test_pid)
cv2.imwrite("sample_high_fidelity_v3.png", img) cv2.imwrite("sample_spinda.png", img)
print(f"Corrected High-fidelity sample saved to sample_high_fidelity_v3.png for PID: {hex(test_pid)}") print(f"Sample saved to sample_spinda.png for PID: {hex(test_pid)}")

View File

@@ -1,21 +0,0 @@
import cv2
from src.data.generator import generate_spinda_face
def save_sample_image():
# Example PID (from documentation/common examples if possible)
# Let's use a PID that should have distinct spots
# 0x12345678 ->
# Byte 0: 0x78 (X=8, Y=7)
# Byte 1: 0x56 (X=6, Y=5)
# Byte 2: 0x34 (X=4, Y=3)
# Byte 3: 0x12 (X=2, Y=1)
test_pid = 0x12345678
img = generate_spinda_face(test_pid)
# Save the image
cv2.imwrite("sample_spinda.png", img)
print(f"Sample Spinda image saved to sample_spinda.png for PID: {hex(test_pid)}")
if __name__ == "__main__":
save_sample_image()

View File

@@ -5,7 +5,7 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from src.models.regression_model import SpindaRegressionModel from src.models.regression_model import SpindaRegressionModel
from src.models.train import SpindaEvalDataset from src.data.dataset import SpindaEvalDataset
def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: SpindaEvalDataset) -> None: def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: SpindaEvalDataset) -> None:
@@ -29,8 +29,8 @@ def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: Spinda
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") parser.add_argument("--model_path", type=str, default="models/best_resnet34_model.pth")
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"]) parser.add_argument("--backbone", type=str, default="resnet34", choices=["resnet18", "resnet34", "convnext_tiny"])
args = parser.parse_args() args = parser.parse_args()
if not os.path.exists(args.model_path): if not os.path.exists(args.model_path):

View File

@@ -2,8 +2,7 @@ import cv2
import numpy as np import numpy as np
import torch import torch
import torchvision.transforms.v2 as T import torchvision.transforms.v2 as T
from PIL import Image from typing import List, Tuple, Union
from typing import List, Tuple
from src.models.regression_model import SpindaRegressionModel from src.models.regression_model import SpindaRegressionModel
@@ -13,8 +12,8 @@ class SpindaInference:
def __init__( def __init__(
self, self,
model_path: str = "models/best_spinda_model.pth", model_path: str = "models/best_resnet34_model.pth",
backbone: str = "resnet18", backbone: str = "resnet34",
): ):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = SpindaRegressionModel(pretrained=False, backbone=backbone) self.model = SpindaRegressionModel(pretrained=False, backbone=backbone)
@@ -29,15 +28,23 @@ class SpindaInference:
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]) ])
def predict(self, image_path: str) -> Tuple[List[int], str]: def predict(self, image: Union[str, np.ndarray]) -> Tuple[List[int], str]:
"""Predict the 8 grid coordinates and return them with a fingerprint string. """Predict the 8 grid coordinates and return them with a fingerprint string.
Args:
image: file path (str) or a BGR numpy array (e.g. from the detector).
Returns: Returns:
grid_coords: list of 8 integers in [0, 15] grid_coords: list of 8 integers in [0, 15]
fingerprint: "X1-Y1-X2-Y2-X3-Y3-X4-Y4" fingerprint: "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
""" """
img_bgr = cv2.imread(image_path) if isinstance(image, str):
img_bgr = cv2.imread(image)
if img_bgr is None:
raise FileNotFoundError(f"Image not found: {image}")
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
else:
img_rgb = image[:, :, ::-1].copy()
img_tensor = torch.from_numpy(np.array(img_rgb)).permute(2, 0, 1) img_tensor = torch.from_numpy(np.array(img_rgb)).permute(2, 0, 1)
img_tensor = self.transform(img_tensor).unsqueeze(0).to(self.device) img_tensor = self.transform(img_tensor).unsqueeze(0).to(self.device)

View File

@@ -23,7 +23,7 @@ class SpindaRegressionModel(nn.Module):
Prediction: output.argmax(dim=2) → (B, 8) integer coordinates. Prediction: output.argmax(dim=2) → (B, 8) integer coordinates.
""" """
def __init__(self, pretrained: bool = True, backbone: str = "resnet18"): def __init__(self, pretrained: bool = True, backbone: str = "resnet34"):
super().__init__() super().__init__()
if backbone not in _BACKBONES: if backbone not in _BACKBONES:
raise ValueError(f"backbone must be one of {list(_BACKBONES)}; got {backbone!r}") raise ValueError(f"backbone must be one of {list(_BACKBONES)}; got {backbone!r}")

View File

@@ -1,49 +1,17 @@
import os import os
import json
import random import random
import cv2
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torchvision.transforms.v2 as T from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm from tqdm import tqdm
from src.data.dataset import SpindaDataset, get_default_augmentations from src.data.dataset import SpindaDataset, SpindaEvalDataset, get_default_augmentations
from src.models.regression_model import SpindaRegressionModel from src.models.regression_model import SpindaRegressionModel
class SpindaEvalDataset(Dataset):
"""Loads a fixed evaluation set (clean val or augmented test).
Images are stored post-augmentation, pre-normalisation; this class
applies only the normalisation step so the on-disk images are stable.
"""
def __init__(self, data_dir: str):
with open(os.path.join(data_dir, "metadata.json")) as f:
self.metadata = json.load(f)
self.transform = T.Compose([
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __len__(self) -> int:
return len(self.metadata)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
item = self.metadata[idx]
img_bgr = cv2.imread(item["img_path"])
img_rgb = img_bgr[:, :, ::-1].copy()
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1)
img_tensor = self.transform(img_tensor)
target = torch.tensor(item["target"], dtype=torch.long)
return img_tensor, target
# BL_x (index 4) and BL_y (index 5) are the weakest coordinates; upweight them # BL_x (index 4) and BL_y (index 5) are the weakest coordinates; upweight them
# so the optimiser focuses more of its gradient on the hardest spot. # so the optimiser focuses more of its gradient on the hardest spot.
_COORD_WEIGHTS = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.5, 2.5, 1.0, 1.0]) _COORD_WEIGHTS = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.5, 2.5, 1.0, 1.0])
@@ -73,7 +41,7 @@ def train_model(
model_path: str = "models/best_spinda_model.pth", model_path: str = "models/best_spinda_model.pth",
num_workers: int = 4, num_workers: int = 4,
epoch_size: int = 200000, epoch_size: int = 200000,
backbone: str = "resnet18", backbone: str = "resnet34",
save_path: str = "", save_path: str = "",
) -> None: ) -> None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -176,8 +144,7 @@ def train_model(
if exact_rate > best_exact_rate: if exact_rate > best_exact_rate:
best_exact_rate = exact_rate best_exact_rate = exact_rate
epochs_without_improvement = 0 epochs_without_improvement = 0
os.makedirs("models", exist_ok=True) os.makedirs(os.path.dirname(checkpoint_path) or ".", exist_ok=True)
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
torch.save(model.state_dict(), checkpoint_path) torch.save(model.state_dict(), checkpoint_path)
print(f" → Saved best model to {checkpoint_path} (clean val exact match: {best_exact_rate:.2%})") print(f" → Saved best model to {checkpoint_path} (clean val exact match: {best_exact_rate:.2%})")
else: else:
@@ -198,7 +165,7 @@ if __name__ == "__main__":
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)") parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)")
parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch") parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch")
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"]) parser.add_argument("--backbone", type=str, default="resnet34", choices=["resnet18", "resnet34", "convnext_tiny"])
parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path") parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path")
args = parser.parse_args() args = parser.parse_args()

View File

@@ -1,21 +1,18 @@
import sqlite3
import os import os
from typing import List, Optional, Tuple import sqlite3
class SpindaRegistry: class SpindaRegistry:
"""Handles the storage and lookup of Spinda PIDs based on their visual fingerprints.""" """Handles the storage and lookup of Spinda PIDs based on their visual fingerprints."""
def __init__(self, db_path: str = "data/spinda_registry.db"): def __init__(self, db_path: str = "data/spinda_registry.db"):
self.db_path = db_path self.db_path = db_path
os.makedirs(os.path.dirname(self.db_path), exist_ok=True) os.makedirs(os.path.dirname(db_path) or ".", exist_ok=True)
self._init_db() self._init_db()
def _init_db(self): def _init_db(self) -> None:
"""Initializes the database with a core table for PIDs and Fingerprints."""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor() cursor = conn.cursor()
# Primary table: Mapping the 8-integer fingerprint to PIDs
# Fingerprint format: "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS registry ( CREATE TABLE IF NOT EXISTS registry (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -24,38 +21,34 @@ class SpindaRegistry:
UNIQUE(fingerprint, pid_hex) UNIQUE(fingerprint, pid_hex)
) )
''') ''')
# Index for fast O(1)-style lookups by fingerprint
cursor.execute('CREATE INDEX IF NOT EXISTS idx_fingerprint ON registry(fingerprint)') cursor.execute('CREATE INDEX IF NOT EXISTS idx_fingerprint ON registry(fingerprint)')
conn.commit() conn.commit()
def add_entry(self, fingerprint: str, pid_hex: str): def add_entry(self, fingerprint: str, pid_hex: str) -> None:
"""Adds a new Spinda entry to the registry.""" """Adds a new Spinda entry to the registry (idempotent)."""
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor() conn.execute(
cursor.execute(
"INSERT INTO registry (fingerprint, pid_hex) VALUES (?, ?)", "INSERT INTO registry (fingerprint, pid_hex) VALUES (?, ?)",
(fingerprint, pid_hex) (fingerprint, pid_hex),
) )
conn.commit() conn.commit()
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
# Entry already exists
pass pass
def lookup_by_fingerprint(self, fingerprint: str) -> List[str]: def lookup_by_fingerprint(self, fingerprint: str) -> list[str]:
"""Returns all PIDs associated with a specific visual fingerprint.""" """Returns all PIDs associated with a specific visual fingerprint."""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor() cursor = conn.execute(
cursor.execute("SELECT pid_hex FROM registry WHERE fingerprint = ?", (fingerprint,)) "SELECT pid_hex FROM registry WHERE fingerprint = ?", (fingerprint,)
results = cursor.fetchall() )
return [row[0] for row in results] return [row[0] for row in cursor.fetchall()]
if __name__ == "__main__": if __name__ == "__main__":
# Quick test
reg = SpindaRegistry("data/test_registry.db") reg = SpindaRegistry("data/test_registry.db")
test_fp = "00-01-02-03-04-05-06-07" test_fp = "00-01-02-03-04-05-06-07"
test_pid = "ABCDE123" test_pid = "ABCDE123"
reg.add_entry(test_fp, test_pid) reg.add_entry(test_fp, test_pid)
matches = reg.lookup_by_fingerprint(test_fp) matches = reg.lookup_by_fingerprint(test_fp)
print(f"Looked up {test_fp}, found PIDs: {matches}") print(f"Looked up {test_fp}, found PIDs: {matches}")

View File

@@ -1,10 +1,22 @@
from typing import List, Tuple from typing import Literal
class SpindaResolver: class SpindaResolver:
"""Mathematically resolves a visual fingerprint back to its possible PIDs.""" """Mathematically resolves a visual fingerprint back to its possible PIDs."""
@staticmethod @staticmethod
def coordinates_to_pid(coords: List[int], mode: str = "standard") -> str: def pid_to_coords(pid_hex: str) -> list[int]:
"""Converts an 8-char hex PID string to the 8 nibble target coordinates."""
p = pid_hex.lower().zfill(8)
return [
int(p[-1], 16), int(p[-2], 16),
int(p[-3], 16), int(p[-4], 16),
int(p[3], 16), int(p[2], 16),
int(p[1], 16), int(p[0], 16),
]
@staticmethod
def coordinates_to_pid(coords: list[int], mode: Literal["standard", "bdsp"] = "standard") -> str:
""" """
Converts 8 grid coordinates (0-15) back to a 32-bit hex PID. Converts 8 grid coordinates (0-15) back to a 32-bit hex PID.
@@ -14,20 +26,19 @@ class SpindaResolver:
Byte 2 (BL): x=coords[4], y=coords[5] Byte 2 (BL): x=coords[4], y=coords[5]
Byte 3 (BR): x=coords[6], y=coords[7] Byte 3 (BR): x=coords[6], y=coords[7]
""" """
if mode not in ("standard", "bdsp"):
raise ValueError(f"mode must be 'standard' or 'bdsp'; got {mode!r}")
# Each byte = (Y << 4) | X # Each byte = (Y << 4) | X
bytes_list = [] bytes_list = []
for i in range(0, 8, 2): for i in range(0, 8, 2):
x = coords[i] x = coords[i]
y = coords[i+1] y = coords[i + 1]
byte = (y << 4) | x bytes_list.append((y << 4) | x)
bytes_list.append(byte)
if mode == "bdsp": if mode == "bdsp":
# BDSP reads the bytes in reverse order (Big-Endian style)
# So we reverse them back
bytes_list = bytes_list[::-1] bytes_list = bytes_list[::-1]
# Combine bytes into 32-bit integer
pid = 0 pid = 0
for i, byte in enumerate(bytes_list): for i, byte in enumerate(bytes_list):
pid |= (byte << (i * 8)) pid |= (byte << (i * 8))
@@ -35,18 +46,18 @@ class SpindaResolver:
return f"{pid:08x}" return f"{pid:08x}"
@staticmethod @staticmethod
def resolve_fingerprint(fingerprint: str) -> dict: def resolve_fingerprint(fingerprint: str) -> dict[str, str]:
""" """
Takes a fingerprint string "X1-Y1-X2-Y2-X3-Y3-X4-Y4" Takes a fingerprint string "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
and returns both possible PIDs (Standard and BDSP). and returns both possible PIDs (Standard and BDSP).
""" """
coords = [int(c) for c in fingerprint.split("-")] coords = [int(c) for c in fingerprint.split("-")]
return { return {
"standard": SpindaResolver.coordinates_to_pid(coords, mode="standard"), "standard": SpindaResolver.coordinates_to_pid(coords, mode="standard"),
"bdsp": SpindaResolver.coordinates_to_pid(coords, mode="bdsp") "bdsp": SpindaResolver.coordinates_to_pid(coords, mode="bdsp"),
} }
if __name__ == "__main__": if __name__ == "__main__":
# Test with a known fingerprint # Test with a known fingerprint
# PID 0x12345678 -> # PID 0x12345678 ->