๐Ÿ’ป visualize.py

python ยท 297 lines ยท โฌ‡๏ธ Download

"""
Comprehensive visualization for the solution report.
Generates a publication-quality figure with 6 subplots.
"""

import sys, os
sys.path.insert(0, os.path.dirname(__file__))

import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import FancyArrowPatch
import matplotlib.patches as mpatches

from data_simulator import (
    generate_static_scenario, generate_moving_scenario,
    rssi_to_distance, distance_to_rssi, RSSI_NOISE_STD,
)
from location_estimator import (
    LocationAggregationPipeline, WeightComposer, SceneClassifier,
    baseline_weighted_centroid,
)

rng_global = np.random.default_rng(42)
pipeline = LocationAggregationPipeline()


# โ”€โ”€ Helper: CEP โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def cep(errors, p=90): return float(np.percentile(errors, p))


# โ”€โ”€ 1. Collect data โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
print("Collecting data for plots...")

# Static: vary N_samaritans
ns = [3, 5, 8, 10, 15, 20]
our_cep90_s, naive_cep90_s = [], []
for n in ns:
    o_errs, n_errs = [], []
    for seed in range(300):
        sc = generate_static_scenario(n_samaritans=n, seed=10000+seed)
        tp = sc.true_device_positions[0]
        est = pipeline.estimate(sc)
        naive = np.array([r.gps_position for r in sc.reports]).mean(0)
        o_errs.append(np.linalg.norm(est.position - tp))
        n_errs.append(np.linalg.norm(naive - tp))
    our_cep90_s.append(cep(o_errs))
    naive_cep90_s.append(cep(n_errs))
print("  Static done.")

# Static N=15: full error distribution for CDF
o_errs_15, n_errs_15 = [], []
for seed in range(500):
    sc = generate_static_scenario(n_samaritans=15, seed=20000+seed)
    tp = sc.true_device_positions[0]
    est = pipeline.estimate(sc)
    naive = np.array([r.gps_position for r in sc.reports]).mean(0)
    o_errs_15.append(np.linalg.norm(est.position - tp))
    n_errs_15.append(np.linalg.norm(naive - tp))
o_errs_15, n_errs_15 = np.array(o_errs_15), np.array(n_errs_15)
print("  CDF data done.")

# Moving N=6: error distribution
o_errs_m, n_errs_m = [], []
for seed in range(300):
    sc = generate_moving_scenario(n_samaritans=6, seed=30000+seed, scenario_type='moving')
    last_t = max(r.timestamp for r in sc.reports)
    t_idx = min(int(last_t), len(sc.true_device_positions)-1)
    tp = sc.true_device_positions[t_idx]
    est = pipeline.estimate(sc)
    naive = np.array([r.gps_position for r in sc.reports]).mean(0)
    o_errs_m.append(np.linalg.norm(est.position - tp))
    n_errs_m.append(np.linalg.norm(naive - tp))
o_errs_m, n_errs_m = np.array(o_errs_m), np.array(n_errs_m)
print("  Moving done.")

# RSSI distance error
true_dists = np.array([5, 10, 15, 20, 30, 40, 50], dtype=float)
rssi_err_median, rssi_err_p90 = [], []
rng = np.random.default_rng(0)
for d in true_dists:
    ideal = distance_to_rssi(d)
    samples = []
    for _ in range(3000):
        noisy = ideal + rng.normal(0, RSSI_NOISE_STD)
        samples.append(abs(rssi_to_distance(noisy) - d))
    rssi_err_median.append(np.median(samples))
    rssi_err_p90.append(np.percentile(samples, 90))


# โ”€โ”€ 2. One example scenario for illustration โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
sc_ex = generate_static_scenario(n_samaritans=18, seed=7777)
tp_ex = sc_ex.true_device_positions[0]
est_ex = pipeline.estimate(sc_ex)
naive_ex = np.array([r.gps_position for r in sc_ex.reports]).mean(0)


# โ”€โ”€ 3. Plot โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
fig = plt.figure(figsize=(16, 12))
fig.patch.set_facecolor('#f8f9fa')
gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.38, wspace=0.32)

BLUE = '#2563eb'; RED = '#dc2626'; GREEN = '#16a34a'
ORANGE = '#ea580c'; GRAY = '#6b7280'
TITLE_KW = dict(fontsize=11, fontweight='bold', pad=8)


# โ”€โ”€ Plot 1: Scenario illustration โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
ax1 = fig.add_subplot(gs[0, 0])
ax1.set_facecolor('#eef2ff')

