293 lines
9.9 KiB
Python
293 lines
9.9 KiB
Python
import base64
|
|
import json
|
|
import re
|
|
import typing
|
|
|
|
import httpx
|
|
from httpx import URL, SyncByteStream, ByteStream
|
|
|
|
from . import GenerateStreamedResponse, Generation, \
|
|
NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse, ApiMeta, ApiMetaTokens, \
|
|
ApiMetaBilledUnits
|
|
from .client import Client, ClientEnvironment
|
|
from .core import construct_type
|
|
from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore
|
|
from .client_v2 import ClientV2
|
|
|
|
class AwsClient(Client):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
aws_access_key: typing.Optional[str] = None,
|
|
aws_secret_key: typing.Optional[str] = None,
|
|
aws_session_token: typing.Optional[str] = None,
|
|
aws_region: typing.Optional[str] = None,
|
|
timeout: typing.Optional[float] = None,
|
|
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
|
|
):
|
|
Client.__init__(
|
|
self,
|
|
base_url="https://api.cohere.com", # this url is unused for BedrockClient
|
|
environment=ClientEnvironment.PRODUCTION,
|
|
client_name="n/a",
|
|
timeout=timeout,
|
|
api_key="n/a",
|
|
httpx_client=httpx.Client(
|
|
event_hooks=get_event_hooks(
|
|
service=service,
|
|
aws_access_key=aws_access_key,
|
|
aws_secret_key=aws_secret_key,
|
|
aws_session_token=aws_session_token,
|
|
aws_region=aws_region,
|
|
),
|
|
timeout=timeout,
|
|
),
|
|
)
|
|
|
|
|
|
class AwsClientV2(ClientV2):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
aws_access_key: typing.Optional[str] = None,
|
|
aws_secret_key: typing.Optional[str] = None,
|
|
aws_session_token: typing.Optional[str] = None,
|
|
aws_region: typing.Optional[str] = None,
|
|
timeout: typing.Optional[float] = None,
|
|
service: typing.Union[typing.Literal["bedrock"], typing.Literal["sagemaker"]],
|
|
):
|
|
ClientV2.__init__(
|
|
self,
|
|
base_url="https://api.cohere.com", # this url is unused for BedrockClient
|
|
environment=ClientEnvironment.PRODUCTION,
|
|
client_name="n/a",
|
|
timeout=timeout,
|
|
api_key="n/a",
|
|
httpx_client=httpx.Client(
|
|
event_hooks=get_event_hooks(
|
|
service=service,
|
|
aws_access_key=aws_access_key,
|
|
aws_secret_key=aws_secret_key,
|
|
aws_session_token=aws_session_token,
|
|
aws_region=aws_region,
|
|
),
|
|
timeout=timeout,
|
|
),
|
|
)
|
|
|
|
|
|
EventHook = typing.Callable[..., typing.Any]
|
|
|
|
|
|
def get_event_hooks(
|
|
service: str,
|
|
aws_access_key: typing.Optional[str] = None,
|
|
aws_secret_key: typing.Optional[str] = None,
|
|
aws_session_token: typing.Optional[str] = None,
|
|
aws_region: typing.Optional[str] = None,
|
|
) -> typing.Dict[str, typing.List[EventHook]]:
|
|
return {
|
|
"request": [
|
|
map_request_to_bedrock(
|
|
service=service,
|
|
aws_access_key=aws_access_key,
|
|
aws_secret_key=aws_secret_key,
|
|
aws_session_token=aws_session_token,
|
|
aws_region=aws_region,
|
|
),
|
|
],
|
|
"response": [
|
|
map_response_from_bedrock()
|
|
],
|
|
}
|
|
|
|
|
|
TextGeneration = typing.TypedDict('TextGeneration',
|
|
{"text": str, "is_finished": str, "event_type": typing.Literal["text-generation"]})
|
|
StreamEnd = typing.TypedDict('StreamEnd',
|
|
{"is_finished": str, "event_type": typing.Literal["stream-end"], "finish_reason": str,
|
|
# "amazon-bedrock-invocationMetrics": {
|
|
# "inputTokenCount": int, "outputTokenCount": int, "invocationLatency": int,
|
|
# "firstByteLatency": int}
|
|
})
|
|
|
|
|
|
class Streamer(SyncByteStream):
|
|
lines: typing.Iterator[bytes]
|
|
|
|
def __init__(self, lines: typing.Iterator[bytes]):
|
|
self.lines = lines
|
|
|
|
def __iter__(self) -> typing.Iterator[bytes]:
|
|
return self.lines
|
|
|
|
|
|
response_mapping: typing.Dict[str, typing.Any] = {
|
|
"chat": NonStreamedChatResponse,
|
|
"embed": EmbedResponse,
|
|
"generate": Generation,
|
|
"rerank": RerankResponse
|
|
}
|
|
|
|
stream_response_mapping: typing.Dict[str, typing.Any] = {
|
|
"chat": StreamedChatResponse,
|
|
"generate": GenerateStreamedResponse,
|
|
}
|
|
|
|
|
|
def stream_generator(response: httpx.Response, endpoint: str) -> typing.Iterator[bytes]:
|
|
regex = r"{[^\}]*}"
|
|
|
|
for _text in response.iter_lines():
|
|
match = re.search(regex, _text)
|
|
if match:
|
|
obj = json.loads(match.group())
|
|
if "bytes" in obj:
|
|
base64_payload = base64.b64decode(obj["bytes"]).decode("utf-8")
|
|
streamed_obj = json.loads(base64_payload)
|
|
if "event_type" in streamed_obj:
|
|
response_type = stream_response_mapping[endpoint]
|
|
parsed = typing.cast(response_type, # type: ignore
|
|
construct_type(type_=response_type, object_=streamed_obj))
|
|
yield (json.dumps(parsed.dict()) + "\n").encode("utf-8") # type: ignore
|
|
|
|
|
|
def map_token_counts(response: httpx.Response) -> ApiMeta:
|
|
input_tokens = int(response.headers.get("X-Amzn-Bedrock-Input-Token-Count", -1))
|
|
output_tokens = int(response.headers.get("X-Amzn-Bedrock-Output-Token-Count", -1))
|
|
return ApiMeta(
|
|
tokens=ApiMetaTokens(input_tokens=input_tokens, output_tokens=output_tokens),
|
|
billed_units=ApiMetaBilledUnits(input_tokens=input_tokens, output_tokens=output_tokens),
|
|
)
|
|
|
|
|
|
def map_response_from_bedrock():
|
|
def _hook(
|
|
response: httpx.Response,
|
|
) -> None:
|
|
stream = response.headers["content-type"] == "application/vnd.amazon.eventstream"
|
|
endpoint = response.request.extensions["endpoint"]
|
|
output: typing.Iterator[bytes]
|
|
|
|
if stream:
|
|
output = stream_generator(httpx.Response(
|
|
stream=response.stream,
|
|
status_code=response.status_code,
|
|
), endpoint)
|
|
else:
|
|
response_type = response_mapping[endpoint]
|
|
response_obj = json.loads(response.read())
|
|
response_obj["meta"] = map_token_counts(response).dict()
|
|
cast_obj: typing.Any = typing.cast(response_type, # type: ignore
|
|
construct_type(
|
|
type_=response_type,
|
|
# type: ignore
|
|
object_=response_obj))
|
|
|
|
output = iter([json.dumps(cast_obj.dict()).encode("utf-8")])
|
|
|
|
response.stream = Streamer(output)
|
|
|
|
# reset response object to allow for re-reading
|
|
if hasattr(response, "_content"):
|
|
del response._content
|
|
response.is_stream_consumed = False
|
|
response.is_closed = False
|
|
|
|
return _hook
|
|
|
|
def get_boto3_session(
|
|
**kwargs: typing.Any,
|
|
):
|
|
non_none_args = {k: v for k, v in kwargs.items() if v is not None}
|
|
return lazy_boto3().Session(**non_none_args)
|
|
|
|
|
|
|
|
def map_request_to_bedrock(
|
|
service: str,
|
|
aws_access_key: typing.Optional[str] = None,
|
|
aws_secret_key: typing.Optional[str] = None,
|
|
aws_session_token: typing.Optional[str] = None,
|
|
aws_region: typing.Optional[str] = None,
|
|
) -> EventHook:
|
|
session = get_boto3_session(
|
|
region_name=aws_region,
|
|
aws_access_key_id=aws_access_key,
|
|
aws_secret_access_key=aws_secret_key,
|
|
aws_session_token=aws_session_token,
|
|
)
|
|
aws_region = session.region_name
|
|
credentials = session.get_credentials()
|
|
signer = lazy_botocore().auth.SigV4Auth(credentials, service, aws_region)
|
|
|
|
def _event_hook(request: httpx.Request) -> None:
|
|
headers = request.headers.copy()
|
|
del headers["connection"]
|
|
|
|
|
|
api_version = request.url.path.split("/")[-2]
|
|
endpoint = request.url.path.split("/")[-1]
|
|
body = json.loads(request.read())
|
|
model = body["model"]
|
|
|
|
url = get_url(
|
|
platform=service,
|
|
aws_region=aws_region,
|
|
model=model, # type: ignore
|
|
stream="stream" in body and body["stream"],
|
|
)
|
|
request.url = URL(url)
|
|
request.headers["host"] = request.url.host
|
|
|
|
if endpoint == "rerank":
|
|
body["api_version"] = get_api_version(version=api_version)
|
|
|
|
if "stream" in body:
|
|
del body["stream"]
|
|
|
|
if "model" in body:
|
|
del body["model"]
|
|
|
|
new_body = json.dumps(body).encode("utf-8")
|
|
request.stream = ByteStream(new_body)
|
|
request._content = new_body
|
|
headers["content-length"] = str(len(new_body))
|
|
|
|
aws_request = lazy_botocore().awsrequest.AWSRequest(
|
|
method=request.method,
|
|
url=url,
|
|
headers=headers,
|
|
data=request.read(),
|
|
)
|
|
signer.add_auth(aws_request)
|
|
|
|
request.headers = httpx.Headers(aws_request.prepare().headers)
|
|
request.extensions["endpoint"] = endpoint
|
|
|
|
return _event_hook
|
|
|
|
|
|
def get_url(
|
|
*,
|
|
platform: str,
|
|
aws_region: typing.Optional[str],
|
|
model: str,
|
|
stream: bool,
|
|
) -> str:
|
|
if platform == "bedrock":
|
|
endpoint = "invoke" if not stream else "invoke-with-response-stream"
|
|
return f"https://{platform}-runtime.{aws_region}.amazonaws.com/model/{model}/{endpoint}"
|
|
elif platform == "sagemaker":
|
|
endpoint = "invocations" if not stream else "invocations-response-stream"
|
|
return f"https://runtime.sagemaker.{aws_region}.amazonaws.com/endpoints/{model}/{endpoint}"
|
|
return ""
|
|
|
|
|
|
def get_api_version(*, version: str):
|
|
int_version = {
|
|
"v1": 1,
|
|
"v2": 2,
|
|
}
|
|
|
|
return int_version.get(version, 1) |