Training Sentence Transformers the OG Way (with Softmax Loss)


Our article introducing sentence embeddings and transformers explained that these models can be used across a range of applications, such as semantic textual similarity (STS), semantic clustering, or information retrieval (IR) using concepts rather than words.

This article dives deeper into the training process of the first sentence transformer, sentence-BERT, or more commonly known as SBERT. We will explore the Natural Language Inference (NLI) training approach of softmax loss to fine-tune models for producing sentence embeddings.

Be aware that softmax loss is no longer the preferred approach to training sentence transformers and has been superseded by other methods such as MSE margin and multiple negatives ranking loss. But we’re covering this training method as an important milestone in the development of ever improving sentence embeddings.

This article also covers two approaches to fine-tuning. The first shows how NLI training with softmax loss works. The second uses the excellent training utilities provided by the sentence-transformers library — it’s more abstracted, making building good sentence transformer models much easier.

NLI Training

There are several ways of training sentence transformers. One of the most popular (and the approach we will cover) is using Natural Language Inference (NLI) datasets.

NLI focus on identifying sentence pairs that infer or do not infer one another. We will use two of these datasets; the Stanford Natural Language Inference (SNLI) and Multi-Genre NLI (MNLI) corpora.

Merging these two corpora gives us 943K sentence pairs (550K from SNLI, 393K from MNLI). All pairs include a premise and a hypothesis, and each pair is assigned a label:

  • 0entailment, e.g. the premise suggests the hypothesis.
  • 1neutral, the premise and hypothesis could both be true, but they are not necessarily related.
  • 2contradiction, the premise and hypothesis contradict each other.

When training the model, we will be feeding sentence A (the premise) into BERT, followed by sentence B (the hypothesis) on the next step.

From there, the models are optimized using softmax loss using the label field. We will explain this in more depth soon.

For now, let’s download and merge the two datasets. We will use the datasets library from Hugging Face, which can be downloaded using !pip install datasets. To download and merge, we write:

Both datasets contain -1 values in the label feature where no confident class could be assigned. We remove them using the filter method.

We must convert our human-readable sentences into transformer-readable tokens, so we go ahead and tokenize our sentences. Both premise and hypothesis features must be split into their own input_ids and attention_mask tensors.

Now, all we need to do is prepare the data to be read into the model. To do this, we first convert the dataset features into PyTorch tensors and then initialize a data loader which will feed data into our model during training.

And we’re done with data preparation. Let’s move on to the training approach.

Softmax Loss

Optimizing with softmax loss was the primary method used by Reimers and Gurevych in the original SBERT paper [1].

Although this was used to train the first sentence transformer model, it is no longer the go-to training approach. Instead, the MNR loss approach is most common today. We will cover this method in another article.

However, we hope that explaining softmax loss will help demystify the different approaches applied to training sentence transformers. We included a comparison to MNR loss at the end of the article.

Model Preparation

When we train an SBERT model, we don’t need to start from scratch. We begin with an already pretrained BERT model (and tokenizer).

We will be using what is called a ‘siamese’-BERT architecture during training. All this means is that given a sentence pair, we feed sentence A into BERT first, then feed sentence B once BERT has finished processing the first.

This has the effect of creating a siamese-like network where we can imagine two identical BERTs are being trained in parallel on sentence pairs. In reality, there is just a single model processing two sentences one after the other.

Start SBERT Siamese-BERT processing a sentence pair and then pooling the large token embeddings tensor into a single dense vector.

BERT will output 512 768-dimensional embeddings. We will convert these into an average embedding using mean-pooling. This pooled output is our sentence embedding. We will have two per step — one for sentence A that we call u, and one for sentence B, called v.

To perform this mean pooling operation, we will define a function called mean_pool.

Here we take BERT’s token embeddings output (we’ll see this all in full soon) and the sentence’s attention_mask tensor. We then resize the attention_mask to align to the higher 768-dimensionality of the token embeddings.

We apply this resized mask in_mask to those token embeddings to exclude padding tokens from the mean pooling operation. Our mean pooling takes the average activation of values across each dimension to produce a single value. This brings our tensor sizes from (512*768) to (1*768).

The next step is to concatenate these embeddings. Several different approaches to this were presented in the paper:

ConcatenationNLI Performance
(u, v)66.04
(|u-v|, u*v)78.37
(u, v, u*v)77.44
(u, v, |u-v|)80.78
(u, v, |u-v|, u*v)80.44

Concatenation methods for sentence embeddings u and v and their performance on STS benchmarks.

Of these, the best performing is built by concatenating vectors u, v, and |u-v|. Concatenation of them all produces a vector three times the length of each original vector. We label this concatenated vector (u, v, |u-v|). Where |u-v| is the element-wise difference between vectors u and v.

UV Vectors We concatenate (u, v, |u-v|) to merge the sentence embeddings from sentence A and B.

We will perform this concatenation operation using PyTorch. Once we have our mean-pooled sentence vectors u and v we concatenate with:

Vector (u, v, |u-v|) is fed into a feed-forward neural network (FFNN). The FFNN processes the vector and outputs three activation values. One for each of our label classes; entailment, neutral, and contradiction.

As these activations and label classes are aligned, we now calculate the softmax loss between them.

