1364 lines
40 KiB
Python
1364 lines
40 KiB
Python
# type: ignore
|
|
import chromadb
|
|
from chromadb.api.types import QueryResult
|
|
from chromadb.config import Settings
|
|
import chromadb.server.fastapi
|
|
import pytest
|
|
import tempfile
|
|
import numpy as np
|
|
import os
|
|
import shutil
|
|
from datetime import datetime, timedelta
|
|
from chromadb.utils.embedding_functions import (
|
|
DefaultEmbeddingFunction,
|
|
)
|
|
|
|
persist_dir = tempfile.mkdtemp()
|
|
|
|
|
|
@pytest.fixture
|
|
def local_persist_api():
|
|
yield chromadb.Client(
|
|
Settings(
|
|
chroma_api_impl="chromadb.api.segment.SegmentAPI",
|
|
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
|
|
allow_reset=True,
|
|
is_persistent=True,
|
|
persist_directory=persist_dir,
|
|
),
|
|
)
|
|
if os.path.exists(persist_dir):
|
|
shutil.rmtree(persist_dir, ignore_errors=True)
|
|
|
|
|
|
# https://docs.pytest.org/en/6.2.x/fixture.html#fixtures-can-be-requested-more-than-once-per-test-return-values-are-cached
|
|
@pytest.fixture
|
|
def local_persist_api_cache_bust():
|
|
yield chromadb.Client(
|
|
Settings(
|
|
chroma_api_impl="chromadb.api.segment.SegmentAPI",
|
|
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
|
|
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
|
|
allow_reset=True,
|
|
is_persistent=True,
|
|
persist_directory=persist_dir,
|
|
),
|
|
)
|
|
if os.path.exists(persist_dir):
|
|
shutil.rmtree(persist_dir, ignore_errors=True)
|
|
|
|
|
|
def approx_equal(a, b, tolerance=1e-6) -> bool:
|
|
return abs(a - b) < tolerance
|
|
|
|
|
|
def vector_approx_equal(a, b, tolerance: float = 1e-6) -> bool:
|
|
if len(a) != len(b):
|
|
return False
|
|
return all([approx_equal(a, b, tolerance) for a, b in zip(a, b)])
|
|
|
|
|
|
@pytest.mark.parametrize("api_fixture", [local_persist_api])
|
|
def test_persist_index_loading(api_fixture, request):
|
|
api = request.getfixturevalue("local_persist_api")
|
|
api.reset()
|
|
collection = api.create_collection("test")
|
|
collection.add(ids="id1", documents="hello")
|
|
|
|
api2 = request.getfixturevalue("local_persist_api_cache_bust")
|
|
collection = api2.get_collection("test")
|
|
|
|
nn = collection.query(
|
|
query_texts="hello",
|
|
n_results=1,
|
|
include=["embeddings", "documents", "metadatas", "distances"],
|
|
)
|
|
for key in nn.keys():
|
|
assert len(nn[key]) == 1
|
|
|
|
|
|
@pytest.mark.parametrize("api_fixture", [local_persist_api])
|
|
def test_persist_index_loading_embedding_function(api_fixture, request):
|
|
embedding_function = lambda x: [[1, 2, 3] for _ in range(len(x))] # noqa E731
|
|
api = request.getfixturevalue("local_persist_api")
|
|
api.reset()
|
|
collection = api.create_collection("test", embedding_function=embedding_function)
|
|
collection.add(ids="id1", documents="hello")
|
|
|
|
api2 = request.getfixturevalue("local_persist_api_cache_bust")
|
|
collection = api2.get_collection("test", embedding_function=embedding_function)
|
|
|
|
nn = collection.query(
|
|
query_texts="hello",
|
|
n_results=1,
|
|
include=["embeddings", "documents", "metadatas", "distances"],
|
|
)
|
|
for key in nn.keys():
|
|
assert len(nn[key]) == 1
|
|
|
|
|
|
@pytest.mark.parametrize("api_fixture", [local_persist_api])
|
|
def test_persist_index_get_or_create_embedding_function(api_fixture, request):
|
|
embedding_function = lambda x: [[1, 2, 3] for _ in range(len(x))] # noqa E731
|
|
api = request.getfixturevalue("local_persist_api")
|
|
api.reset()
|
|
collection = api.get_or_create_collection(
|
|
"test", embedding_function=embedding_function
|
|
)
|
|
collection.add(ids="id1", documents="hello")
|
|
|
|
api2 = request.getfixturevalue("local_persist_api_cache_bust")
|
|
collection = api2.get_or_create_collection(
|
|
"test", embedding_function=embedding_function
|
|
)
|
|
|
|
nn = collection.query(
|
|
query_texts="hello",
|
|
n_results=1,
|
|
include=["embeddings", "documents", "metadatas", "distances"],
|
|
)
|
|
|
|
for key in nn.keys():
|
|
assert len(nn[key]) == 1
|
|
|
|
assert nn["ids"] == [["id1"]]
|
|
assert nn["embeddings"] == [[[1, 2, 3]]]
|
|
assert nn["documents"] == [["hello"]]
|
|
assert nn["distances"] == [[0]]
|
|
|
|
|
|
@pytest.mark.parametrize("api_fixture", [local_persist_api])
|
|
def test_persist(api_fixture, request):
|
|
api = request.getfixturevalue(api_fixture.__name__)
|
|
|
|
api.reset()
|
|
|
|
collection = api.create_collection("testspace")
|
|
|
|
collection.add(**batch_records)
|
|
|
|
assert collection.count() == 2
|
|
|
|
api = request.getfixturevalue(api_fixture.__name__)
|
|
collection = api.get_collection("testspace")
|
|
assert collection.count() == 2
|
|
|
|
api.delete_collection("testspace")
|
|
|
|
api = request.getfixturevalue(api_fixture.__name__)
|
|
assert api.list_collections() == []
|
|
|
|
|
|
def test_heartbeat(api):
|
|
heartbeat_ns = api.heartbeat()
|
|
assert isinstance(heartbeat_ns, int)
|
|
|
|
heartbeat_s = heartbeat_ns // 10**9
|
|
heartbeat = datetime.fromtimestamp(heartbeat_s)
|
|
assert heartbeat > datetime.now() - timedelta(seconds=10)
|
|
|
|
|
|
batch_records = {
|
|
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
|
"ids": ["https://example.com/1", "https://example.com/2"],
|
|
}
|
|
|
|
|
|
def test_add(api):
|
|
api.reset()
|
|
|
|
collection = api.create_collection("testspace")
|
|
|
|
collection.add(**batch_records)
|
|
|
|
assert collection.count() == 2
|
|
|
|
|
|
def test_get_or_create(api):
|
|
api.reset()
|
|
|
|
collection = api.create_collection("testspace")
|
|
|
|
collection.add(**batch_records)
|
|
|
|
assert collection.count() == 2
|
|
|
|
with pytest.raises(Exception):
|
|
collection = api.create_collection("testspace")
|
|
|
|
collection = api.get_or_create_collection("testspace")
|
|
|
|
assert collection.count() == 2
|
|
|
|
|
|
minimal_records = {
|
|
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
|
"ids": ["https://example.com/1", "https://example.com/2"],
|
|
}
|
|
|
|
|
|
def test_add_minimal(api):
|
|
api.reset()
|
|
|
|
collection = api.create_collection("testspace")
|
|
|
|
collection.add(**minimal_records)
|
|
|
|
assert collection.count() == 2
|
|
|
|
|
|
def test_get_from_db(api):
|
|
api.reset()
|
|
collection = api.create_collection("testspace")
|
|
collection.add(**batch_records)
|
|
records = collection.get(include=["embeddings", "documents", "metadatas"])
|
|
for key in records.keys():
|
|
assert len(records[key]) == 2
|
|
|
|
|
|
def test_reset_db(api):
|
|
api.reset()
|
|
|
|
collection = api.create_collection("testspace")
|
|
collection.add(**batch_records)
|
|
assert collection.count() == 2
|
|
|
|
api.reset()
|
|
assert len(api.list_collections()) == 0
|
|
|
|
|
|
def test_get_nearest_neighbors(api):
|
|
api.reset()
|
|
collection = api.create_collection("testspace")
|
|
collection.add(**batch_records)
|
|
|
|
nn = collection.query(
|
|
query_embeddings=[1.1, 2.3, 3.2],
|
|
n_results=1,
|
|
where={},
|
|
include=["embeddings", "documents", "metadatas", "distances"],
|
|
)
|
|
for key in nn.keys():
|
|
assert len(nn[key]) == 1
|
|
|
|
nn = collection.query(
|
|
query_embeddings=[[1.1, 2.3, 3.2]],
|
|
n_results=1,
|
|
where={},
|
|
include=["embeddings", "documents", "metadatas", "distances"],
|
|
)
|
|
for key in nn.keys():
|
|
assert len(nn[key]) == 1
|
|
|
|
nn = collection.query(
|
|
query_embeddings=[[1.1, 2.3, 3.2], [0.1, 2.3, 4.5]],
|
|
n_results=1,
|
|
where={},
|
|
include=["embeddings", "documents", "metadatas", "distances"],
|
|
)
|
|
for key in nn.keys():
|
|
assert len(nn[key]) == 2
|
|
|
|
|
|
def test_delete(api):
|
|
api.reset()
|
|
collection = api.create_collection("testspace")
|
|
collection.add(**batch_records)
|
|
assert collection.count() == 2
|
|
|
|
collection.delete()
|
|
assert collection.count() == 0
|
|
|
|
|
|
def test_delete_with_index(api):
|
|
api.reset()
|
|
collection = api.create_collection("testspace")
|
|
collection.add(**batch_records)
|
|
assert collection.count() == 2
|
|
collection.query(query_embeddings=[[1.1, 2.3, 3.2]], n_results=1)
|
|
|
|
|
|
def test_count(api):
|
|
api.reset()
|
|
collection = api.create_collection("testspace")
|
|
assert collection.count() == 0
|
|
collection.add(**batch_records)
|
|
assert collection.count() == 2
|
|
|
|
|
|
def test_modify(api):
|
|
api.reset()
|
|
collection = api.create_collection("testspace")
|
|
collection.modify(name="testspace2")
|
|
|
|
# collection name is modify
|
|
assert collection.name == "testspace2"
|
|
|
|
|
|
def test_modify_error_on_existing_name(api):
|
|
api.reset()
|
|
|
|
api.create_collection("testspace")
|
|
c2 = api.create_collection("testspace2")
|
|
|
|
with pytest.raises(Exception):
|
|
c2.modify(name="testspace")
|
|
|
|
|
|
def test_metadata_cru(api):
|
|
api.reset()
|
|
metadata_a = {"a": 1, "b": 2}
|
|
# Test create metatdata
|
|
collection = api.create_collection("testspace", metadata=metadata_a)
|
|
assert collection.metadata is not None
|
|
assert collection.metadata["a"] == 1
|
|
assert collection.metadata["b"] == 2
|
|
|
|
# Test get metatdata
|
|
collection = api.get_collection("testspace")
|
|
assert collection.metadata is not None
|
|
assert collection.metadata["a"] == 1
|
|
assert collection.metadata["b"] == 2
|
|
|
|
# Test modify metatdata
|
|
collection.modify(metadata={"a": 2, "c": 3})
|
|
assert collection.metadata["a"] == 2
|
|
assert collection.metadata["c"] == 3
|
|
assert "b" not in collection.metadata
|
|
|
|
# Test get after modify metatdata
|
|
collection = api.get_collection("testspace")
|
|
assert collection.metadata is not None
|
|
assert collection.metadata["a"] == 2
|
|
assert collection.metadata["c"] == 3
|
|
assert "b" not in collection.metadata
|
|
|
|
# Test name exists get_or_create_metadata
|
|
collection = api.get_or_create_collection("testspace")
|
|
assert collection.metadata is not None
|
|
assert collection.metadata["a"] == 2
|
|
assert collection.metadata["c"] == 3
|
|
|
|
# Test name exists create metadata
|
|
collection = api.get_or_create_collection("testspace2")
|
|
assert collection.metadata is None
|
|
|
|
# Test list collections
|
|
collections = api.list_collections()
|
|
for collection in collections:
|
|
if collection.name == "testspace":
|
|
assert collection.metadata is not None
|
|
assert collection.metadata["a"] == 2
|
|
assert collection.metadata["c"] == 3
|
|
elif collection.name == "testspace2":
|
|
assert collection.metadata is None
|
|
|
|
|
|
def test_increment_index_on(api):
|
|
api.reset()
|
|
collection = api.create_collection("testspace")
|
|
collection.add(**batch_records)
|
|
assert collection.count() == 2
|
|
|
|
# increment index
|
|
nn = collection.query(
|
|
query_embeddings=[[1.1, 2.3, 3.2]],
|
|
n_results=1,
|
|
include=["embeddings", "documents", "metadatas", "distances"],
|
|
)
|
|
for key in nn.keys():
|
|
assert len(nn[key]) == 1
|
|
|
|
|
|
def test_add_a_collection(api):
|
|
api.reset()
|
|
api.create_collection("testspace")
|
|
|
|
# get collection does not throw an error
|
|
collection = api.get_collection("testspace")
|
|
assert collection.name == "testspace"
|
|
|
|
# get collection should throw an error if collection does not exist
|
|
with pytest.raises(Exception):
|
|
collection = api.get_collection("testspace2")
|
|
|
|
|
|
def test_list_collections(api):
|
|
api.reset()
|
|
api.create_collection("testspace")
|
|
api.create_collection("testspace2")
|
|
|
|
# get collection does not throw an error
|
|
collections = api.list_collections()
|
|
assert len(collections) == 2
|
|
|
|
|
|
def test_reset(api):
|
|
api.reset()
|
|
api.create_collection("testspace")
|
|
api.create_collection("testspace2")
|
|
|
|
# get collection does not throw an error
|
|
collections = api.list_collections()
|
|
assert len(collections) == 2
|
|
|
|
api.reset()
|
|
collections = api.list_collections()
|
|
assert len(collections) == 0
|
|
|
|
|
|
def test_peek(api):
|
|
api.reset()
|
|
collection = api.create_collection("testspace")
|
|
collection.add(**batch_records)
|
|
assert collection.count() == 2
|
|
|
|
# peek
|
|
peek = collection.peek()
|
|
for key in peek.keys():
|
|
assert len(peek[key]) == 2
|
|
|
|
|
|
# TEST METADATA AND METADATA FILTERING
|
|
# region
|
|
|
|
metadata_records = {
|
|
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
|
"ids": ["id1", "id2"],
|
|
"metadatas": [
|
|
{"int_value": 1, "string_value": "one", "float_value": 1.001},
|
|
{"int_value": 2},
|
|
],
|
|
}
|
|
|
|
|
|
def test_metadata_add_get_int_float(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_int")
|
|
collection.add(**metadata_records)
|
|
|
|
items = collection.get(ids=["id1", "id2"])
|
|
assert items["metadatas"][0]["int_value"] == 1
|
|
assert items["metadatas"][0]["float_value"] == 1.001
|
|
assert items["metadatas"][1]["int_value"] == 2
|
|
assert type(items["metadatas"][0]["int_value"]) == int
|
|
assert type(items["metadatas"][0]["float_value"]) == float
|
|
|
|
|
|
def test_metadata_add_query_int_float(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_int")
|
|
collection.add(**metadata_records)
|
|
|
|
items: QueryResult = collection.query(
|
|
query_embeddings=[[1.1, 2.3, 3.2]], n_results=1
|
|
)
|
|
assert items["metadatas"] is not None
|
|
assert items["metadatas"][0][0]["int_value"] == 1
|
|
assert items["metadatas"][0][0]["float_value"] == 1.001
|
|
assert type(items["metadatas"][0][0]["int_value"]) == int
|
|
assert type(items["metadatas"][0][0]["float_value"]) == float
|
|
|
|
|
|
def test_metadata_get_where_string(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_int")
|
|
collection.add(**metadata_records)
|
|
|
|
items = collection.get(where={"string_value": "one"})
|
|
assert items["metadatas"][0]["int_value"] == 1
|
|
assert items["metadatas"][0]["string_value"] == "one"
|
|
|
|
|
|
def test_metadata_get_where_int(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_int")
|
|
collection.add(**metadata_records)
|
|
|
|
items = collection.get(where={"int_value": 1})
|
|
assert items["metadatas"][0]["int_value"] == 1
|
|
assert items["metadatas"][0]["string_value"] == "one"
|
|
|
|
|
|
def test_metadata_get_where_float(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_int")
|
|
collection.add(**metadata_records)
|
|
|
|
items = collection.get(where={"float_value": 1.001})
|
|
assert items["metadatas"][0]["int_value"] == 1
|
|
assert items["metadatas"][0]["string_value"] == "one"
|
|
assert items["metadatas"][0]["float_value"] == 1.001
|
|
|
|
|
|
def test_metadata_update_get_int_float(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_int")
|
|
collection.add(**metadata_records)
|
|
|
|
collection.update(
|
|
ids=["id1"],
|
|
metadatas=[{"int_value": 2, "string_value": "two", "float_value": 2.002}],
|
|
)
|
|
items = collection.get(ids=["id1"])
|
|
assert items["metadatas"][0]["int_value"] == 2
|
|
assert items["metadatas"][0]["string_value"] == "two"
|
|
assert items["metadatas"][0]["float_value"] == 2.002
|
|
|
|
|
|
bad_metadata_records = {
|
|
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
|
"ids": ["id1", "id2"],
|
|
"metadatas": [{"value": {"nested": "5"}}, {"value": [1, 2, 3]}],
|
|
}
|
|
|
|
|
|
def test_metadata_validation_add(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_metadata_validation")
|
|
with pytest.raises(ValueError, match="metadata"):
|
|
collection.add(**bad_metadata_records)
|
|
|
|
|
|
def test_metadata_validation_update(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_metadata_validation")
|
|
collection.add(**metadata_records)
|
|
with pytest.raises(ValueError, match="metadata"):
|
|
collection.update(ids=["id1"], metadatas={"value": {"nested": "5"}})
|
|
|
|
|
|
def test_where_validation_get(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_where_validation")
|
|
with pytest.raises(ValueError, match="where"):
|
|
collection.get(where={"value": {"nested": "5"}})
|
|
|
|
|
|
def test_where_validation_query(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_where_validation")
|
|
with pytest.raises(ValueError, match="where"):
|
|
collection.query(query_embeddings=[0, 0, 0], where={"value": {"nested": "5"}})
|
|
|
|
|
|
operator_records = {
|
|
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
|
"ids": ["id1", "id2"],
|
|
"metadatas": [
|
|
{"int_value": 1, "string_value": "one", "float_value": 1.001},
|
|
{"int_value": 2, "float_value": 2.002, "string_value": "two"},
|
|
],
|
|
}
|
|
|
|
|
|
def test_where_lt(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_where_lt")
|
|
collection.add(**operator_records)
|
|
items = collection.get(where={"int_value": {"$lt": 2}})
|
|
assert len(items["metadatas"]) == 1
|
|
|
|
|
|
def test_where_lte(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_where_lte")
|
|
collection.add(**operator_records)
|
|
items = collection.get(where={"int_value": {"$lte": 2.0}})
|
|
assert len(items["metadatas"]) == 2
|
|
|
|
|
|
def test_where_gt(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_where_lte")
|
|
collection.add(**operator_records)
|
|
items = collection.get(where={"float_value": {"$gt": -1.4}})
|
|
assert len(items["metadatas"]) == 2
|
|
|
|
|
|
def test_where_gte(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_where_lte")
|
|
collection.add(**operator_records)
|
|
items = collection.get(where={"float_value": {"$gte": 2.002}})
|
|
assert len(items["metadatas"]) == 1
|
|
|
|
|
|
def test_where_ne_string(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_where_lte")
|
|
collection.add(**operator_records)
|
|
items = collection.get(where={"string_value": {"$ne": "two"}})
|
|
assert len(items["metadatas"]) == 1
|
|
|
|
|
|
def test_where_ne_eq_number(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_where_lte")
|
|
collection.add(**operator_records)
|
|
items = collection.get(where={"int_value": {"$ne": 1}})
|
|
assert len(items["metadatas"]) == 1
|
|
items = collection.get(where={"float_value": {"$eq": 2.002}})
|
|
assert len(items["metadatas"]) == 1
|
|
|
|
|
|
def test_where_valid_operators(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_where_valid_operators")
|
|
collection.add(**operator_records)
|
|
with pytest.raises(ValueError):
|
|
collection.get(where={"int_value": {"$invalid": 2}})
|
|
|
|
with pytest.raises(ValueError):
|
|
collection.get(where={"int_value": {"$lt": "2"}})
|
|
|
|
with pytest.raises(ValueError):
|
|
collection.get(where={"int_value": {"$lt": 2, "$gt": 1}})
|
|
|
|
# Test invalid $and, $or
|
|
with pytest.raises(ValueError):
|
|
collection.get(where={"$and": {"int_value": {"$lt": 2}}})
|
|
|
|
with pytest.raises(ValueError):
|
|
collection.get(
|
|
where={"int_value": {"$lt": 2}, "$or": {"int_value": {"$gt": 1}}}
|
|
)
|
|
|
|
with pytest.raises(ValueError):
|
|
collection.get(
|
|
where={"$gt": [{"int_value": {"$lt": 2}}, {"int_value": {"$gt": 1}}]}
|
|
)
|
|
|
|
with pytest.raises(ValueError):
|
|
collection.get(where={"$or": [{"int_value": {"$lt": 2}}]})
|
|
|
|
with pytest.raises(ValueError):
|
|
collection.get(where={"$or": []})
|
|
|
|
with pytest.raises(ValueError):
|
|
collection.get(where={"a": {"$contains": "test"}})
|
|
|
|
with pytest.raises(ValueError):
|
|
collection.get(
|
|
where={
|
|
"$or": [
|
|
{"a": {"$contains": "first"}}, # invalid
|
|
{"$contains": "second"}, # valid
|
|
]
|
|
}
|
|
)
|
|
|
|
|
|
# TODO: Define the dimensionality of these embeddingds in terms of the default record
|
|
bad_dimensionality_records = {
|
|
"embeddings": [[1.1, 2.3, 3.2, 4.5], [1.2, 2.24, 3.2, 4.5]],
|
|
"ids": ["id1", "id2"],
|
|
}
|
|
|
|
bad_dimensionality_query = {
|
|
"query_embeddings": [[1.1, 2.3, 3.2, 4.5], [1.2, 2.24, 3.2, 4.5]],
|
|
}
|
|
|
|
bad_number_of_results_query = {
|
|
"query_embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
|
"n_results": 100,
|
|
}
|
|
|
|
|
|
def test_dimensionality_validation_add(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_dimensionality_validation")
|
|
collection.add(**minimal_records)
|
|
|
|
with pytest.raises(Exception) as e:
|
|
collection.add(**bad_dimensionality_records)
|
|
assert "dimensionality" in str(e.value)
|
|
|
|
|
|
def test_dimensionality_validation_query(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_dimensionality_validation_query")
|
|
collection.add(**minimal_records)
|
|
|
|
with pytest.raises(Exception) as e:
|
|
collection.query(**bad_dimensionality_query)
|
|
assert "dimensionality" in str(e.value)
|
|
|
|
|
|
def test_query_document_valid_operators(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_where_valid_operators")
|
|
collection.add(**operator_records)
|
|
with pytest.raises(ValueError, match="where document"):
|
|
collection.get(where_document={"$lt": {"$nested": 2}})
|
|
|
|
with pytest.raises(ValueError, match="where document"):
|
|
collection.query(query_embeddings=[0, 0, 0], where_document={"$contains": 2})
|
|
|
|
with pytest.raises(ValueError, match="where document"):
|
|
collection.get(where_document={"$contains": []})
|
|
|
|
# Test invalid $and, $or
|
|
with pytest.raises(ValueError):
|
|
collection.get(where_document={"$and": {"$unsupported": "doc"}})
|
|
|
|
with pytest.raises(ValueError):
|
|
collection.get(
|
|
where_document={"$or": [{"$unsupported": "doc"}, {"$unsupported": "doc"}]}
|
|
)
|
|
|
|
with pytest.raises(ValueError):
|
|
collection.get(where_document={"$or": [{"$contains": "doc"}]})
|
|
|
|
with pytest.raises(ValueError):
|
|
collection.get(where_document={"$or": []})
|
|
|
|
with pytest.raises(ValueError):
|
|
collection.get(
|
|
where_document={
|
|
"$or": [{"$and": [{"$contains": "doc"}]}, {"$contains": "doc"}]
|
|
}
|
|
)
|
|
|
|
|
|
contains_records = {
|
|
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
|
|
"documents": ["this is doc1 and it's great!", "doc2 is also great!"],
|
|
"ids": ["id1", "id2"],
|
|
"metadatas": [
|
|
{"int_value": 1, "string_value": "one", "float_value": 1.001},
|
|
{"int_value": 2, "float_value": 2.002, "string_value": "two"},
|
|
],
|
|
}
|
|
|
|
|
|
def test_get_where_document(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_get_where_document")
|
|
collection.add(**contains_records)
|
|
|
|
items = collection.get(where_document={"$contains": "doc1"})
|
|
assert len(items["metadatas"]) == 1
|
|
|
|
items = collection.get(where_document={"$contains": "great"})
|
|
assert len(items["metadatas"]) == 2
|
|
|
|
items = collection.get(where_document={"$contains": "bad"})
|
|
assert len(items["metadatas"]) == 0
|
|
|
|
|
|
def test_query_where_document(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_query_where_document")
|
|
collection.add(**contains_records)
|
|
|
|
items = collection.query(
|
|
query_embeddings=[1, 0, 0], where_document={"$contains": "doc1"}, n_results=1
|
|
)
|
|
assert len(items["metadatas"][0]) == 1
|
|
|
|
items = collection.query(
|
|
query_embeddings=[0, 0, 0], where_document={"$contains": "great"}, n_results=2
|
|
)
|
|
assert len(items["metadatas"][0]) == 2
|
|
|
|
with pytest.raises(Exception) as e:
|
|
items = collection.query(
|
|
query_embeddings=[0, 0, 0], where_document={"$contains": "bad"}, n_results=1
|
|
)
|
|
assert "datapoints" in str(e.value)
|
|
|
|
|
|
def test_delete_where_document(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_delete_where_document")
|
|
collection.add(**contains_records)
|
|
|
|
collection.delete(where_document={"$contains": "doc1"})
|
|
assert collection.count() == 1
|
|
|
|
collection.delete(where_document={"$contains": "bad"})
|
|
assert collection.count() == 1
|
|
|
|
collection.delete(where_document={"$contains": "great"})
|
|
assert collection.count() == 0
|
|
|
|
|
|
logical_operator_records = {
|
|
"embeddings": [
|
|
[1.1, 2.3, 3.2],
|
|
[1.2, 2.24, 3.2],
|
|
[1.3, 2.25, 3.2],
|
|
[1.4, 2.26, 3.2],
|
|
],
|
|
"ids": ["id1", "id2", "id3", "id4"],
|
|
"metadatas": [
|
|
{"int_value": 1, "string_value": "one", "float_value": 1.001, "is": "doc"},
|
|
{"int_value": 2, "float_value": 2.002, "string_value": "two", "is": "doc"},
|
|
{"int_value": 3, "float_value": 3.003, "string_value": "three", "is": "doc"},
|
|
{"int_value": 4, "float_value": 4.004, "string_value": "four", "is": "doc"},
|
|
],
|
|
"documents": [
|
|
"this document is first and great",
|
|
"this document is second and great",
|
|
"this document is third and great",
|
|
"this document is fourth and great",
|
|
],
|
|
}
|
|
|
|
|
|
def test_where_logical_operators(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_logical_operators")
|
|
collection.add(**logical_operator_records)
|
|
|
|
items = collection.get(
|
|
where={
|
|
"$and": [
|
|
{"$or": [{"int_value": {"$gte": 3}}, {"float_value": {"$lt": 1.9}}]},
|
|
{"is": "doc"},
|
|
]
|
|
}
|
|
)
|
|
assert len(items["metadatas"]) == 3
|
|
|
|
items = collection.get(
|
|
where={
|
|
"$or": [
|
|
{
|
|
"$and": [
|
|
{"int_value": {"$eq": 3}},
|
|
{"string_value": {"$eq": "three"}},
|
|
]
|
|
},
|
|
{
|
|
"$and": [
|
|
{"int_value": {"$eq": 4}},
|
|
{"string_value": {"$eq": "four"}},
|
|
]
|
|
},
|
|
]
|
|
}
|
|
)
|
|
assert len(items["metadatas"]) == 2
|
|
|
|
items = collection.get(
|
|
where={
|
|
"$and": [
|
|
{
|
|
"$or": [
|
|
{"int_value": {"$eq": 1}},
|
|
{"string_value": {"$eq": "two"}},
|
|
]
|
|
},
|
|
{
|
|
"$or": [
|
|
{"int_value": {"$eq": 2}},
|
|
{"string_value": {"$eq": "one"}},
|
|
]
|
|
},
|
|
]
|
|
}
|
|
)
|
|
assert len(items["metadatas"]) == 2
|
|
|
|
|
|
def test_where_document_logical_operators(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_document_logical_operators")
|
|
collection.add(**logical_operator_records)
|
|
|
|
items = collection.get(
|
|
where_document={
|
|
"$and": [
|
|
{"$contains": "first"},
|
|
{"$contains": "doc"},
|
|
]
|
|
}
|
|
)
|
|
assert len(items["metadatas"]) == 1
|
|
|
|
items = collection.get(
|
|
where_document={
|
|
"$or": [
|
|
{"$contains": "first"},
|
|
{"$contains": "second"},
|
|
]
|
|
}
|
|
)
|
|
assert len(items["metadatas"]) == 2
|
|
|
|
items = collection.get(
|
|
where_document={
|
|
"$or": [
|
|
{"$contains": "first"},
|
|
{"$contains": "second"},
|
|
]
|
|
},
|
|
where={
|
|
"int_value": {"$ne": 2},
|
|
},
|
|
)
|
|
assert len(items["metadatas"]) == 1
|
|
|
|
|
|
# endregion
|
|
|
|
records = {
|
|
"embeddings": [[0, 0, 0], [1.2, 2.24, 3.2]],
|
|
"ids": ["id1", "id2"],
|
|
"metadatas": [
|
|
{"int_value": 1, "string_value": "one", "float_value": 1.001},
|
|
{"int_value": 2},
|
|
],
|
|
"documents": ["this document is first", "this document is second"],
|
|
}
|
|
|
|
|
|
def test_query_include(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_query_include")
|
|
collection.add(**records)
|
|
|
|
items = collection.query(
|
|
query_embeddings=[0, 0, 0],
|
|
include=["metadatas", "documents", "distances"],
|
|
n_results=1,
|
|
)
|
|
assert items["embeddings"] is None
|
|
assert items["ids"][0][0] == "id1"
|
|
assert items["metadatas"][0][0]["int_value"] == 1
|
|
|
|
items = collection.query(
|
|
query_embeddings=[0, 0, 0],
|
|
include=["embeddings", "documents", "distances"],
|
|
n_results=1,
|
|
)
|
|
assert items["metadatas"] is None
|
|
assert items["ids"][0][0] == "id1"
|
|
|
|
items = collection.query(
|
|
query_embeddings=[[0, 0, 0], [1, 2, 1.2]],
|
|
include=[],
|
|
n_results=2,
|
|
)
|
|
assert items["documents"] is None
|
|
assert items["metadatas"] is None
|
|
assert items["embeddings"] is None
|
|
assert items["distances"] is None
|
|
assert items["ids"][0][0] == "id1"
|
|
assert items["ids"][0][1] == "id2"
|
|
|
|
|
|
def test_get_include(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_get_include")
|
|
collection.add(**records)
|
|
|
|
items = collection.get(include=["metadatas", "documents"], where={"int_value": 1})
|
|
assert items["embeddings"] is None
|
|
assert items["ids"][0] == "id1"
|
|
assert items["metadatas"][0]["int_value"] == 1
|
|
assert items["documents"][0] == "this document is first"
|
|
|
|
items = collection.get(include=["embeddings", "documents"])
|
|
assert items["metadatas"] is None
|
|
assert items["ids"][0] == "id1"
|
|
assert approx_equal(items["embeddings"][1][0], 1.2)
|
|
|
|
items = collection.get(include=[])
|
|
assert items["documents"] is None
|
|
assert items["metadatas"] is None
|
|
assert items["embeddings"] is None
|
|
assert items["ids"][0] == "id1"
|
|
|
|
with pytest.raises(ValueError, match="include"):
|
|
items = collection.get(include=["metadatas", "undefined"])
|
|
|
|
with pytest.raises(ValueError, match="include"):
|
|
items = collection.get(include=None)
|
|
|
|
|
|
# make sure query results are returned in the right order
|
|
|
|
|
|
def test_query_order(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_query_order")
|
|
collection.add(**records)
|
|
|
|
items = collection.query(
|
|
query_embeddings=[1.2, 2.24, 3.2],
|
|
include=["metadatas", "documents", "distances"],
|
|
n_results=2,
|
|
)
|
|
|
|
assert items["documents"][0][0] == "this document is second"
|
|
assert items["documents"][0][1] == "this document is first"
|
|
|
|
|
|
# test to make sure add, get, delete error on invalid id input
|
|
|
|
|
|
def test_invalid_id(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_invalid_id")
|
|
# Add with non-string id
|
|
with pytest.raises(ValueError) as e:
|
|
collection.add(embeddings=[0, 0, 0], ids=[1], metadatas=[{}])
|
|
assert "ID" in str(e.value)
|
|
|
|
# Get with non-list id
|
|
with pytest.raises(ValueError) as e:
|
|
collection.get(ids=1)
|
|
assert "ID" in str(e.value)
|
|
|
|
# Delete with malformed ids
|
|
with pytest.raises(ValueError) as e:
|
|
collection.delete(ids=["valid", 0])
|
|
assert "ID" in str(e.value)
|
|
|
|
|
|
def test_index_params(api):
|
|
EPS = 1e-12
|
|
# first standard add
|
|
api.reset()
|
|
collection = api.create_collection(name="test_index_params")
|
|
collection.add(**records)
|
|
items = collection.query(
|
|
query_embeddings=[0.6, 1.12, 1.6],
|
|
n_results=1,
|
|
)
|
|
assert items["distances"][0][0] > 4
|
|
|
|
# cosine
|
|
api.reset()
|
|
collection = api.create_collection(
|
|
name="test_index_params",
|
|
metadata={"hnsw:space": "cosine", "hnsw:construction_ef": 20, "hnsw:M": 5},
|
|
)
|
|
collection.add(**records)
|
|
items = collection.query(
|
|
query_embeddings=[0.6, 1.12, 1.6],
|
|
n_results=1,
|
|
)
|
|
assert items["distances"][0][0] > 0 - EPS
|
|
assert items["distances"][0][0] < 1 + EPS
|
|
|
|
# ip
|
|
api.reset()
|
|
collection = api.create_collection(
|
|
name="test_index_params", metadata={"hnsw:space": "ip"}
|
|
)
|
|
collection.add(**records)
|
|
items = collection.query(
|
|
query_embeddings=[0.6, 1.12, 1.6],
|
|
n_results=1,
|
|
)
|
|
assert items["distances"][0][0] < -5
|
|
|
|
|
|
def test_invalid_index_params(api):
|
|
api.reset()
|
|
|
|
with pytest.raises(Exception):
|
|
collection = api.create_collection(
|
|
name="test_index_params", metadata={"hnsw:foobar": "blarg"}
|
|
)
|
|
collection.add(**records)
|
|
|
|
with pytest.raises(Exception):
|
|
collection = api.create_collection(
|
|
name="test_index_params", metadata={"hnsw:space": "foobar"}
|
|
)
|
|
collection.add(**records)
|
|
|
|
|
|
def test_persist_index_loading_params(api, request):
|
|
api = request.getfixturevalue("local_persist_api")
|
|
api.reset()
|
|
collection = api.create_collection(
|
|
"test",
|
|
metadata={"hnsw:space": "ip"},
|
|
)
|
|
collection.add(ids="id1", documents="hello")
|
|
|
|
api2 = request.getfixturevalue("local_persist_api_cache_bust")
|
|
collection = api2.get_collection(
|
|
"test",
|
|
)
|
|
|
|
assert collection.metadata["hnsw:space"] == "ip"
|
|
|
|
nn = collection.query(
|
|
query_texts="hello",
|
|
n_results=1,
|
|
include=["embeddings", "documents", "metadatas", "distances"],
|
|
)
|
|
for key in nn.keys():
|
|
assert len(nn[key]) == 1
|
|
|
|
|
|
def test_add_large(api):
|
|
api.reset()
|
|
|
|
collection = api.create_collection("testspace")
|
|
|
|
# Test adding a large number of records
|
|
large_records = np.random.rand(2000, 512).astype(np.float32).tolist()
|
|
|
|
collection.add(
|
|
embeddings=large_records,
|
|
ids=[f"http://example.com/{i}" for i in range(len(large_records))],
|
|
)
|
|
|
|
assert collection.count() == len(large_records)
|
|
|
|
|
|
# test get_version
|
|
def test_get_version(api):
|
|
api.reset()
|
|
version = api.get_version()
|
|
|
|
# assert version matches the pattern x.y.z
|
|
import re
|
|
|
|
assert re.match(r"\d+\.\d+\.\d+", version)
|
|
|
|
|
|
# test delete_collection
|
|
def test_delete_collection(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_delete_collection")
|
|
collection.add(**records)
|
|
|
|
assert len(api.list_collections()) == 1
|
|
api.delete_collection("test_delete_collection")
|
|
assert len(api.list_collections()) == 0
|
|
|
|
|
|
# test default embedding function
|
|
def test_default_embedding():
|
|
embedding_function = DefaultEmbeddingFunction()
|
|
docs = ["this is a test" for _ in range(64)]
|
|
embeddings = embedding_function(docs)
|
|
assert len(embeddings) == 64
|
|
|
|
|
|
def test_multiple_collections(api):
|
|
embeddings1 = np.random.rand(10, 512).astype(np.float32).tolist()
|
|
embeddings2 = np.random.rand(10, 512).astype(np.float32).tolist()
|
|
ids1 = [f"http://example.com/1/{i}" for i in range(len(embeddings1))]
|
|
ids2 = [f"http://example.com/2/{i}" for i in range(len(embeddings2))]
|
|
|
|
api.reset()
|
|
coll1 = api.create_collection("coll1")
|
|
coll1.add(embeddings=embeddings1, ids=ids1)
|
|
|
|
coll2 = api.create_collection("coll2")
|
|
coll2.add(embeddings=embeddings2, ids=ids2)
|
|
|
|
assert len(api.list_collections()) == 2
|
|
assert coll1.count() == len(embeddings1)
|
|
assert coll2.count() == len(embeddings2)
|
|
|
|
results1 = coll1.query(query_embeddings=embeddings1[0], n_results=1)
|
|
results2 = coll2.query(query_embeddings=embeddings2[0], n_results=1)
|
|
|
|
assert results1["ids"][0][0] == ids1[0]
|
|
assert results2["ids"][0][0] == ids2[0]
|
|
|
|
|
|
def test_update_query(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_update_query")
|
|
collection.add(**records)
|
|
|
|
updated_records = {
|
|
"ids": [records["ids"][0]],
|
|
"embeddings": [[0.1, 0.2, 0.3]],
|
|
"documents": ["updated document"],
|
|
"metadatas": [{"foo": "bar"}],
|
|
}
|
|
|
|
collection.update(**updated_records)
|
|
|
|
# test query
|
|
results = collection.query(
|
|
query_embeddings=updated_records["embeddings"],
|
|
n_results=1,
|
|
include=["embeddings", "documents", "metadatas"],
|
|
)
|
|
assert len(results["ids"][0]) == 1
|
|
assert results["ids"][0][0] == updated_records["ids"][0]
|
|
assert results["documents"][0][0] == updated_records["documents"][0]
|
|
assert results["metadatas"][0][0]["foo"] == "bar"
|
|
assert vector_approx_equal(
|
|
results["embeddings"][0][0], updated_records["embeddings"][0]
|
|
)
|
|
|
|
|
|
def test_get_nearest_neighbors_where_n_results_more_than_element(api):
|
|
api.reset()
|
|
collection = api.create_collection("testspace")
|
|
collection.add(**records)
|
|
|
|
results1 = collection.query(
|
|
query_embeddings=[[1.1, 2.3, 3.2]],
|
|
n_results=5,
|
|
where={},
|
|
include=["embeddings", "documents", "metadatas", "distances"],
|
|
)
|
|
for key in results1.keys():
|
|
assert len(results1[key][0]) == 2
|
|
|
|
|
|
def test_invalid_n_results_param(api):
|
|
api.reset()
|
|
collection = api.create_collection("testspace")
|
|
collection.add(**records)
|
|
with pytest.raises(TypeError) as exc:
|
|
collection.query(
|
|
query_embeddings=[[1.1, 2.3, 3.2]],
|
|
n_results=-1,
|
|
where={},
|
|
include=["embeddings", "documents", "metadatas", "distances"],
|
|
)
|
|
assert "Number of requested results -1, cannot be negative, or zero." in str(
|
|
exc.value
|
|
)
|
|
assert exc.type == TypeError
|
|
|
|
with pytest.raises(ValueError) as exc:
|
|
collection.query(
|
|
query_embeddings=[[1.1, 2.3, 3.2]],
|
|
n_results="one",
|
|
where={},
|
|
include=["embeddings", "documents", "metadatas", "distances"],
|
|
)
|
|
assert "int" in str(exc.value)
|
|
assert exc.type == ValueError
|
|
|
|
|
|
initial_records = {
|
|
"embeddings": [[0, 0, 0], [1.2, 2.24, 3.2], [2.2, 3.24, 4.2]],
|
|
"ids": ["id1", "id2", "id3"],
|
|
"metadatas": [
|
|
{"int_value": 1, "string_value": "one", "float_value": 1.001},
|
|
{"int_value": 2},
|
|
{"string_value": "three"},
|
|
],
|
|
"documents": [
|
|
"this document is first",
|
|
"this document is second",
|
|
"this document is third",
|
|
],
|
|
}
|
|
|
|
new_records = {
|
|
"embeddings": [[3.0, 3.0, 1.1], [3.2, 4.24, 5.2]],
|
|
"ids": ["id1", "id4"],
|
|
"metadatas": [
|
|
{"int_value": 1, "string_value": "one_of_one", "float_value": 1.001},
|
|
{"int_value": 4},
|
|
],
|
|
"documents": [
|
|
"this document is even more first",
|
|
"this document is new and fourth",
|
|
],
|
|
}
|
|
|
|
|
|
def test_upsert(api):
|
|
api.reset()
|
|
collection = api.create_collection("test")
|
|
|
|
collection.add(**initial_records)
|
|
assert collection.count() == 3
|
|
|
|
collection.upsert(**new_records)
|
|
assert collection.count() == 4
|
|
|
|
get_result = collection.get(
|
|
include=["embeddings", "metadatas", "documents"], ids=new_records["ids"][0]
|
|
)
|
|
assert vector_approx_equal(
|
|
get_result["embeddings"][0], new_records["embeddings"][0]
|
|
)
|
|
assert get_result["metadatas"][0] == new_records["metadatas"][0]
|
|
assert get_result["documents"][0] == new_records["documents"][0]
|
|
|
|
query_result = collection.query(
|
|
query_embeddings=get_result["embeddings"],
|
|
n_results=1,
|
|
include=["embeddings", "metadatas", "documents"],
|
|
)
|
|
assert vector_approx_equal(
|
|
query_result["embeddings"][0][0], new_records["embeddings"][0]
|
|
)
|
|
assert query_result["metadatas"][0][0] == new_records["metadatas"][0]
|
|
assert query_result["documents"][0][0] == new_records["documents"][0]
|
|
|
|
collection.delete(ids=initial_records["ids"][2])
|
|
collection.upsert(
|
|
ids=initial_records["ids"][2],
|
|
embeddings=[[1.1, 0.99, 2.21]],
|
|
metadatas=[{"string_value": "a new string value"}],
|
|
)
|
|
assert collection.count() == 4
|
|
|
|
get_result = collection.get(
|
|
include=["embeddings", "metadatas", "documents"], ids=["id3"]
|
|
)
|
|
assert vector_approx_equal(get_result["embeddings"][0], [1.1, 0.99, 2.21])
|
|
assert get_result["metadatas"][0] == {"string_value": "a new string value"}
|
|
assert get_result["documents"][0] is None
|
|
|
|
|
|
# test to make sure add, query, update, upsert error on invalid embeddings input
|
|
|
|
|
|
def test_invalid_embeddings(api):
|
|
api.reset()
|
|
collection = api.create_collection("test_invalid_embeddings")
|
|
|
|
# Add with string embeddings
|
|
invalid_records = {
|
|
"embeddings": [["0", "0", "0"], ["1.2", "2.24", "3.2"]],
|
|
"ids": ["id1", "id2"],
|
|
}
|
|
with pytest.raises(ValueError) as e:
|
|
collection.add(**invalid_records)
|
|
assert "embedding" in str(e.value)
|
|
|
|
# Query with invalid embeddings
|
|
with pytest.raises(ValueError) as e:
|
|
collection.query(
|
|
query_embeddings=[["1.1", "2.3", "3.2"]],
|
|
n_results=1,
|
|
)
|
|
assert "embedding" in str(e.value)
|
|
|
|
# Update with invalid embeddings
|
|
invalid_records = {
|
|
"embeddings": [[[0], [0], [0]], [[1.2], [2.24], [3.2]]],
|
|
"ids": ["id1", "id2"],
|
|
}
|
|
with pytest.raises(ValueError) as e:
|
|
collection.update(**invalid_records)
|
|
assert "embedding" in str(e.value)
|
|
|
|
# Upsert with invalid embeddings
|
|
invalid_records = {
|
|
"embeddings": [[[1.1, 2.3, 3.2]], [[1.2, 2.24, 3.2]]],
|
|
"ids": ["id1", "id2"],
|
|
}
|
|
with pytest.raises(ValueError) as e:
|
|
collection.upsert(**invalid_records)
|
|
assert "embedding" in str(e.value)
|