πŸš€Hands-on Tutorial: Fine-tune a Cross-Encoder for Semantic Similarity

πŸ”₯ Why Fine-Tune a Cross-Encoder?

1. More Accurate Semantic Judgments:

  • A Cross-Encoder takes both sentences together as input, so BERT (or another Transformer) can directly compare words across sentences using attention.
  • This allows it to align tokens like β€œman” ↔ β€œperson”, β€œguitar” ↔ β€œinstrument”, and reason at a finer level.
  • Result: higher accuracy on tasks like Semantic Textual Similarity (STS), duplicate detection, or answer re-ranking.
  • Example:
    • Bi-encoder (separate embeddings) might give β€œman plays guitar” β‰ˆ β€œguitarist performing” a similarity of 0.7.
    • Cross-encoder, by jointly encoding, can push it to 0.95 because it captures the equivalence more precisely.

2. Adapting to Domain-Specific Data

  • Pretrained models (BERT, RoBERTa, etc.) are general-purpose.
  • Fine-tuning on your own dataset teaches the cross-encoder to judge similarity in your context.
  • Examples:
    • Legal documents β†’ β€œSection 5.1” vs β€œClause V” might be synonyms only in legal domain.
    • Medical texts β†’ β€œheart attack” β‰ˆ β€œmyocardial infarction”.
    • Customer support β†’ β€œreset password” β‰ˆ β€œforgot login credentials”.

Without fine-tuning, the model might miss these domain-specific relationships.

3. Optimal for Ranking Tasks

  • In search or retrieval, you often want to re-rank candidates returned by a fast retriever.
  • Cross-encoder excels here:
    • Bi-encoder: retrieves top-100 candidates quickly.
    • Cross-encoder: re-scores those top-100 pairs with higher accuracy.
  • This setup is widely used in open-domain QA (like MS MARCO, ColBERT pipelines), recommender systems, and semantic search.

4. Regression & Classification Tasks

  • Many tasks are not just β€œsimilar / not similar” but have graded similarity (0–5 in STS-B).
  • A fine-tuned cross-encoder can predict continuous similarity scores.
  • It can also be adapted for classification (duplicate vs not duplicate, entailment vs contradiction, etc.).

5. When Data Labels Matter

  • If you have annotated sentence pairs, fine-tuning a cross-encoder directly optimizes for your target metric (e.g., MSE on similarity scores, accuracy on duplicates).
  • A pretrained model alone will not β€œknow” your specific scoring function.
  • Example: Two sentences could be judged similar by generic BERT, but your dataset might label them as not duplicates because of context.

6. Performance vs Efficiency Tradeoff

  • Cross-encoders are slower because you must run the Transformer per sentence pair.
  • But they’re worth training when:
    • Accuracy is more important than latency (e.g., offline re-ranking, evaluation tasks).
    • Dataset size is manageable (you don’t need to encode millions of pairs at once).
    • You have a candidate shortlist (bi-encoder first, then cross-encoder refine).

🧠 Fine-tune a Cross Encoder

Let’s come to training part where we’ll fine-tune a cross-encoder (BERT-based) on the STS-Benchmark dataset, where pairs of sentences are scored on semantic similarity (0–5).

Fig. Fine tuning Cross-Encoders

1. Install Dependencies

pip install torch transformers sentence-transformers datasets accelerate

2. Load Data

We’ll use the STS-B dataset from Hugging Face.

# ========================
# Dataset Loading
# ========================
from datasets import load_dataset

# Load Semantic Textual Similarity Benchmark
# https://huggingface.co/datasets/PhilipMay/stsb_multi_mt
print("Loading STS-B (multilingual, English split)...")
dataset = load_dataset("stsb_multi_mt", "en")

print(dataset)  # Show available splits (train/test)

3. Prepare Training Data

We’ll convert pairs into (sentence1, sentence2, score) format. We use the (sentence1, sentence2, score) format because a cross-encoder operates on paired sentences and needs a supervised similarity score to learn from. This format directly aligns with both the model’s input structure and the training objective.

# Prepare InputExamples
train_examples = [
    InputExample(texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]))
    for row in dataset["train"]
]

dev_examples = [
    InputExample(texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]))
    for row in dataset["test"]
]

4. Create a Data Loader

# Create DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)

5. Model Setup

# ========================
# Model Setup
# ========================
print(f"Loading CrossEncoder model: {MODEL_NAME}")
model = CrossEncoder(MODEL_NAME, num_labels=1)

# Evaluator (Spearman/Pearson correlation between predicted & true scores)
evaluator = CECorrelationEvaluator.from_input_examples(dev_examples, name="sts-dev")

6. Training

# ========================
# Training
# ========================
print("Starting training...")
model.fit(
    train_dataloader=train_dataloader,
    evaluator=evaluator,
    epochs=EPOCHS,
    evaluation_steps=EVAL_STEPS,
    warmup_steps=WARMUP_STEPS,
    output_path=OUTPUT_DIR
)

6. Reload Trained Model

# ========================
# Reload Trained Model
# ========================
print("Loading trained model from:", OUTPUT_DIR)
model = CrossEncoder(OUTPUT_DIR)

7. Inference Demo

# --- Pairwise similarity
test_sentences = [
    ("A man is playing a guitar.", "A person is playing a guitar."),
    ("A dog is running in the park.", "A cat is sleeping on the couch.")
]

scores = model.predict(test_sentences)

