Skip to content

Commit d9e57d7

Browse files
authored
Fix lossless casts of vector reduces down to bools (#9012)
Fixes #9011
1 parent 6a7ed97 commit d9e57d7

2 files changed

Lines changed: 32 additions & 2 deletions

File tree

src/IROperator.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,12 +544,20 @@ Expr lossless_cast(Type t,
544544
}
545545
}
546546
} else if (const VectorReduce *op = e.as<VectorReduce>()) {
547-
if (op->op == VectorReduce::Add ||
547+
if ((t.bits() > 1 && op->op == VectorReduce::Add) ||
548548
op->op == VectorReduce::Min ||
549549
op->op == VectorReduce::Max) {
550550
Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, scope, cache);
551551
if (v.defined()) {
552-
return VectorReduce::make(op->op, v, op->type.lanes());
552+
auto reduce_op = op->op;
553+
if (t.bits() == 1) {
554+
// UInt(1) == Bool() is the only 1-bit type we expect to see
555+
internal_assert(t.is_uint()) << "Unexpected type: " << t << "\n";
556+
reduce_op = (op->op == VectorReduce::Min ?
557+
VectorReduce::And :
558+
VectorReduce::Or);
559+
}
560+
return VectorReduce::make(reduce_op, v, op->type.lanes());
553561
}
554562
}
555563
}

test/correctness/lossless_cast.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,28 @@ int lossless_cast_test() {
8181
e = cast(i64, 1024) * cast(i64, 1024) * cast(i64, 1024);
8282
res |= check_lossless_cast(i32, e, (cast(i32, 1024) * 1024) * 1024);
8383

84+
// Check narrowing a vector reduction of something narrowable to bool ...
85+
auto make_reduce = [&](Type t, VectorReduce::Operator op) {
86+
return VectorReduce::make(op,
87+
cast(t.with_lanes(4), Ramp::make(x, 1, 4) > 4), 2);
88+
};
89+
90+
// It's OK to narrow it to 8-bit.
91+
e = make_reduce(UInt(16), VectorReduce::Add);
92+
res |= check_lossless_cast(UInt(8), e, make_reduce(UInt(8), VectorReduce::Add));
93+
94+
// ... but we can't reduce it all the way to bool if the operator isn't
95+
// legal for bools (issue #9011)
96+
e = make_reduce(UInt(8), VectorReduce::Add);
97+
res |= check_lossless_cast(Bool(), e, Expr());
98+
99+
// Min or Max, however, can just become And and Or
100+
e = make_reduce(UInt(8), VectorReduce::Min);
101+
res |= check_lossless_cast(Bool(), e, make_reduce(Bool(), VectorReduce::And));
102+
103+
e = make_reduce(UInt(8), VectorReduce::Max);
104+
res |= check_lossless_cast(Bool(), e, make_reduce(Bool(), VectorReduce::Or));
105+
84106
return res;
85107
}
86108

0 commit comments

Comments
 (0)