vLLMBackend

  • VllmBackend 是 vLLM 自定义的 torch.compile 后端,仅在 CompilationLevel.PIECEWISE 下使用。
  • 它接收 Dynamo 生成的 fx.GraphModule ,按配置切分成若干子图进行编译,并返回一个可调用的拼接图(或带输入复制包装的可调用函数)。
  • 同时把 vLLM 的自定义后处理 pass( PostGradPassManager )注入到 Inductor 配置里,并管理编译缓存目录与命中策略。·
  • 位置:/vllm/vllm/compilation/backends.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class VllmBackend:
"""The compilation backend for `torch.compile` with vLLM.
It is used for compilation level of `CompilationLevel.PIECEWISE`,
where we customize the compilation.

The major work of this backend is to split the graph into
piecewise graphs, and pass them to the piecewise backend.

This backend also adds the PostGradPassManager to Inductor config,
which handles the post-grad passes.
"""

vllm_config: VllmConfig
compilation_config: CompilationConfig
_called: bool = False
# the graph we compiled
graph: fx.GraphModule
# the stiching graph module for all the piecewise graphs
split_gm: fx.GraphModule
piecewise_graphs: list[SplitItem]
returned_callable: Callable
# Inductor passes to run on the graph pre-defunctionalization
post_grad_passes: Sequence[Callable]
sym_tensor_indices: list[int]
input_buffers: list[torch.Tensor]
compiler_manager: CompilerManager

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
# e.g. language_model, vision_model, etc.
# when multiple parts are initialized as independent
# models, we need to use the model_tag to distinguish
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag

# 初始化后处理pass管理器
self.post_grad_pass_manager = PostGradPassManager()
# 准备运行期需要的成员,用于cudgraph模式下“输入复制”的索引与静态缓冲
self.sym_tensor_indices = []
self.input_buffers = []

self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
# 创建编译器管理器,负责计算编译器相关哈希、初始化/读取写入缓存等
self.compiler_manager: CompilerManager = CompilerManager(
self.compilation_config
)
# torch.compile是JIT调用的,真正的编译行为发生在Dynamo调用该后端的__call__(graph,example_inputs)时,而不是在__init__函数中

@support_torch_compile [decorator]

用法:

  1. 直接装饰: @support_torch_compile ,自动按类型注解推断需要标记为动态的参数维度。

  2. 带参数装饰:@support_torch_compile(dynamic_arg_dims=..., enable_if=...),显式指定需要标记的动态维度,并按条件启用。

无参装饰时,装饰器返回的是一个“类处理函数”;带参装饰时,先构造参数,再返回处理函数。两种最终都会调用内部 _support_torch_compile(cls, dynamic_arg_dims, enable_if)

设备:每个rank都会独立执行模型的forward,因此编译也是在各自的GPU上发生,这部分编译包装会在每个rank上独立发生

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# in vllm/vllm/compilation.decorators.py
'''
作用:为任意包含forward的类(通常是mm.Module)注入torch.compile支持
自动推断或接收显式的“动态维度标注”,在首次调用时标记输入的动态形状,避免因形状变化导致重复编译
通过混入一个包装基类,改写类的__init__与__call__,统一管理编译开关、缓存、日志与调度策略
支持条件启用(enable_if),以及与vllm的编译监控、追踪文件收集,dynamo配置补丁集合
'''
def support_torch_compile(
cls: Optional[_T] = None,
*,
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> Union[Callable[[_T], _T], _T]:
"""
动态维度推断:
- 基于forward的类型注解:(torch.Tensor 、 Optional[torch.Tensor] 、 IntermediateTensors 、 Optional[IntermediateTensors] )默认把第0维视为动态维度
- 若调用方显式传入 dynamic_arg_dims ,则使用该配置;否则按注解推断并记录调试日志。
- 校验:必须至少有一个动态参数维度;且显式指定的参数名必须存在于 forward 形参中,否则抛错。
"""

TorchCompileWrapperWithCustomDispatcher

  1. self.compiled_callable如何得到
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# in vllm/vllm/compilation.decorators.py
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: dict[str, Union[int, list[int]]],
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> _T:
# 将 TorchCompileWrapperWithCustomDispatcher 混入到原类的 __bases__ ,以注入编译相关能力
# 避免重复装饰已混入的类
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
return cls
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher,)

