-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathapp.py
More file actions
68 lines (56 loc) · 2.29 KB
/
Copy pathapp.py
File metadata and controls
68 lines (56 loc) · 2.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import os, inferless
from typing import Optional
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
os.environ["HF_HUB_ENABLE_HF_TRANSFER"]='1'
@inferless.request
class RequestObjects(BaseModel):
prompt: str = Field(default="Explain reinforcement learning in simple words.")
system_prompt: Optional[str] = "You are a friendly bot"
max_new_tokens: Optional[int] = 100
temperature: Optional[float] = 0.1
top_p: Optional[float] = 1.0
do_sample: Optional[bool] = True
@inferless.response
class ResponseObjects(BaseModel):
generated_text: str = Field(default="Test output")
class InferlessPythonModel:
def initialize(self):
model_name = "Qwen/Qwen3-30B-A3B-Instruct-2507"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="cuda"
)
# Dump environment info to the console
def infer(self, inputs: RequestObjects) -> ResponseObjects:
# Prepare messages
messages = [{"role": "system", "content": inputs.system_prompt},
{"role": "user", "content": inputs.prompt}
]
# Apply chat template
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Tokenize input
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
# Generate text
with torch.inference_mode():
generated_ids = self.model.generate(
**model_inputs,
max_new_tokens=inputs.max_new_tokens,
temperature=inputs.temperature,
top_p=inputs.top_p,
do_sample=inputs.do_sample
)
# Extract only the generated tokens (excluding input)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
# Decode the generated content
content = self.tokenizer.decode(output_ids, skip_special_tokens=True)
return ResponseObjects(generated_text=content)
def finalize(self):
self.model = None