1import torch
2
3
4# ===========================================================================
5# Task 1: Dot Product
6# ===========================================================================
7
8def dot_product(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
9 """Dot product of two vectors a and b."""
10
11 assert a.ndim == 1 and b.ndim == 1, "Input tensors must be 1D vectors."
12 assert a.size() == b.size(), "Input vectors must have the same size."
13
14 result = torch.tensor(0.0)
15
16 for _a, _b in zip(a, b):
17 result += _a * _b
18
19 return result
20
21
22# ===========================================================================
23# Task 2: Matrix–Matrix Multiplication
24# ===========================================================================
25
26def matmul_loops(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
27 """Matrix product C = A @ B via nested for loops."""
28
29 assert A.ndim == 2 and B.ndim == 2, "Input tensors must be 2D matrices."
30 m, k = A.shape
31 k2, n = B.shape
32 assert k == k2, "Incompatible matrix dimensions"
33
34 C = torch.zeros(m, n)
35
36 for i in range(m):
37 for j in range(n):
38 sum_ij = 0.0
39 for p in range(k):
40 sum_ij += A[i, p] * B[p, j]
41 C[i, j] = sum_ij
42
43 return C
44
45
46def matmul_dot(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
47 """Matrix product C = A @ B via slicing and calls to dot_product."""
48
49 assert A.ndim == 2 and B.ndim == 2, "Input tensors must be 2D matrices."
50 m, k = A.shape
51 k2, n = B.shape
52 assert k == k2, "Incompatible matrix dimensions"
53
54 C = torch.zeros(m, n)
55
56 # TODO: implement using two for loops and calls to dot_product
57
58 for i in range(m):
59 for j in range(n):
60 C[i, j] = dot_product(A[i,:], B[:, j])
61 return C
62
63
64# ===========================================================================
65# Task 3: Einsum acsxp, bspy -> abcxy
66# ===========================================================================
67
68def einsum_loops(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
69 """Einsum acsxp, bspy -> abcxy via nested for loops."""
70
71 assert A.ndim == 5 and B.ndim == 4, "Input tensors must have the correct number of dimensions."
72 assert A.size() == torch.Size([2, 4, 5, 4, 3]), "Input tensor A must have shape [2, 4, 5, 4, 3]."
73 assert B.size() == torch.Size([3, 5, 3, 5]), "Input tensor B must have shape [3, 5, 3, 5]."
74
75 size_a, size_c, size_s, size_x, size_p = A.shape
76 size_b, size_y = B.shape[0], B.shape[3]
77
78 C = torch.zeros(size_a, size_b, size_c, size_x, size_y)
79 # TODO: implement using for loops over all seven index dimensions
80
81 for a in range(size_a):
82 for b in range(size_b):
83 for c in range(size_c):
84 for x in range(size_x):
85 for y in range(size_y):
86 for s in range(size_s):
87 for p in range(size_p):
88 C[a, b, c, x, y] += A[a, c, s, x, p] * B[b, s, p, y]
89
90 return C
91
92
93def einsum_gemm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
94 """Einsum acsxp, bspy -> abcxy via loops over a, b, c, s and a GEMM (xp, py -> xy)."""
95
96 assert A.ndim == 5 and B.ndim == 4, "Input tensors must have the correct number of dimensions."
97 assert A.size() == torch.Size([2, 4, 5, 4, 3]), "Input tensor A must have shape [2, 4, 5, 4, 3]."
98 assert B.size() == torch.Size([3, 5, 3, 5]), "Input tensor B must have shape [3, 5, 3, 5]."
99
100 size_a, size_c, size_s, size_x, size_p = A.shape
101 size_b, size_y = B.shape[0], B.shape[3]
102
103 C = torch.zeros(size_a, size_b, size_c, size_x, size_y)
104 for a in range(size_a):
105 for b in range(size_b):
106 for c in range(size_c):
107 for s in range(size_s):
108 C[a, b, c, :, :] += matmul_loops(A[a, c, s, :, :], B[b, s, :, :])
109
110 return C
111
112
113# ===========================================================================
114# Task runners
115# ===========================================================================
116
117def task1():
118 v1 = torch.rand(128)
119 v2 = torch.rand(128)
120
121 result_custom = dot_product(v1, v2)
122 result_torch = torch.dot(v1, v2)
123 assert torch.allclose(result_custom, result_torch), (
124 f"Task 1 mismatch: custom={result_custom:.6f}, torch={result_torch:.6f}"
125 )
126 print("Task 1 passed!")
127
128
129def task2():
130 A = torch.rand(8, 32)
131 B = torch.rand(32, 16)
132
133 result_loops = matmul_loops(A, B)
134 result_dot = matmul_dot(A, B)
135 result_torch = torch.matmul(A, B)
136 assert torch.allclose(result_loops, result_torch, atol=1e-5), (
137 "Task 2 matmul_loops mismatch!"
138 )
139 assert torch.allclose(result_dot, result_torch, atol=1e-5), (
140 "Task 2 matmul_dot mismatch!"
141 )
142 print("Task 2 passed!")
143
144
145def task3():
146 # A has shape [a, c, s, x, p] = [2, 4, 5, 4, 3]
147 # B has shape [b, s, p, y] = [3, 5, 3, 5]
148 # C has shape [a, b, c, x, y] = [2, 3, 4, 4, 5]
149 A_ein = torch.rand(2, 4, 5, 4, 3)
150 B_ein = torch.rand(3, 5, 3, 5)
151
152 reference = torch.einsum("acsxp, bspy -> abcxy", A_ein, B_ein)
153 result_loops = einsum_loops(A_ein, B_ein)
154 result_gemm = einsum_gemm(A_ein, B_ein)
155 assert torch.allclose(result_loops, reference, atol=1e-5), (
156 "Task 3A mismatch: pure-loop einsum differs from reference!"
157 )
158 assert torch.allclose(result_gemm, reference, atol=1e-5), (
159 "Task 3B mismatch: loop+GEMM einsum differs from reference!"
160 )
161 print("Task 3 passed!")
162
163
164def main():
165 task1()
166 task2()
167 task3()
168
169
170if __name__ == "__main__":
171 main()