cuda.tile.mma#

cuda.tile.mma(x, y, /, acc)#

Matrix multiply-accumulate.

Computes (x @ y) + acc as a single operation (where @ denotes matrix multiplication). Preserves the dtype of acc.

Parameters:
  • x (Tile) – LHS of the mma, 2D or 3D.

  • y (Tile) – RHS of the mma, 2D or 3D.

  • acc (Tile) – Accumulator of mma.

Supported datatypes:

Input

Acc/Output

f16

f16 or f32

bf16

f32

f32

f32

f64

f64

tf32

f32

f8e4m3fn

f16 or f32

f8e5m2

f16 or f32

[u|i]8

i32

If x and y have different dtype, they will NOT be promoted to common dtype. Shape of x and y will be broadcasted to up until the last two axes.

Return type:

Tile

Examples

2D x 2D with accumulation.

x = ct.ones((2, 4), dtype=ct.float32)
y = ct.ones((4, 2), dtype=ct.float32)
acc = ct.full((2, 2), 10.0, dtype=ct.float32)
# (x @ y) + acc: each element = 1*4 + 10 = 14
print(f"{ct.mma(x, y, acc):.1f}")
import cuda.tile as ct
import torch

@ct.kernel
def kernel():
    x = ct.ones((2, 4), dtype=ct.float32)
    y = ct.ones((4, 2), dtype=ct.float32)
    acc = ct.full((2, 2), 10.0, dtype=ct.float32)
    # (x @ y) + acc: each element = 1*4 + 10 = 14
    print(f"{ct.mma(x, y, acc):.1f}")


torch.cuda.init()
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
torch.cuda.synchronize()

Output

[[14.0, 14.0], [14.0, 14.0]]

Batched: 3D x 2D with broadcast.

x = ct.ones((2, 2, 4), dtype=ct.float32)
y = ct.ones((4, 2), dtype=ct.float32)
acc = ct.zeros((2, 2, 2), dtype=ct.float32)
print(f"{ct.mma(x, y, acc):.1f}")
import cuda.tile as ct
import torch

@ct.kernel
def kernel():
    x = ct.ones((2, 2, 4), dtype=ct.float32)
    y = ct.ones((4, 2), dtype=ct.float32)
    acc = ct.zeros((2, 2, 2), dtype=ct.float32)
    print(f"{ct.mma(x, y, acc):.1f}")


torch.cuda.init()
ct.launch(torch.cuda.current_stream(), (1,), kernel, ())
torch.cuda.synchronize()

Output

[[[4.0, 4.0], [4.0, 4.0]], [[4.0, 4.0], [4.0, 4.0]]]