386 lines
12 KiB
Python
386 lines
12 KiB
Python
from typing import Optional, cast
|
|
from chromadb.api import API
|
|
from chromadb.config import Settings, System
|
|
from chromadb.api.types import (
|
|
Documents,
|
|
Embeddings,
|
|
EmbeddingFunction,
|
|
IDs,
|
|
Include,
|
|
Metadatas,
|
|
Where,
|
|
WhereDocument,
|
|
GetResult,
|
|
QueryResult,
|
|
CollectionMetadata,
|
|
)
|
|
import chromadb.utils.embedding_functions as ef
|
|
import requests
|
|
import json
|
|
from typing import Sequence
|
|
from chromadb.api.models.Collection import Collection
|
|
import chromadb.errors as errors
|
|
from uuid import UUID
|
|
from chromadb.telemetry import Telemetry
|
|
from overrides import override
|
|
|
|
|
|
class FastAPI(API):
|
|
_settings: Settings
|
|
|
|
def __init__(self, system: System):
|
|
super().__init__(system)
|
|
url_prefix = "https" if system.settings.chroma_server_ssl_enabled else "http"
|
|
system.settings.require("chroma_server_host")
|
|
system.settings.require("chroma_server_http_port")
|
|
|
|
self._telemetry_client = self.require(Telemetry)
|
|
self._settings = system.settings
|
|
|
|
port_suffix = (
|
|
f":{system.settings.chroma_server_http_port}"
|
|
if system.settings.chroma_server_http_port
|
|
else ""
|
|
)
|
|
self._api_url = (
|
|
f"{url_prefix}://{system.settings.chroma_server_host}{port_suffix}/api/v1"
|
|
)
|
|
|
|
self._header = system.settings.chroma_server_headers
|
|
self._session = requests.Session()
|
|
if self._header is not None:
|
|
self._session.headers.update(self._header)
|
|
|
|
@override
|
|
def heartbeat(self) -> int:
|
|
"""Returns the current server time in nanoseconds to check if the server is alive"""
|
|
resp = self._session.get(self._api_url)
|
|
raise_chroma_error(resp)
|
|
return int(resp.json()["nanosecond heartbeat"])
|
|
|
|
@override
|
|
def list_collections(self) -> Sequence[Collection]:
|
|
"""Returns a list of all collections"""
|
|
resp = self._session.get(self._api_url + "/collections")
|
|
raise_chroma_error(resp)
|
|
json_collections = resp.json()
|
|
collections = []
|
|
for json_collection in json_collections:
|
|
collections.append(Collection(self, **json_collection))
|
|
|
|
return collections
|
|
|
|
@override
|
|
def create_collection(
|
|
self,
|
|
name: str,
|
|
metadata: Optional[CollectionMetadata] = None,
|
|
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
|
|
get_or_create: bool = False,
|
|
) -> Collection:
|
|
"""Creates a collection"""
|
|
resp = self._session.post(
|
|
self._api_url + "/collections",
|
|
data=json.dumps(
|
|
{"name": name, "metadata": metadata, "get_or_create": get_or_create}
|
|
),
|
|
)
|
|
raise_chroma_error(resp)
|
|
resp_json = resp.json()
|
|
return Collection(
|
|
client=self,
|
|
id=resp_json["id"],
|
|
name=resp_json["name"],
|
|
embedding_function=embedding_function,
|
|
metadata=resp_json["metadata"],
|
|
)
|
|
|
|
@override
|
|
def get_collection(
|
|
self,
|
|
name: str,
|
|
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
|
|
) -> Collection:
|
|
"""Returns a collection"""
|
|
resp = self._session.get(self._api_url + "/collections/" + name)
|
|
raise_chroma_error(resp)
|
|
resp_json = resp.json()
|
|
return Collection(
|
|
client=self,
|
|
name=resp_json["name"],
|
|
id=resp_json["id"],
|
|
embedding_function=embedding_function,
|
|
metadata=resp_json["metadata"],
|
|
)
|
|
|
|
@override
|
|
def get_or_create_collection(
|
|
self,
|
|
name: str,
|
|
metadata: Optional[CollectionMetadata] = None,
|
|
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
|
|
) -> Collection:
|
|
return self.create_collection(
|
|
name, metadata, embedding_function, get_or_create=True
|
|
)
|
|
|
|
@override
|
|
def _modify(
|
|
self,
|
|
id: UUID,
|
|
new_name: Optional[str] = None,
|
|
new_metadata: Optional[CollectionMetadata] = None,
|
|
) -> None:
|
|
"""Updates a collection"""
|
|
resp = self._session.put(
|
|
self._api_url + "/collections/" + str(id),
|
|
data=json.dumps({"new_metadata": new_metadata, "new_name": new_name}),
|
|
)
|
|
raise_chroma_error(resp)
|
|
|
|
@override
|
|
def delete_collection(self, name: str) -> None:
|
|
"""Deletes a collection"""
|
|
resp = self._session.delete(self._api_url + "/collections/" + name)
|
|
raise_chroma_error(resp)
|
|
|
|
@override
|
|
def _count(self, collection_id: UUID) -> int:
|
|
"""Returns the number of embeddings in the database"""
|
|
resp = self._session.get(
|
|
self._api_url + "/collections/" + str(collection_id) + "/count"
|
|
)
|
|
raise_chroma_error(resp)
|
|
return cast(int, resp.json())
|
|
|
|
@override
|
|
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
|
|
return self._get(
|
|
collection_id,
|
|
limit=n,
|
|
include=["embeddings", "documents", "metadatas"],
|
|
)
|
|
|
|
@override
|
|
def _get(
|
|
self,
|
|
collection_id: UUID,
|
|
ids: Optional[IDs] = None,
|
|
where: Optional[Where] = {},
|
|
sort: Optional[str] = None,
|
|
limit: Optional[int] = None,
|
|
offset: Optional[int] = None,
|
|
page: Optional[int] = None,
|
|
page_size: Optional[int] = None,
|
|
where_document: Optional[WhereDocument] = {},
|
|
include: Include = ["metadatas", "documents"],
|
|
) -> GetResult:
|
|
if page and page_size:
|
|
offset = (page - 1) * page_size
|
|
limit = page_size
|
|
|
|
resp = self._session.post(
|
|
self._api_url + "/collections/" + str(collection_id) + "/get",
|
|
data=json.dumps(
|
|
{
|
|
"ids": ids,
|
|
"where": where,
|
|
"sort": sort,
|
|
"limit": limit,
|
|
"offset": offset,
|
|
"where_document": where_document,
|
|
"include": include,
|
|
}
|
|
),
|
|
)
|
|
|
|
raise_chroma_error(resp)
|
|
body = resp.json()
|
|
return GetResult(
|
|
ids=body["ids"],
|
|
embeddings=body.get("embeddings", None),
|
|
metadatas=body.get("metadatas", None),
|
|
documents=body.get("documents", None),
|
|
)
|
|
|
|
@override
|
|
def _delete(
|
|
self,
|
|
collection_id: UUID,
|
|
ids: Optional[IDs] = None,
|
|
where: Optional[Where] = {},
|
|
where_document: Optional[WhereDocument] = {},
|
|
) -> IDs:
|
|
"""Deletes embeddings from the database"""
|
|
resp = self._session.post(
|
|
self._api_url + "/collections/" + str(collection_id) + "/delete",
|
|
data=json.dumps(
|
|
{"where": where, "ids": ids, "where_document": where_document}
|
|
),
|
|
)
|
|
|
|
raise_chroma_error(resp)
|
|
return cast(IDs, resp.json())
|
|
|
|
@override
|
|
def _add(
|
|
self,
|
|
ids: IDs,
|
|
collection_id: UUID,
|
|
embeddings: Embeddings,
|
|
metadatas: Optional[Metadatas] = None,
|
|
documents: Optional[Documents] = None,
|
|
) -> bool:
|
|
"""
|
|
Adds a batch of embeddings to the database
|
|
- pass in column oriented data lists
|
|
"""
|
|
resp = self._session.post(
|
|
self._api_url + "/collections/" + str(collection_id) + "/add",
|
|
data=json.dumps(
|
|
{
|
|
"ids": ids,
|
|
"embeddings": embeddings,
|
|
"metadatas": metadatas,
|
|
"documents": documents,
|
|
}
|
|
),
|
|
)
|
|
|
|
raise_chroma_error(resp)
|
|
return True
|
|
|
|
@override
|
|
def _update(
|
|
self,
|
|
collection_id: UUID,
|
|
ids: IDs,
|
|
embeddings: Optional[Embeddings] = None,
|
|
metadatas: Optional[Metadatas] = None,
|
|
documents: Optional[Documents] = None,
|
|
) -> bool:
|
|
"""
|
|
Updates a batch of embeddings in the database
|
|
- pass in column oriented data lists
|
|
"""
|
|
resp = self._session.post(
|
|
self._api_url + "/collections/" + str(collection_id) + "/update",
|
|
data=json.dumps(
|
|
{
|
|
"ids": ids,
|
|
"embeddings": embeddings,
|
|
"metadatas": metadatas,
|
|
"documents": documents,
|
|
}
|
|
),
|
|
)
|
|
|
|
resp.raise_for_status()
|
|
return True
|
|
|
|
@override
|
|
def _upsert(
|
|
self,
|
|
collection_id: UUID,
|
|
ids: IDs,
|
|
embeddings: Embeddings,
|
|
metadatas: Optional[Metadatas] = None,
|
|
documents: Optional[Documents] = None,
|
|
) -> bool:
|
|
"""
|
|
Upserts a batch of embeddings in the database
|
|
- pass in column oriented data lists
|
|
"""
|
|
resp = self._session.post(
|
|
self._api_url + "/collections/" + str(collection_id) + "/upsert",
|
|
data=json.dumps(
|
|
{
|
|
"ids": ids,
|
|
"embeddings": embeddings,
|
|
"metadatas": metadatas,
|
|
"documents": documents,
|
|
}
|
|
),
|
|
)
|
|
|
|
resp.raise_for_status()
|
|
return True
|
|
|
|
@override
|
|
def _query(
|
|
self,
|
|
collection_id: UUID,
|
|
query_embeddings: Embeddings,
|
|
n_results: int = 10,
|
|
where: Optional[Where] = {},
|
|
where_document: Optional[WhereDocument] = {},
|
|
include: Include = ["metadatas", "documents", "distances"],
|
|
) -> QueryResult:
|
|
"""Gets the nearest neighbors of a single embedding"""
|
|
resp = self._session.post(
|
|
self._api_url + "/collections/" + str(collection_id) + "/query",
|
|
data=json.dumps(
|
|
{
|
|
"query_embeddings": query_embeddings,
|
|
"n_results": n_results,
|
|
"where": where,
|
|
"where_document": where_document,
|
|
"include": include,
|
|
}
|
|
),
|
|
)
|
|
|
|
raise_chroma_error(resp)
|
|
body = resp.json()
|
|
|
|
return QueryResult(
|
|
ids=body["ids"],
|
|
distances=body.get("distances", None),
|
|
embeddings=body.get("embeddings", None),
|
|
metadatas=body.get("metadatas", None),
|
|
documents=body.get("documents", None),
|
|
)
|
|
|
|
@override
|
|
def reset(self) -> bool:
|
|
"""Resets the database"""
|
|
resp = self._session.post(self._api_url + "/reset")
|
|
raise_chroma_error(resp)
|
|
return cast(bool, resp.json())
|
|
|
|
@override
|
|
def get_version(self) -> str:
|
|
"""Returns the version of the server"""
|
|
resp = self._session.get(self._api_url + "/version")
|
|
raise_chroma_error(resp)
|
|
return cast(str, resp.json())
|
|
|
|
@override
|
|
def get_settings(self) -> Settings:
|
|
"""Returns the settings of the client"""
|
|
return self._settings
|
|
|
|
|
|
def raise_chroma_error(resp: requests.Response) -> None:
|
|
"""Raises an error if the response is not ok, using a ChromaError if possible"""
|
|
if resp.ok:
|
|
return
|
|
|
|
chroma_error = None
|
|
try:
|
|
body = resp.json()
|
|
if "error" in body:
|
|
if body["error"] in errors.error_types:
|
|
chroma_error = errors.error_types[body["error"]](body["message"])
|
|
|
|
except BaseException:
|
|
pass
|
|
|
|
if chroma_error:
|
|
raise chroma_error
|
|
|
|
try:
|
|
resp.raise_for_status()
|
|
except requests.HTTPError:
|
|
raise (Exception(resp.text))
|