print("\nSimilarity Prediction Demo:")
for (s1, s2), score in zip(test_sentences, scores):
    print(f"  {s1} <-> {s2} => {score:.3f}")

# --- Information retrieval style (ranking)
query = "What is the capital of France?"
candidates = [
    "Paris is the capital city of France.",
    "London is the capital of the UK.",
    "France is known for its wine and cheese."
]

pairs = [(query, cand) for cand in candidates]
scores = model.predict(pairs)

ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)

print("\nRanking Demo:")
for cand, score in ranked:
    print(f"  {cand} => {score:.3f}")

8. Complete Code

# main.py

# ========================
# Imports & Configuration
# ========================
import os
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from sentence_transformers import CrossEncoder, InputExample
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator

# Config
MODEL_NAME = "bert-base-uncased"
OUTPUT_DIR = "./cross-encoder-stsb"
BATCH_SIZE = 16
EPOCHS = 3
WARMUP_STEPS = 100
EVAL_STEPS = 500
SEED = 42

# Ensure reproducibility
torch.manual_seed(SEED)

# ========================
# Dataset Loading
# ========================
print("Loading STS-B (multilingual, English split)...")
dataset = load_dataset("stsb_multi_mt", "en")

print(dataset)  # Show available splits (train/test)

# Prepare InputExamples
train_examples = [
    InputExample(texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]))
    for row in dataset["train"]
]

dev_examples = [
    InputExample(texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]))
    for row in dataset["test"]
]

# Create DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)

# ========================
# Model Setup
# ========================
print(f"Loading CrossEncoder model: {MODEL_NAME}")
model = CrossEncoder(MODEL_NAME, num_labels=1)

# Evaluator (Spearman/Pearson correlation between predicted & true scores)
evaluator = CECorrelationEvaluator.from_input_examples(dev_examples, name="sts-dev")

# ========================
# Training
# ========================
print("Starting training...")
model.fit(
    train_dataloader=train_dataloader,
    evaluator=evaluator,
    epochs=EPOCHS,
    evaluation_steps=EVAL_STEPS,
    warmup_steps=WARMUP_STEPS,
    output_path=OUTPUT_DIR
)

# ========================
# Reload Trained Model
# ========================
print("Loading trained model from:", OUTPUT_DIR)
model = CrossEncoder(OUTPUT_DIR)

# ========================
# Inference Demo
# ========================

# --- Pairwise similarity
test_sentences = [
    ("A man is playing a guitar.", "A person is playing a guitar."),
    ("A dog is running in the park.", "A cat is sleeping on the couch.")
]

scores = model.predict(test_sentences)

print("\nSimilarity Prediction Demo:")
for (s1, s2), score in zip(test_sentences, scores):
    print(f"  {s1} <-> {s2} => {score:.3f}")

# --- Information retrieval style (ranking)
query = "What is the capital of France?"
candidates = [
    "Paris is the capital city of France.",
    "London is the capital of the UK.",
    "France is known for its wine and cheese."
]

pairs = [(query, cand) for cand in candidates]
scores = model.predict(pairs)

ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)

print("\nRanking Demo:")
for cand, score in ranked:
    print(f"  {cand} => {score:.3f}")

Output:

(env) D:\github\finetune-crossencoder>python main1.py
Loading STS-B (multilingual, English split)...
DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'similarity_score'],
        num_rows: 5749
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'similarity_score'],
        num_rows: 1379
    })
    dev: Dataset({
        features: ['sentence1', 'sentence2', 'similarity_score'],
        num_rows: 1500
    })
})
Loading CrossEncoder model: bert-base-uncased
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Starting training...
  0%|                                                                                                                                             | 0/1080 [00:00<?, ?it/s]D:\github\finetune-crossencoder\env\Lib\site-packages\torch\utils\data\dataloader.py:666: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
{'loss': -20.1537, 'grad_norm': 50.69091033935547, 'learning_rate': 1.1832139201637667e-05, 'epoch': 1.39}
{'eval_sts-dev_pearson': 0.4514054666098877, 'eval_sts-dev_spearman': 0.4771302005902, 'eval_runtime': 67.8654, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 1.39}
{'loss': -32.7533, 'grad_norm': 52.87107849121094, 'learning_rate': 1.5967246673490277e-06, 'epoch': 2.78}
{'eval_sts-dev_pearson': 0.5504492763939616, 'eval_sts-dev_spearman': 0.5489895972483916, 'eval_runtime': 91.5175, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 2.78}
{'train_runtime': 5965.8199, 'train_samples_per_second': 2.891, 'train_steps_per_second': 0.181, 'train_loss': -27.04566062644676, 'epoch': 3.0}
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1080/1080 [1:39:25<00:00,  5.52s/it]
Loading trained model from: ./cross-encoder-stsb

Similarity Prediction Demo:
  A man is playing a guitar. <-> A person is playing a guitar. => 1.000
  A dog is running in the park. <-> A cat is sleeping on the couch. => 0.176

Ranking Demo:
  Paris is the capital city of France. => 1.000
  France is known for its wine and cheese. => 1.000
  London is the capital of the UK. => 0.832

βœ… Key Takeaways

  • Cross-encoders model fine-grained token-level interactions, making them highly accurate for semantic similarity, re-ranking, and NLI (Natural Language Inference).
  • Training requires pairs of sentences with labels (scores or categories).
  • They are slower than bi-encoders, so best used for re-ranking top candidates.
  • Libraries like Sentence-Transformers make training straightforward.

Leave a Comment

Your email address will not be published. Required fields are marked *