168 lines
6.1 KiB
Python
168 lines
6.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2023 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
from collections.abc import Iterable, Sequence, Mapping
|
|
import itertools
|
|
from typing import Iterable, overload, TypeVar, Union, Mapping
|
|
|
|
import google.ai.generativelanguage as glm
|
|
|
|
from google.generativeai.client import get_default_generative_client
|
|
|
|
from google.generativeai.types import text_types
|
|
from google.generativeai.types import model_types
|
|
from google.generativeai.types import content_types
|
|
|
|
DEFAULT_EMB_MODEL = "models/embedding-001"
|
|
EMBEDDING_MAX_BATCH_SIZE = 100
|
|
|
|
EmbeddingTaskType = glm.TaskType
|
|
|
|
EmbeddingTaskTypeOptions = Union[int, str, EmbeddingTaskType]
|
|
|
|
_EMBEDDING_TASK_TYPE: dict[EmbeddingTaskTypeOptions, EmbeddingTaskType] = {
|
|
EmbeddingTaskType.TASK_TYPE_UNSPECIFIED: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
|
|
0: EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
|
|
"task_type_unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
|
|
"unspecified": EmbeddingTaskType.TASK_TYPE_UNSPECIFIED,
|
|
EmbeddingTaskType.RETRIEVAL_QUERY: EmbeddingTaskType.RETRIEVAL_QUERY,
|
|
1: EmbeddingTaskType.RETRIEVAL_QUERY,
|
|
"retrieval_query": EmbeddingTaskType.RETRIEVAL_QUERY,
|
|
"query": EmbeddingTaskType.RETRIEVAL_QUERY,
|
|
EmbeddingTaskType.RETRIEVAL_DOCUMENT: EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
|
2: EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
|
"retrieval_document": EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
|
"document": EmbeddingTaskType.RETRIEVAL_DOCUMENT,
|
|
EmbeddingTaskType.SEMANTIC_SIMILARITY: EmbeddingTaskType.SEMANTIC_SIMILARITY,
|
|
3: EmbeddingTaskType.SEMANTIC_SIMILARITY,
|
|
"semantic_similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY,
|
|
"similarity": EmbeddingTaskType.SEMANTIC_SIMILARITY,
|
|
EmbeddingTaskType.CLASSIFICATION: EmbeddingTaskType.CLASSIFICATION,
|
|
4: EmbeddingTaskType.CLASSIFICATION,
|
|
"classification": EmbeddingTaskType.CLASSIFICATION,
|
|
EmbeddingTaskType.CLUSTERING: EmbeddingTaskType.CLUSTERING,
|
|
5: EmbeddingTaskType.CLUSTERING,
|
|
"clustering": EmbeddingTaskType.CLUSTERING,
|
|
}
|
|
|
|
|
|
def to_task_type(x: EmbeddingTaskTypeOptions) -> EmbeddingTaskType:
|
|
if isinstance(x, str):
|
|
x = x.lower()
|
|
return _EMBEDDING_TASK_TYPE[x]
|
|
|
|
|
|
try:
|
|
# python 3.12+
|
|
_batched = itertools.batched # type: ignore
|
|
except AttributeError:
|
|
T = TypeVar("T")
|
|
|
|
def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]:
|
|
if n < 1:
|
|
raise ValueError(f"Batch size `n` must be >0, got: {n}")
|
|
batch = []
|
|
for item in iterable:
|
|
batch.append(item)
|
|
if len(batch) == n:
|
|
yield batch
|
|
batch = []
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
|
|
@overload
|
|
def embed_content(
|
|
model: model_types.BaseModelNameOptions,
|
|
content: content_types.ContentType,
|
|
task_type: EmbeddingTaskTypeOptions | None = None,
|
|
title: str | None = None,
|
|
client: glm.GenerativeServiceClient | None = None,
|
|
) -> text_types.EmbeddingDict:
|
|
...
|
|
|
|
|
|
@overload
|
|
def embed_content(
|
|
model: model_types.BaseModelNameOptions,
|
|
content: Iterable[content_types.ContentType],
|
|
task_type: EmbeddingTaskTypeOptions | None = None,
|
|
title: str | None = None,
|
|
client: glm.GenerativeServiceClient | None = None,
|
|
) -> text_types.BatchEmbeddingDict:
|
|
...
|
|
|
|
|
|
def embed_content(
|
|
model: model_types.BaseModelNameOptions,
|
|
content: content_types.ContentType | Iterable[content_types.ContentType],
|
|
task_type: EmbeddingTaskTypeOptions | None = None,
|
|
title: str | None = None,
|
|
client: glm.GenerativeServiceClient = None,
|
|
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
|
|
"""
|
|
Calls the API to create embeddings for content passed in.
|
|
|
|
Args:
|
|
model: Which model to call, as a string or a `types.Model`.
|
|
|
|
content: Content to embed.
|
|
|
|
task_type: Optional task type for which the embeddings will be used. Can only be set for `models/embedding-001`.
|
|
|
|
title: An optional title for the text. Only applicable when task_type is `RETRIEVAL_DOCUMENT`.
|
|
|
|
Return:
|
|
Dictionary containing the embedding (list of float values) for the input content.
|
|
"""
|
|
model = model_types.make_model_name(model)
|
|
|
|
if client is None:
|
|
client = get_default_generative_client()
|
|
|
|
if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT:
|
|
raise ValueError(
|
|
"If a title is specified, the task must be a retrieval document type task."
|
|
)
|
|
|
|
if task_type:
|
|
task_type = to_task_type(task_type)
|
|
|
|
if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)):
|
|
result = {"embedding": []}
|
|
requests = (
|
|
glm.EmbedContentRequest(
|
|
model=model, content=content_types.to_content(c), task_type=task_type, title=title
|
|
)
|
|
for c in content
|
|
)
|
|
for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE):
|
|
embedding_request = glm.BatchEmbedContentsRequest(model=model, requests=batch)
|
|
embedding_response = client.batch_embed_contents(embedding_request)
|
|
embedding_dict = type(embedding_response).to_dict(embedding_response)
|
|
result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"])
|
|
return result
|
|
else:
|
|
embedding_request = glm.EmbedContentRequest(
|
|
model=model, content=content_types.to_content(content), task_type=task_type, title=title
|
|
)
|
|
embedding_response = client.embed_content(embedding_request)
|
|
embedding_dict = type(embedding_response).to_dict(embedding_response)
|
|
embedding_dict["embedding"] = embedding_dict["embedding"]["values"]
|
|
return embedding_dict
|