# Domain Adaptation with Generative Pseudo-Labeling (GPL)

> This futuristic vision had seemed to be utterly infeasible, but, in recent years, has become much more than a dream. Thanks to techniques like Generative Pseudo-Labeling (GPL) that allow us to fine-tune new or existing models in previously inaccessible domains, machines are ever closer to understanding the meaning behind the content on the web.

In 1999, a concept known as the _semantic web_ was described by the creator of the _World Wide Web_, Tim Berners-Lee. This dream of Berners-Lee was the internet of today that we know and love but deeply understood by machines [1].

This futuristic vision had seemed to be utterly infeasible, but, in recent years, has become much more than a dream. Thanks to techniques like **G**enerative **P**seudo-**L**abeling (GPL) that allow us to fine-tune new or existing models in previously inaccessible domains, machines are ever closer to understanding the meaning behind the content on the web.

> _I have a dream for the Web [in which computers] become capable of analyzing all the data on the Web – the content, links, and transactions between people and computers. A "Semantic Web", which makes this possible, has yet to emerge, but when it does, the day-to-day mechanisms of trade, bureaucracy and our daily lives will be handled by machines talking to machines. The "intelligent agents" people have touted for ages will finally materialize. - **Tim Berners-Lee, 1999**_

Berners-Lee’s vision is fascinating, in particular the _“intelligent agents”_ referred to in his quote above. Generally speaking, these intelligent agents (IAs) perceive their environment, take actions based on that environment to achieve a particular goal, and improve their performance with self-learning.

These IAs sound very much like the ML models we see today. Some models can look at a web page’s content (its environment), scrape and classify the meaning of this content (takes actions to achieve goals), and do this through an iterative learning process.

A model that can read and comprehend the _meaning_ of language from the internet is a vital component of the semantic web. There are already models that can do this within a limited scope.

However, there is a problem. These **L**anguage **M**odels (LMs) need to learn before becoming these autonomous, language-comprehending IAs. They must be _trained_.