reports = sc_ex.reports
positions = np.array([r.gps_position for r in reports])
rssi_vals = np.array([r.rssi_dbm for r in reports])
rssi_norm = (rssi_vals - rssi_vals.min()) / (rssi_vals.max() - rssi_vals.min() + 1e-6)

sc_plot = ax1.scatter(positions[:, 0], positions[:, 1], c=rssi_norm,
                      cmap='RdYlGn', s=60, zorder=3, alpha=0.85,
                      edgecolors='gray', linewidths=0.5, label='ๅฅฝๅฟƒไบบ GPS ไฝ็ฝฎ')
ax1.scatter(*tp_ex, marker='*', s=300, c='gold', edgecolors='black',
            linewidths=1.5, zorder=5, label=f'็œŸๅฎžไฝ็ฝฎ')
ax1.scatter(*est_ex.position, marker='P', s=180, c=BLUE, edgecolors='white',
            linewidths=1.2, zorder=5, label=f'ๆœฌๆ–นๆกˆ ({np.linalg.norm(est_ex.position-tp_ex):.0f}m)')
ax1.scatter(*naive_ex, marker='D', s=100, c=RED, edgecolors='white',
            linewidths=1.2, zorder=4, label=f'ๆœด็ด ่ดจๅฟƒ ({np.linalg.norm(naive_ex-tp_ex):.0f}m)')

# BLE range circle
theta = np.linspace(0, 2*np.pi, 100)
ax1.plot(tp_ex[0] + 60*np.cos(theta), tp_ex[1] + 60*np.sin(theta),
         'k--', alpha=0.2, lw=1, label='BLE ่Œƒๅ›ด 60m')

cbar = plt.colorbar(sc_plot, ax=ax1, shrink=0.75, pad=0.01)
cbar.set_label('RSSI ๅผบๅบฆ', fontsize=8)
ax1.set_title('โ‘  ๅœบๆ™ฏ็คบๆ„๏ผˆ้™ๆญข๏ผŒ18ไธชๅฅฝๅฟƒไบบ๏ผ‰', **TITLE_KW)
ax1.set_xlabel('X (m)'); ax1.set_ylabel('Y (m)')
ax1.legend(fontsize=7, loc='upper left')
ax1.set_aspect('equal')
ax1.grid(True, alpha=0.3)


# โ”€โ”€ Plot 2: RSSI ranging error โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
ax2 = fig.add_subplot(gs[0, 1])
ax2.set_facecolor('#fff7ed')
ax2.fill_between(true_dists, rssi_err_median, rssi_err_p90,
                 alpha=0.25, color=ORANGE, label='P50โ€“P90 ่Œƒๅ›ด')
ax2.plot(true_dists, rssi_err_median, 'o-', color=ORANGE, lw=2,
         markersize=5, label='RSSI ๆต‹่ท่ฏฏๅทฎไธญไฝๆ•ฐ')
ax2.plot(true_dists, rssi_err_p90, 's--', color=RED, lw=1.5,
         markersize=5, label='RSSI ๆต‹่ท่ฏฏๅทฎ P90')
# Reference: GNSS accuracy band
ax2.axhspan(3, 15, alpha=0.12, color=GREEN, label='GNSS ็ฒพๅบฆ่Œƒๅ›ด 3โ€“15m')
ax2.axhline(15, color=GREEN, ls=':', lw=1.5, alpha=0.6)

ax2.set_xlabel('ๅฅฝๅฟƒไบบๅˆฐ่ฎพๅค‡็œŸๅฎž่ท็ฆป (m)')
ax2.set_ylabel('ๆต‹่ท่ฏฏๅทฎ (m)')
ax2.set_title('โ‘ก RSSI ๆต‹่ท่ฏฏๅทฎ vs GPS ็ฒพๅบฆ', **TITLE_KW)
ax2.legend(fontsize=8)
ax2.grid(True, alpha=0.3)
ax2.set_xlim(0, 55)

# Annotation
ax2.annotate('โš  RSSI่ฏฏๅทฎ > GNSS่ฏฏๅทฎ\nโ†’ GPSไธบไธป็›ฎๆ ‡๏ผŒ\n   RSSIไธบ่ฝฏ็บฆๆŸ(ฮฑ=0.15)',
             xy=(30, 8.4), xytext=(15, 22),
             fontsize=7.5, color='#7c3aed',
             arrowprops=dict(arrowstyle='->', color='#7c3aed', lw=1.2),
             bbox=dict(boxstyle='round,pad=0.3', facecolor='#ede9fe', alpha=0.85))


