Assignment 05: Contraction Interface and L2 Optimization

In this assignment you will build a high-level configuration interface for tensor contractions, implement an optimizer that manipulates those configurations, and use it to derive and benchmark an L2-optimized cuTile kernel.

All code should be written in src/.

We assume the following import conventions:

import cuda.tile as ct
import cupy as cp
import torch
import triton
from dataclasses import dataclass, field

Use FP16 data type for tensor inputs and outputs, accumulate in FP32. We assume row-major order for all tensors.


Task 1: Config Class

Define the following Python types that together represent a tensor contraction configuration as introduced in the lecture.

a) Define enumeration types (e.g. using Python’s enum.Enum or simple class constants) for:

  • DimType: M, N, K, C

  • ExecType: SEQ, PAR, PRIM

  • PrimType: GEMM, BGEMM

  • LastType: NONE, ELWISE_MUL

  • FirstType: ZERO

  • DataType: FLOAT16, FLOAT32

b) Define a Config dataclass with the following fields, matching the interface shown in the lecture:

Field

Type

Description

data_type

DataType

Numeric precision of the operands

prim_main

PrimType

Main (B)GEMM primitive used inside the kernel

prim_last

LastType

Optional elementwise operation applied after the accumulation

prim_first

FirstType

Initialization of the accumulator

dim_types

list[DimType]

Per-dimension index type

exec_types

list[ExecType]

Per-dimension execution strategy

dim_sizes

list[int]

Per-dimension size

strides

list[list[int]]

Per-tensor, per-dimension stride (one inner list per tensor)


Task 2: Generating a Basic Config

Write a function generate_config that takes an einsum string and a list of shapes for the input tensors (the output shape is implied by the einsum) and returns a basic Config.

Requirements:

  • Classify each dimension index automatically by inspecting in which tensors it appears.

  • Compute strides for every tensor assuming row-major layout. A stride of 0 indicates that the dimension does not appear in that tensor.

  • Set all exec_types to SEQ.

  • Set data_type = DataType.FLOAT16, prim_main = PrimType.GEMM, prim_last = LastType.NONE, prim_first = FirstType.ZERO.


Task 3: Optimizer Class

Implement a class Optimizer that wraps a Config and exposes methods to transform it.

a) Implement the function split_dim(dim_id: int, outer_size: int, inner_size: int).

It splits one dimension into two. outer_size * inner_size must equal the original size; raise a ValueError otherwise.

After splitting:

  • Insert two new dimensions at the position of the original dimension.

  • The outer dimension (left) gets size = outer_size.

  • The inner dimension (right) gets size = inner_size.

  • Strides have to be updated accordingly.

  • Both new dimensions inherit dim_type and exec_type from the original.

b) Implement the function fuse_dims(dim_id_a: int, dim_id_b: int).

Fuse two dimensions into a single one. Two dimensions can only be fused if they are adjacent in every tensor they both appear in, i.e., the two dimensions are contiguous in memory (stride[a] == stride[b] * size[b] or stride[a] * size[a] == stride[b]) in every tensor.

Check this condition for all tensors before performing the fusion. Raise a descriptive ValueError if the check fails.

After a valid fusion:

  • The new size is size[a] * size[b].

  • Update the strides lists accordingly.

  • The fused dimension inherits the dim_type and exec_type of dim_id_a.

  • Remove one dimension from all lists.

c) Implement the function permute_dims(permutation: list[int]).

Reorder all per-dimension lists (dim_types, exec_types, dim_sizes, and each tensor’s strides list) according to permutation, following the syntax of torch.permute.

d) Implement the function make_executable().

Set exec types and permute the config’s dimensions so that the config becomes executable via cuTile. Use the parallel execution type where possible. Test the resulting configuration with your verify() function from e).

e) Implement the function verify().

Check that the current configuration is executable. Raise a descriptive ValueError for each violated condition:

  1. No K-dimension may have exec_type = PAR.

  2. All dimensions with exec_type = SEQ must appear to the left of all dimensions with exec_type = PRIM in the config.

  3. All dimensions with exec_type = PAR must appear to the left of all dimensions with exec_type = SEQ in the config.

  4. The rightmost dimensions must be PRIM and the PRIM dimensions must include at least one dimension of each type M, N, and K.


Task 4: L2-Optimized Batched Contraction

Consider the batched matrix multiplication expressed as cmk, ckn -> cmn with dimension sizes $|c| = 4$, $|m| = |n| = |k| = 4096$.

a) Use your generate_config function from Task 2 to produce the initial Config for this contraction. Report the resulting config.

b) Use your Optimizer and the implemented functions from Task 3 to transform the basic config into an L2-optimized one, following the general L2-reuse pattern from the lecture.

config.dim_sizes = [ [...], |m_l2|, |n_l2|, |m_prim|, |n_prim|, |k_prim|]

Choose the sizes for m_l2, m_prim, n_l2, n_prim and justify your choice with respect to L2 cache reuse. Report the final config.

c) Implement the kernel

Implement a cuTile kernel that computes cmk, ckn -> cmn following your optimized config from b). Verify correctness of your kernel.

d) Use triton.testing.do_bench (or a similar benchmark function provided by cuTile/Torch) to measure the average kernel runtime. Report the achieved performance in TFLOPS. Compare the performance of your L2-optimized kernel to a baseline kernel that maps BIDs in plain row-major order over (c, m, n) without any splitting or permuting. Report your findings.