2358 lines
105 KiB
Python
2358 lines
105 KiB
Python
# This file was auto-generated by Fern from our API Definition.
|
|
|
|
import typing
|
|
from ..core.client_wrapper import SyncClientWrapper
|
|
from ..types.chat_messages import ChatMessages
|
|
from ..types.tool_v2 import ToolV2
|
|
from .types.v2chat_stream_request_documents_item import V2ChatStreamRequestDocumentsItem
|
|
from ..types.citation_options import CitationOptions
|
|
from ..types.response_format_v2 import ResponseFormatV2
|
|
from .types.v2chat_stream_request_safety_mode import V2ChatStreamRequestSafetyMode
|
|
from .types.v2chat_stream_request_tool_choice import V2ChatStreamRequestToolChoice
|
|
from ..core.request_options import RequestOptions
|
|
from ..types.streamed_chat_response_v2 import StreamedChatResponseV2
|
|
from ..core.serialization import convert_and_respect_annotation_metadata
|
|
import httpx_sse
|
|
from ..core.unchecked_base_model import construct_type
|
|
import json
|
|
from ..errors.bad_request_error import BadRequestError
|
|
from ..errors.unauthorized_error import UnauthorizedError
|
|
from ..errors.forbidden_error import ForbiddenError
|
|
from ..errors.not_found_error import NotFoundError
|
|
from ..errors.unprocessable_entity_error import UnprocessableEntityError
|
|
from ..errors.too_many_requests_error import TooManyRequestsError
|
|
from ..errors.invalid_token_error import InvalidTokenError
|
|
from ..errors.client_closed_request_error import ClientClosedRequestError
|
|
from ..errors.internal_server_error import InternalServerError
|
|
from ..errors.not_implemented_error import NotImplementedError
|
|
from ..errors.service_unavailable_error import ServiceUnavailableError
|
|
from ..errors.gateway_timeout_error import GatewayTimeoutError
|
|
from json.decoder import JSONDecodeError
|
|
from ..core.api_error import ApiError
|
|
from .types.v2chat_request_documents_item import V2ChatRequestDocumentsItem
|
|
from .types.v2chat_request_safety_mode import V2ChatRequestSafetyMode
|
|
from .types.v2chat_request_tool_choice import V2ChatRequestToolChoice
|
|
from ..types.chat_response import ChatResponse
|
|
from ..types.embed_input_type import EmbedInputType
|
|
from ..types.embedding_type import EmbeddingType
|
|
from ..types.embed_input import EmbedInput
|
|
from .types.v2embed_request_truncate import V2EmbedRequestTruncate
|
|
from ..types.embed_by_type_response import EmbedByTypeResponse
|
|
from .types.v2rerank_response import V2RerankResponse
|
|
from ..core.client_wrapper import AsyncClientWrapper
|
|
|
|
# this is used as the default value for optional parameters
|
|
OMIT = typing.cast(typing.Any, ...)
|
|
|
|
|
|
class V2Client:
|
|
def __init__(self, *, client_wrapper: SyncClientWrapper):
|
|
self._client_wrapper = client_wrapper
|
|
|
|
def chat_stream(
|
|
self,
|
|
*,
|
|
model: str,
|
|
messages: ChatMessages,
|
|
tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT,
|
|
strict_tools: typing.Optional[bool] = OMIT,
|
|
documents: typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] = OMIT,
|
|
citation_options: typing.Optional[CitationOptions] = OMIT,
|
|
response_format: typing.Optional[ResponseFormatV2] = OMIT,
|
|
safety_mode: typing.Optional[V2ChatStreamRequestSafetyMode] = OMIT,
|
|
max_tokens: typing.Optional[int] = OMIT,
|
|
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
|
|
temperature: typing.Optional[float] = OMIT,
|
|
seed: typing.Optional[int] = OMIT,
|
|
frequency_penalty: typing.Optional[float] = OMIT,
|
|
presence_penalty: typing.Optional[float] = OMIT,
|
|
k: typing.Optional[float] = OMIT,
|
|
p: typing.Optional[float] = OMIT,
|
|
return_prompt: typing.Optional[bool] = OMIT,
|
|
logprobs: typing.Optional[bool] = OMIT,
|
|
tool_choice: typing.Optional[V2ChatStreamRequestToolChoice] = OMIT,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
) -> typing.Iterator[StreamedChatResponseV2]:
|
|
"""
|
|
Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
|
|
|
|
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
|
|
|
|
Parameters
|
|
----------
|
|
model : str
|
|
The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/v2/docs/chat-fine-tuning) model.
|
|
|
|
messages : ChatMessages
|
|
|
|
tools : typing.Optional[typing.Sequence[ToolV2]]
|
|
A list of available tools (functions) that the model may suggest invoking before producing a text response.
|
|
|
|
When `tools` is passed (without `tool_results`), the `text` content in the response will be empty and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
|
|
|
|
|
|
strict_tools : typing.Optional[bool]
|
|
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
|
|
|
|
**Note**: The first few requests with a new set of tools will take longer to process.
|
|
|
|
|
|
documents : typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]]
|
|
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
|
|
|
|
|
|
citation_options : typing.Optional[CitationOptions]
|
|
|
|
response_format : typing.Optional[ResponseFormatV2]
|
|
|
|
safety_mode : typing.Optional[V2ChatStreamRequestSafetyMode]
|
|
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
|
|
When `OFF` is specified, the safety instruction will be omitted.
|
|
|
|
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
|
|
|
|
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
|
|
|
|
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
|
|
|
|
|
|
max_tokens : typing.Optional[int]
|
|
The maximum number of tokens the model will generate as part of the response.
|
|
|
|
**Note**: Setting a low value may result in incomplete generations.
|
|
|
|
|
|
stop_sequences : typing.Optional[typing.Sequence[str]]
|
|
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
|
|
|
|
|
|
temperature : typing.Optional[float]
|
|
Defaults to `0.3`.
|
|
|
|
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
|
|
|
|
Randomness can be further maximized by increasing the value of the `p` parameter.
|
|
|
|
|
|
seed : typing.Optional[int]
|
|
If specified, the backend will make a best effort to sample tokens
|
|
deterministically, such that repeated requests with the same
|
|
seed and parameters should return the same result. However,
|
|
determinism cannot be totally guaranteed.
|
|
|
|
|
|
frequency_penalty : typing.Optional[float]
|
|
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
|
|
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
|
|
|
|
|
|
presence_penalty : typing.Optional[float]
|
|
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
|
|
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
|
|
|
|
|
|
k : typing.Optional[float]
|
|
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
|
|
Defaults to `0`, min value of `0`, max value of `500`.
|
|
|
|
|
|
p : typing.Optional[float]
|
|
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
|
|
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
|
|
|
|
|
|
return_prompt : typing.Optional[bool]
|
|
Whether to return the prompt in the response.
|
|
|
|
logprobs : typing.Optional[bool]
|
|
Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
|
|
|
|
|
|
tool_choice : typing.Optional[V2ChatStreamRequestToolChoice]
|
|
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
|
|
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
|
|
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
|
|
|
|
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
|
|
|
|
**Note**: The same functionality can be achieved in `/v1/chat` using the `force_single_step` parameter. If `force_single_step=true`, this is equivalent to specifying `REQUIRED`. While if `force_single_step=true` and `tool_results` are passed, this is equivalent to specifying `NONE`.
|
|
|
|
|
|
request_options : typing.Optional[RequestOptions]
|
|
Request-specific configuration.
|
|
|
|
Yields
|
|
------
|
|
typing.Iterator[StreamedChatResponseV2]
|
|
|
|
|
|
Examples
|
|
--------
|
|
from cohere import Client, ToolChatMessageV2
|
|
|
|
client = Client(
|
|
client_name="YOUR_CLIENT_NAME",
|
|
token="YOUR_TOKEN",
|
|
)
|
|
response = client.v2.chat_stream(
|
|
model="model",
|
|
messages=[
|
|
ToolChatMessageV2(
|
|
tool_call_id="messages",
|
|
content="messages",
|
|
)
|
|
],
|
|
)
|
|
for chunk in response:
|
|
yield chunk
|
|
"""
|
|
with self._client_wrapper.httpx_client.stream(
|
|
"v2/chat",
|
|
method="POST",
|
|
json={
|
|
"model": model,
|
|
"messages": convert_and_respect_annotation_metadata(
|
|
object_=messages, annotation=ChatMessages, direction="write"
|
|
),
|
|
"tools": convert_and_respect_annotation_metadata(
|
|
object_=tools, annotation=typing.Sequence[ToolV2], direction="write"
|
|
),
|
|
"strict_tools": strict_tools,
|
|
"documents": convert_and_respect_annotation_metadata(
|
|
object_=documents, annotation=typing.Sequence[V2ChatStreamRequestDocumentsItem], direction="write"
|
|
),
|
|
"citation_options": convert_and_respect_annotation_metadata(
|
|
object_=citation_options, annotation=CitationOptions, direction="write"
|
|
),
|
|
"response_format": convert_and_respect_annotation_metadata(
|
|
object_=response_format, annotation=ResponseFormatV2, direction="write"
|
|
),
|
|
"safety_mode": safety_mode,
|
|
"max_tokens": max_tokens,
|
|
"stop_sequences": stop_sequences,
|
|
"temperature": temperature,
|
|
"seed": seed,
|
|
"frequency_penalty": frequency_penalty,
|
|
"presence_penalty": presence_penalty,
|
|
"k": k,
|
|
"p": p,
|
|
"return_prompt": return_prompt,
|
|
"logprobs": logprobs,
|
|
"tool_choice": tool_choice,
|
|
"stream": True,
|
|
},
|
|
headers={
|
|
"content-type": "application/json",
|
|
},
|
|
request_options=request_options,
|
|
omit=OMIT,
|
|
) as _response:
|
|
try:
|
|
if 200 <= _response.status_code < 300:
|
|
_event_source = httpx_sse.EventSource(_response)
|
|
for _sse in _event_source.iter_sse():
|
|
try:
|
|
yield typing.cast(
|
|
StreamedChatResponseV2,
|
|
construct_type(
|
|
type_=StreamedChatResponseV2, # type: ignore
|
|
object_=json.loads(_sse.data),
|
|
),
|
|
)
|
|
except:
|
|
pass
|
|
return
|
|
_response.read()
|
|
if _response.status_code == 400:
|
|
raise BadRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 401:
|
|
raise UnauthorizedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 403:
|
|
raise ForbiddenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 404:
|
|
raise NotFoundError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 422:
|
|
raise UnprocessableEntityError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 429:
|
|
raise TooManyRequestsError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 498:
|
|
raise InvalidTokenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 499:
|
|
raise ClientClosedRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 500:
|
|
raise InternalServerError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 501:
|
|
raise NotImplementedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 503:
|
|
raise ServiceUnavailableError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 504:
|
|
raise GatewayTimeoutError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
_response_json = _response.json()
|
|
except JSONDecodeError:
|
|
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
|
|
def chat(
|
|
self,
|
|
*,
|
|
model: str,
|
|
messages: ChatMessages,
|
|
tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT,
|
|
strict_tools: typing.Optional[bool] = OMIT,
|
|
documents: typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] = OMIT,
|
|
citation_options: typing.Optional[CitationOptions] = OMIT,
|
|
response_format: typing.Optional[ResponseFormatV2] = OMIT,
|
|
safety_mode: typing.Optional[V2ChatRequestSafetyMode] = OMIT,
|
|
max_tokens: typing.Optional[int] = OMIT,
|
|
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
|
|
temperature: typing.Optional[float] = OMIT,
|
|
seed: typing.Optional[int] = OMIT,
|
|
frequency_penalty: typing.Optional[float] = OMIT,
|
|
presence_penalty: typing.Optional[float] = OMIT,
|
|
k: typing.Optional[float] = OMIT,
|
|
p: typing.Optional[float] = OMIT,
|
|
return_prompt: typing.Optional[bool] = OMIT,
|
|
logprobs: typing.Optional[bool] = OMIT,
|
|
tool_choice: typing.Optional[V2ChatRequestToolChoice] = OMIT,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
) -> ChatResponse:
|
|
"""
|
|
Generates a text response to a user message and streams it down, token by token. To learn how to use the Chat API with streaming follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
|
|
|
|
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
|
|
|
|
Parameters
|
|
----------
|
|
model : str
|
|
The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/v2/docs/chat-fine-tuning) model.
|
|
|
|
messages : ChatMessages
|
|
|
|
tools : typing.Optional[typing.Sequence[ToolV2]]
|
|
A list of available tools (functions) that the model may suggest invoking before producing a text response.
|
|
|
|
When `tools` is passed (without `tool_results`), the `text` content in the response will be empty and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
|
|
|
|
|
|
strict_tools : typing.Optional[bool]
|
|
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
|
|
|
|
**Note**: The first few requests with a new set of tools will take longer to process.
|
|
|
|
|
|
documents : typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]]
|
|
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
|
|
|
|
|
|
citation_options : typing.Optional[CitationOptions]
|
|
|
|
response_format : typing.Optional[ResponseFormatV2]
|
|
|
|
safety_mode : typing.Optional[V2ChatRequestSafetyMode]
|
|
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
|
|
When `OFF` is specified, the safety instruction will be omitted.
|
|
|
|
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
|
|
|
|
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
|
|
|
|
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
|
|
|
|
|
|
max_tokens : typing.Optional[int]
|
|
The maximum number of tokens the model will generate as part of the response.
|
|
|
|
**Note**: Setting a low value may result in incomplete generations.
|
|
|
|
|
|
stop_sequences : typing.Optional[typing.Sequence[str]]
|
|
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
|
|
|
|
|
|
temperature : typing.Optional[float]
|
|
Defaults to `0.3`.
|
|
|
|
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
|
|
|
|
Randomness can be further maximized by increasing the value of the `p` parameter.
|
|
|
|
|
|
seed : typing.Optional[int]
|
|
If specified, the backend will make a best effort to sample tokens
|
|
deterministically, such that repeated requests with the same
|
|
seed and parameters should return the same result. However,
|
|
determinism cannot be totally guaranteed.
|
|
|
|
|
|
frequency_penalty : typing.Optional[float]
|
|
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
|
|
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
|
|
|
|
|
|
presence_penalty : typing.Optional[float]
|
|
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
|
|
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
|
|
|
|
|
|
k : typing.Optional[float]
|
|
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
|
|
Defaults to `0`, min value of `0`, max value of `500`.
|
|
|
|
|
|
p : typing.Optional[float]
|
|
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
|
|
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
|
|
|
|
|
|
return_prompt : typing.Optional[bool]
|
|
Whether to return the prompt in the response.
|
|
|
|
logprobs : typing.Optional[bool]
|
|
Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
|
|
|
|
|
|
tool_choice : typing.Optional[V2ChatRequestToolChoice]
|
|
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
|
|
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
|
|
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
|
|
|
|
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
|
|
|
|
**Note**: The same functionality can be achieved in `/v1/chat` using the `force_single_step` parameter. If `force_single_step=true`, this is equivalent to specifying `REQUIRED`. While if `force_single_step=true` and `tool_results` are passed, this is equivalent to specifying `NONE`.
|
|
|
|
|
|
request_options : typing.Optional[RequestOptions]
|
|
Request-specific configuration.
|
|
|
|
Returns
|
|
-------
|
|
ChatResponse
|
|
|
|
|
|
Examples
|
|
--------
|
|
from cohere import Client, ToolChatMessageV2
|
|
|
|
client = Client(
|
|
client_name="YOUR_CLIENT_NAME",
|
|
token="YOUR_TOKEN",
|
|
)
|
|
client.v2.chat(
|
|
model="model",
|
|
messages=[
|
|
ToolChatMessageV2(
|
|
tool_call_id="messages",
|
|
content="messages",
|
|
)
|
|
],
|
|
)
|
|
"""
|
|
_response = self._client_wrapper.httpx_client.request(
|
|
"v2/chat",
|
|
method="POST",
|
|
json={
|
|
"model": model,
|
|
"messages": convert_and_respect_annotation_metadata(
|
|
object_=messages, annotation=ChatMessages, direction="write"
|
|
),
|
|
"tools": convert_and_respect_annotation_metadata(
|
|
object_=tools, annotation=typing.Sequence[ToolV2], direction="write"
|
|
),
|
|
"strict_tools": strict_tools,
|
|
"documents": convert_and_respect_annotation_metadata(
|
|
object_=documents, annotation=typing.Sequence[V2ChatRequestDocumentsItem], direction="write"
|
|
),
|
|
"citation_options": convert_and_respect_annotation_metadata(
|
|
object_=citation_options, annotation=CitationOptions, direction="write"
|
|
),
|
|
"response_format": convert_and_respect_annotation_metadata(
|
|
object_=response_format, annotation=ResponseFormatV2, direction="write"
|
|
),
|
|
"safety_mode": safety_mode,
|
|
"max_tokens": max_tokens,
|
|
"stop_sequences": stop_sequences,
|
|
"temperature": temperature,
|
|
"seed": seed,
|
|
"frequency_penalty": frequency_penalty,
|
|
"presence_penalty": presence_penalty,
|
|
"k": k,
|
|
"p": p,
|
|
"return_prompt": return_prompt,
|
|
"logprobs": logprobs,
|
|
"tool_choice": tool_choice,
|
|
"stream": False,
|
|
},
|
|
headers={
|
|
"content-type": "application/json",
|
|
},
|
|
request_options=request_options,
|
|
omit=OMIT,
|
|
)
|
|
try:
|
|
if 200 <= _response.status_code < 300:
|
|
return typing.cast(
|
|
ChatResponse,
|
|
construct_type(
|
|
type_=ChatResponse, # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
if _response.status_code == 400:
|
|
raise BadRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 401:
|
|
raise UnauthorizedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 403:
|
|
raise ForbiddenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 404:
|
|
raise NotFoundError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 422:
|
|
raise UnprocessableEntityError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 429:
|
|
raise TooManyRequestsError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 498:
|
|
raise InvalidTokenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 499:
|
|
raise ClientClosedRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 500:
|
|
raise InternalServerError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 501:
|
|
raise NotImplementedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 503:
|
|
raise ServiceUnavailableError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 504:
|
|
raise GatewayTimeoutError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
_response_json = _response.json()
|
|
except JSONDecodeError:
|
|
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
|
|
def embed(
|
|
self,
|
|
*,
|
|
model: str,
|
|
input_type: EmbedInputType,
|
|
embedding_types: typing.Sequence[EmbeddingType],
|
|
texts: typing.Optional[typing.Sequence[str]] = OMIT,
|
|
images: typing.Optional[typing.Sequence[str]] = OMIT,
|
|
inputs: typing.Optional[typing.Sequence[EmbedInput]] = OMIT,
|
|
max_tokens: typing.Optional[int] = OMIT,
|
|
output_dimension: typing.Optional[int] = OMIT,
|
|
truncate: typing.Optional[V2EmbedRequestTruncate] = OMIT,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
) -> EmbedByTypeResponse:
|
|
"""
|
|
This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents.
|
|
|
|
Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
|
|
|
|
If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
|
|
|
|
Parameters
|
|
----------
|
|
model : str
|
|
Defaults to embed-english-v2.0
|
|
|
|
The identifier of the model. Smaller "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
|
|
|
|
Available models and corresponding embedding dimensions:
|
|
|
|
* `embed-english-v3.0` 1024
|
|
* `embed-multilingual-v3.0` 1024
|
|
* `embed-english-light-v3.0` 384
|
|
* `embed-multilingual-light-v3.0` 384
|
|
|
|
* `embed-english-v2.0` 4096
|
|
* `embed-english-light-v2.0` 1024
|
|
* `embed-multilingual-v2.0` 768
|
|
|
|
input_type : EmbedInputType
|
|
|
|
embedding_types : typing.Sequence[EmbeddingType]
|
|
Specifies the types of embeddings you want to get back. Can be one or more of the following types.
|
|
|
|
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
|
|
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
|
|
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for only v3 models.
|
|
* `"binary"`: Use this when you want to get back signed binary embeddings. Valid for only v3 models.
|
|
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for only v3 models.
|
|
|
|
texts : typing.Optional[typing.Sequence[str]]
|
|
An array of strings for the model to embed. Maximum number of texts per call is `96`. We recommend reducing the length of each text to be under `512` tokens for optimal quality.
|
|
|
|
images : typing.Optional[typing.Sequence[str]]
|
|
An array of image data URIs for the model to embed. Maximum number of images per call is `1`.
|
|
|
|
The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg` or `image/png` format and has a maximum size of 5MB.
|
|
|
|
inputs : typing.Optional[typing.Sequence[EmbedInput]]
|
|
An array of inputs for the model to embed. Maximum number of inputs per call is `96`. An input can contain a mix of text and image components.
|
|
|
|
max_tokens : typing.Optional[int]
|
|
The maximum number of tokens to embed per input. If the input text is longer than this, it will be truncated according to the `truncate` parameter.
|
|
|
|
output_dimension : typing.Optional[int]
|
|
The number of dimensions of the output embedding. This is only available for `embed-v4` and newer models.
|
|
Possible values are `256`, `512`, `1024`, and `1536`. The default is `1536`.
|
|
|
|
truncate : typing.Optional[V2EmbedRequestTruncate]
|
|
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
|
|
|
|
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
|
|
|
|
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
|
|
|
|
request_options : typing.Optional[RequestOptions]
|
|
Request-specific configuration.
|
|
|
|
Returns
|
|
-------
|
|
EmbedByTypeResponse
|
|
OK
|
|
|
|
Examples
|
|
--------
|
|
from cohere import Client
|
|
|
|
client = Client(
|
|
client_name="YOUR_CLIENT_NAME",
|
|
token="YOUR_TOKEN",
|
|
)
|
|
client.v2.embed(
|
|
model="model",
|
|
input_type="search_document",
|
|
embedding_types=["float"],
|
|
)
|
|
"""
|
|
_response = self._client_wrapper.httpx_client.request(
|
|
"v2/embed",
|
|
method="POST",
|
|
json={
|
|
"texts": texts,
|
|
"images": images,
|
|
"model": model,
|
|
"input_type": input_type,
|
|
"inputs": convert_and_respect_annotation_metadata(
|
|
object_=inputs, annotation=typing.Sequence[EmbedInput], direction="write"
|
|
),
|
|
"max_tokens": max_tokens,
|
|
"output_dimension": output_dimension,
|
|
"embedding_types": embedding_types,
|
|
"truncate": truncate,
|
|
},
|
|
headers={
|
|
"content-type": "application/json",
|
|
},
|
|
request_options=request_options,
|
|
omit=OMIT,
|
|
)
|
|
try:
|
|
if 200 <= _response.status_code < 300:
|
|
return typing.cast(
|
|
EmbedByTypeResponse,
|
|
construct_type(
|
|
type_=EmbedByTypeResponse, # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
if _response.status_code == 400:
|
|
raise BadRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 401:
|
|
raise UnauthorizedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 403:
|
|
raise ForbiddenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 404:
|
|
raise NotFoundError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 422:
|
|
raise UnprocessableEntityError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 429:
|
|
raise TooManyRequestsError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 498:
|
|
raise InvalidTokenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 499:
|
|
raise ClientClosedRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 500:
|
|
raise InternalServerError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 501:
|
|
raise NotImplementedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 503:
|
|
raise ServiceUnavailableError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 504:
|
|
raise GatewayTimeoutError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
_response_json = _response.json()
|
|
except JSONDecodeError:
|
|
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
|
|
def rerank(
|
|
self,
|
|
*,
|
|
model: str,
|
|
query: str,
|
|
documents: typing.Sequence[str],
|
|
top_n: typing.Optional[int] = OMIT,
|
|
return_documents: typing.Optional[bool] = OMIT,
|
|
max_tokens_per_doc: typing.Optional[int] = OMIT,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
) -> V2RerankResponse:
|
|
"""
|
|
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
|
|
|
|
Parameters
|
|
----------
|
|
model : str
|
|
The identifier of the model to use, eg `rerank-v3.5`.
|
|
|
|
query : str
|
|
The search query
|
|
|
|
documents : typing.Sequence[str]
|
|
A list of texts that will be compared to the `query`.
|
|
For optimal performance we recommend against sending more than 1,000 documents in a single request.
|
|
|
|
**Note**: long documents will automatically be truncated to the value of `max_tokens_per_doc`.
|
|
|
|
**Note**: structured data should be formatted as YAML strings for best performance.
|
|
|
|
top_n : typing.Optional[int]
|
|
Limits the number of returned rerank results to the specified value. If not passed, all the rerank results will be returned.
|
|
|
|
return_documents : typing.Optional[bool]
|
|
- If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request.
|
|
- If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request.
|
|
|
|
max_tokens_per_doc : typing.Optional[int]
|
|
Defaults to `4096`. Long documents will be automatically truncated to the specified number of tokens.
|
|
|
|
request_options : typing.Optional[RequestOptions]
|
|
Request-specific configuration.
|
|
|
|
Returns
|
|
-------
|
|
V2RerankResponse
|
|
OK
|
|
|
|
Examples
|
|
--------
|
|
from cohere import Client
|
|
|
|
client = Client(
|
|
client_name="YOUR_CLIENT_NAME",
|
|
token="YOUR_TOKEN",
|
|
)
|
|
client.v2.rerank(
|
|
model="model",
|
|
query="query",
|
|
documents=["documents"],
|
|
)
|
|
"""
|
|
_response = self._client_wrapper.httpx_client.request(
|
|
"v2/rerank",
|
|
method="POST",
|
|
json={
|
|
"model": model,
|
|
"query": query,
|
|
"documents": documents,
|
|
"top_n": top_n,
|
|
"return_documents": return_documents,
|
|
"max_tokens_per_doc": max_tokens_per_doc,
|
|
},
|
|
headers={
|
|
"content-type": "application/json",
|
|
},
|
|
request_options=request_options,
|
|
omit=OMIT,
|
|
)
|
|
try:
|
|
if 200 <= _response.status_code < 300:
|
|
return typing.cast(
|
|
V2RerankResponse,
|
|
construct_type(
|
|
type_=V2RerankResponse, # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
if _response.status_code == 400:
|
|
raise BadRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 401:
|
|
raise UnauthorizedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 403:
|
|
raise ForbiddenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 404:
|
|
raise NotFoundError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 422:
|
|
raise UnprocessableEntityError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 429:
|
|
raise TooManyRequestsError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 498:
|
|
raise InvalidTokenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 499:
|
|
raise ClientClosedRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 500:
|
|
raise InternalServerError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 501:
|
|
raise NotImplementedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 503:
|
|
raise ServiceUnavailableError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 504:
|
|
raise GatewayTimeoutError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
_response_json = _response.json()
|
|
except JSONDecodeError:
|
|
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
|
|
|
|
class AsyncV2Client:
|
|
def __init__(self, *, client_wrapper: AsyncClientWrapper):
|
|
self._client_wrapper = client_wrapper
|
|
|
|
async def chat_stream(
|
|
self,
|
|
*,
|
|
model: str,
|
|
messages: ChatMessages,
|
|
tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT,
|
|
strict_tools: typing.Optional[bool] = OMIT,
|
|
documents: typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]] = OMIT,
|
|
citation_options: typing.Optional[CitationOptions] = OMIT,
|
|
response_format: typing.Optional[ResponseFormatV2] = OMIT,
|
|
safety_mode: typing.Optional[V2ChatStreamRequestSafetyMode] = OMIT,
|
|
max_tokens: typing.Optional[int] = OMIT,
|
|
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
|
|
temperature: typing.Optional[float] = OMIT,
|
|
seed: typing.Optional[int] = OMIT,
|
|
frequency_penalty: typing.Optional[float] = OMIT,
|
|
presence_penalty: typing.Optional[float] = OMIT,
|
|
k: typing.Optional[float] = OMIT,
|
|
p: typing.Optional[float] = OMIT,
|
|
return_prompt: typing.Optional[bool] = OMIT,
|
|
logprobs: typing.Optional[bool] = OMIT,
|
|
tool_choice: typing.Optional[V2ChatStreamRequestToolChoice] = OMIT,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
) -> typing.AsyncIterator[StreamedChatResponseV2]:
|
|
"""
|
|
Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
|
|
|
|
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
|
|
|
|
Parameters
|
|
----------
|
|
model : str
|
|
The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/v2/docs/chat-fine-tuning) model.
|
|
|
|
messages : ChatMessages
|
|
|
|
tools : typing.Optional[typing.Sequence[ToolV2]]
|
|
A list of available tools (functions) that the model may suggest invoking before producing a text response.
|
|
|
|
When `tools` is passed (without `tool_results`), the `text` content in the response will be empty and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
|
|
|
|
|
|
strict_tools : typing.Optional[bool]
|
|
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
|
|
|
|
**Note**: The first few requests with a new set of tools will take longer to process.
|
|
|
|
|
|
documents : typing.Optional[typing.Sequence[V2ChatStreamRequestDocumentsItem]]
|
|
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
|
|
|
|
|
|
citation_options : typing.Optional[CitationOptions]
|
|
|
|
response_format : typing.Optional[ResponseFormatV2]
|
|
|
|
safety_mode : typing.Optional[V2ChatStreamRequestSafetyMode]
|
|
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
|
|
When `OFF` is specified, the safety instruction will be omitted.
|
|
|
|
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
|
|
|
|
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
|
|
|
|
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
|
|
|
|
|
|
max_tokens : typing.Optional[int]
|
|
The maximum number of tokens the model will generate as part of the response.
|
|
|
|
**Note**: Setting a low value may result in incomplete generations.
|
|
|
|
|
|
stop_sequences : typing.Optional[typing.Sequence[str]]
|
|
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
|
|
|
|
|
|
temperature : typing.Optional[float]
|
|
Defaults to `0.3`.
|
|
|
|
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
|
|
|
|
Randomness can be further maximized by increasing the value of the `p` parameter.
|
|
|
|
|
|
seed : typing.Optional[int]
|
|
If specified, the backend will make a best effort to sample tokens
|
|
deterministically, such that repeated requests with the same
|
|
seed and parameters should return the same result. However,
|
|
determinism cannot be totally guaranteed.
|
|
|
|
|
|
frequency_penalty : typing.Optional[float]
|
|
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
|
|
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
|
|
|
|
|
|
presence_penalty : typing.Optional[float]
|
|
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
|
|
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
|
|
|
|
|
|
k : typing.Optional[float]
|
|
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
|
|
Defaults to `0`, min value of `0`, max value of `500`.
|
|
|
|
|
|
p : typing.Optional[float]
|
|
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
|
|
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
|
|
|
|
|
|
return_prompt : typing.Optional[bool]
|
|
Whether to return the prompt in the response.
|
|
|
|
logprobs : typing.Optional[bool]
|
|
Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
|
|
|
|
|
|
tool_choice : typing.Optional[V2ChatStreamRequestToolChoice]
|
|
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
|
|
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
|
|
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
|
|
|
|
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
|
|
|
|
**Note**: The same functionality can be achieved in `/v1/chat` using the `force_single_step` parameter. If `force_single_step=true`, this is equivalent to specifying `REQUIRED`. While if `force_single_step=true` and `tool_results` are passed, this is equivalent to specifying `NONE`.
|
|
|
|
|
|
request_options : typing.Optional[RequestOptions]
|
|
Request-specific configuration.
|
|
|
|
Yields
|
|
------
|
|
typing.AsyncIterator[StreamedChatResponseV2]
|
|
|
|
|
|
Examples
|
|
--------
|
|
import asyncio
|
|
|
|
from cohere import AsyncClient, ToolChatMessageV2
|
|
|
|
client = AsyncClient(
|
|
client_name="YOUR_CLIENT_NAME",
|
|
token="YOUR_TOKEN",
|
|
)
|
|
|
|
|
|
async def main() -> None:
|
|
response = await client.v2.chat_stream(
|
|
model="model",
|
|
messages=[
|
|
ToolChatMessageV2(
|
|
tool_call_id="messages",
|
|
content="messages",
|
|
)
|
|
],
|
|
)
|
|
async for chunk in response:
|
|
yield chunk
|
|
|
|
|
|
asyncio.run(main())
|
|
"""
|
|
async with self._client_wrapper.httpx_client.stream(
|
|
"v2/chat",
|
|
method="POST",
|
|
json={
|
|
"model": model,
|
|
"messages": convert_and_respect_annotation_metadata(
|
|
object_=messages, annotation=ChatMessages, direction="write"
|
|
),
|
|
"tools": convert_and_respect_annotation_metadata(
|
|
object_=tools, annotation=typing.Sequence[ToolV2], direction="write"
|
|
),
|
|
"strict_tools": strict_tools,
|
|
"documents": convert_and_respect_annotation_metadata(
|
|
object_=documents, annotation=typing.Sequence[V2ChatStreamRequestDocumentsItem], direction="write"
|
|
),
|
|
"citation_options": convert_and_respect_annotation_metadata(
|
|
object_=citation_options, annotation=CitationOptions, direction="write"
|
|
),
|
|
"response_format": convert_and_respect_annotation_metadata(
|
|
object_=response_format, annotation=ResponseFormatV2, direction="write"
|
|
),
|
|
"safety_mode": safety_mode,
|
|
"max_tokens": max_tokens,
|
|
"stop_sequences": stop_sequences,
|
|
"temperature": temperature,
|
|
"seed": seed,
|
|
"frequency_penalty": frequency_penalty,
|
|
"presence_penalty": presence_penalty,
|
|
"k": k,
|
|
"p": p,
|
|
"return_prompt": return_prompt,
|
|
"logprobs": logprobs,
|
|
"tool_choice": tool_choice,
|
|
"stream": True,
|
|
},
|
|
headers={
|
|
"content-type": "application/json",
|
|
},
|
|
request_options=request_options,
|
|
omit=OMIT,
|
|
) as _response:
|
|
try:
|
|
if 200 <= _response.status_code < 300:
|
|
_event_source = httpx_sse.EventSource(_response)
|
|
async for _sse in _event_source.aiter_sse():
|
|
try:
|
|
yield typing.cast(
|
|
StreamedChatResponseV2,
|
|
construct_type(
|
|
type_=StreamedChatResponseV2, # type: ignore
|
|
object_=json.loads(_sse.data),
|
|
),
|
|
)
|
|
except:
|
|
pass
|
|
return
|
|
await _response.aread()
|
|
if _response.status_code == 400:
|
|
raise BadRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 401:
|
|
raise UnauthorizedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 403:
|
|
raise ForbiddenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 404:
|
|
raise NotFoundError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 422:
|
|
raise UnprocessableEntityError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 429:
|
|
raise TooManyRequestsError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 498:
|
|
raise InvalidTokenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 499:
|
|
raise ClientClosedRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 500:
|
|
raise InternalServerError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 501:
|
|
raise NotImplementedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 503:
|
|
raise ServiceUnavailableError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 504:
|
|
raise GatewayTimeoutError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
_response_json = _response.json()
|
|
except JSONDecodeError:
|
|
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
|
|
async def chat(
|
|
self,
|
|
*,
|
|
model: str,
|
|
messages: ChatMessages,
|
|
tools: typing.Optional[typing.Sequence[ToolV2]] = OMIT,
|
|
strict_tools: typing.Optional[bool] = OMIT,
|
|
documents: typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]] = OMIT,
|
|
citation_options: typing.Optional[CitationOptions] = OMIT,
|
|
response_format: typing.Optional[ResponseFormatV2] = OMIT,
|
|
safety_mode: typing.Optional[V2ChatRequestSafetyMode] = OMIT,
|
|
max_tokens: typing.Optional[int] = OMIT,
|
|
stop_sequences: typing.Optional[typing.Sequence[str]] = OMIT,
|
|
temperature: typing.Optional[float] = OMIT,
|
|
seed: typing.Optional[int] = OMIT,
|
|
frequency_penalty: typing.Optional[float] = OMIT,
|
|
presence_penalty: typing.Optional[float] = OMIT,
|
|
k: typing.Optional[float] = OMIT,
|
|
p: typing.Optional[float] = OMIT,
|
|
return_prompt: typing.Optional[bool] = OMIT,
|
|
logprobs: typing.Optional[bool] = OMIT,
|
|
tool_choice: typing.Optional[V2ChatRequestToolChoice] = OMIT,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
) -> ChatResponse:
|
|
"""
|
|
Generates a text response to a user message and streams it down, token by token. To learn how to use the Chat API with streaming follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api).
|
|
|
|
Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2.
|
|
|
|
Parameters
|
|
----------
|
|
model : str
|
|
The name of a compatible [Cohere model](https://docs.cohere.com/v2/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/v2/docs/chat-fine-tuning) model.
|
|
|
|
messages : ChatMessages
|
|
|
|
tools : typing.Optional[typing.Sequence[ToolV2]]
|
|
A list of available tools (functions) that the model may suggest invoking before producing a text response.
|
|
|
|
When `tools` is passed (without `tool_results`), the `text` content in the response will be empty and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty.
|
|
|
|
|
|
strict_tools : typing.Optional[bool]
|
|
When set to `true`, tool calls in the Assistant message will be forced to follow the tool definition strictly. Learn more in the [Structured Outputs (Tools) guide](https://docs.cohere.com/docs/structured-outputs-json#structured-outputs-tools).
|
|
|
|
**Note**: The first few requests with a new set of tools will take longer to process.
|
|
|
|
|
|
documents : typing.Optional[typing.Sequence[V2ChatRequestDocumentsItem]]
|
|
A list of relevant documents that the model can cite to generate a more accurate reply. Each document is either a string or document object with content and metadata.
|
|
|
|
|
|
citation_options : typing.Optional[CitationOptions]
|
|
|
|
response_format : typing.Optional[ResponseFormatV2]
|
|
|
|
safety_mode : typing.Optional[V2ChatRequestSafetyMode]
|
|
Used to select the [safety instruction](https://docs.cohere.com/v2/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`.
|
|
When `OFF` is specified, the safety instruction will be omitted.
|
|
|
|
Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters.
|
|
|
|
**Note**: This parameter is only compatible newer Cohere models, starting with [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release) and [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release).
|
|
|
|
**Note**: `command-r7b-12-2024` and newer models only support `"CONTEXTUAL"` and `"STRICT"` modes.
|
|
|
|
|
|
max_tokens : typing.Optional[int]
|
|
The maximum number of tokens the model will generate as part of the response.
|
|
|
|
**Note**: Setting a low value may result in incomplete generations.
|
|
|
|
|
|
stop_sequences : typing.Optional[typing.Sequence[str]]
|
|
A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence.
|
|
|
|
|
|
temperature : typing.Optional[float]
|
|
Defaults to `0.3`.
|
|
|
|
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations.
|
|
|
|
Randomness can be further maximized by increasing the value of the `p` parameter.
|
|
|
|
|
|
seed : typing.Optional[int]
|
|
If specified, the backend will make a best effort to sample tokens
|
|
deterministically, such that repeated requests with the same
|
|
seed and parameters should return the same result. However,
|
|
determinism cannot be totally guaranteed.
|
|
|
|
|
|
frequency_penalty : typing.Optional[float]
|
|
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
|
|
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
|
|
|
|
|
|
presence_penalty : typing.Optional[float]
|
|
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
|
|
Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
|
|
|
|
|
|
k : typing.Optional[float]
|
|
Ensures that only the top `k` most likely tokens are considered for generation at each step. When `k` is set to `0`, k-sampling is disabled.
|
|
Defaults to `0`, min value of `0`, max value of `500`.
|
|
|
|
|
|
p : typing.Optional[float]
|
|
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
|
|
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
|
|
|
|
|
|
return_prompt : typing.Optional[bool]
|
|
Whether to return the prompt in the response.
|
|
|
|
logprobs : typing.Optional[bool]
|
|
Defaults to `false`. When set to `true`, the log probabilities of the generated tokens will be included in the response.
|
|
|
|
|
|
tool_choice : typing.Optional[V2ChatRequestToolChoice]
|
|
Used to control whether or not the model will be forced to use a tool when answering. When `REQUIRED` is specified, the model will be forced to use at least one of the user-defined tools, and the `tools` parameter must be passed in the request.
|
|
When `NONE` is specified, the model will be forced **not** to use one of the specified tools, and give a direct response.
|
|
If tool_choice isn't specified, then the model is free to choose whether to use the specified tools or not.
|
|
|
|
**Note**: This parameter is only compatible with models [Command-r7b](https://docs.cohere.com/v2/docs/command-r7b) and newer.
|
|
|
|
**Note**: The same functionality can be achieved in `/v1/chat` using the `force_single_step` parameter. If `force_single_step=true`, this is equivalent to specifying `REQUIRED`. While if `force_single_step=true` and `tool_results` are passed, this is equivalent to specifying `NONE`.
|
|
|
|
|
|
request_options : typing.Optional[RequestOptions]
|
|
Request-specific configuration.
|
|
|
|
Returns
|
|
-------
|
|
ChatResponse
|
|
|
|
|
|
Examples
|
|
--------
|
|
import asyncio
|
|
|
|
from cohere import AsyncClient, ToolChatMessageV2
|
|
|
|
client = AsyncClient(
|
|
client_name="YOUR_CLIENT_NAME",
|
|
token="YOUR_TOKEN",
|
|
)
|
|
|
|
|
|
async def main() -> None:
|
|
await client.v2.chat(
|
|
model="model",
|
|
messages=[
|
|
ToolChatMessageV2(
|
|
tool_call_id="messages",
|
|
content="messages",
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
asyncio.run(main())
|
|
"""
|
|
_response = await self._client_wrapper.httpx_client.request(
|
|
"v2/chat",
|
|
method="POST",
|
|
json={
|
|
"model": model,
|
|
"messages": convert_and_respect_annotation_metadata(
|
|
object_=messages, annotation=ChatMessages, direction="write"
|
|
),
|
|
"tools": convert_and_respect_annotation_metadata(
|
|
object_=tools, annotation=typing.Sequence[ToolV2], direction="write"
|
|
),
|
|
"strict_tools": strict_tools,
|
|
"documents": convert_and_respect_annotation_metadata(
|
|
object_=documents, annotation=typing.Sequence[V2ChatRequestDocumentsItem], direction="write"
|
|
),
|
|
"citation_options": convert_and_respect_annotation_metadata(
|
|
object_=citation_options, annotation=CitationOptions, direction="write"
|
|
),
|
|
"response_format": convert_and_respect_annotation_metadata(
|
|
object_=response_format, annotation=ResponseFormatV2, direction="write"
|
|
),
|
|
"safety_mode": safety_mode,
|
|
"max_tokens": max_tokens,
|
|
"stop_sequences": stop_sequences,
|
|
"temperature": temperature,
|
|
"seed": seed,
|
|
"frequency_penalty": frequency_penalty,
|
|
"presence_penalty": presence_penalty,
|
|
"k": k,
|
|
"p": p,
|
|
"return_prompt": return_prompt,
|
|
"logprobs": logprobs,
|
|
"tool_choice": tool_choice,
|
|
"stream": False,
|
|
},
|
|
headers={
|
|
"content-type": "application/json",
|
|
},
|
|
request_options=request_options,
|
|
omit=OMIT,
|
|
)
|
|
try:
|
|
if 200 <= _response.status_code < 300:
|
|
return typing.cast(
|
|
ChatResponse,
|
|
construct_type(
|
|
type_=ChatResponse, # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
if _response.status_code == 400:
|
|
raise BadRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 401:
|
|
raise UnauthorizedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 403:
|
|
raise ForbiddenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 404:
|
|
raise NotFoundError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 422:
|
|
raise UnprocessableEntityError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 429:
|
|
raise TooManyRequestsError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 498:
|
|
raise InvalidTokenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 499:
|
|
raise ClientClosedRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 500:
|
|
raise InternalServerError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 501:
|
|
raise NotImplementedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 503:
|
|
raise ServiceUnavailableError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 504:
|
|
raise GatewayTimeoutError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
_response_json = _response.json()
|
|
except JSONDecodeError:
|
|
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
|
|
async def embed(
|
|
self,
|
|
*,
|
|
model: str,
|
|
input_type: EmbedInputType,
|
|
embedding_types: typing.Sequence[EmbeddingType],
|
|
texts: typing.Optional[typing.Sequence[str]] = OMIT,
|
|
images: typing.Optional[typing.Sequence[str]] = OMIT,
|
|
inputs: typing.Optional[typing.Sequence[EmbedInput]] = OMIT,
|
|
max_tokens: typing.Optional[int] = OMIT,
|
|
output_dimension: typing.Optional[int] = OMIT,
|
|
truncate: typing.Optional[V2EmbedRequestTruncate] = OMIT,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
) -> EmbedByTypeResponse:
|
|
"""
|
|
This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents.
|
|
|
|
Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
|
|
|
|
If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](https://docs.cohere.com/docs/semantic-search).
|
|
|
|
Parameters
|
|
----------
|
|
model : str
|
|
Defaults to embed-english-v2.0
|
|
|
|
The identifier of the model. Smaller "light" models are faster, while larger models will perform better. [Custom models](https://docs.cohere.com/docs/training-custom-models) can also be supplied with their full ID.
|
|
|
|
Available models and corresponding embedding dimensions:
|
|
|
|
* `embed-english-v3.0` 1024
|
|
* `embed-multilingual-v3.0` 1024
|
|
* `embed-english-light-v3.0` 384
|
|
* `embed-multilingual-light-v3.0` 384
|
|
|
|
* `embed-english-v2.0` 4096
|
|
* `embed-english-light-v2.0` 1024
|
|
* `embed-multilingual-v2.0` 768
|
|
|
|
input_type : EmbedInputType
|
|
|
|
embedding_types : typing.Sequence[EmbeddingType]
|
|
Specifies the types of embeddings you want to get back. Can be one or more of the following types.
|
|
|
|
* `"float"`: Use this when you want to get back the default float embeddings. Valid for all models.
|
|
* `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
|
|
* `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for only v3 models.
|
|
* `"binary"`: Use this when you want to get back signed binary embeddings. Valid for only v3 models.
|
|
* `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for only v3 models.
|
|
|
|
texts : typing.Optional[typing.Sequence[str]]
|
|
An array of strings for the model to embed. Maximum number of texts per call is `96`. We recommend reducing the length of each text to be under `512` tokens for optimal quality.
|
|
|
|
images : typing.Optional[typing.Sequence[str]]
|
|
An array of image data URIs for the model to embed. Maximum number of images per call is `1`.
|
|
|
|
The image must be a valid [data URI](https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data). The image must be in either `image/jpeg` or `image/png` format and has a maximum size of 5MB.
|
|
|
|
inputs : typing.Optional[typing.Sequence[EmbedInput]]
|
|
An array of inputs for the model to embed. Maximum number of inputs per call is `96`. An input can contain a mix of text and image components.
|
|
|
|
max_tokens : typing.Optional[int]
|
|
The maximum number of tokens to embed per input. If the input text is longer than this, it will be truncated according to the `truncate` parameter.
|
|
|
|
output_dimension : typing.Optional[int]
|
|
The number of dimensions of the output embedding. This is only available for `embed-v4` and newer models.
|
|
Possible values are `256`, `512`, `1024`, and `1536`. The default is `1536`.
|
|
|
|
truncate : typing.Optional[V2EmbedRequestTruncate]
|
|
One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length.
|
|
|
|
Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
|
|
|
|
If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned.
|
|
|
|
request_options : typing.Optional[RequestOptions]
|
|
Request-specific configuration.
|
|
|
|
Returns
|
|
-------
|
|
EmbedByTypeResponse
|
|
OK
|
|
|
|
Examples
|
|
--------
|
|
import asyncio
|
|
|
|
from cohere import AsyncClient
|
|
|
|
client = AsyncClient(
|
|
client_name="YOUR_CLIENT_NAME",
|
|
token="YOUR_TOKEN",
|
|
)
|
|
|
|
|
|
async def main() -> None:
|
|
await client.v2.embed(
|
|
model="model",
|
|
input_type="search_document",
|
|
embedding_types=["float"],
|
|
)
|
|
|
|
|
|
asyncio.run(main())
|
|
"""
|
|
_response = await self._client_wrapper.httpx_client.request(
|
|
"v2/embed",
|
|
method="POST",
|
|
json={
|
|
"texts": texts,
|
|
"images": images,
|
|
"model": model,
|
|
"input_type": input_type,
|
|
"inputs": convert_and_respect_annotation_metadata(
|
|
object_=inputs, annotation=typing.Sequence[EmbedInput], direction="write"
|
|
),
|
|
"max_tokens": max_tokens,
|
|
"output_dimension": output_dimension,
|
|
"embedding_types": embedding_types,
|
|
"truncate": truncate,
|
|
},
|
|
headers={
|
|
"content-type": "application/json",
|
|
},
|
|
request_options=request_options,
|
|
omit=OMIT,
|
|
)
|
|
try:
|
|
if 200 <= _response.status_code < 300:
|
|
return typing.cast(
|
|
EmbedByTypeResponse,
|
|
construct_type(
|
|
type_=EmbedByTypeResponse, # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
if _response.status_code == 400:
|
|
raise BadRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 401:
|
|
raise UnauthorizedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 403:
|
|
raise ForbiddenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 404:
|
|
raise NotFoundError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 422:
|
|
raise UnprocessableEntityError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 429:
|
|
raise TooManyRequestsError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 498:
|
|
raise InvalidTokenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 499:
|
|
raise ClientClosedRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 500:
|
|
raise InternalServerError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 501:
|
|
raise NotImplementedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 503:
|
|
raise ServiceUnavailableError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 504:
|
|
raise GatewayTimeoutError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
_response_json = _response.json()
|
|
except JSONDecodeError:
|
|
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
|
|
async def rerank(
|
|
self,
|
|
*,
|
|
model: str,
|
|
query: str,
|
|
documents: typing.Sequence[str],
|
|
top_n: typing.Optional[int] = OMIT,
|
|
return_documents: typing.Optional[bool] = OMIT,
|
|
max_tokens_per_doc: typing.Optional[int] = OMIT,
|
|
request_options: typing.Optional[RequestOptions] = None,
|
|
) -> V2RerankResponse:
|
|
"""
|
|
This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
|
|
|
|
Parameters
|
|
----------
|
|
model : str
|
|
The identifier of the model to use, eg `rerank-v3.5`.
|
|
|
|
query : str
|
|
The search query
|
|
|
|
documents : typing.Sequence[str]
|
|
A list of texts that will be compared to the `query`.
|
|
For optimal performance we recommend against sending more than 1,000 documents in a single request.
|
|
|
|
**Note**: long documents will automatically be truncated to the value of `max_tokens_per_doc`.
|
|
|
|
**Note**: structured data should be formatted as YAML strings for best performance.
|
|
|
|
top_n : typing.Optional[int]
|
|
Limits the number of returned rerank results to the specified value. If not passed, all the rerank results will be returned.
|
|
|
|
return_documents : typing.Optional[bool]
|
|
- If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request.
|
|
- If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request.
|
|
|
|
max_tokens_per_doc : typing.Optional[int]
|
|
Defaults to `4096`. Long documents will be automatically truncated to the specified number of tokens.
|
|
|
|
request_options : typing.Optional[RequestOptions]
|
|
Request-specific configuration.
|
|
|
|
Returns
|
|
-------
|
|
V2RerankResponse
|
|
OK
|
|
|
|
Examples
|
|
--------
|
|
import asyncio
|
|
|
|
from cohere import AsyncClient
|
|
|
|
client = AsyncClient(
|
|
client_name="YOUR_CLIENT_NAME",
|
|
token="YOUR_TOKEN",
|
|
)
|
|
|
|
|
|
async def main() -> None:
|
|
await client.v2.rerank(
|
|
model="model",
|
|
query="query",
|
|
documents=["documents"],
|
|
)
|
|
|
|
|
|
asyncio.run(main())
|
|
"""
|
|
_response = await self._client_wrapper.httpx_client.request(
|
|
"v2/rerank",
|
|
method="POST",
|
|
json={
|
|
"model": model,
|
|
"query": query,
|
|
"documents": documents,
|
|
"top_n": top_n,
|
|
"return_documents": return_documents,
|
|
"max_tokens_per_doc": max_tokens_per_doc,
|
|
},
|
|
headers={
|
|
"content-type": "application/json",
|
|
},
|
|
request_options=request_options,
|
|
omit=OMIT,
|
|
)
|
|
try:
|
|
if 200 <= _response.status_code < 300:
|
|
return typing.cast(
|
|
V2RerankResponse,
|
|
construct_type(
|
|
type_=V2RerankResponse, # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
if _response.status_code == 400:
|
|
raise BadRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 401:
|
|
raise UnauthorizedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 403:
|
|
raise ForbiddenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 404:
|
|
raise NotFoundError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 422:
|
|
raise UnprocessableEntityError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 429:
|
|
raise TooManyRequestsError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 498:
|
|
raise InvalidTokenError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 499:
|
|
raise ClientClosedRequestError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 500:
|
|
raise InternalServerError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 501:
|
|
raise NotImplementedError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 503:
|
|
raise ServiceUnavailableError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
if _response.status_code == 504:
|
|
raise GatewayTimeoutError(
|
|
typing.cast(
|
|
typing.Optional[typing.Any],
|
|
construct_type(
|
|
type_=typing.Optional[typing.Any], # type: ignore
|
|
object_=_response.json(),
|
|
),
|
|
)
|
|
)
|
|
_response_json = _response.json()
|
|
except JSONDecodeError:
|
|
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
raise ApiError(status_code=_response.status_code, body=_response_json)
|