Training LMs is hard; they need vast amounts of data. In particular, bi-encoder models (as explored in earlier chapters on [AugSBERT](https://www.pinecone.io/learn/series/nlp/data-augmentation/) and [GenQ](https://www.pinecone.io/learn/series/nlp/genq/)) that can enable a large chunk of this semantic web are notoriously data-hungry.

Sometimes this is okay. We can fine-tune a model easily in places where we have massive amounts of relevant and labeled data. We can use a simple [supervised fine-tuning approach](https://www.pinecone.io/learn/series/nlp/fine-tune-sentence-transformers-mnr/). Unfortunately, these scenarios are few and far between. It is for this reason that existing models have this limited scope. So, what can we do?

We should first consider why it is hard to get data to train these models. On one hand, the internet is full of data, and, on the other, this data is _not_ in the format we need. We usually need to use a supervised training method to train a high-performance bi-encoder model.

Supervised training methods require labeled data. The problem with labeled data is that a human must (almost always) manually create it.

Currently, we have data-hungry models that require supervised training methods. We must find a way to train a model with _little_ labeled data or use _unsupervised_ methods that need nothing more than unstructured text data.

Fortunately, there are _some_ unsupervised (or supervised using very little data) approaches, such as:

- [Multilingual Knowledge Distillation](https://www.pinecone.io/learn/series/nlp/multilingual-transformers/#Training-approaches) for low-resource languages.
- [TSDAE](https://www.pinecone.io/learn/series/nlp/unsupervised-training-sentence-transformers/) for building simple similarity models without labeled data.
- Data augmentation with AugSBERT for [in-domain](https://www.pinecone.io/learn/series/nlp/data-augmentation/) and [out-of-domain](https://www.pinecone.io/learn/series/nlp/domain-transfer/) tasks.
- [GenQ](https://www.pinecone.io/learn/series/nlp/genq/) for asymmetric semantic search without labeled data.

We can apply these approaches in different scenarios with varying degrees of success. As we’ve seen, there is a lot of potential for models being trained using unsupervised techniques. These no (or low) resource scenarios cover the vast majority of use-cases, many of which are the most unique and interesting.

![As the domain (eg topic, language) becomes more niche, the number of available labeled datasets decreases. The vast majority of domains have no labeled datasets.](https://cdn.sanity.io/images/vr8gru94/production/b930b821ba37c4ca9c7be4d1ce58dec49c0089c7-1760x966.png)


For example, we may identify an opportunity to introduce semantic search on internal financial documents with highly technical language specific to our organization. Or a specific use-case using a less common language such as Swahili or Dhivehi.

It is infeasible for the semantic web to find labeled data for every topic, language, or format of information found on the internet. Because of this, the dream only becomes a reality once there are techniques that can train or adapt _high-performance_ IAs with nothing more than the text found on the internet, without human-made labels or curation.

There is research producing techniques placing us ever closer to this reality. One of the most promising is GPL [2]. GPL is almost a culmination of the techniques listed above. At its core, it allows us to take unstructured text data and use it to build models that can understand this text. These models can then intelligently respond to natural language queries regarding this same text data.

It is a fascinating approach, with massive potential across innumerous use cases spanning all industries and borders. With that in mind, let’s dive into the details of GPL and how we can implement it to build high-performance LMs with nothing more than plain text.

---

Watch our webinar [Searching Freely: Using GPL for Semantic Search](https://www.youtube.com/watch?v=OQhoi1CabWw) for a rundown of GPL presented by Nils Reimers, the creator of _sentence-transformers_.

---

## GPL Overview

[Video](https://www.youtube.com/watch?v=uEbCXwInnPs)


GPL can be used in two ways, as a technique used in fine-tuning a pretrained model (such as the BERT base model), or as a technique for domain adaptation of an already fine-tuned bi-encoder model (such as SBERT).

By _domain adaptation_ we mean the adaptation of an existing [sentence transformer](https://www.pinecone.io/learn/series/nlp/sentence-embeddings/) to new topics (domains). Effective domain adaptation is incredibly helpful in taking models pretrained on large, existing datasets, and helping them understand a new domain that lacks any labeled datasets.

For example, any model trained on data from before 2019 will be blissfully unaware of COVID-19 and everything that comes with it. If we query any of these pre-2019 models about COVID, they will struggle to return relevant information as they simply do not know what it is, and what is relevant.

```json
{
  "_key": "a94981405b34",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"f5b302ad-85ff-4323-bb46-3bda4dd996dd\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \"Query: How is COVID-19 transmitted\\n\",\n      \"94.83\\tEbola is transmitted via direct contact with blood\\n\",\n      \"92.86\\tHIV is transmitted via sex or sharing needles\\n\",\n      \"92.31\\tCorona is transmitted via the air\\n\",\n      \"\\n\",\n      \"Query: what is the latest named variant of corona\\n\",\n      \"92.94\\tpeople keep drinking corona lager\\n\",\n      \"91.91\\tthe omicron variant of SARS-CoV-2\\n\",\n      \"85.27\\tCOVID-19 is an illness caused by a virus\\n\",\n      \"\\n\",\n      \"Query: What are the symptoms of COVID-19?\\n\",\n      \"90.34\\tcommon Flu symptoms include a high temperature, aching body, and fatigue\\n\",\n      \"89.36\\tcommon Corona symptoms include a high temperature, cough, and fatigue\\n\",\n      \"87.82\\tsymptoms are a physical or mental features which indicate a condition of disease\\n\",\n      \"\\n\",\n      \"Query: How will most people identify that they have contracted the coronavirus\\n\",\n      \"91.03\\tafter drinking too many bottles of corona beer most people are hungover\\n\",\n      \"86.55\\tthe most common COVID-19 symptoms include a high temperature, cough, and fatigue\\n\",\n      \"84.72\\tcommon symptoms of flu include a high temperature, aching, and exhaustion\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"show_examples(old_model)\\n\",\n    \"# we can see similarity scores below\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"environment\": {\n   \"kernel\": \"python3\",\n   \"name\": \"common-cu110.m91\",\n   \"type\": \"gcloud\",\n   \"uri\": \"gcr.io/deeplearning-platform-release/base-cu110:m91\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.12\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
}
```

GPL hopes to solve this problem by allowing us to take existing models and _adapt_ them to new domains using nothing more than unlabeled data. By using unlabeled data we greatly enhance the ease of finding relevant data, all we need is unstructured text.

---

Just looking for a fast implementation of GPL? Skip ahead to [“A Simpler Approach”](https://www.pinecone.io/learn/series/nlp/gpl/#A-Simpler-Approach).

---

As you may have guessed, the same applies to the first scenario of _fine-tuning a pretrained model_. It can be hard to find relevant, labeled data. With GPL we don’t need to. Unstructured text is all you need.

### How it Works

At a high level, GPL consists of _three_ data preparation steps and _one_ fine-tuning step. We will begin by looking at the data preparation portion of GPL. The three steps are:

- Query generation, creating queries from passages.
- Negative mining, retrieving similar passages that do not match (negatives).
- Pseudo labeling, using a cross-encoder model to assign similarity scores to pairs.

![Overview of the GPL process. Beginning with passages P+, we generate queries Q. The passages are indexed and a dense retrieval step is used to find high similarity ‘negative’ passages P-. We then use a cross encoder to produce margin scores.](https://cdn.sanity.io/images/vr8gru94/production/419064e30468f6a4b29afa215bdd7da6e90f48a1-1920x1680.png)


Each of these steps requires the use of a pre-existing model fine-tuned for each task. The team that introduced GPL also provided models that handle each task. We will discuss these models as we introduce each step and note alternative models where relevant.

#### 1. Query Generation

GPL is perfect for scenarios where we have no labeled data. However, it does require a _large amount_ of unstructured text. That could be text data scraped from web pages, PDF documents, etc. The only requirement is that this text data is _in-domain_, meaning it is relevant to our particular use case.

![For a target domain of German financial documents, any data that fits the topic and we would expect our model to encounter is in-domain. Anything else is out-of-domain.](https://cdn.sanity.io/images/vr8gru94/production/65fb375f9316180617c42d07029ed835e8b2213e-1920x1080.png)


In our examples, we will use the _CORD-19_ dataset. CORD-19 can be downloaded using the script [found here](https://gist.github.com/jamescalam/06882ea05a307420c1354481325f08ad#file-00_download_cord_19-ipynb). The script will leave many JSON files in a directory called _document_parses/pdf_json_ that we will be using in our query generation step. We will use a generator function called `get_text` to read in those files.

```json
{
  "_key": "72610cc4fa83",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from pathlib import Path\\n\",\n    \"import json\\n\",\n    \"\\n\",\n    \"paths = [str(path) for path in Path('document_parses/pdf_json').glob('*.json')]\\n\",\n    \"\\n\",\n    \"def get_text():\\n\",\n    \"    for path in paths:\\n\",\n    \"        with open(path, 'r') as fp:\\n\",\n    \"            doc = json.load(fp)\\n\",\n    \"        # extract the passages of text from each document\\n\",\n    \"        body_text = [line['text'] for line in doc['body_text']]\\n\",\n    \"        # loop through and yield one passage at a time\\n\",\n    \"        for passage in body_text:\\n\",\n    \"            yield passage\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

If you’ve read our previous [chapter on GenQ](https://www.pinecone.io/learn/series/nlp/genq/), this step follows the same query generation process. However, we will outline the process for new readers.

We’re starting with passages of data (the unstructured text data). Generally, these are reasonably long chunks of text, but not always.

```json
{
  "_key": "595fdaf0fb20",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Until recently, seven types of coronaviruses had been reported to cause infections in humans...\\n\",\n      \"An 80-year-old male with a medical history of diabetes, hypertension, dyslipidemia, asthma, coronary...\\n\",\n      \"COVID-19 is a pandemic illness that primarily affects the respiratory system with a wide spectrum of...\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"passages = get_text()\\n\",\n    \"\\n\",\n    \"for i, passage in enumerate(passages):\\n\",\n    \"    print(passage)\\n\",\n    \"    if i == 2: break\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

Given a passage, we pass it through a _query generation_ T5 model. We can initialize this T5 model using _HuggingFace Transformers_.

---

_T5 refers to Google’s Text-to-Text Transfer Transformer. We discuss it in more detail in_ _[the chapter on GenQ](https://www.pinecone.io/learn/series/nlp/genq/)._

---

```json
{
  "_key": "6abef2da0e18",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\\n\",\n    \"\\n\",\n    \"model_name = 'doc2query/msmarco-t5-base-v1'\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_name)\\n\",\n    \"model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

We are using the `doc2query/msmarco-t5-base-v1` model that was trained on a pre-COVID dataset. Nonetheless, when generating queries for COVID-related text the model can produce sensible questions by copying words from the passage text.

With this T5 model, we can begin generating queries that we use to produce synthetic _(query, passage) pairs_.

```json
{
  "_key": "068b8877d1bb",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"for passage in passages:\\n\",\n    \"    break  # just pull a single example\\n\",\n    \"\\n\",\n    \"# tokenize the passage\\n\",\n    \"inputs = tokenizer(passage, return_tensors='pt')\\n\",\n    \"# generate three queries\\n\",\n    \"outputs = model.generate(\\n\",\n    \"    input_ids=inputs['input_ids'].cuda(),\\n\",\n    \"    attention_mask=inputs['attention_mask'].cuda(),\\n\",\n    \"    max_length=64,\\n\",\n    \"    do_sample=True,\\n\",\n    \"    top_p=0.95,\\n\",\n    \"    num_return_sequences=3)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Paragraph:\\n\",\n      \"The surgeons of COVID-19 dedicated hospitals do rarely practice surgery. When ICU patients need mechanical ventilation, percutaneous tracheostomy under endoscopic control is mostly performed...\\n\",\n      \"\\n\",\n      \"Generated Queries:\\n\",\n      \"1: why percutaneous tracheostomy\\n\",\n      \"2: what is percutaneous tracheostomy under endoscopic control\\n\",\n      \"3: what is percutaneous tracheostomy\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(\\\"Paragraph:\\\")\\n\",\n    \"print(passage)\\n\",\n    \"\\n\",\n    \"print(\\\"\\\\nGenerated Queries:\\\")\\n\",\n    \"for i in range(len(outputs)):\\n\",\n    \"    query = tokenizer.decode(outputs[i], skip_special_tokens=True)\\n\",\n    \"    print(f'{i + 1}: {query}')\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

Query generation is _not perfect_. It can generate noisy, sometimes nonsensical queries. And this is where GPL improved upon GenQ. GenQ relies heavily on these synthetic queries being high-quality with little noise. With GPL, this is not the case as the final cross-encoder step labels the similarity of pairs. Meaning dissimilar pairs are likely to be labeled as such. GenQ does not have any such labeling step.

We now have _(query, passage) pairs_ and can move onto the next step of identifying _negative_ passages.

[Full script](https://gist.github.com/jamescalam/8498a03b66b0317d21575c0a0ac50f66#file-01_query_gen-ipynb)

### 2. Negative Mining

The (query, passage) pairs we have now are assumed to be positively similar, written as _(Q, P+)_ where the query is _Q_, and the positive passage is _P+_.

Suppose we fine-tune our bi-encoder on only these positive matches. In that case, our model will struggle to learn more nuanced differences. A good model must learn to distinguish between similar and dissimilar pairs even where the content of these different pairs is very similar.

To fix this, we perform a _negative mining_ step to find highly similar passages to existing _P+_ passages. As these new passages will be highly similar but _not_ matches to our query _Q_, our model will need to learn how to distinguish them from genuine matches _P+_. We refer to these non-matches as _negative passages_ and are written as _P-_.

The negative mining process is a retrieval step where, given a query, we return the _top_k_ most similar results. Excluding the positive passage (if returned), we assume all other returned passages are negatives. We then select one of these _negative passages_ at random to become the negative pair for our query.

It may seem counterintuitive at first. Why would we return the most similar passages and train a model to view these as dissimilar?

Yes, those returned results are the most similar passages to our query, but they are _not_ the correct passage for our query. We are, in essence, increasing the similarity gap between _the_ correct passage and all other passages, no matter how similar they may be.

Adding these _‘negative’_ training examples _(Q, P-)_ is a common approach used in many bi-encoder fine-tuning methods, including multiple negatives ranking _and_ margin MSE loss (the latter of which we will be using). Using hard negatives in-particular can significantly improve the performance of our models [3].

![The impact on model performance trained on MSMARCO with and without hard negatives. Model training used margin MSE loss. Adapted from [3].](https://cdn.sanity.io/images/vr8gru94/production/cf65a54c973a99f626a25dfa77ed30f7c4a130f9-1744x1059.png)


When we later tune our model to identify the difference between these positive and negative passages, we are teaching it to determine what are often very nuanced differences.

With all of that in mind, we do need to understand that only some of the returned passages will be relevant. We will explain how that is handled in the Pseudo-labeling step later.

Moving on to the implementation of negative mining. As before, we need an existing model to embed our passages and create searchable [dense vectors](https://www.pinecone.io/learn/vector-database/). We use the `msmarco-distilbert-base-tas-b` bi-encoder which was fine-tuned on pre-COVID datasets.

```json
{
  "_key": "e234c249681e",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"SentenceTransformer(\\n\",\n       \"  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: DistilBertModel \\n\",\n       \"  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\\n\",\n       \")\"\n      ]\n     },\n     \"execution_count\": 1,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"from sentence_transformers import SentenceTransformer\\n\",\n    \"\\n\",\n    \"model = SentenceTransformer('msmarco-distilbert-base-tas-b')\\n\",\n    \"model.max_seq_length = 256\\n\",\n    \"model\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

In the GPL paper, two retrieval models are used and their results compared. To keep things simple, we will stick with a single model.

We need a [vector database](https://www.pinecone.io/learn/vector-database/) to store the passage embeddings. We will use Pinecone as an incredibly easy-to-use service that can scale to the millions of passage embeddings we’d like to search.

```json
{
  "_key": "748e975b8c58",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import pinecone\\n\",\n    \"\\n\",\n    \"with open('secret', 'r') as fp:\\n\",\n    \"    API_KEY = fp.read()  # get api key app.pinecone.io\\n\",\n    \"\\n\",\n    \"pinecone.init(\\n\",\n    \"    api_key=API_KEY,\\n\",\n    \"    environment='YOUR_ENV'  # find next to API key in console\\n\",\n    \")\\n\",\n    \"# create a new genq index if does not already exist\\n\",\n    \"if 'negative-mine' not in pinecone.list_indexes():\\n\",\n    \"    pinecone.create_index(\\n\",\n    \"        'negative-mine',\\n\",\n    \"        dimension=model.get_sentence_embedding_dimension(),\\n\",\n    \"        metric='dotproduct',\\n\",\n    \"        pods=1  # increase for faster mining\\n\",\n    \"    )\\n\",\n    \"# connect\\n\",\n    \"index = pinecone.Index('negative-mine')\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

We encode our passages, assign unique IDs to each, and then upload the record to Pinecone. As we later need to match these returned vectors back to their original plaintext format, we will create an ID-to-passage mapping to be stored locally.

```json
{
  "_key": "bfe45e1f8124",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"f123d57309b042eca4ce279ec0aff06e\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"  100%|██████████| 200/200 [44:31<00:00, 0.07it/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'dimension': 768,\\n\",\n       \" 'index_fullness': 0.0,\\n\",\n       \" 'namespaces': {'': {'vector_count': 67840}}}\"\n      ]\n     },\n     \"execution_count\": 4,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"pair_gen = get_text()  # generator that loads (query, passage) pairs\\n\",\n    \"\\n\",\n    \"pairs = []\\n\",\n    \"to_upsert = []\\n\",\n    \"passage_batch = []\\n\",\n    \"id_batch = []\\n\",\n    \"batch_size = 64  # encode and upload size\\n\",\n    \"\\n\",\n    \"for i, (query, passage) in enumerate(pairs_gen):\\n\",\n    \"    pairs.append((query, passage))\\n\",\n    \"    # we do this to avoid passage duplication in the vector DB\\n\",\n    \"    if passage not in passage_batch: \\n\",\n    \"        passage_batch.append(passage)\\n\",\n    \"        id_batch.append(str(i))\\n\",\n    \"    # on reaching batch_size, we encode and upsert\\n\",\n    \"    if len(passage_batch) == batch_size:\\n\",\n    \"        embeds = model.encode(passage_batch).tolist()\\n\",\n    \"        # upload to index\\n\",\n    \"        index.upsert(vectors=list(zip(id_batch, embeds)))\\n\",\n    \"        # refresh batches\\n\",\n    \"        passage_batch = []\\n\",\n    \"        id_batch = []\\n\",\n    \"        \\n\",\n    \"# check number of vectors in the index\\n\",\n    \"index.describe_index_stats()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

[Full version here](https://gist.github.com/jamescalam/9b84408d6c7f1fe4bf7eda2ab410c086#file-02_negative_mining-ipynb).



The vector database is set up for us to begin _negative mining_. We loop through each query, returning _10_ of the most similar passages by setting `top_k=10`.

```json
{
  "_key": "0b3754de5cff",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"33e05700fb8a43daadf39f5c2f2166d5\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"  0%|          | 0/2000 [00:00<?, ?it/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"import random\\n\",\n    \"\\n\",\n    \"batch_size = 100\\n\",\n    \"triplets = []\\n\",\n    \"\\n\",\n    \"for i in tqdm(range(0, len(pairs), batch_size)):\\n\",\n    \"    # embed queries and query pinecone in batches to minimize network latency\\n\",\n    \"    i_end = min(i+batch_size, len(pairs))\\n\",\n    \"    queries = [pair[0] for pair in pairs[i:i_end]]\\n\",\n    \"    pos_passages = [pair[1] for pair in pairs[i:i_end]]\\n\",\n    \"    # create query embeddings\\n\",\n    \"    query_embs = model.encode(queries, convert_to_tensor=True, show_progress_bar=False)\\n\",\n    \"    # search for top_k most similar passages\\n\",\n    \"    res = index.query(query_embs.tolist(), top_k=10)\\n\",\n    \"    # iterate through queries and find negatives\\n\",\n    \"    for query, pos_passage, query_res in zip(queries, pos_passages, res['results']):\\n\",\n    \"        top_results = query_res['matches']\\n\",\n    \"        # shuffle results so they are in random order\\n\",\n    \"        random.shuffle(top_results)\\n\",\n    \"        for hit in top_results:\\n\",\n    \"            neg_passage = pairs[int(hit['id'])][1]\\n\",\n    \"            # check that we're not just returning the positive passage\\n\",\n    \"            if neg_passage != pos_passage:\\n\",\n    \"                # if not we can add this to our (Q, P+, P-) triplets\\n\",\n    \"                triplets.append(query+'\\\\t'+pos_passage+'\\\\t'+neg_passage)\\n\",\n    \"                break\\n\",\n    \"\\n\",\n    \"with open('data/triplets.tsv', 'w', encoding='utf-8') as fp:\\n\",\n    \"    fp.write('\\\\n'.join(triplets))  # save training data to file\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# delete the index when done to avoid higher charges (if using multiple pods)\\n\",\n    \"pinecone.delete_index('negative-mine')  # when pods == 1, no charges\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

We then loop through each set of queries, _P+_ passages, and their negatively mined results. Next, we shuffle those results and return the first that does _not_ match to our _P+_ passage, this becomes the _P-_ passage. We write each record to file in the format _(Q, P+, P-)_, ready for the next step.

[Full script](https://gist.github.com/jamescalam/9b84408d6c7f1fe4bf7eda2ab410c086#file-02_negative_mining-ipynb)

#### 3. Pseudo-labeling

Pseudo-labeling is the final step in _preparing_ our training data. In this step, we use a cross encoder model to generate similarity scores for both positive and negative pairs.

$$
sim(Q, P^+)
$$

$$
and
$$

$$
sim(Q, P^-)
$$

```json
{
  "_key": "971ed97642ed",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<sentence_transformers.cross_encoder.CrossEncoder.CrossEncoder at 0x16e35356ee0>\"\n      ]\n     },\n     \"execution_count\": 1,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"from sentence_transformers import CrossEncoder\\n\",\n    \"\\n\",\n    \"# initialize the cross encoder model first\\n\",\n    \"model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')\\n\",\n    \"model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"label_lines = []\\n\",\n    \"\\n\",\n    \"# triplets is list of Q, P+, and P- tuples\\n\",\n    \"for line in triplets:\\n\",\n    \"    q, p, n = line\\n\",\n    \"    # predict (Q, P+) and (Q, P-) scores\\n\",\n    \"    p_score = model.predict((q, p))\\n\",\n    \"    n_score = model.predict((q, n))\\n\",\n    \"    # calculate the margin score\\n\",\n    \"    margin = p_score - n_score\",\n    \"    # append pairs to label_lines with margin score\\n\",\n    \"    label_lines.append(\\n\",\n    \"        q + '\\\\t' + p + '\\\\t' + n + '\\\\t' + str(margin)\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"with open(\\\"data/triplets_margin.tsv\\\", 'w', encoding='utf-8') as fp:\\n\",\n    \"    fp.write('\\\\n'.join(label_lines))\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

Given a _positive_ and _negative_ query-passage similarity score (GPL uses dot-product similarity), we then take the difference between both scores to give the _margin_ between both.

$$
margin = sim(Q, P^+) - sim(Q, P^-)
$$

We calculate the margin between the two similarity scores to train our bi-encoder model using _margin MSE loss_, which requires the margin score. After generating these scores, our final data format contains the query, both passages, and the margin score.

$$
(Q, P^+, P^-, margin)
$$

This final Pseudo-labeling step is very important in ensuring we have high quality training data. Without it, we would need to assume that all passages returned in the negative mining step are irrelevant to our query and must share the same dissimilarity when contrasted against our positive passages.

In reality this is never the case. Some negative passages are more relevant than others. The authors of GPL split these negative passages into three categories [2].

![Three categories of negative passages. Whereas previous methods like GenQ that lack the pseudo-labeling step would view passages as either positive 1 or negative 0, GPL can score passages on a more meaningful scale.](https://cdn.sanity.io/images/vr8gru94/production/5ac1dd39c23524bfc1d5964bccadb2147abd4747-2520x1516.png)


We are likely to return a mix of negative passages, from highly relevant to completely irrelevant. Pseudo-labeling allows us to score passages accordingly. Above we can see three negative categories:

- **False negatives**: we haven’t returned the _exact_ match to our positive passage, but that does not mean we will not return relevant passages (that are in fact _not_ negatives). In this case our cross-encoder will label the passage as relevant, without a cross-encoder this would be marked as irrelevant.
- **Easy negatives**: these are passages that are loosely connected to the query (such as containing matching keywords) but are _not_ relevant. The cross-encoder should mark these as having low relevance.
- **Hard negatives**: in this case the passages may be tightly connected to the query, or even contain a partial answer, but still not answer the query. Our cross-encoder should mark these as being more relevant than _easy negatives_ but less so than any positive or false negative passages.

Now that we have our fully prepared data, we can move on to the training portion of GPL.

[Full script](https://gist.github.com/jamescalam/f9609f0e47937c3545364cfbef3ea1b8#file-03_ce_scoring-ipynb)

## Training with Margin MSE

The fine-tuning/training portion of GPL is not anything unique or new. It is a tried and tested bi-encoder training process that optimizes with _margin MSE loss_.

![](https://cdn.sanity.io/images/vr8gru94/production/dd2a4a8740e8e3acbbe30b1868beaad0162a70f0-534x138.png)


We are looking at the sum of squared errors between the predicted margin _𝛿^i_ and the true margin _𝛿i_ for _all samples_ in the training set (from _i=0_ to _i=M-1_). We make it a _mean_ squared error by dividing the summed error by the number of samples in the training set _M_.

Looking back at the generated training data, we have the format _(Q, P+, P-, margin)_. How do these fit into the _margin MSE loss_ function above?

![High-level view of (Q, P+, P-) triplets and how they fit into the margin MSE loss function.](https://cdn.sanity.io/images/vr8gru94/production/6c860ae8f87064a01c043a1230712da67a85aef4-2379x1323.png)


The bi-encoder model creates embeddings for the query _Q_, positive passage *P+, and negative passage _P-_. We then calculate the dot-product similarity between embeddings for both _sim(Q, P+)_ and _sim(Q, P-)_. These give us the predicted margin:

![](https://cdn.sanity.io/images/vr8gru94/production/2990777398a1cc6ca5dc9eedae11584976c34837-452x73.png)


The true margin _𝛿i_ has already been calculated by our cross-encoder, it is simply _𝛿i_ _= margin_.

We can use the default _sentence-transformers_ methods for fine-tuning models with margin MSE loss. We begin by loading our pairs into a list of `InputExample` objects.

```json
{
  "_key": "53ac87674064",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"6000\"\n      ]\n     },\n     \"execution_count\": 2,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"from sentence_transformers import InputExample\\n\",\n    \"\\n\",\n    \"training_data = []\\n\",\n    \"\\n\",\n    \"for line in label_lines:\\n\",\n    \"    q, p, n, margin = line.split('\\\\t')\\n\",\n    \"    training_data.append(InputExample(\\n\",\n    \"        texts=[q, p, n],\\n\",\n    \"        label=float(margin)\\n\",\n    \"    ))\\n\",\n    \"\\n\",\n    \"len(training_data)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

We can see the contents of one of our `InputExample` objects:

```json
{
  "_key": "6584b526173a",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \"Query: can mutation be introduced to plasmid\\n\",\n      \"Passage +: Mutations were then introduced into the replicon plasmid as described above....\\n\",\n      \"Passage -: It is worth noting that PfSWIB may lead to stage-specific PCD, in the same way as BAF60a of the mamm...\\n\",\n      \"Margin: 17.084930419921875\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(f\\\"\\\"\\\"\\n\",\n    \"Query: {training_data[0].texts[0]}\\n\",\n    \"Passage +: {training_data[0].texts[1][:100]}...\\n\",\n    \"Passage -: {training_data[0].texts[2][:100]}...\\n\",\n    \"Margin: {training_data[0].label}\\n\",\n    \"\\\"\\\"\\\")\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

We use a generic PyTorch `DataLoader` to load the data into the model during training. One crucial detail is that margin MSE loss works best with large batch sizes. A batch size of _32_ or even _64_ is a good target, but this does require significant GPU memory and may not be feasible. If that is the case, reduce the batch size until it fits within your hardware restraints.

```json
{
  "_key": "155115091bed",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"\\n\",\n    \"torch.cuda.empty_cache()  # clear GPU\\n\",\n    \"\\n\",\n    \"batch_size = 32\\n\",\n    \"\\n\",\n    \"loader = torch.utils.data.DataLoader(\\n\",\n    \"    training_data, batch_size=batch_size, shuffle=True\\n\",\n    \")\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

Next, we initialize a bi-encoder model using the pre-COVID DistilBERT bi-encoder that we used in the negative mining step. It is this model that we are adapting to better understand COVID-19 related language.

```json
{
  "_key": "cd21948583cb",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"SentenceTransformer(\\n\",\n       \"  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: DistilBertModel \\n\",\n       \"  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\\n\",\n       \")\"\n      ]\n     },\n     \"execution_count\": 4,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"from sentence_transformers import SentenceTransformer\\n\",\n    \"\\n\",\n    \"model = SentenceTransformer('msmarco-distilbert-base-tas-b')\\n\",\n    \"model.max_seq_length = 256\\n\",\n    \"model\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

We’re ready to initialize the margin MSE loss that will optimize the model later.

```json
{
  "_key": "048042dca992",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from sentence_transformers import losses\\n\",\n    \"\\n\",\n    \"loss = losses.MarginMSELoss(model)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

With that, we’re finally ready to begin fine-tuning our model. We used a single epoch, with a training set of 600K samples this is a large number of steps. It was found that GPL performance tends to stop improving after around 100K steps [2]. However, this will vary by dataset.

![NDCG@10% performance for zero-shot (not adapted), GPL fine-tuned, and GPL fine-tuned + TSDAE pre-trained models. GPL fine-tuning using a model that had previously been pretrained using TSDAE demonstrates consistently better performance. Model performance seems to level-out after 100K training steps. Visual adapted from [2].](https://cdn.sanity.io/images/vr8gru94/production/32dca390f3667aad12274e4d8f9baae6a205b54b-1744x1139.png)


```json
{
  "_key": "9dcb59937073",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from sentence_transformers import SentenceTransformer\\n\",\n    \"\\n\",\n    \"model = SentenceTransformer('msmarco-distilbert-base-tas-b-covid')\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

Once training is complete, we will find all of our model files in the _msmarco-distilbert-base-tas-b-covid_ directory. To use our model in the future, we simply load it from the same directory using _sentence-transformers_.

```json
{
  "_key": "104e43309663",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from sentence_transformers import SentenceTransformer\\n\",\n    \"\\n\",\n    \"model = SentenceTransformer('msmarco-distilbert-base-tas-b-covid')\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"interpreter\": {\n   \"hash\": \"5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.8 ('ml')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}"
}
```

If you’d like to use the model trained in this article, you can specify the model name `pinecone/msmarco-distilbert-base-tas-b-covid`.

Now let’s return to the COVID-19 queries we asked the initial model (without GPL adaptation).

```json
{
  "_key": "1727b97a10a0",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"id\": \"f5aa1f5b-9d4b-4401-b7aa-5d0db9bc137f\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \"Query: How is COVID-19 transmitted\\n\",\n      \"101.72\\tCorona is transmitted via the air\\n\",\n      \"101.57\\tEbola is transmitted via direct contact with blood\\n\",\n      \"100.58\\tHIV is transmitted via sex or sharing needles\\n\",\n      \"\\n\",\n      \"Query: what is the latest named variant of corona\\n\",\n      \"99.32\\tpeople keep drinking corona lager\\n\",\n      \"98.77\\tthe omicron variant of SARS-CoV-2\\n\",\n      \"90.97\\tCOVID-19 is an illness caused by a virus\\n\",\n      \"\\n\",\n      \"Query: What are the symptoms of COVID-19?\\n\",\n      \"99.37\\tcommon Corona symptoms include a high temperature, cough, and fatigue\\n\",\n      \"98.75\\tcommon Flu symptoms include a high temperature, aching body, and fatigue\\n\",\n      \"97.51\\tsymptoms are a physical or mental features which indicate a condition of disease\\n\",\n      \"\\n\",\n      \"Query: How will most people identify that they have contracted the coronavirus\\n\",\n      \"97.69\\tafter drinking too many bottles of corona beer most people are hungover\\n\",\n      \"97.35\\tthe most common COVID-19 symptoms include a high temperature, cough, and fatigue\\n\",\n      \"93.53\\tcommon symptoms of flu include a high temperature, aching, and exhaustion\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"show_examples(gpl_model)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"environment\": {\n   \"kernel\": \"python3\",\n   \"name\": \"common-cu110.m91\",\n   \"type\": \"gcloud\",\n   \"uri\": \"gcr.io/deeplearning-platform-release/base-cu110:m91\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.12\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
}
```

As before we are asking four questions, each of which has three possible passages. Our model is tasked with scoring the similarity of each passage, the goal is to return COVID-19 related sentences higher than any other sentences.

We can see that this has worked for two of our queries. For the two queries it performs worse on, it looks like our GPL trained model is confusing the drink Corona with _“corona”_ in the context of COVID-19.

What we can do is try and fine-tune our model for more epochs, if we try again with a model trained for 10 epochs we get more promising results.

```json
{
  "_key": "7898990b108c",
  "_type": "colabBlock",
  "jsonContent": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"id\": \"e0aca7fc-7be4-4ff1-a07d-b9b283ce4895\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\n\",\n      \"Query: How is COVID-19 transmitted\\n\",\n      \"98.21\\tCorona is transmitted via the air\\n\",\n      \"96.15\\tEbola is transmitted via direct contact with blood\\n\",\n      \"94.77\\tHIV is transmitted via sex or sharing needles\\n\",\n      \"\\n\",\n      \"Query: what is the latest named variant of corona\\n\",\n      \"93.19\\tthe omicron variant of SARS-CoV-2\\n\",\n      \"93.10\\tpeople keep drinking corona lager\\n\",\n      \"86.33\\tCOVID-19 is an illness caused by a virus\\n\",\n      \"\\n\",\n      \"Query: What are the symptoms of COVID-19?\\n\",\n      \"95.02\\tcommon Corona symptoms include a high temperature, cough, and fatigue\\n\",\n      \"94.21\\tcommon Flu symptoms include a high temperature, aching body, and fatigue\\n\",\n      \"93.02\\tsymptoms are a physical or mental features which indicate a condition of disease\\n\",\n      \"\\n\",\n      \"Query: How will most people identify that they have contracted the coronavirus\\n\",\n      \"91.62\\tthe most common COVID-19 symptoms include a high temperature, cough, and fatigue\\n\",\n      \"91.36\\tafter drinking too many bottles of corona beer most people are hungover\\n\",\n      \"87.99\\tcommon symptoms of flu include a high temperature, aching, and exhaustion\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"show_examples(gpl_model10)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"environment\": {\n   \"kernel\": \"python3\",\n   \"name\": \"common-cu110.m91\",\n   \"type\": \"gcloud\",\n   \"uri\": \"gcr.io/deeplearning-platform-release/base-cu110:m91\"\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.12\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}"
}
```

Now we see much better results and our model is more easily differentiating between the Corona beer, and COVID-19.

[Full script](https://gist.github.com/jamescalam/15d968bed79b884bf50090a34093f508#file-04_finetune-ipynb)

## A Simpler Approach

We’ve worked through a lot of theory and code to understand GPL, and hopefully, it is now much clearer. However, we don’t need to work through all of that to apply GPL. It is much easier when using the [official GPL library](https://github.com/UKPLab/gpl).

Doing the same as we did before requires little more than a few lines of code. To start, we first `pip install gpl`. Our input data must use the BeIR data format, a single JSON lines (`.jsonl`) file called _corpus.jsonl_. Each sample in the file will look like this:

```json
{
    "_id": "string ID value",
    "title": "string, can be empty",
    "text": "string, the primary content of the sample, like a paragraph of text",
    "metadata": {
        "key1": "metadata is a dictionary containing additional info",
        "key2": "key-value pairs do not need to be str-str",
        "for example": 1337,
        "don't forget": "the dictionary can be empty too"
    }
}
```

Our CORD-19 dataset is not initially in the correct format, so we must first reformat it.

```python
from tqdm.auto import tqdm
import json
import os

# create directory if needed
if not os.path.exists('./cord_data'):
os.mkdir('./cord_data')

id_count = 0

with open('./cord_data/corpus.jsonl', 'w') as jsonl:
    for path in tqdm(paths):
        # read each json file in the CORD-19 pdf_json directory
        with open(path, 'r') as fp:
            doc = json.load(fp)
        # extract the passages of text from each document
        for line in doc['body_text']:
            line = {
                '_id': str(id_count),
                'title': "",
                'text': line['text'].replace('\n', ' '),
                'metadata': doc['metadata']
            }
            id_count += 1
    # iteratively write lines to the JSON lines corpus.jsonl file
    jsonl.write(json.dumps(line)+'\n')
```

Now we will have a new _corpus.jsonl_ file in the _cord_data_ directory. The first sample from the newly formatted CORD-19 dataset looks similar to this:

```json
{
    "_id": "0",
    "title": "",
    "text": "Digital technologies have provided support in diverse policy...",
    "metadata": {
        "title": "Students' Acceptance of Technology...",
        "authors": [
            {"first": "Pilar", "last": "Lacasa"},
            {"first": "Juan", "last": "Nieto"},
            ...
        ]
    }
}
```

With that newly formatted dataset, we can run the whole GPL data generation and fine-tuning process with a single, slightly lengthy function call.

```python
import gpl

gpl.train(
    path_to_generated_data='./cord_data',
    base_ckpt='distilbert-base-uncased',
    batch_size_gpl=16,
    gpl_steps=140_000,
    output_dir='./output/cord_model',
    generator='BeIR/query-gen-msmarco-t5-base-v1',
    retrievers=[
        'msmarco-distilbert-base-v3',
        'msmarco-MiniLM-L-6-v3'
    ],
    cross_encoder='cross-encoder/ms-marco-MiniLM-L-6-v2',
    qgen_prefix='qgen',
    do_evaluation=False
)
```

Let’s break all of this down. We have:

- `path_to_generated_data` - the directory containing _corpus.jsonl_.
- `base_ckpt` - the starting point of the bi-encoder model that we will be fine-tuning.
- `batch_size_gpl` - batch size for the margin MSE loss fine-tuning step.
- `gpl_steps` - number of training steps to run for the MSE margin loss fine-tuning.
- `output_dir` - where to save the fine-tuned bi-encoder model.
- `generator` - the query generation model.
- `retrievers` - a list of retriever models to use in the negative mining step.
- `cross_encoder` - the cross encoder model used for pseudo-labeling.
- `qgen_prefix` - the query generation data files prefix.
- `do_evaluation` - whether to evaluate the model on an evaluation dataset requires an evaluation set to be provided.

After running this, which can take some time, we have a bi-encoder fine-tuned using GPL on nothing more than the passages of text passed from the _./cord_data/corpus.jsonl_ file.

Using the GPL library is a great way to apply unsupervised learning. When compared to our more in-depth process, it is _much_ simpler. The one downside is that the negative mining step uses [exhaustive search](https://www.pinecone.io/learn/series/faiss/faiss-tutorial/). This type of search is no problem for smaller corpora but becomes slow for larger datasets (100K–1M+) and, depending on your hardware, impossible for anything too large to be stored in memory.

That’s it for this chapter on **G**enerative **P**seudo-**L**abeling (GPL). Using this impressive approach, we can fine-tune new or existing models in domains that were previously inaccessible due to little or no labeled data.

The research on unsupervised training methods for bi-encoder models continues to progress. GPL is the latest in a series of techniques that extends the performance of these exciting models trained without labeled data.

What is possible with GPL is impressive. Perhaps even more exciting is the possibility of further improvements to GPL or completely new methods that take the performance of these unsupervised training methods to even greater heights.

## References

[Code Notebooks](https://github.com/pinecone-io/examples/tree/master/learn/analytics-and-ml/model-training/gpl)

[1] T. Berners-Lee, M. Fischetti, [Weaving the Web, The Original Design and Ultimate Destiny of the World Wide Web by Its Inventor](https://www.w3.org/People/Berners-Lee/Weaving/) (1999)

[2] K. Wang, et al., [GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval](https://arxiv.org/abs/2112.07577)

[3] Y. Qu, et al., [RocketQA: An Optimized Training Approach to Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2010.08191) (2021), NAACL