PyLate Late Interaction Retrieval
PyLate makes ColBERT style late interaction practical in Python. Better than bi encoders for precision, costlier to store. Here is when it is worth it.
There is a gap in how most retrieval systems work today. Bi encoders are fast and cheap but lose a lot of semantic detail by compressing everything into a single vector. Cross encoders are accurate but too slow to run over thousands of documents per query. Late interaction sits between them, and PyLate is the cleanest way to use it in Python.
The basic idea behind late interaction is that you should not have to choose between speed and precision at encoding time. You encode your documents once, keep the individual token embeddings, and do the meaningful comparison at scoring time. That single change fixes a lot of what goes wrong with standard retrieval.
The Problem With Single Vector Retrieval#
When a bi encoder processes “the patient has no history of cardiac disease”, it produces one vector for the whole sentence. That vector captures the general topic, but the word “no” barely changes it. Run a query like “patients with cardiac conditions” and the similarity score will be high, even though the document is explicitly saying the opposite.
This is not a failure of the model. It is a consequence of compressing a sentence into a fixed size vector. You are asking 768 numbers to represent everything about a piece of text, and some of the nuance does not survive the compression.
Cross encoders solve this by processing the query and document together. The model can attend across both at once, so negation and precise phrasing actually matter. The problem is you cannot pre-compute anything. Every query requires a forward pass over every candidate document, which is not practical at scale.
Late interaction keeps the token level detail from the bi encoder approach while recovering much of the precision that cross encoders get. You pay for it in storage, and PyLate is what makes the tradeoff manageable.
How Late Interaction Actually Works#
The ColBERT model that PyLate builds on produces one embedding per token rather than one embedding per document. A 40 token document becomes 40 vectors. A 12 token query becomes 12 vectors.
Scoring works through a MaxSim operation. For each query token, you find its maximum cosine similarity to any token in the document, then sum those maximum scores across all query tokens.
import torch
import torch.nn.functional as F
def maxsim(query_embs, doc_embs):
# query_embs: (num_query_tokens, dim)
# doc_embs: (num_doc_tokens, dim)
sims = torch.matmul(query_embs, doc_embs.T) # (q_tokens, d_tokens)
return sims.max(dim=1).values.sum()pythonWhat this means in practice: if your query contains the token “not”, it will look for its best match across every token in the document. A document that contains “not” in a relevant position scores higher than one that does not. The scoring step is sensitive to the presence and placement of individual words in a way that single vector similarity cannot be.
Getting Started#
pip install pylatebashPyLate is built on top of sentence transformers, so the API should feel familiar.
from pylate import models
model = models.ColBERT(
model_name_or_path="lightonai/colbertv2.0"
)
# Encoding documents and queries returns token level embeddings
query_embeddings = model.encode(
["what causes transformer overfitting"],
is_query=True
)
doc_embeddings = model.encode(
[
"Overfitting in transformers is caused by insufficient regularization.",
"Transformer architecture uses self attention mechanisms."
],
is_query=False
)
print(query_embeddings[0].shape) # (num_query_tokens, 128)
print(doc_embeddings[0].shape) # (num_doc_tokens, 128)pythonNote that queries and documents are encoded differently. PyLate handles the ColBERT query augmentation internally, which pads query sequences to a fixed length to make batched MaxSim more efficient.
Retrieval with PLAID#
Storing and searching raw token embeddings is not practical for large corpora without an index. PyLate integrates with PLAID, ColBERT’s indexing and retrieval engine.
from pylate import indexes, retrieve
index = indexes.PLAID(
index_path="./my_index",
index_name="docs",
override=True
)
retriever = retrieve.ColBERT(index=index)
# Index your documents
retriever.index(
collection=["Document one text.", "Document two text.", "Document three text."],
document_ids=["doc1", "doc2", "doc3"],
batch_size=32
)
# Retrieve
results = retriever.retrieve(
queries=["what is document one about"],
top_k=5
)pythonPLAID uses a two stage process: a fast approximate candidate retrieval followed by exact MaxSim scoring on the shortlist. In practice this gives you most of the accuracy of exact search at a fraction of the compute cost.
Fine Tuning on Your Own Data#
The ColBERT models available on HuggingFace are trained on MS MARCO, which is a general web retrieval dataset. For domain specific work, fine tuning on your own data makes a noticeable difference.
PyLate uses triplet data: an anchor query, a positive document, and a hard negative document.
from datasets import load_dataset
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from pylate import losses, models
model = models.ColBERT(model_name_or_path="lightonai/colbertv2.0")
train_dataset = load_dataset("json", data_files="train_triplets.jsonl", split="train")
# Expected format: {"query": "...", "positive": "...", "negative": "..."}
loss = losses.Contrastive(model=model)
args = SentenceTransformerTrainingArguments(
output_dir="./colbert_finetuned",
num_train_epochs=3,
per_device_train_batch_size=16,
learning_rate=3e-6,
)
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
loss=loss
)
trainer.train()pythonThe learning rate is worth paying attention to. ColBERT models are sensitive to it. Values around 3e-6 tend to work well as a starting point, but you will want to validate on a held out set specific to your domain before going to production.
The Storage Cost Is Real#
Here is the tradeoff that does not get mentioned enough. A bi encoder stores one 768 dimensional vector per document. A ColBERT model stores one 128 dimensional vector per token per document.
For a document averaging 200 tokens:
| Model Type | Vectors per Document | Storage per 1M Docs |
|---|---|---|
| Bi encoder (768d) | 1 | ~3 GB |
| ColBERT (128d, 200 tokens) | 200 | ~100 GB |
That is roughly a 30x increase in storage. At a million documents that is manageable. At a billion documents you are looking at 100 TB of index storage, and the engineering problem becomes as much about storage systems as it is about the model.
PLAID reduces the effective storage through quantization and compression, but the fundamental cost is still there. If you are running a lean infrastructure this matters.
When to Actually Use It#
Late interaction is worth the complexity in situations where retrieval precision directly affects downstream quality and storage is not a blocker.
Cases where it tends to pay off: legal or medical document retrieval where negation and precise phrasing change the meaning completely, technical support where users phrase questions differently from how documentation is written, and any domain where fine tuning on specific terminology makes a meaningful difference.
Cases where it is overkill: general content recommendation, semantic similarity between short texts, or any use case where a well tuned bi encoder already retrieves the right documents. If your current pipeline retrieves the right answer in the top five results most of the time, late interaction will not move the needle enough to justify the storage cost.
The practical pattern most teams end up with is a bi encoder for first pass retrieval at scale, followed by a cross encoder re ranker on the top 50 results. PyLate gives you a third option that sits between those two stages, and it is genuinely useful when the existing two stage pipeline still misses too many relevant documents.
What PyLate Gets Right#
A year ago, using ColBERT in Python meant either working directly with the original Stanford research code or stitching together pieces from several libraries. PyLate wraps all of it cleanly: model loading, query and document encoding, PLAID indexing, retrieval, and training through the sentence transformers trainer interface.
If you want to try late interaction retrieval without spending a week reading research code, PyLate is the right starting point. The storage cost and the more involved indexing pipeline are real constraints, but within those constraints it works well.