zulip_bot/sync_all_channels.py
2025-05-16 18:00:22 +04:00

580 lines
25 KiB
Python
Executable File

#!/usr/bin/env python
"""
Script to sync messages from all Zulip channels (except sandbox) to ChromaDB.
This script also excludes messages from IT_Bot and ai_bot users.
"""
import os
import sys
import argparse
import logging
import signal
import time
from datetime import datetime, timedelta
import pickle
# Add the current directory to the path so we can import the app module
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# Apply NumPy compatibility patch for ChromaDB
from app.utils import patch_chromadb_numpy
patch_chromadb_numpy()
from app import create_app
from app.db.zulip_service import ZulipDatabaseService
from app.db.chroma_service import ChromaDBService
from app.models.zulip import Message, Stream, Recipient, UserProfile
from sqlalchemy import and_, not_, or_
from app.db import get_db_session
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("sync_all_channels")
# Global flag for graceful shutdown
is_shutting_down = False
# Signal handler for CTRL+C
def signal_handler(sig, frame):
global is_shutting_down
logger.info("Received shutdown signal, completing current operation before exiting...")
is_shutting_down = True
# Register signal handler
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
class AllChannelSyncService:
"""Service for syncing messages from all channels except sandbox."""
# File to store the last synced message ID
_SYNC_STATE_FILE = "all_channels_sync_state.pickle"
def __init__(self, batch_size=200, include_direct_messages=False):
"""
Initialize the sync service.
Args:
batch_size (int): Number of messages to process in each batch
include_direct_messages (bool): Whether to include direct messages
"""
self.batch_size = batch_size
self.last_sync_time = None
self.last_message_id = None
self.state_dir = os.path.dirname(os.path.abspath(__file__))
self.channels_to_sync = []
self.include_direct_messages = include_direct_messages
# Load the last synced state if available
self._load_sync_state()
def _get_state_file_path(self):
"""Get the full path to the sync state file."""
return os.path.join(self.state_dir, self._SYNC_STATE_FILE)
def _load_sync_state(self):
"""Load the last sync state from disk."""
try:
state_file = self._get_state_file_path()
if os.path.exists(state_file):
with open(state_file, 'rb') as f:
state = pickle.load(f)
self.last_sync_time = state.get('last_sync_time')
self.last_message_id = state.get('last_message_id')
logger.info(f"Loaded sync state: last_sync_time={self.last_sync_time}, last_message_id={self.last_message_id}")
else:
logger.info("No previous sync state found, starting fresh")
except Exception as e:
logger.error(f"Error loading sync state: {e}")
def _save_sync_state(self, channel_counts=None):
"""Save the current sync state to disk."""
try:
state = {
'last_sync_time': self.last_sync_time,
'last_message_id': self.last_message_id
}
if channel_counts:
state['channel_counts'] = channel_counts
state_file = self._get_state_file_path()
# Save to a temporary file first, then rename to avoid corruption if interrupted
temp_file = state_file + '.temp'
with open(temp_file, 'wb') as f:
pickle.dump(state, f)
f.flush()
os.fsync(f.fileno()) # Ensure data is written to disk
# Rename the temp file to the actual state file (atomic operation)
os.rename(temp_file, state_file)
logger.info(f"Saved sync state: {state}")
except Exception as e:
logger.error(f"Error saving sync state: {e}")
def get_excluded_user_ids(self):
"""Get the user IDs of IT_Bot and ai_bot."""
session = get_db_session()
excluded_users = session.query(UserProfile).filter(
UserProfile.full_name.in_(['IT_Bot', 'ai_bot'])
).all()
excluded_user_ids = [user.id for user in excluded_users]
logger.info(f"Excluding messages from users: {[u.full_name for u in excluded_users]} (IDs: {excluded_user_ids})")
return excluded_user_ids
def get_sandbox_recipient_id(self):
"""Get the recipient ID for the sandbox channel."""
session = get_db_session()
sandbox_stream = session.query(Stream).filter(
Stream.name == 'sandbox'
).first()
if sandbox_stream:
logger.info(f"Excluding messages from sandbox channel (recipient_id={sandbox_stream.recipient_id})")
return sandbox_stream.recipient_id
else:
logger.warning("Sandbox channel not found")
return None
def get_channels_to_sync(self):
"""Get all active channels except sandbox with their recipient IDs."""
session = get_db_session()
sandbox_recipient_id = self.get_sandbox_recipient_id()
# Get all active streams
streams = session.query(Stream).filter(
Stream.deactivated == False
).all()
# Filter out sandbox
included_streams = [stream for stream in streams
if stream.recipient_id != sandbox_recipient_id]
# Create a list of channels to sync with their recipient IDs
channels = [(stream.name, stream.recipient_id) for stream in included_streams]
# Sort by channel name
channels.sort(key=lambda x: x[0])
# Print the list of channels
logger.info(f"Found {len(channels)} channels to sync:")
for channel_name, recipient_id in channels:
logger.info(f"- {channel_name} (recipient_id={recipient_id})")
self.channels_to_sync = channels
# Return just the recipient IDs for filtering
recipient_ids = [recipient_id for _, recipient_id in channels]
return recipient_ids
def get_messages_newer_than_id(self, message_id, excluded_user_ids, excluded_recipient_id):
"""Get messages with ID greater than the specified ID."""
session = get_db_session()
# Build filters
filters = [Message.id > message_id]
# Add filter for excluded users
if excluded_user_ids:
filters.append(not_(Message.sender_id.in_(excluded_user_ids)))
# Add filter for excluded recipient (sandbox)
if excluded_recipient_id:
filters.append(Message.recipient_id != excluded_recipient_id)
messages = session.query(Message).filter(
and_(*filters)
).order_by(Message.id.asc()).limit(self.batch_size).all()
return messages
def get_messages_for_timeframe(self, since, excluded_user_ids, excluded_recipient_id, limit=1000, all_messages=False):
"""
Get messages from the specified timeframe.
Args:
since (datetime): Get messages after this datetime
excluded_user_ids (list): User IDs to exclude
excluded_recipient_id (int): Recipient ID to exclude
limit (int): Maximum number of messages to return
all_messages (bool): If True, ignore the since parameter and get all messages
Returns:
list: List of Message objects
"""
session = get_db_session()
# Build filters
filters = []
# Add date filter if specified and not getting all messages
if since and not all_messages:
filters.append(Message.date_sent >= since)
# Add filter for excluded users
if excluded_user_ids:
filters.append(not_(Message.sender_id.in_(excluded_user_ids)))
# Add filter for excluded recipient (sandbox)
if excluded_recipient_id:
filters.append(Message.recipient_id != excluded_recipient_id)
# Get results
query = session.query(Message)
if filters:
query = query.filter(and_(*filters))
messages = query.order_by(Message.id.desc()).limit(limit).all()
return messages
def get_channel_message_counts(self, since, excluded_user_ids, excluded_recipient_id, all_messages=False):
"""Get message counts by channel for the specified timeframe."""
session = get_db_session()
# Build filters
filters = []
# Add date filter if specified and not getting all messages
if since and not all_messages:
filters.append(Message.date_sent >= since)
# Add filter for excluded users
if excluded_user_ids:
filters.append(not_(Message.sender_id.in_(excluded_user_ids)))
# Add filter for excluded recipient (sandbox)
if excluded_recipient_id:
filters.append(Message.recipient_id != excluded_recipient_id)
# Get all messages
query = session.query(Message)
if filters:
query = query.filter(and_(*filters))
messages = query.all()
# Count messages by channel
channel_counts = {}
for message in messages:
channel_name = ZulipDatabaseService.get_channel_name_for_message(message)
if channel_name:
if channel_name not in channel_counts:
channel_counts[channel_name] = 0
channel_counts[channel_name] += 1
# Sort by channel name
sorted_counts = {k: channel_counts[k] for k in sorted(channel_counts.keys())}
# Print the message counts by channel
logger.info(f"Message counts by channel:")
for channel, count in sorted_counts.items():
logger.info(f"- {channel}: {count} messages")
return sorted_counts
def sync_messages(self, days=None, force=False, max_messages=5000, all_messages=False):
"""
Sync messages from all Zulip channels to ChromaDB.
Args:
days (int): Number of days to look back for messages (default: use sync state)
force (bool): Whether to force sync all messages from the lookback period
max_messages (int): Maximum total number of messages to sync
all_messages (bool): If True, ignore date filtering and sync all messages
"""
global is_shutting_down
try:
# Get excluded user IDs (IT_Bot and ai_bot)
excluded_user_ids = self.get_excluded_user_ids()
# Get sandbox recipient ID to exclude
excluded_recipient_id = self.get_sandbox_recipient_id()
# Get all channels to sync and their recipient IDs
self.get_channels_to_sync()
# Reset sync state if forced
if force:
if all_messages:
self.last_sync_time = None
self.last_message_id = None
logger.info("Force syncing ALL messages regardless of date")
elif days:
self.last_sync_time = datetime.now() - timedelta(days=days)
self.last_message_id = None
logger.info(f"Force syncing messages from the last {days} days")
# Set default sync time if not set yet and not syncing all messages
if not self.last_sync_time and not all_messages and not force:
# Start with messages from the last 30 days if no previous sync
self.last_sync_time = datetime.now() - timedelta(days=30 if not days else days)
logger.info(f"No previous sync time, starting from {self.last_sync_time}")
# Count total messages to sync if forcing
total_messages = 0
if force:
since_date = None if all_messages else (datetime.now() - timedelta(days=days if days else 30))
all_messages_count = self.get_messages_for_timeframe(
since=since_date,
excluded_user_ids=excluded_user_ids,
excluded_recipient_id=excluded_recipient_id,
limit=max_messages,
all_messages=all_messages
)
total_messages = len(all_messages_count)
logger.info(f"Found a total of {total_messages} messages to sync")
# Get message counts by channel
self.get_channel_message_counts(since_date, excluded_user_ids, excluded_recipient_id, all_messages=all_messages)
# Run multiple batches of sync
total_synced = 0
already_exists_count = 0
highest_message_id = self.last_message_id or 0
batch_count = 0
# Track synced messages by channel
channel_sync_counts = {}
# Time to save state
last_save_time = time.time()
save_interval = 10 # Save state every 10 seconds
while not is_shutting_down:
batch_count += 1
logger.info(f"Running batch {batch_count}, synced {total_synced} messages so far")
# Get new messages
messages = []
if self.last_message_id:
# Get messages with ID greater than the last processed message ID
messages = self.get_messages_newer_than_id(
self.last_message_id,
excluded_user_ids,
excluded_recipient_id
)
else:
# Get messages since the last sync time or all messages
messages = self.get_messages_for_timeframe(
since=self.last_sync_time,
excluded_user_ids=excluded_user_ids,
excluded_recipient_id=excluded_recipient_id,
limit=self.batch_size,
all_messages=all_messages
)
if not messages:
logger.info("No new messages found to sync")
break
logger.info(f"Found {len(messages)} new messages to sync in batch {batch_count}")
# Process each message
synced_in_batch = 0
for message in messages:
# Check if we need to shutdown
if is_shutting_down:
logger.info("Shutdown requested, saving state and exiting...")
break
message_id = message.id
# Update highest message ID seen
if message_id > highest_message_id:
highest_message_id = message_id
channel_name = ZulipDatabaseService.get_channel_name_for_message(message)
sender_name = ZulipDatabaseService.get_sender_name_for_message(message)
# Skip excluded channels and users
if channel_name == "sandbox":
continue
if sender_name in ["IT_Bot", "ai_bot"]:
continue
# Skip direct messages unless explicitly included
if not self.include_direct_messages and channel_name in ["Direct Message", "Group Message"]:
logger.debug(f"Skipping {channel_name} message {message_id} (use --include-direct-messages to include)")
continue
# Check if this message already exists in ChromaDB to avoid duplicates
if ChromaDBService.message_exists(message_id):
already_exists_count += 1
logger.debug(f"Message {message_id} already exists in ChromaDB, skipping")
continue
# Handle None channel names
if channel_name is None:
channel_name = "Unknown Channel"
logger.warning(f"Found message {message_id} with None channel name, using '{channel_name}' instead")
# Add the message to ChromaDB
try:
success = ChromaDBService.add_message(
message_id=message_id,
content=message.content,
channel_name=channel_name,
subject=message.subject,
sender_name=sender_name,
date_sent=message.date_sent
)
if success:
synced_in_batch += 1
total_synced += 1
# Update channel counts
if channel_name not in channel_sync_counts:
channel_sync_counts[channel_name] = 0
channel_sync_counts[channel_name] += 1
# Update the last message ID after each successful addition
self.last_message_id = message_id
# Save state periodically
current_time = time.time()
if current_time - last_save_time > save_interval:
self.last_sync_time = datetime.now()
self._save_sync_state(channel_sync_counts)
last_save_time = current_time
else:
logger.warning(f"Failed to add message {message_id} to ChromaDB")
except Exception as e:
logger.error(f"Error adding message {message_id} to ChromaDB: {e}")
# Continue with next message
# Update the last sync time and message ID at the end of the batch
self.last_sync_time = datetime.now()
if highest_message_id > (self.last_message_id or 0):
self.last_message_id = highest_message_id
# Save the sync state after each batch
self._save_sync_state(channel_sync_counts)
last_save_time = time.time()
logger.info(f"Batch {batch_count} completed. Added {synced_in_batch} new messages to ChromaDB. " +
f"Total synced: {total_synced}. Last message ID: {self.last_message_id}")
# Check if we've reached the max messages limit
if total_synced >= max_messages:
logger.info(f"Reached max messages limit of {max_messages}")
break
# If this batch had fewer messages than the batch size, we're done
if len(messages) < self.batch_size:
logger.info("Fetched fewer messages than batch size, assuming all messages have been processed")
break
# Final state save with channel statistics
if is_shutting_down:
logger.info("Shutdown signal received, saving final state...")
# Print synced messages by channel
if channel_sync_counts:
logger.info("Messages synced by channel:")
try:
# Use a safe sorting method that handles None keys
sorted_items = sorted(channel_sync_counts.items(),
key=lambda item: item[0] if item[0] is not None else "")
for channel, count in sorted_items:
channel_name = channel if channel is not None else "Unknown Channel"
logger.info(f"- {channel_name}: {count} messages")
except Exception as e:
logger.warning(f"Error displaying channel stats: {e}")
# Fallback display without sorting
for channel, count in channel_sync_counts.items():
channel_name = channel if channel is not None else "Unknown Channel"
logger.info(f"- {channel_name}: {count} messages")
# Return the final stats
stats = {
'last_sync_time': self.last_sync_time,
'last_message_id': self.last_message_id,
'total_synced': total_synced,
'batches': batch_count,
'already_exists': already_exists_count,
'channel_counts': channel_sync_counts
}
logger.info(f"Sync completed. Current state: {stats}")
return stats
except Exception as e:
logger.error(f"Error syncing messages: {e}")
# Save state on error
self._save_sync_state()
return None
def main():
"""Main entry point."""
# Parse command line arguments
parser = argparse.ArgumentParser(description="Sync messages from all Zulip channels to ChromaDB")
parser.add_argument("--days", type=int, help="Number of days to look back for messages")
parser.add_argument("--force", action="store_true", help="Force sync all messages from the lookback period")
parser.add_argument("--batch-size", type=int, default=200, help="Number of messages to process in each batch")
parser.add_argument("--max-messages", type=int, default=10000, help="Maximum total number of messages to sync")
parser.add_argument("--include-direct-messages", action="store_true", help="Include direct and group messages in sync")
parser.add_argument("--all-messages", action="store_true", help="Sync all messages regardless of date")
args = parser.parse_args()
# Create the Flask app
app = create_app()
with app.app_context():
try:
# Initialize sync service
sync_service = AllChannelSyncService(
batch_size=args.batch_size,
include_direct_messages=args.include_direct_messages
)
# Sync messages
stats = sync_service.sync_messages(
days=args.days,
force=args.force,
max_messages=args.max_messages,
all_messages=args.all_messages
)
if stats:
channel_counts = stats.get('channel_counts', {})
print(f"\nSync completed at {datetime.now()}")
print(f"Last sync time: {stats['last_sync_time']}")
print(f"Last message ID: {stats['last_message_id']}")
print(f"Total messages synced: {stats['total_synced']}")
print(f"Number of batches: {stats['batches']}")
print(f"Messages already in DB: {stats['already_exists']}")
if channel_counts:
print("\nMessages synced by channel:")
try:
# Use a safe sorting method that handles None keys
sorted_items = sorted(channel_counts.items(),
key=lambda item: item[0] if item[0] is not None else "")
for channel, count in sorted_items:
channel_name = channel if channel is not None else "Unknown Channel"
print(f"- {channel_name}: {count} messages")
except Exception as e:
# Fallback display without sorting
for channel, count in channel_counts.items():
channel_name = channel if channel is not None else "Unknown Channel"
print(f"- {channel_name}: {count} messages")
except KeyboardInterrupt:
print("\nSync process interrupted by user. State has been saved.")
logger.info("Sync process interrupted by user. State has been saved.")
if __name__ == "__main__":
main()