NLP Powered GIF search

We will use the Tumblr GIF Description Dataset, which contains over 100k animated GIFs and 120K sentences describing its visual content. Using this data with a vector database and retriever we are able to create an NLP-powered GIF search tool.

There are a few packages that must be installed for this notebook to run:

Copy
Copied
pip install -U pandas pinecone-client sentence-transformers tqdm

We must also set the following notebook parameters to display the GIF images we will be working with.

Copy
Copied
from IPython.display import HTML
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

Download and Extract Dataset

First let's download and extract the dataset. The dataset is available here on GitHub. We can use the link below to download the dataset directly. We can also access the link from a browser to directly download the files.

Copy
Copied
# Use wget to download the master.zip file which contains the dataset
!wget https://github.com/raingo/TGIF-Release/archive/master.zip
Copy
Copied
# Use unzip to extract the master.zip file
!unzip master.zip

Explore the Dataset

Now let's explore the downloaded files. The data we want is in tgif-v1.0.tsv file in the data folder. We can use pandas library to open the file. We need to set delimiter as \t as the file contains tab separated values.

Copy
Copied
import pandas as pd
Copy
Copied
# Load dataset to a pandas dataframe
df = pd.read_csv(
    "./TGIF-Release-master/data/tgif-v1.0.tsv",
    delimiter="\t",
    names=['url', 'description']
)
df.head()
url description
0 https://38.media.tumblr.com/9f6c25cc350f12aa74... a man is glaring, and someone with sunglasses ...
1 https://38.media.tumblr.com/9ead028ef62004ef6a... a cat tries to catch a mouse on a tablet
2 https://38.media.tumblr.com/9f43dc410be85b1159... a man dressed in red is dancing.
3 https://38.media.tumblr.com/9f659499c8754e40cf... an animal comes close to another in the jungle
4 https://38.media.tumblr.com/9ed1c99afa7d714118... a man in a hat adjusts his tie and makes a wei...

Note the dataset does not contain the actual GIF files. But it has URLs we can use to download/access the GIF files. This is great as we do not need to store/download all the GIF files. We can directly load the required GIF files using the URL when displaying the search results.

There are some duplicate descriptions in the dataset.

Copy
Copied
len(df)
125782
Copy
Copied
# Number of *unique* GIFs in the dataset
len(df["url"].unique())
102068
Copy
Copied
dupes = df['url'].value_counts().sort_values(ascending=False)
dupes.head()
https://38.media.tumblr.com/ddbfe51aff57fd8446f49546bc027bd7/tumblr_nowv0v6oWj1uwbrato1_500.gif    4
https://33.media.tumblr.com/46c873a60bb8bd97bdc253b826d1d7a1/tumblr_nh7vnlXEvL1u6fg3no1_500.gif    4
https://38.media.tumblr.com/b544f3c87cbf26462dc267740bb1c842/tumblr_n98uooxl0K1thiyb6o1_250.gif    4
https://33.media.tumblr.com/88235b43b48e9823eeb3e7890f3d46ef/tumblr_nkg5leY4e21sof15vo1_500.gif    4
https://31.media.tumblr.com/69bca8520e1f03b4148dde2ac78469ec/tumblr_npvi0kW4OD1urqm0mo1_400.gif    4
Name: url, dtype: int64

Let's take a look at one of these duplicated URLs and it's descriptions.

Copy
Copied
dupe_url = "https://33.media.tumblr.com/88235b43b48e9823eeb3e7890f3d46ef/tumblr_nkg5leY4e21sof15vo1_500.gif"
dupe_df = df[df['url'] == dupe_url]

# let's take a look at this GIF and it's duplicated descriptions
for _, gif in dupe_df.iterrows():
    HTML(f"<img src={gif['url']} style='width:120px; height:90px'>")
    print(gif["description"])
two girls are singing music pop in a concert
a woman sings sang girl on a stage singing
two girls on a stage sing into microphones.
two girls dressed in black are singing.

There is no reason for us to remove these duplicates, as shown here, every description is accurate. You can spot check a few of the other URLs but they all seem to be the same where we have several accurate descriptions for a single GIF.

That leaves us with 125,781 descriptions for 102,067 GIFs. We will use these descriptions to create context vectors that will be indexed in a vector database to create our GIF search tool. Let's take a look at a few more examples of GIFs and their descriptions.

Copy
Copied
for _, gif in df[:5].iterrows():
  HTML(f"<img src={gif['url']} style='width:120px; height:90px'>")
  print(gif["description"])
a man is glaring, and someone with sunglasses appears.
a cat tries to catch a mouse on a tablet
a man dressed in red is dancing.
an animal comes close to another in the jungle
a man in a hat adjusts his tie and makes a weird face.

We can see that the description of the GIF accurately describes what is happening in the GIF, we can use these descriptions to search through our GIFs.

Using this data, we can build the GIF search tool with just two components:

  • a retriever to embed GIF descriptions
  • a vector database to store GIF description embeddings and retrieve relevant GIFs

Initialize Pinecone Index

The vector database stores vector representations of our GIF descriptions which we can retrieve using another vector (query vector). We will use the Pinecone vector database, a fully managed vector database that can store and search through billions of records in milliseconds. You could use any other vector database such as FAISS to build this tool. But you may need to manage the database yourself.

To initialize the database, we sign up for a free Pinecone API key and pip install pinecone-client. Once ready, we initialize our index with:

Copy
Copied
import pinecone

# Connect to pinecone environment
pinecone.init(
    api_key="<<YOUR_API_KEY>>",
    environment="us-west1-gcp"
)

index_name = 'gif-search'

# check if the gif-search exists
if index_name not in pinecone.list_indexes():
    # create the index if it does not exist
    pinecone.create_index(
        index_name,
        dimension=384,
        metric="cosine"
    )

# Connect to gif-search index we created
index = pinecone.Index(index_name)

Here we specify the name of the index where we will store our GIF descriptions and their URLs, the similarity metric, and the embedding dimension of the vectors. The similarity metric and embedding dimension can change depending on the embedding model used. However, most retrievers use "cosine" and 768.

Initialize Retriever

Next, we need to initialize our retriever. The retriever will mainly do two things:

  1. Generate embeddings for all the GIF descriptions (context vectors/embeddings)
  2. Generate embeddings for the query (query vector/embedding)

The retriever will generate the embeddings in a way that the queries and GIF descriptions with similar meanings are in a similar vector space. Then we can use cosine similarity to calculate this similarity between the query and context embeddings and find the most relevant GIF to our query.

We will use a SentenceTransformer model trained based on Microsoft's MPNet as our retriever. This model performs well out-of-the-box when searching based on generic semantic similarity.

Copy
Copied
from sentence_transformers import SentenceTransformer
Copy
Copied
# Initialize retriever with SentenceTransformer model 
retriever = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
retriever
SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

Generate Embeddings and Upsert

Now our retriever and the pinecone index are initialized. Next, we need to generate embeddings for the GIF descriptions. We will do this in batches to help us more quickly generate embeddings. This means our retriever will generate embeddings for 64 GIF descriptions at once instead of generating them individually (much faster) and send a single API call for each batch of 64 (also much faster).

When passing the documents to pinecone, we need an id (a unique value), embedding (embeddings for the GIF descriptions we have generated earlier), and metadata for each document representing GIFs in the dataset. The metadata is a dictionary containing data relevant to our embeddings. For the GIF search tool, we only need the URL and description.

Copy
Copied
from tqdm.auto import tqdm

# we will use batches of 64
batch_size = 64

for i in tqdm(range(0, len(df), batch_size)):
    # find end of batch
    i_end = min(i+batch_size, len(df))
    # extract batch
    batch = df.iloc[i:i_end]
    # generate embeddings for batch
    emb = retriever.encode(batch['description'].tolist()).tolist()
    # get metadata
    meta = batch.to_dict(orient='records')
    # create IDs
    ids = [f"{idx}" for idx in range(i, i_end)]
    # add all to upsert list
    to_upsert = list(zip(ids, emb, meta))
    # upsert/insert these records to pinecone
    _ = index.upsert(vectors=to_upsert)

    
# check that we have all vectors in index
index.describe_index_stats()
  0%|          | 0/1966 [00:00<?, ?it/s]





{'dimension': 384,
 'index_fullness': 0.05,
 'namespaces': {'': {'vector_count': 125782}}}

We can see all our documents are now in the pinecone index. Let's run some queries to test our GIF search tool.

Querying

We have two functions, search_gif, to handle our search query, and display_gif, to display the search results.

The search_gif function generates vector embedding for the search query using the retriever model and then runs the query on the pinecone index. index.query will compute the cosine similarity between the query embedding and the GIF description embeddings as we set the metric type as "cosine" when we initialize the pinecone index. The function will return the URL of the top 10 most relevant GIFs to our search query.

Copy
Copied
def search_gif(query):
    # Generate embeddings for the query
    xq = retriever.encode(query).tolist()
    # Compute cosine similarity between query and embeddings vectors and return top 10 URls
    xc = index.query(xq, top_k=10,
                    include_metadata=True)
    result = []
    for context in xc['matches']:
        url = context['metadata']['url']
        result.append(url)
    return result

The display_gif can display multiple GIFs using its URLs in the jupyter notebook in a grid style. We use this function to display the top 10 GIFs returned by the search_gif function.

Copy
Copied
def display_gif(urls):
    figures = []
    for url in urls:
        figures.append(f'''
            <figure style="margin: 5px !important;">
              <img src="{url}" style="width: 120px; height: 90px" >
            </figure>
        ''')
    return HTML(data=f'''
        <div style="display: flex; flex-flow: row wrap; text-align: center;">
        {''.join(figures)}
        </div>
    ''')

Let's begin testing some queries.

Copy
Copied
gifs = search_gif("a dog being confused")
display_gif(gifs)
Copy
Copied
gifs = search_gif("animals being cute")
display_gif(gifs)
Copy
Copied
gifs = search_gif("an animal dancing")
display_gif(gifs)

Let's describe the third GIF with the ginger dog dancing on his hind legs.

Copy
Copied
gifs = search_gif("a fluffy dog being cute and dancing like a person")
display_gif(gifs)

These look like pretty good, interesting results.