-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Maybe a bug in the TVM inpmelentation #15
Comments
你好,感谢你对我们工作的关注。 你实现反向传播的思路是对的,按照你的思路写出来的代码应该也能够正常运行。只不过我们代码里"k_q_mask"的含义和你想的稍有不同,才造成了“可能是bug”的误会。 先说为什么计算K的梯度时,我们仍然使用的是q_k_mask:因为PAM的注意力机制构成的图是无向图,这意味着第i个query关注了哪些key,第i个key就被哪些query关注。因此,对于每个k_i,反过来找跟它‘结合’过的多个queries,也需要看q_k_mask[i]中的元素。 也正是因为q_k_mask兼任了你的实现中"k_q_mask"的功能,所以我们代码中的k_q_mask和你理解的稍有不同,实际上和你的实现中M的作用相近。k[i, j]的含义是:对于第i个key的第j个query,k[i ,j]存储了序列中第j个点作为query时,序列中第i个点在它关注列表中的索引。 这部分因为索引众多,逻辑确实很绕,希望上面的解释能帮助到你。 另外在完成TVM的代码后我们也实际验证过它训练出的性能,和naive implementation的差别确实是不大的。 |
非常感谢! 我推导的时候,自动带入了更复杂的有向图的情况,忽略了PAM是无向图这个前提条件;有向图在实现上会更复杂一点,因为“对于每个k_i, 反向找跟它结合过的queries”的时候不能再简单地用q_k_mask来进行索引了。 我再仔细检查下自己的计算细节,再次感谢! |
不客气,你实现的有向图的情况可以支持更灵活的attention机制,如果有兴趣单独开一个repo的话或许能帮助到更多的人 |
当然,等完善好相关实现,后续会把代码开源的~ |
TVM里的推导有点复杂,直接讲中文了哈。我仔细读了TVM部分的代码,发现反向传播部分有一个part好像有点问题,希望能和作者讨论一下。
https://github.com/alipay/Pyraformer/blob/84af4dbd93b7b96975b5034f0dde412005260123/pyraformer/hierarchical_mm_tvm.py#L74-L77
这个部分的实现是对应的是反向传播的计算,具体有两种情况,一种是attn=Q*K中反向传播(计算K的梯度),另一种是contex=attn*V中的反向传播(计算V的梯度);从矩阵计算的形式来说,这两者是等价的。所以这里我以第一种情况为例子:
反向计算K的梯度时,对于每个k_i,我们要反过来找跟它‘结合’过的多个queries; 所以条件判断时用q_k_mask[i, k]是不正确的,这样的话依然找的是每个q_i对应的keys。正确的写法应该是:
k_q_mask[i, k]>=0,
X[l, k_q_mask[i, k], q, idx] * Y[l, k_q_mask[i, k], q, j].
对应的解释是:对于每个k_i, 反向找跟它结合过的queries,也即k_q_mask[i, k];这里的X对应的attn的梯度,Y对应的是Q;由于X的维度是[batch, seq_n, heads, max_attn], 对于X的梯度,还需要进一步找到k_i在k_q_mask[i, k]这个query这一行的索引,也即idx。
我推导了一下,idx貌似不能直接使用q_k_mask和k_q_mask直接计算出来,这时我们需要一个新的index matrix来使计算变得简单。假设这个idx matrix为M,那么它应该是:M[k_q_mask[i,j], i] = torch.where(q_k_mask[k_q_mask[i,j]]=i)。这个M可以像k_q_mask那样提前计算出来,然后用于TVM的计算。然后真正的计算就变成了:
k_q_mask[i, k]>=0,
X[l, k_q_mask[i, k], q, M[k_q_mask[i,k], i]] * Y[l, k_q_mask[i, k], q, j]
另外不知道作者在完成TVM的代码部分之后,有没有比较它跑出来的结果跟naive impelmentation的差异大不大,还是只比较了efficiency。
这里的计算确实很绕,以上的推导只建立在我自已的理解上,不保证一定是对的,如果有可能的话,想跟作者仔细讨论一下这里TVM的实现。非常感谢!
The text was updated successfully, but these errors were encountered: