2025-05-16 18:00:22 +04:00

160 lines
6.8 KiB
Python

"""
Hybrid Search Service that combines vector search and BM25 search.
This service implements hybrid search by combining results from vector-based
semantic search and BM25 lexical search using rank fusion.
"""
import numpy as np
from typing import Dict, List, Optional, Tuple, Union
from app.db.chroma_service import ChromaDBService
from app.utils.contextual_retrieval.bm25_service import BM25Service
from app.utils.contextual_retrieval.reranker_service import RerankerService
import logging
# Set up logging
logger = logging.getLogger("hybrid_search")
class HybridSearchService:
"""Service for hybrid search combining vector search and BM25."""
@staticmethod
def hybrid_search(query: str, n_results: int = 5, filter_criteria: Optional[Dict] = None,
rerank: bool = True, semantic_weight: float = 0.7) -> List[Dict]:
"""
Perform hybrid search using vector search and BM25.
Args:
query (str): Query text
n_results (int): Number of results to return
filter_criteria (Dict): Metadata filter criteria
rerank (bool): Whether to apply reranking
semantic_weight (float): Weight for semantic search (0-1)
Returns:
List[Dict]: Search results
"""
try:
# Get more results than requested for fusion
vector_n = n_results * 3
bm25_n = n_results * 3
# Perform vector search - use _internal_call=True to prevent circular imports
vector_results = ChromaDBService.search_similar(
query_text=query,
n_results=vector_n,
filter_criteria=filter_criteria,
_internal_call=True # This prevents circular calls
)
# Extract vector search results
vec_docs = []
if vector_results and 'documents' in vector_results and len(vector_results['documents']) > 0:
for i in range(len(vector_results['documents'][0])):
vec_docs.append({
'id': vector_results['ids'][0][i],
'content': vector_results['documents'][0][i],
'metadata': vector_results['metadatas'][0][i],
'vector_score': 1.0 - min(vector_results['distances'][0][i], 1.0),
'rank': i + 1 # 1-based rank
})
# Perform BM25 search
bm25_results = BM25Service.search(query, top_k=bm25_n)
# Extract BM25 search results and normalize scores
bm25_docs = []
if bm25_results:
# Get max score for normalization
max_score = max([score for _, score in bm25_results]) if bm25_results else 1.0
# Create a set of doc IDs already in vector results to avoid duplicate lookups
existing_doc_ids = {doc['id'] for doc in vec_docs}
for i, (doc_id, score) in enumerate(bm25_results):
# Skip duplicate lookups
if doc_id in existing_doc_ids:
continue
# Get document content from ChromaDB (if available)
try:
doc_data = ChromaDBService.get_message_by_id(doc_id)
if doc_data:
bm25_docs.append({
'id': doc_id,
'content': doc_data['content'],
'metadata': doc_data['metadata'],
'bm25_score': score / max_score if max_score > 0 else 0,
'rank': i + 1 # 1-based rank
})
except Exception as e:
logger.warning(f"Error retrieving document {doc_id}: {e}")
continue
# Combine results using reciprocal rank fusion
fused_docs = HybridSearchService._fuse_results(vec_docs, bm25_docs, semantic_weight)
# Apply reranking if requested
if rerank and len(fused_docs) > 0:
try:
return RerankerService.rerank(query, fused_docs, top_k=n_results)
except Exception as e:
logger.warning(f"Reranking failed: {e}, returning non-reranked results")
return fused_docs[:n_results]
# Otherwise just return the top n fused results
return fused_docs[:n_results]
except Exception as e:
logger.error(f"Error in hybrid search: {e}")
# Return empty results on error
return []
@staticmethod
def _fuse_results(vec_docs: List[Dict], bm25_docs: List[Dict],
semantic_weight: float = 0.7) -> List[Dict]:
"""
Fuse results from vector search and BM25 search.
Args:
vec_docs (List[Dict]): Vector search results
bm25_docs (List[Dict]): BM25 search results
semantic_weight (float): Weight for semantic search (0-1)
Returns:
List[Dict]: Fused search results
"""
# Create a map of document IDs to documents
doc_map = {}
# Process vector search results
for doc in vec_docs:
doc_id = doc['id']
if doc_id not in doc_map:
doc_map[doc_id] = doc.copy()
doc_map[doc_id]['combined_score'] = doc.get('vector_score', 0) * semantic_weight
else:
# Update existing document
doc_map[doc_id]['vector_score'] = doc.get('vector_score', 0)
doc_map[doc_id]['combined_score'] = (
doc_map[doc_id].get('combined_score', 0) +
doc.get('vector_score', 0) * semantic_weight
)
# Process BM25 search results
for doc in bm25_docs:
doc_id = doc['id']
if doc_id not in doc_map:
doc_map[doc_id] = doc.copy()
doc_map[doc_id]['combined_score'] = doc.get('bm25_score', 0) * (1 - semantic_weight)
else:
# Update existing document
doc_map[doc_id]['bm25_score'] = doc.get('bm25_score', 0)
doc_map[doc_id]['combined_score'] = (
doc_map[doc_id].get('combined_score', 0) +
doc.get('bm25_score', 0) * (1 - semantic_weight)
)
# Convert map to list and sort by combined score
results = list(doc_map.values())
results.sort(key=lambda x: x.get('combined_score', 0), reverse=True)
return results