SBERT Training The final steps of training. The concatenated (u, v, |u-v|) vector is fed through a feed-forward NN to produce three output activations. Then we calculate the softmax loss between these predictions and the true labels.

Softmax loss is calculated by applying a softmax function across the three activation values (or nodes), producing a predicted label. We then use cross-entropy loss to calculate the difference between our predicted label and true label.

The model is then optimized using this loss. We use an Adam optimizer with a learning rate of 2e-5 and a linear warmup period of 10% of the total training data for the optimization function. To set that up, we use the standard PyTorch Adam optimizer alongside a learning rate scheduler provided by HF transformers:

Now let’s put all of that together in a PyTorch training loop.

We only train for a single epoch here. Realistically this should be enough (and mirrors what was described in the original SBERT paper). The last thing we need to do is save the model.

Now let’s compare everything we’ve done so far with sentence-transformers training utilities. We will compare this and other sentence transformer models at the end of the article.

Fine-Tuning With Sentence Transformers

As we already mentioned, the sentence-transformers library has excellent support for those of us just wanting to train a model without worrying about the underlying training mechanisms.

We don’t need to do much beyond a little data preprocessing (but less than what we did above). So let’s go ahead and put together the same fine-tuning process, but using sentence-transformers.

Training Data

Again we’re using the same SNLI and MNLI corpora, but this time we will be transforming them into the format required by sentence-transformers using their InputExample class. Before that, we need to download and merge the two datasets just like before.

Now we’re ready to format our data for sentence-transformers. All we do is convert the current premise, hypothesis, and label format into an almost matching format with the InputExample class.

We’ve also initialized a DataLoader just as we did before. From here, we want to begin setting up the model. In sentence-transformers we build models using different modules.

All we need is the transformer model module, followed by a mean pooling module. The transformer models are loaded from HF, so we define bert-base-uncased as before.

We have our data, the model, and now we define how to optimize our model. Softmax loss is very easy to initialize.

Now we’re ready to train the model. We train for a single epoch and warm up for 10% of training as before.

With that, we’re done, the new model is saved to ./sbert_test_b. We can load the model from that location using either the SentenceTransformer or HF’s from_pretrained methods! Let’s move on to comparing this to other SBERT models.

Compare SBERT Models

We’re going to test the models on a set of random sentences. We will build our mean-pooled embeddings for each sentence using four models; softmax-loss SBERT, multiple-negatives-ranking-loss SBERT, the original SBERT sentence-transformers/bert-base-nli-mean-tokens, and BERT bert-base-uncased.

After producing sentence embeddings, we will calculate the cosine similarity between all possible sentence pairs, producing a simple but insightful semantic textual similarity (STS) test.

We define two new functions; sts_process to build the sentence embeddings and compare them with cosine similarity and sim_matrix to construct a similarity matrix from all possible pairs.

Then we just run each model through the sim_matrix function.

After processing all pairs, we visualize the results in heatmap visualizations.

SBERT Heatmaps Similarity score heatmaps for four BERT/SBERT models.

In these heatmaps, we ideally want all dissimilar pairs to have very low scores (near white) and similar pairs to produce distinctly higher scores.

Let’s talk through these results. The bottom-left and top-right models produce the correct top three pairs, whereas BERT and softmax loss SBERT return 2/3 of the correct pairs.

If we focus on the standard BERT model, we see minimal variation in square color. This is because almost every pair produces a similarity score of between 0.6 to 0.7. This lack of variation makes it challenging to distinguish between more-or-less similar pairs. Although this is to be expected as BERT has not been fine-tuned for semantic similarity.

Our PyTorch softmax loss SBERT (top-left) misses the 9-1 sentence pair. Nonetheless, the pairs it produces are much more distinct from dissimilar pairs than the vanilla BERT model, so it’s an improvement. The sentence-transformers version is better still and did not miss the 9-1 pair.

Next up, we have the SBERT model trained by Reimers and Gurevych in the 2019 paper (bottom-left) [1]. It produces better performance than our SBERT models but still has little variation between similar and dissimilar pairs.

And finally, we have an SBERT model trained using MNR loss. This model is easily the highest performing. Most dissimilar pairs produce a score very close to zero. The highest non-pair returns 0.28 — roughly half of the true-pair scores.

From these results, the SBERT MNR model seems to be the highest performing. Producing much higher activations (with respect to the average) for true pairs than any other model, making similarity much easier to identify. SBERT with softmax loss is clearly an improvement over BERT, but unlikely to offer any benefit over the SBERT with MNR loss model.

That’s it for this article on fine-tuning BERT for building sentence embeddings! We delved into the details of preprocessing SNLI and MNLI datasets for NLI training and how to fine-tune BERT using the softmax loss approach.

Finally, we compared this softmax-loss SBERT against vanilla BERT, the original SBERT, and an MNR loss SBERT using a simple STS task. We found that although fine-tuning with softmax loss does produce valuable sentence embeddings — it still lacks quality compared to more recent training approaches.

We hope this has been an insightful and exciting exploration of how transformers can be fine-tuned for building sentence embeddings.


[1] N. Reimers, I. Gurevych, Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks (2019), ACL

What will you build?

Upgrade your search or recommendation systems with just a few lines of code, or contact us for help.