@@ -1319,12 +1319,85 @@ let build_extended_tvar_names sig_indices sig_names body_tvars =
13191319 else
13201320 sig_names
13211321
1322+ (* Walk an ML AST and collect source-order parameter indices that are NOT
1323+ simply forwarded unchanged at recursive call sites. [is_self_call depth f]
1324+ returns true when the head [f] of an application is a self-recursive
1325+ reference at the given binder depth.
1326+
1327+ After collect_lams, param source index [i] has de Bruijn index
1328+ [n_params - i] at depth 0, shifted by [depth] under binders. *)
1329+ let detect_non_forwarded_params_generic ~is_self_call n_params body =
1330+ let non_fwd = Hashtbl. create 4 in
1331+ let is_forwarded depth i arg =
1332+ let expected_db = n_params - i + depth in
1333+ match arg with
1334+ | MLmagic (MLrel db ) | MLrel db -> db = expected_db
1335+ | _ -> false
1336+ in
1337+ let rec walk depth = function
1338+ | MLapp (f , args ) when is_self_call depth f ->
1339+ List. iteri
1340+ (fun i arg ->
1341+ if i < n_params && not (is_forwarded depth i arg) then
1342+ Hashtbl. replace non_fwd i true )
1343+ args
1344+ | MLapp (f , args ) ->
1345+ walk depth f;
1346+ List. iter (walk depth) args
1347+ | MLlam (_ , _ , body ) -> walk (depth + 1 ) body
1348+ | MLletin (_ , _ , e1 , e2 ) ->
1349+ walk depth e1;
1350+ walk (depth + 1 ) e2
1351+ | MLcase (_ , scrut , branches ) ->
1352+ walk depth scrut;
1353+ Array. iter
1354+ (fun (ids , _ , _ , body ) ->
1355+ walk (depth + List. length ids) body )
1356+ branches
1357+ | MLcons (_ , _ , args ) -> List. iter (walk depth) args
1358+ | MLtuple args -> List. iter (walk depth) args
1359+ | MLfix (_ , _ , bodies , _ ) ->
1360+ let n = Array. length bodies in
1361+ Array. iter (walk (depth + n)) bodies
1362+ | MLmagic e -> walk depth e
1363+ | MLparray (elts , def ) ->
1364+ Array. iter (walk depth) elts;
1365+ walk depth def
1366+ | MLrel _
1367+ | MLglob _
1368+ | MLexn _
1369+ | MLdummy _
1370+ | MLaxiom _
1371+ | MLuint _
1372+ | MLfloat _
1373+ | MLstring _ -> ()
1374+ in
1375+ walk 0 body;
1376+ Hashtbl. fold (fun k _ acc -> k :: acc) non_fwd []
1377+
1378+ (* Detect non-forwarded params in a local fixpoint body. Self-references
1379+ use MLrel: after collect_lams strips [n_params] lambda params, the fix
1380+ binding for [fix_idx] in [n_fix] mutual funs is at
1381+ db = [n_params + n_fix - fix_idx], shifted by binder depth. *)
1382+ let detect_non_forwarded_params_fix n_params n_fix fix_idx body =
1383+ let base_self_db = n_params + n_fix - fix_idx in
1384+ detect_non_forwarded_params_generic
1385+ ~is_self_call: (fun depth -> function
1386+ | MLrel db -> db = base_self_db + depth
1387+ | _ -> false )
1388+ n_params body
1389+
13221390(* * Convert ML params to C++ types with const/ref wrapping, and create
13231391 forwarding-ref template parameters for function-typed params. convert_fn:
13241392 function to convert ml_type -> cpp_type (typically
13251393 convert_ml_type_to_cpp_type env Refset'.empty tvar_names) Returns
13261394 (cpp_params, all_temps_with_funs). *)
1327- let build_lifted_cpp_params convert_fn base_temps params =
1395+ let build_lifted_cpp_params ?(non_fwd_source_indices = [] ) convert_fn base_temps params =
1396+ let n_total = List. length params in
1397+ (* Non-forwarded check in source order (for fun_tys, which iterates List.rev) *)
1398+ let is_non_fwd_source j = List. mem j non_fwd_source_indices in
1399+ (* Non-forwarded check in de Bruijn order (for cpp_params replacement) *)
1400+ let is_non_fwd_db j = List. mem (n_total - 1 - j) non_fwd_source_indices in
13281401 let cpp_params =
13291402 List. map
13301403 (fun (id , ty ) ->
@@ -1338,7 +1411,7 @@ let build_lifted_cpp_params convert_fn base_temps params =
13381411 List. filter_map
13391412 (fun (x , ty , j ) ->
13401413 match ty with
1341- | Tmod (TMconst, Tfun (dom , cod_f )) ->
1414+ | Tmod (TMconst, Tfun (dom , cod_f )) when not (is_non_fwd_source j) ->
13421415 let cod_f = if is_cpp_unit_type cod_f then Tvoid else cod_f in
13431416 Some (x, TTfun (dom, cod_f), fun_tparam_id j)
13441417 | _ -> None )
@@ -1349,7 +1422,7 @@ let build_lifted_cpp_params convert_fn base_temps params =
13491422 List. mapi
13501423 (fun j (x , ty ) ->
13511424 match ty with
1352- | Tmod (TMconst, Tfun (_ , _ )) ->
1425+ | Tmod (TMconst, Tfun (_ , _ )) when not (is_non_fwd_db j) ->
13531426 (x, Tref (Tref (Tvar (0 , Some (fun_tparam_id (n_params - j - 1 ))))))
13541427 | _ -> (x, ty) )
13551428 cpp_params
@@ -5224,6 +5297,7 @@ and gen_stmts env (k : cpp_expr -> cpp_stmt) ast =
52245297 (* Restore outer type vars *)
52255298 set_current_type_vars saved_tvars;
52265299 (* Build a lifted Dfundef for each fixpoint function (usually just one) *)
5300+ let n_fix = Array. length funs in
52275301 List. iteri
52285302 (fun i ((renamed_id , fix_ty ), params , body ) ->
52295303 let cpp_ty =
@@ -5234,8 +5308,19 @@ and gen_stmts env (k : cpp_expr -> cpp_stmt) ast =
52345308 | Tfun (dom , cod ) -> (dom, cod)
52355309 | _ -> ([] , cpp_ty)
52365310 in
5311+ (* Detect params that are not simply forwarded at recursive call
5312+ sites — these must keep std::function type to avoid infinite
5313+ recursive template instantiation. *)
5314+ let non_fwd_source_indices =
5315+ let lam_params, stripped_body =
5316+ Mlutil. collect_lams funs.(i)
5317+ in
5318+ detect_non_forwarded_params_fix
5319+ (List. length lam_params) n_fix i stripped_body
5320+ in
52375321 let cpp_params, all_temps_with_funs =
52385322 build_lifted_cpp_params
5323+ ~non_fwd_source_indices
52395324 (convert_ml_type_to_cpp_type env Refset'. empty all_tvar_names)
52405325 all_temps
52415326 params
@@ -6000,6 +6085,7 @@ and gen_stmts env (k : cpp_expr -> cpp_stmt) ast =
60006085 in
60016086 set_current_type_vars saved_tvars;
60026087 (* Build lifted declarations *)
6088+ let n_fix = Array. length funs in
60036089 List. iteri
60046090 (fun i ((renamed_id , fix_ty ), params , body ) ->
60056091 let cpp_ty =
@@ -6014,8 +6100,16 @@ and gen_stmts env (k : cpp_expr -> cpp_stmt) ast =
60146100 | Tfun (dom , cod ) -> (dom, cod)
60156101 | _ -> ([] , cpp_ty)
60166102 in
6103+ let non_fwd_source_indices =
6104+ let lam_params, stripped_body =
6105+ Mlutil. collect_lams funs.(i)
6106+ in
6107+ detect_non_forwarded_params_fix
6108+ (List. length lam_params) n_fix i stripped_body
6109+ in
60176110 let cpp_params, all_temps_with_funs =
60186111 build_lifted_cpp_params
6112+ ~non_fwd_source_indices
60196113 (convert_ml_type_to_cpp_type
60206114 env
60216115 Refset'. empty
@@ -7804,99 +7898,32 @@ and tvar_subst_stmt (tvars : Id.t list) (s : cpp_stmt) : cpp_stmt =
78047898 (tvar_subst_type tvars)
78057899 s
78067900
7807- (* * Detect function-typed parameter positions that receive a freshly constructed
7808- lambda in a self-recursive call.
7809-
7810- Higher-order function parameters are normally emitted as C++ template
7811- parameters constrained with a [MapsTo] concept:
7812-
7813- template <MapsTo<T1, unsigned int> F0, MapsTo<T1, shared_ptr<tree>, T1,
7814- shared_ptr<tree>, T1> F1> static T1 tree_rect(F0 &&f, F1 &&f0, const
7815- shared_ptr<tree> &t);
7816-
7817- This is preferred because template parameters preserve the exact lambda type,
7818- enabling the compiler to inline the call — there is no type-erasure overhead
7819- as there would be with [std::function].
7820-
7821- However, this breaks when a self-recursive function passes a *new* lambda at
7822- a function-typed parameter position in its own recursive call. This is the
7823- continuation-passing style (CPS) pattern:
7824-
7825- template <MapsTo<unsigned int, unsigned int> F1> static unsigned int
7826- fact_cps(unsigned int n, F1 &&k) { ... return fact_cps(n_, [&](unsigned int
7827- r) { return k(n_ * r); }); }
7828-
7829- Each recursive call wraps [k] inside a fresh lambda with a unique type.
7830- Because [F1] is a template parameter, the compiler must instantiate a new
7831- specialization of [fact_cps] for every nesting depth. That creates an
7832- infinite chain of template instantiations and the compiler rejects the
7833- program.
7834-
7835- The fix is to emit those specific parameters as [std::function] instead.
7836- [std::function] is a concrete type — the continuation's type is always
7837- [std::function<unsigned int(unsigned int)>] regardless of how many lambdas
7838- are wrapped around it, so the recursive call resolves to the same function
7839- and no new instantiation is needed:
7840-
7841- static unsigned int fact_cps( unsigned int n, const std::function<unsigned
7842- int(unsigned int)> k) { ... return fact_cps(n_, [&](unsigned int r) { return
7843- k(n_ * r); }); }
7844-
7845- Only parameters that actually exhibit this pattern are affected. For example,
7846- in [partition_cps p l k], the predicate [p] is passed unchanged to the
7847- recursive call ([partition_cps p rest (fun ...)]), so [p] stays as a template
7848- parameter while [k] becomes [std::function]. Similarly, non-recursive
7849- higher-order functions like [tree_rect] never trigger this issue and keep
7850- full template parameters throughout.
7851-
7852- The detection works by walking the function body's ML AST, looking for
7853- self-recursive calls [MLapp(MLglob(self_ref, _), args)]. For each such call,
7854- we check which argument positions contain a lambda ([MLlam]). Those positions
7855- are the CPS parameters that must use [std::function]. *)
7856- let detect_cps_params (self_ref : GlobRef.t ) (n_params : int ) (body : ml_ast ) :
7857- int list =
7858- let cps_set = Hashtbl. create 4 in
7859- let rec walk = function
7860- | MLapp (MLglob (r, _), args)
7861- when Environ.QGlobRef. equal Environ. empty_env r self_ref ->
7862- List. iteri
7863- (fun i arg ->
7864- if i < n_params && contains_lambda arg then
7865- Hashtbl. replace cps_set i true )
7866- args
7867- | MLapp (f , args ) ->
7868- walk f;
7869- List. iter walk args
7870- | MLlam (_ , _ , body ) -> walk body
7871- | MLletin (_ , _ , e1 , e2 ) ->
7872- walk e1;
7873- walk e2
7874- | MLcase (_ , scrut , branches ) ->
7875- walk scrut;
7876- Array. iter (fun (_ , _ , _ , body ) -> walk body) branches
7877- | MLcons (_ , _ , args ) -> List. iter walk args
7878- | MLtuple args -> List. iter walk args
7879- | MLfix (_ , _ , bodies , _ ) -> Array. iter walk bodies
7880- | MLmagic e -> walk e
7881- | MLparray (elts , def ) ->
7882- Array. iter walk elts;
7883- walk def
7884- | MLrel _
7885- | MLglob _
7886- | MLexn _
7887- | MLdummy _
7888- | MLaxiom _
7889- | MLuint _
7890- | MLfloat _
7891- | MLstring _ -> ()
7892- and contains_lambda = function
7893- | MLlam _ -> true
7894- | MLletin (_ , _ , _ , body ) -> contains_lambda body
7895- | MLmagic e -> contains_lambda e
7896- | _ -> false
7897- in
7898- walk body;
7899- Hashtbl. fold (fun k _ acc -> k :: acc) cps_set []
7901+ (* * Detect function-typed parameters that are NOT simply forwarded at
7902+ self-recursive call sites.
7903+
7904+ Higher-order function parameters are normally emitted as C++ template
7905+ parameters constrained with a [MapsTo] concept, preserving the exact lambda
7906+ type for inlining. However, when a recursive call passes a *different*
7907+ expression (not the parameter variable itself) for a function-typed parameter,
7908+ each recursion level creates a new template instantiation with a distinct type,
7909+ leading to infinite recursive template instantiation.
7910+
7911+ The fix: detect which parameters are not forwarded unchanged at any recursive
7912+ call site. Those parameters are emitted as [std::function] instead of template
7913+ parameters, since [std::function] is a concrete type that stays the same
7914+ regardless of wrapping.
7915+
7916+ Parameters that ARE forwarded unchanged (e.g., a predicate [p] passed as-is in
7917+ [partition_cps p rest (fun ...)]) keep their template parameter status.
7918+ Non-recursive higher-order functions like [tree_rect] are unaffected since they
7919+ have no self-recursive calls. *)
7920+ let detect_non_forwarded_params (self_ref : GlobRef.t ) (n_params : int )
7921+ (body : ml_ast ) : int list =
7922+ detect_non_forwarded_params_generic
7923+ ~is_self_call: (fun _depth -> function
7924+ | MLglob (r , _ ) -> Environ.QGlobRef. equal Environ. empty_env r self_ref
7925+ | _ -> false )
7926+ n_params body
79007927
79017928(* * Generate a C++ function definition from an ML function body.
79027929
@@ -8078,25 +8105,25 @@ let gen_dfun n b cty ty temps =
80788105 else
80798106 ids
80808107 in
8081- (* Detect which function-typed parameters are CPS parameters (see
8082- [detect_cps_params] above for the full explanation). These are excluded
8083- from template-parameter promotion below — they keep their [Tmod(TMconst,
8084- Tfun(dom, cod))] type which prints as [const std::function<R(Args...)>].
8108+ (* Detect which function-typed parameters are NOT simply forwarded at
8109+ self-recursive call sites. These are excluded from template-parameter
8110+ promotion below — they keep their [Tmod(TMconst, Tfun(dom, cod))] type
8111+ which prints as [const std::function<R(Args...)>].
80858112
8086- [detect_cps_params ] returns source-order indices (param 0 = first Rocq
8087- parameter). We need two index-checking helpers because the parameter list
8088- [ids] is in de Bruijn order (innermost first = last source param first),
8089- while [List.rev ids] is in source order:
8113+ [detect_non_forwarded_params ] returns source-order indices (param 0 =
8114+ first Rocq parameter). We need two index-checking helpers because the
8115+ parameter list [ids] is in de Bruijn order (innermost first = last source
8116+ param first), while [List.rev ids] is in source order:
80908117
80918118 Source order (Rocq): p0 p1 p2 indices 0, 1, 2 De Bruijn order (ids): p2 p1
80928119 p0 indices 0, 1, 2
80938120
8094- So CPS source index [i] maps to de Bruijn index [n_ids - 1 - i]. *)
8095- let cps_param_indices = detect_cps_params n (List. length ids) b in
8096- let cps_set = IntSet. of_list cps_param_indices in
8097- let is_cps_param_source i = IntSet. mem i cps_set in
8121+ So non-forwarded source index [i] maps to de Bruijn index [n_ids - 1 - i]. *)
8122+ let non_fwd_param_indices = detect_non_forwarded_params n (List. length ids) b in
8123+ let non_fwd_set = IntSet. of_list non_fwd_param_indices in
8124+ let is_non_fwd_param_source i = IntSet. mem i non_fwd_set in
80988125 let n_ids = List. length ids in
8099- let is_cps_param_db i = IntSet. mem (n_ids - 1 - i) cps_set in
8126+ let is_non_fwd_param_db i = IntSet. mem (n_ids - 1 - i) non_fwd_set in
81008127 let all_params = missing @ ids in
81018128 (* Type class instance parameters become C++ template type parameters. We
81028129 assign unique names (_tcI0, _tcI1, ...) to avoid collision with: - User
@@ -8330,7 +8357,7 @@ let gen_dfun n b cty ty temps =
83308357 (x, wrapped) )
83318358 ids_with_owned
83328359 in
8333- (* Promote non-CPS function-typed parameters to C++ template parameters.
8360+ (* Promote forwarded function-typed parameters to C++ template parameters.
83348361
83358362 Function-typed parameters (those with C++ type [Tmod(TMconst, Tfun(...))])
83368363 are normally promoted to template parameters with [MapsTo] concept
@@ -8348,20 +8375,20 @@ let gen_dfun n b cty ty temps =
83488375 [f0] unchanged — the template type stays the same at every recursion
83498376 depth.
83508377
8351- CPS parameters are excluded from this promotion. A CPS parameter
8352- receives a *new* lambda at each recursive call site, which means the
8378+ Non-forwarded parameters are excluded from this promotion. A parameter
8379+ that receives a *different* expression at a recursive call site means the
83538380 template type would be different at each recursion depth, causing
83548381 infinite template instantiation. These parameters keep their
83558382 [const std::function<R(Args...)>] type, which is a concrete
8356- (non-template) type that stays the same regardless of lambda wrapping.
8383+ (non-template) type that stays the same regardless of wrapping.
83578384
83588385 For example, [partition_cps p l k] has three parameters:
8359- - [p] is passed unchanged to the recursive call → template [F0 &&p]
8386+ - [p] is forwarded unchanged to the recursive call → template [F0 &&p]
83608387 - [l] is not function-typed → stays as-is
8361- - [k] receives a new lambda at the recursive call → [const std::function<...> k]
8388+ - [k] receives a different expression at the recursive call → [const std::function<...> k]
83628389
83638390 This loop iterates [List.rev ids] which is in source order,
8364- so we use [is_cps_param_source ] for the CPS guard. *)
8391+ so we use [is_non_fwd_param_source ] for the guard. *)
83658392 (* Determine which tvars are "primary" — deducible from non-function domain
83668393 params or the return type. Function-typed params that reference tvars
83678394 outside this set (e.g., erased HKT type variables) get TTtypename (no
@@ -8375,7 +8402,7 @@ let gen_dfun n b cty ty temps =
83758402 List. filter_map
83768403 (fun (x , ty , i ) ->
83778404 match ty with
8378- | Tmod (TMconst, Tfun (fdom , fcod )) when not (is_cps_param_source i) ->
8405+ | Tmod (TMconst, Tfun (fdom , fcod )) when not (is_non_fwd_param_source i) ->
83798406 let fun_idx = get_tvar_indices (Tfun (fdom, fcod)) in
83808407 let has_undeclared =
83818408 List. exists (fun idx -> not (IntSet. mem idx primary)) fun_idx
@@ -8388,16 +8415,16 @@ let gen_dfun n b cty ty temps =
83888415 | _ -> None )
83898416 (List. mapi (fun i (x , ty ) -> (x, ty, i)) (List. rev ids))
83908417 in
8391- (* Replace the parameter type of promoted (non-CPS ) function params with the
8392- template type variable [F&&]. CPS params are left untouched — they keep
8393- [Tmod(TMconst, Tfun(dom, cod))] which prints as [const
8418+ (* Replace the parameter type of promoted (forwarded ) function params with the
8419+ template type variable [F&&]. Non-forwarded params are left untouched — they
8420+ keep [Tmod(TMconst, Tfun(dom, cod))] which prints as [const
83948421 std::function<R(Args...)>]. This loop iterates [ids] which is in de Bruijn
8395- order, so we use [is_cps_param_db ] for the CPS guard. *)
8422+ order, so we use [is_non_fwd_param_db ] for the guard. *)
83968423 let ids =
83978424 List. mapi
83988425 (fun i (x , ty ) ->
83998426 match ty with
8400- | Tmod (TMconst, Tfun (dom , cod )) when not (is_cps_param_db i) ->
8427+ | Tmod (TMconst, Tfun (dom , cod )) when not (is_non_fwd_param_db i) ->
84018428 ( x,
84028429 Tref
84038430 (Tref (Tvar (0 , Some (fun_tparam_id (List. length ids - i - 1 )))))
0 commit comments