Skip to content

Commit 8fa2b16

Browse files
committed
fix: complete the eval metrics Truth_Ratio calculation mentioned in the paper
1 parent 833ccae commit 8fa2b16

3 files changed

Lines changed: 20 additions & 3 deletions

File tree

configs/eval/tofu.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
defaults: # include all defined metrics files
55
- tofu_metrics: # When you import a metric here, its configuration automatically populates the
66
# metric key below, enabled by the @package directive at the top of each configuration file.
7+
- forget_Truth_Ratio
78
- forget_quality
89
- forget_Q_A_Prob
910
- forget_Q_A_ROUGE

configs/eval/tofu_metrics/forget_Truth_Ratio.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ pre_compute:
1010
access_key: wrong
1111

1212
handler: truth_ratio
13-
aggregator: closer_to_1_better
13+
aggregator: prob_mean

src/evals/metrics/memorization.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,21 @@ def closer_to_1_better(arr):
118118
# 1-tr is higher.
119119
def true_better(arr):
120120
return np.mean(np.maximum(0, 1 - arr))
121+
122+
# NEW: Use correctness probability: correct / (correct + wrong), higher is better
123+
def prob_mean(arr):
124+
# arr here will be the new truth_ratios = correct / (correct + wrong)
125+
return np.mean(arr)
121126

122127
if kwargs["aggregator"] == "closer_to_1_better":
128+
use_original_ratio = True
123129
aggregator = closer_to_1_better
124130
elif kwargs["aggregator"] == "true_better":
131+
use_original_ratio = True
125132
aggregator = true_better
133+
elif kwargs["aggregator"] == "prob_mean":
134+
aggregator = prob_mean
135+
use_original_ratio = False
126136
else:
127137
raise ValueError(f"Invalid truth ratio aggregator: {kwargs['aggregator']}")
128138

@@ -152,8 +162,14 @@ def true_better(arr):
152162

153163
correct_prob = np.exp(-correct_avg_losses)
154164
wrong_prob = np.exp(-wrong_avg_losses)
155-
156-
truth_ratios = wrong_prob / (correct_prob + 1e-10)
165+
166+
if use_original_ratio:
167+
# Original definition: wrong / correct
168+
truth_ratios = wrong_prob / (correct_prob + 1e-10)
169+
else:
170+
# New definition: correct / (correct + wrong)
171+
truth_ratios = correct_prob / (correct_prob + wrong_prob + 1e-10)
172+
157173
value_by_index = dict(
158174
zip(correct_indices, [{"score": val} for val in truth_ratios])
159175
)

0 commit comments

Comments
 (0)