Skip to content

Commit d19226b

Browse files
committed
interpret box/unbox in deopts
1 parent b241a2e commit d19226b

2 files changed

Lines changed: 170 additions & 6 deletions

File tree

server/src/main/java/org/prlprg/fir/interpret/internal/InternalInterpreter.java

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package org.prlprg.fir.interpret.internal;
22

3+
import static org.prlprg.fir.GlobalModules.BOX_FUN;
4+
import static org.prlprg.fir.GlobalModules.UNBOX_FUN;
35
import static org.prlprg.sexp.ArgumentMatcher.matchArguments;
46

57
import com.google.common.collect.ImmutableList;
68
import java.util.HashMap;
9+
import java.util.LinkedHashMap;
710
import java.util.List;
811
import java.util.Map;
912
import java.util.Objects;
@@ -1195,7 +1198,8 @@ case Consume(var variable) -> {
11951198
// Reverse-evaluate from deopt to checkpoint
11961199
cursor = new CFGCursor(deoptBc, deoptBc.statements().size() - 1);
11971200
while (cursor.instructionIndex() >= 0) {
1198-
var expression = ((Statement) Objects.requireNonNull(cursor.instruction())).expression();
1201+
var stmt = (Statement) Objects.requireNonNull(cursor.instruction());
1202+
var expression = stmt.expression();
11991203
switch (expression) {
12001204
case MkEnv() -> {
12011205
try {
@@ -1206,6 +1210,51 @@ case MkEnv() -> {
12061210
}
12071211
}
12081212
case Store(var storeType, _, _) when storeType == StoreType.LOCAL_VAR -> {}
1213+
case Call call when stmt.assignee() != null && isReversiblePureFun(call) -> {
1214+
var assigneeValue = topFrame().get(stmt.assignee());
1215+
if (assigneeValue == null) {
1216+
throw fail("deopt box/unbox assignee is uninitialized: " + stmt.assignee());
1217+
}
1218+
1219+
var inverseCall =
1220+
switch (call.callee()) {
1221+
case StaticFnCallee(var isDispatch, var functionRef, var signature)
1222+
when !isDispatch && functionRef.get() == BOX_FUN ->
1223+
new Call(
1224+
new StaticFnCallee(
1225+
false,
1226+
UNBOX_FUN,
1227+
new Signature(
1228+
ImmutableList.of(signature.returnType()),
1229+
signature.parameterTypes().getFirst(),
1230+
signature.effects())),
1231+
ImmutableList.of(new Constant(assigneeValue)));
1232+
case StaticFnCallee(var isDispatch, var functionRef, var signature)
1233+
when !isDispatch && functionRef.get() == UNBOX_FUN ->
1234+
new Call(
1235+
new StaticFnCallee(
1236+
false,
1237+
BOX_FUN,
1238+
new Signature(
1239+
ImmutableList.of(signature.returnType()),
1240+
signature.parameterTypes().getFirst(),
1241+
signature.effects())),
1242+
ImmutableList.of(new Constant(assigneeValue)));
1243+
default -> throw new UnreachableError();
1244+
};
1245+
var argumentRegister =
1246+
switch (call.callArguments().getFirst()) {
1247+
case Read(var register) -> register;
1248+
case Consume(var register) -> register;
1249+
default ->
1250+
throw fail(
1251+
"deopt box/unbox argument must be a register, got: "
1252+
+ call.callArguments().getFirst());
1253+
};
1254+
var argumentValue = Objects.requireNonNull(run(null, inverseCall));
1255+
topFrame().put(argumentRegister, argumentValue);
1256+
recordTypeFeedback(topFrame().scopeFeedback(), argumentRegister, argumentValue);
1257+
}
12091258
default ->
12101259
throw fail(
12111260
"unexpected expression in deopt branch: "
@@ -1229,12 +1278,28 @@ private DeoptSnapshot snapshotAtCheckpoint(Target deopt) {
12291278
}
12301279

12311280
var env = topFrame().environment().deepCopyUserEnvs();
1281+
var localRegs = new LinkedHashMap<Register, Value>();
12321282
for (var i = 0; i < deopt.bb().statements().size(); i++) {
12331283
var stmt = deopt.bb().statements().get(i);
12341284
switch (stmt.expression()) {
12351285
case MkEnv() -> env = new UserEnvSXP(env);
1286+
case Call call when stmt.assignee() != null && isReversiblePureFun(call) ->
1287+
localRegs.put(
1288+
stmt.assignee(),
1289+
Objects.requireNonNull(
1290+
run(
1291+
null,
1292+
call.mapArguments(
1293+
arg ->
1294+
switch (arg) {
1295+
case Read(var register) when localRegs.containsKey(register) ->
1296+
new Constant(localRegs.get(register));
1297+
case Consume(var register) when localRegs.containsKey(register) ->
1298+
new Constant(localRegs.get(register));
1299+
default -> arg;
1300+
}))));
12361301
case Store(var storeType, var variable, var arg) when storeType == StoreType.LOCAL_VAR -> {
1237-
var value = run(arg);
1302+
var value = runInSnapshotDeopt(arg, localRegs);
12381303
if (!(value instanceof Value.Sexp(var valueSexp))) {
12391304
throw fail("Can't store non-SEXP in environment: " + value);
12401305
}
@@ -1244,16 +1309,13 @@ case Store(var storeType, var variable, var arg) when storeType == StoreType.LOC
12441309
throw fail(
12451310
"Unsupported expression in deopt branch at index " + i + ": " + stmt.expression());
12461311
}
1247-
if (stmt.assignee() != null) {
1248-
throw fail("Deopt branch statement has assignee at index " + i + ": " + stmt);
1249-
}
12501312
}
12511313

12521314
var deoptJump = (Deopt) deopt.bb().jump();
12531315
var pc = deoptJump.pc();
12541316
var bcStack =
12551317
deoptJump.stack().stream()
1256-
.map(this::run)
1318+
.map(arg -> runInSnapshotDeopt(arg, localRegs))
12571319
.map(
12581320
value -> {
12591321
if (!(value instanceof Value.Sexp(var sexp))) {
@@ -1266,6 +1328,21 @@ case Store(var storeType, var variable, var arg) when storeType == StoreType.LOC
12661328
return new DeoptSnapshot(pc, bcStack, env, stackToString());
12671329
}
12681330

1331+
private Value runInSnapshotDeopt(Argument arg, LinkedHashMap<Register, Value> localRegs) {
1332+
return switch (arg) {
1333+
case Read(var register) when localRegs.containsKey(register) -> localRegs.get(register);
1334+
case Consume(var register) when localRegs.containsKey(register) -> localRegs.get(register);
1335+
default -> run(arg);
1336+
};
1337+
}
1338+
1339+
private boolean isReversiblePureFun(Call call) {
1340+
return call.callee() instanceof StaticFnCallee(var isDispatch, var functionRef, _)
1341+
&& !isDispatch
1342+
&& call.callArguments().size() == 1
1343+
&& (functionRef.get() == BOX_FUN || functionRef.get() == UNBOX_FUN);
1344+
}
1345+
12691346
private void recordTypeFeedback(AbstractionFeedback feedback, Register variable, Value value) {
12701347
feedback.recordType(variable, inferType(value, Ownership.SHARED));
12711348
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package org.prlprg.fir.interpret;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertSame;
5+
import static org.prlprg.fir.interpret.internal.Builtins.registerBuiltins;
6+
7+
import org.junit.jupiter.api.Test;
8+
import org.prlprg.fir.interpret.internal.InternalInterpreter;
9+
import org.prlprg.fir.ir.ParseUtil;
10+
import org.prlprg.fir.ir.value.Value;
11+
import org.prlprg.sexp.SEXPs;
12+
13+
/// Test [InternalInterpreter] behavior with larger, parsed modules.
14+
class InternalInterpretTest {
15+
@Test
16+
void deoptRestoreReverseEvaluatesBoxAndUnbox() {
17+
var module =
18+
ParseUtil.parseModule(
19+
"""
20+
fun main(x) {
21+
(reg x:v1(I)) --> v1(I) { reg xi:I, reg result:v1(I) |
22+
xi = unbox< v1(I) --> I >(x);
23+
result = f< I --> v1(I) >(xi);
24+
return result;
25+
}
26+
}
27+
28+
fun f(x) {
29+
(reg x:v1(I)) --> v1(I) { reg checked:v1(I), reg i:I, reg roundTrip:v1(I) |
30+
check Ok() else Deopt();
31+
Ok():
32+
checked = x ?: v1(I);
33+
return checked;
34+
Deopt():
35+
i = unbox< v1(I) --> I >(x);
36+
roundTrip = box< I --> v1(I) >(i);
37+
deopt 0 [roundTrip];
38+
}
39+
(reg x:I) --> v1(I) { reg checked:v1(I), reg roundTrip:v1(I) |
40+
check Ok() else Deopt();
41+
Ok():
42+
checked = x ?: v1(I);
43+
return checked;
44+
Deopt():
45+
roundTrip = box< I --> v1(I) >(x);
46+
deopt 0 [roundTrip];
47+
}
48+
}
49+
""");
50+
var interpreter = new InternalInterpreter(module);
51+
registerBuiltins(interpreter);
52+
53+
var result = interpreter.call("main", new Value.Sexp(SEXPs.integer(1)));
54+
55+
assertEquals(new Value.Sexp(SEXPs.integer(1)), result);
56+
}
57+
58+
@Test
59+
void checkpointSnapshotKeepsDeoptOnlyBoxedRegistersLocal() {
60+
var module =
61+
ParseUtil.parseModule(
62+
"""
63+
fun main(x) {
64+
(reg x:I) --> v1(I) { reg boxed:v1(I), var y:v1(I) |
65+
check Ok() else Deopt();
66+
Ok():
67+
boxed = box< I --> v1(I) >(x);
68+
return boxed;
69+
Deopt():
70+
mkenv;
71+
boxed = box< I --> v1(I) >(x);
72+
st y = boxed;
73+
deopt 0 [];
74+
}
75+
}
76+
""");
77+
var interpreter = new InternalInterpreter(module);
78+
registerBuiltins(interpreter);
79+
80+
var snapshots =
81+
interpreter.checkpointTrace().track(() -> interpreter.call("main", new Value.Int(1)));
82+
83+
assertEquals(1, snapshots.size());
84+
assertEquals(SEXPs.integer(1), snapshots.getFirst().env().getLocal("y").orElseThrow());
85+
assertSame(snapshots.getFirst().env().parent(), interpreter.globalEnv());
86+
}
87+
}

0 commit comments

Comments
 (0)