From 973a43301e47a89b290be2b5d6b9cdf0ba6fc253 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Thu, 3 Apr 2025 22:47:22 -0700 Subject: [PATCH 01/12] First attempt at genericizing data source --- python-package/examples/app.py | 4 +- python-package/querychat/datasource.py | 207 +++++++++++++++++++++++++ python-package/querychat/querychat.py | 162 +++++-------------- 3 files changed, 251 insertions(+), 122 deletions(-) create mode 100644 python-package/querychat/datasource.py diff --git a/python-package/examples/app.py b/python-package/examples/app.py index 926622ce..5e628f43 100644 --- a/python-package/examples/app.py +++ b/python-package/examples/app.py @@ -4,6 +4,7 @@ from shiny import App, render, ui import querychat +from querychat.datasource import DataFrameSource titanic = load_dataset("titanic") @@ -14,8 +15,7 @@ # 1. Configure querychat querychat_config = querychat.init( - titanic, - "titanic", + DataFrameSource(titanic, "titanic"), greeting=greeting, data_description=data_desc, ) diff --git a/python-package/querychat/datasource.py b/python-package/querychat/datasource.py new file mode 100644 index 00000000..495139ed --- /dev/null +++ b/python-package/querychat/datasource.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from typing import Protocol +import pandas as pd +import duckdb +import sqlite3 +import narwhals as nw + + +class DataSource(Protocol): + def get_schema(self) -> str: + """Return schema information about the table as a string. + + Returns: + A string containing the schema information in a format suitable for + prompting an LLM about the data structure + """ + ... + + def execute_query(self, query: str) -> pd.DataFrame: + """Execute SQL query and return results as DataFrame. + + Args: + query: SQL query to execute + + Returns: + Query results as a pandas DataFrame + """ + ... + + def get_data(self) -> pd.DataFrame: + """Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + """ + ... + + +class DataFrameSource: + """A DataSource implementation that wraps a pandas DataFrame using DuckDB.""" + + def __init__(self, df: pd.DataFrame, table_name: str): + """Initialize with a pandas DataFrame. + + Args: + df: The DataFrame to wrap + table_name: Name of the table in SQL queries + """ + self._conn = duckdb.connect(database=":memory:") + self._df = df + self._table_name = table_name + self._conn.register(table_name, df) + + def get_schema(self, categorical_threshold: int = 10) -> str: + """Generate schema information from DataFrame. + + Args: + table_name: Name to use for the table in schema description + categorical_threshold: Maximum number of unique values for a text column + to be considered categorical + + Returns: + String describing the schema + """ + ndf = nw.from_native(self._df) + + schema = [f"Table: {self._table_name}", "Columns:"] + + for column in ndf.columns: + # Map pandas dtypes to SQL-like types + dtype = ndf[column].dtype + if dtype.is_integer(): + sql_type = "INTEGER" + elif dtype.is_float(): + sql_type = "FLOAT" + elif dtype == nw.Boolean: + sql_type = "BOOLEAN" + elif dtype == nw.Datetime: + sql_type = "TIME" + elif dtype == nw.Date: + sql_type = "DATE" + else: + sql_type = "TEXT" + + column_info = [f"- {column} ({sql_type})"] + + # For TEXT columns, check if they're categorical + if sql_type == "TEXT": + unique_values = ndf[column].drop_nulls().unique() + if unique_values.len() <= categorical_threshold: + categories = unique_values.to_list() + categories_str = ", ".join([f"'{c}'" for c in categories]) + column_info.append(f" Categorical values: {categories_str}") + + # For numeric columns, include range + elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]: + rng = ndf[column].min(), ndf[column].max() + if rng[0] is None and rng[1] is None: + column_info.append(" Range: NULL to NULL") + else: + column_info.append(f" Range: {rng[0]} to {rng[1]}") + + schema.extend(column_info) + + return "\n".join(schema) + + def execute_query(self, query: str) -> pd.DataFrame: + """Execute query using DuckDB. + + Args: + query: SQL query to execute + + Returns: + Query results as pandas DataFrame + """ + return self._conn.execute(query).df() + + def get_data(self) -> pd.DataFrame: + """Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + """ + return self._df.copy() + + +class SQLiteSource: + """A DataSource implementation that wraps a SQLite connection.""" + + def __init__(self, conn: sqlite3.Connection, table_name: str): + """Initialize with a SQLite connection. + + Args: + conn: SQLite database connection + """ + self._conn = conn + self._table_name = table_name + + def get_schema(self) -> str: + """Generate schema information from SQLite table. + + Returns: + String describing the schema + """ + # Get column info + cursor = self._conn.execute(f"PRAGMA table_info({self._table_name})") + columns = cursor.fetchall() + + schema = [f"Table: {self._table_name}", "Columns:"] + + for col in columns: + # col format: (cid, name, type, notnull, dflt_value, pk) + column_info = [f"- {col[1]} ({col[2].upper()})"] + + # For numeric columns, try to get range + if col[2].upper() in ["INTEGER", "FLOAT", "REAL", "NUMERIC"]: + try: + cursor = self._conn.execute( + f"SELECT MIN({col[1]}), MAX({col[1]}) FROM {self._table_name}" + ) + min_val, max_val = cursor.fetchone() + if min_val is not None and max_val is not None: + column_info.append(f" Range: {min_val} to {max_val}") + except sqlite3.Error: + pass # Skip range info if query fails + + # For text columns, check if categorical (limited distinct values) + elif col[2].upper() == "TEXT": + try: + cursor = self._conn.execute( + f"SELECT COUNT(DISTINCT {col[1]}) FROM {self._table_name}" + ) + distinct_count = cursor.fetchone()[0] + if distinct_count <= 10: # Use fixed threshold for simplicity + cursor = self._conn.execute( + f"SELECT DISTINCT {col[1]} FROM {self._table_name} " + f"WHERE {col[1]} IS NOT NULL" + ) + values = [str(row[0]) for row in cursor.fetchall()] + values_str = ", ".join([f"'{v}'" for v in values]) + column_info.append(f" Categorical values: {values_str}") + except sqlite3.Error: + pass # Skip categorical info if query fails + + schema.extend(column_info) + + return "\n".join(schema) + + def execute_query(self, query: str) -> pd.DataFrame: + """Execute query using SQLite. + + Args: + query: SQL query to execute + + Returns: + Query results as pandas DataFrame + """ + return pd.read_sql_query(query, self._conn) + + def get_data(self) -> pd.DataFrame: + """Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + """ + return pd.read_sql_query(f"SELECT * FROM {self._table_name}", self._conn) diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index 4e492fb1..22b2f5ff 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -15,29 +15,49 @@ import narwhals as nw from narwhals.typing import IntoFrame +from .datasource import DataSource + + +class CreateChatCallback(Protocol): + def __call__(self, system_prompt: str) -> chatlas.Chat: ... + + +class QueryChatConfig: + """ + Configuration class for querychat. + """ + + def __init__( + self, + data_source: DataSource, + system_prompt: str, + greeting: Optional[str], + create_chat_callback: CreateChatCallback, + ): + self.data_source = data_source + self.system_prompt = system_prompt + self.greeting = greeting + self.create_chat_callback = create_chat_callback + def system_prompt( - df: IntoFrame, - table_name: str, + data_source: DataSource, data_description: Optional[str] = None, extra_instructions: Optional[str] = None, - categorical_threshold: int = 10, ) -> str: """ - Create a system prompt for the chat model based on a data frame's + Create a system prompt for the chat model based on a data source's schema and optional additional context and instructions. Args: - df: A DataFrame to generate schema information from - table_name: A string containing the name of the table in SQL queries + data_source: A data source to generate schema information from data_description: Optional description of the data, in plain text or Markdown format extra_instructions: Optional additional instructions for the chat model, in plain text or Markdown format - categorical_threshold: The maximum number of unique values for a text column to be considered categorical Returns: A string containing the system prompt for the chat model """ - schema = df_to_schema(df, table_name, categorical_threshold) + schema = data_source.get_schema() # Read the prompt file prompt_path = os.path.join(os.path.dirname(__file__), "prompt", "prompt.md") @@ -65,62 +85,6 @@ def system_prompt( return prompt_text -def df_to_schema(df: IntoFrame, table_name: str, categorical_threshold: int) -> str: - """ - Convert a DataFrame schema to a string representation for the system prompt. - - Args: - df: The DataFrame to extract schema from - table_name: The name of the table in SQL queries - categorical_threshold: The maximum number of unique values for a text column to be considered categorical - - Returns: - A string containing the schema information - """ - - ndf = nw.from_native(df) - - schema = [f"Table: {table_name}", "Columns:"] - - for column in ndf.columns: - # Map pandas dtypes to SQL-like types - dtype = ndf[column].dtype - if dtype.is_integer(): - sql_type = "INTEGER" - elif dtype.is_float(): - sql_type = "FLOAT" - elif dtype == nw.Boolean: - sql_type = "BOOLEAN" - elif dtype == nw.Datetime: - sql_type = "TIME" - elif dtype == nw.Date: - sql_type = "DATE" - else: - sql_type = "TEXT" - - column_info = [f"- {column} ({sql_type})"] - - # For TEXT columns, check if they're categorical - if sql_type == "TEXT": - unique_values = ndf[column].drop_nulls().unique() - if unique_values.len() <= categorical_threshold: - categories = unique_values.to_list() - categories_str = ", ".join([f"'{c}'" for c in categories]) - column_info.append(f" Categorical values: {categories_str}") - - # For numeric columns, include range - elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]: - rng = ndf[column].min(), ndf[column].max() - if rng[0] is None and rng[1] is None: - column_info.append(" Range: NULL to NULL") - else: - column_info.append(f" Range: {rng[0]} to {rng[1]}") - - schema.extend(column_info) - - return "\n".join(schema) - - def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: """ Convert a DataFrame to an HTML table for display in chat. @@ -149,45 +113,18 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: return table_html + rows_notice -class CreateChatCallback(Protocol): - def __call__(self, system_prompt: str) -> chatlas.Chat: ... - - -class QueryChatConfig: - """ - Configuration class for querychat. - """ - - def __init__( - self, - df: pd.DataFrame, - conn: duckdb.DuckDBPyConnection, - system_prompt: str, - greeting: Optional[str], - create_chat_callback: CreateChatCallback, - ): - self.df = df - self.conn = conn - self.system_prompt = system_prompt - self.greeting = greeting - self.create_chat_callback = create_chat_callback - - def init( - df: pd.DataFrame, - table_name: str, + data_source: DataSource, greeting: Optional[str] = None, data_description: Optional[str] = None, extra_instructions: Optional[str] = None, create_chat_callback: Optional[CreateChatCallback] = None, system_prompt_override: Optional[str] = None, ) -> QueryChatConfig: - """ - Call this once outside of any server function to initialize querychat. + """Initialize querychat with any compliant data source. Args: - df: A data frame - table_name: A string containing a valid table name for the data frame + data_source: A DataSource implementation that provides schema and query execution greeting: A string in Markdown format, containing the initial message data_description: Description of the data in plain text or Markdown extra_instructions: Additional instructions for the chat model @@ -197,12 +134,6 @@ def init( Returns: A QueryChatConfig object that can be passed to server() """ - # Validate table name (must begin with letter, contain only letters, numbers, underscores) - if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name): - raise ValueError( - "Table name must begin with a letter and contain only letters, numbers, and underscores" - ) - # Process greeting if greeting is None: print( @@ -211,26 +142,18 @@ def init( file=sys.stderr, ) - # Create the system prompt - if system_prompt_override is None: - _system_prompt = system_prompt( - df, table_name, data_description, extra_instructions - ) - else: - _system_prompt = system_prompt_override - - # Set up DuckDB connection and register the data frame - conn = duckdb.connect(database=":memory:") - conn.register(table_name, df) + # Create the system prompt, or use the override + _system_prompt = system_prompt_override or system_prompt( + data_source, data_description, extra_instructions + ) # Default chat function if none provided create_chat_callback = create_chat_callback or partial( - chatlas.ChatOpenAI, model="gpt-4o" + chatlas.ChatOpenAI, model="gpt-4" ) return QueryChatConfig( - df=df, - conn=conn, + data_source=data_source, system_prompt=_system_prompt, greeting=greeting, create_chat_callback=create_chat_callback, @@ -306,8 +229,7 @@ def _(): pass # Extract config parameters - df = querychat_config.df - conn = querychat_config.conn + data_source = querychat_config.data_source system_prompt = querychat_config.system_prompt greeting = querychat_config.greeting create_chat_callback = querychat_config.create_chat_callback @@ -319,9 +241,9 @@ def _(): @reactive.Calc def filtered_df(): if current_query.get() == "": - return df + return data_source.get_data() else: - return conn.execute(current_query.get()).fetch_df() + return data_source.execute_query(current_query.get()) # This would handle appending messages to the chat UI async def append_output(text): @@ -345,7 +267,7 @@ async def update_dashboard(query: str, title: str): try: # Try the query to see if it errors - conn.execute(query) + data_source.execute_query(query) except Exception as e: error_msg = str(e) await append_output(f"> Error: {error_msg}\n\n") @@ -370,7 +292,7 @@ async def query(query: str): await append_output(f"\n```sql\n{query}\n```\n\n") try: - result_df = conn.execute(query).fetch_df() + result_df = data_source.execute_query(query) except Exception as e: error_msg = str(e) await append_output(f"> Error: {error_msg}\n\n") From 8de0ac71d3e687ec66151b7e977ced697f2a590a Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Fri, 4 Apr 2025 08:53:21 -0700 Subject: [PATCH 02/12] Unify prompts by adding chevron Python dependency --- python-package/pyproject.toml | 1 + python-package/querychat/prompt/prompt.md | 6 ++++ python-package/querychat/querychat.py | 39 +++++++---------------- 3 files changed, 18 insertions(+), 28 deletions(-) diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index c709ee05..dca3b063 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "htmltools", "chatlas", "narwhals", + "chevron", ] [project.urls] diff --git a/python-package/querychat/prompt/prompt.md b/python-package/querychat/prompt/prompt.md index 62d1ea17..154ce0cc 100644 --- a/python-package/querychat/prompt/prompt.md +++ b/python-package/querychat/prompt/prompt.md @@ -10,7 +10,13 @@ You have at your disposal a DuckDB database containing this schema: For security reasons, you may only query this specific table. +{{#data_description}} +Additional helpful info about the data: + + {{data_description}} + +{{/data_description}} There are several tasks you may be asked to do: diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index 22b2f5ff..37af66e1 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -1,19 +1,13 @@ from __future__ import annotations -import sys import os -import re -import pandas as pd -import duckdb -import json +import sys from functools import partial -from typing import List, Dict, Any, Callable, Optional, Union, Protocol +from typing import Any, Dict, Optional, Protocol import chatlas -from htmltools import TagList, tags, HTML -from shiny import module, reactive, ui, Inputs, Outputs, Session -import narwhals as nw -from narwhals.typing import IntoFrame +import chevron +from shiny import Inputs, Outputs, Session, module, reactive, ui from .datasource import DataSource @@ -64,26 +58,15 @@ def system_prompt( with open(prompt_path, "r") as f: prompt_text = f.read() - # Simple template replacement (a more robust template engine could be used) - if data_description: - data_description_section = ( - "Additional helpful info about the data:\n\n" - "\n" - f"{data_description}\n" - "" - ) - else: - data_description_section = "" - - # Replace variables in the template - prompt_text = prompt_text.replace("{{schema}}", schema) - prompt_text = prompt_text.replace("{{data_description}}", data_description_section) - prompt_text = prompt_text.replace( - "{{extra_instructions}}", extra_instructions or "" + return chevron.render( + prompt_text, + { + "schema": schema, + "data_description": data_description, + "extra_instructions": extra_instructions, + }, ) - return prompt_text - def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: """ From 53c7df3ddeda8b07f534a906165b04205ef83b31 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Fri, 18 Apr 2025 14:03:24 -0700 Subject: [PATCH 03/12] Make prompt aware of what engine is being used --- python-package/querychat/datasource.py | 13 ++++++++++--- python-package/querychat/prompt/prompt.md | 10 +++++++--- python-package/querychat/querychat.py | 4 ++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/python-package/querychat/datasource.py b/python-package/querychat/datasource.py index 495139ed..e408e4b0 100644 --- a/python-package/querychat/datasource.py +++ b/python-package/querychat/datasource.py @@ -1,13 +1,16 @@ from __future__ import annotations -from typing import Protocol -import pandas as pd -import duckdb import sqlite3 +from typing import ClassVar, Protocol + +import duckdb import narwhals as nw +import pandas as pd class DataSource(Protocol): + db_engine: ClassVar[str] + def get_schema(self) -> str: """Return schema information about the table as a string. @@ -40,6 +43,8 @@ def get_data(self) -> pd.DataFrame: class DataFrameSource: """A DataSource implementation that wraps a pandas DataFrame using DuckDB.""" + db_engine: ClassVar[str] = "DuckDB" + def __init__(self, df: pd.DataFrame, table_name: str): """Initialize with a pandas DataFrame. @@ -128,6 +133,8 @@ def get_data(self) -> pd.DataFrame: class SQLiteSource: """A DataSource implementation that wraps a SQLite connection.""" + db_engine: ClassVar[str] = "SQLite" + def __init__(self, conn: sqlite3.Connection, table_name: str): """Initialize with a SQLite connection. diff --git a/python-package/querychat/prompt/prompt.md b/python-package/querychat/prompt/prompt.md index 154ce0cc..5155ae18 100644 --- a/python-package/querychat/prompt/prompt.md +++ b/python-package/querychat/prompt/prompt.md @@ -4,7 +4,7 @@ It's important that you get clear, unambiguous instructions from the user, so if The user interface in which this conversation is being shown is a narrow sidebar of a dashboard, so keep your answers concise and don't include unnecessary patter, nor additional prompts or offers for further assistance. -You have at your disposal a DuckDB database containing this schema: +You have at your disposal a {{db_engine}} database containing this schema: {{schema}} @@ -25,7 +25,7 @@ There are several tasks you may be asked to do: The user may ask you to perform filtering and sorting operations on the dashboard; if so, your job is to write the appropriate SQL query for this database. Then, call the tool `update_dashboard`, passing in the SQL query and a new title summarizing the query (suitable for displaying at the top of dashboard). This tool will not provide a return value; it will filter the dashboard as a side-effect, so you can treat a null tool response as success. * **Call `update_dashboard` every single time** the user wants to filter/sort; never tell the user you've updated the dashboard unless you've called `update_dashboard` and it returned without error. -* The SQL query must be a **DuckDB SQL** SELECT query. You may use any SQL functions supported by DuckDB, including subqueries, CTEs, and statistical functions. +* The SQL query must be a SELECT query. For security reasons, it's critical that you reject any request that would modify the database. * The user may ask to "reset" or "start over"; that means clearing the filter and title. Do this by calling `update_dashboard({"query": "", "title": ""})`. * Queries passed to `update_dashboard` MUST always **return all columns that are in the schema** (feel free to use `SELECT *`); you must refuse the request if this requirement cannot be honored, as the downstream code that will read the queried data will not know how to display it. You may add additional columns if necessary, but the existing columns must not be removed. * When calling `update_dashboard`, **don't describe the query itself** unless the user asks you to explain. Don't pretend you have access to the resulting data set, as you don't. @@ -80,7 +80,11 @@ Example of question answering: If the user provides a vague help request, like "Help" or "Show me instructions", describe your own capabilities in a helpful way, including examples of questions they can ask. Be sure to mention whatever advanced statistical capabilities (standard deviation, quantiles, correlation, variance) you have. -## DuckDB SQL tips +## SQL tips + +* The SQL engine is {{db_engine}}. + +* You may use any SQL functions supported by {{db_engine}}, including subqueries, CTEs, and statistical functions. * `percentile_cont` and `percentile_disc` are "ordered set" aggregate functions. These functions are specified using the WITHIN GROUP (ORDER BY sort_expression) syntax, and they are converted to an equivalent aggregate function that takes the ordering expression as the first argument. For example, `percentile_cont(fraction) WITHIN GROUP (ORDER BY column [(ASC|DESC)])` is equivalent to `quantile_cont(column, fraction ORDER BY column [(ASC|DESC)])`. diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index 37af66e1..fb0e6997 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -51,7 +51,6 @@ def system_prompt( Returns: A string containing the system prompt for the chat model """ - schema = data_source.get_schema() # Read the prompt file prompt_path = os.path.join(os.path.dirname(__file__), "prompt", "prompt.md") @@ -61,7 +60,8 @@ def system_prompt( return chevron.render( prompt_text, { - "schema": schema, + "db_engine": data_source.db_engine, + "schema": data_source.get_schema(), "data_description": data_description, "extra_instructions": extra_instructions, }, From a2122f22da9233ce6edc3ece5ac12440e5a35f63 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Fri, 18 Apr 2025 14:37:45 -0700 Subject: [PATCH 04/12] Replace SQLite support with SQLAlchemy support --- python-package/pyproject.toml | 5 + python-package/querychat/datasource.py | 133 ++++++++++++++++++------- 2 files changed, 100 insertions(+), 38 deletions(-) diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index dca3b063..4ca437a2 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -23,6 +23,11 @@ dependencies = [ "chevron", ] +[project.optional-dependencies] +sqlalchemy = [ + "sqlalchemy>=2.0.0", # Using 2.0+ for improved type hints and API +] + [project.urls] Homepage = "https://github.com/posit-dev/querychat" Issues = "https://github.com/posit-dev/querychat/issues" diff --git a/python-package/querychat/datasource.py b/python-package/querychat/datasource.py index e408e4b0..e33711e7 100644 --- a/python-package/querychat/datasource.py +++ b/python-package/querychat/datasource.py @@ -1,11 +1,13 @@ from __future__ import annotations -import sqlite3 from typing import ClassVar, Protocol import duckdb import narwhals as nw import pandas as pd +from sqlalchemy import inspect, text +from sqlalchemy.engine import Engine, Connection +from sqlalchemy.sql import sqltypes class DataSource(Protocol): @@ -130,64 +132,93 @@ def get_data(self) -> pd.DataFrame: return self._df.copy() -class SQLiteSource: - """A DataSource implementation that wraps a SQLite connection.""" +class SQLAlchemySource: + """A DataSource implementation that supports multiple SQL databases via SQLAlchemy. - db_engine: ClassVar[str] = "SQLite" + Supports various databases including PostgreSQL, MySQL, SQLite, Snowflake, and Databricks. + """ - def __init__(self, conn: sqlite3.Connection, table_name: str): - """Initialize with a SQLite connection. + db_engine: ClassVar[str] = "SQLAlchemy" + + def __init__(self, engine: Engine, table_name: str): + """Initialize with a SQLAlchemy engine. Args: - conn: SQLite database connection + engine: SQLAlchemy engine + table_name: Name of the table to query """ - self._conn = conn + self._engine = engine self._table_name = table_name + # Validate table exists + inspector = inspect(self._engine) + if table_name not in inspector.get_table_names(): + raise ValueError(f"Table '{table_name}' not found in database") + def get_schema(self) -> str: - """Generate schema information from SQLite table. + """Generate schema information from database table. Returns: String describing the schema """ - # Get column info - cursor = self._conn.execute(f"PRAGMA table_info({self._table_name})") - columns = cursor.fetchall() + inspector = inspect(self._engine) + columns = inspector.get_columns(self._table_name) schema = [f"Table: {self._table_name}", "Columns:"] for col in columns: - # col format: (cid, name, type, notnull, dflt_value, pk) - column_info = [f"- {col[1]} ({col[2].upper()})"] + # Get SQL type name + sql_type = self._get_sql_type_name(col["type"]) + column_info = [f"- {col['name']} ({sql_type})"] # For numeric columns, try to get range - if col[2].upper() in ["INTEGER", "FLOAT", "REAL", "NUMERIC"]: + if isinstance( + col["type"], + ( + sqltypes.Integer, + sqltypes.Numeric, + sqltypes.Float, + sqltypes.Date, + sqltypes.Time, + sqltypes.DateTime, + sqltypes.BigInteger, + sqltypes.SmallInteger, + # sqltypes.Interval, + ), + ): try: - cursor = self._conn.execute( - f"SELECT MIN({col[1]}), MAX({col[1]}) FROM {self._table_name}" + query = text( + f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}" ) - min_val, max_val = cursor.fetchone() - if min_val is not None and max_val is not None: - column_info.append(f" Range: {min_val} to {max_val}") - except sqlite3.Error: + with self._get_connection() as conn: + result = conn.execute(query).fetchone() + if result and result[0] is not None and result[1] is not None: + column_info.append(f" Range: {result[0]} to {result[1]}") + except Exception: pass # Skip range info if query fails - # For text columns, check if categorical (limited distinct values) - elif col[2].upper() == "TEXT": + # For string/text columns, check if categorical + elif isinstance( + col["type"], (sqltypes.String, sqltypes.Text, sqltypes.Enum) + ): try: - cursor = self._conn.execute( - f"SELECT COUNT(DISTINCT {col[1]}) FROM {self._table_name}" + count_query = text( + f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}" ) - distinct_count = cursor.fetchone()[0] - if distinct_count <= 10: # Use fixed threshold for simplicity - cursor = self._conn.execute( - f"SELECT DISTINCT {col[1]} FROM {self._table_name} " - f"WHERE {col[1]} IS NOT NULL" - ) - values = [str(row[0]) for row in cursor.fetchall()] - values_str = ", ".join([f"'{v}'" for v in values]) - column_info.append(f" Categorical values: {values_str}") - except sqlite3.Error: + with self._get_connection() as conn: + distinct_count = conn.execute(count_query).scalar() + if distinct_count and distinct_count <= 10: + values_query = text( + f"SELECT DISTINCT {col['name']} FROM {self._table_name} " + f"WHERE {col['name']} IS NOT NULL" + ) + values = [ + str(row[0]) + for row in conn.execute(values_query).fetchall() + ] + values_str = ", ".join([f"'{v}'" for v in values]) + column_info.append(f" Categorical values: {values_str}") + except Exception: pass # Skip categorical info if query fails schema.extend(column_info) @@ -195,7 +226,7 @@ def get_schema(self) -> str: return "\n".join(schema) def execute_query(self, query: str) -> pd.DataFrame: - """Execute query using SQLite. + """Execute SQL query and return results as DataFrame. Args: query: SQL query to execute @@ -203,7 +234,8 @@ def execute_query(self, query: str) -> pd.DataFrame: Returns: Query results as pandas DataFrame """ - return pd.read_sql_query(query, self._conn) + with self._get_connection() as conn: + return pd.read_sql_query(text(query), conn) def get_data(self) -> pd.DataFrame: """Return the unfiltered data as a DataFrame. @@ -211,4 +243,29 @@ def get_data(self) -> pd.DataFrame: Returns: The complete dataset as a pandas DataFrame """ - return pd.read_sql_query(f"SELECT * FROM {self._table_name}", self._conn) + return self.execute_query(f"SELECT * FROM {self._table_name}") + + def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: + """Convert SQLAlchemy type to SQL type name.""" + if isinstance(type_, sqltypes.Integer): + return "INTEGER" + elif isinstance(type_, sqltypes.Float): + return "FLOAT" + elif isinstance(type_, sqltypes.Numeric): + return "NUMERIC" + elif isinstance(type_, sqltypes.Boolean): + return "BOOLEAN" + elif isinstance(type_, sqltypes.DateTime): + return "TIMESTAMP" + elif isinstance(type_, sqltypes.Date): + return "DATE" + elif isinstance(type_, sqltypes.Time): + return "TIME" + elif isinstance(type_, (sqltypes.String, sqltypes.Text)): + return "TEXT" + else: + return type_.__class__.__name__.upper() + + def _get_connection(self) -> Connection: + """Get a connection to use for queries.""" + return self._engine.connect() From a218fb914963a4477598c8f4d0081bae043de286 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Wed, 23 Apr 2025 16:26:58 -0700 Subject: [PATCH 05/12] Don't fail when given table name's case differs from SQLAlchemy Inspector --- python-package/querychat/datasource.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python-package/querychat/datasource.py b/python-package/querychat/datasource.py index e33711e7..1fee9b9c 100644 --- a/python-package/querychat/datasource.py +++ b/python-package/querychat/datasource.py @@ -6,7 +6,7 @@ import narwhals as nw import pandas as pd from sqlalchemy import inspect, text -from sqlalchemy.engine import Engine, Connection +from sqlalchemy.engine import Connection, Engine from sqlalchemy.sql import sqltypes @@ -152,7 +152,7 @@ def __init__(self, engine: Engine, table_name: str): # Validate table exists inspector = inspect(self._engine) - if table_name not in inspector.get_table_names(): + if not inspector.has_table(table_name): raise ValueError(f"Table '{table_name}' not found in database") def get_schema(self) -> str: From dc0814ef6a68575d0bb9624f43596507d769f4e3 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Thu, 1 May 2025 16:58:29 -0400 Subject: [PATCH 06/12] Forgot import --- python-package/querychat/querychat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index fb0e6997..ed558362 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -7,6 +7,7 @@ import chatlas import chevron +import narwhals as nw from shiny import Inputs, Outputs, Session, module, reactive, ui from .datasource import DataSource From 9d95d1d0f47db306c3a422d913cfbcf8c6e0d244 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 16:12:35 -0700 Subject: [PATCH 07/12] Have server() return proper class with typed methods, instead of dict --- .gitignore | 3 +- python-package/examples/app-database.py | 55 +++++++++ .../examples/{app.py => app-dataframe.py} | 7 +- python-package/querychat/querychat.py | 104 ++++++++++++++++-- 4 files changed, 154 insertions(+), 15 deletions(-) create mode 100644 python-package/examples/app-database.py rename python-package/examples/{app.py => app-dataframe.py} (97%) diff --git a/.gitignore b/.gitignore index 98ab2295..32d0462b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ animation.screenflow/ README_files/ -README.html \ No newline at end of file +README.html +.DS_Store \ No newline at end of file diff --git a/python-package/examples/app-database.py b/python-package/examples/app-database.py new file mode 100644 index 00000000..cfee136e --- /dev/null +++ b/python-package/examples/app-database.py @@ -0,0 +1,55 @@ +import sqlite3 +from pathlib import Path + +import querychat +from querychat.datasource import SQLAlchemySource +from seaborn import load_dataset +from shiny import App, render, ui +from sqlalchemy import create_engine + +# Load titanic data and create SQLite database +db_path = Path(__file__).parent / "titanic.db" +engine = create_engine("sqlite:///" + str(db_path)) +# titanic = load_dataset("titanic") +# titanic.to_sql("titanic", conn, if_exists="replace", index=False) + +with open(Path(__file__).parent / "greeting.md", "r") as f: + greeting = f.read() +with open(Path(__file__).parent / "data_description.md", "r") as f: + data_desc = f.read() + +# 1. Configure querychat +querychat_config = querychat.init( + SQLAlchemySource(engine, "titanic"), + greeting=greeting, + data_description=data_desc, +) + +# Create UI +app_ui = ui.page_sidebar( + # 2. Place the chat component in the sidebar + querychat.sidebar("chat"), + # Main panel with data viewer + ui.card( + ui.output_data_frame("data_table"), + fill=True, + ), + title="querychat with Python (SQLite)", + fillable=True, +) + + +# Define server logic +def server(input, output, session): + # 3. Initialize querychat server with the config from step 1 + chat = querychat.server("chat", querychat_config) + + # 4. Display the filtered dataframe + @render.data_frame + def data_table(): + # Access filtered data via chat.df() reactive + return chat["df"]() + + +# Create Shiny app +app = App(app_ui, server) diff --git a/python-package/examples/app.py b/python-package/examples/app-dataframe.py similarity index 97% rename from python-package/examples/app.py rename to python-package/examples/app-dataframe.py index 5e628f43..13d224fb 100644 --- a/python-package/examples/app.py +++ b/python-package/examples/app-dataframe.py @@ -1,10 +1,9 @@ from pathlib import Path -from seaborn import load_dataset -from shiny import App, render, ui - import querychat from querychat.datasource import DataFrameSource +from seaborn import load_dataset +from shiny import App, render, ui titanic = load_dataset("titanic") @@ -43,7 +42,7 @@ def server(input, output, session): @render.data_frame def data_table(): # Access filtered data via chat.df() reactive - return chat["df"]() + return chat.df() # Create Shiny app diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index ed558362..093dec16 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -3,11 +3,13 @@ import os import sys from functools import partial -from typing import Any, Dict, Optional, Protocol +from typing import Any, Callable, Optional, Protocol import chatlas import chevron import narwhals as nw +import pandas as pd +from narwhals.typing import IntoFrame from shiny import Inputs, Outputs, Session, module, reactive, ui from .datasource import DataSource @@ -35,6 +37,93 @@ def __init__( self.create_chat_callback = create_chat_callback +class QueryChat: + """ + An object representing a query chat session. This is created within a Shiny + server function or Shiny module server function by using + `querychat.server()`. Use this object to bridge the chat interface with the + rest of the Shiny app, for example, by displaying the filtered data. + """ + + def __init__( + self, + chat: chatlas.Chat, + sql: Callable[[], str], + title: Callable[[], str | None], + df: Callable[[], pd.DataFrame], + ): + """ + Initialize a QueryChat object. + + Args: + chat: The chat object for the session + sql: Reactive that returns the current SQL query + title: Reactive that returns the current title + df: Reactive that returns the filtered data frame + """ + self._chat = chat + self._sql = sql + self._title = title + self._df = df + + def chat(self) -> chatlas.Chat: + """ + Get the chat object for this session. + + Returns: + The chat object + """ + return self._chat() + + def sql(self) -> str: + """ + Reactively read the current SQL query that is in effect. + + Returns: + The current SQL query as a string, or `""` if no query has been set. + """ + return self._sql() + + def title(self) -> str | None: + """ + Reactively read the current title that is in effect. The title is a + short description of the current query that the LLM provides to us + whenever it generates a new SQL query. It can be used as a status string + for the data dashboard. + + Returns: + The current title as a string, or `None` if no title has been set + due to no SQL query being set. + """ + return self._title() + + def df(self) -> pd.DataFrame: + """ + Reactively read the current filtered data frame that is in effect. + + Returns: + The current filtered data frame as a pandas DataFrame. If no query + has been set, this will return the unfiltered data frame from the + data source. + """ + return self._df() + + def __getitem__(self, key: str) -> Any: + """ + Allow access to configuration parameters like a dictionary. For + backwards compatibility only; new code should use the attributes + directly instead. + """ + if key == "chat": + return self.chat + elif key == "sql": + return self.sql + elif key == "title": + return self.title + elif key == "df": + return self.df + + def system_prompt( data_source: DataSource, data_description: Optional[str] = None, @@ -190,7 +279,7 @@ def sidebar(id: str, width: int = 400, height: str = "100%", **kwargs) -> ui.Sid @module.server def server( input: Inputs, output: Outputs, session: Session, querychat_config: QueryChatConfig -) -> Dict[str, Any]: +) -> QueryChat: """ Initialize the querychat server. @@ -219,8 +308,8 @@ def _(): create_chat_callback = querychat_config.create_chat_callback # Reactive values to store state - current_title = reactive.Value(None) - current_query = reactive.Value("") + current_title: reactive.Value[str | None] = reactive.Value(None) + current_query: reactive.Value[str] = reactive.Value("") @reactive.Calc def filtered_df(): @@ -326,9 +415,4 @@ async def greet_on_startup(): await chat_ui.append_message_stream(stream) # Return the interface for other components to use - return { - "chat": chat, - "sql": current_query.get, - "title": current_title.get, - "df": filtered_df, - } + return QueryChat(chat, current_query.get, current_title.get, filtered_df) From aeb87dd060fbafb1c973d94c7041ab20ccf71dd8 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 16:17:43 -0700 Subject: [PATCH 08/12] Auto-create sqlite database for example --- .gitignore | 3 ++- python-package/examples/app-database.py | 9 ++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 32d0462b..1639e057 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ __pycache__/ animation.screenflow/ README_files/ README.html -.DS_Store \ No newline at end of file +.DS_Store +python-package/examples/titanic.db diff --git a/python-package/examples/app-database.py b/python-package/examples/app-database.py index cfee136e..c196b3e7 100644 --- a/python-package/examples/app-database.py +++ b/python-package/examples/app-database.py @@ -1,4 +1,3 @@ -import sqlite3 from pathlib import Path import querychat @@ -10,8 +9,12 @@ # Load titanic data and create SQLite database db_path = Path(__file__).parent / "titanic.db" engine = create_engine("sqlite:///" + str(db_path)) -# titanic = load_dataset("titanic") -# titanic.to_sql("titanic", conn, if_exists="replace", index=False) + +if not db_path.exists(): + # For example purposes, we'll create the database if it doesn't exist. Don't + # do this in your app! + titanic = load_dataset("titanic") + titanic.to_sql("titanic", engine, if_exists="replace", index=False) with open(Path(__file__).parent / "greeting.md", "r") as f: greeting = f.read() From c38b567189b73dee742715c5983ae32d57adc6c1 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 16:38:25 -0700 Subject: [PATCH 09/12] Have init() take data frame or sqlalchemy engine directly ...instead of requiring explicit DataSource subclass creation --- python-package/examples/app-database.py | 4 ++-- python-package/examples/app-dataframe.py | 4 ++-- python-package/querychat/querychat.py | 19 ++++++++++++++----- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/python-package/examples/app-database.py b/python-package/examples/app-database.py index c196b3e7..9769cc17 100644 --- a/python-package/examples/app-database.py +++ b/python-package/examples/app-database.py @@ -1,7 +1,6 @@ from pathlib import Path import querychat -from querychat.datasource import SQLAlchemySource from seaborn import load_dataset from shiny import App, render, ui from sqlalchemy import create_engine @@ -23,7 +22,8 @@ # 1. Configure querychat querychat_config = querychat.init( - SQLAlchemySource(engine, "titanic"), + engine, + "titanic", greeting=greeting, data_description=data_desc, ) diff --git a/python-package/examples/app-dataframe.py b/python-package/examples/app-dataframe.py index 13d224fb..1a1fd858 100644 --- a/python-package/examples/app-dataframe.py +++ b/python-package/examples/app-dataframe.py @@ -1,7 +1,6 @@ from pathlib import Path import querychat -from querychat.datasource import DataFrameSource from seaborn import load_dataset from shiny import App, render, ui @@ -14,7 +13,8 @@ # 1. Configure querychat querychat_config = querychat.init( - DataFrameSource(titanic, "titanic"), + titanic, + "titanic", greeting=greeting, data_description=data_desc, ) diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index 093dec16..aec6bba7 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -9,10 +9,11 @@ import chevron import narwhals as nw import pandas as pd +import sqlalchemy from narwhals.typing import IntoFrame from shiny import Inputs, Outputs, Session, module, reactive, ui -from .datasource import DataSource +from .datasource import DataFrameSource, DataSource, SQLAlchemySource class CreateChatCallback(Protocol): @@ -73,7 +74,7 @@ def chat(self) -> chatlas.Chat: Returns: The chat object """ - return self._chat() + return self._chat def sql(self) -> str: """ @@ -187,7 +188,8 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: def init( - data_source: DataSource, + data_source: IntoFrame | sqlalchemy.Engine, + table_name: str, greeting: Optional[str] = None, data_description: Optional[str] = None, extra_instructions: Optional[str] = None, @@ -207,6 +209,13 @@ def init( Returns: A QueryChatConfig object that can be passed to server() """ + + data_source_obj: DataSource + if isinstance(data_source, sqlalchemy.Engine): + data_source_obj = SQLAlchemySource(data_source, table_name) + else: + data_source_obj = DataFrameSource(nw.from_native(data_source).to_pandas(), table_name) + # Process greeting if greeting is None: print( @@ -217,7 +226,7 @@ def init( # Create the system prompt, or use the override _system_prompt = system_prompt_override or system_prompt( - data_source, data_description, extra_instructions + data_source_obj, data_description, extra_instructions ) # Default chat function if none provided @@ -226,7 +235,7 @@ def init( ) return QueryChatConfig( - data_source=data_source, + data_source=data_source_obj, system_prompt=_system_prompt, greeting=greeting, create_chat_callback=create_chat_callback, From 57922b3fe2eeda722f28ca35b60543a9d4223d15 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 17:11:26 -0700 Subject: [PATCH 10/12] Use GPT-4.1 by default, not GPT-4, yuck --- python-package/querychat/querychat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/querychat/querychat.py b/python-package/querychat/querychat.py index ed558362..167560c5 100644 --- a/python-package/querychat/querychat.py +++ b/python-package/querychat/querychat.py @@ -133,7 +133,7 @@ def init( # Default chat function if none provided create_chat_callback = create_chat_callback or partial( - chatlas.ChatOpenAI, model="gpt-4" + chatlas.ChatOpenAI, model="gpt-4.1" ) return QueryChatConfig( From a08764bf130895a895fdff7c2d535ef40855f156 Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Mon, 2 Jun 2025 17:23:12 -0700 Subject: [PATCH 11/12] Update README --- python-package/README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python-package/README.md b/python-package/README.md index be8057ea..9b29fb19 100644 --- a/python-package/README.md +++ b/python-package/README.md @@ -56,7 +56,7 @@ def server(input, output, session): # chat["df"]() reactive. @render.data_frame def data_table(): - return chat["df"]() + return chat.df() # Create Shiny app @@ -171,8 +171,8 @@ which you can then pass via: ```python querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", data_description=Path("data_description.md").read_text() ) ``` @@ -185,8 +185,8 @@ You can add additional instructions of your own to the end of the system prompt, ```python querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", extra_instructions=[ "You're speaking to a British audience--please use appropriate spelling conventions.", "Use lots of emojis! πŸ˜ƒ Emojis everywhere, 🌍 emojis forever. ♾️", @@ -218,8 +218,8 @@ def my_chat_func(system_prompt: str) -> chatlas.Chat: my_chat_func = partial(chatlas.ChatAnthropic, model="claude-3-7-sonnet-latest") querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", create_chat_callback=my_chat_func ) ``` From 43ee05088a8f966471ac352ab70c68490818372e Mon Sep 17 00:00:00 2001 From: Nick Pelikan Date: Fri, 6 Jun 2025 12:50:54 -0700 Subject: [PATCH 12/12] Optimizing Schema Queries (#26) * this should significantly speed up schema generation * another speedup * ruff formatting * updating so formatting checks pass --- python-package/pyproject.toml | 7 +- python-package/src/querychat/datasource.py | 144 ++++++++++----- python-package/src/querychat/querychat.py | 14 +- python-package/tests/__init__.py | 0 python-package/tests/test_datasource.py | 194 +++++++++++++++++++++ 5 files changed, 310 insertions(+), 49 deletions(-) create mode 100644 python-package/tests/__init__.py create mode 100644 python-package/tests/test_datasource.py diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index 1ac303bb..7fbfe145 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -43,7 +43,12 @@ packages = ["src/querychat"] include = ["src/querychat", "LICENSE", "README.md"] [tool.uv] -dev-dependencies = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4"] +dev-dependencies = [ + "ruff>=0.6.5", + "pyright>=1.1.401", + "tox-uv>=1.11.4", + "pytest>=8.4.0", +] [tool.ruff] src = ["src/querychat"] diff --git a/python-package/src/querychat/datasource.py b/python-package/src/querychat/datasource.py index d9322ff4..c3c00390 100644 --- a/python-package/src/querychat/datasource.py +++ b/python-package/src/querychat/datasource.py @@ -1,14 +1,16 @@ from __future__ import annotations -from typing import ClassVar, Protocol +from typing import TYPE_CHECKING, ClassVar, Protocol import duckdb import narwhals as nw import pandas as pd from sqlalchemy import inspect, text -from sqlalchemy.engine import Connection, Engine from sqlalchemy.sql import sqltypes +if TYPE_CHECKING: + from sqlalchemy.engine import Connection, Engine + class DataSource(Protocol): db_engine: ClassVar[str] @@ -176,7 +178,7 @@ def __init__(self, engine: Engine, table_name: str): if not inspector.has_table(table_name): raise ValueError(f"Table '{table_name}' not found in database") - def get_schema(self, *, categorical_threshold: int) -> str: + def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912 """ Generate schema information from database table. @@ -189,12 +191,15 @@ def get_schema(self, *, categorical_threshold: int) -> str: schema = [f"Table: {self._table_name}", "Columns:"] + # Build a single query to get all column statistics + select_parts = [] + numeric_columns = [] + text_columns = [] + for col in columns: - # Get SQL type name - sql_type = self._get_sql_type_name(col["type"]) - column_info = [f"- {col['name']} ({sql_type})"] + col_name = col["name"] - # For numeric columns, try to get range + # Check if column is numeric if isinstance( col["type"], ( @@ -206,44 +211,103 @@ def get_schema(self, *, categorical_threshold: int) -> str: sqltypes.DateTime, sqltypes.BigInteger, sqltypes.SmallInteger, - # sqltypes.Interval, ), ): - try: - query = text( - f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}", - ) - with self._get_connection() as conn: - result = conn.execute(query).fetchone() - if result and result[0] is not None and result[1] is not None: - column_info.append(f" Range: {result[0]} to {result[1]}") - except Exception: - pass # Skip range info if query fails - - # For string/text columns, check if categorical + numeric_columns.append(col_name) + select_parts.extend( + [ + f"MIN({col_name}) as {col_name}_min", + f"MAX({col_name}) as {col_name}_max", + ], + ) + + # Check if column is text/string elif isinstance( col["type"], (sqltypes.String, sqltypes.Text, sqltypes.Enum), ): - try: - count_query = text( - f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}", - ) + text_columns.append(col_name) + select_parts.append( + f"COUNT(DISTINCT {col_name}) as {col_name}_distinct_count", + ) + + # Execute single query to get all statistics + column_stats = {} + if select_parts: + try: + stats_query = text( + f"SELECT {', '.join(select_parts)} FROM {self._table_name}", # noqa: S608 + ) + with self._get_connection() as conn: + result = conn.execute(stats_query).fetchone() + if result: + # Convert result to dict for easier access + column_stats = dict(zip(result._fields, result)) + except Exception: # noqa: S110 + pass # Fall back to no statistics if query fails + + # Get categorical values for text columns that are below threshold + categorical_values = {} + text_cols_to_query = [] + for col_name in text_columns: + distinct_count_key = f"{col_name}_distinct_count" + if ( + distinct_count_key in column_stats + and column_stats[distinct_count_key] + and column_stats[distinct_count_key] <= categorical_threshold + ): + text_cols_to_query.append(col_name) + + # Get categorical values in a single query if needed + if text_cols_to_query: + try: + # Build UNION query for all categorical columns + union_parts = [ + f"SELECT '{col_name}' as column_name, {col_name} as value " # noqa: S608 + f"FROM {self._table_name} WHERE {col_name} IS NOT NULL " + f"GROUP BY {col_name}" + for col_name in text_cols_to_query + ] + + if union_parts: + categorical_query = text(" UNION ALL ".join(union_parts)) with self._get_connection() as conn: - distinct_count = conn.execute(count_query).scalar() - if distinct_count and distinct_count <= categorical_threshold: - values_query = text( - f"SELECT DISTINCT {col['name']} FROM {self._table_name} " - f"WHERE {col['name']} IS NOT NULL", - ) - values = [ - str(row[0]) - for row in conn.execute(values_query).fetchall() - ] - values_str = ", ".join([f"'{v}'" for v in values]) - column_info.append(f" Categorical values: {values_str}") - except Exception: - pass # Skip categorical info if query fails + results = conn.execute(categorical_query).fetchall() + for row in results: + col_name, value = row + if col_name not in categorical_values: + categorical_values[col_name] = [] + categorical_values[col_name].append(str(value)) + except Exception: # noqa: S110 + pass # Skip categorical values if query fails + + # Build schema description using collected statistics + for col in columns: + col_name = col["name"] + sql_type = self._get_sql_type_name(col["type"]) + column_info = [f"- {col_name} ({sql_type})"] + + # Add range info for numeric columns + if col_name in numeric_columns: + min_key = f"{col_name}_min" + max_key = f"{col_name}_max" + if ( + min_key in column_stats + and max_key in column_stats + and column_stats[min_key] is not None + and column_stats[max_key] is not None + ): + column_info.append( + f" Range: {column_stats[min_key]} to {column_stats[max_key]}", + ) + + # Add categorical values for text columns + elif col_name in categorical_values: + values = categorical_values[col_name] + # Remove duplicates and sort + unique_values = sorted(set(values)) + values_str = ", ".join([f"'{v}'" for v in unique_values]) + column_info.append(f" Categorical values: {values_str}") schema.extend(column_info) @@ -271,9 +335,9 @@ def get_data(self) -> pd.DataFrame: The complete dataset as a pandas DataFrame """ - return self.execute_query(f"SELECT * FROM {self._table_name}") + return self.execute_query(f"SELECT * FROM {self._table_name}") # noqa: S608 - def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: + def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: # noqa: PLR0911 """Convert SQLAlchemy type to SQL type name.""" if isinstance(type_, sqltypes.Integer): return "INTEGER" diff --git a/python-package/src/querychat/querychat.py b/python-package/src/querychat/querychat.py index 5e693659..94bd93eb 100644 --- a/python-package/src/querychat/querychat.py +++ b/python-package/src/querychat/querychat.py @@ -126,14 +126,12 @@ def __getitem__(self, key: str) -> Any: backwards compatibility only; new code should use the attributes directly instead. """ - if key == "chat": - return self.chat - elif key == "sql": - return self.sql - elif key == "title": - return self.title - elif key == "df": - return self.df + return { + "chat": self.chat, + "sql": self.sql, + "title": self.title, + "df": self.df, + }.get(key) def system_prompt( diff --git a/python-package/tests/__init__.py b/python-package/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python-package/tests/test_datasource.py b/python-package/tests/test_datasource.py new file mode 100644 index 00000000..ca5395c2 --- /dev/null +++ b/python-package/tests/test_datasource.py @@ -0,0 +1,194 @@ +import sqlite3 +import tempfile +from pathlib import Path + +import pytest +from sqlalchemy import create_engine + +from src.querychat.datasource import SQLAlchemySource + + +@pytest.fixture +def test_db_engine(): + """Create a temporary SQLite database with test data.""" + # Create temporary database file + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + temp_db.close() + + # Connect and create test table with various data types + conn = sqlite3.connect(temp_db.name) + cursor = conn.cursor() + + # Create table with different column types + cursor.execute(""" + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY, + name TEXT, + age INTEGER, + salary REAL, + is_active BOOLEAN, + join_date DATE, + category TEXT, + score NUMERIC, + description TEXT + ) + """) + + # Insert test data + test_data = [ + (1, "Alice", 30, 75000.50, True, "2023-01-15", "A", 95.5, "Senior developer"), + (2, "Bob", 25, 60000.00, True, "2023-03-20", "B", 87.2, "Junior developer"), + (3, "Charlie", 35, 85000.75, False, "2022-12-01", "A", 92.1, "Team lead"), + (4, "Diana", 28, 70000.25, True, "2023-05-10", "C", 89.8, "Mid-level developer"), + (5, "Eve", 32, 80000.00, True, "2023-02-28", "A", 91.3, "Senior developer"), + (6, "Frank", 26, 62000.50, False, "2023-04-15", "B", 85.7, "Junior developer"), + (7, "Grace", 29, 72000.75, True, "2023-01-30", "A", 93.4, "Developer"), + (8, "Henry", 31, 78000.25, True, "2023-03-05", "C", 88.9, "Senior developer"), + ] + + cursor.executemany(""" + INSERT INTO test_table + (id, name, age, salary, is_active, join_date, category, score, description) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, test_data) + + conn.commit() + conn.close() + + # Create SQLAlchemy engine + engine = create_engine(f"sqlite:///{temp_db.name}") + + yield engine + + # Cleanup + Path(temp_db.name).unlink() + + +def test_get_schema_numeric_ranges(test_db_engine): + """Test that numeric columns include min/max ranges.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Check that numeric columns have range information + assert "- id (INTEGER)" in schema + assert "Range: 1 to 8" in schema + + assert "- age (INTEGER)" in schema + assert "Range: 25 to 35" in schema + + assert "- salary (FLOAT)" in schema + assert "Range: 60000.0 to 85000.75" in schema + + assert "- score (NUMERIC)" in schema + assert "Range: 85.7 to 95.5" in schema + + +def test_get_schema_categorical_values(test_db_engine): + """Test that text columns with few unique values show categorical values.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Category column should be treated as categorical (3 unique values: A, B, C) + assert "- category (TEXT)" in schema + assert "Categorical values:" in schema + assert "'A'" in schema and "'B'" in schema and "'C'" in schema + + +def test_get_schema_non_categorical_text(test_db_engine): + """Test that text columns with many unique values don't show categorical values.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=3) + + # Name and description columns should not be categorical (8 and 6 unique values respectively) + lines = schema.split('\n') + name_line_idx = next(i for i, line in enumerate(lines) if "- name (TEXT)" in line) + description_line_idx = next(i for i, line in enumerate(lines) if "- description (TEXT)" in line) + + # Check that the next line after name column doesn't contain categorical values + if name_line_idx + 1 < len(lines): + assert "Categorical values:" not in lines[name_line_idx + 1] + + # Check that the next line after description column doesn't contain categorical values + if description_line_idx + 1 < len(lines): + assert "Categorical values:" not in lines[description_line_idx + 1] + + +def test_get_schema_different_thresholds(test_db_engine): + """Test that categorical_threshold parameter works correctly.""" + source = SQLAlchemySource(test_db_engine, "test_table") + + # With threshold 2, only category column (3 unique) should not be categorical + schema_low = source.get_schema(categorical_threshold=2) + assert "- category (TEXT)" in schema_low + assert "'A'" not in schema_low # Should not show categorical values + + # With threshold 5, category column should be categorical + schema_high = source.get_schema(categorical_threshold=5) + assert "- category (TEXT)" in schema_high + assert "'A'" in schema_high # Should show categorical values + + +def test_get_schema_table_structure(test_db_engine): + """Test the overall structure of the schema output.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + lines = schema.split('\n') + + # Check header + assert lines[0] == "Table: test_table" + assert lines[1] == "Columns:" + + # Check that all columns are present + expected_columns = ["id", "name", "age", "salary", "is_active", "join_date", "category", "score", "description"] + for col in expected_columns: + assert any(f"- {col} (" in line for line in lines), f"Column {col} not found in schema" + + +def test_get_schema_empty_result_handling(test_db_engine): + """Test handling when statistics queries return empty results.""" + # Create empty table + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + cursor.execute("CREATE TABLE empty_table (id INTEGER, name TEXT)") + conn.commit() + + engine = create_engine("sqlite:///:memory:") + # Recreate table in the new engine + with engine.connect() as connection: + from sqlalchemy import text + connection.execute(text("CREATE TABLE empty_table (id INTEGER, name TEXT)")) + connection.commit() + + source = SQLAlchemySource(engine, "empty_table") + schema = source.get_schema(categorical_threshold=5) + + # Should still work but without range/categorical info + assert "Table: empty_table" in schema + assert "- id (INTEGER)" in schema + assert "- name (TEXT)" in schema + # Should not have range or categorical information + assert "Range:" not in schema + assert "Categorical values:" not in schema + + +def test_get_schema_boolean_and_date_types(test_db_engine): + """Test handling of boolean and date column types.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Boolean column should show range + assert "- is_active (BOOLEAN)" in schema + # SQLite stores booleans as integers, so should show 0 to 1 range + + # Date column should show range + assert "- join_date (DATE)" in schema + assert "Range:" in schema + + +def test_invalid_table_name(): + """Test that invalid table name raises appropriate error.""" + engine = create_engine("sqlite:///:memory:") + + with pytest.raises(ValueError, match="Table 'nonexistent' not found in database"): + SQLAlchemySource(engine, "nonexistent") \ No newline at end of file