160 lines
6.8 KiB
Python
160 lines
6.8 KiB
Python
"""
|
|
Hybrid Search Service that combines vector search and BM25 search.
|
|
|
|
This service implements hybrid search by combining results from vector-based
|
|
semantic search and BM25 lexical search using rank fusion.
|
|
"""
|
|
import numpy as np
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
from app.db.chroma_service import ChromaDBService
|
|
from app.utils.contextual_retrieval.bm25_service import BM25Service
|
|
from app.utils.contextual_retrieval.reranker_service import RerankerService
|
|
import logging
|
|
|
|
# Set up logging
|
|
logger = logging.getLogger("hybrid_search")
|
|
|
|
class HybridSearchService:
|
|
"""Service for hybrid search combining vector search and BM25."""
|
|
|
|
@staticmethod
|
|
def hybrid_search(query: str, n_results: int = 5, filter_criteria: Optional[Dict] = None,
|
|
rerank: bool = True, semantic_weight: float = 0.7) -> List[Dict]:
|
|
"""
|
|
Perform hybrid search using vector search and BM25.
|
|
|
|
Args:
|
|
query (str): Query text
|
|
n_results (int): Number of results to return
|
|
filter_criteria (Dict): Metadata filter criteria
|
|
rerank (bool): Whether to apply reranking
|
|
semantic_weight (float): Weight for semantic search (0-1)
|
|
|
|
Returns:
|
|
List[Dict]: Search results
|
|
"""
|
|
try:
|
|
# Get more results than requested for fusion
|
|
vector_n = n_results * 3
|
|
bm25_n = n_results * 3
|
|
|
|
# Perform vector search - use _internal_call=True to prevent circular imports
|
|
vector_results = ChromaDBService.search_similar(
|
|
query_text=query,
|
|
n_results=vector_n,
|
|
filter_criteria=filter_criteria,
|
|
_internal_call=True # This prevents circular calls
|
|
)
|
|
|
|
# Extract vector search results
|
|
vec_docs = []
|
|
if vector_results and 'documents' in vector_results and len(vector_results['documents']) > 0:
|
|
for i in range(len(vector_results['documents'][0])):
|
|
vec_docs.append({
|
|
'id': vector_results['ids'][0][i],
|
|
'content': vector_results['documents'][0][i],
|
|
'metadata': vector_results['metadatas'][0][i],
|
|
'vector_score': 1.0 - min(vector_results['distances'][0][i], 1.0),
|
|
'rank': i + 1 # 1-based rank
|
|
})
|
|
|
|
# Perform BM25 search
|
|
bm25_results = BM25Service.search(query, top_k=bm25_n)
|
|
|
|
# Extract BM25 search results and normalize scores
|
|
bm25_docs = []
|
|
if bm25_results:
|
|
# Get max score for normalization
|
|
max_score = max([score for _, score in bm25_results]) if bm25_results else 1.0
|
|
|
|
# Create a set of doc IDs already in vector results to avoid duplicate lookups
|
|
existing_doc_ids = {doc['id'] for doc in vec_docs}
|
|
|
|
for i, (doc_id, score) in enumerate(bm25_results):
|
|
# Skip duplicate lookups
|
|
if doc_id in existing_doc_ids:
|
|
continue
|
|
|
|
# Get document content from ChromaDB (if available)
|
|
try:
|
|
doc_data = ChromaDBService.get_message_by_id(doc_id)
|
|
if doc_data:
|
|
bm25_docs.append({
|
|
'id': doc_id,
|
|
'content': doc_data['content'],
|
|
'metadata': doc_data['metadata'],
|
|
'bm25_score': score / max_score if max_score > 0 else 0,
|
|
'rank': i + 1 # 1-based rank
|
|
})
|
|
except Exception as e:
|
|
logger.warning(f"Error retrieving document {doc_id}: {e}")
|
|
continue
|
|
|
|
# Combine results using reciprocal rank fusion
|
|
fused_docs = HybridSearchService._fuse_results(vec_docs, bm25_docs, semantic_weight)
|
|
|
|
# Apply reranking if requested
|
|
if rerank and len(fused_docs) > 0:
|
|
try:
|
|
return RerankerService.rerank(query, fused_docs, top_k=n_results)
|
|
except Exception as e:
|
|
logger.warning(f"Reranking failed: {e}, returning non-reranked results")
|
|
return fused_docs[:n_results]
|
|
|
|
# Otherwise just return the top n fused results
|
|
return fused_docs[:n_results]
|
|
except Exception as e:
|
|
logger.error(f"Error in hybrid search: {e}")
|
|
# Return empty results on error
|
|
return []
|
|
|
|
@staticmethod
|
|
def _fuse_results(vec_docs: List[Dict], bm25_docs: List[Dict],
|
|
semantic_weight: float = 0.7) -> List[Dict]:
|
|
"""
|
|
Fuse results from vector search and BM25 search.
|
|
|
|
Args:
|
|
vec_docs (List[Dict]): Vector search results
|
|
bm25_docs (List[Dict]): BM25 search results
|
|
semantic_weight (float): Weight for semantic search (0-1)
|
|
|
|
Returns:
|
|
List[Dict]: Fused search results
|
|
"""
|
|
# Create a map of document IDs to documents
|
|
doc_map = {}
|
|
|
|
# Process vector search results
|
|
for doc in vec_docs:
|
|
doc_id = doc['id']
|
|
if doc_id not in doc_map:
|
|
doc_map[doc_id] = doc.copy()
|
|
doc_map[doc_id]['combined_score'] = doc.get('vector_score', 0) * semantic_weight
|
|
else:
|
|
# Update existing document
|
|
doc_map[doc_id]['vector_score'] = doc.get('vector_score', 0)
|
|
doc_map[doc_id]['combined_score'] = (
|
|
doc_map[doc_id].get('combined_score', 0) +
|
|
doc.get('vector_score', 0) * semantic_weight
|
|
)
|
|
|
|
# Process BM25 search results
|
|
for doc in bm25_docs:
|
|
doc_id = doc['id']
|
|
if doc_id not in doc_map:
|
|
doc_map[doc_id] = doc.copy()
|
|
doc_map[doc_id]['combined_score'] = doc.get('bm25_score', 0) * (1 - semantic_weight)
|
|
else:
|
|
# Update existing document
|
|
doc_map[doc_id]['bm25_score'] = doc.get('bm25_score', 0)
|
|
doc_map[doc_id]['combined_score'] = (
|
|
doc_map[doc_id].get('combined_score', 0) +
|
|
doc.get('bm25_score', 0) * (1 - semantic_weight)
|
|
)
|
|
|
|
# Convert map to list and sort by combined score
|
|
results = list(doc_map.values())
|
|
results.sort(key=lambda x: x.get('combined_score', 0), reverse=True)
|
|
|
|
return results |