zulip_bot/app/utils/contextual_retrieval/reranker_service.py
2025-05-16 18:00:22 +04:00

249 lines
8.5 KiB
Python

"""
Reranker Service for improving search results by reranking candidate documents.
This service uses a custom reranking approach combining multiple signals
to improve the relevance of search results.
"""
import re
import numpy as np
from typing import Dict, List, Optional, Tuple, Union
import logging
# Set up logging
logger = logging.getLogger("reranker_service")
class RerankerService:
"""Service for reranking search results using a custom approach."""
# Cache for reranked results
_rerank_cache = {}
@staticmethod
def rerank(query: str, documents: List[Dict], top_k: int = 20) -> List[Dict]:
"""
Rerank documents based on relevance to query using a multi-factor approach.
Args:
query (str): Query text
documents (List[Dict]): List of document dictionaries with 'id' and 'content'
top_k (int): Number of results to return
Returns:
List[Dict]: Reranked documents
"""
# Return all documents if there are fewer than top_k
if len(documents) <= top_k:
return documents
# Create cache key
cache_key = f"{query}_{sorted([doc.get('id', '') for doc in documents])}"
# Check if we have this reranking cached
if cache_key in RerankerService._rerank_cache:
return RerankerService._rerank_cache[cache_key][:top_k]
try:
# Prepare query
query_terms = RerankerService._tokenize(query)
query_lower = query.lower()
# Calculate multi-factor relevance score for each document
scored_docs = []
for doc in documents:
content = doc.get('content', '')
content_lower = content.lower()
# 1. Term frequency scoring (similar to BM25)
term_score = RerankerService._calculate_term_score(content_lower, query_terms)
# 2. Exact phrase matching
phrase_score = RerankerService._calculate_phrase_score(content_lower, query_lower)
# 3. Semantic similarity (use existing score if available)
semantic_score = RerankerService._get_semantic_score(doc)
# 4. Document position bonus
position_score = RerankerService._calculate_position_score(content_lower, query_terms)
# 5. Document length normalization
length_factor = RerankerService._calculate_length_factor(content)
# Calculate final combined score
# Weights can be adjusted based on performance
final_score = (
0.35 * term_score +
0.30 * phrase_score +
0.25 * semantic_score +
0.10 * position_score
) * length_factor
scored_doc = doc.copy()
scored_doc['score'] = final_score
scored_doc['_term_score'] = term_score
scored_doc['_phrase_score'] = phrase_score
scored_doc['_semantic_score'] = semantic_score
scored_doc['_position_score'] = position_score
scored_docs.append(scored_doc)
# Sort by final score (highest first)
scored_docs.sort(key=lambda x: x.get('score', 0), reverse=True)
# Take the top_k
result = scored_docs[:top_k]
# Clean up diagnostic scores before returning
for doc in result:
doc.pop('_term_score', None)
doc.pop('_phrase_score', None)
doc.pop('_semantic_score', None)
doc.pop('_position_score', None)
# Cache the results
RerankerService._rerank_cache[cache_key] = result
return result
except Exception as e:
logger.error(f"Error reranking documents: {e}")
# Fallback: simple sorting based on combined_score if available
documents.sort(key=lambda x: x.get('combined_score', 0), reverse=True)
return documents[:top_k]
@staticmethod
def _tokenize(text: str) -> List[str]:
"""
Tokenize a string into terms.
Args:
text (str): Text to tokenize
Returns:
List[str]: List of tokens
"""
# Simple tokenization by splitting on whitespace and removing punctuation
tokens = re.findall(r'\b\w+\b', text.lower())
return tokens
@staticmethod
def _calculate_term_score(content: str, query_terms: List[str]) -> float:
"""
Calculate term frequency score.
Args:
content (str): Document content
query_terms (List[str]): Query terms
Returns:
float: Term frequency score
"""
score = 0
content_tokens = RerankerService._tokenize(content)
# Simple term frequency calculation
for term in query_terms:
term_count = content_tokens.count(term)
score += term_count
# Normalize by document length
if len(content_tokens) > 0:
score = score / len(content_tokens)
return score
@staticmethod
def _calculate_phrase_score(content: str, query: str) -> float:
"""
Calculate exact phrase matching score.
Args:
content (str): Document content
query (str): Original query
Returns:
float: Phrase matching score
"""
# Count exact matches of the query in the content
exact_matches = content.count(query)
# Calculating score for sentence fragments
score = exact_matches * 2.0 # Higher weight for exact matches
# Check for partial matches if no exact matches
if exact_matches == 0 and len(query) > 5:
# Generate query n-grams (only for longer queries)
query_parts = [query[i:i+4] for i in range(0, len(query)-3)]
for part in query_parts:
if len(part) >= 4: # Only consider meaningful parts
score += 0.2 * content.count(part)
return min(score, 10.0) # Cap to avoid extremely high scores
@staticmethod
def _get_semantic_score(doc: Dict) -> float:
"""
Extract semantic similarity score from document.
Args:
doc (Dict): Document
Returns:
float: Semantic similarity score
"""
# Use vector_score if available (from vector search)
if 'vector_score' in doc:
return doc['vector_score']
# Use combined_score as fallback
if 'combined_score' in doc:
return doc['combined_score']
return 0.5 # Default middle value if no scores available
@staticmethod
def _calculate_position_score(content: str, query_terms: List[str]) -> float:
"""
Calculate score based on position of match in document.
Earlier matches often indicate higher relevance.
Args:
content (str): Document content
query_terms (List[str]): Query terms
Returns:
float: Position score
"""
score = 0
# Check for terms in the first 20% of the document
first_section = content[:int(len(content) * 0.2)]
for term in query_terms:
if term in first_section:
score += 0.5
return min(score, 1.0) # Normalize to maximum of 1.0
@staticmethod
def _calculate_length_factor(content: str) -> float:
"""
Calculate length normalization factor.
Prevents extremely short documents from ranking too high.
Args:
content (str): Document content
Returns:
float: Length normalization factor
"""
token_count = len(RerankerService._tokenize(content))
# Penalize very short documents
if token_count < 10:
return 0.7
# Slightly favor mid-sized documents
if 20 <= token_count <= 300:
return 1.1
return 1.0 # Neutral factor for other documents