import base64 import json import re import typing import httpx from httpx import URL, SyncByteStream, ByteStream from . import GenerateStreamedResponse, Generation, \ NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse, ApiMeta, ApiMetaTokens, \ ApiMetaBilledUnits from .client import Client, ClientEnvironment from .core import construct_type from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore from .client_v2 import ClientV2 class AwsClient(Client): def __init__( self, *, aws_access_key: typing.Optional[str] = None, aws_secret_key: typing.Optional[str] = None, aws_session_token: typing.Optional[str] = None, aws_region: typing.Optional[str] = None, timeout: typing.Optional[float] = None, service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]], ): Client.__init__( self, base_url="https://api.cohere.com", # this url is unused for BedrockClient environment=ClientEnvironment.PRODUCTION, client_name="n/a", timeout=timeout, api_key="n/a", httpx_client=httpx.Client( event_hooks=get_event_hooks( service=service, aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, aws_session_token=aws_session_token, aws_region=aws_region, ), timeout=timeout, ), ) class AwsClientV2(ClientV2): def __init__( self, *, aws_access_key: typing.Optional[str] = None, aws_secret_key: typing.Optional[str] = None, aws_session_token: typing.Optional[str] = None, aws_region: typing.Optional[str] = None, timeout: typing.Optional[float] = None, service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]], ): ClientV2.__init__( self, base_url="https://api.cohere.com", # this url is unused for BedrockClient environment=ClientEnvironment.PRODUCTION, client_name="n/a", timeout=timeout, api_key="n/a", httpx_client=httpx.Client( event_hooks=get_event_hooks( service=service, aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, aws_session_token=aws_session_token, aws_region=aws_region, ), timeout=timeout, ), ) EventHook = typing.Callable[..., typing.Any] def get_event_hooks( service: str, aws_access_key: typing.Optional[str] = None, aws_secret_key: typing.Optional[str] = None, aws_session_token: typing.Optional[str] = None, aws_region: typing.Optional[str] = None, ) -> typing.Dict[str, typing.List[EventHook]]: return { "request": [ map_request_to_bedrock( service=service, aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, aws_session_token=aws_session_token, aws_region=aws_region, ), ], "response": [ map_response_from_bedrock() ], } TextGeneration = typing.TypedDict('TextGeneration', {"text": str, "is_finished": str, "event_type": typing.Literal["text-generation"]}) StreamEnd = typing.TypedDict('StreamEnd', {"is_finished": str, "event_type": typing.Literal["stream-end"], "finish_reason": str, # "amazon-bedrock-invocationMetrics": { # "inputTokenCount": int, "outputTokenCount": int, "invocationLatency": int, # "firstByteLatency": int} }) class Streamer(SyncByteStream): lines: typing.Iterator[bytes] def __init__(self, lines: typing.Iterator[bytes]): self.lines = lines def __iter__(self) -> typing.Iterator[bytes]: return self.lines response_mapping: typing.Dict[str, typing.Any] = { "chat": NonStreamedChatResponse, "embed": EmbedResponse, "generate": Generation, "rerank": RerankResponse } stream_response_mapping: typing.Dict[str, typing.Any] = { "chat": StreamedChatResponse, "generate": GenerateStreamedResponse, } def stream_generator(response: httpx.Response, endpoint: str) -> typing.Iterator[bytes]: regex = r"{[^\}]*}" for _text in response.iter_lines(): match = re.search(regex, _text) if match: obj = json.loads(match.group()) if "bytes" in obj: base64_payload = base64.b64decode(obj["bytes"]).decode("utf-8") streamed_obj = json.loads(base64_payload) if "event_type" in streamed_obj: response_type = stream_response_mapping[endpoint] parsed = typing.cast(response_type, # type: ignore construct_type(type_=response_type, object_=streamed_obj)) yield (json.dumps(parsed.dict()) + "\n").encode("utf-8") # type: ignore def map_token_counts(response: httpx.Response) -> ApiMeta: input_tokens = int(response.headers.get("X-Amzn-Bedrock-Input-Token-Count", -1)) output_tokens = int(response.headers.get("X-Amzn-Bedrock-Output-Token-Count", -1)) return ApiMeta( tokens=ApiMetaTokens(input_tokens=input_tokens, output_tokens=output_tokens), billed_units=ApiMetaBilledUnits(input_tokens=input_tokens, output_tokens=output_tokens), ) def map_response_from_bedrock(): def _hook( response: httpx.Response, ) -> None: stream = response.headers["content-type"] == "application/vnd.amazon.eventstream" endpoint = response.request.extensions["endpoint"] output: typing.Iterator[bytes] if stream: output = stream_generator(httpx.Response( stream=response.stream, status_code=response.status_code, ), endpoint) else: response_type = response_mapping[endpoint] response_obj = json.loads(response.read()) response_obj["meta"] = map_token_counts(response).dict() cast_obj: typing.Any = typing.cast(response_type, # type: ignore construct_type( type_=response_type, # type: ignore object_=response_obj)) output = iter([json.dumps(cast_obj.dict()).encode("utf-8")]) response.stream = Streamer(output) # reset response object to allow for re-reading if hasattr(response, "_content"): del response._content response.is_stream_consumed = False response.is_closed = False return _hook def get_boto3_session( **kwargs: typing.Any, ): non_none_args = {k: v for k, v in kwargs.items() if v is not None} return lazy_boto3().Session(**non_none_args) def map_request_to_bedrock( service: str, aws_access_key: typing.Optional[str] = None, aws_secret_key: typing.Optional[str] = None, aws_session_token: typing.Optional[str] = None, aws_region: typing.Optional[str] = None, ) -> EventHook: session = get_boto3_session( region_name=aws_region, aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_key, aws_session_token=aws_session_token, ) aws_region = session.region_name credentials = session.get_credentials() signer = lazy_botocore().auth.SigV4Auth(credentials, service, aws_region) def _event_hook(request: httpx.Request) -> None: headers = request.headers.copy() del headers["connection"] api_version = request.url.path.split("/")[-2] endpoint = request.url.path.split("/")[-1] body = json.loads(request.read()) model = body["model"] url = get_url( platform=service, aws_region=aws_region, model=model, # type: ignore stream="stream" in body and body["stream"], ) request.url = URL(url) request.headers["host"] = request.url.host if endpoint == "rerank": body["api_version"] = get_api_version(version=api_version) if "stream" in body: del body["stream"] if "model" in body: del body["model"] new_body = json.dumps(body).encode("utf-8") request.stream = ByteStream(new_body) request._content = new_body headers["content-length"] = str(len(new_body)) aws_request = lazy_botocore().awsrequest.AWSRequest( method=request.method, url=url, headers=headers, data=request.read(), ) signer.add_auth(aws_request) request.headers = httpx.Headers(aws_request.prepare().headers) request.extensions["endpoint"] = endpoint return _event_hook def get_url( *, platform: str, aws_region: typing.Optional[str], model: str, stream: bool, ) -> str: if platform == "bedrock": endpoint = "invoke" if not stream else "invoke-with-response-stream" return f"https://{platform}-runtime.{aws_region}.amazonaws.com/model/{model}/{endpoint}" elif platform == "sagemaker": endpoint = "invocations" if not stream else "invocations-response-stream" return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}" return "" def get_api_version(*, version: str): int_version = { "v1": 1, "v2": 2, } return int_version.get(version, 1)