580 lines
25 KiB
Python
Executable File
580 lines
25 KiB
Python
Executable File
#!/usr/bin/env python
|
|
"""
|
|
Script to sync messages from all Zulip channels (except sandbox) to ChromaDB.
|
|
This script also excludes messages from IT_Bot and ai_bot users.
|
|
"""
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import logging
|
|
import signal
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
import pickle
|
|
|
|
# Add the current directory to the path so we can import the app module
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
# Apply NumPy compatibility patch for ChromaDB
|
|
from app.utils import patch_chromadb_numpy
|
|
patch_chromadb_numpy()
|
|
|
|
from app import create_app
|
|
from app.db.zulip_service import ZulipDatabaseService
|
|
from app.db.chroma_service import ChromaDBService
|
|
from app.models.zulip import Message, Stream, Recipient, UserProfile
|
|
from sqlalchemy import and_, not_, or_
|
|
from app.db import get_db_session
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
logger = logging.getLogger("sync_all_channels")
|
|
|
|
# Global flag for graceful shutdown
|
|
is_shutting_down = False
|
|
|
|
# Signal handler for CTRL+C
|
|
def signal_handler(sig, frame):
|
|
global is_shutting_down
|
|
logger.info("Received shutdown signal, completing current operation before exiting...")
|
|
is_shutting_down = True
|
|
|
|
# Register signal handler
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
class AllChannelSyncService:
|
|
"""Service for syncing messages from all channels except sandbox."""
|
|
|
|
# File to store the last synced message ID
|
|
_SYNC_STATE_FILE = "all_channels_sync_state.pickle"
|
|
|
|
def __init__(self, batch_size=200, include_direct_messages=False):
|
|
"""
|
|
Initialize the sync service.
|
|
|
|
Args:
|
|
batch_size (int): Number of messages to process in each batch
|
|
include_direct_messages (bool): Whether to include direct messages
|
|
"""
|
|
self.batch_size = batch_size
|
|
self.last_sync_time = None
|
|
self.last_message_id = None
|
|
self.state_dir = os.path.dirname(os.path.abspath(__file__))
|
|
self.channels_to_sync = []
|
|
self.include_direct_messages = include_direct_messages
|
|
|
|
# Load the last synced state if available
|
|
self._load_sync_state()
|
|
|
|
def _get_state_file_path(self):
|
|
"""Get the full path to the sync state file."""
|
|
return os.path.join(self.state_dir, self._SYNC_STATE_FILE)
|
|
|
|
def _load_sync_state(self):
|
|
"""Load the last sync state from disk."""
|
|
try:
|
|
state_file = self._get_state_file_path()
|
|
if os.path.exists(state_file):
|
|
with open(state_file, 'rb') as f:
|
|
state = pickle.load(f)
|
|
self.last_sync_time = state.get('last_sync_time')
|
|
self.last_message_id = state.get('last_message_id')
|
|
logger.info(f"Loaded sync state: last_sync_time={self.last_sync_time}, last_message_id={self.last_message_id}")
|
|
else:
|
|
logger.info("No previous sync state found, starting fresh")
|
|
except Exception as e:
|
|
logger.error(f"Error loading sync state: {e}")
|
|
|
|
def _save_sync_state(self, channel_counts=None):
|
|
"""Save the current sync state to disk."""
|
|
try:
|
|
state = {
|
|
'last_sync_time': self.last_sync_time,
|
|
'last_message_id': self.last_message_id
|
|
}
|
|
|
|
if channel_counts:
|
|
state['channel_counts'] = channel_counts
|
|
|
|
state_file = self._get_state_file_path()
|
|
|
|
# Save to a temporary file first, then rename to avoid corruption if interrupted
|
|
temp_file = state_file + '.temp'
|
|
with open(temp_file, 'wb') as f:
|
|
pickle.dump(state, f)
|
|
f.flush()
|
|
os.fsync(f.fileno()) # Ensure data is written to disk
|
|
|
|
# Rename the temp file to the actual state file (atomic operation)
|
|
os.rename(temp_file, state_file)
|
|
|
|
logger.info(f"Saved sync state: {state}")
|
|
except Exception as e:
|
|
logger.error(f"Error saving sync state: {e}")
|
|
|
|
def get_excluded_user_ids(self):
|
|
"""Get the user IDs of IT_Bot and ai_bot."""
|
|
session = get_db_session()
|
|
excluded_users = session.query(UserProfile).filter(
|
|
UserProfile.full_name.in_(['IT_Bot', 'ai_bot'])
|
|
).all()
|
|
|
|
excluded_user_ids = [user.id for user in excluded_users]
|
|
logger.info(f"Excluding messages from users: {[u.full_name for u in excluded_users]} (IDs: {excluded_user_ids})")
|
|
return excluded_user_ids
|
|
|
|
def get_sandbox_recipient_id(self):
|
|
"""Get the recipient ID for the sandbox channel."""
|
|
session = get_db_session()
|
|
sandbox_stream = session.query(Stream).filter(
|
|
Stream.name == 'sandbox'
|
|
).first()
|
|
|
|
if sandbox_stream:
|
|
logger.info(f"Excluding messages from sandbox channel (recipient_id={sandbox_stream.recipient_id})")
|
|
return sandbox_stream.recipient_id
|
|
else:
|
|
logger.warning("Sandbox channel not found")
|
|
return None
|
|
|
|
def get_channels_to_sync(self):
|
|
"""Get all active channels except sandbox with their recipient IDs."""
|
|
session = get_db_session()
|
|
sandbox_recipient_id = self.get_sandbox_recipient_id()
|
|
|
|
# Get all active streams
|
|
streams = session.query(Stream).filter(
|
|
Stream.deactivated == False
|
|
).all()
|
|
|
|
# Filter out sandbox
|
|
included_streams = [stream for stream in streams
|
|
if stream.recipient_id != sandbox_recipient_id]
|
|
|
|
# Create a list of channels to sync with their recipient IDs
|
|
channels = [(stream.name, stream.recipient_id) for stream in included_streams]
|
|
|
|
# Sort by channel name
|
|
channels.sort(key=lambda x: x[0])
|
|
|
|
# Print the list of channels
|
|
logger.info(f"Found {len(channels)} channels to sync:")
|
|
for channel_name, recipient_id in channels:
|
|
logger.info(f"- {channel_name} (recipient_id={recipient_id})")
|
|
|
|
self.channels_to_sync = channels
|
|
|
|
# Return just the recipient IDs for filtering
|
|
recipient_ids = [recipient_id for _, recipient_id in channels]
|
|
return recipient_ids
|
|
|
|
def get_messages_newer_than_id(self, message_id, excluded_user_ids, excluded_recipient_id):
|
|
"""Get messages with ID greater than the specified ID."""
|
|
session = get_db_session()
|
|
|
|
# Build filters
|
|
filters = [Message.id > message_id]
|
|
|
|
# Add filter for excluded users
|
|
if excluded_user_ids:
|
|
filters.append(not_(Message.sender_id.in_(excluded_user_ids)))
|
|
|
|
# Add filter for excluded recipient (sandbox)
|
|
if excluded_recipient_id:
|
|
filters.append(Message.recipient_id != excluded_recipient_id)
|
|
|
|
messages = session.query(Message).filter(
|
|
and_(*filters)
|
|
).order_by(Message.id.asc()).limit(self.batch_size).all()
|
|
|
|
return messages
|
|
|
|
def get_messages_for_timeframe(self, since, excluded_user_ids, excluded_recipient_id, limit=1000, all_messages=False):
|
|
"""
|
|
Get messages from the specified timeframe.
|
|
|
|
Args:
|
|
since (datetime): Get messages after this datetime
|
|
excluded_user_ids (list): User IDs to exclude
|
|
excluded_recipient_id (int): Recipient ID to exclude
|
|
limit (int): Maximum number of messages to return
|
|
all_messages (bool): If True, ignore the since parameter and get all messages
|
|
|
|
Returns:
|
|
list: List of Message objects
|
|
"""
|
|
session = get_db_session()
|
|
|
|
# Build filters
|
|
filters = []
|
|
|
|
# Add date filter if specified and not getting all messages
|
|
if since and not all_messages:
|
|
filters.append(Message.date_sent >= since)
|
|
|
|
# Add filter for excluded users
|
|
if excluded_user_ids:
|
|
filters.append(not_(Message.sender_id.in_(excluded_user_ids)))
|
|
|
|
# Add filter for excluded recipient (sandbox)
|
|
if excluded_recipient_id:
|
|
filters.append(Message.recipient_id != excluded_recipient_id)
|
|
|
|
# Get results
|
|
query = session.query(Message)
|
|
if filters:
|
|
query = query.filter(and_(*filters))
|
|
|
|
messages = query.order_by(Message.id.desc()).limit(limit).all()
|
|
|
|
return messages
|
|
|
|
def get_channel_message_counts(self, since, excluded_user_ids, excluded_recipient_id, all_messages=False):
|
|
"""Get message counts by channel for the specified timeframe."""
|
|
session = get_db_session()
|
|
|
|
# Build filters
|
|
filters = []
|
|
|
|
# Add date filter if specified and not getting all messages
|
|
if since and not all_messages:
|
|
filters.append(Message.date_sent >= since)
|
|
|
|
# Add filter for excluded users
|
|
if excluded_user_ids:
|
|
filters.append(not_(Message.sender_id.in_(excluded_user_ids)))
|
|
|
|
# Add filter for excluded recipient (sandbox)
|
|
if excluded_recipient_id:
|
|
filters.append(Message.recipient_id != excluded_recipient_id)
|
|
|
|
# Get all messages
|
|
query = session.query(Message)
|
|
if filters:
|
|
query = query.filter(and_(*filters))
|
|
|
|
messages = query.all()
|
|
|
|
# Count messages by channel
|
|
channel_counts = {}
|
|
for message in messages:
|
|
channel_name = ZulipDatabaseService.get_channel_name_for_message(message)
|
|
if channel_name:
|
|
if channel_name not in channel_counts:
|
|
channel_counts[channel_name] = 0
|
|
channel_counts[channel_name] += 1
|
|
|
|
# Sort by channel name
|
|
sorted_counts = {k: channel_counts[k] for k in sorted(channel_counts.keys())}
|
|
|
|
# Print the message counts by channel
|
|
logger.info(f"Message counts by channel:")
|
|
for channel, count in sorted_counts.items():
|
|
logger.info(f"- {channel}: {count} messages")
|
|
|
|
return sorted_counts
|
|
|
|
def sync_messages(self, days=None, force=False, max_messages=5000, all_messages=False):
|
|
"""
|
|
Sync messages from all Zulip channels to ChromaDB.
|
|
|
|
Args:
|
|
days (int): Number of days to look back for messages (default: use sync state)
|
|
force (bool): Whether to force sync all messages from the lookback period
|
|
max_messages (int): Maximum total number of messages to sync
|
|
all_messages (bool): If True, ignore date filtering and sync all messages
|
|
"""
|
|
global is_shutting_down
|
|
|
|
try:
|
|
# Get excluded user IDs (IT_Bot and ai_bot)
|
|
excluded_user_ids = self.get_excluded_user_ids()
|
|
|
|
# Get sandbox recipient ID to exclude
|
|
excluded_recipient_id = self.get_sandbox_recipient_id()
|
|
|
|
# Get all channels to sync and their recipient IDs
|
|
self.get_channels_to_sync()
|
|
|
|
# Reset sync state if forced
|
|
if force:
|
|
if all_messages:
|
|
self.last_sync_time = None
|
|
self.last_message_id = None
|
|
logger.info("Force syncing ALL messages regardless of date")
|
|
elif days:
|
|
self.last_sync_time = datetime.now() - timedelta(days=days)
|
|
self.last_message_id = None
|
|
logger.info(f"Force syncing messages from the last {days} days")
|
|
|
|
# Set default sync time if not set yet and not syncing all messages
|
|
if not self.last_sync_time and not all_messages and not force:
|
|
# Start with messages from the last 30 days if no previous sync
|
|
self.last_sync_time = datetime.now() - timedelta(days=30 if not days else days)
|
|
logger.info(f"No previous sync time, starting from {self.last_sync_time}")
|
|
|
|
# Count total messages to sync if forcing
|
|
total_messages = 0
|
|
if force:
|
|
since_date = None if all_messages else (datetime.now() - timedelta(days=days if days else 30))
|
|
all_messages_count = self.get_messages_for_timeframe(
|
|
since=since_date,
|
|
excluded_user_ids=excluded_user_ids,
|
|
excluded_recipient_id=excluded_recipient_id,
|
|
limit=max_messages,
|
|
all_messages=all_messages
|
|
)
|
|
total_messages = len(all_messages_count)
|
|
logger.info(f"Found a total of {total_messages} messages to sync")
|
|
|
|
# Get message counts by channel
|
|
self.get_channel_message_counts(since_date, excluded_user_ids, excluded_recipient_id, all_messages=all_messages)
|
|
|
|
# Run multiple batches of sync
|
|
total_synced = 0
|
|
already_exists_count = 0
|
|
highest_message_id = self.last_message_id or 0
|
|
batch_count = 0
|
|
|
|
# Track synced messages by channel
|
|
channel_sync_counts = {}
|
|
|
|
# Time to save state
|
|
last_save_time = time.time()
|
|
save_interval = 10 # Save state every 10 seconds
|
|
|
|
while not is_shutting_down:
|
|
batch_count += 1
|
|
logger.info(f"Running batch {batch_count}, synced {total_synced} messages so far")
|
|
|
|
# Get new messages
|
|
messages = []
|
|
if self.last_message_id:
|
|
# Get messages with ID greater than the last processed message ID
|
|
messages = self.get_messages_newer_than_id(
|
|
self.last_message_id,
|
|
excluded_user_ids,
|
|
excluded_recipient_id
|
|
)
|
|
else:
|
|
# Get messages since the last sync time or all messages
|
|
messages = self.get_messages_for_timeframe(
|
|
since=self.last_sync_time,
|
|
excluded_user_ids=excluded_user_ids,
|
|
excluded_recipient_id=excluded_recipient_id,
|
|
limit=self.batch_size,
|
|
all_messages=all_messages
|
|
)
|
|
|
|
if not messages:
|
|
logger.info("No new messages found to sync")
|
|
break
|
|
|
|
logger.info(f"Found {len(messages)} new messages to sync in batch {batch_count}")
|
|
|
|
# Process each message
|
|
synced_in_batch = 0
|
|
for message in messages:
|
|
# Check if we need to shutdown
|
|
if is_shutting_down:
|
|
logger.info("Shutdown requested, saving state and exiting...")
|
|
break
|
|
|
|
message_id = message.id
|
|
|
|
# Update highest message ID seen
|
|
if message_id > highest_message_id:
|
|
highest_message_id = message_id
|
|
|
|
channel_name = ZulipDatabaseService.get_channel_name_for_message(message)
|
|
sender_name = ZulipDatabaseService.get_sender_name_for_message(message)
|
|
|
|
# Skip excluded channels and users
|
|
if channel_name == "sandbox":
|
|
continue
|
|
|
|
if sender_name in ["IT_Bot", "ai_bot"]:
|
|
continue
|
|
|
|
# Skip direct messages unless explicitly included
|
|
if not self.include_direct_messages and channel_name in ["Direct Message", "Group Message"]:
|
|
logger.debug(f"Skipping {channel_name} message {message_id} (use --include-direct-messages to include)")
|
|
continue
|
|
|
|
# Check if this message already exists in ChromaDB to avoid duplicates
|
|
if ChromaDBService.message_exists(message_id):
|
|
already_exists_count += 1
|
|
logger.debug(f"Message {message_id} already exists in ChromaDB, skipping")
|
|
continue
|
|
|
|
# Handle None channel names
|
|
if channel_name is None:
|
|
channel_name = "Unknown Channel"
|
|
logger.warning(f"Found message {message_id} with None channel name, using '{channel_name}' instead")
|
|
|
|
# Add the message to ChromaDB
|
|
try:
|
|
success = ChromaDBService.add_message(
|
|
message_id=message_id,
|
|
content=message.content,
|
|
channel_name=channel_name,
|
|
subject=message.subject,
|
|
sender_name=sender_name,
|
|
date_sent=message.date_sent
|
|
)
|
|
|
|
if success:
|
|
synced_in_batch += 1
|
|
total_synced += 1
|
|
|
|
# Update channel counts
|
|
if channel_name not in channel_sync_counts:
|
|
channel_sync_counts[channel_name] = 0
|
|
channel_sync_counts[channel_name] += 1
|
|
|
|
# Update the last message ID after each successful addition
|
|
self.last_message_id = message_id
|
|
|
|
# Save state periodically
|
|
current_time = time.time()
|
|
if current_time - last_save_time > save_interval:
|
|
self.last_sync_time = datetime.now()
|
|
self._save_sync_state(channel_sync_counts)
|
|
last_save_time = current_time
|
|
|
|
else:
|
|
logger.warning(f"Failed to add message {message_id} to ChromaDB")
|
|
except Exception as e:
|
|
logger.error(f"Error adding message {message_id} to ChromaDB: {e}")
|
|
# Continue with next message
|
|
|
|
# Update the last sync time and message ID at the end of the batch
|
|
self.last_sync_time = datetime.now()
|
|
if highest_message_id > (self.last_message_id or 0):
|
|
self.last_message_id = highest_message_id
|
|
|
|
# Save the sync state after each batch
|
|
self._save_sync_state(channel_sync_counts)
|
|
last_save_time = time.time()
|
|
|
|
logger.info(f"Batch {batch_count} completed. Added {synced_in_batch} new messages to ChromaDB. " +
|
|
f"Total synced: {total_synced}. Last message ID: {self.last_message_id}")
|
|
|
|
# Check if we've reached the max messages limit
|
|
if total_synced >= max_messages:
|
|
logger.info(f"Reached max messages limit of {max_messages}")
|
|
break
|
|
|
|
# If this batch had fewer messages than the batch size, we're done
|
|
if len(messages) < self.batch_size:
|
|
logger.info("Fetched fewer messages than batch size, assuming all messages have been processed")
|
|
break
|
|
|
|
# Final state save with channel statistics
|
|
if is_shutting_down:
|
|
logger.info("Shutdown signal received, saving final state...")
|
|
|
|
# Print synced messages by channel
|
|
if channel_sync_counts:
|
|
logger.info("Messages synced by channel:")
|
|
try:
|
|
# Use a safe sorting method that handles None keys
|
|
sorted_items = sorted(channel_sync_counts.items(),
|
|
key=lambda item: item[0] if item[0] is not None else "")
|
|
|
|
for channel, count in sorted_items:
|
|
channel_name = channel if channel is not None else "Unknown Channel"
|
|
logger.info(f"- {channel_name}: {count} messages")
|
|
except Exception as e:
|
|
logger.warning(f"Error displaying channel stats: {e}")
|
|
# Fallback display without sorting
|
|
for channel, count in channel_sync_counts.items():
|
|
channel_name = channel if channel is not None else "Unknown Channel"
|
|
logger.info(f"- {channel_name}: {count} messages")
|
|
|
|
# Return the final stats
|
|
stats = {
|
|
'last_sync_time': self.last_sync_time,
|
|
'last_message_id': self.last_message_id,
|
|
'total_synced': total_synced,
|
|
'batches': batch_count,
|
|
'already_exists': already_exists_count,
|
|
'channel_counts': channel_sync_counts
|
|
}
|
|
|
|
logger.info(f"Sync completed. Current state: {stats}")
|
|
return stats
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error syncing messages: {e}")
|
|
# Save state on error
|
|
self._save_sync_state()
|
|
return None
|
|
|
|
def main():
|
|
"""Main entry point."""
|
|
# Parse command line arguments
|
|
parser = argparse.ArgumentParser(description="Sync messages from all Zulip channels to ChromaDB")
|
|
parser.add_argument("--days", type=int, help="Number of days to look back for messages")
|
|
parser.add_argument("--force", action="store_true", help="Force sync all messages from the lookback period")
|
|
parser.add_argument("--batch-size", type=int, default=200, help="Number of messages to process in each batch")
|
|
parser.add_argument("--max-messages", type=int, default=10000, help="Maximum total number of messages to sync")
|
|
parser.add_argument("--include-direct-messages", action="store_true", help="Include direct and group messages in sync")
|
|
parser.add_argument("--all-messages", action="store_true", help="Sync all messages regardless of date")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create the Flask app
|
|
app = create_app()
|
|
|
|
with app.app_context():
|
|
try:
|
|
# Initialize sync service
|
|
sync_service = AllChannelSyncService(
|
|
batch_size=args.batch_size,
|
|
include_direct_messages=args.include_direct_messages
|
|
)
|
|
|
|
# Sync messages
|
|
stats = sync_service.sync_messages(
|
|
days=args.days,
|
|
force=args.force,
|
|
max_messages=args.max_messages,
|
|
all_messages=args.all_messages
|
|
)
|
|
|
|
if stats:
|
|
channel_counts = stats.get('channel_counts', {})
|
|
|
|
print(f"\nSync completed at {datetime.now()}")
|
|
print(f"Last sync time: {stats['last_sync_time']}")
|
|
print(f"Last message ID: {stats['last_message_id']}")
|
|
print(f"Total messages synced: {stats['total_synced']}")
|
|
print(f"Number of batches: {stats['batches']}")
|
|
print(f"Messages already in DB: {stats['already_exists']}")
|
|
|
|
if channel_counts:
|
|
print("\nMessages synced by channel:")
|
|
try:
|
|
# Use a safe sorting method that handles None keys
|
|
sorted_items = sorted(channel_counts.items(),
|
|
key=lambda item: item[0] if item[0] is not None else "")
|
|
|
|
for channel, count in sorted_items:
|
|
channel_name = channel if channel is not None else "Unknown Channel"
|
|
print(f"- {channel_name}: {count} messages")
|
|
except Exception as e:
|
|
# Fallback display without sorting
|
|
for channel, count in channel_counts.items():
|
|
channel_name = channel if channel is not None else "Unknown Channel"
|
|
print(f"- {channel_name}: {count} messages")
|
|
except KeyboardInterrupt:
|
|
print("\nSync process interrupted by user. State has been saved.")
|
|
logger.info("Sync process interrupted by user. State has been saved.")
|
|
|
|
if __name__ == "__main__":
|
|
main() |