Skip to content

vllm.model_executor.models.qwen_vl

Inference-only Qwen-VL model compatible with HuggingFace weights.

QwenImageEmbeddingInputs

Bases: TensorSchema

Dimensions
  • bn: Batch size * number of images
  • ifs: Image feature size (256)
  • hs: Hidden size

hidden_size must match the hidden size of the language model backbone and is stored in the visual config of the model if we have one.

Source code in vllm/model_executor/models/qwen_vl.py
class QwenImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size (256)
        - hs: Hidden size

    `hidden_size` must match the hidden size of the language model backbone
    and is stored in the visual config of the model if we have one.
    """

    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", 256, "hs")]

QwenImagePixelInputs

Bases: TensorSchema

Dimensions
  • bn: Batch size * number of images
  • c: Number of channels (3)
  • h: Height
  • w: Width

Note that image_size is the value in the vision config to which we resize the image to in the normalization transform. Currently multi-image support can only be leveraged by passing image embeddings directly.

Source code in vllm/model_executor/models/qwen_vl.py
class QwenImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width

    Note that image_size is the value in the vision config to which we resize
    the image to in the normalization transform. Currently multi-image support
    can only be leveraged by passing image embeddings directly.
    """

    type: Literal["pixel_values"] = "pixel_values"
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]

QwenVLForConditionalGeneration

Bases: QWenBaseModel, SupportsPP, SupportsLoRA, SupportsMultiModal

Source code in vllm/model_executor/models/qwen_vl.py
@MULTIMODAL_REGISTRY.register_processor(
    QwenVLMultiModalProcessor,
    info=QwenVLProcessingInfo,
    dummy_inputs=QwenVLDummyInputsBuilder,
)
class QwenVLForConditionalGeneration(
    QWenBaseModel, SupportsPP, SupportsLoRA, SupportsMultiModal
):
    packed_modules_mapping = {
        "c_attn": ["c_attn"],
        "gate_up_proj": [
            "w2",
            "w1",
        ],
    }

    embed_input_ids = SupportsMultiModal.embed_input_ids

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="transformer.h",
            connector="transformer.visual.attn_pool",
            tower_model="transformer.visual.transformer",
        )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return f"Picture {i}: <img></img>"

        raise ValueError("Only image modality is supported")

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[QwenVLModel] = QwenVLModel,
    ) -> None:
        with self._mark_composite_model(
            vllm_config,
            language_targets=QWenBlock,
            tower_targets={"image": VisionTransformer},
        ):
            super().__init__(
                vllm_config=vllm_config,
                prefix=prefix,
                transformer_type=transformer_type,
            )

        self.transformer: QwenVLModel

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> QwenImageInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is not None:
            expected_h = expected_w = self.config.visual["image_size"]
            resolve_bindings = {"h": expected_h, "w": expected_w}

            return QwenImagePixelInputs(
                type="pixel_values",
                data=pixel_values,
                resolve_bindings=resolve_bindings,
            )

        if image_embeds is not None:
            return QwenImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        return None

    def _process_image_input(self, image_input: QwenImageInputs) -> torch.Tensor:
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        return self.transformer.visual(image_input["data"])

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
        return hidden_states

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix in multimodal models

Source code in vllm/model_executor/models/qwen_vl.py
def get_mm_mapping(self) -> MultiModelKeys:
    """
    Get the module prefix in multimodal models
    """
    return MultiModelKeys.from_string_field(
        language_model="transformer.h",
        connector="transformer.visual.attn_pool",
        tower_model="transformer.visual.transformer",
    )

QwenVLMLP

Bases: Module

MLP for the visual component of the Qwen model.

Source code in vllm/model_executor/models/qwen_vl.py
class QwenVLMLP(nn.Module):
    """MLP for the visual component of the Qwen model."""

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_fc",
        )
        self.act_fn = get_act_fn("gelu")
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_proj",
        )

    def forward(self, x):
        x, _ = self.c_fc(x)
        x = self.act_fn(x)
        x, _ = self.c_proj(x)
        return x

QwenVLProcessor

This model doesn't define its own HF processor, so we implement our own one here.

We call the wrapped tokenizer to automatically insert image pad tokens: https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245

The image processor is defined here: https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354

