Skip to content

vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize

FlashInferA2APrepareAndFinalize

Bases: FusedMoEPrepareAndFinalize

Base class for FlashInfer MoE prepare and finalize operations.

Source code in vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
    """Base class for FlashInfer MoE prepare and finalize operations."""

    def __init__(
        self,
        num_dispatchers: int = 1,
    ):
        super().__init__()
        self.num_dispatchers_ = num_dispatchers
        self.all2all_manager = get_ep_group().device_communicator.all2all_manager

    @property
    def activation_format(self) -> mk.FusedMoEActivationFormat:
        return mk.FusedMoEActivationFormat.Standard

    def max_num_tokens_per_rank(self) -> int | None:
        return None

    def topk_indices_dtype(self) -> torch.dtype | None:
        return None

    def num_dispatchers(self) -> int:
        return self.num_dispatchers_

    def output_is_reduced(self) -> bool:
        return False

    def _apply_router_weight_on_input(
        self,
        a1: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        apply_router_weight_on_input: bool,
    ) -> None:
        """Apply router weight on input if needed."""
        if apply_router_weight_on_input:
            topk = topk_ids.size(1)
            assert topk == 1, (
                "apply_router_weight_on_input is only implemented for topk=1"
            )
            a1.mul_(topk_weights.to(a1.dtype))

    def prepare(
        self,
        a1: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        num_experts: int,
        expert_map: torch.Tensor | None,
        apply_router_weight_on_input: bool,
        quant_config: FusedMoEQuantConfig,
        defer_input_quant: bool = False,
    ) -> mk.PrepareResultType:
        self._apply_router_weight_on_input(
            a1, topk_weights, topk_ids, apply_router_weight_on_input
        )
        global_num_tokens_cpu = get_local_sizes()
        top_k = topk_ids.size(1)

        (self.alltoall_info, topk_ids, topk_weights, a1q, a1q_scale) = (
            flashinfer_alltoall_dispatch(
                self.all2all_manager,
                global_num_tokens_cpu,
                a1,
                quant_config.a1_gscale,
                topk_ids,
                topk_weights,
                top_k,
                num_experts,
                quant_config,
                defer_input_quant=defer_input_quant,
            )
        )

        return a1q, a1q_scale, None, topk_ids, topk_weights

    def finalize(
        self,
        output: torch.Tensor,
        fused_expert_output: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        apply_router_weight_on_input: bool,
        weight_and_reduce_impl: mk.TopKWeightAndReduce,
    ) -> None:
        top_k = topk_ids.size(1)
        token_count = output.shape[0]
        fused_expert_output = flashinfer_alltoall_combine(
            self.all2all_manager,
            fused_expert_output,
            top_k=top_k,
            token_count=token_count,
            alltoall_info=self.alltoall_info,
        )
        output.copy_(fused_expert_output)

_apply_router_weight_on_input

_apply_router_weight_on_input(
    a1: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    apply_router_weight_on_input: bool,
) -> None

Apply router weight on input if needed.

Source code in vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
def _apply_router_weight_on_input(
    self,
    a1: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    apply_router_weight_on_input: bool,
) -> None:
    """Apply router weight on input if needed."""
    if apply_router_weight_on_input:
        topk = topk_ids.size(1)
        assert topk == 1, (
            "apply_router_weight_on_input is only implemented for topk=1"
        )
        a1.mul_(topk_weights.to(a1.dtype))