Skip to content

Commit 4d94b2e

Browse files
authored
Merge pull request #57 from MTSWebServices/feature/ab-test-abstraction
Reviewed and fixed: added metric_funcs docstrings, bootstrap test, pinned setuptools<82, resolved CLAUDE.md conflict
2 parents 263d657 + d3c2366 commit 4d94b2e

10 files changed

Lines changed: 378 additions & 11 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ settings.json
4444
.mypy_cache/
4545
.pytest_cache/
4646

47+
# Claude Code
48+
CLAUDE.md
49+
4750
# Tests artifacts
4851
reports/
4952
coverage.xml

ambrosia/preprocessing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .ml_var_reducer import MLVarianceReducer
2222
from .preprocessor import Preprocessor
2323
from .robust import IQRPreprocessor, RobustPreprocessor
24-
from .transformers import BoxCoxTransformer, LogTransformer
24+
from .transformers import BoxCoxTransformer, LinearizationTransformer, LogTransformer
2525

2626
__all__ = [
2727
"AggregatePreprocessor",
@@ -32,5 +32,6 @@
3232
"RobustPreprocessor",
3333
"IQRPreprocessor",
3434
"BoxCoxTransformer",
35+
"LinearizationTransformer",
3536
"LogTransformer",
3637
]

ambrosia/preprocessing/preprocessor.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ambrosia.preprocessing.aggregate import AggregatePreprocessor
3535
from ambrosia.preprocessing.cuped import Cuped, MultiCuped
3636
from ambrosia.preprocessing.robust import IQRPreprocessor, RobustPreprocessor
37-
from ambrosia.preprocessing.transformers import BoxCoxTransformer, LogTransformer
37+
from ambrosia.preprocessing.transformers import BoxCoxTransformer, LinearizationTransformer, LogTransformer
3838

3939

