2025-05-16 18:00:22 +04:00

293 lines
9.9 KiB
Python

import base64
import json
import re
import typing
import httpx
from httpx import URL, SyncByteStream, ByteStream
from . import GenerateStreamedResponse, Generation, \
NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse, ApiMeta, ApiMetaTokens, \
ApiMetaBilledUnits
from .client import Client, ClientEnvironment
from .core import construct_type
from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore
from .client_v2 import ClientV2
class AwsClient(Client):
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
):
Client.__init__(
self,
base_url="https://api.cohere.com", # this url is unused for BedrockClient
environment=ClientEnvironment.PRODUCTION,
client_name="n/a",
timeout=timeout,
api_key="n/a",
httpx_client=httpx.Client(
event_hooks=get_event_hooks(
service=service,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
),
timeout=timeout,
),
)
class AwsClientV2(ClientV2):
def __init__(
self,
*,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
):
ClientV2.__init__(
self,
base_url="https://api.cohere.com", # this url is unused for BedrockClient
environment=ClientEnvironment.PRODUCTION,
client_name="n/a",
timeout=timeout,
api_key="n/a",
httpx_client=httpx.Client(
event_hooks=get_event_hooks(
service=service,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
),
timeout=timeout,
),
)
EventHook = typing.Callable[..., typing.Any]
def get_event_hooks(
service: str,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
) -> typing.Dict[str, typing.List[EventHook]]:
return {
"request": [
map_request_to_bedrock(
service=service,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
),
],
"response": [
map_response_from_bedrock()
],
}
TextGeneration = typing.TypedDict('TextGeneration',
{"text": str, "is_finished": str, "event_type": typing.Literal["text-generation"]})
StreamEnd = typing.TypedDict('StreamEnd',
{"is_finished": str, "event_type": typing.Literal["stream-end"], "finish_reason": str,
# "amazon-bedrock-invocationMetrics": {
# "inputTokenCount": int, "outputTokenCount": int, "invocationLatency": int,
# "firstByteLatency": int}
})
class Streamer(SyncByteStream):
lines: typing.Iterator[bytes]
def __init__(self, lines: typing.Iterator[bytes]):
self.lines = lines
def __iter__(self) -> typing.Iterator[bytes]:
return self.lines
response_mapping: typing.Dict[str, typing.Any] = {
"chat": NonStreamedChatResponse,
"embed": EmbedResponse,
"generate": Generation,
"rerank": RerankResponse
}
stream_response_mapping: typing.Dict[str, typing.Any] = {
"chat": StreamedChatResponse,
"generate": GenerateStreamedResponse,
}
def stream_generator(response: httpx.Response, endpoint: str) -> typing.Iterator[bytes]:
regex = r"{[^\}]*}"
for _text in response.iter_lines():
match = re.search(regex, _text)
if match:
obj = json.loads(match.group())
if "bytes" in obj:
base64_payload = base64.b64decode(obj["bytes"]).decode("utf-8")
streamed_obj = json.loads(base64_payload)
if "event_type" in streamed_obj:
response_type = stream_response_mapping[endpoint]
parsed = typing.cast(response_type, # type: ignore
construct_type(type_=response_type, object_=streamed_obj))
yield (json.dumps(parsed.dict()) + "\n").encode("utf-8") # type: ignore
def map_token_counts(response: httpx.Response) -> ApiMeta:
input_tokens = int(response.headers.get("X-Amzn-Bedrock-Input-Token-Count", -1))
output_tokens = int(response.headers.get("X-Amzn-Bedrock-Output-Token-Count", -1))
return ApiMeta(
tokens=ApiMetaTokens(input_tokens=input_tokens, output_tokens=output_tokens),
billed_units=ApiMetaBilledUnits(input_tokens=input_tokens, output_tokens=output_tokens),
)
def map_response_from_bedrock():
def _hook(
response: httpx.Response,
) -> None:
stream = response.headers["content-type"] == "application/vnd.amazon.eventstream"
endpoint = response.request.extensions["endpoint"]
output: typing.Iterator[bytes]
if stream:
output = stream_generator(httpx.Response(
stream=response.stream,
status_code=response.status_code,
), endpoint)
else:
response_type = response_mapping[endpoint]
response_obj = json.loads(response.read())
response_obj["meta"] = map_token_counts(response).dict()
cast_obj: typing.Any = typing.cast(response_type, # type: ignore
construct_type(
type_=response_type,
# type: ignore
object_=response_obj))
output = iter([json.dumps(cast_obj.dict()).encode("utf-8")])
response.stream = Streamer(output)
# reset response object to allow for re-reading
if hasattr(response, "_content"):
del response._content
response.is_stream_consumed = False
response.is_closed = False
return _hook
def get_boto3_session(
**kwargs: typing.Any,
):
non_none_args = {k: v for k, v in kwargs.items() if v is not None}
return lazy_boto3().Session(**non_none_args)
def map_request_to_bedrock(
service: str,
aws_access_key: typing.Optional[str] = None,
aws_secret_key: typing.Optional[str] = None,
aws_session_token: typing.Optional[str] = None,
aws_region: typing.Optional[str] = None,
) -> EventHook:
session = get_boto3_session(
region_name=aws_region,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
aws_session_token=aws_session_token,
)
aws_region = session.region_name
credentials = session.get_credentials()
signer = lazy_botocore().auth.SigV4Auth(credentials, service, aws_region)
def _event_hook(request: httpx.Request) -> None:
headers = request.headers.copy()
del headers["connection"]
api_version = request.url.path.split("/")[-2]
endpoint = request.url.path.split("/")[-1]
body = json.loads(request.read())
model = body["model"]
url = get_url(
platform=service,
aws_region=aws_region,
model=model, # type: ignore
stream="stream" in body and body["stream"],
)
request.url = URL(url)
request.headers["host"] = request.url.host
if endpoint == "rerank":
body["api_version"] = get_api_version(version=api_version)
if "stream" in body:
del body["stream"]
if "model" in body:
del body["model"]
new_body = json.dumps(body).encode("utf-8")
request.stream = ByteStream(new_body)
request._content = new_body
headers["content-length"] = str(len(new_body))
aws_request = lazy_botocore().awsrequest.AWSRequest(
method=request.method,
url=url,
headers=headers,
data=request.read(),
)
signer.add_auth(aws_request)
request.headers = httpx.Headers(aws_request.prepare().headers)
request.extensions["endpoint"] = endpoint
return _event_hook
def get_url(
*,
platform: str,
aws_region: typing.Optional[str],
model: str,
stream: bool,
) -> str:
if platform == "bedrock":
endpoint = "invoke" if not stream else "invoke-with-response-stream"
return f"https://{platform}-runtime.{aws_region}.amazonaws.com/model/{model}/{endpoint}"
elif platform == "sagemaker":
endpoint = "invocations" if not stream else "invocations-response-stream"
return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}"
return ""
def get_api_version(*, version: str):
int_version = {
"v1": 1,
"v2": 2,
}
return int_version.get(version, 1)