RAGs to Riches: Using RAG for Solving Zero-Shot Multi-label Classification of Documents ¶
Introduction ¶
Retrieval Augmented Generation (RAG) is one of the most popular techniques for implementing large language models (LLMs) and offers an easy win for companies looking to get on board with the latest developments in artificial intelligence (AI). However, apart from productivity tools like chatbots and Q&A services, RAG has struggled to provide any use cases that drive value in the same way as traditional machine learning and natural language processing. While Q&A services do improve companies' efficiency by putting more information at the fingertips of their employees, the value can be hard to quantify.
In this project, I offer a RAG-powered solution that drives value in a more direct and measurable way, by acting as a zero-shot multi-label classifier.
Overview of Retrieval Augmented Generation (RAG) ¶
RAG is, at its core, just a bit of clever prompting where you ask a question of an LLM while simultaneously providing the LLM with the information it needs to answer the question.
You can imagine handing a person a history textbook and asking them "What year was George Washington born?" The person would be able to look the answer up in the textbook, and this task would be even easier for them if we only provided them with the chunks of text from the book that spoke about George Washington. This is exactly what a RAG system does.
The first step of RAG is creating a knowledge bank, or vector store. In the context of the example above, this would be taking the history book and creating a searchable database from its text. This is acheived by taking chunks of the text, vectorizing the chunks with the encoding side of a transformer model, and storing the embeddings in a vector database, where records can be retrieved based on vector similarity.
Next, when a user asks a question about the source material, the question is embedded using the same model used to embed the source material, and a similarity search is executed against the knowledge bank, retrieving the chunks of text that pertain to the question that's being asked.
The question, along with the context retrieved from the knowledge bank, are then sent to an LLM, which responds with the answer to the question based on the provided context.
RAG as a Document Classifier ¶
One of the most popular uses of RAG within industry is providing users with a way of engaging with largs repositories of text information: think along the lines of providing employees with the ability to ask questions about HR documents, benefits, or standard operating procedures. You could also envision legal departments being able to engage with long and complex documents. Typically, knowledge banks would be composed of corpi of many documents, allowing answers to be compiled from diverse sets of sources, but by creating smaller knowledge banks from single documents and asking specific and intentional questions, one could transform RAG into a robust zero-shot classifier model.
Imagine a simple classification use case where you want to simply classify a book as "Romance" or "Sci-Fi." You could create a knowledge bank from the book, and then ask a prompt like, "Which genre best applies to this book: Romance or Sci-Fi?" Now this wouldn't be the best use of LLM resources, since this could easily be done with classical ML (or even just looking at the cover of the book). But now imagine you have a set of questions that require more knowledge of the book's text: Is there a plot twist? Is it appropriate for teens? Does it contain any illicit themes? Questions like these would require one to spend hours fully reading the book to ascertain the answers, but a RAG application could answer them in less than a minute (probably).
Now imagine a simlilar use case applied to a company's contracts. Does the contract contain a lease agreement? Is the payment structure based on progress milestones or a payment dates? Does it contain any rebates or warrantees? These are questions a legal associate could spend a significant amount of time trying to answer by reading and searching through the contract. Using RAG to address such a use case could reduce this time to essentially nothing. And with RAG systems' ability to provide the sources from which answers were obtained, the legal analyst would have the ability to double-check any answer the LLM provided.
Each of these questions could be framed as a binary classification problem solvable, zero-shot, using RAG.
About the Data ¶
As the data set for this project, we'll use the Jupyter Notebooks of my personal data science projects, rendered as HTML, performing a multi-label classification task to assign which machine learning concepts relate to the project from a list of possible concepts.
Packages and Setup ¶
This project will mostly depend on LangChain for orchestrating the document processing, creation of the knowledge bank, and the prompting. To keep things cheap, we'll use the open-source Ollama model for our embeddings and LLM, and FAISS as an in-memory vector database.
import os
import multiprocessing
import time
# LLMs
from langchain_community.llms import Ollama
from langchain_community.embeddings import OllamaEmbeddings
# Text prep & vector store
from langchain.document_loaders import BSHTMLLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
# RAG
from langchain_core.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
Basic RAG Pipeline ¶
Let's start by building a simple RAG pipeline to allow questions to be asked of my Animal Crossing time-series clustering project. We'll start by reading in the HTML document, splitting it into chunks, and creating the knowleddge bank.
First, let's set the path for where the project HTML is stored and load in the LLMs.
# projects directory
DOC_PATH = '/home/nastory/repos/nigelstorydata_flask/nigelstorydata/templates/'
# load Ollama models
llm = Ollama(model='llama3', temperature=0)
embedding = OllamaEmbeddings(model='llama3')
Let's preview the text.
test_file = os.path.join(DOC_PATH, 'acnh.html')
with open(test_file, 'r') as f:
txt = f.read()
print(txt[:200])
<title>Animal Crossing New Horizons, the Stalk Market</title> <script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.1.10/require.min.js"><
We'll use LangChain's BeautifulSoup HTML Loader to read in and parse the project file. We'll then split the project text into 1,000 character chunks with a 100-character overlap between chunks. We'll then use the Ollama embedding model to create our FAISS vector database.
loader = BSHTMLLoader(test_file)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100
)
documents = splitter.split_documents(docs)
vector_store = FAISS.from_documents(documents, embedding)
Now we'll create the retriever, which will allow us to run similarity searches against our vector database.
retriever = vector_store.as_retriever()
Next we'll create our prompt template. This is perhaps the most important step in the RAG process, as it allows us to determine the behavior of our LLM and pass through the questions and the context chunks retrieved by the retriever. Through our prompting, we can determine what kind of classifier our application will behave as later on. For now, we'll just have it behave as a Q&A service.
prompt_template = """
You are a question answer service. Given the provided context, which comes
from a machine learning project composed in Jupyter Notebooks by Nigel Story,
answer the question below.
<context>
{context}
</context>
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(prompt_template)
Now we'll compose our retriever, prompt, and LLM model into a pipeline into which we can pass our questions.
rag_chain = (
{'context': retriever, 'question': RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
Let's ask our model a question!
res = rag_chain.invoke("what are some of the ML techniques used in this project?")
print(res)
Based on the provided context from the Jupyter Notebooks, some Machine Learning (ML) techniques used in this project include: 1. **Random Forest Classification**: This is a popular ensemble learning method that combines multiple decision trees to improve predictive accuracy. 2. **Artificial Neural Network (ANN) classifier**: A type of feedforward neural network designed for classification tasks. Additionally, the context mentions the use of: 3. **LabelEncoder()**: A pre-processing technique used to convert categorical labels into numerical representations. 4. **to_categorical()**: A function used to convert integer labels into one-hot encoded vectors. These ML techniques are likely used in combination with other tools and libraries (e.g., scikit-learn, TensorFlow) to analyze and predict the trends of turnip prices in Animal Crossing New Horizons.
From this answer, one can definitely see the benefits of RAG. Without having to even look at this project, you could get a good idea of what it's about and which ML techniques were used.
Now let's formalize this pipeline and turn it into a zero-shot classifier.
RAG as a Zero-Shot Document Multi-label Classifier ¶
Using the same process above, let's rework the prompting to create a multi-label classifier rather than a Q&A service. Multi-label classification is a more challenging use case than typical classificatino tasks, requiring more data and more advanced ML techniques to develop models. But piggy-backing off of pre-trained LLMs allows us to get reliable multi-label predictions without any additional training or even any additional data, aka, zero-shot.
In the code cell below, we'll define a RAG pipeline that will allow a user to load in an HTML document along with a list of possible labels to be assigned to the document, and the LLM will decide which labels best apply to the document. For my projects, I'll try to label them according to machine learning concepts that are used within the projects.
def load_html_document(file_path):
"""Load in a project HTML document.
"""
loader = BSHTMLLoader(file_path)
doc = loader.load()
return doc
def chunk_document(doc, chunk_size=1000, chunk_overlap=100):
"""Split a document into text chunks for embedding.
"""
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
documents = splitter.split_documents(doc)
return documents
def rag_retriever(documents, embedding):
"""Create the vector database and the retriever
to execute RAG.
"""
vector_store = FAISS.from_documents(documents, embedding)
retriever = vector_store.as_retriever()
return retriever
def get_rag_prompt(classifier_prompt):
"""Create LangChain prompt template object.
"""
prompt = ChatPromptTemplate.from_template(classifier_prompt)
return prompt
def generate_questions(labels):
"""Generate questions to pass through the RAG pipeline
based on the user-provided possible labels.
"""
question_template = "Does this project relate to {} in a significant way?"
labels_questions = [(label, question_template.format(label)) for label in labels]
return labels_questions
def zero_shot_mutli_label(file_path, labels, llm, embedding):
"""Ingest an HTML document and a list of possible labels and
execute zero-shot multi-label classification.
"""
classifier_prompt = """
You are a helpful yes or no answer service. Given the provided context, which comes
from a machine learning project composed in Jupyter Notebooks by Nigel Story,
answer the question below. Take your time and find the correct answer from
the context. Only respond with "Yes" or "No".
<context>
{context}
</context>
Question: {question}
"""
text = load_html_document(file_path)
documents = chunk_document(text, chunk_size=1000, chunk_overlap=100)
retriever = rag_retriever(documents, embedding)
prompt = get_rag_prompt(classifier_prompt)
rag_chain = (
{'context': retriever, 'question': RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
labels_questions = generate_questions(labels)
start = time.time()
preds = [(q[0], rag_chain.invoke(q[1])) for q in labels_questions]
print(f"time to process prompts: {time.time() - start}s")
return preds
def filter_preds(preds):
"""Filter binary outputs from LLM to only
return the multi-label predictions.
"""
return [p[0] for p in preds if p[1] == 'Yes']
Now that the pipeline's built, we'll just define a somewhat random list of concepts that could apply to my personal python projects.
possible_labels = [
'SQL',
'Classification',
'Clustering',
'Regression',
'Web Development',
'Image Analytics',
'Anomaly Detection',
'Simulation',
'NLP'
]
Let's run the prediction pipeline and look at the results.
The project we're running through the pipeline uses a variety of techniques and technologies, but at its core, it's a clustering analysis followed by classification.
acnh_preds = zero_shot_mutli_label(test_file, possible_labels, llm, embedding)
time to process prompts: 13.631207942962646s
filter_preds(acnh_preds)
['Classification', 'Clustering', 'Regression', 'Simulation']
The model captured the Classification and Clustering categories, and it added the Regression and Simulation categories. These last two categories are ones that I had forgotten about, but there are aspects of both used in the project. So great result that, funnily enough, did better than I might have just from memory!
However, compared to classical ML, 13 seconds is pretty slow to make a single prediction, which is why this technique would mainly benefit complex classification tasks on large documents or corpi. But let's see if we can speed things up a little with parallel processing.
Parallel Prompting Implementation ¶
We can speed up the processing time a bit and make the model robust to high-dimensional multi-label outputs by using parallel processing to execute the prompt created for each label.
def pool_invoke(label, prompt, rag_chain):
'''Global picklable func for use in Pool processing.
'''
response = rag_chain.invoke(prompt)
return label, response
def parallel_prompt_zero_shot_mutli_label(file_path, labels, llm, embedding, n_jobs=1):
"""Parallelized zero-shot multi-label document classifier.
"""
classifier_prompt = """
You are a helpful yes or no answer service. Given the provided context, which comes
from a machine learning project composed in Jupyter Notebooks by Nigel Story,
answer the question below. Take your time and find the correct answer from
the context. Only respond with "Yes" or "No".
<context>
{context}
</context>
Question: {question}
"""
text = load_html_document(file_path)
documents = chunk_document(text, chunk_size=1000, chunk_overlap=100)
retriever = rag_retriever(documents, embedding)
prompt = get_rag_prompt(classifier_prompt)
rag_chain = (
{'context': retriever, 'question': RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
labels_questions = generate_questions(labels)
start = time.time()
if n_jobs == -1:
def mp_invoke(label, prompt, pred_dict):
result = rag_chain.invoke(prompt)
pred_dict[label] = result
# use max possible processes
manager = multiprocessing.Manager()
pred_dict = manager.dict()
jobs = []
for question in labels_questions:
p = multiprocessing.Process(target=mp_invoke, args=(question[0], question[1], pred_dict))
jobs.append(p)
p.start()
for p in jobs:
p.join()
results = pred_dict.items()
else:
# use specified number
pool = multiprocessing.Pool(processes=n_jobs)
results = [pool.apply(pool_invoke, args=(question[0], question[1], rag_chain)) for question in labels_questions]
print(f"time to process prompts: {time.time() - start}s")
return results
parallel_acnh_preds = parallel_prompt_zero_shot_mutli_label(test_file, possible_labels, llm, embedding, n_jobs=-1)
time to process prompts: 11.637054204940796s
filter_preds(parallel_acnh_preds)
['Classification', 'Clustering', 'Regression', 'Simulation']
As expected, the time save is pretty marginal from parallel prompting. A more effective use of parallelization would likely be in creating the vector stores and retrievers. I hope to explore this in future projects.
Conclusion ¶
This approach to RAG -- using it as a vehicle for accomplishing complex classification tasks on large documents or corpi -- would provide business with value adds similar to the savings produced by classical ML, but the development time to produce these savings are drastically lessened. Rather than months of feature engineering, model selection, hyperparameter tuning, and training, and AI engineer could very quickly spin up a RAG-based zero-shot model that would produce reliable results without much effort. This is where I see a lot of savings from AI coming from in future.
In follow-up projects, I'd like to examine cross-contamination of source materials from creating larger vector stores and trying to classify documents within the store, rather than creating a new vector store for each document. This could produce significant time savings, and would be worth comparison to the parallel creation of vectore stores for individual documents.