From 1b683259b8feed277701b74012ae1aff6215e91d Mon Sep 17 00:00:00 2001 From: Daniel Chen Date: Thu, 12 Jun 2025 11:32:45 -0700 Subject: [PATCH] feat(py): re deploy #26 from Nick chore: add pytest in pyproject toml, but uses current tool.uv table tests: oops put test code in wrong file --- pkg-py/src/querychat/datasource.py | 136 ++++++++++---- pkg-py/src/querychat/querychat.py | 19 +- pkg-py/tests/__init__.py | 0 pkg-py/tests/test_datasource.py | 287 +++++++++++++++++++++++++++++ pyproject.toml | 7 +- 5 files changed, 398 insertions(+), 51 deletions(-) create mode 100644 pkg-py/tests/__init__.py create mode 100644 pkg-py/tests/test_datasource.py diff --git a/pkg-py/src/querychat/datasource.py b/pkg-py/src/querychat/datasource.py index 24c3a30f..c3c00390 100644 --- a/pkg-py/src/querychat/datasource.py +++ b/pkg-py/src/querychat/datasource.py @@ -178,7 +178,7 @@ def __init__(self, engine: Engine, table_name: str): if not inspector.has_table(table_name): raise ValueError(f"Table '{table_name}' not found in database") - def get_schema(self, *, categorical_threshold: int) -> str: + def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912 """ Generate schema information from database table. @@ -191,12 +191,15 @@ def get_schema(self, *, categorical_threshold: int) -> str: schema = [f"Table: {self._table_name}", "Columns:"] + # Build a single query to get all column statistics + select_parts = [] + numeric_columns = [] + text_columns = [] + for col in columns: - # Get SQL type name - sql_type = self._get_sql_type_name(col["type"]) - column_info = [f"- {col['name']} ({sql_type})"] + col_name = col["name"] - # For numeric columns, try to get range + # Check if column is numeric if isinstance( col["type"], ( @@ -208,44 +211,103 @@ def get_schema(self, *, categorical_threshold: int) -> str: sqltypes.DateTime, sqltypes.BigInteger, sqltypes.SmallInteger, - # sqltypes.Interval, ), ): - try: - query = text( - f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}", - ) - with self._get_connection() as conn: - result = conn.execute(query).fetchone() - if result and result[0] is not None and result[1] is not None: - column_info.append(f" Range: {result[0]} to {result[1]}") - except Exception: # noqa: S110 - pass # Silently skip range info if query fails - - # For string/text columns, check if categorical + numeric_columns.append(col_name) + select_parts.extend( + [ + f"MIN({col_name}) as {col_name}_min", + f"MAX({col_name}) as {col_name}_max", + ], + ) + + # Check if column is text/string elif isinstance( col["type"], (sqltypes.String, sqltypes.Text, sqltypes.Enum), ): - try: - count_query = text( - f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}", - ) + text_columns.append(col_name) + select_parts.append( + f"COUNT(DISTINCT {col_name}) as {col_name}_distinct_count", + ) + + # Execute single query to get all statistics + column_stats = {} + if select_parts: + try: + stats_query = text( + f"SELECT {', '.join(select_parts)} FROM {self._table_name}", # noqa: S608 + ) + with self._get_connection() as conn: + result = conn.execute(stats_query).fetchone() + if result: + # Convert result to dict for easier access + column_stats = dict(zip(result._fields, result)) + except Exception: # noqa: S110 + pass # Fall back to no statistics if query fails + + # Get categorical values for text columns that are below threshold + categorical_values = {} + text_cols_to_query = [] + for col_name in text_columns: + distinct_count_key = f"{col_name}_distinct_count" + if ( + distinct_count_key in column_stats + and column_stats[distinct_count_key] + and column_stats[distinct_count_key] <= categorical_threshold + ): + text_cols_to_query.append(col_name) + + # Get categorical values in a single query if needed + if text_cols_to_query: + try: + # Build UNION query for all categorical columns + union_parts = [ + f"SELECT '{col_name}' as column_name, {col_name} as value " # noqa: S608 + f"FROM {self._table_name} WHERE {col_name} IS NOT NULL " + f"GROUP BY {col_name}" + for col_name in text_cols_to_query + ] + + if union_parts: + categorical_query = text(" UNION ALL ".join(union_parts)) with self._get_connection() as conn: - distinct_count = conn.execute(count_query).scalar() - if distinct_count and distinct_count <= categorical_threshold: - values_query = text( - f"SELECT DISTINCT {col['name']} FROM {self._table_name} " - f"WHERE {col['name']} IS NOT NULL", - ) - values = [ - str(row[0]) - for row in conn.execute(values_query).fetchall() - ] - values_str = ", ".join([f"'{v}'" for v in values]) - column_info.append(f" Categorical values: {values_str}") - except Exception: # noqa: S110 - pass # Silently skip categorical info if query fails + results = conn.execute(categorical_query).fetchall() + for row in results: + col_name, value = row + if col_name not in categorical_values: + categorical_values[col_name] = [] + categorical_values[col_name].append(str(value)) + except Exception: # noqa: S110 + pass # Skip categorical values if query fails + + # Build schema description using collected statistics + for col in columns: + col_name = col["name"] + sql_type = self._get_sql_type_name(col["type"]) + column_info = [f"- {col_name} ({sql_type})"] + + # Add range info for numeric columns + if col_name in numeric_columns: + min_key = f"{col_name}_min" + max_key = f"{col_name}_max" + if ( + min_key in column_stats + and max_key in column_stats + and column_stats[min_key] is not None + and column_stats[max_key] is not None + ): + column_info.append( + f" Range: {column_stats[min_key]} to {column_stats[max_key]}", + ) + + # Add categorical values for text columns + elif col_name in categorical_values: + values = categorical_values[col_name] + # Remove duplicates and sort + unique_values = sorted(set(values)) + values_str = ", ".join([f"'{v}'" for v in unique_values]) + column_info.append(f" Categorical values: {values_str}") schema.extend(column_info) @@ -273,7 +335,7 @@ def get_data(self) -> pd.DataFrame: The complete dataset as a pandas DataFrame """ - return self.execute_query(f"SELECT * FROM {self._table_name}") + return self.execute_query(f"SELECT * FROM {self._table_name}") # noqa: S608 def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: # noqa: PLR0911 """Convert SQLAlchemy type to SQL type name.""" diff --git a/pkg-py/src/querychat/querychat.py b/pkg-py/src/querychat/querychat.py index 9eba2c47..8c0c768a 100644 --- a/pkg-py/src/querychat/querychat.py +++ b/pkg-py/src/querychat/querychat.py @@ -125,19 +125,12 @@ def __getitem__(self, key: str) -> Any: backwards compatibility only; new code should use the attributes directly instead. """ - if key == "chat": # noqa: SIM116 - return self.chat - elif key == "sql": - return self.sql - elif key == "title": - return self.title - elif key == "df": - return self.df - - raise KeyError( - f"`QueryChat` does not have a key `'{key}'`. " - "Use the attributes `chat`, `sql`, `title`, or `df` instead.", - ) + return { + "chat": self.chat, + "sql": self.sql, + "title": self.title, + "df": self.df, + }.get(key) def system_prompt( diff --git a/pkg-py/tests/__init__.py b/pkg-py/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg-py/tests/test_datasource.py b/pkg-py/tests/test_datasource.py new file mode 100644 index 00000000..08b0ebb8 --- /dev/null +++ b/pkg-py/tests/test_datasource.py @@ -0,0 +1,287 @@ +import sqlite3 +import tempfile +from pathlib import Path + +import pytest +from sqlalchemy import create_engine +from src.querychat.datasource import SQLAlchemySource + + +@pytest.fixture +def test_db_engine(): + """Create a temporary SQLite database with test data.""" + # Create temporary database file + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + temp_db.close() + + # Connect and create test table with various data types + conn = sqlite3.connect(temp_db.name) + cursor = conn.cursor() + + # Create table with different column types + cursor.execute(""" + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY, + name TEXT, + age INTEGER, + salary REAL, + is_active BOOLEAN, + join_date DATE, + category TEXT, + score NUMERIC, + description TEXT + ) + """) + + # Insert test data + test_data = [ + ( + 1, + "Alice", + 30, + 75000.50, + True, + "2023-01-15", + "A", + 95.5, + "Senior developer", + ), + ( + 2, + "Bob", + 25, + 60000.00, + True, + "2023-03-20", + "B", + 87.2, + "Junior developer", + ), + ( + 3, + "Charlie", + 35, + 85000.75, + False, + "2022-12-01", + "A", + 92.1, + "Team lead", + ), + ( + 4, + "Diana", + 28, + 70000.25, + True, + "2023-05-10", + "C", + 89.8, + "Mid-level developer", + ), + ( + 5, + "Eve", + 32, + 80000.00, + True, + "2023-02-28", + "A", + 91.3, + "Senior developer", + ), + ( + 6, + "Frank", + 26, + 62000.50, + False, + "2023-04-15", + "B", + 85.7, + "Junior developer", + ), + (7, "Grace", 29, 72000.75, True, "2023-01-30", "A", 93.4, "Developer"), + ( + 8, + "Henry", + 31, + 78000.25, + True, + "2023-03-05", + "C", + 88.9, + "Senior developer", + ), + ] + + cursor.executemany( + """ + INSERT INTO test_table + (id, name, age, salary, is_active, join_date, category, score, description) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + test_data, + ) + + conn.commit() + conn.close() + + # Create SQLAlchemy engine + engine = create_engine(f"sqlite:///{temp_db.name}") + + yield engine + + # Cleanup + Path(temp_db.name).unlink() + + +def test_get_schema_numeric_ranges(test_db_engine): + """Test that numeric columns include min/max ranges.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Check that numeric columns have range information + assert "- id (INTEGER)" in schema + assert "Range: 1 to 8" in schema + + assert "- age (INTEGER)" in schema + assert "Range: 25 to 35" in schema + + assert "- salary (FLOAT)" in schema + assert "Range: 60000.0 to 85000.75" in schema + + assert "- score (NUMERIC)" in schema + assert "Range: 85.7 to 95.5" in schema + + +def test_get_schema_categorical_values(test_db_engine): + """Test that text columns with few unique values show categorical values.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Category column should be treated as categorical (3 unique values: A, B, C) + assert "- category (TEXT)" in schema + assert "Categorical values:" in schema + assert "'A'" in schema and "'B'" in schema and "'C'" in schema + + +def test_get_schema_non_categorical_text(test_db_engine): + """Test that text columns with many unique values don't show categorical values.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=3) + + # Name and description columns should not be categorical (8 and 6 unique values respectively) + lines = schema.split("\n") + name_line_idx = next( + i for i, line in enumerate(lines) if "- name (TEXT)" in line + ) + description_line_idx = next( + i for i, line in enumerate(lines) if "- description (TEXT)" in line + ) + + # Check that the next line after name column doesn't contain categorical values + if name_line_idx + 1 < len(lines): + assert "Categorical values:" not in lines[name_line_idx + 1] + + # Check that the next line after description column doesn't contain categorical values + if description_line_idx + 1 < len(lines): + assert "Categorical values:" not in lines[description_line_idx + 1] + + +def test_get_schema_different_thresholds(test_db_engine): + """Test that categorical_threshold parameter works correctly.""" + source = SQLAlchemySource(test_db_engine, "test_table") + + # With threshold 2, only category column (3 unique) should not be categorical + schema_low = source.get_schema(categorical_threshold=2) + assert "- category (TEXT)" in schema_low + assert "'A'" not in schema_low # Should not show categorical values + + # With threshold 5, category column should be categorical + schema_high = source.get_schema(categorical_threshold=5) + assert "- category (TEXT)" in schema_high + assert "'A'" in schema_high # Should show categorical values + + +def test_get_schema_table_structure(test_db_engine): + """Test the overall structure of the schema output.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + lines = schema.split("\n") + + # Check header + assert lines[0] == "Table: test_table" + assert lines[1] == "Columns:" + + # Check that all columns are present + expected_columns = [ + "id", + "name", + "age", + "salary", + "is_active", + "join_date", + "category", + "score", + "description", + ] + for col in expected_columns: + assert any(f"- {col} (" in line for line in lines), ( + f"Column {col} not found in schema" + ) + + +def test_get_schema_empty_result_handling(test_db_engine): + """Test handling when statistics queries return empty results.""" + # Create empty table + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + cursor.execute("CREATE TABLE empty_table (id INTEGER, name TEXT)") + conn.commit() + + engine = create_engine("sqlite:///:memory:") + # Recreate table in the new engine + with engine.connect() as connection: + from sqlalchemy import text + + connection.execute( + text("CREATE TABLE empty_table (id INTEGER, name TEXT)") + ) + connection.commit() + + source = SQLAlchemySource(engine, "empty_table") + schema = source.get_schema(categorical_threshold=5) + + # Should still work but without range/categorical info + assert "Table: empty_table" in schema + assert "- id (INTEGER)" in schema + assert "- name (TEXT)" in schema + # Should not have range or categorical information + assert "Range:" not in schema + assert "Categorical values:" not in schema + + +def test_get_schema_boolean_and_date_types(test_db_engine): + """Test handling of boolean and date column types.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Boolean column should show range + assert "- is_active (BOOLEAN)" in schema + # SQLite stores booleans as integers, so should show 0 to 1 range + + # Date column should show range + assert "- join_date (DATE)" in schema + assert "Range:" in schema + + +def test_invalid_table_name(): + """Test that invalid table name raises appropriate error.""" + engine = create_engine("sqlite:///:memory:") + + with pytest.raises( + ValueError, match="Table 'nonexistent' not found in database" + ): + SQLAlchemySource(engine, "nonexistent") diff --git a/pyproject.toml b/pyproject.toml index 3ce33dc4..89a3bff6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,12 @@ packages = ["pkg-py/src/querychat"] include = ["pkg-py/src/querychat", "pkg-py/LICENSE", "pkg-py/README.md"] [tool.uv] -dev-dependencies = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4"] +dev-dependencies = [ + "ruff>=0.6.5", + "pyright>=1.1.401", + "tox-uv>=1.11.4", + "pytest>=8.4.0", +] [tool.ruff] src = ["pkg-py/src/querychat"]