Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,9 +841,9 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:

# check for remaining dims in the outside dimensions
rem_a_out, rem_b_out = 0, 0
if a.lshape[-2] % mB != 0 or (kB == 1 and a.lshape[-2] != 1):
if a.lshape[-2] % mB != 0 or (mB == 1 and a.lshape[-2] != 1):
rem_a_out = 1
if b.lshape[-1] % nB != 0 or (kB == 1 and b.lshape[-1] != 1):
if b.lshape[-1] % nB != 0 or (nB == 1 and b.lshape[-1] != 1):
Comment thread
brownbaerchen marked this conversation as resolved.
rem_b_out = 1

# get the flags from all processes
Expand Down
15 changes: 15 additions & 0 deletions tests/core/linalg/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,21 @@ def test_matmul(self):
# self.assertTrue(ht.allclose(ret_batched, c, 1e-2))
self.assertTrue(max_diff < 1e-4)

def test_matmul_edge_case_1(self):
# test edge cases as documented in #2093

if ht.comm.size == 4:
split = 0
shape = (8, 6)

A = ht.ones(shape, split=split)
B = ht.ones(shape[::-1], split=split)

C = A @ B
self.assertTrue(ht.allclose(C, shape[-1]))
else:
self.skipTest('This edge case requires four tasks')

def test_matrix_norm(self):
a = ht.arange(9, dtype=ht.float) - 4
b = a.reshape((3, 3))
Expand Down
Loading