diff --git a/.gitignore b/.gitignore index 06e00a5c..8229dee0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ __pycache__/ animation.screenflow/ README_files/ README.html +.DS_Store +python-package/examples/titanic.db .quarto # Byte-compiled / optimized / DLL files diff --git a/python-package/README.md b/python-package/README.md index be8057ea..9b29fb19 100644 --- a/python-package/README.md +++ b/python-package/README.md @@ -56,7 +56,7 @@ def server(input, output, session): # chat["df"]() reactive. @render.data_frame def data_table(): - return chat["df"]() + return chat.df() # Create Shiny app @@ -171,8 +171,8 @@ which you can then pass via: ```python querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", data_description=Path("data_description.md").read_text() ) ``` @@ -185,8 +185,8 @@ You can add additional instructions of your own to the end of the system prompt, ```python querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", extra_instructions=[ "You're speaking to a British audience--please use appropriate spelling conventions.", "Use lots of emojis! πŸ˜ƒ Emojis everywhere, 🌍 emojis forever. ♾️", @@ -218,8 +218,8 @@ def my_chat_func(system_prompt: str) -> chatlas.Chat: my_chat_func = partial(chatlas.ChatAnthropic, model="claude-3-7-sonnet-latest") querychat_config = querychat.init( - df=titanic, - table_name="titanic", + titanic, + "titanic", create_chat_callback=my_chat_func ) ``` diff --git a/python-package/examples/app-database.py b/python-package/examples/app-database.py new file mode 100644 index 00000000..ac7066d9 --- /dev/null +++ b/python-package/examples/app-database.py @@ -0,0 +1,57 @@ +from pathlib import Path + +from seaborn import load_dataset +from shiny import App, render, ui +from sqlalchemy import create_engine + +import querychat + +# Load titanic data and create SQLite database +db_path = Path(__file__).parent / "titanic.db" +engine = create_engine("sqlite:///" + str(db_path)) + +if not db_path.exists(): + # For example purposes, we'll create the database if it doesn't exist. Don't + # do this in your app! + titanic = load_dataset("titanic") + titanic.to_sql("titanic", engine, if_exists="replace", index=False) + +greeting = (Path(__file__).parent / "greeting.md").read_text() +data_desc = (Path(__file__).parent / "data_description.md").read_text() + +# 1. Configure querychat +querychat_config = querychat.init( + engine, + "titanic", + greeting=greeting, + data_description=data_desc, +) + +# Create UI +app_ui = ui.page_sidebar( + # 2. Place the chat component in the sidebar + querychat.sidebar("chat"), + # Main panel with data viewer + ui.card( + ui.output_data_frame("data_table"), + fill=True, + ), + title="querychat with Python (SQLite)", + fillable=True, +) + + +# Define server logic +def server(input, output, session): + # 3. Initialize querychat server with the config from step 1 + chat = querychat.server("chat", querychat_config) + + # 4. Display the filtered dataframe + @render.data_frame + def data_table(): + # Access filtered data via chat.df() reactive + return chat["df"]() + + +# Create Shiny app +app = App(app_ui, server) diff --git a/python-package/examples/app.py b/python-package/examples/app-dataframe.py similarity index 82% rename from python-package/examples/app.py rename to python-package/examples/app-dataframe.py index 926622ce..1966900f 100644 --- a/python-package/examples/app.py +++ b/python-package/examples/app-dataframe.py @@ -7,10 +7,8 @@ titanic = load_dataset("titanic") -with open(Path(__file__).parent / "greeting.md", "r") as f: - greeting = f.read() -with open(Path(__file__).parent / "data_description.md", "r") as f: - data_desc = f.read() +greeting = (Path(__file__).parent / "greeting.md").read_text() +data_desc = (Path(__file__).parent / "data_description.md").read_text() # 1. Configure querychat querychat_config = querychat.init( @@ -43,7 +41,7 @@ def server(input, output, session): @render.data_frame def data_table(): # Access filtered data via chat.df() reactive - return chat["df"]() + return chat.df() # Create Shiny app diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index acd772a2..7fbfe145 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -18,8 +18,9 @@ dependencies = [ "htmltools", "chatlas", "narwhals", + "chevron", + "sqlalchemy>=2.0.0" # Using 2.0+ for improved type hints and API ] - classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3.8", @@ -42,7 +43,12 @@ packages = ["src/querychat"] include = ["src/querychat", "LICENSE", "README.md"] [tool.uv] -dev-dependencies = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4"] +dev-dependencies = [ + "ruff>=0.6.5", + "pyright>=1.1.401", + "tox-uv>=1.11.4", + "pytest>=8.4.0", +] [tool.ruff] src = ["src/querychat"] diff --git a/python-package/src/querychat/datasource.py b/python-package/src/querychat/datasource.py new file mode 100644 index 00000000..c3c00390 --- /dev/null +++ b/python-package/src/querychat/datasource.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar, Protocol + +import duckdb +import narwhals as nw +import pandas as pd +from sqlalchemy import inspect, text +from sqlalchemy.sql import sqltypes + +if TYPE_CHECKING: + from sqlalchemy.engine import Connection, Engine + + +class DataSource(Protocol): + db_engine: ClassVar[str] + + def get_schema(self, *, categorical_threshold) -> str: + """ + Return schema information about the table as a string. + + Args: + categorical_threshold: Maximum number of unique values for a text + column to be considered categorical + + Returns: + A string containing the schema information in a format suitable for + prompting an LLM about the data structure + + """ + ... + + def execute_query(self, query: str) -> pd.DataFrame: + """ + Execute SQL query and return results as DataFrame. + + Args: + query: SQL query to execute + + Returns: + Query results as a pandas DataFrame + + """ + ... + + def get_data(self) -> pd.DataFrame: + """ + Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + + """ + ... + + +class DataFrameSource: + """A DataSource implementation that wraps a pandas DataFrame using DuckDB.""" + + db_engine: ClassVar[str] = "DuckDB" + + def __init__(self, df: pd.DataFrame, table_name: str): + """ + Initialize with a pandas DataFrame. + + Args: + df: The DataFrame to wrap + table_name: Name of the table in SQL queries + + """ + self._conn = duckdb.connect(database=":memory:") + self._df = df + self._table_name = table_name + self._conn.register(table_name, df) + + def get_schema(self, *, categorical_threshold: int) -> str: + """ + Generate schema information from DataFrame. + + Args: + table_name: Name to use for the table in schema description + categorical_threshold: Maximum number of unique values for a text column + to be considered categorical + + Returns: + String describing the schema + + """ + ndf = nw.from_native(self._df) + + schema = [f"Table: {self._table_name}", "Columns:"] + + for column in ndf.columns: + # Map pandas dtypes to SQL-like types + dtype = ndf[column].dtype + if dtype.is_integer(): + sql_type = "INTEGER" + elif dtype.is_float(): + sql_type = "FLOAT" + elif dtype == nw.Boolean: + sql_type = "BOOLEAN" + elif dtype == nw.Datetime: + sql_type = "TIME" + elif dtype == nw.Date: + sql_type = "DATE" + else: + sql_type = "TEXT" + + column_info = [f"- {column} ({sql_type})"] + + # For TEXT columns, check if they're categorical + if sql_type == "TEXT": + unique_values = ndf[column].drop_nulls().unique() + if unique_values.len() <= categorical_threshold: + categories = unique_values.to_list() + categories_str = ", ".join([f"'{c}'" for c in categories]) + column_info.append(f" Categorical values: {categories_str}") + + # For numeric columns, include range + elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]: + rng = ndf[column].min(), ndf[column].max() + if rng[0] is None and rng[1] is None: + column_info.append(" Range: NULL to NULL") + else: + column_info.append(f" Range: {rng[0]} to {rng[1]}") + + schema.extend(column_info) + + return "\n".join(schema) + + def execute_query(self, query: str) -> pd.DataFrame: + """ + Execute query using DuckDB. + + Args: + query: SQL query to execute + + Returns: + Query results as pandas DataFrame + + """ + return self._conn.execute(query).df() + + def get_data(self) -> pd.DataFrame: + """ + Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + + """ + return self._df.copy() + + +class SQLAlchemySource: + """ + A DataSource implementation that supports multiple SQL databases via SQLAlchemy. + + Supports various databases including PostgreSQL, MySQL, SQLite, Snowflake, and Databricks. + """ + + db_engine: ClassVar[str] = "SQLAlchemy" + + def __init__(self, engine: Engine, table_name: str): + """ + Initialize with a SQLAlchemy engine. + + Args: + engine: SQLAlchemy engine + table_name: Name of the table to query + + """ + self._engine = engine + self._table_name = table_name + + # Validate table exists + inspector = inspect(self._engine) + if not inspector.has_table(table_name): + raise ValueError(f"Table '{table_name}' not found in database") + + def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912 + """ + Generate schema information from database table. + + Returns: + String describing the schema + + """ + inspector = inspect(self._engine) + columns = inspector.get_columns(self._table_name) + + schema = [f"Table: {self._table_name}", "Columns:"] + + # Build a single query to get all column statistics + select_parts = [] + numeric_columns = [] + text_columns = [] + + for col in columns: + col_name = col["name"] + + # Check if column is numeric + if isinstance( + col["type"], + ( + sqltypes.Integer, + sqltypes.Numeric, + sqltypes.Float, + sqltypes.Date, + sqltypes.Time, + sqltypes.DateTime, + sqltypes.BigInteger, + sqltypes.SmallInteger, + ), + ): + numeric_columns.append(col_name) + select_parts.extend( + [ + f"MIN({col_name}) as {col_name}_min", + f"MAX({col_name}) as {col_name}_max", + ], + ) + + # Check if column is text/string + elif isinstance( + col["type"], + (sqltypes.String, sqltypes.Text, sqltypes.Enum), + ): + text_columns.append(col_name) + select_parts.append( + f"COUNT(DISTINCT {col_name}) as {col_name}_distinct_count", + ) + + # Execute single query to get all statistics + column_stats = {} + if select_parts: + try: + stats_query = text( + f"SELECT {', '.join(select_parts)} FROM {self._table_name}", # noqa: S608 + ) + with self._get_connection() as conn: + result = conn.execute(stats_query).fetchone() + if result: + # Convert result to dict for easier access + column_stats = dict(zip(result._fields, result)) + except Exception: # noqa: S110 + pass # Fall back to no statistics if query fails + + # Get categorical values for text columns that are below threshold + categorical_values = {} + text_cols_to_query = [] + for col_name in text_columns: + distinct_count_key = f"{col_name}_distinct_count" + if ( + distinct_count_key in column_stats + and column_stats[distinct_count_key] + and column_stats[distinct_count_key] <= categorical_threshold + ): + text_cols_to_query.append(col_name) + + # Get categorical values in a single query if needed + if text_cols_to_query: + try: + # Build UNION query for all categorical columns + union_parts = [ + f"SELECT '{col_name}' as column_name, {col_name} as value " # noqa: S608 + f"FROM {self._table_name} WHERE {col_name} IS NOT NULL " + f"GROUP BY {col_name}" + for col_name in text_cols_to_query + ] + + if union_parts: + categorical_query = text(" UNION ALL ".join(union_parts)) + with self._get_connection() as conn: + results = conn.execute(categorical_query).fetchall() + for row in results: + col_name, value = row + if col_name not in categorical_values: + categorical_values[col_name] = [] + categorical_values[col_name].append(str(value)) + except Exception: # noqa: S110 + pass # Skip categorical values if query fails + + # Build schema description using collected statistics + for col in columns: + col_name = col["name"] + sql_type = self._get_sql_type_name(col["type"]) + column_info = [f"- {col_name} ({sql_type})"] + + # Add range info for numeric columns + if col_name in numeric_columns: + min_key = f"{col_name}_min" + max_key = f"{col_name}_max" + if ( + min_key in column_stats + and max_key in column_stats + and column_stats[min_key] is not None + and column_stats[max_key] is not None + ): + column_info.append( + f" Range: {column_stats[min_key]} to {column_stats[max_key]}", + ) + + # Add categorical values for text columns + elif col_name in categorical_values: + values = categorical_values[col_name] + # Remove duplicates and sort + unique_values = sorted(set(values)) + values_str = ", ".join([f"'{v}'" for v in unique_values]) + column_info.append(f" Categorical values: {values_str}") + + schema.extend(column_info) + + return "\n".join(schema) + + def execute_query(self, query: str) -> pd.DataFrame: + """ + Execute SQL query and return results as DataFrame. + + Args: + query: SQL query to execute + + Returns: + Query results as pandas DataFrame + + """ + with self._get_connection() as conn: + return pd.read_sql_query(text(query), conn) + + def get_data(self) -> pd.DataFrame: + """ + Return the unfiltered data as a DataFrame. + + Returns: + The complete dataset as a pandas DataFrame + + """ + return self.execute_query(f"SELECT * FROM {self._table_name}") # noqa: S608 + + def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: # noqa: PLR0911 + """Convert SQLAlchemy type to SQL type name.""" + if isinstance(type_, sqltypes.Integer): + return "INTEGER" + elif isinstance(type_, sqltypes.Float): + return "FLOAT" + elif isinstance(type_, sqltypes.Numeric): + return "NUMERIC" + elif isinstance(type_, sqltypes.Boolean): + return "BOOLEAN" + elif isinstance(type_, sqltypes.DateTime): + return "TIMESTAMP" + elif isinstance(type_, sqltypes.Date): + return "DATE" + elif isinstance(type_, sqltypes.Time): + return "TIME" + elif isinstance(type_, (sqltypes.String, sqltypes.Text)): + return "TEXT" + else: + return type_.__class__.__name__.upper() + + def _get_connection(self) -> Connection: + """Get a connection to use for queries.""" + return self._engine.connect() diff --git a/python-package/src/querychat/prompt/prompt.md b/python-package/src/querychat/prompt/prompt.md index fcb00a5e..7acf8066 100644 --- a/python-package/src/querychat/prompt/prompt.md +++ b/python-package/src/querychat/prompt/prompt.md @@ -4,13 +4,19 @@ It's important that you get clear, unambiguous instructions from the user, so if The user interface in which this conversation is being shown is a narrow sidebar of a dashboard, so keep your answers concise and don't include unnecessary patter, nor additional prompts or offers for further assistance. -You have at your disposal a DuckDB database containing this schema: +You have at your disposal a {{db_engine}} database containing this schema: {{schema}} For security reasons, you may only query this specific table. +{{#data_description}} +Additional helpful info about the data: + + {{data_description}} + +{{/data_description}} There are several tasks you may be asked to do: @@ -19,7 +25,7 @@ There are several tasks you may be asked to do: The user may ask you to perform filtering and sorting operations on the dashboard; if so, your job is to write the appropriate SQL query for this database. Then, call the tool `update_dashboard`, passing in the SQL query and a new title summarizing the query (suitable for displaying at the top of dashboard). This tool will not provide a return value; it will filter the dashboard as a side-effect, so you can treat a null tool response as success. * **Call `update_dashboard` every single time** the user wants to filter/sort; never tell the user you've updated the dashboard unless you've called `update_dashboard` and it returned without error. -* The SQL query must be a **DuckDB SQL** SELECT query. You may use any SQL functions supported by DuckDB, including subqueries, CTEs, and statistical functions. +* The SQL query must be a SELECT query. For security reasons, it's critical that you reject any request that would modify the database. * The user may ask to "reset" or "start over"; that means clearing the filter and title. Do this by calling `update_dashboard({"query": "", "title": ""})`. * Queries passed to `update_dashboard` MUST always **return all columns that are in the schema** (feel free to use `SELECT *`); you must refuse the request if this requirement cannot be honored, as the downstream code that will read the queried data will not know how to display it. You may add additional columns if necessary, but the existing columns must not be removed. * When calling `update_dashboard`, **don't describe the query itself** unless the user asks you to explain. Don't pretend you have access to the resulting data set, as you don't. @@ -84,6 +90,12 @@ If you find yourself offering example questions to the user as part of your resp * Suggestion 3. ``` +## SQL tips + +* The SQL engine is {{db_engine}}. + +* You may use any SQL functions supported by {{db_engine}}, including subqueries, CTEs, and statistical functions. + ## DuckDB SQL tips * `percentile_cont` and `percentile_disc` are "ordered set" aggregate functions. These functions are specified using the WITHIN GROUP (ORDER BY sort_expression) syntax, and they are converted to an equivalent aggregate function that takes the ordering expression as the first argument. For example, `percentile_cont(fraction) WITHIN GROUP (ORDER BY column [(ASC|DESC)])` is equivalent to `quantile_cont(column, fraction ORDER BY column [(ASC|DESC)])`. diff --git a/python-package/src/querychat/querychat.py b/python-package/src/querychat/querychat.py index f8763e0d..94bd93eb 100644 --- a/python-package/src/querychat/querychat.py +++ b/python-package/src/querychat/querychat.py @@ -4,11 +4,14 @@ import sys from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Protocol, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union import chatlas -import duckdb +import chevron import narwhals as nw +import pandas as pd +import sqlalchemy +from narwhals.typing import IntoFrame from shiny import Inputs, Outputs, Session, module, reactive, ui if TYPE_CHECKING: @@ -16,29 +19,143 @@ from narwhals.typing import IntoFrame +from .datasource import DataFrameSource, DataSource, SQLAlchemySource + + +class CreateChatCallback(Protocol): + def __call__(self, system_prompt: str) -> chatlas.Chat: ... + + +class QueryChatConfig: + """ + Configuration class for querychat. + """ + + def __init__( + self, + data_source: DataSource, + system_prompt: str, + greeting: Optional[str], + create_chat_callback: CreateChatCallback, + ): + self.data_source = data_source + self.system_prompt = system_prompt + self.greeting = greeting + self.create_chat_callback = create_chat_callback + + +class QueryChat: + """ + An object representing a query chat session. This is created within a Shiny + server function or Shiny module server function by using + `querychat.server()`. Use this object to bridge the chat interface with the + rest of the Shiny app, for example, by displaying the filtered data. + """ + + def __init__( + self, + chat: chatlas.Chat, + sql: Callable[[], str], + title: Callable[[], Union[str, None]], + df: Callable[[], pd.DataFrame], + ): + """ + Initialize a QueryChat object. + + Args: + chat: The chat object for the session + sql: Reactive that returns the current SQL query + title: Reactive that returns the current title + df: Reactive that returns the filtered data frame + + """ + self._chat = chat + self._sql = sql + self._title = title + self._df = df + + def chat(self) -> chatlas.Chat: + """ + Get the chat object for this session. + + Returns: + The chat object + + """ + return self._chat + + def sql(self) -> str: + """ + Reactively read the current SQL query that is in effect. + + Returns: + The current SQL query as a string, or `""` if no query has been set. + + """ + return self._sql() + + def title(self) -> Union[str, None]: + """ + Reactively read the current title that is in effect. The title is a + short description of the current query that the LLM provides to us + whenever it generates a new SQL query. It can be used as a status string + for the data dashboard. + + Returns: + The current title as a string, or `None` if no title has been set + due to no SQL query being set. + + """ + return self._title() + + def df(self) -> pd.DataFrame: + """ + Reactively read the current filtered data frame that is in effect. + + Returns: + The current filtered data frame as a pandas DataFrame. If no query + has been set, this will return the unfiltered data frame from the + data source. + + """ + return self._df() + + def __getitem__(self, key: str) -> Any: + """ + Allow access to configuration parameters like a dictionary. For + backwards compatibility only; new code should use the attributes + directly instead. + """ + return { + "chat": self.chat, + "sql": self.sql, + "title": self.title, + "df": self.df, + }.get(key) + + def system_prompt( - df: IntoFrame, - table_name: str, + data_source: DataSource, data_description: Optional[str] = None, extra_instructions: Optional[str] = None, categorical_threshold: int = 10, ) -> str: """ - Create a system prompt for the chat model based on a DataFrame's - schema and optional context and instructions. + Create a system prompt for the chat model based on a data source's schema + and optional additional context and instructions. Parameters ---------- - df : IntoFrame - Input data to generate schema information from. - table_name : str - Name of the table to be used in SQL queries. + data_source : DataSource + A data source to generate schema information from data_description : str, optional - Description of the data, in plain text or Markdown format. + Optional description of the data, in plain text or Markdown format extra_instructions : str, optional - Additional instructions for the chat model, in plain text or Markdown format. + Optional additional instructions for the chat model, in plain text or + Markdown format categorical_threshold : int, default=10 - Maximum number of unique values for a text column to be considered categorical. + Threshold for determining if a column is categorical based on number of + unique values Returns ------- @@ -46,95 +163,22 @@ def system_prompt( The system prompt for the chat model. """ - schema = df_to_schema(df, table_name, categorical_threshold) - # Read the prompt file prompt_path = Path(__file__).parent / "prompt" / "prompt.md" prompt_text = prompt_path.read_text() - # Simple template replacement (a more robust template engine could be used) - if data_description: - data_description_section = ( - "Additional helpful info about the data:\n\n" - "\n" - f"{data_description}\n" - "" - ) - else: - data_description_section = "" - - # Replace variables in the template - prompt_text = prompt_text.replace("{{schema}}", schema) - prompt_text = prompt_text.replace("{{data_description}}", data_description_section) - prompt_text = prompt_text.replace( - "{{extra_instructions}}", - extra_instructions or "", + return chevron.render( + prompt_text, + { + "db_engine": data_source.db_engine, + "schema": data_source.get_schema( + categorical_threshold=categorical_threshold, + ), + "data_description": data_description, + "extra_instructions": extra_instructions, + }, ) - return prompt_text - - -def df_to_schema(df: IntoFrame, table_name: str, categorical_threshold: int) -> str: - """ - Convert a DataFrame schema to a string representation for the system prompt. - - Parameters - ---------- - df : IntoFrame - The DataFrame to extract schema from - table_name : str - The name of the table in SQL queries - categorical_threshold : int - The maximum number of unique values for a text column to be considered categorical - - Returns - ------- - str - A string containing the schema information. - - """ - ndf = nw.from_native(df) - - schema = [f"Table: {table_name}", "Columns:"] - - for column in ndf.columns: - # Map pandas dtypes to SQL-like types - dtype = ndf[column].dtype - if dtype.is_integer(): - sql_type = "INTEGER" - elif dtype.is_float(): - sql_type = "FLOAT" - elif dtype == nw.Boolean: - sql_type = "BOOLEAN" - elif dtype == nw.Datetime: - sql_type = "TIME" - elif dtype == nw.Date: - sql_type = "DATE" - else: - sql_type = "TEXT" - - column_info = [f"- {column} ({sql_type})"] - - # For TEXT columns, check if they're categorical - if sql_type == "TEXT": - unique_values = ndf[column].drop_nulls().unique() - if unique_values.len() <= categorical_threshold: - categories = unique_values.to_list() - categories_str = ", ".join([f"'{c}'" for c in categories]) - column_info.append(f" Categorical values: {categories_str}") - - # For numeric columns, include range - elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]: - rng = ndf[column].min(), ndf[column].max() - if rng[0] is None and rng[1] is None: - column_info.append(" Range: NULL to NULL") - else: - column_info.append(f" Range: {rng[0]} to {rng[1]}") - - schema.extend(column_info) - - return "\n".join(schema) - def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: """ @@ -173,32 +217,8 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: return table_html + rows_notice -class CreateChatCallback(Protocol): - def __call__(self, system_prompt: str) -> chatlas.Chat: ... - - -class QueryChatConfig: - """ - Configuration class for querychat. - """ - - def __init__( - self, - df: pd.DataFrame, - conn: duckdb.DuckDBPyConnection, - system_prompt: str, - greeting: Optional[str], - create_chat_callback: CreateChatCallback, - ): - self.df = df - self.conn = conn - self.system_prompt = system_prompt - self.greeting = greeting - self.create_chat_callback = create_chat_callback - - def init( - df: pd.DataFrame, + data_source: IntoFrame | sqlalchemy.Engine, table_name: str, greeting: Optional[str] = None, data_description: Optional[str] = None, @@ -207,14 +227,18 @@ def init( system_prompt_override: Optional[str] = None, ) -> QueryChatConfig: """ - Call this once outside of any server function to initialize querychat. + Initialize querychat with any compliant data source. Parameters ---------- - df : pd.DataFrame - A data frame + data_source : IntoFrame | sqlalchemy.Engine + Either a Narwhals-compatible data frame (e.g., Polars or Pandas) or a + SQLAlchemy engine containing the table to query against. table_name : str - A string containing a valid table name for the data frame + If a data_source is a data frame, a name to use to refer to the table in + SQL queries (usually the variable name of the data frame, but it doesn't + have to be). If a data_source is a SQLAlchemy engine, the table_name is + the name of the table in the database to query against. greeting : str, optional A string in Markdown format, containing the initial message data_description : str, optional @@ -238,6 +262,14 @@ def init( "Table name must begin with a letter and contain only letters, numbers, and underscores", ) + data_source_obj: DataSource + if isinstance(data_source, sqlalchemy.Engine): + data_source_obj = SQLAlchemySource(data_source, table_name) + else: + data_source_obj = DataFrameSource( + nw.from_native(data_source).to_pandas(), + table_name, + ) # Process greeting if greeting is None: print( @@ -246,30 +278,21 @@ def init( file=sys.stderr, ) - # Create the system prompt - if system_prompt_override is None: - _system_prompt = system_prompt( - df, - table_name, - data_description, - extra_instructions, - ) - else: - _system_prompt = system_prompt_override - - # Set up DuckDB connection and register the data frame - conn = duckdb.connect(database=":memory:") - conn.register(table_name, df) + # Create the system prompt, or use the override + _system_prompt = system_prompt_override or system_prompt( + data_source_obj, + data_description, + extra_instructions, + ) # Default chat function if none provided create_chat_callback = create_chat_callback or partial( chatlas.ChatOpenAI, - model="gpt-4o", + model="gpt-4.1", ) return QueryChatConfig( - df=df, - conn=conn, + data_source=data_source_obj, system_prompt=_system_prompt, greeting=greeting, create_chat_callback=create_chat_callback, @@ -338,7 +361,7 @@ def server( # noqa: D417 output: Outputs, session: Session, querychat_config: QueryChatConfig, -) -> dict[str, Any]: +) -> QueryChat: """ Initialize the querychat server. @@ -365,8 +388,7 @@ def _(): pass # Extract config parameters - df = querychat_config.df - conn = querychat_config.conn + data_source = querychat_config.data_source system_prompt = querychat_config.system_prompt greeting = querychat_config.greeting create_chat_callback = querychat_config.create_chat_callback @@ -378,9 +400,9 @@ def _(): @reactive.calc def filtered_df(): if current_query.get() == "": - return df + return data_source.get_data() else: - return conn.execute(current_query.get()).fetch_df() + return data_source.execute_query(current_query.get()) # This would handle appending messages to the chat UI async def append_output(text): @@ -405,7 +427,7 @@ async def update_dashboard(query: str, title: str): try: # Try the query to see if it errors - conn.execute(query) + data_source.execute_query(query) except Exception as e: error_msg = str(e) await append_output(f"> Error: {error_msg}\n\n") @@ -430,7 +452,7 @@ async def query(query: str): await append_output(f"\n```sql\n{query}\n```\n\n") try: - result_df = conn.execute(query).fetch_df() + result_df = data_source.execute_query(query) except Exception as e: error_msg = str(e) await append_output(f"> Error: {error_msg}\n\n") @@ -480,9 +502,4 @@ async def greet_on_startup(): await chat_ui.append_message_stream(stream) # Return the interface for other components to use - return { - "chat": chat, - "sql": current_query.get, - "title": current_title.get, - "df": filtered_df, - } + return QueryChat(chat, current_query.get, current_title.get, filtered_df) diff --git a/python-package/tests/__init__.py b/python-package/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python-package/tests/test_datasource.py b/python-package/tests/test_datasource.py new file mode 100644 index 00000000..ca5395c2 --- /dev/null +++ b/python-package/tests/test_datasource.py @@ -0,0 +1,194 @@ +import sqlite3 +import tempfile +from pathlib import Path + +import pytest +from sqlalchemy import create_engine + +from src.querychat.datasource import SQLAlchemySource + + +@pytest.fixture +def test_db_engine(): + """Create a temporary SQLite database with test data.""" + # Create temporary database file + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + temp_db.close() + + # Connect and create test table with various data types + conn = sqlite3.connect(temp_db.name) + cursor = conn.cursor() + + # Create table with different column types + cursor.execute(""" + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY, + name TEXT, + age INTEGER, + salary REAL, + is_active BOOLEAN, + join_date DATE, + category TEXT, + score NUMERIC, + description TEXT + ) + """) + + # Insert test data + test_data = [ + (1, "Alice", 30, 75000.50, True, "2023-01-15", "A", 95.5, "Senior developer"), + (2, "Bob", 25, 60000.00, True, "2023-03-20", "B", 87.2, "Junior developer"), + (3, "Charlie", 35, 85000.75, False, "2022-12-01", "A", 92.1, "Team lead"), + (4, "Diana", 28, 70000.25, True, "2023-05-10", "C", 89.8, "Mid-level developer"), + (5, "Eve", 32, 80000.00, True, "2023-02-28", "A", 91.3, "Senior developer"), + (6, "Frank", 26, 62000.50, False, "2023-04-15", "B", 85.7, "Junior developer"), + (7, "Grace", 29, 72000.75, True, "2023-01-30", "A", 93.4, "Developer"), + (8, "Henry", 31, 78000.25, True, "2023-03-05", "C", 88.9, "Senior developer"), + ] + + cursor.executemany(""" + INSERT INTO test_table + (id, name, age, salary, is_active, join_date, category, score, description) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, test_data) + + conn.commit() + conn.close() + + # Create SQLAlchemy engine + engine = create_engine(f"sqlite:///{temp_db.name}") + + yield engine + + # Cleanup + Path(temp_db.name).unlink() + + +def test_get_schema_numeric_ranges(test_db_engine): + """Test that numeric columns include min/max ranges.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Check that numeric columns have range information + assert "- id (INTEGER)" in schema + assert "Range: 1 to 8" in schema + + assert "- age (INTEGER)" in schema + assert "Range: 25 to 35" in schema + + assert "- salary (FLOAT)" in schema + assert "Range: 60000.0 to 85000.75" in schema + + assert "- score (NUMERIC)" in schema + assert "Range: 85.7 to 95.5" in schema + + +def test_get_schema_categorical_values(test_db_engine): + """Test that text columns with few unique values show categorical values.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Category column should be treated as categorical (3 unique values: A, B, C) + assert "- category (TEXT)" in schema + assert "Categorical values:" in schema + assert "'A'" in schema and "'B'" in schema and "'C'" in schema + + +def test_get_schema_non_categorical_text(test_db_engine): + """Test that text columns with many unique values don't show categorical values.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=3) + + # Name and description columns should not be categorical (8 and 6 unique values respectively) + lines = schema.split('\n') + name_line_idx = next(i for i, line in enumerate(lines) if "- name (TEXT)" in line) + description_line_idx = next(i for i, line in enumerate(lines) if "- description (TEXT)" in line) + + # Check that the next line after name column doesn't contain categorical values + if name_line_idx + 1 < len(lines): + assert "Categorical values:" not in lines[name_line_idx + 1] + + # Check that the next line after description column doesn't contain categorical values + if description_line_idx + 1 < len(lines): + assert "Categorical values:" not in lines[description_line_idx + 1] + + +def test_get_schema_different_thresholds(test_db_engine): + """Test that categorical_threshold parameter works correctly.""" + source = SQLAlchemySource(test_db_engine, "test_table") + + # With threshold 2, only category column (3 unique) should not be categorical + schema_low = source.get_schema(categorical_threshold=2) + assert "- category (TEXT)" in schema_low + assert "'A'" not in schema_low # Should not show categorical values + + # With threshold 5, category column should be categorical + schema_high = source.get_schema(categorical_threshold=5) + assert "- category (TEXT)" in schema_high + assert "'A'" in schema_high # Should show categorical values + + +def test_get_schema_table_structure(test_db_engine): + """Test the overall structure of the schema output.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + lines = schema.split('\n') + + # Check header + assert lines[0] == "Table: test_table" + assert lines[1] == "Columns:" + + # Check that all columns are present + expected_columns = ["id", "name", "age", "salary", "is_active", "join_date", "category", "score", "description"] + for col in expected_columns: + assert any(f"- {col} (" in line for line in lines), f"Column {col} not found in schema" + + +def test_get_schema_empty_result_handling(test_db_engine): + """Test handling when statistics queries return empty results.""" + # Create empty table + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + cursor.execute("CREATE TABLE empty_table (id INTEGER, name TEXT)") + conn.commit() + + engine = create_engine("sqlite:///:memory:") + # Recreate table in the new engine + with engine.connect() as connection: + from sqlalchemy import text + connection.execute(text("CREATE TABLE empty_table (id INTEGER, name TEXT)")) + connection.commit() + + source = SQLAlchemySource(engine, "empty_table") + schema = source.get_schema(categorical_threshold=5) + + # Should still work but without range/categorical info + assert "Table: empty_table" in schema + assert "- id (INTEGER)" in schema + assert "- name (TEXT)" in schema + # Should not have range or categorical information + assert "Range:" not in schema + assert "Categorical values:" not in schema + + +def test_get_schema_boolean_and_date_types(test_db_engine): + """Test handling of boolean and date column types.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Boolean column should show range + assert "- is_active (BOOLEAN)" in schema + # SQLite stores booleans as integers, so should show 0 to 1 range + + # Date column should show range + assert "- join_date (DATE)" in schema + assert "Range:" in schema + + +def test_invalid_table_name(): + """Test that invalid table name raises appropriate error.""" + engine = create_engine("sqlite:///:memory:") + + with pytest.raises(ValueError, match="Table 'nonexistent' not found in database"): + SQLAlchemySource(engine, "nonexistent") \ No newline at end of file