Skip to content

Add evaluation helpers for single-label multiclass metrics (recall, precision, micro-accuracy) #28

@atecon

Description

@atecon

Feature Request

Add evaluation functions for single-label multiclass classification

It would be valuable to provide users with template/helper functions for computing common evaluation metrics for single-label multiclass prediction problems. These can be very useful during classification analyses and will help align the package with the functionality found in Python's scikit-learn module.

The following functions serve as complete, concrete templates:

function scalar RecallMacroKClass (const series y, const series f,
                                   const bool weighted[FALSE])
  /* Macro recall from class-wise recalls via aggregate().
     Computes per-class recall (TP_i / (TP_i + FN_i)) by grouping on true class y.
    In single-label multiclass classification, weighted macro recall equals
    accuracy and also equals micro-precision.

     Parameters:
       y: series of true class labels
       f: series of predicted class labels
       weighted: if TRUE, compute support-weighted macro recall; otherwise
                 return simple unweighted average of per-class recalls

     Returns:
       Unweighted: average of all per-class recalls
       Weighted: support-weighted average following sklearn convention

    See also:
        - https://scikit-learn.org/stable/modules/model_evaluation.html#accuracy-score
        - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html
  */

  series iscorrect = y == f ? 1 : 0
  matrix recall_by_class = aggregate(iscorrect, y, "mean")
  /* aggregate() returns one row per true class:
     - Column 1: class label
     - Column 2: count (support) of that true class
     - Column 3: mean of iscorrect for that class = recall_i
  */

  if weighted
    /* Weighted macro recall: sum_i (n_i/N) * recall_i */
    matrix frequencies = recall_by_class[,2] ./ sumc(recall_by_class[,2])
    return sumc(frequencies .* recall_by_class[,3])
  endif

  /* Unweighted macro recall: average of all per-class recalls */
  return meanc(recall_by_class[,3])
end function


function scalar PrecisionMacroKClass (const series y, const series f,
                                      const bool weighted[FALSE])
  /* Macro precision from confusion-matrix counts via xtab.
     Per-class precision_i = TP_i / (TP_i + FP_i) = cm[i,i] / colsum_i.
     If weighted is nonzero, weight class-wise precisions by true-class
     support (row sums), following sklearn's weighted macro precision. */

  xtab y f --quiet
  matrix cm = $result[-end,-end]

  scalar k = rows(cm)
  scalar precision_sum = 0
  scalar weighted_sum = 0
  scalar total_support = 0

  if cols(cm) != k
    printf "Error: confusion matrix must be square (k x k).\n"
    return NA
  endif

  loop i=1..k --quiet
    scalar colsum_i = 0
    scalar support_i = 0
    scalar precision_i = 0

    loop r=1..k --quiet
      colsum_i += cm[r,i]
      support_i += cm[i,r]
    endloop

    if colsum_i > 0
      precision_i = cm[i,i] / colsum_i
    endif

    precision_sum += precision_i
    weighted_sum += support_i * precision_i
    total_support += support_i
  endloop

  if weighted
    if total_support <= 0
      return 0
    endif

    return weighted_sum / total_support
  endif

  if k <= 0
    return NA
  endif

  return precision_sum / k
end function


function scalar PrecisionMicroKClass (const series y, const series f)
  /* In single-label multiclass classification, sklearn's micro-averaged
     precision is identical to the overall classification accuracy. */

  xtab y f --quiet
  cm = $result[-end,-end]

  series iscorrect = y == f ? 1 : 0

  return mean(iscorrect)
end function

References:

These helpers/templates would assist in calculating macro/micro averages for recall and precision, similarly to scikit-learn, for single-label multi-class problems.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions