Skip to content

Commit 426308d

Browse files
committed
testDeopt works
TODO: figure out why the earlier version didn't, in case it's important
1 parent 910b28e commit 426308d

6 files changed

Lines changed: 150 additions & 40 deletions

File tree

server/src/main/java/org/prlprg/fir/interpret/internal/Builtins.java

Lines changed: 115 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,12 @@ private static SEXP applyBinaryPreservingInt(
383383
int len = Math.max(i1.size(), i2.size());
384384
return SEXPs.integer(
385385
IntStream.range(0, len)
386-
.map(i -> (int) fn.applyAsDouble(i1.get(i % i1.size()), i2.get(i % i2.size())))
386+
.map(
387+
i ->
388+
i1.get(i % i1.size()) == Constants.NA_INT
389+
|| i2.get(i % i2.size()) == Constants.NA_INT
390+
? Constants.NA_INT
391+
: (int) fn.applyAsDouble(i1.get(i % i1.size()), i2.get(i % i2.size())))
387392
.toArray());
388393
}
389394
case RealSXP r1 when s2 instanceof RealSXP r2 -> {
@@ -393,23 +398,94 @@ private static SEXP applyBinaryPreservingInt(
393398
.mapToDouble(i -> fn.applyAsDouble(r1.get(i % r1.size()), r2.get(i % r2.size())))
394399
.toArray());
395400
}
401+
case LglSXP l1 when s2 instanceof LglSXP l2 -> {
402+
int len = Math.max(l1.size(), l2.size());
403+
return SEXPs.integer(
404+
IntStream.range(0, len)
405+
.map(
406+
i ->
407+
l1.get(i % l1.size()) == Logical.NA || l2.get(i % l2.size()) == Logical.NA
408+
? Constants.NA_INT
409+
: (int)
410+
fn.applyAsDouble(
411+
l1.get(i % l1.size()).toInt(), l2.get(i % l2.size()).toInt()))
412+
.toArray());
413+
}
396414
case IntSXP i1 when s2 instanceof RealSXP r2 -> {
397415
int len = Math.max(i1.size(), r2.size());
398416
return SEXPs.real(
399417
IntStream.range(0, len)
400-
.mapToDouble(i -> fn.applyAsDouble(i1.get(i % i1.size()), r2.get(i % r2.size())))
418+
.mapToDouble(
419+
i ->
420+
i1.get(i % i1.size()) == Constants.NA_INT
421+
? Double.NaN
422+
: fn.applyAsDouble(i1.get(i % i1.size()), r2.get(i % r2.size())))
423+
.toArray());
424+
}
425+
case IntSXP i1 when s2 instanceof LglSXP l2 -> {
426+
int len = Math.max(i1.size(), l2.size());
427+
return SEXPs.real(
428+
IntStream.range(0, len)
429+
.mapToDouble(
430+
i ->
431+
i1.get(i % i1.size()) == Constants.NA_INT
432+
|| l2.get(i % l2.size()) == Logical.NA
433+
? Double.NaN
434+
: fn.applyAsDouble(
435+
i1.get(i % i1.size()), l2.get(i % l2.size()).toInt()))
401436
.toArray());
402437
}
403438
case RealSXP r1 when s2 instanceof IntSXP i2 -> {
404439
int len = Math.max(r1.size(), i2.size());
405440
return SEXPs.real(
406441
IntStream.range(0, len)
407-
.mapToDouble(i -> fn.applyAsDouble(r1.get(i % r1.size()), i2.get(i % i2.size())))
442+
.mapToDouble(
443+
i ->
444+
i2.get(i % i2.size()) == Constants.NA_INT
445+
? Double.NaN
446+
: fn.applyAsDouble(r1.get(i % r1.size()), i2.get(i % i2.size())))
447+
.toArray());
448+
}
449+
case RealSXP r1 when s2 instanceof LglSXP l2 -> {
450+
int len = Math.max(r1.size(), l2.size());
451+
return SEXPs.real(
452+
IntStream.range(0, len)
453+
.mapToDouble(
454+
i ->
455+
l2.get(i % l2.size()) == Logical.NA
456+
? Double.NaN
457+
: fn.applyAsDouble(
458+
r1.get(i % r1.size()), l2.get(i % l2.size()).toInt()))
459+
.toArray());
460+
}
461+
case LglSXP l1 when s2 instanceof IntSXP i2 -> {
462+
int len = Math.max(l1.size(), i2.size());
463+
return SEXPs.real(
464+
IntStream.range(0, len)
465+
.mapToDouble(
466+
i ->
467+
l1.get(i % l1.size()) == Logical.NA
468+
|| i2.get(i % i2.size()) == Constants.NA_INT
469+
? Double.NaN
470+
: fn.applyAsDouble(
471+
l1.get(i % l1.size()).toInt(), i2.get(i % i2.size())))
472+
.toArray());
473+
}
474+
case LglSXP l1 when s2 instanceof RealSXP r2 -> {
475+
int len = Math.max(l1.size(), r2.size());
476+
return SEXPs.real(
477+
IntStream.range(0, len)
478+
.mapToDouble(
479+
i ->
480+
l1.get(i % l1.size()) == Logical.NA || r2.get(i % r2.size()).isNaN()
481+
? Double.NaN
482+
: fn.applyAsDouble(
483+
l1.get(i % l1.size()).toInt(), r2.get(i % r2.size())))
408484
.toArray());
409485
}
410486
default -> {}
411487
}
412-
throw interpreter.fail("`" + ctx + "` generic requires numeric args");
488+
throw interpreter.fail("`" + ctx + "` generic requires logical or numeric args");
413489
}
414490

