Skip to content

Commit 7d5e098

Browse files
authored
Update wga.py
1 parent 04529e5 commit 7d5e098

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/trainer/unlearn/wga.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, beta=1.0, gamma=1.0, alpha=1.0, *args, **kwargs):
1111
if self.ref_model is None:
1212
self.ref_model = self._prepare_ref_model(self.model)
1313

14-
def compute_loss(self, model, inputs, return_outputs=False):
14+
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
1515
forget_inputs = inputs["forget"]
1616
forget_inputs = {
1717
"input_ids": forget_inputs["input_ids"],

0 commit comments

Comments
 (0)