Skip to content

Commit 78ca46d

Browse files
authored
Merge pull request #302 from RamiAwar/S18.1-enable-disable-schemas-and-tables
Toggle schemas and tables. FE, BE, the whole package.
2 parents e2388f3 + e29b2e7 commit 78ca46d

17 files changed

Lines changed: 456 additions & 148 deletions

File tree

backend/README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,22 @@ They stay if you ship something you're proud of, you've shipped too late.
9191
```
9292
pre-commit install
9393
```
94+
95+
# Example DBMS-based databases
96+
97+
## SQL Server:
98+
99+
`docker run -p 1433:1433 -e 'ACCEPT_EULA=Y' -e 'SA_PASSWORD=My_password1' -d chriseaton/adventureworks:latest`
100+
DSN: `mssql://SA:My_password1@localhost/AdventureWorks?TrustServerCertificate=yes&driver=ODBC+Driver+18+for+SQL+Server`
101+
102+
## PostgreSQL:
103+
104+
Build & Run image locally: `bash ./scripts/postgres_img_with_sample_data.sh`
105+
106+
DSN: `postgres://postgres:dvdrental@localhost:5432/dvdrental`
107+
108+
## MySQL:
109+
110+
`docker run -p 3306:3306 -d sakiladb/mysql`
111+
112+
DSN: `mysql://sakila:p_ssW0rd@127.0.0.1:3306/sakila`
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Add Connection options column (JSON)
2+
3+
Revision ID: fa9cefccac47
4+
Revises: 3f6e32040035
5+
Create Date: 2024-08-15 13:02:24.137632
6+
7+
"""
8+
9+
from typing import Sequence, Union
10+
11+
from alembic import op
12+
import sqlalchemy as sa
13+
14+
15+
# revision identifiers, used by Alembic.
16+
revision: str = "fa9cefccac47"
17+
down_revision: Union[str, None] = "3f6e32040035"
18+
branch_labels: Union[str, Sequence[str], None] = None
19+
depends_on: Union[str, Sequence[str], None] = None
20+
21+
22+
def upgrade() -> None:
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
with op.batch_alter_table("connections", schema=None) as batch_op:
25+
batch_op.add_column(sa.Column("options", sa.JSON(), nullable=True))
26+
27+
# ### end Alembic commands ###
28+
29+
30+
def downgrade() -> None:
31+
# ### commands auto generated by Alembic - please adjust! ###
32+
with op.batch_alter_table("connections", schema=None) as batch_op:
33+
batch_op.drop_column("options")
34+
35+
# ### end Alembic commands ###

backend/dataline/api/conversation/router.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ async def execute_sql(
132132
connection = await connection_service.get_connection(session, connection_id)
133133

134134
# Refresh chart data
135-
db = SQLDatabase.from_uri(connection.dsn)
135+
db = SQLDatabase.from_dataline_connection(connection)
136136
query_run_data = execute_sql_query(db, sql)
137137

138138
# Execute query
Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, TypedDict
22

3-
from sqlalchemy import Boolean, String
3+
from sqlalchemy import Boolean, JSON, String
44
from sqlalchemy.orm import Mapped, mapped_column, relationship
55

66
from dataline.models.base import DBModel, UUIDMixin
@@ -9,6 +9,21 @@
99
from dataline.models.conversation.model import ConversationModel
1010

1111

12+
class ConnectionSchemaTable(TypedDict):
13+
name: str
14+
enabled: bool
15+
16+
17+
class ConnectionSchema(TypedDict):
18+
name: str
19+
tables: list[ConnectionSchemaTable]
20+
enabled: bool
21+
22+
23+
class ConnectionOptions(TypedDict):
24+
schemas: list[ConnectionSchema]
25+
26+
1227
class ConnectionModel(DBModel, UUIDMixin, kw_only=True):
1328
__tablename__ = "connections"
1429
dsn: Mapped[str] = mapped_column("dsn", String, nullable=False, unique=True)
@@ -17,6 +32,7 @@ class ConnectionModel(DBModel, UUIDMixin, kw_only=True):
1732
type: Mapped[str] = mapped_column("type", String, nullable=False)
1833
dialect: Mapped[str | None] = mapped_column("dialect", String)
1934
is_sample: Mapped[bool] = mapped_column("is_sample", Boolean, nullable=False, default=False, server_default="false")
35+
options: Mapped[ConnectionOptions | None] = mapped_column("options", JSON, nullable=True)
2036

2137
# Relationships
2238
conversations: Mapped[list["ConversationModel"]] = relationship("ConversationModel", back_populates="connection")

backend/dataline/models/connection/schema.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@
88
from dataline.config import config
99

1010

11+
class ConnecitonSchemaTable(BaseModel):
12+
name: str
13+
enabled: bool
14+
15+
16+
class ConnectionSchema(BaseModel):
17+
name: str
18+
tables: list[ConnecitonSchemaTable]
19+
enabled: bool
20+
21+
22+
class ConnectionOptions(BaseModel):
23+
schemas: list[ConnectionSchema]
24+
25+
1126
class Connection(BaseModel):
1227
model_config = ConfigDict(from_attributes=True)
1328

@@ -18,6 +33,7 @@ class Connection(BaseModel):
1833
dialect: str
1934
type: str
2035
is_sample: bool
36+
options: Optional[ConnectionOptions] = None
2137

2238

2339
class ConnectionOut(Connection):
@@ -141,6 +157,7 @@ def validate_dsn_format(cls, value: str) -> str:
141157
class ConnectionUpdateIn(BaseModel):
142158
name: Optional[str] = None
143159
dsn: Optional[str] = None
160+
options: Optional[ConnectionOptions] = None
144161

145162
@field_validator("dsn")
146163
def validate_dsn_format(cls, value: str) -> str:

backend/dataline/repositories/connection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from dataline.models.connection.model import ConnectionModel
88
from dataline.repositories.base import AsyncSession, BaseRepository
9+
from dataline.models.connection.schema import ConnectionOptions
10+
911

1012
class ConnectionType(Enum):
1113
csv = "csv"
@@ -16,6 +18,7 @@ class ConnectionType(Enum):
1618
snowflake = "snowflake"
1719
sas = "sas"
1820

21+
1922
class ConnectionCreate(BaseModel):
2023
model_config = ConfigDict(from_attributes=True, extra="ignore")
2124

@@ -25,6 +28,7 @@ class ConnectionCreate(BaseModel):
2528
dialect: str
2629
type: str
2730
is_sample: bool = False
31+
options: ConnectionOptions | None = None
2832

2933

3034
class ConnectionUpdate(BaseModel):
@@ -36,6 +40,7 @@ class ConnectionUpdate(BaseModel):
3640
dialect: str | None = None
3741
type: str | None = None
3842
is_sample: bool | None = None
43+
options: ConnectionOptions | None = None
3944

4045

4146
class ConnectionRepository(BaseRepository[ConnectionModel, ConnectionCreate, ConnectionUpdate]):

backend/dataline/repositories/result.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,23 @@ def model(self) -> Type[ResultModel]:
1818
return ResultModel
1919

2020
async def get_dsn_from_result(self, session: AsyncSession, result_id: UUID) -> str:
21+
connection = await self.get_connection_from_result(session, result_id)
22+
return connection.dsn
23+
24+
async def get_connection_from_result(self, session: AsyncSession, result_id: UUID) -> ConnectionModel:
2125
query = (
22-
select(ConnectionModel.dsn)
26+
select(ConnectionModel)
2327
.join(ConversationModel)
2428
.join(MessageModel)
2529
.join(ResultModel)
2630
.where(ResultModel.id == result_id)
2731
)
2832
result = await session.execute(query)
29-
dsn = result.fetchone()
30-
if not dsn:
31-
raise ValueError(f"Could not find DSN for result_id: {result_id}")
33+
connection = result.fetchone()
34+
if not connection:
35+
raise ValueError(f"Could not find connection for result_id: {result_id}")
3236

33-
return dsn[0]
37+
return connection[0]
3438

3539
async def get_chart_from_sql_query(self, session: AsyncSession, sql_string_result_id: UUID) -> ResultModel:
3640
query = (

backend/dataline/services/connection.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,18 @@
99
import pandas as pd
1010
import pyreadstat
1111
from fastapi import Depends, UploadFile
12-
from sqlalchemy import create_engine
1312
from sqlalchemy.exc import OperationalError
1413

1514
from dataline.config import config
1615
from dataline.errors import ValidationError
1716
from dataline.models.connection.model import ConnectionModel
18-
from dataline.models.connection.schema import ConnectionOut, ConnectionUpdateIn
17+
from dataline.models.connection.schema import (
18+
ConnecitonSchemaTable,
19+
ConnectionOptions,
20+
ConnectionOut,
21+
ConnectionSchema,
22+
ConnectionUpdateIn,
23+
)
1924
from dataline.repositories.base import AsyncSession, NotFoundError, NotUniqueError
2025
from dataline.repositories.connection import (
2126
ConnectionCreate,
@@ -24,6 +29,7 @@
2429
ConnectionUpdate,
2530
)
2631
from dataline.services.file_parsers.excel_parser import ExcelParserService
32+
from dataline.services.llm_flow.utils import DatalineSQLDatabase as SQLDatabase
2733
from dataline.utils.utils import (
2834
forward_connection_errors,
2935
generate_short_uuid,
@@ -54,37 +60,29 @@ async def get_connections(self, session: AsyncSession) -> list[ConnectionOut]:
5460
async def delete_connection(self, session: AsyncSession, connection_id: UUID) -> None:
5561
await self.connection_repo.delete_by_uuid(session, connection_id)
5662

57-
async def get_connection_details(self, dsn: str) -> tuple[str, str, str]:
63+
async def get_connection_details(self, dsn: str) -> SQLDatabase:
5864
# Check if connection can be established before saving it
5965
try:
60-
engine = create_engine(dsn)
61-
with engine.connect():
62-
pass
63-
64-
dialect = engine.url.get_dialect().name
65-
database = engine.url.database
66+
db = SQLDatabase.from_uri(dsn)
67+
database = db._engine.url.database
6668

6769
if not database:
6870
raise ValidationError("Invalid DSN. Database name is missing, append '/DBNAME'.")
6971

70-
return dialect, database, dsn
72+
return db
7173

7274
except OperationalError as exc:
7375
# Try again replacing localhost with host.docker.internal to connect with DBs running in docker
7476
if "localhost" in dsn:
7577
dsn = dsn.replace("localhost", "host.docker.internal")
7678
try:
77-
engine = create_engine(dsn)
78-
with engine.connect():
79-
pass
80-
81-
dialect = engine.url.get_dialect().name
82-
database = engine.url.database
79+
db = SQLDatabase.from_uri(dsn)
80+
database = db._engine.url.database
8381

8482
if not database:
8583
raise ValidationError("Invalid DSN. Database name is missing, append '/DBNAME'.")
8684

87-
return dialect, database, dsn
85+
return db
8886
except OperationalError as e:
8987
logger.error(e)
9088
raise ValidationError("Failed to connect to database, please check your DSN.")
@@ -124,10 +122,14 @@ async def update_connection(
124122
raise NotUniqueError("Connection DSN already exists.")
125123

126124
# Check if connection can be established before saving it
127-
dialect, database, dsn = await self.get_connection_details(data.dsn)
128-
update.dsn = dsn
129-
update.database = database
130-
update.dialect = dialect
125+
db = await self.get_connection_details(data.dsn)
126+
update.dsn = str(db._engine.url.render_as_string(hide_password=False))
127+
update.database = db._engine.url.database
128+
update.dialect = db.dialect
129+
# TODO: What do we do for the options? Enable everything?
130+
elif data.options:
131+
# only modify options if dsn hasn't changed
132+
update.options = data.options
131133

132134
if data.name:
133135
update.name = data.name
@@ -140,20 +142,35 @@ async def create_connection(
140142
session: AsyncSession,
141143
dsn: str,
142144
name: str,
143-
type: str | None = None,
145+
connection_type: str | None = None,
144146
is_sample: bool = False,
145147
) -> ConnectionOut:
146148
# Check if connection can be established before saving it
147-
dialect, database, dsn = await self.get_connection_details(dsn)
148-
if not type:
149-
type = dialect
149+
db = await self.get_connection_details(dsn)
150+
if not connection_type:
151+
connection_type = db.dialect
150152

151153
# Check if connection already exists
152154
await self.check_dsn_already_exists(session, dsn)
153-
155+
connection_schemas: list[ConnectionSchema] = [
156+
ConnectionSchema(
157+
name=schema,
158+
tables=[ConnecitonSchemaTable(name=table, enabled=True) for table in tables],
159+
enabled=True,
160+
)
161+
for schema, tables in db._all_tables_per_schema.items()
162+
]
154163
connection = await self.connection_repo.create(
155164
session,
156-
ConnectionCreate(dsn=dsn, database=database, name=name, dialect=dialect, type=type, is_sample=is_sample),
165+
ConnectionCreate(
166+
dsn=dsn,
167+
database=db._engine.url.database,
168+
name=name,
169+
dialect=db.dialect,
170+
type=connection_type,
171+
is_sample=is_sample,
172+
options=ConnectionOptions(schemas=connection_schemas),
173+
),
157174
)
158175
return ConnectionOut.model_validate(connection)
159176

@@ -186,7 +203,9 @@ async def create_csv_connection(self, session: AsyncSession, file: UploadFile, n
186203

187204
# Create connection with the locally copied file
188205
dsn = get_sqlite_dsn(str(file_path.absolute()))
189-
return await self.create_connection(session, dsn=dsn, name=name, type=ConnectionType.csv.value, is_sample=False)
206+
return await self.create_connection(
207+
session, dsn=dsn, name=name, connection_type=ConnectionType.csv.value, is_sample=False
208+
)
190209

191210
async def create_excel_connection(self, session: AsyncSession, file: UploadFile, name: str) -> ConnectionOut:
192211
generated_name = generate_short_uuid() + ".sqlite"
@@ -201,7 +220,7 @@ async def create_excel_connection(self, session: AsyncSession, file: UploadFile,
201220
# Create connection with the locally copied file
202221
dsn = get_sqlite_dsn(str(file_path.absolute()))
203222
return await self.create_connection(
204-
session, dsn=dsn, name=name, type=ConnectionType.excel.value, is_sample=False
223+
session, dsn=dsn, name=name, connection_type=ConnectionType.excel.value, is_sample=False
205224
)
206225

207226
async def create_sas7bdat_connection(self, session: AsyncSession, file: UploadFile, name: str) -> ConnectionOut:
@@ -244,7 +263,7 @@ async def create_sas7bdat_connection(self, session: AsyncSession, file: UploadFi
244263
# Create connection with the locally copied file
245264
dsn = get_sqlite_dsn(str(file_path.absolute()))
246265
return await self.create_connection(
247-
session, dsn=dsn, name=name, type=ConnectionType.sas.value, is_sample=False
266+
session, dsn=dsn, name=name, connection_type=ConnectionType.sas.value, is_sample=False
248267
)
249268
finally:
250269
# Clean up the temporary file

backend/dataline/services/conversation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,7 @@ async def query(
114114
user_with_model_details = await self.settings_service.get_model_details(session)
115115

116116
# Create query graph
117-
query_graph = QueryGraphService(
118-
dsn=connection.dsn,
119-
)
117+
query_graph = QueryGraphService(connection=connection)
120118
history = await self.get_conversation_history(session, conversation_id)
121119

122120
messages: list[BaseMessage] = []

0 commit comments

Comments
 (0)