249 lines
8.5 KiB
Python
249 lines
8.5 KiB
Python
"""
|
|
Reranker Service for improving search results by reranking candidate documents.
|
|
|
|
This service uses a custom reranking approach combining multiple signals
|
|
to improve the relevance of search results.
|
|
"""
|
|
import re
|
|
import numpy as np
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
import logging
|
|
|
|
# Set up logging
|
|
logger = logging.getLogger("reranker_service")
|
|
|
|
class RerankerService:
|
|
"""Service for reranking search results using a custom approach."""
|
|
|
|
# Cache for reranked results
|
|
_rerank_cache = {}
|
|
|
|
@staticmethod
|
|
def rerank(query: str, documents: List[Dict], top_k: int = 20) -> List[Dict]:
|
|
"""
|
|
Rerank documents based on relevance to query using a multi-factor approach.
|
|
|
|
Args:
|
|
query (str): Query text
|
|
documents (List[Dict]): List of document dictionaries with 'id' and 'content'
|
|
top_k (int): Number of results to return
|
|
|
|
Returns:
|
|
List[Dict]: Reranked documents
|
|
"""
|
|
# Return all documents if there are fewer than top_k
|
|
if len(documents) <= top_k:
|
|
return documents
|
|
|
|
# Create cache key
|
|
cache_key = f"{query}_{sorted([doc.get('id', '') for doc in documents])}"
|
|
|
|
# Check if we have this reranking cached
|
|
if cache_key in RerankerService._rerank_cache:
|
|
return RerankerService._rerank_cache[cache_key][:top_k]
|
|
|
|
try:
|
|
# Prepare query
|
|
query_terms = RerankerService._tokenize(query)
|
|
query_lower = query.lower()
|
|
|
|
# Calculate multi-factor relevance score for each document
|
|
scored_docs = []
|
|
for doc in documents:
|
|
content = doc.get('content', '')
|
|
content_lower = content.lower()
|
|
|
|
# 1. Term frequency scoring (similar to BM25)
|
|
term_score = RerankerService._calculate_term_score(content_lower, query_terms)
|
|
|
|
# 2. Exact phrase matching
|
|
phrase_score = RerankerService._calculate_phrase_score(content_lower, query_lower)
|
|
|
|
# 3. Semantic similarity (use existing score if available)
|
|
semantic_score = RerankerService._get_semantic_score(doc)
|
|
|
|
# 4. Document position bonus
|
|
position_score = RerankerService._calculate_position_score(content_lower, query_terms)
|
|
|
|
# 5. Document length normalization
|
|
length_factor = RerankerService._calculate_length_factor(content)
|
|
|
|
# Calculate final combined score
|
|
# Weights can be adjusted based on performance
|
|
final_score = (
|
|
0.35 * term_score +
|
|
0.30 * phrase_score +
|
|
0.25 * semantic_score +
|
|
0.10 * position_score
|
|
) * length_factor
|
|
|
|
scored_doc = doc.copy()
|
|
scored_doc['score'] = final_score
|
|
scored_doc['_term_score'] = term_score
|
|
scored_doc['_phrase_score'] = phrase_score
|
|
scored_doc['_semantic_score'] = semantic_score
|
|
scored_doc['_position_score'] = position_score
|
|
|
|
scored_docs.append(scored_doc)
|
|
|
|
# Sort by final score (highest first)
|
|
scored_docs.sort(key=lambda x: x.get('score', 0), reverse=True)
|
|
|
|
# Take the top_k
|
|
result = scored_docs[:top_k]
|
|
|
|
# Clean up diagnostic scores before returning
|
|
for doc in result:
|
|
doc.pop('_term_score', None)
|
|
doc.pop('_phrase_score', None)
|
|
doc.pop('_semantic_score', None)
|
|
doc.pop('_position_score', None)
|
|
|
|
# Cache the results
|
|
RerankerService._rerank_cache[cache_key] = result
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error reranking documents: {e}")
|
|
|
|
# Fallback: simple sorting based on combined_score if available
|
|
documents.sort(key=lambda x: x.get('combined_score', 0), reverse=True)
|
|
return documents[:top_k]
|
|
|
|
@staticmethod
|
|
def _tokenize(text: str) -> List[str]:
|
|
"""
|
|
Tokenize a string into terms.
|
|
|
|
Args:
|
|
text (str): Text to tokenize
|
|
|
|
Returns:
|
|
List[str]: List of tokens
|
|
"""
|
|
# Simple tokenization by splitting on whitespace and removing punctuation
|
|
tokens = re.findall(r'\b\w+\b', text.lower())
|
|
return tokens
|
|
|
|
@staticmethod
|
|
def _calculate_term_score(content: str, query_terms: List[str]) -> float:
|
|
"""
|
|
Calculate term frequency score.
|
|
|
|
Args:
|
|
content (str): Document content
|
|
query_terms (List[str]): Query terms
|
|
|
|
Returns:
|
|
float: Term frequency score
|
|
"""
|
|
score = 0
|
|
content_tokens = RerankerService._tokenize(content)
|
|
|
|
# Simple term frequency calculation
|
|
for term in query_terms:
|
|
term_count = content_tokens.count(term)
|
|
score += term_count
|
|
|
|
# Normalize by document length
|
|
if len(content_tokens) > 0:
|
|
score = score / len(content_tokens)
|
|
|
|
return score
|
|
|
|
@staticmethod
|
|
def _calculate_phrase_score(content: str, query: str) -> float:
|
|
"""
|
|
Calculate exact phrase matching score.
|
|
|
|
Args:
|
|
content (str): Document content
|
|
query (str): Original query
|
|
|
|
Returns:
|
|
float: Phrase matching score
|
|
"""
|
|
# Count exact matches of the query in the content
|
|
exact_matches = content.count(query)
|
|
|
|
# Calculating score for sentence fragments
|
|
score = exact_matches * 2.0 # Higher weight for exact matches
|
|
|
|
# Check for partial matches if no exact matches
|
|
if exact_matches == 0 and len(query) > 5:
|
|
# Generate query n-grams (only for longer queries)
|
|
query_parts = [query[i:i+4] for i in range(0, len(query)-3)]
|
|
for part in query_parts:
|
|
if len(part) >= 4: # Only consider meaningful parts
|
|
score += 0.2 * content.count(part)
|
|
|
|
return min(score, 10.0) # Cap to avoid extremely high scores
|
|
|
|
@staticmethod
|
|
def _get_semantic_score(doc: Dict) -> float:
|
|
"""
|
|
Extract semantic similarity score from document.
|
|
|
|
Args:
|
|
doc (Dict): Document
|
|
|
|
Returns:
|
|
float: Semantic similarity score
|
|
"""
|
|
# Use vector_score if available (from vector search)
|
|
if 'vector_score' in doc:
|
|
return doc['vector_score']
|
|
|
|
# Use combined_score as fallback
|
|
if 'combined_score' in doc:
|
|
return doc['combined_score']
|
|
|
|
return 0.5 # Default middle value if no scores available
|
|
|
|
@staticmethod
|
|
def _calculate_position_score(content: str, query_terms: List[str]) -> float:
|
|
"""
|
|
Calculate score based on position of match in document.
|
|
Earlier matches often indicate higher relevance.
|
|
|
|
Args:
|
|
content (str): Document content
|
|
query_terms (List[str]): Query terms
|
|
|
|
Returns:
|
|
float: Position score
|
|
"""
|
|
score = 0
|
|
# Check for terms in the first 20% of the document
|
|
first_section = content[:int(len(content) * 0.2)]
|
|
|
|
for term in query_terms:
|
|
if term in first_section:
|
|
score += 0.5
|
|
|
|
return min(score, 1.0) # Normalize to maximum of 1.0
|
|
|
|
@staticmethod
|
|
def _calculate_length_factor(content: str) -> float:
|
|
"""
|
|
Calculate length normalization factor.
|
|
Prevents extremely short documents from ranking too high.
|
|
|
|
Args:
|
|
content (str): Document content
|
|
|
|
Returns:
|
|
float: Length normalization factor
|
|
"""
|
|
token_count = len(RerankerService._tokenize(content))
|
|
|
|
# Penalize very short documents
|
|
if token_count < 10:
|
|
return 0.7
|
|
|
|
# Slightly favor mid-sized documents
|
|
if 20 <= token_count <= 300:
|
|
return 1.1
|
|
|
|
return 1.0 # Neutral factor for other documents |