@@ -418,8 +418,31 @@ struct FlatLeaf {
418418// / realistic ~10).
419419static constexpr int kFlattenMaxDepth = 12 ;
420420
421+ // / Recursively walk a record type and append every reachable flat
422+ // / leaf to ``out``. Returns false if any path bottoms out at a
423+ // / non-flat shape (allocatable / pointer member, dynamic-extent
424+ // / inner array, etc.); on false the caller falls back to a
425+ // / non-nested rewrite and the un-flattened struct surfaces a
426+ // / loud failure downstream.
427+ // /
428+ // / Three member shapes are recognised at each level:
429+ // / * **flat member** (scalar / static-shape array of scalar) —
430+ // / contributes one leaf with its intrinsic shape preserved
431+ // / (the ``outerDims`` accumulated above are prepended so
432+ // / intermediate ``array<N x RecordType>`` levels concat into
433+ // / the leaf's flat companion shape).
434+ // / * **pure record** (``RecordType`` directly) — recurses with
435+ // / no shape contribution.
436+ // / * **array of records** (``array<N x RecordType>``) — recurses
437+ // / into the inner record after pushing ``N`` onto
438+ // / ``outerDims``; every leaf produced by that recursion
439+ // / inherits ``N`` as a leading dim. This is what enables
440+ // / ``p_prog%pprog(i)%w(j, k)`` (where ``pprog: type(t)(10)``
441+ // / is an array-of-struct member) to flatten to a 3D companion
442+ // / ``p_prog_pprog_w`` of shape ``(10, 5, 5)``.
421443static bool collectFlatLeaves (fir::RecordType rec,
422444 llvm::SmallVectorImpl<std::string> &prefix,
445+ llvm::SmallVectorImpl<int64_t > &outerDims,
423446 llvm::SmallVectorImpl<FlatLeaf> &out,
424447 int depth = 0 ) {
425448 if (depth > kFlattenMaxDepth ) return false ;
@@ -428,16 +451,70 @@ static bool collectFlatLeaves(fir::RecordType rec,
428451 if (isFlatMemberType (pair.second )) {
429452 FlatLeaf leaf;
430453 leaf.path .assign (prefix.begin (), prefix.end ());
431- leaf.leafTy = pair.second ;
454+ // Compose the leaf's flat companion shape:
455+ // outerDims (accumulated array-of-record dims walked
456+ // on the way down) ++ memberDims (the leaf member's
457+ // own intrinsic shape, if any).
458+ mlir::Type leafEle = pair.second ;
459+ llvm::SmallVector<int64_t , 4 > memberDims;
460+ if (auto seq = mlir::dyn_cast<fir::SequenceType>(leafEle)) {
461+ for (auto d : seq.getShape ()) {
462+ if (d == fir::SequenceType::getUnknownExtent ()) {
463+ prefix.pop_back ();
464+ return false ; // dynamic extents in the
465+ // leaf require a runtime
466+ // shape we don't synthesise
467+ // in this path.
468+ }
469+ memberDims.push_back (d);
470+ }
471+ leafEle = seq.getEleTy ();
472+ }
473+ if (outerDims.empty () && memberDims.empty ()) {
474+ // Pure scalar leaf — no array wrapper.
475+ leaf.leafTy = leafEle;
476+ } else {
477+ llvm::SmallVector<int64_t , 6 > shape (outerDims.begin (),
478+ outerDims.end ());
479+ shape.append (memberDims.begin (), memberDims.end ());
480+ leaf.leafTy = fir::SequenceType::get (shape, leafEle);
481+ }
432482 out.push_back (std::move (leaf));
433483 } else if (auto innerRec = mlir::dyn_cast<fir::RecordType>(pair.second )) {
434- if (!collectFlatLeaves (innerRec, prefix, out, depth + 1 )) {
484+ if (!collectFlatLeaves (innerRec, prefix, outerDims, out, depth + 1 )) {
485+ prefix.pop_back ();
486+ return false ;
487+ }
488+ } else if (auto seq = mlir::dyn_cast<fir::SequenceType>(pair.second )) {
489+ // Array-of-record member: recurse INTO the inner record
490+ // with the outer extents pushed on so each leaf inherits
491+ // them as leading dims. Bail on dynamic extents — those
492+ // would need a runtime-shape companion the synth path
493+ // doesn't yet emit.
494+ auto innerRec = mlir::dyn_cast<fir::RecordType>(seq.getEleTy ());
495+ if (!innerRec) {
496+ prefix.pop_back ();
497+ return false ;
498+ }
499+ llvm::SmallVector<int64_t , 4 > theseDims;
500+ for (auto d : seq.getShape ()) {
501+ if (d == fir::SequenceType::getUnknownExtent ()) {
502+ prefix.pop_back ();
503+ return false ;
504+ }
505+ theseDims.push_back (d);
506+ }
507+ for (auto d : theseDims) outerDims.push_back (d);
508+ bool ok = collectFlatLeaves (innerRec, prefix, outerDims, out,
509+ depth + 1 );
510+ for (size_t i = 0 ; i < theseDims.size (); ++i) outerDims.pop_back ();
511+ if (!ok) {
435512 prefix.pop_back ();
436513 return false ;
437514 }
438515 } else {
439- // Member is e.g. an array of struct, or an allocatable —
440- // not flattenable here . Bail out so the pass leaves the
516+ // Member is e.g. allocatable / pointer — not flattenable
517+ // through this path . Bail so the pass leaves the
441518 // struct untouched and the loud-failure throw in
442519 // extract_vars points at the right gap.
443520 prefix.pop_back ();
@@ -448,6 +525,17 @@ static bool collectFlatLeaves(fir::RecordType rec,
448525 return true ;
449526}
450527
528+ // / Top-level entry point for the flat-leaf walker. Internal
529+ // / callers always start with empty ``outerDims``. Forwards to
530+ // / the recursive form above.
531+ static bool collectFlatLeaves (fir::RecordType rec,
532+ llvm::SmallVectorImpl<std::string> &prefix,
533+ llvm::SmallVectorImpl<FlatLeaf> &out,
534+ int depth = 0 ) {
535+ llvm::SmallVector<int64_t , 4 > outerDims;
536+ return collectFlatLeaves (rec, prefix, outerDims, out, depth);
537+ }
538+
451539// / Detect a "jagged" scalar-struct: every member is a 1-D array of the same
452540// / scalar element type, and at least two members have different extents.
453541// /
@@ -559,9 +647,42 @@ static int memberRank(mlir::Type memTy) {
559647// / first non-designate operand (typically the original ``hlfir.declare``
560648// / of the struct root). Returns the joined "_" path on success and an
561649// / empty string if the chain doesn't end in pure component selectors.
562- static std::string designateChainPath (hlfir::DesignateOp leaf,
563- hlfir::DesignateOp &outAnchor) {
564- llvm::SmallVector<std::string, 4 > parts;
650+ // / Walk a chain of ``hlfir.designate`` ops back from the leaf up
651+ // / to the underlying ``hlfir.declare``, collecting:
652+ // / * ``path`` — outer-first list of component names.
653+ // / * ``intermediateIndices`` — outer-first list of indices that
654+ // / appeared on NON-LEAF designates
655+ // / (i.e. on intermediate steps of
656+ // / the chain). Empty for the
657+ // / simple case where only the leaf
658+ // / carries indices.
659+ // /
660+ // / Two chain shapes are handled by separate downstream paths:
661+ // /
662+ // / 1. **Leaf-only indices** (the original case): every
663+ // / intermediate designate is a pure ``{component}`` selector,
664+ // / and any indices live on the leaf itself. Caller clones
665+ // / the leaf and swaps its memref to the flat companion —
666+ // / preserving triplet sections, shape operands, and any
667+ // / other leaf-side attributes.
668+ // /
669+ // / 2. **Intermediate indices** (array-of-record member): the
670+ // / chain has a ``designate(idx)`` step between component
671+ // / designates, e.g. ``p_prog%pprog(i)%w(j, k)``. Caller
672+ // / builds a fresh designate over the flat companion with
673+ // / indices merged across all chain steps. Triplet sections
674+ // / on intermediate steps aren't in scope here (rare; would
675+ // / need separate handling).
676+ // /
677+ // / Returns the joined ``"a_b_c"`` path key on success (matching the
678+ // / FlatLeaf naming the synth produces); empty string if the chain
679+ // / has no component step at all, or if a triplet section appears
680+ // / at a non-leaf level.
681+ static std::string walkDesignateChain (
682+ hlfir::DesignateOp leaf,
683+ llvm::SmallVectorImpl<mlir::Value> &intermediateIndices) {
684+ llvm::SmallVector<std::string, 4 > compsRev;
685+ llvm::SmallVector<llvm::SmallVector<mlir::Value, 4 >, 4 > intermediateIdxGroupsRev;
565686 hlfir::DesignateOp cur = leaf;
566687 for (int i = 0 ; i < kFlattenMaxDepth && cur; ++i) {
567688 mlir::StringAttr compAttr;
@@ -570,26 +691,44 @@ static std::string designateChainPath(hlfir::DesignateOp leaf,
570691 compAttr = a;
571692 break ;
572693 }
573- if (!compAttr) {
574- // Reached a non-component designate (a subscripted access
575- // ``a(i,j)``) — that's only valid as the LEAF of the chain,
576- // i.e. the very first call. Stop here.
577- break ;
694+ if (compAttr) compsRev.push_back (compAttr.getValue ().str ());
695+ bool isLeaf = (cur == leaf);
696+ if (!isLeaf) {
697+ // Intermediate steps must be plain (no triplets).
698+ // Triplet sections on intermediate levels would mean a
699+ // non-uniform slice through the array-of-record path
700+ // (e.g. ``p_prog%pprog(2:5)%w(j)``); not in scope.
701+ for (bool t : cur.getIsTriplet ()) if (t) return " " ;
702+ llvm::SmallVector<mlir::Value, 4 > these (
703+ cur.getIndices ().begin (), cur.getIndices ().end ());
704+ intermediateIdxGroupsRev.push_back (std::move (these));
578705 }
579- parts.push_back (compAttr.getValue ().str ());
580- outAnchor = cur;
581- // Walk to the parent; if it's another designate keep going.
706+ // Walk to parent.
582707 auto memref = cur.getMemref ();
583708 cur = mlir::dyn_cast_or_null<hlfir::DesignateOp>(memref.getDefiningOp ());
584709 }
585- if (parts .empty ()) return " " ;
586- // Reverse to outermost -first order matching FlatLeaf.path.
587- std::reverse (parts. begin (), parts. end ());
710+ if (compsRev .empty ()) return " " ;
711+ // Reverse to outer -first. Components join with "_" to match
712+ // FlatLeaf.path's canonical form.
588713 std::string joined;
589- for (unsigned i = 0 ; i < parts. size (); ++i ) {
590- if (i ) joined += " _" ;
591- joined += parts[i] ;
714+ for (auto it = compsRev. rbegin (); it != compsRev. rend (); ++it ) {
715+ if (!joined. empty () ) joined += " _" ;
716+ joined += *it ;
592717 }
718+ for (auto it = intermediateIdxGroupsRev.rbegin ();
719+ it != intermediateIdxGroupsRev.rend (); ++it)
720+ intermediateIndices.append (it->begin (), it->end ());
721+ return joined;
722+ }
723+
724+ // / Backwards-compatible wrapper used by callers that only need the
725+ // / path (no merged indices) — keeps the original entry point shape
726+ // / while ``walkDesignateChain`` is the canonical implementation.
727+ static std::string designateChainPath (hlfir::DesignateOp leaf,
728+ hlfir::DesignateOp &outAnchor) {
729+ llvm::SmallVector<mlir::Value, 4 > ignored;
730+ auto joined = walkDesignateChain (leaf, ignored);
731+ outAnchor = leaf;
593732 return joined;
594733}
595734
@@ -602,27 +741,53 @@ static bool rewriteDesignateChain(
602741 hlfir::DesignateOp leaf,
603742 const llvm::StringMap<mlir::Value> &leafBase) {
604743
605- hlfir::DesignateOp anchor ;
606- std::string path = designateChainPath (leaf, anchor );
744+ llvm::SmallVector<mlir::Value, 4 > intermediateIndices ;
745+ std::string path = walkDesignateChain (leaf, intermediateIndices );
607746 if (path.empty ()) return false ;
608747 auto it = leafBase.find (path);
609748 if (it == leafBase.end ()) return false ;
610749 auto newBase = it->second ;
611750
612- // ``leaf`` is the INNERMOST (component or component-with-indices) op.
613- // Its result is what the rest of the IR consumes.
614- if (leaf.getIndices ().empty ()) {
615- leaf.getResult ().replaceAllUsesWith (newBase);
751+ // Leaf-only path (no intermediate indices). Preserves the
752+ // leaf's full shape — including triplet sections, shape
753+ // operand, complex_part, etc. — by cloning and rewiring just
754+ // the memref + clearing the component attrs. Whole-leaf
755+ // access (``base{"a"}{"b"}`` with no indices) just RAUWs.
756+ if (intermediateIndices.empty ()) {
757+ if (leaf.getIndices ().empty ()) {
758+ leaf.getResult ().replaceAllUsesWith (newBase);
759+ leaf.erase ();
760+ return true ;
761+ }
762+ mlir::OpBuilder rb (leaf);
763+ auto *clone = rb.clone (*leaf.getOperation ());
764+ clone->setOperand (0 , newBase);
765+ clone->removeAttr (" component" );
766+ clone->removeAttr (" component_name" );
767+ leaf.getResult ().replaceAllUsesWith (clone->getResult (0 ));
616768 leaf.erase ();
617769 return true ;
618770 }
619771
772+ // Intermediate-indices path (array-of-record member surfaced
773+ // by ``collectFlatLeaves``'s extra outerDims). Build a fresh
774+ // designate over the flat companion with intermediate +
775+ // leaf indices merged in outer-first order. No triplets at
776+ // intermediate levels (walker bails on that). Whether the
777+ // leaf itself has triplets is rare in this shape — a section
778+ // on the innermost array of a record-of-record-of-... — and
779+ // is also out of scope; bail to keep the contract narrow.
780+ for (bool t : leaf.getIsTriplet ()) if (t) return false ;
781+ llvm::SmallVector<mlir::Value, 6 > merged (intermediateIndices.begin (),
782+ intermediateIndices.end ());
783+ for (auto v : leaf.getIndices ()) merged.push_back (v);
620784 mlir::OpBuilder rb (leaf);
621- auto *clone = rb.clone (*leaf.getOperation ());
622- clone->setOperand (0 , newBase);
623- clone->removeAttr (" component" );
624- clone->removeAttr (" component_name" );
625- leaf.getResult ().replaceAllUsesWith (clone->getResult (0 ));
785+ auto newOp = rb.create <hlfir::DesignateOp>(
786+ leaf.getLoc (),
787+ leaf.getResult ().getType (),
788+ newBase,
789+ mlir::ValueRange{merged});
790+ leaf.getResult ().replaceAllUsesWith (newOp.getResult ());
626791 leaf.erase ();
627792 return true ;
628793}
0 commit comments