415491
/// Apply unary math op on SEXP, preserving int type (for +, -)
@@ -422,14 +498,29 @@ private static SEXP applyUnaryPreservingInt(
422498
return SEXPs.real(fn.applyAsDouble(s.asScalarReal().get()));
423499
}
424500
// Vector operations
425-
if (s instanceof IntSXP iv) {
426-
return SEXPs.integer(
427-
IntStream.range(0, iv.size()).map(i -> (int) fn.applyAsDouble(iv.get(i))).toArray());
428-
} else if (s instanceof RealSXP rv) {
429-
return SEXPs.real(
430-
IntStream.range(0, rv.size()).mapToDouble(i -> fn.applyAsDouble(rv.get(i))).toArray());
431-
}
432-
throw interpreter.fail("`" + ctx + "` unary requires a numeric arg");
501+
return switch (s) {
502+
case IntSXP iv ->
503+
SEXPs.integer(
504+
IntStream.range(0, iv.size())
505+
.map(
506+
i ->
507+
iv.get(i) == Constants.NA_INT
508+
? Constants.NA_INT
509+
: (int) fn.applyAsDouble(iv.get(i)))
510+
.toArray());
511+
case RealSXP rv ->
512+
SEXPs.real(
513+
IntStream.range(0, rv.size())
514+
.mapToDouble(i -> fn.applyAsDouble(rv.get(i)))
515+
.toArray());
516+
case LglSXP lv ->
517+
SEXPs.integer(
518+
IntStream.range(0, lv.size())
519+
.map(i -> lv.get(i) == Logical.NA ? Constants.NA_INT : lv.get(i).toInt())
520+
.map(j -> (int) fn.applyAsDouble(j))
521+
.toArray());
522+
default -> throw interpreter.fail("`" + ctx + "` unary requires a logical or numeric arg");
523+
};
433524
}
434525

435526
private static void registerBinaryMathToRealBuiltin(
@@ -2488,22 +2579,31 @@ private static Value sexpToValueOfType(
24882579
}
24892580

24902581
private static double sexpToDouble(SEXP sexp, InternalInterpreter interpreter, String ctx) {
2491-
if (sexp.asScalarInteger().isPresent()) return sexp.asScalarInteger().get();
2582+
if (sexp.asScalarInteger().isPresent())
2583+
return sexp.asScalarInteger().get() == Constants.NA_INT
2584+
? Double.NaN
2585+
: sexp.asScalarInteger().get();
24922586
if (sexp.asScalarReal().isPresent()) return sexp.asScalarReal().get();
24932587
if (sexp.asScalarLogical().isPresent()) return sexp.asScalarLogical().get().toInt();
24942588
throw interpreter.fail(ctx + " requires a numeric scalar");
24952589
}
24962590

24972591
private static @Nullable Double sexpToDoubleOpt(SEXP sexp) {
24982592
if (sexp.asScalarReal().isPresent()) return sexp.asScalarReal().get();
2499-
if (sexp.asScalarInteger().isPresent()) return (double) sexp.asScalarInteger().get();
2593+
if (sexp.asScalarInteger().isPresent())
2594+
return sexp.asScalarInteger().get() == Constants.NA_INT
2595+
? Double.NaN
2596+
: (double) sexp.asScalarInteger().get();
25002597
if (sexp.asScalarLogical().isPresent()) return (double) sexp.asScalarLogical().get().toInt();
25012598
return null;
25022599
}
25032600

25042601
private static int sexpToInt(SEXP sexp, InternalInterpreter interpreter, String ctx) {
25052602
if (sexp.asScalarInteger().isPresent()) return sexp.asScalarInteger().get();
2506-
if (sexp.asScalarReal().isPresent()) return (int) sexp.asScalarReal().get().doubleValue();
2603+
if (sexp.asScalarReal().isPresent())
2604+
return sexp.asScalarReal().get().isNaN()
2605+
? Constants.NA_INT
2606+
: (int) sexp.asScalarReal().get().doubleValue();
25072607
if (sexp.asScalarLogical().isPresent()) return sexp.asScalarLogical().get().toInt();
25082608
throw interpreter.fail(ctx + " requires a numeric scalar");
25092609
}

server/src/main/java/org/prlprg/fir/interpret/internal/InternalInterpreter.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,6 @@ private void checkStack() {
462462

463463
/// Executes a statement instruction.
464464
private void run(Statement statement) {
465-
System.out.println("R: " + statement);
466465
var assignee = statement.assignee();
467466
var value = run(assignee, statement.expression());
468467

@@ -481,7 +480,6 @@ private void run(Statement statement) {
481480

482481
/// Executes a jump instruction and returns the next control-flow action.
483482
private ControlFlow run(Jump jump) {
484-
System.out.println("R: " + jump);
485483
return switch (jump) {
486484
case Goto(_, var next) -> new ControlFlow.Goto(next);
487485
case If(_, var condition, var ifTrue, var ifFalse) -> {

server/src/test/java/org/prlprg/snapshot/SnapshotStore.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ public SnapshotStore() {}
3232
/// - OR [Query#verifyNoRegression(Object, Object, Example, SnapshotStore)] fails, and
3333
/// [TestConfig#IGNORE_SNAPSHOTS] is unset.
3434
/// If other checks (e.g. [Query#verifyExtra(Object, Example, SnapshotStore)]) fail, they are
35-
// reported
36-
/// (this call raises an exception) but the snapshot is still saved for debugging.
35+
/// reported (this call raises an exception) but the snapshot is still saved for debugging.
3736
///
3837
/// Returns the actual computed value (and only returns if it passes verification).
3938
public <T> T verify(Example example, Query<T> query) {
@@ -133,8 +132,17 @@ private <T> void verifyNoRegression(
133132
/// For when verification is expected to fail in some cases, and in those cases the rest of
134133
/// the enclosing test is expected to not necessarily work either (abort != fail).
135134
public <T> void assumeVerify(Example example, Query<T> query, T actual) {
135+
assumeVerify(example, query, actual, null);
136+
}
137+
138+
/// Same as [#verify(Example, Query, Object, String)] but aborts the test instead of failing.
139+
///
140+
/// For when verification is expected to fail in some cases, and in those cases the rest of
141+
/// the enclosing test is expected to not necessarily work either (abort != fail).
142+
public <T> void assumeVerify(
143+
Example example, Query<T> query, T actual, @Nullable String context) {
136144
try {
137-
verify(example, query, actual);
145+
verify(example, query, actual, context);
138146
} catch (AssertionError e) {
139147
throw new TestAbortedException("Verification failed for " + query.name(), e);
140148
}

server/src/test/java/org/prlprg/snapshot/fir/interpret/InterpretAfterOptTest.java

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void testRepeat(Example example, SnapshotStore store) {
5353
void testDeopt(Example example, SnapshotStore store) {
5454
var optimization = optimizations();
5555

56-
var module = store.load(example, FirQuery.INSTANCE);
56+
var module = store.load(example, new OptimizedFirQuery(optimization));
5757
var interpreter = new TestInterpreter(module);
5858

5959
var deoptFnName =
@@ -74,54 +74,54 @@ void testDeopt(Example example, SnapshotStore store) {
7474
new AssertionError(
7575
"deopt example doesn't start by declaring a function (the deopt function)")));
7676

77-
// Warmup
78-
for (int i = 1; i <= 3; i++) {
79-
store.assumeVerify(example, InterpretQuery.MAIN, interpreter.call("main"));
80-
}
81-
optimization.run(interpreter.feedback(), module);
82-
store.assumeVerify(example, InterpretQuery.MAIN, interpreter.call("main"));
77+
// Module is already optimized, but make sure it runs
78+
store.assumeVerify(
79+
example,
80+
InterpretQuery.deopt_int(deoptFnName),
81+
interpreter.call(deoptFnName, SEXPs.integer(1)),
82+
"run before deopts");
8383

8484
// Test
8585

8686
for (int i = 1; i <= 3; i++) {
8787
store.verify(
8888
example,
89-
InterpretQuery.DEOPT_REAL,
89+
InterpretQuery.deopt_real(deoptFnName),
9090
interpreter.call(deoptFnName, SEXPs.real(1)),
9191
"phase 1 run " + i);
9292
}
9393
optimization.run(interpreter.feedback(), module);
9494
store.verify(
9595
example,
96-
InterpretQuery.DEOPT_REAL,
96+
InterpretQuery.deopt_real(deoptFnName),
9797
interpreter.call(deoptFnName, SEXPs.real(1)),
9898
"phase 1 post-opt run");
9999

100100
for (int i = 1; i <= 3; i++) {
101101
store.verify(
102102
example,
103-
InterpretQuery.DEOPT_INT,
103+
InterpretQuery.deopt_int(deoptFnName),
104104
interpreter.call(deoptFnName, SEXPs.integer(1)),
105105
"phase 2 run " + i);
106106
}
107107
optimization.run(interpreter.feedback(), module);
108108
store.verify(
109109
example,
110-
InterpretQuery.DEOPT_INT,
110+
InterpretQuery.deopt_int(deoptFnName),
111111
interpreter.call(deoptFnName, SEXPs.integer(1)),
112112
"phase 2 post-opt run");
113113

114114
for (int i = 1; i <= 3; i++) {
115115
store.verify(
116116
example,
117-
InterpretQuery.DEOPT_LGL,
117+
InterpretQuery.deopt_lgl(deoptFnName),
118118
interpreter.call(deoptFnName, SEXPs.TRUE),
119119
"phase 3 run " + i);
120120
}
121121
optimization.run(interpreter.feedback(), module);
122122
store.verify(
123123
example,
124-
InterpretQuery.DEOPT_LGL,
124+
InterpretQuery.deopt_lgl(deoptFnName),
125125
interpreter.call(deoptFnName, SEXPs.TRUE),
126126
"phase 3 post-opt run");
127127
}

server/src/test/java/org/prlprg/snapshot/fir/interpret/InterpretQuery.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,17 @@ public record InterpretQuery(@Override String name, String functionName, SEXP...
2222
implements Query<InterpretOutput> {
2323
public static final InterpretQuery MAIN = new InterpretQuery("interpret", "main");
2424

25-
public static final InterpretQuery DEOPT_INT =
26-
new InterpretQuery("interpret.deopt.integer", "f", SEXPs.integer(1));
27-
public static final InterpretQuery DEOPT_REAL =
28-
new InterpretQuery("interpret.deopt.real", "f", SEXPs.real(1));
29-
public static final InterpretQuery DEOPT_LGL =
30-
new InterpretQuery("interpret.deopt.logical", "f", SEXPs.TRUE);
25+
public static InterpretQuery deopt_int(String deoptFnName) {
26+
return new InterpretQuery("interpret.deopt.integer", deoptFnName, SEXPs.integer(1));
27+
}
28+
29+
public static InterpretQuery deopt_real(String deoptFnName) {
30+
return new InterpretQuery("interpret.deopt.real", deoptFnName, SEXPs.real(1));
31+
}
32+
33+
public static InterpretQuery deopt_lgl(String deoptFnName) {
34+
return new InterpretQuery("interpret.deopt.logical", deoptFnName, SEXPs.TRUE);
35+
}
3136

3237
@Override
3338
public InterpretOutput compute(Example example, SnapshotStore store) {

server/src/test/java/org/prlprg/snapshot/fir/interpret/TestInterpreter.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ public TestInterpreter(Module module) {
2525
}
2626

2727
public InterpretOutput call(String functionName, SEXP... arguments) {
28-
System.out.println("Test " + functionName + "(" + Strings.join(", ", arguments) + ")");
2928
@SuppressWarnings({"DataFlowIssue"})
3029
SexpResult[] result = new SexpResult[] {null};
3130
var checkpointTrace =

0 commit comments

Comments
 (0)