量化#

SGLang 支持多种量化方法,包括离线量化和在线动态量化。

离线量化在推理期间直接加载预量化的模型权重。对于 GPTQ 和 AWQ 等量化方法,这是必需的,因为这些方法需要使用校准数据集从原始权重中收集和预计算各种统计信息。

在线量化在运行时动态计算缩放参数,例如模型权重的最大/最小值。与 NVIDIA FP8 训练的延迟缩放机制类似,在线量化会即时计算适当的缩放因子,将高精度权重转换为低精度格式。

注意:为获得更好的性能、可用性和便利性,建议使用离线量化而非在线量化。

如果您使用预量化模型,请不要同时添加 --quantization 来启用在线量化。对于流行的预量化模型,请访问 HuggingFace 上的 UnslothModelCloudNeuralMagic 集合,获取一些经过质量验证的流行量化模型。量化模型必须在量化后通过基准进行验证,以防止异常的量化损失回归。

离线量化#

要加载已经量化的模型,只需加载模型权重和配置即可。再次强调,如果模型已经离线量化,启动引擎时无需添加 --quantization 参数。量化方法将从下载的 HuggingFace 配置中解析。例如,DeepSeek V3/R1 模型已经是 FP8 格式,因此无需添加冗余参数。

python3 -m sglang.launch_server \
    --model-path hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 \
    --port 30000 --host 0.0.0.0

请注意,如果您的模型是按通道量化(INT8 或 FP8)并带有按令牌动态量化激活的,您可以选择包含 --quantization w8a8_int8--quantization w8a8_fp8 来调用 sgl-kernel 中相应的 CUTLASS int8_kernel 或 fp8_kernel。此操作将忽略 HuggingFace 配置的量化设置。例如,对于 neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic,如果您使用 --quantization w8a8_fp8 执行,系统将使用 SGLang 的 W8A8Fp8Config 来调用 sgl-kernel,而不是 vLLM 内核的 CompressedTensorsConfig

python3 -m sglang.launch_server \
    --model-path neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic \
    --quantization w8a8_fp8 \
    --port 30000 --host 0.0.0.0

离线模型量化示例#

使用 Unsloth#

我们强烈建议使用 Unsloth 来量化和加载模型。请参考 SGLang with Unsloth 部署与推理指南

使用 auto-round#

# 安装
pip install auto-round
  • LLM 量化

# 对于 LLM
from auto_round import AutoRound
model_id = "meta-llama/Llama-3.2-1B-Instruct"
quant_path = "Llama-3.2-1B-Instruct-autoround-4bit"
# 方案示例: "W2A16", "W3A16", "W4A16", "W8A16", "NVFP4", "MXFP4" (无实际内核), "GGUF:Q4_K_M" 等
scheme = "W4A16"
format = "auto_round"
autoround = AutoRound(model_id, scheme=scheme)
autoround.quantize_and_save(quant_path, format=format) # 量化并保存
  • VLM 量化

# 对于 VLMs
from auto_round import AutoRoundMLLM
model_name = "Qwen/Qwen2-VL-2B-Instruct"
quant_path = "Qwen2-VL-2B-Instruct-autoround-4bit"
scheme = "W4A16"
format = "auto_round"
autoround = AutoRoundMLLM(model_name, scheme)
autoround.quantize_and_save(quant_path, format=format) # 量化并保存
  • 命令行使用 (Gaudi/CPU/Intel GPU/CUDA)

auto-round \
    --model meta-llama/Llama-3.2-1B-Instruct \
    --bits 4 \
    --group_size 128 \
    --format "auto_round" \
    --output_dir ./tmp_autoround
  • 已知问题

目前有几个限制影响 sglang 中的离线量化模型加载,这些问题可能会在未来的 sglang 更新中得到解决。如果您遇到任何问题,请考虑使用 Hugging Face Transformers 作为替代方案。

  1. 混合位量化限制

    混合位量化未完全支持。由于 vLLM 的层融合(例如 QKV 融合),在同一融合层中对组件应用不同的位宽可能会导致兼容性问题。

  2. 对量化 MoE 模型的支持有限

    由于内核限制(例如不支持 mlp.gate 层量化),量化 MoE 模型可能会遇到推理问题。请尝试跳过量化这些层以避免此类错误。

  3. 对量化 VLM 的支持有限

    VLM 失败情况

    Qwen2.5-VL-7B

    auto_round:auto_gptq 格式: 准确率接近零。

    GPTQ 格式: 失败并显示:

    输出大小与量化权重形状不匹配
    

    auto_round:auto_awq 和 AWQ 格式: 按预期工作。

