We Fine-Tuned GPT-4 to Beat the Industry Standard for Text2SQL
Our machine learning team at Scale has recently fine-tuned GPT-4 to achieve state-of-the-art performance (84% accuracy) for generalized text-to-SQL translation on one of the most popular benchmark datasets, the SpiderDev Set. In this blog post, we will discuss why text2sql is an important use case, why it is hard in practice, where fine-tuning can help, how we implemented a real-world solution, and finally, what our results were.
Why is Text2SQL important?
Most business decisions today are data-driven decisions. This means that organizations collect, aggregate, and interpret large amounts of available information about their business or the market environment with a set of tools and processes that are often summarized as business intelligence or BI. However, obtaining the relevant pieces of information from the vast amounts of available data typically requires analytical expertise (SQL or similar) and knowledge of the relevant databases, dashboards, or related tools. This often creates a massive bottleneck and reliance on data analysts to build these tools, which then proliferate and become hard to navigate. Multi-billion dollar industries have emerged to provide generalist or highly specialized analytics tools to bridge this gap.
The advent of large language models is poised to change this paradigm with the ability to generate SQL queries directly from natural language questions such as “How many vehicles did we sell last year?”. Building generative models that can robustly generate SQL queries for any given set of databases hence has the potential to disrupt an entire industry and truly democratize access to structured data at large.
Why are LLMs still bad at SQL in the real world?
Running some basic tests using models like OpenAI’s ChatGPT provides very promising results:
Also, looking at the leaderboard of benchmark datasets like the infamous SpiderDev makes it appear that the problem is pretty much solved:
However, despite the impressive code generation capabilities of state-of-the-art language models like GPT-4, they are not immediately useful for generating queries that run on custom, real-world databases. First of all, the LLMs do not know the schema of the databases in question out of the box. The most obvious solution is to provide the schema to the model in addition to the prompt.
However, in many cases, real-world databases will have hundreds of columns with custom names. The schema might not fit into the context window of the prompt and even if it does, the model still does not understand the meaning of the column names and how they relate to each other. For example, does a “date” column in a vehicle sales database record the time the sale was recorded or the time the sale took place? A very robust understanding of typical business terms for the given databases and the column contents is essential, especially to correctly apply aggregation and window functions. The relationships between multiple table schemas are also difficult to convey in the prompt, but this is required for slightly more complex operations like JOINs.
How can fine-tuning and retrieval help to resolve these challenges?
At Scale, we are working with enterprise customers across many industries to build customized Generative AI solutions for their respective use cases. In most of these applications, we fine-tune an underlying base model to solve the relevant business problems at the required accuracy level. Fine-tuning not only can improve the performance of a model for a given task, but can also drive model safety and alignment, ensuring a certain tone and behavior. It is also a good way to improve ROI as it can be used to teach smaller (and cheaper) models a very specific skill and eventually even outperform much bigger, generalized models at this task.
Fine-tuning is an effective way to improve the specificity of a certain skill that the model is capable of performing but has not yet mastered. It can be used to teach a model highly specific terms and instructions and improve its capabilities. A good way to figure out if fine-tuning is going to work is by experimenting with prompt engineering. As a rule of thumb, if prompt engineering shows promising results, then fine-tuning will likely be effective for the given task.
Conversely, fine-tuning is not a good way to add new data or knowledge, such as the database schema or even detailed explanations of columns and their relationships to the model. Instead, this type of context information is best infused into a model using Retrieval Augmented Generation or RAG (see this recent blog post for a deep dive on RAG). Hence, a real-world solution will likely have to include both fine-tuning and RAG to achieve acceptable results.
How did we implement our solution?
Our solution reflects a system intended for enterprise customers. Accordingly, we benchmarked multiple techniques that could be used for real-world use cases against a baseline:
-
Full database schema with off-the-shelf model (Baseline)
-
Schema RAG with In-context-learning (ICL)
-
Fine Tuned model against Schema RAG and ICL
We’ll now walk through each of these in more detail.
Database Schema Retrieval
The Spider dataset is the standard benchmark for comparing natural language to SQL models and methods. However, real-world enterprise SQL databases differ from Spider in both size and complexity. Whereas 90% of the databases in Spider’s Train and Dev datasets contain fewer than 50 columns, enterprise databases contain up to and beyond 1000 unique columns. This discrepancy renders the common approach of providing the entire database schema in the prompt infeasible for real-world use cases, given token limit constraints and the “lost in the middle problem.”
As many SQL queries require only a fraction of all columns, we solve the above dilemma with a fine-tuned retrieval system, which retrieves those database features relevant to a user’s question. Given a customer’s database schema, we can fine-tune a model to learn the unique ways customers refer to their database. Once the embedding model is deployed into the backend of our Enterprise Generative AI Platform (EGP), we can easily create, populate, and query the retrieval Knowledge Base.
from scale_egp.sdk.client import EGPClient
from scale_egp.sdk.models import S3DataSourceConfig, CharacterChunkingStrategyConfig
# Instantiate client
client = EGPClient()
# This is Pseudocode
embedding_model = client.models().create(
model_type="embedding",
model_template="sentence_transformers_embedding_model",
model_template_config={
weights_uri: "s3://model_weights/presigned_url"
}
)
# Create a Knowledge Base
knowledge_base = client.knowledge_bases().create(
name="my-knowledge-base"
embedding_model_id=embedding_model.id,
)
# Configure knowledge base and uplaod
data_source = S3DataSourceConfig(
s3_bucket="my-bucket",
s3_prefix="my-schema",
)
chunking_strategy_config = CharacterChunkingStrategyConfig()
upload = client.knowledge_bases().uploads().create_remote_upload(
knowledge_base=knowledge_base,
data_source_config=data_source
chunking_strategy_config=chunking_strategy_config
)
# Query schema for a given question
query = "What was last month's total expense for service provider X?"
retrieved_schema = client.knowledge_bases().query(
knowledge_base=knowledge_base,
query=query,
top_k=20,
)
In context learning
With the schema retrieval providing a lower token count, we can supply more relevant context to address misalignment between business terms and the database schema. On initial deployment, we collect terms and information relevant to SQL logic. For example, “the term MPG refers to miles per gallon”, or “stock means filter on asset_type=’equity’”. This initial retrieval mechanism is the primer in the engine of our data flywheel. As users then interact with the tool, we collect real-world samples that can be retrieved for in-context learning. With this additional corpus, we provide queries from similar questions in the context window provided to the LLM.
Fine Tuning
Tying together prompt engineering and the above retrieval mechanisms, we optimize the density of information available to the out-of-the-box model. To get the final accuracy boost, we turn to fine-tuning. For customers, this means collecting real user data and fine-tuning an LLM to learn the nuances of their data and terminology. The fine-tuning not only fills in gaps in the retrieval data but hones in on trends or relationships unknown to users. Once the model is complete, each step is seamlessly integrated with the EGP SDK.
To prove out the viability of this system, we leveraged OpenAI’s Fine-tuning API to train GPT-3.5 and GPT-4. For each question-query pair in the Spider training set, we used the above methodology to create prompts with at most 20 features and five relevant question, SQL query pairs. For each generated prompt, we simply set the target to the respective SQL query. We use the same approach to generate a validation set of question, query pairs from the Spider dev set. After packaging up the train and validation sets into files, uploading the data, fine tuning models and generating validation predictions was completely handled by OpenAI’s robust APIs.
# Create a custom LLM
LLM_MODEL = client.models().create(
model_type="llm",
model_template="llm_engine_model",
model_template_config={
weights_uri: "s3://model_weights/presigned_url"
}
)
class Text2SQLApplication:
name = "Text2SQL"
description = "Natural language to SQL queries for My Company"
llm_model = LLM_MODEL.id
def __init__(self, schema_knowledge_base_id: str,
icl_knowledge_base_id: str):
self.schema_kb_id = schema_knowledge_base_id
self.icl_kb_id = icl_knowledge_base_id
def create_prompt(self, question: str, schema_chunks: List[str], icl_chunks:
List[str]) -> str:
...
return rag_prompt
def generate(self, question: str, schema_k: int, icl_k: int) -> str:
# Retreive relevant schema information
schema_chunks = client.knowledge_bases().query(
knowledge_base=self.schema_kb_id,
query=question,
top_k=schema_k
)
# Retrieve in-context-learning samples
icl_chunks = client.knowledge_bases().query(
knowledge_base=self.icl_kb_id,
query=question,
top_k=icl_k
)
# Generate prompt from retrieval
rag_prompt = self.create_prompt(question, schema_chunks, icl_chunks)
# Generate SQL
generate_response = client.completions().create(
model=self.llm_model,
prompt=rag_prompt
)
return generate_response.completion.text
Validating Against Spider
We validate the system and benchmark performance using the Spider dataset. For schema retrieval, we fine-tuned a Sentence Transformer to match questions with their relevant database columns and achieved 97% recall@20. With respect to in-context learning, we leverage an out-of-the-box Sentence Transformer. For both GPT-3.5 and GPT-4, we measure the execution accuracy of generated SQL queries for a baseline (prompt with the entire database), RAG (schema retrieval and in-context-learning), and finally the respective model fine-tuned on the RAG prompts. We observe performance improvements at each stage, which results in a best execution accuracy of 83.6% on the Spider Dev set. Thus, we not only achieve state-of-the-art level performance, but have a system optimized to provide enterprise customers with the best commercially available natural language to SQL available.
What are the results?
Baseline: Entire Schema in the Prompt
With the Spider validation data, we can calculate a baseline execution accuracy. For each question-query pair, we pack the entire schema of the respective SQL database (respective to the query and question) into the prompt and ask GPT-4 to answer the question given that context. From this simple baseline, GPT-4 achieves an execution accuracy of 70% (a D+).
Adding Schema RAG and In-context learning (ICL)
Using a structured way to find the relevant parts of the DB schema with RAG shows consistent improvements in performance across models. For GPT-3.5, we see 7 ppts improvement from 60% to 66% and for GPT-4 a slightly smaller bump from 70% to 73%.
Adding Schema RAG, ICL, and fine-tuning
When additionally fine-tuning the model with specific prompt-response pairs, we see consistent further performance improvements both for GPT-3.5 and GPT-4. The final, fine-tuned GPT-4 with schema RAG and ICL achieves 84% accuracy on Spider, up from the 70% in the baseline version, which marks an impressive 14 ppts improvement. For GPT-3.5 the increase is even more pronounced, reaching 82% (almost as good as GPT-4) with RAG and fine-tuning, which is up 22 ppts from the baseline of using only prompt engineering. For GPT-3.5, the biggest increase is from fine-tuning itself, pushing performance from 66% to 82% with this technique alone.
Below is a comparison of the performance across the three different approaches for both GPT-3.5 and GPT-4.
Let’s look at a practical query example to show the difference between using GPT-4 out of the box versus the RAG and fine-tuned version.
We can see that the fine-tuned model displayed on the right-hand side not only interprets the terms for the natural language query correctly but also applies a better and more efficient query structure, using a subquery instead of a left join.
What’s next?
Our solution is not quite on top of the SpiderDev leaderboard, as most of the submitted architectures rely on prompt engineering and data pre-processing that is extremely tailored to this benchmark. However, our model does achieve top 5 performance and crucially demonstrates a comparable accuracy even when deployed in much more complex, real-world contexts and databases.
If you’re interested in using our Text2SQL model for your business use case or want to learn more about our solutions in fine-tuning and RAG, book a demo below.