Skip to content

Commit 7976c7a

Browse files
committed
New features, a lot of struggling with AoS to SoA
1 parent 30b068b commit 7976c7a

2 files changed

Lines changed: 236 additions & 32 deletions

File tree

dace/frontend/hlfir/passes/FlattenStructs.cpp

Lines changed: 197 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,31 @@ struct FlatLeaf {
418418
/// realistic ~10).
419419
static constexpr int kFlattenMaxDepth = 12;
420420

421+
/// Recursively walk a record type and append every reachable flat
422+
/// leaf to ``out``. Returns false if any path bottoms out at a
423+
/// non-flat shape (allocatable / pointer member, dynamic-extent
424+
/// inner array, etc.); on false the caller falls back to a
425+
/// non-nested rewrite and the un-flattened struct surfaces a
426+
/// loud failure downstream.
427+
///
428+
/// Three member shapes are recognised at each level:
429+
/// * **flat member** (scalar / static-shape array of scalar) —
430+
/// contributes one leaf with its intrinsic shape preserved
431+
/// (the ``outerDims`` accumulated above are prepended so
432+
/// intermediate ``array<N x RecordType>`` levels concat into
433+
/// the leaf's flat companion shape).
434+
/// * **pure record** (``RecordType`` directly) — recurses with
435+
/// no shape contribution.
436+
/// * **array of records** (``array<N x RecordType>``) — recurses
437+
/// into the inner record after pushing ``N`` onto
438+
/// ``outerDims``; every leaf produced by that recursion
439+
/// inherits ``N`` as a leading dim. This is what enables
440+
/// ``p_prog%pprog(i)%w(j, k)`` (where ``pprog: type(t)(10)``
441+
/// is an array-of-struct member) to flatten to a 3D companion
442+
/// ``p_prog_pprog_w`` of shape ``(10, 5, 5)``.
421443
static bool collectFlatLeaves(fir::RecordType rec,
422444
llvm::SmallVectorImpl<std::string> &prefix,
445+
llvm::SmallVectorImpl<int64_t> &outerDims,
423446
llvm::SmallVectorImpl<FlatLeaf> &out,
424447
int depth = 0) {
425448
if (depth > kFlattenMaxDepth) return false;
@@ -428,16 +451,70 @@ static bool collectFlatLeaves(fir::RecordType rec,
428451
if (isFlatMemberType(pair.second)) {
429452
FlatLeaf leaf;
430453
leaf.path.assign(prefix.begin(), prefix.end());
431-
leaf.leafTy = pair.second;
454+
// Compose the leaf's flat companion shape:
455+
// outerDims (accumulated array-of-record dims walked
456+
// on the way down) ++ memberDims (the leaf member's
457+
// own intrinsic shape, if any).
458+
mlir::Type leafEle = pair.second;
459+
llvm::SmallVector<int64_t, 4> memberDims;
460+
if (auto seq = mlir::dyn_cast<fir::SequenceType>(leafEle)) {
461+
for (auto d : seq.getShape()) {
462+
if (d == fir::SequenceType::getUnknownExtent()) {
463+
prefix.pop_back();
464+
return false; // dynamic extents in the
465+
// leaf require a runtime
466+
// shape we don't synthesise
467+
// in this path.
468+
}
469+
memberDims.push_back(d);
470+
}
471+
leafEle = seq.getEleTy();
472+
}
473+
if (outerDims.empty() && memberDims.empty()) {
474+
// Pure scalar leaf — no array wrapper.
475+
leaf.leafTy = leafEle;
476+
} else {
477+
llvm::SmallVector<int64_t, 6> shape(outerDims.begin(),
478+
outerDims.end());
479+
shape.append(memberDims.begin(), memberDims.end());
480+
leaf.leafTy = fir::SequenceType::get(shape, leafEle);
481+
}
432482
out.push_back(std::move(leaf));
433483
} else if (auto innerRec = mlir::dyn_cast<fir::RecordType>(pair.second)) {
434-
if (!collectFlatLeaves(innerRec, prefix, out, depth + 1)) {
484+
if (!collectFlatLeaves(innerRec, prefix, outerDims, out, depth + 1)) {
485+
prefix.pop_back();
486+
return false;
487+
}
488+
} else if (auto seq = mlir::dyn_cast<fir::SequenceType>(pair.second)) {
489+
// Array-of-record member: recurse INTO the inner record
490+
// with the outer extents pushed on so each leaf inherits
491+
// them as leading dims. Bail on dynamic extents — those
492+
// would need a runtime-shape companion the synth path
493+
// doesn't yet emit.
494+
auto innerRec = mlir::dyn_cast<fir::RecordType>(seq.getEleTy());
495+
if (!innerRec) {
496+
prefix.pop_back();
497+
return false;
498+
}
499+
llvm::SmallVector<int64_t, 4> theseDims;
500+
for (auto d : seq.getShape()) {
501+
if (d == fir::SequenceType::getUnknownExtent()) {
502+
prefix.pop_back();
503+
return false;
504+
}
505+
theseDims.push_back(d);
506+
}
507+
for (auto d : theseDims) outerDims.push_back(d);
508+
bool ok = collectFlatLeaves(innerRec, prefix, outerDims, out,
509+
depth + 1);
510+
for (size_t i = 0; i < theseDims.size(); ++i) outerDims.pop_back();
511+
if (!ok) {
435512
prefix.pop_back();
436513
return false;
437514
}
438515
} else {
439-
// Member is e.g. an array of struct, or an allocatable —
440-
// not flattenable here. Bail out so the pass leaves the
516+
// Member is e.g. allocatable / pointer — not flattenable
517+
// through this path. Bail so the pass leaves the
441518
// struct untouched and the loud-failure throw in
442519
// extract_vars points at the right gap.
443520
prefix.pop_back();
@@ -448,6 +525,17 @@ static bool collectFlatLeaves(fir::RecordType rec,
448525
return true;
449526
}
450527

528+
/// Top-level entry point for the flat-leaf walker. Internal
529+
/// callers always start with empty ``outerDims``. Forwards to
530+
/// the recursive form above.
531+
static bool collectFlatLeaves(fir::RecordType rec,
532+
llvm::SmallVectorImpl<std::string> &prefix,
533+
llvm::SmallVectorImpl<FlatLeaf> &out,
534+
int depth = 0) {
535+
llvm::SmallVector<int64_t, 4> outerDims;
536+
return collectFlatLeaves(rec, prefix, outerDims, out, depth);
537+
}
538+
451539
/// Detect a "jagged" scalar-struct: every member is a 1-D array of the same
452540
/// scalar element type, and at least two members have different extents.
453541
///
@@ -559,9 +647,42 @@ static int memberRank(mlir::Type memTy) {
559647
/// first non-designate operand (typically the original ``hlfir.declare``
560648
/// of the struct root). Returns the joined "_" path on success and an
561649
/// empty string if the chain doesn't end in pure component selectors.
562-
static std::string designateChainPath(hlfir::DesignateOp leaf,
563-
hlfir::DesignateOp &outAnchor) {
564-
llvm::SmallVector<std::string, 4> parts;
650+
/// Walk a chain of ``hlfir.designate`` ops back from the leaf up
651+
/// to the underlying ``hlfir.declare``, collecting:
652+
/// * ``path`` — outer-first list of component names.
653+
/// * ``intermediateIndices`` — outer-first list of indices that
654+
/// appeared on NON-LEAF designates
655+
/// (i.e. on intermediate steps of
656+
/// the chain). Empty for the
657+
/// simple case where only the leaf
658+
/// carries indices.
659+
///
660+
/// Two chain shapes are handled by separate downstream paths:
661+
///
662+
/// 1. **Leaf-only indices** (the original case): every
663+
/// intermediate designate is a pure ``{component}`` selector,
664+
/// and any indices live on the leaf itself. Caller clones
665+
/// the leaf and swaps its memref to the flat companion —
666+
/// preserving triplet sections, shape operands, and any
667+
/// other leaf-side attributes.
668+
///
669+
/// 2. **Intermediate indices** (array-of-record member): the
670+
/// chain has a ``designate(idx)`` step between component
671+
/// designates, e.g. ``p_prog%pprog(i)%w(j, k)``. Caller
672+
/// builds a fresh designate over the flat companion with
673+
/// indices merged across all chain steps. Triplet sections
674+
/// on intermediate steps aren't in scope here (rare; would
675+
/// need separate handling).
676+
///
677+
/// Returns the joined ``"a_b_c"`` path key on success (matching the
678+
/// FlatLeaf naming the synth produces); empty string if the chain
679+
/// has no component step at all, or if a triplet section appears
680+
/// at a non-leaf level.
681+
static std::string walkDesignateChain(
682+
hlfir::DesignateOp leaf,
683+
llvm::SmallVectorImpl<mlir::Value> &intermediateIndices) {
684+
llvm::SmallVector<std::string, 4> compsRev;
685+
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> intermediateIdxGroupsRev;
565686
hlfir::DesignateOp cur = leaf;
566687
for (int i = 0; i < kFlattenMaxDepth && cur; ++i) {
567688
mlir::StringAttr compAttr;
@@ -570,26 +691,44 @@ static std::string designateChainPath(hlfir::DesignateOp leaf,
570691
compAttr = a;
571692
break;
572693
}
573-
if (!compAttr) {
574-
// Reached a non-component designate (a subscripted access
575-
// ``a(i,j)``) — that's only valid as the LEAF of the chain,
576-
// i.e. the very first call. Stop here.
577-
break;
694+
if (compAttr) compsRev.push_back(compAttr.getValue().str());
695+
bool isLeaf = (cur == leaf);
696+
if (!isLeaf) {
697+
// Intermediate steps must be plain (no triplets).
698+
// Triplet sections on intermediate levels would mean a
699+
// non-uniform slice through the array-of-record path
700+
// (e.g. ``p_prog%pprog(2:5)%w(j)``); not in scope.
701+
for (bool t : cur.getIsTriplet()) if (t) return "";
702+
llvm::SmallVector<mlir::Value, 4> these(
703+
cur.getIndices().begin(), cur.getIndices().end());
704+
intermediateIdxGroupsRev.push_back(std::move(these));
578705
}
579-
parts.push_back(compAttr.getValue().str());
580-
outAnchor = cur;
581-
// Walk to the parent; if it's another designate keep going.
706+
// Walk to parent.
582707
auto memref = cur.getMemref();
583708
cur = mlir::dyn_cast_or_null<hlfir::DesignateOp>(memref.getDefiningOp());
584709
}
585-
if (parts.empty()) return "";
586-
// Reverse to outermost-first order matching FlatLeaf.path.
587-
std::reverse(parts.begin(), parts.end());
710+
if (compsRev.empty()) return "";
711+
// Reverse to outer-first. Components join with "_" to match
712+
// FlatLeaf.path's canonical form.
588713
std::string joined;
589-
for (unsigned i = 0; i < parts.size(); ++i) {
590-
if (i) joined += "_";
591-
joined += parts[i];
714+
for (auto it = compsRev.rbegin(); it != compsRev.rend(); ++it) {
715+
if (!joined.empty()) joined += "_";
716+
joined += *it;
592717
}
718+
for (auto it = intermediateIdxGroupsRev.rbegin();
719+
it != intermediateIdxGroupsRev.rend(); ++it)
720+
intermediateIndices.append(it->begin(), it->end());
721+
return joined;
722+
}
723+
724+
/// Backwards-compatible wrapper used by callers that only need the
725+
/// path (no merged indices) — keeps the original entry point shape
726+
/// while ``walkDesignateChain`` is the canonical implementation.
727+
static std::string designateChainPath(hlfir::DesignateOp leaf,
728+
hlfir::DesignateOp &outAnchor) {
729+
llvm::SmallVector<mlir::Value, 4> ignored;
730+
auto joined = walkDesignateChain(leaf, ignored);
731+
outAnchor = leaf;
593732
return joined;
594733
}
595734

@@ -602,27 +741,53 @@ static bool rewriteDesignateChain(
602741
hlfir::DesignateOp leaf,
603742
const llvm::StringMap<mlir::Value> &leafBase) {
604743

605-
hlfir::DesignateOp anchor;
606-
std::string path = designateChainPath(leaf, anchor);
744+
llvm::SmallVector<mlir::Value, 4> intermediateIndices;
745+
std::string path = walkDesignateChain(leaf, intermediateIndices);
607746
if (path.empty()) return false;
608747
auto it = leafBase.find(path);
609748
if (it == leafBase.end()) return false;
610749
auto newBase = it->second;
611750

612-
// ``leaf`` is the INNERMOST (component or component-with-indices) op.
613-
// Its result is what the rest of the IR consumes.
614-
if (leaf.getIndices().empty()) {
615-
leaf.getResult().replaceAllUsesWith(newBase);
751+
// Leaf-only path (no intermediate indices). Preserves the
752+
// leaf's full shape — including triplet sections, shape
753+
// operand, complex_part, etc. — by cloning and rewiring just
754+
// the memref + clearing the component attrs. Whole-leaf
755+
// access (``base{"a"}{"b"}`` with no indices) just RAUWs.
756+
if (intermediateIndices.empty()) {
757+
if (leaf.getIndices().empty()) {
758+
leaf.getResult().replaceAllUsesWith(newBase);
759+
leaf.erase();
760+
return true;
761+
}
762+
mlir::OpBuilder rb(leaf);
763+
auto *clone = rb.clone(*leaf.getOperation());
764+
clone->setOperand(0, newBase);
765+
clone->removeAttr("component");
766+
clone->removeAttr("component_name");
767+
leaf.getResult().replaceAllUsesWith(clone->getResult(0));
616768
leaf.erase();
617769
return true;
618770
}
619771

772+
// Intermediate-indices path (array-of-record member surfaced
773+
// by ``collectFlatLeaves``'s extra outerDims). Build a fresh
774+
// designate over the flat companion with intermediate +
775+
// leaf indices merged in outer-first order. No triplets at
776+
// intermediate levels (walker bails on that). Whether the
777+
// leaf itself has triplets is rare in this shape — a section
778+
// on the innermost array of a record-of-record-of-... — and
779+
// is also out of scope; bail to keep the contract narrow.
780+
for (bool t : leaf.getIsTriplet()) if (t) return false;
781+
llvm::SmallVector<mlir::Value, 6> merged(intermediateIndices.begin(),
782+
intermediateIndices.end());
783+
for (auto v : leaf.getIndices()) merged.push_back(v);
620784
mlir::OpBuilder rb(leaf);
621-
auto *clone = rb.clone(*leaf.getOperation());
622-
clone->setOperand(0, newBase);
623-
clone->removeAttr("component");
624-
clone->removeAttr("component_name");
625-
leaf.getResult().replaceAllUsesWith(clone->getResult(0));
785+
auto newOp = rb.create<hlfir::DesignateOp>(
786+
leaf.getLoc(),
787+
leaf.getResult().getType(),
788+
newBase,
789+
mlir::ValueRange{merged});
790+
leaf.getResult().replaceAllUsesWith(newOp.getResult());
626791
leaf.erase();
627792
return true;
628793
}

tests/hlfir/derived_type_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,45 @@ def test_nested_struct_lowered_via_phase2(tmp_path: Path):
13101310
np.testing.assert_array_equal(d, d_ref)
13111311

13121312

1313+
def test_array_of_nested_struct_member(tmp_path: Path):
1314+
"""Phase 2 extension: a struct member is an ARRAY of another
1315+
struct (``type(simple_type) :: pprog(10)``). The flat
1316+
companion folds the array dim into the leaf's shape so
1317+
``p_prog%pprog(i)%w(j, k)`` rewrites to a 3D companion
1318+
``p_prog_pprog_w(i, j, k)``. Exercises ``collectFlatLeaves``'s
1319+
``array<N x RecordType>`` branch and ``walkDesignateChain``'s
1320+
intermediate-indices path together.
1321+
"""
1322+
src = """
1323+
module lib
1324+
implicit none
1325+
type simple_type
1326+
real :: w(5, 5)
1327+
end type simple_type
1328+
type simple_type2
1329+
type(simple_type) :: pprog(10)
1330+
end type simple_type2
1331+
end module lib
1332+
1333+
subroutine main(d)
1334+
use lib
1335+
implicit none
1336+
real, intent(out) :: d(5, 5)
1337+
type(simple_type2) :: p_prog
1338+
p_prog%pprog(1)%w(1, 1) = 47.0
1339+
d(1, 1) = p_prog%pprog(1)%w(1, 1)
1340+
end subroutine main
1341+
"""
1342+
mod = f2py_compile(src, tmp_path / "ref", "array_of_nested_struct_ref")
1343+
d_ref = np.asarray(mod.main(), dtype=np.float32)
1344+
1345+
sdfg = _build(src, tmp_path)
1346+
d = np.zeros((5, 5), order="F", dtype=np.float32)
1347+
sdfg(d=d)
1348+
np.testing.assert_array_equal(d, d_ref)
1349+
assert d[0][0] == 47.0
1350+
1351+
13131352
def test_aos_member_to_member_array_copy(tmp_path: Path):
13141353
"""AoS pattern ``a(i)%b = a(j)%c`` where ``b`` and ``c`` are array
13151354
members — the assignment is a whole-array copy of one inner row.

0 commit comments

Comments
 (0)