You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The mask always has new axes for dimensions 2 and 3. In all four cases it is wrapped in alphafold.model.mapping.inference_subbatch, but this doesn't affect the dimensions, only the sizes.
The docstring gives incorrect shapes for the arguments: mask, nonbatched_bias, based on the usage it should be:
mask: A mask for the attention, shape [batch_size, N_heads, N_queries, N_keys].
nonbatched_bias: Shared bias, shape [N_heads, N_queries, N_keys].
instead of:
mask: A mask for the attention, shape [batch_size, N_queries, N_keys].
nonbatched_bias: Shared bias, shape [N_queries, N_keys].
This is clear when looking at where mask is used:
...
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k)
if nonbatched_bias is not None:
logits += jnp.expand_dims(nonbatched_bias, axis=0)
logits = jnp.where(mask, logits, _SOFTMAX_MASK)
...
as the output of the einsum has shape: bhqk.
I believe some implementations of attention wont have a head dimension in the mask, since it is not used in AlphaFold maybe it would be worth removing it in the mask when attention is called (and including an expand_dims for this head dimension within the attention module). But only changing the docstring is easier, and it is still a valid implementation of Attention so I think that is the way to go.
The text was updated successfully, but these errors were encountered:
Linked to the following PR and fork.
The Attention module is only called in four places, with the following shapes:
The mask always has new axes for dimensions 2 and 3. In all four cases it is wrapped in alphafold.model.mapping.inference_subbatch, but this doesn't affect the dimensions, only the sizes.
The docstring gives incorrect shapes for the arguments: mask, nonbatched_bias, based on the usage it should be:
instead of:
This is clear when looking at where mask is used:
as the output of the einsum has shape:
bhqk
.I believe some implementations of attention wont have a head dimension in the mask, since it is not used in AlphaFold maybe it would be worth removing it in the mask when attention is called (and including an expand_dims for this head dimension within the attention module). But only changing the docstring is easier, and it is still a valid implementation of Attention so I think that is the way to go.
The text was updated successfully, but these errors were encountered: