Skip to content

Commit 549e9bb

Browse files
committed
Fix #96
1 parent 7f4356d commit 549e9bb

5 files changed

Lines changed: 161 additions & 130 deletions

File tree

src/translation.ml

Lines changed: 149 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,12 +1319,85 @@ let build_extended_tvar_names sig_indices sig_names body_tvars =
13191319
else
13201320
sig_names
13211321

1322+
(* Walk an ML AST and collect source-order parameter indices that are NOT
1323+
simply forwarded unchanged at recursive call sites. [is_self_call depth f]
1324+
returns true when the head [f] of an application is a self-recursive
1325+
reference at the given binder depth.
1326+
1327+
After collect_lams, param source index [i] has de Bruijn index
1328+
[n_params - i] at depth 0, shifted by [depth] under binders. *)
1329+
let detect_non_forwarded_params_generic ~is_self_call n_params body =
1330+
let non_fwd = Hashtbl.create 4 in
1331+
let is_forwarded depth i arg =
1332+
let expected_db = n_params - i + depth in
1333+
match arg with
1334+
| MLmagic (MLrel db) | MLrel db -> db = expected_db
1335+
| _ -> false
1336+
in
1337+
let rec walk depth = function
1338+
| MLapp (f, args) when is_self_call depth f ->
1339+
List.iteri
1340+
(fun i arg ->
1341+
if i < n_params && not (is_forwarded depth i arg) then
1342+
Hashtbl.replace non_fwd i true )
1343+
args
1344+
| MLapp (f, args) ->
1345+
walk depth f;
1346+
List.iter (walk depth) args
1347+
| MLlam (_, _, body) -> walk (depth + 1) body
1348+
| MLletin (_, _, e1, e2) ->
1349+
walk depth e1;
1350+
walk (depth + 1) e2
1351+
| MLcase (_, scrut, branches) ->
1352+
walk depth scrut;
1353+
Array.iter
1354+
(fun (ids, _, _, body) ->
1355+
walk (depth + List.length ids) body )
1356+
branches
1357+
| MLcons (_, _, args) -> List.iter (walk depth) args
1358+
| MLtuple args -> List.iter (walk depth) args
1359+
| MLfix (_, _, bodies, _) ->
1360+
let n = Array.length bodies in
1361+
Array.iter (walk (depth + n)) bodies
1362+
| MLmagic e -> walk depth e
1363+
| MLparray (elts, def) ->
1364+
Array.iter (walk depth) elts;
1365+
walk depth def
1366+
| MLrel _
1367+
|MLglob _
1368+
|MLexn _
1369+
|MLdummy _
1370+
|MLaxiom _
1371+
|MLuint _
1372+
|MLfloat _
1373+
|MLstring _ -> ()
1374+
in
1375+
walk 0 body;
1376+
Hashtbl.fold (fun k _ acc -> k :: acc) non_fwd []
1377+
1378+
(* Detect non-forwarded params in a local fixpoint body. Self-references
1379+
use MLrel: after collect_lams strips [n_params] lambda params, the fix
1380+
binding for [fix_idx] in [n_fix] mutual funs is at
1381+
db = [n_params + n_fix - fix_idx], shifted by binder depth. *)
1382+
let detect_non_forwarded_params_fix n_params n_fix fix_idx body =
1383+
let base_self_db = n_params + n_fix - fix_idx in
1384+
detect_non_forwarded_params_generic
1385+
~is_self_call:(fun depth -> function
1386+
| MLrel db -> db = base_self_db + depth
1387+
| _ -> false )
1388+
n_params body
1389+
13221390
(** Convert ML params to C++ types with const/ref wrapping, and create
13231391
forwarding-ref template parameters for function-typed params. convert_fn:
13241392
function to convert ml_type -> cpp_type (typically
13251393
convert_ml_type_to_cpp_type env Refset'.empty tvar_names) Returns
13261394
(cpp_params, all_temps_with_funs). *)
1327-
let build_lifted_cpp_params convert_fn base_temps params =
1395+
let build_lifted_cpp_params ?(non_fwd_source_indices = []) convert_fn base_temps params =
1396+
let n_total = List.length params in
1397+
(* Non-forwarded check in source order (for fun_tys, which iterates List.rev) *)
1398+
let is_non_fwd_source j = List.mem j non_fwd_source_indices in
1399+
(* Non-forwarded check in de Bruijn order (for cpp_params replacement) *)
1400+
let is_non_fwd_db j = List.mem (n_total - 1 - j) non_fwd_source_indices in
13281401
let cpp_params =
13291402
List.map
13301403
(fun (id, ty) ->
@@ -1338,7 +1411,7 @@ let build_lifted_cpp_params convert_fn base_temps params =
13381411
List.filter_map
13391412
(fun (x, ty, j) ->
13401413
match ty with
1341-
| Tmod (TMconst, Tfun (dom, cod_f)) ->
1414+
| Tmod (TMconst, Tfun (dom, cod_f)) when not (is_non_fwd_source j) ->
13421415
let cod_f = if is_cpp_unit_type cod_f then Tvoid else cod_f in
13431416
Some (x, TTfun (dom, cod_f), fun_tparam_id j)
13441417
| _ -> None )
@@ -1349,7 +1422,7 @@ let build_lifted_cpp_params convert_fn base_temps params =
13491422
List.mapi
13501423
(fun j (x, ty) ->
13511424
match ty with
1352-
| Tmod (TMconst, Tfun (_, _)) ->
1425+
| Tmod (TMconst, Tfun (_, _)) when not (is_non_fwd_db j) ->
13531426
(x, Tref (Tref (Tvar (0, Some (fun_tparam_id (n_params - j - 1))))))
13541427
| _ -> (x, ty) )
13551428
cpp_params
@@ -5224,6 +5297,7 @@ and gen_stmts env (k : cpp_expr -> cpp_stmt) ast =
52245297
(* Restore outer type vars *)
52255298
set_current_type_vars saved_tvars;
52265299
(* Build a lifted Dfundef for each fixpoint function (usually just one) *)
5300+
let n_fix = Array.length funs in
52275301
List.iteri
52285302
(fun i ((renamed_id, fix_ty), params, body) ->
52295303
let cpp_ty =
@@ -5234,8 +5308,19 @@ and gen_stmts env (k : cpp_expr -> cpp_stmt) ast =
52345308
| Tfun (dom, cod) -> (dom, cod)
52355309
| _ -> ([], cpp_ty)
52365310
in
5311+
(* Detect params that are not simply forwarded at recursive call
5312+
sites — these must keep std::function type to avoid infinite
5313+
recursive template instantiation. *)
5314+
let non_fwd_source_indices =
5315+
let lam_params, stripped_body =
5316+
Mlutil.collect_lams funs.(i)
5317+
in
5318+
detect_non_forwarded_params_fix
5319+
(List.length lam_params) n_fix i stripped_body
5320+
in
52375321
let cpp_params, all_temps_with_funs =
52385322
build_lifted_cpp_params
5323+
~non_fwd_source_indices
52395324
(convert_ml_type_to_cpp_type env Refset'.empty all_tvar_names)
52405325
all_temps
52415326
params
@@ -6000,6 +6085,7 @@ and gen_stmts env (k : cpp_expr -> cpp_stmt) ast =
60006085
in
60016086
set_current_type_vars saved_tvars;
60026087
(* Build lifted declarations *)
6088+
let n_fix = Array.length funs in
60036089
List.iteri
60046090
(fun i ((renamed_id, fix_ty), params, body) ->
60056091
let cpp_ty =
@@ -6014,8 +6100,16 @@ and gen_stmts env (k : cpp_expr -> cpp_stmt) ast =
60146100
| Tfun (dom, cod) -> (dom, cod)
60156101
| _ -> ([], cpp_ty)
60166102
in
6103+
let non_fwd_source_indices =
6104+
let lam_params, stripped_body =
6105+
Mlutil.collect_lams funs.(i)
6106+
in
6107+
detect_non_forwarded_params_fix
6108+
(List.length lam_params) n_fix i stripped_body
6109+
in
60176110
let cpp_params, all_temps_with_funs =
60186111
build_lifted_cpp_params
6112+
~non_fwd_source_indices
60196113
(convert_ml_type_to_cpp_type
60206114
env
60216115
Refset'.empty
@@ -7804,99 +7898,32 @@ and tvar_subst_stmt (tvars : Id.t list) (s : cpp_stmt) : cpp_stmt =
78047898
(tvar_subst_type tvars)
78057899
s
78067900

7807-
(** Detect function-typed parameter positions that receive a freshly constructed
7808-
lambda in a self-recursive call.
7809-
7810-
Higher-order function parameters are normally emitted as C++ template
7811-
parameters constrained with a [MapsTo] concept:
7812-
7813-
template <MapsTo<T1, unsigned int> F0, MapsTo<T1, shared_ptr<tree>, T1,
7814-
shared_ptr<tree>, T1> F1> static T1 tree_rect(F0 &&f, F1 &&f0, const
7815-
shared_ptr<tree> &t);
7816-
7817-
This is preferred because template parameters preserve the exact lambda type,
7818-
enabling the compiler to inline the call — there is no type-erasure overhead
7819-
as there would be with [std::function].
7820-
7821-
However, this breaks when a self-recursive function passes a *new* lambda at
7822-
a function-typed parameter position in its own recursive call. This is the
7823-
continuation-passing style (CPS) pattern:
7824-
7825-
template <MapsTo<unsigned int, unsigned int> F1> static unsigned int
7826-
fact_cps(unsigned int n, F1 &&k) { ... return fact_cps(n_, [&](unsigned int
7827-
r) { return k(n_ * r); }); }
7828-
7829-
Each recursive call wraps [k] inside a fresh lambda with a unique type.
7830-
Because [F1] is a template parameter, the compiler must instantiate a new
7831-
specialization of [fact_cps] for every nesting depth. That creates an
7832-
infinite chain of template instantiations and the compiler rejects the
7833-
program.
7834-
7835-
The fix is to emit those specific parameters as [std::function] instead.
7836-
[std::function] is a concrete type — the continuation's type is always
7837-
[std::function<unsigned int(unsigned int)>] regardless of how many lambdas
7838-
are wrapped around it, so the recursive call resolves to the same function
7839-
and no new instantiation is needed:
7840-
7841-
static unsigned int fact_cps( unsigned int n, const std::function<unsigned
7842-
int(unsigned int)> k) { ... return fact_cps(n_, [&](unsigned int r) { return
7843-
k(n_ * r); }); }
7844-
7845-
Only parameters that actually exhibit this pattern are affected. For example,
7846-
in [partition_cps p l k], the predicate [p] is passed unchanged to the
7847-
recursive call ([partition_cps p rest (fun ...)]), so [p] stays as a template
7848-
parameter while [k] becomes [std::function]. Similarly, non-recursive
7849-
higher-order functions like [tree_rect] never trigger this issue and keep
7850-
full template parameters throughout.
7851-
7852-
The detection works by walking the function body's ML AST, looking for
7853-
self-recursive calls [MLapp(MLglob(self_ref, _), args)]. For each such call,
7854-
we check which argument positions contain a lambda ([MLlam]). Those positions
7855-
are the CPS parameters that must use [std::function]. *)
7856-
let detect_cps_params (self_ref : GlobRef.t) (n_params : int) (body : ml_ast) :
7857-
int list =
7858-
let cps_set = Hashtbl.create 4 in
7859-
let rec walk = function
7860-
| MLapp (MLglob (r, _), args)
7861-
when Environ.QGlobRef.equal Environ.empty_env r self_ref ->
7862-
List.iteri
7863-
(fun i arg ->
7864-
if i < n_params && contains_lambda arg then
7865-
Hashtbl.replace cps_set i true )
7866-
args
7867-
| MLapp (f, args) ->
7868-
walk f;
7869-
List.iter walk args
7870-
| MLlam (_, _, body) -> walk body
7871-
| MLletin (_, _, e1, e2) ->
7872-
walk e1;
7873-
walk e2
7874-
| MLcase (_, scrut, branches) ->
7875-
walk scrut;
7876-
Array.iter (fun (_, _, _, body) -> walk body) branches
7877-
| MLcons (_, _, args) -> List.iter walk args
7878-
| MLtuple args -> List.iter walk args
7879-
| MLfix (_, _, bodies, _) -> Array.iter walk bodies
7880-
| MLmagic e -> walk e
7881-
| MLparray (elts, def) ->
7882-
Array.iter walk elts;
7883-
walk def
7884-
| MLrel _
7885-
|MLglob _
7886-
|MLexn _
7887-
|MLdummy _
7888-
|MLaxiom _
7889-
|MLuint _
7890-
|MLfloat _
7891-
|MLstring _ -> ()
7892-
and contains_lambda = function
7893-
| MLlam _ -> true
7894-
| MLletin (_, _, _, body) -> contains_lambda body
7895-
| MLmagic e -> contains_lambda e
7896-
| _ -> false
7897-
in
7898-
walk body;
7899-
Hashtbl.fold (fun k _ acc -> k :: acc) cps_set []
7901+
(** Detect function-typed parameters that are NOT simply forwarded at
7902+
self-recursive call sites.
7903+
7904+
Higher-order function parameters are normally emitted as C++ template
7905+
parameters constrained with a [MapsTo] concept, preserving the exact lambda
7906+
type for inlining. However, when a recursive call passes a *different*
7907+
expression (not the parameter variable itself) for a function-typed parameter,
7908+
each recursion level creates a new template instantiation with a distinct type,
7909+
leading to infinite recursive template instantiation.
7910+
7911+
The fix: detect which parameters are not forwarded unchanged at any recursive
7912+
call site. Those parameters are emitted as [std::function] instead of template
7913+
parameters, since [std::function] is a concrete type that stays the same
7914+
regardless of wrapping.
7915+
7916+
Parameters that ARE forwarded unchanged (e.g., a predicate [p] passed as-is in
7917+
[partition_cps p rest (fun ...)]) keep their template parameter status.
7918+
Non-recursive higher-order functions like [tree_rect] are unaffected since they
7919+
have no self-recursive calls. *)
7920+
let detect_non_forwarded_params (self_ref : GlobRef.t) (n_params : int)
7921+
(body : ml_ast) : int list =
7922+
detect_non_forwarded_params_generic
7923+
~is_self_call:(fun _depth -> function
7924+
| MLglob (r, _) -> Environ.QGlobRef.equal Environ.empty_env r self_ref
7925+
| _ -> false )
7926+
n_params body
79007927

79017928
(** Generate a C++ function definition from an ML function body.
79027929
@@ -8078,25 +8105,25 @@ let gen_dfun n b cty ty temps =
80788105
else
80798106
ids
80808107
in
8081-
(* Detect which function-typed parameters are CPS parameters (see
8082-
[detect_cps_params] above for the full explanation). These are excluded
8083-
from template-parameter promotion below — they keep their [Tmod(TMconst,
8084-
Tfun(dom, cod))] type which prints as [const std::function<R(Args...)>].
8108+
(* Detect which function-typed parameters are NOT simply forwarded at
8109+
self-recursive call sites. These are excluded from template-parameter
8110+
promotion below — they keep their [Tmod(TMconst, Tfun(dom, cod))] type
8111+
which prints as [const std::function<R(Args...)>].
80858112
8086-
[detect_cps_params] returns source-order indices (param 0 = first Rocq
8087-
parameter). We need two index-checking helpers because the parameter list
8088-
[ids] is in de Bruijn order (innermost first = last source param first),
8089-
while [List.rev ids] is in source order:
8113+
[detect_non_forwarded_params] returns source-order indices (param 0 =
8114+
first Rocq parameter). We need two index-checking helpers because the
8115+
parameter list [ids] is in de Bruijn order (innermost first = last source
8116+
param first), while [List.rev ids] is in source order:
80908117
80918118
Source order (Rocq): p0 p1 p2 indices 0, 1, 2 De Bruijn order (ids): p2 p1
80928119
p0 indices 0, 1, 2
80938120
8094-
So CPS source index [i] maps to de Bruijn index [n_ids - 1 - i]. *)
8095-
let cps_param_indices = detect_cps_params n (List.length ids) b in
8096-
let cps_set = IntSet.of_list cps_param_indices in
8097-
let is_cps_param_source i = IntSet.mem i cps_set in
8121+
So non-forwarded source index [i] maps to de Bruijn index [n_ids - 1 - i]. *)
8122+
let non_fwd_param_indices = detect_non_forwarded_params n (List.length ids) b in
8123+
let non_fwd_set = IntSet.of_list non_fwd_param_indices in
8124+
let is_non_fwd_param_source i = IntSet.mem i non_fwd_set in
80988125
let n_ids = List.length ids in
8099-
let is_cps_param_db i = IntSet.mem (n_ids - 1 - i) cps_set in
8126+
let is_non_fwd_param_db i = IntSet.mem (n_ids - 1 - i) non_fwd_set in
81008127
let all_params = missing @ ids in
81018128
(* Type class instance parameters become C++ template type parameters. We
81028129
assign unique names (_tcI0, _tcI1, ...) to avoid collision with: - User
@@ -8330,7 +8357,7 @@ let gen_dfun n b cty ty temps =
83308357
(x, wrapped) )
83318358
ids_with_owned
83328359
in
8333-
(* Promote non-CPS function-typed parameters to C++ template parameters.
8360+
(* Promote forwarded function-typed parameters to C++ template parameters.
83348361
83358362
Function-typed parameters (those with C++ type [Tmod(TMconst, Tfun(...))])
83368363
are normally promoted to template parameters with [MapsTo] concept
@@ -8348,20 +8375,20 @@ let gen_dfun n b cty ty temps =
83488375
[f0] unchanged — the template type stays the same at every recursion
83498376
depth.
83508377
8351-
CPS parameters are excluded from this promotion. A CPS parameter
8352-
receives a *new* lambda at each recursive call site, which means the
8378+
Non-forwarded parameters are excluded from this promotion. A parameter
8379+
that receives a *different* expression at a recursive call site means the
83538380
template type would be different at each recursion depth, causing
83548381
infinite template instantiation. These parameters keep their
83558382
[const std::function<R(Args...)>] type, which is a concrete
8356-
(non-template) type that stays the same regardless of lambda wrapping.
8383+
(non-template) type that stays the same regardless of wrapping.
83578384
83588385
For example, [partition_cps p l k] has three parameters:
8359-
- [p] is passed unchanged to the recursive call → template [F0 &&p]
8386+
- [p] is forwarded unchanged to the recursive call → template [F0 &&p]
83608387
- [l] is not function-typed → stays as-is
8361-
- [k] receives a new lambda at the recursive call → [const std::function<...> k]
8388+
- [k] receives a different expression at the recursive call → [const std::function<...> k]
83628389
83638390
This loop iterates [List.rev ids] which is in source order,
8364-
so we use [is_cps_param_source] for the CPS guard. *)
8391+
so we use [is_non_fwd_param_source] for the guard. *)
83658392
(* Determine which tvars are "primary" — deducible from non-function domain
83668393
params or the return type. Function-typed params that reference tvars
83678394
outside this set (e.g., erased HKT type variables) get TTtypename (no
@@ -8375,7 +8402,7 @@ let gen_dfun n b cty ty temps =
83758402
List.filter_map
83768403
(fun (x, ty, i) ->
83778404
match ty with
8378-
| Tmod (TMconst, Tfun (fdom, fcod)) when not (is_cps_param_source i) ->
8405+
| Tmod (TMconst, Tfun (fdom, fcod)) when not (is_non_fwd_param_source i) ->
83798406
let fun_idx = get_tvar_indices (Tfun (fdom, fcod)) in
83808407
let has_undeclared =
83818408
List.exists (fun idx -> not (IntSet.mem idx primary)) fun_idx
@@ -8388,16 +8415,16 @@ let gen_dfun n b cty ty temps =
83888415
| _ -> None )
83898416
(List.mapi (fun i (x, ty) -> (x, ty, i)) (List.rev ids))
83908417
in
8391-
(* Replace the parameter type of promoted (non-CPS) function params with the
8392-
template type variable [F&&]. CPS params are left untouched — they keep
8393-
[Tmod(TMconst, Tfun(dom, cod))] which prints as [const
8418+
(* Replace the parameter type of promoted (forwarded) function params with the
8419+
template type variable [F&&]. Non-forwarded params are left untouched — they
8420+
keep [Tmod(TMconst, Tfun(dom, cod))] which prints as [const
83948421
std::function<R(Args...)>]. This loop iterates [ids] which is in de Bruijn
8395-
order, so we use [is_cps_param_db] for the CPS guard. *)
8422+
order, so we use [is_non_fwd_param_db] for the guard. *)
83968423
let ids =
83978424
List.mapi
83988425
(fun i (x, ty) ->
83998426
match ty with
8400-
| Tmod (TMconst, Tfun (dom, cod)) when not (is_cps_param_db i) ->
8427+
| Tmod (TMconst, Tfun (dom, cod)) when not (is_non_fwd_param_db i) ->
84018428
( x,
84028429
Tref
84038430
(Tref (Tvar (0, Some (fun_tparam_id (List.length ids - i - 1)))))

0 commit comments

Comments
 (0)