2025-05-16 18:00:22 +04:00

270 lines
11 KiB
Python

"""
This class allows for programmatic interactions with Atlas - Nomic's neural database. Initialize AtlasClient in any Python context such as a script
or in a Jupyter Notebook to organize and interact with your unstructured data.
"""
import uuid
from typing import Dict, List, Optional
import numpy as np
from loguru import logger
from tqdm import tqdm
from .project import AtlasProject
from .settings import *
from .utils import b64int, get_random_name
def map_embeddings(
embeddings: np.array,
data: List[Dict] = None,
id_field: str = None,
name: str = None,
description: str = None,
is_public: bool = True,
colorable_fields: list = [],
build_topic_model: bool = True,
topic_label_field: str = None,
num_workers: None = None,
organization_name: str = None,
reset_project_if_exists: bool = False,
add_datums_if_exists: bool = False,
shard_size: None = None,
projection_n_neighbors: int = DEFAULT_PROJECTION_N_NEIGHBORS,
projection_epochs: int = DEFAULT_PROJECTION_EPOCHS,
projection_spread: float = DEFAULT_PROJECTION_SPREAD,
) -> AtlasProject:
'''
Args:
embeddings: An [N,d] numpy array containing the batch of N embeddings to add.
data: An [N,] element list of dictionaries containing metadata for each embedding.
id_field: Specify your data unique id field. This field can be up 36 characters in length. If not specified, one will be created for you named `id_`.
name: A name for your map.
description: A description for your map.
is_public: Should this embedding map be public? Private maps can only be accessed by members of your organization.
colorable_fields: The project fields you want to be able to color by on the map. Must be a subset of the projects fields.
organization_name: The name of the organization to create this project under. You must be a member of the organization with appropriate permissions. If not specified, defaults to your user accounts default organization.
reset_project_if_exists: If the specified project exists in your organization, reset it by deleting all of its data. This means your uploaded data will not be contextualized with existing data.
add_datums_if_exists: If specifying an existing project and you want to add data to it, set this to true.
build_topic_model: Builds a hierarchical topic model over your data to discover patterns.
topic_label_field: The metadata field to estimate topic labels from. Usually the field you embedded.
projection_n_neighbors: The number of neighbors to build.
projection_epochs: The number of epochs to build the map with.
projection_spread: The spread of the map.
Returns:
An AtlasProject that now contains your map.
'''
assert isinstance(embeddings, np.ndarray), 'You must pass in a numpy array'
if embeddings.size == 0:
raise Exception("Your embeddings cannot be empty")
if id_field is None:
id_field = ATLAS_DEFAULT_ID_FIELD
project_name = get_random_name()
if description is None:
description = 'A description for your map.'
index_name = project_name
if name:
project_name = name
index_name = name
if description:
description = description
added_id_field = False
if data is None:
data = [{ATLAS_DEFAULT_ID_FIELD: b64int(i)} for i in range(len(embeddings))]
added_id_field = True
if id_field == ATLAS_DEFAULT_ID_FIELD and id_field not in data[0]:
added_id_field = True
for i in range(len(data)):
data[i][id_field] = b64int(i)
if added_id_field:
logger.warning("An ID field was not specified in your data so one was generated for you in insertion order.")
project = AtlasProject(
name=project_name,
description=description,
unique_id_field=id_field,
modality='embedding',
is_public=is_public,
organization_name=organization_name,
reset_project_if_exists=reset_project_if_exists,
add_datums_if_exists=add_datums_if_exists,
)
number_of_datums_before_upload = project.total_datums
# sends several requests to allow for threadpool refreshing. Threadpool hogs memory and new ones need to be created.
logger.info("Uploading embeddings to Atlas.")
embeddings = embeddings.astype(np.float16)
if shard_size is not None:
logger.warning("Passing `shard_size` is deprecated and will raise an error in a future release")
if num_workers is not None:
logger.warning("Passing `num_workers` is deprecated and will raise an error in a future release")
try:
project.add_embeddings(
embeddings=embeddings,
data=data,
)
except BaseException as e:
if number_of_datums_before_upload == 0:
logger.info(f"{project.name}: Deleting project due to failure in initial upload.")
project.delete()
raise e
logger.info("Embedding upload succeeded.")
# make a new index if there were no datums in the project before
if number_of_datums_before_upload == 0:
projection = project.create_index(
name=index_name,
colorable_fields=colorable_fields,
build_topic_model=build_topic_model,
projection_n_neighbors=projection_n_neighbors,
projection_epochs=projection_epochs,
projection_spread=projection_spread,
topic_label_field=topic_label_field,
)
logger.info(str(projection))
else:
# otherwise refresh the maps
project.rebuild_maps()
project = project._latest_project_state()
return project
def map_text(
data: List[Dict],
indexed_field: str,
id_field: str = None,
name: str = None,
description: str = None,
build_topic_model: bool = True,
multilingual: bool = False,
is_public: bool = True,
colorable_fields: list = [],
num_workers: None = None,
organization_name: str = None,
reset_project_if_exists: bool = False,
add_datums_if_exists: bool = False,
shard_size: None = None,
projection_n_neighbors: int = DEFAULT_PROJECTION_N_NEIGHBORS,
projection_epochs: int = DEFAULT_PROJECTION_EPOCHS,
projection_spread: float = DEFAULT_PROJECTION_SPREAD,
duplicate_detection: bool = False,
duplicate_threshold: float = DEFAULT_DUPLICATE_THRESHOLD,
) -> AtlasProject:
'''
Generates or updates a map of the given text.
Args:
data: An [N,] element list of dictionaries containing metadata for each embedding.
indexed_field: The name the data field containing the text your want to map.
id_field: Specify your data unique id field. This field can be up 36 characters in length. If not specified, one will be created for you named `id_`.
name: A name for your map.
description: A description for your map.
build_topic_model: Builds a hierarchical topic model over your data to discover patterns.
multilingual: Should the map take language into account? If true, points from different with semantically similar text are considered similar.
is_public: Should this embedding map be public? Private maps can only be accessed by members of your organization.
colorable_fields: The project fields you want to be able to color by on the map. Must be a subset of the projects fields.
organization_name: The name of the organization to create this project under. You must be a member of the organization with appropriate permissions. If not specified, defaults to your user account's default organization.
reset_project_if_exists: If the specified project exists in your organization, reset it by deleting all of its data. This means your uploaded data will not be contextualized with existing data.
add_datums_if_exists: If specifying an existing project and you want to add data to it, set this to true.
projection_n_neighbors: The number of neighbors to build.
projection_epochs: The number of epochs to build the map with.
projection_spread: The spread of the map.
Returns:
The AtlasProject containing your map.
'''
if id_field is None:
id_field = ATLAS_DEFAULT_ID_FIELD
project_name = get_random_name()
if description is None:
description = 'A description for your map.'
index_name = project_name
if name:
project_name = name
index_name = name
project = AtlasProject(
name=project_name,
description=description,
unique_id_field=id_field,
modality='text',
is_public=is_public,
organization_name=organization_name,
reset_project_if_exists=reset_project_if_exists,
add_datums_if_exists=add_datums_if_exists,
)
added_id_field = False
if id_field == ATLAS_DEFAULT_ID_FIELD and id_field not in data[0]:
added_id_field = True
for i in range(len(data)):
data[i][id_field] = b64int(i)
if added_id_field:
logger.warning("An ID field was not specified in your data so one was generated for you in insertion order.")
project._validate_map_data_inputs(colorable_fields=colorable_fields, id_field=id_field, data=data)
number_of_datums_before_upload = project.total_datums
logger.info("Uploading text to Atlas.")
if shard_size is not None:
logger.warning("Passing 'shard_size' is deprecated and will be removed in a future release.")
if num_workers is not None:
logger.warning("Passing 'num_workers' is deprecated and will be removed in a future release.")
try:
project.add_text(
data,
shard_size=None,
)
except BaseException as e:
if number_of_datums_before_upload == 0:
logger.info(f"{project.name}: Deleting project due to failure in initial upload.")
project.delete()
raise e
logger.info("Text upload succeeded.")
# make a new index if there were no datums in the project before
if number_of_datums_before_upload == 0:
projection = project.create_index(
name=index_name,
indexed_field=indexed_field,
colorable_fields=colorable_fields,
build_topic_model=build_topic_model,
projection_n_neighbors=projection_n_neighbors,
projection_epochs=projection_epochs,
projection_spread=projection_spread,
multilingual=multilingual,
duplicate_detection=duplicate_detection,
duplicate_threshold=duplicate_threshold,
)
logger.info(str(projection))
else:
# otherwise refresh the maps
project.rebuild_maps()
project = project._latest_project_state()
return project