2929"""
3030import itertools
3131from copy import deepcopy
32- from typing import Dict , List , Optional , Union
32+ from typing import Callable , Dict , List , Optional , Union
3333from warnings import warn
3434
3535import numpy as np
@@ -88,6 +88,12 @@ class Tester(ABToolAbstract):
8888 metrics : MetricNameType, optional
8989 Metrics (columns of dataframe) which is used to calculate
9090 experiment result.
91+ metric_funcs : Dict[str, Callable], optional
92+ Dictionary mapping metric names to callable functions.
93+ Each function receives a ``pd.DataFrame`` (group data) and must
94+ return an array-like of numeric values. When provided, the
95+ function is used instead of column lookup for the corresponding
96+ metric name. Only supported for pandas DataFrames.
9197
9298 Attributes
9399 ----------
@@ -241,6 +247,7 @@ def __init__(
241247 id_column : Optional [types .ColumnNameType ] = None ,
242248 first_type_errors : types .StatErrorType = 0.05 ,
243249 metrics : Optional [types .MetricNamesType ] = None ,
250+ metric_funcs : Optional [Dict [str , Callable ]] = None ,
244251 ):
245252 """
246253 Tester class constructor to initialize the object.
@@ -257,6 +264,7 @@ def __init__(
257264 self .set_experiment_results (experiment_results = experiment_results )
258265 self .set_errors (first_type_errors )
259266 self .set_metrics (metrics )
267+ self .__metric_funcs = metric_funcs or {}
260268
261269 @staticmethod
262270 def __filter_data (
@@ -372,9 +380,15 @@ def __pre_run(method: str, args: types._UsageArgumentsType, **kwargs) -> types.T
372380 if method not in accepted_methods :
373381 raise ValueError (f'Choose method from { ", " .join (accepted_methods )} ' )
374382 result : types .TesterResult = {}
383+ metric_funcs : Dict = args .get ("metric_funcs" , {})
375384 for metric in args ["metrics" ]:
376- a_values : np .ndarray = args ["data_a_group" ][metric ].values
377- b_values : np .ndarray = args ["data_b_group" ][metric ].values
385+ metric_func = metric_funcs .get (metric )
386+ if metric_func is not None :
387+ a_values : np .ndarray = np .asarray (metric_func (args ["data_a_group" ]))
388+ b_values : np .ndarray = np .asarray (metric_func (args ["data_b_group" ]))
389+ else :
390+ a_values = args ["data_a_group" ][metric ].values
391+ b_values = args ["data_b_group" ][metric ].values
378392 if method == "theory" :
379393 # TODO: Make it SolverClass ~ method
380394 # solver = SolverClass(...)
@@ -386,6 +400,7 @@ def __pre_run(method: str, args: types._UsageArgumentsType, **kwargs) -> types.T
386400 alpha = np .array (args ["alpha" ]),
387401 effect_type = args ["effect_type" ],
388402 criterion = args ["criterion" ],
403+ metric_func = metric_func ,
389404 ** kwargs ,
390405 )
391406 sub_result = solver .solve ()
@@ -473,6 +488,7 @@ def run(
473488 criterion : Optional [ABStatCriterion ] = None ,
474489 correction_method : Union [str , None ] = "bonferroni" ,
475490 as_table : bool = True ,
491+ metric_funcs : Optional [Dict [str , Callable ]] = None ,
476492 ** kwargs ,
477493 ) -> types .TesterResult :
478494 """
@@ -515,6 +531,11 @@ def run(
515531 as_table : bool, default: ``True``
516532 Return the test results as a pandas dataframe.
517533 If ``False``, a list of dicts with results will be returned.
534+ metric_funcs : Dict[str, Callable], optional
535+ Dictionary mapping metric names to callable functions.
536+ Each function receives a group ``pd.DataFrame`` and returns
537+ array-like values. Overrides functions set in constructor
538+ for matching metric names. Only pandas DataFrames supported.
518539 **kwargs : Dict
519540 Other keyword arguments.
520541
@@ -556,6 +577,8 @@ def run(
556577 chosen_args : types ._UsageArgumentsType = Tester ._prepare_arguments (arguments_choice )
557578 chosen_args ["effect_type" ] = effect_type
558579 chosen_args ["criterion" ] = criterion
580+ effective_metric_funcs = {** self .__metric_funcs , ** (metric_funcs or {})}
581+ chosen_args ["metric_funcs" ] = effective_metric_funcs
559582
560583 hypothesis_num : int = len (list (itertools .combinations (chosen_args ["experiment_results" ], 2 ))) * len (
561584 chosen_args ["metrics" ]
@@ -602,6 +625,7 @@ def test(
602625 criterion : Optional [ABStatCriterion ] = None ,
603626 correction_method : Union [str , None ] = "bonferroni" ,
604627 as_table : bool = True ,
628+ metric_funcs : Optional [Dict [str , Callable ]] = None ,
605629 ** kwargs ,
606630) -> types .TesterResult :
607631 """
@@ -649,6 +673,10 @@ def test(
649673 as_table : bool, default: ``True``
650674 Return the test results as a pandas dataframe.
651675 If ``False``, a list of dicts with results will be returned.
676+ metric_funcs : Dict[str, Callable], optional
677+ Dictionary mapping metric names to callable functions.
678+ Each function receives a group ``pd.DataFrame`` and returns
679+ array-like values. Only pandas DataFrames supported.
652680 **kwargs : Dict
653681 Other keyword arguments.
654682
@@ -673,5 +701,6 @@ def test(
673701 criterion = criterion ,
674702 correction_method = correction_method ,
675703 as_table = as_table ,
704+ metric_funcs = metric_funcs ,
676705 ** kwargs ,
677706 )
0 commit comments