vllm.model_executor.layers.fla.ops.chunk_scaled_dot_kkt ¶
chunk_scaled_dot_kkt_fwd ¶
chunk_scaled_dot_kkt_fwd(
k: Tensor,
g: Tensor | None = None,
beta: Tensor | None = None,
cu_seqlens: LongTensor | None = None,
chunk_size: int = 64,
output_dtype: dtype = float32,
) -> Tensor
Compute beta * K * K^T.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
k | Tensor | The key tensor of shape | required |
beta | Tensor | The beta tensor of shape | None |
g | Tensor | The cumulative sum of the gate tensor of shape | None |
cu_seqlens | LongTensor | The cumulative sequence lengths of the input tensor. Default: None | None |
chunk_size | int | The chunk size. Default: 64. | 64 |
output_dtype | dtype | The dtype of the output tensor. Default: | float32 |
Returns:
| Type | Description |
|---|---|
Tensor | beta * K * K^T of shape |