Skip to content

为什么我使用示例Transformers Example的结果是乱码 #9

Description

@czheng1108

模型我是从https://huggingface.co/opendatalab/MinerU-Diffusion-V1-0320-2.5B 下载的,然后运行脚本如下:

import torch
from transformers import AutoModel, AutoProcessor, AutoTokenizer

model_id = "/data1/models/MinerU-Diffusion-V1-0320-2.5B"
image_path = "/data1/test/page1.png"
device = torch.device("cuda:6")

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(
    model_id,
    trust_remote_code=True,
    use_fast=False,
)
model = AutoModel.from_pretrained(
    model_id,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
).eval().to(device)

messages = [
    {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
    {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": "\nText Recognition:"}]},
]

prompt_text = processor.apply_chat_template(messages, add_generation_prompt=True)
if isinstance(prompt_text, tuple):
    prompt_text = prompt_text[0]

inputs = processor(
    images=[image_path],
    text=prompt_text,
    truncation=True,
    max_length=4096,
    return_tensors="pt",
)
input_ids = inputs["input_ids"].to(torch.long).to(device)
pixel_values = inputs["pixel_values"].to(torch.bfloat16).to(device)
image_grid_thw = inputs.get("image_grid_thw")
if image_grid_thw is not None:
    image_grid_thw = image_grid_thw.to(torch.long).to(device)

with torch.no_grad():
    generate_outputs = model.generate(
        pixel_values=pixel_values,
        image_grid_thw=image_grid_thw,
        input_ids=input_ids,
        mask_token_id=tokenizer.convert_tokens_to_ids("<|MASK|>"),
        denoising_steps=32,
        gen_length=1024,
        block_length=32,
        temperature=1.0,
        remasking_strategy="low_confidence_dynamic",
        dynamic_threshold=0.95,
        tokenizer=tokenizer,
        stopping_criteria=["<|endoftext|>", "<|im_end|>"],
    )

if isinstance(generate_outputs, tuple):
    output_ids = generate_outputs[0]
else:
    output_ids = generate_outputs

text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
for stop in ("<|endoftext|>", "<|im_end|>"):
    text = text.split(stop, 1)[0]

print(text.strip())

结果如下:

(((((((您((是为(((提((((所:
费  
(是(((((是(明(((本险((((加(产品(((((明中费(为费天(险(由(:
您费联付(险(费费(((费(:
((连续(费((费费)((保本()险费费((人费费(:
((的(险(保:
费(“(费费()(三((险保(明险明保费((险())(:
保

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions