418 lines
17 KiB
Python
418 lines
17 KiB
Python
"""
|
|
Service for storing and retrieving embedded messages in ChromaDB.
|
|
"""
|
|
import json
|
|
from datetime import datetime
|
|
from typing import List, Dict, Any, Optional, Union
|
|
import chromadb
|
|
from chromadb.utils import embedding_functions
|
|
from app.db import get_chroma_collection
|
|
from app.utils.embeddings import EmbeddingService
|
|
from app.utils.contextual_retrieval.context_service import ContextService
|
|
from app.utils.contextual_retrieval.bm25_service import BM25Service
|
|
from app.config import Config
|
|
import logging
|
|
|
|
# Set up logging
|
|
logger = logging.getLogger("chroma_service")
|
|
|
|
class CustomEmbeddingFunction(embedding_functions.EmbeddingFunction):
|
|
"""Custom embedding function using our EmbeddingService."""
|
|
|
|
def __init__(self, use_nomic: bool = True):
|
|
"""
|
|
Initialize the custom embedding function.
|
|
|
|
Args:
|
|
use_nomic: Whether to use Nomic (True) or Ollama (False) for embeddings
|
|
"""
|
|
self.use_nomic = use_nomic
|
|
|
|
def __call__(self, texts: List[str]) -> List[List[float]]:
|
|
"""
|
|
Generate embeddings for a list of texts.
|
|
|
|
Args:
|
|
texts: List of texts to generate embeddings for
|
|
|
|
Returns:
|
|
List of embeddings as float arrays
|
|
"""
|
|
return EmbeddingService.get_embeddings(texts, use_nomic=self.use_nomic)
|
|
|
|
class ChromaDBService:
|
|
"""Service for storing and retrieving embedded messages in ChromaDB."""
|
|
|
|
# Use Ollama embeddings by default for reliability
|
|
_embedding_function = CustomEmbeddingFunction(use_nomic=False)
|
|
|
|
@staticmethod
|
|
def format_message_content(content, channel_name, subject, sender_name, date_sent):
|
|
"""
|
|
Format message content with metadata but without contextual enrichment.
|
|
|
|
Args:
|
|
content (str): Original message content
|
|
channel_name (str): Name of the channel
|
|
subject (str): Subject of the message
|
|
sender_name (str): Name of the sender
|
|
date_sent (datetime): Date the message was sent
|
|
|
|
Returns:
|
|
str: Formatted message content with basic metadata
|
|
"""
|
|
# Format date in a readable format
|
|
date_str = date_sent.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
# Replace None values with empty strings
|
|
content = content or ""
|
|
channel_name = channel_name or "Unknown Channel"
|
|
subject = subject or "No Subject"
|
|
sender_name = sender_name or "Unknown Sender"
|
|
|
|
# Return plain content with minimal metadata prefix
|
|
return f"Channel: {channel_name} | Subject: {subject} | Sent by: {sender_name} | Date: {date_str}\n\n{content}"
|
|
|
|
@staticmethod
|
|
def sanitize_metadata(metadata):
|
|
"""
|
|
Sanitize metadata to ensure no None values.
|
|
|
|
Args:
|
|
metadata (dict): Metadata dictionary
|
|
|
|
Returns:
|
|
dict: Sanitized metadata with no None values
|
|
"""
|
|
sanitized = {}
|
|
for key, value in metadata.items():
|
|
if value is None:
|
|
if key == "channel":
|
|
sanitized[key] = "Unknown Channel"
|
|
elif key == "subject":
|
|
sanitized[key] = "No Subject"
|
|
elif key == "sender":
|
|
sanitized[key] = "Unknown Sender"
|
|
elif key == "timestamp":
|
|
sanitized[key] = datetime.now().isoformat()
|
|
else:
|
|
sanitized[key] = ""
|
|
else:
|
|
sanitized[key] = value
|
|
return sanitized
|
|
|
|
@staticmethod
|
|
def add_message(message_id, content, channel_name, subject, sender_name, date_sent):
|
|
"""
|
|
Add a message to the ChromaDB collection with contextual information.
|
|
|
|
Args:
|
|
message_id (str): ID of the message
|
|
content (str): Content of the message
|
|
channel_name (str): Name of the channel
|
|
subject (str): Subject of the message
|
|
sender_name (str): Name of the sender
|
|
date_sent (datetime): Date the message was sent
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
# Check if message already exists to avoid duplicates
|
|
if ChromaDBService.message_exists(message_id):
|
|
logger.info(f"Message ID {message_id} already exists in ChromaDB, skipping")
|
|
return True
|
|
|
|
collection = get_chroma_collection()
|
|
|
|
# Create metadata and sanitize to prevent None values
|
|
metadata = {
|
|
"channel": channel_name,
|
|
"subject": subject,
|
|
"sender": sender_name,
|
|
"timestamp": date_sent.isoformat() if date_sent else datetime.now().isoformat(),
|
|
"source": "zulip"
|
|
}
|
|
|
|
# Sanitize metadata to replace None values
|
|
metadata = ChromaDBService.sanitize_metadata(metadata)
|
|
|
|
# Format the content to include structured context information
|
|
formatted_content = ChromaDBService.format_message_content(
|
|
content, channel_name, subject, sender_name, date_sent
|
|
)
|
|
|
|
# Generate embeddings using our custom embedding function
|
|
embeddings = ChromaDBService._embedding_function([formatted_content])
|
|
|
|
# Add to ChromaDB
|
|
collection.add(
|
|
ids=[str(message_id)],
|
|
documents=[formatted_content],
|
|
metadatas=[metadata],
|
|
embeddings=embeddings if embeddings else None
|
|
)
|
|
|
|
# Also add to BM25 index for hybrid search
|
|
BM25Service.add_document(formatted_content, str(message_id))
|
|
|
|
logger.info(f"Successfully added message ID {message_id} to ChromaDB")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error adding message to ChromaDB: {e}")
|
|
return False
|
|
|
|
@staticmethod
|
|
def search_similar(query_text, n_results=5, filter_criteria=None, use_hybrid=True, _internal_call=False):
|
|
"""
|
|
Search for similar messages in ChromaDB with improved contextual relevance.
|
|
|
|
Args:
|
|
query_text (str): Text to search for
|
|
n_results (int): Number of results to return
|
|
filter_criteria (dict): Metadata filter criteria
|
|
use_hybrid (bool): Whether to use hybrid search or just vector search
|
|
_internal_call (bool): Internal parameter to prevent circular calls
|
|
|
|
Returns:
|
|
dict: Search results from ChromaDB
|
|
"""
|
|
try:
|
|
logger.info("Using temporary ChromaDB client to prevent duplicate embeddings")
|
|
collection = get_chroma_collection()
|
|
|
|
# If hybrid search is disabled or this is an internal call from HybridSearchService,
|
|
# fall back to vector-only search to prevent circular references
|
|
if not use_hybrid or _internal_call:
|
|
try:
|
|
# Generate query embedding locally instead of using the collection's embedding function
|
|
query_embedding = EmbeddingService.get_ollama_embeddings([query_text])[0]
|
|
|
|
# Perform search with embeddings using API directly to prevent collection modifications
|
|
# Create a temporary read-only client just for search to avoid modifying the main collection
|
|
temp_client = chromadb.PersistentClient(
|
|
path=Config.CHROMADB_PATH,
|
|
settings=chromadb.Settings(
|
|
anonymized_telemetry=False,
|
|
is_persistent=True,
|
|
allow_reset=False
|
|
)
|
|
)
|
|
|
|
# Get the existing collection without an embedding function
|
|
temp_collection = temp_client.get_collection(
|
|
name=Config.CHROMADB_COLLECTION or "zulip_messages"
|
|
)
|
|
|
|
# Perform search with embeddings
|
|
results = temp_collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=n_results,
|
|
where=filter_criteria,
|
|
include=["metadatas", "documents", "distances"]
|
|
)
|
|
|
|
# Close temporary client
|
|
del temp_client
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error with vector search: {e}")
|
|
logger.info("Falling back to direct text query")
|
|
|
|
# Fallback to direct text query if embeddings fail
|
|
# But use a similar approach with a temporary client
|
|
try:
|
|
# Create temporary client just for search
|
|
temp_client = chromadb.PersistentClient(
|
|
path=Config.CHROMADB_PATH,
|
|
settings=chromadb.Settings(
|
|
anonymized_telemetry=False,
|
|
is_persistent=True,
|
|
allow_reset=False
|
|
)
|
|
)
|
|
|
|
# Get the existing collection without an embedding function
|
|
temp_collection = temp_client.get_collection(
|
|
name=Config.CHROMADB_COLLECTION or "zulip_messages"
|
|
)
|
|
|
|
# Use CustomEmbeddingFunction for just this query
|
|
from app.db.chroma_service import CustomEmbeddingFunction
|
|
embedding_func = CustomEmbeddingFunction(use_nomic=False)
|
|
|
|
# Get embedding for query
|
|
query_embedding = embedding_func([query_text])[0]
|
|
|
|
# Search using the embedding
|
|
results = temp_collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=n_results,
|
|
where=filter_criteria,
|
|
include=["metadatas", "documents", "distances"]
|
|
)
|
|
|
|
# Close temporary client
|
|
del temp_client
|
|
|
|
return results
|
|
|
|
except Exception as text_query_error:
|
|
logger.error(f"Error with text query: {text_query_error}")
|
|
# Last resort, just get all documents and do a simple text search
|
|
all_docs = collection.get(where=filter_criteria, include=["metadatas", "documents", "embeddings"])
|
|
# Return an empty result structure if no docs found
|
|
if not all_docs or not all_docs.get('ids'):
|
|
return {"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]}
|
|
return {"ids": [all_docs['ids'][:n_results]],
|
|
"documents": [all_docs['documents'][:n_results]],
|
|
"metadatas": [all_docs['metadatas'][:n_results]],
|
|
"distances": [[1.0] * min(n_results, len(all_docs['ids']))]}
|
|
|
|
# Use BM25 + vector search from hybrid search module
|
|
# We're not calling it directly here to avoid circular imports
|
|
try:
|
|
from app.utils.contextual_retrieval.hybrid_search import HybridSearchService
|
|
|
|
# Use hybrid search
|
|
results = HybridSearchService.hybrid_search(
|
|
query=query_text,
|
|
n_results=n_results,
|
|
filter_criteria=filter_criteria,
|
|
rerank=True # Enable reranking
|
|
)
|
|
|
|
# Convert to ChromaDB query result format
|
|
formatted_results = {
|
|
'ids': [[doc['id'] for doc in results]],
|
|
'documents': [[doc['content'] for doc in results]],
|
|
'metadatas': [[doc.get('metadata', {}) for doc in results]],
|
|
'distances': [[1.0 - doc.get('combined_score', 0) for doc in results]]
|
|
}
|
|
|
|
return formatted_results
|
|
except ImportError:
|
|
logger.warning("Hybrid search module not available, falling back to vector search")
|
|
# Fall back to vector search if hybrid search module not available
|
|
|
|
# Create temporary client for search
|
|
temp_client = chromadb.PersistentClient(
|
|
path=Config.CHROMADB_PATH,
|
|
settings=chromadb.Settings(
|
|
anonymized_telemetry=False,
|
|
is_persistent=True,
|
|
allow_reset=False
|
|
)
|
|
)
|
|
|
|
# Get the existing collection without an embedding function
|
|
temp_collection = temp_client.get_collection(
|
|
name=Config.CHROMADB_COLLECTION or "zulip_messages"
|
|
)
|
|
|
|
# Generate embedding
|
|
query_embedding = EmbeddingService.get_ollama_embeddings([query_text])[0]
|
|
|
|
# Perform search
|
|
results = temp_collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=n_results,
|
|
where=filter_criteria,
|
|
include=["metadatas", "documents", "distances"]
|
|
)
|
|
|
|
# Close temporary client
|
|
del temp_client
|
|
|
|
return results
|
|
except Exception as e:
|
|
logger.error(f"Error searching ChromaDB: {e}")
|
|
# Return an empty result set rather than None
|
|
return {"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]}
|
|
|
|
@staticmethod
|
|
def delete_message(message_id):
|
|
"""
|
|
Delete a message from ChromaDB.
|
|
|
|
Args:
|
|
message_id (str): ID of the message to delete
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
try:
|
|
collection = get_chroma_collection()
|
|
collection.delete(ids=[str(message_id)])
|
|
|
|
# Also update BM25 index - for simplicity, we'll rebuild it from ChromaDB
|
|
# In a production scenario, you might want a more efficient approach
|
|
all_results = collection.get()
|
|
if all_results and all_results['ids']:
|
|
BM25Service.index_documents(all_results['documents'], all_results['ids'])
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error deleting message from ChromaDB: {e}")
|
|
return False
|
|
|
|
@staticmethod
|
|
def get_message_by_id(message_id):
|
|
"""
|
|
Get a message from ChromaDB by ID.
|
|
|
|
Args:
|
|
message_id (str): ID of the message to retrieve
|
|
|
|
Returns:
|
|
dict: Message data or None if not found
|
|
"""
|
|
try:
|
|
collection = get_chroma_collection()
|
|
result = collection.get(ids=[str(message_id)])
|
|
|
|
if result['ids'] and len(result['ids']) > 0:
|
|
return {
|
|
'id': result['ids'][0],
|
|
'content': result['documents'][0],
|
|
'metadata': result['metadatas'][0]
|
|
}
|
|
return None
|
|
except RecursionError:
|
|
logger.error(f"Recursion error when getting message ID {message_id} from ChromaDB")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting message from ChromaDB: {e}")
|
|
return None
|
|
|
|
@staticmethod
|
|
def message_exists(message_id):
|
|
"""
|
|
Check if a message exists in ChromaDB.
|
|
|
|
Args:
|
|
message_id (str): ID of the message to check
|
|
|
|
Returns:
|
|
bool: True if exists, False otherwise
|
|
"""
|
|
try:
|
|
collection = get_chroma_collection()
|
|
result = collection.get(ids=[str(message_id)], include=[])
|
|
|
|
return len(result['ids']) > 0
|
|
except Exception as e:
|
|
logger.error(f"Error checking if message exists in ChromaDB: {e}")
|
|
return False
|
|
|
|
@staticmethod
|
|
def switch_embedding_method(use_nomic: bool):
|
|
"""
|
|
Switch between Nomic and Ollama embedding methods.
|
|
|
|
Args:
|
|
use_nomic: Whether to use Nomic (True) or Ollama (False)
|
|
"""
|
|
ChromaDBService._embedding_function = CustomEmbeddingFunction(use_nomic=use_nomic) |