Caching for RAG¶
RAG pipelines make expensive external calls on every request. Caching reduces cost and latency without changing your system's behavior.
Three cache layers¶
| Layer | What you cache | Staleness risk | Recommendation |
|---|---|---|---|
| Embedding cache | embed_texts(text) → vector |
Very low — same text always gives same embedding | Always use |
| Query result cache | retrieve(query) → chunks |
Low — only stale when docs change | Use with TTL |
| LLM response cache | answer_question(q, ctx) → text |
High — answers become stale, personalized, or date-sensitive | Usually skip |
Start with embedding caching. It's zero-risk and can eliminate 80–90% of embedding API calls during ingestion.
Install¶
uv pip install openai cachetools psycopg[binary] python-dotenv
Layer 1: EmbeddingCache (content-addressed file cache)¶
This cache is content-addressed: the same (model, text) always maps to the same file. Safe to use across runs and processes.
# rag/embedding_cache.py
from __future__ import annotations
import hashlib
import json
from pathlib import Path
class EmbeddingCache:
"""File-based embedding cache keyed by sha256(model + text)."""
def __init__(self, cache_dir: str = ".cache/embeddings") -> None:
self._dir = Path(cache_dir)
self._dir.mkdir(parents=True, exist_ok=True)
def _key(self, model: str, text: str) -> str:
return hashlib.sha256(f"{model}:{text}".encode()).hexdigest()
def get(self, model: str, text: str) -> list[float] | None:
path = self._dir / f"{self._key(model, text)}.json"
if path.exists():
return json.loads(path.read_text())
return None
def set(self, model: str, text: str, embedding: list[float]) -> None:
path = self._dir / f"{self._key(model, text)}.json"
path.write_text(json.dumps(embedding))
Integrate the cache into embed_texts()¶
# rag/embed.py
from __future__ import annotations
from openai import OpenAI
from rag.embedding_cache import EmbeddingCache
client = OpenAI()
_cache = EmbeddingCache()
def embed_texts(
texts: list[str],
*,
model: str = "text-embedding-3-small",
) -> list[list[float]]:
"""Embed texts, using the file cache for already-seen texts."""
results: list[list[float] | None] = [None] * len(texts)
uncached: list[tuple[int, str]] = []
# Check cache
for i, text in enumerate(texts):
cached = _cache.get(model, text)
if cached is not None:
results[i] = cached
else:
uncached.append((i, text))
# Batch-embed only the uncached texts
if uncached:
indices, to_embed = zip(*uncached)
resp = client.embeddings.create(model=model, input=list(to_embed))
for idx, item in zip(indices, resp.data):
_cache.set(model, texts[idx], item.embedding)
results[idx] = item.embedding
return results # type: ignore[return-value]
On a re-ingest run, all previously seen chunks hit the cache and zero API calls are made.
Layer 2: TTLCache for query results (in-process)¶
Cache retrieval results in memory for a short window. Use cachetools.TTLCache:
# rag/query_cache.py
from __future__ import annotations
from cachetools import TTLCache
# 512 entries, 60-second TTL
_query_cache: TTLCache = TTLCache(maxsize=512, ttl=60)
def retrieve_cached(
query: str,
*,
k: int = 8,
retrieve_fn,
) -> list[dict]:
"""Return cached results if available, else call retrieve_fn and cache."""
cache_key = f"{query}::{k}"
if cache_key in _query_cache:
return _query_cache[cache_key]
chunks = retrieve_fn(query, k=k)
_query_cache[cache_key] = chunks
return chunks
Usage:
from rag.retrieve import retrieve
from rag.query_cache import retrieve_cached
chunks = retrieve_cached(query, k=8, retrieve_fn=retrieve)
TTL guidance: 60 seconds is safe for most docs corpora that change daily. Reduce to 5–10 seconds if your documents change frequently.
Layer 2 (multi-process): Postgres embedding_cache table¶
The file cache above only works in a single process. For multi-process deployments (e.g., uvicorn --workers 4), use a shared Postgres table:
CREATE TABLE embedding_cache (
cache_key TEXT PRIMARY KEY,
model TEXT NOT NULL,
embedding VECTOR(1536),
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
# rag/pg_embedding_cache.py
from __future__ import annotations
import hashlib
import psycopg
class PgEmbeddingCache:
def __init__(self, conn: psycopg.Connection) -> None:
self._conn = conn
def _key(self, model: str, text: str) -> str:
return hashlib.sha256(f"{model}:{text}".encode()).hexdigest()
def get(self, model: str, text: str) -> list[float] | None:
row = self._conn.execute(
"SELECT embedding FROM embedding_cache WHERE cache_key = %s AND model = %s",
(self._key(model, text), model),
).fetchone()
return list(row[0]) if row else None
def set(self, model: str, text: str, embedding: list[float]) -> None:
self._conn.execute(
"""
INSERT INTO embedding_cache (cache_key, model, embedding)
VALUES (%s, %s, %s)
ON CONFLICT (cache_key) DO NOTHING
""",
(self._key(model, text), model, embedding),
)
self._conn.commit()
When NOT to cache LLM responses¶
Caching answer_question() output seems appealing but has serious risks:
| Risk | Example |
|---|---|
| Stale facts | Price cached on Monday; price changes Tuesday |
| Personalized content | User A's answer cached; shown to User B |
| Date-sensitive answers | "What are the latest releases?" cached for 24 hours |
| Different contexts | Same question → different retrieved chunks → different answer |
Only cache LLM responses if ALL of the following are true: 1. The content is static (policy docs, FAQs) 2. Answers are not personalized 3. You use the exact chunk IDs as part of the cache key 4. You invalidate the cache when source documents change
Next steps¶
- Track cache hit rates in production: Monitoring & Observability
- Add caching to your deployed API: Serving RAG as an API