Test case MultiHeadAttentionTest::test_query_mask_propagation has been disabled on GPU

Error:

    def dot_product_attention(
        query,
        key,
        value,
        bias=None,
        mask=None,
        scale=None,
        is_causal=False,
        flash_attention=False,
    ):
        if bias is not None:
            raise ValueError(
                "torch's `dot_product_attention` doesn't support `bias`."
            )
        query = convert_to_tensor(query)
        key = convert_to_tensor(key)
        value = convert_to_tensor(value)
        if len(query.shape) != 4:
            raise ValueError(
                "`dot_product_attention` only supports 3D and 4D inputs. "
                f"Received: query.shape={query.shape}, key.shape={key.shape}, "
                f"value.shape={value.shape}."
            )
        bias = bias if bias is None else convert_to_tensor(bias)
        mask = mask if mask is None else convert_to_tensor(mask, dtype="bool")
        if mask is not None:
            # Explicit set `is_causal` to `False` when `mask` is not `None`.
            is_causal = False
            mask = torch.where(mask, 0.0, _get_large_negative(query.dtype))

        axis0, axis1 = 1, 2
        query = torch.transpose(query, axis0, axis1)
        key = torch.transpose(key, axis0, axis1)
        value = torch.transpose(value, axis0, axis1)

        if flash_attention:
            is_enabled = is_flash_attention_enabled(
                query=query,
                key=key,
                value=value,
                mask=mask,
                is_causal=is_causal,
            )
            if not is_enabled:
                raise ValueError(
                    "Flash attention is not enabled in `torch` backend. "
                    "The dtype of the inputs should be float16/bfloat16 "
                    "and your GPU should support flash attention implementation."
                )

            with torch.nn.attention.sdpa_kernel(
                backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION],
            ):
                attention_output = torch.nn.functional.scaled_dot_product_attention(
                    query,
                    key,
                    value,
                    attn_mask=mask,
                    is_causal=is_causal,
                    scale=scale,
                )
        else:
            if mask is not None:
                mask = mask.contiguous()
>           attention_output = torch.nn.functional.scaled_dot_product_attention(
                query.contiguous(),
                key.contiguous(),
                value.contiguous(),
                attn_mask=mask,
                is_causal=is_causal,
                scale=scale,
            )
E           RuntimeError: (*bias): last dimension must be contiguous

keras/src/backend/torch/nn.py:959: RuntimeError

Comment From: divyashreepathihalli

Issue tracking this bug on PyTorch : https://github.com/pytorch/pytorch/issues/109607 https://discuss.pytorch.org/t/weird-behavior-of-f-scaled-dot-product-attention/203279 https://github.com/pytorch/pytorch/issues/139424