Source code in vllm/model_executor/models/qwen_vl.py
class QwenVLProcessor:
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.

    We call the wrapped tokenizer to automatically insert image pad tokens:
    https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245

    The image processor is defined here:
    https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
    """

    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: PreTrainedTokenizer,
    ) -> None:
        super().__init__()

        self.config = config
        self.tokenizer = tokenizer

        vision_config = config.visual
        image_size = vision_config["image_size"]

        self.image_transform = transforms.Compose(
            [
                transforms.Resize(
                    (image_size, image_size),
                    interpolation=InterpolationMode.BICUBIC,
                ),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711),
                ),
            ]
        )

    @property
    def image_start_tag(self) -> str:
        return self.tokenizer.image_start_tag  # type: ignore

    @property
    def image_end_tag(self) -> str:
        return self.tokenizer.image_end_tag  # type: ignore

    @property
    def image_pad_tag(self) -> str:
        return self.tokenizer.image_pad_tag  # type: ignore

    def __call__(
        self,
        text: TextInput | list[TextInput] | None = None,
        images: ImageInput | list[ImageInput] | None = None,
        return_tensors: str | TensorType | None = None,
    ) -> BatchFeature:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        text_inputs = self.tokenizer(text)

        if len(images) == 0:
            image_inputs = {}
        else:
            pixel_values = [self.image_transform(image) for image in images]
            image_inputs = {"pixel_values": torch.stack(pixel_values)}

        return BatchFeature(
            {
                **text_inputs,
                **image_inputs,
            },
            tensor_type=return_tensors,
        )

VisualAttention

Bases: Module

self-attention layer class. Self-attention layer takes input with size [s, b, h] and returns output of the same size.

Source code in vllm/model_executor/models/qwen_vl.py
class VisualAttention(nn.Module):
    """self-attention layer class.
    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
        kdim: int | None = None,
        vdim: int | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads

        # Per attention head and per partition values.
        assert embed_dim % num_heads == 0
        self.hidden_size_per_attention_head = embed_dim // num_heads
        self.num_attention_heads_per_partition = num_heads
        self.hidden_size_per_partition = embed_dim

        # Strided linear layer.
        assert self._qkv_same_embed_dim, (
            "Visual Attention implementation only supports self-attention"
        )
        self.in_proj = ReplicatedLinear(
            embed_dim, 3 * embed_dim, prefix=f"{prefix}.in_proj"
        )
        self.out_proj = ReplicatedLinear(
            embed_dim, embed_dim, prefix=f"{prefix}.out_proj"
        )
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # query/key/value: [sq, b, h]
        sq, b, _ = x.size()
        mixed_x_layer, _ = self.in_proj(x)

        # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
        new_tensor_shape = mixed_x_layer.size()[:-1] + (
            self.num_attention_heads_per_partition,
            3 * self.hidden_size_per_attention_head,
        )
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
        query_layer, key_layer, value_layer = mixed_x_layer.split(
            self.hidden_size_per_attention_head, dim=-1
        )

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(
            sq,
            b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        ).transpose(0, 1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(
            sq,
            b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        ).transpose(0, 1)

        q_scaled = query_layer / self.norm_factor
        if attn_mask is not None:
            attention_probs = torch.baddbmm(
                attn_mask, q_scaled, key_layer.transpose(-2, -1)
            )
        else:
            attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
        attention_probs = attention_probs.softmax(dim=-1)

        value_layer = value_layer.view(
            sq,
            b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        ).transpose(0, 1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer)

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(
            b,
            self.num_attention_heads_per_partition,
            sq,
            self.hidden_size_per_attention_head,
        )

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + (
            self.hidden_size_per_partition,
        )
        context_layer = context_layer.view(*new_context_layer_shape)

        output, _ = self.out_proj(context_layer)

        return output

_get_tokenizer_without_image_pad cached

_get_tokenizer_without_image_pad(
    tokenizer: PreTrainedTokenizer,
) -> PreTrainedTokenizer

The logic of adding image pad tokens should only be applied in QwenVLProcessor, so they are patched out here.

The definition of the wrapped tokenizer can be found here: https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py

Source code in vllm/model_executor/models/qwen_vl.py
@lru_cache(maxsize=1)
def _get_tokenizer_without_image_pad(
    tokenizer: PreTrainedTokenizer,
) -> PreTrainedTokenizer:
    """
    The logic of adding image pad tokens should only be applied in
    [`QwenVLProcessor`][vllm.model_executor.models.qwen_vl.QwenVLProcessor],
    so they are patched out here.

    The definition of the wrapped tokenizer can be found here:
    https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
    """
    new_tokenizer = copy.deepcopy(tokenizer)

    class TokenizerWithoutImagePad(tokenizer.__class__):  # type: ignore
        def tokenize(
            self,
            text: str,
            allowed_special: Set[str] | str = "all",
            disallowed_special: Collection[str] | str = (),
            **kwargs,
        ) -> list[bytes | str]:
            text = unicodedata.normalize("NFC", text)

            return [
                self.decoder[t]
                for t in self.tokenizer.encode(
                    text,
                    allowed_special=allowed_special,
                    disallowed_special=disallowed_special,
                )
            ]

        def _decode(
            self,
            token_ids: int | list[int],
            skip_special_tokens: bool = False,
            errors: str | None = None,
            **kwargs,
        ) -> str:
            if isinstance(token_ids, int):
                token_ids = [token_ids]

            return self.tokenizer.decode(
                token_ids,
                errors=errors or self.errors,
            )

    TokenizerWithoutImagePad.__name__ = f"{tokenizer.__class__.__name__}WithoutImagePad"

    new_tokenizer.__class__ = TokenizerWithoutImagePad
    return new_tokenizer