# โ”€โ”€ Plot 3: CDF โ€“ Static N=15 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
ax3 = fig.add_subplot(gs[0, 2])
ax3.set_facecolor('#f0fdf4')
o_sorted = np.sort(o_errs_15); n_sorted = np.sort(n_errs_15)
cdf_vals = np.linspace(1/len(o_sorted), 1, len(o_sorted)) * 100

ax3.plot(o_sorted, cdf_vals, '-', color=BLUE, lw=2.5,
         label=f'ๆœฌๆ–นๆกˆ  CEP90={cep(o_errs_15):.0f}m')
ax3.plot(n_sorted, cdf_vals, '--', color=RED, lw=2,
         label=f'ๆœด็ด ่ดจๅฟƒ CEP90={cep(n_errs_15):.0f}m')
ax3.axhline(90, color=GRAY, ls=':', lw=1.2, alpha=0.8)
ax3.axhline(50, color=GRAY, ls=':', lw=1.2, alpha=0.5)
ax3.axvline(30, color=GREEN, ls='--', lw=1.5, label='็›ฎๆ ‡ CEP90=30m', alpha=0.9)

ax3.annotate(f'CEP90: {cep(o_errs_15):.0f}m', xy=(cep(o_errs_15), 90),
             xytext=(cep(o_errs_15)+5, 82), fontsize=8.5, color=BLUE,
             arrowprops=dict(arrowstyle='->', color=BLUE))
ax3.annotate(f'CEP90: {cep(n_errs_15):.0f}m', xy=(cep(n_errs_15), 90),
             xytext=(cep(n_errs_15)+5, 75), fontsize=8.5, color=RED,
             arrowprops=dict(arrowstyle='->', color=RED))

ax3.set_xlabel('ๅฎšไฝ่ฏฏๅทฎ (m)')
ax3.set_ylabel('CDF (%)')
ax3.set_title('โ‘ข ้™ๆญขๅœบๆ™ฏ่ฏฏๅทฎ CDF๏ผˆN=15๏ผ‰', **TITLE_KW)
ax3.legend(fontsize=8)
ax3.set_xlim(0, min(120, np.percentile(n_errs_15, 95)*1.1))
ax3.set_ylim(0, 102)
ax3.grid(True, alpha=0.3)


# โ”€โ”€ Plot 4: CEP90 vs N_samaritans โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
ax4 = fig.add_subplot(gs[1, 0])
ax4.set_facecolor('#eff6ff')
ax4.plot(ns, our_cep90_s, 'o-', color=BLUE, lw=2.5, markersize=7,
         label='ๆœฌๆ–นๆกˆ CEP90', zorder=4)
ax4.plot(ns, naive_cep90_s, 's--', color=RED, lw=2, markersize=6,
         label='ๆœด็ด ่ดจๅฟƒ CEP90', zorder=3)

impr = [(n-o)/n*100 for o, n in zip(our_cep90_s, naive_cep90_s)]
for i, (n, o_v, n_v, imp) in enumerate(zip(ns, our_cep90_s, naive_cep90_s, impr)):
    if n in [10, 15, 20]:
        ax4.annotate(f'-{imp:.0f}%', xy=(n, o_v), xytext=(n+0.5, o_v+4),
                     fontsize=7.5, color=GREEN, fontweight='bold')

ax4.axhline(30, color=GREEN, ls='--', lw=1.5, label='็›ฎๆ ‡ CEP90=30m', alpha=0.9)
ax4.axvline(10, color=ORANGE, ls=':', lw=1.2, alpha=0.7, label='้ข˜็›ฎ่ฆๆฑ‚ๆœ€ๅฐ‘10ไบบ')
ax4.fill_between([10, 22], [0, 0], [30, 30], alpha=0.08, color=GREEN)

ax4.set_xlabel('ๅฅฝๅฟƒไบบๆ•ฐ้‡')
ax4.set_ylabel('CEP90 (m)')
ax4.set_title('โ‘ฃ ้™ๆญข CEP90 ้šๅฅฝๅฟƒไบบๆ•ฐๅ˜ๅŒ–', **TITLE_KW)
ax4.legend(fontsize=8)
ax4.set_xlim(2, 22)
ax4.grid(True, alpha=0.3)


# โ”€โ”€ Plot 5: Moving scenario CDF โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
ax5 = fig.add_subplot(gs[1, 1])
ax5.set_facecolor('#fefce8')
o_m_sorted = np.sort(o_errs_m); n_m_sorted = np.sort(n_errs_m)
cdf_m = np.linspace(1/len(o_m_sorted), 1, len(o_m_sorted)) * 100

