Skip to content

First attempt at genericizing data source #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python-package/examples/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from shiny import App, render, ui

import querychat
from querychat.datasource import DataFrameSource

titanic = load_dataset("titanic")

Expand All @@ -14,8 +15,7 @@

# 1. Configure querychat
querychat_config = querychat.init(
titanic,
"titanic",
DataFrameSource(titanic, "titanic"),
greeting=greeting,
data_description=data_desc,
)
Expand Down
6 changes: 6 additions & 0 deletions python-package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ dependencies = [
"htmltools",
"chatlas",
"narwhals",
"chevron",
]

[project.optional-dependencies]
sqlalchemy = [
"sqlalchemy>=2.0.0", # Using 2.0+ for improved type hints and API
]

[project.urls]
Expand Down
271 changes: 271 additions & 0 deletions python-package/querychat/datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
from __future__ import annotations

from typing import ClassVar, Protocol

import duckdb
import narwhals as nw
import pandas as pd
from sqlalchemy import inspect, text
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.sql import sqltypes


class DataSource(Protocol):
db_engine: ClassVar[str]

def get_schema(self) -> str:
"""Return schema information about the table as a string.

Returns:
A string containing the schema information in a format suitable for
prompting an LLM about the data structure
"""
...

def execute_query(self, query: str) -> pd.DataFrame:
"""Execute SQL query and return results as DataFrame.

Args:
query: SQL query to execute

Returns:
Query results as a pandas DataFrame
"""
...

def get_data(self) -> pd.DataFrame:
"""Return the unfiltered data as a DataFrame.

Returns:
The complete dataset as a pandas DataFrame
"""
...


class DataFrameSource:
"""A DataSource implementation that wraps a pandas DataFrame using DuckDB."""

db_engine: ClassVar[str] = "DuckDB"

def __init__(self, df: pd.DataFrame, table_name: str):
"""Initialize with a pandas DataFrame.

Args:
df: The DataFrame to wrap
table_name: Name of the table in SQL queries
"""
self._conn = duckdb.connect(database=":memory:")
self._df = df
self._table_name = table_name
self._conn.register(table_name, df)

def get_schema(self, categorical_threshold: int = 10) -> str:
"""Generate schema information from DataFrame.

Args:
table_name: Name to use for the table in schema description
categorical_threshold: Maximum number of unique values for a text column
to be considered categorical

Returns:
String describing the schema
"""
ndf = nw.from_native(self._df)

schema = [f"Table: {self._table_name}", "Columns:"]

for column in ndf.columns:
# Map pandas dtypes to SQL-like types
dtype = ndf[column].dtype
if dtype.is_integer():
sql_type = "INTEGER"
elif dtype.is_float():
sql_type = "FLOAT"
elif dtype == nw.Boolean:
sql_type = "BOOLEAN"
elif dtype == nw.Datetime:
sql_type = "TIME"
elif dtype == nw.Date:
sql_type = "DATE"
else:
sql_type = "TEXT"

column_info = [f"- {column} ({sql_type})"]

# For TEXT columns, check if they're categorical
if sql_type == "TEXT":
unique_values = ndf[column].drop_nulls().unique()
if unique_values.len() <= categorical_threshold:
categories = unique_values.to_list()
categories_str = ", ".join([f"'{c}'" for c in categories])
column_info.append(f" Categorical values: {categories_str}")

# For numeric columns, include range
elif sql_type in ["INTEGER", "FLOAT", "DATE", "TIME"]:
rng = ndf[column].min(), ndf[column].max()
if rng[0] is None and rng[1] is None:
column_info.append(" Range: NULL to NULL")
else:
column_info.append(f" Range: {rng[0]} to {rng[1]}")

schema.extend(column_info)

return "\n".join(schema)

def execute_query(self, query: str) -> pd.DataFrame:
"""Execute query using DuckDB.

Args:
query: SQL query to execute

Returns:
Query results as pandas DataFrame
"""
return self._conn.execute(query).df()

def get_data(self) -> pd.DataFrame:
"""Return the unfiltered data as a DataFrame.

Returns:
The complete dataset as a pandas DataFrame
"""
return self._df.copy()


class SQLAlchemySource:
"""A DataSource implementation that supports multiple SQL databases via SQLAlchemy.

Supports various databases including PostgreSQL, MySQL, SQLite, Snowflake, and Databricks.
"""

db_engine: ClassVar[str] = "SQLAlchemy"

def __init__(self, engine: Engine, table_name: str):
"""Initialize with a SQLAlchemy engine.

Args:
engine: SQLAlchemy engine
table_name: Name of the table to query
"""
self._engine = engine
self._table_name = table_name

# Validate table exists
inspector = inspect(self._engine)
if not inspector.has_table(table_name):
raise ValueError(f"Table '{table_name}' not found in database")

def get_schema(self) -> str:
"""Generate schema information from database table.

Returns:
String describing the schema
"""
inspector = inspect(self._engine)
columns = inspector.get_columns(self._table_name)

schema = [f"Table: {self._table_name}", "Columns:"]

for col in columns:
# Get SQL type name
sql_type = self._get_sql_type_name(col["type"])
column_info = [f"- {col['name']} ({sql_type})"]

