π₯ Why Fine-Tune a Cross-Encoder?
1. More Accurate Semantic Judgments:
- A Cross-Encoder takes both sentences together as input, so BERT (or another Transformer) can directly compare words across sentences using attention.
- This allows it to align tokens like βmanβ β βpersonβ, βguitarβ β βinstrumentβ, and reason at a finer level.
- Result: higher accuracy on tasks like Semantic Textual Similarity (STS), duplicate detection, or answer re-ranking.
- Example:
- Bi-encoder (separate embeddings) might give βman plays guitarβ β βguitarist performingβ a similarity of 0.7.
- Cross-encoder, by jointly encoding, can push it to 0.95 because it captures the equivalence more precisely.
2. Adapting to Domain-Specific Data
- Pretrained models (BERT, RoBERTa, etc.) are general-purpose.
- Fine-tuning on your own dataset teaches the cross-encoder to judge similarity in your context.
- Examples:
- Legal documents β βSection 5.1β vs βClause Vβ might be synonyms only in legal domain.
- Medical texts β βheart attackβ β βmyocardial infarctionβ.
- Customer support β βreset passwordβ β βforgot login credentialsβ.
Without fine-tuning, the model might miss these domain-specific relationships.
3. Optimal for Ranking Tasks
- In search or retrieval, you often want to re-rank candidates returned by a fast retriever.
- Cross-encoder excels here:
- Bi-encoder: retrieves top-100 candidates quickly.
- Cross-encoder: re-scores those top-100 pairs with higher accuracy.
- This setup is widely used in open-domain QA (like MS MARCO, ColBERT pipelines), recommender systems, and semantic search.
4. Regression & Classification Tasks
- Many tasks are not just βsimilar / not similarβ but have graded similarity (0β5 in STS-B).
- A fine-tuned cross-encoder can predict continuous similarity scores.
- It can also be adapted for classification (duplicate vs not duplicate, entailment vs contradiction, etc.).
5. When Data Labels Matter
- If you have annotated sentence pairs, fine-tuning a cross-encoder directly optimizes for your target metric (e.g., MSE on similarity scores, accuracy on duplicates).
- A pretrained model alone will not βknowβ your specific scoring function.
- Example: Two sentences could be judged similar by generic BERT, but your dataset might label them as not duplicates because of context.
6. Performance vs Efficiency Tradeoff
- Cross-encoders are slower because you must run the Transformer per sentence pair.
- But theyβre worth training when:
- Accuracy is more important than latency (e.g., offline re-ranking, evaluation tasks).
- Dataset size is manageable (you donβt need to encode millions of pairs at once).
- You have a candidate shortlist (bi-encoder first, then cross-encoder refine).
π§ Fine-tune a Cross Encoder
Let’s come to training part where weβll fine-tune a cross-encoder (BERT-based) on the STS-Benchmark dataset, where pairs of sentences are scored on semantic similarity (0β5).