4040
class Preprocessor:
@@ -378,6 +378,50 @@ def multicuped(
378378
self.transformers.append(transformer)
379379
return self
380380

381+
def linearize(
382+
self,
383+
numerator: types.ColumnNameType,
384+
denominator: types.ColumnNameType,
385+
transformed_name: Optional[types.ColumnNameType] = None,
386+
load_path: Optional[Path] = None,
387+
) -> Preprocessor:
388+
"""
389+
Linearize a ratio metric for use in A/B testing.
390+
391+
Computes a per-unit linearized value that is approximately normally
392+
distributed, enabling correct t-test usage for ratio metrics:
393+
394+
linearized_i = numerator_i - ratio * denominator_i
395+
396+
where ratio = mean(numerator) / mean(denominator) is estimated on
397+
the data passed to this ``Preprocessor`` instance (reference / control data).
398+
399+
Parameters
400+
----------
401+
numerator : ColumnNameType
402+
Column name of the ratio numerator (e.g. ``"revenue"``).
403+
denominator : ColumnNameType
404+
Column name of the ratio denominator (e.g. ``"orders"``).
405+
transformed_name : ColumnNameType, optional
406+
Name for the new linearized column. Defaults to
407+
``"{numerator}_lin"``.
408+
load_path : Path, optional
409+
Path to a json file with pre-fitted parameters.
410+
411+
Returns
412+
-------
413+
self : Preprocessor
414+
Instance object.
415+
"""
416+
transformer = LinearizationTransformer()
417+
if load_path is None:
418+
transformer.fit_transform(self.dataframe, numerator, denominator, transformed_name, inplace=True)
419+
else:
420+
transformer.load_params(load_path)
421+
transformer.transform(self.dataframe, inplace=True)
422+
self.transformers.append(transformer)
423+
return self
424+
381425
def transformations(self) -> List:
382426
"""
383427
List of all transformations which were called.

ambrosia/preprocessing/transformers.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
Module contains tools for metrics transformations during a
1717
preprocessing task.
1818
"""
19-
from typing import Dict, Union
19+
from typing import Dict, Optional, Union
2020

2121
import numpy as np
2222
import pandas as pd
@@ -386,3 +386,134 @@ def inverse_transform(self, dataframe: pd.DataFrame, inplace: bool = False) -> U
386386
transformed: pd.DataFrame = dataframe if inplace else dataframe.copy()
387387
transformed[self.column_names] = np.exp(transformed[self.column_names].values)
388388
return None if inplace else transformed
389+
390+
391+
class LinearizationTransformer(AbstractFittableTransformer):
392+
"""
393+
Linearization transformer for ratio metrics.
394+
395+
Converts a ratio metric (numerator / denominator) into a per-unit linearized
396+
metric that is approximately normally distributed, enabling correct t-test usage:
397+
398+
linearized_i = numerator_i - ratio * denominator_i
399+
400+
where ratio = mean(numerator) / mean(denominator), estimated on the reference
401+
(control group / historical) data passed to fit().
402+
403+
Parameters
404+
----------
405+
numerator : str
406+
Column name of the ratio numerator (e.g. "revenue").
407+
denominator : str
408+
Column name of the ratio denominator (e.g. "orders").
409+
transformed_name : str, optional
410+
Name for the new column. Defaults to ``"{numerator}_lin"``.
411+
412+
Examples
413+
--------
414+
>>> transformer = LinearizationTransformer()
415+
>>> transformer.fit(control_df, "revenue", "orders", "arpu_lin")
416+
>>> transformer.transform(experiment_df, inplace=True)
417+
"""
418+
419+
def __str__(self) -> str:
420+
return "Linearization transformation"
421+
422+
def __init__(self) -> None:
423+
self.numerator: Optional[str] = None
424+
self.denominator: Optional[str] = None
425+
self.transformed_name: Optional[str] = None
426+
self.ratio: Optional[float] = None
427+
super().__init__()
428+
429+
def get_params_dict(self) -> Dict:
430+
self._check_fitted()
431+
return {
432+
"numerator": self.numerator,
433+
"denominator": self.denominator,
434+
"transformed_name": self.transformed_name,
435+
"ratio": self.ratio,
436+
}
437+
438+
def load_params_dict(self, params: Dict) -> None:
439+
for key in ("numerator", "denominator", "transformed_name", "ratio"):
440+
if key not in params:
441+
raise TypeError(f"params argument must contain: {key}")
442+
setattr(self, key, params[key])
443+
self.fitted = True
444+
445+
def fit(
446+
self,
447+
dataframe: pd.DataFrame,
448+
numerator: str,
449+
denominator: str,
450+
transformed_name: Optional[str] = None,
451+
):
452+
"""
453+
Estimate ratio = mean(numerator) / mean(denominator) on reference data.
454+
455+
Parameters
456+
----------
457+
dataframe : pd.DataFrame
458+
Reference dataframe (typically control group or historical data).
459+
numerator : str
460+
Column name of the ratio numerator.
461+
denominator : str
462+
Column name of the ratio denominator.
463+
transformed_name : str, optional
464+
Name for the linearized column. Defaults to ``"{numerator}_lin"``.
465+
"""
466+
self._check_cols(dataframe, [numerator, denominator])
467+
denom_mean = dataframe[denominator].mean()
468+
if denom_mean == 0:
469+
raise ValueError(f"Mean of denominator column '{denominator}' is zero; cannot compute ratio.")
470+
self.numerator = numerator
471+
self.denominator = denominator
472+
self.transformed_name = transformed_name if transformed_name is not None else f"{numerator}_lin"
473+
self.ratio = dataframe[numerator].mean() / denom_mean
474+
self.fitted = True
475+
return self
476+
477+
def transform(self, dataframe: pd.DataFrame, inplace: bool = False) -> Union[pd.DataFrame, None]:
478+
"""
479+
Apply linearization: transformed = numerator - ratio * denominator.
480+
481+
Parameters
482+
----------
483+
dataframe : pd.DataFrame
484+
Dataframe to transform.
485+
inplace : bool, default: ``False``
486+
If ``True`` modifies dataframe in place, otherwise returns a copy.
487+
"""
488+
self._check_fitted()
489+
self._check_cols(dataframe, [self.numerator, self.denominator])
490+
df = dataframe if inplace else dataframe.copy()
491+
df[self.transformed_name] = df[self.numerator] - self.ratio * df[self.denominator]
492+
return None if inplace else df
493+
494+
def fit_transform(
495+
self,
496+
dataframe: pd.DataFrame,
497+
numerator: str,
498+
denominator: str,
499+
transformed_name: Optional[str] = None,
500+
inplace: bool = False,
501+
) -> Union[pd.DataFrame, None]:
502+
"""
503+
Fit and transform in one step.
504+
505+
Parameters
506+
----------
507+
dataframe : pd.DataFrame
508+
Reference dataframe for fitting and transformation.
509+
numerator : str
510+
Column name of the ratio numerator.
511+
denominator : str
512+
Column name of the ratio denominator.
513+
transformed_name : str, optional
514+
Name for the linearized column.
515+
inplace : bool, default: ``False``
516+
If ``True`` modifies dataframe in place.
517+
"""
518+
self.fit(dataframe, numerator, denominator, transformed_name)
519+
return self.transform(dataframe, inplace)

ambrosia/tester/handlers.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,23 @@ class SparkCriteria(enum.Enum):
5151

5252
class TheoreticalTesterHandler:
5353
def __init__(
54-
self, group_a, group_b, column: str, alpha: np.ndarray, effect_type: str, criterion: StatCriterion, **kwargs
54+
self,
55+
group_a,
56+
group_b,
57+
column: str,
58+
alpha: np.ndarray,
59+
effect_type: str,
60+
criterion: StatCriterion,
61+
metric_func=None,
62+
**kwargs,
5563
):
5664
self.group_a = group_a
5765
self.group_b = group_b
5866
self.column = column
5967
self.alpha = alpha
6068
self.effect_type = effect_type
6169
self.criterion = criterion
70+
self.metric_func = metric_func
6271
self.kwargs = kwargs
6372

6473
def _correct_criterion(self, criterion: tp.Any) -> bool:
@@ -79,8 +88,12 @@ def get_criterion(self, criterion: str, data_example: types.SparkOrPandas):
7988

8089
def _set_kwargs(self):
8190
if isinstance(self.group_a, pd.DataFrame):
82-
self.group_a = self.group_a[self.column].values
83-
self.group_b = self.group_b[self.column].values
91+
if self.metric_func is not None:
92+
self.group_a = np.asarray(self.metric_func(self.group_a))
93+
self.group_b = np.asarray(self.metric_func(self.group_b))
94+
else:
95+
self.group_a = self.group_a[self.column].values
96+
self.group_b = self.group_b[self.column].values
8497
elif isinstance(self.group_a, types.SparkDataFrame):
8598
self.kwargs["column"] = self.column
8699
self.kwargs["alpha"] = self.alpha

ambrosia/tester/tester.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"""
3030
import itertools
3131
from copy import deepcopy
32-
from typing import Dict, List, Optional, Union
32+
from typing import Callable, Dict, List, Optional, Union
3333
from warnings import warn
3434

3535
import numpy as np
@@ -88,6 +88,12 @@ class Tester(ABToolAbstract):
8888
metrics : MetricNameType, optional
8989
Metrics (columns of dataframe) which is used to calculate
9090
experiment result.
91+
metric_funcs : Dict[str, Callable], optional
92+
Dictionary mapping metric names to callable functions.
93+
Each function receives a ``pd.DataFrame`` (group data) and must
94+
return an array-like of numeric values. When provided, the
95+
function is used instead of column lookup for the corresponding
96+
metric name. Only supported for pandas DataFrames.
9197
9298
Attributes
9399
----------
@@ -241,6 +247,7 @@ def __init__(
241247
id_column: Optional[types.ColumnNameType] = None,
242248
first_type_errors: types.StatErrorType = 0.05,
243249
metrics: Optional[types.MetricNamesType] = None,
250+
metric_funcs: Optional[Dict[str, Callable]] = None,
244251
):
245252
"""
246253
Tester class constructor to initialize the object.
@@ -257,6 +264,7 @@ def __init__(
257264
self.set_experiment_results(experiment_results=experiment_results)
258265
self.set_errors(first_type_errors)
259266
self.set_metrics(metrics)
267+
self.__metric_funcs = metric_funcs or {}
260268

261269
@staticmethod
262270
def __filter_data(
@@ -372,9 +380,15 @@ def __pre_run(method: str, args: types._UsageArgumentsType, **kwargs) -> types.T
372380
if method not in accepted_methods:
373381
raise ValueError(f'Choose method from {", ".join(accepted_methods)}')
374382
result: types.TesterResult = {}
383+
metric_funcs: Dict = args.get("metric_funcs", {})
375384
for metric in args["metrics"]:
376-
a_values: np.ndarray = args["data_a_group"][metric].values
377-
b_values: np.ndarray = args["data_b_group"][metric].values
385+
metric_func = metric_funcs.get(metric)
386+
if metric_func is not None:
387+
a_values: np.ndarray = np.asarray(metric_func(args["data_a_group"]))
388+
b_values: np.ndarray = np.asarray(metric_func(args["data_b_group"]))
389+
else:
390+
a_values = args["data_a_group"][metric].values
391+
b_values = args["data_b_group"][metric].values
378392
if method == "theory":
379393
# TODO: Make it SolverClass ~ method
380394
# solver = SolverClass(...)
@@ -386,6 +400,7 @@ def __pre_run(method: str, args: types._UsageArgumentsType, **kwargs) -> types.T
386400
alpha=np.array(args["alpha"]),
387401
effect_type=args["effect_type"],
388402
criterion=args["criterion"],
403+
metric_func=metric_func,
389404
**kwargs,
390405
)
391406
sub_result = solver.solve()
@@ -473,6 +488,7 @@ def run(
473488
criterion: Optional[ABStatCriterion] = None,
474489
correction_method: Union[str, None] = "bonferroni",
475490
as_table: bool = True,
491+
metric_funcs: Optional[Dict[str, Callable]] = None,
476492
**kwargs,
477493
) -> types.TesterResult:
478494
"""
@@ -515,6 +531,11 @@ def run(
515531
as_table : bool, default: ``True``
516532
Return the test results as a pandas dataframe.
517533
If ``False``, a list of dicts with results will be returned.
534+
metric_funcs : Dict[str, Callable], optional
535+
Dictionary mapping metric names to callable functions.
536+
Each function receives a group ``pd.DataFrame`` and returns
537+
array-like values. Overrides functions set in constructor
538+
for matching metric names. Only pandas DataFrames supported.
518539
**kwargs : Dict
519540
Other keyword arguments.
520541
@@ -556,6 +577,8 @@ def run(
556577
chosen_args: types._UsageArgumentsType = Tester._prepare_arguments(arguments_choice)
557578
chosen_args["effect_type"] = effect_type
558579
chosen_args["criterion"] = criterion
580+
effective_metric_funcs = {**self.__metric_funcs, **(metric_funcs or {})}
581+
chosen_args["metric_funcs"] = effective_metric_funcs
559582

560583
hypothesis_num: int = len(list(itertools.combinations(chosen_args["experiment_results"], 2))) * len(
561584
chosen_args["metrics"]
@@ -602,6 +625,7 @@ def test(
602625
criterion: Optional[ABStatCriterion] = None,
603626
correction_method: Union[str, None] = "bonferroni",
604627
as_table: bool = True,
628+
metric_funcs: Optional[Dict[str, Callable]] = None,
605629
**kwargs,
606630
) -> types.TesterResult:
607631
"""
@@ -649,6 +673,10 @@ def test(
649673
as_table : bool, default: ``True``
650674
Return the test results as a pandas dataframe.
651675
If ``False``, a list of dicts with results will be returned.
676+
metric_funcs : Dict[str, Callable], optional
677+
Dictionary mapping metric names to callable functions.
678+
Each function receives a group ``pd.DataFrame`` and returns
679+
array-like values. Only pandas DataFrames supported.
652680
**kwargs : Dict
653681
Other keyword arguments.
654682
@@ -673,5 +701,6 @@ def test(
673701
criterion=criterion,
674702
correction_method=correction_method,
675703
as_table=as_table,
704+
metric_funcs=metric_funcs,
676705
**kwargs,
677706
)

poetry.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)