← Back to Home

Hands-on Tutorial: Fine-tune a Cross-Encoder for Semantic Similarity

ai_mldeep-learninggenainatural-language-processing-nlpunstructured-data
#bert#cross-encoder#data-science#deep-learning#fin-tuning#huggingface#machine-learning#machinelearning#natural-languagep-rocessing#nlp

🔥 Why Fine-Tune a Cross-Encoder?

1. More Accurate Semantic Judgments:

  • A Cross-Encoder takes both sentences together as input, so BERT (or another Transformer) can directly compare words across sentences using attention.

  • This allows it to align tokens like “man” ↔ “person”, “guitar” ↔ “instrument”, and reason at a finer level.

  • Result: higher accuracy on tasks like Semantic Textual Similarity (STS), duplicate detection, or answer re-ranking.

  • Example:

    • Bi-encoder (separate embeddings) might give “man plays guitar”“guitarist performing” a similarity of 0.7.

    • Cross-encoder, by jointly encoding, can push it to 0.95 because it captures the equivalence more precisely.

2. Adapting to Domain-Specific Data

  • Pretrained models (BERT, RoBERTa, etc.) are general-purpose.

  • Fine-tuning on your own dataset teaches the cross-encoder to judge similarity in your context.

  • Examples:

    • Legal documents → “Section 5.1” vs “Clause V” might be synonyms only in legal domain.

    • Medical texts → “heart attack” ≈ “myocardial infarction”.

    • Customer support → “reset password” ≈ “forgot login credentials”.

Without fine-tuning, the model might miss these domain-specific relationships.

3. Optimal for Ranking Tasks

  • In search or retrieval, you often want to re-rank candidates returned by a fast retriever.

  • Cross-encoder excels here:

    • Bi-encoder: retrieves top-100 candidates quickly.

    • Cross-encoder: re-scores those top-100 pairs with higher accuracy.

  • This setup is widely used in open-domain QA (like MS MARCO, ColBERT pipelines), recommender systems, and semantic search.

4. Regression & Classification Tasks

  • Many tasks are not just “similar / not similar” but have graded similarity (0–5 in STS-B).

  • A fine-tuned cross-encoder can predict continuous similarity scores.

  • It can also be adapted for classification (duplicate vs not duplicate, entailment vs contradiction, etc.).

5. When Data Labels Matter

  • If you have annotated sentence pairs, fine-tuning a cross-encoder directly optimizes for your target metric (e.g., MSE on similarity scores, accuracy on duplicates).

  • A pretrained model alone will not “know” your specific scoring function.

  • Example: Two sentences could be judged similar by generic BERT, but your dataset might label them as not duplicates because of context.

6. Performance vs Efficiency Tradeoff

  • Cross-encoders are slower because you must run the Transformer per sentence pair.

  • But they’re worth training when:

    • Accuracy is more important than latency (e.g., offline re-ranking, evaluation tasks).

    • Dataset size is manageable (you don’t need to encode millions of pairs at once).

    • You have a candidate shortlist (bi-encoder first, then cross-encoder refine).

🧠 Fine-tune a Cross Encoder

Let's come to training part where we’ll fine-tune a cross-encoder (BERT-based) on the STS-Benchmark dataset, where pairs of sentences are scored on semantic similarity (0–5).

Fig. Fine tuning Cross-Encoders

1. Install Dependencies

code
pip install torch transformers sentence-transformers datasets accelerate

2. Load Data

We’ll use the STS-B dataset from Hugging Face.

code
# ========================# Dataset Loading# ========================from datasets import load_dataset# Load Semantic Textual Similarity Benchmark# https://huggingface.co/datasets/PhilipMay/stsb_multi_mtprint("Loading STS-B (multilingual, English split)...")dataset = load_dataset("stsb_multi_mt", "en")print(dataset)  # Show available splits (train/test)

3. Prepare Training Data

We’ll convert pairs into (sentence1, sentence2, score) format. We use the (sentence1, sentence2, score) format because a cross-encoder operates on paired sentences and needs a supervised similarity score to learn from. This format directly aligns with both the model’s input structure and the training objective.

code
# Prepare InputExamplestrain_examples = [    InputExample(texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]))    for row in dataset["train"]]dev_examples = [    InputExample(texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]))    for row in dataset["test"]]

4. Create a Data Loader

code
# Create DataLoadertrain_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)

5. Model Setup

code
# ========================# Model Setup# ========================print(f"Loading CrossEncoder model: {MODEL_NAME}")model = CrossEncoder(MODEL_NAME, num_labels=1)# Evaluator (Spearman/Pearson correlation between predicted & true scores)evaluator = CECorrelationEvaluator.from_input_examples(dev_examples, name="sts-dev")

6. Training

code
# ========================# Training# ========================print("Starting training...")model.fit(    train_dataloader=train_dataloader,    evaluator=evaluator,    epochs=EPOCHS,    evaluation_steps=EVAL_STEPS,    warmup_steps=WARMUP_STEPS,    output_path=OUTPUT_DIR)

6. Reload Trained Model

code
# ========================# Reload Trained Model# ========================print("Loading trained model from:", OUTPUT_DIR)model = CrossEncoder(OUTPUT_DIR)

7. Inference Demo

code
# --- Pairwise similaritytest_sentences = [    ("A man is playing a guitar.", "A person is playing a guitar."),    ("A dog is running in the park.", "A cat is sleeping on the couch.")]scores = model.predict(test_sentences)print("\nSimilarity Prediction Demo:")for (s1, s2), score in zip(test_sentences, scores):    print(f"  {s1} <-> {s2} => {score:.3f}")# --- Information retrieval style (ranking)query = "What is the capital of France?"candidates = [    "Paris is the capital city of France.",    "London is the capital of the UK.",    "France is known for its wine and cheese."]pairs = [(query, cand) for cand in candidates]scores = model.predict(pairs)ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)print("\nRanking Demo:")for cand, score in ranked:    print(f"  {cand} => {score:.3f}")