使用 GPTQModel#

# 安装
pip install gptqmodel --no-build-isolation -v
from datasets import load_dataset
from gptqmodel import GPTQModel, QuantizeConfig

model_id = "meta-llama/Llama-3.2-1B-Instruct"
quant_path = "Llama-3.2-1B-Instruct-gptqmodel-4bit"

calibration_dataset = load_dataset(
    "allenai/c4", data_files="en/c4-train.00001-of-01024.json.gz",
    split="train"
  ).select(range(1024))["text"]

quant_config = QuantizeConfig(bits=4, group_size=128) # 量化配置
model = GPTQModel.load(model_id, quant_config) # 加载模型

model.quantize(calibration_dataset, batch_size=2) # 量化
model.save(quant_path) # 保存模型

使用 LLM Compressor#

# 安装
pip install llmcompressor

这里,我们以将 meta-llama/Meta-Llama-3-8B-Instruct 量化为 FP8 为例,详细说明如何进行离线量化。

from transformers import AutoTokenizer
from llmcompressor.transformers import SparseAutoModelForCausalLM
from llmcompressor.transformers import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier

# 第 1 步:加载原始模型。
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

model = SparseAutoModelForCausalLM.from_pretrained(
  MODEL_ID, device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# 第 2 步:执行离线量化。
# 第 2.1 步:配置简单 PTQ 量化。
recipe = QuantizationModifier(
  targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"])

# 第 2.2 步:应用量化算法。
oneshot(model=model, recipe=recipe)

# 第 3 步:保存模型。
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

然后,您可以使用以下命令直接将量化模型与 SGLang 一起使用:

python3 -m sglang.launch_server \
    --model-path $PWD/Meta-Llama-3-8B-Instruct-FP8-Dynamic \
    --port 30000 --host 0.0.0.0

使用 NVIDIA ModelOpt#

NVIDIA Model Optimizer (ModelOpt) 提供针对 NVIDIA 硬件优化的高级量化技术。SGLang 包含一个简化的工作流,用于使用 ModelOpt 量化模型并自动导出以供部署。

安装#

首先安装 ModelOpt。您可以直接安装,也可以作为 SGLang 的可选依赖项安装:

# 选项 1: 直接安装 ModelOpt
pip install nvidia-modelopt

# 选项 2: 安装支持 ModelOpt 的 SGLang (推荐)
pip install sglang[modelopt]
量化和导出工作流#

SGLang 提供了一个示例脚本,演示了完整的 ModelOpt 量化和导出工作流:

# 使用 ModelOpt FP8 量化并导出模型
python examples/usage/modelopt_quantize_and_export.py quantize \
    --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
    --export-dir ./quantized_tinyllama_fp8 \
    --quantization-method modelopt_fp8

# 对于 FP4 量化
python examples/usage/modelopt_quantize_and_export.py quantize \
    --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
    --export-dir ./quantized_tinyllama_fp4 \
    --quantization-method modelopt_fp4
可用量化方法#
  • modelopt_fp8: FP8 量化,在 NVIDIA Hopper 和 Blackwell GPU 上具有最佳性能

  • modelopt_fp4: FP4 量化,在 NVIDIA Blackwell GPU 上具有最佳性能

Python API 使用#

您也可以以编程方式使用 ModelOpt 量化:

import sglang as sgl
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_loader.loader import get_model_loader

# 配置带 ModelOpt 量化和导出的模型
model_config = ModelConfig(
    model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    quantization="modelopt_fp8",  # 或 "modelopt_fp4"
    trust_remote_code=True,
)

load_config = LoadConfig(
    modelopt_export_path="./exported_model",
    modelopt_checkpoint_save_path="./checkpoint.pth",  # 可选,伪量化检查点
)
device_config = DeviceConfig(device="cuda")

# 加载和量化模型 (导出自动完成)
model_loader = get_model_loader(load_config, model_config)
quantized_model = model_loader.load_model(
    model_config=model_config,
    device_config=device_config,
)
部署量化模型#

量化和导出后,您可以使用 SGLang 部署模型:

# 部署导出的量化模型
python -m sglang.launch_server \
    --model-path ./quantized_tinyllama_fp8 \
    --quantization modelopt \
    --port 30000 --host 0.0.0.0

或使用 Python API:

import sglang as sgl

# 部署导出的 ModelOpt 量化模型
llm = sgl.Engine(
    model_path="./quantized_tinyllama_fp8",
    quantization="modelopt"
)

# 运行推理
prompts = ["你好,你好吗?", "法国的首都是什么?"]
sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 100}
outputs = llm.generate(prompts, sampling_params)

for i, output in enumerate(outputs):
    print(f"提示: {prompts[i]}")
    print(f"输出: {output.outputs[0].text}")
高级功能#

检查点管理: 保存和恢复伪量化检查点以便重用:

# 量化期间保存伪量化检查点
python examples/usage/modelopt_quantize_and_export.py quantize \
    --model-path meta-llama/Llama-3.2-1B-Instruct \
    --export-dir ./quantized_model \
    --quantization-method modelopt_fp8 \
    --checkpoint-save-path ./my_checkpoint.pth

# 该检查点可在未来的量化运行中重用并跳过校准

仅导出工作流: 如果您已有现成的伪量化 ModelOpt 检查点,可以直接导出:

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_loader.loader import get_model_loader

model_config = ModelConfig(
    model_path="meta-llama/Llama-3.2-1B-Instruct",
    quantization="modelopt_fp8",
    trust_remote_code=True,
)

load_config = LoadConfig(
    modelopt_checkpoint_restore_path="./my_checkpoint.pth",
    modelopt_export_path="./exported_model",
)

# 加载和导出模型
model_loader = get_model_loader(load_config, model_config)
model_loader.load_model(model_config=model_config, device_config=DeviceConfig())
ModelOpt 的优势#
  • 硬件优化: 专门为 NVIDIA GPU 架构优化

  • 高级量化: 支尖端的 FP8 和 FP4 量化技术

  • 无缝集成: 自动导出到 HuggingFace 格式,便于部署

  • 基于校准: 使用校准数据集实现最佳量化质量

  • 生产就绪: 具有 NVIDIA 支持的企业级量化

在线量化#

要启用在线量化,只需在命令行中指定 --quantization。例如,您可以使用以下命令启动服务器,为模型 meta-llama/Meta-Llama-3.1-8B-Instruct 启用 FP8 量化:

python3 -m sglang.launch_server \
    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
    --quantization fp8 \
    --port 30000 --host 0.0.0.0

我们的团队正在支持更多在线量化方法。SGLang 将很快支持包括但不限于 ["awq", "gptq", "marlin", "gptq_marlin", "awq_marlin", "bitsandbytes", "gguf"] 的方法。

SGLang 还支持基于 torchao 的量化方法。您只需在命令行中指定 --torchao-config 即可支持此功能。例如,如果您想为模型 meta-llama/Meta-Llama-3.1-8B-Instruct 启用 int4wo-128,可以使用以下命令启动服务器:

python3 -m sglang.launch_server \
    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
    --torchao-config int4wo-128 \
    --port 30000 --host 0.0.0.0

SGLang 支持基于 torchao 的以下量化方法 ["int8dq", "int8wo", "fp8wo", "fp8dq-per_tensor", "fp8dq-per_row", "int4wo-32", "int4wo-64", "int4wo-128", "int4wo-256"]

注意:根据此问题"int8dq" 方法目前在使用 cuda 图捕获时存在一些问题。因此,我们建议在使用 "int8dq" 方法时禁用 cuda 图捕获。即,请使用以下命令:

python3 -m sglang.launch_server \
    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
    --torchao-config int8dq \
    --disable-cuda-graph \
    --port 30000 --host 0.0.0.0

参考#