old_init = cls.__init__

setattr(cls, IGNORE_COMPILE_KEY, False)

def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
# 首先调用旧的init函数
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
# ...
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level
)
cls.__init__ = __init__

这里的TorchCompileWrapperWithCustomDispatcher

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# compiled_callable 是包装类 TorchCompileWrapperWithCustomDispatcher 在__init__时通过 torch.compile(self.forward, fullgraph=True, backend=..., options=...) 得到的、可调用的已编译函数指针

# vllm/vllm/compilation/wrapper.py
class TorchCompileWrapperWithCustomDispatcher:
""" 专门用于torch.compile的wrapper class,遵守固定的dispatch逻辑"""
def __init__(
self, compiled_callable: Optional[Callable] = None, compilation_level: int = 0
):
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
if compiled_callable is None:
# 使用默认compile方法,编译forward函数
backend = vllm_config.compilation_config.init_backend(vllm_config)
options = None
if isinstance(backend, str) and backend == "inductor":
options = (
get_current_vllm_config().compilation_config.inductor_compile_config
)

compiled_callable = torch.compile(
self.forward, fullgraph=True, backend=backend, options=options
)

self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: list[CodeType] = []
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)

# read the env var to determine whether to use the custom dispatcher
# subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism.
self.use_custom_dispatcher: bool = (
compilation_level >= CompilationLevel.DYNAMO_ONCE
)

