import base64 import concurrent import concurrent.futures import io import json import os import pickle import time import uuid from collections import defaultdict from contextlib import contextmanager from datetime import date, datetime from pathlib import Path from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import pandas as pd import pyarrow as pa import requests from loguru import logger from pandas import DataFrame from pyarrow import compute as pc from pyarrow import feather, ipc from pydantic import BaseModel, Field from tqdm import tqdm import nomic from .cli import refresh_bearer_token, validate_api_http_response from .data_inference import convert_pyarrow_schema_for_atlas from .data_operations import AtlasMapDuplicates, AtlasMapEmbeddings, AtlasMapTags, AtlasMapTopics from .settings import * from .utils import assert_valid_project_id, get_object_size_in_bytes class AtlasUser: def __init__(self): self.credentials = refresh_bearer_token() class AtlasClass(object): def __init__(self): ''' Initializes the Atlas client. ''' if self.credentials['tenant'] == 'staging': api_hostname = 'staging-api-atlas.nomic.ai' web_hostname = 'staging-atlas.nomic.ai' elif self.credentials['tenant'] == 'production': api_hostname = 'api-atlas.nomic.ai' web_hostname = 'atlas.nomic.ai' else: raise ValueError("Invalid tenant.") self.atlas_api_path = f"https://{api_hostname}" self.web_path = f"https://{web_hostname}" try: override_api_path = os.environ['ATLAS_API_PATH'] except KeyError: override_api_path = None if override_api_path: self.atlas_api_path = override_api_path token = self.credentials['token'] self.token = token self.header = {"Authorization": f"Bearer {token}"} if self.token: response = requests.get( self.atlas_api_path + "/v1/user", headers=self.header, ) response = validate_api_http_response(response) if not response.status_code == 200: logger.warning(str(response)) logger.info("Your authorization token is no longer valid.") else: raise ValueError( "Could not find an authorization token. Run `nomic login` to authorize this client with the Nomic API." ) @property def credentials(self): return refresh_bearer_token() def _get_current_user(self): api_base_path = self.atlas_api_path if self.atlas_api_path.startswith('https://api-atlas.nomic.ai'): api_base_path = "https://no-cdn-api-atlas.nomic.ai" response = requests.get( api_base_path + "/v1/user", headers=self.header, ) response = validate_api_http_response(response) if not response.status_code == 200: raise ValueError("Your authorization token is no longer valid. Run `nomic login` to obtain a new one.") return response.json() def _validate_map_data_inputs(self, colorable_fields, id_field, data): '''Validates inputs to map data calls.''' if not isinstance(colorable_fields, list): raise ValueError("colorable_fields must be a list of fields") if id_field in colorable_fields: raise Exception(f'Cannot color by unique id field: {id_field}') for field in colorable_fields: if field not in data[0]: raise Exception(f"Cannot color by field `{field}` as it is not present in the metadata.") def _get_current_users_main_organization(self): ''' Retrieves the ID of the current users default organization. **Returns:** The ID of the current users default organization ''' user = self._get_current_user() for organization in user['organizations']: if organization['user_id'] == user['sub'] and organization['access_role'] == 'OWNER': return organization def _delete_project_by_id(self, project_id): response = requests.post( self.atlas_api_path + "/v1/project/remove", headers=self.header, json={'project_id': project_id}, ) def _get_project_by_id(self, project_id: str): ''' Args: project_id: The project id Returns: Returns the requested project. ''' assert_valid_project_id(project_id) response = requests.get( self.atlas_api_path + f"/v1/project/{project_id}", headers=self.header, ) if response.status_code != 200: raise Exception(f"Could not access project with id {project_id}: {response.text}") return response.json() def _get_index_job(self, job_id: str): ''' Args: job_id: The job id to retrieve the state of. Returns: Job ID meta-data. ''' response = requests.get( self.atlas_api_path + f"/v1/project/index/job/{job_id}", headers=self.header, ) if response.status_code != 200: raise Exception(f'Could not access job state: {response.text}') return response.json() def _validate_and_correct_arrow_upload(self, data: pa.Table, project: "AtlasProject") -> pa.Table: ''' Private method. validates upload data against the project arrow schema, and associated other checks. 1. If unique_id_field is specified, validates that each datum has that field. If not, adds it and then notifies the user that it was added. Args: data: an arrow table. project: the atlas project you are validating the data for. Returns: ''' if not isinstance(data, pa.Table): raise Exception("Invalid data type for upload: {}".format(type(data))) if project.meta['modality'] == 'text': if "_embeddings" in data: msg = "Can't add embeddings to a text project." raise ValueError(msg) if project.meta['modality'] == 'embedding': if "_embeddings" not in data.column_names: msg = "Must include embeddings in embedding project upload." raise ValueError(msg) if project.id_field not in data.column_names: raise ValueError(f'Data must contain the ID column `{project.id_field}`') if project.schema is None: project._schema = convert_pyarrow_schema_for_atlas(data.schema) # Reformat to match the schema of the project. # This includes shuffling the order around if necessary, # filling in nulls, etc. reformatted = {} if data[project.id_field].null_count > 0: raise ValueError( f"{project.id_field} must not contain null values, but {data[project.id_field].null_count} found." ) for field in project.schema: if field.name in data.column_names: # Allow loss of precision in dates and ints, etc. reformatted[field.name] = data[field.name].cast(field.type, safe=False) else: raise KeyError( f"Field {field.name} present in table schema not found in data. Present fields: {data.column_names}" ) if pa.types.is_string(field.type): # Ugly temporary measures if data[field.name].null_count > 0: logger.warning( f"Replacing {data[field.name].null_count} null values for field {field.name} with string 'null'. This behavior will change in a future version." ) reformatted[field.name] = pc.fill_null(reformatted[field.name], "null") if pc.any(pc.equal(pc.binary_length(reformatted[field.name]), 0)): mask = pc.equal(pc.binary_length(reformatted[field.name]), 0).combine_chunks() assert pa.types.is_boolean(mask.type) reformatted[field.name] = pc.replace_with_mask(reformatted[field.name], mask, "null") for field in data.schema: if not field.name in reformatted: if field.name == "_embeddings": reformatted['_embeddings'] = data['_embeddings'] else: logger.warning(f"Field {field.name} present in data, but not found in table schema. Ignoring.") data = pa.Table.from_pydict(reformatted, schema=project.schema) if project.meta['insert_update_delete_lock']: raise Exception("Project is currently indexing and cannot ingest new datums. Try again later.") # The following two conditions should never occur given the above, but just in case... assert project.id_field in data.column_names, f"Upload does not contain your specified id_field" if not pa.types.is_string(data[project.id_field].type): logger.warning(f"id_field is not a string. Converting to string from {data[project.id_field].type}") data = data.drop([project.id_field]).append_column( project.id_field, data[project.id_field].cast(pa.string()) ) for key in data.column_names: if key.startswith('_'): if key == '_embeddings': continue raise ValueError('Metadata fields cannot start with _') if pc.max(pc.utf8_length(data[project.id_field])).as_py() > 36: first_match = data.filter(pc.greater(pc.utf8_length(data[project.id_field]), 36)).to_pylist()[0][ project.id_field ] raise ValueError( f"The id_field {first_match} is greater than 36 characters. Atlas does not support id_fields longer than 36 characters." ) return data def _get_organization(self, organization_name=None, organization_id=None) -> Tuple[str, str]: ''' Gets an organization by either it's name or id. Args: organization_name: the name of the organization organization_id: the id of the organization Returns: The organization_name and organization_id if one was found. ''' if organization_name is None: if organization_id is None: # default to current users organization (the one with their name) organization = self._get_current_users_main_organization() organization_name = organization['nickname'] organization_id = organization['organization_id'] else: raise NotImplementedError("Getting organization by a specific ID is not yet implemented.") else: organization_id_request = requests.get( self.atlas_api_path + f"/v1/organization/search/{organization_name}", headers=self.header ) if organization_id_request.status_code != 200: user = self._get_current_user() users_organizations = [org['nickname'] for org in user['organizations']] raise Exception( f"No such organization exists: {organization_name}. You have access to the following organizations: {users_organizations}" ) organization_id = organization_id_request.json()['organization_id'] return organization_name, organization_id def _get_existing_project_by_name(self, project_name, organization_name) -> Dict: ''' Utility method for instantiating an AtlasProject. Retrieves an existing project by name in a given organization. Fail Args: project_name: the project name organization_name: the organization name Returns: A dictionary containing the project_id, organization_id and organization_name ''' # check if this project already exists. response = requests.post( self.atlas_api_path + "/v1/project/search/name", headers=self.header, json={'organization_name': organization_name, 'project_name': project_name}, ) if response.status_code != 200: raise Exception(f"Failed to find project: {response.text}") search_results = response.json()['results'] if search_results: existing_project = search_results[0] existing_project_id = existing_project['id'] return { 'project_id': existing_project_id, 'organization_name': existing_project['owner'], } organization_name, organization_id = self._get_organization(organization_name=organization_name) return {'organization_id': organization_id, 'organization_name': organization_name} class AtlasIndex: """ An AtlasIndex represents a single view of an Atlas Project at a point in time. An AtlasIndex typically contains one or more *projections* which are 2D representations of the points in the index that you can browse online. """ def __init__(self, atlas_index_id, name, indexed_field, projections): '''Initializes an Atlas index. Atlas indices organize data and store views of the data as maps.''' self.id = atlas_index_id self.name = name self.indexed_field = indexed_field self.projections = projections def _repr_html_(self): return '
'.join([d._repr_html_() for d in self.projections]) class AtlasProjection: ''' Interact and access state of an Atlas Map including text/vector search. This class should not be instantiated directly. Instead instantiate an AtlasProject and use the project.maps attribute to retrieve an AtlasProjection. ''' def __init__(self, project: "AtlasProject", atlas_index_id: str, projection_id: str, name): """ Creates an AtlasProjection. """ self.project = project self.id = projection_id self.atlas_index_id = atlas_index_id self.projection_id = projection_id self.name = name self._duplicates = None self._embeddings = None self._topics = None self._tags = None self._tile_data = None @property def map_link(self): ''' Retrieves a map link. ''' return f"{self.project.web_path}/map/{self.project.id}/{self.id}" @property def _status(self): response = requests.get( self.project.atlas_api_path + f"/v1/project/index/job/progress/{self.atlas_index_id}", headers=self.project.header, ) if response.status_code != 200: raise Exception(response.text) content = response.json() return content def __str__(self): return f"{self.name}: {self.map_link}" def __repr__(self): return self.__str__() def _iframe(self): return f""" """ def _embed_html(self): return f"""
Hide embedded project
Explore on atlas.nomic.ai
{self._iframe()} """ def _repr_html_(self): # Don't make an iframe if the project is locked. state = self._status['index_build_stage'] if state != 'Completed': return f"""Atlas Projection {self.name}. Status {state}. view online""" return f"""

Project: {self.name}

{self._embed_html()} """ @property def duplicates(self): """Duplicate detection state""" if self.project.is_locked: raise Exception('Project is locked! Please wait until the project is unlocked to access duplicates.') if self._duplicates is None: self._duplicates = AtlasMapDuplicates(self) return self._duplicates @property def topics(self): """Topic state""" if self.project.is_locked: raise Exception('Project is locked for state access! Please wait until the project is unlocked to access topics.') if self._topics is None: self._topics = AtlasMapTopics(self) return self._topics @property def embeddings(self): """Embedding state""" if self.project.is_locked: raise Exception('Project is locked for state access! Please wait until the project is unlocked to access embeddings.') if self._embeddings is None: self._embeddings = AtlasMapEmbeddings(self) return self._embeddings @property def tags(self): """Embedding state""" if self._tags is None: self._tags = AtlasMapTags(self) return self._tags def _fetch_tiles(self, overwrite: bool = True): """ Downloads all web data for the projection to the specified directory and returns it as a memmapped arrow table. Args: overwrite: If True then overwrite web tile files. Returns: An Arrow table containing information for all data points in the index. """ if self._tile_data is not None: return self._tile_data self._download_feather(overwrite=overwrite) tbs = [] root = feather.read_table(self.tile_destination / "0/0/0.feather") try: sidecars = set([v for k, v in json.loads(root.schema.metadata[b'sidecars']).items()]) except KeyError: sidecars = [] for path in self.tile_destination.glob('**/*.feather'): if len(path.stem.split(".")) > 1: # Sidecars are loaded alongside continue tb = pa.feather.read_table(path) for sidecar_file in sidecars: carfile = pa.feather.read_table(path.parent / f"{path.stem}.{sidecar_file}.feather") for col in carfile.column_names: tb = tb.append_column(col, carfile[col]) tbs.append(tb) self._tile_data = pa.concat_tables(tbs) return self._tile_data @property def tile_destination(self): return Path("~/.nomic/cache", self.id).expanduser() def _download_feather(self, dest: Optional[Union[str, Path]] = None, overwrite: bool = True): ''' Downloads the feather tree. Args: overwrite: if True then overwrite existing feather files. Returns: A list containing all quadtiles downloads ''' self.tile_destination.mkdir(parents=True, exist_ok=True) root = f'{self.project.atlas_api_path}/v1/project/public/{self.project.id}/index/projection/{self.id}/quadtree/' quads = [f'0/0/0'] all_quads = [] sidecars = None while len(quads) > 0: rawquad = quads.pop(0) quad = rawquad + ".feather" all_quads.append(quad) path = self.tile_destination / quad if not path.exists() or overwrite: data = requests.get(root + quad) readable = io.BytesIO(data.content) readable.seek(0) tb = feather.read_table(readable) path.parent.mkdir(parents=True, exist_ok=True) feather.write_feather(tb, path) schema = ipc.open_file(path).schema if sidecars is None and b'sidecars' in schema.metadata: # Grab just the filenames sidecars = set([v for k, v in json.loads(schema.metadata.get(b'sidecars')).items()]) elif sidecars is None: sidecars = set() if not "." in rawquad: for sidecar in sidecars: # The sidecar loses the feather suffix because it's supposed to be raw. quads.append(quad.replace(".feather", f'.{sidecar}')) if not schema.metadata or b'children' not in schema.metadata: # Sidecars don't have children. continue kids = schema.metadata.get(b'children') children = json.loads(kids) quads.extend(children) return all_quads @property def datum_id_field(self): return self.project.meta["unique_id_field"] def _get_atoms(self, ids: List[str]) -> List[Dict]: ''' Retrieves atoms by id Args: ids: list of atom ids Returns: A dictionary containing the resulting atoms, keyed by atom id. ''' if not isinstance(ids, list): raise ValueError("You must specify a list of ids when getting data.") response = requests.post( self.project.atlas_api_path + "/v1/project/atoms/get", headers=self.project.header, json={'project_id': self.project.id, 'index_id': self.atlas_index_id, 'atom_ids': ids}, ) if response.status_code == 200: return response.json()['atoms'] else: raise Exception(response.text) class AtlasProject(AtlasClass): def __init__( self, name: Optional[str] = None, description: Optional[str] = 'A description for your map.', unique_id_field: Optional[str] = None, modality: Optional[str] = None, organization_name: Optional[str] = None, is_public: bool = True, project_id=None, reset_project_if_exists=False, add_datums_if_exists=True, ): """ Creates or loads an Atlas project. Atlas projects store data (text, embeddings, etc) that you can organize by building indices. If the organization already contains a project with this name, it will be returned instead. **Parameters:** * **project_name** - The name of the project. * **description** - A description for the project. * **unique_id_field** - The field that uniquely identifies each datum. If a datum does not contain this field, it will be added and assigned a random unique ID. * **modality** - The data modality of this project. Currently, Atlas supports either `text` or `embedding` modality projects. * **organization_name** - The name of the organization to create this project under. You must be a member of the organization with appropriate permissions. If not specified, defaults to your user account's default organization. * **is_public** - Should this project be publicly accessible for viewing (read only). If False, only members of your Nomic organization can view. * **reset_project_if_exists** - If the requested project exists in your organization, will delete it and re-create it. * **project_id** - An alternative way to retrieve a project is by passing the project_id directly. This only works if a project exists. * **reset_project_if_exists** - If the specified project exists in your organization, reset it by deleting all of its data. This means your uploaded data will not be contextualized with existing data. * **add_datums_if_exists** - If specifying an existing project and you want to add data to it, set this to true. """ assert name is not None or project_id is not None, "You must pass a name or project_id" super().__init__() if project_id is not None: self.meta = self._get_project_by_id(project_id) return if organization_name is None: organization_name = self._get_current_users_main_organization()['nickname'] results = self._get_existing_project_by_name(project_name=name, organization_name=organization_name) if 'project_id' in results: # project already exists organization_name = results['organization_name'] project_id = results['project_id'] if reset_project_if_exists: # reset the project logger.info( f"Found existing project `{name}` in organization `{organization_name}`. Clearing it of data by request." ) self._delete_project_by_id(project_id=project_id) project_id = None elif not add_datums_if_exists: # prevent adding datums to existing project explicitly raise ValueError( f"Project already exists with the name `{name}` in organization `{organization_name}`. " f"You can add datums to it by settings `add_datums_if_exists = True` or reset it by specifying `reset_project_if_exists=True` on a new upload." ) else: logger.info(f"Loading existing project `{name}` from organization `{organization_name}`.") if project_id is None: # if there is no existing project, make a new one. if unique_id_field is None: unique_id_field = ATLAS_DEFAULT_ID_FIELD raise ValueError("You must specify a unique_id_field when creating a new project.") if modality is None: raise ValueError("You must specify a modality when creating a new project.") assert modality in ['text', 'embedding'], "Modality must be either `text` or `embedding`" assert name is not None project_id = self._create_project( project_name=name, description=description, unique_id_field=unique_id_field, modality=modality, organization_name=organization_name, is_public=is_public, ) self.meta = self._get_project_by_id(project_id=project_id) self._schema = None def delete(self): ''' Deletes an atlas project with all associated metadata. ''' organization = self._get_current_users_main_organization() organization_name = organization['nickname'] logger.info(f"Deleting project `{self.name}` from organization `{organization_name}`") self._delete_project_by_id(project_id=self.id) return False def _create_project( self, project_name: str, description: Optional[str], unique_id_field: str, modality: str, organization_name: Optional[str] = None, is_public: bool = True, ): ''' Creates an Atlas project. Atlas projects store data (text, embeddings, etc) that you can organize by building indices. If the organization already contains a project with this name, it will be returned instead. **Parameters:** * **project_name** - The name of the project. * **description** - A description for the project. * **unique_id_field** - The field that uniquely identifies each datum. If a datum does not contain this field, it will be added and assigned a random unique ID. * **modality** - The data modality of this project. Currently, Atlas supports either `text` or `embedding` modality projects. * **organization_name** - The name of the organization to create this project under. You must be a member of the organization with appropriate permissions. If not specified, defaults to your user accounts default organization. * **is_public** - Should this project be publicly accessible for viewing (read only). If False, only members of your Nomic organization can view. **Returns:** project_id on success. ''' organization_name, organization_id = self._get_organization(organization_name=organization_name) supported_modalities = ['text', 'embedding'] if modality not in supported_modalities: msg = 'Tried to create project with modality: {}, but Atlas only supports: {}'.format( modality, supported_modalities ) raise ValueError(msg) if unique_id_field is None: raise ValueError("You must specify a unique id field") logger.info(f"Creating project `{project_name}` in organization `{organization_name}`") if description is None: description = "" response = requests.post( self.atlas_api_path + "/v1/project/create", headers=self.header, json={ 'organization_id': organization_id, 'project_name': project_name, 'description': description, 'unique_id_field': unique_id_field, 'modality': modality, 'is_public': is_public, }, ) if response.status_code != 201: raise Exception(f"Failed to create project: {response.json()}") return response.json()['project_id'] def _latest_project_state(self): ''' Refreshes the project's state. Try to call this sparingly but use it when you need it. ''' self.meta = self._get_project_by_id(self.id) return self @property def indices(self) -> List[AtlasIndex]: self._latest_project_state() output = [] for index in self.meta['atlas_indices']: projections = [] for projection in index['projections']: projection = AtlasProjection( project=self, projection_id=projection['id'], atlas_index_id=index['id'], name=index['index_name'] ) projections.append(projection) index = AtlasIndex( atlas_index_id=index['id'], name=index['index_name'], indexed_field=index['indexed_field'], projections=projections, ) output.append(index) return output @property def projections(self) -> List[AtlasProjection]: output = [] for index in self.indices: for projection in index.projections: output.append(projection) return output @property def maps(self) -> List[AtlasProjection]: return self.projections @property def id(self) -> str: '''The UUID of the project.''' return self.meta['id'] @property def id_field(self) -> str: return self.meta['unique_id_field'] @property def total_datums(self) -> int: '''The total number of data points in the project.''' return self.meta['total_datums_in_project'] @property def modality(self) -> str: return self.meta['modality'] @property def name(self) -> str: '''The name of the project.''' return self.meta['project_name'] @property def description(self): return self.meta['description'] @property def project_fields(self): return self.meta['project_fields'] @property def is_locked(self) -> bool: self._latest_project_state() return self.meta['insert_update_delete_lock'] @property def schema(self) -> Optional[pa.Schema]: if self._schema is not None: return self._schema if 'schema' in self.meta and self.meta['schema'] is not None: self._schema: pa.Schema = ipc.read_schema(io.BytesIO(base64.b64decode(self.meta['schema']))) return self._schema return None @property def is_accepting_data(self) -> bool: ''' Checks if the project can accept data. Projects cannot accept data when they are being indexed. Returns: True if project is unlocked for data additions, false otherwise. ''' return not self.is_locked @contextmanager def wait_for_project_lock(self): '''Blocks thread execution until project is in a state where it can ingest data.''' has_logged = False while True: if self.is_accepting_data: yield self break if not has_logged: logger.info(f"{self.name}: Waiting for Project Lock Release.") has_logged = True time.sleep(5) def get_map(self, name: str = None, atlas_index_id: str = None, projection_id: str = None) -> AtlasProjection: ''' Retrieves a Map Args: name: The name of your map. This defaults to your projects name but can be different if you build multiple maps in your project. atlas_index_id: If specified, will only return a map if there is one built under the index with the id atlas_index_id. projection_id: If projection_id is specified, will only return a map if there is one built under the index with id projection_id. Returns: The map or a ValueError. ''' indices = self.indices if atlas_index_id is not None: for index in indices: if index.id == atlas_index_id: if len(index.projections) == 0: raise ValueError(f"No map found under index with atlas_index_id='{atlas_index_id}'") return index.projections[0] raise ValueError(f"Could not find a map with atlas_index_id='{atlas_index_id}'") if projection_id is not None: for index in indices: for projection in index.projections: if projection.id == projection_id: return projection raise ValueError(f"Could not find a map with projection_id='{atlas_index_id}'") if len(indices) == 0: raise ValueError("You have no maps built in your project") if len(indices) > 1 and name is None: raise ValueError("You have multiple maps in this project, specify a name.") if len(indices) == 1: if len(indices[0].projections) == 1: return indices[0].projections[0] for index in indices: if index.name == name: return index.projections[0] raise ValueError(f"Could not find a map named {name} in your project.") def create_index( self, name: str, indexed_field: str = None, colorable_fields: list = [], multilingual: bool = False, build_topic_model: bool = False, projection_n_neighbors: int = DEFAULT_PROJECTION_N_NEIGHBORS, projection_epochs: int = DEFAULT_PROJECTION_EPOCHS, projection_spread: float = DEFAULT_PROJECTION_SPREAD, topic_label_field: str = None, reuse_embeddings_from_index: str = None, duplicate_detection: bool = False, duplicate_threshold: float = DEFAULT_DUPLICATE_THRESHOLD, ) -> AtlasProjection: ''' Creates an index in the specified project. Args: name: The name of the index and the map. indexed_field: For text projects, name the data field corresponding to the text to be mapped. colorable_fields: The project fields you want to be able to color by on the map. Must be a subset of the projects fields. multilingual: Should the map take language into account? If true, points from different languages but semantically similar text are close together. build_topic_model: Should a topic model be built? projection_n_neighbors: A projection hyperparameter projection_epochs: A projection hyperparameter projection_spread: A projection hyperparameter topic_label_field: A text field in your metadata to estimate topic labels from. Defaults to the indexed_field for text projects if not specified. reuse_embeddings_from_index: the name of the index to reuse embeddings from. duplicate_detection: A boolean whether to run duplicate detection duplicate_threshold: At which threshold to consider points to be duplicates Returns: The projection this index has built. ''' self._latest_project_state() # for large projects, alter the default projection configurations. if self.total_datums >= 1_000_000: if ( projection_epochs == DEFAULT_PROJECTION_EPOCHS and projection_n_neighbors == DEFAULT_PROJECTION_N_NEIGHBORS ): projection_n_neighbors = DEFAULT_LARGE_PROJECTION_N_NEIGHBORS projection_epochs = DEFAULT_LARGE_PROJECTION_EPOCHS if self.modality == 'embedding': if duplicate_detection: raise ValueError("Cannot tag duplicates in an embedding project.") build_template = { 'project_id': self.id, 'index_name': name, 'indexed_field': None, 'atomizer_strategies': None, 'model': None, 'colorable_fields': colorable_fields, 'model_hyperparameters': None, 'nearest_neighbor_index': 'HNSWIndex', 'nearest_neighbor_index_hyperparameters': json.dumps({'space': 'l2', 'ef_construction': 100, 'M': 16}), 'projection': 'NomicProject', 'projection_hyperparameters': json.dumps( {'n_neighbors': projection_n_neighbors, 'n_epochs': projection_epochs, 'spread': projection_spread} ), 'topic_model_hyperparameters': json.dumps( {'build_topic_model': build_topic_model, 'community_description_target_field': topic_label_field} ), } elif self.modality == 'text': # find the index id of the index with name reuse_embeddings_from_index reuse_embedding_from_index_id = None indices = self.indices if reuse_embeddings_from_index is not None: for index in indices: if index.name == reuse_embeddings_from_index: reuse_embedding_from_index_id = index.id break if reuse_embedding_from_index_id is None: raise Exception( f"Could not find the index '{reuse_embeddings_from_index}' to re-use from. Possible options are {[index.name for index in indices]}" ) if indexed_field is None: raise Exception("You did not specify a field to index. Specify an 'indexed_field'.") if indexed_field not in self.project_fields: raise Exception(f"Indexing on {indexed_field} not allowed. Valid options are: {self.project_fields}") model = 'NomicEmbed' if multilingual: model = 'NomicEmbedMultilingual' build_template = { 'project_id': self.id, 'index_name': name, 'indexed_field': indexed_field, 'atomizer_strategies': ['document', 'charchunk'], 'model': model, 'colorable_fields': colorable_fields, 'reuse_atoms_and_embeddings_from': reuse_embedding_from_index_id, 'model_hyperparameters': json.dumps( { 'dataset_buffer_size': 1000, 'batch_size': 20, 'polymerize_by': 'charchunk', 'norm': 'both', } ), 'nearest_neighbor_index': 'HNSWIndex', 'nearest_neighbor_index_hyperparameters': json.dumps({'space': 'l2', 'ef_construction': 100, 'M': 16}), 'projection': 'NomicProject', 'projection_hyperparameters': json.dumps( {'n_neighbors': projection_n_neighbors, 'n_epochs': projection_epochs, 'spread': projection_spread} ), 'topic_model_hyperparameters': json.dumps( {'build_topic_model': build_topic_model, 'community_description_target_field': indexed_field} ), 'duplicate_detection_hyperparameters': json.dumps( {'tag_duplicates': duplicate_detection, 'duplicate_cutoff': duplicate_threshold} ), } response = requests.post( self.atlas_api_path + "/v1/project/index/create", headers=self.header, json=build_template, ) if response.status_code != 200: logger.info('Create project failed with code: {}'.format(response.status_code)) logger.info('Additional info: {}'.format(response.text)) raise Exception(response.json()['detail']) job_id = response.json()['job_id'] job = requests.get( self.atlas_api_path + f"/v1/project/index/job/{job_id}", headers=self.header, ).json() index_id = job['index_id'] try: projection = self.get_map(atlas_index_id=index_id) except ValueError: # give some delay time.sleep(5) try: projection = self.get_map(atlas_index_id=index_id) except ValueError: projection = None if projection is None: logger.warning( "Could not find a map being built for this project. See atlas.nomic.ai/dashboard for map status." ) logger.info(f"Created map `{projection.name}` in project `{self.name}`: {projection.map_link}") return projection def __repr__(self): m = self.meta return f"AtlasProject: <{m}>" def _repr_html_(self): self._latest_project_state() m = self.meta html = f""" {m['project_name']}
{m['description']} {m['total_datums_in_project']} datums inserted.
{len(m['atlas_indices'])} index built. """ complete_projections = [] if len(self.projections) >= 1: html += "
Projections\n" html += "" if len(complete_projections) >= 1: # Display most recent complete projection. html += "
" html += complete_projections[-1]._embed_html() return html def __str__(self): return "\n".join([str(projection) for index in self.indices for projection in index.projections]) def get_data(self, ids: List[str]) -> List[Dict]: ''' Retrieve the contents of the data given ids Args: ids: a list of datum ids Returns: A list of dictionaries corresponding ''' if not isinstance(ids, list): raise ValueError("You must specify a list of ids when getting data.") if isinstance(ids[0], list): raise ValueError("You must specify a list of ids when getting data, not a nested list.") response = requests.post( self.atlas_api_path + "/v1/project/data/get", headers=self.header, json={'project_id': self.id, 'datum_ids': ids}, ) if response.status_code == 200: return [item for item in response.json()['datums']] else: raise Exception(response.text) def delete_data(self, ids: List[str]) -> bool: ''' Deletes the specified datums from the project. Args: ids: A list of datum ids to delete Returns: True if data deleted successfully. ''' if not isinstance(ids, list): raise ValueError("You must specify a list of ids when deleting datums.") response = requests.post( self.atlas_api_path + "/v1/project/data/delete", headers=self.header, json={'project_id': self.id, 'datum_ids': ids}, ) if response.status_code == 200: return True else: raise Exception(response.text) def add_text(self, data=Union[DataFrame, List[Dict], pa.Table], pbar=None, shard_size=None, num_workers=None): """ Add text data to the project. data: A pandas DataFrame, a list of python dictionaries, or a pyarrow Table matching the project schema. pbar: (Optional). A tqdm progress bar to display progress. """ if shard_size is not None or num_workers is not None: raise DeprecationWarning("shard_size and num_workers are deprecated.") if DataFrame is not None and isinstance(data, DataFrame): data = pa.Table.from_pandas(data) elif isinstance(data, list): data = pa.Table.from_pylist(data) elif not isinstance(data, pa.Table): raise ValueError("Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table.") self._add_data(data, pbar=pbar) def add_embeddings( self, data: Union[DataFrame, List[Dict], pa.Table, None], embeddings: np.array, pbar=None, shard_size=None, num_workers=None, ): """ Add data, with associated embeddings, to the project. Args: data: A pandas DataFrame, list of dictionaries, or pyarrow Table matching the project schema. embeddings: A numpy array of embeddings: each row corresponds to a row in the table. pbar: (Optional). A tqdm progress bar to update. """ """ # TODO: validate embedding size. assert embeddings.shape[1] == self.embedding_size, "Embedding size must match the embedding size of the project." """ if shard_size is not None: raise DeprecationWarning("shard_size is deprecated and no longer has any effect") if num_workers is not None: raise DeprecationWarning("num_workers is deprecated and no longer has any effect") assert type(embeddings) == np.ndarray, "Embeddings must be a numpy array." assert len(embeddings.shape) == 2, "Embeddings must be a 2D numpy array." assert len(data) == embeddings.shape[0], "Data and embeddings must have the same number of rows." assert len(data) > 0, "Data must have at least one row." tb: pa.Table if DataFrame is not None and isinstance(data, DataFrame): tb = pa.Table.from_pandas(data) elif isinstance(data, list): tb = pa.Table.from_pylist(data) elif isinstance(data, pa.Table): tb = data else: raise ValueError( f"Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table, not {type(data)}" ) del data # Add embeddings to the data. embeddings = embeddings.astype(np.float16) # Fail if any embeddings are NaN or Inf. assert not np.isnan(embeddings).any(), "Embeddings must not contain NaN values." assert not np.isinf(embeddings).any(), "Embeddings must not contain Inf values." pyarrow_embeddings = pa.FixedSizeListArray.from_arrays(embeddings.reshape((-1)), embeddings.shape[1]) data_with_embeddings = tb.append_column("_embeddings", pyarrow_embeddings) self._add_data(data_with_embeddings, pbar=pbar) def _add_data( self, data: pa.Table, pbar=None, ): ''' Low level interface to upload an Arrow Table. Users should generally call 'add_text' or 'add_embeddings.' Args: data: A pyarrow Table that will be cast to the project schema. pbar: A tqdm progress bar to update. Returns: None ''' # Exactly 10 upload workers at a time. num_workers = 10 # Each worker currently is too slow beyond a shard_size of 5000 # The heuristic here is: Never let shards be more than 5,000 items, # OR more than 4MB uncompressed. Whichever is smaller. bytesize = data.nbytes nrow = len(data) shard_size = 5000 n_chunks = int(np.ceil(nrow / shard_size)) # Chunk into 4MB pieces. These will probably compress down a bit. if bytesize / n_chunks > 4_000_000: shard_size = int(np.ceil(nrow / (bytesize / 4_000_000))) data = self._validate_and_correct_arrow_upload( data=data, project=self, ) upload_endpoint = "/v1/project/data/add/arrow" # Actually do the upload def send_request(i): data_shard = data.slice(i, shard_size) with io.BytesIO() as buffer: data_shard = data_shard.replace_schema_metadata({'project_id': self.id}) feather.write_feather(data_shard, buffer, compression='zstd', compression_level=6) buffer.seek(0) response = requests.post( self.atlas_api_path + upload_endpoint, headers=self.header, data=buffer, ) return response # if this method is being called internally, we pass a global progress bar close_pbar = False if pbar is None: close_pbar = True pbar = tqdm(total=int(len(data)) // shard_size) failed = 0 succeeded = 0 errors_504 = 0 with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: futures = {executor.submit(send_request, i): i for i in range(0, len(data), shard_size)} while futures: # check for status of the futures which are currently working done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED) # process any completed futures for future in done: response = future.result() if response.status_code != 200: try: logger.error(f"Shard upload failed: {response.text}") if 'more datums exceeds your organization limit' in response.json(): return False if 'Project transaction lock is held' in response.json(): raise Exception( "Project is currently indexing and cannot ingest new datums. Try again later." ) if 'Insert failed due to ID conflict' in response.json(): continue except (requests.JSONDecodeError, json.decoder.JSONDecodeError): if response.status_code == 413: # Possibly split in two and retry? logger.error("Shard upload failed: you are sending meta-data that is too large.") pbar.update(1) response.close() failed += shard_size elif response.status_code == 504: errors_504 += shard_size start_point = futures[future] logger.debug( f"{self.name}: Connection failed for records {start_point}-{start_point + shard_size}, retrying." ) failure_fraction = errors_504 / (failed + succeeded + errors_504) if failure_fraction > 0.5 and errors_504 > shard_size * 3: raise RuntimeError( f"{self.name}: Atlas is under high load and cannot ingest datums at this time. Please try again later." ) new_submission = executor.submit(send_request, start_point) futures[new_submission] = start_point response.close() else: logger.error(f"{self.name}: Shard upload failed: {response}") failed += shard_size pbar.update(1) response.close() else: # A successful upload. succeeded += shard_size pbar.update(1) response.close() # remove the now completed future del futures[future] # close the progress bar if this method was called with no external progresbar if close_pbar: pbar.close() if failed: logger.warning(f"Failed to upload {failed} datums") if close_pbar: if failed: logger.warning("Upload partially succeeded.") else: logger.info("Upload succeeded.") def update_maps( self, data: List[Dict], embeddings: Optional[np.array] = None, shard_size: int = 1000, num_workers: int = 10 ): ''' Utility method to update a projects maps by adding the given data. Args: data: An [N,] element list of dictionaries containing metadata for each embedding. embeddings: An [N, d] matrix of embeddings for updating embedding projects. Leave as None to update text projects. shard_size: Data is uploaded in parallel by many threads. Adjust the number of datums to upload by each worker. num_workers: The number of workers to use when sending data. ''' # Validate data if self.modality == 'embedding' and embeddings is None: msg = 'Please specify embeddings for updating an embedding project' raise ValueError(msg) if self.modality == 'text' and embeddings is not None: msg = 'Please dont specify embeddings for updating a text project' raise ValueError(msg) if embeddings is not None and len(data) != embeddings.shape[0]: msg = ( 'Expected data and embeddings to be the same length but found lengths {} and {} respectively.'.format() ) raise ValueError(msg) # Add new data logger.info("Uploading data to Nomic's neural database Atlas.") with tqdm(total=len(data) // shard_size) as pbar: for i in range(0, len(data), MAX_MEMORY_CHUNK): if self.modality == 'embedding': self.add_embeddings( embeddings=embeddings[i : i + MAX_MEMORY_CHUNK, :], data=data[i : i + MAX_MEMORY_CHUNK], shard_size=shard_size, num_workers=num_workers, pbar=pbar, ) else: self.add_text( data=data[i : i + MAX_MEMORY_CHUNK], shard_size=shard_size, num_workers=num_workers, pbar=pbar, ) logger.info("Upload succeeded.") # Update maps # finally, update all the indices return self.rebuild_maps() def rebuild_maps(self, rebuild_topic_models: bool = False): ''' Rebuilds all maps in a project with the latest state project data state. Maps will not be rebuilt to reflect the additions, deletions or updates you have made to your data until this method is called. Args: rebuild_topic_models: (Default False) - If true, will create new topic models when updating these indices ''' response = requests.post( self.atlas_api_path + "/v1/project/update_indices", headers=self.header, json={'project_id': self.id, 'rebuild_topic_models': rebuild_topic_models}, ) logger.info(f"Updating maps in project `{self.name}`")