import asyncio import os import typing from concurrent.futures import ThreadPoolExecutor from tokenizers import Tokenizer # type: ignore import logging import httpx from cohere.types.detokenize_response import DetokenizeResponse from cohere.types.tokenize_response import TokenizeResponse from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate from .base_client import BaseCohere, AsyncBaseCohere, OMIT from .config import embed_batch_size from .core import RequestOptions from .environment import ClientEnvironment from .manually_maintained.cache import CacheMixin from .manually_maintained import tokenizers as local_tokenizers from .overrides import run_overrides from .utils import wait, async_wait, merge_embed_responses, SyncSdkUtils, AsyncSdkUtils logger = logging.getLogger(__name__) run_overrides() # Use NoReturn as Never type for compatibility Never = typing.NoReturn def validate_args(obj: typing.Any, method_name: str, check_fn: typing.Callable[[typing.Any], typing.Any]) -> None: method = getattr(obj, method_name) def _wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: check_fn(*args, **kwargs) return method(*args, **kwargs) async def _async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: # The `return await` looks redundant, but it's necessary to ensure that the return type is correct. check_fn(*args, **kwargs) return await method(*args, **kwargs) wrapped = _wrapped if asyncio.iscoroutinefunction(method): wrapped = _async_wrapped wrapped.__name__ = method.__name__ wrapped.__doc__ = method.__doc__ setattr(obj, method_name, wrapped) def throw_if_stream_is_true(*args, **kwargs) -> None: if kwargs.get("stream") is True: raise ValueError( "Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)" ) def moved_function(fn_name: str, new_fn_name: str) -> typing.Any: """ This method is moved. Please update usage. """ def fn(*args, **kwargs): raise ValueError( f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been moved to {new_fn_name}(...). " f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues." ) return fn def deprecated_function(fn_name: str) -> typing.Any: """ This method is deprecated. Please update usage. """ def fn(*args, **kwargs): raise ValueError( f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been deprecated. " f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues." ) return fn # Logs a warning when a user calls a function with an experimental parameter (kwarg in our case) # `deprecated_kwarg` is the name of the experimental parameter, which can be a dot-separated string for nested parameters def experimental_kwarg_decorator(func, deprecated_kwarg): # Recursive utility function to check if a kwarg is present in the kwargs. def check_kwarg(deprecated_kwarg: str, kwargs: typing.Dict[str, typing.Any]) -> bool: if "." in deprecated_kwarg: key, rest = deprecated_kwarg.split(".", 1) if key in kwargs: return check_kwarg(rest, kwargs[key]) return deprecated_kwarg in kwargs def _wrapped(*args, **kwargs): if check_kwarg(deprecated_kwarg, kwargs): logger.warning( f"The `{deprecated_kwarg}` parameter is an experimental feature and may change in future releases.\n" "To suppress this warning, set `log_warning_experimental_features=False` when initializing the client." ) return func(*args, **kwargs) async def _async_wrapped(*args, **kwargs): if check_kwarg(deprecated_kwarg, kwargs): logger.warning( f"The `{deprecated_kwarg}` parameter is an experimental feature and may change in future releases.\n" "To suppress this warning, set `log_warning_experimental_features=False` when initializing the client." ) return await func(*args, **kwargs) wrap = _wrapped if asyncio.iscoroutinefunction(func): wrap = _async_wrapped wrap.__name__ = func.__name__ wrap.__doc__ = func.__doc__ return wrap def fix_base_url(base_url: typing.Optional[str]) -> typing.Optional[str]: if base_url is not None: if "cohere.com" in base_url or "cohere.ai" in base_url: return base_url.replace("/v1", "") return base_url return None class Client(BaseCohere, CacheMixin): _executor: ThreadPoolExecutor def __init__( self, api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None, *, base_url: typing.Optional[str] = os.getenv("CO_API_URL"), environment: ClientEnvironment = ClientEnvironment.PRODUCTION, client_name: typing.Optional[str] = None, timeout: typing.Optional[float] = None, httpx_client: typing.Optional[httpx.Client] = None, thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64), log_warning_experimental_features: bool = True, ): if api_key is None: api_key = _get_api_key_from_environment() base_url = fix_base_url(base_url) self._executor = thread_pool_executor BaseCohere.__init__( self, base_url=base_url, environment=environment, client_name=client_name, token=api_key, timeout=timeout, httpx_client=httpx_client, ) validate_args(self, "chat", throw_if_stream_is_true) if log_warning_experimental_features: self.chat = experimental_kwarg_decorator(self.chat, "response_format.schema") # type: ignore self.chat_stream = experimental_kwarg_decorator(self.chat_stream, "response_format.schema") # type: ignore utils = SyncSdkUtils() # support context manager until Fern upstreams # https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self._client_wrapper.httpx_client.httpx_client.close() wait = wait def embed( self, *, texts: typing.Optional[typing.Sequence[str]] = OMIT, images: typing.Optional[typing.Sequence[str]] = OMIT, model: typing.Optional[str] = OMIT, input_type: typing.Optional[EmbedInputType] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[EmbedRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, batching: typing.Optional[bool] = True, ) -> EmbedResponse: # skip batching for images for now if batching is False or images is not OMIT: return BaseCohere.embed( self, texts=texts, images=images, model=model, input_type=input_type, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ) textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else [] texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)] responses = [ response for response in self._executor.map( lambda text_batch: BaseCohere.embed( self, texts=text_batch, model=model, input_type=input_type, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ), texts_batches, ) ] return merge_embed_responses(responses) """ The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues. """ check_api_key: Never = deprecated_function("check_api_key") loglikelihood: Never = deprecated_function("loglikelihood") batch_generate: Never = deprecated_function("batch_generate") codebook: Never = deprecated_function("codebook") batch_tokenize: Never = deprecated_function("batch_tokenize") batch_detokenize: Never = deprecated_function("batch_detokenize") detect_language: Never = deprecated_function("detect_language") generate_feedback: Never = deprecated_function("generate_feedback") generate_preference_feedback: Never = deprecated_function("generate_preference_feedback") create_dataset: Never = moved_function("create_dataset", ".datasets.create") get_dataset: Never = moved_function("get_dataset", ".datasets.get") list_datasets: Never = moved_function("list_datasets", ".datasets.list") delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete") get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage") wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait") _check_response: Never = deprecated_function("_check_response") _request: Never = deprecated_function("_request") create_cluster_job: Never = deprecated_function("create_cluster_job") get_cluster_job: Never = deprecated_function("get_cluster_job") list_cluster_jobs: Never = deprecated_function("list_cluster_jobs") wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job") create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create") list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list") get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get") cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel") wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait") create_custom_model: Never = deprecated_function("create_custom_model") wait_for_custom_model: Never = deprecated_function("wait_for_custom_model") _upload_dataset: Never = deprecated_function("_upload_dataset") _create_signed_url: Never = deprecated_function("_create_signed_url") get_custom_model: Never = deprecated_function("get_custom_model") get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name") get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics") list_custom_models: Never = deprecated_function("list_custom_models") create_connector: Never = moved_function("create_connector", ".connectors.create") update_connector: Never = moved_function("update_connector", ".connectors.update") get_connector: Never = moved_function("get_connector", ".connectors.get") list_connectors: Never = moved_function("list_connectors", ".connectors.list") delete_connector: Never = moved_function("delete_connector", ".connectors.delete") oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize") def tokenize( self, *, text: str, model: str, request_options: typing.Optional[RequestOptions] = None, offline: bool = True, ) -> TokenizeResponse: # `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached), # and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True. opts: RequestOptions = request_options or {} # type: ignore if offline: try: tokens = local_tokenizers.local_tokenize(self, text=text, model=model) return TokenizeResponse(tokens=tokens, token_strings=[]) except Exception: # Fallback to calling the API. opts["additional_headers"] = opts.get("additional_headers", {}) opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed" return super().tokenize(text=text, model=model, request_options=opts) def detokenize( self, *, tokens: typing.Sequence[int], model: str, request_options: typing.Optional[RequestOptions] = None, offline: typing.Optional[bool] = True, ) -> DetokenizeResponse: # `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached), # and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True. opts: RequestOptions = request_options or {} # type: ignore if offline: try: text = local_tokenizers.local_detokenize(self, model=model, tokens=tokens) return DetokenizeResponse(text=text) except Exception: # Fallback to calling the API. opts["additional_headers"] = opts.get("additional_headers", {}) opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed" return super().detokenize(tokens=tokens, model=model, request_options=opts) def fetch_tokenizer(self, *, model: str) -> Tokenizer: """ Returns a Hugging Face tokenizer from a given model name. """ return local_tokenizers.get_hf_tokenizer(self, model) class AsyncClient(AsyncBaseCohere, CacheMixin): _executor: ThreadPoolExecutor def __init__( self, api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None, *, base_url: typing.Optional[str] = os.getenv("CO_API_URL"), environment: ClientEnvironment = ClientEnvironment.PRODUCTION, client_name: typing.Optional[str] = None, timeout: typing.Optional[float] = None, httpx_client: typing.Optional[httpx.AsyncClient] = None, thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64), log_warning_experimental_features: bool = True, ): if api_key is None: api_key = _get_api_key_from_environment() base_url = fix_base_url(base_url) self._executor = thread_pool_executor AsyncBaseCohere.__init__( self, base_url=base_url, environment=environment, client_name=client_name, token=api_key, timeout=timeout, httpx_client=httpx_client, ) validate_args(self, "chat", throw_if_stream_is_true) if log_warning_experimental_features: self.chat = experimental_kwarg_decorator(self.chat, "response_format.schema") # type: ignore self.chat_stream = experimental_kwarg_decorator(self.chat_stream, "response_format.schema") # type: ignore utils = AsyncSdkUtils() # support context manager until Fern upstreams # https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): await self._client_wrapper.httpx_client.httpx_client.aclose() wait = async_wait async def embed( self, *, texts: typing.Optional[typing.Sequence[str]] = OMIT, images: typing.Optional[typing.Sequence[str]] = OMIT, model: typing.Optional[str] = OMIT, input_type: typing.Optional[EmbedInputType] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, truncate: typing.Optional[EmbedRequestTruncate] = OMIT, request_options: typing.Optional[RequestOptions] = None, batching: typing.Optional[bool] = True, ) -> EmbedResponse: # skip batching for images for now if batching is False or images is not OMIT: return await AsyncBaseCohere.embed( self, texts=texts, images=images, model=model, input_type=input_type, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ) textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else [] texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)] responses = typing.cast( typing.List[EmbedResponse], await asyncio.gather( *[ AsyncBaseCohere.embed( self, texts=text_batch, model=model, input_type=input_type, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ) for text_batch in texts_batches ] ), ) return merge_embed_responses(responses) """ The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues. """ check_api_key: Never = deprecated_function("check_api_key") loglikelihood: Never = deprecated_function("loglikelihood") batch_generate: Never = deprecated_function("batch_generate") codebook: Never = deprecated_function("codebook") batch_tokenize: Never = deprecated_function("batch_tokenize") batch_detokenize: Never = deprecated_function("batch_detokenize") detect_language: Never = deprecated_function("detect_language") generate_feedback: Never = deprecated_function("generate_feedback") generate_preference_feedback: Never = deprecated_function("generate_preference_feedback") create_dataset: Never = moved_function("create_dataset", ".datasets.create") get_dataset: Never = moved_function("get_dataset", ".datasets.get") list_datasets: Never = moved_function("list_datasets", ".datasets.list") delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete") get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage") wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait") _check_response: Never = deprecated_function("_check_response") _request: Never = deprecated_function("_request") create_cluster_job: Never = deprecated_function("create_cluster_job") get_cluster_job: Never = deprecated_function("get_cluster_job") list_cluster_jobs: Never = deprecated_function("list_cluster_jobs") wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job") create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create") list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list") get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get") cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel") wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait") create_custom_model: Never = deprecated_function("create_custom_model") wait_for_custom_model: Never = deprecated_function("wait_for_custom_model") _upload_dataset: Never = deprecated_function("_upload_dataset") _create_signed_url: Never = deprecated_function("_create_signed_url") get_custom_model: Never = deprecated_function("get_custom_model") get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name") get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics") list_custom_models: Never = deprecated_function("list_custom_models") create_connector: Never = moved_function("create_connector", ".connectors.create") update_connector: Never = moved_function("update_connector", ".connectors.update") get_connector: Never = moved_function("get_connector", ".connectors.get") list_connectors: Never = moved_function("list_connectors", ".connectors.list") delete_connector: Never = moved_function("delete_connector", ".connectors.delete") oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize") async def tokenize( self, *, text: str, model: str, request_options: typing.Optional[RequestOptions] = None, offline: typing.Optional[bool] = True, ) -> TokenizeResponse: # `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached), # and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True. opts: RequestOptions = request_options or {} # type: ignore if offline: try: tokens = await local_tokenizers.async_local_tokenize(self, model=model, text=text) return TokenizeResponse(tokens=tokens, token_strings=[]) except Exception: opts["additional_headers"] = opts.get("additional_headers", {}) opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed" return await super().tokenize(text=text, model=model, request_options=opts) async def detokenize( self, *, tokens: typing.Sequence[int], model: str, request_options: typing.Optional[RequestOptions] = None, offline: typing.Optional[bool] = True, ) -> DetokenizeResponse: # `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached), # and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True. opts: RequestOptions = request_options or {} # type: ignore if offline: try: text = await local_tokenizers.async_local_detokenize(self, model=model, tokens=tokens) return DetokenizeResponse(text=text) except Exception: opts["additional_headers"] = opts.get("additional_headers", {}) opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed" return await super().detokenize(tokens=tokens, model=model, request_options=opts) async def fetch_tokenizer(self, *, model: str) -> Tokenizer: """ Returns a Hugging Face tokenizer from a given model name. """ return await local_tokenizers.async_get_hf_tokenizer(self, model) def _get_api_key_from_environment() -> typing.Optional[str]: """ Retrieves the Cohere API key from specific environment variables. CO_API_KEY is preferred (and documented) COHERE_API_KEY is accepted (but not documented). """ return os.getenv("CO_API_KEY", os.getenv("COHERE_API_KEY"))