Test case
MultiHeadAttentionTest::test_query_mask_propagation
has been disabled on GPU
Error:
def dot_product_attention(
query,
key,
value,
bias=None,
mask=None,
scale=None,
is_causal=False,
flash_attention=False,
):
if bias is not None:
raise ValueError(
"torch's `dot_product_attention` doesn't support `bias`."
)
query = convert_to_tensor(query)
key = convert_to_tensor(key)
value = convert_to_tensor(value)
if len(query.shape) != 4:
raise ValueError(
"`dot_product_attention` only supports 3D and 4D inputs. "
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
f"value.shape={value.shape}."
)
bias = bias if bias is None else convert_to_tensor(bias)
mask = mask if mask is None else convert_to_tensor(mask, dtype="bool")
if mask is not None:
# Explicit set `is_causal` to `False` when `mask` is not `None`.
is_causal = False
mask = torch.where(mask, 0.0, _get_large_negative(query.dtype))
axis0, axis1 = 1, 2
query = torch.transpose(query, axis0, axis1)
key = torch.transpose(key, axis0, axis1)
value = torch.transpose(value, axis0, axis1)
if flash_attention:
is_enabled = is_flash_attention_enabled(
query=query,
key=key,
value=value,
mask=mask,
is_causal=is_causal,
)
if not is_enabled:
raise ValueError(
"Flash attention is not enabled in `torch` backend. "
"The dtype of the inputs should be float16/bfloat16 "
"and your GPU should support flash attention implementation."
)
with torch.nn.attention.sdpa_kernel(
backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION],
):
attention_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=mask,
is_causal=is_causal,
scale=scale,
)
else:
if mask is not None:
mask = mask.contiguous()
> attention_output = torch.nn.functional.scaled_dot_product_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_mask=mask,
is_causal=is_causal,
scale=scale,
)
E RuntimeError: (*bias): last dimension must be contiguous
keras/src/backend/torch/nn.py:959: RuntimeError
Comment From: divyashreepathihalli
Issue tracking this bug on PyTorch : https://github.com/pytorch/pytorch/issues/109607 https://discuss.pytorch.org/t/weird-behavior-of-f-scaled-dot-product-attention/203279 https://github.com/pytorch/pytorch/issues/139424