- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2.6k
Text2sql tool - e2e evals and fine-tuning #967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 48 commits
ebed0ef
              edcf746
              76a8caf
              ab7df10
              2e8b278
              2c514b1
              5a18b6b
              0033fc9
              3997357
              44aa896
              46d3245
              b89d945
              3731175
              094ab01
              6d76ea0
              cf54eb4
              e182902
              79945b6
              0aa42d8
              4cdd5f6
              5e8a7b0
              71ca0ae
              b17f90b
              6815255
              03ba7d5
              11a4a64
              99ead57
              ee1fc97
              9c294df
              ef6bbb2
              caf98ec
              4107171
              b02334a
              f7c68c1
              4037737
              f07da72
              57ffb74
              a6f7d02
              7a4ae9f
              9ac5dd1
              6b92409
              c4573ba
              7b508ec
              3c23112
              cc93b73
              6269c15
              2cdfbf0
              4bb7faa
              2bd662c
              b574c6d
              58ea6cb
              33ac1ab
              e10ddda
              5baa1e3
              f894d26
              ad48509
              e059899
              f80e7bf
              b630735
              1ac67d9
              1b802d3
              df598c4
              cb8b0bd
              77d3544
              deca42c
              12a6dfa
              799dee6
              6501cf4
              82bb008
              27a23af
              e38abf1
              fc80546
              be4817c
              af3ea4f
              0c7b348
              54e49bc
              c88e10f
              57c0517
              7edf3d8
              8989e69
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1,30 +1,24 @@ | ||
| ## Text2SQL: Natural Language to SQL Interface | ||
| # Text2SQL: Evaluating and Fine-tuning Llama Models with CoT | ||
|  | ||
| This project provides a set of scripts to convert natural language queries into SQL statements using Meta's Llama model. The goal is to enable users to interact with databases using natural language inputs, making it easier for non-technical users to access and analyze data. | ||
| This folder contains scripts to: | ||
|  | ||
| For detailed instructions on setting up the environment, creating a database, and executing natural language queries using the Text2SQL interface, please refer to the quickstart.ipynb notebook. | ||
| 1. Evaluate Llama (original and fine-tuned) models on the Text2SQL task using the popular [BIRD](https://bird-bench.github.io) dataset. | ||
|  | ||
| ### Structure: | ||
| 2. Generate two supervised fine-tuning (SFT) datasets (with and without CoT) and fine-tuning Llama 3.1 8B with the datasets, using different SFT options: with or without CoT, using quantization or not, full fine-tuning (FFT) or parameter-efficient fine-tuning (PEFT). The non-quantized PEFT CoT SFT has the most performance gains: from 39.47% of the original Llama 3.1 8B model to 43.35%. (Note: the results are based on 3 epochs of SFT.) | ||
|  | ||
| - quickstart.ipynb: A Quick Demo of Text2SQL Using Llama 3.3. This Jupyter Notebook includes examples of how to use the interface to execute natural language queries on the sample data. It uses Llama 3.3 to answer questions about a SQLite database using LangChain and the Llama cloud provider Together.ai. | ||
| - nba.txt: A text file containing NBA roster information, which is used as sample data for demonstration purposes. | ||
| - txt2csv.py: A script that converts text data into a CSV format. This script is used to preprocess the input data before it is fed into csv2db.py. | ||
| - csv2db.py: A script that imports data from a CSV file into a SQLite database. This script is used to populate the database with sample data. | ||
| - nba_roster.db: A SQLite database file created from the nba.txt data, used to test the Text2SQL interface. | ||
| Our end goal is to maximize the accuracy of Llama models on the Text2SQL task. To do so we need to first evaluate the current state of the art Llama models on the task, then apply fine-tuning, agent and other approaches to evaluate and improve Llama's performance. | ||
|  | ||
| ### Detailed steps on running the notebook: | ||
| ## Structure: | ||
|  | ||
| - Before getting started, please make sure to setup Together.ai and get an API key from [here](https://www.together.ai/). | ||
| - data: contains scripts to download the BIRD TRAIN and DEV datasets; | ||
| - eval: contains scripts to evaluate Llama models (original and fine-tuned) on the BIRD dataset; | ||
| - fine-tune: contains scripts to generate non-CoT and CoT datasets based on the BIRD TRAIN set and to supervised fine-tune Llama models using the datasets, with different SFT options (quantization or not, full fine-tuning or parameter-efficient fine-tuning); | ||
| - quickstart: contains a notebook to ask Llama 3.3 to convert natural language queries into SQL queries. | ||
|  | ||
| - First, please install the requirements from [here](https://github.com/meta-llama/llama-cookbook/blob/main/end-to-end-use-cases/coding/text2sql/requirements.txt) by running inside the folder: | ||
| ## Next Steps | ||
|  | ||
| ``` | ||
| git clone https://github.com/meta-llama/llama-cookbook.git | ||
| cd llama-cookbook/end-to-end-use-cases/coding/text2sql/ | ||
| pip install -r requirements.txt | ||
| ``` | ||
|  | ||
| ### Contributing | ||
| Contributions are welcome! If you'd like to add new features or improve existing ones, please submit a pull request. We encourage contributions in the following areas: | ||
| - Adding support for additional databases | ||
| - Developing new interfaces or applications that use the Text2SQL interface | ||
| 1. Hyper-parameter tuning of the current SFT scripts. | ||
| 2. Try GRPO reinforcement learning to further improve the accuracy. | ||
| 3. Fine-tune Llama 3.3 70B and Llama 4 models. | ||
| 4. Try agentic workflow. | ||
| 5. Expand the eval to support other enterprise databases. | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it be dataset vs database? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the "database" - to expand the eval to go beyond the current sqlite and include Oracle, etc. | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| wget https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip | ||
| unzip dev.zip | ||
| rm dev.zip | ||
| rm -rf __MACOSX | ||
| cd dev_20240627 | ||
| unzip dev_databases.zip | ||
| rm dev_databases.zip | ||
| rm -rf __MACOSX | ||
| cd .. | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| wget https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip | ||
| UNZIP_DISABLE_ZIPBOMB_DETECTION=TRUE unzip train.zip | ||
| rm train.zip | ||
| rm -rf __MACOSX | ||
| cd train | ||
| unzip train_databases.zip | ||
| rm train_databases.zip | ||
| rm -rf __MACOSX | ||
| cd .. | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| # Llama Text2SQL Evaluation | ||
|  | ||
| We have updated and simplified the original eval scripts from the BIRD [repo](https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/bird) to 3 simple steps for Llama 3 & 4 models hosted via Meta's [Llama API](https://llama.developer.meta.com) or [Together.ai](https://together.ai), as well as the fine-tuned Llama 3.1 model. | ||
|  | ||
| ## Evaluation Results | ||
|  | ||
| Below are the results of the Llama models we have evaluated on the BIRD DEV dataset: | ||
|  | ||
| | Model | Llama API Accuracy | Together Accuracy | | ||
|          | ||
| |------------------------|--------------------|-------------------| | ||
| | Llama 3.1 8b | - | 35.66% | | ||
| | Llama 3.3 70b | 54.11% | 54.63% | | ||
| | Llama-3.1-405B | - | 55.80% | | ||
| | Llama 4 Scout | 44.39% | 43.94% | | ||
| | Llama 4 Maverick | 44.00% | 41.46% | | ||
|  | ||
| - Llama 3.1 8b on Hugging Face: quantized 14.02%, non-quantized 39.47% | ||
| - Non-quantized FFT with no CoT dataset: 36.31% | ||
| - Non-quantized FFT with CoT dataset: 43.87% | ||
|  | ||
| ## Quick Start | ||
|  | ||
| First, run the commands below to create a new Conda environment and install all the required packages for Text2SQL evaluation and fine-tuning: | ||
|  | ||
| ``` | ||
| git clone https://github.com/meta-llama/llama-cookbook | ||
| cd llama-cookbook/end-to-end-use-cases/coding/text2sql | ||
| conda create -n llama-text2sql python=3.10 | ||
| conda activate llama-text2sql | ||
| pip install -r requirements.txt | ||
| ``` | ||
|  | ||
| Then, follow the steps below to evaluate Llama 3 & 4 models on Text2SQL using the BIRD benchmark: | ||
|  | ||
| 1. Get the DEV dataset: | ||
| ``` | ||
| cd data | ||
| sh download_dev_unzip.sh | ||
| cd ../eval | ||
| ``` | ||
|  | ||
| 2. Open `llama_eval.sh` and set `YOUR_API_KEY` to your [Llama API](https://llama.developer.meta.com/) key or [Together](https://api.together.ai/) API key, then uncomment a line that starts with `model=` to specify the Llama model to use for the text2sql eval. | ||
|  | ||
| 3. Run the evaluation script `sh llama_eval.sh`, which will use the BIRD DEV dataset (1534 examples in total) with external knowledge turned on to run the Llama model on each text question and compare the generated SQL with the gold SQL. | ||
|  | ||
| If your API key or model name is incorrect, the script will exit with an authentication or model not supported error. | ||
|  | ||
| After the script completes, you'll see the accuracy of the Llama model on the BIRD DEV text2sql. For example, the total accuracy is about 54.24% with `YOUR_API_KEY` set to your Llama API key and `model='Llama-3.3-70B-Instruct'`, or about 35.07% with `YOUR_API_KEY` set to your Together API key and `model=meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo`. | ||
|  | ||
| To compare your evaluated accuracy of your selected Llama model with other results in the BIRD Dev leaderboard, click [here](https://bird-bench.github.io/). | ||
|  | ||
| ## Evaluation Process | ||
|  | ||
| 1. **SQL Generation**: `llama_text2sql.py` sends natural language questions to the specified Llama model and collects the generated SQL queries. | ||
|  | ||
| 2. **SQL Execution**: `text2sql_eval.py` executes both the generated SQL and ground truth SQL against the corresponding databases, then continues with steps 3 and 4 below. | ||
|  | ||
| 3. **Result Comparison**: The results from executing the generated SQL are compared ([source code](text2sql_eval.py#L30)) with the results from the ground truth SQL to determine correctness. | ||
|  | ||
| 4. **Accuracy Calculation**: Accuracy scores are calculated overall and broken down by difficulty levels (simple, moderate, challenging). | ||
|  | ||
| ## Supported Models for Evaluation | ||
|  | ||
| Llama models supported on Together AI: | ||
| - meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo | ||
| - meta-llama/Llama-3.3-70B-Instruct-Turbo | ||
| - meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 | ||
| - meta-llama/Llama-4-Scout-17B-16E-Instruct | ||
| - other Llama models hosted on Together AI | ||
|  | ||
| Llama models supported on Llama API: | ||
| - Llama-3.3-8B-Instruct | ||
| - Llama-3.3-70B-Instruct | ||
| - Llama-4-Maverick-17B-128E-Instruct-FP8 | ||
| - Llama-4-Scout-17B-16E-Instruct-FP8 | ||
| - other Llama models hosted on Llama API | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| eval_path='../data/dev_20240627/dev.json' | ||
| db_root_path='../data/dev_20240627/dev_databases/' | ||
| ground_truth_path='../data/' | ||
|  | ||
| # Llama models on Together | ||
| #YOUR_API_KEY='YOUR_TOGETHER_API_KEY' | ||
| #model='meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo' | ||
| #model='meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo' | ||
| #model='meta-llama/Llama-3.3-70B-Instruct-Turbo' | ||
| #model='meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8' | ||
| #model='meta-llama/Llama-4-Scout-17B-16E-Instruct' | ||
|  | ||
| # Llama models on Llama API | ||
| YOUR_API_KEY='YOUR_LLAMA_API_KEY' | ||
| model='Llama-3.3-8B-Instruct' | ||
| #model='Llama-3.3-70B-Instruct' | ||
| #model='Llama-4-Maverick-17B-128E-Instruct-FP8' | ||
| #model='Llama-4-Scout-17B-16E-Instruct-FP8' | ||
|  | ||
| # Llama model on Hugging Face Hub | ||
| # https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct | ||
| # YOUR_API_KEY='huggingface' | ||
| # model='meta-llama/Llama-3.1-8B-Instruct' | ||
|  | ||
| # Fine-tuned Llama models locally | ||
| #YOUR_API_KEY='finetuned' | ||
| #model='../fine_tuning/llama31-8b-text2sql-fine-tuned' | ||
|  | ||
| data_output_path="./output/$model/" | ||
|  | ||
| echo "Text2SQL using $model" | ||
| python3 -u llama_text2sql.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} \ | ||
| --model ${model} --eval_path ${eval_path} --data_output_path ${data_output_path} | ||
|  | ||
| # Check if llama_text2sql.py exited successfully | ||
| if [ $? -eq 0 ]; then | ||
| echo "llama_text2sql.py completed successfully. Proceeding with evaluation..." | ||
| python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \ | ||
| --ground_truth_path ${ground_truth_path} \ | ||
| --diff_json_path ${eval_path} | ||
|  | ||
| echo "Done evaluating $model." | ||
|  | ||
| else | ||
| echo "Error: llama_text2sql.py failed with exit code $?. Skipping evaluation." | ||
| exit 1 | ||
| fi | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My read into this was more along these lines:
"This post present a fine-tuning recipe to improve the performance of llama models on text to sql by adding COT to our data. We also provide, an easy way to evaluate llama models on SQL capabilities using BIRD datasets."
We want to center the message around the best method that boosted the perf, which was COT in this case, everything comes complementary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree the end goal is to maximize the perf but different users may have different needs - to some, they just need an easy way to do eval, so a more objective messaging (eval + fine-tuning) may be better. But we can stress CoT for the fine-tuning part.