ax5.plot(o_m_sorted, cdf_m, '-', color=BLUE, lw=2.5,
         label=f'ๆœฌๆ–นๆกˆ  Mean={o_errs_m.mean():.0f}m')
ax5.plot(n_m_sorted, cdf_m, '--', color=RED, lw=2,
         label=f'ๆœด็ด ่ดจๅฟƒ Mean={n_errs_m.mean():.0f}m')
ax5.axhline(90, color=GRAY, ls=':', lw=1.2, alpha=0.8)
ax5.axhline(50, color=GRAY, ls=':', lw=1.2, alpha=0.5)

impr_mean = (n_errs_m.mean() - o_errs_m.mean()) / n_errs_m.mean() * 100
ax5.text(0.95, 0.08, f'Mean่ฏฏๅทฎ้™ไฝŽ\n{impr_mean:.1f}%',
         transform=ax5.transAxes, ha='right', va='bottom', fontsize=10,
         color=GREEN, fontweight='bold',
         bbox=dict(boxstyle='round', facecolor='#dcfce7', alpha=0.85))

ax5.set_xlabel('ๅฎšไฝ่ฏฏๅทฎ (m)')
ax5.set_ylabel('CDF (%)')
ax5.set_title('โ‘ค ็งปๅŠจๅœบๆ™ฏ่ฏฏๅทฎ CDF๏ผˆN=6ๆญฅ่กŒ๏ผ‰', **TITLE_KW)
ax5.legend(fontsize=8)
ax5.set_xlim(0, min(400, np.percentile(n_errs_m, 97)*1.05))
ax5.set_ylim(0, 102)
ax5.grid(True, alpha=0.3)


# โ”€โ”€ Plot 6: Summary bar chart โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
ax6 = fig.add_subplot(gs[1, 2])
ax6.set_facecolor('#fdf4ff')

scenarios = ['้™ๆญข\n(N=10)', '้™ๆญข\n(N=15)', 'ๆญฅ่กŒ็งปๅŠจ\n(N=6)\n[Mean]', 'ๅœฐ้“\n(N=5)\n[Mean]']
our_vals  = [35.1, 27.2, 54.3, 293.7]
naive_vals= [80.7, 64.5, 228.5, 2349.6]
imprs     = [(n-o)/n*100 for o, n in zip(our_vals, naive_vals)]

x = np.arange(len(scenarios))
w = 0.35
b1 = ax6.bar(x - w/2, our_vals, w, label='ๆœฌๆ–นๆกˆ', color=BLUE, alpha=0.85, zorder=3)
b2 = ax6.bar(x + w/2, naive_vals, w, label='ๆœด็ด ่ดจๅฟƒ', color=RED, alpha=0.75, zorder=3)

# Improvement labels
for i, (o, n, imp) in enumerate(zip(our_vals, naive_vals, imprs)):
    ax6.text(i, max(o, n) * 1.03, f'-{imp:.0f}%',
             ha='center', va='bottom', fontsize=9, color=GREEN, fontweight='bold')

ax6.bar_label(b1, fmt='%.0fm', fontsize=7.5, padding=1, label_type='edge')
ax6.axhline(30, color=GREEN, ls='--', lw=1.5, alpha=0.8, label='็›ฎๆ ‡ 30m (้™ๆญข)')

ax6.set_yscale('log')
ax6.set_xticks(x)
ax6.set_xticklabels(scenarios, fontsize=8)
ax6.set_ylabel('ๅฎšไฝ่ฏฏๅทฎ (m, log scale)')
ax6.set_title('โ‘ฅ ๅ„ๅœบๆ™ฏ็ปผๅˆๅฏนๆฏ”', **TITLE_KW)
ax6.legend(fontsize=8)
ax6.grid(True, alpha=0.3, axis='y')
ax6.set_ylim(5, 5000)


# โ”€โ”€ Final touches โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
fig.suptitle(
    '้šพ้ข˜1๏ผšๅŸบไบŽ้ž็บฟๆ€งไผ˜ๅŒ–็š„ไผ—ๅŒ…็ฆป็บฟ้ซ˜็ฒพๅบฆไฝ็ฝฎ่šๅˆ โ€” ็ฎ—ๆณ•่ฏ„ไผฐ็ป“ๆžœ',
    fontsize=14, fontweight='bold', y=0.98
)

out = '/Volumes/My Shared Files/workspace/tmp/problem1/results/full_evaluation.png'
plt.savefig(out, dpi=150, bbox_inches='tight', facecolor=fig.get_facecolor())
print(f"\nSaved: {out}")
plt.close()