Retriever Models for Open Domain Question-Answering
It’s a sci-fi staple. A vital component of the legendary Turing test. The dream of many across the world. And, until recently, impossible.
We are talking about the ability to ask a machine a question and receive a genuinely intelligent, insightful answer.
Until recently, technology like this existed only in books, Hollywood, and our collective imagination. Now, it is everywhere. Most of us use this technology every day, and we often don’t even notice it.
Google is just one example. Over the last few years, Google has gradually introduced an intelligent question-answering angle to search. When we now ask how do I tie my shoelaces?" Google gives us the ‘exact answer’ alongside the context or video this answer came from:
In response to our question, Google finds the exact (audio-to-text) answer to be “Start by taking the first lace. And place it behind the second one…", and highlights the exact part of the video that contains this extracted answer.
We can ask other questions like “Is Google Skynet?" and this time return an even more precise answer “Yes”.
At least Google is honest.
In this example, Google returns an exact answer and the context (paragraph) from where the answer is extracted.
How does Google do this? And more importantly, why should we care?
This search style emulates a human-like interaction. We’re asking a question in natural language as if we were speaking to another person. This natural language Q&A creates a very different search experience to traditional search.
Imagine you find yourself in the world’s biggest warehouse. You have no idea how the place is organized. All you know is that your task is to find some round marble-like objects.
Where do you start? Well, we need to figure out how the warehouse is organized. Maybe everything is stored alphabetically, categorized by industry, or intended use. The traditional search interface requires that we understand how the warehouse is structured before we begin searching. Often, there is a specific ‘query language’ such as:
SELECT * WHERE 'material' == 'marble' or ("marble" | "stone") & "product"
Our first task is to learn this query language so we can search. Once we understand how the warehouse is structured, we use that knowledge to begin our search. How do we find “round marble-like objects”? We can narrow our search down using similar queries to those above, but we are in the world’s biggest warehouse, so this will take a very long time.
Without a natural Q&A-style interface, this is your search. Unless your users know the ins and outs of the warehouse and its contents, they’re going to struggle.
What happens if we add a natural Q&A-style interface to the warehouse? Imagine we now have people in the warehouse whose entire purpose is to guide us through the warehouse. These people know exactly where everything is.
Those people can understand our question of “where can I find the round marble-like objects?". It may take a few tries until we find the exact object we’re looking for, but we now have a guide that understands our question. There is no longer the need to understand how the warehouse is organized nor to know the exact name of what it is we’re trying to find.
With this natural Q&A-style interface, your users now have a guide. They just need to be able to ask a question.
How can we design these natural, human-like Q&A interfaces? The answer is open-domain question-answering (ODQA). ODQA allows us to use natural language to query a database.
That means that, given a dataset like a set of internal company documents, online documentation, or as is the case with Google, everything on the world’s internet, we can retrieve relevant information in a natural, more human way.
However, ODQA is not a single model. It is more complex and requires three primary components.
- A vector database to store information-rich vectors that numerically represent the meaning of contexts (paragraphs that we use to extract answers to our questions).
- The retriever model encodes questions and contexts into the same vector space. It is these context vectors that we later store in the vector database. The retriever also encodes questions to be compared to the context vectors in a vector database to retrieve the most relevant contexts.
- A reader model takes a question and context and attempts to identify a span (sub-section) from the context which answers the question.
Building a retriever model is our focus here. Without it, there is no ODQA; it is arguably the most critical component in the whole process. We need our retriever model to return relevant results; otherwise, the reader model will receive and output garbage.
If we instead had a mediocre reader model, it may still return garbage to us, but it has a much smaller negative impact on the ODQA pipeline. A good retriever means we can at least retrieve relevant contexts, therefore successfully returning relevant information to the user. A paragraph-long context isn’t as clean-cut as a perfectly framed two or three-word answer, but it’s better than nothing.
Our focus in this article is on building a retriever model, of which the vector database is a crucial component, as we will see later.
Train or Not?
Do we need to fine-tune our retriever models? Or can we use pretrained models like those in the HuggingFace model hub?
The answer is: It depends. An excellent concept from Nils Reimers describes the difficulty of benchmarking models where the use case is within a niché domain that very few people would understand. The idea is that most benchmarks and datasets focus on this short head of knowledge (where most people understand), whereas the most exciting use cases belong in the long-tail portion of the graph .
Nils Reimer’s long tail of semantic relatedness . The more people that know about something (y-axis), the easier it is to find benchmarks and labeled data (x-axis), but the most interesting use cases belong in the long-tail region.
We can take the same idea and modify the x-axis to indicate whether we should be able to take a pretrained model or fine-tune our own.
The more something is common knowledge (y-axis), the easier it is to find pretrained models that excel in the broader, more general scope. However, as before, most interesting use cases belong in the long-tail, and here is where we would need to fine-tune our own model.
Imagine you are walking down your local high street. You pick a stranger at random and ask them the sort of question that you would expect from your use case. Do you think they would get the answer? If there’s a good chance they will, you might be able to get away with a pretrained model.
On the other hand, if you ask this stranger what the difference is between RoBERTa and DeBERTa, there is a very high chance that they will have no idea what you’re asking. In this case, you will probably need to fine-tune a retriever model.
Fine-Tuning a Retriever
Let’s assume the strangers on the street have no chance of answering our questions. Most likely, a custom retriever model is our best bet. But, how do we train/fine-tune a custom retriever model?
The very first ingredient is data. Our retriever consumes a question and returns relevant contexts to us. For it to do this, it must learn to encode similar question-context pairs into the same vector space.
The retriever model must learn to encode similar question-context pairs into a similar vector space.
Our first task is to find and create a set of question-context pairs. One of the best-known datasets for this is the Stanford Question Answering Dataset (SQuAD).
Step One: Data
SQuAD is a reading comprehension dataset built from question, context, and answers with information from Wikipedia articles. Let’s take a look at an example.
We first download the
squad_v2 dataset via 🤗 Datasets. In the first sample, we can see:
title(or topic) of Beyoncé
context, a short paragraph from Wikipedia about Beyoncé
question, “When did Beyonce start becoming popular?"
- the answer
text, “in the late 1990s”, which is extracted from the context
answer_start, which is the starting position of the answer within the context string.
The SQuAD v2 dataset contains 130,319 of these samples, more than enough for us to train a good retriever model.
We will be using the Sentence Transformers library to train our retriever model. When using this library, we must format our training data into a list of
After creating this list of
InputExample objects, we need to load them into a data loader. A data loader is commonly used with PyTorch, which Sentence Transformers uses under the hood. Because of this, we can often use the PyTorch
However, we need to do something slightly different. Our training data consists of positive question-context pairs; positive meaning that every sample in our dataset can be viewed as having a positive or high similarity. There are no negative or dissimilar pairs.
When our data looks like this, one of the most effective training techniques we can use uses the Multiple Negatives Ranking (MNR) loss function. We will not explain MNR loss in this article, but you can learn about it here.
One crucial property of training with MNR loss is that each training batch does not contain duplicate questions or contexts. This is a problem, as the SQuAD data includes several questions for each context. Because of this, if we used the standard
DataLoader, there is a high probability that we would find duplicate contexts in our batches.
Screenshot from HuggingFace’s dataset viewer for the squad_v2 dataset. Each row represents a different question, but they all map to the same context.
Fortunately, there is an easy solution to this. Sentence Transformers provides a set of modified data loaders. One of those is the
NoDuplicatesDataLoader, which ensures our batches contain no duplicates.
With that, our training data is fully prepared, and we can move on to initializing and training our retriever model.
Step Two: Initialize and Train
Before training our model, we need to initialize it. For this, we begin with a pretrained transformer model from the HuggingFace model hub. A popular choice for sentence transformers is Microsoft’s MPNet model, which we access via
There is one problem with our pretrained transformer model. It outputs many word/token-level vector embeddings. We don’t want token vectors; we need sentence vectors.
We need a way to transform the many token vectors output by the model into a single sentence vector.
Transformation of the many token vectors output by a transformer model into a single sentence vector.
To perform this transformation, we add a mean pooling layer to process the outputs of the transformer model. There are a few different pooling techniques. The one that we will use is mean pooling. This approach will take the many token vectors output by the model and average the activations across each vector dimension to create a single sentence vector.
We can do this via
SentenceTransformer utilities of the Sentence Transformers library.
We have a
SentenceTransformer object; a pretrained
microsoft/mpnet-base model followed by a mean pooling layer.
With our model defined, we can initialize our MNR loss function.
That is everything we need for fine-tuning the model. We set the number of training epochs to
1; anything more for sentence transformers often leads to overfitting. Another method to reduce the likelihood of overfitting is adding a learning rate warmup. Here, we warmup for the first 10% of our training steps (10% is the go to % for warmup steps; if you find the model is overfitting, try increasing the number).
We now have an ODQA retriever model saved to the local
./mpnet-mnr-squad2 directory. That’s great, but we have no idea how well the model performs, so our next step is to evaluate model performance.
Evaluation of retriever models is slightly different from the evaluation of most language models. Typically, we input some text and calculate the error between clearly defined predicted and true values.
For information retrieval (IR), we need a metric that measures the rate of successful vs. unsuccessful retrievals. A popular metric for this is mAP@K. In short, this is an averaged precision value (fraction of retrieved contexts that are relevant) that considers the top K of retrieved results.
The setup for IR evaluation is a little more involved than with other evaluators in the Sentence Transformers library. We will be using the
InformationRetrievalEvaluator, and this requires three inputs:
ir_queriesis a dictionary mapping question IDs to question text
ir_corpusmaps context IDs to context text
ir_relevant_docsmaps question IDs to their relevant context IDs
Before we initialize the evaluator, we need to download a new set of samples that our model has not seen before and format them into the three dictionaries above. We will use the SQuAD validation set.
To create the dictionary objects required by the
InformationRetrievalEvaluator, we must assign unique IDs to both contexts and questions. And we need to ensure that duplicate contexts are not assigned different IDs. To handle these, we will first convert our dataset object into a Pandas dataframe.
From here, we can quickly drop duplicate contexts with the
drop_duplicates method. As we no longer have duplicates, we can append
'con' to each context ID, giving each unique context a unique ID different from any question IDs.
We now have unique question IDs in the
squad_df dataframe and unique context IDs in the
no_dupe dataframe. Next, we perform an inner join on the
context feature to bring these two sets of IDs together and find our question ID to context ID mappings.
We’re now ready to build the three mapping dictionaries for the
InformationRetrievalEvaluator. First, we map question/context IDs to questions/contexts.
And then map question IDs to a set of relevant context IDs. For the SQuAD data, we only have many-to-one or one-to-one question ID to context ID mappings, but we will write our code to additionally handle one-to-many mappings (so we can handle other, non-SQuAD datasets).
Our evaluator inputs are ready, so we initialize the evaluator and then evaluate our
We return a mAP@K score of 0.74, where @K is 100 by default. This performance is comparable to other state-of-the-art retriever models. Performing the same evaluation with the
multi-qa-mpnet-base-cos-v1 returns a mAP@K score of 0.76, just two percentage points greater than our custom model.
Of course, if your target domain was SQuAD data, the pretrained
multi-qa-mpnet-base-cos-v1 model would be the better model. But if you have your own unique dataset and domain. A custom model fine-tuned on that domain will very likely outperform existing models like
multi-qa-mpnet-base-cos-v1 in that domain.
Storing the Vectors
We have our retriever model, we’ve evaluated it, and we’re happy with its performance. But we don’t know how to use it.
When you perform a Google search, Google does not look at the whole internet, encode all of that information into vector embeddings, and then compare all of those vectors to your query vector. We would be waiting a very long time to return results if that were the case.
Instead, Google has already searched for, collected, and encoded all of that data. Google then stores those encoded vectors in some sort of vector database. When you query now, the only thing Google needs to encode is your question.
Taking this a step further, comparing your query vector to all vectors indexed by Google (which represent the entire Google-accessible internet) would still take an incredibly long time. We refer to this accurate but inefficient comparison of every single vector as an exhaustive search.
For big datasets, an exhaustive search is too slow. The solution to this is to perform an approximate search. An approximate search allows us to massively reduce our search scope to a smaller but (hopefully) more relevant sub-section of the index. Making our search times much more manageable.
The Pinecone vector database is a straightforward and robust solution that allows us to (1) store our context vectors and (2) perform an accurate and fast approximate search. These are the two elements we need for a promising ODQA pipeline.
Again, we need to work through a few steps to set up our vector database.
Steps from retriever and context preparation (top-right) that allow us to encode contexts into context vectors. After initializing a vector database index, we can populate the index with the context vectors.
After working through each of those steps, we will be ready to begin retrieving relevant contexts.
We have already created our retriever model, and during the earlier evaluation step, we downloaded the SQuAD validation data. We can use this same validation data and encode all unique contexts.
After removing duplicate contexts, we’re left with 1,204 samples. It is a tiny dataset but large enough for our example.
Initializing the Index
Before adding the context vectors to our index, we need to initialize it. Fortunately, Pinecone makes this very easy. We start by installing the Pinecone client if required:
!pip install pinecone-client
Then we initialize a connection to Pinecone. For this, we need a free API key.
We then create a new index with
pinecone.create_index. Before initializing the index, we should check that the index name does not already exist (which it will not if this is your first time creating the index).
When creating a new index, we need to specify the index
name, and the dimensionality of vectors to be added. We either check our encoded context vectors’ dimensions directly or find the dimension attribute within the retriever model (as shown above).
Populating the Index
After creating both our index and the context vectors, we can go ahead and upsert (upload) the vectors into our index.
Pinecone expects us to upsert data in the format:
vectors = [ (id_0, vector_0, metadata_0), (id_1, vector_1, metadata_1) ]
Our IDs are the unique alphanumeric identifiers that we saw earlier in the SQuAD data. The vectors are our encoded context vectors formatted as lists; the metadata is a dictionary that allows us to store extra information in a key-value format.
Using the metadata field, Pinecone allows us to create complex or straightforward metadata filters to target our search scope to specific numeric ranges, categories, and more.
Once the upsert is complete, the retrieval components of our ODQA pipeline are ready to go, and we can begin asking questions.
With everything set up, querying our retriever-vector database pipeline is pretty straightforward. We first define a question and encode it as we did for our context vectors before.
After creating our query vector, we pass it to Pinecone via the
index.query method, specify how many results we’d like to return with
include_metadata so that we can see the text associated with each returned vector.
We return the correct context as our second top result in this example. The first result is relevant in the context of Normans and Normandy, but it does not answer the specific question of when the Normans were in Normandy.
Let’s try a couple more questions.
For this question, we return the correct context as the highest result with a much higher score than the remaining samples.
We return the correct context in the first position. Again, there is a good separation between sample scores of the correct context and other contexts.
That’s it for this guide to fine-tuning and implementing a custom retriever model in an ODQA pipeline. Now we can implement two of the most crucial components in ODQA: enabling a more human and natural approach to information retrieval.
One of the most incredible things about ODQA is how widely applicable it is. Organizations across almost every industry have the opportunity to benefit from more intelligent and efficient information retrieval.
Any organization that handles unstructured information such as word documents, PDFs, emails, and more has a clear use case: freeing this information and enabling easy and natural access through QA systems.
Although this is the most apparent use case, there are many more, whether it be an internal efficiency speedup or a key component in a product (as with Google search). The opportunities are both broad and highly impactful.
 N. Reimers, Neural Search for Low Resource Scenarios (2021), YouTube
S. Sawtelle, Mean Average Precision (MAP) For Recommender Systems (2016), GitHub