Extreme Classification with Similarity Search

This demo aims to label new texts automatically when the number of possible labels is enormous. This scenario is known as extreme classification, a supervised learning variant that deals with multi-class and multi-label problems involving many choices.

Examples for applying extreme classification are labeling a new article with Wikipedia’s topical labels, matching web content with a set of relevant advertisements, classifying product descriptions with catalog labels, and classifying a resume into a collection of pertinent job titles.

Article labeling with extreme classification.

Here’s how we’ll perform extreme classification:

  1. We’ll transform 250,000 labels into vector embeddings using a publicly available embedding model and upload them into a managed vector index.
  2. Then we’ll take an article that requires labeling and transform it into a vector embedding using the same model.
  3. We’ll use that article’s vector embedding as the query to search the vector index. In effect, this will retrieve the most similar labels to the article’s semantic content.
  4. With the most relevant labels retrieved, we can automatically apply them to the article.

Let’s get started!

Open Notebook View Source

Dependencies

!pip install -qU pinecone-client ipywidgets setuptools>=36.2.1 wikitextparser
!pip install -qU sentence-transformers --no-cache-dir
import os
import re
import gzip
import json
import pandas as pd
import numpy as np
from wikitextparser import remove_markup, parse
from sentence_transformers import SentenceTransformer

Setting up Pinecone’s Similarity Search Service

Here we set up our similarity search service. We assume you are familiar with Pinecone’s quick start tutorial.

import pinecone
# Load Pinecone API key
api_key = os.getenv("PINECONE_API_KEY") or "YOUR_API_KEY"
pinecone.init(api_key=api_key)

Get a Pinecone API key if you don’t have one.

# Pick a name for the new index
index_name = 'extreme-ml'
# Check whether the index with the same name already exists
if index_name in pinecone.list_indexes():
    pinecone.delete_index(index_name)
# Create a new vector index
pinecone.create_index(name=index_name, metric='cosine', shards=1)
# Connect to the created index
index = pinecone.Index(name = index_name, response_timeout=300)

# Print info
index.info()
InfoResult(index_size=0)

Data Preparation

In this demo, we classify Wikipedia articles using a standard dataset from an extreme classification benchmarking resource. The data used in this example is Wikipedia-500k which contains around 500,000 labels. Here, we will download the raw data and prepare it for the classification task.

# Download train dataset
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=10RBSf6nC9C38wUMwqWur2Yd8mCwtup5K" -O 'trn.raw.json.gz' && rm -rf /tmp/cookies.txt 

# Download test dataset
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1pEyKXtkwHhinuRxmARhtwEQ39VIughDf' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1pEyKXtkwHhinuRxmARhtwEQ39VIughDf" -O 'tst.raw.json.gz' && rm -rf /tmp/cookies.txt

# Download categories labels file
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1ZYTZPlnkPBCMcNqRRO-gNx8EPgtV-GL3' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1ZYTZPlnkPBCMcNqRRO-gNx8EPgtV-GL3" -O 'Yf.txt' && rm -rf /tmp/cookies.txt

# Create and move downloaded files to data folder
!mkdir data
!mv 'trn.raw.json.gz' 'tst.raw.json.gz' 'Yf.txt' data
# Define paths
ROOT_PATH = os.getcwd()
TRAIN_DATA_PATH = (os.path.join(ROOT_PATH, './data/trn.raw.json.gz'))
TEST_DATA_PATH = (os.path.join(ROOT_PATH, './data/tst.raw.json.gz'))
# Load categories
with open('./data/Yf.txt',  encoding='utf-8') as f:
    categories = f.readlines()

# Clean values
categories = [cat.split('->')[1].strip('\n') for cat in categories]

# Show frist few categories
categories[:3]
['!!!_albums', '+/-_(band)_albums', '+44_(band)_songs']

Using a Subset of the Data

For this example, we will select and use a subset of wikipedia articles. This will save time for processing and consume much less memory than the complete dataset.

We will select a sample of 200,000 articles that contains around 250,000 different labels.

Feel free to run the notebook with more data.

WIKI_ARTICLES_INDEX = range(0, 1000000, 5)

lines = []

with gzip.open(TRAIN_DATA_PATH) as f:
    for e, line in enumerate(f):
        if e >= 1000000:
            break
        if e in WIKI_ARTICLES_INDEX:
            lines.append(json.loads(line))
        
