Assignment 01: Tensors and Einsum

  1import torch
  2
  3
  4# ===========================================================================
  5# Task 1: Dot Product
  6# ===========================================================================
  7
  8def dot_product(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
  9    """Dot product of two vectors a and b."""
 10
 11    assert a.ndim == 1 and b.ndim == 1, "Input tensors must be 1D vectors."
 12    assert a.size() == b.size(), "Input vectors must have the same size."
 13
 14    result = torch.tensor(0.0)
 15
 16    for _a, _b in zip(a, b):
 17        result += _a * _b
 18
 19    return result
 20
 21
 22# ===========================================================================
 23# Task 2: Matrix–Matrix Multiplication
 24# ===========================================================================
 25
 26def matmul_loops(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
 27    """Matrix product C = A @ B via nested for loops."""
 28
 29    assert A.ndim == 2 and B.ndim == 2, "Input tensors must be 2D matrices."
 30    m, k = A.shape
 31    k2, n = B.shape
 32    assert k == k2, "Incompatible matrix dimensions"
 33    
 34    C = torch.zeros(m, n)
 35
 36    for i in range(m):
 37        for j in range(n):
 38            sum_ij = 0.0
 39            for p in range(k):
 40                sum_ij += A[i, p] * B[p, j]
 41            C[i, j] = sum_ij
 42
 43    return C
 44
 45
 46def matmul_dot(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
 47    """Matrix product C = A @ B via slicing and calls to dot_product."""
 48
 49    assert A.ndim == 2 and B.ndim == 2, "Input tensors must be 2D matrices."
 50    m, k = A.shape
 51    k2, n = B.shape
 52    assert k == k2, "Incompatible matrix dimensions"
 53    
 54    C = torch.zeros(m, n)
 55
 56    # TODO: implement using two for loops and calls to dot_product
 57
 58    for i in range(m):
 59        for j in range(n):
 60            C[i, j] = dot_product(A[i,:], B[:, j])
 61    return C
 62
 63
 64# ===========================================================================
 65# Task 3: Einsum  acsxp, bspy -> abcxy
 66# ===========================================================================
 67
 68def einsum_loops(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
 69    """Einsum acsxp, bspy -> abcxy via nested for loops."""
 70
 71    assert A.ndim == 5 and B.ndim == 4, "Input tensors must have the correct number of dimensions."
 72    assert A.size() == torch.Size([2, 4, 5, 4, 3]), "Input tensor A must have shape [2, 4, 5, 4, 3]."
 73    assert B.size() == torch.Size([3, 5, 3, 5]), "Input tensor B must have shape [3, 5, 3, 5]."
 74
 75    size_a, size_c, size_s, size_x, size_p = A.shape
 76    size_b, size_y = B.shape[0], B.shape[3]
 77
 78    C = torch.zeros(size_a, size_b, size_c, size_x, size_y)
 79    # TODO: implement using for loops over all seven index dimensions
 80
 81    for a in range(size_a):
 82        for b in range(size_b):
 83            for c in range(size_c):
 84                for x in range(size_x):
 85                    for y in range(size_y):
 86                        for s in range(size_s):
 87                            for p in range(size_p):
 88                                C[a, b, c, x, y] += A[a, c, s, x, p] * B[b, s, p, y]
 89
 90    return C
 91
 92
 93def einsum_gemm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
 94    """Einsum acsxp, bspy -> abcxy via loops over a, b, c, s and a GEMM (xp, py -> xy)."""
 95
 96    assert A.ndim == 5 and B.ndim == 4, "Input tensors must have the correct number of dimensions."
 97    assert A.size() == torch.Size([2, 4, 5, 4, 3]), "Input tensor A must have shape [2, 4, 5, 4, 3]."
 98    assert B.size() == torch.Size([3, 5, 3, 5]), "Input tensor B must have shape [3, 5, 3, 5]."
 99
100    size_a, size_c, size_s, size_x, size_p = A.shape
101    size_b, size_y = B.shape[0], B.shape[3]
102
103    C = torch.zeros(size_a, size_b, size_c, size_x, size_y)
104    for a in range(size_a):
105        for b in range(size_b):
106            for c in range(size_c):
107                for s in range(size_s):
108                    C[a, b, c, :, :] += matmul_loops(A[a, c, s, :, :], B[b, s, :, :])
109
110    return C
111
112
113# ===========================================================================
114# Task runners
115# ===========================================================================
116
117def task1():
118    v1 = torch.rand(128)
119    v2 = torch.rand(128)
120
121    result_custom = dot_product(v1, v2)
122    result_torch  = torch.dot(v1, v2)
123    assert torch.allclose(result_custom, result_torch), (
124        f"Task 1 mismatch: custom={result_custom:.6f}, torch={result_torch:.6f}"
125    )
126    print("Task 1 passed!")
127
128
129def task2():
130    A = torch.rand(8, 32)
131    B = torch.rand(32, 16)
132
133    result_loops = matmul_loops(A, B)
134    result_dot   = matmul_dot(A, B)
135    result_torch = torch.matmul(A, B)
136    assert torch.allclose(result_loops, result_torch, atol=1e-5), (
137        "Task 2 matmul_loops mismatch!"
138    )
139    assert torch.allclose(result_dot, result_torch, atol=1e-5), (
140        "Task 2 matmul_dot mismatch!"
141    )
142    print("Task 2 passed!")
143
144
145def task3():
146    # A has shape [a, c, s, x, p] = [2, 4, 5, 4, 3]
147    # B has shape [b, s, p, y]    = [3, 5, 3, 5]
148    # C has shape [a, b, c, x, y] = [2, 3, 4, 4, 5]
149    A_ein = torch.rand(2, 4, 5, 4, 3)
150    B_ein = torch.rand(3, 5, 3, 5)
151
152    reference    = torch.einsum("acsxp, bspy -> abcxy", A_ein, B_ein)
153    result_loops = einsum_loops(A_ein, B_ein)
154    result_gemm  = einsum_gemm(A_ein, B_ein)
155    assert torch.allclose(result_loops, reference, atol=1e-5), (
156        "Task 3A mismatch: pure-loop einsum differs from reference!"
157    )
158    assert torch.allclose(result_gemm, reference, atol=1e-5), (
159        "Task 3B mismatch: loop+GEMM einsum differs from reference!"
160    )
161    print("Task 3 passed!")
162
163
164def main():
165    task1()
166    task2()
167    task3()
168
169
170if __name__ == "__main__":
171    main()