In the local-global attention block of the CrossViewSwapAttention class, I noticed that there are two rearrange operations applied to the key tensor: From my understanding, these two operations seem to cancel each other out as they appear to reshape the key tensor first into a global feature map and then back into the original window partitioned shape. Could you help explain the purpose of these operations? Why does the key tensor need to be reshaped twice in this way?
# local-to-local cross-attention
query = rearrange(query, 'b n d (x w1) (y w2) -> b n x y w1 w2 d',
w1=self.q_win_size[0], w2=self.q_win_size[1]) # window partition
key = rearrange(key, 'b n d (x w1) (y w2) -> b n x y w1 w2 d',
w1=self.feat_win_size[0], w2=self.feat_win_size[1]) # window partition
val = rearrange(val, 'b n d (x w1) (y w2) -> b n x y w1 w2 d',
w1=self.feat_win_size[0], w2=self.feat_win_size[1]) # window partition
query = rearrange(self.cross_win_attend_1(query, key, val,
skip=rearrange(x,
'b d (x w1) (y w2) -> b x y w1 w2 d',
w1=self.q_win_size[0], w2=self.q_win_size[1]) if self.skip else None),
'b x y w1 w2 d -> b (x w1) (y w2) d') # reverse window to feature 全部恢复原来的形状
query = query + self.mlp_1(self.prenorm_1(query))
x_skip = query
query = repeat(query, 'b x y d -> b n x y d', n=n) # b n x y d
# local-to-global cross-attention
query = rearrange(query, 'b n (x w1) (y w2) d -> b n x y w1 w2 d',
w1=self.q_win_size[0], w2=self.q_win_size[1]) # window partition
# Todo: 这不是相互抵消的操作吗?
key = rearrange(key, 'b n x y w1 w2 d -> b n (x w1) (y w2) d') # reverse window to feature
key = rearrange(key, 'b n (w1 x) (w2 y) d -> b n x y w1 w2 d',
w1=self.feat_win_size[0], w2=self.feat_win_size[1]) # grid partition
val = rearrange(val, 'b n x y w1 w2 d -> b n (x w1) (y w2) d') # reverse window to feature
val = rearrange(val, 'b n (w1 x) (w2 y) d -> b n x y w1 w2 d',
w1=self.feat_win_size[0], w2=self.feat_win_size[1]) # grid partition
query = rearrange(self.cross_win_attend_2(query,
key,
val,
skip=rearrange(x_skip,
'b (x w1) (y w2) d -> b x y w1 w2 d',
w1=self.q_win_size[0],
w2=self.q_win_size[1])
if self.skip else None),
'b x y w1 w2 d -> b (x w1) (y w2) d') # reverse grid to feature