@support_torch_compile
class Llama4Model(LlamaModel):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer,
):
self.num_experts = vllm_config.model_config.hf_config.num_local_experts
self.n_redundant_experts = (
vllm_config.parallel_config.eplb_config.num_redundant_experts
)
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
def load_moe_expert_weights(
self,
name: str,
loaded_weight: torch.Tensor,
params_dict: dict[str, nn.Parameter],
loaded_params: set[str],
expert_params_mapping: list[tuple[str, str, int, str]],
fused: bool = True,
) -> bool:
"""
Load MoE expert weights.
Args:
name: The name of the weight to load.
loaded_weight: The weight to load.
params_dict: The dictionary of module parameters.
loaded_params: The set of already loaded parameters.
expert_params_mapping: The mapping of expert parameters. Must be
generated by SharedFusedMoE.make_expert_params_mapping().
fused: Whether the expert weights are fused into a single weight
tensor or are separate weight tensors for each expert.
When fused is True, loaded_weight should have shape of:
[num_experts, hidden_in, hidden_out] for gate/up/down proj and
[hidden_out, hidden_in] for the others like router.
When fused is False, loaded_weight should have shape of:
[hidden_out, hidden_in].
Returns:
True if loaded_weight is one of MoE weights and the MoE expert
weights are loaded successfully, False otherwise.
"""
# Whether the MoE expert weights are loaded successfully.
expert_param_loaded = False
# If fused is True, the loaded weight is in the layout of:
# [num_experts, hidden_in, hidden_out], so we must transpose the last
# two dimensions to match the expected layout of the parameters.
if fused and loaded_weight.ndim == 3:
loaded_weight = loaded_weight.transpose(-1, -2)
# If the gate_proj and up_proj weights are fused into a single
# weight tensor, we need to split the weight tensor into a tuple
# of two weight tensors along the hidden_out dimension.
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2)
# Iterate over all the expert parameters and load the weights if we find
# a match in weight name.
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
# Get a view of the loaded_weight to avoid modifying the original
# one across iterations.
new_loaded_weight = loaded_weight
# If expert weights are fused into a single weight tensor, remove
# the expert index from the expected weight name.
if fused:
# The string between e_str and proj_str is the expert index.
e_str, _, proj_str, _ = weight_name.split(".")
weight_name = f"{e_str}.{proj_str}"
param_name = f"{param_name}weight"
# Skip if the current weight is not one of the MoE weights.
if weight_name not in name:
continue
# Replace the weight name with the parameter name.
full_param_name = name.replace(weight_name, param_name)
# Skip if the current weight corresponds to a parameter that
# does not exist on the current PP (pipeline parallel) rank.
if is_pp_missing_parameter(name, self):
continue
# Skip if the current weight is for the bias.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
param = params_dict[full_param_name]
weight_loader = param.weight_loader
if fused:
# If the parameter is for w13 together, the corresponding weight
# will be a tuple, so we must select the correct weight
# depending on the shard id, which is either "w1" or "w3".
if "w13" in full_param_name:
assert shard_id in ["w1", "w3"]
shard_idx = 0 if shard_id == "w1" else 1
new_loaded_weight = new_loaded_weight[shard_idx]
# If EP (expert parallel) is enabled, update expert_id to the
# starting expert index for the current EP rank and extract the
# corresponding expert weights.
layer_idx = extract_layer_index(name)
expert_map = self.layers[layer_idx].feed_forward.experts.expert_map
if expert_map is not None:
local_expert_indices = (
(expert_map != -1)
.nonzero()
.flatten()
.to(new_loaded_weight.device)
)
# Workaround for FP8 CPU indexing on older PyTorch:
# https://github.com/vllm-project/vllm/issues/32862
is_fp8_dtype = new_loaded_weight.dtype == (
current_platform.fp8_dtype()
) or (
new_loaded_weight.dtype.is_floating_point
and new_loaded_weight.element_size() == 1
)
if (
new_loaded_weight.device.type == "cpu"
and is_fp8_dtype
and not is_torch_equal_or_newer("2.11.0")
):
# PyTorch < 2.11 doesn't support CPU float8 indexing.
new_loaded_weight = new_loaded_weight.to(torch.float16)[
local_expert_indices
].to(new_loaded_weight.dtype)
else:
new_loaded_weight = new_loaded_weight[local_expert_indices]
expert_id = local_expert_indices[0].item()
else:
# TODO: add EP support for non fused weights
pass
# Load the weight into the module parameter with corresponding
# shard id and expert id.
weight_loader(
param,
new_loaded_weight,
full_param_name,
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(full_param_name)
expert_param_loaded = True
return expert_param_loaded
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# Name mapping from the parameter name to the shard name and
# corresponding shard id.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
# Indicate whether the expert weights are fused into a single weight
# tensor.
fused_experts_params = False
# Expert parameter mapping for the case where the expert weights are
# not fused into a single weight tensor.
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.num_experts,
num_redundant_experts=self.n_redundant_experts,
)
# Expert parameter mapping for the case where the expert weights are
# fused into a single weight tensor.
expert_params_mapping_fused = SharedFusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="gate_up_proj",
num_experts=1,
)
# All the module parameters.
params_dict = dict(self.named_parameters())
# The module parameters that have been loaded.
loaded_params: set[str] = set()
# Iterate over all the weights and load them into module parameters.
for name, loaded_weight in weights:
# If the name contains "experts.gate_up_proj" or "experts.down_proj"
# without the expert indices, it means the expert weights are fused
# into a single weight tensor across all experts.
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
fused_experts_params = True
expert_params_mapping = expert_params_mapping_fused
# If kv cache quantization scales exist and the weight name
# corresponds to one of the kv cache quantization scales, load
# them.
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Iterate over stacked_params_mapping to check if the current weight
# is one of the stacked parameters. If so, load the weight with the
# corresponding shard id. Note that MoE weights are handled
# separately in the else block.
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip if the current weight is not one of the stacked
# parameters or if the current weight is a MoE weight.
if weight_name not in name or "experts" in name:
continue
# For ModelOpt checkpoints, we need to rename the self_attn
# weight/weight_scale names except for kv cache scales.
if not (
name.endswith((".k_scale", ".v_scale")) and "self_attn" in name
):
name = name.replace(weight_name, param_name)
# Skip if the current weight corresponds to a parameter that
# does not exist on the current PP (pipeline parallel) rank.
if is_pp_missing_parameter(name, self):
continue
# Remap kv cache scale names for ModelOpt checkpoints.
# TODO: ModelOpt should implement get_cache_scale() such that
# kv cache scale name remapping can be done there.
if name.endswith("scale"):
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
# Load the weight into the module parameter with corresponding
# shard id and exit the for loop and the else block.
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
# Handle normal (non-stacked) weights and MoE weights.
else:
# First, try to load MoE weights using load_moe_expert_weights.
# If successful, move on to next loaded weight.
if self.load_moe_expert_weights(
name,
loaded_weight,
params_dict,
loaded_params,
expert_params_mapping,
fused=fused_experts_params,
):
continue
# Skip if the current weight corresponds to a parameter that
# does not exist on the current PP (pipeline parallel) rank.
if is_pp_missing_parameter(name, self):
continue
# Handle flat expert scale parameters that don't match
# per-expert patterns, i.e. one weight scale tensor for all
# experts.
scale_names = [
"w13_input_scale",
"w13_weight_scale",
"w2_input_scale",
"w2_weight_scale",
]
if "experts." in name and any(
scale_name in name for scale_name in scale_names
):
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
# If weight loader supports special moe loading, use it to
# avoid expensive runtime reflection
if getattr(weight_loader, "supports_moe_loading", False):
# Map the weight name to the corresponding shard id.
shard_id = "w2" if "w2_" in name else "w1"
# Transpose if weight scales are FP8 block scales with
# three dimensions:
# [num_experts, hidden_in, hidden_out].
if (
name.endswith("weight_scale")
and loaded_weight.dtype == torch.float8_e4m3fn
and loaded_weight.ndim == 3
):
loaded_weight = loaded_weight.transpose(-1, -2)
# Load the weight into the module parameter with
# corresponding shard id and expert id.
weight_loader(
param, loaded_weight, name, shard_id=shard_id, expert_id=0
)
else:
# Regular weight loader (handles both
# param.weight_loader and default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
continue
# Handle normal (non-stacked, non-MoE) weights.
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
# Finally, return the set of loaded parameters.
return loaded_params