Using Qdrant for embeddings search

,
Jun 28, 2023
Open in Github

This notebook takes you through a simple flow to download some data, embed it, and then index and search it using a selection of vector databases. This is a common requirement for customers who want to store and search our embeddings with their own data in a secure environment to support production use cases such as chatbots, topic modelling and more.

What is a Vector Database

A vector database is a database made to store, manage and search embedding vectors. The use of embeddings to encode unstructured data (text, audio, video and more) as vectors for consumption by machine-learning models has exploded in recent years, due to the increasing effectiveness of AI in solving use cases involving natural language, image recognition and other unstructured forms of data. Vector databases have emerged as an effective solution for enterprises to deliver and scale these use cases.

Why use a Vector Database

Vector databases enable enterprises to take many of the embeddings use cases we've shared in this repo (question and answering, chatbot and recommendation services, for example), and make use of them in a secure, scalable environment. Many of our customers make embeddings solve their problems at small scale but performance and security hold them back from going into production - we see vector databases as a key component in solving that, and in this guide we'll walk through the basics of embedding text data, storing it in a vector database and using it for semantic search.

Demo Flow

The demo flow is:

  • Setup: Import packages and set any required variables
  • Load data: Load a dataset and embed it using OpenAI embeddings
  • Qdrant
    • Setup: Here we'll set up the Python client for Qdrant. For more details go here
    • Index Data: We'll create a collection with vectors for titles and content
    • Search Data: We'll run a few searches to confirm it works

Once you've run through this notebook you should have a basic understanding of how to setup and use vector databases, and can move on to more complex use cases making use of our embeddings.

Setup

Import the required libraries and set the embedding model that we'd like to use.

# We'll need to install Qdrant client
!pip install qdrant-client

#Install wget to pull zip file
!pip install wget
Collecting qdrant-client
  ...
Successfully installed certifi-2023.5.7 grpcio-1.56.0 grpcio-tools-1.56.0 h11-0.14.0 h2-4.1.0 hpack-4.0.0 httpcore-0.17.2 httpx-0.24.1 hyperframe-6.0.1 numpy-1.25.0 portalocker-2.7.0 protobuf-4.23.3 pydantic-1.10.9 qdrant-client-1.3.1 typing-extensions-4.5.0 urllib3-1.26.16
Collecting wget
  Using cached wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25ldone
[?25h  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9657 sha256=eb5f15f12150fc304e7b14973424f696fa8d95225772bc0cbc0b318bf92e04b9
  Stored in directory: /home/user/.cache/pip/wheels/04/5f/3e/46cc37c5d698415694d83f607f833f83f0149e49b3af9d0f38
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2
import openai

from typing import List, Iterator
import pandas as pd
import numpy as np
import os
import wget
from ast import literal_eval

# Qdrant's client library for Python
import qdrant_client

# I've set this to our new embeddings model, this can be changed to the embedding model of your choice
EMBEDDING_MODEL = "text-embedding-3-small"

# Ignore unclosed SSL socket warnings - optional in case you get these errors
import warnings

warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning) 

Load data

In this section we'll load embedded data that we've prepared previous to this session.

embeddings_url = 'https://cdn.openai.com/API/examples/data/vector_database_wikipedia_articles_embedded.zip'

# The file is ~700 MB so this will take some time
wget.download(embeddings_url)
'vector_database_wikipedia_articles_embedded.zip'
import zipfile
with zipfile.ZipFile("vector_database_wikipedia_articles_embedded.zip","r") as zip_ref:
    zip_ref.extractall("../data")
article_df = pd.read_csv('../data/vector_database_wikipedia_articles_embedded.csv')
article_df.head()
id url title text title_vector content_vector vector_id
0 1 https://simple.wikipedia.org/wiki/April April April is the fourth month of the year in the J... [0.001009464613161981, -0.020700545981526375, ... [-0.011253940872848034, -0.013491976074874401,... 0
1 2 https://simple.wikipedia.org/wiki/August August August (Aug.) is the eighth month of the year ... [0.0009286514250561595, 0.000820168002974242, ... [0.0003609954728744924, 0.007262262050062418, ... 1
2 6 https://simple.wikipedia.org/wiki/Art Art Art is a creative activity that expresses imag... [0.003393713850528002, 0.0061537534929811954, ... [-0.004959689453244209, 0.015772193670272827, ... 2
3 8 https://simple.wikipedia.org/wiki/A A A or a is the first letter of the English alph... [0.0153952119871974, -0.013759135268628597, 0.... [0.024894846603274345, -0.022186409682035446, ... 3
4 9 https://simple.wikipedia.org/wiki/Air Air Air refers to the Earth's atmosphere. Air is a... [0.02224554680287838, -0.02044147066771984, -0... [0.021524671465158463, 0.018522677943110466, -... 4
# Read vectors from strings back into a list
article_df['title_vector'] = article_df.title_vector.apply(literal_eval)
article_df['content_vector'] = article_df.content_vector.apply(literal_eval)

# Set vector_id to be a string
article_df['vector_id'] = article_df['vector_id'].apply(str)
article_df.info(show_counts=True)
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 25000 entries, 0 to 24999
Data columns (total 7 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   id              25000 non-null  int64 
 1   url             25000 non-null  object
 2   title           25000 non-null  object
 3   text            25000 non-null  object
 4   title_vector    25000 non-null  object
 5   content_vector  25000 non-null  object
 6   vector_id       25000 non-null  object
dtypes: int64(1), object(6)
memory usage: 1.3+ MB

Qdrant

Qdrant. is a high-performant vector search database written in Rust. It offers both on-premise and cloud version, but for the purposes of that example we're going to use the local deployment mode.

Setting everything up will require:

  • Spinning up a local instance of Qdrant
  • Configuring the collection and storing the data in it
  • Trying out with some queries

Setup

For the local deployment, we are going to use Docker, according to the Qdrant documentation: https://qdrant.tech/documentation/quick_start/. Qdrant requires just a single container, but an example of the docker-compose.yaml file is available at ./qdrant/docker-compose.yaml in this repo.

You can start Qdrant instance locally by navigating to this directory and running docker-compose up -d

You might need to increase the memory limit for Docker to 8GB or more. Or Qdrant might fail to execute with an error message like 7 Killed.

qdrant = qdrant_client.QdrantClient(host='localhost', prefer_grpc=True)
qdrant.get_collections()
CollectionsResponse(collections=[CollectionDescription(name='Routines')])

Index data

Qdrant stores data in collections where each object is described by at least one vector and may contain an additional metadata called payload. Our collection will be called Articles and each object will be described by both title and content vectors.

We'll be using an official qdrant-client package that has all the utility methods already built-in.

from qdrant_client.http import models as rest
vector_size = len(article_df['content_vector'][0])

qdrant.recreate_collection(
    collection_name='Articles',
    vectors_config={
        'title': rest.VectorParams(
            distance=rest.Distance.COSINE,
            size=vector_size,
        ),
        'content': rest.VectorParams(
            distance=rest.Distance.COSINE,
            size=vector_size,
        ),
    }
)
True
qdrant.upsert(
    collection_name='Articles',
    points=[
        rest.PointStruct(
            id=k,
            vector={
                'title': v['title_vector'],
                'content': v['content_vector'],
            },
            payload=v.to_dict(),
        )
        for k, v in article_df.iterrows()
    ],
)
UpdateResult(operation_id=0, status=<UpdateStatus.COMPLETED: 'completed'>)
# Check the collection size to make sure all the points have been stored
qdrant.count(collection_name='Articles')
CountResult(count=25000)

Search Data

Once the data is put into Qdrant we will start querying the collection for the closest vectors. We may provide an additional parameter vector_name to switch from title to content based search.

def query_qdrant(query, collection_name, vector_name='title', top_k=20):

    # Creates embedding vector from user query
    embedded_query = openai.Embedding.create(
        input=query,
        model=EMBEDDING_MODEL,
    )['data'][0]['embedding']
    
    query_results = qdrant.search(
        collection_name=collection_name,
        query_vector=(
            vector_name, embedded_query
        ),
        limit=top_k,
    )
    
    return query_results
query_results = query_qdrant('modern art in Europe', 'Articles')
for i, article in enumerate(query_results):
    print(f'{i + 1}. {article.payload["title"]} (Score: {round(article.score, 3)})')
1. Museum of Modern Art (Score: 0.875)
2. Western Europe (Score: 0.867)
3. Renaissance art (Score: 0.864)
4. Pop art (Score: 0.86)
5. Northern Europe (Score: 0.855)
6. Hellenistic art (Score: 0.853)
7. Modernist literature (Score: 0.847)
8. Art film (Score: 0.843)
9. Central Europe (Score: 0.843)
10. European (Score: 0.841)
11. Art (Score: 0.841)
12. Byzantine art (Score: 0.841)
13. Postmodernism (Score: 0.84)
14. Eastern Europe (Score: 0.839)
15. Europe (Score: 0.839)
16. Cubism (Score: 0.839)
17. Impressionism (Score: 0.838)
18. Bauhaus (Score: 0.838)
19. Expressionism (Score: 0.837)
20. Surrealism (Score: 0.837)
# This time we'll query using content vector
query_results = query_qdrant('Famous battles in Scottish history', 'Articles', 'content')
for i, article in enumerate(query_results):
    print(f'{i + 1}. {article.payload["title"]} (Score: {round(article.score, 3)})')
1. Battle of Bannockburn (Score: 0.869)
2. Wars of Scottish Independence (Score: 0.861)
3. 1651 (Score: 0.853)
4. First War of Scottish Independence (Score: 0.85)
5. Robert I of Scotland (Score: 0.846)
6. 841 (Score: 0.844)
7. 1716 (Score: 0.844)
8. 1314 (Score: 0.837)
9. 1263 (Score: 0.836)
10. William Wallace (Score: 0.835)
11. Stirling (Score: 0.831)
12. 1306 (Score: 0.831)
13. 1746 (Score: 0.831)
14. 1040s (Score: 0.828)
15. 1106 (Score: 0.827)
16. 1304 (Score: 0.827)
17. David II of Scotland (Score: 0.825)
18. Braveheart (Score: 0.824)
19. 1124 (Score: 0.824)
20. July 27 (Score: 0.823)