Building Privacy-Preserving Machine Learning Applications in Python with Homomorphic Encryption

Data privacy is one of the biggest challenges in deploying AI systems. From healthcare to finance, sensitive datasets are often required to train or run machine learning models — but sharing raw data with cloud providers or third-party services can lead to regulatory, security, and trust issues.

What if we could train and run models directly on encrypted data?
That’s the promise of Homomorphic Encryption (HE) — a cryptographic technique that allows computations on ciphertexts without ever decrypting them.

In this blog, we’ll build a series of demo applications in Python that showcase how homomorphic encryption can power privacy-preserving machine learning:

  • 🔑 Introduction to homomorphic encryption
  • 🧮 Linear regression on encrypted data
  • 🌐 FastAPI-based encrypted inference service
  • ✅ Logistic regression classification with encrypted medical data
  • 🚀 Limitations, challenges, and the road ahead

1. What is Homomorphic Encryption?

Traditional encryption protects data at rest (storage) and in transit (network), but not during computation. Once data is processed, it must be decrypted — exposing it to whoever is running the computation.

Homomorphic encryption changes this paradigm. It enables computation on encrypted values such that when decrypted, the result matches the computation as if it were done on plaintext.

For example:

  • Client encrypts 5 and 7
  • Server computes (enc_5 + enc_7)
  • Client decrypts → gets 12

The server never saw the numbers 5 or 7, but still produced a meaningful result.

This opens the door for privacy-preserving AI services where cloud providers can run models on user data without ever seeing the raw inputs.

2. Python Libraries for Homomorphic Encryption

Several libraries bring HE to Python developers:

  • Pyfhel → general-purpose HE (wrapper around Microsoft SEAL)
  • TenSEAL → optimized for machine learning, supports encrypted vectors & tensors
  • HElib → C++ library with Python bindings

For our demos, we’ll use TenSEAL, which is designed for encrypted machine learning use cases.

Install it:

pip install tenseal

3. Demo: Linear Regression with Encrypted Data

Let’s start with a toy regression task: predict house price from house size using encrypted training data.

Step 1: Setup TenSEAL Context

import tenseal as ts
import numpy as np

def create_context():
    context = ts.context(
        ts.SCHEME_TYPE.CKKS,
        poly_modulus_degree=8192,
        coeff_mod_bit_sizes=[60, 40, 40, 60]
    )
    context.global_scale = 2**40
    context.generate_galois_keys()
    return context

This function creates a CKKS homomorphic encryption context with polynomial modulus degree 8192, precision scale 2^40, and Galois keys enabled. This context is the foundation for performing encrypted computations (like addition, multiplication, or rotations) on encrypted real numbers.

Step 2: Sample Training Data

X = np.array([1, 2, 3, 4, 5], dtype=float)
y = np.array([15, 30, 45, 60, 75], dtype=float)  # price = 15 * size

It’s creating a toy dataset where the price is directly proportional to the size, with a multiplier of 15.

Step 3: Encrypt Data

context = create_context()
enc_X = [ts.ckks_vector(context, [val]) for val in X]
enc_y = [ts.ckks_vector(context, [val]) for val in y]

This snippet takes the plaintext training data (X and y) and converts each number into an encrypted vector using CKKS. After this step, you can do computations (like addition, multiplication, scaling) directly on the encrypted data without ever decrypting it.

Step 4: Training (Simplified Gradient Descent)

For demo purposes, we decrypt inside gradient computation — but in a real HE setup, all computations could remain encrypted.

def train_linear_regression(enc_X, enc_y, lr=0.1, epochs=20):
    w, b = 0.0, 0.0
    n = len(enc_X)

    for epoch in range(epochs):
        grad_w, grad_b = 0, 0
        for xi, yi in zip(enc_X, enc_y):
            y_pred = xi * w + b
            error = y_pred - yi
            grad_w += (xi * error).decrypt()[0]
            grad_b += error.decrypt()[0]

        grad_w /= n
        grad_b /= n
        w -= lr * grad_w
        b -= lr * grad_b

        print(f"Epoch {epoch+1}: w={w:.4f}, b={b:.4f}")

    return w, b

The code trains a simple linear regression model using gradient descent. It starts with weight and bias set to zero, then for each epoch it computes predictions, calculates the error, and derives gradients with respect to the weight and bias. These gradients are averaged, then used to update the parameters by stepping in the opposite direction of the gradient. Although the inputs are encrypted, the gradients are decrypted during computation (for demo purposes). Finally, the function prints progress each epoch and returns the learned weight and bias.

Step 5: Train and Predict

w, b = train_linear_regression(enc_X, enc_y)
print(f"Final model: price = {w:.2f} * size + {b:.2f}")

enc_input = ts.ckks_vector(context, [6.0])
enc_pred = enc_input * w + b
print("Prediction for size=6:", enc_pred.decrypt()[0])

The code trains the model, prints the learned equation, and demonstrates making a prediction on new encrypted data.

Output:

(env) root@81eb33810340:/workspace/he-ml# python lin-reg-enc-data.py
Epoch 1: w=16.5000, b=4.5000
Epoch 2: w=13.5000, b=3.6000
Epoch 3: w=14.0700, b=3.6900
Epoch 4: w=13.9860, b=3.6000
Epoch 5: w=14.0214, b=3.5442
Epoch 6: w=14.0346, b=3.4834
Epoch 7: w=14.0515, b=3.4246
Epoch 8: w=14.0675, b=3.3667
Epoch 9: w=14.0832, b=3.3098
Epoch 10: w=14.0987, b=3.2539
Epoch 11: w=14.1140, b=3.1989
Epoch 12: w=14.1289, b=3.1448
Epoch 13: w=14.1437, b=3.0916
Epoch 14: w=14.1581, b=3.0394
Epoch 15: w=14.1724, b=2.9880
Epoch 16: w=14.1864, b=2.9375
Epoch 17: w=14.2001, b=2.8878
Epoch 18: w=14.2136, b=2.8390
Epoch 19: w=14.2269, b=2.7910
Epoch 20: w=14.2400, b=2.7438
Final model: price = 14.24 * size + 2.74
Prediction for size=6: 88.1838602661561

✅ We successfully trained & inferred on encrypted data.

4. Challenges and Limitations

While homomorphic encryption (HE) makes it possible to run machine learning on encrypted data, there are several practical challenges that must be understood before deploying such systems at scale:

4.1 Performance Overhead

  • Problem: HE computations are significantly slower compared to traditional machine learning on plaintext data.
    • For example, a single encrypted addition or multiplication can take milliseconds, while the same operation on plaintext takes microseconds or less.
    • Complex models that involve thousands or millions of operations (like deep neural networks) can become prohibitively slow.
  • Why it happens: Encryption schemes like CKKS or BFV encode values as large polynomials. Each multiplication or addition involves expensive polynomial arithmetic, number-theoretic transforms (NTT), and modulus switching.
  • Impact: HE is currently more suitable for smaller models (linear regression, logistic regression, decision trees) than large-scale deep learning, unless heavily optimized.
  • Performance → HE computations are slower than plaintext ML.
  • Ciphertext size → Encrypted data is much larger than plaintext.
  • Limited operations → Non-linear functions (sigmoid, softmax) must be approximated.
  • Training → Training fully on encrypted data is possible but heavy; many systems use federated learning + HE for practicality.

4.2 Ciphertext Size & Memory Consumption

  • Problem: Encrypted data (ciphertexts) are much larger than plaintext data.
    • For example, a single encrypted floating-point number might take a few kilobytes, whereas the raw value is just 8 bytes.
  • Why it happens: HE ciphertexts must include redundancy and structure (e.g., modulus, polynomial coefficients) to allow encrypted computations.
  • Impact:
    • Storing large datasets in encrypted form can require 10–100× more space.
    • Network communication between client and server becomes bandwidth-heavy.
    • Memory usage on the server can be a bottleneck if too many encrypted vectors are processed simultaneously.

4.3 Limited Supported Operations

  • Problem: Homomorphic encryption schemes support only a restricted set of operations efficiently.
    • Linear operations (addition, multiplication) are natural.
    • Non-linear functions like sigmoid, tanh, softmax, ReLU are not directly supported.
  • Workaround: Use polynomial approximations of non-linear functions.
    • Example: Replace the logistic sigmoid with a simple polynomial
    • These approximations work well in limited ranges but reduce accuracy.
  • Impact:
    • High-accuracy deep learning models cannot be fully ported to HE without approximation losses.
    • Research is ongoing into better polynomial or piecewise approximations that preserve accuracy while being HE-friendly.

4.4 Training on Encrypted Data

  • Problem: Training machine learning models entirely on encrypted data is computationally very expensive.
    • Gradient descent requires repeated multiplications, non-linear activations, and updates across many iterations.
    • Even a small logistic regression trained under HE can take hours or days.
  • Practical Approach:
    • Federated Learning + HE:
      • Clients keep data locally.
      • They compute model updates (gradients) on plaintext but encrypt them before sending to a central server.
      • The server aggregates encrypted updates (without seeing individual contributions) and updates the global model.
    • This hybrid approach combines efficiency with privacy, making it more realistic than fully HE-based training.
  • Impact: End-to-end encrypted training is still an active research area, with most production-ready solutions focusing on encrypted inference or encrypted aggregation of updates.

Homomorphic encryption is a breakthrough for privacy-preserving machine learning, but it comes with trade-offs: slower computations, larger ciphertexts, limited function support, and impracticality for large-scale training. For now, HE is most effective in encrypted inference and in combination with federated learning for training.

5. Future Directions

Homomorphic encryption for machine learning is still in its early stages, but the pace of research and applied innovation is accelerating. The next few years will likely bring advancements that address today’s limitations and open new possibilities for privacy-preserving AI. Here are some promising directions:

5.1 Federated Learning with Homomorphic Encryption

  • What it is:
    • In federated learning, multiple clients (e.g., hospitals, banks, mobile devices) train a shared model collaboratively without centralizing raw data.
    • Each client computes local updates (gradients or weights) and sends them to a central server for aggregation.
    • With HE, these updates can be encrypted before transmission. The server aggregates encrypted updates and sends back an improved global model — all without ever seeing the clients’ raw data or gradients.
  • Why it matters:
    • Protects sensitive datasets such as medical records, financial transactions, or user behavior logs.
    • Prevents the server or malicious insiders from inferring private information from model updates.
    • Enables cross-organization collaboration — e.g., pharmaceutical companies jointly training models on encrypted clinical trial data.
  • Challenges ahead:
    • Scaling to millions of clients while keeping training efficient.
    • Handling non-IID data (when different clients’ data distributions differ significantly).
    • Balancing HE’s computational overhead with the real-time needs of federated learning.

5.2 Encrypted Deep Learning

  • The vision: Run full-scale deep learning models like Convolutional Neural Networks (CNNs) for image classification or Transformers for natural language processing directly on encrypted inputs.
  • Progress so far:
    • Research prototypes have shown CNNs running on encrypted images for tasks like digit recognition (MNIST) or medical imaging.
    • Transformers under HE are being studied for privacy-preserving NLP, where users can query encrypted documents without revealing their text.
  • Why it’s hard:
    • Deep models rely heavily on non-linear functions (ReLU, softmax, attention mechanisms), which HE does not natively support.
    • Even polynomial approximations for these functions become unstable as model depth increases.
    • The ciphertext growth and computational cost scale rapidly with network complexity.
  • The future:
    • Research into HE-friendly neural architectures — custom-designed layers that avoid costly operations.
    • Use of bootstrapping optimizations (refreshing ciphertexts) to enable deeper computations.
    • Hybrid models where only the most privacy-sensitive layers are run under HE, while less critical parts run in plaintext.

5.3 Hybrid Privacy Technologies

Homomorphic encryption is powerful, but it isn’t a silver bullet. The most promising direction is combining HE with other privacy-preserving technologies to build robust, end-to-end secure ML systems:

  • HE + Differential Privacy (DP):
    • HE ensures the data remains encrypted during computation.
    • DP adds statistical noise to outputs or gradients to prevent leakage about individual records.
    • Together, they provide both cryptographic security and formal privacy guarantees.
  • HE + Secure Multi-Party Computation (SMPC):
    • SMPC splits data across multiple parties who jointly compute without revealing their shares.
    • HE can accelerate or simplify SMPC protocols by reducing communication rounds.
    • This hybrid approach is useful for high-stakes collaborations (e.g., banks jointly detecting fraud without revealing customer data).
  • HE + Trusted Execution Environments (TEE):
    • TEE (like Intel SGX) provides hardware-based secure enclaves.
    • HE can complement TEEs by reducing the trust required in hardware vendors — even if an enclave is compromised, the data remains encrypted.

5.4 Looking Ahead

The long-term vision is fully private AI pipelines, where:

  1. Data is encrypted at collection.
  2. Training happens across multiple entities without any party seeing the raw data.
  3. Inference is run on encrypted queries, producing encrypted outputs.
  4. Clients alone decrypt results, ensuring data confidentiality, model confidentiality, and output confidentiality.

If today’s limitations are addressed, such pipelines could transform industries like:

  • Healthcare: AI diagnosis on encrypted medical scans without hospitals sharing raw images.
  • Finance: Fraud detection on encrypted transaction streams.
  • Government & Defense: Secure intelligence sharing across agencies.
  • Consumer Tech: Voice assistants or chatbots that process encrypted user inputs without “listening in.”

6. Conclusion

The future of homomorphic encryption in machine learning is not about HE alone, but about ecosystems of privacy technologies — federated learning for collaboration, HE for encrypted computation, DP for statistical privacy, SMPC for secure multi-party workflows, and TEEs for hardware-level isolation. Together, these will bring us closer to a world where AI can learn from everyone, without exposing anyone.

Provenance in AI: Auto-Capturing Provenance with MLflow and W3C PROV-O in PyTorch Pipelines – Part 4

AI engineers spend a lot of time building, training, and iterating on models. But as pipelines grow more complex, it becomes difficult to answer simple but crucial questions:

  • Which dataset version trained this model?
  • Which parameters were used?
  • Who triggered this training job?
  • Can I reproduce this run six months later?

Without structured provenance tracking, reproducibility and compliance become almost impossible. In regulated domains, this is not optional — it’s mandatory.

In this article, we’ll show how to integrate W3C PROV-O (a standard for provenance modeling) with MLflow (a popular experiment tracking framework) in a PyTorch pipeline. The result: every training run not only logs metrics and artifacts but also generates a machine-readable provenance graph for accountability, auditability, and governance.

🔎 Background: Why PROV-O + MLflow?

  • MLflow is widely used for experiment tracking. It records metrics, parameters, and artifacts like models and logs. However, MLflow’s logs are application-specific and not standardized for knowledge sharing across systems.
  • W3C PROV-O is a semantic ontology (built on RDF/OWL2) that provides a standardized vocabulary for describing provenance: Entities, Activities, and Agents, and their relationships (prov:used, prov:wasGeneratedBy, prov:wasAttributedTo).

By combining the two:

  • MLflow provides the data source of truth for training runs.
  • PROV-O provides the interoperable representation of lineage, useful for audits, governance, and integration into knowledge graphs.

🏗️ Architecture Overview

Our integration maps MLflow concepts to PROV-O concepts:

MLflow ConceptPROV-O EquivalentExample
MLflow Runprov:ActivityTraining job run ID f4a22
MLflow Artifact (model)prov:Entitymodel_v1.pth
Dataset (input)prov:Entitydataset.csv
Metrics (loss, accuracy)prov:Entitymetrics.json
MLflow User/Systemprov:AgentEngineer triggering the run

⚙️ Step 1: Setup

We need a combination of MLflow (for tracking) and rdflib (for provenance graph generation).

pip install mlflow torch rdflib prov

  • mlflow → tracks experiments, models, metrics, and artifacts.
  • torch → used for building the PyTorch model.
  • rdflib → builds and serializes RDF/PROV-O graphs.
  • prov → utilities for working with W3C PROV specifications.

🧑‍💻 Step 2: PyTorch Training with MLflow Logging

We start with a simple PyTorch script that trains a small neural network while logging to MLflow.

import torch
import torch.nn as nn
import torch.optim as optim
import mlflow
import mlflow.pytorch

# Fake dataset
X = torch.randn(100, 10)
y = torch.randint(0, 2, (100,))

# Simple NN
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 2))
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

with mlflow.start_run() as run:
    for epoch in range(5):
        optimizer.zero_grad()
        preds = model(X)
        loss = loss_fn(preds, y)
        loss.backward()
        optimizer.step()
        mlflow.log_metric("loss", loss.item(), step=epoch)

    mlflow.log_param("lr", 0.001)
    mlflow.pytorch.log_model(model, "model")

At this point, MLflow is recording metrics (loss), params (lr), and the trained model artifact. But it doesn’t capture semantic provenance — for example, which dataset was used, who ran this job, and how results are connected.

🔗 Step 3: Provenance Tracker for MLflow

Here’s where PROV-O comes in. We build a Provenance Tracker that:

  1. Defines entities (datasets, models, metrics).
  2. Defines activities (the MLflow run).
  3. Defines agents (engineer, system).
  4. Links them using PROV-O relations.
  5. Serializes into Turtle (.ttl) or JSON-LD.
from rdflib import Graph, Namespace, URIRef, Literal
from rdflib.namespace import RDF, FOAF
import mlflow

PROV = Namespace("http://www.w3.org/ns/prov#")
EX = Namespace("http://example.org/")

def log_provenance(run):
    g = Graph()
    g.bind("prov", PROV)
    g.bind("ex", EX)

    # Agent (engineer/system)
    user = EX["engineer"]
    g.add((user, RDF.type, PROV.Agent))
    g.add((user, FOAF.name, Literal("AI Engineer")))

    # Activity (the MLflow run)
    activity = EX[f"run_{run.info.run_id}"]
    g.add((activity, RDF.type, PROV.Activity))

    # Input dataset
    dataset = EX["dataset.csv"]
    g.add((dataset, RDF.type, PROV.Entity))
    g.add((activity, PROV.used, dataset))

    # Model entity
    model = EX[f"model_{run.info.run_id}.pth"]
    g.add((model, RDF.type, PROV.Entity))
    g.add((model, PROV.wasGeneratedBy, activity))
    g.add((model, PROV.wasAttributedTo, user))

    # Metrics entity
    metrics = EX[f"metrics_{run.info.run_id}.json"]
    g.add((metrics, RDF.type, PROV.Entity))
    g.add((metrics, PROV.wasGeneratedBy, activity))
    g.add((metrics, PROV.wasAttributedTo, user))

    # Serialize + store
    prov_file = f"prov_{run.info.run_id}.ttl"
    g.serialize(prov_file, format="turtle")
    mlflow.log_artifact(prov_file, artifact_path="provenance")
    print(f"✅ Provenance logged in {prov_file}")

📦 Step 4: Integrate Tracker

Modify the training script to call log_provenance(run) after training completes.

with mlflow.start_run() as run:
    # Training loop (as above) ...
    mlflow.pytorch.log_model(model, "model")

    # Capture provenance
    log_provenance(run)

Now every MLflow run will automatically create a provenance graph and store it alongside model artifacts.

Final script train-small-nn-pytorch.py:

import torch
import torch.nn as nn
import torch.optim as optim
import mlflow
import mlflow.pytorch
from rdflib import Graph, Namespace, URIRef, Literal
from rdflib.namespace import RDF, FOAF
import mlflow

# Fake dataset
X = torch.randn(100, 10)
y = torch.randint(0, 2, (100,))

# Simple NN
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 2))
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Provenance Tracker for MLflow
PROV = Namespace("http://www.w3.org/ns/prov#")
EX = Namespace("http://example.org/")

def log_provenance(run):
    g = Graph()
    g.bind("prov", PROV)
    g.bind("ex", EX)

    # Agent (engineer/system)
    user = EX["engineer"]
    g.add((user, RDF.type, PROV.Agent))
    g.add((user, FOAF.name, Literal("AI Engineer")))

    # Activity (the MLflow run)
    activity = EX[f"run_{run.info.run_id}"]
    g.add((activity, RDF.type, PROV.Activity))

    # Input dataset
    dataset = EX["dataset.csv"]
    g.add((dataset, RDF.type, PROV.Entity))
    g.add((activity, PROV.used, dataset))

    # Model entity
    model = EX[f"model_{run.info.run_id}.pth"]
    g.add((model, RDF.type, PROV.Entity))
    g.add((model, PROV.wasGeneratedBy, activity))
    g.add((model, PROV.wasAttributedTo, user))

    # Metrics entity
    metrics = EX[f"metrics_{run.info.run_id}.json"]
    g.add((metrics, RDF.type, PROV.Entity))
    g.add((metrics, PROV.wasGeneratedBy, activity))
    g.add((metrics, PROV.wasAttributedTo, user))

    # Serialize + store
    prov_file = f"prov_{run.info.run_id}.ttl"
    g.serialize(prov_file, format="turtle")
    mlflow.log_artifact(prov_file, artifact_path="provenance")
    print(f"✅ Provenance logged in {prov_file}")

# MLflow
with mlflow.start_run() as run:
    for epoch in range(5):
        # Training loop
        optimizer.zero_grad()
        preds = model(X)
        loss = loss_fn(preds, y)
        loss.backward()
        optimizer.step()
        mlflow.log_metric("loss", loss.item(), step=epoch)
        
        mlflow.pytorch.log_model(model, "model")
        # Capture provenance
        log_provenance(run)

    mlflow.log_param("lr", 0.001)
    mlflow.pytorch.log_model(model, "model")

📂 Step 5: Example Output

Provenance graph (Turtle format) prov_70d8b46c6451416d92a0ae7cac4c8602.ttl:

@prefix ex: <http://example.org/> .
@prefix foaf: <http://xmlns.com/foaf/0.1/> .
@prefix prov: <http://www.w3.org/ns/prov#> .

ex:metrics_70d8b46c6451416d92a0ae7cac4c8602.json a prov:Entity ;
    prov:wasAttributedTo ex:engineer ;
    prov:wasGeneratedBy ex:run_70d8b46c6451416d92a0ae7cac4c8602 .

ex:model_70d8b46c6451416d92a0ae7cac4c8602.pth a prov:Entity ;
    prov:wasAttributedTo ex:engineer ;
    prov:wasGeneratedBy ex:run_70d8b46c6451416d92a0ae7cac4c8602 .

ex:dataset.csv a prov:Entity .

ex:engineer a prov:Agent ;
    foaf:name "AI Engineer" .

ex:run_70d8b46c6451416d92a0ae7cac4c8602 a prov:Activity ;
    prov:used ex:dataset.csv .

This graph is machine-readable and interoperable with semantic web tools, knowledge graphs, and governance platforms.

🔍 Step 6: Query Provenance

Since PROV-O is RDF-based, we can load graphs into a triple store and query with SPARQL. The following are a few example queries:

1️⃣Which dataset was used to generate a given model?

SELECT ?dataset WHERE {
  ex:model_70d8b46c6451416d92a0ae7cac4c8602.pth prov:wasGeneratedBy ?activity .
  ?activity prov:used ?dataset .
}

This query returns dataset.csv as the dataset that trained model_f4a22.pth.

The SPARQL queries can be run using the following Python script:

import rdflib

# Create a Graph object
g = rdflib.Graph()

# Parse the TTL file into the graph
g.parse("prov_70d8b46c6451416d92a0ae7cac4c8602.ttl", format='turtle')

# Define your SPARQL query
sparql_query = """
SELECT ?dataset WHERE {
  ex:model_70d8b46c6451416d92a0ae7cac4c8602.pth prov:wasGeneratedBy ?activity .
  ?activity prov:used ?dataset .
}
"""

# Execute the query
results = g.query(sparql_query)

# Process the results
for row in results:
	print(row)

2️⃣All models generated by a given engineer

SELECT ?model
WHERE {
  ?model a prov:Entity ;
         prov:wasAttributedTo ex:engineer .
}

👉 Returns all model URIs that were attributed to the engineer ex:engineer.

3️⃣All datasets used in the last month

If your provenance tracker adds prov:generatedAtTime or similar timestamps on entities/activities, you can filter by date. Example:

SELECT ?dataset ?time
WHERE {
  ?activity a prov:Activity ;
            prov:used ?dataset ;
            prov:endedAtTime ?time .
  ?dataset a prov:Entity .
  FILTER (?time >= "2025-07-28T00:00:00Z"^^xsd:dateTime && 
          ?time <= "2025-08-28T23:59:59Z"^^xsd:dateTime)
}

👉 This finds all prov:Entity datasets used by any activity that ended in the last month.

4️⃣Provenance chains across multiple runs (for auditing)

Here we want to trace lineage from dataset → activity → model → metrics.

SELECT ?dataset ?activity ?model ?metrics
WHERE {
  ?dataset a prov:Entity .
  ?activity a prov:Activity ;
            prov:used ?dataset ;
            prov:generated ?model, ?metrics .
  ?model a prov:Entity .
  ?metrics a prov:Entity .
}

👉 This gives a table of full provenance chains, so you can audit multiple runs together.

5️⃣Find all runs that reused the same dataset

Useful for detecting data reuse:

SELECT ?dataset (GROUP_CONCAT(?model; separator=", ") AS ?models)
WHERE {
  ?activity prov:used ?dataset ;
            prov:generated ?model .
}
GROUP BY ?dataset
HAVING (COUNT(?model) > 1)

👉 Returns datasets that were reused in multiple model generations.

⚡ These queries assume you have prov:used, prov:generated, prov:wasAttributedTo, and timestamps (prov:endedAtTime or prov:generatedAtTime) in your TTL logs.

✅ Why This Matters

By extending MLflow with PROV-O, AI engineers gain:

  • Reproducibility → Every model is linked to the exact data and parameters that generated it.
  • Auditability → Regulators and compliance teams can trace how outputs were produced.
  • Transparency → Business stakeholders can understand lineage without relying on tribal knowledge.
  • Interoperability → Since PROV-O is a W3C standard, provenance metadata can be integrated into external governance, data catalog, and knowledge graph systems.

🚀 What We Learnt

We’ve seen how to:

  1. Train a PyTorch model with MLflow.
  2. Capture provenance automatically using PROV-O.
  3. Serialize provenance graphs as RDF/Turtle.
  4. Query lineage with SPARQL.

Navigating AI Risks with NIST’s AI Risk Management Framework (AI RMF)

Practical Guide for AI Engineers with Supporting Tools

Artificial Intelligence (AI) is no longer a research curiosity—it powers critical systems in healthcare, finance, transportation, and defense. But as AI adoption grows, so do the risks: bias, security vulnerabilities, lack of transparency, and unintended consequences.

To help organizations manage these challenges, the U.S. National Institute of Standards and Technology (NIST) introduced the AI Risk Management Framework (AI RMF 1.0) in January 2023.

For AI engineers, this framework is more than high-level governance—it can be operationalized with existing open-source libraries, MLOps pipelines, and monitoring tools. Let’s break it down.

What is the NIST AI RMF?

The AI RMF is a voluntary, flexible, and sector-agnostic framework designed to help organizations manage risks across the AI lifecycle.

Its ultimate goal is to foster trustworthy AI systems by emphasizing principles like fairness, robustness, explainability, privacy, and accountability.

Think of it as the AI-equivalent of DevSecOps best practices—a structured way to integrate risk thinking into design, development, deployment, and monitoring. Instead of retrofitting ethical or legal concerns at the end, engineers can bake them directly into code, pipelines, and testing.

The Core Pillars of AI RMF and Supporting Tools

NIST organizes the framework around four core functions, known as the AI RMF Core. For engineers, these map nicely onto the ML lifecycle.

1. Govern – Organizational Structures & Accountability

What it means:
Governance is about who owns risk and how it is tracked. Without clear accountability, even the best fairness metrics or privacy protections won’t scale. This function ensures leadership commitment, defined responsibilities, and enforceable processes.

How engineers can implement it:

  • Standardize documentation for datasets and models.
  • Track lineage and provenance of data and experiments.
  • Build reproducible ML pipelines so decisions can be audited later.

Supporting Tools:

  • Model Cards (Google) → lightweight docs describing model purpose, limitations, and ethical considerations.
  • Datasheets for Datasets (MIT/Google) → dataset documentation to capture origin, bias, and quality.
  • Weights & Biases / MLflow → experiment tracking, versioning, and governance metadata.
  • Great Expectations → data quality validation built into ETL/ML pipelines.

2. Map – Understanding Context & Identifying Risks

What it means:
Before writing a line of model code, engineers need to understand context, stakeholders, and risks. Mapping ensures the AI system is aligned with real-world use cases and surfaces risks early.

How engineers can implement it:

  • Identify who the AI system impacts (users, communities, regulators).
  • Document the intended use vs. possible misuse.
  • Anticipate risks (bias, adversarial threats, performance in edge cases).

Supporting Tools:

  • NIST Trustworthy AI Playbook → companion guide with risk templates and examples.
  • Data Provenance Tools:
    • Pachyderm → versioned data pipelines.
    • DVC → Git-like data and model versioning.
    • LakeFS → Git-style object store for ML data.
  • Risk Taxonomy Checklists → resources from Partnership on AI and OECD for structured risk mapping.

3. Measure – Quantifying and Testing Risks

What it means:
Mapping risks isn’t enough—you need to quantify them with metrics. This includes fairness, robustness, explainability, privacy leakage, and resilience to adversarial attacks.

How engineers can implement it:

  • Integrate fairness and bias checks into CI/CD pipelines.
  • Run explainability tests to ensure interpretability across stakeholders.
  • Stress-test robustness with adversarial attacks and edge cases.

Supporting Tools:

  • Fairness & Bias:
    • IBM AIF360 → 70+ fairness metrics and mitigation strategies.
    • Microsoft Fairlearn → fairness dashboards and post-processing.
  • Explainability:
    • SHAP, LIME, Captum (PyTorch) → feature attribution and local/global explainability.
    • Evidently AI → interpretability reports integrated with drift monitoring.
  • Robustness & Security:
    • Adversarial Robustness Toolbox (ART) → adversarial testing, poisoning, and defense.
    • Foolbox → adversarial attack generation for benchmarking model resilience.

4. Manage – Continuous Monitoring & Mitigation

What it means:
AI risks don’t stop at deployment—they evolve as data shifts, adversaries adapt, and systems scale. Managing risk means establishing feedback loops, monitoring dashboards, and incident response plans.

How engineers can implement it:

  • Treat models as living systems that require continuous health checks.
  • Monitor for data drift, bias drift, and performance decay.
  • Set up incident management protocols for when AI fails.

Supporting Tools:

  • MLOps Platforms: Kubeflow, Seldon, MLflow for deployment and lifecycle tracking.
  • Continuous Monitoring:
    • Neptune.ai → experiment tracking with risk-aware metrics.
    • Evidently AI, Arize AI, WhyLabs → production-grade drift, bias, and observability dashboards.
  • Incident Management: Adapt SRE playbooks (PagerDuty, Opsgenie) for ML-specific failures like data poisoning or unexpected bias spikes.

Characteristics of Trustworthy AI (and Tools to Support Them)

The AI RMF identifies seven key characteristics of trustworthy AI. These are cross-cutting qualities every AI system should strive for:

  1. Valid & Reliable → Testing frameworks (pytest, pytest-ml) + continuous evaluation.
  2. Safe → Simulation environments (e.g., CARLA for self-driving AI).
  3. Secure & Resilient → Adversarial robustness tools (ART, Foolbox).
  4. Accountable & Transparent → Model Cards, version control (MLflow, DVC).
  5. Explainable & Interpretable → SHAP, LIME, Captum.
  6. Privacy-EnhancedTensorFlow Privacy, PySyft, Opacus.
  7. Fair with Harm MitigationFairlearn, AIF360, Evidently AI bias dashboards.

For engineers, these aren’t abstract principles—they map directly to unit tests, pipelines, and monitoring dashboards you can implement today.

Why It Matters for Engineers

Traditionally, “risk” in engineering meant downtime or performance degradation. But in AI, risk is multi-dimensional:

  • A biased recommendation engine → unfair economic impact.
  • A misclassified medical image → patient safety risks.
  • An adversarial attack on a financial model → systemic security threat.

The AI RMF helps engineers broaden their definition of risk and integrate safeguards across the lifecycle.

By adopting the framework with supporting tools, engineers can:

  • Automate fairness, robustness, and privacy checks in CI/CD.
  • Log provenance for datasets and models.
  • Build dashboards for continuous risk monitoring.
  • Collaborate with legal and policy teams using standardized documentation.

Getting Started (Actionable Steps)

  1. Integrate Provenance Tracking → Use DVC or Pachyderm in your ML pipeline.
  2. Automate Fairness & Robustness Tests → Add Fairlearn and ART checks into CI/CD.
  3. Adopt Transparency Practices → Publish Model Cards for all deployed models.
  4. Monitor in Production → Deploy Evidently AI or WhyLabs for drift & bias detection.
  5. Collaborate Cross-Functionally → Align engineering practices with governance and compliance teams.

Final Thoughts

The NIST AI RMF is not a compliance checklist—it’s a living guide to building trustworthy AI. For engineers, it bridges the gap between technical implementation and organizational responsibility.

By embedding Govern, Map, Measure, Manage into your workflow—and leveraging open-source tools like AIF360, Fairlearn, ART, MLflow, and Evidently AI—you don’t just ship models, you ship trustworthy models.

As regulation around AI tightens globally (EU AI Act, U.S. AI Executive Orders, ISO/IEC standards), frameworks like NIST’s AI RMF will help engineers stay ahead of the curve.

