@@ -122,76 +122,100 @@ def distance_matrix(
122122def knn_purity_plot (
123123 metric_result : dict ,
124124 output_path : str | Path = None ,
125+ class_labels : dict [str , str ] | None = None ,
125126 ** kwargs ,
126127) -> None :
127128 """Render KNN class purity comparison as line plots.
128129
129- Creates a figure with two subplots:
130- - Top: overall purity vs k, one line per experiment
131- - Bottom: per-class purity vs k, faceted by experiment
130+ Creates a figure with:
131+ - Top row (full width): overall purity vs k, one line per experiment.
132+ - Facet grid below: one subplot per class, each showing purity vs k with
133+ one line per experiment — makes cross-model comparison direct.
132134
133135 Args:
134136 metric_result: Output from ``knn_purity_comparison`` metric.
135137 output_path: Path to save the figure. Shows interactively if None.
138+ class_labels: Optional mapping from class id (as string) to display
139+ name. Used for facet subplot titles. Falls back to raw class id.
136140 """
137141 df = metric_result .get ("comparison_df" , pd .DataFrame ())
138142 if df .empty :
139143 logger .warning ("no data for KNN purity plot, skipping" )
140144 return
141145
146+ class_labels = class_labels or {}
147+
142148 overall = df [df ["class" ] == "overall" ]
143149 per_class = df [df ["class" ] != "overall" ]
144- experiments = overall ["experiment" ].unique ()
145-
146- fig , ( ax_top , ax_bot ) = plt . subplots ( 2 , 1 , figsize = ( 10 , 8 ) )
150+ experiments = list ( overall ["experiment" ].unique () )
151+ classes = sorted ( per_class [ "class" ]. unique ())
152+ n_classes = len ( classes )
147153
148- # --- Top: overall purity ---
149154 markers = ["o" , "s" , "^" , "D" , "v" , "P" , "X" , "*" ]
155+
156+ # Layout: 1 row for overall, then a facet grid (up to 4 cols) for classes.
157+ n_cols = min (4 , n_classes ) if n_classes else 1
158+ n_facet_rows = (n_classes + n_cols - 1 ) // n_cols if n_classes else 0
159+ fig_height = 4 + 2.5 * n_facet_rows
160+ fig = plt .figure (figsize = (3.5 * n_cols , fig_height ))
161+ gs = fig .add_gridspec (1 + n_facet_rows , n_cols , hspace = 0.5 , wspace = 0.3 )
162+
163+ # --- Top: overall purity (spans all columns) ---
164+ ax_top = fig .add_subplot (gs [0 , :])
150165 for i , exp in enumerate (experiments ):
151166 exp_data = overall [overall ["experiment" ] == exp ].sort_values ("k" )
152- marker = markers [i % len (markers )]
153- ax_top .plot (exp_data ["k" ], exp_data ["purity" ], marker = marker , label = exp )
154-
167+ ax_top .plot (
168+ exp_data ["k" ],
169+ exp_data ["purity" ],
170+ marker = markers [i % len (markers )],
171+ label = exp ,
172+ )
155173 ax_top .set_xlabel ("k" )
156174 ax_top .set_ylabel ("Purity" )
157175 ax_top .set_ylim (0 , 1.05 )
158176 ax_top .set_title ("Overall KNN Class Purity by Experiment" )
159177 ax_top .legend ()
160178 ax_top .grid (True , alpha = 0.3 )
161179
162- # --- Bottom: per-class purity ---
163- classes = sorted (per_class ["class" ].unique ())
164- n_classes = len (classes )
165- n_experiments = len (experiments )
166- k_values = sorted (per_class ["k" ].unique ())
167- x = np .arange (len (k_values ))
168- total_bars = n_classes * n_experiments
169- width = 0.8 / max (total_bars , 1 )
180+ # --- Facets: one subplot per class ---
181+ facet_axes = []
182+ for idx , cls in enumerate (classes ):
183+ row = 1 + idx // n_cols
184+ col = idx % n_cols
185+ ax = fig .add_subplot (gs [row , col ])
186+ facet_axes .append (ax )
170187
171- for i , exp in enumerate (experiments ):
172- for j , cls in enumerate (classes ):
188+ for i , exp in enumerate (experiments ):
173189 subset = per_class [
174190 (per_class ["experiment" ] == exp ) & (per_class ["class" ] == cls )
175191 ].sort_values ("k" )
176- offset = (i * n_classes + j - total_bars / 2 ) * width + width / 2
177- bar_vals = [
178- subset [subset ["k" ] == k ]["purity" ].values [0 ]
179- if len (subset [subset ["k" ] == k ]) > 0
180- else 0
181- for k in k_values
182- ]
183- ax_bot .bar (x + offset , bar_vals , width , label = f"{ exp } — { cls } " )
184-
185- ax_bot .set_xlabel ("k" )
186- ax_bot .set_ylabel ("Purity" )
187- ax_bot .set_ylim (0 , 1.05 )
188- ax_bot .set_xticks (x )
189- ax_bot .set_xticklabels ([str (k ) for k in k_values ])
190- ax_bot .set_title ("Per-Class KNN Purity by Experiment" )
191- ax_bot .legend (fontsize = 7 , ncol = 2 )
192- ax_bot .grid (True , alpha = 0.3 , axis = "y" )
193-
194- fig .tight_layout ()
192+ if subset .empty :
193+ continue
194+ ax .plot (
195+ subset ["k" ],
196+ subset ["purity" ],
197+ marker = markers [i % len (markers )],
198+ label = exp ,
199+ )
200+
201+ display = class_labels .get (str (cls ), str (cls ))
202+ ax .set_title (display )
203+ ax .set_xlabel ("k" )
204+ ax .set_ylabel ("Purity" )
205+ ax .set_ylim (0 , 1.05 )
206+ ax .grid (True , alpha = 0.3 )
207+
208+ # Single shared legend for facets, placed below the grid.
209+ if facet_axes :
210+ handles , labels = facet_axes [0 ].get_legend_handles_labels ()
211+ if handles :
212+ fig .legend (
213+ handles ,
214+ labels ,
215+ loc = "lower center" ,
216+ ncol = min (len (labels ), 4 ),
217+ bbox_to_anchor = (0.5 , - 0.02 ),
218+ )
195219
196220 if output_path :
197221 plt .savefig (output_path , dpi = 300 , bbox_inches = "tight" )
0 commit comments