d_model=512
n_heads=8
d_k = d_v = d_model // n_heads
attn = ScaledDotProductAttention(d_model=d_model, n_heads=n_heads)
mha_attn = MultiheadAttentionCustom(d_model, n_heads)
W_Q = nn.Linear(d_model, d_k * n_heads)
W_K = nn.Linear(d_model, d_k * n_heads)
W_V = nn.Linear(d_model, d_v * n_heads)
X,_,_ = ds[0]
X = create_patch(X, patch_len=(10*50), stride=(5*50), constant_pad=True)
patch_len = X.shape[-1]
X = X[None, ...].permute(0,2,1,3) # simulate batch size of 1 [bs x n_vars x num_patch x patch_len]
print(f'X input shape: {X.shape}')
W_P = nn.Linear(patch_len, d_model)
X = W_P(X) # project to d_model
print(f"Projected X shape to d_model: {X.shape}")
X = torch.reshape(X, (X.shape[0]*X.shape[1],X.shape[2],X.shape[3]))
print(f"Reshape for attention: {X.shape}")
# test multihead attention
print("\nTesting MHA and SDA attention, with just 50 elements.")
mha_output, mha_attn_weights = mha_attn(Q=X[:,:50,:])
print(f"MHA attention output shape: {mha_output.shape}, mha attn weight shape: {mha_attn_weights.shape}")
# test scaled dot product attn
K = Q = V = X
# # Linear (+ split in multiple heads)
bs = 1 # 1 * 16
q_s = W_Q(Q).reshape(bs, -1, n_heads, d_k).transpose(1, 2)
k_s = W_K(K).reshape(bs, -1, n_heads, d_k).permute(0, 2, 3, 1)
v_s = W_V(V).reshape(bs, -1, n_heads, d_v).transpose(1, 2)
print(f"Q shape: {q_s.shape}, K shape: {k_s.shape}, V shape: {v_s.shape}")
to_out = nn.Linear(n_heads * d_v, d_model)
output, attn_weights = attn(q_s[:,:,:50,:],k_s[:,:,:,:50], v_s[:,:,:50,:])
output = output.transpose(1, 2).contiguous().view(bs, -1, n_heads * d_v)
print(f"Attn output shape {output.shape}, attn weight shape: {attn_weights.shape}")