Source code
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1 +1,2 @@
|
|||||||
build/**
|
build/**
|
||||||
|
**/__pycache__/**
|
||||||
|
|||||||
97
src/data/dataset.py
Normal file
97
src/data/dataset.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import torchvision.transforms.v2 as T
|
||||||
|
import numpy as np
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
import random
|
||||||
|
|
||||||
|
from src.data.high_fidelity_generator import generate_high_fidelity_spinda
|
||||||
|
|
||||||
|
class SpindaDataset(Dataset):
|
||||||
|
"""PyTorch Dataset for generating synthetic Spinda samples with augmentations."""
|
||||||
|
|
||||||
|
def __init__(self, size: int = 10000, transform: Optional[T.Transform] = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
size: Virtual size of the dataset (since it's synthetic).
|
||||||
|
transform: Optional torchvision transforms to apply.
|
||||||
|
"""
|
||||||
|
self.size = size
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Generate a random 32-bit PID
|
||||||
|
pid = random.getrandbits(32)
|
||||||
|
pid_hex = f"{pid:08x}"
|
||||||
|
|
||||||
|
# 1. Generate High-Fidelity Image on a random background colour
|
||||||
|
r = random.randint(0, 255)
|
||||||
|
g = random.randint(0, 255)
|
||||||
|
b = random.randint(0, 255)
|
||||||
|
img_bgr = generate_high_fidelity_spinda(pid, bg_color=(r, g, b))
|
||||||
|
|
||||||
|
# Convert BGR to RGB for PyTorch/Torchvision
|
||||||
|
img_rgb = img_bgr[:, :, ::-1].copy()
|
||||||
|
|
||||||
|
# 2. Get Ground Truth Coordinates (Target)
|
||||||
|
# Raw nibble values (0-15) for each spot, in TL/TR/BL/BR order.
|
||||||
|
raw_coords = []
|
||||||
|
# TL (Spot 1): Nibble 0, 1 (PID[-1], PID[-2])
|
||||||
|
raw_coords.extend([int(pid_hex[-1], 16), int(pid_hex[-2], 16)])
|
||||||
|
# TR (Spot 2): Nibble 2, 3 (PID[-3], PID[-4])
|
||||||
|
raw_coords.extend([int(pid_hex[-3], 16), int(pid_hex[-4], 16)])
|
||||||
|
# BL (Spot 3): Nibble 4, 5 (PID[3], PID[2])
|
||||||
|
raw_coords.extend([int(pid_hex[3], 16), int(pid_hex[2], 16)])
|
||||||
|
# BR (Spot 4): Nibble 6, 7 (PID[1], PID[0])
|
||||||
|
raw_coords.extend([int(pid_hex[1], 16), int(pid_hex[0], 16)])
|
||||||
|
|
||||||
|
# Integer labels in [0, 15] — used with CrossEntropyLoss.
|
||||||
|
target_tensor = torch.tensor(raw_coords, dtype=torch.long)
|
||||||
|
|
||||||
|
# 3. Apply Transforms
|
||||||
|
if self.transform:
|
||||||
|
# Convert to PIL or Tensor first if needed by transform
|
||||||
|
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1) # C, H, W
|
||||||
|
img_tensor = self.transform(img_tensor)
|
||||||
|
else:
|
||||||
|
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0
|
||||||
|
|
||||||
|
return img_tensor, target_tensor
|
||||||
|
|
||||||
|
def add_gaussian_noise(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return (x + torch.randn_like(x) * 0.05).clamp(0, 1)
|
||||||
|
|
||||||
|
def add_scan_lines(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Simulate LCD scan lines seen in handheld-camera photos of 3DS screens."""
|
||||||
|
if torch.rand(1).item() < 0.5:
|
||||||
|
strength = torch.rand(1).item() * 0.25
|
||||||
|
x = x.clone()
|
||||||
|
x[:, ::2, :] *= (1.0 - strength)
|
||||||
|
return x.clamp(0, 1)
|
||||||
|
|
||||||
|
def get_default_augmentations() -> T.Compose:
|
||||||
|
"""Domain randomisation pipeline calibrated for real handheld-photo conditions."""
|
||||||
|
return T.Compose([
|
||||||
|
T.ToDtype(torch.float32, scale=True),
|
||||||
|
# Spatial — wider range to cover camera angle and zoom variation
|
||||||
|
T.RandomAffine(degrees=25, translate=(0.05, 0.05), shear=8),
|
||||||
|
T.RandomResizedCrop(size=(128, 128), scale=(0.75, 1.0), ratio=(0.85, 1.15), antialias=True),
|
||||||
|
# Colour / sensor — stronger to cover screen glare and ambient lighting
|
||||||
|
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.05),
|
||||||
|
T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)),
|
||||||
|
# Sensor noise and LCD scan lines
|
||||||
|
T.Lambda(add_gaussian_noise),
|
||||||
|
T.Lambda(add_scan_lines),
|
||||||
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Test the dataset
|
||||||
|
ds = SpindaDataset(size=5, transform=get_default_augmentations())
|
||||||
|
img, target = ds[0]
|
||||||
|
print(f"Image shape: {img.shape}")
|
||||||
|
print(f"Target (normalized 0-1): {target}")
|
||||||
|
print(f"Target (grid units): {target * 15.0}")
|
||||||
72
src/data/generate_aug_test_set.py
Normal file
72
src/data/generate_aug_test_set.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""Generate a fixed augmented test set for measuring domain-adaptation progress.
|
||||||
|
|
||||||
|
Images are saved post-augmentation but pre-normalisation, so the loader only
|
||||||
|
needs to normalise them — making the dataset stable and epoch-comparable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms.v2 as T
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from src.data.dataset import add_gaussian_noise, add_scan_lines
|
||||||
|
from src.data.high_fidelity_generator import generate_high_fidelity_spinda
|
||||||
|
|
||||||
|
|
||||||
|
# Augmentation without the final Normalize step — we bake everything else in.
|
||||||
|
_AUG = T.Compose([
|
||||||
|
T.ToDtype(torch.float32, scale=True),
|
||||||
|
T.RandomAffine(degrees=25, translate=(0.05, 0.05), shear=8),
|
||||||
|
T.RandomResizedCrop(size=(128, 128), scale=(0.75, 1.0), ratio=(0.85, 1.15), antialias=True),
|
||||||
|
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.05),
|
||||||
|
T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)),
|
||||||
|
T.Lambda(add_gaussian_noise),
|
||||||
|
T.Lambda(add_scan_lines),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def generate_aug_test_set(size: int = 500, output_dir: str = "data/aug_test") -> None:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
metadata = []
|
||||||
|
|
||||||
|
print(f"Generating {size} augmented test samples → {output_dir}")
|
||||||
|
for i in tqdm(range(size)):
|
||||||
|
pid = random.getrandbits(32)
|
||||||
|
pid_hex = f"{pid:08x}"
|
||||||
|
|
||||||
|
# Random background colour, same as training
|
||||||
|
bg = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
||||||
|
img_bgr = generate_high_fidelity_spinda(pid, bg_color=bg)
|
||||||
|
img_rgb = img_bgr[:, :, ::-1].copy()
|
||||||
|
|
||||||
|
# Apply augmentations and convert back to uint8 PNG
|
||||||
|
t = torch.from_numpy(img_rgb).permute(2, 0, 1)
|
||||||
|
t = _AUG(t).clamp(0, 1)
|
||||||
|
aug_np = (t.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
||||||
|
aug_bgr = cv2.cvtColor(aug_np, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
img_path = os.path.join(output_dir, f"sample_{i:04d}.png")
|
||||||
|
cv2.imwrite(img_path, aug_bgr)
|
||||||
|
|
||||||
|
raw_coords = [
|
||||||
|
int(pid_hex[-1], 16), int(pid_hex[-2], 16),
|
||||||
|
int(pid_hex[-3], 16), int(pid_hex[-4], 16),
|
||||||
|
int(pid_hex[3], 16), int(pid_hex[2], 16),
|
||||||
|
int(pid_hex[1], 16), int(pid_hex[0], 16),
|
||||||
|
]
|
||||||
|
metadata.append({"img_path": img_path, "pid_hex": pid_hex, "target": raw_coords})
|
||||||
|
|
||||||
|
with open(os.path.join(output_dir, "metadata.json"), "w") as f:
|
||||||
|
json.dump(metadata, f, indent=4)
|
||||||
|
|
||||||
|
print(f"Done — augmented test set saved to {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
random.seed(99)
|
||||||
|
generate_aug_test_set()
|
||||||
47
src/data/generate_val_set.py
Normal file
47
src/data/generate_val_set.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
from tqdm import tqdm
|
||||||
|
from src.data.high_fidelity_generator import generate_high_fidelity_spinda
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
def generate_fixed_val_set(size: int = 1000, output_dir: str = "data/val"):
|
||||||
|
"""Generates a fixed set of Spinda images and their targets for validation."""
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
metadata = []
|
||||||
|
|
||||||
|
print(f"Generating {size} validation samples...")
|
||||||
|
for i in tqdm(range(size)):
|
||||||
|
pid = random.getrandbits(32)
|
||||||
|
pid_hex = f"{pid:08x}"
|
||||||
|
|
||||||
|
# Generate image
|
||||||
|
img_bgr = generate_high_fidelity_spinda(pid)
|
||||||
|
img_path = os.path.join(output_dir, f"sample_{i:04d}.png")
|
||||||
|
cv2.imwrite(img_path, img_bgr)
|
||||||
|
|
||||||
|
# Extract raw coordinates (0-15)
|
||||||
|
raw_coords = [
|
||||||
|
int(pid_hex[-1], 16), int(pid_hex[-2], 16),
|
||||||
|
int(pid_hex[-3], 16), int(pid_hex[-4], 16),
|
||||||
|
int(pid_hex[3], 16), int(pid_hex[2], 16),
|
||||||
|
int(pid_hex[1], 16), int(pid_hex[0], 16)
|
||||||
|
]
|
||||||
|
|
||||||
|
metadata.append({
|
||||||
|
"img_path": img_path,
|
||||||
|
"pid_hex": pid_hex,
|
||||||
|
"target": raw_coords
|
||||||
|
})
|
||||||
|
|
||||||
|
with open(os.path.join(output_dir, "metadata.json"), "w") as f:
|
||||||
|
json.dump(metadata, f, indent=4)
|
||||||
|
|
||||||
|
print(f"Validation set generated in {output_dir}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Seed for reproducibility of the validation set itself
|
||||||
|
random.seed(42)
|
||||||
|
generate_fixed_val_set()
|
||||||
81
src/data/generator.py
Normal file
81
src/data/generator.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
|
# Image and Grid constants
|
||||||
|
IMG_SIZE = 128
|
||||||
|
GRID_SIZE = 16
|
||||||
|
SPOT_RADIUS = 6 # Approximate radius in pixels
|
||||||
|
|
||||||
|
# Colors (BGR)
|
||||||
|
FACE_COLOR = (180, 220, 240) # Pale cream/tan
|
||||||
|
SPOT_COLOR = (50, 50, 200) # Reddish-orange
|
||||||
|
EYE_COLOR = (0, 0, 0)
|
||||||
|
|
||||||
|
def extract_coords(pid: int, mode: str = "standard") -> List[Tuple[int, int]]:
|
||||||
|
"""Extracts four (x, y) coordinates from a 32-bit PID/EC.
|
||||||
|
|
||||||
|
Standard (Little-Endian): Byte 0 (LL), 1 (LR), 2 (UL), 3 (UR)
|
||||||
|
BDSP (Big-Endian): Byte 3 (LL), 2 (LR), 1 (UL), 0 (UR)
|
||||||
|
"""
|
||||||
|
bytes_list = [
|
||||||
|
(pid >> 0) & 0xFF, # Byte 0
|
||||||
|
(pid >> 8) & 0xFF, # Byte 1
|
||||||
|
(pid >> 16) & 0xFF, # Byte 2
|
||||||
|
(pid >> 24) & 0xFF, # Byte 3
|
||||||
|
]
|
||||||
|
|
||||||
|
if mode == "bdsp":
|
||||||
|
# BDSP reads the bytes in reverse order
|
||||||
|
ordered_bytes = [bytes_list[3], bytes_list[2], bytes_list[1], bytes_list[0]]
|
||||||
|
else:
|
||||||
|
ordered_bytes = bytes_list
|
||||||
|
|
||||||
|
coords = []
|
||||||
|
for byte in ordered_bytes:
|
||||||
|
x = byte & 0x0F
|
||||||
|
y = (byte >> 4) & 0x0F
|
||||||
|
coords.append((x, y))
|
||||||
|
return coords
|
||||||
|
|
||||||
|
def generate_spinda_face(pid: int, mode: str = "standard") -> np.ndarray:
|
||||||
|
"""Generates a procedural Spinda face with spots based on the PID."""
|
||||||
|
# Create blank canvas (white)
|
||||||
|
img = np.ones((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8) * 255
|
||||||
|
|
||||||
|
# Draw Ears
|
||||||
|
cv2.circle(img, (IMG_SIZE // 2 - 40, IMG_SIZE // 2 - 50), 20, FACE_COLOR, -1)
|
||||||
|
cv2.circle(img, (IMG_SIZE // 2 - 40, IMG_SIZE // 2 - 50), 20, (0, 0, 0), 2)
|
||||||
|
cv2.circle(img, (IMG_SIZE // 2 + 40, IMG_SIZE // 2 - 50), 20, FACE_COLOR, -1)
|
||||||
|
cv2.circle(img, (IMG_SIZE // 2 + 40, IMG_SIZE // 2 - 50), 20, (0, 0, 0), 2)
|
||||||
|
|
||||||
|
# Draw main face (oval)
|
||||||
|
center = (IMG_SIZE // 2, IMG_SIZE // 2)
|
||||||
|
axes = (50, 60)
|
||||||
|
cv2.ellipse(img, center, axes, 0, 0, 360, FACE_COLOR, -1)
|
||||||
|
cv2.ellipse(img, center, axes, 0, 0, 360, (0, 0, 0), 2) # Outline
|
||||||
|
|
||||||
|
# Fixed Eyes
|
||||||
|
cv2.circle(img, (center[0] - 15, center[1] - 10), 4, EYE_COLOR, -1)
|
||||||
|
cv2.circle(img, (center[0] + 15, center[1] - 10), 4, EYE_COLOR, -1)
|
||||||
|
|
||||||
|
# Define Spot Zones (Relative to the 16x16 grid)
|
||||||
|
# Heuristic mapping for a more "Pokemon-like" layout
|
||||||
|
# Spot 1 (LL), Spot 2 (LR), Spot 3 (UL), Spot 4 (UR)
|
||||||
|
# These offsets are designed to place spots in their respective quadrants
|
||||||
|
quadrant_offsets = [
|
||||||
|
(center[0] - 45, center[1] + 5), # LL (Face)
|
||||||
|
(center[0] + 5, center[1] + 5), # LR (Face)
|
||||||
|
(center[0] - 45, center[1] - 55), # UL (Ear/Upper Face)
|
||||||
|
(center[0] + 5, center[1] - 55), # UR (Ear/Upper Face)
|
||||||
|
]
|
||||||
|
|
||||||
|
coords = extract_coords(pid, mode=mode)
|
||||||
|
for i, (x, y) in enumerate(coords):
|
||||||
|
offset_x, offset_y = quadrant_offsets[i]
|
||||||
|
px = int(offset_x + x * 2.5) # Scale grid to quadrant
|
||||||
|
py = int(offset_y + y * 2.5)
|
||||||
|
|
||||||
|
cv2.circle(img, (px, py), SPOT_RADIUS, SPOT_COLOR, -1)
|
||||||
|
|
||||||
|
return img
|
||||||
95
src/data/high_fidelity_generator.py
Normal file
95
src/data/high_fidelity_generator.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
IMG_SIZE = 128
|
||||||
|
BASE_PATH = "assets"
|
||||||
|
|
||||||
|
# Offsets from ProfessorRex's Spinda_generator.py (Observer Perspective)
|
||||||
|
# This matches the script's PID_to_Coordinates logic:
|
||||||
|
# TL (Spot 1): Nibble 0, 1 (PID[-1], PID[-2])
|
||||||
|
# TR (Spot 2): Nibble 2, 3 (PID[-3], PID[-4]) + (24, 1)
|
||||||
|
# BL (Spot 3): Nibble 4, 5 (PID[3], PID[2]) + (6, 18)
|
||||||
|
# BR (Spot 4): Nibble 6, 7 (PID[1], PID[0]) + (18, 19)
|
||||||
|
SPOT_BASE_OFFSETS = [
|
||||||
|
(0, 0), # Spot 1 (TL)
|
||||||
|
(24, 1), # Spot 2 (TR)
|
||||||
|
(6, 18), # Spot 3 (BL)
|
||||||
|
(18, 19), # Spot 4 (BR)
|
||||||
|
]
|
||||||
|
|
||||||
|
def extract_coords_rex(pid_hex: str) -> List[Tuple[int, int]]:
|
||||||
|
"""Extracts coordinates following ProfessorRex's PID_to_Coordinates logic.
|
||||||
|
PID is expected to be an 8-char hex string.
|
||||||
|
"""
|
||||||
|
pid = pid_hex.lower().zfill(8)
|
||||||
|
TL = (int(pid[-1], 16), int(pid[-2], 16))
|
||||||
|
TR = (int(pid[-3], 16) + 24, int(pid[-4], 16) + 1)
|
||||||
|
BL = (int(pid[3], 16) + 6, int(pid[2], 16) + 18)
|
||||||
|
BR = (int(pid[1], 16) + 18, int(pid[0], 16) + 19)
|
||||||
|
return [TL, TR, BL, BR]
|
||||||
|
|
||||||
|
def generate_high_fidelity_spinda(
|
||||||
|
pid: int,
|
||||||
|
bg_color: tuple[int, int, int] = (255, 255, 255),
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Generates a high-fidelity Spinda face using correctly sized assets."""
|
||||||
|
pid_hex = f"{pid:08x}"
|
||||||
|
|
||||||
|
# 1. Load Assets
|
||||||
|
base_img = Image.open(os.path.join(BASE_PATH, "Spinda_Base_Top.png")).convert("RGBA")
|
||||||
|
head_img = Image.open(os.path.join(BASE_PATH, "Spinda_Head.png")).convert("RGBA")
|
||||||
|
head_data = np.array(head_img)
|
||||||
|
|
||||||
|
spot_names = ["Spot_TL.png", "Spot_TR.png", "Spot_BL.png", "Spot_BR.png"]
|
||||||
|
spots = [Image.open(os.path.join(BASE_PATH, name)).convert("RGBA") for name in spot_names]
|
||||||
|
|
||||||
|
# 2. Create Pattern Layer (Integer grid like in Rex's script)
|
||||||
|
W, H = base_img.size
|
||||||
|
pattern_grid = np.zeros((H, W), dtype=np.uint8)
|
||||||
|
|
||||||
|
coords = extract_coords_rex(pid_hex)
|
||||||
|
|
||||||
|
for i, (px_start, py_start) in enumerate(coords):
|
||||||
|
spot_arr = np.array(spots[i])
|
||||||
|
sh, sw = spot_arr.shape[:2]
|
||||||
|
|
||||||
|
for sy in range(sh):
|
||||||
|
for sx in range(sw):
|
||||||
|
# If spot pixel is white (active) in the mask
|
||||||
|
# Check for white pixels (R,G,B > 200)
|
||||||
|
if spot_arr[sy, sx, 0] > 200:
|
||||||
|
tx, ty = px_start + sx, py_start + sy
|
||||||
|
if 0 <= tx < W and 0 <= ty < H:
|
||||||
|
# Mark as active
|
||||||
|
pattern_grid[ty, tx] = 1
|
||||||
|
|
||||||
|
# 3. Colourize
|
||||||
|
# Create an empty RGBA layer for the spots
|
||||||
|
spot_layer = np.zeros((H, W, 4), dtype=np.uint8)
|
||||||
|
for y in range(H):
|
||||||
|
for x in range(W):
|
||||||
|
if pattern_grid[y, x] > 0:
|
||||||
|
# Take color from Spinda_Head.png
|
||||||
|
spot_layer[y, x] = head_data[y, x]
|
||||||
|
|
||||||
|
# 4. Composite
|
||||||
|
spot_layer_img = Image.fromarray(spot_layer, "RGBA")
|
||||||
|
combined = Image.alpha_composite(base_img, spot_layer_img)
|
||||||
|
|
||||||
|
# 5. Final Canvas (128x128)
|
||||||
|
final_img = Image.new("RGBA", (IMG_SIZE, IMG_SIZE), (*bg_color, 255))
|
||||||
|
offset = ((IMG_SIZE - W) // 2, (IMG_SIZE - H) // 2)
|
||||||
|
final_img.paste(combined, offset, combined)
|
||||||
|
|
||||||
|
return cv2.cvtColor(np.array(final_img), cv2.COLOR_RGBA2BGR)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Test with 0x12345678
|
||||||
|
test_pid = 0x12345678
|
||||||
|
img = generate_high_fidelity_spinda(test_pid)
|
||||||
|
cv2.imwrite("sample_high_fidelity_v3.png", img)
|
||||||
|
print(f"Corrected High-fidelity sample saved to sample_high_fidelity_v3.png for PID: {hex(test_pid)}")
|
||||||
21
src/data/visualize_sample.py
Normal file
21
src/data/visualize_sample.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import cv2
|
||||||
|
from src.data.generator import generate_spinda_face
|
||||||
|
|
||||||
|
def save_sample_image():
|
||||||
|
# Example PID (from documentation/common examples if possible)
|
||||||
|
# Let's use a PID that should have distinct spots
|
||||||
|
# 0x12345678 ->
|
||||||
|
# Byte 0: 0x78 (X=8, Y=7)
|
||||||
|
# Byte 1: 0x56 (X=6, Y=5)
|
||||||
|
# Byte 2: 0x34 (X=4, Y=3)
|
||||||
|
# Byte 3: 0x12 (X=2, Y=1)
|
||||||
|
test_pid = 0x12345678
|
||||||
|
|
||||||
|
img = generate_spinda_face(test_pid)
|
||||||
|
|
||||||
|
# Save the image
|
||||||
|
cv2.imwrite("sample_spinda.png", img)
|
||||||
|
print(f"Sample Spinda image saved to sample_spinda.png for PID: {hex(test_pid)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
save_sample_image()
|
||||||
58
src/models/inference.py
Normal file
58
src/models/inference.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms.v2 as T
|
||||||
|
from PIL import Image
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from src.models.regression_model import SpindaRegressionModel
|
||||||
|
|
||||||
|
|
||||||
|
class SpindaInference:
|
||||||
|
"""Loads the trained model and predicts spot coordinates from an image crop."""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str = "models/best_spinda_model.pth"):
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.model = SpindaRegressionModel(pretrained=False)
|
||||||
|
self.model.load_state_dict(
|
||||||
|
torch.load(model_path, map_location=self.device)
|
||||||
|
)
|
||||||
|
self.model.to(self.device).eval()
|
||||||
|
|
||||||
|
self.transform = T.Compose([
|
||||||
|
T.Resize((128, 128)),
|
||||||
|
T.ToDtype(torch.float32, scale=True),
|
||||||
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
|
||||||
|
def predict(self, image_path: str) -> Tuple[List[int], str]:
|
||||||
|
"""Predict the 8 grid coordinates and return them with a fingerprint string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
grid_coords: list of 8 integers in [0, 15]
|
||||||
|
fingerprint: "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
|
||||||
|
"""
|
||||||
|
img_bgr = cv2.imread(image_path)
|
||||||
|
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||||
|
img_tensor = torch.from_numpy(np.array(img_rgb)).permute(2, 0, 1)
|
||||||
|
img_tensor = self.transform(img_tensor).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.model(img_tensor) # (1, 8, 16)
|
||||||
|
|
||||||
|
grid_coords = logits.argmax(dim=2).squeeze(0).cpu().tolist() # [8]
|
||||||
|
fingerprint = "-".join(f"{c:02d}" for c in grid_coords)
|
||||||
|
|
||||||
|
return grid_coords, fingerprint
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: python src/models/inference.py <image_path>")
|
||||||
|
else:
|
||||||
|
inf = SpindaInference()
|
||||||
|
coords, fingerprint = inf.predict(sys.argv[1])
|
||||||
|
print(f"Predicted Grid Coordinates: {coords}")
|
||||||
|
print(f"Visual Fingerprint: {fingerprint}")
|
||||||
47
src/models/regression_model.py
Normal file
47
src/models/regression_model.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision import models
|
||||||
|
from torchvision.models import ResNet18_Weights, ResNet34_Weights
|
||||||
|
|
||||||
|
_BACKBONES = {
|
||||||
|
"resnet18": (models.resnet18, ResNet18_Weights.DEFAULT),
|
||||||
|
"resnet34": (models.resnet34, ResNet34_Weights.DEFAULT),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SpindaRegressionModel(nn.Module):
|
||||||
|
"""ResNet backbone with 8 independent 16-class coordinate heads.
|
||||||
|
|
||||||
|
Each of the 8 output coordinates (4 spots × x, y) is treated as a
|
||||||
|
16-class classification problem over the [0, 15] nibble grid.
|
||||||
|
This eliminates the float→integer rounding step and lets CrossEntropy
|
||||||
|
directly optimise for exact coordinate prediction.
|
||||||
|
|
||||||
|
Output shape: (B, 8, 16) — unnormalised logits.
|
||||||
|
Prediction: output.argmax(dim=2) → (B, 8) integer coordinates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pretrained: bool = True, backbone: str = "resnet18"):
|
||||||
|
super().__init__()
|
||||||
|
if backbone not in _BACKBONES:
|
||||||
|
raise ValueError(f"backbone must be one of {list(_BACKBONES)}; got {backbone!r}")
|
||||||
|
factory, default_weights = _BACKBONES[backbone]
|
||||||
|
weights = default_weights if pretrained else None
|
||||||
|
net = factory(weights=weights)
|
||||||
|
# Strip the final FC; keep the feature extractor + average pool.
|
||||||
|
self.features = nn.Sequential(*list(net.children())[:-1])
|
||||||
|
# 8 coordinates × 16 classes each (512-dim output for both resnet18/34)
|
||||||
|
self.classifier = nn.Linear(512, 8 * 16)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.features(x) # (B, 512, 1, 1)
|
||||||
|
x = x.flatten(1) # (B, 512)
|
||||||
|
x = self.classifier(x) # (B, 128)
|
||||||
|
return x.view(-1, 8, 16) # (B, 8, 16)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
for name in ("resnet18", "resnet34"):
|
||||||
|
model = SpindaRegressionModel(pretrained=False, backbone=name)
|
||||||
|
out = model(torch.randn(2, 3, 128, 128))
|
||||||
|
print(f"{name}: output {out.shape}, predictions {out.argmax(dim=2)}")
|
||||||
210
src/models/train.py
Normal file
210
src/models/train.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
import torchvision.transforms.v2 as T
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from src.data.dataset import SpindaDataset, get_default_augmentations
|
||||||
|
from src.models.regression_model import SpindaRegressionModel
|
||||||
|
|
||||||
|
|
||||||
|
class SpindaEvalDataset(Dataset):
|
||||||
|
"""Loads a fixed evaluation set (clean val or augmented test).
|
||||||
|
|
||||||
|
Images are stored post-augmentation, pre-normalisation; this class
|
||||||
|
applies only the normalisation step so the on-disk images are stable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_dir: str):
|
||||||
|
with open(os.path.join(data_dir, "metadata.json")) as f:
|
||||||
|
self.metadata = json.load(f)
|
||||||
|
self.transform = T.Compose([
|
||||||
|
T.ToDtype(torch.float32, scale=True),
|
||||||
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.metadata)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
item = self.metadata[idx]
|
||||||
|
img_bgr = cv2.imread(item["img_path"])
|
||||||
|
img_rgb = img_bgr[:, :, ::-1].copy()
|
||||||
|
img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1)
|
||||||
|
img_tensor = self.transform(img_tensor)
|
||||||
|
target = torch.tensor(item["target"], dtype=torch.long)
|
||||||
|
return img_tensor, target
|
||||||
|
|
||||||
|
|
||||||
|
# BL_x (index 4) and BL_y (index 5) are the weakest coordinates; upweight them
|
||||||
|
# so the optimiser focuses more of its gradient on the hardest spot.
|
||||||
|
_COORD_WEIGHTS = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.5, 2.5, 1.0, 1.0])
|
||||||
|
|
||||||
|
|
||||||
|
def _weighted_loss(
|
||||||
|
logits: torch.Tensor, targets: torch.Tensor, weights: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""CrossEntropy with per-coordinate weights. logits: (B,8,16), targets: (B,8)."""
|
||||||
|
B = logits.size(0)
|
||||||
|
per = F.cross_entropy(logits.view(-1, 16), targets.view(-1), reduction="none")
|
||||||
|
return (per.view(B, 8) * weights).mean()
|
||||||
|
|
||||||
|
|
||||||
|
def _worker_init_fn(worker_id: int) -> None:
|
||||||
|
"""Give each DataLoader worker a unique random seed so they generate different PIDs."""
|
||||||
|
seed = torch.initial_seed() % (2 ** 32)
|
||||||
|
random.seed(seed + worker_id)
|
||||||
|
np.random.seed(seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(
|
||||||
|
epochs: int = 50,
|
||||||
|
batch_size: int = 64,
|
||||||
|
lr: float = 1e-4,
|
||||||
|
resume: bool = False,
|
||||||
|
model_path: str = "models/best_spinda_model.pth",
|
||||||
|
num_workers: int = 4,
|
||||||
|
epoch_size: int = 200000,
|
||||||
|
backbone: str = "resnet18",
|
||||||
|
save_path: str = "",
|
||||||
|
) -> None:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
train_ds = SpindaDataset(size=epoch_size, transform=get_default_augmentations())
|
||||||
|
val_ds = SpindaEvalDataset("data/val")
|
||||||
|
aug_test_ds = SpindaEvalDataset("data/aug_test") if os.path.exists("data/aug_test") else None
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_ds, batch_size=batch_size, shuffle=True,
|
||||||
|
num_workers=num_workers, worker_init_fn=_worker_init_fn if num_workers > 0 else None,
|
||||||
|
persistent_workers=num_workers > 0, pin_memory=True,
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
||||||
|
aug_loader = DataLoader(aug_test_ds, batch_size=batch_size, shuffle=False, num_workers=0) if aug_test_ds else None
|
||||||
|
|
||||||
|
checkpoint_path = save_path or f"models/best_{backbone}_model.pth"
|
||||||
|
|
||||||
|
if resume:
|
||||||
|
model = SpindaRegressionModel(pretrained=False, backbone=backbone).to(device)
|
||||||
|
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||||
|
print(f"Resumed from {model_path}")
|
||||||
|
else:
|
||||||
|
model = SpindaRegressionModel(pretrained=True, backbone=backbone).to(device)
|
||||||
|
coord_weights = _COORD_WEIGHTS.to(device)
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||||
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer, mode="min", factor=0.5, patience=3
|
||||||
|
)
|
||||||
|
|
||||||
|
best_exact_rate = 0.0
|
||||||
|
epochs_without_improvement = 0
|
||||||
|
early_stop_patience = 10
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
# ── Training ──────────────────────────────────────────────────
|
||||||
|
model.train()
|
||||||
|
train_loss = 0.0
|
||||||
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} [Train]")
|
||||||
|
for images, targets in pbar:
|
||||||
|
images = images.to(device)
|
||||||
|
targets = targets.to(device) # (B, 8) long, values 0-15
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
logits = model(images) # (B, 8, 16)
|
||||||
|
loss = _weighted_loss(logits, targets, coord_weights)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
train_loss += loss.item() * images.size(0)
|
||||||
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
|
|
||||||
|
train_loss /= len(train_loader.dataset)
|
||||||
|
|
||||||
|
# ── Validation ────────────────────────────────────────────────
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
exact_matches = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for images, targets in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{epochs} [Val]"):
|
||||||
|
images = images.to(device)
|
||||||
|
targets = targets.to(device)
|
||||||
|
|
||||||
|
logits = model(images) # (B, 8, 16)
|
||||||
|
loss = F.cross_entropy(logits.view(-1, 16), targets.view(-1))
|
||||||
|
val_loss += loss.item() * images.size(0)
|
||||||
|
|
||||||
|
preds = logits.argmax(dim=2) # (B, 8)
|
||||||
|
exact_matches += torch.all(preds == targets, dim=1).sum().item()
|
||||||
|
|
||||||
|
val_loss /= len(val_loader.dataset)
|
||||||
|
exact_rate = exact_matches / len(val_loader.dataset)
|
||||||
|
|
||||||
|
# ── Augmented test set ────────────────────────────────────────
|
||||||
|
aug_exact_rate = 0.0
|
||||||
|
if aug_loader is not None:
|
||||||
|
aug_exact = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for images, targets in aug_loader:
|
||||||
|
images, targets = images.to(device), targets.to(device)
|
||||||
|
logits = model(images)
|
||||||
|
preds = logits.argmax(dim=2)
|
||||||
|
aug_exact += torch.all(preds == targets, dim=1).sum().item()
|
||||||
|
aug_exact_rate = aug_exact / len(aug_test_ds)
|
||||||
|
|
||||||
|
aug_str = f" Aug Test: {aug_exact_rate:.2%}" if aug_loader else ""
|
||||||
|
print(
|
||||||
|
f"Epoch {epoch + 1}: "
|
||||||
|
f"Train Loss: {train_loss:.4f} "
|
||||||
|
f"Val Loss: {val_loss:.4f} "
|
||||||
|
f"Clean Val: {exact_rate:.2%}"
|
||||||
|
f"{aug_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler.step(val_loss)
|
||||||
|
|
||||||
|
# Save on exact-match improvement (the metric that actually matters)
|
||||||
|
if exact_rate > best_exact_rate:
|
||||||
|
best_exact_rate = exact_rate
|
||||||
|
epochs_without_improvement = 0
|
||||||
|
os.makedirs("models", exist_ok=True)
|
||||||
|
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
||||||
|
torch.save(model.state_dict(), checkpoint_path)
|
||||||
|
print(f" → Saved best model to {checkpoint_path} (clean val exact match: {best_exact_rate:.2%})")
|
||||||
|
else:
|
||||||
|
epochs_without_improvement += 1
|
||||||
|
if epochs_without_improvement >= early_stop_patience:
|
||||||
|
print(f" → No improvement for {early_stop_patience} epochs. Stopping early.")
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--epochs", type=int, default=50)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=64)
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-4)
|
||||||
|
parser.add_argument("--resume", action="store_true", help="Fine-tune from --model_path checkpoint")
|
||||||
|
parser.add_argument("--model_path", type=str, default="models/best_spinda_model.pth")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker count (0 = main process only)")
|
||||||
|
parser.add_argument("--epoch_size", type=int, default=200000, help="Virtual dataset size per epoch")
|
||||||
|
parser.add_argument("--backbone", type=str, default="resnet18", choices=["resnet18", "resnet34"])
|
||||||
|
parser.add_argument("--save_path", type=str, default="", help="Override checkpoint save path")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
train_model(
|
||||||
|
epochs=args.epochs, batch_size=args.batch_size, lr=args.lr,
|
||||||
|
resume=args.resume, model_path=args.model_path,
|
||||||
|
num_workers=args.num_workers, epoch_size=args.epoch_size,
|
||||||
|
backbone=args.backbone, save_path=args.save_path,
|
||||||
|
)
|
||||||
61
src/registry/database.py
Normal file
61
src/registry/database.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import sqlite3
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
class SpindaRegistry:
|
||||||
|
"""Handles the storage and lookup of Spinda PIDs based on their visual fingerprints."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str = "data/spinda_registry.db"):
|
||||||
|
self.db_path = db_path
|
||||||
|
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||||
|
self._init_db()
|
||||||
|
|
||||||
|
def _init_db(self):
|
||||||
|
"""Initializes the database with a core table for PIDs and Fingerprints."""
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
# Primary table: Mapping the 8-integer fingerprint to PIDs
|
||||||
|
# Fingerprint format: "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
|
||||||
|
cursor.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS registry (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
fingerprint TEXT NOT NULL,
|
||||||
|
pid_hex TEXT NOT NULL,
|
||||||
|
UNIQUE(fingerprint, pid_hex)
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
# Index for fast O(1)-style lookups by fingerprint
|
||||||
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_fingerprint ON registry(fingerprint)')
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def add_entry(self, fingerprint: str, pid_hex: str):
|
||||||
|
"""Adds a new Spinda entry to the registry."""
|
||||||
|
try:
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO registry (fingerprint, pid_hex) VALUES (?, ?)",
|
||||||
|
(fingerprint, pid_hex)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.IntegrityError:
|
||||||
|
# Entry already exists
|
||||||
|
pass
|
||||||
|
|
||||||
|
def lookup_by_fingerprint(self, fingerprint: str) -> List[str]:
|
||||||
|
"""Returns all PIDs associated with a specific visual fingerprint."""
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT pid_hex FROM registry WHERE fingerprint = ?", (fingerprint,))
|
||||||
|
results = cursor.fetchall()
|
||||||
|
return [row[0] for row in results]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Quick test
|
||||||
|
reg = SpindaRegistry("data/test_registry.db")
|
||||||
|
test_fp = "00-01-02-03-04-05-06-07"
|
||||||
|
test_pid = "ABCDE123"
|
||||||
|
|
||||||
|
reg.add_entry(test_fp, test_pid)
|
||||||
|
matches = reg.lookup_by_fingerprint(test_fp)
|
||||||
|
print(f"Looked up {test_fp}, found PIDs: {matches}")
|
||||||
15
src/scrgr.egg-info/PKG-INFO
Normal file
15
src/scrgr.egg-info/PKG-INFO
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
Metadata-Version: 2.4
|
||||||
|
Name: scrgr
|
||||||
|
Version: 0.1.0
|
||||||
|
Summary: Spinda Coordinate Regression & Global Registry
|
||||||
|
Requires-Python: >=3.10
|
||||||
|
Requires-Dist: torch>=2.0.0
|
||||||
|
Requires-Dist: torchvision>=0.15.0
|
||||||
|
Requires-Dist: opencv-python>=4.7.0
|
||||||
|
Requires-Dist: Pillow>=9.5.0
|
||||||
|
Requires-Dist: numpy>=1.24.0
|
||||||
|
Provides-Extra: dev
|
||||||
|
Requires-Dist: pytest>=7.3.0; extra == "dev"
|
||||||
|
Requires-Dist: ruff>=0.0.270; extra == "dev"
|
||||||
|
Requires-Dist: mypy>=1.3.0; extra == "dev"
|
||||||
|
Requires-Dist: tqdm>=4.65.0; extra == "dev"
|
||||||
12
src/scrgr.egg-info/SOURCES.txt
Normal file
12
src/scrgr.egg-info/SOURCES.txt
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
pyproject.toml
|
||||||
|
src/data/dataset.py
|
||||||
|
src/data/generate_val_set.py
|
||||||
|
src/data/generator.py
|
||||||
|
src/data/high_fidelity_generator.py
|
||||||
|
src/data/visualize_sample.py
|
||||||
|
src/models/regression_model.py
|
||||||
|
src/scrgr.egg-info/PKG-INFO
|
||||||
|
src/scrgr.egg-info/SOURCES.txt
|
||||||
|
src/scrgr.egg-info/dependency_links.txt
|
||||||
|
src/scrgr.egg-info/requires.txt
|
||||||
|
src/scrgr.egg-info/top_level.txt
|
||||||
1
src/scrgr.egg-info/dependency_links.txt
Normal file
1
src/scrgr.egg-info/dependency_links.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
11
src/scrgr.egg-info/requires.txt
Normal file
11
src/scrgr.egg-info/requires.txt
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
torch>=2.0.0
|
||||||
|
torchvision>=0.15.0
|
||||||
|
opencv-python>=4.7.0
|
||||||
|
Pillow>=9.5.0
|
||||||
|
numpy>=1.24.0
|
||||||
|
|
||||||
|
[dev]
|
||||||
|
pytest>=7.3.0
|
||||||
|
ruff>=0.0.270
|
||||||
|
mypy>=1.3.0
|
||||||
|
tqdm>=4.65.0
|
||||||
4
src/scrgr.egg-info/top_level.txt
Normal file
4
src/scrgr.egg-info/top_level.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
data
|
||||||
|
models
|
||||||
|
registry
|
||||||
|
utils
|
||||||
284
src/utils/detector.py
Normal file
284
src/utils/detector.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
import math
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class SpindaDetector:
|
||||||
|
"""Detects and crops a Spinda face from a larger image.
|
||||||
|
|
||||||
|
Two-tier strategy:
|
||||||
|
|
||||||
|
Tier 1 — screenshot / clean image: find individual red spot candidates,
|
||||||
|
cluster them, and derive the crop from known spot geometry.
|
||||||
|
|
||||||
|
Tier 2 — real photos where spots merge: find the Spinda body blob and
|
||||||
|
crop the face region from its top portion, using the body width as a
|
||||||
|
scale reference.
|
||||||
|
|
||||||
|
Both tiers produce a 128×128 BGR image calibrated to the training-data
|
||||||
|
layout (Spinda_Base_Top.png, 52×43, centred in a 128×128 canvas).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Tier 1 constants ────────────────────────────────────────────────
|
||||||
|
# spot cluster span ≈ 24.5 px in the sprite → CROP_RATIO = 128/24.5 ≈ 5.22
|
||||||
|
# Using 5.5 gives a 5 % margin over the exact ratio.
|
||||||
|
_SPOT_CROP_RATIO: float = 5.5
|
||||||
|
# The spot centroid sits at 44.4 % from the top of the 128×128 training crop.
|
||||||
|
# Shifting the crop centre down by this fraction places the centroid there.
|
||||||
|
_SPOT_CENTER_OFFSET: float = 0.056
|
||||||
|
|
||||||
|
# ── Tier 2 constants ────────────────────────────────────────────────
|
||||||
|
# In the sprite: face (Spinda_Base_Top) = 52×43; full body = 52×58.
|
||||||
|
# blob_w ≈ 52 sprite px → scale = blob_w / 52.
|
||||||
|
# crop_side = 128 * scale (same canvas ratio as training).
|
||||||
|
_FACE_W_PX: float = 52.0
|
||||||
|
_FACE_H_PX: float = 43.0
|
||||||
|
_CANVAS_PX: float = 128.0
|
||||||
|
# Blob scoring: reward area (log-scale) and penalise non-square aspect.
|
||||||
|
# Spinda body aspect (h/w) ≈ 58/52 = 1.12.
|
||||||
|
_BLOB_IDEAL_ASPECT: float = 1.12
|
||||||
|
|
||||||
|
def detect_and_crop(self, image_path: str) -> Optional[np.ndarray]:
|
||||||
|
"""Locate the Spinda face and return a 128×128 BGR crop, or None."""
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
if img is None:
|
||||||
|
print(f"Error: could not read {image_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Tier 1: individual spot candidates → tight cluster
|
||||||
|
candidates = self._find_spot_candidates(img)
|
||||||
|
if len(candidates) >= 2:
|
||||||
|
cluster = self._find_tightest_cluster(candidates, img.shape)
|
||||||
|
if cluster is not None:
|
||||||
|
result = self._crop_from_cluster(img, cluster)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
print("Tier 1 (spot cluster) failed; trying Tier 2 (body blob).")
|
||||||
|
|
||||||
|
# Tier 2: full Spinda body blob → face crop
|
||||||
|
blob_rect = self._find_spinda_blob(img)
|
||||||
|
if blob_rect is None:
|
||||||
|
print("Could not detect Spinda.")
|
||||||
|
return None
|
||||||
|
return self._crop_from_blob(img, blob_rect)
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────
|
||||||
|
# Tier 1 helpers
|
||||||
|
# ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _find_spot_candidates(
|
||||||
|
self, img: np.ndarray
|
||||||
|
) -> list[tuple[int, int, float]]:
|
||||||
|
"""Return (cx, cy, area) for each red blob that could be a Spinda spot.
|
||||||
|
|
||||||
|
Filters by area (spots are small, not tiny specks) and circularity
|
||||||
|
(spots are roughly circular; elongated trainer-outfit blobs are excluded).
|
||||||
|
"""
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
img_area = float(h * w)
|
||||||
|
|
||||||
|
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||||
|
lo = cv2.inRange(hsv, np.array([0, 60, 60]), np.array([12, 255, 255]))
|
||||||
|
hi = cv2.inRange(hsv, np.array([155, 60, 60]), np.array([180, 255, 255]))
|
||||||
|
mask = cv2.bitwise_or(lo, hi)
|
||||||
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
||||||
|
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
||||||
|
|
||||||
|
contours, _ = cv2.findContours(
|
||||||
|
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
||||||
|
)
|
||||||
|
|
||||||
|
candidates: list[tuple[int, int, float]] = []
|
||||||
|
for cnt in contours:
|
||||||
|
area = cv2.contourArea(cnt)
|
||||||
|
if not (img_area * 0.00005 < area < img_area * 0.01):
|
||||||
|
continue
|
||||||
|
perimeter = cv2.arcLength(cnt, True)
|
||||||
|
if perimeter == 0:
|
||||||
|
continue
|
||||||
|
if 4 * np.pi * area / (perimeter ** 2) < 0.25:
|
||||||
|
continue
|
||||||
|
M = cv2.moments(cnt)
|
||||||
|
if M["m00"] == 0:
|
||||||
|
continue
|
||||||
|
cx = int(M["m10"] / M["m00"])
|
||||||
|
cy = int(M["m01"] / M["m00"])
|
||||||
|
candidates.append((cx, cy, area))
|
||||||
|
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
def _find_tightest_cluster(
|
||||||
|
self,
|
||||||
|
candidates: list[tuple[int, int, float]],
|
||||||
|
img_shape: tuple[int, ...],
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
|
"""Return the xy-array of the tightest cluster of up to 4 candidates.
|
||||||
|
|
||||||
|
For each candidate as seed, collects its k-1 nearest neighbours and
|
||||||
|
measures the bounding-box span. The seed whose group has the smallest
|
||||||
|
span within plausible face-size bounds wins.
|
||||||
|
"""
|
||||||
|
h, w = img_shape[:2]
|
||||||
|
img_diag = float(np.sqrt(h ** 2 + w ** 2))
|
||||||
|
|
||||||
|
# Drop candidates whose area deviates wildly from the median —
|
||||||
|
# the four spots are similar in size; UI elements are not.
|
||||||
|
areas = np.array([c[2] for c in candidates])
|
||||||
|
med = float(np.median(areas))
|
||||||
|
filtered = [c for c in candidates if med / 6.0 < c[2] < med * 6.0]
|
||||||
|
if len(filtered) < 2:
|
||||||
|
filtered = candidates
|
||||||
|
|
||||||
|
pts = np.array([(c[0], c[1]) for c in filtered], dtype=np.float32)
|
||||||
|
n = len(pts)
|
||||||
|
k = min(4, n)
|
||||||
|
|
||||||
|
best_group: Optional[np.ndarray] = None
|
||||||
|
best_span = float("inf")
|
||||||
|
|
||||||
|
for i in range(n):
|
||||||
|
dists = np.linalg.norm(pts - pts[i], axis=1)
|
||||||
|
idx = np.argsort(dists)[:k]
|
||||||
|
group = pts[idx]
|
||||||
|
span = float(max(
|
||||||
|
group[:, 0].max() - group[:, 0].min(),
|
||||||
|
group[:, 1].max() - group[:, 1].min(),
|
||||||
|
))
|
||||||
|
# Too small → degenerate noise cluster; too large → not a face.
|
||||||
|
if not (img_diag * 0.01 < span < img_diag * 0.45):
|
||||||
|
continue
|
||||||
|
if span < best_span:
|
||||||
|
best_span = span
|
||||||
|
best_group = group
|
||||||
|
|
||||||
|
return best_group
|
||||||
|
|
||||||
|
def _crop_from_cluster(
|
||||||
|
self, img: np.ndarray, cluster: np.ndarray
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
|
"""Crop based on spot-cluster geometry and resize to 128×128."""
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
|
||||||
|
cx = int(cluster[:, 0].mean())
|
||||||
|
cy = int(cluster[:, 1].mean())
|
||||||
|
cluster_span = float(max(
|
||||||
|
cluster[:, 0].max() - cluster[:, 0].min(),
|
||||||
|
cluster[:, 1].max() - cluster[:, 1].min(),
|
||||||
|
))
|
||||||
|
|
||||||
|
face_side = max(
|
||||||
|
int(cluster_span * self._SPOT_CROP_RATIO),
|
||||||
|
int(min(h, w) * 0.12),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shift crop centre down so the spot centroid lands at ~44 % from the top,
|
||||||
|
# matching its position in the training images.
|
||||||
|
cy_adj = cy + int(face_side * self._SPOT_CENTER_OFFSET)
|
||||||
|
x1 = max(0, cx - face_side // 2)
|
||||||
|
y1 = max(0, cy_adj - face_side // 2)
|
||||||
|
x2 = min(w, x1 + face_side)
|
||||||
|
y2 = min(h, y1 + face_side)
|
||||||
|
|
||||||
|
crop = img[y1:y2, x1:x2]
|
||||||
|
return None if crop.size == 0 else cv2.resize(crop, (128, 128))
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────
|
||||||
|
# Tier 2 helpers
|
||||||
|
# ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _find_spinda_blob(
|
||||||
|
self, img: np.ndarray
|
||||||
|
) -> Optional[tuple[int, int, int, int]]:
|
||||||
|
"""Return (x, y, w, h) of the most Spinda-like red blob.
|
||||||
|
|
||||||
|
Score = circularity + 0.2 × log(area / min_area) − 0.1 × |aspect − 1.12|
|
||||||
|
|
||||||
|
This rewards large, roughly circular blobs with a body-like aspect ratio,
|
||||||
|
and discards thin stripes (UI chrome) or very small stray blobs.
|
||||||
|
"""
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
img_area = float(h * w)
|
||||||
|
min_area = img_area * 0.001
|
||||||
|
|
||||||
|
# Use slightly looser HSV bounds to catch faded reds in real photos.
|
||||||
|
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||||
|
lo = cv2.inRange(hsv, np.array([0, 50, 50]), np.array([15, 255, 255]))
|
||||||
|
hi = cv2.inRange(hsv, np.array([150, 50, 50]), np.array([180, 255, 255]))
|
||||||
|
mask = cv2.bitwise_or(lo, hi)
|
||||||
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
||||||
|
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
||||||
|
|
||||||
|
contours, _ = cv2.findContours(
|
||||||
|
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
||||||
|
)
|
||||||
|
|
||||||
|
best_rect: Optional[tuple[int, int, int, int]] = None
|
||||||
|
best_score = float("-inf")
|
||||||
|
|
||||||
|
for cnt in contours:
|
||||||
|
area = cv2.contourArea(cnt)
|
||||||
|
if not (min_area < area < img_area * 0.10):
|
||||||
|
continue
|
||||||
|
rx, ry, rw, rh = cv2.boundingRect(cnt)
|
||||||
|
if rw == 0 or rh == 0:
|
||||||
|
continue
|
||||||
|
aspect = rh / rw
|
||||||
|
if not (0.4 < aspect < 2.5):
|
||||||
|
continue
|
||||||
|
perimeter = cv2.arcLength(cnt, True)
|
||||||
|
circ = 4 * np.pi * area / (perimeter ** 2) if perimeter > 0 else 0
|
||||||
|
score = (
|
||||||
|
circ
|
||||||
|
+ 0.2 * math.log(area / min_area)
|
||||||
|
- 0.1 * abs(aspect - self._BLOB_IDEAL_ASPECT)
|
||||||
|
)
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_rect = (rx, ry, rw, rh)
|
||||||
|
|
||||||
|
return best_rect
|
||||||
|
|
||||||
|
def _crop_from_blob(
|
||||||
|
self, img: np.ndarray, blob_rect: tuple[int, int, int, int]
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
|
"""Crop the face region from the body blob and resize to 128×128.
|
||||||
|
|
||||||
|
The blob width corresponds to 52 sprite pixels (the face/body width).
|
||||||
|
The face occupies the top 43/58 of the body height at that scale.
|
||||||
|
The crop side = 128 × scale so the face fills the same fraction of the
|
||||||
|
output as in the training images.
|
||||||
|
"""
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
bx, by, bw, _ = blob_rect
|
||||||
|
|
||||||
|
scale = bw / self._FACE_W_PX
|
||||||
|
face_h = int(self._FACE_H_PX * scale)
|
||||||
|
crop_side = max(int(self._CANVAS_PX * scale), int(min(h, w) * 0.12))
|
||||||
|
|
||||||
|
face_cx = bx + bw // 2
|
||||||
|
face_cy = by + face_h // 2 # face centre = top of blob + half face height
|
||||||
|
|
||||||
|
x1 = max(0, face_cx - crop_side // 2)
|
||||||
|
y1 = max(0, face_cy - crop_side // 2)
|
||||||
|
x2 = min(w, x1 + crop_side)
|
||||||
|
y2 = min(h, y1 + crop_side)
|
||||||
|
|
||||||
|
crop = img[y1:y2, x1:x2]
|
||||||
|
return None if crop.size == 0 else cv2.resize(crop, (128, 128))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: python src/utils/detector.py <image_path>")
|
||||||
|
else:
|
||||||
|
det = SpindaDetector()
|
||||||
|
crop = det.detect_and_crop(sys.argv[1])
|
||||||
|
if crop is not None:
|
||||||
|
cv2.imwrite("debug_crop.png", crop)
|
||||||
|
print("Saved debug_crop.png")
|
||||||
|
else:
|
||||||
|
print("Failed to detect Spinda.")
|
||||||
61
src/utils/resolver.py
Normal file
61
src/utils/resolver.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
class SpindaResolver:
|
||||||
|
"""Mathematically resolves a visual fingerprint back to its possible PIDs."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def coordinates_to_pid(coords: List[int], mode: str = "standard") -> str:
|
||||||
|
"""
|
||||||
|
Converts 8 grid coordinates (0-15) back to a 32-bit hex PID.
|
||||||
|
|
||||||
|
Standard Mapping (Rex/Little-Endian):
|
||||||
|
Byte 0 (TL): x=coords[0], y=coords[1]
|
||||||
|
Byte 1 (TR): x=coords[2], y=coords[3]
|
||||||
|
Byte 2 (BL): x=coords[4], y=coords[5]
|
||||||
|
Byte 3 (BR): x=coords[6], y=coords[7]
|
||||||
|
"""
|
||||||
|
# Each byte = (Y << 4) | X
|
||||||
|
bytes_list = []
|
||||||
|
for i in range(0, 8, 2):
|
||||||
|
x = coords[i]
|
||||||
|
y = coords[i+1]
|
||||||
|
byte = (y << 4) | x
|
||||||
|
bytes_list.append(byte)
|
||||||
|
|
||||||
|
if mode == "bdsp":
|
||||||
|
# BDSP reads the bytes in reverse order (Big-Endian style)
|
||||||
|
# So we reverse them back
|
||||||
|
bytes_list = bytes_list[::-1]
|
||||||
|
|
||||||
|
# Combine bytes into 32-bit integer
|
||||||
|
pid = 0
|
||||||
|
for i, byte in enumerate(bytes_list):
|
||||||
|
pid |= (byte << (i * 8))
|
||||||
|
|
||||||
|
return f"{pid:08x}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resolve_fingerprint(fingerprint: str) -> dict:
|
||||||
|
"""
|
||||||
|
Takes a fingerprint string "X1-Y1-X2-Y2-X3-Y3-X4-Y4"
|
||||||
|
and returns both possible PIDs (Standard and BDSP).
|
||||||
|
"""
|
||||||
|
coords = [int(c) for c in fingerprint.split("-")]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"standard": SpindaResolver.coordinates_to_pid(coords, mode="standard"),
|
||||||
|
"bdsp": SpindaResolver.coordinates_to_pid(coords, mode="bdsp")
|
||||||
|
}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Test with a known fingerprint
|
||||||
|
# PID 0x12345678 ->
|
||||||
|
# Byte 0: 0x78 (X=8, Y=7)
|
||||||
|
# Byte 1: 0x56 (X=6, Y=5)
|
||||||
|
# Byte 2: 0x34 (X=4, Y=3)
|
||||||
|
# Byte 3: 0x12 (X=2, Y=1)
|
||||||
|
# Fingerprint: 08-07-06-05-04-03-02-01
|
||||||
|
test_fp = "08-07-06-05-04-03-02-01"
|
||||||
|
results = SpindaResolver.resolve_fingerprint(test_fp)
|
||||||
|
print(f"Fingerprint: {test_fp}")
|
||||||
|
print(f"Resolved PIDs: {results}")
|
||||||
Reference in New Issue
Block a user