df = pd.DataFrame.from_dict(lines)
df = df[['title', 'content', 'target_ind']]
df.head()
titlecontenttarget_ind
0Anarchism{{redirect2|anarchist|anarchists|the fictional...[81199, 83757, 83805, 193030, 368811, 368937, ...
1Academy_Awards{{redirect2|oscars|the oscar|the film|the osca...[19080, 65864, 78208, 96051]
2Anthropology{{about|the social science}} {{use dmy dates|d...[83605, 423943]
3American_Football_Conference{{refimprove|date=september 2014}} {{use dmy d...[76725, 314198, 334093]
4Analysis_of_variance{{use dmy dates|date=june 2013}} '''analysis o...[81170, 168516, 338198, 441529]
print(df.shape)
(200000, 3)

Remove Wikipedia Markup Format

We are going to use only the first part of the articles to make them comparable in terms of length. Also, Wikipedia articles have a certain format that is not so readable, so we will remove the markup to make the content as clean as possible.

# Reduce content to first 3000 characters
df['content_short'] = df.content.apply(lambda x: x[:3000])

# Remove wiki articles markup
df['content_cleaned'] = df.content_short.apply(lambda x: remove_markup(x))

# Keep only certain columns
df = df[['title', 'content_cleaned', 'target_ind']]

# Show data
df.head()
titlecontent_cleanedtarget_ind
0Anarchismanarchism is a political philosophy that a...[81199, 83757, 83805, 193030, 368811, 368937, ...
1Academy_Awardsthe academy awards or the oscars (the offi...[19080, 65864, 78208, 96051]
2Anthropologyanthropology is the scientific study of hu...[83605, 423943]
3American_Football_Conferencethe american football conference (afc) is o...[76725, 314198, 334093]
4Analysis_of_varianceanalysis of variance (anova) is a collection ...[81170, 168516, 338198, 441529]
# Keep all labels in a single list
all_categories = []
for i, row in df.iterrows():
    all_categories.extend(row.target_ind)
print('Number of labels: ',len(list(set(all_categories))))
Number of labels:  256899

Create Article Vector Embeddings

Recall, we want to index and search all possible (250,000) labels. We do that by averaging, for each label, the corresponding article vector embeddings that contain that label.

Let’s first create the article vector embeddings. Here we use the Average Word Embeddings Models. In the next section, we will aggregate these vectors to make the final label embeddings.

# Load the model
model = SentenceTransformer('average_word_embeddings_komninos')

# Create embeddings
encoded_articles = model.encode(df['content_cleaned'], show_progress_bar=True)
df['content_vector'] = pd.Series(encoded_articles.tolist())

Upload articles

It appears that using the article embeddings per se doesn’t provide good enough accuracies. Therefore, we chose to index and search the labels directly.

The label embedding is simply the average of all its corresponding article embeddings.

# Explode the target indicator column
df_explode = df.explode('target_ind')

# Group by label and define a unique vector for each label
result = df_explode.groupby('target_ind').agg(mean=('content_vector', lambda x: np.vstack(x).mean(axis=0).tolist()))
result['target_ind'] = result.index
result.columns = ['content_vector', 'ind']

result.head()
content_vectorind
target_ind
2[0.0704750344157219, -0.007719345390796661, 0....2
3[0.05894148722290993, -0.03119848482310772, 0....3
5[0.18302207440137863, 0.061663837544620036, 0....5
6[0.1543595753610134, 0.03904660418629646, 0.03...6
9[0.22310754656791687, 0.1524289846420288, 0.09...9
# Create a list of items to upsert
items_to_upsert = [(categories[int(row.ind)][:64], row.content_vector) for i, row in result.iterrows()]
# Upsert data
acks = index.upsert(items=items_to_upsert)
acks[:3]

Let’s validate the number of indexed labels.

index.info()
InfoResult(index_size=256552)

Query

Now, let’s test the vector index and examine the classifier results. Observe that here we retrieve a fixed number of labels. Naturally, in an actual application, you might want to calculate the size of the retrieved label set dynamically.

NUM_OF_WIKI_ARTICLES = 3
WIKI_ARTICLES_INDEX = range(1111, 100000, 57)[:NUM_OF_WIKI_ARTICLES]

lines = []

with gzip.open(TEST_DATA_PATH) as f:
    for e, line in enumerate(f):
        if e in  WIKI_ARTICLES_INDEX:
            lines.append(json.loads(line)) 
        if e > max(WIKI_ARTICLES_INDEX):
            break
            
df_test = pd.DataFrame.from_dict(lines)
df_test = df_test[['title', 'content', 'target_ind']]
df_test.head()
titlecontenttarget_ind
0Discrimination{{otheruses}} {{discrimination sidebar}} '''di...[170479, 423902]
1Erfurt{{refimprove|date=june 2014}} {{use dmy dates|...[142638, 187156, 219262, 294479, 329185, 38243...
2ETA{{about|the basque organization|other uses|eta...[83681, 100838, 100849, 100868, 176034, 188979...
# Reduce content to first 3000 characters
df_test['content_short'] = df_test.content.apply(lambda x: x[:3000])

# Remove wiki articles markup
df_test['content_cleaned'] = df_test.content_short.apply(lambda x: remove_markup(x))

# Keep only certain columns
df_test = df_test[['title', 'content_cleaned', 'target_ind']]

# Show data
df_test.head()
titlecontent_cleanedtarget_ind
0Discriminationdiscrimination is action that denies social ...[170479, 423902]
1Erfurterfurt () is the capital city of thuringia ...[142638, 187156, 219262, 294479, 329185, 38243...
2ETAeta (, ), an acronym for euskadi ta askatas...[83681, 100838, 100849, 100868, 176034, 188979...
# Create embeddings for test articles
test_vectors = model.encode(df_test['content_cleaned'], show_progress_bar=True)
# Query the vector index
query_results = index.query(queries=test_vectors, top_k=10)
# Show results
for term, labs, res in zip(df_test.title.tolist(), df_test.target_ind.tolist(), query_results):
    print()
    print('Term queried: ',term)
    print('Original labels: ')
    for l in labs:
        if l in all_categories:
            print('\t', categories[l])
    print('Predicted: ')
    df_result = pd.DataFrame({
                'id':[id for id in res.ids],
                'score':[id for id in res.scores],})
    display(df_result)    
Term queried:  Discrimination
Original labels: 
   Discrimination
   Social_justice
Predicted: 
idscore
0Discrimination0.972957
1Sociological_terminology0.971605
2Identity_politics0.970097
3Social_concepts0.967534
4Sexism0.967476
5Affirmative_action0.967288
6Political_correctness0.966926
7Human_behavior0.966475
8Persecution0.965421
9Social_movements0.964393
Term queried:  Erfurt
Original labels: 
   Erfurt
   German_state_capitals
   Members_of_the_Hanseatic_League
   Oil_Campaign_of_World_War_II
   Province_of_Saxony
   University_towns_in_Germany
Predicted: 
idscore
0University_towns_in_Germany0.966058
1Province_of_Saxony0.959731
2Populated_places_on_the_Rhine0.958738
3Imperial_free_cities0.957159
4Hildesheim_(district)0.956928
5History_of_the_Electoral_Palatinate0.956800
6Towns_in_Saxony-Anhalt0.956501
7Towns_in_Lower_Saxony0.955259
8Halle_(Saale)0.954934
9Cities_in_Saxony-Anhalt0.954934
Term queried:  ETA
Original labels: 
   Anti-Francoism
   Basque_conflict
   Basque_history
   Basque_politics
   ETA
   European_Union_designated_terrorist_organizations
   Far-left_politics
   Francoist_Spain
   Government_of_Canada_designated_terrorist_organizations
   Irregular_military
   Military_wings_of_political_parties
   National_liberation_movements
   Nationalist_terrorism
   Organizations_designated_as_terrorist_by_the_United_States_government
   Organizations_designated_as_terrorist_in_Europe
   Organizations_established_in_1959
   Politics_of_Spain
   Resistance_movements
   Secession_in_Spain
   Secessionist_organizations_in_Europe
   Terrorism_in_Spain
   United_Kingdom_Home_Office_designated_terrorist_groups
Predicted: 
idscore
0Organizations_designated_as_terrorist_in_Europe0.948875
1Terrorism_in_Spain0.948431
2Basque_politics0.942670
3Politics_of_Spain0.941830
4European_Union_designated_terrorist_organizations0.940194
5Irregular_military0.938163
6Political_parties_disestablished_in_19770.936437
7Algerian_Civil_War0.936311
8Republicanism_in_Spain0.935577
9Guerrilla_organizations0.935507

Summary

We demonstrated a similarity search approach for performing extreme classification of texts. We took a simple approach representing labels as the average of their corresponding texts' vector embeddings. In classification time, we match between a new article embedding and its nearest label embeddings. Our result examples indicate the usefulness of this approach.

You can take this forward by exploring advanced ideas. For example, you can utilize the hierarchical relationship between labels or improve the label representations. Just have fun, and feel free to share your thoughts.

Turn off the service

Turn off the service once you are sure that you do not want to use it anymore. Once the service is stopped, you cannot use it again.

pinecone.delete_index(index_name)