8. Complete Code

code
# main.py# ========================# Imports & Configuration# ========================import osimport torchfrom datasets import load_datasetfrom torch.utils.data import DataLoaderfrom sentence_transformers import CrossEncoder, InputExamplefrom sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator# ConfigMODEL_NAME = "bert-base-uncased"OUTPUT_DIR = "./cross-encoder-stsb"BATCH_SIZE = 16EPOCHS = 3WARMUP_STEPS = 100EVAL_STEPS = 500SEED = 42# Ensure reproducibilitytorch.manual_seed(SEED)# ========================# Dataset Loading# ========================print("Loading STS-B (multilingual, English split)...")dataset = load_dataset("stsb_multi_mt", "en")print(dataset)  # Show available splits (train/test)# Prepare InputExamplestrain_examples = [    InputExample(texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]))    for row in dataset["train"]]dev_examples = [    InputExample(texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]))    for row in dataset["test"]]# Create DataLoadertrain_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)# ========================# Model Setup# ========================print(f"Loading CrossEncoder model: {MODEL_NAME}")model = CrossEncoder(MODEL_NAME, num_labels=1)# Evaluator (Spearman/Pearson correlation between predicted & true scores)evaluator = CECorrelationEvaluator.from_input_examples(dev_examples, name="sts-dev")# ========================# Training# ========================print("Starting training...")model.fit(    train_dataloader=train_dataloader,    evaluator=evaluator,    epochs=EPOCHS,    evaluation_steps=EVAL_STEPS,    warmup_steps=WARMUP_STEPS,    output_path=OUTPUT_DIR)# ========================# Reload Trained Model# ========================print("Loading trained model from:", OUTPUT_DIR)model = CrossEncoder(OUTPUT_DIR)# ========================# Inference Demo# ========================# --- Pairwise similaritytest_sentences = [    ("A man is playing a guitar.", "A person is playing a guitar."),    ("A dog is running in the park.", "A cat is sleeping on the couch.")]scores = model.predict(test_sentences)print("\nSimilarity Prediction Demo:")for (s1, s2), score in zip(test_sentences, scores):    print(f"  {s1} <-> {s2} => {score:.3f}")# --- Information retrieval style (ranking)query = "What is the capital of France?"candidates = [    "Paris is the capital city of France.",    "London is the capital of the UK.",    "France is known for its wine and cheese."]pairs = [(query, cand) for cand in candidates]scores = model.predict(pairs)ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)print("\nRanking Demo:")for cand, score in ranked:    print(f"  {cand} => {score:.3f}")

Output:

code
(env) D:\github\finetune-crossencoder>python main1.pyLoading STS-B (multilingual, English split)...DatasetDict({    train: Dataset({        features: ['sentence1', 'sentence2', 'similarity_score'],        num_rows: 5749    })    test: Dataset({        features: ['sentence1', 'sentence2', 'similarity_score'],        num_rows: 1379    })    dev: Dataset({        features: ['sentence1', 'sentence2', 'similarity_score'],        num_rows: 1500    })})Loading CrossEncoder model: bert-base-uncasedSome weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.Starting training...  0%|                                                                                                                                             | 0/1080 [00:00<?, ?it/s]D:\github\finetune-crossencoder\env\Lib\site-packages\torch\utils\data\dataloader.py:666: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.  warnings.warn(warn_msg){'loss': -20.1537, 'grad_norm': 50.69091033935547, 'learning_rate': 1.1832139201637667e-05, 'epoch': 1.39}{'eval_sts-dev_pearson': 0.4514054666098877, 'eval_sts-dev_spearman': 0.4771302005902, 'eval_runtime': 67.8654, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 1.39}{'loss': -32.7533, 'grad_norm': 52.87107849121094, 'learning_rate': 1.5967246673490277e-06, 'epoch': 2.78}{'eval_sts-dev_pearson': 0.5504492763939616, 'eval_sts-dev_spearman': 0.5489895972483916, 'eval_runtime': 91.5175, 'eval_samples_per_second': 0.0, 'eval_steps_per_second': 0.0, 'epoch': 2.78}{'train_runtime': 5965.8199, 'train_samples_per_second': 2.891, 'train_steps_per_second': 0.181, 'train_loss': -27.04566062644676, 'epoch': 3.0}100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1080/1080 [1:39:25<00:00,  5.52s/it]Loading trained model from: ./cross-encoder-stsbSimilarity Prediction Demo:  A man is playing a guitar. <-> A person is playing a guitar. => 1.000  A dog is running in the park. <-> A cat is sleeping on the couch. => 0.176Ranking Demo:  Paris is the capital city of France. => 1.000  France is known for its wine and cheese. => 1.000  London is the capital of the UK. => 0.832

✅ Key Takeaways

  • Cross-encoders model fine-grained token-level interactions, making them highly accurate for semantic similarity, re-ranking, and NLI (Natural Language Inference).

  • Training requires pairs of sentences with labels (scores or categories).

  • They are slower than bi-encoders, so best used for re-ranking top candidates.

  • Libraries like Sentence-Transformers make training straightforward.

Related articles: Searching / Indexing / RAG Series

  1. BM25-Based Searching: A Developer’s Comprehensive Guide
  2. BM25 vs Dense Retrieval for RAG Engineers
  3. Building a Full-Stack Hybrid Search System (BM25 + Vectors + Cross-Encoders) with Docker
  4. Hands-on Tutorial: Fine-tune a Cross-Encoder for Semantic Similarity (This Article)