Retrieval Augmented Classification: Improving Text Classification with External Knowledge
When and How to best use LLMs as text classifiers The post Retrieval Augmented Classification: Improving Text Classification with External Knowledge appeared first on Towards Data Science.

LLMs vs Custom ML models for text classification
Pros:
Let’s first explore the pro and cons of each of the two approaches to do text classification.
Large language models as general purpose classifiers:
- High generalization ability given the vast pre-training corpus and reasoning abilities of the LLM.
- A single general purpose LLM can handle multiple classifications tasks without the need to deploy a model for each.
- As Llms continue to improve, you can potentially enhance accuracy with minimal effort simply by adopting newer, more powerful models as they become available.
- The availability of most LLMs as managed services significantly reduces the deployment knowledge and effort required to get started.
- LLMs often outperform custom ML models in low-data scenarios where labeled data is limited or costly to obtain.
- LLMs generalize to multiple languages.
- LLMs can be cheaper when having low or unpredictable volumes of predictions if you pay per token.
- Class definitions can be changed dynamically without retraining by simply modifying the prompts.
Cons:
- LLMs are prone to hallucinations.
- LLMs can be slow, or at least slower than small custom ML models.
- They require prompt engineering effort.
- High-throughput applications using LLMs-as-a-service may quickly encounter quota limitations.
- This approach becomes less effective with a very large number of potential classes due to context size constraints. Defining all the classes would consume a significant portion of the available and effective input context.
- LLMs usually have worse accuracy than custom models in the high data regime.
Custom Machine Learning models:
Pros:
- Efficient and fast.
- More flexible in architecture choice, training and serving method.
- Ability to add interpretability and uncertainty estimation aspects to the model.
- Higher accuracy in the high data regime.
- You keep control of your model and serving infrastructure.
Cons:
- Requires frequent re-trainings to adapt to new data or distribution changes.
- May need significant amounts of labeled data.
- Limited generalization.
- Sensitive to out-of-domain vocabulary or formulations.
- Requires MLOps knowledge for deployment.
Bridging the gap between custom text classifier and LLMs:
Let’s work on a way to keep the pros of using LLMs for classification while alleviating some of the cons. We will take inspiration from RAG and use a prompting technique called few-shot prompting.
Let’s define both:
RAG
Retrieval Augmented Generation is a popular method that augments the LLM context with external knowledge before asking a question. This reduces the likelihood of hallucination and improves the quality of the responses.
Few-shot prompting
In each classification task, we show the LLM examples of inputs and expected outputs as part of the prompt to help it understand the task.
Now, the main idea of this project is mixing both. We dynamically fetch examples that are the most similar to the text query to be classified and inject them as few-shot example prompts. We also limit the scope of possible classes dynamically using those of the K-nearest neighbors. This frees up a significant amount of tokens in the input context when working with a classification problem with a large number of possible classes.
Here is how that would work:
Let’s go through the practical steps of getting this approach to run:
- Building a knowledge base of labeled input text / category pairs. This will be our source of external knowledge for the LLM. We will be using ChromaDB.
from typing import List
from uuid import uuid4
from langchain_core.documents import Document
from chromadb import PersistentClient
from langchain_chroma import Chroma
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import torch
from tqdm import tqdm
from chromadb.config import Settings
from retrieval_augmented_classification.logger import logger
class DatasetVectorStore:
"""ChromaDB vector store for PublicationModel objects with SentenceTransformers embeddings."""
def __init__(
self,
db_name: str = "retrieval_augmented_classification", # Using db_name as collection name in Chroma
collection_name: str = "classification_dataset",
persist_directory: str = "chroma_db", # Directory to persist ChromaDB
):
self.db_name = db_name
self.collection_name = collection_name
self.persist_directory = persist_directory
# Determine if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
self.embeddings = HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-small-en-v1.5",
model_kwargs={"device": device},
encode_kwargs={
"device": device,
"batch_size": 100,
}, # Adjust batch_size as needed
)
# Initialize Chroma vector store
self.client = PersistentClient(
path=self.persist_directory, settings=Settings(anonymized_telemetry=False)
)
self.vector_store = Chroma(
client=self.client,
collection_name=self.collection_name,
embedding_function=self.embeddings,
persist_directory=self.persist_directory,
)
def add_documents(self, documents: List) -> None:
"""
Add multiple documents to the vector store.
Args:
documents: List of dictionaries containing document data. Each dict needs a "text" key.
"""
local_documents = []
ids = []
for doc_data in documents:
if not doc_data.get("id"):
doc_data["id"] = str(uuid4())
local_documents.append(
Document(
page_content=doc_data["text"],
metadata={k: v for k, v in doc_data.items() if k != "text"},
)
)
ids.append(doc_data["id"])
batch_size = 100 # Adjust batch size as needed
for i in tqdm(range(0, len(documents), batch_size)):
batch_docs = local_documents[i : i + batch_size]
batch_ids = ids[i : i + batch_size]
# Chroma's add_documents doesn't directly support pre-defined IDs. Upsert instead.
self._upsert_batch(batch_docs, batch_ids)
def _upsert_batch(self, batch_docs: List[Document], batch_ids: List[str]):
"""Upsert a batch of documents into Chroma. If the ID exists, it updates; otherwise, it creates."""
texts = [doc.page_content for doc in batch_docs]
metadatas = [doc.metadata for doc in batch_docs]
self.vector_store.add_texts(texts=texts, metadatas=metadatas, ids=batch_ids)
This class handles creating a collection and embedding each document’s before inserting it into the vector index. We use BAAI/bge-small-en-v1.5 but any embedding model would work, even those available as-a-service from Gemini, OpenAI, or Nebius.
- Finding the K nearest neighbors for an input text
def search(self, query: str, k: int = 5) -> List[Document]:
"""Search documents by semantic similarity."""
results = self.vector_store.similarity_search(query, k=k)
return results
This method returns the documents in the vector database that are most similar to our input.
- Building the Retrieval Augmented Classifier
from typing import Optional
from pydantic import BaseModel, Field
from collections import Counter
from retrieval_augmented_classification.vector_store import DatasetVectorStore
from tenacity import retry, stop_after_attempt, wait_exponential
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
class PredictedCategories(BaseModel):
"""
Pydantic model for the predicted categories from the LLM.
"""
reasoning: str = Field(description="Explain your reasoning")
predicted_category: str = Field(description="Category")
class RAC:
"""
A hybrid classifier combining K-Nearest Neighbors retrieval with an LLM for multi-class prediction.
Finds top K neighbors, uses top few-shot for context, and uses all neighbor categories
as potential prediction candidates for the LLM.
"""
def __init__(
self,
vector_store: DatasetVectorStore,
llm_client,
knn_k_search: int = 30,
knn_k_few_shot: int = 5,
):
"""
Initializes the classifier.
Args:
vector_store: An instance of DatasetVectorStore with a search method.
llm_client: An instance of the LLM client capable of structured output.
knn_k_search: The number of nearest neighbors to retrieve from the vector store.
knn_k_few_shot: The number of top neighbors to use as few-shot examples for the LLM.
Must be less than or equal to knn_k_search.
"""
self.vector_store = vector_store
self.llm_client = llm_client
self.knn_k_search = knn_k_search
self.knn_k_few_shot = knn_k_few_shot
@retry(
stop=stop_after_attempt(3), # Retry LLM call a few times
wait=wait_exponential(multiplier=1, min=2, max=5), # Shorter waits for demo
)
def predict(self, document_text: str) -> Optional[str]:
"""
Predicts the relevant categories for a given document text using KNN retrieval and an LLM.
Args:
document_text: The text content of the document to classify.
Returns:
The predicted category
"""
neighbors = self.vector_store.search(document_text, k=self.knn_k_search)
all_neighbor_categories = set()
valid_neighbors = [] # Store neighbors that have metadata and categories
for neighbor in neighbors:
if (
hasattr(neighbor, "metadata")
and isinstance(neighbor.metadata, dict)
and "category" in neighbor.metadata
):
all_neighbor_categories.add(neighbor.metadata["category"])
valid_neighbors.append(neighbor)
else:
pass # Suppress warnings for cleaner demo output
if not valid_neighbors:
return None
category_counts = Counter(all_neighbor_categories)
ranked_categories = [
category for category, count in category_counts.most_common()
]
if not ranked_categories:
return None
few_shot_neighbors = valid_neighbors[: self.knn_k_few_shot]
messages = []
system_prompt = f"""You are an expert multi-class classifier. Your task is to analyze the provided document text and assign the most relevant category from the list of allowed categories.
You MUST only return categories that are present in the following list: {ranked_categories}.
If none of the allowed categories are relevant, return an empty list.
Return the categories by likelihood (more confident to least confident).
Output your prediction as a JSON object matching the Pydantic schema: {PredictedCategories.model_json_schema()}.
"""
messages.append(SystemMessage(content=system_prompt))
for i, neighbor in enumerate(few_shot_neighbors):
messages.append(
HumanMessage(content=f"Document: {neighbor.page_content}")
)
expected_output_json = PredictedCategories(
reasoning="Your reasoning here",
predicted_category=neighbor.metadata["category"]
).model_dump_json()
# Simulate the structure often used with tool calling/structured output
ai_message_with_tool = AIMessage(
content=expected_output_json,
)
messages.append(ai_message_with_tool)
# Final user message: The document text to classify
messages.append(HumanMessage(content=f"Document: {document_text}"))
# Configure the client for structured output with the Pydantic schema
structured_client = self.llm_client.with_structured_output(PredictedCategories)
llm_response: PredictedCategories = structured_client.invoke(messages)
predicted_category = llm_response.predicted_category
return predicted_category if predicted_category in ranked_categories else None
The first part of the code defines the structure of the output we expect from the LLM. The Pydantic class has two fields, the reasoning, used for chain-of-though prompting (https://www.promptingguide.ai/techniques/cot) and the predicted category.
The predict method first finds the K nearest neighbors and uses them as few-shot prompts by creating a synthetic message history as if the LLM gave the correct categories for each of the KNN, then we inject the query text as the last human message.
We filter the value to check if it is valid and if so, return it.
- Example prediction:
_rac = RAC(
vector_store=store,
llm_client=llm_client,
knn_k_search=50,
knn_k_few_shot=10,
)
print(
f"Initialized rac with knn_k_search={_rac.knn_k_search}, knn_k_few_shot={_rac.knn_k_few_shot}."
)
text = """Ivanoe Bonomi [iˈvaːnoe boˈnɔːmi] (18 October 1873 – 20 April 1951) was an Italian politician and statesman before and after World War II. Bonomi was born in Mantua. He was elected to the Italian Chamber of Deputies in ...
"""
category = _rac.predict(text)
print(text)
print(category)
text = """Michel Rocard, né le 23 août 1930 à Courbevoie et mort le 2 juillet 2016 à Paris, est un haut fonctionnaire et ...
"""
category = _rac.predict(text)
print(text)
print(category)
Both inputs return the prediction “PrimeMinister” even though the second example is in french while the training dataset is fully in English. This illustrates the generalization abilities of this approach even across similar languages.
- Evaluation:
We use the DBPedia Classes dataset’s l3 categories (https://www.kaggle.com/datasets/danofer/dbpedia-classes ,License CC BY-SA 3.0.) for our evaluation. This dataset has more than 200 categories and 240000 training samples.
We benchmark the Retrieval Augmented Classification approach against a simple KNN classifier with majority vote and obtain the following results the DBpedia dataset’s l3 categories:
Accuracy Average Latency Throughput (multi-threaded) KNN classifier 87% 24ms 108 predictions / s LLM only classifier 88% ~600ms 47 predictions / s RAC 96% ~1s 27 predictions / s
By reference, the best accuracy I found on Kaggle notebooks for this dataset’s l3 level was around 94% using custom ML models.
We note that combining a KNN search with the reasoning abilities of an LLM allows us to gain +9% accuracy points but comes at a cost of a lower throughput and higher latency.
Conclusion
In this project we built a text classifier that leverages “retrieval” to boost the ability of an LLM to find the correct category of the input content. This approach offers several advantages over traditional ML text classifiers. These include the ability to dynamically change the training dataset without retraining, a higher generalization ability due to the reasoning and general knowledge of LLMs, easy deployment when using managed LLM services compared to custom ML models, and the capability to handle multiple classification tasks with a single base LLM model. This comes at a cost of higher latency and lower throughput and a risk of LLM vendor lock-in.
This method should not be your first go-to when working on a classification task but would still be useful as part of your toolbox when your application can benefit from the flexibility of not having to re-train a classifier every time the data changes or when working with a small amount of labeled data. It can also allow you to get a target of having a classification service up and running very quickly when a deadline is looming
Read More