Fine-Tuned Q&A - Collect Data

Mar 10, 2022
Open in Github

Note: To answer questions based on text documents, we recommend the procedure in Question Answering using Embeddings. Some of the code below may rely on deprecated API endpoints.

1. Collect Wikipedia data about Olympic Games 2020

The idea of this project is to create a question answering model, based on a few paragraphs of provided text. Base GPT-3 models do a good job at answering questions when the answer is contained within the paragraph, however if the answer isn't contained, the base models tend to try their best to answer anyway, often leading to confabulated answers.

To create a model which answers questions only if there is sufficient context for doing so, we first create a dataset of questions and answers based on paragraphs of text. In order to train the model to answer only when the answer is present, we also add adversarial examples, where the question doesn't match the context. In those cases, we ask the model to output "No sufficient context for answering the question".

We will perform this task in three notebooks:

  1. The first (this) notebook focuses on collecting recent data, which GPT-3 didn't see during it's pre-training. We picked the topic of Olympic Games 2020 (which actually took place in the summer of 2021), and downloaded 713 unique pages. We organized the dataset by individual sections, which will serve as context for asking and answering the questions.
  2. The second notebook will utilize Davinci-instruct to ask a few questions based on a Wikipedia section, as well as answer those questions, based on that section.
  3. The third notebook will utilize the dataset of context, question and answer pairs to additionally create adversarial questions and context pairs, where the question was not generated on that context. In those cases the model will be prompted to answer "No sufficient context for answering the question". We will also train a discriminator model, which predicts whether the question can be answered based on the context or not.
import pandas as pd
import wikipedia

def filter_olympic_2020_titles(titles):
    Get the titles which are related to Olympic games hosted in 2020, given a list of titles
    titles = [title for title in titles if '2020' in title and 'olympi' in title.lower()]
    return titles

def get_wiki_page(title):
    Get the wikipedia page given a title
    except wikipedia.exceptions.DisambiguationError as e:
    except wikipedia.exceptions.PageError as e:
        return None

def recursively_find_all_pages(titles, titles_so_far=set()):
    Recursively find all the pages that are linked to the Wikipedia titles in the list
    all_pages = []
    titles = list(set(titles) - titles_so_far)
    titles = filter_olympic_2020_titles(titles)
    for title in titles:
        page = get_wiki_page(title)
        if page is None:

        new_pages = recursively_find_all_pages(page.links, titles_so_far)
        for pg in new_pages:
            if pg.title not in [p.title for p in all_pages]:
    return all_pages

pages = recursively_find_all_pages(["2020 Summer Olympics"])

import re
from typing import Set
from transformers import GPT2TokenizerFast

import numpy as np
from nltk.tokenize import sent_tokenize

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

def count_tokens(text: str) -> int:
    """count the number of tokens in a string"""
    return len(tokenizer.encode(text))

def reduce_long(
    long_text: str, long_text_tokens: bool = False, max_len: int = 590
) -> str:
    Reduce a long text to a maximum of `max_len` tokens by potentially cutting at a sentence end
    if not long_text_tokens:
        long_text_tokens = count_tokens(long_text)
    if long_text_tokens > max_len:
        sentences = sent_tokenize(long_text.replace("\n", " "))
        ntokens = 0
        for i, sentence in enumerate(sentences):
            ntokens += 1 + count_tokens(sentence)
            if ntokens > max_len:
                return ". ".join(sentences[:i]) + "."

    return long_text

discard_categories = ['See also', 'References', 'External links', 'Further reading', "Footnotes",
    "Bibliography", "Sources", "Citations", "Literature", "Footnotes", "Notes and references",
    "Photo gallery", "Works cited", "Photos", "Gallery", "Notes", "References and sources",
    "References and notes",]