def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level.
NOTE: this function can have additional arguments beyond the forward
method, for directly dispatching to the compiled code.
"""
return self.compiled_callable(*args, **kwargs)
  1. dynamic_arg_dims 是一个 “参数名 → 动态维度索引” 的映射,用来告诉 Dynamo 哪些输入张量的哪些维度在不同调用间会变化,需要被标记为动态形状,以生成可适配不同形状的图(或避免不必要的守卫/重编译)。
  • 在装饰器的 call 首次调用时,根据函数签名绑定实参后,调用 torch._dynamo.mark_dynamic(arg, dims) 对这些入参进行标记。支持负索引(如 -1 表示最后一维)。如果是 IntermediateTensors ,则会对其内部的每个张量按同样规则标记。
  • 该映射只针对“调用时传入的输入张量”,例如 batch 的 input_idspositions 、中间张量等。模型权重参数不会在这里被标记动态。
  • Qwen2/Qwen3 的实际映射(来自源码):
    • vllm/vllm/model_executor/models/qwen2.py 的@support_torch_compile(dynamic_arg_dims=...)
      • input_ids: 0 (batch 或序列的第 0 维)
      • positions: -1 (最后一维;Qwen2-VL 的位置张量维度在 MRoPE 情况下不同)
      • intermediate_tensors: 0
      • inputs_embeds: 0
    • vllm/vllm/model_executor/models/qwen3.py 同样声明了 dynamic_arg_dims (Qwen3 继承自 Qwen2,forward 沿用 Qwen2 的逻辑)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# in vllm/vllm/compilation.decorators.py
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: dict[str, Union[int, list[int]]],
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> _T:

# TorchCompileWrapperWithCustomDispatcher...(之前已经展示)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
# 首先调用旧的init函数
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
# 读取vllm config
self.vllm_config = vllm_config
enable_compile = enable_if is None or enable_if(vllm_config)
# 计算 self.do_not_compile ,在以下任一条件成立时关闭编译:
# - CompilationLevel 为 NO_COMPILATION 或 DYNAMO_AS_IS
# - 环境不支持 Dynamo( supports_dynamo() 为假)
# - 类被标记忽略
self.do_not_compile = (
vllm_config.compilation_config.level
in [CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS]
or not supports_dynamo()
or _should_ignore_torch_compile(self.__class__)
or not enable_compile
)
# 如果不需要编译。直接返回
if self.do_not_compile:
return
# 若启用编译,增加编译统计计数,并调用混入基类的__init__,将编译级别装入包装器
compilation_counter.num_models_seen += 1
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level
)
cls.__init__ = __init__
traces_files(基于当前vllm_config的追踪文件集合)

加入时机与内容:

  • 必定加入:顶层 forward 方法的源码文件,即 self.original_code_object.co_filename (你的日志里就是 …/site-packages/vllm/…/qwen2.py )。
  • 可能加入:Dynamo 追踪期间内联的所有函数的源码文件。装饰器在首次编译时临时“打补丁”了 InliningInstructionTranslator.inline_call ,对每个被内联的函数,执行 self.vllm_config.compilation_config.traced_files.add(code.co_filename)

在 Qwen2/Qwen3 的典型路径下,会看到至少:

  • 顶层: vllm/model_executor/models/qwen2.py (Qwen3 继承未重写 forward,所以日志指向 qwen2.py)
  • 层实现: vllm/model_executor/layers/linear.py vllm/model_executor/layers/layernorm.pyvllm/model_executor/layers/rotary_embedding/...vllm/model_executor/layers/attention/... 等被 Python 层直接调用的函数/方法所在文件
  • 工具/接口:若在 Python 层有函数调用并被 Dynamo 内联(例如某些 utils.py 中的纯 Python 函数),其文件也会加入

集合定义位置: vllm/vllm/config/compilation.py:359traced_files: set[str]

具体集合取决于“首次编译时实际走过的 Python 调用路径”。因此除了“顶层forward文件”是确定的之外,其他文件集合需要以你的具体模型/配置的首轮输入为准。可以在编译后直接查看 get_current_vllm_config().compilation_config.traced_files 获取当前实例的实际集合。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# in vllm/vllm/compilation/decorators.py
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: dict[str, Union[int, list[int]]],
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> _T:

# 调用阶段(核心路径)
def __call__(self, *args, **kwargs):
# 早退条件
# torch.compiler.is_compiling表明对象正在编译上下文
if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs)
# 首次调用时
if len(self.compiled_codes) < 1:
# inspect.signature绑定实参
sig = inspect.signature(self.__class__.forward)
bound_args = sig.bind(self, *args, **kwargs)
bound_args.apply_defaults()
# 依据dynamic_arg_dims标记动态维度
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
# 对 torch.Tensor 用 torch._dynamo.mark_dynamic(arg, dims)
if isinstance(arg, torch.Tensor):
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
torch._dynamo.mark_dynamic(arg, dims)
# 对 IntermediateTensors 遍历内部张量并标记
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
dims = [
tensor.ndim + dim if dim < 0 else dim for dim in dims
]
torch._dynamo.mark_dynamic(tensor, dims)
else:
raise ValueError(
"Unsupported dynamic dimensions"
f" {dims} for argument {k} with type {type(arg)}."
)
# 启动编译监控,打印日志Start compiling function
# original_code_object指向被编译为forward的代码对象
start_monitoring_torch_compile(self.vllm_config)
logger.debug("Start compiling function %s", self.original_code_object)

# 如果还没有任何已捕获的编译产物,或者当前选择不走自定义分发器
# 进入该分支后,将有Dynamo进行一次完整的trace/编译,并直接调用编译好的入口返回结果
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
# 清理dynamo缓存,避免跨实例复用导致的不可控编译(vllm需要精确控制每个实例的编译与缓存,避免不受控的重用导致不一致的图或守卫)
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object)

# 收集追踪文件以构建重编译触发集合
# 1. 顶层forward所在文件
# 作用:当这些文件变更时触发重编译
""" What's the usage of trace_files? """
self.vllm_config.compilation_config.traced_files.add(
# self.original_code_object就是编译为forward函数的对象
# 所以在trace_files中加入了 forward 函数所在的文件名
self.original_code_object.co_filename
)

