Skip to content

Commit 51282de

Browse files
committed
fix: Fixed bug with invalid chart types generated
1 parent 5fb6150 commit 51282de

2 files changed

Lines changed: 12 additions & 7 deletions

File tree

backend/dataline/services/llm_flow/llm_calls/chart_generator.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from enum import StrEnum
33

44
from mirascope.core import prompt_template
5-
from pydantic import BaseModel, ValidationInfo, field_validator
5+
from pydantic import BaseModel, model_validator
66

77

88
class ChartType(StrEnum):
@@ -201,22 +201,27 @@ class ChartType(StrEnum):
201201

202202

203203
class GeneratedChart(BaseModel):
204+
chart_type: ChartType
204205
chartjs_json: str
205206

206-
@field_validator("chartjs_json", mode="before")
207+
@model_validator(mode="before")
207208
@classmethod
208-
def check_alphanumeric(cls, v: str | dict, info: ValidationInfo) -> str:
209+
def check_json(cls, data: dict):
210+
v = data["chartjs_json"]
209211
if isinstance(v, dict):
210212
# check chart type - fails if not in valid types
211-
v["type"] = ChartType[v["type"]]
213+
v["type"] = ChartType[data["chart_type"]]
212214
# convert to json str
213215
v = json.dumps(v)
214216

215217
elif isinstance(v, str):
216218
v_dict = json.loads(v)
217219
# check chart type - this fails if not in valid types
218-
v_dict["type"] = ChartType[v_dict["type"]]
219-
return v
220+
v_dict["type"] = ChartType[data["chart_type"]]
221+
# convert back to json str
222+
v = json.dumps(v_dict)
223+
data["chartjs_json"] = v
224+
return data
220225

221226

222227
@prompt_template()

backend/dataline/services/llm_flow/toolkit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def get_response( # type: ignore[misc]
522522
base_url=state.options.openai_base_url,
523523
),
524524
)(
525-
chart_type=ChartType[args["chart_type"]],
525+
chart_type=chart_type,
526526
request=args["request"],
527527
chartjs_template=TEMPLATES[chart_type],
528528
)

0 commit comments

Comments
 (0)