Skip to content

feat(py): generic datasources improvements re-submit #26 from Nick #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 99 additions & 37 deletions pkg-py/src/querychat/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(self, engine: Engine, table_name: str):
if not inspector.has_table(table_name):
raise ValueError(f"Table '{table_name}' not found in database")

def get_schema(self, *, categorical_threshold: int) -> str:
def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912
"""
Generate schema information from database table.

Expand All @@ -191,12 +191,15 @@ def get_schema(self, *, categorical_threshold: int) -> str:

schema = [f"Table: {self._table_name}", "Columns:"]

# Build a single query to get all column statistics
select_parts = []
numeric_columns = []
text_columns = []

for col in columns:
# Get SQL type name
sql_type = self._get_sql_type_name(col["type"])
column_info = [f"- {col['name']} ({sql_type})"]
col_name = col["name"]

# For numeric columns, try to get range
# Check if column is numeric
if isinstance(
col["type"],
(
Expand All @@ -208,44 +211,103 @@ def get_schema(self, *, categorical_threshold: int) -> str:
sqltypes.DateTime,
sqltypes.BigInteger,
sqltypes.SmallInteger,
# sqltypes.Interval,
),
):
try:
query = text(
f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}",
)
with self._get_connection() as conn:
result = conn.execute(query).fetchone()
if result and result[0] is not None and result[1] is not None:
column_info.append(f" Range: {result[0]} to {result[1]}")
except Exception: # noqa: S110
pass # Silently skip range info if query fails

# For string/text columns, check if categorical
numeric_columns.append(col_name)
select_parts.extend(
[
f"MIN({col_name}) as {col_name}_min",
f"MAX({col_name}) as {col_name}_max",
],
)

# Check if column is text/string
elif isinstance(
col["type"],
(sqltypes.String, sqltypes.Text, sqltypes.Enum),
):
try:
count_query = text(
f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}",
)
text_columns.append(col_name)
select_parts.append(
f"COUNT(DISTINCT {col_name}) as {col_name}_distinct_count",
)

# Execute single query to get all statistics
column_stats = {}
if select_parts:
try:
stats_query = text(
f"SELECT {', '.join(select_parts)} FROM {self._table_name}", # noqa: S608
)
with self._get_connection() as conn:
result = conn.execute(stats_query).fetchone()
if result:
# Convert result to dict for easier access
column_stats = dict(zip(result._fields, result))
except Exception: # noqa: S110
pass # Fall back to no statistics if query fails

# Get categorical values for text columns that are below threshold
categorical_values = {}
text_cols_to_query = []
for col_name in text_columns:
distinct_count_key = f"{col_name}_distinct_count"
if (
distinct_count_key in column_stats
and column_stats[distinct_count_key]
and column_stats[distinct_count_key] <= categorical_threshold
):
text_cols_to_query.append(col_name)

# Get categorical values in a single query if needed
if text_cols_to_query:
try:
# Build UNION query for all categorical columns
union_parts = [
f"SELECT '{col_name}' as column_name, {col_name} as value " # noqa: S608
f"FROM {self._table_name} WHERE {col_name} IS NOT NULL "
f"GROUP BY {col_name}"
for col_name in text_cols_to_query
]

if union_parts:
categorical_query = text(" UNION ALL ".join(union_parts))
with self._get_connection() as conn:
distinct_count = conn.execute(count_query).scalar()
if distinct_count and distinct_count <= categorical_threshold:
values_query = text(
f"SELECT DISTINCT {col['name']} FROM {self._table_name} "
f"WHERE {col['name']} IS NOT NULL",
)
values = [
str(row[0])
for row in conn.execute(values_query).fetchall()
]
values_str = ", ".join([f"'{v}'" for v in values])
column_info.append(f" Categorical values: {values_str}")
except Exception: # noqa: S110
pass # Silently skip categorical info if query fails
results = conn.execute(categorical_query).fetchall()
for row in results:
col_name, value = row
if col_name not in categorical_values:
categorical_values[col_name] = []
categorical_values[col_name].append(str(value))
except Exception: # noqa: S110
pass # Skip categorical values if query fails

# Build schema description using collected statistics
for col in columns:
col_name = col["name"]
sql_type = self._get_sql_type_name(col["type"])
column_info = [f"- {col_name} ({sql_type})"]

# Add range info for numeric columns
if col_name in numeric_columns:
min_key = f"{col_name}_min"
max_key = f"{col_name}_max"
if (
min_key in column_stats
and max_key in column_stats
and column_stats[min_key] is not None
and column_stats[max_key] is not None
):
column_info.append(
f" Range: {column_stats[min_key]} to {column_stats[max_key]}",
)

# Add categorical values for text columns
elif col_name in categorical_values:
values = categorical_values[col_name]
# Remove duplicates and sort
unique_values = sorted(set(values))
values_str = ", ".join([f"'{v}'" for v in unique_values])
column_info.append(f" Categorical values: {values_str}")

schema.extend(column_info)

Expand Down Expand Up @@ -273,7 +335,7 @@ def get_data(self) -> pd.DataFrame:
The complete dataset as a pandas DataFrame

"""
return self.execute_query(f"SELECT * FROM {self._table_name}")
return self.execute_query(f"SELECT * FROM {self._table_name}") # noqa: S608

def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: # noqa: PLR0911
"""Convert SQLAlchemy type to SQL type name."""
Expand Down
19 changes: 6 additions & 13 deletions pkg-py/src/querychat/querychat.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,12 @@ def __getitem__(self, key: str) -> Any:
backwards compatibility only; new code should use the attributes
directly instead.
"""
if key == "chat": # noqa: SIM116
return self.chat
elif key == "sql":
return self.sql
elif key == "title":
return self.title
elif key == "df":
return self.df

raise KeyError(
f"`QueryChat` does not have a key `'{key}'`. "
"Use the attributes `chat`, `sql`, `title`, or `df` instead.",
)
return {
"chat": self.chat,
"sql": self.sql,
"title": self.title,
"df": self.df,
}.get(key)


def system_prompt(
Expand Down
Empty file added pkg-py/tests/__init__.py
Empty file.
Loading
Loading