2025-05-16 18:00:22 +04:00

410 lines
13 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing
from typing import Any, Literal
import google.ai.generativelanguage as glm
from google.generativeai import operations
from google.generativeai.client import get_default_model_client
from google.generativeai.types import model_types
from google.api_core import operation
from google.api_core import protobuf_helpers
from google.protobuf import field_mask_pb2
def get_model(
name: model_types.AnyModelNameOptions, *, client=None
) -> model_types.Model | model_types.TunedModel:
"""Given a model name, fetch the `types.Model` or `types.TunedModel` object.
```
import pprint
model = genai.get_tuned_model(model_name):
pprint.pprint(model)
```
Args:
name: The name of the model to fetch.
client: The client to use.
Returns:
A `types.Model` or `types.TunedModel` object.
"""
name = model_types.make_model_name(name)
if name.startswith("models/"):
return get_base_model(name, client=client)
elif name.startswith("tunedModels/"):
return get_tuned_model(name, client=client)
else:
raise ValueError("Model names must start with `models/` or `tunedModels/`")
def get_base_model(name: model_types.BaseModelNameOptions, *, client=None) -> model_types.Model:
"""Get the `types.Model` for the given base model name.
```
import pprint
model = genai.get_model('models/chat-bison-001'):
pprint.pprint(model)
```
Args:
name: The name of the model to fetch.
client: The client to use.
Returns:
A `types.Model`.
"""
if client is None:
client = get_default_model_client()
name = model_types.make_model_name(name)
if not name.startswith("models/"):
raise ValueError(f"Base model names must start with `models/`, got: {name}")
result = client.get_model(name=name)
result = type(result).to_dict(result)
return model_types.Model(**result)
def get_tuned_model(
name: model_types.TunedModelNameOptions, *, client=None
) -> model_types.TunedModel:
"""Get the `types.TunedModel` for the given tuned model name.
```
import pprint
model = genai.get_tuned_model('tunedModels/my-model-1234'):
pprint.pprint(model)
```
Args:
name: The name of the model to fetch.
client: The client to use.
Returns:
A `types.TunedModel`.
"""
if client is None:
client = get_default_model_client()
name = model_types.make_model_name(name)
if not name.startswith("tunedModels/"):
raise ValueError("Tuned model names must start with `tunedModels/`")
result = client.get_tuned_model(name=name)
return model_types.decode_tuned_model(result)
def get_base_model_name(
model: model_types.AnyModelNameOptions, client: glm.ModelServiceClient | None = None
):
if isinstance(model, str):
if model.startswith("tunedModels/"):
model = get_model(model, client=client)
base_model = model.base_model
else:
base_model = model
elif isinstance(model, model_types.TunedModel):
base_model = model.base_model
elif isinstance(model, model_types.Model):
base_model = model.name
elif isinstance(model, glm.Model):
base_model = model.name
elif isinstance(model, glm.TunedModel):
base_model = getattr(model, "base_model", None)
if not base_model:
base_model = model.tuned_model_source.base_model
else:
raise TypeError(f"Cannot understand model: {model}")
return base_model
def list_models(
*,
page_size: int | None = 50,
client: glm.ModelServiceClient | None = None,
) -> model_types.ModelsIterable:
"""Lists available models.
```
import pprint
for model in genai.list_models():
pprint.pprint(model)
```
Args:
page_size: How many `types.Models` to fetch per page (api call).
client: You may pass a `glm.ModelServiceClient` instead of using the default client.
Yields:
`types.Model` objects.
"""
if client is None:
client = get_default_model_client()
for model in client.list_models(page_size=page_size):
model = type(model).to_dict(model)
yield model_types.Model(**model)
def list_tuned_models(
*,
page_size: int | None = 50,
client: glm.ModelServiceClient | None = None,
) -> model_types.TunedModelsIterable:
"""Lists available models.
```
import pprint
for model in genai.list_tuned_models():
pprint.pprint(model)
```
Args:
page_size: How many `types.Models` to fetch per page (api call).
client: You may pass a `glm.ModelServiceClient` instead of using the default client.
Yields:
`types.TunedModel` objects.
"""
if client is None:
client = get_default_model_client()
for model in client.list_tuned_models(page_size=page_size):
model = type(model).to_dict(model)
yield model_types.decode_tuned_model(model)
def create_tuned_model(
source_model: model_types.AnyModelNameOptions,
training_data: model_types.TuningDataOptions,
*,
id: str | None = None,
display_name: str | None = None,
description: str | None = None,
temperature: float | None = None,
top_p: float | None = None,
top_k: int | None = None,
epoch_count: int | None = None,
batch_size: int | None = None,
learning_rate: float | None = None,
input_key: str = "text_input",
output_key: str = "output",
client: glm.ModelServiceClient | None = None,
) -> operations.CreateTunedModelOperation:
"""Launches a tuning job to create a TunedModel.
Since tuning a model can take significant time, this API doesn't wait for the tuning to complete.
Instead, it returns a `google.api_core.operation.Operation` object that lets you check on the
status of the tuning job, or wait for it to complete, and check the result.
After the job completes you can either find the resulting `TunedModel` object in
`Operation.result()` or `palm.list_tuned_models` or `palm.get_tuned_model(model_id)`.
```
my_id = "my-tuned-model-id"
operation = palm.create_tuned_model(
id = my_id,
source_model="models/text-bison-001",
training_data=[{'text_input': 'example input', 'output': 'example output'},...]
)
tuned_model=operation.result() # Wait for tuning to finish
palm.generate_text(f"tunedModels/{my_id}", prompt="...")
```
Args:
source_model: The name of the model to tune.
training_data: The dataset to tune the model on. This must be either:
* A `glm.Dataset`, or
* An `Iterable` of:
*`glm.TuningExample`,
* {'text_input': text_input, 'output': output} dicts, or
* `(text_input, output)` tuples.
* A `Mapping` of `Iterable[str]` - use `input_key` and `output_key` to choose which
columns to use as the input/output
* A csv file (will be read with `pd.read_csv` and handles as a `Mapping`
above). This can be:
* A local path as a `str` or `pathlib.Path`.
* A url for a csv file.
* The url of a Google Sheets file.
* A JSON file - Its contents will be handled either as an `Iterable` or `Mapping`
above. This can be:
* A local path as a `str` or `pathlib.Path`.
id: The model identifier, used to refer to the model in the API
`tunedModels/{id}`. Must be unique.
display_name: A human-readable name for display.
description: A description of the tuned model.
temperature: The default temperature for the tuned model, see `types.Model` for details.
top_p: The default `top_p` for the model, see `types.Model` for details.
top_k: The default `top_k` for the model, see `types.Model` for details.
epoch_count: The number of tuning epochs to run. An epoch is a pass over the whole dataset.
batch_size: The number of examples to use in each training batch.
learning_rate: The step size multiplier for the gradient updates.
client: Which client to use.
Returns:
A [`google.api_core.operation.Operation`](https://googleapis.dev/python/google-api-core/latest/operation.html)
"""
if client is None:
client = get_default_model_client()
source_model_name = model_types.make_model_name(source_model)
base_model_name = get_base_model_name(source_model)
if source_model_name.startswith("models/"):
source_model = {"base_model": source_model_name}
elif source_model_name.startswith("tunedModels/"):
source_model = {
"tuned_model_source": {
"tuned_model": source_model_name,
"base_model": base_model_name,
}
}
else:
ValueError(f"Not understood: `{source_model=}`")
training_data = model_types.encode_tuning_data(
training_data, input_key=input_key, output_key=output_key
)
hyperparameters = glm.Hyperparameters(
epoch_count=epoch_count,
batch_size=batch_size,
learning_rate=learning_rate,
)
tuning_task = glm.TuningTask(
training_data=training_data,
hyperparameters=hyperparameters,
)
tuned_model = glm.TunedModel(
**source_model,
display_name=display_name,
description=description,
temperature=temperature,
top_p=top_p,
top_k=top_k,
tuning_task=tuning_task,
)
operation = client.create_tuned_model(dict(tuned_model_id=id, tuned_model=tuned_model))
return operations.CreateTunedModelOperation.from_core_operation(operation)
@typing.overload
def update_tuned_model(
tuned_model: glm.TunedModel,
updates: None = None,
*,
client: glm.ModelServiceClient | None = None,
) -> model_types.TunedModel:
pass
@typing.overload
def update_tuned_model(
tuned_model: str,
updates: dict[str, Any],
*,
client: glm.ModelServiceClient | None = None,
) -> model_types.TunedModel:
pass
def update_tuned_model(
tuned_model: str | glm.TunedModel,
updates: dict[str, Any] | None = None,
*,
client: glm.ModelServiceClient | None = None,
) -> model_types.TunedModel:
"""Push updates to the tuned model. Only certain attributes are updatable."""
if client is None:
client = get_default_model_client()
if isinstance(tuned_model, str):
name = tuned_model
if not isinstance(updates, dict):
raise TypeError(
"When calling `update_tuned_model(name:str, updates: dict)`,\n"
"`updates` must be a `dict`.\n"
f"got: {type(updates)}"
)
tuned_model = client.get_tuned_model(name=name)
updates = _flatten_update_paths(updates)
field_mask = field_mask_pb2.FieldMask()
for path in updates.keys():
field_mask.paths.append(path)
for path, value in updates.items():
_apply_update(tuned_model, path, value)
elif isinstance(tuned_model, glm.TunedModel):
if updates is not None:
raise ValueError(
"When calling `update_tuned_model(tuned_model:glm.TunedModel, updates=None)`,"
"`updates` must not be set."
)
name = tuned_model.name
was = client.get_tuned_model(name=name)
field_mask = protobuf_helpers.field_mask(was._pb, tuned_model._pb)
else:
raise TypeError(
"For `update_tuned_model(tuned_model:dict|glm.TunedModel)`,"
f"`tuned_model` must be a `dict` or a `glm.TunedModel`. Got a: `{type(tuned_model)}`"
)
result = client.update_tuned_model(
glm.UpdateTunedModelRequest(tuned_model=tuned_model, update_mask=field_mask)
)
return model_types.decode_tuned_model(result)
def _flatten_update_paths(updates):
new_updates = {}
for key, value in updates.items():
if isinstance(value, dict):
for sub_key, sub_value in _flatten_update_paths(value).items():
new_updates[f"{key}.{sub_key}"] = sub_value
else:
new_updates[key] = value
return new_updates
def _apply_update(thing, path, value):
parts = path.split(".")
for part in parts[:-1]:
thing = getattr(thing, part)
setattr(thing, parts[-1], value)
def delete_tuned_model(
tuned_model: model_types.TunedModelNameOptions,
client: glm.ModelServiceClient | None = None,
) -> None:
if client is None:
client = get_default_model_client()
name = model_types.make_model_name(tuned_model)
client.delete_tuned_model(name=name)