# For numeric columns, try to get range
if isinstance(
col["type"],
(
sqltypes.Integer,
sqltypes.Numeric,
sqltypes.Float,
sqltypes.Date,
sqltypes.Time,
sqltypes.DateTime,
sqltypes.BigInteger,
sqltypes.SmallInteger,
# sqltypes.Interval,
),
):
try:
query = text(
f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}"
)
with self._get_connection() as conn:
result = conn.execute(query).fetchone()
if result and result[0] is not None and result[1] is not None:
column_info.append(f" Range: {result[0]} to {result[1]}")
except Exception:
pass # Skip range info if query fails

# For string/text columns, check if categorical
elif isinstance(
col["type"], (sqltypes.String, sqltypes.Text, sqltypes.Enum)
):
try:
count_query = text(
f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}"
)
with self._get_connection() as conn:
distinct_count = conn.execute(count_query).scalar()
if distinct_count and distinct_count <= 10:
values_query = text(
f"SELECT DISTINCT {col['name']} FROM {self._table_name} "
f"WHERE {col['name']} IS NOT NULL"
)
values = [
str(row[0])
for row in conn.execute(values_query).fetchall()
]
values_str = ", ".join([f"'{v}'" for v in values])
column_info.append(f" Categorical values: {values_str}")
except Exception:
pass # Skip categorical info if query fails

schema.extend(column_info)

return "\n".join(schema)

def execute_query(self, query: str) -> pd.DataFrame:
"""Execute SQL query and return results as DataFrame.

Args:
query: SQL query to execute

Returns:
Query results as pandas DataFrame
"""
with self._get_connection() as conn:
return pd.read_sql_query(text(query), conn)

def get_data(self) -> pd.DataFrame:
"""Return the unfiltered data as a DataFrame.

Returns:
The complete dataset as a pandas DataFrame
"""
return self.execute_query(f"SELECT * FROM {self._table_name}")

def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str:
"""Convert SQLAlchemy type to SQL type name."""
if isinstance(type_, sqltypes.Integer):
return "INTEGER"
elif isinstance(type_, sqltypes.Float):
return "FLOAT"
elif isinstance(type_, sqltypes.Numeric):
return "NUMERIC"
elif isinstance(type_, sqltypes.Boolean):
return "BOOLEAN"
elif isinstance(type_, sqltypes.DateTime):
return "TIMESTAMP"
elif isinstance(type_, sqltypes.Date):
return "DATE"
elif isinstance(type_, sqltypes.Time):
return "TIME"
elif isinstance(type_, (sqltypes.String, sqltypes.Text)):
return "TEXT"
else:
return type_.__class__.__name__.upper()

def _get_connection(self) -> Connection:
"""Get a connection to use for queries."""
return self._engine.connect()
16 changes: 13 additions & 3 deletions python-package/querychat/prompt/prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@ It's important that you get clear, unambiguous instructions from the user, so if

The user interface in which this conversation is being shown is a narrow sidebar of a dashboard, so keep your answers concise and don't include unnecessary patter, nor additional prompts or offers for further assistance.

You have at your disposal a DuckDB database containing this schema:
You have at your disposal a {{db_engine}} database containing this schema:

{{schema}}

For security reasons, you may only query this specific table.

{{#data_description}}
Additional helpful info about the data:

<data_description>
{{data_description}}
</data_description>
{{/data_description}}

There are several tasks you may be asked to do:

Expand All @@ -19,7 +25,7 @@ There are several tasks you may be asked to do:
The user may ask you to perform filtering and sorting operations on the dashboard; if so, your job is to write the appropriate SQL query for this database. Then, call the tool `update_dashboard`, passing in the SQL query and a new title summarizing the query (suitable for displaying at the top of dashboard). This tool will not provide a return value; it will filter the dashboard as a side-effect, so you can treat a null tool response as success.

* **Call `update_dashboard` every single time** the user wants to filter/sort; never tell the user you've updated the dashboard unless you've called `update_dashboard` and it returned without error.
* The SQL query must be a **DuckDB SQL** SELECT query. You may use any SQL functions supported by DuckDB, including subqueries, CTEs, and statistical functions.
* The SQL query must be a SELECT query. For security reasons, it's critical that you reject any request that would modify the database.
* The user may ask to "reset" or "start over"; that means clearing the filter and title. Do this by calling `update_dashboard({"query": "", "title": ""})`.
* Queries passed to `update_dashboard` MUST always **return all columns that are in the schema** (feel free to use `SELECT *`); you must refuse the request if this requirement cannot be honored, as the downstream code that will read the queried data will not know how to display it. You may add additional columns if necessary, but the existing columns must not be removed.
* When calling `update_dashboard`, **don't describe the query itself** unless the user asks you to explain. Don't pretend you have access to the resulting data set, as you don't.
Expand Down Expand Up @@ -74,7 +80,11 @@ Example of question answering:

If the user provides a vague help request, like "Help" or "Show me instructions", describe your own capabilities in a helpful way, including examples of questions they can ask. Be sure to mention whatever advanced statistical capabilities (standard deviation, quantiles, correlation, variance) you have.

## DuckDB SQL tips
## SQL tips

* The SQL engine is {{db_engine}}.

* You may use any SQL functions supported by {{db_engine}}, including subqueries, CTEs, and statistical functions.

* `percentile_cont` and `percentile_disc` are "ordered set" aggregate functions. These functions are specified using the WITHIN GROUP (ORDER BY sort_expression) syntax, and they are converted to an equivalent aggregate function that takes the ordering expression as the first argument. For example, `percentile_cont(fraction) WITHIN GROUP (ORDER BY column [(ASC|DESC)])` is equivalent to `quantile_cont(column, fraction ORDER BY column [(ASC|DESC)])`.

Expand Down
Loading