diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 8ddeffa..beb634a 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -18,7 +18,8 @@ "Bash(.venv/bin/python -m src.models.evaluate)", "Bash(.venv/bin/python -m src.models.evaluate --backbone resnet34 --model_path models/best_resnet34_model.pth)", "Bash(.venv/bin/python identify.py test_image2.jpg --backbone resnet34 --model_path models/best_resnet34_model.pth)", - "Bash(.venv/bin/python src/models/regression_model.py)" + "Bash(.venv/bin/python src/models/regression_model.py)", + "Bash(xargs cat -n)" ] } } diff --git a/.gitignore b/.gitignore index d5e366b..78854d7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ build/** **/__pycache__/** +models/** diff --git a/CLAUDE.md b/CLAUDE.md index 724f881..9a4096d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -11,10 +11,10 @@ All commands must be run from the project root using the local venv: .venv/bin/python identify.py # Train the model -.venv/bin/python -m src.models.train --epochs 50 --batch_size 64 --lr 1e-4 +.venv/bin/python -m src.models.train --epochs 50 # Evaluate a trained model on val and aug_test sets -.venv/bin/python -m src.models.evaluate [--backbone resnet18|resnet34] [--model_path ] +.venv/bin/python -m src.models.evaluate [--backbone resnet34] [--model_path ] # Run inference only (no registry lookup) .venv/bin/python src/models/inference.py @@ -26,7 +26,7 @@ All commands must be run from the project root using the local venv: .venv/bin/python -m src.data.generate_aug_test_set # Generate a single sample image for visual inspection -.venv/bin/python src/data/high_fidelity_generator.py +.venv/bin/python src/data/renderer.py # Lint .venv/bin/ruff check src/ @@ -50,7 +50,7 @@ Image → Detector → Inference → Resolver → Registry ### 1. Detection (`src/utils/detector.py`) -`SpindaDetector.detect_and_crop()` returns a **128×128 BGR** image, or `None`. +`SpindaDetector.detect_and_crop()` returns a **128×128 BGR** numpy array, or `None`. Two-tier strategy, tried in order: - **Tier 1 (screenshots/sprites):** HSV-filter red pixels → find individual spot blobs → cluster to 4 spots → derive crop from cluster centroid + `_SPOT_CROP_RATIO=5.5` (= 128 / 24.5 px span) with a `_SPOT_CENTER_OFFSET=0.056` downward shift so the spot centroid lands at 44.4 % from the top of the crop (matching the training canvas). @@ -58,15 +58,18 @@ Two-tier strategy, tried in order: ### 2. Model (`src/models/regression_model.py`) -ResNet-18 backbone with the final FC replaced by `Linear(512, 8·16)`. Forward pass returns **(B, 8, 16)** — treating each of the 8 coordinates as a 16-class classification problem. Trained with `CrossEntropyLoss` on `view(-1, 16)` vs `view(-1)` targets; predictions use `argmax(dim=2)`. +Configurable backbone (default: ResNet-34) with the final FC replaced by `Linear(feat_dim, 8·16)`. Forward pass returns **(B, 8, 16)** — treating each of the 8 coordinates as a 16-class classification problem. Trained with `CrossEntropyLoss` on `view(-1, 16)` vs `view(-1)` targets; predictions use `argmax(dim=2)`. + +Supported backbones: `resnet18` (512-d), `resnet34` (512-d), `convnext_tiny` (768-d). ### 3. Training (`src/models/train.py`) - `SpindaDataset` (200 k virtual samples/epoch): generates a fresh random 32-bit PID per `__getitem__`, renders the sprite with a random background colour, then applies the full augmentation pipeline. -- `SpindaEvalDataset`: loads pre-generated images from disk (post-augmentation, pre-normalisation) and applies only the normalise step. Used for both `data/val/` (clean, seed=42) and `data/aug_test/` (augmented, seed=99). +- `SpindaEvalDataset` (in `src/data/dataset.py`): loads pre-generated images from disk (post-augmentation, pre-normalisation) and applies only the normalise step. Used for both `data/val/` (clean, seed=42) and `data/aug_test/` (augmented, seed=99). - `_worker_init_fn` re-seeds Python `random` and NumPy per worker so forked workers generate distinct PIDs. +- Weighted loss: BL_x ×1.5, BL_y ×2.5 — applied during training only; val loss is unweighted. - Early stopping: patience = 10 epochs on clean-val exact-match rate. -- Best model checkpoint: `models/best_spinda_model.pth`. +- Checkpoints saved to `models/best_{backbone}_model.pth`. ### 4. PID Encoding (domain invariant — must not be changed) @@ -107,7 +110,8 @@ assets/ # Sprite assets used by the renderer Spot_{TL,TR,BL,BR}.png models/ - best_spinda_model.pth + best_resnet34_model.pth # current best (default) + best_convnext_tiny_model.pth # convnext experiment ``` `metadata.json` format: `[{"img_path": "...", "pid_hex": "...", "target": [int×8]}, ...]` @@ -115,8 +119,8 @@ models/ ## Key invariants - **Visual collisions:** ~1.3 % of fingerprints are shared by multiple PIDs (many-to-one mapping). `SpindaRegistry` stores `(fingerprint, pid_hex)` pairs with a unique constraint so `lookup_by_fingerprint` can return *all* matching PIDs — this is intentional, not a bug. - - The **validation set** uses white backgrounds (no augmentation baked in) to give a stable epoch-comparable baseline. Do not add augmentation to `generate_val_set.py`. - The **augmented test set** is pre-generated and fixed. Regenerating it changes the baseline; do so intentionally. - The crop output size is always **128×128** regardless of tier. The model transform chain also resizes to 128×128, so the inference path is robust to re-size. -- `generate_high_fidelity_spinda()` always takes `bg_color` as a `(R, G, B)` tuple in PIL order (not BGR). +- `generate_high_fidelity_spinda()` in `src/data/renderer.py` always takes `bg_color` as a `(R, G, B)` tuple in PIL order (not BGR). +- `SpindaInference.predict()` accepts either a file path or a BGR numpy array directly (e.g. from the detector). diff --git a/identify.py b/identify.py index 6343279..dd3563c 100644 --- a/identify.py +++ b/identify.py @@ -1,70 +1,59 @@ +import argparse import os import sys + import cv2 -import torch + +from src.data.renderer import generate_high_fidelity_spinda from src.models.inference import SpindaInference -from src.utils.resolver import SpindaResolver from src.registry.database import SpindaRegistry -from src.data.high_fidelity_generator import generate_high_fidelity_spinda -from src.utils.detector import SpindaDetector # Import the detector +from src.utils.detector import SpindaDetector +from src.utils.resolver import SpindaResolver + def identify_spinda( image_path: str, - model_path: str = "models/best_spinda_model.pth", - backbone: str = "resnet18", -): + model_path: str = "models/best_resnet34_model.pth", + backbone: str = "resnet34", +) -> None: if not os.path.exists(image_path): print(f"Error: File {image_path} not found.") return print(f"--- Identifying Spinda in {image_path} ---") - - # 1. Detect and Crop Spinda + + # 1. Detect and crop detector = SpindaDetector() cropped_img = detector.detect_and_crop(image_path) - if cropped_img is None: print("Error: Could not detect Spinda in the image.") return - - # Save cropped image for debug/visual check + cv2.imwrite("detected_spinda_crop.png", cropped_img) print("Detected Spinda saved to detected_spinda_crop.png") - # We need to save the cropped image to a temporary file for the inference model to read - temp_cropped_path = "temp_cropped_spinda.png" - cv2.imwrite(temp_cropped_path, cropped_img) - - # 2. Inference (Model Prediction) using the cropped image - try: - inf = SpindaInference(model_path=model_path, backbone=backbone) - coords, fingerprint = inf.predict(temp_cropped_path) - except Exception as e: - print(f"Error during inference: {e}") - os.remove(temp_cropped_path) # Clean up temp file - return - finally: - os.remove(temp_cropped_path) # Clean up temp file + # 2. Inference — pass the BGR array directly, no temp file needed + inf = SpindaInference(model_path=model_path, backbone=backbone) + coords, fingerprint = inf.predict(cropped_img) print(f"Visual Fingerprint: {fingerprint}") print(f"Predicted Grid Coordinates: {coords}") - # 3. Resolution (Mathematical PIDs) + # 3. Resolve to PIDs resolved = SpindaResolver.resolve_fingerprint(fingerprint) print("\nPossible PIDs:") print(f" Standard (Gen 3-8, HOME): 0x{resolved['standard']}") print(f" BDSP (Big-Endian Flip): 0x{resolved['bdsp']}") - # 4. Visual Verification + # 4. Visual verification print("\nGenerating visual verification image...") verify_img = generate_high_fidelity_spinda(int(resolved['standard'], 16)) cv2.imwrite("prediction_verify.png", verify_img) print("Verification image saved to: prediction_verify.png") - # 5. Registry Lookup + # 5. Registry lookup reg = SpindaRegistry() matches = reg.lookup_by_fingerprint(fingerprint) - if matches: print("\nMatches found in Global Registry:") for pid in matches: @@ -72,13 +61,12 @@ def identify_spinda( else: print("\nNo matching entries in Global Registry.") - print("\nNote: Accuracy depends on model training progress.") if __name__ == "__main__": - import argparse parser = argparse.ArgumentParser() parser.add_argument("image_path") - parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"]) - parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") + parser.add_argument("--backbone", type=str, default="resnet34", + choices=["resnet18", "resnet34", "convnext_tiny"]) + parser.add_argument("--model_path", type=str, default="models/best_resnet34_model.pth") args = parser.parse_args() identify_spinda(args.image_path, model_path=args.model_path, backbone=args.backbone) diff --git a/models/best_spinda_model.pth b/models/best_spinda_model.pth index 4199b86..c175e04 100644 Binary files a/models/best_spinda_model.pth and b/models/best_spinda_model.pth differ diff --git a/pyproject.toml b/pyproject.toml index 57a13a7..b7eb365 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "opencv-python>=4.7.0", "Pillow>=9.5.0", "numpy>=1.24.0", + "tqdm>=4.65.0", ] [project.optional-dependencies] @@ -16,7 +17,6 @@ dev = [ "pytest>=7.3.0", "ruff>=0.0.270", "mypy>=1.3.0", - "tqdm>=4.65.0", ] [tool.ruff] diff --git a/src/data/dataset.py b/src/data/dataset.py index e4a12c4..04fdb67 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -1,69 +1,52 @@ -import torch -from torch.utils.data import Dataset -import torchvision.transforms.v2 as T -import numpy as np -from typing import Tuple, Optional +import json +import os import random +from typing import Optional + +import cv2 +import torch +import torchvision.transforms.v2 as T +from torch.utils.data import Dataset + +from src.data.renderer import generate_high_fidelity_spinda +from src.utils.resolver import SpindaResolver -from src.data.high_fidelity_generator import generate_high_fidelity_spinda class SpindaDataset(Dataset): """PyTorch Dataset for generating synthetic Spinda samples with augmentations.""" - + def __init__(self, size: int = 10000, transform: Optional[T.Transform] = None): - """ - Args: - size: Virtual size of the dataset (since it's synthetic). - transform: Optional torchvision transforms to apply. - """ self.size = size self.transform = transform def __len__(self) -> int: return self.size - def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: - # Generate a random 32-bit PID + def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: pid = random.getrandbits(32) pid_hex = f"{pid:08x}" - - # 1. Generate High-Fidelity Image on a random background colour + r = random.randint(0, 255) g = random.randint(0, 255) b = random.randint(0, 255) img_bgr = generate_high_fidelity_spinda(pid, bg_color=(r, g, b)) - - # Convert BGR to RGB for PyTorch/Torchvision img_rgb = img_bgr[:, :, ::-1].copy() - - # 2. Get Ground Truth Coordinates (Target) - # Raw nibble values (0-15) for each spot, in TL/TR/BL/BR order. - raw_coords = [] - # TL (Spot 1): Nibble 0, 1 (PID[-1], PID[-2]) - raw_coords.extend([int(pid_hex[-1], 16), int(pid_hex[-2], 16)]) - # TR (Spot 2): Nibble 2, 3 (PID[-3], PID[-4]) - raw_coords.extend([int(pid_hex[-3], 16), int(pid_hex[-4], 16)]) - # BL (Spot 3): Nibble 4, 5 (PID[3], PID[2]) - raw_coords.extend([int(pid_hex[3], 16), int(pid_hex[2], 16)]) - # BR (Spot 4): Nibble 6, 7 (PID[1], PID[0]) - raw_coords.extend([int(pid_hex[1], 16), int(pid_hex[0], 16)]) - - # Integer labels in [0, 15] — used with CrossEntropyLoss. - target_tensor = torch.tensor(raw_coords, dtype=torch.long) - - # 3. Apply Transforms + + target_tensor = torch.tensor(SpindaResolver.pid_to_coords(pid_hex), dtype=torch.long) + if self.transform: - # Convert to PIL or Tensor first if needed by transform - img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1) # C, H, W + img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1) img_tensor = self.transform(img_tensor) else: img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0 - + return img_tensor, target_tensor + def add_gaussian_noise(x: torch.Tensor) -> torch.Tensor: return (x + torch.randn_like(x) * 0.05).clamp(0, 1) + def add_scan_lines(x: torch.Tensor) -> torch.Tensor: """Simulate LCD scan lines seen in handheld-camera photos of 3DS screens.""" if torch.rand(1).item() < 0.5: @@ -72,26 +55,52 @@ def add_scan_lines(x: torch.Tensor) -> torch.Tensor: x[:, ::2, :] *= (1.0 - strength) return x.clamp(0, 1) + def get_default_augmentations() -> T.Compose: """Domain randomisation pipeline calibrated for real handheld-photo conditions.""" return T.Compose([ T.ToDtype(torch.float32, scale=True), - # Spatial — wider range to cover camera angle and zoom variation T.RandomAffine(degrees=25, translate=(0.05, 0.05), shear=8), T.RandomResizedCrop(size=(128, 128), scale=(0.75, 1.0), ratio=(0.85, 1.15), antialias=True), - # Colour / sensor — stronger to cover screen glare and ambient lighting T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.05), T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)), - # Sensor noise and LCD scan lines T.Lambda(add_gaussian_noise), T.Lambda(add_scan_lines), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) + +class SpindaEvalDataset(Dataset): + """Loads a fixed evaluation set (clean val or augmented test). + + Images are stored post-augmentation, pre-normalisation; this class + applies only the normalisation step so the on-disk images are stable. + """ + + def __init__(self, data_dir: str): + with open(os.path.join(data_dir, "metadata.json")) as f: + self.metadata = json.load(f) + self.transform = T.Compose([ + T.ToDtype(torch.float32, scale=True), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def __len__(self) -> int: + return len(self.metadata) + + def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: + item = self.metadata[idx] + img_bgr = cv2.imread(item["img_path"]) + if img_bgr is None: + raise FileNotFoundError(f"Image not found: {item['img_path']}") + img_rgb = img_bgr[:, :, ::-1].copy() + img_tensor = self.transform(torch.from_numpy(img_rgb).permute(2, 0, 1)) + target = torch.tensor(item["target"], dtype=torch.long) + return img_tensor, target + + if __name__ == "__main__": - # Test the dataset ds = SpindaDataset(size=5, transform=get_default_augmentations()) img, target = ds[0] print(f"Image shape: {img.shape}") - print(f"Target (normalized 0-1): {target}") - print(f"Target (grid units): {target * 15.0}") + print(f"Target: {target.tolist()}") diff --git a/src/data/generate_aug_test_set.py b/src/data/generate_aug_test_set.py index 17dcbb6..db9a546 100644 --- a/src/data/generate_aug_test_set.py +++ b/src/data/generate_aug_test_set.py @@ -15,7 +15,8 @@ import torchvision.transforms.v2 as T from tqdm import tqdm from src.data.dataset import add_gaussian_noise, add_scan_lines -from src.data.high_fidelity_generator import generate_high_fidelity_spinda +from src.data.renderer import generate_high_fidelity_spinda +from src.utils.resolver import SpindaResolver # Augmentation without the final Normalize step — we bake everything else in. @@ -53,13 +54,11 @@ def generate_aug_test_set(size: int = 500, output_dir: str = "data/aug_test") -> img_path = os.path.join(output_dir, f"sample_{i:04d}.png") cv2.imwrite(img_path, aug_bgr) - raw_coords = [ - int(pid_hex[-1], 16), int(pid_hex[-2], 16), - int(pid_hex[-3], 16), int(pid_hex[-4], 16), - int(pid_hex[3], 16), int(pid_hex[2], 16), - int(pid_hex[1], 16), int(pid_hex[0], 16), - ] - metadata.append({"img_path": img_path, "pid_hex": pid_hex, "target": raw_coords}) + metadata.append({ + "img_path": img_path, + "pid_hex": pid_hex, + "target": SpindaResolver.pid_to_coords(pid_hex), + }) with open(os.path.join(output_dir, "metadata.json"), "w") as f: json.dump(metadata, f, indent=4) diff --git a/src/data/generate_val_set.py b/src/data/generate_val_set.py index bce0c82..5ae78dd 100644 --- a/src/data/generate_val_set.py +++ b/src/data/generate_val_set.py @@ -1,47 +1,40 @@ -import os -import torch import json +import os import random -from tqdm import tqdm -from src.data.high_fidelity_generator import generate_high_fidelity_spinda -import cv2 -def generate_fixed_val_set(size: int = 1000, output_dir: str = "data/val"): +import cv2 +from tqdm import tqdm + +from src.data.renderer import generate_high_fidelity_spinda +from src.utils.resolver import SpindaResolver + + +def generate_fixed_val_set(size: int = 1000, output_dir: str = "data/val") -> None: """Generates a fixed set of Spinda images and their targets for validation.""" os.makedirs(output_dir, exist_ok=True) - metadata = [] - + print(f"Generating {size} validation samples...") for i in tqdm(range(size)): pid = random.getrandbits(32) pid_hex = f"{pid:08x}" - - # Generate image + img_bgr = generate_high_fidelity_spinda(pid) img_path = os.path.join(output_dir, f"sample_{i:04d}.png") cv2.imwrite(img_path, img_bgr) - - # Extract raw coordinates (0-15) - raw_coords = [ - int(pid_hex[-1], 16), int(pid_hex[-2], 16), - int(pid_hex[-3], 16), int(pid_hex[-4], 16), - int(pid_hex[3], 16), int(pid_hex[2], 16), - int(pid_hex[1], 16), int(pid_hex[0], 16) - ] - + metadata.append({ "img_path": img_path, "pid_hex": pid_hex, - "target": raw_coords + "target": SpindaResolver.pid_to_coords(pid_hex), }) - + with open(os.path.join(output_dir, "metadata.json"), "w") as f: json.dump(metadata, f, indent=4) - + print(f"Validation set generated in {output_dir}") + if __name__ == "__main__": - # Seed for reproducibility of the validation set itself random.seed(42) generate_fixed_val_set() diff --git a/src/data/generator.py b/src/data/generator.py deleted file mode 100644 index 10b4b7d..0000000 --- a/src/data/generator.py +++ /dev/null @@ -1,81 +0,0 @@ -import numpy as np -import cv2 -from typing import Tuple, List - -# Image and Grid constants -IMG_SIZE = 128 -GRID_SIZE = 16 -SPOT_RADIUS = 6 # Approximate radius in pixels - -# Colors (BGR) -FACE_COLOR = (180, 220, 240) # Pale cream/tan -SPOT_COLOR = (50, 50, 200) # Reddish-orange -EYE_COLOR = (0, 0, 0) - -def extract_coords(pid: int, mode: str = "standard") -> List[Tuple[int, int]]: - """Extracts four (x, y) coordinates from a 32-bit PID/EC. - - Standard (Little-Endian): Byte 0 (LL), 1 (LR), 2 (UL), 3 (UR) - BDSP (Big-Endian): Byte 3 (LL), 2 (LR), 1 (UL), 0 (UR) - """ - bytes_list = [ - (pid >> 0) & 0xFF, # Byte 0 - (pid >> 8) & 0xFF, # Byte 1 - (pid >> 16) & 0xFF, # Byte 2 - (pid >> 24) & 0xFF, # Byte 3 - ] - - if mode == "bdsp": - # BDSP reads the bytes in reverse order - ordered_bytes = [bytes_list[3], bytes_list[2], bytes_list[1], bytes_list[0]] - else: - ordered_bytes = bytes_list - - coords = [] - for byte in ordered_bytes: - x = byte & 0x0F - y = (byte >> 4) & 0x0F - coords.append((x, y)) - return coords - -def generate_spinda_face(pid: int, mode: str = "standard") -> np.ndarray: - """Generates a procedural Spinda face with spots based on the PID.""" - # Create blank canvas (white) - img = np.ones((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8) * 255 - - # Draw Ears - cv2.circle(img, (IMG_SIZE // 2 - 40, IMG_SIZE // 2 - 50), 20, FACE_COLOR, -1) - cv2.circle(img, (IMG_SIZE // 2 - 40, IMG_SIZE // 2 - 50), 20, (0, 0, 0), 2) - cv2.circle(img, (IMG_SIZE // 2 + 40, IMG_SIZE // 2 - 50), 20, FACE_COLOR, -1) - cv2.circle(img, (IMG_SIZE // 2 + 40, IMG_SIZE // 2 - 50), 20, (0, 0, 0), 2) - - # Draw main face (oval) - center = (IMG_SIZE // 2, IMG_SIZE // 2) - axes = (50, 60) - cv2.ellipse(img, center, axes, 0, 0, 360, FACE_COLOR, -1) - cv2.ellipse(img, center, axes, 0, 0, 360, (0, 0, 0), 2) # Outline - - # Fixed Eyes - cv2.circle(img, (center[0] - 15, center[1] - 10), 4, EYE_COLOR, -1) - cv2.circle(img, (center[0] + 15, center[1] - 10), 4, EYE_COLOR, -1) - - # Define Spot Zones (Relative to the 16x16 grid) - # Heuristic mapping for a more "Pokemon-like" layout - # Spot 1 (LL), Spot 2 (LR), Spot 3 (UL), Spot 4 (UR) - # These offsets are designed to place spots in their respective quadrants - quadrant_offsets = [ - (center[0] - 45, center[1] + 5), # LL (Face) - (center[0] + 5, center[1] + 5), # LR (Face) - (center[0] - 45, center[1] - 55), # UL (Ear/Upper Face) - (center[0] + 5, center[1] - 55), # UR (Ear/Upper Face) - ] - - coords = extract_coords(pid, mode=mode) - for i, (x, y) in enumerate(coords): - offset_x, offset_y = quadrant_offsets[i] - px = int(offset_x + x * 2.5) # Scale grid to quadrant - py = int(offset_y + y * 2.5) - - cv2.circle(img, (px, py), SPOT_RADIUS, SPOT_COLOR, -1) - - return img diff --git a/src/data/high_fidelity_generator.py b/src/data/renderer.py similarity index 53% rename from src/data/high_fidelity_generator.py rename to src/data/renderer.py index 93546e2..84abdc5 100644 --- a/src/data/high_fidelity_generator.py +++ b/src/data/renderer.py @@ -1,15 +1,14 @@ -import numpy as np -import cv2 -from PIL import Image import os -from typing import List, Tuple + +import cv2 +import numpy as np +from PIL import Image # Constants IMG_SIZE = 128 -BASE_PATH = "assets" +BASE_PATH = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "assets")) # Offsets from ProfessorRex's Spinda_generator.py (Observer Perspective) -# This matches the script's PID_to_Coordinates logic: # TL (Spot 1): Nibble 0, 1 (PID[-1], PID[-2]) # TR (Spot 2): Nibble 2, 3 (PID[-3], PID[-4]) + (24, 1) # BL (Spot 3): Nibble 4, 5 (PID[3], PID[2]) + (6, 18) @@ -21,10 +20,9 @@ SPOT_BASE_OFFSETS = [ (18, 19), # Spot 4 (BR) ] -def extract_coords_rex(pid_hex: str) -> List[Tuple[int, int]]: - """Extracts coordinates following ProfessorRex's PID_to_Coordinates logic. - PID is expected to be an 8-char hex string. - """ + +def extract_coords_rex(pid_hex: str) -> list[tuple[int, int]]: + """Extracts coordinates following ProfessorRex's PID_to_Coordinates logic.""" pid = pid_hex.lower().zfill(8) TL = (int(pid[-1], 16), int(pid[-2], 16)) TR = (int(pid[-3], 16) + 24, int(pid[-4], 16) + 1) @@ -32,64 +30,52 @@ def extract_coords_rex(pid_hex: str) -> List[Tuple[int, int]]: BR = (int(pid[1], 16) + 18, int(pid[0], 16) + 19) return [TL, TR, BL, BR] + def generate_high_fidelity_spinda( pid: int, bg_color: tuple[int, int, int] = (255, 255, 255), ) -> np.ndarray: """Generates a high-fidelity Spinda face using correctly sized assets.""" pid_hex = f"{pid:08x}" - + # 1. Load Assets base_img = Image.open(os.path.join(BASE_PATH, "Spinda_Base_Top.png")).convert("RGBA") head_img = Image.open(os.path.join(BASE_PATH, "Spinda_Head.png")).convert("RGBA") head_data = np.array(head_img) - - spot_names = ["Spot_TL.png", "Spot_TR.png", "Spot_BL.png", "Spot_BR.png"] - spots = [Image.open(os.path.join(BASE_PATH, name)).convert("RGBA") for name in spot_names] - - # 2. Create Pattern Layer (Integer grid like in Rex's script) - W, H = base_img.size - pattern_grid = np.zeros((H, W), dtype=np.uint8) - - coords = extract_coords_rex(pid_hex) - - for i, (px_start, py_start) in enumerate(coords): - spot_arr = np.array(spots[i]) - sh, sw = spot_arr.shape[:2] - - for sy in range(sh): - for sx in range(sw): - # If spot pixel is white (active) in the mask - # Check for white pixels (R,G,B > 200) - if spot_arr[sy, sx, 0] > 200: - tx, ty = px_start + sx, py_start + sy - if 0 <= tx < W and 0 <= ty < H: - # Mark as active - pattern_grid[ty, tx] = 1 - # 3. Colourize - # Create an empty RGBA layer for the spots + spot_names = ["Spot_TL.png", "Spot_TR.png", "Spot_BL.png", "Spot_BR.png"] + spots = [np.array(Image.open(os.path.join(BASE_PATH, name)).convert("RGBA")) for name in spot_names] + + # 2. Build pattern grid — mark pixels covered by any spot + W, H = base_img.size + pattern_grid = np.zeros((H, W), dtype=bool) + + for spot_arr, (px_start, py_start) in zip(spots, extract_coords_rex(pid_hex)): + sh, sw = spot_arr.shape[:2] + active = spot_arr[:, :, 0] > 200 # white pixels in the spot mask + ys, xs = np.where(active) + tx = px_start + xs + ty = py_start + ys + valid = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H) + pattern_grid[ty[valid], tx[valid]] = True + + # 3. Colourize — copy head colours onto active spot pixels spot_layer = np.zeros((H, W, 4), dtype=np.uint8) - for y in range(H): - for x in range(W): - if pattern_grid[y, x] > 0: - # Take color from Spinda_Head.png - spot_layer[y, x] = head_data[y, x] + spot_layer[pattern_grid] = head_data[pattern_grid] # 4. Composite - spot_layer_img = Image.fromarray(spot_layer, "RGBA") - combined = Image.alpha_composite(base_img, spot_layer_img) - - # 5. Final Canvas (128x128) + combined = Image.alpha_composite(base_img, Image.fromarray(spot_layer, "RGBA")) + + # 5. Final Canvas (128×128) final_img = Image.new("RGBA", (IMG_SIZE, IMG_SIZE), (*bg_color, 255)) offset = ((IMG_SIZE - W) // 2, (IMG_SIZE - H) // 2) final_img.paste(combined, offset, combined) - + return cv2.cvtColor(np.array(final_img), cv2.COLOR_RGBA2BGR) + if __name__ == "__main__": - # Test with 0x12345678 test_pid = 0x12345678 img = generate_high_fidelity_spinda(test_pid) - cv2.imwrite("sample_high_fidelity_v3.png", img) - print(f"Corrected High-fidelity sample saved to sample_high_fidelity_v3.png for PID: {hex(test_pid)}") + cv2.imwrite("sample_spinda.png", img) + print(f"Sample saved to sample_spinda.png for PID: {hex(test_pid)}") diff --git a/src/data/visualize_sample.py b/src/data/visualize_sample.py deleted file mode 100644 index b8170fc..0000000 --- a/src/data/visualize_sample.py +++ /dev/null @@ -1,21 +0,0 @@ -import cv2 -from src.data.generator import generate_spinda_face - -def save_sample_image(): - # Example PID (from documentation/common examples if possible) - # Let's use a PID that should have distinct spots - # 0x12345678 -> - # Byte 0: 0x78 (X=8, Y=7) - # Byte 1: 0x56 (X=6, Y=5) - # Byte 2: 0x34 (X=4, Y=3) - # Byte 3: 0x12 (X=2, Y=1) - test_pid = 0x12345678 - - img = generate_spinda_face(test_pid) - - # Save the image - cv2.imwrite("sample_spinda.png", img) - print(f"Sample Spinda image saved to sample_spinda.png for PID: {hex(test_pid)}") - -if __name__ == "__main__": - save_sample_image() diff --git a/src/models/evaluate.py b/src/models/evaluate.py index 073f430..a120672 100644 --- a/src/models/evaluate.py +++ b/src/models/evaluate.py @@ -5,7 +5,7 @@ import torch from torch.utils.data import DataLoader from src.models.regression_model import SpindaRegressionModel -from src.models.train import SpindaEvalDataset +from src.data.dataset import SpindaEvalDataset def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: SpindaEvalDataset) -> None: @@ -29,8 +29,8 @@ def evaluate(model: torch.nn.Module, device: torch.device, name: str, ds: Spinda if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") - parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"]) + parser.add_argument("--model_path", type=str, default="models/best_resnet34_model.pth") + parser.add_argument("--backbone", type=str, default="resnet34", choices=["resnet18", "resnet34", "convnext_tiny"]) args = parser.parse_args() if not os.path.exists(args.model_path): diff --git a/src/models/inference.py b/src/models/inference.py index b808109..23a68fc 100644 --- a/src/models/inference.py +++ b/src/models/inference.py @@ -2,8 +2,7 @@ import cv2 import numpy as np import torch import torchvision.transforms.v2 as T -from PIL import Image -from typing import List, Tuple +from typing import List, Tuple, Union from src.models.regression_model import SpindaRegressionModel @@ -13,8 +12,8 @@ class SpindaInference: def __init__( self, - model_path: str = "models/best_spinda_model.pth", - backbone: str = "resnet18", + model_path: str = "models/best_resnet34_model.pth", + backbone: str = "resnet34", ): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = SpindaRegressionModel(pretrained=False, backbone=backbone) @@ -29,15 +28,23 @@ class SpindaInference: T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) - def predict(self, image_path: str) -> Tuple[List[int], str]: + def predict(self, image: Union[str, np.ndarray]) -> Tuple[List[int], str]: """Predict the 8 grid coordinates and return them with a fingerprint string. + Args: + image: file path (str) or a BGR numpy array (e.g. from the detector). Returns: grid_coords: list of 8 integers in [0, 15] fingerprint: "X1-Y1-X2-Y2-X3-Y3-X4-Y4" """ - img_bgr = cv2.imread(image_path) - img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + if isinstance(image, str): + img_bgr = cv2.imread(image) + if img_bgr is None: + raise FileNotFoundError(f"Image not found: {image}") + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + else: + img_rgb = image[:, :, ::-1].copy() + img_tensor = torch.from_numpy(np.array(img_rgb)).permute(2, 0, 1) img_tensor = self.transform(img_tensor).unsqueeze(0).to(self.device) diff --git a/src/models/regression_model.py b/src/models/regression_model.py index 473959f..17130d2 100644 --- a/src/models/regression_model.py +++ b/src/models/regression_model.py @@ -23,7 +23,7 @@ class SpindaRegressionModel(nn.Module): Prediction: output.argmax(dim=2) → (B, 8) integer coordinates. """ - def __init__(self, pretrained: bool = True, backbone: str = "resnet18"): + def __init__(self, pretrained: bool = True, backbone: str = "resnet34"): super().__init__() if backbone not in _BACKBONES: raise ValueError(f"backbone must be one of {list(_BACKBONES)}; got {backbone!r}") diff --git a/src/models/train.py b/src/models/train.py index 0049a59..43aa06d 100644 --- a/src/models/train.py +++ b/src/models/train.py @@ -1,49 +1,17 @@ import os -import json import random -import cv2 import numpy as np import torch -import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -import torchvision.transforms.v2 as T -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from tqdm import tqdm -from src.data.dataset import SpindaDataset, get_default_augmentations +from src.data.dataset import SpindaDataset, SpindaEvalDataset, get_default_augmentations from src.models.regression_model import SpindaRegressionModel -class SpindaEvalDataset(Dataset): - """Loads a fixed evaluation set (clean val or augmented test). - - Images are stored post-augmentation, pre-normalisation; this class - applies only the normalisation step so the on-disk images are stable. - """ - - def __init__(self, data_dir: str): - with open(os.path.join(data_dir, "metadata.json")) as f: - self.metadata = json.load(f) - self.transform = T.Compose([ - T.ToDtype(torch.float32, scale=True), - T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) - - def __len__(self) -> int: - return len(self.metadata) - - def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: - item = self.metadata[idx] - img_bgr = cv2.imread(item["img_path"]) - img_rgb = img_bgr[:, :, ::-1].copy() - img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1) - img_tensor = self.transform(img_tensor) - target = torch.tensor(item["target"], dtype=torch.long) - return img_tensor, target - - # BL_x (index 4) and BL_y (index 5) are the weakest coordinates; upweight them # so the optimiser focuses more of its gradient on the hardest spot. _COORD_WEIGHTS = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.5, 2.5, 1.0, 1.0]) @@ -73,7 +41,7 @@ def train_model( model_path: str = "models/best_spinda_model.pth", num_workers: int = 4, epoch_size: int = 200000, - backbone: str = "resnet18", + backbone: str = "resnet34", save_path: str = "", ) -> None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -176,8 +144,7 @@ def train_model( if exact_rate > best_exact_rate: best_exact_rate = exact_rate epochs_without_improvement = 0 - os.makedirs("models", exist_ok=True) - os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) + os.makedirs(os.path.dirname(checkpoint_path) or ".", exist_ok=True) torch.save(model.state_dict(), checkpoint_path) print(f" → Saved best model to {checkpoint_path} (clean val exact match: {best_exact_rate:.2%})") else: @@ -198,7 +165,7 @@ if __name__ == "__main__": parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth") parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)") parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch") - parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34", "convnext_tiny"]) + parser.add_argument("--backbone", type=str, default="resnet34", choices=["resnet18", "resnet34", "convnext_tiny"]) parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path") args = parser.parse_args() diff --git a/src/registry/database.py b/src/registry/database.py index c6e6ede..ac72415 100644 --- a/src/registry/database.py +++ b/src/registry/database.py @@ -1,21 +1,18 @@ -import sqlite3 import os -from typing import List, Optional, Tuple +import sqlite3 + class SpindaRegistry: """Handles the storage and lookup of Spinda PIDs based on their visual fingerprints.""" - + def __init__(self, db_path: str = "data/spinda_registry.db"): self.db_path = db_path - os.makedirs(os.path.dirname(self.db_path), exist_ok=True) + os.makedirs(os.path.dirname(db_path) or ".", exist_ok=True) self._init_db() - def _init_db(self): - """Initializes the database with a core table for PIDs and Fingerprints.""" + def _init_db(self) -> None: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() - # Primary table: Mapping the 8-integer fingerprint to PIDs - # Fingerprint format: "X1-Y1-X2-Y2-X3-Y3-X4-Y4" cursor.execute(''' CREATE TABLE IF NOT EXISTS registry ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -24,38 +21,34 @@ class SpindaRegistry: UNIQUE(fingerprint, pid_hex) ) ''') - # Index for fast O(1)-style lookups by fingerprint cursor.execute('CREATE INDEX IF NOT EXISTS idx_fingerprint ON registry(fingerprint)') conn.commit() - def add_entry(self, fingerprint: str, pid_hex: str): - """Adds a new Spinda entry to the registry.""" + def add_entry(self, fingerprint: str, pid_hex: str) -> None: + """Adds a new Spinda entry to the registry (idempotent).""" try: with sqlite3.connect(self.db_path) as conn: - cursor = conn.cursor() - cursor.execute( + conn.execute( "INSERT INTO registry (fingerprint, pid_hex) VALUES (?, ?)", - (fingerprint, pid_hex) + (fingerprint, pid_hex), ) conn.commit() except sqlite3.IntegrityError: - # Entry already exists pass - def lookup_by_fingerprint(self, fingerprint: str) -> List[str]: + def lookup_by_fingerprint(self, fingerprint: str) -> list[str]: """Returns all PIDs associated with a specific visual fingerprint.""" with sqlite3.connect(self.db_path) as conn: - cursor = conn.cursor() - cursor.execute("SELECT pid_hex FROM registry WHERE fingerprint = ?", (fingerprint,)) - results = cursor.fetchall() - return [row[0] for row in results] + cursor = conn.execute( + "SELECT pid_hex FROM registry WHERE fingerprint = ?", (fingerprint,) + ) + return [row[0] for row in cursor.fetchall()] + if __name__ == "__main__": - # Quick test reg = SpindaRegistry("data/test_registry.db") test_fp = "00-01-02-03-04-05-06-07" test_pid = "ABCDE123" - reg.add_entry(test_fp, test_pid) matches = reg.lookup_by_fingerprint(test_fp) print(f"Looked up {test_fp}, found PIDs: {matches}") diff --git a/src/utils/resolver.py b/src/utils/resolver.py index c250cae..9a21ac6 100644 --- a/src/utils/resolver.py +++ b/src/utils/resolver.py @@ -1,55 +1,66 @@ -from typing import List, Tuple +from typing import Literal + class SpindaResolver: """Mathematically resolves a visual fingerprint back to its possible PIDs.""" @staticmethod - def coordinates_to_pid(coords: List[int], mode: str = "standard") -> str: + def pid_to_coords(pid_hex: str) -> list[int]: + """Converts an 8-char hex PID string to the 8 nibble target coordinates.""" + p = pid_hex.lower().zfill(8) + return [ + int(p[-1], 16), int(p[-2], 16), + int(p[-3], 16), int(p[-4], 16), + int(p[3], 16), int(p[2], 16), + int(p[1], 16), int(p[0], 16), + ] + + @staticmethod + def coordinates_to_pid(coords: list[int], mode: Literal["standard", "bdsp"] = "standard") -> str: """ Converts 8 grid coordinates (0-15) back to a 32-bit hex PID. - + Standard Mapping (Rex/Little-Endian): Byte 0 (TL): x=coords[0], y=coords[1] Byte 1 (TR): x=coords[2], y=coords[3] Byte 2 (BL): x=coords[4], y=coords[5] Byte 3 (BR): x=coords[6], y=coords[7] """ + if mode not in ("standard", "bdsp"): + raise ValueError(f"mode must be 'standard' or 'bdsp'; got {mode!r}") + # Each byte = (Y << 4) | X bytes_list = [] for i in range(0, 8, 2): x = coords[i] - y = coords[i+1] - byte = (y << 4) | x - bytes_list.append(byte) + y = coords[i + 1] + bytes_list.append((y << 4) | x) if mode == "bdsp": - # BDSP reads the bytes in reverse order (Big-Endian style) - # So we reverse them back bytes_list = bytes_list[::-1] - # Combine bytes into 32-bit integer pid = 0 for i, byte in enumerate(bytes_list): pid |= (byte << (i * 8)) - + return f"{pid:08x}" @staticmethod - def resolve_fingerprint(fingerprint: str) -> dict: + def resolve_fingerprint(fingerprint: str) -> dict[str, str]: """ Takes a fingerprint string "X1-Y1-X2-Y2-X3-Y3-X4-Y4" and returns both possible PIDs (Standard and BDSP). """ coords = [int(c) for c in fingerprint.split("-")] - return { "standard": SpindaResolver.coordinates_to_pid(coords, mode="standard"), - "bdsp": SpindaResolver.coordinates_to_pid(coords, mode="bdsp") + "bdsp": SpindaResolver.coordinates_to_pid(coords, mode="bdsp"), } + if __name__ == "__main__": # Test with a known fingerprint - # PID 0x12345678 -> + # PID 0x12345678 -> # Byte 0: 0x78 (X=8, Y=7) # Byte 1: 0x56 (X=6, Y=5) # Byte 2: 0x34 (X=4, Y=3)