Assignment 03: Matrix Multiplication with cuTile

The file assignments/03_assignment/src/__main__.py contains the main function that runs all the tasks for this assignment. Each task is implemented in a separate file in the same directory. The results of each task are printed to the console when the main function is executed.

Task 1: FP32 vs FP16 Performance

Output:

16bit TFLOPs:  1.1449544748761349
32bit TFLOPs:  0.019409971572566145
import cuda.tile as ct
import cupy as cp
from numpy.strings import index
import torch
import triton

def main():
    inner_size = 4096
    A = torch.randn((64, inner_size), device='cuda', dtype=torch.float16)
    B = torch.randn((inner_size, 64), device='cuda', dtype=torch.float16)
    C = torch.empty((64, 64), device='cuda', dtype=torch.float32)
    
    grid = (1, )

    torch.cuda.init()
    ct.launch(torch.cuda.current_stream(), grid, kernel_fp16, (A, B, C))
    torch.cuda.synchronize()

    expected = torch.empty((64, 64), device='cuda', dtype=torch.float16)
    torch.matmul(A, B, out=expected)
    expected = expected.to(torch.float32)  # Convert to float32 for comparison
    assert torch.allclose(C, expected, atol=1e-1), "The result is incorrect!"


@ct.kernel
def kernel_fp16(A, B, C):
    m_tile=64
    n_tile=64
    k_tile=64

    result = ct.load(C, index=(0, 0), shape=(m_tile, n_tile))
    for i in range(0, A.shape[0] // m_tile):
        for j in range(0, B.shape[1] // n_tile):
            for k in range(0, A.shape[1] // k_tile):
                A_block = ct.load(A, index=(i, k), shape=(m_tile, k_tile))
                B_block = ct.load(B, index=(k, j), shape=(k_tile, n_tile))
                result = ct.mma(A_block, B_block, acc=result)

    ct.store(C, index=(0, 0), tile=result)

if __name__ == "__main__":
    main()

Task 2: Simple Matrix Multiplication Kernel

Output:

16bit TFLOPs:  1.1449544748761349
32bit TFLOPs:  0.019409971572566145
import math

import cuda.tile as ct
import cupy as cp
from numpy.strings import index
import torch
import triton

def main():
    M = 321
    N = 123
    K = 23
    m_tile=64
    n_tile=32
    k_tile=128

    M_padded = int(2**math.ceil(math.log2(max(M, m_tile)))) 
    N_padded = int(2**math.ceil(math.log2(max(N, n_tile))))
    K_padded = int(2**math.ceil(math.log2(max(K, k_tile))))


    A = torch.randn((M, K), device='cuda', dtype=torch.float16)
    B = torch.randn((K, N), device='cuda', dtype=torch.float16)
    C = torch.empty((M, N), device='cuda', dtype=torch.float32)

    grid = (math.ceil(M_padded / m_tile) * math.ceil(N_padded / n_tile), )
    torch.cuda.init()
    ct.launch(torch.cuda.current_stream(), grid, kernel_fp16, (A, B, C, m_tile, n_tile, k_tile, M_padded, N_padded, K_padded))
    torch.cuda.synchronize()

    expected = torch.empty((M, N), device='cuda', dtype=torch.float16)
    torch.matmul(A, B, out=expected)
    expected = expected.to(torch.float32)  # Convert to float32 for comparison
    # print("Expected:\n", expected)
    # print("Actual:\n", C)
    assert torch.allclose(C, expected, atol=1e-1), "The result is incorrect!"


@ct.kernel
def kernel_fp16(A, B, C, m_tile: ct.Constant[int], n_tile: ct.Constant[int], k_tile: ct.Constant[int], M_padded: ct.Constant[int], N_padded: ct.Constant[int], K_padded: ct.Constant[int]):
    bid = ct.bid(0)
    bid_x = bid % (N_padded // n_tile)
    bid_y = bid // (N_padded // n_tile)
    result = ct.zeros((m_tile, n_tile), dtype=torch.float32)
    for k in range(0, K_padded // k_tile):
        A_block = ct.load(A, index=(bid_y , k), shape=(m_tile, k_tile), padding_mode=ct.PaddingMode.ZERO)
        B_block = ct.load(B, index=(k, bid_x), shape=(k_tile, n_tile), padding_mode=ct.PaddingMode.ZERO)
        result = ct.mma(A_block, B_block, acc=result)
    # print("Result in kernel:\n", result)
    ct.store(C, index=(bid_y, bid_x), tile=result)

if __name__ == "__main__":
    main()

Task 3: Benchmarking the Matrix Multiplication Kernel

a) Benchmark your kernel with tile shapes (64, 64, 64) for square matrix multiplications of sizes:

alt text

b) Fix the matrix size at 2048 × 2048 × 2048, as well as 512 × 512 × 512, and benchmark all tile shape combinations (27 total):

alt text

alt text

Output:

BEST tile shape for 512x512x512 is (128, 64, 128) achieving 12.07 TFLOPS
BEST tile shape for 2048x2048x2048 is (128, 128, 64) achieving 54.57 TFLOPS

Task 4: L2 Cache Optimization via Block Swizzling

@ct.kernel
def kernel_matmul_swizzle(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int], tk: ct.Constant[int], grid_x, grid_y):

    swizzle_size = 8
    pid = ct.bid(0)

    num_pid_in_block = swizzle_size * grid_y
    block_index = (pid // num_pid_in_block)
    
    begin_m = (block_index * swizzle_size)

    swizzle = swizzle_size

    if (begin_m + swizzle_size) > grid_x:
        swizzle = grid_x - begin_m

    index_m_temp  = pid % swizzle
    index_n_temp = pid // swizzle
    
    index_n = index_n_temp % grid_y
    index_m = begin_m + index_m_temp


    num_tiles_k = ct.num_tiles(A, axis=1, shape=(tm, tk))
    accumulator = ct.full((tm, tn), 0, dtype=ct.float32)


    for k in range(num_tiles_k):                                                
        
        a = ct.load(A, index=(index_m, k), shape=(tm, tk), padding_mode=ct.PaddingMode.ZERO)
        b = ct.load(B, index=(k, index_n), shape=(tk, tn), padding_mode=ct.PaddingMode.ZERO)

        accumulator = ct.mma(a, b, accumulator)

    ct.store(C, index=(index_m, index_n), tile=accumulator)

Output:

swizzle_kernel TFLOPs:  68.14132984785671
non_swizzle_kernel TFLOPs:  27.46563761972282

PIDs are mapped into horizontal ‘stripes’ across the output matrix. Each stripe consists of 8 rows. Within a stripe, the PIDs traverse the tiles column by column: the first 8 PIDs compute a vertical column of 8 tiles downwards. When the stripe is finished. The next stripe is computed, starting at row index 8. At the last stripe the remaining heiht of the stripe (the rows) are calculated dynamically, to prevent out-of-bounds memory accesses.

alt text

-> BEST tile shape for 512x512x512 is (128, 64, 32) achieving 10.77 TFLOPS

alt text

-> BEST tile shape for 2048x2048x2048 is (128, 128, 64) achieving 54.77 TFLOPS