-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Text2sql tool - e2e evals and fine-tuning #967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 48 commits
ebed0ef
edcf746
76a8caf
ab7df10
2e8b278
2c514b1
5a18b6b
0033fc9
3997357
44aa896
46d3245
b89d945
3731175
094ab01
6d76ea0
cf54eb4
e182902
79945b6
0aa42d8
4cdd5f6
5e8a7b0
71ca0ae
b17f90b
6815255
03ba7d5
11a4a64
99ead57
ee1fc97
9c294df
ef6bbb2
caf98ec
4107171
b02334a
f7c68c1
4037737
f07da72
57ffb74
a6f7d02
7a4ae9f
9ac5dd1
6b92409
c4573ba
7b508ec
3c23112
cc93b73
6269c15
2cdfbf0
4bb7faa
2bd662c
b574c6d
58ea6cb
33ac1ab
e10ddda
5baa1e3
f894d26
ad48509
e059899
f80e7bf
b630735
1ac67d9
1b802d3
df598c4
cb8b0bd
77d3544
deca42c
12a6dfa
799dee6
6501cf4
82bb008
27a23af
e38abf1
fc80546
be4817c
af3ea4f
0c7b348
54e49bc
c88e10f
57c0517
7edf3d8
8989e69
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,30 +1,24 @@ | ||
| ## Text2SQL: Natural Language to SQL Interface | ||
| # Text2SQL: Evaluating and Fine-tuning Llama Models with CoT | ||
|
|
||
| This project provides a set of scripts to convert natural language queries into SQL statements using Meta's Llama model. The goal is to enable users to interact with databases using natural language inputs, making it easier for non-technical users to access and analyze data. | ||
| This folder contains scripts to: | ||
|
|
||
| For detailed instructions on setting up the environment, creating a database, and executing natural language queries using the Text2SQL interface, please refer to the quickstart.ipynb notebook. | ||
| 1. Evaluate Llama (original and fine-tuned) models on the Text2SQL task using the popular [BIRD](https://bird-bench.github.io) dataset. | ||
|
|
||
| ### Structure: | ||
| 2. Generate two supervised fine-tuning (SFT) datasets (with and without CoT) and fine-tuning Llama 3.1 8B with the datasets, using different SFT options: with or without CoT, using quantization or not, full fine-tuning (FFT) or parameter-efficient fine-tuning (PEFT). The non-quantized PEFT CoT SFT has the most performance gains: from 39.47% of the original Llama 3.1 8B model to 43.35%. (Note: the results are based on 3 epochs of SFT.) | ||
|
|
||
| - quickstart.ipynb: A Quick Demo of Text2SQL Using Llama 3.3. This Jupyter Notebook includes examples of how to use the interface to execute natural language queries on the sample data. It uses Llama 3.3 to answer questions about a SQLite database using LangChain and the Llama cloud provider Together.ai. | ||
| - nba.txt: A text file containing NBA roster information, which is used as sample data for demonstration purposes. | ||
| - txt2csv.py: A script that converts text data into a CSV format. This script is used to preprocess the input data before it is fed into csv2db.py. | ||
| - csv2db.py: A script that imports data from a CSV file into a SQLite database. This script is used to populate the database with sample data. | ||
| - nba_roster.db: A SQLite database file created from the nba.txt data, used to test the Text2SQL interface. | ||
| Our end goal is to maximize the accuracy of Llama models on the Text2SQL task. To do so we need to first evaluate the current state of the art Llama models on the task, then apply fine-tuning, agent and other approaches to evaluate and improve Llama's performance. | ||
|
|
||
| ### Detailed steps on running the notebook: | ||
| ## Structure: | ||
|
|
||
| - Before getting started, please make sure to setup Together.ai and get an API key from [here](https://www.together.ai/). | ||
| - data: contains scripts to download the BIRD TRAIN and DEV datasets; | ||
| - eval: contains scripts to evaluate Llama models (original and fine-tuned) on the BIRD dataset; | ||
| - fine-tune: contains scripts to generate non-CoT and CoT datasets based on the BIRD TRAIN set and to supervised fine-tune Llama models using the datasets, with different SFT options (quantization or not, full fine-tuning or parameter-efficient fine-tuning); | ||
| - quickstart: contains a notebook to ask Llama 3.3 to convert natural language queries into SQL queries. | ||
|
|
||
| - First, please install the requirements from [here](https://github.com/meta-llama/llama-cookbook/blob/main/end-to-end-use-cases/coding/text2sql/requirements.txt) by running inside the folder: | ||
| ## Next Steps | ||
|
|
||
| ``` | ||
| git clone https://github.com/meta-llama/llama-cookbook.git | ||
| cd llama-cookbook/end-to-end-use-cases/coding/text2sql/ | ||
| pip install -r requirements.txt | ||
| ``` | ||
|
|
||
| ### Contributing | ||
| Contributions are welcome! If you'd like to add new features or improve existing ones, please submit a pull request. We encourage contributions in the following areas: | ||
| - Adding support for additional databases | ||
| - Developing new interfaces or applications that use the Text2SQL interface | ||
| 1. Hyper-parameter tuning of the current SFT scripts. | ||
| 2. Try GRPO reinforcement learning to further improve the accuracy. | ||
| 3. Fine-tune Llama 3.3 70B and Llama 4 models. | ||
| 4. Try agentic workflow. | ||
| 5. Expand the eval to support other enterprise databases. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it be dataset vs database? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the "database" - to expand the eval to go beyond the current sqlite and include Oracle, etc. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| wget https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip | ||
| unzip dev.zip | ||
| rm dev.zip | ||
| rm -rf __MACOSX | ||
| cd dev_20240627 | ||
| unzip dev_databases.zip | ||
| rm dev_databases.zip | ||
| rm -rf __MACOSX | ||
| cd .. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| wget https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip | ||
| UNZIP_DISABLE_ZIPBOMB_DETECTION=TRUE unzip train.zip | ||
| rm train.zip | ||
| rm -rf __MACOSX | ||
| cd train | ||
| unzip train_databases.zip | ||
| rm train_databases.zip | ||
| rm -rf __MACOSX | ||
| cd .. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| # Llama Text2SQL Evaluation | ||
|
|
||
| We have updated and simplified the original eval scripts from the BIRD [repo](https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/bird) to 3 simple steps for Llama 3 & 4 models hosted via Meta's [Llama API](https://llama.developer.meta.com) or [Together.ai](https://together.ai), as well as the fine-tuned Llama 3.1 model. | ||
|
|
||
| ## Evaluation Results | ||
|
|
||
| Below are the results of the Llama models we have evaluated on the BIRD DEV dataset: | ||
|
|
||
| | Model | Llama API Accuracy | Together Accuracy | | ||
|
||
| |------------------------|--------------------|-------------------| | ||
| | Llama 3.1 8b | - | 35.66% | | ||
| | Llama 3.3 70b | 54.11% | 54.63% | | ||
| | Llama-3.1-405B | - | 55.80% | | ||
| | Llama 4 Scout | 44.39% | 43.94% | | ||
| | Llama 4 Maverick | 44.00% | 41.46% | | ||
|
|
||
| - Llama 3.1 8b on Hugging Face: quantized 14.02%, non-quantized 39.47% | ||
| - Non-quantized FFT with no CoT dataset: 36.31% | ||
| - Non-quantized FFT with CoT dataset: 43.87% | ||
|
|
||
| ## Quick Start | ||
|
|
||
| First, run the commands below to create a new Conda environment and install all the required packages for Text2SQL evaluation and fine-tuning: | ||
|
|
||
| ``` | ||
| git clone https://github.com/meta-llama/llama-cookbook | ||
| cd llama-cookbook/end-to-end-use-cases/coding/text2sql | ||
| conda create -n llama-text2sql python=3.10 | ||
| conda activate llama-text2sql | ||
| pip install -r requirements.txt | ||
| ``` | ||
|
|
||
| Then, follow the steps below to evaluate Llama 3 & 4 models on Text2SQL using the BIRD benchmark: | ||
|
|
||
| 1. Get the DEV dataset: | ||
| ``` | ||
| cd data | ||
| sh download_dev_unzip.sh | ||
| cd ../eval | ||
| ``` | ||
|
|
||
| 2. Open `llama_eval.sh` and set `YOUR_API_KEY` to your [Llama API](https://llama.developer.meta.com/) key or [Together](https://api.together.ai/) API key, then uncomment a line that starts with `model=` to specify the Llama model to use for the text2sql eval. | ||
|
|
||
| 3. Run the evaluation script `sh llama_eval.sh`, which will use the BIRD DEV dataset (1534 examples in total) with external knowledge turned on to run the Llama model on each text question and compare the generated SQL with the gold SQL. | ||
|
|
||
| If your API key or model name is incorrect, the script will exit with an authentication or model not supported error. | ||
|
|
||
| After the script completes, you'll see the accuracy of the Llama model on the BIRD DEV text2sql. For example, the total accuracy is about 54.24% with `YOUR_API_KEY` set to your Llama API key and `model='Llama-3.3-70B-Instruct'`, or about 35.07% with `YOUR_API_KEY` set to your Together API key and `model=meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo`. | ||
|
|
||
| To compare your evaluated accuracy of your selected Llama model with other results in the BIRD Dev leaderboard, click [here](https://bird-bench.github.io/). | ||
|
|
||
| ## Evaluation Process | ||
|
|
||
| 1. **SQL Generation**: `llama_text2sql.py` sends natural language questions to the specified Llama model and collects the generated SQL queries. | ||
|
|
||
| 2. **SQL Execution**: `text2sql_eval.py` executes both the generated SQL and ground truth SQL against the corresponding databases, then continues with steps 3 and 4 below. | ||
|
|
||
| 3. **Result Comparison**: The results from executing the generated SQL are compared ([source code](text2sql_eval.py#L30)) with the results from the ground truth SQL to determine correctness. | ||
|
|
||
| 4. **Accuracy Calculation**: Accuracy scores are calculated overall and broken down by difficulty levels (simple, moderate, challenging). | ||
|
|
||
| ## Supported Models for Evaluation | ||
|
|
||
| Llama models supported on Together AI: | ||
| - meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo | ||
| - meta-llama/Llama-3.3-70B-Instruct-Turbo | ||
| - meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 | ||
| - meta-llama/Llama-4-Scout-17B-16E-Instruct | ||
| - other Llama models hosted on Together AI | ||
|
|
||
| Llama models supported on Llama API: | ||
| - Llama-3.3-8B-Instruct | ||
| - Llama-3.3-70B-Instruct | ||
| - Llama-4-Maverick-17B-128E-Instruct-FP8 | ||
| - Llama-4-Scout-17B-16E-Instruct-FP8 | ||
| - other Llama models hosted on Llama API | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| eval_path='../data/dev_20240627/dev.json' | ||
| db_root_path='../data/dev_20240627/dev_databases/' | ||
| ground_truth_path='../data/' | ||
|
|
||
| # Llama models on Together | ||
| #YOUR_API_KEY='YOUR_TOGETHER_API_KEY' | ||
| #model='meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo' | ||
| #model='meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo' | ||
| #model='meta-llama/Llama-3.3-70B-Instruct-Turbo' | ||
| #model='meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8' | ||
| #model='meta-llama/Llama-4-Scout-17B-16E-Instruct' | ||
|
|
||
| # Llama models on Llama API | ||
| YOUR_API_KEY='YOUR_LLAMA_API_KEY' | ||
| model='Llama-3.3-8B-Instruct' | ||
| #model='Llama-3.3-70B-Instruct' | ||
| #model='Llama-4-Maverick-17B-128E-Instruct-FP8' | ||
| #model='Llama-4-Scout-17B-16E-Instruct-FP8' | ||
|
|
||
| # Llama model on Hugging Face Hub | ||
| # https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct | ||
| # YOUR_API_KEY='huggingface' | ||
| # model='meta-llama/Llama-3.1-8B-Instruct' | ||
|
|
||
| # Fine-tuned Llama models locally | ||
| #YOUR_API_KEY='finetuned' | ||
| #model='../fine_tuning/llama31-8b-text2sql-fine-tuned' | ||
|
|
||
| data_output_path="./output/$model/" | ||
|
|
||
| echo "Text2SQL using $model" | ||
| python3 -u llama_text2sql.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} \ | ||
| --model ${model} --eval_path ${eval_path} --data_output_path ${data_output_path} | ||
|
|
||
| # Check if llama_text2sql.py exited successfully | ||
| if [ $? -eq 0 ]; then | ||
| echo "llama_text2sql.py completed successfully. Proceeding with evaluation..." | ||
| python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \ | ||
| --ground_truth_path ${ground_truth_path} \ | ||
| --diff_json_path ${eval_path} | ||
|
|
||
| echo "Done evaluating $model." | ||
|
|
||
| else | ||
| echo "Error: llama_text2sql.py failed with exit code $?. Skipping evaluation." | ||
| exit 1 | ||
| fi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My read into this was more along these lines:
"This post present a fine-tuning recipe to improve the performance of llama models on text to sql by adding COT to our data. We also provide, an easy way to evaluate llama models on SQL capabilities using BIRD datasets."
We want to center the message around the best method that boosted the perf, which was COT in this case, everything comes complementary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree the end goal is to maximize the perf but different users may have different needs - to some, they just need an easy way to do eval, so a more objective messaging (eval + fine-tuning) may be better. But we can stress CoT for the fine-tuning part.