How to evaluate LLMs for SQL generation

Jan 23, 2024
Open in Github

LLMs are fundamentatlly non-deterministic in their responses, this attribute makes them wonderfully creative and dynamic in their responses. However, this trait poses significant challenges in achieving consistency, a crucial aspect for integrating LLMs into production environments.

The key to harnessing the potential of LLMs in practical applications lies in consistent and systematic evaluation. This enables the identification and rectification of inconsistencies and helps in monitoring progress over time as the application evolves.

Scope of this notebook

This notebook aims to demonstrate a framework for evaluating LLMs, particularly focusing on:

  • Unit Testing: Essential for assessing individual components of the application.
  • Evaluation Metrics: Methods to quantitatively measure the model's effectiveness.
  • Runbook Documentation: A record of historical evaluations to track progress and regression.

This example focuses on a natural language to SQL use case - code generation use cases fit well with this approach when you combine code validation with code execution, so your application can test code for real as it is generated to ensure consistency.

Although this notebook uses SQL generation usecase to demonstrate the concept, the approach is generic and can be applied to a wide variety of LLM driven applications.

We will use two versions of a prompt to perform SQL generation. We will then use the unit tests and evaluation functions to test the perforamance of the prompts. Specifically, in this demonstration, we will evaluate:

  1. The consistency of JSON response.
  2. Syntactic correctness of SQL in response.

Table of contents

  1. Setup: Install required libraries, download data consisting of SQL queries and corresponding natural language translations.
  2. Test Development: Create unit tests and define evaluation metrics for the SQL generation process.
  3. Evaluation: Conduct tests using different prompts to assess the impact on performance.
  4. Reporting: Compile a report that succinctly presents the performance differences observed across various tests.
from datasets import load_dataset
from openai import OpenAI
import pandas as pd
import pydantic
import os
import sqlite3
from sqlite3 import Error
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv

# Loads key from local .env file to setup API KEY in env variables
%reload_ext dotenv
%dotenv
    
GPT_MODEL = 'gpt-4o'
dataset = load_dataset("b-mc2/sql-create-context")
cannot find .env file

Looking at the dataset

We use Huggingface datasets library to download SQL create context dataset. This dataset consists of:

  1. Question, expressed in natural language
  2. Answer, expressed in SQL designed to answer the question in natural language.
  3. Context, expressed as a CREATE SQL statement, that describes the table that may be used to answer the question.

In our demonstration today, we will use LLM to attempt to answer the question (in natural language). The LLM will be expected to generate a CREATE SQL statement to create a context suitable to answer the user question and a coresponding SELECT SQL query designed to answer the user question completely.

The dataset looks like this:

sql_df = dataset['train'].to_pandas()
sql_df.head()
answer question context
0 SELECT COUNT(*) FROM head WHERE age > 56 How many heads of the departments are older th... CREATE TABLE head (age INTEGER)
1 SELECT name, born_state, age FROM head ORDER B... List the name, born state and age of the heads... CREATE TABLE head (name VARCHAR, born_state VA...
2 SELECT creation, name, budget_in_billions FROM... List the creation year, name and budget of eac... CREATE TABLE department (creation VARCHAR, nam...
3 SELECT MAX(budget_in_billions), MIN(budget_in_... What are the maximum and minimum budget of the... CREATE TABLE department (budget_in_billions IN...
4 SELECT AVG(num_employees) FROM department WHER... What is the average number of employees of the... CREATE TABLE department (num_employees INTEGER...

Test development

To test to output of the LLM generations, we'll develop two unit tests and an evaluation, which will combine to give us a basic evaluation framework to grade the quality of our LLM iterations.

To re-iterate, our purpose is to measure the correctness and consistency of LLM output given our questions.

Unit tests

Unit tests should test the most granular components of your LLM application.

For this section we'll develop unit tests to test the following:

  • test_valid_schema will check that a parseable create and select statement are returned by the LLM.
  • test_llm_sql will execute both the create and select statements on a sqlite database to ensure they are syntactically correct.
from pydantic import BaseModel


class LLMResponse(BaseModel):
    """This simple Class expects to receive a JSON string that can be parsed into a `create` and `select` statement."""
    create: str
    select: str

Prompt

For this demonstration purposes, we use a fairly simple prompt requesting GPT to generate a pair of context CREATE SQL and a answering SELECT SQL query. We supply the natural language question as part of the prompt. We request the response to be in JSON format, so that it can be parsed easily.

system_prompt = '''Translate this natural language request into a JSON object containing two SQL queries. 
The first query should be a CREATE statement for a table answering the user's request, while the second should be a SELECT query answering their question.'''

pprint(system_prompt)
('Translate this natural language request into a JSON object containing two '
 'SQL queries. \n'
 'The first query should be a CREATE statement for a table answering the '
 "user's request, while the second should be a SELECT query answering their "
 'question.')
# Compiling the system prompt and user question into message array

messages = []
messages.append({"role": "system", "content": system_prompt})
messages.append({"role":"user","content": sql_df.iloc[0]['question']})
pprint(messages)
[{'content': 'Translate this natural language request into a JSON object '
             'containing two SQL queries. \n'
             'The first query should be a CREATE statement for a table '
             "answering the user's request, while the second should be a "
             'SELECT query answering their question.',
  'role': 'system'},
 {'content': 'How many heads of the departments are older than 56 ?',
  'role': 'user'}]
# Sending the message array to GPT, requesting a response (ensure that you have API key loaded to Env for this step)

client = OpenAI()
# completion = client.chat.completions.create(model = GPT_MODEL, messages = messages)


completion = client.beta.chat.completions.parse(
    model=GPT_MODEL,
    messages=messages,
    response_format=LLMResponse,
)

Check JSON formatting

Our first simple unit test checks that the LLM response is parseable into the LLMResponse Pydantic class that we've defined.

We'll test that our first response passes, then create a failing example to check that the check fails. This logic will be wrapped in a simple function test_valid_schema.

# Viewing the output from  GPT

content = completion.choices[0].message.content
pprint(content)
('{"create":"CREATE TABLE department_heads (\\n    id INT PRIMARY KEY,\\n    '
 'name VARCHAR(255),\\n    age INT,\\n    department '
 'VARCHAR(255)\\n);","select":"SELECT COUNT(*) FROM department_heads WHERE age '
 '> 56;"}')

Validating the output schema

We expect GPT to respond with a valid SQL, we can validate this using LLMResponse base model. test_valid_schema is designed to help us validate this.

def test_valid_schema(content):
    """Tests whether the content provided can be parsed into our Pydantic model."""
    try:
        LLMResponse.model_validate_json(content)
        return True
    # Catch pydantic's validation errors:
    except pydantic.ValidationError as exc:
        print(f"ERROR: Invalid schema: {exc}")
        return False
test_valid_schema(content)
True

Testing negative scenario

To simulate a scenario in which we get an invalid JSON response from GPT, we hardcode an invalid JSON as response. We expect test_valid_schema function to throw an exception.

failing_query = 'CREATE departments, select * from departments'
test_valid_schema(failing_query)
ERROR: Invalid schema: 1 validation error for LLMResponse
  Invalid JSON: expected value at line 1 column 1 [type=json_invalid, input_value='CREATE departments, select * from departments', input_type=str]
    For further information visit https://errors.pydantic.dev/2.8/v/json_invalid
False

As expected, we get an exception thrown from the test_valid_schema fucntion.

Test SQL queries

Next we'll validate the correctness of the SQL. This test will be desined to validate:

  1. The CREATE SQL returned in GPT response is syntactically correct.
  2. The SELECT SQL returned in the GPT response is syntactically correct.

To achieve this, we will use a sqlite instance. We will direct the retured SQL functions to a sqlite instance. If the SQL statements are valid, sqlite instance will accept and execute the statements; otherwise we will expect an exception to be thrown.

create_connection function below will setup a sqlite instance (in-memory by default) and create a connection to be used later.

# Set up SQLite to act as our test database
def create_connection(db_file=":memory:"):
    """create a database connection to a SQLite database"""
    try:
        conn = sqlite3.connect(db_file)
        # print(sqlite3.version)
    except Error as e:
        print(e)
        return None

    return conn

def close_connection(conn):
    """close a database connection"""
    try:
        conn.close()
    except Error as e:
        print(e)


conn = create_connection()

Next, we will create the following functions to carry out the syntactical correctness checks.

  • test_create: Function testing if the CREATE SQL statement succeeds.
  • test_select: Function testing if the SELECT SQL statement succeeds.
  • test_llm_sql: Wrapper function executing the two tests above.
def test_select(conn, cursor, select):
    """Tests that a SQLite select query can be executed successfully."""
    try:
        print(f"Testing select query: {select}")
        cursor.execute(select)
        record = cursor.fetchall()
        print(record)

        return True

    except sqlite3.Error as error:
        print("Error while executing select query:", error)

        return False


def test_create(conn, cursor, create):
    """Tests that a SQLite create query can be executed successfully"""
    try:
        print(f"Testing create query: {create}")
        cursor.execute(create)
        conn.commit()

        return True

    except sqlite3.Error as error:
        print("Error while creating the SQLite table:", error)

        return False


def test_llm_sql(LLMResponse):
    """Runs a suite of SQLite tests"""
    try:
        conn = create_connection()
        cursor = conn.cursor()

        create_response = test_create(conn, cursor, LLMResponse.create)

        select_response = test_select(conn, cursor, LLMResponse.select)

        if conn:
            close_connection(conn)

        if create_response is not True:
            return False

        elif select_response is not True:
            return False

        else:
            return True

    except sqlite3.Error as error:
        print("Error while creating a sqlite table", error)

        return False
# Viewing CREATE and SELECT sqls returned by GPT

test_query = LLMResponse.model_validate_json(content)
print(f"CREATE SQL is: {test_query.create}")
print(f"SELECT SQL is: {test_query.select}")
CREATE SQL is: CREATE TABLE department_heads (
    id INT PRIMARY KEY,
    name VARCHAR(255),
    age INT,
    department VARCHAR(255)
);
SELECT SQL is: SELECT COUNT(*) FROM department_heads WHERE age > 56;
# Testing the CREATE and SELECT sqls are valid (we expect this to be succesful)

test_llm_sql(test_query)
Testing create query: CREATE TABLE department_heads (
    id INT PRIMARY KEY,
    name VARCHAR(255),
    age INT,
    department VARCHAR(255)
);
Testing select query: SELECT COUNT(*) FROM department_heads WHERE age > 56;
[(0,)]
True
# Again we'll perform a negative test to confirm that a failing SELECT will return an error.

test_failure_query = '{"create": "CREATE TABLE departments (id INT, name VARCHAR(255), head_of_department VARCHAR(255))", "select": "SELECT COUNT(*) FROM departments WHERE age > 56"}'
test_failure_query = LLMResponse.model_validate_json(test_failure_query)
test_llm_sql(test_failure_query)
Testing create query: CREATE TABLE departments (id INT, name VARCHAR(255), head_of_department VARCHAR(255))
Testing select query: SELECT COUNT(*) FROM departments WHERE age > 56
Error while executing select query: no such column: age
False

Evaluation

The last component is to evaluate whether the generate SQL actually answers the user's question. This test will be performed by gpt-4o-mini, and will assess how relevant the produced SQL query is when compared to the initial user request.

This is a simple example which adapts an approach outlined in the G-Eval paper, and tested in one of our other cookbooks.

EVALUATION_MODEL = "gpt-4o-mini"

EVALUATION_PROMPT_TEMPLATE = """
You will be given one summary written for an article. Your task is to rate the summary on one metric.
Please make sure you read and understand these instructions very carefully. 
Please keep this document open while reviewing, and refer to it as needed.

Evaluation Criteria:

{criteria}

Evaluation Steps:

{steps}

Example:

Request:

{request}

Queries:

{queries}

Evaluation Form (scores ONLY):

- {metric_name}
"""

# Relevance

RELEVANCY_SCORE_CRITERIA = """
Relevance(1-5) - review of how relevant the produced SQL queries are to the original question. \
The queries should contain all points highlighted in the user's request. \
Annotators were instructed to penalize queries which contained redundancies and excess information.
"""

RELEVANCY_SCORE_STEPS = """
1. Read the request and the queries carefully.
2. Compare the queries to the request document and identify the main points of the request.
3. Assess how well the queries cover the main points of the request, and how much irrelevant or redundant information it contains.
4. Assign a relevance score from 1 to 5.
"""
def get_geval_score(
    criteria: str, steps: str, request: str, queries: str, metric_name: str
):
    """Given evaluation criteria and an observation, this function uses EVALUATION GPT to evaluate the observation against those criteria.
"""
    prompt = EVALUATION_PROMPT_TEMPLATE.format(
        criteria=criteria,
        steps=steps,
        request=request,
        queries=queries,
        metric_name=metric_name,
    )
    response = client.chat.completions.create(
        model=EVALUATION_MODEL,
        messages=[{"role": "user", "content": prompt}],
        temperature=0,
        max_tokens=5,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
    )
    return response.choices[0].message.content
# Test out evaluation on a few records

evaluation_results = []

for x,y in sql_df.head(3).iterrows():
    
    score = get_geval_score(RELEVANCY_SCORE_CRITERIA,RELEVANCY_SCORE_STEPS,y['question'],y['context'] + '\n' + y['answer'],'relevancy')
    
    evaluation_results.append((y['question'],y['context'] + '\n' + y['answer'],score))
for result in evaluation_results:
    print(f"User Question \t: {result[0]}")
    print(f"CREATE SQL Returned \t: {result[1].splitlines()[0]}")
    print(f"SELECT SQL Returned \t: {result[1].splitlines()[1]}")
    print(f"{result[2]}")
    print("*" * 20)
User Question 	: How many heads of the departments are older than 56 ?
CREATE SQL Returned 	: CREATE TABLE head (age INTEGER)
SELECT SQL Returned 	: SELECT COUNT(*) FROM head WHERE age > 56
5
********************
User Question 	: List the name, born state and age of the heads of departments ordered by age.
CREATE SQL Returned 	: CREATE TABLE head (name VARCHAR, born_state VARCHAR, age VARCHAR)
SELECT SQL Returned 	: SELECT name, born_state, age FROM head ORDER BY age
4
********************
User Question 	: List the creation year, name and budget of each department.
CREATE SQL Returned 	: CREATE TABLE department (creation VARCHAR, name VARCHAR, budget_in_billions VARCHAR)
SELECT SQL Returned 	: SELECT creation, name, budget_in_billions FROM department
4
********************

Putting it all together

We'll now test these functions in combination including our unit test and evaluations to test out two system prompts.

Each iteration of input/output and scores should be stored as a run. Optionally you can add GPT-4 annotation within your evaluations or as a separate step to review an entire run and highlight the reasons for errors.

For this example, the second system prompt will include an extra line of clarification, so we can assess the impact of this for both SQL validity and quality of solution.

First run - System Prompt 1

The system under test is the first system prompt as shown below. This run will generate responses for this system prompt and evaluate the responses using the functions we've created so far.

# Set first system prompt
system_prompt = """Translate this natural language request into a JSON object containing two SQL queries. 
The first query should be a CREATE statement for a table answering the user's request, while the second should be a SELECT query answering their question. 
"""

pprint(system_prompt, width = 120)
('Translate this natural language request into a JSON object containing two SQL queries. \n'
 "The first query should be a CREATE statement for a table answering the user's request, while the second should be a "
 'SELECT query answering their question. \n')
def get_response(system_prompt,user_message,model=GPT_MODEL):
    messages = []
    messages.append({"role": "system", "content": system_prompt})
    messages.append({"role":"user","content": user_message})

    response = client.beta.chat.completions.parse(
        model=GPT_MODEL,
        messages=messages,
        response_format=LLMResponse,
    )

    
    # response = client.chat.completions.create(model=GPT_MODEL,messages=messages,temperature=0,response_format=LLMResponse)
    
    return response.choices[0].message.content
def execute_unit_tests(input_df,output_list,system_prompt):
    """Unit testing function that takes in a dataframe and appends test results to an output_list.
    The system prompt is configurable to allow us to test a couple with this framework."""

    for x,y in input_df.iterrows():
        model_response = get_response(system_prompt,y['question'])

        format_valid = test_valid_schema(model_response)

        try:
            test_query = LLMResponse.model_validate_json(model_response)
            sql_valid = test_llm_sql(test_query)

        except:
            sql_valid = False

        output_list.append((y['question'],model_response,format_valid,sql_valid))
        
def evaluate_row(row):
    """Simple evaluation function to categorize unit testing results. 
    If the format or SQL are flagged it returns a label, otherwise it is correct"""
    if row['format'] == False:
        return 'Format incorrect'
    
    elif row['sql'] == False:
        return 'SQL incorrect'
    
    else:
        return 'SQL correct'
# Select 100 unseen queries to test this one
test_df = sql_df.tail(50)
# Execute unit tests and capture results
results = []

execute_unit_tests(input_df=test_df,output_list=results,system_prompt=system_prompt)
Testing create query: CREATE TABLE cricket_partnerships (
    id INT PRIMARY KEY,
    player1 VARCHAR(50),
    player2 VARCHAR(50),
    venue VARCHAR(100),
    match_date DATE
);
Testing select query: SELECT venue FROM cricket_partnerships WHERE player1 = 'Shoaib Malik' AND player2 = 'Misbah-ul-Haq' OR player1 = 'Misbah-ul-Haq' AND player2 = 'Shoaib Malik';
[]
Testing create query: CREATE TABLE CricketPartnerships (
    id INT PRIMARY KEY,
    player1 VARCHAR(255),
    player2 VARCHAR(255),
    venue VARCHAR(255),
    date DATE,
    runs_scored INT
);
Testing select query: SELECT venue FROM CricketPartnerships WHERE player1 = 'Herschelle Gibbs' AND player2 = 'Justin Kemp';
[]
Testing create query: CREATE TABLE PointsTable (
    NumberPlayed INT,
    Points INT
);
Testing select query: SELECT NumberPlayed FROM PointsTable WHERE Points = 310;
[]
Testing create query: CREATE TABLE sports_stats (
    team_id INTEGER PRIMARY KEY,
    team_name TEXT,
    points_against INTEGER,
    losing_bonus INTEGER
);
Testing select query: SELECT losing_bonus FROM sports_stats WHERE points_against = 588;
[]
Testing create query: CREATE TABLE rugby_points (
  id SERIAL PRIMARY KEY,
  team_name VARCHAR(100),
  tries_against INT,
  losing_bonus INT
);
Testing select query: SELECT * FROM rugby_points WHERE tries_against = 7 AND losing_bonus > 0;
[]
Testing create query: CREATE TABLE rugby_stats (
    team_name VARCHAR(50),
    games_played INT,
    tries_scored INT,
    try_bonus INT,
    points_against INT
);
Testing select query: SELECT try_bonus FROM rugby_stats WHERE points_against = 488;
[]
Testing create query: CREATE TABLE Points (
    id INT PRIMARY KEY,
    description VARCHAR(255),
    try_bonus INT
);
Testing select query: SELECT * FROM Points WHERE try_bonus = 140;
[]
Testing create query: CREATE TABLE Matches (
    MatchID INT PRIMARY KEY,
    TeamName VARCHAR(255),
    Drawn INT,
    TriesAgainst INT
);
Testing select query: SELECT TeamName FROM Matches WHERE Drawn = 1 AND TriesAgainst = 0;
[]
Testing create query: CREATE TABLE Champions (id INT PRIMARY KEY, name VARCHAR(100), reign_days INT, defenses INT);
Testing select query: SELECT reign_days FROM Champions WHERE reign_days > 3 AND defenses = 1;
[]
Testing create query: CREATE TABLE champions (
  id INTEGER PRIMARY KEY,
  name VARCHAR(255),
  reign_days INTEGER,
  defenses INTEGER
);
Testing select query: SELECT reign_days FROM champions WHERE reign_days > 3 AND defenses < 1;
[]
Testing create query: CREATE TABLE ChampionReigns (
    id INT PRIMARY KEY,
    champion_name VARCHAR(255) NOT NULL,
    total_defenses INT NOT NULL,
    days_held INT NOT NULL
);
Testing select query: SELECT AVG(total_defenses) AS average_defenses
FROM ChampionReigns
WHERE days_held = 404 AND TOTAL_REIGNS > 1;
Error while executing select query: no such column: TOTAL_REIGNS
Testing create query: CREATE TABLE Champions (
    id INT PRIMARY KEY,
    name VARCHAR(255),
    days_held INT,
    defense INT
);
Testing select query: SELECT MIN(defense) as lowest_defense FROM Champions WHERE days_held = 345;
[(None,)]
Testing create query: CREATE TABLE games_records (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    game_date DATE NOT NULL,
    team1_score INTEGER NOT NULL,
    team2_score INTEGER NOT NULL
);
Testing select query: SELECT game_date FROM games_records WHERE team1_score = 76 AND team2_score = 72;
[]
Testing create query: CREATE TABLE GameResults (
  game_id INT PRIMARY KEY,
  pitcher_name VARCHAR(50),
  pitcher_record VARCHAR(10),
  attendance INT,
  result VARCHAR(50)
);
Testing select query: SELECT attendance FROM GameResults WHERE pitcher_name = 'Ponson' AND pitcher_record = '1-5' AND result LIKE '%loss%';
[]
Testing create query: CREATE TABLE records (
    id SERIAL PRIMARY KEY,
    event_date DATE NOT NULL,
    record_type TEXT NOT NULL,
    record_value TEXT NOT NULL
);
Testing select query: SELECT event_date FROM records WHERE record_value = '36-39';
[]
Testing create query: CREATE TABLE records (
    record_id INT PRIMARY KEY,
    win_count INT,
    loss_count INT,
    record_date DATE
);
Testing select query: SELECT record_date FROM records WHERE win_count = 30 AND loss_count = 31;
[]
Testing create query: CREATE TABLE games (
  id INT PRIMARY KEY,
  player_name VARCHAR(255),
  opponent_name VARCHAR(255),
  player_score INT,
  opponent_score INT
);
Testing select query: SELECT opponent_name FROM games WHERE player_name = 'Leonard' AND player_score = 7 AND opponent_score = 8;
[]
Testing create query: CREATE TABLE GameScores (
    id INT PRIMARY KEY,
    record VARCHAR(10),
    score VARCHAR(10),
    date_played DATE
);
Testing select query: SELECT score FROM GameScores WHERE record = '18–43';
[]
Testing create query: CREATE TABLE game_scores (
    id INT PRIMARY KEY AUTO_INCREMENT,
    game_date DATE,
    opposing_team VARCHAR(50),
    team_score INT,
    opponent_score INT,
    season_record VARCHAR(10)
);
Error while creating the SQLite table: near "AUTO_INCREMENT": syntax error
Testing select query: SELECT team_score, opponent_score FROM game_scores WHERE opposing_team = 'Royals' AND season_record = '24-52';
Error while executing select query: no such table: game_scores
Testing create query: CREATE TABLE GameRecord (
    id INT PRIMARY KEY,
    record VARCHAR(10),
    score VARCHAR(50),
    date DATE
);
Testing select query: SELECT score FROM GameRecord WHERE record = '22–46';
[]
Testing create query: CREATE TABLE MilitaryPersonnel (
    id INT PRIMARY KEY,
    real_name VARCHAR(255),
    primary_specialty VARCHAR(255)
);
Testing select query: SELECT real_name FROM MilitaryPersonnel WHERE primary_specialty = 'shock paratrooper';
[]
Testing create query: CREATE TABLE Persons (
    PersonID INT PRIMARY KEY,
    FirstName VARCHAR(255) NOT NULL,
    LastName VARCHAR(255) NOT NULL,
    Birthplace VARCHAR(255) NOT NULL
);
Testing select query: SELECT Birthplace FROM Persons WHERE FirstName = 'Pete' AND LastName = 'Sanderson';
[]
Testing create query: CREATE TABLE roles (
    id SERIAL PRIMARY KEY,
    person_name VARCHAR(100) NOT NULL,
    role_title VARCHAR(100) NOT NULL
);
Testing select query: SELECT role_title FROM roles WHERE person_name = 'Jean-Luc Bouvier';
[]
Testing create query: CREATE TABLE KayakPilots (
    id INT PRIMARY KEY,
    real_name VARCHAR(255),
    nickname VARCHAR(255),
    vessel_type VARCHAR(100)
);
Testing select query: SELECT real_name FROM KayakPilots WHERE vessel_type = 'silent attack kayak';
[]
Testing create query: CREATE TABLE people (
    id INT PRIMARY KEY,
    name VARCHAR(100),
    code_name VARCHAR(100),
    city_of_birth VARCHAR(100),
    date_of_birth DATE
);
Testing select query: SELECT code_name FROM people WHERE city_of_birth = 'Liverpool';
[]
Testing create query: CREATE TABLE CanoeingMedalists (
    id INT PRIMARY KEY,
    name VARCHAR(100) NOT NULL,
    event VARCHAR(100) NOT NULL,
    medal_type VARCHAR(50) NOT NULL,
    year INT NOT NULL
);
Testing select query: SELECT name FROM CanoeingMedalists WHERE event = 'Canoeing';
[]
Testing create query: CREATE TABLE HalfMiddleweightEvents (
    GameID INT PRIMARY KEY,
    GameName VARCHAR(255),
    Event VARCHAR(50),
    Year INT
);
Testing select query: SELECT GameName, Year FROM HalfMiddleweightEvents WHERE Event = 'Women\'s Half Middleweight';
Error while executing select query: near "s": syntax error
Testing create query: CREATE TABLE OlympicMedalists2000 (
    AthleteID INT PRIMARY KEY,
    Name VARCHAR(255) NOT NULL,
    MedalType VARCHAR(50) NOT NULL,
    Event VARCHAR(255) NOT NULL,
    Country VARCHAR(100) NOT NULL
);
Testing select query: SELECT Name, Event, Country FROM OlympicMedalists2000 WHERE MedalType = 'Bronze' AND Event = 'Specific Event Name';
[]
Testing create query: CREATE TABLE GameAttendance (
    GameID INT PRIMARY KEY,
    Opponent VARCHAR(50),
    Attendance INT
);
Testing select query: SELECT SUM(Attendance) AS Total_Attendance
FROM GameAttendance
WHERE Opponent = 'Twins';
[(None,)]
Testing create query: CREATE TABLE sports_records (
    id SERIAL PRIMARY KEY,
    date DATE NOT NULL,
    record VARCHAR(255) NOT NULL
);
Testing select query: SELECT date FROM sports_records WHERE record = '41-46';
[]
Testing create query: CREATE TABLE Scores (
    id INT PRIMARY KEY,
    score_name VARCHAR(255),
    score_value VARCHAR(10)
);
Testing select query: SELECT score_name FROM Scores WHERE score_value = '48-55';
[]
Testing create query: CREATE TABLE sports_records (
    team_name VARCHAR(255),
    games_won INT,
    games_lost INT,
    PRIMARY KEY (team_name)
);
Testing select query: SELECT team_name FROM sports_records WHERE games_won = 44 AND games_lost = 49;
[]
Testing create query: CREATE TABLE games (
    game_id INT PRIMARY KEY,
    opponent VARCHAR(50),
    record VARCHAR(10),
    score VARCHAR(10)
);
Testing select query: SELECT score FROM games WHERE opponent = 'white sox' AND record = '2-0';
[]
Testing create query: CREATE TABLE election_votes (
  candidate_name VARCHAR(255),
  votes_received INT
);
Testing select query: SELECT votes_received FROM election_votes WHERE candidate_name = 'Candice Sjostrom';
[]
Testing create query: CREATE TABLE election_results (
    candidate_name VARCHAR(100),
    votes_received INT,
    total_votes INT
);
Testing select query: SELECT (votes_received * 100.0) / total_votes AS percentage_received 
FROM election_results 
WHERE candidate_name = 'Chris Wright';
[]
Testing create query: CREATE TABLE election_results (
    election_year INT,
    candidate_name VARCHAR(255),
    vote_count INT,
    vote_percentage DECIMAL(5, 2),
    office VARCHAR(255),
    office_district INT
);
Testing select query: SELECT vote_count
FROM election_results
WHERE election_year > 1992 
  AND vote_percentage = 1.59 
  AND office = 'us representative' 
  AND office_district = 4;
[]
Testing create query: CREATE TABLE Representatives (
  id SERIAL PRIMARY KEY,
  first_name VARCHAR(50),
  last_name VARCHAR(50),
  start_year INT,
  end_year INT
);
Testing select query: SELECT start_year, end_year FROM Representatives WHERE first_name = 'J.' AND last_name = 'Smith Young';
[]
Testing create query: CREATE TABLE Politicians (
    id INT PRIMARY KEY,
    name VARCHAR(100),
    party VARCHAR(100)
);
Testing select query: SELECT party FROM Politicians WHERE name = 'Thomas L. Young';
[]
Testing create query: CREATE TABLE MedalCounts (
    Country VARCHAR(100),
    Gold INT,
    Silver INT,
    Bronze INT,
    Total INT
);
Testing select query: SELECT MIN(Total) AS LowestMedalCount FROM MedalCounts 
WHERE Gold = 0 AND Bronze > 2 AND Silver > 1;
[(None,)]
Testing create query: CREATE TABLE country_medals (
    rank INT,
    country_name VARCHAR(255),
    gold_medals INT,
    silver_medals INT,
    bronze_medals INT,
    total_medals INT
);
Testing select query: SELECT SUM(silver_medals) FROM country_medals WHERE rank = 14 AND total_medals < 1;
[(None,)]
Testing create query: CREATE TABLE player_stats (
    player_id INT PRIMARY KEY,
    player_name VARCHAR(100),
    tackles INT,
    fumble_recoveries INT,
    forced_fumbles INT
);
Testing select query: SELECT tackles FROM player_stats WHERE fumble_recoveries > 0 AND forced_fumbles > 0;
[]
Testing create query: CREATE TABLE DefensiveStats (
    player_id INT PRIMARY KEY,
    player_name VARCHAR(100),
    solo_tackles INT,
    forced_fumbles INT
);
Testing select query: SELECT forced_fumbles FROM DefensiveStats WHERE player_name = 'jim laney' AND solo_tackles < 2;
[]
Testing create query: CREATE TABLE PlayersStats (
    PlayerID INT PRIMARY KEY,
    PlayerName VARCHAR(255),
    SoloTackles INT,
    Total INT
);
Testing select query: SELECT MAX(Total) AS HighTotal FROM PlayersStats WHERE SoloTackles > 15;
[(None,)]
Testing create query: CREATE TABLE PlayerStats (
    player_id INT PRIMARY KEY,
    player_name VARCHAR(100),
    fumble_recoveries INT,
    forced_fumbles INT,
    sacks INT,
    solo_tackles INT
);
Testing select query: SELECT fumble_recoveries FROM PlayerStats WHERE player_name = 'Scott Gajos' AND forced_fumbles = 0 AND sacks = 0 AND solo_tackles < 2;
[]
Testing create query: CREATE TABLE Matches (
    MatchID INT PRIMARY KEY,
    HomeTeam VARCHAR(255),
    OpponentTeam VARCHAR(255),
    MatchTime TIME,
    Stadium VARCHAR(255)
);
Testing select query: SELECT OpponentTeam FROM Matches WHERE MatchTime = '20:00:00' AND Stadium = 'Camp Nou';
[]
Testing create query: CREATE TABLE matches (
    id INT PRIMARY KEY,
    date DATE,
    time TIME,
    score VARCHAR(5)
);
Testing select query: SELECT time FROM matches WHERE score = '3-2';
[]
Testing create query: CREATE TABLE Matches (
    MatchID INT PRIMARY KEY,
    HomeTeam VARCHAR(255),
    AwayTeam VARCHAR(255),
    Ground VARCHAR(255),
    MatchDate DATE
);
Testing select query: SELECT Ground FROM Matches WHERE HomeTeam = 'Aston Villa' OR AwayTeam = 'Aston Villa';
[]
Testing create query: CREATE TABLE CompetitionEvents (
    EventID INT PRIMARY KEY AUTO_INCREMENT,
    CompetitionName VARCHAR(100),
    Location VARCHAR(100),
    EventTime TIME,
    EventDate DATE
);
Error while creating the SQLite table: near "AUTO_INCREMENT": syntax error
Testing select query: SELECT CompetitionName FROM CompetitionEvents 
WHERE Location = 'San Siro' AND EventTime = '18:30:00' 
ORDER BY EventDate DESC 
LIMIT 1;
Error while executing select query: no such table: CompetitionEvents
Testing create query: CREATE TABLE school_locality_deciles (
    locality_id INT PRIMARY KEY,
    locality_name VARCHAR(255),
    total_decile INT
);
Testing select query: SELECT locality_name, total_decile FROM school_locality_deciles WHERE locality_name = 'redwood';
[]
Testing create query: CREATE TABLE racing_reports (
  report_id INT PRIMARY KEY,
  report_name VARCHAR(255) NOT NULL,
  track_name VARCHAR(255),
  event_date DATE
);
Testing select query: SELECT report_name FROM racing_reports WHERE track_name = 'Circuit of Tripoli';
[]

Run Evaluation

Now that we have generated the SQL based on system prompt 1 (run 1), we can run evaluation against the results. We use pandas apply functin to "apply" evaluation to each resulting generation

results_df = pd.DataFrame(results)
results_df.columns = ['question','response','format','sql']

# Execute evaluation
results_df['evaluation_score'] = results_df.apply(lambda x: get_geval_score(RELEVANCY_SCORE_CRITERIA,RELEVANCY_SCORE_STEPS,x['question'],x['response'],'relevancy'),axis=1)
results_df['unit_test_evaluation'] = results_df.apply(lambda x: evaluate_row(x),axis=1)
results_df['unit_test_evaluation'].value_counts()
unit_test_evaluation
SQL correct      46
SQL incorrect     4
Name: count, dtype: int64
results_df['evaluation_score'].value_counts()
evaluation_score
5    33
4    15
3     2
Name: count, dtype: int64

Second run

We now use a new system prompt to run same unit test and evaluation. Please note that we are using the same functions for unit testing and evaluations; the only change is the system prompt (which is under the test).

system_prompt_2 = """Translate this natural language request into a JSON object containing two SQL queries. 
The first query should be a CREATE statement for a table answering the user's request, while the second should be a SELECT query answering their question. 
Ensure the SQL is always generated on one line, never use \n to separate rows."""

pprint(system_prompt_2, width=120)
('Translate this natural language request into a JSON object containing two SQL queries. \n'
 "The first query should be a CREATE statement for a table answering the user's request, while the second should be a "
 'SELECT query answering their question. \n'
 'Ensure the SQL is always generated on one line, never use \n'
 ' to separate rows.')
# Execute unit tests
results_2 = []

execute_unit_tests(input_df=test_df,output_list=results_2,system_prompt=system_prompt_2)
Testing create query: CREATE TABLE cricket_partnership (id INT PRIMARY KEY, player1 VARCHAR(50), player2 VARCHAR(50), venue VARCHAR(100));
Testing select query: SELECT venue FROM cricket_partnership WHERE (player1 = 'Shoaib Malik' AND player2 = 'Misbah-ul-Haq') OR (player1 = 'Misbah-ul-Haq' AND player2 = 'Shoaib Malik');
[]
Testing create query: CREATE TABLE cricket_partnerships (id INT PRIMARY KEY, player_a VARCHAR(50), player_b VARCHAR(50), venue VARCHAR(100), date DATE);
Testing select query: SELECT venue FROM cricket_partnerships WHERE (player_a = 'Herschelle Gibbs' AND player_b = 'Justin Kemp') OR (player_a = 'Justin Kemp' AND player_b = 'Herschelle Gibbs');
[]
Testing create query: CREATE TABLE Points_Table (Number_Played INT, Points INT);
Testing select query: SELECT Number_Played FROM Points_Table WHERE Points = 310;
[]
Testing create query: CREATE TABLE TeamResults (TeamID INT PRIMARY KEY, TeamName VARCHAR(100), PointsFor INT, PointsAgainst INT, Wins INT, Losses INT, LosingBonus INT);
Testing select query: SELECT LosingBonus FROM TeamResults WHERE PointsAgainst = 588;
[]
Testing create query: CREATE TABLE RugbyMatches (MatchID INT PRIMARY KEY, TriesAgainst INT, LosingBonus INT);
Testing select query: SELECT * FROM RugbyMatches WHERE LosingBonus = 7;
[]
Testing create query: CREATE TABLE RugbyTeamStats (TeamID INT PRIMARY KEY, TeamName VARCHAR(255), TryBonus INT, PointsAgainst INT);
Testing select query: SELECT TryBonus FROM RugbyTeamStats WHERE PointsAgainst = 488;
[]
Testing create query: CREATE TABLE Points (id INT PRIMARY KEY, bonus_type VARCHAR(50), bonus_value INT);
Testing select query: SELECT * FROM Points WHERE bonus_type = 'Try' AND bonus_value = 140;
[]
Testing create query: CREATE TABLE Matches (Drawn BOOLEAN, Tries_Against INTEGER);
Testing select query: SELECT * FROM Matches WHERE Drawn = TRUE AND Tries_Against = 0;
[]
Testing create query: CREATE TABLE Champions (id INT PRIMARY KEY, name VARCHAR(50), days_held INT, reign INT, defenses INT);
Testing select query: SELECT days_held FROM Champions WHERE reign > 3 AND defenses = 1;
[]
Testing create query: CREATE TABLE champions (id INT, name VARCHAR(255), days_held INT, reign_count INT, defenses INT);
Testing select query: SELECT days_held FROM champions WHERE reign_count > 3 AND defenses < 1;
[]
Testing create query: CREATE TABLE champions (id INT PRIMARY KEY, name VARCHAR(255), days_held INT, reign INT, defenses INT);
Testing select query: SELECT AVG(defenses) AS average_defenses FROM champions WHERE days_held = 404 AND reign > 1;
[(None,)]
Testing create query: CREATE TABLE champions (id INT PRIMARY KEY, name VARCHAR(255), defense INT, days_held INT);
Testing select query: SELECT MIN(defense) AS lowest_defense FROM champions WHERE days_held = 345;
[(None,)]
Testing create query: CREATE TABLE games (date DATE, team1_score INT, team2_score INT);
Testing select query: SELECT date FROM games WHERE (team1_score = 76 AND team2_score = 72) OR (team1_score = 72 AND team2_score = 76);
[]
Testing create query: CREATE TABLE match_results (match_id INT, team_name VARCHAR(255), player_name VARCHAR(255), attendance INT, result VARCHAR(50));
Testing select query: SELECT attendance FROM match_results WHERE player_name = 'Ponson' AND result = 'loss' AND match_id = 1;
[]
Testing create query: CREATE TABLE daily_records (id INT PRIMARY KEY, record_value INT, record_date DATE);
Testing select query: SELECT record_date FROM daily_records WHERE record_value BETWEEN 36 AND 39;
[]
Testing create query: CREATE TABLE records (id INTEGER PRIMARY KEY, record VARCHAR(10), date DATE);
Testing select query: SELECT date FROM records WHERE record = '30-31';
[]
Testing create query: CREATE TABLE baseball_games (game_date DATE, opponent TEXT, player_name TEXT, record TEXT, outcome TEXT);
Testing select query: SELECT opponent FROM baseball_games WHERE player_name = 'Leonard' AND record = '7-8';
[]
Testing create query: CREATE TABLE Games (id INT PRIMARY KEY, record VARCHAR(10), score VARCHAR(10));
Testing select query: SELECT score FROM Games WHERE record = '18–43';
[]
Testing create query: CREATE TABLE baseball_games (id INT PRIMARY KEY, opponent VARCHAR(50), our_score INT, their_score INT, our_record VARCHAR(10));
Testing select query: SELECT our_score, their_score FROM baseball_games WHERE opponent = 'Royals' AND our_record = '24-52';
[]
Testing create query: CREATE TABLE GameScores (id INTEGER PRIMARY KEY, record VARCHAR(10), score VARCHAR(10));
Testing select query: SELECT score FROM GameScores WHERE record = '22–46';
[]
Testing create query: CREATE TABLE MilitarySpecialty (real_name VARCHAR(100), specialty VARCHAR(100));
Testing select query: SELECT real_name FROM MilitarySpecialty WHERE specialty = 'shock paratrooper';
[]
Testing create query: CREATE TABLE People (Name VARCHAR(255), Birthplace VARCHAR(255));
Testing select query: SELECT Birthplace FROM People WHERE Name = 'Pete Sanderson';
[]
Testing create query: CREATE TABLE Roles (person_name VARCHAR(255), role VARCHAR(255));
Testing select query: SELECT role FROM Roles WHERE person_name = 'Jean-Luc Bouvier';
[]
Testing create query: CREATE TABLE KayakPilots (PilotID INT PRIMARY KEY, RealName VARCHAR(255), Alias VARCHAR(255));
Testing select query: SELECT RealName FROM KayakPilots WHERE Alias = 'silent attack kayak';
[]
Testing create query: CREATE TABLE people (id INT PRIMARY KEY, name VARCHAR(100), birth_city VARCHAR(100), code_name VARCHAR(100));
Testing select query: SELECT code_name FROM people WHERE birth_city = 'Liverpool';
[]
Testing create query: CREATE TABLE CanoeingMedalists (ID INT PRIMARY KEY, Name VARCHAR(100), Country VARCHAR(100), MedalType VARCHAR(50));
Testing select query: SELECT Name FROM CanoeingMedalists;
[]
Testing create query: CREATE TABLE WomensHalfMiddleweightEvents (id INT PRIMARY KEY, game_name TEXT, year INT, location TEXT);
Testing select query: SELECT game_name, year, location FROM WomensHalfMiddleweightEvents WHERE game_name IS NOT NULL;
[]
Testing create query: CREATE TABLE BronzeMedals (Year INT, Games VARCHAR(100), Sport VARCHAR(100), Event VARCHAR(100), Athlete VARCHAR(100), Country VARCHAR(100));
Testing select query: SELECT Athlete, Country FROM BronzeMedals WHERE Year = 2000 AND Games = 'Sydney';
[]
Testing create query: CREATE TABLE BaseballMatches (match_id INT PRIMARY KEY, date DATE, opponent VARCHAR(100), attendance INT);
Testing select query: SELECT COUNT(*) FROM BaseballMatches WHERE opponent = 'twins';
[(0,)]
Testing create query: CREATE TABLE Records (Date DATE, WinCount INT, LossCount INT, PRIMARY KEY (Date));
Testing select query: SELECT Date FROM Records WHERE WinCount = 41 AND LossCount = 46;
[]
Testing create query: CREATE TABLE team_records (id INTEGER PRIMARY KEY, team_name VARCHAR(100), wins INTEGER, losses INTEGER, score INTEGER);
Testing select query: SELECT score FROM team_records WHERE wins = 48 AND losses = 55;
[]
Testing create query: CREATE TABLE Scores (record VARCHAR(10), score INT);
Testing select query: SELECT score FROM Scores WHERE record = '44-49';
[]
Testing create query: CREATE TABLE GameScores (Score INT, Opponent VARCHAR(50), Record VARCHAR(5));
Testing select query: SELECT Score FROM GameScores WHERE Opponent = 'white sox' AND Record = '2-0';
[]
Testing create query: CREATE TABLE votes (id INT PRIMARY KEY, candidate_name VARCHAR(255), votes INT);
Testing select query: SELECT votes FROM votes WHERE candidate_name = 'candice sjostrom';
[]
Testing create query: CREATE TABLE Votes(candidate_name VARCHAR(255), received_percentage DECIMAL(5,2));
Testing select query: SELECT received_percentage FROM Votes WHERE candidate_name = 'Chris Wright';
[]
Testing create query: CREATE TABLE election_results (year INT, votes INT, percentage FLOAT, office VARCHAR(50), candidate_id INT);
Testing select query: SELECT votes FROM election_results WHERE year > 1992 AND percentage = 1.59 AND office = 'US Representative 4';
[]
Testing create query: CREATE TABLE Representatives (Name VARCHAR(100), StartYear INT, EndYear INT);
Testing select query: SELECT StartYear, EndYear FROM Representatives WHERE Name = 'J. Smith Young';
[]
Testing create query: CREATE TABLE Politicians (id INT PRIMARY KEY, name VARCHAR(255), party VARCHAR(255), term_start DATE, term_end DATE);
Testing select query: SELECT party FROM Politicians WHERE name = 'Thomas L. Young';
[]
Testing create query: CREATE TABLE medals (country VARCHAR(100), gold INT, silver INT, bronze INT, total INT);
Testing select query: SELECT MIN(total) FROM medals WHERE gold = 0 AND bronze > 2 AND silver > 1;
[(None,)]
Testing create query: CREATE TABLE OlympicStats (Country VARCHAR(100), Rank INT, GoldMedals INT, SilverMedals INT, BronzeMedals INT, TotalMedals INT);
Testing select query: SELECT SUM(SilverMedals) as TotalSilverMedals FROM OlympicStats WHERE Rank = 14 AND TotalMedals < 1;
[(None,)]
Testing create query: CREATE TABLE player_statistics (player_id INT PRIMARY KEY, player_name VARCHAR(100), number_of_tackles INT, fumble_recoveries INT, forced_fumbles INT);
Testing select query: SELECT player_name, number_of_tackles FROM player_statistics WHERE fumble_recoveries > 0 AND forced_fumbles > 0;
[]
Testing create query: CREATE TABLE ForcedFumbles (player_name VARCHAR(100), solo_tackles INT, forced_fumbles INT);
Testing select query: SELECT forced_fumbles FROM ForcedFumbles WHERE player_name = 'Jim Laney' AND solo_tackles < 2;
[]
Testing create query: CREATE TABLE PlayerStatistics (PlayerID INT PRIMARY KEY, PlayerName VARCHAR(100), SoloTackles INT, TotalTackles INT);
Testing select query: SELECT MAX(TotalTackles) AS HighTotal FROM PlayerStatistics WHERE SoloTackles > 15;
[(None,)]
Testing create query: CREATE TABLE PlayerStats (player_name VARCHAR(50), fumble_recoveries INT, forced_fumbles INT, sacks INT, solo_tackles INT);
Testing select query: SELECT fumble_recoveries FROM PlayerStats WHERE player_name = 'Scott Gajos' AND forced_fumbles = 0 AND sacks = 0 AND solo_tackles < 2;
[]
Testing create query: CREATE TABLE matches (id INT PRIMARY KEY, home_team VARCHAR(100), opponent VARCHAR(100), date TIME, location VARCHAR(100));
Testing select query: SELECT opponent FROM matches WHERE date = '20:00:00' AND location = 'Camp Nou';
[]
Testing create query: CREATE TABLE matches (id INT PRIMARY KEY, match_time TIME, team1_score INT, team2_score INT);
Testing select query: SELECT match_time FROM matches WHERE team1_score = 3 AND team2_score = 2;
[]
Testing create query: CREATE TABLE Matches (id INT PRIMARY KEY, home_team VARCHAR(255), away_team VARCHAR(255), ground VARCHAR(255), date DATE);
Testing select query: SELECT ground FROM Matches WHERE away_team = 'Aston Villa';
[]
Testing create query: CREATE TABLE Competitions (id INT PRIMARY KEY, name VARCHAR(255), location VARCHAR(255), start_time TIME, timezone VARCHAR(255));
Testing select query: SELECT name FROM Competitions WHERE location = 'San Siro' AND start_time = '18:30:00' AND timezone = 'GMT';
[]
Testing create query: CREATE TABLE locality_scores (locality_id INT PRIMARY KEY, decile INT NOT NULL, locality_name VARCHAR(100) NOT NULL);
Testing select query: SELECT SUM(decile) AS total_decile FROM locality_scores WHERE locality_name = 'redwood school';
[(None,)]
Testing create query: CREATE TABLE reports (report_id INT PRIMARY KEY, report_name VARCHAR(255), content TEXT);
Testing select query: SELECT report_name FROM reports WHERE content LIKE '%Circuit of Tripoli%';
[]
results_2_df = pd.DataFrame(results_2)
results_2_df.columns = ['question','response','format','sql']

# Execute evaluation
results_2_df['evaluation_score'] = results_2_df.apply(lambda x: get_geval_score(RELEVANCY_SCORE_CRITERIA,RELEVANCY_SCORE_STEPS,x['question'],x['response'],'relevancy'),axis=1)
results_2_df['unit_test_evaluation'] = results_2_df.apply(lambda x: evaluate_row(x),axis=1)
results_2_df['unit_test_evaluation'].value_counts()
unit_test_evaluation
SQL correct    50
Name: count, dtype: int64
results_2_df['evaluation_score'].value_counts()
evaluation_score
5    34
4    15
3     1
Name: count, dtype: int64

Report

We'll make a simple dataframe to store and display the run performance - this is where you can use tools like Weights & Biases Prompts or Gantry to store the results for analytics on your different iterations.

results_df['run'] = 1
results_df['Evaluating Model'] = 'gpt-4'

results_2_df['run'] = 2
results_2_df['Evaluating Model'] = 'gpt-4'
run_df = pd.concat([results_df,results_2_df])
run_df.head()
question response format sql evaluation_score unit_test_evaluation run Evaluating Model
0 What venue did the parntership of shoaib malik... {"create":"CREATE TABLE cricket_partnerships (... True True 5 SQL correct 1 gpt-4
1 What venue did the partnership of herschelle g... {"create":"CREATE TABLE CricketPartnerships (\... True True 5 SQL correct 1 gpt-4
2 What is the number Played that has 310 Points ... {"create":"CREATE TABLE PointsTable (\n Num... True True 5 SQL correct 1 gpt-4
3 What Losing bonus has a Points against of 588? {"create":"CREATE TABLE sports_stats (\n te... True True 5 SQL correct 1 gpt-4
4 What Tries against has a Losing bonus of 7? {"create":"CREATE TABLE rugby_points (\n id S... True True 4 SQL correct 1 gpt-4
# Unit test results
unittest_df_pivot = pd.pivot_table(run_df,  values='format',index=['run','unit_test_evaluation'], #columns='position',
                          aggfunc='count')
unittest_df_pivot.columns = ['Number of records']
unittest_df_pivot
Number of records
run unit_test_evaluation
1 SQL correct 46
SQL incorrect 4
2 SQL correct 50

Plotting the results

We can create a simple bar chart to visualise the results of unit tests for both runs.

unittest_df_pivot.reset_index(inplace=True)

# Plotting
plt.figure(figsize=(10, 6))

# Set the width of each bar
bar_width = 0.35

# OpenAI brand colors
openai_colors = ['#00D1B2', '#000000']  # Green and Black

# Get unique runs and unit test evaluations
unique_runs = unittest_df_pivot['run'].unique()
unique_unit_test_evaluations = unittest_df_pivot['unit_test_evaluation'].unique()

# Ensure we have enough colors (repeating the pattern if necessary)
colors = openai_colors * (len(unique_runs) // len(openai_colors) + 1)

# Iterate over each run to plot
for i, run in enumerate(unique_runs):
    run_data = unittest_df_pivot[unittest_df_pivot['run'] == run]

    # Position of bars for this run
    positions = np.arange(len(unique_unit_test_evaluations)) + i * bar_width

    plt.bar(positions, run_data['Number of records'], width=bar_width, label=f'Run {run}', color=colors[i])

# Setting the x-axis labels to be the unit test evaluations, centered under the groups
plt.xticks(np.arange(len(unique_unit_test_evaluations)) + bar_width / 2, unique_unit_test_evaluations)

plt.xlabel('Unit Test Evaluation')
plt.ylabel('Number of Records')
plt.title('Unit Test Evaluations vs Number of Records for Each Run')
plt.legend()
plt.show()
image generated by notebook
# Unit test results
evaluation_df_pivot = pd.pivot_table(run_df,  values='format',index=['run','evaluation_score'], #columns='position',
                          aggfunc='count')
evaluation_df_pivot.columns = ['Number of records']
evaluation_df_pivot
Number of records
run evaluation_score
1 3 2
4 15
5 33
2 3 1
4 15
5 34

Plotting the results

We can create a simple bar chart to visualise the results of unit tests for both runs.



# Reset index without dropping the 'run' and 'evaluation_score' columns
evaluation_df_pivot.reset_index(inplace=True)

# Plotting
plt.figure(figsize=(10, 6))

bar_width = 0.35

# OpenAI brand colors
openai_colors = ['#00D1B2', '#000000']  # Green, Black

# Identify unique runs and evaluation scores
unique_runs = evaluation_df_pivot['run'].unique()
unique_evaluation_scores = evaluation_df_pivot['evaluation_score'].unique()

# Repeat colors if there are more runs than colors
colors = openai_colors * (len(unique_runs) // len(openai_colors) + 1)

for i, run in enumerate(unique_runs):
    # Select rows for this run only
    run_data = evaluation_df_pivot[evaluation_df_pivot['run'] == run].copy()
    
    # Ensure every 'evaluation_score' is present
    run_data.set_index('evaluation_score', inplace=True)
    run_data = run_data.reindex(unique_evaluation_scores, fill_value=0)
    run_data.reset_index(inplace=True)
    
    # Plot each bar
    positions = np.arange(len(unique_evaluation_scores)) + i * bar_width
    plt.bar(
        positions,
        run_data['Number of records'],
        width=bar_width,
        label=f'Run {run}',
        color=colors[i]
    )

# Configure the x-axis to show evaluation scores under the grouped bars
plt.xticks(np.arange(len(unique_evaluation_scores)) + bar_width / 2, unique_evaluation_scores)

plt.xlabel('Evaluation Score')
plt.ylabel('Number of Records')
plt.title('Evaluation Scores vs Number of Records for Each Run')
plt.legend()
plt.show()
image generated by notebook

Conclusion

Now you have a framework to test SQL generation using LLMs, and with some tweaks this approach can be extended to many other code generation use cases. With GPT-4 and engaged human labellers you can aim to automate the evaluation of these test cases, making an iterative loop where new examples are added to the test set and this structure detects any performance regressions.

We hope you find this useful, and please supply any feedback.