99import pandas as pd
1010import pyreadstat
1111from fastapi import Depends , UploadFile
12- from sqlalchemy import create_engine
1312from sqlalchemy .exc import OperationalError
1413
1514from dataline .config import config
1615from dataline .errors import ValidationError
1716from dataline .models .connection .model import ConnectionModel
18- from dataline .models .connection .schema import ConnectionOut , ConnectionUpdateIn
17+ from dataline .models .connection .schema import (
18+ ConnecitonSchemaTable ,
19+ ConnectionOptions ,
20+ ConnectionOut ,
21+ ConnectionSchema ,
22+ ConnectionUpdateIn ,
23+ )
1924from dataline .repositories .base import AsyncSession , NotFoundError , NotUniqueError
2025from dataline .repositories .connection import (
2126 ConnectionCreate ,
2429 ConnectionUpdate ,
2530)
2631from dataline .services .file_parsers .excel_parser import ExcelParserService
32+ from dataline .services .llm_flow .utils import DatalineSQLDatabase as SQLDatabase
2733from dataline .utils .utils import (
2834 forward_connection_errors ,
2935 generate_short_uuid ,
@@ -54,37 +60,29 @@ async def get_connections(self, session: AsyncSession) -> list[ConnectionOut]:
5460 async def delete_connection (self , session : AsyncSession , connection_id : UUID ) -> None :
5561 await self .connection_repo .delete_by_uuid (session , connection_id )
5662
57- async def get_connection_details (self , dsn : str ) -> tuple [ str , str , str ] :
63+ async def get_connection_details (self , dsn : str ) -> SQLDatabase :
5864 # Check if connection can be established before saving it
5965 try :
60- engine = create_engine (dsn )
61- with engine .connect ():
62- pass
63-
64- dialect = engine .url .get_dialect ().name
65- database = engine .url .database
66+ db = SQLDatabase .from_uri (dsn )
67+ database = db ._engine .url .database
6668
6769 if not database :
6870 raise ValidationError ("Invalid DSN. Database name is missing, append '/DBNAME'." )
6971
70- return dialect , database , dsn
72+ return db
7173
7274 except OperationalError as exc :
7375 # Try again replacing localhost with host.docker.internal to connect with DBs running in docker
7476 if "localhost" in dsn :
7577 dsn = dsn .replace ("localhost" , "host.docker.internal" )
7678 try :
77- engine = create_engine (dsn )
78- with engine .connect ():
79- pass
80-
81- dialect = engine .url .get_dialect ().name
82- database = engine .url .database
79+ db = SQLDatabase .from_uri (dsn )
80+ database = db ._engine .url .database
8381
8482 if not database :
8583 raise ValidationError ("Invalid DSN. Database name is missing, append '/DBNAME'." )
8684
87- return dialect , database , dsn
85+ return db
8886 except OperationalError as e :
8987 logger .error (e )
9088 raise ValidationError ("Failed to connect to database, please check your DSN." )
@@ -124,10 +122,14 @@ async def update_connection(
124122 raise NotUniqueError ("Connection DSN already exists." )
125123
126124 # Check if connection can be established before saving it
127- dialect , database , dsn = await self .get_connection_details (data .dsn )
128- update .dsn = dsn
129- update .database = database
130- update .dialect = dialect
125+ db = await self .get_connection_details (data .dsn )
126+ update .dsn = str (db ._engine .url .render_as_string (hide_password = False ))
127+ update .database = db ._engine .url .database
128+ update .dialect = db .dialect
129+ # TODO: What do we do for the options? Enable everything?
130+ elif data .options :
131+ # only modify options if dsn hasn't changed
132+ update .options = data .options
131133
132134 if data .name :
133135 update .name = data .name
@@ -140,20 +142,35 @@ async def create_connection(
140142 session : AsyncSession ,
141143 dsn : str ,
142144 name : str ,
143- type : str | None = None ,
145+ connection_type : str | None = None ,
144146 is_sample : bool = False ,
145147 ) -> ConnectionOut :
146148 # Check if connection can be established before saving it
147- dialect , database , dsn = await self .get_connection_details (dsn )
148- if not type :
149- type = dialect
149+ db = await self .get_connection_details (dsn )
150+ if not connection_type :
151+ connection_type = db . dialect
150152
151153 # Check if connection already exists
152154 await self .check_dsn_already_exists (session , dsn )
153-
155+ connection_schemas : list [ConnectionSchema ] = [
156+ ConnectionSchema (
157+ name = schema ,
158+ tables = [ConnecitonSchemaTable (name = table , enabled = True ) for table in tables ],
159+ enabled = True ,
160+ )
161+ for schema , tables in db ._all_tables_per_schema .items ()
162+ ]
154163 connection = await self .connection_repo .create (
155164 session ,
156- ConnectionCreate (dsn = dsn , database = database , name = name , dialect = dialect , type = type , is_sample = is_sample ),
165+ ConnectionCreate (
166+ dsn = dsn ,
167+ database = db ._engine .url .database ,
168+ name = name ,
169+ dialect = db .dialect ,
170+ type = connection_type ,
171+ is_sample = is_sample ,
172+ options = ConnectionOptions (schemas = connection_schemas ),
173+ ),
157174 )
158175 return ConnectionOut .model_validate (connection )
159176
@@ -186,7 +203,9 @@ async def create_csv_connection(self, session: AsyncSession, file: UploadFile, n
186203
187204 # Create connection with the locally copied file
188205 dsn = get_sqlite_dsn (str (file_path .absolute ()))
189- return await self .create_connection (session , dsn = dsn , name = name , type = ConnectionType .csv .value , is_sample = False )
206+ return await self .create_connection (
207+ session , dsn = dsn , name = name , connection_type = ConnectionType .csv .value , is_sample = False
208+ )
190209
191210 async def create_excel_connection (self , session : AsyncSession , file : UploadFile , name : str ) -> ConnectionOut :
192211 generated_name = generate_short_uuid () + ".sqlite"
@@ -201,7 +220,7 @@ async def create_excel_connection(self, session: AsyncSession, file: UploadFile,
201220 # Create connection with the locally copied file
202221 dsn = get_sqlite_dsn (str (file_path .absolute ()))
203222 return await self .create_connection (
204- session , dsn = dsn , name = name , type = ConnectionType .excel .value , is_sample = False
223+ session , dsn = dsn , name = name , connection_type = ConnectionType .excel .value , is_sample = False
205224 )
206225
207226 async def create_sas7bdat_connection (self , session : AsyncSession , file : UploadFile , name : str ) -> ConnectionOut :
@@ -244,7 +263,7 @@ async def create_sas7bdat_connection(self, session: AsyncSession, file: UploadFi
244263 # Create connection with the locally copied file
245264 dsn = get_sqlite_dsn (str (file_path .absolute ()))
246265 return await self .create_connection (
247- session , dsn = dsn , name = name , type = ConnectionType .sas .value , is_sample = False
266+ session , dsn = dsn , name = name , connection_type = ConnectionType .sas .value , is_sample = False
248267 )
249268 finally :
250269 # Clean up the temporary file
0 commit comments