Fig. Fine tuning Cross-Encoders
1. Install Dependencies
pip install torch transformers sentence-transformers datasets accelerate
2. Load Data
Weβll use the STS-B dataset from Hugging Face.
# ========================
# Dataset Loading
# ========================
from datasets import load_dataset
# Load Semantic Textual Similarity Benchmark
# https://huggingface.co/datasets/PhilipMay/stsb_multi_mt
print("Loading STS-B (multilingual, English split)...")
dataset = load_dataset("stsb_multi_mt", "en")
print(dataset) # Show available splits (train/test)
3. Prepare Training Data
Weβll convert pairs into (sentence1, sentence2, score)
format. We use the (sentence1, sentence2, score) format because a cross-encoder operates on paired sentences and needs a supervised similarity score to learn from. This format directly aligns with both the modelβs input structure and the training objective.
# Prepare InputExamples
train_examples = [
InputExample(texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]))
for row in dataset["train"]
]
dev_examples = [
InputExample(texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]))
for row in dataset["test"]
]
4. Create a Data Loader
# Create DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)
5. Model Setup
# ========================
# Model Setup
# ========================
print(f"Loading CrossEncoder model: {MODEL_NAME}")
model = CrossEncoder(MODEL_NAME, num_labels=1)
# Evaluator (Spearman/Pearson correlation between predicted & true scores)
evaluator = CECorrelationEvaluator.from_input_examples(dev_examples, name="sts-dev")
6. Training
# ========================
# Training
# ========================
print("Starting training...")
model.fit(
train_dataloader=train_dataloader,
evaluator=evaluator,
epochs=EPOCHS,
evaluation_steps=EVAL_STEPS,
warmup_steps=WARMUP_STEPS,
output_path=OUTPUT_DIR
)
6. Reload Trained Model
# ========================
# Reload Trained Model
# ========================
print("Loading trained model from:", OUTPUT_DIR)
model = CrossEncoder(OUTPUT_DIR)
7. Inference Demo
# --- Pairwise similarity
test_sentences = [
("A man is playing a guitar.", "A person is playing a guitar."),
("A dog is running in the park.", "A cat is sleeping on the couch.")
]
scores = model.predict(test_sentences)
print("\nSimilarity Prediction Demo:")
for (s1, s2), score in zip(test_sentences, scores):
print(f" {s1} <-> {s2} => {score:.3f}")
# --- Information retrieval style (ranking)
query = "What is the capital of France?"
candidates = [
"Paris is the capital city of France.",
"London is the capital of the UK.",
"France is known for its wine and cheese."
]
pairs = [(query, cand) for cand in candidates]
scores = model.predict(pairs)
ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
print("\nRanking Demo:")
for cand, score in ranked:
print(f" {cand} => {score:.3f}")
8. Complete Code
# main.py
# ========================
# Imports & Configuration
# ========================
import os
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from sentence_transformers import CrossEncoder, InputExample
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
# Config
MODEL_NAME = "bert-base-uncased"
OUTPUT_DIR = "./cross-encoder-stsb"
BATCH_SIZE = 16
EPOCHS = 3
WARMUP_STEPS = 100
EVAL_STEPS = 500
SEED = 42
# Ensure reproducibility
torch.manual_seed(SEED)
# ========================
# Dataset Loading
# ========================
print("Loading STS-B (multilingual, English split)...")
dataset = load_dataset("stsb_multi_mt", "en")
print(dataset) # Show available splits (train/test)
# Prepare InputExamples
train_examples = [
InputExample(texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]))
for row in dataset["train"]
]
dev_examples = [
InputExample(texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]))
for row in dataset["test"]
]
# Create DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)
# ========================
# Model Setup
# ========================
print(f"Loading CrossEncoder model: {MODEL_NAME}")
model = CrossEncoder(MODEL_NAME, num_labels=1)
# Evaluator (Spearman/Pearson correlation between predicted & true scores)
evaluator = CECorrelationEvaluator.from_input_examples(dev_examples, name="sts-dev")
# ========================
# Training
# ========================
print("Starting training...")
model.fit(
train_dataloader=train_dataloader,
evaluator=evaluator,
epochs=EPOCHS,
evaluation_steps=EVAL_STEPS,
warmup_steps=WARMUP_STEPS,
output_path=OUTPUT_DIR
)
# ========================
# Reload Trained Model
# ========================
print("Loading trained model from:", OUTPUT_DIR)
model = CrossEncoder(OUTPUT_DIR)
# ========================
# Inference Demo
# ========================
# --- Pairwise similarity
test_sentences = [
("A man is playing a guitar.", "A person is playing a guitar."),
("A dog is running in the park.", "A cat is sleeping on the couch.")
]
scores = model.predict(test_sentences)
print("\nSimilarity Prediction Demo:")
for (s1, s2), score in zip(test_sentences, scores):
print(f" {s1} <-> {s2} => {score:.3f}")
# --- Information retrieval style (ranking)
query = "What is the capital of France?"
candidates = [
"Paris is the capital city of France.",
"London is the capital of the UK.",
"France is known for its wine and cheese."
]
pairs = [(query, cand) for cand in candidates]
scores = model.predict(pairs)
ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
print("\nRanking Demo:")
for cand, score in ranked:
print(f" {cand} => {score:.3f}")
Output:
(env) D:\github\finetune-crossencoder>python main1.py
Loading STS-B (multilingual, English split)...
DatasetDict({
train: Dataset({
features: ['sentence1', 'sentence2', 'similarity_score'],
num_rows: 5749
})
test: Dataset({
features: ['sentence1', 'sentence2', 'similarity_score'],
num_rows: 1379
})
dev: Dataset({
features: ['sentence1', 'sentence2', 'similarity_score'],
num_rows: 1500
})
})
Loading CrossEncoder model: bert-base-uncased
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Starting training...
0%| | 0/1080 [00:00<?, ?it/s]D:\github\finetune-crossencoder\env\Lib\site-packages\torch\utils\data\dataloader.py:666: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
warnings.warn(warn_msg)
{'loss': -20.1537, 'grad_norm': 50.69091033935547, 'learning_rate': 1.1832139201637667e-05, 'epoch': 1.39}
{'eval_sts-dev_pearson': 0.4514054666098877, 'eval_sts-dev_spearman': 0.4771302005902, 'eval_runtime': 67.8654, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 1.39}
{'loss': -32.7533, 'grad_norm': 52.87107849121094, 'learning_rate': 1.5967246673490277e-06, 'epoch': 2.78}
{'eval_sts-dev_pearson': 0.5504492763939616, 'eval_sts-dev_spearman': 0.5489895972483916, 'eval_runtime': 91.5175, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 2.78}
{'train_runtime': 5965.8199, 'train_samples_per_second': 2.891, 'train_steps_per_second': 0.181, 'train_loss': -27.04566062644676, 'epoch': 3.0}
100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1080/1080 [1:39:25<00:00, 5.52s/it]
Loading trained model from: ./cross-encoder-stsb
Similarity Prediction Demo:
A man is playing a guitar. <-> A person is playing a guitar. => 1.000
A dog is running in the park. <-> A cat is sleeping on the couch. => 0.176
Ranking Demo:
Paris is the capital city of France. => 1.000
France is known for its wine and cheese. => 1.000
London is the capital of the UK. => 0.832
β Key Takeaways
- Cross-encoders model fine-grained token-level interactions, making them highly accurate for semantic similarity, re-ranking, and NLI (Natural Language Inference).
- Training requires pairs of sentences with labels (scores or categories).
- They are slower than bi-encoders, so best used for re-ranking top candidates.
- Libraries like Sentence-Transformers make training straightforward.