cuda.tile.mma#
- cuda.tile.mma(x, y, /, acc)#
Matrix multiply-accumulate.
Computes
(x @ y) + accas a single operation (where@denotes matrix multiplication). Preserves the dtype of acc.- Parameters:
Supported datatypes:
Input
Acc/Output
f16
f16 or f32
bf16
f32
f32
f32
f64
f64
tf32
f32
f8e4m3fn
f16 or f32
f8e5m2
f16 or f32
[u|i]8
i32
If x and y have different dtype, they will NOT be promoted to common dtype. Shape of x and y will be broadcasted to up until the last two axes.
- Return type:
Examples
2D x 2D with accumulation.
x = ct.ones((2, 4), dtype=ct.float32) y = ct.ones((4, 2), dtype=ct.float32) acc = ct.full((2, 2), 10.0, dtype=ct.float32) # (x @ y) + acc: each element = 1*4 + 10 = 14 print(f"{ct.mma(x, y, acc):.1f}")
import cuda.tile as ct import torch @ct.kernel def kernel(): x = ct.ones((2, 4), dtype=ct.float32) y = ct.ones((4, 2), dtype=ct.float32) acc = ct.full((2, 2), 10.0, dtype=ct.float32) # (x @ y) + acc: each element = 1*4 + 10 = 14 print(f"{ct.mma(x, y, acc):.1f}") torch.cuda.init() ct.launch(torch.cuda.current_stream(), (1,), kernel, ()) torch.cuda.synchronize()
Output
[[14.0, 14.0], [14.0, 14.0]]
Batched: 3D x 2D with broadcast.
x = ct.ones((2, 2, 4), dtype=ct.float32) y = ct.ones((4, 2), dtype=ct.float32) acc = ct.zeros((2, 2, 2), dtype=ct.float32) print(f"{ct.mma(x, y, acc):.1f}")
import cuda.tile as ct import torch @ct.kernel def kernel(): x = ct.ones((2, 2, 4), dtype=ct.float32) y = ct.ones((4, 2), dtype=ct.float32) acc = ct.zeros((2, 2, 2), dtype=ct.float32) print(f"{ct.mma(x, y, acc):.1f}") torch.cuda.init() ct.launch(torch.cuda.current_stream(), (1,), kernel, ()) torch.cuda.synchronize()
Output
[[[4.0, 4.0], [4.0, 4.0]], [[4.0, 4.0], [4.0, 4.0]]]