520 lines
23 KiB
Python
520 lines
23 KiB
Python
import asyncio
|
|
import os
|
|
import typing
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from tokenizers import Tokenizer # type: ignore
|
|
import logging
|
|
|
|
import httpx
|
|
|
|
from cohere.types.detokenize_response import DetokenizeResponse
|
|
from cohere.types.tokenize_response import TokenizeResponse
|
|
|
|
from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate
|
|
from .base_client import BaseCohere, AsyncBaseCohere, OMIT
|
|
from .config import embed_batch_size
|
|
from .core import RequestOptions
|
|
from .environment import ClientEnvironment
|
|
from .manually_maintained.cache import CacheMixin
|
|
from .manually_maintained import tokenizers as local_tokenizers
|
|
from .overrides import run_overrides
|
|
from .utils import wait, async_wait, merge_embed_responses, SyncSdkUtils, AsyncSdkUtils
|
|
|
|
logger = logging.getLogger(__name__)
|
|
run_overrides()
|
|
|
|
# Use NoReturn as Never type for compatibility
|
|
Never = typing.NoReturn
|
|
|
|
|
|
def validate_args(obj: typing.Any, method_name: str, check_fn: typing.Callable[[typing.Any], typing.Any]) -> None:
|
|
method = getattr(obj, method_name)
|
|
|
|
def _wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
|
|
check_fn(*args, **kwargs)
|
|
return method(*args, **kwargs)
|
|
|
|
async def _async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
|
|
# The `return await` looks redundant, but it's necessary to ensure that the return type is correct.
|
|
check_fn(*args, **kwargs)
|
|
return await method(*args, **kwargs)
|
|
|
|
wrapped = _wrapped
|
|
if asyncio.iscoroutinefunction(method):
|
|
wrapped = _async_wrapped
|
|
|
|
wrapped.__name__ = method.__name__
|
|
wrapped.__doc__ = method.__doc__
|
|
setattr(obj, method_name, wrapped)
|
|
|
|
|
|
def throw_if_stream_is_true(*args, **kwargs) -> None:
|
|
if kwargs.get("stream") is True:
|
|
raise ValueError(
|
|
"Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)"
|
|
)
|
|
|
|
|
|
def moved_function(fn_name: str, new_fn_name: str) -> typing.Any:
|
|
"""
|
|
This method is moved. Please update usage.
|
|
"""
|
|
|
|
def fn(*args, **kwargs):
|
|
raise ValueError(
|
|
f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been moved to {new_fn_name}(...). "
|
|
f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues."
|
|
)
|
|
|
|
return fn
|
|
|
|
|
|
def deprecated_function(fn_name: str) -> typing.Any:
|
|
"""
|
|
This method is deprecated. Please update usage.
|
|
"""
|
|
|
|
def fn(*args, **kwargs):
|
|
raise ValueError(
|
|
f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been deprecated. "
|
|
f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues."
|
|
)
|
|
|
|
return fn
|
|
|
|
|
|
# Logs a warning when a user calls a function with an experimental parameter (kwarg in our case)
|
|
# `deprecated_kwarg` is the name of the experimental parameter, which can be a dot-separated string for nested parameters
|
|
def experimental_kwarg_decorator(func, deprecated_kwarg):
|
|
# Recursive utility function to check if a kwarg is present in the kwargs.
|
|
def check_kwarg(deprecated_kwarg: str, kwargs: typing.Dict[str, typing.Any]) -> bool:
|
|
if "." in deprecated_kwarg:
|
|
key, rest = deprecated_kwarg.split(".", 1)
|
|
if key in kwargs:
|
|
return check_kwarg(rest, kwargs[key])
|
|
return deprecated_kwarg in kwargs
|
|
|
|
def _wrapped(*args, **kwargs):
|
|
if check_kwarg(deprecated_kwarg, kwargs):
|
|
logger.warning(
|
|
f"The `{deprecated_kwarg}` parameter is an experimental feature and may change in future releases.\n"
|
|
"To suppress this warning, set `log_warning_experimental_features=False` when initializing the client."
|
|
)
|
|
return func(*args, **kwargs)
|
|
|
|
async def _async_wrapped(*args, **kwargs):
|
|
if check_kwarg(deprecated_kwarg, kwargs):
|
|
logger.warning(
|
|
f"The `{deprecated_kwarg}` parameter is an experimental feature and may change in future releases.\n"
|
|
"To suppress this warning, set `log_warning_experimental_features=False` when initializing the client."
|
|
)
|
|
return await func(*args, **kwargs)
|
|
|
|
wrap = _wrapped
|
|
if asyncio.iscoroutinefunction(func):
|
|
wrap = _async_wrapped
|
|
|
|
wrap.__name__ = func.__name__
|
|
wrap.__doc__ = func.__doc__
|
|
|
|
return wrap
|
|
|
|
|
|
def fix_base_url(base_url: typing.Optional[str]) -> typing.Optional[str]:
|
|
if base_url is not None:
|
|
if "cohere.com" in base_url or "cohere.ai" in base_url:
|
|
return base_url.replace("/v1", "")
|
|
return base_url
|
|
return None
|
|
|
|
|
|
class Client(BaseCohere, CacheMixin):
|
|
_executor: ThreadPoolExecutor
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,
|
|
*,
|
|
base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
|
|
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
|
|
client_name: typing.Optional[str] = None,
|
|
timeout: typing.Optional[float] = None,
|
|
httpx_client: typing.Optional[httpx.Client] = None,
|
|
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
|
|
log_warning_experimental_features: bool = True,
|
|
):
|
|
if api_key is None:
|
|
api_key = _get_api_key_from_environment()
|
|
|
|
base_url = fix_base_url(base_url)
|
|
|
|
self._executor = thread_pool_executor
|
|
|
|
BaseCohere.__init__(
|
|
self,
|
|
base_url=base_url,
|
|
environment=environment,
|
|
client_name=client_name,
|
|
token=api_key,
|
|
timeout=timeout,
|
|
httpx_client=httpx_client,
|
|
)
|
|
|
|
validate_args(self, "chat", throw_if_stream_is_true)
|
|
if log_warning_experimental_features:
|
|
self.chat = experimental_kwarg_decorator(self.chat, "response_format.schema") # type: ignore
|
|
self.chat_stream = experimental_kwarg_decorator(self.chat_stream, "response_format.schema") # type: ignore
|
|
|
|
utils = SyncSdkUtils()
|
|
|
|
# support context manager until Fern upstreams
|
|
# https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self._client_wrapper.httpx_client.httpx_client.close()
|
|
|
|
wait = wait
|
|
|
|
def embed(
|
|
self,
|
|
*,
|
|
texts: typing.Optional[typing.Sequence[str]] = OMIT,
|
|
images: typing.Optional[typing.Sequence[str]] = OMIT,
|
|
model: typing.Optional[str] = OMIT,
|
|
input_type: typing.Optional[EmbedInputType] = OMIT,
|
|
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
|
|
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
batching: typing.Optional[bool] = True,
|
|
) -> EmbedResponse:
|
|
# skip batching for images for now
|
|
if batching is False or images is not OMIT:
|
|
return BaseCohere.embed(
|
|
self,
|
|
texts=texts,
|
|
images=images,
|
|
model=model,
|
|
input_type=input_type,
|
|
embedding_types=embedding_types,
|
|
truncate=truncate,
|
|
request_options=request_options,
|
|
)
|
|
|
|
textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
|
|
texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)]
|
|
|
|
responses = [
|
|
response
|
|
for response in self._executor.map(
|
|
lambda text_batch: BaseCohere.embed(
|
|
self,
|
|
texts=text_batch,
|
|
model=model,
|
|
input_type=input_type,
|
|
embedding_types=embedding_types,
|
|
truncate=truncate,
|
|
request_options=request_options,
|
|
),
|
|
texts_batches,
|
|
)
|
|
]
|
|
|
|
return merge_embed_responses(responses)
|
|
|
|
"""
|
|
The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
|
|
Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.
|
|
"""
|
|
check_api_key: Never = deprecated_function("check_api_key")
|
|
loglikelihood: Never = deprecated_function("loglikelihood")
|
|
batch_generate: Never = deprecated_function("batch_generate")
|
|
codebook: Never = deprecated_function("codebook")
|
|
batch_tokenize: Never = deprecated_function("batch_tokenize")
|
|
batch_detokenize: Never = deprecated_function("batch_detokenize")
|
|
detect_language: Never = deprecated_function("detect_language")
|
|
generate_feedback: Never = deprecated_function("generate_feedback")
|
|
generate_preference_feedback: Never = deprecated_function("generate_preference_feedback")
|
|
create_dataset: Never = moved_function("create_dataset", ".datasets.create")
|
|
get_dataset: Never = moved_function("get_dataset", ".datasets.get")
|
|
list_datasets: Never = moved_function("list_datasets", ".datasets.list")
|
|
delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete")
|
|
get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage")
|
|
wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait")
|
|
_check_response: Never = deprecated_function("_check_response")
|
|
_request: Never = deprecated_function("_request")
|
|
create_cluster_job: Never = deprecated_function("create_cluster_job")
|
|
get_cluster_job: Never = deprecated_function("get_cluster_job")
|
|
list_cluster_jobs: Never = deprecated_function("list_cluster_jobs")
|
|
wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job")
|
|
create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create")
|
|
list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list")
|
|
get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get")
|
|
cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel")
|
|
wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait")
|
|
create_custom_model: Never = deprecated_function("create_custom_model")
|
|
wait_for_custom_model: Never = deprecated_function("wait_for_custom_model")
|
|
_upload_dataset: Never = deprecated_function("_upload_dataset")
|
|
_create_signed_url: Never = deprecated_function("_create_signed_url")
|
|
get_custom_model: Never = deprecated_function("get_custom_model")
|
|
get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name")
|
|
get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics")
|
|
list_custom_models: Never = deprecated_function("list_custom_models")
|
|
create_connector: Never = moved_function("create_connector", ".connectors.create")
|
|
update_connector: Never = moved_function("update_connector", ".connectors.update")
|
|
get_connector: Never = moved_function("get_connector", ".connectors.get")
|
|
list_connectors: Never = moved_function("list_connectors", ".connectors.list")
|
|
delete_connector: Never = moved_function("delete_connector", ".connectors.delete")
|
|
oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize")
|
|
|
|
def tokenize(
|
|
self,
|
|
*,
|
|
text: str,
|
|
model: str,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
offline: bool = True,
|
|
) -> TokenizeResponse:
|
|
# `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached),
|
|
# and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True.
|
|
opts: RequestOptions = request_options or {} # type: ignore
|
|
|
|
if offline:
|
|
try:
|
|
tokens = local_tokenizers.local_tokenize(self, text=text, model=model)
|
|
return TokenizeResponse(tokens=tokens, token_strings=[])
|
|
except Exception:
|
|
# Fallback to calling the API.
|
|
opts["additional_headers"] = opts.get("additional_headers", {})
|
|
opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed"
|
|
return super().tokenize(text=text, model=model, request_options=opts)
|
|
|
|
def detokenize(
|
|
self,
|
|
*,
|
|
tokens: typing.Sequence[int],
|
|
model: str,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
offline: typing.Optional[bool] = True,
|
|
) -> DetokenizeResponse:
|
|
# `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached),
|
|
# and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True.
|
|
opts: RequestOptions = request_options or {} # type: ignore
|
|
|
|
if offline:
|
|
try:
|
|
text = local_tokenizers.local_detokenize(self, model=model, tokens=tokens)
|
|
return DetokenizeResponse(text=text)
|
|
except Exception:
|
|
# Fallback to calling the API.
|
|
opts["additional_headers"] = opts.get("additional_headers", {})
|
|
opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed"
|
|
|
|
return super().detokenize(tokens=tokens, model=model, request_options=opts)
|
|
|
|
def fetch_tokenizer(self, *, model: str) -> Tokenizer:
|
|
"""
|
|
Returns a Hugging Face tokenizer from a given model name.
|
|
"""
|
|
return local_tokenizers.get_hf_tokenizer(self, model)
|
|
|
|
|
|
class AsyncClient(AsyncBaseCohere, CacheMixin):
|
|
_executor: ThreadPoolExecutor
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,
|
|
*,
|
|
base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
|
|
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
|
|
client_name: typing.Optional[str] = None,
|
|
timeout: typing.Optional[float] = None,
|
|
httpx_client: typing.Optional[httpx.AsyncClient] = None,
|
|
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
|
|
log_warning_experimental_features: bool = True,
|
|
):
|
|
if api_key is None:
|
|
api_key = _get_api_key_from_environment()
|
|
|
|
base_url = fix_base_url(base_url)
|
|
|
|
self._executor = thread_pool_executor
|
|
|
|
AsyncBaseCohere.__init__(
|
|
self,
|
|
base_url=base_url,
|
|
environment=environment,
|
|
client_name=client_name,
|
|
token=api_key,
|
|
timeout=timeout,
|
|
httpx_client=httpx_client,
|
|
)
|
|
|
|
validate_args(self, "chat", throw_if_stream_is_true)
|
|
if log_warning_experimental_features:
|
|
self.chat = experimental_kwarg_decorator(self.chat, "response_format.schema") # type: ignore
|
|
self.chat_stream = experimental_kwarg_decorator(self.chat_stream, "response_format.schema") # type: ignore
|
|
|
|
utils = AsyncSdkUtils()
|
|
|
|
# support context manager until Fern upstreams
|
|
# https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
await self._client_wrapper.httpx_client.httpx_client.aclose()
|
|
|
|
wait = async_wait
|
|
|
|
async def embed(
|
|
self,
|
|
*,
|
|
texts: typing.Optional[typing.Sequence[str]] = OMIT,
|
|
images: typing.Optional[typing.Sequence[str]] = OMIT,
|
|
model: typing.Optional[str] = OMIT,
|
|
input_type: typing.Optional[EmbedInputType] = OMIT,
|
|
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
|
|
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
batching: typing.Optional[bool] = True,
|
|
) -> EmbedResponse:
|
|
# skip batching for images for now
|
|
if batching is False or images is not OMIT:
|
|
return await AsyncBaseCohere.embed(
|
|
self,
|
|
texts=texts,
|
|
images=images,
|
|
model=model,
|
|
input_type=input_type,
|
|
embedding_types=embedding_types,
|
|
truncate=truncate,
|
|
request_options=request_options,
|
|
)
|
|
|
|
textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
|
|
texts_batches = [textsarr[i : i + embed_batch_size] for i in range(0, len(textsarr), embed_batch_size)]
|
|
|
|
responses = typing.cast(
|
|
typing.List[EmbedResponse],
|
|
await asyncio.gather(
|
|
*[
|
|
AsyncBaseCohere.embed(
|
|
self,
|
|
texts=text_batch,
|
|
model=model,
|
|
input_type=input_type,
|
|
embedding_types=embedding_types,
|
|
truncate=truncate,
|
|
request_options=request_options,
|
|
)
|
|
for text_batch in texts_batches
|
|
]
|
|
),
|
|
)
|
|
|
|
return merge_embed_responses(responses)
|
|
|
|
"""
|
|
The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
|
|
Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.
|
|
"""
|
|
check_api_key: Never = deprecated_function("check_api_key")
|
|
loglikelihood: Never = deprecated_function("loglikelihood")
|
|
batch_generate: Never = deprecated_function("batch_generate")
|
|
codebook: Never = deprecated_function("codebook")
|
|
batch_tokenize: Never = deprecated_function("batch_tokenize")
|
|
batch_detokenize: Never = deprecated_function("batch_detokenize")
|
|
detect_language: Never = deprecated_function("detect_language")
|
|
generate_feedback: Never = deprecated_function("generate_feedback")
|
|
generate_preference_feedback: Never = deprecated_function("generate_preference_feedback")
|
|
create_dataset: Never = moved_function("create_dataset", ".datasets.create")
|
|
get_dataset: Never = moved_function("get_dataset", ".datasets.get")
|
|
list_datasets: Never = moved_function("list_datasets", ".datasets.list")
|
|
delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete")
|
|
get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage")
|
|
wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait")
|
|
_check_response: Never = deprecated_function("_check_response")
|
|
_request: Never = deprecated_function("_request")
|
|
create_cluster_job: Never = deprecated_function("create_cluster_job")
|
|
get_cluster_job: Never = deprecated_function("get_cluster_job")
|
|
list_cluster_jobs: Never = deprecated_function("list_cluster_jobs")
|
|
wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job")
|
|
create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create")
|
|
list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list")
|
|
get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get")
|
|
cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel")
|
|
wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait")
|
|
create_custom_model: Never = deprecated_function("create_custom_model")
|
|
wait_for_custom_model: Never = deprecated_function("wait_for_custom_model")
|
|
_upload_dataset: Never = deprecated_function("_upload_dataset")
|
|
_create_signed_url: Never = deprecated_function("_create_signed_url")
|
|
get_custom_model: Never = deprecated_function("get_custom_model")
|
|
get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name")
|
|
get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics")
|
|
list_custom_models: Never = deprecated_function("list_custom_models")
|
|
create_connector: Never = moved_function("create_connector", ".connectors.create")
|
|
update_connector: Never = moved_function("update_connector", ".connectors.update")
|
|
get_connector: Never = moved_function("get_connector", ".connectors.get")
|
|
list_connectors: Never = moved_function("list_connectors", ".connectors.list")
|
|
delete_connector: Never = moved_function("delete_connector", ".connectors.delete")
|
|
oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize")
|
|
|
|
async def tokenize(
|
|
self,
|
|
*,
|
|
text: str,
|
|
model: str,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
offline: typing.Optional[bool] = True,
|
|
) -> TokenizeResponse:
|
|
# `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached),
|
|
# and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True.
|
|
opts: RequestOptions = request_options or {} # type: ignore
|
|
if offline:
|
|
try:
|
|
tokens = await local_tokenizers.async_local_tokenize(self, model=model, text=text)
|
|
return TokenizeResponse(tokens=tokens, token_strings=[])
|
|
except Exception:
|
|
opts["additional_headers"] = opts.get("additional_headers", {})
|
|
opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed"
|
|
|
|
return await super().tokenize(text=text, model=model, request_options=opts)
|
|
|
|
async def detokenize(
|
|
self,
|
|
*,
|
|
tokens: typing.Sequence[int],
|
|
model: str,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
offline: typing.Optional[bool] = True,
|
|
) -> DetokenizeResponse:
|
|
# `offline` parameter controls whether to use an offline tokenizer. If set to True, the tokenizer config will be downloaded (and cached),
|
|
# and the request will be processed using the offline tokenizer. If set to False, the request will be processed using the API. The default value is True.
|
|
opts: RequestOptions = request_options or {} # type: ignore
|
|
if offline:
|
|
try:
|
|
text = await local_tokenizers.async_local_detokenize(self, model=model, tokens=tokens)
|
|
return DetokenizeResponse(text=text)
|
|
except Exception:
|
|
opts["additional_headers"] = opts.get("additional_headers", {})
|
|
opts["additional_headers"]["sdk-api-warning-message"] = "offline_tokenizer_failed"
|
|
|
|
return await super().detokenize(tokens=tokens, model=model, request_options=opts)
|
|
|
|
async def fetch_tokenizer(self, *, model: str) -> Tokenizer:
|
|
"""
|
|
Returns a Hugging Face tokenizer from a given model name.
|
|
"""
|
|
return await local_tokenizers.async_get_hf_tokenizer(self, model)
|
|
|
|
|
|
def _get_api_key_from_environment() -> typing.Optional[str]:
|
|
"""
|
|
Retrieves the Cohere API key from specific environment variables.
|
|
CO_API_KEY is preferred (and documented) COHERE_API_KEY is accepted (but not documented).
|
|
"""
|
|
return os.getenv("CO_API_KEY", os.getenv("COHERE_API_KEY"))
|