Embedding Wikipedia articles for search

Nov 26, 2024
Open in Github

This notebook shows how we prepared a dataset of Wikipedia articles for search, used in Question_answering_using_embeddings.ipynb.


  1. Prerequisites: Import libraries, set API key (if needed)
  2. Collect: We download a few hundred Wikipedia articles about the 2022 Olympics
  3. Chunk: Documents are split into short, semi-self-contained sections to be embedded
  4. Embed: Each section is embedded with the OpenAI API
  5. Store: Embeddings are saved in a CSV file (for large datasets, use a vector database)
# imports
import mwclient  # for downloading example Wikipedia articles
import mwparserfromhell  # for splitting Wikipedia articles into sections
from openai import OpenAI  # for generating embeddings
import os  # for environment variables
import pandas as pd  # for DataFrames to store article sections and embeddings
import re  # for cutting <ref> links out of Wikipedia articles
import tiktoken  # for counting tokens

client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))

Install any missing libraries with pip install in your terminal. E.g.,

pip install openai

(You can also do this in a notebook cell with !pip install openai.)

If you install any libraries, be sure to restart the notebook kernel.

Set API key (if needed)

Note that the OpenAI library will try to read your API key from the OPENAI_API_KEY environment variable. If you haven't already, set this environment variable by following these instructions.

1. Collect documents

In this example, we'll download a few hundred Wikipedia articles related to the 2022 Winter Olympics.

# get Wikipedia pages about the 2022 Winter Olympics

CATEGORY_TITLE = "Category:2022 Winter Olympics"
WIKI_SITE = "en.wikipedia.org"

def titles_from_category(
    category: mwclient.listing.Category, max_depth: int
) -> set[str]:
    """Return a set of page titles in a given Wiki category and its subcategories."""
    titles = set()
    for cm in category.members():
        if type(cm) == mwclient.page.Page:
            # ^type() used instead of isinstance() to catch match w/ no inheritance
        elif isinstance(cm, mwclient.listing.Category) and max_depth > 0:
            deeper_titles = titles_from_category(cm, max_depth=max_depth - 1)
    return titles

site = mwclient.Site(WIKI_SITE)
category_page = site.pages[CATEGORY_TITLE]
titles = titles_from_category(category_page, max_depth=1)
# ^note: max_depth=1 means we go one level deep in the category tree
print(f"Found {len(titles)} article titles in {CATEGORY_TITLE}.")
Found 179 article titles in Category:2022 Winter Olympics.

2. Chunk documents

Now that we have our reference documents, we need to prepare them for search.

Because GPT can only read a limited amount of text at once, we'll split each document into chunks short enough to be read.

For this specific example on Wikipedia articles, we'll:

  • Discard less relevant-looking sections like External Links and Footnotes
  • Clean up the text by removing reference tags (e.g., ), whitespace, and super short sections
  • Split each article into sections
  • Prepend titles and subtitles to each section's text, to help GPT understand the context
  • If a section is long (say, > 1,600 tokens), we'll recursively split it into smaller sections, trying to split along semantic boundaries like paragraphs
# define functions to split Wikipedia pages into sections

    "See also",
    "External links",
    "Further reading",
    "Notes and references",
    "Photo gallery",
    "Works cited",
    "References and sources",
    "References and notes",

def all_subsections_from_section(
    section: mwparserfromhell.wikicode.Wikicode,
    parent_titles: list[str],
    sections_to_ignore: set[str],
) -> list[tuple[list[str], str]]:
    From a Wikipedia section, return a flattened list of all nested subsections.
    Each subsection is a tuple, where:
        - the first element is a list of parent subtitles, starting with the page title
        - the second element is the text of the subsection (but not any children)
    headings = [str(h) for h in section.filter_headings()]
    title = headings[0]
    if title.strip("=" + " ") in sections_to_ignore:
        # ^wiki headings are wrapped like "== Heading =="
        return []
    titles = parent_titles + [title]
    full_text = str(section)
    section_text = full_text.split(title)[1]
    if len(headings) == 1:
        return [(titles, section_text)]
        first_subtitle = headings[1]
        section_text = section_text.split(first_subtitle)[0]
        results = [(titles, section_text)]
        for subsection in section.get_sections(levels=[len(titles) + 1]):
            results.extend(all_subsections_from_section(subsection, titles, sections_to_ignore))
        return results

def all_subsections_from_title(
    title: str,
    sections_to_ignore: set[str] = SECTIONS_TO_IGNORE,
    site_name: str = WIKI_SITE,
) -> list[tuple[list[str], str]]:
    """From a Wikipedia page title, return a flattened list of all nested subsections.
    Each subsection is a tuple, where:
        - the first element is a list of parent subtitles, starting with the page title
        - the second element is the text of the subsection (but not any children)
    site = mwclient.Site(site_name)
    page = site.pages[title]
    text = page.text()
    parsed_text = mwparserfromhell.parse(text)
    headings = [str(h) for h in parsed_text.filter_headings()]
    if headings:
        summary_text = str(parsed_text).split(headings[0])[0]
        summary_text = str(parsed_text)
    results = [([title], summary_text)]
    for subsection in parsed_text.get_sections(levels=[2]):
        results.extend(all_subsections_from_section(subsection, [title], sections_to_ignore))
    return results
# split pages into sections
# may take ~1 minute per 100 articles
wikipedia_sections = []
for title in titles:
print(f"Found {len(wikipedia_sections)} sections in {len(titles)} pages.")
Found 1838 sections in 179 pages.
# clean text
def clean_section(section: tuple[list[str], str]) -> tuple[list[str], str]:
    Return a cleaned up section with:
        - <ref>xyz</ref> patterns removed
        - leading/trailing whitespace removed
    titles, text = section
    text = re.sub(r"<ref.*?</ref>", "", text)
    text = text.strip()
    return (titles, text)

wikipedia_sections = [clean_section(ws) for ws in wikipedia_sections]

# filter out short/blank sections
def keep_section(section: tuple[list[str], str]) -> bool:
    """Return True if the section should be kept, False otherwise."""
    titles, text = section
    if len(text) < 16:
        return False
        return True

original_num_sections = len(wikipedia_sections)
wikipedia_sections = [ws for ws in wikipedia_sections if keep_section(ws)]
print(f"Filtered out {original_num_sections-len(wikipedia_sections)} sections, leaving {len(wikipedia_sections)} sections.")
Filtered out 89 sections, leaving 1749 sections.
# print example data
for ws in wikipedia_sections[:5]:
    display(ws[1][:77] + "...")
['Concerns and controversies at the 2022 Winter Olympics']
'{{Short description|Overview of concerns and controversies surrounding the Ga...'
['Concerns and controversies at the 2022 Winter Olympics', '==Criticism of host selection==']
'American sportscaster [[Bob Costas]] criticized the [[International Olympic C...'
['Concerns and controversies at the 2022 Winter Olympics', '==Organizing concerns and controversies==', '===Cost and climate===']
'Several cities withdrew their applications during [[Bids for the 2022 Winter ...'
['Concerns and controversies at the 2022 Winter Olympics', '==Organizing concerns and controversies==', '===Promotional song===']
'Some commentators alleged that one of the early promotional songs for the [[2...'
['Concerns and controversies at the 2022 Winter Olympics', '== Diplomatic boycotts or non-attendance ==']
'<section begin=boycotts />\n[[File:2022 Winter Olympics (Beijing) diplomatic b...'

Next, we'll recursively split long sections into smaller sections.

There's no perfect recipe for splitting text into sections.

Some tradeoffs include:

  • Longer sections may be better for questions that require more context
  • Longer sections may be worse for retrieval, as they may have more topics muddled together
  • Shorter sections are better for reducing costs (which are proportional to the number of tokens)
  • Shorter sections allow more sections to be retrieved, which may help with recall
  • Overlapping sections may help prevent answers from being cut by section boundaries

Here, we'll use a simple approach and limit sections to 1,600 tokens each, recursively halving any sections that are too long. To avoid cutting in the middle of useful sentences, we'll split along paragraph boundaries when possible.

GPT_MODEL = "gpt-4o-mini"  # only matters insofar as it selects which tokenizer to use

def num_tokens(text: str, model: str = GPT_MODEL) -> int:
    """Return the number of tokens in a string."""
    encoding = tiktoken.encoding_for_model(model)
    return len(encoding.encode(text))

def halved_by_delimiter(string: str, delimiter: str = "\n") -> list[str, str]:
    """Split a string in two, on a delimiter, trying to balance tokens on each side."""
    chunks = string.split(delimiter)
    if len(chunks) == 1:
        return [string, ""]  # no delimiter found
    elif len(chunks) == 2:
        return chunks  # no need to search for halfway point
        total_tokens = num_tokens(string)
        halfway = total_tokens // 2
        best_diff = halfway
        for i, chunk in enumerate(chunks):
            left = delimiter.join(chunks[: i + 1])
            left_tokens = num_tokens(left)
            diff = abs(halfway - left_tokens)
            if diff >= best_diff:
                best_diff = diff
        left = delimiter.join(chunks[:i])
        right = delimiter.join(chunks[i:])
        return [left, right]

def truncated_string(
    string: str,
    model: str,
    max_tokens: int,
    print_warning: bool = True,
) -> str:
    """Truncate a string to a maximum number of tokens."""
    encoding = tiktoken.encoding_for_model(model)
    encoded_string = encoding.encode(string)
    truncated_string = encoding.decode(encoded_string[:max_tokens])
    if print_warning and len(encoded_string) > max_tokens:
        print(f"Warning: Truncated string from {len(encoded_string)} tokens to {max_tokens} tokens.")
    return truncated_string

def split_strings_from_subsection(
    subsection: tuple[list[str], str],
    max_tokens: int = 1000,
    model: str = GPT_MODEL,
    max_recursion: int = 5,
) -> list[str]:
    Split a subsection into a list of subsections, each with no more than max_tokens.
    Each subsection is a tuple of parent titles [H1, H2, ...] and text (str).
    titles, text = subsection
    string = "\n\n".join(titles + [text])
    num_tokens_in_string = num_tokens(string)
    # if length is fine, return string
    if num_tokens_in_string <= max_tokens:
        return [string]
    # if recursion hasn't found a split after X iterations, just truncate
    elif max_recursion == 0:
        return [truncated_string(string, model=model, max_tokens=max_tokens)]
    # otherwise, split in half and recurse
        titles, text = subsection
        for delimiter in ["\n\n", "\n", ". "]:
            left, right = halved_by_delimiter(text, delimiter=delimiter)
            if left == "" or right == "":
                # if either half is empty, retry with a more fine-grained delimiter
                # recurse on each half
                results = []
                for half in [left, right]:
                    half_subsection = (titles, half)
                    half_strings = split_strings_from_subsection(
                        max_recursion=max_recursion - 1,
                return results
    # otherwise no split was found, so just truncate (should be very rare)
    return [truncated_string(string, model=model, max_tokens=max_tokens)]
# split sections into chunks
wikipedia_strings = []
for section in wikipedia_sections:
    wikipedia_strings.extend(split_strings_from_subsection(section, max_tokens=MAX_TOKENS))

print(f"{len(wikipedia_sections)} Wikipedia sections split into {len(wikipedia_strings)} strings.")
1749 Wikipedia sections split into 2052 strings.
# print example data
Concerns and controversies at the 2022 Winter Olympics

==Criticism of host selection==

American sportscaster [[Bob Costas]] criticized the [[International Olympic Committee]]'s (IOC) decision to award the games to China saying "The IOC deserves all of the disdain and disgust that comes their way for going back to China yet again" referencing China's human rights record.

After winning two gold medals and returning to his home country of Sweden skater [[Nils van der Poel]] criticized the IOC's selection of China as the host saying "I think it is extremely irresponsible to give it to a country that violates human rights as blatantly as the Chinese regime is doing." He had declined to criticize China before leaving for the games saying "I don't think it would be particularly wise for me to criticize the system I'm about to transition to, if I want to live a long and productive life."
EMBEDDING_MODEL = "text-embedding-3-small"
BATCH_SIZE = 1000  # you can submit up to 2048 embedding inputs per request

embeddings = []
for batch_start in range(0, len(wikipedia_strings), BATCH_SIZE):
    batch_end = batch_start + BATCH_SIZE
    batch = wikipedia_strings[batch_start:batch_end]
    print(f"Batch {batch_start} to {batch_end-1}")
    response = client.embeddings.create(model=EMBEDDING_MODEL, input=batch)
    for i, be in enumerate(response.data):
        assert i == be.index  # double check embeddings are in same order as input
    batch_embeddings = [e.embedding for e in response.data]

df = pd.DataFrame({"text": wikipedia_strings, "embedding": embeddings})
Batch 0 to 999
Batch 1000 to 1999
Batch 2000 to 2999

4. Store document chunks and embeddings

Because this example only uses a few thousand strings, we'll store them in a CSV file.

(For larger datasets, use a vector database, which will be more performant.)

# save document chunks and embeddings

SAVE_PATH = "data/winter_olympics_2022.csv"

df.to_csv(SAVE_PATH, index=False)