Skip to content

Commit 03e8beb

Browse files
authored
Merge pull request #48 from ClarkCGA:feature/issue-36-knn-plots
separate knn ablation plots into individual by class, cast category to str, fixes #46
2 parents 6ed6b42 + e5ce6b3 commit 03e8beb

3 files changed

Lines changed: 79 additions & 50 deletions

File tree

gelos/comp_plots.py

Lines changed: 63 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -122,76 +122,100 @@ def distance_matrix(
122122
def knn_purity_plot(
123123
metric_result: dict,
124124
output_path: str | Path = None,
125+
class_labels: dict[str, str] | None = None,
125126
**kwargs,
126127
) -> None:
127128
"""Render KNN class purity comparison as line plots.
128129
129-
Creates a figure with two subplots:
130-
- Top: overall purity vs k, one line per experiment
131-
- Bottom: per-class purity vs k, faceted by experiment
130+
Creates a figure with:
131+
- Top row (full width): overall purity vs k, one line per experiment.
132+
- Facet grid below: one subplot per class, each showing purity vs k with
133+
one line per experiment — makes cross-model comparison direct.
132134
133135
Args:
134136
metric_result: Output from ``knn_purity_comparison`` metric.
135137
output_path: Path to save the figure. Shows interactively if None.
138+
class_labels: Optional mapping from class id (as string) to display
139+
name. Used for facet subplot titles. Falls back to raw class id.
136140
"""
137141
df = metric_result.get("comparison_df", pd.DataFrame())
138142
if df.empty:
139143
logger.warning("no data for KNN purity plot, skipping")
140144
return
141145

146+
class_labels = class_labels or {}
147+
142148
overall = df[df["class"] == "overall"]
143149
per_class = df[df["class"] != "overall"]
144-
experiments = overall["experiment"].unique()
145-
146-
fig, (ax_top, ax_bot) = plt.subplots(2, 1, figsize=(10, 8))
150+
experiments = list(overall["experiment"].unique())
151+
classes = sorted(per_class["class"].unique())
152+
n_classes = len(classes)
147153

148-
# --- Top: overall purity ---
149154
markers = ["o", "s", "^", "D", "v", "P", "X", "*"]
155+
156+
# Layout: 1 row for overall, then a facet grid (up to 4 cols) for classes.
157+
n_cols = min(4, n_classes) if n_classes else 1
158+
n_facet_rows = (n_classes + n_cols - 1) // n_cols if n_classes else 0
159+
fig_height = 4 + 2.5 * n_facet_rows
160+
fig = plt.figure(figsize=(3.5 * n_cols, fig_height))
161+
gs = fig.add_gridspec(1 + n_facet_rows, n_cols, hspace=0.5, wspace=0.3)
162+
163+
# --- Top: overall purity (spans all columns) ---
164+
ax_top = fig.add_subplot(gs[0, :])
150165
for i, exp in enumerate(experiments):
151166
exp_data = overall[overall["experiment"] == exp].sort_values("k")
152-
marker = markers[i % len(markers)]
153-
ax_top.plot(exp_data["k"], exp_data["purity"], marker=marker, label=exp)
154-
167+
ax_top.plot(
168+
exp_data["k"],
169+
exp_data["purity"],
170+
marker=markers[i % len(markers)],
171+
label=exp,
172+
)
155173
ax_top.set_xlabel("k")
156174
ax_top.set_ylabel("Purity")
157175
ax_top.set_ylim(0, 1.05)
158176
ax_top.set_title("Overall KNN Class Purity by Experiment")
159177
ax_top.legend()
160178
ax_top.grid(True, alpha=0.3)
161179

162-
# --- Bottom: per-class purity ---
163-
classes = sorted(per_class["class"].unique())
164-
n_classes = len(classes)
165-
n_experiments = len(experiments)
166-
k_values = sorted(per_class["k"].unique())
167-
x = np.arange(len(k_values))
168-
total_bars = n_classes * n_experiments
169-
width = 0.8 / max(total_bars, 1)
180+
# --- Facets: one subplot per class ---
181+
facet_axes = []
182+
for idx, cls in enumerate(classes):
183+
row = 1 + idx // n_cols
184+
col = idx % n_cols
185+
ax = fig.add_subplot(gs[row, col])
186+
facet_axes.append(ax)
170187

171-
for i, exp in enumerate(experiments):
172-
for j, cls in enumerate(classes):
188+
for i, exp in enumerate(experiments):
173189
subset = per_class[
174190
(per_class["experiment"] == exp) & (per_class["class"] == cls)
175191
].sort_values("k")
176-
offset = (i * n_classes + j - total_bars / 2) * width + width / 2
177-
bar_vals = [
178-
subset[subset["k"] == k]["purity"].values[0]
179-
if len(subset[subset["k"] == k]) > 0
180-
else 0
181-
for k in k_values
182-
]
183-
ax_bot.bar(x + offset, bar_vals, width, label=f"{exp}{cls}")
184-
185-
ax_bot.set_xlabel("k")
186-
ax_bot.set_ylabel("Purity")
187-
ax_bot.set_ylim(0, 1.05)
188-
ax_bot.set_xticks(x)
189-
ax_bot.set_xticklabels([str(k) for k in k_values])
190-
ax_bot.set_title("Per-Class KNN Purity by Experiment")
191-
ax_bot.legend(fontsize=7, ncol=2)
192-
ax_bot.grid(True, alpha=0.3, axis="y")
193-
194-
fig.tight_layout()
192+
if subset.empty:
193+
continue
194+
ax.plot(
195+
subset["k"],
196+
subset["purity"],
197+
marker=markers[i % len(markers)],
198+
label=exp,
199+
)
200+
201+
display = class_labels.get(str(cls), str(cls))
202+
ax.set_title(display)
203+
ax.set_xlabel("k")
204+
ax.set_ylabel("Purity")
205+
ax.set_ylim(0, 1.05)
206+
ax.grid(True, alpha=0.3)
207+
208+
# Single shared legend for facets, placed below the grid.
209+
if facet_axes:
210+
handles, labels = facet_axes[0].get_legend_handles_labels()
211+
if handles:
212+
fig.legend(
213+
handles,
214+
labels,
215+
loc="lower center",
216+
ncol=min(len(labels), 4),
217+
bbox_to_anchor=(0.5, -0.02),
218+
)
195219

196220
if output_path:
197221
plt.savefig(output_path, dpi=300, bbox_inches="tight")

gelos/comparison.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class ComparisonContext:
3838
comp_plots: list[dict]
3939
output_dir: Path
4040
figures_dir: Path
41+
class_labels: dict[str, str]
4142

4243

4344
def _resolve_embedding_path(exp: ComparisonExperiment, processed_data_dir: Path) -> Path:
@@ -85,6 +86,11 @@ def setup_comparison(
8586
comp_metrics = yaml_config.get("comp_metrics", [])
8687
comp_plots = yaml_config.get("comp_plots", [])
8788

89+
# Normalize class label keys to strings so they can be looked up against the
90+
# CSV ``class`` column (which is string-typed on read).
91+
raw_labels = yaml_config.get("class_labels", {}) or {}
92+
class_labels = {str(k): str(v) for k, v in raw_labels.items()}
93+
8894
output_dir = processed_data_dir / "comparisons" / config_stem
8995
output_dir.mkdir(exist_ok=True, parents=True)
9096
figures_dir = figures_base_dir / "comparisons" / config_stem
@@ -98,6 +104,7 @@ def setup_comparison(
98104
comp_plots=comp_plots,
99105
output_dir=output_dir,
100106
figures_dir=figures_dir,
107+
class_labels=class_labels,
101108
)
102109

103110

@@ -123,8 +130,7 @@ def run_comparison(
123130

124131
# Determine whether any metric requires loading embeddings
125132
any_needs_embeddings = any(
126-
getattr(COMP_METRICS.get(m["type"]), "requires_embeddings", True)
127-
for m in ctx.comp_metrics
133+
getattr(COMP_METRICS.get(m["type"]), "requires_embeddings", True) for m in ctx.comp_metrics
128134
)
129135

130136
# Build experiment lists: always build labels-only, load arrays only when needed
@@ -136,18 +142,15 @@ def run_comparison(
136142
for exp in ctx.experiments:
137143
emb_path = _resolve_embedding_path(exp, processed_data_dir)
138144
if not emb_path.exists():
139-
logger.warning(
140-
f"embeddings not found for '{exp.label}' at {emb_path}, skipping"
141-
)
145+
logger.warning(f"embeddings not found for '{exp.label}' at {emb_path}, skipping")
142146
continue
143147
emb = np.load(emb_path)
144148
loaded.append((exp.label, emb))
145149
logger.info(f"loaded embeddings for '{exp.label}': shape={emb.shape}")
146150

147151
if len(loaded) < 2:
148152
logger.warning(
149-
f"need at least 2 experiments with embeddings for comparison, "
150-
f"got {len(loaded)}"
153+
f"need at least 2 experiments with embeddings for comparison, got {len(loaded)}"
151154
)
152155
return {}
153156

@@ -212,6 +215,7 @@ def run_comparison(
212215
p_fn(
213216
metric_results[source_metric],
214217
output_path=output_path,
218+
class_labels=ctx.class_labels,
215219
**p_params,
216220
)
217221
logger.info(f"comparison plot saved to {output_path}")

gelos/plotting.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def scatter_2d(
3636
plot a 2d transform of embeddings colored according to chip category
3737
"""
3838
category_column, color_dict, legend_patches = build_style_from_config(style_cfg)
39-
colors = chip_gdf[category_column].loc[chip_indices].map(color_dict)
39+
colors = chip_gdf[category_column].loc[chip_indices].astype(str).map(color_dict)
4040
transform_title = TRANSFORM_TITLES[t_type]
4141

4242
fig = plt.figure(figsize=(10, 8))
@@ -63,9 +63,10 @@ def scatter_2d(
6363
def build_style_from_config(style_cfg: dict) -> tuple[str, dict, list[Patch]]:
6464
"""Extract category_column, color_dict, and legend_patches from the style config section."""
6565
category_column = style_cfg["category_column"]
66-
color_dict = style_cfg["colors"]
66+
color_dict = {str(k): v for k, v in style_cfg["colors"].items()}
67+
label_dict = {str(k): v for k, v in style_cfg["labels"].items()}
6768
legend_patches = [
68-
Patch(color=color, label=style_cfg["labels"][k]) for k, color in color_dict.items()
69+
Patch(color=color, label=label_dict[k]) for k, color in color_dict.items()
6970
]
7071
return category_column, color_dict, legend_patches
7172

@@ -83,7 +84,7 @@ def temporal_cosine_similarity(
8384
n_timesteps: int = 4,
8485
timestep_labels: list[str] | None = None,
8586
n_cols: int = 6,
86-
ylim: tuple[float, float] = (0,1)
87+
ylim: tuple[float, float] = (0.5,1)
8788
) -> None:
8889
"""Plot cosine similarity between consecutive timesteps per land-cover category.
8990

0 commit comments

Comments
 (0)