# 2. 通过打补丁InliningInstructionTranslator.inline_call
# 在函数内联时收集所有被跟踪函数的 code.co_filename(重点关注!)
# 保留原始
inline_call = InliningInstructionTranslator.inline_call
# 打补丁函数
def patched_inline_call(parent, func, args, kwargs):
code = func.get_code()# 取出code函数
# 加入追踪集中
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
# 调用原始的inline_call(...)完成内联
return inline_call(parent, func, args, kwargs)


# Dynamo配置补丁(减少编译开销)
dynamo_config_patches = {}
try:
# 如果存在enable_cpp_symbolic_shape_guards
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
# 将其设置为false,禁用cpp形状守卫生成
# 生成cpp守卫的加速收益不明显,但编译耗时会增加
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
except AttributeError:
# torch 2.6上配置不存在时捕获error,记录log,不影响编译
logger.debug("enable_cpp_symbolic_shape_guards config not available")

with (
""" What is this function doing? """
patch.object(
InliningInstructionTranslator, "inline_call", patched_inline_call
),
# 在编译期间暂时应用配置
torch._dynamo.config.patch(**dynamo_config_patches),
# 按配置启动cuda graph相关的分区包装(无需求时为no-op)
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
# 为torch2.7做张量子类的临时兼容(避免trace过程中断)
_torch27_patch_tensor_subclasses(),
):
# 触发一次完整的Dynamo/AOT编译流程,生成或复用合适的后端执行路径,随后返回推理结果
output = self.compiled_callable(*args, **kwargs)
return output

# 当自定义分发器为真且已有compile_codes时,类会走另一条路径
# 绕过守卫检查直接执行已经捕获的高校代码
with self.dispatch_to_code(0):
model_output = self.forward(*args, **kwargs)
return model_output

cls.__call__ = __call__
return cls

日志对应时序(与“Start compiling function … qwen2.py …” 日志对齐)

  • 设备与模型加载阶段
    • 设置设备: gpu_worker.py:170 绑定到 cuda:{LOCAL_RANK} 。
    • 加载模型: gpu_model_runner.py 完成权重加载,打印显存与耗时日志。
    • 决策是否包 CUDA Graph: gpu_model_runner.py:2880-2960 。
  • 预热与首次编译触发
    • 预热调用: gpu_worker.py:345+ 遍历 compile_sizes 调用 _dummy_run ,进入模型的 call
    • 首次编译入口: decorators.py:call 检测到 compiled_codes 为空,按 dynamic_arg_dims 标记输入的动态维度,调用 start_monitoring_torch_compile(…) 。
    • 打印编译开始日志: logger.debug(“Start compiling function %s”, self.original_code_object) ,这就是你看到的 “Start compiling function <code object forward at …, file …/qwen2.py, line …>”。
    • 收集追踪文件:同一 call 中,在首次执行 self.compiled_callable(*args, **kwargs) 前打补丁 InliningInstructionTranslator.inline_call ,将顶层 forward 的 co_filename 以及所有内联函数的 co_filename 加入 traced_files 。
    • 编译完成:Dynamo/Inductor 完成编译后, bytecode_hook 把新字节码保存到 self.compiled_codes ,必要时把反编译源码 dump 到 compile_debug_dump_path() 。
  • CUDA Graph 捕获(若启用)
    • 完成预热后: gpu_worker.py:365+ 调用 capture_model() 。如启用 FULL 模式,会做一次全图捕获;如是 piecewise,按照分区逐段包装/捕获。
    • 随后进入实际推理,调度至已编译图或 CUDA Graph。
1
2
3
4
5
self.compiled_callable是什么?怎么得到的?
这个_support_torch_compile在device=cuda,tp=2的配置下,是在哪个设备上执行的?
这里的dynamic_arg_dims是什么东西?是需要进行分配的batch输入吗?还是模型参数?
maybe_use_cudagraph_partition_wrapper
InliningInstructionTranslator是什么

评论