Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepseek r1的蒸馏模型使用xgrammar时输出异常 #187

Open
distance92 opened this issue Feb 8, 2025 · 2 comments
Open

deepseek r1的蒸馏模型使用xgrammar时输出异常 #187

distance92 opened this issue Feb 8, 2025 · 2 comments

Comments

@distance92
Copy link

使用deepseek r1的蒸馏模型qwen2.5-7B时无输出,原版qwen模型有输出

@distance92
Copy link
Author

distance92 commented Feb 8, 2025

from fastapi import FastAPI, Request
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import uvicorn
import torch
import xgrammar as xgr
import json
import datetime

# 设置设备参数
DEVICE = "cuda"  # 使用CUDA
DEVICE_ID = "0"  # CUDA设备ID,如果未设置则为空
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE  # 组合CUDA设备信息


# 清理GPU内存函数
def torch_gc():
    if torch.cuda.is_available():  # 检查是否可用CUDA
        with torch.cuda.device(CUDA_DEVICE):  # 指定CUDA设备
            torch.cuda.empty_cache()  # 清空CUDA缓存
            torch.cuda.ipc_collect()  # 收集CUDA内存碎片


# 创建FastAPI应用
app = FastAPI()

# 加载预训练的分词器和模型
model_name = "/root/model/DeepSeek-R1-Distill-Qwen-7B"
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype='auto', device_map=DEVICE
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
full_vocab_size = config.vocab_size

# 编译语法
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=full_vocab_size)
grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
compiled_grammar: xgr.CompiledGrammar = grammar_compiler.compile_builtin_json_grammar()


# 处理POST请求的端点
@app.post("/")
async def generate_json(request: Request):
    global model, tokenizer, compiled_grammar  # 声明全局变量
    json_post_raw = await request.json()  # 获取POST请求的JSON数据
    json_post = json.dumps(json_post_raw)  # 将JSON数据转换为字符串
    json_post_list = json.loads(json_post)  # 将字符串转换为Python对象
    prompt = json_post_list.get('prompt')  # 获取请求中的提示
    system_prompt = json_post_list.get('system')

    messages = [{"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt}]

    input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([input_ids], return_tensors="pt").to(model.device)

    # 实例化LogitsProcessor并调用generate()
    xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)
    generated_ids = model.generate(
        **model_inputs, max_new_tokens=512, logits_processor=[xgr_logits_processor]
    )

    raw_response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print("raw_response: " + repr(raw_response))

    # 后处理输出
    generated_ids = [
        output_ids[len(input_ids):]
        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    now = datetime.datetime.now()  # 获取当前时间
    time = now.strftime("%Y-%m-%d %H:%M:%S")  # 格式化时间为字符串

    # 构建响应JSON
    answer = {
        "response": response,
        "status": 200,
        "time": time
    }

    # 构建日志信息
    log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
    print(log)  # 打印日志
    torch_gc()  # 执行GPU内存清理
    return answer  # 返回响应


# 主函数入口
if __name__ == '__main__':
    uvicorn.run(app, host='172.26.123.154', port=6006, workers=1)  # 在指定端口和主机上启动应用

@distance92 distance92 changed the title 对deepseek r1的蒸馏模型不支持 deepseek r1的蒸馏模型使用xgrammar时输出异常 Feb 8, 2025
@Ubospica
Copy link
Collaborator

Hi @distance92 , thanks for reporting the error! Could you provide more details of your experiment, like the error message or the output before exception?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants