Image Similarity Search

In this tutorial notebook, we will walk you through how to create an image similarity search backend service. You will learn how to use the pre-trained embedding model called squeezenet from torchvision to transform image data into vector embeddings. You will build an index with Pinecone to store these vector embeddings. Lastly, you will learn how to send a new image as query, and retrieve similar images in the index.

Try Demo Open Notebook View Source

The structure of this notebook is as follows:

  1. Install dependencies and set up Pinecone.
  2. Download an image dataset (e.g., TinyImageNet).
  3. Download a pre-trained computer vision (CV) model (e.g. squeezenet).
  4. Using the CV model, convert images to their vector embeddings.
  5. Create a Pinecone vector index, which serves as the backend service.
  6. Upload and index your image vectors.
  7. Perform image similarity search and review the results.
  8. Delete the index.

Prerequisites

Install dependencies.

!pip install -qU pip pinecone-client torchvision pandas matplotlib

Set up Pinecone. Get your Pinecone API key here

import pinecone
import os

api_key = os.getenv("PINECONE_API_KEY") or "YOUR_API_KEY"
pinecone.init(api_key=api_key)

Check Pinecone version compatibility.

import pinecone.info

version_info = pinecone.info.version()
server_version = ".".join(version_info.server.split(".")[:2])
client_version = ".".join(version_info.client.split(".")[:2])
notebook_version = "0.8"

assert (
    notebook_version == server_version
), "This notebook is outdated. Consider using the latest version of the notebook."
assert client_version == server_version, "Please upgrade pinecone-client."

Data Preparation

In this example, we will use the Tiny ImageNet dataset, which has 100,000 training images cross 200 classes. Each class has 500 training images.

import requests, os, zipfile

DATA_DIR = "tmp"
IMAGENET_DIR = f"{DATA_DIR}/tiny-imagenet-200"
IMAGENET_ZIP = f"{DATA_DIR}/tiny-imagenet-200.zip"
IMAGENET_URL = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"


def download_data():
    os.makedirs(DATA_DIR, exist_ok=True)

    if not os.path.exists(IMAGENET_DIR):
        if not os.path.exists(IMAGENET_ZIP):
            r = requests.get(IMAGENET_URL)  # create HTTP response object
            with open(IMAGENET_ZIP, "wb") as f:
                f.write(r.content)

        with zipfile.ZipFile(IMAGENET_ZIP, "r") as zip_ref:
            zip_ref.extractall(DATA_DIR)


download_data()

torchvision is a popular library from PyTorch(torch) that provides convenient functions for loading and transforming images. Here we use torchvision to select images that belong to a random subset of image classes.

from torchvision import datasets
import random

random_seed = 123
random.seed(random_seed)

# Select a random sample of image classes
image_classes = set(random.sample(range(200), 5))

# Get the corresponding image file names
image_file_names = [
    file_name
    for file_name, label in datasets.ImageFolder(f"{IMAGENET_DIR}/train").imgs
    if label in image_classes
]

Let’s look at some of selected images. We’ll write two utility functions to show the images, which will also be useful later.

import matplotlib.pyplot as plt
from PIL import Image


def show_images_horizontally(file_names):
    m = len(file_names)
    fig, ax = plt.subplots(1, m)
    fig.set_figwidth(1.5 * m)
    for a, f in zip(ax, file_names):
        a.imshow(Image.open(f))
        a.axis("off")
    plt.show()


def show_image(file_name):
    fig, ax = plt.subplots(1, 1)
    fig.set_figwidth(1.3)
    ax.imshow(Image.open(file_name))
    ax.axis("off")
for i in range(5):
    show_images_horizontally(random.sample(image_file_names, 10))

Sample images from dataset

Sample images from dataset

Sample images from dataset

Sample images from dataset

Sample images from dataset

Convert Images to Embeddings

Vector embeddings for images are representations of images as high dimensional vectors.

A typical image file consists of values of pixel intensities in three spectra: red, green, and blue. A naïve way of converting an image file to a vector is by “flattening” the image file–that is, arranging all pixel value in one row. Sadly this is practically useless for most machine learning applications.

A much better way of creating a vector representation of an image is to apply a computer vision (CV) model to the image. The output of the CV model is a vector embedding of the image optimized for assessing the image similarity. In other words, the CV model maps similar images to vectors that are close to each other, and dissimilar images to vectors that are far apart. By converting images to vector embeddings, we can simplify the problem of image similarity search to that of vector proximity search.

In practice, we don’t always have to train a new CV model. In this example, we will use a pre-trained model from torchvision called squeezenet to convert the images. Feel free to experiment with other models provided by torchvision to see which model best fits your application. See torchvision examples page.

from torchvision import transforms as ts
import torchvision.models as models


class ImageEmbedder:
    def __init__(self):
        self.normalize = ts.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        # see https://pytorch.org/vision/0.8/models.html for many more model options
        self.model = models.squeezenet1_0(pretrained=True, progress=False)  # squeezenet

    def embed(self, image_file_name):
        image = Image.open(image_file_name).convert("RGB")
        image = ts.Resize(256)(image)
        image = ts.CenterCrop(224)(image)
        tensor = ts.ToTensor()(image)
        tensor = self.normalize(tensor).reshape(1, 3, 224, 224)
        vector = self.model(tensor).cpu().detach().numpy().flatten()
        return vector


image_embedder = ImageEmbedder()
Downloading: "https://download.pytorch.org/models/squeezenet1_0-a815701f.pth" to /root/.cache/torch/hub/checkpoints/squeezenet1_0-a815701f.pth

Once we have the image embedder class, we can use it to convert image files to vector embeddings.

from tqdm import tqdm
import pandas as pd

df = pd.DataFrame()
df["image_file_name"] = image_file_names
df["embedding_id"] = [
    file_name.split(IMAGENET_DIR)[-1] for file_name in image_file_names
]
df["embedding"] = [
    image_embedder.embed(file_name) for file_name in tqdm(image_file_names)
]
df = df.sample(frac=1)  # shuffle the images
100%|██████████| 2500/2500 [01:17<00:00, 32.25it/s]
print(df.head(2))
                                        image_file_name  \
2249  tmp/tiny-imagenet-200/train/n09256479/images/n...
219   tmp/tiny-imagenet-200/train/n01910747/images/n...

                                    embedding_id  \
2249  /train/n09256479/images/n09256479_322.JPEG
219   /train/n01910747/images/n01910747_296.JPEG

                                              embedding
2249  [12.009511, 9.01633, 4.372963, 7.2551413, 6.49...
219   [8.683792, 11.956106, 8.437065, 11.03078, 9.75...

Split the data into the items dataset and the queries dataset. Because the data is already shuffled, we can directly split the data.

cutoff = int(len(df) * 0.8)
item_df, query_df = df[:cutoff], df[cutoff:]
(item_df.shape, query_df.shape)
((2000, 3), (500, 3))

Create a Pinecone Vector Index Service

Now that we have our vector embeddings we can create our Pinecone service and upload the data to it.

Below we create a Pinecone index. Here we use two arguments:

  • metric="euclidean" means the vector search will use euclidean distance as the measure of similarity.
  • shards=1 means the service will run on a single node. Since we only index at most 100k vectors in 1000 dimensions, a single shard is more than enough. If you have more than 1M vectors, you should increase the number of shards. Refer to the documentation for guidelines on how to choose the number of shards.
# Choosing an arbitrary name for my index
index_name = "simple-pytorch-image-search"

# Checking whether the index already exists.
if index_name not in pinecone.list_indexes():
    pinecone.create_index(index_name, metric="euclidean", shards=1)
100%|██████████| 2/2 [00:15<00:00,  7.74s/it]

Now that the index is created, we can upload vectorized images. To do that, we connect to the index. You only need to do this once.

index = pinecone.Index(name=index_name)

upsert will upload, add, and index items into your remote vector similarity search service. If an item with that ID already exists it will be overwritten with the new provided value.

acks = index.upsert(items=zip(item_df.embedding_id, item_df.embedding))
2000it [00:11, 175.42it/s]

Querying

Let’s test dataset as test queries.

len(query_df)
500

Pinecone supports running multiple queries in a minibatch.

import time


start = time.perf_counter()

res = index.query(query_df.embedding, batch_size=100)  # issuing queries

end = time.perf_counter()
print("Run this test on a fast network to get the best performance.")
# print(f"Service QPS is {int(len(query_df)/(end-start))} queries/second.")
500it [00:02, 167.07it/s]

Run this test on a fast network to get the best performance.

Of course, running single queries one at a time is also an option.

start = time.perf_counter()

for query in query_df.embedding:
    _ = index.unary_query(query)  # issuing one query

end = time.perf_counter()
print("Run this test on a fast network to get the best performance.")
# print(f"Service QPS is {int(len(query_df)/(end-start))} queries/second.")
Run this test on a fast network to get the best performance.

Now that we have our results, we can look at them.

for i in range(100, 110):
    print(f"Query {i+1} and search results")
    show_image(query_df.image_file_name.iloc[i])
    show_images_horizontally(
        [IMAGENET_DIR + embedding_id for embedding_id in res[i].ids]
    )

    print("-" * 80)
Query 101 and search results

Example of an image search query

Example of an image search result

--------------------------------------------------------------------------------
Query 102 and search results

Example of an image search query

Example of an image search result

--------------------------------------------------------------------------------
Query 103 and search results

Example of an image search query

Example of an image search result

--------------------------------------------------------------------------------
Query 104 and search results

Example of an image search query

Example of an image search result

--------------------------------------------------------------------------------
Query 105 and search results

Example of an image search query

Example of an image search result

--------------------------------------------------------------------------------
Query 106 and search results

Example of an image search query

Example of an image search result

--------------------------------------------------------------------------------
Query 107 and search results

Example of an image search query

Example of an image search result

--------------------------------------------------------------------------------
Query 108 and search results

Example of an image search query

Example of an image search result

--------------------------------------------------------------------------------
Query 109 and search results

Example of an image search query

Example of an image search result

--------------------------------------------------------------------------------
Query 110 and search results

Example of an image search query

Example of an image search result

--------------------------------------------------------------------------------

As is shown above, image similarity search works relatively well with the pre-trained CV model: when we query with images that are not in the index, most similarity seaches return images from the same class as the query images.

Wrap up

You can use pinecone.delete_index to delete the index and free all resources dedicated to it. Once an index is deleted all resources need to be recreated. We suggest that you only delete an index if no application is using it.

pinecone.delete_index(index_name)
100%|██████████| 1/1 [00:13<00:00, 13.07s/it]

({'success': True}, <pinecone.utils.progressbar.ProgressBar at 0x7f1a8df5f7b8>)