181 lines
5.6 KiB
Python
181 lines
5.6 KiB
Python
"""
|
|
BM25 Service for exact keyword matching in retrieval.
|
|
|
|
This service implements the BM25 algorithm for better lexical search,
|
|
complementing the semantic search provided by vector embeddings.
|
|
"""
|
|
import os
|
|
import pickle
|
|
import numpy as np
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
from rank_bm25 import BM25Okapi
|
|
import re
|
|
import nltk
|
|
from nltk.tokenize import word_tokenize
|
|
from nltk.corpus import stopwords
|
|
|
|
# Download NLTK resources
|
|
try:
|
|
nltk.data.find('tokenizers/punkt')
|
|
except LookupError:
|
|
nltk.download('punkt', quiet=True)
|
|
|
|
try:
|
|
nltk.data.find('corpora/stopwords')
|
|
except LookupError:
|
|
nltk.download('stopwords', quiet=True)
|
|
|
|
class BM25Service:
|
|
"""Service for BM25-based search."""
|
|
|
|
# BM25 index and corpus
|
|
_bm25 = None
|
|
_corpus = []
|
|
_doc_ids = []
|
|
_index_path = os.path.join("chromadb", "bm25_index.pkl")
|
|
|
|
@staticmethod
|
|
def preprocess_text(text: str) -> List[str]:
|
|
"""
|
|
Preprocess text for BM25 indexing.
|
|
|
|
Args:
|
|
text (str): Text to preprocess
|
|
|
|
Returns:
|
|
List[str]: List of preprocessed tokens
|
|
"""
|
|
# Convert to lowercase
|
|
text = text.lower()
|
|
|
|
# Remove special characters and digits
|
|
text = re.sub(r'[^\w\s]', ' ', text)
|
|
text = re.sub(r'\d+', ' ', text)
|
|
|
|
# Tokenize
|
|
tokens = word_tokenize(text)
|
|
|
|
# Remove stopwords
|
|
stop_words = set(stopwords.words('english'))
|
|
tokens = [token for token in tokens if token not in stop_words and len(token) > 1]
|
|
|
|
return tokens
|
|
|
|
@staticmethod
|
|
def index_documents(documents: List[str], doc_ids: List[str]) -> None:
|
|
"""
|
|
Create a BM25 index for a list of documents.
|
|
|
|
Args:
|
|
documents (List[str]): List of document contents
|
|
doc_ids (List[str]): List of document IDs
|
|
"""
|
|
# Preprocess documents
|
|
tokenized_corpus = [BM25Service.preprocess_text(doc) for doc in documents]
|
|
|
|
# Create BM25 index
|
|
BM25Service._bm25 = BM25Okapi(tokenized_corpus)
|
|
BM25Service._corpus = documents
|
|
BM25Service._doc_ids = doc_ids
|
|
|
|
# Save index to disk
|
|
BM25Service.save_index()
|
|
|
|
@staticmethod
|
|
def add_document(document: str, doc_id: str) -> None:
|
|
"""
|
|
Add a single document to the BM25 index.
|
|
|
|
Args:
|
|
document (str): Document content
|
|
doc_id (str): Document ID
|
|
"""
|
|
# Create index if it doesn't exist
|
|
if BM25Service._bm25 is None:
|
|
BM25Service.load_index()
|
|
if BM25Service._bm25 is None:
|
|
BM25Service.index_documents([document], [doc_id])
|
|
return
|
|
|
|
# Add document to corpus
|
|
BM25Service._corpus.append(document)
|
|
BM25Service._doc_ids.append(doc_id)
|
|
|
|
# Preprocess document
|
|
tokenized_doc = BM25Service.preprocess_text(document)
|
|
|
|
# Rebuild index
|
|
tokenized_corpus = [BM25Service.preprocess_text(doc) for doc in BM25Service._corpus]
|
|
BM25Service._bm25 = BM25Okapi(tokenized_corpus)
|
|
|
|
# Save index to disk
|
|
BM25Service.save_index()
|
|
|
|
@staticmethod
|
|
def search(query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
|
"""
|
|
Search for documents using BM25.
|
|
|
|
Args:
|
|
query (str): Query text
|
|
top_k (int): Number of results to return
|
|
|
|
Returns:
|
|
List[Tuple[str, float]]: List of (doc_id, score) tuples
|
|
"""
|
|
# Load index if it doesn't exist
|
|
if BM25Service._bm25 is None:
|
|
BM25Service.load_index()
|
|
if BM25Service._bm25 is None:
|
|
return []
|
|
|
|
# Preprocess query
|
|
tokenized_query = BM25Service.preprocess_text(query)
|
|
|
|
# Get scores
|
|
scores = BM25Service._bm25.get_scores(tokenized_query)
|
|
|
|
# Get top-k documents
|
|
top_indices = np.argsort(scores)[::-1][:top_k]
|
|
|
|
# Return (doc_id, score) pairs
|
|
results = []
|
|
for idx in top_indices:
|
|
if idx < len(BM25Service._doc_ids):
|
|
results.append((BM25Service._doc_ids[idx], scores[idx]))
|
|
|
|
return results
|
|
|
|
@staticmethod
|
|
def save_index() -> None:
|
|
"""Save BM25 index to disk."""
|
|
try:
|
|
# Create directory if it doesn't exist
|
|
os.makedirs(os.path.dirname(BM25Service._index_path), exist_ok=True)
|
|
|
|
# Save index
|
|
with open(BM25Service._index_path, 'wb') as f:
|
|
pickle.dump({
|
|
'bm25': BM25Service._bm25,
|
|
'corpus': BM25Service._corpus,
|
|
'doc_ids': BM25Service._doc_ids
|
|
}, f)
|
|
except Exception as e:
|
|
print(f"Error saving BM25 index: {e}")
|
|
|
|
@staticmethod
|
|
def load_index() -> None:
|
|
"""Load BM25 index from disk."""
|
|
try:
|
|
if os.path.exists(BM25Service._index_path):
|
|
with open(BM25Service._index_path, 'rb') as f:
|
|
data = pickle.load(f)
|
|
BM25Service._bm25 = data.get('bm25')
|
|
BM25Service._corpus = data.get('corpus', [])
|
|
BM25Service._doc_ids = data.get('doc_ids', [])
|
|
except Exception as e:
|
|
print(f"Error loading BM25 index: {e}")
|
|
# Initialize with empty index
|
|
BM25Service._bm25 = None
|
|
BM25Service._corpus = []
|
|
BM25Service._doc_ids = [] |