295 lines
12 KiB
Python
295 lines
12 KiB
Python
import math
|
|
from chromadb.test.property.strategies import NormalizedRecordSet, RecordSet
|
|
from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast
|
|
from typing_extensions import Literal
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
from chromadb.api import types
|
|
from chromadb.api.models.Collection import Collection
|
|
from hypothesis import note
|
|
from hypothesis.errors import InvalidArgument
|
|
|
|
from chromadb.utils import distance_functions
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def wrap(value: Union[T, List[T]]) -> List[T]:
|
|
"""Wrap a value in a list if it is not a list"""
|
|
if value is None:
|
|
raise InvalidArgument("value cannot be None")
|
|
elif isinstance(value, List):
|
|
return value
|
|
else:
|
|
return [value]
|
|
|
|
|
|
def wrap_all(record_set: RecordSet) -> NormalizedRecordSet:
|
|
"""Ensure that an embedding set has lists for all its values"""
|
|
|
|
embedding_list: Optional[types.Embeddings]
|
|
if record_set["embeddings"] is None:
|
|
embedding_list = None
|
|
elif isinstance(record_set["embeddings"], list):
|
|
assert record_set["embeddings"] is not None
|
|
if len(record_set["embeddings"]) > 0 and not all(
|
|
isinstance(embedding, list) for embedding in record_set["embeddings"]
|
|
):
|
|
if all(isinstance(e, (int, float)) for e in record_set["embeddings"]):
|
|
embedding_list = cast(types.Embeddings, [record_set["embeddings"]])
|
|
else:
|
|
raise InvalidArgument("an embedding must be a list of floats or ints")
|
|
else:
|
|
embedding_list = cast(types.Embeddings, record_set["embeddings"])
|
|
else:
|
|
raise InvalidArgument(
|
|
"embeddings must be a list of lists, a list of numbers, or None"
|
|
)
|
|
|
|
return {
|
|
"ids": wrap(record_set["ids"]),
|
|
"documents": wrap(record_set["documents"])
|
|
if record_set["documents"] is not None
|
|
else None,
|
|
"metadatas": wrap(record_set["metadatas"])
|
|
if record_set["metadatas"] is not None
|
|
else None,
|
|
"embeddings": embedding_list,
|
|
}
|
|
|
|
|
|
def count(collection: Collection, record_set: RecordSet) -> None:
|
|
"""The given collection count is equal to the number of embeddings"""
|
|
count = collection.count()
|
|
normalized_record_set = wrap_all(record_set)
|
|
assert count == len(normalized_record_set["ids"])
|
|
|
|
|
|
def _field_matches(
|
|
collection: Collection,
|
|
normalized_record_set: NormalizedRecordSet,
|
|
field_name: Union[
|
|
Literal["documents"], Literal["metadatas"], Literal["embeddings"]
|
|
],
|
|
) -> None:
|
|
"""
|
|
The actual embedding field is equal to the expected field
|
|
field_name: one of [documents, metadatas]
|
|
"""
|
|
result = collection.get(ids=normalized_record_set["ids"], include=[field_name])
|
|
# The test_out_of_order_ids test fails because of this in test_add.py
|
|
# Here we sort by the ids to match the input order
|
|
embedding_id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])}
|
|
actual_field = result[field_name]
|
|
|
|
if len(normalized_record_set["ids"]) == 0:
|
|
assert isinstance(actual_field, list) and len(actual_field) == 0
|
|
return
|
|
|
|
# This assert should never happen, if we include metadatas/documents it will be
|
|
# [None, None..] if there is no metadata. It will not be just None.
|
|
assert actual_field is not None
|
|
sorted_field = sorted(
|
|
enumerate(actual_field),
|
|
key=lambda index_and_field_value: embedding_id_to_index[
|
|
result["ids"][index_and_field_value[0]]
|
|
],
|
|
)
|
|
field_values = [field_value for _, field_value in sorted_field]
|
|
|
|
expected_field = normalized_record_set[field_name]
|
|
if expected_field is None:
|
|
# Since an RecordSet is the user input, we need to convert the documents to
|
|
# a List since thats what the API returns -> none per entry
|
|
expected_field = [None] * len(normalized_record_set["ids"]) # type: ignore
|
|
if field_name == "embeddings":
|
|
assert np.allclose(np.array(field_values), np.array(expected_field))
|
|
else:
|
|
assert field_values == expected_field
|
|
|
|
|
|
def ids_match(collection: Collection, record_set: RecordSet) -> None:
|
|
"""The actual embedding ids is equal to the expected ids"""
|
|
normalized_record_set = wrap_all(record_set)
|
|
actual_ids = collection.get(ids=normalized_record_set["ids"], include=[])["ids"]
|
|
# The test_out_of_order_ids test fails because of this in test_add.py
|
|
# Here we sort the ids to match the input order
|
|
embedding_id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])}
|
|
actual_ids = sorted(actual_ids, key=lambda id: embedding_id_to_index[id])
|
|
assert actual_ids == normalized_record_set["ids"]
|
|
|
|
|
|
def metadatas_match(collection: Collection, record_set: RecordSet) -> None:
|
|
"""The actual embedding metadata is equal to the expected metadata"""
|
|
normalized_record_set = wrap_all(record_set)
|
|
_field_matches(collection, normalized_record_set, "metadatas")
|
|
|
|
|
|
def documents_match(collection: Collection, record_set: RecordSet) -> None:
|
|
"""The actual embedding documents is equal to the expected documents"""
|
|
normalized_record_set = wrap_all(record_set)
|
|
_field_matches(collection, normalized_record_set, "documents")
|
|
|
|
|
|
def embeddings_match(collection: Collection, record_set: RecordSet) -> None:
|
|
"""The actual embedding documents is equal to the expected documents"""
|
|
normalized_record_set = wrap_all(record_set)
|
|
_field_matches(collection, normalized_record_set, "embeddings")
|
|
|
|
|
|
def no_duplicates(collection: Collection) -> None:
|
|
ids = collection.get()["ids"]
|
|
assert len(ids) == len(set(ids))
|
|
|
|
|
|
def _exact_distances(
|
|
query: types.Embeddings,
|
|
targets: types.Embeddings,
|
|
distance_fn: Callable[
|
|
[npt.ArrayLike, npt.ArrayLike], float
|
|
] = distance_functions.l2,
|
|
) -> Tuple[List[List[int]], List[List[float]]]:
|
|
"""Return the ordered indices and distances from each query to each target"""
|
|
np_query = np.array(query)
|
|
np_targets = np.array(targets)
|
|
|
|
# Compute the distance between each query and each target, using the distance function
|
|
distances = np.apply_along_axis(
|
|
lambda query: np.apply_along_axis(distance_fn, 1, np_targets, query),
|
|
1,
|
|
np_query,
|
|
)
|
|
# Sort the distances and return the indices
|
|
return np.argsort(distances).tolist(), distances.tolist()
|
|
|
|
|
|
def is_metadata_valid(normalized_record_set: NormalizedRecordSet) -> bool:
|
|
if normalized_record_set["metadatas"] is None:
|
|
return True
|
|
return not any([len(m) == 0 for m in normalized_record_set["metadatas"]])
|
|
|
|
|
|
def ann_accuracy(
|
|
collection: Collection,
|
|
record_set: RecordSet,
|
|
n_results: int = 1,
|
|
min_recall: float = 0.99,
|
|
embedding_function: Optional[types.EmbeddingFunction] = None,
|
|
query_indices: Optional[List[int]] = None,
|
|
) -> None:
|
|
"""Validate that the API performs nearest_neighbor searches correctly"""
|
|
normalized_record_set = wrap_all(record_set)
|
|
|
|
if len(normalized_record_set["ids"]) == 0:
|
|
return # nothing to test here
|
|
|
|
embeddings: Optional[types.Embeddings] = normalized_record_set["embeddings"]
|
|
have_embeddings = embeddings is not None and len(embeddings) > 0
|
|
if not have_embeddings:
|
|
assert embedding_function is not None
|
|
assert normalized_record_set["documents"] is not None
|
|
assert isinstance(normalized_record_set["documents"], list)
|
|
# Compute the embeddings for the documents
|
|
embeddings = embedding_function(normalized_record_set["documents"])
|
|
|
|
# l2 is the default distance function
|
|
distance_function = distance_functions.l2
|
|
accuracy_threshold = 1e-6
|
|
assert collection.metadata is not None
|
|
assert embeddings is not None
|
|
if "hnsw:space" in collection.metadata:
|
|
space = collection.metadata["hnsw:space"]
|
|
# TODO: ip and cosine are numerically unstable in HNSW.
|
|
# The higher the dimensionality, the more noise is introduced, since each float element
|
|
# of the vector has noise added, which is then subsequently included in all normalization calculations.
|
|
# This means that higher dimensions will have more noise, and thus more error.
|
|
assert all(isinstance(e, list) for e in embeddings)
|
|
dim = len(embeddings[0])
|
|
accuracy_threshold = accuracy_threshold * math.pow(10, int(math.log10(dim)))
|
|
|
|
if space == "cosine":
|
|
distance_function = distance_functions.cosine
|
|
if space == "ip":
|
|
distance_function = distance_functions.ip
|
|
|
|
# Perform exact distance computation
|
|
query_embeddings = (
|
|
embeddings if query_indices is None else [embeddings[i] for i in query_indices]
|
|
)
|
|
query_documents = normalized_record_set["documents"]
|
|
if query_indices is not None and query_documents is not None:
|
|
query_documents = [query_documents[i] for i in query_indices]
|
|
|
|
indices, distances = _exact_distances(
|
|
query_embeddings, embeddings, distance_fn=distance_function
|
|
)
|
|
|
|
query_results = collection.query(
|
|
query_embeddings=query_embeddings if have_embeddings else None,
|
|
query_texts=query_documents if not have_embeddings else None,
|
|
n_results=n_results,
|
|
include=["embeddings", "documents", "metadatas", "distances"],
|
|
)
|
|
|
|
assert query_results["distances"] is not None
|
|
assert query_results["documents"] is not None
|
|
assert query_results["metadatas"] is not None
|
|
assert query_results["embeddings"] is not None
|
|
|
|
# Dict of ids to indices
|
|
id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])}
|
|
missing = 0
|
|
for i, (indices_i, distances_i) in enumerate(zip(indices, distances)):
|
|
expected_ids = np.array(normalized_record_set["ids"])[indices_i[:n_results]]
|
|
missing += len(set(expected_ids) - set(query_results["ids"][i]))
|
|
|
|
# For each id in the query results, find the index in the embeddings set
|
|
# and assert that the embeddings are the same
|
|
for j, id in enumerate(query_results["ids"][i]):
|
|
# This may be because the true nth nearest neighbor didn't get returned by the ANN query
|
|
unexpected_id = id not in expected_ids
|
|
index = id_to_index[id]
|
|
|
|
correct_distance = np.allclose(
|
|
distances_i[index],
|
|
query_results["distances"][i][j],
|
|
atol=accuracy_threshold,
|
|
)
|
|
if unexpected_id:
|
|
# If the ID is unexpcted, but the distance is correct, then we
|
|
# have a duplicate in the data. In this case, we should not reduce recall.
|
|
if correct_distance:
|
|
missing -= 1
|
|
else:
|
|
continue
|
|
else:
|
|
assert correct_distance
|
|
|
|
assert np.allclose(embeddings[index], query_results["embeddings"][i][j])
|
|
if normalized_record_set["documents"] is not None:
|
|
assert (
|
|
normalized_record_set["documents"][index]
|
|
== query_results["documents"][i][j]
|
|
)
|
|
if normalized_record_set["metadatas"] is not None:
|
|
assert (
|
|
normalized_record_set["metadatas"][index]
|
|
== query_results["metadatas"][i][j]
|
|
)
|
|
|
|
size = len(normalized_record_set["ids"])
|
|
recall = (size - missing) / size
|
|
|
|
try:
|
|
note(
|
|
f"recall: {recall}, missing {missing} out of {size}, accuracy threshold {accuracy_threshold}"
|
|
)
|
|
except InvalidArgument:
|
|
pass # it's ok if we're running outside hypothesis
|
|
|
|
assert recall >= min_recall
|
|
|
|
# Ensure that the query results are sorted by distance
|
|
for distance_result in query_results["distances"]:
|
|
assert np.allclose(np.sort(distance_result), distance_result)
|