👉 Takeaway for AI Engineers: Use the NIST AI RMF as your north star, and operationalize it with today’s open-source and enterprise tools. Trustworthy AI isn’t just theory—it’s code, pipelines, and monitoring.

    Provenance in AI: Building a Provenance Graph with Neo4j – Part 3

    In Part 2, we built a ProvenanceTracker that generates signed, schema-versioned lineage logs for datasets, models, and inferences. That ensures trust at the data level — but provenance becomes truly valuable when we can query and reason about it.

    In this post, we’ll import the signed logs into Neo4j, the leading graph database, and show how to query provenance directly using Cypher.

    Why Neo4j for Provenance?

    AI lineage is fundamentally a graph:

    • A Dataset can be used to train many Models.
    • A Model can generate thousands of Inferences.
    • An Inference must be traceable back to the model and dataset(s).

    Representing this as a graph gives us a natural way to answer questions like:

    • “Which datasets were used to train this model?”
    • “Which inferences came from this model version?”
    • “What is the complete lineage of an inference?”

    Step 1. Provenance Importer with Signature Verification

    The importer reads signed JSONL logs, verifies signatures, and inserts data into Neo4j with constraints.

    # ProvenanceImporter.py
    import json
    import base64
    from typing import Dict, Any
    from cryptography.hazmat.primitives import hashes
    from cryptography.hazmat.primitives.asymmetric import padding
    from cryptography.hazmat.primitives.serialization import load_pem_public_key
    from neo4j import GraphDatabase
    
    
    EXPECTED_SCHEMA = "1.1"
    
    
    class ProvenanceImporter:
        def __init__(self, uri, user, password, public_key_path: str):
            self.driver = GraphDatabase.driver(uri, auth=(user, password))
    
            # Load public key for verifying signatures
            with open(public_key_path, "rb") as f:
                self.public_key = load_pem_public_key(f.read())
    
        def close(self):
            self.driver.close()
    
        def _verify_signature(self, signed_data: str, signature_b64: str) -> bool:
            try:
                signature = base64.b64decode(signature_b64)
                self.public_key.verify(
                    signature,
                    signed_data.encode("utf-8"),
                    padding.PSS(
                        mgf=padding.MGF1(hashes.SHA256()),
                        salt_length=padding.PSS.MAX_LENGTH,
                    ),
                    hashes.SHA256(),
                )
                return True
            except Exception:
                return False
    
        def _validate_jsonl(self, jsonl_path: str):
            """
            Validate schema + signatures before import.
            Returns list of verified payloads (dicts).
            """
            valid_records = []
            with open(jsonl_path, "r") as f:
                for line_no, line in enumerate(f, start=1):
                    try:
                        envelope = json.loads(line.strip())
                    except json.JSONDecodeError:
                        raise ValueError(f"Line {line_no}: invalid JSON")
    
                    schema = envelope.get("schema_version")
                    signed_data = envelope.get("signed_data")
                    signature = envelope.get("signature")
    
                    if schema != EXPECTED_SCHEMA:
                        raise ValueError(f"Line {line_no}: schema version mismatch ({schema})")
    
                    if not signed_data or not signature:
                        raise ValueError(f"Line {line_no}: missing signed_data/signature")
    
                    if not self._verify_signature(signed_data, signature):
                        raise ValueError(f"Line {line_no}: signature verification failed")
    
                    # Verified, safe to parse
                    valid_records.append(json.loads(signed_data))
    
            return valid_records
    
        def import_from_jsonl(self, jsonl_path: str):
            # Validate before importing
            print("🔍 Validating provenance log file...")
            valid_records = self._validate_jsonl(jsonl_path)
            print(f"✅ Validation successful: {len(valid_records)} records")
    
            with self.driver.session() as session:
                self._ensure_constraints(session)
                for record in valid_records:
                    self._process_record(session, record)
    
        def _process_record(self, session, record: Dict[str, Any]):
            if record["type"] == "dataset":
                session.run(
                    """
                    MERGE (d:Dataset {hash: $hash})
                    SET d.path = $path, d.description = $desc, d.timestamp = $ts
                    """,
                    hash=record["hash"],
                    path=record["path"],
                    desc=record.get("description", ""),
                    ts=record["timestamp"],
                )
    
            elif record["type"] == "model":
                session.run(
                    """
                    MERGE (m:Model {name: $name, commit: $commit})
                    SET m.hyperparameters = $hyperparams,
                        m.environment = $env,
                        m.timestamp = $ts
                    """,
                    name=record["model_name"],
                    commit=record.get("git_commit", "N/A"),
                    hyperparams=json.dumps(record.get("hyperparameters", {})),
                    env=json.dumps(record.get("environment", {})),
                    ts=record["timestamp"],
                )
    
                # Multiple dataset links
                for d_hash in record.get("dataset_hashes", []):
                    session.run(
                        """
                        MATCH (d:Dataset {hash: $hash})
                        MATCH (m:Model {name: $name})
                        MERGE (d)-[:USED_IN]->(m)
                        """,
                        hash=d_hash,
                        name=record["model_name"],
                    )
    
            elif record["type"] == "inference":
                session.run(
                    """
                    MERGE (i:Inference {id: $id})
                    SET i.input = $input,
                        i.output = $output,
                        i.timestamp = $ts
                    WITH i
                    MATCH (m:Model {name: $name})
                    MERGE (m)-[:GENERATED]->(i)
                    """,
                    id=record.get("id"),
                    name=record["model_name"],
                    input=json.dumps(record.get("input", {})),
                    output=json.dumps(record.get("output", {})),
                    ts=record["timestamp"],
                )
    
        def _ensure_constraints(self, session):
            """Create uniqueness constraints (idempotent)."""
            session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (d:Dataset) REQUIRE d.hash IS UNIQUE")
            session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (m:Model) REQUIRE (m.name, m.commit) IS UNIQUE")
            session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (i:Inference) REQUIRE i.id IS UNIQUE")
    
    
    # ---------------------------
    # Example usage
    # ---------------------------
    if __name__ == "__main__":
        importer = ProvenanceImporter(
            "bolt://localhost:7687", "neo4j", "password@1234", "public_key.pem"
        )
        try:
            importer.import_from_jsonl("provenance_logs.jsonl")
            print("✅ Imported signed provenance logs into Neo4j with constraints")
        except Exception as e:
            print(f"❌ Import aborted: {e}")
        finally:
            importer.close()

    Step 2. Running Neo4j and Importer

    1. Start Neo4j via Docker:
      docker run --publish=7474:7474 --publish=7687:7687 neo4j:latest
    2. Access the Neo4j Browser at http://localhost:7474
      Default user/pass: neo4j/neo4j (change the password after first login).
    3. Run the importer:
      python ProvenanceImporter.py

    Step 3. Querying Provenance with Cypher

    Fig: Schema Diagram

    3.1 List all datasets

      MATCH (d:Dataset)
      RETURN d.hash AS hash, d.path AS path, d.description AS description, d.timestamp AS logged_at
      ORDER BY logged_at DESC;

      3.2 List all models and their hyperparameters

      MATCH (m:Model)
      RETURN m.name AS model, m.commit AS git_commit,
             m.hyperparameters AS hyperparams,
             m.environment AS env,
             m.timestamp AS logged_at
      ORDER BY logged_at DESC;

      3.3 Show which datasets were used for each model

      MATCH (d:Dataset)-[:USED_IN]->(m:Model)
      RETURN d.hash AS dataset_hash, d.path AS dataset_path,
             m.name AS model, m.commit AS commit
      ORDER BY model;

      3.4 List all inferences with input/output

      MATCH (m:Model)-[:GENERATED]->(i:Inference)
      RETURN i.id AS inference_id, m.name AS model,
             i.input AS input_data, i.output AS output_data, i.timestamp AS ts
      ORDER BY ts DESC;

      3.5 Full provenance lineage (dataset → model → inference)

      MATCH (d:Dataset)-[:USED_IN]->(m:Model)-[:GENERATED]->(i:Inference)
      RETURN d.hash AS dataset_hash, m.name AS model, i.id AS inference_id, i.timestamp AS ts
      ORDER BY ts DESC;

      3.6 Visualize provenance graph

      MATCH (d:Dataset)-[:USED_IN]->(m:Model)-[:GENERATED]->(i:Inference)
      RETURN d, m, i;

      👉 Run this in Neo4j Browser and click the graph view (circle-node visualization).
      You’ll see the chain of custody: Datasets → Models → Inferences.

      3.7 Find models trained on multiple datasets

      MATCH (m:Model)<-[:USED_IN]-(d:Dataset)
      WITH m, collect(d.hash) AS datasets
      WHERE size(datasets) > 1
      RETURN m.name AS model, datasets, size(datasets) AS dataset_count;

      3.8 Check if all models have dataset provenance

      MATCH (m:Model)
      WHERE NOT (m)<-[:USED_IN]-(:Dataset)
      RETURN m.name AS model_without_provenance;

      No changes, no records

      ⚡ With these queries, you can:

      • Audit which dataset versions were used
      • Trace from inference results back to datasets
      • Verify reproducibility and compliance

      What We Achieved

      By combining signed JSONL provenance logs with Neo4j:

      • Schema constraints ensure data integrity.
      • Every record is tamper-resistant (signatures verified before import).
      • Relationships are explicit (USED_IN, GENERATED).
      • Provenance queries are expressive (thanks to Cypher).

      Takeaway: With Neo4j as the provenance store, AI engineers can query, audit, and explain the complete lineage of any model or inference — a vital step toward trustworthy and compliant AI systems.

      Provenance in AI: Tracking AI Lineage with Signed Provenance Logs in Python – Part 2

      In modern AI pipelines, provenance — the lineage of datasets, models, and inferences — is becoming as important as accuracy metrics. Regulators, auditors, and even downstream consumers increasingly demand answers to questions like:

      • Which dataset was this model trained on?
      • What code commit produced this artifact?
      • How do we know logs weren’t tampered with after training?

      To learn more about provenance in AI, read my previous article: Provenance in AI: Why It Matters for AI Engineers – Part 1

      To answer the above-raised questions, let’s walk through a Python-based provenance tracker that logs lineage events, cryptographically signs them, and maintains schema versioning for forward compatibility.

      1. The Provenance Tracker: Key Features

      The ProvenanceTracker implements three important ideas:

      1. Multiple dataset support
        • Models often train on more than one dataset (train + validation + test).
        • This tracker keeps a list of dataset hashes (dataset_hashes) and auto-links them to model logs.
      2. Signed JSONL envelopes
        • Each log entry is wrapped in an envelope:
          {
          "schema_version": "1.1",
          "signed_data": "{…}",
          "signature": ""
          }
        • signed_data is serialized with stable JSON (sort_keys=True).
        • A digital signature (RSA + PSS padding + SHA-256) is generated using a private key.
      3. Schema versioning
        • schema_version = "1.1" is embedded in every record.

      2. The Provenance Tracker: Source Code

      Before we get to the provenance tracker code, let’s see a companion script generate_keys.py that creates the RSA keypair (private_key.pem, public_key.pem). This is used by the ProvenanceTracker.py to sign the JSONL logs.

      # generate_keys.py
      from cryptography.hazmat.primitives.asymmetric import rsa
      from cryptography.hazmat.primitives import serialization
      
      # Generate RSA private key (2048 bits)
      private_key = rsa.generate_private_key(
          public_exponent=65537,
          key_size=2048,
      )
      
      # Save private key (PEM)
      with open("private_key.pem", "wb") as f:
          f.write(
              private_key.private_bytes(
                  encoding=serialization.Encoding.PEM,
                  format=serialization.PrivateFormat.PKCS8,
                  encryption_algorithm=serialization.NoEncryption(),
              )
          )
      
      # Save public key (PEM)
      public_key = private_key.public_key()
      with open("public_key.pem", "wb") as f:
          f.write(
              public_key.public_bytes(
                  encoding=serialization.Encoding.PEM,
                  format=serialization.PublicFormat.SubjectPublicKeyInfo,
              )
          )
      
      print("✅ RSA keypair generated: private_key.pem & public_key.pem")

      Run once to create your keypair:

      python generate_keys.py

      Here’s a secure ProvenanceTracker (schema version 1.1) that:

      • Supports multiple datasets
      • Includes schema version
      • Signs JSONL using RSA private key
      # ProvenanceTracker.py
      import hashlib
      import json
      import os
      import platform
      import socket
      import subprocess
      import base64
      from datetime import datetime
      from typing import Any, Dict, List
      from cryptography.hazmat.primitives import hashes, serialization
      from cryptography.hazmat.primitives.asymmetric import padding
      
      
      class ProvenanceTracker:
          SCHEMA_VERSION = "1.1"
      
          def __init__(self, storage_path: str = "provenance_logs.jsonl", private_key_path: str = "private_key.pem"):
              self.storage_path = storage_path
              self._dataset_hashes: List[str] = []  # track datasets used
              self.private_key = self._load_private_key(private_key_path)
      
          def _load_private_key(self, path: str):
              with open(path, "rb") as f:
                  return serialization.load_pem_private_key(f.read(), password=None)
      
          def _get_git_commit(self) -> str:
              try:
                  return subprocess.check_output(
                      ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL
                  ).decode("utf-8").strip()
              except Exception:
                  return "N/A"
      
          def _hash_file(self, file_path: str) -> str:
              h = hashlib.sha256()
              with open(file_path, "rb") as f:
                  while chunk := f.read(8192):
                      h.update(chunk)
              return h.hexdigest()
      
          def _sign(self, payload: str) -> str:
              signature = self.private_key.sign(
                  payload.encode("utf-8"),
                  padding.PSS(
                      mgf=padding.MGF1(hashes.SHA256()),
                      salt_length=padding.PSS.MAX_LENGTH,
                  ),
                  hashes.SHA256(),
              )
              return base64.b64encode(signature).decode("utf-8")
      
          def _log(self, record: Dict[str, Any]):
              record["timestamp"] = datetime.utcnow().isoformat()
              record["schema_version"] = self.SCHEMA_VERSION
      
              # Serialize signed_data separately (stable JSON encoding)
              signed_data = json.dumps(record, sort_keys=True)
              signature = self._sign(signed_data)
      
              envelope = {
                  "schema_version": self.SCHEMA_VERSION,
                  "signed_data": signed_data,
                  "signature": signature,
              }
      
              with open(self.storage_path, "a") as f:
                  f.write(json.dumps(envelope) + "\n")
      
          def log_dataset(self, dataset_path: str, description: str = ""):
              dataset_hash = self._hash_file(dataset_path)
              self._dataset_hashes.append(dataset_hash)
      
              record = {
                  "type": "dataset",
                  "path": dataset_path,
                  "hash": dataset_hash,
                  "description": description,
              }
              self._log(record)
              return dataset_hash
      
          def log_model(self, model_name: str, hyperparams: Dict[str, Any]):
              record = {
                  "type": "model",
                  "model_name": model_name,
                  "hyperparameters": hyperparams,
                  "git_commit": self._get_git_commit(),
                  "environment": {
                      "python_version": platform.python_version(),
                      "platform": platform.system(),
                      "hostname": socket.gethostname(),
                  },
                  "dataset_hashes": self._dataset_hashes,  # link all datasets
              }
              self._log(record)
      
          def log_inference(self, model_name: str, input_data: Any, output_data: Any):
              record = {
                  "type": "inference",
                  "id": f"inf-{hashlib.sha1(json.dumps(input_data).encode()).hexdigest()[:12]}",  # deterministic ID
                  "model_name": model_name,
                  "input": input_data,
                  "output": output_data,
              }
              self._log(record)
      
      
      if __name__ == "__main__":
          tracker = ProvenanceTracker()
      
          # 1. Log datasets
          ds1 = tracker.log_dataset("data/training.csv", "Customer churn dataset")
          ds2 = tracker.log_dataset("data/validation.csv", "Validation set")
      
          # 2. Log model (linked to all datasets seen so far)
          tracker.log_model("churn-predictor-v2", {
              "algorithm": "XGBoost",
              "n_estimators": 200,
              "max_depth": 12,
          })
      
          # 3. Log inference
          tracker.log_inference(
              "churn-predictor-v2",
              {"customer_id": 54321, "features": [0.4, 1.7, 0.2]},
              {"churn_risk": 0.42}
          )
      
          print("✅ Signed provenance logs recorded in provenance_logs.jsonl")

      3. Under the Hood

      3.1 Datasets

      Datasets are logged with a SHA-256 file hash, ensuring that even if file names change, the integrity check remains stable.

      ds1 = tracker.log_dataset("data/training.csv", "Customer churn dataset")
      ds2 = tracker.log_dataset("data/validation.csv", "Validation set")

      Resulting record (inside signed_data):

      {
        "type": "dataset",
        "path": "data/training.csv",
        "hash": "a41be7b96f...",
        "description": "Customer churn dataset",
        "timestamp": "2025-08-28T10:12:34.123456",
        "schema_version": "1.1"
      }

      3.2 Models

      When logging a model, the tracker attaches:

      • Model metadata (name, hyperparameters)
      • Git commit hash (if available)
      • Runtime environment (Python version, OS, hostname)
      • All dataset hashes seen so far
      tracker.log_model("churn-predictor-v2", {
          "algorithm": "XGBoost",
          "n_estimators": 200,
          "max_depth": 12,
      })

      This creates a strong lineage link:
      Dataset(s)Model

      3.3 Inferences

      Every inference is logged with a deterministic ID, computed as a SHA-1 hash of the input payload. This ensures repeat queries generate the same inference ID (helpful for deduplication).

      tracker.log_inference(
          "churn-predictor-v2",
          {"customer_id": 54321, "features": [0.4, 1.7, 0.2]},
          {"churn_risk": 0.42}
      )

      Graphically:
      ModelInference

      4. Signed Envelopes for Tamper-Proofing

      Each record is not stored raw but wrapped in a signed envelope:

      {
        "schema_version": "1.1",
        "signed_data": "{\"description\": \"Validation set\", \"hash\": \"c62...\"}",
        "signature": "MEUCIQDgtd...xyz..."
      }

      To verify:

      • Load the public key.
      • Verify the signature against the serialized signed_data.
      • If modified, verification fails → tampering detected.

      This is exactly the mechanism PKI systems and blockchain protocols use for immutability.

      5. Example End-to-End Run

      When running ProvenanceTracker.py:

      $ python ProvenanceTracker.py
       Signed provenance logs recorded in provenance_logs.jsonl

      The log file (provenance_logs.jsonl) will contain three signed envelopes — one each for datasets, the model, and an inference.

      Following is provenance_logs.jsonl after run:

      {"schema_version": "1.1", "signed_data": "{\"description\": \"Customer churn dataset\", \"hash\": \"a41be7b96fb85110521bf03d1530879e9ca94b9f5de19866757f6d184300fff7\", \"path\": \"data/training.csv\", \"schema_version\": \"1.1\", \"timestamp\": \"2025-08-28T01:06:31.062695\", \"type\": \"dataset\"}", "signature": "MnCRJ+Acg0F1UledjnMwQMp24wAIPmLPaZonI7hvdFvdi7d8CaZDOIamNq0KnRgcZgttJnI1L675tqT1O1M5N2FRNuy/Wj6elzpyM9w56Kd2mBcQLFumhVHiGZHtwKj2wQtXND0SCqWo5jxQPLPl0dSFClA+FKzpxfazwMtrHAE7aoUmyt2cv1Wiv9uZxsE+Il226J91rBk03lpLcArqqxTtfstkayOK5AON9ETXs65ERf26oURngR/0HS9jnO0IH1DxZOcHcfWZMrLwGqdjRF1sSDYcH70XV61yeYzSeIb8KDODttuxxfzsIlb0897tv/ZZ/X4tv/FFICei7LeAuw=="}
      {"schema_version": "1.1", "signed_data": "{\"description\": \"Validation set\", \"hash\": \"330af932f2dc1cae917f3bd0fb29395c4021319dd906189b7dc257d0ad58a617\", \"path\": \"data/validation.csv\", \"schema_version\": \"1.1\", \"timestamp\": \"2025-08-28T01:06:31.070827\", \"type\": \"dataset\"}", "signature": "pu8IvzPriN6eP9HTQGlIog8nfXV0FOEw818aw6uJS8oPKiQPjN3odzbP9zaeB+ZW4Nu9bBL5fm1btiiOSm9ziWUJWUzFRoHwlYTv2rgp/IXR0oWfTpXsdVeBj7NYVjUywLPofTeEE1C4J7XzZmusuCU9ZiKJzXU442E6Gsrj6tjRJxZoylONuekxegdTot4LwIcmCRtgigi1t3rQYBGdknmTFdW/I2h1Gguh+Shc/WG/jVuMq10vFNNM8iUJJAxAEktbpfhGw0of6lrZu9yn4wAmxvq0DFICKMEJlsyvEZ/mDaPkR4D55xuJh+dLlFbzNZvyw0woMII0hbIarNmG+w=="}
      {"schema_version": "1.1", "signed_data": "{\"dataset_hashes\": [\"a41be7b96fb85110521bf03d1530879e9ca94b9f5de19866757f6d184300fff7\", \"330af932f2dc1cae917f3bd0fb29395c4021319dd906189b7dc257d0ad58a617\"], \"environment\": {\"hostname\": \"GlamorPC\", \"platform\": \"Windows\", \"python_version\": \"3.10.11\"}, \"git_commit\": \"N/A\", \"hyperparameters\": {\"algorithm\": \"XGBoost\", \"max_depth\": 12, \"n_estimators\": 200}, \"model_name\": \"churn-predictor-v2\", \"schema_version\": \"1.1\", \"timestamp\": \"2025-08-28T01:06:31.117627\", \"type\": \"model\"}", "signature": "tq/y6Blz04u2iYZh5OqfyZChADA+osNIzwb9Z2g++AZjFu2hkywazf19rbTMsdx9J5s4BDz6rglfcFczRW/TXMECD3k91ZmAds/e0I+Xw42xeTnr7+jHKq5kPdV6Pan8yFVd9ikGso93ZDatX72rx+orIg41BggFN7ifYlKNnGD87zCypahI7Eev0frnD6w8GybmPcBMnCVLYlIo2nWpLgJELkVpwwagQ9rKA+WOlBbLe41ZizooSL/hhGJOXTuwYrkJpBZ69TIwCzihINr+joZBqYrPF+0E+CFohdc03b0SFv1OuNTo7dyqL9qpWdCMSi1iK0LfCukCO41Bvr2yHA=="}
      {"schema_version": "1.1", "signed_data": "{\"id\": \"inf-0276b2064ad0\", \"input\": {\"customer_id\": 54321, \"features\": [0.4, 1.7, 0.2]}, \"model_name\": \"churn-predictor-v2\", \"output\": {\"churn_risk\": 0.42}, \"schema_version\": \"1.1\", \"timestamp\": \"2025-08-28T01:06:31.118634\", \"type\": \"inference\"}", "signature": "Lf9r1vcXOaCxSc11UKNvuDjx7FophWXBxAobYlixIJgNIk2toFtEdjB2zzJtQI5cYEAImhNHB8hdssKUv3Dths0SpKeMQjpb0x0aKvXolnNsJMnEnGP443IRfMTpkcHpRjCVjIfEvP8EtAh58z4yHE77cy2IlSUFu3exwSEcRFVqBXvIKlojQTEneERUwEDZjfniluomSCLXiVFYMIB+LefPHGkChCVVulmyFJ9ITquD4Wymp2/c2/knopqXSP00EFON4SBOD9/RyQAXAl5UxP0s6faD7NeZxAdJWh3CY31+5V3Vv8b9y/jroAvxWjbpuCZT20gkHemArawDae3s0w=="}

      The following is the standalone verification code validate_logs.py:

      #!/usr/bin/env python3
      """
      Usage:
          python validate_logs.py provenance_logs.jsonl public_key.pem
      """
      
      import json
      import base64
      import sys
      from cryptography.hazmat.primitives import hashes
      from cryptography.hazmat.primitives.asymmetric import padding
      from cryptography.hazmat.primitives.serialization import load_pem_public_key
      
      
      EXPECTED_SCHEMA = "1.1"
      
      
      def load_public_key(path: str):
          with open(path, "rb") as f:
              return load_pem_public_key(f.read())
      
      
      def verify_signature(public_key, signed_data: str, signature_b64: str) -> bool:
          try:
              signature = base64.b64decode(signature_b64)
              public_key.verify(
                  signature,
                  signed_data.encode("utf-8"),
                  padding.PSS(
                      mgf=padding.MGF1(hashes.SHA256()),
                      salt_length=padding.PSS.MAX_LENGTH,
                  ),
                  hashes.SHA256(),
              )
              return True
          except Exception:
              return False
      
      
      def validate_file(jsonl_path: str, pubkey_path: str):
          public_key = load_public_key(pubkey_path)
      
          valid_count = 0
          failed_count = 0
          schema_mismatch = 0
      
          with open(jsonl_path, "r") as f:
              for line_no, line in enumerate(f, start=1):
                  try:
                      envelope = json.loads(line.strip())
                  except json.JSONDecodeError:
                      print(f"❌ Line {line_no}: invalid JSON")
                      failed_count += 1
                      continue
      
                  schema = envelope.get("schema_version")
                  signed_data = envelope.get("signed_data")
                  signature = envelope.get("signature")
      
                  if schema != EXPECTED_SCHEMA:
                      print(f"⚠️  Line {line_no}: schema version mismatch ({schema})")
                      schema_mismatch += 1
                      continue
      
                  if not signed_data or not signature:
                      print(f"❌ Line {line_no}: missing signed_data/signature")
                      failed_count += 1
                      continue
      
                  if verify_signature(public_key, signed_data, signature):
                      valid_count += 1
                  else:
                      print(f"❌ Line {line_no}: signature verification failed")
                      failed_count += 1
      
          print("\n--- Validation Report ---")
          print(f"✅ Valid entries      : {valid_count}")
          print(f"❌ Signature failures : {failed_count}")
          print(f"⚠️  Schema mismatches : {schema_mismatch}")
          print(f"📄 Total lines        : {valid_count + failed_count + schema_mismatch}")
      
      
      if __name__ == "__main__":
          if len(sys.argv) != 3:
              print("Usage: python validate_logs.py provenance_logs.jsonl public_key.pem")
              sys.exit(1)
      
          jsonl_file = sys.argv[1]
          pubkey_file = sys.argv[2]
      
          validate_file(jsonl_file, pubkey_file)

      Output:

      6. Extending This for Real-World AI Workflows

      An AI engineer could extend this design in several directions:

      • Neo4j Importer: Build a provenance graph to visualize dataset → model → inference lineage.
      • Metrics integration: Log evaluation metrics (AUC, F1) into the model record.
      • MLOps pipelines: Integrate into training jobs so every experiment auto-generates signed lineage logs.
      • Cloud KMS keys: Replace PEM private key with keys from AWS KMS, GCP KMS, or HashiCorp Vault.
      • Verification service: Deploy a microservice that validates provenance logs on ingestion.

      7. Why This Matters for You

      As AI systems leave the lab and enter regulated domains (finance, healthcare, insurance), being able to say:

      • “This prediction came from Model X at commit Y, trained on Dataset Z, verified cryptographically.”

      …will be non-negotiable.

      Implementing provenance today sets you ahead of compliance requirements tomorrow.

      This ProvenanceTracker is a blueprint for trustworthy AI engineering — versioned, signed, and reproducible lineage for every dataset, model, and inference in your pipeline.

      Note: The Customer churn dataset can be downloaded from Kaggle and can be renamed and placed in the data directory.

      Provenance in AI: Why It Matters for AI Engineers – Part 1

      1. Introduction: Why AI Needs a Paper Trail

      Imagine debugging a complex AI pipeline without knowing which version of the dataset was used, how the features were preprocessed, or which checkpoint your model came from.

      It feels like trying to fix a car engine blindfolded.

      This is where provenance comes in. In everyday life, provenance means “the origin and history of an object”—like how art collectors care about where a painting was created, who owned it, and how it changed hands.

      In AI, provenance plays the same role: it provides the paper trail of data, models, and inference processes. For engineers, it’s not just a compliance buzzword—it’s the difference between flying blind and having full visibility into your system.

      2. What Do We Mean by Provenance in AI?

      At its core, provenance answers two questions:

      • Where did this come from?
      • What happened to it along the way?

      Breaking it down:

      • Data Provenance – Where the dataset originated (source system, sensor, scraper), how it was cleaned, annotated, or transformed.
      • Model Provenance – Which algorithm, architecture, hyperparameters, code commits, and training checkpoints were used.
      • Inference Provenance – Which input went into the system, which version of the model handled it, and what external knowledge (e.g., retrieved documents for LLMs) influenced the output.

      Think of it like Git for AI systems, but not just code—it’s Git for data, models, and decisions.

      3. Why Engineers Should Care About Provenance

      Let’s be honest—engineers already juggle versioning, monitoring, and debugging. Why add another layer? The answer is: because provenance directly impacts the things engineers care most about such as:

      🔄 Reproducibility

      Ever had a model perform brilliantly during training but fail miserably in production? Without provenance, you won’t know if the issue was due to different data, missing preprocessing, or a silent dependency update.

      🛠 Debugging Failures

      When a fraud detection model misses a case, or an LLM hallucinates, provenance lets you retrace the steps:

      • Was the input preprocessed correctly?
      • Did the model drift due to newer data?
      • Was the wrong model version deployed?

      ✅ Trust and Compliance

      In regulated industries, provenance is not optional. Imagine telling a regulator:

      “We don’t know which dataset our AI was trained on, but trust us—it works.”

      That’s a career-ending statement. Provenance provides the audit trail to show decision accountability.

      👩‍💻 Team Collaboration

      Large AI teams often face the “who changed what?” problem. Provenance provides a shared source of truth, just like version control did for software engineering.

      4. Best Practices: How to Build Provenance into Your AI Stack

      Here’s how engineers can start today:

      1. Data Lineage Tracking

      • Store dataset hashes, schema versions, and preprocessing scripts.
      • Tools: Pachyderm, Delta Lake, Weights & Biases.

      2. Model Lineage

      • Version every model artifact.
      • Log hyperparameters, training environment (Docker image, dependencies), and code commit hash.
      • Tools: MLflow, DVC, Hugging Face Hub.

      3. Inference Logging

      • Record input queries, model version, and outputs.
      • For LLMs: capture prompt templates and retrieved context documents (this is sometimes called Retrieval Provenance).

      4. Cryptographic Provenance (Next Frontier)

      • Use hashing and digital signatures to verify datasets and models.
      • Standards like W3C PROV-O and NIST AI RMF are moving toward cryptographic provenance.

      5. Automate It

      Don’t rely on engineers remembering to log everything. Instead:

      • Make provenance tracking a default part of pipelines (Airflow, Kubeflow).
      • Integrate it into CI/CD for ML (MLOps pipelines).

      6. Open-Source Tools for AI Provenance & Metadata Tracking

      Tool / PlatformTypeDescription
      MLflowOpen-sourceExperiment tracking, model registry, lifecycle metadata
      DVCOpen-sourceData/model versioning with Git integration
      AiiDAOpen-sourceProvenance graph for end-to-end workflows (scientific)
      OpenMetadata + MarquezOpen-sourceData lineage with UI and API; supports column-level tracking
      TribuoOpen-sourceJava ML library with built-in provenance
      AtlasOpen-sourceTransparency and verifiable ML pipelines
      PROV-AGENTOpen-sourceProvenance tracking for AI agent workflows
      ProMLOpen-sourceBlockchain-backed ML provenance platform
      VamsaOpen-sourceAutomated feature/data usage provenance in Python scripts
      Collective KnowledgeOpen-sourceReproducible experiment packaging, FAIR workflows
      Neptune.aiCommercialCollaboration-focused experiment tracking with lineage
      Weights & BiasesCommercialRich dashboards, experiment tracking, lineage, auditability
      Fiddler / IBM OpenScaleCommercialRich dashboards, experiment tracking, lineage, and auditability

      7. Real-World Examples

      • Google’s Model Cards – Provide structured metadata about a model’s context, limitations, and evaluation.
      • OpenAI’s System Cards – Disclose training data categories, design choices, and safety mitigations.
      • Financial Services – Provenance helps auditors verify that a credit-scoring model wasn’t biased due to faulty data.
      • Healthcare AI – Every step from raw clinical data → feature engineering → model inference must be traceable for FDA compliance.

      8. Challenges in Provenance

      Of course, provenance isn’t free. Engineers face:

      • Storage Overhead – Lineage metadata can grow faster than the datasets themselves.
      • Standardization Gaps – No single accepted way to store provenance across frameworks.
      • Privacy Risks – Detailed provenance may unintentionally expose sensitive information (e.g., training data sources).

      9. The Road Ahead

      The future of provenance in AI looks a lot like the early days of DevOps:

      • Standardization – Expect industry-wide adoption of W3C PROV-O, NIST RMF, and EU AI Act requirements.
      • Framework Integration – PyTorch, TensorFlow, and Hugging Face will likely include built-in provenance logging.
      • Verification – Blockchain and cryptographic fingerprints may guarantee tamper-proof provenance trails.

      In short: provenance will become a first-class engineering practice, just like CI/CD, monitoring, and version control.

      10. Closing Thoughts

      For AI engineers, provenance isn’t academic jargon—it’s the foundation for trustworthy, reproducible, and maintainable AI systems.

      Think of it this way:

      • In software engineering, we wouldn’t dream of working without Git.
      • In AI engineering, provenance will play the same role—giving us visibility, accountability, and control over increasingly complex systems.

      LLMs for SMEs – 001: How Small Businesses Can Leverage AI Without Cloud Costs

      1. Introduction

      Ravi runs a small auto parts shop in Navi Mumbai. His day starts at 8 AM, but even before he lifts the shutter, his phone is already buzzing. Customers want to know if a specific part is in stock. A supplier has sent an invoice that needs checking. A potential buyer has emailed asking for a quote — marked urgent.

      By the time Ravi responds to everyone, he’s drained — and the shop hasn’t even opened.

      For many small business owners like him, this is daily life: endless tasks, limited hands, tight margins. Hiring more staff isn’t feasible. Outsourcing feels expensive. And AI? That’s something only massive corporations with Silicon Valley budgets could afford — or so Ravi thought.

      What if he could have his own digital assistant — one that never sleeps, never complains, and works at a fraction of the cost?

      This is where Large Language Models (LLMs) come in. Once the playground of tech giants, LLMs are now accessible, affordable, and practical for small and medium enterprises (SMEs). Even better: they don’t always need the cloud.

      This is Ravi’s story — and the story of thousands of SMEs discovering how AI can help them grow without burning holes in their pockets.

      2. Why SMEs Need LLMs

      Ravi isn’t alone.

      • Meera, who runs a boutique travel agency in Jaipur, spends hours daily answering the same visa questions on WhatsApp.
      • Arjun, who owns a logistics firm in Pune, is buried under compliance paperwork.
      • Neha, who manages a clothing boutique in Delhi, struggles to keep up with customer queries across Instagram, WhatsApp, and email.

      Different businesses. Same problem: limited people, unlimited expectations.

      Customers today demand instant replies, 24/7 support, and professional service. SMEs can’t afford large teams or call centers, leading to lost sales and unhappy customers.

      LLMs flip this equation. They act as digital force multipliers by:

      • Handling FAQs instantly
      • Drafting emails and replies
      • Translating into local languages
      • Summarizing lengthy documents
      • Helping staff find knowledge quickly

      It’s not about replacing people. It’s about amplifying small teams so they can focus on growth, not grunt work.

      3. Breaking the Myth: AI Isn’t Just for Big Companies

      When Ravi first heard of AI chatbots, he imagined giant servers, complicated code, and lakhs of rupees in cloud bills. “AI is for Tatas and Birlas, not a six-person shop like mine,” he thought.

      But that’s a myth.

      Today, open-source LLMs like LLaMA, Qwen, Phi, and Mistral are lightweight and efficient. With the right setup, they can run on a mid-range workstation or even a laptop. No massive infrastructure required.

      Even better, local deployment means data stays private. Ravi’s customer information never leaves his shop — unlike cloud services that often raise data concerns.

      AI is no longer just for big players. SMEs can play too — and win.

      4. Practical Use Cases for SMEs

      a) Customer Support Chatbot for FAQs

      Every day Ravi’s shop gets the same questions:
      “Do you deliver outside Navi Mumbai?”
      “What’s the warranty on this clutch plate?”
      “Can I return a faulty part?”

      Earlier, Ravi or his assistants had to stop mid-task to reply — sometimes late at night.

      Now, an LLM-powered chatbot (trained on his product catalog and policies) answers instantly, politely, and accurately. Ravi only steps in when a query is complex, like bulk orders. His team saves energy for meaningful interactions.

      b) Writing Product Descriptions & Marketing Content

      Ravi always struggled with writing product listings. Manufacturer descriptions were too technical, and leaving blanks made his catalog look unprofessional.

      With LLMs, he simply uploads product specs, and in seconds gets customer-friendly text:

      • Before: “Voltage: 220V, RPM: 1000, Plastic body.”
      • After: “A lightweight 220V drill machine designed for everyday use. Perfect for DIY projects, with a sturdy body and reliable performance.”

      The same tool drafts Facebook posts and promotional SMS messages, helping him market like a pro without hiring an agency.

      c) Translating Offers into Local Languages

      One day a customer said, “Bhaiya, sab English mein likha hai. Hindi mein batao na.”

      That’s when Ravi realized half his customers weren’t comfortable with English. With an LLM, he translated offers into Hindi and Marathi, making messages inclusive and relatable.

      Result? Customers felt understood. Competitors still sent everything in English.

      Meera, the travel agent, does the same — sending brochures in Hindi, Gujarati, and Bengali to expand her customer base.

      d) Summarizing Compliance & Legal Documents

      Arjun, the logistics owner, used to spend evenings wrestling with GST notices and government circulars. Now he uploads PDFs to an LLM and asks simple questions like:

      • “What’s the penalty if I miss the deadline?”
      • “Which rules apply for turnover under ₹5 crore?”

      The AI explains in plain language, cutting dependency on costly consultants. Ravi uses the same approach with supplier contracts, finally understanding terms before signing.

      e) Training New Employees with Company Knowledge

      Every new hire meant hours of Ravi’s time explaining policies:

      • Fast-moving products
      • Discount rules
      • Return process

      Now, Ravi loads this knowledge into an LLM assistant. New employees ask the AI instead of interrupting him 20 times a day.

      Onboarding is faster, consistent, and less stressful. Meera also uses this to train interns at her travel agency.

      5. The Road Ahead for Ravi and SMEs

      Ravi’s journey is just beginning. His auto parts shop still has the same tight space, same six people, same crowded Navi Mumbai street. But with AI, he’s no longer drowning in repetitive tasks. He spends more time negotiating with suppliers, building customer relationships, and planning how to expand.

      For SMEs everywhere, the message is clear: AI is no longer a luxury — it’s a necessity.

      The road ahead won’t be without challenges:

      • Choosing the right tools
      • Training staff to use them
      • Balancing automation with human touch

      But SMEs that embrace AI early will stand out — more efficient, more responsive, and more competitive.

      And for Ravi, the tired shopkeeper who once thought AI was out of reach, the future suddenly feels a lot more manageable — and exciting.

      LLM-Powered Chatbots: A Practical Guide to User Input Classification and Intent Handling

      1. Introduction

      If you’ve ever built a chatbot that confidently answered the wrong question, you know the pain of poor intent detection. Imagine a user typing:

      “Block my debit card immediately.”

      If your chatbot treats that as a generic banking query instead of an urgent fraud request, the experience goes from frustrating to dangerous.

      This is where intent classification comes in. Whether you’re building an Dummy Bank banking assistant, a customer service bot, or an internal support tool, correctly classifying user input before handing it off to a Large Language Model (LLM) is key to delivering fast, accurate, and safe responses.

      In this guide, we’ll break down how to:

      • Detect user intent using three practical approaches — Fine-tuned models, Zero-shot LLMs, and Few-shot LLMs.
      • Route each intent to the right handler function for execution.
      • Apply these methods to a banking domain example that developers can adapt for their own projects.

      2. Chatbot Intent Classification Pipeline

      Here’s the high-level workflow you’ll implement:

      1. Input Reception – The chatbot receives the raw user message.
      2. Preprocessing – Normalize text (lowercasing, punctuation handling, tokenization).
      3. Intent Classification – Use ML or LLM to predict the most likely intent (e.g., check_balance, block_card).
      4. Handler Mapping – Map the predicted intent to a specific function in your codebase.
      5. Response Generation – Call the handler, optionally using an LLM to format or elaborate the output.

      Below is a simplified diagram of the pipeline:

      Flow of Intent Classification + Handler in LLM-Based Chatbot

      By the end of this article, you’ll not only understand the theory but also have ready-to-run code for all three approaches, along with tips for choosing the right one for your use case.

      2. Why Intent Classification is Important for Chatbots

      Banking customers expect fast and accurate responses. A chatbot without intent classification would behave like a generic Q&A bot—it might give unrelated or vague answers.

      With intent classification, the chatbot can:

      1. Identify the exact customer need (e.g., “Check account balance”)
      2. Route the request to the right handler
      3. Provide accurate, domain-specific responses

      Example:

      • Query: “What’s my savings account balance?”
      • Without intent classification → Might return a random banking FAQ answer
      • With intent classification → Identifies as “Check_Balance” and fetches live balance

      3. Flow of Intent Classification + Handler in LLM-Based Chatbot

      Let’s understand the flow of pipeline step by step:

      3.1 User Input

      Example: “Transfer ₹5000 to my savings account”

      What to consider:

      • Input may come from different channels: web chat, mobile app, voice → convert ASR result to text for voice.
      • Record metadata (user_id, session_id, channel, timestamp) for auditing and debugging.

      Following is the example message envelope (JSON):

      {
        "user_id": "user-123",
        "session_id": "sess-456",
        "channel": "mobile",
        "text": "Transfer ₹5000 to my savings account",
        "timestamp": "2025-08-12T09:10:00+05:30"
      }

      3.2 Preprocessing (cleaning & normalization)

      Goals: reduce noise, normalize currency/amounts, expand abbreviations, correct obvious typos.

      Common steps:

      • Trim/normalize whitespace, unicode, punctuation.
      • Normalize currency tokens → ₹5000 → numeric 5000.00 plus currency field.
      • Mask or redact PII(Personally Identifiable Information) for logs (partial redaction), but keep full data for the handler (in secure memory).
      • Language detection / transliteration (if supporting multi-lingual inputs).

      Example amount normalization:

      def parse_amount(text):
          # very small heuristic example
          match = re.search(r'₹\s?([\d,]+)', text)
          if match:
              return float(match.group(1).replace(',', ''))
          return None

      If preprocessing discovers ambiguity (e.g., no amount present), mark for clarification.

      3.3 LLM-based Intent Classification

      You use an LLM (zero-shot, few-shot, or fine-tuned) to predict intent. Important production details:

      • Return both predicted_intent and confidence_score.
      • Thresholds: If confidence < threshold (e.g., 0.6), ask a clarifying question or fallback to a smaller model / human.
      • Entity hints: LLM can also return entities (amount, target_account, account_type) to speed up pipeline.

      Example classifier output:

      {
        "predicted_intent": "Fund_Transfer",
        "confidence": 0.92,
        "entities": {
          "amount": 5000.0,
          "currency": "INR",
          "target_account": "savings",
          "recipient_name": null
        }
      }

      Confidence handling:

      if confidence < 0.6:
          ask_clarification("Do you want to transfer money? Please confirm amount and recipient.")

      3.4 Intent Validation & Slot / Entity Extraction

      Before routing to the handler, validate entities and fill missing slots.

      Steps:

      • Validate amount > 0 and within user limits.
      • Resolve ambiguous targets (“my savings account” → which account id?).
      • Extract target account number or nickname from user profile.
      • Run fraud checks and quick policy validations (transfer limits, blocked status).

      Entity extraction strategy:

      • Use combined approach: regex rules for amounts/IFSC, lightweight NER model for names/locations, and LLM for tricky phrasings.

      Example check:

      if amount > user.available_balance:
          return "Insufficient funds. Your available balance is ₹X."

      3.5 Handler Mapping (Router)

      Map predicted_intent → handler function. Keep router simple and deterministic.

      intent_router = {
        "Check_Balance": handle_check_balance,
        "Fund_Transfer": handle_fund_transfer,
        "Open_Account": handle_open_account,
        "Loan_Enquiry": handle_loan_enquiry,
        "Card_Block": handle_card_block,
        "Branch_Location": handle_branch_location,
      }
      handler = intent_router[predicted_intent]

      Before calling handler, ensure required slots are present. If not, the handler should initiate a slot-filling dialog (ask for missing info).

      3.6 Handler Execution (example: handle_fund_transfer)

      This is the business logic layer that must be secure, idempotent, auditable, and often synchronous with backend banking APIs.

      Key steps inside handle_fund_transfer:

      1. Authenticate/authorize user (session token, MFA (Multi-Factor Authentication) status).
      2. Validate inputs (amount limits, beneficiary verification).
      3. Pre-checks: AML(Anti-Money Laundering)/fraud checks, transaction velocity checks.
      4. Confirm: If required, ask the user to confirm (show transfer summary).
      5. OTP / 2FA: Request OTP or biometric verification for high-risk transfers.
      6. Call core banking API (use idempotency key).
      7. Handle API errors (retry/backoff, rollback where applicable).
      8. Log & audit: Write transaction record to secure audit trail.
      9. Return structured result (success/fail, transaction id, timestamp).

      Simplified handler:

      def handle_fund_transfer(user_id, amount, target_account):
          # 1. Auth check
          if not is_user_authenticated(user_id):
              return require_login()
      
          # 2. Validate amount and beneficiary
          if amount <= 0 or amount > get_transfer_limit(user_id):
              return "Transfer amount invalid or exceeds limit."
      
          # 3. Sufficient balance
          if amount > get_available_balance(user_id):
              return "Insufficient funds."
      
          # 4. Confirmation & OTP flow
          confirmation = ask_user_confirmation(amount, target_account)
          if not confirmation:
              return "Transfer cancelled."
      
          if requires_otp(amount):
              otp_ok = verify_otp(user_id)
              if not otp_ok:
                  return "OTP validation failed."
      
          # 5. Call bank API with idempotency_key
          tx = call_core_banking_transfer(user_id, amount, target_account, idempotency_key=uuid4())
          if tx.success:
              audit_log("transfer", user_id, amount, target_account, tx.id)
              return f"₹{amount} transferred successfully. Transaction ID: {tx.id}"
          else:
              handle_failure(tx)
              return "Transfer failed. Please try again or contact support."

      Idempotency: always pass unique idempotency keys to avoid duplicate transfers on retries.

      3.7 Response Generation

      The handler returns a structured response. The response generator formats it for the user, optionally uses LLM to produce friendly wording.

      Example final message:

      • "₹5000 transferred successfully to your savings account. Transaction ID TXN12345. Would you like a receipt via SMS?"

      Make sure the message:

      • Avoids leaking sensitive data (full account numbers).
      • Provides transaction reference and next steps.

      3.8 Auditing, Logging & Compliance

      Banking requires strict logs and retention policies.

      • Log: user_id, session_id, intent, entities (redacted in logs), handler invoked, API responses, timestamps, geolocation if relevant.
      • Audit trail must be tamper-resistant (write-once logs or append-only store).
      • GDPR/RBI compliance: minimize PII storage; use encryption-at-rest & in-transit.

      Audit record example:

      {
        "event":"fund_transfer",
        "user_id":"user-123",
        "amount":5000,
        "target":"savings",
        "tx_id":"TXN12345",
        "timestamp":"2025-08-12T09:10:15+05:30"
      }

      3.9 Error Handling & Fallbacks

      • Low classifier confidence → ask clarifying question or route to human agent.
      • API failures → retry with exponential backoff, provide user-friendly error, log incident.
      • Security checks fail → escalate to fraud queue, block transaction if necessary.
      • Unrecognized intent → route to fallback intent or handover to live agent.

      3.10 Monitoring & Metrics

      Track these to measure health and improve models:

      • Intent classification accuracy, confusion matrix
      • Avg pipeline latency (preprocessing → final response)
      • Handler success rate (e.g., transfer success %)
      • Human-handover rate
      • False positives for high-risk intents

      Use these logs to improve training data and to retrain periodically.

      3.12 Security & Privacy Checklist (banking)

      • Enforce strong authentication (session tokens, MFA) before sensitive handlers.
      • Mask or avoid logging full account numbers/PINs.
      • Use secure channels & encryption for all backend calls.
      • Implement rate limits & anomaly detection to prevent abuse.

      3.11 Continuous Learning & Retraining

      • Capture misclassifications and ambiguous interactions; add them to a labeled dataset.
      • Schedule periodic retraining for the fine-tuned model or update few-shot examples for LLM prompts.
      • A/B test classifier changes in a staging environment before rolling to production.

      3.12 Security & Privacy Checklist (banking)

      • Enforce strong authentication (session tokens, MFA) before sensitive handlers.
      • Mask or avoid logging full account numbers/PINs.
      • Use secure channels & encryption for all backend calls.
      • Implement rate limits & anomaly detection to prevent abuse.

      3.13 Quick end-to-end example (summary)

      1. User: "Transfer ₹5000 to my savings account"
      2. Preprocess → extract amount=5000, target=savings
      3. LLM classifier → Fund_Transfer (confidence 0.93)
      4. Router → handle_fund_transfer()
      5. Handler validates, asks OTP, calls bank API with idempotency key
      6. Response → "₹5000 transferred successfully. TXN12345."
      7. Audit log written and user notified

      4. Banking Intent Dataset Example

      To train or evaluate an intent classification system for a banking chatbot, you will need a well-structured dataset that captures the variety of ways users might express their requests. Below is a sample dataset for training/testing your banking chatbot intent classifier.

      Intent NameExample Queries
      Check_Balance“What is my account balance?”, “Show my savings account balance”, “Check my current balance”
      Fund_Transfer“Transfer ₹5000 to my savings account”, “Send ₹2000 to John”, “Make a transfer to account 123456789”
      Open_Account“How can I open a savings account?”, “Start new account application”, “I want to open an account”
      Loan_Enquiry“Tell me about home loan interest rates”, “Apply for personal loan”, “Loan eligibility for ₹10 lakh”
      Card_Block“Block my debit card”, “My ATM card is lost”, “Stop transactions from my credit card”
      Branch_Location“Nearest Dummy Bank branch”, “Where is the closest Dummy Bank ATM?”, “Find a branch near me”

      5. Intent Handlers for Banking Chatbot

      Once an intent is correctly identified by the classifier, the chatbot needs to decide what to do next. This is where intent handlers come into play. An intent handler is a function or module responsible for executing the specific action linked to an intent. In a banking chatbot, each intent can have a dedicated handler that connects to backend services (like Dummy Bank’s core banking system), retrieves or updates data, and formats the response for the user.

      Example handlers:

      • handle_check_balance() – Connects to the user’s account system, fetches the latest balance, and presents it in a friendly message.
      • handle_fund_transfer() – Validates account details, initiates the transfer, confirms the transaction status, and logs it for auditing.
      • handle_open_account() – Guides the user through the required KYC steps, generates a reference number, and schedules a branch visit if needed.
      • handle_loan_enquiry() – Checks loan eligibility, fetches applicable loan rates, and provides repayment schedules.
      • handle_card_block() – Immediately blocks the reported card, sends confirmation via SMS/email, and prompts the user for reissue options.
      • handle_branch_location() – Uses a geolocation API to find the nearest branch or ATM based on the user’s location.

      In well-structured chatbots, these handlers are modular and reusable. They can also be enriched with context awareness (e.g., remembering the user’s last transaction) and security layers (e.g., OTP verification before fund transfer). This separation of intent detection and intent handling ensures that the chatbot remains scalable, secure, and easy to maintain.

      Following is the sample simulated code above mentioned handlers:

      def handle_check_balance(user_id):
          # Simulated balance fetch
          return f"Your account balance is ₹25,340."
      
      def handle_fund_transfer(user_id, amount, target_account):
          # Simulated transfer
          return f"₹{amount} has been transferred to account {target_account}."
      
      def handle_open_account():
          return "To open a savings account, please visit your nearest Dummy Bank branch or apply online at dummy.bank.co.in."
      
      def handle_loan_enquiry(loan_type="home"):
          return f"The current {loan_type} loan interest rate is 8.25% p.a. You can apply via the Dummy Bank website."
      
      def handle_card_block(card_type="debit"):
          return f"Your {card_type} card has been blocked. A replacement will be sent to your registered address."
      
      def handle_branch_location(pincode):
          return f"The nearest Dummy Bank branch to pincode {pincode} is at Main Market Road, Sector 15."

      6. Training the Intent Classifier

      Training an intent classifier involves teaching a model to correctly identify a user’s goal from their query. This process starts with collecting representative training data for each intent category, followed by preprocessing the text for tokenization. The model is then trained on these labeled examples, learning patterns and keywords associated with each intent. Once trained, the classifier can quickly and accurately predict intents for new, unseen queries, enabling downstream applications like chatbots and virtual assistants to respond appropriately. Regular retraining with fresh data helps maintain accuracy as user behavior and language evolve.

      6.1 Fine-tune a smaller model like distilbert-base-uncased for intent classification

      Fine-tuning a lightweight model such as distilbert-base-uncased is an efficient way to build a high-performance intent classifier without the computational overhead of large LLMs. DistilBERT retains much of BERT’s language understanding capability while being faster and more resource-friendly, making it ideal for deployment in production environments with limited hardware. By training it on domain-specific data—such as banking-related queries for Dummy Bank—it can achieve high accuracy in recognizing intents like Check_Balance, Fund_Transfer, or Card_Block. This approach combines speed, cost-effectiveness, and adaptability.

      Example code:

      import pandas as pd
      from sklearn.model_selection import train_test_split
      from datasets import Dataset
      from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments
      import torch
      
      # ---------------------------
      # 1. Example Dataset
      # ---------------------------
      data = [
          # Check_Balance
          ("What is my current account balance?", "Check_Balance"),
          ("Show me my savings balance", "Check_Balance"),
          ("How much money is in my account?", "Check_Balance"),
      
          # Fund_Transfer
          ("Transfer ₹5000 to my brother's account", "Fund_Transfer"),
          ("Send 2000 rupees to account 1234567890", "Fund_Transfer"),
          ("Make a payment to Ramesh", "Fund_Transfer"),
      
          # Open_Account
          ("I want to open a new savings account", "Open_Account"),
          ("How can I open a current account?", "Open_Account"),
          ("Open an account for me", "Open_Account"),
      
          # Loan_Enquiry
          ("Tell me about home loan interest rates", "Loan_Enquiry"),
          ("What is the EMI for a 5 lakh personal loan?", "Loan_Enquiry"),
          ("How can I apply for a car loan?", "Loan_Enquiry"),
      
          # Card_Block
          ("Block my debit card immediately", "Card_Block"),
          ("I lost my credit card, please block it", "Card_Block"),
          ("Block my ATM card", "Card_Block"),
      
          # Branch_Location
          ("Where is the nearest Dummy Bank branch?", "Branch_Location"),
          ("Find me a branch near Andheri", "Branch_Location"),
          ("Locate the closest ATM", "Branch_Location"),
      ]
      
      df = pd.DataFrame(data, columns=["text", "label"])
      
      # ---------------------------
      # 2. Encode Labels
      # ---------------------------
      label_list = df["label"].unique().tolist()
      label2id = {label: idx for idx, label in enumerate(label_list)}
      id2label = {idx: label for label, idx in label2id.items()}
      
      df["label_id"] = df["label"].map(label2id)
      
      # ---------------------------
      # 3. Train-Test Split
      # ---------------------------
      train_texts, val_texts, train_labels, val_labels = train_test_split(
          df["text"], df["label_id"], test_size=0.2, random_state=42
      )
      
      train_df = pd.DataFrame({"text": train_texts, "label": train_labels})
      val_df = pd.DataFrame({"text": val_texts, "label": val_labels})
      
      # ---------------------------
      # 4. Convert to Hugging Face Dataset
      # ---------------------------
      train_dataset = Dataset.from_pandas(train_df)
      val_dataset = Dataset.from_pandas(val_df)
      
      # ---------------------------
      # 5. Tokenization
      # ---------------------------
      tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
      
      def tokenize(batch):
          return tokenizer(batch["text"], padding=True, truncation=True, max_length=64)
      
      train_dataset = train_dataset.map(tokenize, batched=True)
      val_dataset = val_dataset.map(tokenize, batched=True)
      
      # ---------------------------
      # 6. Load Model
      # ---------------------------
      model = DistilBertForSequenceClassification.from_pretrained(
          "distilbert-base-uncased",
          num_labels=len(label_list),
          id2label=id2label,
          label2id=label2id
      )
      
      # ---------------------------
      # 7. Training Arguments
      # ---------------------------
      training_args = TrainingArguments(
          output_dir="./intent_classifier_model",
          evaluation_strategy="epoch",
          save_strategy="epoch",
          learning_rate=5e-5,
          per_device_train_batch_size=8,
          per_device_eval_batch_size=8,
          num_train_epochs=5,
          weight_decay=0.01,
          logging_dir="./logs",
          logging_steps=10,
          load_best_model_at_end=True
      )
      
      # ---------------------------
      # 8. Trainer
      # ---------------------------
      def compute_metrics(eval_pred):
          from sklearn.metrics import accuracy_score, f1_score
          logits, labels = eval_pred
          preds = logits.argmax(axis=-1)
          return {
              "accuracy": accuracy_score(labels, preds),
              "f1": f1_score(labels, preds, average="weighted")
          }
      
      trainer = Trainer(
          model=model,
          args=training_args,
          train_dataset=train_dataset,
          eval_dataset=val_dataset,
          tokenizer=tokenizer,
          compute_metrics=compute_metrics
      )
      
      # ---------------------------
      # 9. Train
      # ---------------------------
      trainer.train()
      
      # ---------------------------
      # 10. Test Prediction
      # ---------------------------
      test_queries = [
          "Please transfer 1000 rupees to my son's account",
          "Find me the nearest dummy bank branch in Pune",
          "I lost my ATM card",
          "Show me my account balance"
      ]
      
      tokens = tokenizer(test_queries, padding=True, truncation=True, return_tensors="pt")
      outputs = model(**tokens)
      predictions = torch.argmax(outputs.logits, dim=-1)
      
      for query, pred_id in zip(test_queries, predictions):
          print(f"Query: {query} -> Intent: {id2label[pred_id.item()]}")

      Expected output:

      Query: Please transfer 1000 rupees to my son's account -> Intent: Fund_Transfer
      Query: Find me the nearest Dummy bank branch in Pune -> Intent: Branch_Location
      Query: I lost my ATM card -> Intent: Card_Block
      Query: Show me my account balance -> Intent: Check_Balance

      6.2 LLM-based Intent Classification (Zero-shot classification) using Hugging Face pipeline

      Zero-shot intent classification leverages the language understanding power of large language models to identify user intents without any task-specific training data. Using Hugging Face’s pipeline API, we can provide the model with a query and a list of possible intent labels, and it will determine the most likely match based on its vast pre-trained knowledge. This approach is especially useful for quickly deploying chatbots in domains like banking, where intents (e.g., Check_Balance, Fund_Transfer, Card_Block) can be recognized instantly, even if no historical data is available for those categories.

      Example Code:

      from transformers import pipeline
      
      # Banking intents
      intents = [
          "Check_Balance",
          "Fund_Transfer",
          "Open_Account",
          "Loan_Enquiry",
          "Card_Block",
          "Branch_Location"
      ]
      
      classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
      
      query = "Transfer ₹5000 to my savings account"
      result = classifier(query, candidate_labels=intents)
      
      predicted_intent = result['labels'][0]
      print("Predicted Intent:", predicted_intent)

      Sample Output

      Predicted Intent: Fund_Transfer

      6.3 LLM-based Intent Classification (Few-shot classification) using Hugging Face pipeline

      Few-shot classification leverages the power of Large Language Models to accurately predict intents with only a handful of labeled examples per category. Instead of training a model from scratch, we simply provide the LLM with sample queries for each intent along with the user’s new query. Using the Hugging Face pipeline, the LLM applies its vast language understanding to match the query to the closest intent, even if the wording is unfamiliar. This approach is fast to implement, requires minimal data, and works particularly well for domains like banking where intent categories are clearly defined.

      Example Code:

      from transformers import pipeline
      
      # Banking intents
      intents = [
          "Check_Balance",
          "Fund_Transfer",
          "Open_Account",
          "Loan_Enquiry",
          "Card_Block",
          "Branch_Location"
      ]
      
      # Few-shot examples for better classification
      examples = [
          ("Show me my account balance", "Check_Balance"),
          ("Please transfer ₹2000 to Ramesh's account", "Fund_Transfer"),
          ("I want to apply for a home loan", "Loan_Enquiry"),
          ("I lost my debit card, please block it", "Card_Block"),
          ("Where is the nearest Dummy bank branch in Delhi?", "Branch_Location"),
      ]
      
      # Create the few-shot prompt
      def build_few_shot_prompt(query):
          prompt = "Classify the following customer queries into one of these intents:\n"
          prompt += ", ".join(intents) + "\n\n"
          prompt += "Examples:\n"
          for ex_query, ex_intent in examples:
              prompt += f"Query: {ex_query}\nIntent: {ex_intent}\n\n"
          prompt += f"Query: {query}\nIntent:"
          return prompt
      
      query = "Transfer ₹5000 to my savings account"
      prompt = build_few_shot_prompt(query)
      
      # Using a text-generation pipeline (could be GPT-like model)
      generator = pipeline("text-generation", model="meta-llama/Llama-2-7b-chat-hf", device_map="auto")
      
      response = generator(prompt, max_new_tokens=10, temperature=0.0)
      predicted_intent = response[0]['generated_text'].split("Intent:")[-1].strip()
      
      print("Predicted Intent:", predicted_intent)

      6.4 Comparision of LLM-based Intent Classification (Zero-shot vs. Few-shot classification)

      Zero-ShotFew-Shot
      No examples given; model must guess purely from intent names.Provides a few labeled examples so the model learns the style and meaning of intents before predicting.
      Works okay for common phrasing but may fail on domain-specific terms.More accurate for banking-specific terms (e.g., RD account, cheque book).
      Simpler but less controlled.Slightly more work to prepare, but boosts accuracy.

      6.5 Comparision of Fine-Tuning a Smaller Model for Intent Classification and LLM-Based Intent Classification

      Feature / CriteriaLLM-Based Intent ClassificationFine-Tuned Smaller Model (e.g., DistilBERT)
      Training Data RequirementCan work zero-shot (no training data needed for new intents).Requires labeled training data for all intents.
      FlexibilityHandles multiple phrasings and unseen variations well.Performs best on phrasings seen during training; less robust to unexpected inputs.
      Domain AdaptabilityAdapts quickly to new banking terms without retraining.Needs retraining to add or modify intents.
      Inference SpeedSlower (especially large models like GPT or LLaMA) — may need GPU.Fast (can run on CPU), ideal for real-time responses.
      Hosting CostHigh — requires GPU or expensive API usage.Low — can run on inexpensive servers or on-premise hardware.
      Privacy & ComplianceOften cloud-hosted → possible compliance issues unless using on-prem LLM.Easy on-prem deployment, ensuring customer data never leaves the bank’s network.
      Accuracy for Fixed IntentsMay misclassify if intent phrasing is too vague or similar to others.Very high accuracy for trained intents (e.g., Check_Balance, Card_Block).
      Hallucination RiskHigher — might output unrelated intents or responses.Lower — restricted to predefined set of intents.
      MaintenanceEasy to add new intents without retraining.Adding new intents requires retraining the model.

      8. Conclusion

      In the fast-paced world of digital banking, a chatbot’s ability to accurately identify customer intent is the foundation for delivering seamless, human-like support. Our exploration of intent classification — from fine-tuning smaller models to leveraging powerful LLMs — shows that there’s no one-size-fits-all solution.

      Fine-tuned smaller models like DistilBERT excel in speed, cost-efficiency, and privacy, making them a strong choice for banks that deal with fixed sets of intents and require on-premises deployment. LLM-based approaches, on the other hand, offer unmatched flexibility, adaptability to new domains, and zero-shot capabilities — perfect for scenarios where customer queries evolve quickly or domain-specific terms frequently emerge.

      Ultimately, the best approach depends on your priorities:

      • If cost, privacy, and speed are paramount, go for a fine-tuned smaller model.
      • If adaptability, reduced training overhead, and rapid intent expansion are more important, LLM-based classification is the way forward.

      By choosing the right intent classification strategy, banks can ensure their chatbots not only respond faster but also understand customers better — building trust, improving satisfaction, and making every digital interaction as smooth as talking to a trusted branch representative.

      Reranking for RAG: Boosting Answer Quality in Retrieval-Augmented Generation

      Retrieval-Augmented Generation (RAG) is one of the most effective techniques for making large language models (LLMs) answer accurately using external knowledge.
      The idea is straightforward:

      1. Retrieve relevant documents from your knowledge base.
      2. Augment your LLM prompt with those documents.
      3. Generate an answer using the LLM.

      Sounds simple, right? The problem is:

      Even the best vector search algorithms sometimes return documents that are only loosely related to the query — or miss subtle but highly relevant matches.

      This is where Reranking enters the scene — the “quality filter” for your retrieved documents.

      What is Reranking in RAG?

      Reranking is a second-stage filtering process that reorders retrieved documents by actual relevance to the user query, often using a more sophisticated model than the one used for the initial retrieval.

      Think of it as precision tuning:

      • Stage 1 (vector retrieval) → Fast and broad: retrieve 30–100 potentially relevant docs.
      • Stage 2 (reranking) → Slow but sharp: deeply score these docs for true relevance.

      This two-stage approach mirrors real-world search engines like Google, which first retrieve a broad set of results (recall-focused) and then apply a more precise ranking model (precision-focused).

      This is especially important because standard retrieval models (like BM25, dense embeddings) often prioritize speed over deep contextual matching. Reranking uses more advanced models (like cross-encoders) that compare the query and each document together for higher precision.

      Why Reranking Matters in RAG

      Without reranking, your RAG model might answer from a less relevant document simply because it was retrieved higher by the retriever’s default scoring.

      Example:
      Imagine a customer of the State Bank of India (SBI) asks:
      “What is the minimum balance required for an SBI savings account in a metro city?”

      Without Reranking:

      • Retriever might pull in documents about fixed deposit interest rates, ATM withdrawal limits, and minimum balance rules for rural branches.
      • The first retrieved document might mention “minimum balance” but for rural accounts, not metro city accounts.

      With Reranking:

      • The reranker analyzes the exact query and re-scores documents so that the top-ranked one specifically contains:
        • Metro city rules
        • SBI’s updated minimum balance criteria
        • Correct fee details if balance is below the limit

      This ensures the generator receives the right context and produces a correct answer.

      Common Reranking Techniques

      Here are the most common approaches used in production RAG systems:

      1. Cross-Encoder Models

      • Takes the query and document together as input.
      • Outputs a single relevance score.
      • Pros: Very accurate.
      • Cons: Slower, since each document is scored independently.
      Python Example
      from sentence_transformers import CrossEncoder
      
      # Load a cross-encoder model
      model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
      
      # Example query
      query = "What is the minimum balance required for an SBI savings account in a metro city?"
      
      # Retrieved documents
      documents = [
          "SBI savings account in metro cities requires a minimum balance of Rs. 3,000 to avoid penalties.",
          "SBI fixed deposit interest rates vary between 3% and 6% depending on tenure.",
          "In rural areas, SBI savings accounts require a minimum balance of Rs. 1,000."
      ]
      
      # Prepare pairs for scoring
      pairs = [(query, doc) for doc in documents]
      
      # Score each document for relevance
      scores = model.predict(pairs)
      
      # Sort by score (descending)
      reranked_docs = [doc for _, doc in sorted(zip(scores, documents), reverse=True)]
      
      print("Reranked Documents:")
      for doc in reranked_docs:
          print(doc)

      Sample Output:

      Reranked Documents:
      SBI savings account in metro cities requires a minimum balance of Rs. 3,000 to avoid penalties.
      In rural areas, SBI savings accounts require a minimum balance of Rs. 1,000.
      SBI fixed deposit interest rates vary between 3% and 6% depending on tenure.

      2. Bi-Encoder + Cross-Encoder Hybrid

      • First, a fast bi-encoder retrieves candidates.
      • Then, a cross-encoder reranks the top results.
      • Best of both worlds — speed and accuracy.
      Python Example
      from sentence_transformers import SentenceTransformer, CrossEncoder, util
      import torch
      
      # Step 1: Create SBI corpus
      corpus = [
          "The minimum balance required for SBI savings account is ₹1000 in metro cities.",
          "SBI provides 7.5% interest rate for senior citizen fixed deposits.",
          "You can link your Aadhaar to your SBI account through the YONO app.",
          "SBI charges ₹20 per transaction for ATM withdrawals beyond the free limit.",
          "The SBI home loan interest rate starts from 8.5% per annum.",
          "SBI credit cards offer reward points on every transaction."
      ]
      
      # Step 2: Load Bi-Encoder and Cross-Encoder
      bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')  # For retrieval
      cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')  # For reranking
      
      # Step 3: Encode corpus for Bi-Encoder retrieval
      corpus_embeddings = bi_encoder.encode(corpus, convert_to_tensor=True)
      
      # Step 4: User query
      query = "What is the interest rate for senior citizen FD in SBI?"
      query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
      
      # Step 5: Retrieve top N candidates using Bi-Encoder
      top_k = 3
      bi_encoder_hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
      
      # Step 6: Prepare for Cross-Encoder reranking
      cross_inp = [(query, corpus[hit['corpus_id']]) for hit in bi_encoder_hits]
      cross_scores = cross_encoder.predict(cross_inp)
      
      # Step 7: Combine results and sort by Cross-Encoder score
      reranked_results = sorted(
          zip(cross_inp, cross_scores),
          key=lambda x: x[1],
          reverse=True
      )
      
      # Step 8: Print results
      print(f"Query: {query}\n")
      print("Top Results after Reranking:")
      for (q, passage), score in reranked_results:
          print(f"Score: {score:.4f} | {passage}")

      Sample Output:

      Query: What is the interest rate for senior citizen FD in SBI?
      
      Top Results after Reranking:
      Score: 8.5123 | SBI provides 7.5% interest rate for senior citizen fixed deposits.
      Score: 5.9012 | The SBI home loan interest rate starts from 8.5% per annum.
      Score: 3.2710 | SBI credit cards offer reward points on every transaction.

      3. LLM-based Reranking

      • Uses large language models (e.g., GPT, LLaMA) to rate document relevance.
      • Can understand nuanced and multi-step queries.
      • Higher cost, but sometimes worth it for complex domains.
      Python Example
      from transformers import AutoModelForCausalLM, AutoTokenizer
      import torch
      
      # 1. SBI Corpus
      corpus = [
          "The minimum balance required for SBI savings account in metro cities is ₹3000.",
          "SBI offers a 3.5% interest rate for savings accounts up to ₹1 lakh.",
          "SBI home loan interest rate starts from 8.5% per annum.",
          "SBI fixed deposit for senior citizens offers 7.5% per annum interest."
      ]
      
      # 2. Simulated Retrieval Output
      retrieved_docs = [
          corpus[1],  # savings account interest
          corpus[3],  # senior citizen FD
          corpus[0]   # minimum balance
      ]
      
      query = "What interest rate does SBI offer for fixed deposits for senior citizens?"
      
      # 3. Load Phi-3-Mini-Instruct Model from Hugging Face
      # Supports chat-style prompts with system, user, and assistant roles
      model_name = "microsoft/phi-3-mini-128k-instruct"
      tokenizer = AutoTokenizer.from_pretrained(model_name)
      model = AutoModelForCausalLM.from_pretrained(
          model_name,
          device_map="auto",
          torch_dtype="auto",
          trust_remote_code=True
      )
      
      # 4. Build prompt for reranking
      prompt_prefix = "<|system|>You are an assistant that ranks documents by relevance.<|end|>\n"
      prompt_prefix += f"<|user|>Query: {query}\nDocuments:\n"
      
      for idx, doc in enumerate(retrieved_docs):
          prompt_prefix += f"{idx}: {doc}\n"
      prompt_prefix += "<|assistant|>Provide ranking as list of indexes [most relevant first], plus brief explanation."
      
      # 5. Tokenize and generate
      inputs = tokenizer(prompt_prefix, return_tensors="pt").to(model.device)
      outputs = model.generate(
          **inputs,
          max_new_tokens=100,
          temperature=0.0
      )
      response = tokenizer.decode(outputs[0], skip_special_tokens=True)
      
      print("=== Reranking Response ===")
      print(response)

      Sample Output:

      === Reranking Response ===
      [1, 2, 0]
      The most relevant document is index 1: "SBI fixed deposit for senior citizens offers 7.5% per annum interest." 
      It directly answers the query about FD interest for senior citizens. 
      Next is index 2: "The minimum balance required for SBI savings account in metro cities is ₹3000." 
      While not about fixed deposits, it mentions account-related terms. 
      Index 0: "SBI offers a 3.5% interest rate for savings accounts up to ₹1 lakh." 
      This is least relevant because it talks about savings account rates, not fixed deposit rates.

      Best Practices for Reranking in RAG

      1. Limit the candidate pool — Avoid reranking all retrieved results; rerank only the top N (e.g., 50).
      2. Use domain-specific fine-tuning — Fine-tune reranker models on your domain data for better accuracy.
      3. Cache results — For frequent queries, store reranked results to save computation.
      4. Balance speed vs accuracy — In real-time applications, choose models that meet your latency requirements.
      5. Continuously evaluate — Track metrics like MRR (Mean Reciprocal Rank) and nDCG to measure impact.

      Conclusion

      Reranking acts as a precision filter for RAG pipelines. By ensuring that the right documents make it to the generation stage, you can drastically reduce irrelevant or partially correct answers.

      For any production-grade RAG system — whether it’s for banking FAQs, legal document search, or technical support — reranking can be the key differentiator in delivering high-quality, trustworthy AI answers.

      ChatML: The Structured Language Behind Conversational AI

      If you’ve interacted with ChatGPT or built your own conversational AI, you might have wondered — how exactly does the AI know which parts of a message are from the user, which are from the system, and which are from the assistant?

      Behind the scenes, OpenAI uses a simple but powerful markup format called ChatML (Chat Markup Language) to structure conversations. While it originated with OpenAI’s models, similar role-based message formatting is now used or adapted by other large language models as well — for example, Anthropic Claude, Qwen, Mistral, and various open-source chat models have implemented ChatML-compatible or inspired prompt formats to maintain clear conversation context.

      In this article, we’ll explore what ChatML is, how it works, and why it matters for building smarter AI systems.

      Want to go deep into ChatML? Explore my new book on the topic “🚀 The ChatML (Chat Markup Language) Handbook“.

      What is ChatML?

      ChatML is a lightweight, plain-text markup format designed to give large language models a clear, structured way to understand conversation history.

      Instead of sending raw text, developers wrap messages with special tokens that identify the role of the speaker (system, user, assistant, or tool) and the message content.

      For example:

      <|im_start|>system
      You are a helpful assistant.
      <|im_end|>
      <|im_start|>user
      What's the capital of France?
      <|im_end|>
      <|im_start|>assistantCode language: HTML, XML (xml)

      Here’s what’s happening:

      • system → Sets rules, instructions, or context for the AI.
      • user → Represents a message from the end-user.
      • assistant → Represents the AI’s reply.
      • <|im_start|> & <|im_end|> → Special tokens to mark message boundaries.

      Why Does ChatML Exist?

      In early LLM implementations, prompts were often long strings with no strict structure. This made them fragile — minor wording changes could break expected behavior.

      ChatML solves this by:

      • Separating roles clearly → The model knows who said what.
      • Making multi-turn conversations stable → No guessing where one message ends and another begins.
      • Supporting system-level control → Developers can enforce guidelines (e.g., tone, style, or restrictions).

      Roles in ChatML

      RolePurpose
      systemDefines the AI’s personality, constraints, and instructions.
      userThe actual human input.
      assistantThe AI’s output in the conversation.
      toolFor calling or simulating API/tool outputs (in some implementations).

      Building a ChatML Prompt in Python

      Here’s a quick helper function to convert a list of messages into ChatML format:

      def to_chatml(messages):
          chatml = ""
          for m in messages:
              chatml += f"< |im_start|>{m['role']}\n{m['content']}<|im_end|>\n"
          chatml += "<|im_start|>assistant\n"  # Leave open for AI's reply
          return chatml
      
      messages = [
          {"role": "system", "content": "You are a helpful assistant."},
          {"role": "user", "content": "Tell me a joke."}
      ]
      
      print(to_chatml(messages))

      This produces a properly formatted ChatML string ready for the model.

      Advantages of Using ChatML

      1. Consistency – Prevents prompt breakage due to formatting errors.
      2. Flexibility – Works for single-turn and multi-turn conversations.
      3. Control – Gives developers fine-grained control over model behavior.
      4. Scalability – Easy to extend for new roles or system instructions.

      When to Use ChatML

      • Custom LLM Applications – If you’re building a chatbot with models like GPT-3.5, GPT-4, or Qwen.
      • Multi-Turn Conversations – Where keeping track of roles is important.
      • Prompt Engineering – For reliable, repeatable outputs.

      ChatML Beyond OpenAI: How Other LLMs Use It

      Although ChatML began as an OpenAI-specific format, its structure has proven so practical that many other large language models have adopted either direct compatibility or ChatML-inspired variations.

      Here’s how some popular LLMs approach it:

      1. Qwen (Alibaba Cloud)

      Qwen models (including Qwen2 and Qwen2.5) support ChatML-style formatting directly. They use the same <|im_start|> and <|im_end|> tokens with roles like system, user, and assistant. This makes it easy for developers to swap prompts between OpenAI models and Qwen without heavy modifications.

      2. Anthropic Claude

      Claude doesn’t use ChatML syntax literally, but it follows the same role-based conversation pattern — separating system instructions, user messages, and assistant replies. Developers often wrap Claude prompts in ChatML-like structures for internal consistency in multi-model applications.

      3. Mistral / Mixtral

      Some Mistral-based chat models on Hugging Face have fine-tunes that understand ChatML, especially in the open-source community. This helps standardize multi-turn conversations without reinventing formatting rules.

      4. Open-Source Fine-Tunes

      Many open-source LLaMA 2/3 fine-tunes — such as Vicuna, Alpaca, and WizardLM — adopt ChatML or similar message separation schemes. Even if the tokens differ, the concept of “role + message boundary” comes directly from ChatML’s influence.

      ChatML Compatibility Across LLMs

      LLM / Model FamilyChatML SupportNotes on Usage
      OpenAI GPT-3.5 / GPT-4✅ Full supportNative format, uses <|im_start|> / <|im_end|> tokens with roles (system, user, assistant).
      Qwen / Qwen2 / Qwen2.5✅ Full supportChatML-compatible; directly understands OpenAI-style role markup.
      Anthropic Claude⚠️ Partial / AdaptedDoesn’t use ChatML tokens but follows the same role/message separation; can be adapted easily.
      Mistral / Mixtral Chat Models⚠️ Partial / Fine-tune dependentSome fine-tunes understand ChatML, others require a different role separator format.
      LLaMA-based Fine-Tunes (Vicuna, WizardLM, etc.)⚠️ Partial / InspiredOften trained with similar role-based prompts but token formats may differ.
      Gemini (Google)❌ No native supportUses its own structured prompt format, but conceptually similar in role separation.
      Falcon Chat Models⚠️ Partial / InspiredMany fine-tunes replicate ChatML-style conversations for compatibility.

      Why This Matters for Developers

      By understanding ChatML’s role-based design, you can:

      • Switch between models with minimal prompt changes.
      • Standardize multi-model pipelines using one consistent conversation format.
      • Avoid prompt fragility when moving from prototyping to production.

      In short, ChatML isn’t just an OpenAI thing anymore — it’s becoming a de facto standard for structuring chatbot conversations across the LLM ecosystem.

      Summary

      ChatML might look like a simple markup, but it plays a huge role in making conversations with AI structured, predictable, and controllable. If you’re building an app that needs to work across multiple LLMs, it’s smart to create a prompt formatting layer in your code. This layer can output true ChatML for models that support it and convert it to a role-based equivalent for those that don’t.