2025-05-16 18:00:22 +04:00

260 lines
8.0 KiB
Python

from chromadb.config import Settings, System
from chromadb.api import API
from chromadb.ingest import Producer
import chromadb.server.fastapi
from requests.exceptions import ConnectionError
import hypothesis
import tempfile
import os
import uvicorn
import time
import pytest
from typing import (
Generator,
Iterator,
List,
Optional,
Sequence,
Tuple,
Callable,
)
from typing_extensions import Protocol
import shutil
import logging
import socket
import multiprocessing
from chromadb.types import SeqId, SubmitEmbeddingRecord
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG) # This will only run when testing
logger = logging.getLogger(__name__)
hypothesis.settings.register_profile(
"dev",
deadline=45000,
suppress_health_check=[
hypothesis.HealthCheck.data_too_large,
hypothesis.HealthCheck.large_base_example,
hypothesis.HealthCheck.function_scoped_fixture,
],
)
hypothesis.settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "dev"))
def find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1] # type: ignore
def _run_server(
port: int, is_persistent: bool = False, persist_directory: Optional[str] = None
) -> None:
"""Run a Chroma server locally"""
if is_persistent and persist_directory:
settings = Settings(
chroma_api_impl="chromadb.api.segment.SegmentAPI",
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
is_persistent=is_persistent,
persist_directory=persist_directory,
allow_reset=True,
)
else:
settings = Settings(
chroma_api_impl="chromadb.api.segment.SegmentAPI",
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
is_persistent=False,
allow_reset=True,
)
server = chromadb.server.fastapi.FastAPI(settings)
uvicorn.run(server.app(), host="0.0.0.0", port=port, log_level="error")
def _await_server(api: API, attempts: int = 0) -> None:
try:
api.heartbeat()
except ConnectionError as e:
if attempts > 15:
logger.error("Test server failed to start after 15 attempts")
raise e
else:
logger.info("Waiting for server to start...")
time.sleep(4)
_await_server(api, attempts + 1)
def _fastapi_fixture(is_persistent: bool = False) -> Generator[System, None, None]:
"""Fixture generator that launches a server in a separate process, and yields a
fastapi client connect to it"""
port = find_free_port()
logger.info(f"Running test FastAPI server on port {port}")
ctx = multiprocessing.get_context("spawn")
args: Tuple[int, bool, Optional[str]] = (port, False, None)
persist_directory = None
if is_persistent:
persist_directory = tempfile.mkdtemp()
args = (port, is_persistent, persist_directory)
proc = ctx.Process(target=_run_server, args=args, daemon=True)
proc.start()
settings = Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
chroma_server_host="localhost",
chroma_server_http_port=str(port),
allow_reset=True,
)
system = System(settings)
api = system.instance(API)
system.start()
_await_server(api)
yield system
system.stop()
proc.kill()
if is_persistent and persist_directory is not None:
if os.path.exists(persist_directory):
shutil.rmtree(persist_directory)
def fastapi() -> Generator[System, None, None]:
return _fastapi_fixture(is_persistent=False)
def fastapi_persistent() -> Generator[System, None, None]:
return _fastapi_fixture(is_persistent=True)
def integration() -> Generator[System, None, None]:
"""Fixture generator for returning a client configured via environmenet
variables, intended for externally configured integration tests
"""
settings = Settings(allow_reset=True)
system = System(settings)
system.start()
yield system
system.stop()
def sqlite() -> Generator[System, None, None]:
"""Fixture generator for segment-based API using in-memory Sqlite"""
settings = Settings(
chroma_api_impl="chromadb.api.segment.SegmentAPI",
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
is_persistent=False,
allow_reset=True,
)
system = System(settings)
system.start()
yield system
system.stop()
def sqlite_persistent() -> Generator[System, None, None]:
"""Fixture generator for segment-based API using persistent Sqlite"""
save_path = tempfile.mkdtemp()
settings = Settings(
chroma_api_impl="chromadb.api.segment.SegmentAPI",
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
allow_reset=True,
is_persistent=True,
persist_directory=save_path,
)
system = System(settings)
system.start()
yield system
system.stop()
if os.path.exists(save_path):
shutil.rmtree(save_path)
def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]:
fixtures = [fastapi, fastapi_persistent, sqlite, sqlite_persistent]
if "CHROMA_INTEGRATION_TEST" in os.environ:
fixtures.append(integration)
if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ:
fixtures = [integration]
return fixtures
@pytest.fixture(scope="module", params=system_fixtures())
def system(request: pytest.FixtureRequest) -> Generator[API, None, None]:
yield next(request.param())
@pytest.fixture(scope="function")
def api(system: System) -> Generator[API, None, None]:
system.reset_state()
api = system.instance(API)
yield api
# Producer / Consumer fixtures #
class ProducerFn(Protocol):
def __call__(
self,
producer: Producer,
topic: str,
embeddings: Iterator[SubmitEmbeddingRecord],
n: int,
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
...
def produce_n_single(
producer: Producer,
topic: str,
embeddings: Iterator[SubmitEmbeddingRecord],
n: int,
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
submitted_embeddings = []
seq_ids = []
for _ in range(n):
e = next(embeddings)
seq_id = producer.submit_embedding(topic, e)
submitted_embeddings.append(e)
seq_ids.append(seq_id)
return submitted_embeddings, seq_ids
def produce_n_batch(
producer: Producer,
topic: str,
embeddings: Iterator[SubmitEmbeddingRecord],
n: int,
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
submitted_embeddings = []
seq_ids: Sequence[SeqId] = []
for _ in range(n):
e = next(embeddings)
submitted_embeddings.append(e)
seq_ids = producer.submit_embeddings(topic, submitted_embeddings)
return submitted_embeddings, seq_ids
def produce_fn_fixtures() -> List[ProducerFn]:
return [produce_n_single, produce_n_batch]
@pytest.fixture(scope="module", params=produce_fn_fixtures())
def produce_fns(
request: pytest.FixtureRequest,
) -> Generator[ProducerFn, None, None]:
yield request.param