Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,8 @@ build

*.xlsx

lib
lib

hf_cache/
downloaded_databases/
downloaded_databases_vec/
6 changes: 6 additions & 0 deletions dataflow/operators/text2sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,23 @@
# filter
from filter.sql_consistency_filter import SQLConsistencyFilter
from filter.sql_execution_filter import SQLExecutionFilter
from filter.vecsql_execution_filter import VecSQLExecutionFilter

# generate
from generate.sql_generator import SQLGenerator
from generate.vecsql_generator import VecSQLGenerator
from generate.sql_variation_generator import SQLVariationGenerator
from generate.text2sql_cot_generator import Text2SQLCoTGenerator
from generate.text2sql_prompt_generator import Text2SQLPromptGenerator
from generate.text2vecsql_prompt_generator import Text2VecSQLPromptGenerator
from generate.text2sql_question_generator import Text2SQLQuestionGenerator
from generate.text2vecsql_question_generator import Text2VecSQLQuestionGenerator

# eval
from eval.sql_component_classifier import SQLComponentClassifier
from eval.sql_component_classifier import VecSQLComponentClassifier
from eval.sql_execution_classifier import SQLExecutionClassifier
from eval.vecsql_execution_classifier import VecSQLExecutionClassifier

else:
import sys
Expand Down
93 changes: 92 additions & 1 deletion dataflow/operators/text2sql/eval/sql_component_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,4 +770,95 @@ def run(self, storage: DataFlowStorage,
output_file = storage.write(dataframe)
self.logger.info(f"Extracted answers saved to {output_file}")

return [self.output_difficulty_key]
return [self.output_difficulty_key]

@OPERATOR_REGISTRY.register()
class VecSQLComponentClassifier(OperatorABC):
def __init__(self,
difficulty_thresholds: list,
difficulty_labels: list
):
self.difficulty_config = {
'thresholds': difficulty_thresholds,
'labels': difficulty_labels
}
self.logger = get_logger()
if len(self.difficulty_config['thresholds']) != len(self.difficulty_config['labels']) - 1:
raise ValueError("Thresholds and labels configuration mismatch")

def check_column(self, dataframe):
required_columns = [self.input_sql_key]
missing_columns = [col for col in required_columns if col not in dataframe.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")

@staticmethod
def get_desc(lang):
if lang == "zh":
return (
"根据VecSQL的组件数量和复杂度,评估SQL的难度。\n\n"
"输入参数:\n"
"- input_sql_key: 输入SQL列名\n\n"
"输出参数:\n"
"- output_difficulty_key: 输出难度列名"
)
elif lang == "en":
return (
"This operator evaluates the difficulty of VecSQL components based on the number and complexity of components.\n\n"
"Input parameters:\n"
"- input_sql_key: The name of the input SQL column\n\n"
"Output parameters:\n"
"- output_difficulty_key: The name of the output difficulty column"
)
else:
return "SQL component difficulty evaluator for Text2SQL tasks."

def get_schema(self, db):
try:
import sqlite_vec
import sqlite_lembed
except ImportError:
logging.info("Fatal Error: 'sqlite_vec or sqlite_lembed' library not installed. Please install with 'pip install sqlite_vec sqlite_lembed'")
exit()
schema = {}
conn = sqlite3.connect(db)
cursor = conn.cursor()
# load sqlite-vec
conn.enable_load_extension(True)
sqlite_vec.load(conn)
sqlite_lembed.load(conn)

cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [str(table[0].lower()) for table in cursor.fetchall()]

for table in tables:
cursor.execute("PRAGMA table_info({})".format(table))
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]

return schema

def report_statistics(self, dataframe: pd.DataFrame):
counts = dataframe[self.output_difficulty_key].value_counts()
self.logger.info("SQL Difficulty Statistics")
difficulty_counts = {d: counts.get(d, 0) for d in ['easy', 'medium', 'hard', 'extra']}
self.logger.info(" | ".join([f"{d.title()}: {v}" for d, v in difficulty_counts.items()]))

def run(self, storage: DataFlowStorage,
input_sql_key: str = "SQL",
output_difficulty_key: str = "sql_component_difficulty"):

self.input_sql_key = input_sql_key
self.output_difficulty_key = output_difficulty_key
dataframe = storage.read("dataframe")
self.check_column(dataframe)
for idx, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="Processing"):
sql = row.get(self.input_sql_key)
sql_hardness = EvalHardnessLite(sql, self.difficulty_config)
hardness = sql_hardness.run()
dataframe.at[idx, self.output_difficulty_key] = hardness

self.report_statistics(dataframe)
output_file = storage.write(dataframe)
self.logger.info(f"Extracted answers saved to {output_file}")

return [self.output_difficulty_key]
Loading