class MatcherCustomOp(ABC):
def __init__(self, enabled: bool) -> None:
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
self.enabled = enabled
self.forward = self.forward_custom if enabled else self.forward_native
@abstractmethod
def forward_custom(self, *args: Any, **kwargs: Any) -> Any:
pass
@abstractmethod
def forward_native(self, *args: Any, **kwargs: Any) -> Any:
pass
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.forward(*args, **kwargs)
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs)
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
def inputs(self) -> list[torch.Tensor]:
"""Utility for inputs to the pattern"""
raise NotImplementedError