def extract_sections(
    wiki_text: str,
    title: str,
    max_len: int = 1500,
    discard_categories: Set[str] = discard_categories,
) -> str:
    Extract the sections of a Wikipedia page, discarding the references and other low information sections
    if len(wiki_text) == 0:
        return []

    # find all headings and the corresponding contents
    headings = re.findall("==+ .* ==+", wiki_text)
    for heading in headings:
        wiki_text = wiki_text.replace(heading, "==+ !! ==+")
    contents = wiki_text.split("==+ !! ==+")
    contents = [c.strip() for c in contents]
    assert len(headings) == len(contents) - 1

    cont = contents.pop(0).strip()
    outputs = [(title, "Summary", cont, count_tokens(cont)+4)]
    # discard the discard categories, accounting for a tree structure
    max_level = 100
    keep_group_level = max_level
    remove_group_level = max_level
    nheadings, ncontents = [], []
    for heading, content in zip(headings, contents):
        plain_heading = " ".join(heading.split(" ")[1:-1])
        num_equals = len(heading.split(" ")[0])
        if num_equals <= keep_group_level:
            keep_group_level = max_level

        if num_equals > remove_group_level:
            if (
                num_equals <= keep_group_level
        keep_group_level = max_level
        if plain_heading in discard_categories:
            remove_group_level = num_equals
            keep_group_level = max_level
        nheadings.append(heading.replace("=", "").strip())
        remove_group_level = max_level

    # count the tokens of each section
    ncontent_ntokens = [
        + 3
        + count_tokens(" ".join(h.split(" ")[1:-1]))
        - (1 if len(c) == 0 else 0)
        for h, c in zip(nheadings, ncontents)

    # Create a tuple of (title, section_name, content, number of tokens)
    outputs += [(title, h, c, t) if t<max_len 
                else (title, h, reduce_long(c, max_len), count_tokens(reduce_long(c,max_len))) 
                    for h, c, t in zip(nheadings, ncontents, ncontent_ntokens)]
    return outputs

# Example page being processed into sections
bermuda_page = get_wiki_page('Bermuda at the 2020 Summer Olympics')
ber = extract_sections(bermuda_page.content, bermuda_page.title)

# Example section
('Bermuda at the 2020 Summer Olympics',
 "Bermuda entered one dressage rider into the Olympic competition by finishing in the top four, outside the group selection, of the individual FEI Olympic Rankings for Groups D and E (North, Central, and South America), marking the country's recurrence to the sport after an eight-year absence. The quota was later withdrawn, following an injury of Annabelle Collins' main horse Joyero and a failure to obtain minimum eligibility requirements (MER) aboard a new horse Chuppy Checker.",
res = []
for page in pages:
    res += extract_sections(page.content, page.title)
df = pd.DataFrame(res, columns=["title", "heading", "content", "tokens"])
df = df[df.tokens>40]
df = df.drop_duplicates(['title','heading'])
df = df.reset_index().drop('index',axis=1) # reset index
Token indices sequence length is longer than the specified maximum sequence length for this model (1060 > 1024). Running this sequence through the model will result in indexing errors
title heading content tokens
0 2020 Summer Olympics Summary The 2020 Summer Olympics (Japanese: 2020年夏季オリン... 713
1 2020 Summer Olympics Host city selection The International Olympic Committee (IOC) vote... 126
2 2020 Summer Olympics Impact of the COVID-19 pandemic In January 2020, concerns were raised about th... 369
3 2020 Summer Olympics Qualifying event cancellation and postponement Concerns about the pandemic began to affect qu... 298
4 2020 Summer Olympics Effect on doping tests Mandatory doping tests were being severely res... 163
df.to_csv('olympics-data/olympics_sections.csv', index=False)
Concerns and controversies at the 2020 Summer Olympics    51
United States at the 2020 Summer Olympics                 46
Great Britain at the 2020 Summer Olympics                 42
Canada at the 2020 Summer Olympics                        39
Olympic Games                                             39
Name: title, dtype: int64

There appear to be winter and summer Olympics 2020. We chose to leave a little ambiguity and noise in the dataset, even though we were interested in only Summer Olympics 2020.

True     3567
False     305
Name: title, dtype: int64
False    3774
True       98
Name: title, dtype: int64
import pandas as pd
from matplotlib import pyplot as plt

df = pd.read_csv('olympics-data/olympics_sections.csv')
# add axis descriptions and title
plt.xlabel('Number of tokens')
plt.ylabel('Number of Wikipedia sections')
plt.title('Distribution of number of tokens in Wikipedia sections')
image generated by notebook

We can see that the majority of section are fairly short (less than 500 tokens).