"""FAISS-backed vector store with persistence and metadata tracking."""
from __future__ import annotations
import json
import logging
import pickle
from pathlib import Path
import faiss
import numpy as np
from paperrag.chunker import Chunk
from paperrag.config import PaperRAGConfig
logger = logging.getLogger(__name__)
INDEX_FILE = "faiss.index"
METADATA_FILE = "metadata.json"
CONFIG_SNAPSHOT_FILE = "config_snapshot.json"
FILE_HASHES_FILE = "file_hashes.json"
VERSION_FILE = "version.json"
[docs]
class VectorStore:
"""Manages a FAISS IndexFlatIP index plus chunk metadata on disk."""
def __init__(self, index_dir: Path, dimension: int) -> None:
self.index_dir = index_dir
self.dimension = dimension
self.index: faiss.IndexFlatIP = faiss.IndexFlatIP(dimension)
self.chunks: list[dict] = []
self.file_hashes: dict[str, str] = {}
self.version: int = 0
# ------------------------------------------------------------------
# Persistence
# ------------------------------------------------------------------
[docs]
def save(self, config: PaperRAGConfig | None = None) -> None:
"""Write index, metadata, hashes, version with atomic operations."""
import shutil
self.index_dir.mkdir(parents=True, exist_ok=True)
# Write FAISS index atomically
index_tmp = self.index_dir / f"{INDEX_FILE}.tmp"
faiss.write_index(self.index, str(index_tmp))
shutil.move(str(index_tmp), str(self.index_dir / INDEX_FILE))
# Write metadata as pickle (much faster than JSON)
metadata_tmp = self.index_dir / f"{METADATA_FILE}.tmp"
with open(metadata_tmp, 'wb') as f:
pickle.dump(self.chunks, f, protocol=pickle.HIGHEST_PROTOCOL)
shutil.move(str(metadata_tmp), str(self.index_dir / METADATA_FILE))
# Write file hashes as pickle
hashes_tmp = self.index_dir / f"{FILE_HASHES_FILE}.tmp"
with open(hashes_tmp, 'wb') as f:
pickle.dump(self.file_hashes, f, protocol=pickle.HIGHEST_PROTOCOL)
shutil.move(str(hashes_tmp), str(self.index_dir / FILE_HASHES_FILE))
# Version file remains JSON (small, human-readable)
version_tmp = self.index_dir / f"{VERSION_FILE}.tmp"
version_tmp.write_text(
json.dumps({"version": self.version, "dimension": self.dimension})
)
shutil.move(str(version_tmp), str(self.index_dir / VERSION_FILE))
# Config snapshot remains JSON (human-readable)
if config is not None:
config_tmp = self.index_dir / f"{CONFIG_SNAPSHOT_FILE}.tmp"
config.save_snapshot(config_tmp)
shutil.move(str(config_tmp), str(self.index_dir / CONFIG_SNAPSHOT_FILE))
logger.info(
"Saved index v%d (%d vectors) to %s",
self.version,
self.index.ntotal,
self.index_dir,
)
[docs]
@classmethod
def load(cls, index_dir: Path) -> VectorStore:
"""Load an existing index from disk."""
version_data = json.loads((index_dir / VERSION_FILE).read_text())
dimension = version_data["dimension"]
version = version_data["version"]
store = cls(index_dir, dimension)
store.index = faiss.read_index(str(index_dir / INDEX_FILE))
# Load metadata - support both pickle (new) and JSON (legacy)
metadata_file = index_dir / METADATA_FILE
try:
# Try pickle first (new format)
with open(metadata_file, 'rb') as f:
store.chunks = pickle.load(f)
logger.info("Loaded pickle metadata format")
except (pickle.UnpicklingError, ValueError, EOFError):
# Fallback to JSON (legacy format)
logger.info("Falling back to JSON metadata format")
store.chunks = json.loads(metadata_file.read_text())
# Load file hashes - support both pickle (new) and JSON (legacy)
hashes_file = index_dir / FILE_HASHES_FILE
try:
# Try pickle first (new format)
with open(hashes_file, 'rb') as f:
store.file_hashes = pickle.load(f)
except (pickle.UnpicklingError, ValueError, EOFError):
# Fallback to JSON (legacy format)
store.file_hashes = json.loads(hashes_file.read_text())
store.version = version
logger.info(
"Loaded index v%d (%d vectors, dim=%d)",
store.version,
store.index.ntotal,
store.dimension,
)
return store
[docs]
@classmethod
def exists(cls, index_dir: Path) -> bool:
return (index_dir / INDEX_FILE).exists() and (index_dir / VERSION_FILE).exists()
# ------------------------------------------------------------------
# Modification helpers
# ------------------------------------------------------------------
[docs]
def add(self, embeddings: np.ndarray, chunks: list[Chunk]) -> None:
"""Add vectors and their corresponding chunk metadata."""
assert embeddings.shape[0] == len(chunks)
assert embeddings.shape[1] == self.dimension
self.index.add(embeddings)
self.chunks.extend([c.to_dict() for c in chunks])
[docs]
def remove_by_file(self, file_path: str) -> None:
"""Remove all vectors belonging to *file_path*.
Because FAISS IndexFlatIP does not support selective removal we
rebuild the index from the remaining vectors.
"""
keep_indices = [
i for i, c in enumerate(self.chunks) if c["file_path"] != file_path
]
if len(keep_indices) == len(self.chunks):
return
if keep_indices:
all_vecs = faiss.rev_swig_ptr(
self.index.get_xb(), self.index.ntotal * self.dimension
)
all_vecs = np.array(all_vecs, dtype=np.float32).reshape(-1, self.dimension)
kept_vecs = all_vecs[keep_indices]
self.index.reset()
self.index.add(kept_vecs)
else:
self.index.reset()
self.chunks = [self.chunks[i] for i in keep_indices]
[docs]
def set_file_hash(self, file_path: str, file_hash: str) -> None:
self.file_hashes[file_path] = file_hash
[docs]
def get_file_hash(self, file_path: str) -> str | None:
return self.file_hashes.get(file_path)
# ------------------------------------------------------------------
# Search
# ------------------------------------------------------------------
[docs]
def search(
self,
query_vec: np.ndarray,
top_k: int = 3,
file_path: str | None = None,
) -> list[tuple[dict, float]]:
"""Return top-k (chunk_metadata, score) pairs, optionally filtered by file_path."""
if self.index.ntotal == 0:
return []
query_vec = np.asarray(query_vec, dtype=np.float32)
if query_vec.ndim == 1:
query_vec = query_vec.reshape(1, -1)
# If filtering by file_path, we fetch more results to increase the
# chance of finding top-k matches for that specific file.
fetch_k = min(self.index.ntotal, 100 if file_path else top_k)
if file_path and fetch_k < top_k * 5:
fetch_k = min(self.index.ntotal, top_k * 10)
scores, indices = self.index.search(query_vec, fetch_k)
results: list[tuple[dict, float]] = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0:
continue
meta = self.chunks[idx]
if file_path and meta["file_path"] != file_path:
continue
results.append((meta, float(score)))
if len(results) >= top_k:
break
return results