Retrieval-Augmented Generation (RAG)¶
This notebook shows how to use Kaval.AI’s RagService to build the retrieval
half of a retrieval-augmented generation pipeline: turning text into embeddings,
storing them in PostgreSQL (with the pgvector extension), and running fast
similarity search over them.
RagService handles the whole loop for you:
Indexing — embed text and persist it in the
rag_indextable.Querying — embed a query and return the most similar stored items, ranked by cosine similarity.
Organising — group items into collections and tag them with a source_id and arbitrary JSON metadata so you can filter retrieval.
We work through indexing a small corpus, single and batched queries, collapsing
chunked documents with keep_best, and computing a full similarity matrix.
Setup¶
The notebook lives in notebooks/, so we load the project .env from the
parent directory. That file supplies KAVALAI_DB_URI (a PostgreSQL instance
with pgvector and the rag_index table already migrated) and
KAVALAI_DB_SCHEMA, which tells the ORM which schema to read and write.
import dotenv
dotenv.load_dotenv("../.env")
True
Choosing an embedding model¶
RagService is constructed with an embedding model named as
provider/model. Supported providers are openai, gemini, ollama and
fastembed. We use fastembed here because it runs entirely locally — no
API key or external service is required, which makes the notebook reproducible.
(The model is downloaded from the Hugging Face Hub on first use.)
RagService.from_uri builds the service straight from a database URI. All
indexed items are written to the rag_index table; we keep this tutorial’s data
isolated under a dedicated collection_name so it never mixes with other
collections, and delete it again at the end.
import os
from kavalai.agents.rag_service import RagService
EMBEDDING_MODEL = "fastembed/BAAI/bge-small-en-v1.5"
COLLECTION = "rag_tutorial"
rag = RagService.from_uri(os.environ["KAVALAI_DB_URI"], model=EMBEDDING_MODEL)
Indexing documents¶
batch_index embeds every text and stores it in one round trip. Each item
carries:
texts— the content to embed and store.metadata_list— a JSON dict per item, for filtering and bookkeeping.source_ids— an external identifier per item (e.g. a document id). When one logical document is split into several chunks they share asource_id.collection_name— the logical group the items belong to.
For a single item there is also the convenience method
index(text, source_metadata=..., collection_name=..., source_id=...).
docs = [
"The kakapo is a flightless, nocturnal parrot native to New Zealand.",
"Espresso is brewed by forcing hot water through finely-ground coffee beans.",
"Mount Everest is the highest mountain above sea level, at 8,849 metres.",
"Python's asyncio library provides infrastructure for writing concurrent code.",
"The Great Barrier Reef is the world's largest coral reef system.",
]
metadata = [
{"topic": "animals"},
{"topic": "food"},
{"topic": "geography"},
{"topic": "programming"},
{"topic": "geography"},
]
source_ids = ["kakapo", "espresso", "everest", "asyncio", "reef"]
items = await rag.batch_index(
texts=docs,
metadata_list=metadata,
source_ids=source_ids,
collection_name=COLLECTION,
)
print(f"Indexed {len(items)} documents into collection {COLLECTION!r}.")
/home/timo/projects/kaval.ai/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Indexed 5 documents into collection 'rag_tutorial'.
Querying¶
query embeds the query text and returns the top_k most similar items as
RagServiceResult objects, ordered from most to least similar. Each result
exposes similarity (1.0 − cosine distance, so higher is closer), the stored
content, the source_id, the collection_name and rag_metadata.
Note that retrieval is semantic: the query “a bird that cannot fly” matches the kakapo entry even though it shares almost no words with it.
results = await rag.query("a bird that cannot fly", top_k=3, collection_name=COLLECTION)
for r in results:
print(f"{r.similarity:.3f} [{r.source_id}] {r.content}")
0.663 [kakapo] The kakapo is a flightless, nocturnal parrot native to New Zealand.
0.663 [kakapo] The kakapo is a flightless, nocturnal parrot native to New Zealand.
0.521 [reef] The Great Barrier Reef is the world's largest coral reef system.
A query closer to one specific document produces a much sharper top hit — the similarity gap between the best match and the rest widens.
for r in await rag.query("concurrent programming in python", top_k=2, collection_name=COLLECTION):
print(f"{r.similarity:.3f} [{r.source_id}] {r.content}")
0.839 [asyncio] Python's asyncio library provides infrastructure for writing concurrent code.
0.839 [asyncio] Python's asyncio library provides infrastructure for writing concurrent code.
Chunked documents and keep_best¶
Long documents are usually split into several chunks that all share one
source_id. A plain query can then return multiple chunks from the same
source, crowding out other documents. Passing keep_best=True collapses the
results to the single highest-scoring chunk per source_id.
Here we index three chunks of a “Saturn” document under one source_id and
query them with and without keep_best.
chunks = [
"Chapter 1. Saturn is the sixth planet from the Sun.",
"Chapter 2. Saturn is famous for its prominent ring system.",
"Chapter 3. Saturn has at least 146 known moons, including Titan.",
]
await rag.batch_index(
texts=chunks,
metadata_list=[{"chunk": i} for i in range(len(chunks))],
source_ids=["saturn", "saturn", "saturn"],
collection_name=COLLECTION,
)
q = "rings around a planet"
dup = await rag.query(q, top_k=5, collection_name=COLLECTION, source_ids=["saturn"])
best = await rag.query(q, top_k=5, collection_name=COLLECTION, source_ids=["saturn"], keep_best=True)
print("without keep_best:", [(r.source_id, round(r.similarity, 3)) for r in dup])
print("with keep_best: ", [(r.source_id, round(r.similarity, 3)) for r in best])
without keep_best: [('saturn', 0.744), ('saturn', 0.715), ('saturn', 0.591)]
with keep_best: [('saturn', 0.744)]
The source_ids argument also shows up here: it restricts the search to
specific sources, which is handy when you already know the candidate documents
(for example after a metadata pre-filter).
Batched queries¶
When you have several queries, batch_query embeds and searches them all in a
single database round trip (using a CROSS JOIN LATERAL under the hood). It
returns one result list per query, in input order.
queries = ["coffee preparation", "tall mountains"]
batched = await rag.batch_query(queries, top_k=2, collection_name=COLLECTION)
for q, res in zip(queries, batched):
print(f"query: {q!r}")
for r in res:
print(f" {r.similarity:.3f} [{r.source_id}]")
query: 'coffee preparation'
0.730 [espresso]
0.730 [espresso]
query: 'tall mountains'
0.733 [everest]
0.733 [everest]
For the common case of filtering retrieval by attributes that live in
another table (price, category, availability, …), batch_query_with_join
joins the vector search against that table inside a single SQL statement, so you
never have to round-trip large lists of ids back and forth.
Similarity matrix¶
compute_similarity_matrix scores every query against every listed source_id
and returns a 2-D matrix where matrix[i][j] is the similarity between
texts[i] and source_ids[j]. This is useful for ranking, clustering, or
deduplication tasks where you need all pairwise scores at once.
texts = ["ocean wildlife", "high altitude"]
targets = ["reef", "everest", "kakapo"]
matrix = await rag.compute_similarity_matrix(texts=texts, source_ids=targets)
print(" " + " ".join(f"{t:>8}" for t in targets))
for q, row in zip(texts, matrix):
print(f"{q:>14} " + " ".join(f"{v:8.3f}" for v in row))
reef everest kakapo
ocean wildlife 0.623 0.527 0.559
high altitude 0.547 0.764 0.486
As expected, “ocean wildlife” scores highest against the reef and “high altitude” against Everest — the diagonal of relevant pairs dominates.
Cleanup¶
Finally, remove the tutorial’s items so the shared rag_index table is left as
we found it. delete_by_source_ids deletes every item in a collection whose
source_id is in the given list.
await rag.delete_by_source_ids(
COLLECTION,
["kakapo", "espresso", "everest", "asyncio", "reef", "saturn"],
)
remaining = await rag.query("anything", top_k=10, collection_name=COLLECTION)
print(f"Remaining items in {COLLECTION!r}: {len(remaining)}")
Remaining items in 'rag_tutorial': 0