diff --git a/.gitignore b/.gitignore index 06e00a5c..8229dee0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ __pycache__/ animation.screenflow/ README_files/ README.html +.DS_Store +python-package/examples/titanic.db .quarto # Byte-compiled / optimized / DLL files diff --git a/python-package/README.md b/python-package/README.md index be8057ea..9b29fb19 100644 --- a/python-package/README.md +++ b/python-package/README.md @@ -56,7 +56,7 @@ def server(input, output, session): # chat["df"]() reactive. @render.data_frame def data_table(): - return chat["df"]() + return chat.df() # Create Shiny app @@ -171,8 +171,8 @@ which you can then pass via: ```python querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", data_description=Path("data_description.md").read_text() ) ``` @@ -185,8 +185,8 @@ You can add additional instructions of your own to the end of the system prompt, ```python querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", extra_instructions=[ "You're speaking to a British audience--please use appropriate spelling conventions.", "Use lots of emojis! πŸ˜ƒ Emojis everywhere, 🌍 emojis forever. ♾️", @@ -218,8 +218,8 @@ def my_chat_func(system_prompt: str) -> chatlas.Chat: my_chat_func = partial(chatlas.ChatAnthropic, model="claude-3-7-sonnet-latest") querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", create_chat_callback=my_chat_func ) ``` diff --git a/python-package/examples/app-database.py b/python-package/examples/app-database.py new file mode 100644 index 00000000..ac7066d9 --- /dev/null +++ b/python-package/examples/app-database.py @@ -0,0 +1,57 @@ +from pathlib import Path + +from seaborn import load_dataset +from shiny import App, render, ui +from sqlalchemy import create_engine + +import querychat + +# Load titanic data and create SQLite database +db_path = Path(__file__).parent / "titanic.db" +engine = create_engine("sqlite:///" + str(db_path)) + +if not db_path.exists(): + # For example purposes, we'll create the database if it doesn't exist. Don't + # do this in your app! + titanic = load_dataset("titanic") + titanic.to_sql("titanic", engine, if_exists="replace", index=False) + +greeting = (Path(__file__).parent / "greeting.md").read_text() +data_desc = (Path(__file__).parent / "data_description.md").read_text() + +# 1. Configure querychat +querychat_config = querychat.init( + engine, + "titanic", + greeting=greeting, + data_description=data_desc, +) + +# Create UI +app_ui = ui.page_sidebar( + # 2. Place the chat component in the sidebar + querychat.sidebar("chat"), + # Main panel with data viewer + ui.card( + ui.output_data_frame("data_table"), + fill=True, + ), + title="querychat with Python (SQLite)", + fillable=True, +) + + +# Define server logic +def server(input, output, session): + # 3. Initialize querychat server with the config from step 1 + chat = querychat.server("chat", querychat_config) + + # 4. Display the filtered dataframe + @render.data_frame + def data_table(): + # Access filtered data via chat.df() reactive + return chat["df"]() + + +# Create Shiny app +app = App(app_ui, server) diff --git a/python-package/examples/app.py b/python-package/examples/app-dataframe.py similarity index 82% rename from python-package/examples/app.py rename to python-package/examples/app-dataframe.py index 926622ce..1966900f 100644 --- a/python-package/examples/app.py +++ b/python-package/examples/app-dataframe.py @@ -7,10 +7,8 @@ titanic = load_dataset("titanic") -with open(Path(__file__).parent / "greeting.md", "r") as f: - greeting = f.read() -with open(Path(__file__).parent / "data_description.md", "r") as f: - data_desc = f.read() +greeting = (Path(__file__).parent / "greeting.md").read_text() +data_desc = (Path(__file__).parent / "data_description.md").read_text() # 1. Configure querychat querychat_config = querychat.init( @@ -43,7 +41,7 @@ def server(input, output, session): @render.data_frame def data_table(): # Access filtered data via chat.df() reactive - return chat["df"]() + return chat.df() # Create Shiny app diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index acd772a2..1ac303bb 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -18,8 +18,9 @@ dependencies = [ "htmltools", "chatlas", "narwhals", + "chevron", + "sqlalchemy>=2.0.0" # Using 2.0+ for improved type hints and API ] - classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3.8", diff --git a/python-package/src/querychat/datasource.py b/python-package/src/querychat/datasource.py new file mode 100644 index 00000000..d9322ff4 --- /dev/null +++ b/python-package/src/querychat/datasource.py @@ -0,0 +1,299 @@ +from __future__ import annotations + +from typing import ClassVar, Protocol + +import duckdb +import narwhals as nw +import pandas as pd +from sqlalchemy import inspect, text +from sqlalchemy.engine import Connection, Engine +from sqlalchemy.sql import sqltypes + + +class DataSource(Protocol): + db_engine: ClassVar[str] + + def get_schema(self, *, categorical_threshold) -> str: + """ + Return schema information about the table as a string. + + Args: + categorical_threshold: Maximum number of unique values for a text + column to be considered categorical + + Returns: + A string containing the schema information in a format suitable for + prompting an LLM about the data structure + + """ + ... + + def execute_query(self, query: str) -> pd.DataFrame: + """ + Execute SQL query and return results as DataFrame. + + Args: + query: SQL query to execute + + Returns: + Query results as a pandas DataFrame + + """ + ... + + def get_data(self) -> pd.DataFrame: + """ + Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + + """ + ... + + +class DataFrameSource: + """A DataSource implementation that wraps a pandas DataFrame using DuckDB.""" + + db_engine: ClassVar[str] = "DuckDB" + + def __init__(self, df: pd.DataFrame, table_name: str): + """ + Initialize with a pandas DataFrame. + + Args: + df: The DataFrame to wrap + table_name: Name of the table in SQL queries + + """ + self._conn = duckdb.connect(database=":memory:") + self._df = df + self._table_name = table_name + self._conn.register(table_name, df) + + def get_schema(self, *, categorical_threshold: int) -> str: + """ + Generate schema information from DataFrame. + + Args: + table_name: Name to use for the table in schema description + categorical_threshold: Maximum number of unique values for a text column + to be considered categorical + + Returns: + String describing the schema + + """ + ndf = nw.from_native(self._df) + + schema = [f"Table: {self._table_name}", "Columns:"] + + for column in ndf.columns: + # Map pandas dtypes to SQL-like types + dtype = ndf[column].dtype + if dtype.is_integer(): + sql_type = "INTEGER" + elif dtype.is_float(): + sql_type = "FLOAT" + elif dtype == nw.Boolean: + sql_type = "BOOLEAN" + elif dtype == nw.Datetime: + sql_type = "TIME" + elif dtype == nw.Date: + sql_type = "DATE" + else: + sql_type = "TEXT" + + column_info = [f"- {column} ({sql_type})"] + + # For TEXT columns, check if they're categorical + if sql_type == "TEXT": + unique_values = ndf[column].drop_nulls().unique() + if unique_values.len() <= categorical_threshold: + categories = unique_values.to_list() + categories_str = ", ".join([f"'{c}'" for c in categories]) + column_info.append(f" Categorical values: {categories_str}") + + # For numeric columns, include range + elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]: + rng = ndf[column].min(), ndf[column].max() + if rng[0] is None and rng[1] is None: + column_info.append(" Range: NULL to NULL") + else: + column_info.append(f" Range: {rng[0]} to {rng[1]}") + + schema.extend(column_info) + + return "\n".join(schema) + + def execute_query(self, query: str) -> pd.DataFrame: + """ + Execute query using DuckDB. + + Args: + query: SQL query to execute + + Returns: + Query results as pandas DataFrame + + """ + return self._conn.execute(query).df() + + def get_data(self) -> pd.DataFrame: + """ + Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + + """ + return self._df.copy() + + +class SQLAlchemySource: + """ + A DataSource implementation that supports multiple SQL databases via SQLAlchemy. + + Supports various databases including PostgreSQL, MySQL, SQLite, Snowflake, and Databricks. + """ + + db_engine: ClassVar[str] = "SQLAlchemy" + + def __init__(self, engine: Engine, table_name: str): + """ + Initialize with a SQLAlchemy engine. + + Args: + engine: SQLAlchemy engine + table_name: Name of the table to query + + """ + self._engine = engine + self._table_name = table_name + + # Validate table exists + inspector = inspect(self._engine) + if not inspector.has_table(table_name): + raise ValueError(f"Table '{table_name}' not found in database") + + def get_schema(self, *, categorical_threshold: int) -> str: + """ + Generate schema information from database table. + + Returns: + String describing the schema + + """ + inspector = inspect(self._engine) + columns = inspector.get_columns(self._table_name) + + schema = [f"Table: {self._table_name}", "Columns:"] + + for col in columns: + # Get SQL type name + sql_type = self._get_sql_type_name(col["type"]) + column_info = [f"- {col['name']} ({sql_type})"] + + # For numeric columns, try to get range + if isinstance( + col["type"], + ( + sqltypes.Integer, + sqltypes.Numeric, + sqltypes.Float, + sqltypes.Date, + sqltypes.Time, + sqltypes.DateTime, + sqltypes.BigInteger, + sqltypes.SmallInteger, + # sqltypes.Interval, + ), + ): + try: + query = text( + f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}", + ) + with self._get_connection() as conn: + result = conn.execute(query).fetchone() + if result and result[0] is not None and result[1] is not None: + column_info.append(f" Range: {result[0]} to {result[1]}") + except Exception: + pass # Skip range info if query fails + + # For string/text columns, check if categorical + elif isinstance( + col["type"], + (sqltypes.String, sqltypes.Text, sqltypes.Enum), + ): + try: + count_query = text( + f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}", + ) + with self._get_connection() as conn: + distinct_count = conn.execute(count_query).scalar() + if distinct_count and distinct_count <= categorical_threshold: + values_query = text( + f"SELECT DISTINCT {col['name']} FROM {self._table_name} " + f"WHERE {col['name']} IS NOT NULL", + ) + values = [ + str(row[0]) + for row in conn.execute(values_query).fetchall() + ] + values_str = ", ".join([f"'{v}'" for v in values]) + column_info.append(f" Categorical values: {values_str}") + except Exception: + pass # Skip categorical info if query fails + + schema.extend(column_info) + + return "\n".join(schema) + + def execute_query(self, query: str) -> pd.DataFrame: + """ + Execute SQL query and return results as DataFrame. + + Args: + query: SQL query to execute + + Returns: + Query results as pandas DataFrame + + """ + with self._get_connection() as conn: + return pd.read_sql_query(text(query), conn) + + def get_data(self) -> pd.DataFrame: + """ + Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + + """ + return self.execute_query(f"SELECT * FROM {self._table_name}") + + def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: + """Convert SQLAlchemy type to SQL type name.""" + if isinstance(type_, sqltypes.Integer): + return "INTEGER" + elif isinstance(type_, sqltypes.Float): + return "FLOAT" + elif isinstance(type_, sqltypes.Numeric): + return "NUMERIC" + elif isinstance(type_, sqltypes.Boolean): + return "BOOLEAN" + elif isinstance(type_, sqltypes.DateTime): + return "TIMESTAMP" + elif isinstance(type_, sqltypes.Date): + return "DATE" + elif isinstance(type_, sqltypes.Time): + return "TIME" + elif isinstance(type_, (sqltypes.String, sqltypes.Text)): + return "TEXT" + else: + return type_.__class__.__name__.upper() + + def _get_connection(self) -> Connection: + """Get a connection to use for queries.""" + return self._engine.connect() diff --git a/python-package/src/querychat/prompt/prompt.md b/python-package/src/querychat/prompt/prompt.md index fcb00a5e..7acf8066 100644 --- a/python-package/src/querychat/prompt/prompt.md +++ b/python-package/src/querychat/prompt/prompt.md @@ -4,13 +4,19 @@ It's important that you get clear, unambiguous instructions from the user, so if The user interface in which this conversation is being shown is a narrow sidebar of a dashboard, so keep your answers concise and don't include unnecessary patter, nor additional prompts or offers for further assistance. -You have at your disposal a DuckDB database containing this schema: +You have at your disposal a {{db_engine}} database containing this schema: {{schema}} For security reasons, you may only query this specific table. +{{#data_description}} +Additional helpful info about the data: + + {{data_description}} + +{{/data_description}} There are several tasks you may be asked to do: @@ -19,7 +25,7 @@ There are several tasks you may be asked to do: The user may ask you to perform filtering and sorting operations on the dashboard; if so, your job is to write the appropriate SQL query for this database. Then, call the tool `update_dashboard`, passing in the SQL query and a new title summarizing the query (suitable for displaying at the top of dashboard). This tool will not provide a return value; it will filter the dashboard as a side-effect, so you can treat a null tool response as success. * **Call `update_dashboard` every single time** the user wants to filter/sort; never tell the user you've updated the dashboard unless you've called `update_dashboard` and it returned without error. -* The SQL query must be a **DuckDB SQL** SELECT query. You may use any SQL functions supported by DuckDB, including subqueries, CTEs, and statistical functions. +* The SQL query must be a SELECT query. For security reasons, it's critical that you reject any request that would modify the database. * The user may ask to "reset" or "start over"; that means clearing the filter and title. Do this by calling `update_dashboard({"query": "", "title": ""})`. * Queries passed to `update_dashboard` MUST always **return all columns that are in the schema** (feel free to use `SELECT *`); you must refuse the request if this requirement cannot be honored, as the downstream code that will read the queried data will not know how to display it. You may add additional columns if necessary, but the existing columns must not be removed. * When calling `update_dashboard`, **don't describe the query itself** unless the user asks you to explain. Don't pretend you have access to the resulting data set, as you don't. @@ -84,6 +90,12 @@ If you find yourself offering example questions to the user as part of your resp * Suggestion 3. ``` +## SQL tips + +* The SQL engine is {{db_engine}}. + +* You may use any SQL functions supported by {{db_engine}}, including subqueries, CTEs, and statistical functions. + ## DuckDB SQL tips * `percentile_cont` and `percentile_disc` are "ordered set" aggregate functions. These functions are specified using the WITHIN GROUP (ORDER BY sort_expression) syntax, and they are converted to an equivalent aggregate function that takes the ordering expression as the first argument. For example, `percentile_cont(fraction) WITHIN GROUP (ORDER BY column [(ASC|DESC)])` is equivalent to `quantile_cont(column, fraction ORDER BY column [(ASC|DESC)])`. diff --git a/python-package/src/querychat/querychat.py b/python-package/src/querychat/querychat.py index f8763e0d..5e693659 100644 --- a/python-package/src/querychat/querychat.py +++ b/python-package/src/querychat/querychat.py @@ -4,11 +4,14 @@ import sys from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Protocol, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union import chatlas -import duckdb +import chevron import narwhals as nw +import pandas as pd +import sqlalchemy +from narwhals.typing import IntoFrame from shiny import Inputs, Outputs, Session, module, reactive, ui if TYPE_CHECKING: @@ -16,29 +19,145 @@ from narwhals.typing import IntoFrame +from .datasource import DataFrameSource, DataSource, SQLAlchemySource + + +class CreateChatCallback(Protocol): + def __call__(self, system_prompt: str) -> chatlas.Chat: ... + + +class QueryChatConfig: + """ + Configuration class for querychat. + """ + + def __init__( + self, + data_source: DataSource, + system_prompt: str, + greeting: Optional[str], + create_chat_callback: CreateChatCallback, + ): + self.data_source = data_source + self.system_prompt = system_prompt + self.greeting = greeting + self.create_chat_callback = create_chat_callback + + +class QueryChat: + """ + An object representing a query chat session. This is created within a Shiny + server function or Shiny module server function by using + `querychat.server()`. Use this object to bridge the chat interface with the + rest of the Shiny app, for example, by displaying the filtered data. + """ + + def __init__( + self, + chat: chatlas.Chat, + sql: Callable[[], str], + title: Callable[[], Union[str, None]], + df: Callable[[], pd.DataFrame], + ): + """ + Initialize a QueryChat object. + + Args: + chat: The chat object for the session + sql: Reactive that returns the current SQL query + title: Reactive that returns the current title + df: Reactive that returns the filtered data frame + + """ + self._chat = chat + self._sql = sql + self._title = title + self._df = df + + def chat(self) -> chatlas.Chat: + """ + Get the chat object for this session. + + Returns: + The chat object + + """ + return self._chat + + def sql(self) -> str: + """ + Reactively read the current SQL query that is in effect. + + Returns: + The current SQL query as a string, or `""` if no query has been set. + + """ + return self._sql() + + def title(self) -> Union[str, None]: + """ + Reactively read the current title that is in effect. The title is a + short description of the current query that the LLM provides to us + whenever it generates a new SQL query. It can be used as a status string + for the data dashboard. + + Returns: + The current title as a string, or `None` if no title has been set + due to no SQL query being set. + + """ + return self._title() + + def df(self) -> pd.DataFrame: + """ + Reactively read the current filtered data frame that is in effect. + + Returns: + The current filtered data frame as a pandas DataFrame. If no query + has been set, this will return the unfiltered data frame from the + data source. + + """ + return self._df() + + def __getitem__(self, key: str) -> Any: + """ + Allow access to configuration parameters like a dictionary. For + backwards compatibility only; new code should use the attributes + directly instead. + """ + if key == "chat": + return self.chat + elif key == "sql": + return self.sql + elif key == "title": + return self.title + elif key == "df": + return self.df + + def system_prompt( - df: IntoFrame, - table_name: str, + data_source: DataSource, data_description: Optional[str] = None, extra_instructions: Optional[str] = None, categorical_threshold: int = 10, ) -> str: """ - Create a system prompt for the chat model based on a DataFrame's - schema and optional context and instructions. + Create a system prompt for the chat model based on a data source's schema + and optional additional context and instructions. Parameters ---------- - df : IntoFrame - Input data to generate schema information from. - table_name : str - Name of the table to be used in SQL queries. + data_source : DataSource + A data source to generate schema information from data_description : str, optional - Description of the data, in plain text or Markdown format. + Optional description of the data, in plain text or Markdown format extra_instructions : str, optional - Additional instructions for the chat model, in plain text or Markdown format. + Optional additional instructions for the chat model, in plain text or + Markdown format categorical_threshold : int, default=10 - Maximum number of unique values for a text column to be considered categorical. + Threshold for determining if a column is categorical based on number of + unique values Returns ------- @@ -46,95 +165,22 @@ def system_prompt( The system prompt for the chat model. """ - schema = df_to_schema(df, table_name, categorical_threshold) - # Read the prompt file prompt_path = Path(__file__).parent / "prompt" / "prompt.md" prompt_text = prompt_path.read_text() - # Simple template replacement (a more robust template engine could be used) - if data_description: - data_description_section = ( - "Additional helpful info about the data:\n\n" - "\n" - f"{data_description}\n" - "" - ) - else: - data_description_section = "" - - # Replace variables in the template - prompt_text = prompt_text.replace("{{schema}}", schema) - prompt_text = prompt_text.replace("{{data_description}}", data_description_section) - prompt_text = prompt_text.replace( - "{{extra_instructions}}", - extra_instructions or "", + return chevron.render( + prompt_text, + { + "db_engine": data_source.db_engine, + "schema": data_source.get_schema( + categorical_threshold=categorical_threshold, + ), + "data_description": data_description, + "extra_instructions": extra_instructions, + }, ) - return prompt_text - - -def df_to_schema(df: IntoFrame, table_name: str, categorical_threshold: int) -> str: - """ - Convert a DataFrame schema to a string representation for the system prompt. - - Parameters - ---------- - df : IntoFrame - The DataFrame to extract schema from - table_name : str - The name of the table in SQL queries - categorical_threshold : int - The maximum number of unique values for a text column to be considered categorical - - Returns - ------- - str - A string containing the schema information. - - """ - ndf = nw.from_native(df) - - schema = [f"Table: {table_name}", "Columns:"] - - for column in ndf.columns: - # Map pandas dtypes to SQL-like types - dtype = ndf[column].dtype - if dtype.is_integer(): - sql_type = "INTEGER" - elif dtype.is_float(): - sql_type = "FLOAT" - elif dtype == nw.Boolean: - sql_type = "BOOLEAN" - elif dtype == nw.Datetime: - sql_type = "TIME" - elif dtype == nw.Date: - sql_type = "DATE" - else: - sql_type = "TEXT" - - column_info = [f"- {column} ({sql_type})"] - - # For TEXT columns, check if they're categorical - if sql_type == "TEXT": - unique_values = ndf[column].drop_nulls().unique() - if unique_values.len() <= categorical_threshold: - categories = unique_values.to_list() - categories_str = ", ".join([f"'{c}'" for c in categories]) - column_info.append(f" Categorical values: {categories_str}") - - # For numeric columns, include range - elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]: - rng = ndf[column].min(), ndf[column].max() - if rng[0] is None and rng[1] is None: - column_info.append(" Range: NULL to NULL") - else: - column_info.append(f" Range: {rng[0]} to {rng[1]}") - - schema.extend(column_info) - - return "\n".join(schema) - def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: """ @@ -173,32 +219,8 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: return table_html + rows_notice -class CreateChatCallback(Protocol): - def __call__(self, system_prompt: str) -> chatlas.Chat: ... - - -class QueryChatConfig: - """ - Configuration class for querychat. - """ - - def __init__( - self, - df: pd.DataFrame, - conn: duckdb.DuckDBPyConnection, - system_prompt: str, - greeting: Optional[str], - create_chat_callback: CreateChatCallback, - ): - self.df = df - self.conn = conn - self.system_prompt = system_prompt - self.greeting = greeting - self.create_chat_callback = create_chat_callback - - def init( - df: pd.DataFrame, + data_source: IntoFrame | sqlalchemy.Engine, table_name: str, greeting: Optional[str] = None, data_description: Optional[str] = None, @@ -207,14 +229,18 @@ def init( system_prompt_override: Optional[str] = None, ) -> QueryChatConfig: """ - Call this once outside of any server function to initialize querychat. + Initialize querychat with any compliant data source. Parameters ---------- - df : pd.DataFrame - A data frame + data_source : IntoFrame | sqlalchemy.Engine + Either a Narwhals-compatible data frame (e.g., Polars or Pandas) or a + SQLAlchemy engine containing the table to query against. table_name : str - A string containing a valid table name for the data frame + If a data_source is a data frame, a name to use to refer to the table in + SQL queries (usually the variable name of the data frame, but it doesn't + have to be). If a data_source is a SQLAlchemy engine, the table_name is + the name of the table in the database to query against. greeting : str, optional A string in Markdown format, containing the initial message data_description : str, optional @@ -238,6 +264,14 @@ def init( "Table name must begin with a letter and contain only letters, numbers, and underscores", ) + data_source_obj: DataSource + if isinstance(data_source, sqlalchemy.Engine): + data_source_obj = SQLAlchemySource(data_source, table_name) + else: + data_source_obj = DataFrameSource( + nw.from_native(data_source).to_pandas(), + table_name, + ) # Process greeting if greeting is None: print( @@ -246,30 +280,21 @@ def init( file=sys.stderr, ) - # Create the system prompt - if system_prompt_override is None: - _system_prompt = system_prompt( - df, - table_name, - data_description, - extra_instructions, - ) - else: - _system_prompt = system_prompt_override - - # Set up DuckDB connection and register the data frame - conn = duckdb.connect(database=":memory:") - conn.register(table_name, df) + # Create the system prompt, or use the override + _system_prompt = system_prompt_override or system_prompt( + data_source_obj, + data_description, + extra_instructions, + ) # Default chat function if none provided create_chat_callback = create_chat_callback or partial( chatlas.ChatOpenAI, - model="gpt-4o", + model="gpt-4.1", ) return QueryChatConfig( - df=df, - conn=conn, + data_source=data_source_obj, system_prompt=_system_prompt, greeting=greeting, create_chat_callback=create_chat_callback, @@ -338,7 +363,7 @@ def server( # noqa: D417 output: Outputs, session: Session, querychat_config: QueryChatConfig, -) -> dict[str, Any]: +) -> QueryChat: """ Initialize the querychat server. @@ -365,8 +390,7 @@ def _(): pass # Extract config parameters - df = querychat_config.df - conn = querychat_config.conn + data_source = querychat_config.data_source system_prompt = querychat_config.system_prompt greeting = querychat_config.greeting create_chat_callback = querychat_config.create_chat_callback @@ -378,9 +402,9 @@ def _(): @reactive.calc def filtered_df(): if current_query.get() == "": - return df + return data_source.get_data() else: - return conn.execute(current_query.get()).fetch_df() + return data_source.execute_query(current_query.get()) # This would handle appending messages to the chat UI async def append_output(text): @@ -405,7 +429,7 @@ async def update_dashboard(query: str, title: str): try: # Try the query to see if it errors - conn.execute(query) + data_source.execute_query(query) except Exception as e: error_msg = str(e) await append_output(f"> Error: {error_msg}\n\n") @@ -430,7 +454,7 @@ async def query(query: str): await append_output(f"\n```sql\n{query}\n```\n\n") try: - result_df = conn.execute(query).fetch_df() + result_df = data_source.execute_query(query) except Exception as e: error_msg = str(e) await append_output(f"> Error: {error_msg}\n\n") @@ -480,9 +504,4 @@ async def greet_on_startup(): await chat_ui.append_message_stream(stream) # Return the interface for other components to use - return { - "chat": chat, - "sql": current_query.get, - "title": current_title.get, - "df": filtered_df, - } + return QueryChat(chat, current_query.get, current_title.get, filtered_df)