2025-05-16 18:00:22 +04:00

1447 lines
57 KiB
Python

import base64
import concurrent
import concurrent.futures
import io
import json
import os
import pickle
import time
import uuid
from collections import defaultdict
from contextlib import contextmanager
from datetime import date, datetime
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import pyarrow as pa
import requests
from loguru import logger
from pandas import DataFrame
from pyarrow import compute as pc
from pyarrow import feather, ipc
from pydantic import BaseModel, Field
from tqdm import tqdm
import nomic
from .cli import refresh_bearer_token, validate_api_http_response
from .data_inference import convert_pyarrow_schema_for_atlas
from .data_operations import AtlasMapDuplicates, AtlasMapEmbeddings, AtlasMapTags, AtlasMapTopics
from .settings import *
from .utils import assert_valid_project_id, get_object_size_in_bytes
class AtlasUser:
def __init__(self):
self.credentials = refresh_bearer_token()
class AtlasClass(object):
def __init__(self):
'''
Initializes the Atlas client.
'''
if self.credentials['tenant'] == 'staging':
api_hostname = 'staging-api-atlas.nomic.ai'
web_hostname = 'staging-atlas.nomic.ai'
elif self.credentials['tenant'] == 'production':
api_hostname = 'api-atlas.nomic.ai'
web_hostname = 'atlas.nomic.ai'
else:
raise ValueError("Invalid tenant.")
self.atlas_api_path = f"https://{api_hostname}"
self.web_path = f"https://{web_hostname}"
try:
override_api_path = os.environ['ATLAS_API_PATH']
except KeyError:
override_api_path = None
if override_api_path:
self.atlas_api_path = override_api_path
token = self.credentials['token']
self.token = token
self.header = {"Authorization": f"Bearer {token}"}
if self.token:
response = requests.get(
self.atlas_api_path + "/v1/user",
headers=self.header,
)
response = validate_api_http_response(response)
if not response.status_code == 200:
logger.warning(str(response))
logger.info("Your authorization token is no longer valid.")
else:
raise ValueError(
"Could not find an authorization token. Run `nomic login` to authorize this client with the Nomic API."
)
@property
def credentials(self):
return refresh_bearer_token()
def _get_current_user(self):
api_base_path = self.atlas_api_path
if self.atlas_api_path.startswith('https://api-atlas.nomic.ai'):
api_base_path = "https://no-cdn-api-atlas.nomic.ai"
response = requests.get(
api_base_path + "/v1/user",
headers=self.header,
)
response = validate_api_http_response(response)
if not response.status_code == 200:
raise ValueError("Your authorization token is no longer valid. Run `nomic login` to obtain a new one.")
return response.json()
def _validate_map_data_inputs(self, colorable_fields, id_field, data):
'''Validates inputs to map data calls.'''
if not isinstance(colorable_fields, list):
raise ValueError("colorable_fields must be a list of fields")
if id_field in colorable_fields:
raise Exception(f'Cannot color by unique id field: {id_field}')
for field in colorable_fields:
if field not in data[0]:
raise Exception(f"Cannot color by field `{field}` as it is not present in the metadata.")
def _get_current_users_main_organization(self):
'''
Retrieves the ID of the current users default organization.
**Returns:** The ID of the current users default organization
'''
user = self._get_current_user()
for organization in user['organizations']:
if organization['user_id'] == user['sub'] and organization['access_role'] == 'OWNER':
return organization
def _delete_project_by_id(self, project_id):
response = requests.post(
self.atlas_api_path + "/v1/project/remove",
headers=self.header,
json={'project_id': project_id},
)
def _get_project_by_id(self, project_id: str):
'''
Args:
project_id: The project id
Returns:
Returns the requested project.
'''
assert_valid_project_id(project_id)
response = requests.get(
self.atlas_api_path + f"/v1/project/{project_id}",
headers=self.header,
)
if response.status_code != 200:
raise Exception(f"Could not access project with id {project_id}: {response.text}")
return response.json()
def _get_index_job(self, job_id: str):
'''
Args:
job_id: The job id to retrieve the state of.
Returns:
Job ID meta-data.
'''
response = requests.get(
self.atlas_api_path + f"/v1/project/index/job/{job_id}",
headers=self.header,
)
if response.status_code != 200:
raise Exception(f'Could not access job state: {response.text}')
return response.json()
def _validate_and_correct_arrow_upload(self, data: pa.Table, project: "AtlasProject") -> pa.Table:
'''
Private method. validates upload data against the project arrow schema, and associated other checks.
1. If unique_id_field is specified, validates that each datum has that field. If not, adds it and then notifies the user that it was added.
Args:
data: an arrow table.
project: the atlas project you are validating the data for.
Returns:
'''
if not isinstance(data, pa.Table):
raise Exception("Invalid data type for upload: {}".format(type(data)))
if project.meta['modality'] == 'text':
if "_embeddings" in data:
msg = "Can't add embeddings to a text project."
raise ValueError(msg)
if project.meta['modality'] == 'embedding':
if "_embeddings" not in data.column_names:
msg = "Must include embeddings in embedding project upload."
raise ValueError(msg)
if project.id_field not in data.column_names:
raise ValueError(f'Data must contain the ID column `{project.id_field}`')
if project.schema is None:
project._schema = convert_pyarrow_schema_for_atlas(data.schema)
# Reformat to match the schema of the project.
# This includes shuffling the order around if necessary,
# filling in nulls, etc.
reformatted = {}
if data[project.id_field].null_count > 0:
raise ValueError(
f"{project.id_field} must not contain null values, but {data[project.id_field].null_count} found."
)
for field in project.schema:
if field.name in data.column_names:
# Allow loss of precision in dates and ints, etc.
reformatted[field.name] = data[field.name].cast(field.type, safe=False)
else:
raise KeyError(
f"Field {field.name} present in table schema not found in data. Present fields: {data.column_names}"
)
if pa.types.is_string(field.type):
# Ugly temporary measures
if data[field.name].null_count > 0:
logger.warning(
f"Replacing {data[field.name].null_count} null values for field {field.name} with string 'null'. This behavior will change in a future version."
)
reformatted[field.name] = pc.fill_null(reformatted[field.name], "null")
if pc.any(pc.equal(pc.binary_length(reformatted[field.name]), 0)):
mask = pc.equal(pc.binary_length(reformatted[field.name]), 0).combine_chunks()
assert pa.types.is_boolean(mask.type)
reformatted[field.name] = pc.replace_with_mask(reformatted[field.name], mask, "null")
for field in data.schema:
if not field.name in reformatted:
if field.name == "_embeddings":
reformatted['_embeddings'] = data['_embeddings']
else:
logger.warning(f"Field {field.name} present in data, but not found in table schema. Ignoring.")
data = pa.Table.from_pydict(reformatted, schema=project.schema)
if project.meta['insert_update_delete_lock']:
raise Exception("Project is currently indexing and cannot ingest new datums. Try again later.")
# The following two conditions should never occur given the above, but just in case...
assert project.id_field in data.column_names, f"Upload does not contain your specified id_field"
if not pa.types.is_string(data[project.id_field].type):
logger.warning(f"id_field is not a string. Converting to string from {data[project.id_field].type}")
data = data.drop([project.id_field]).append_column(
project.id_field, data[project.id_field].cast(pa.string())
)
for key in data.column_names:
if key.startswith('_'):
if key == '_embeddings':
continue
raise ValueError('Metadata fields cannot start with _')
if pc.max(pc.utf8_length(data[project.id_field])).as_py() > 36:
first_match = data.filter(pc.greater(pc.utf8_length(data[project.id_field]), 36)).to_pylist()[0][
project.id_field
]
raise ValueError(
f"The id_field {first_match} is greater than 36 characters. Atlas does not support id_fields longer than 36 characters."
)
return data
def _get_organization(self, organization_name=None, organization_id=None) -> Tuple[str, str]:
'''
Gets an organization by either it's name or id.
Args:
organization_name: the name of the organization
organization_id: the id of the organization
Returns:
The organization_name and organization_id if one was found.
'''
if organization_name is None:
if organization_id is None: # default to current users organization (the one with their name)
organization = self._get_current_users_main_organization()
organization_name = organization['nickname']
organization_id = organization['organization_id']
else:
raise NotImplementedError("Getting organization by a specific ID is not yet implemented.")
else:
organization_id_request = requests.get(
self.atlas_api_path + f"/v1/organization/search/{organization_name}", headers=self.header
)
if organization_id_request.status_code != 200:
user = self._get_current_user()
users_organizations = [org['nickname'] for org in user['organizations']]
raise Exception(
f"No such organization exists: {organization_name}. You have access to the following organizations: {users_organizations}"
)
organization_id = organization_id_request.json()['organization_id']
return organization_name, organization_id
def _get_existing_project_by_name(self, project_name, organization_name) -> Dict:
'''
Utility method for instantiating an AtlasProject.
Retrieves an existing project by name in a given organization. Fail
Args:
project_name: the project name
organization_name: the organization name
Returns:
A dictionary containing the project_id, organization_id and organization_name
'''
# check if this project already exists.
response = requests.post(
self.atlas_api_path + "/v1/project/search/name",
headers=self.header,
json={'organization_name': organization_name, 'project_name': project_name},
)
if response.status_code != 200:
raise Exception(f"Failed to find project: {response.text}")
search_results = response.json()['results']
if search_results:
existing_project = search_results[0]
existing_project_id = existing_project['id']
return {
'project_id': existing_project_id,
'organization_name': existing_project['owner'],
}
organization_name, organization_id = self._get_organization(organization_name=organization_name)
return {'organization_id': organization_id, 'organization_name': organization_name}
class AtlasIndex:
"""
An AtlasIndex represents a single view of an Atlas Project at a point in time.
An AtlasIndex typically contains one or more *projections* which are 2D representations of
the points in the index that you can browse online.
"""
def __init__(self, atlas_index_id, name, indexed_field, projections):
'''Initializes an Atlas index. Atlas indices organize data and store views of the data as maps.'''
self.id = atlas_index_id
self.name = name
self.indexed_field = indexed_field
self.projections = projections
def _repr_html_(self):
return '<br>'.join([d._repr_html_() for d in self.projections])
class AtlasProjection:
'''
Interact and access state of an Atlas Map including text/vector search.
This class should not be instantiated directly.
Instead instantiate an AtlasProject and use the project.maps attribute to retrieve an AtlasProjection.
'''
def __init__(self, project: "AtlasProject", atlas_index_id: str, projection_id: str, name):
"""
Creates an AtlasProjection.
"""
self.project = project
self.id = projection_id
self.atlas_index_id = atlas_index_id
self.projection_id = projection_id
self.name = name
self._duplicates = None
self._embeddings = None
self._topics = None
self._tags = None
self._tile_data = None
@property
def map_link(self):
'''
Retrieves a map link.
'''
return f"{self.project.web_path}/map/{self.project.id}/{self.id}"
@property
def _status(self):
response = requests.get(
self.project.atlas_api_path + f"/v1/project/index/job/progress/{self.atlas_index_id}",
headers=self.project.header,
)
if response.status_code != 200:
raise Exception(response.text)
content = response.json()
return content
def __str__(self):
return f"{self.name}: {self.map_link}"
def __repr__(self):
return self.__str__()
def _iframe(self):
return f"""
<iframe class="iframe" id="iframe{self.id}" allow="clipboard-read; clipboard-write" src="{self.map_link}">
</iframe>
<style>
.iframe {{
/* vh can be **very** large in vscode ipynb. */
height: min(75vh, 66vw);
width: 100%;
}}
</style>
"""
def _embed_html(self):
return f"""<script>
destroy = function() {{
document.getElementById("iframe{self.id}").remove()
}}
</script>
<div class="actions">
<div id="hide" class="action" onclick="destroy()">Hide embedded project</div>
<div class="action" id="out">
<a href="{self.map_link}" target="_blank">Explore on atlas.nomic.ai</a>
</div>
</div>
{self._iframe()}
<style>
.actions {{
display: block;
}}
.action {{
min-height: 18px;
margin: 5px;
transition: all 500ms ease-in-out;
}}
.action:hover {{
cursor: pointer;
}}
#hide:hover::after {{
content: " X";
}}
#out:hover::after {{
content: "";
}}
</style>
"""
def _repr_html_(self):
# Don't make an iframe if the project is locked.
state = self._status['index_build_stage']
if state != 'Completed':
return f"""Atlas Projection {self.name}. Status {state}. <a target="_blank" href="{self.map_link}">view online</a>"""
return f"""
<h3>Project: {self.name}</h3>
{self._embed_html()}
"""
@property
def duplicates(self):
"""Duplicate detection state"""
if self.project.is_locked:
raise Exception('Project is locked! Please wait until the project is unlocked to access duplicates.')
if self._duplicates is None:
self._duplicates = AtlasMapDuplicates(self)
return self._duplicates
@property
def topics(self):
"""Topic state"""
if self.project.is_locked:
raise Exception('Project is locked for state access! Please wait until the project is unlocked to access topics.')
if self._topics is None:
self._topics = AtlasMapTopics(self)
return self._topics
@property
def embeddings(self):
"""Embedding state"""
if self.project.is_locked:
raise Exception('Project is locked for state access! Please wait until the project is unlocked to access embeddings.')
if self._embeddings is None:
self._embeddings = AtlasMapEmbeddings(self)
return self._embeddings
@property
def tags(self):
"""Embedding state"""
if self._tags is None:
self._tags = AtlasMapTags(self)
return self._tags
def _fetch_tiles(self, overwrite: bool = True):
"""
Downloads all web data for the projection to the specified directory and returns it as a memmapped arrow table.
Args:
overwrite: If True then overwrite web tile files.
Returns:
An Arrow table containing information for all data points in the index.
"""
if self._tile_data is not None:
return self._tile_data
self._download_feather(overwrite=overwrite)
tbs = []
root = feather.read_table(self.tile_destination / "0/0/0.feather")
try:
sidecars = set([v for k, v in json.loads(root.schema.metadata[b'sidecars']).items()])
except KeyError:
sidecars = []
for path in self.tile_destination.glob('**/*.feather'):
if len(path.stem.split(".")) > 1:
# Sidecars are loaded alongside
continue
tb = pa.feather.read_table(path)
for sidecar_file in sidecars:
carfile = pa.feather.read_table(path.parent / f"{path.stem}.{sidecar_file}.feather")
for col in carfile.column_names:
tb = tb.append_column(col, carfile[col])
tbs.append(tb)
self._tile_data = pa.concat_tables(tbs)
return self._tile_data
@property
def tile_destination(self):
return Path("~/.nomic/cache", self.id).expanduser()
def _download_feather(self, dest: Optional[Union[str, Path]] = None, overwrite: bool = True):
'''
Downloads the feather tree.
Args:
overwrite: if True then overwrite existing feather files.
Returns:
A list containing all quadtiles downloads
'''
self.tile_destination.mkdir(parents=True, exist_ok=True)
root = f'{self.project.atlas_api_path}/v1/project/public/{self.project.id}/index/projection/{self.id}/quadtree/'
quads = [f'0/0/0']
all_quads = []
sidecars = None
while len(quads) > 0:
rawquad = quads.pop(0)
quad = rawquad + ".feather"
all_quads.append(quad)
path = self.tile_destination / quad
if not path.exists() or overwrite:
data = requests.get(root + quad)
readable = io.BytesIO(data.content)
readable.seek(0)
tb = feather.read_table(readable)
path.parent.mkdir(parents=True, exist_ok=True)
feather.write_feather(tb, path)
schema = ipc.open_file(path).schema
if sidecars is None and b'sidecars' in schema.metadata:
# Grab just the filenames
sidecars = set([v for k, v in json.loads(schema.metadata.get(b'sidecars')).items()])
elif sidecars is None:
sidecars = set()
if not "." in rawquad:
for sidecar in sidecars:
# The sidecar loses the feather suffix because it's supposed to be raw.
quads.append(quad.replace(".feather", f'.{sidecar}'))
if not schema.metadata or b'children' not in schema.metadata:
# Sidecars don't have children.
continue
kids = schema.metadata.get(b'children')
children = json.loads(kids)
quads.extend(children)
return all_quads
@property
def datum_id_field(self):
return self.project.meta["unique_id_field"]
def _get_atoms(self, ids: List[str]) -> List[Dict]:
'''
Retrieves atoms by id
Args:
ids: list of atom ids
Returns:
A dictionary containing the resulting atoms, keyed by atom id.
'''
if not isinstance(ids, list):
raise ValueError("You must specify a list of ids when getting data.")
response = requests.post(
self.project.atlas_api_path + "/v1/project/atoms/get",
headers=self.project.header,
json={'project_id': self.project.id, 'index_id': self.atlas_index_id, 'atom_ids': ids},
)
if response.status_code == 200:
return response.json()['atoms']
else:
raise Exception(response.text)
class AtlasProject(AtlasClass):
def __init__(
self,
name: Optional[str] = None,
description: Optional[str] = 'A description for your map.',
unique_id_field: Optional[str] = None,
modality: Optional[str] = None,
organization_name: Optional[str] = None,
is_public: bool = True,
project_id=None,
reset_project_if_exists=False,
add_datums_if_exists=True,
):
"""
Creates or loads an Atlas project.
Atlas projects store data (text, embeddings, etc) that you can organize by building indices.
If the organization already contains a project with this name, it will be returned instead.
**Parameters:**
* **project_name** - The name of the project.
* **description** - A description for the project.
* **unique_id_field** - The field that uniquely identifies each datum. If a datum does not contain this field, it will be added and assigned a random unique ID.
* **modality** - The data modality of this project. Currently, Atlas supports either `text` or `embedding` modality projects.
* **organization_name** - The name of the organization to create this project under. You must be a member of the organization with appropriate permissions. If not specified, defaults to your user account's default organization.
* **is_public** - Should this project be publicly accessible for viewing (read only). If False, only members of your Nomic organization can view.
* **reset_project_if_exists** - If the requested project exists in your organization, will delete it and re-create it.
* **project_id** - An alternative way to retrieve a project is by passing the project_id directly. This only works if a project exists.
* **reset_project_if_exists** - If the specified project exists in your organization, reset it by deleting all of its data. This means your uploaded data will not be contextualized with existing data.
* **add_datums_if_exists** - If specifying an existing project and you want to add data to it, set this to true.
"""
assert name is not None or project_id is not None, "You must pass a name or project_id"
super().__init__()
if project_id is not None:
self.meta = self._get_project_by_id(project_id)
return
if organization_name is None:
organization_name = self._get_current_users_main_organization()['nickname']
results = self._get_existing_project_by_name(project_name=name, organization_name=organization_name)
if 'project_id' in results: # project already exists
organization_name = results['organization_name']
project_id = results['project_id']
if reset_project_if_exists: # reset the project
logger.info(
f"Found existing project `{name}` in organization `{organization_name}`. Clearing it of data by request."
)
self._delete_project_by_id(project_id=project_id)
project_id = None
elif not add_datums_if_exists: # prevent adding datums to existing project explicitly
raise ValueError(
f"Project already exists with the name `{name}` in organization `{organization_name}`. "
f"You can add datums to it by settings `add_datums_if_exists = True` or reset it by specifying `reset_project_if_exists=True` on a new upload."
)
else:
logger.info(f"Loading existing project `{name}` from organization `{organization_name}`.")
if project_id is None: # if there is no existing project, make a new one.
if unique_id_field is None:
unique_id_field = ATLAS_DEFAULT_ID_FIELD
raise ValueError("You must specify a unique_id_field when creating a new project.")
if modality is None:
raise ValueError("You must specify a modality when creating a new project.")
assert modality in ['text', 'embedding'], "Modality must be either `text` or `embedding`"
assert name is not None
project_id = self._create_project(
project_name=name,
description=description,
unique_id_field=unique_id_field,
modality=modality,
organization_name=organization_name,
is_public=is_public,
)
self.meta = self._get_project_by_id(project_id=project_id)
self._schema = None
def delete(self):
'''
Deletes an atlas project with all associated metadata.
'''
organization = self._get_current_users_main_organization()
organization_name = organization['nickname']
logger.info(f"Deleting project `{self.name}` from organization `{organization_name}`")
self._delete_project_by_id(project_id=self.id)
return False
def _create_project(
self,
project_name: str,
description: Optional[str],
unique_id_field: str,
modality: str,
organization_name: Optional[str] = None,
is_public: bool = True,
):
'''
Creates an Atlas project.
Atlas projects store data (text, embeddings, etc) that you can organize by building indices.
If the organization already contains a project with this name, it will be returned instead.
**Parameters:**
* **project_name** - The name of the project.
* **description** - A description for the project.
* **unique_id_field** - The field that uniquely identifies each datum. If a datum does not contain this field, it will be added and assigned a random unique ID.
* **modality** - The data modality of this project. Currently, Atlas supports either `text` or `embedding` modality projects.
* **organization_name** - The name of the organization to create this project under. You must be a member of the organization with appropriate permissions. If not specified, defaults to your user accounts default organization.
* **is_public** - Should this project be publicly accessible for viewing (read only). If False, only members of your Nomic organization can view.
**Returns:** project_id on success.
'''
organization_name, organization_id = self._get_organization(organization_name=organization_name)
supported_modalities = ['text', 'embedding']
if modality not in supported_modalities:
msg = 'Tried to create project with modality: {}, but Atlas only supports: {}'.format(
modality, supported_modalities
)
raise ValueError(msg)
if unique_id_field is None:
raise ValueError("You must specify a unique id field")
logger.info(f"Creating project `{project_name}` in organization `{organization_name}`")
if description is None:
description = ""
response = requests.post(
self.atlas_api_path + "/v1/project/create",
headers=self.header,
json={
'organization_id': organization_id,
'project_name': project_name,
'description': description,
'unique_id_field': unique_id_field,
'modality': modality,
'is_public': is_public,
},
)
if response.status_code != 201:
raise Exception(f"Failed to create project: {response.json()}")
return response.json()['project_id']
def _latest_project_state(self):
'''
Refreshes the project's state. Try to call this sparingly but use it when you need it.
'''
self.meta = self._get_project_by_id(self.id)
return self
@property
def indices(self) -> List[AtlasIndex]:
self._latest_project_state()
output = []
for index in self.meta['atlas_indices']:
projections = []
for projection in index['projections']:
projection = AtlasProjection(
project=self, projection_id=projection['id'], atlas_index_id=index['id'], name=index['index_name']
)
projections.append(projection)
index = AtlasIndex(
atlas_index_id=index['id'],
name=index['index_name'],
indexed_field=index['indexed_field'],
projections=projections,
)
output.append(index)
return output
@property
def projections(self) -> List[AtlasProjection]:
output = []
for index in self.indices:
for projection in index.projections:
output.append(projection)
return output
@property
def maps(self) -> List[AtlasProjection]:
return self.projections
@property
def id(self) -> str:
'''The UUID of the project.'''
return self.meta['id']
@property
def id_field(self) -> str:
return self.meta['unique_id_field']
@property
def total_datums(self) -> int:
'''The total number of data points in the project.'''
return self.meta['total_datums_in_project']
@property
def modality(self) -> str:
return self.meta['modality']
@property
def name(self) -> str:
'''The name of the project.'''
return self.meta['project_name']
@property
def description(self):
return self.meta['description']
@property
def project_fields(self):
return self.meta['project_fields']
@property
def is_locked(self) -> bool:
self._latest_project_state()
return self.meta['insert_update_delete_lock']
@property
def schema(self) -> Optional[pa.Schema]:
if self._schema is not None:
return self._schema
if 'schema' in self.meta and self.meta['schema'] is not None:
self._schema: pa.Schema = ipc.read_schema(io.BytesIO(base64.b64decode(self.meta['schema'])))
return self._schema
return None
@property
def is_accepting_data(self) -> bool:
'''
Checks if the project can accept data. Projects cannot accept data when they are being indexed.
Returns:
True if project is unlocked for data additions, false otherwise.
'''
return not self.is_locked
@contextmanager
def wait_for_project_lock(self):
'''Blocks thread execution until project is in a state where it can ingest data.'''
has_logged = False
while True:
if self.is_accepting_data:
yield self
break
if not has_logged:
logger.info(f"{self.name}: Waiting for Project Lock Release.")
has_logged = True
time.sleep(5)
def get_map(self, name: str = None, atlas_index_id: str = None, projection_id: str = None) -> AtlasProjection:
'''
Retrieves a Map
Args:
name: The name of your map. This defaults to your projects name but can be different if you build multiple maps in your project.
atlas_index_id: If specified, will only return a map if there is one built under the index with the id atlas_index_id.
projection_id: If projection_id is specified, will only return a map if there is one built under the index with id projection_id.
Returns:
The map or a ValueError.
'''
indices = self.indices
if atlas_index_id is not None:
for index in indices:
if index.id == atlas_index_id:
if len(index.projections) == 0:
raise ValueError(f"No map found under index with atlas_index_id='{atlas_index_id}'")
return index.projections[0]
raise ValueError(f"Could not find a map with atlas_index_id='{atlas_index_id}'")
if projection_id is not None:
for index in indices:
for projection in index.projections:
if projection.id == projection_id:
return projection
raise ValueError(f"Could not find a map with projection_id='{atlas_index_id}'")
if len(indices) == 0:
raise ValueError("You have no maps built in your project")
if len(indices) > 1 and name is None:
raise ValueError("You have multiple maps in this project, specify a name.")
if len(indices) == 1:
if len(indices[0].projections) == 1:
return indices[0].projections[0]
for index in indices:
if index.name == name:
return index.projections[0]
raise ValueError(f"Could not find a map named {name} in your project.")
def create_index(
self,
name: str,
indexed_field: str = None,
colorable_fields: list = [],
multilingual: bool = False,
build_topic_model: bool = False,
projection_n_neighbors: int = DEFAULT_PROJECTION_N_NEIGHBORS,
projection_epochs: int = DEFAULT_PROJECTION_EPOCHS,
projection_spread: float = DEFAULT_PROJECTION_SPREAD,
topic_label_field: str = None,
reuse_embeddings_from_index: str = None,
duplicate_detection: bool = False,
duplicate_threshold: float = DEFAULT_DUPLICATE_THRESHOLD,
) -> AtlasProjection:
'''
Creates an index in the specified project.
Args:
name: The name of the index and the map.
indexed_field: For text projects, name the data field corresponding to the text to be mapped.
colorable_fields: The project fields you want to be able to color by on the map. Must be a subset of the projects fields.
multilingual: Should the map take language into account? If true, points from different languages but semantically similar text are close together.
build_topic_model: Should a topic model be built?
projection_n_neighbors: A projection hyperparameter
projection_epochs: A projection hyperparameter
projection_spread: A projection hyperparameter
topic_label_field: A text field in your metadata to estimate topic labels from. Defaults to the indexed_field for text projects if not specified.
reuse_embeddings_from_index: the name of the index to reuse embeddings from.
duplicate_detection: A boolean whether to run duplicate detection
duplicate_threshold: At which threshold to consider points to be duplicates
Returns:
The projection this index has built.
'''
self._latest_project_state()
# for large projects, alter the default projection configurations.
if self.total_datums >= 1_000_000:
if (
projection_epochs == DEFAULT_PROJECTION_EPOCHS
and projection_n_neighbors == DEFAULT_PROJECTION_N_NEIGHBORS
):
projection_n_neighbors = DEFAULT_LARGE_PROJECTION_N_NEIGHBORS
projection_epochs = DEFAULT_LARGE_PROJECTION_EPOCHS
if self.modality == 'embedding':
if duplicate_detection:
raise ValueError("Cannot tag duplicates in an embedding project.")
build_template = {
'project_id': self.id,
'index_name': name,
'indexed_field': None,
'atomizer_strategies': None,
'model': None,
'colorable_fields': colorable_fields,
'model_hyperparameters': None,
'nearest_neighbor_index': 'HNSWIndex',
'nearest_neighbor_index_hyperparameters': json.dumps({'space': 'l2', 'ef_construction': 100, 'M': 16}),
'projection': 'NomicProject',
'projection_hyperparameters': json.dumps(
{'n_neighbors': projection_n_neighbors, 'n_epochs': projection_epochs, 'spread': projection_spread}
),
'topic_model_hyperparameters': json.dumps(
{'build_topic_model': build_topic_model, 'community_description_target_field': topic_label_field}
),
}
elif self.modality == 'text':
# find the index id of the index with name reuse_embeddings_from_index
reuse_embedding_from_index_id = None
indices = self.indices
if reuse_embeddings_from_index is not None:
for index in indices:
if index.name == reuse_embeddings_from_index:
reuse_embedding_from_index_id = index.id
break
if reuse_embedding_from_index_id is None:
raise Exception(
f"Could not find the index '{reuse_embeddings_from_index}' to re-use from. Possible options are {[index.name for index in indices]}"
)
if indexed_field is None:
raise Exception("You did not specify a field to index. Specify an 'indexed_field'.")
if indexed_field not in self.project_fields:
raise Exception(f"Indexing on {indexed_field} not allowed. Valid options are: {self.project_fields}")
model = 'NomicEmbed'
if multilingual:
model = 'NomicEmbedMultilingual'
build_template = {
'project_id': self.id,
'index_name': name,
'indexed_field': indexed_field,
'atomizer_strategies': ['document', 'charchunk'],
'model': model,
'colorable_fields': colorable_fields,
'reuse_atoms_and_embeddings_from': reuse_embedding_from_index_id,
'model_hyperparameters': json.dumps(
{
'dataset_buffer_size': 1000,
'batch_size': 20,
'polymerize_by': 'charchunk',
'norm': 'both',
}
),
'nearest_neighbor_index': 'HNSWIndex',
'nearest_neighbor_index_hyperparameters': json.dumps({'space': 'l2', 'ef_construction': 100, 'M': 16}),
'projection': 'NomicProject',
'projection_hyperparameters': json.dumps(
{'n_neighbors': projection_n_neighbors, 'n_epochs': projection_epochs, 'spread': projection_spread}
),
'topic_model_hyperparameters': json.dumps(
{'build_topic_model': build_topic_model, 'community_description_target_field': indexed_field}
),
'duplicate_detection_hyperparameters': json.dumps(
{'tag_duplicates': duplicate_detection, 'duplicate_cutoff': duplicate_threshold}
),
}
response = requests.post(
self.atlas_api_path + "/v1/project/index/create",
headers=self.header,
json=build_template,
)
if response.status_code != 200:
logger.info('Create project failed with code: {}'.format(response.status_code))
logger.info('Additional info: {}'.format(response.text))
raise Exception(response.json()['detail'])
job_id = response.json()['job_id']
job = requests.get(
self.atlas_api_path + f"/v1/project/index/job/{job_id}",
headers=self.header,
).json()
index_id = job['index_id']
try:
projection = self.get_map(atlas_index_id=index_id)
except ValueError:
# give some delay
time.sleep(5)
try:
projection = self.get_map(atlas_index_id=index_id)
except ValueError:
projection = None
if projection is None:
logger.warning(
"Could not find a map being built for this project. See atlas.nomic.ai/dashboard for map status."
)
logger.info(f"Created map `{projection.name}` in project `{self.name}`: {projection.map_link}")
return projection
def __repr__(self):
m = self.meta
return f"AtlasProject: <{m}>"
def _repr_html_(self):
self._latest_project_state()
m = self.meta
html = f"""
<strong><a href="https://atlas.nomic.ai/dashboard/project/{m['id']}">{m['project_name']}</strong></a>
<br>
{m['description']} {m['total_datums_in_project']} datums inserted.
<br>
{len(m['atlas_indices'])} index built.
"""
complete_projections = []
if len(self.projections) >= 1:
html += "<br><strong>Projections</strong>\n"
html += "<ul>\n"
for projection in self.projections:
state = projection._status['index_build_stage']
if state == 'Completed':
complete_projections.append(projection)
html += f"""<li>{projection.name}. Status {state}. <a target="_blank" href="{projection.map_link}">view online</a></li>"""
html += "</ul>"
if len(complete_projections) >= 1:
# Display most recent complete projection.
html += "<hr>"
html += complete_projections[-1]._embed_html()
return html
def __str__(self):
return "\n".join([str(projection) for index in self.indices for projection in index.projections])
def get_data(self, ids: List[str]) -> List[Dict]:
'''
Retrieve the contents of the data given ids
Args:
ids: a list of datum ids
Returns:
A list of dictionaries corresponding
'''
if not isinstance(ids, list):
raise ValueError("You must specify a list of ids when getting data.")
if isinstance(ids[0], list):
raise ValueError("You must specify a list of ids when getting data, not a nested list.")
response = requests.post(
self.atlas_api_path + "/v1/project/data/get",
headers=self.header,
json={'project_id': self.id, 'datum_ids': ids},
)
if response.status_code == 200:
return [item for item in response.json()['datums']]
else:
raise Exception(response.text)
def delete_data(self, ids: List[str]) -> bool:
'''
Deletes the specified datums from the project.
Args:
ids: A list of datum ids to delete
Returns:
True if data deleted successfully.
'''
if not isinstance(ids, list):
raise ValueError("You must specify a list of ids when deleting datums.")
response = requests.post(
self.atlas_api_path + "/v1/project/data/delete",
headers=self.header,
json={'project_id': self.id, 'datum_ids': ids},
)
if response.status_code == 200:
return True
else:
raise Exception(response.text)
def add_text(self, data=Union[DataFrame, List[Dict], pa.Table], pbar=None, shard_size=None, num_workers=None):
"""
Add text data to the project.
data: A pandas DataFrame, a list of python dictionaries, or a pyarrow Table matching the project schema.
pbar: (Optional). A tqdm progress bar to display progress.
"""
if shard_size is not None or num_workers is not None:
raise DeprecationWarning("shard_size and num_workers are deprecated.")
if DataFrame is not None and isinstance(data, DataFrame):
data = pa.Table.from_pandas(data)
elif isinstance(data, list):
data = pa.Table.from_pylist(data)
elif not isinstance(data, pa.Table):
raise ValueError("Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table.")
self._add_data(data, pbar=pbar)
def add_embeddings(
self,
data: Union[DataFrame, List[Dict], pa.Table, None],
embeddings: np.array,
pbar=None,
shard_size=None,
num_workers=None,
):
"""
Add data, with associated embeddings, to the project.
Args:
data: A pandas DataFrame, list of dictionaries, or pyarrow Table matching the project schema.
embeddings: A numpy array of embeddings: each row corresponds to a row in the table.
pbar: (Optional). A tqdm progress bar to update.
"""
"""
# TODO: validate embedding size.
assert embeddings.shape[1] == self.embedding_size, "Embedding size must match the embedding size of the project."
"""
if shard_size is not None:
raise DeprecationWarning("shard_size is deprecated and no longer has any effect")
if num_workers is not None:
raise DeprecationWarning("num_workers is deprecated and no longer has any effect")
assert type(embeddings) == np.ndarray, "Embeddings must be a numpy array."
assert len(embeddings.shape) == 2, "Embeddings must be a 2D numpy array."
assert len(data) == embeddings.shape[0], "Data and embeddings must have the same number of rows."
assert len(data) > 0, "Data must have at least one row."
tb: pa.Table
if DataFrame is not None and isinstance(data, DataFrame):
tb = pa.Table.from_pandas(data)
elif isinstance(data, list):
tb = pa.Table.from_pylist(data)
elif isinstance(data, pa.Table):
tb = data
else:
raise ValueError(
f"Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table, not {type(data)}"
)
del data
# Add embeddings to the data.
embeddings = embeddings.astype(np.float16)
# Fail if any embeddings are NaN or Inf.
assert not np.isnan(embeddings).any(), "Embeddings must not contain NaN values."
assert not np.isinf(embeddings).any(), "Embeddings must not contain Inf values."
pyarrow_embeddings = pa.FixedSizeListArray.from_arrays(embeddings.reshape((-1)), embeddings.shape[1])
data_with_embeddings = tb.append_column("_embeddings", pyarrow_embeddings)
self._add_data(data_with_embeddings, pbar=pbar)
def _add_data(
self,
data: pa.Table,
pbar=None,
):
'''
Low level interface to upload an Arrow Table. Users should generally call 'add_text' or 'add_embeddings.'
Args:
data: A pyarrow Table that will be cast to the project schema.
pbar: A tqdm progress bar to update.
Returns:
None
'''
# Exactly 10 upload workers at a time.
num_workers = 10
# Each worker currently is too slow beyond a shard_size of 5000
# The heuristic here is: Never let shards be more than 5,000 items,
# OR more than 4MB uncompressed. Whichever is smaller.
bytesize = data.nbytes
nrow = len(data)
shard_size = 5000
n_chunks = int(np.ceil(nrow / shard_size))
# Chunk into 4MB pieces. These will probably compress down a bit.
if bytesize / n_chunks > 4_000_000:
shard_size = int(np.ceil(nrow / (bytesize / 4_000_000)))
data = self._validate_and_correct_arrow_upload(
data=data,
project=self,
)
upload_endpoint = "/v1/project/data/add/arrow"
# Actually do the upload
def send_request(i):
data_shard = data.slice(i, shard_size)
with io.BytesIO() as buffer:
data_shard = data_shard.replace_schema_metadata({'project_id': self.id})
feather.write_feather(data_shard, buffer, compression='zstd', compression_level=6)
buffer.seek(0)
response = requests.post(
self.atlas_api_path + upload_endpoint,
headers=self.header,
data=buffer,
)
return response
# if this method is being called internally, we pass a global progress bar
close_pbar = False
if pbar is None:
close_pbar = True
pbar = tqdm(total=int(len(data)) // shard_size)
failed = 0
succeeded = 0
errors_504 = 0
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(send_request, i): i for i in range(0, len(data), shard_size)}
while futures:
# check for status of the futures which are currently working
done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED)
# process any completed futures
for future in done:
response = future.result()
if response.status_code != 200:
try:
logger.error(f"Shard upload failed: {response.text}")
if 'more datums exceeds your organization limit' in response.json():
return False
if 'Project transaction lock is held' in response.json():
raise Exception(
"Project is currently indexing and cannot ingest new datums. Try again later."
)
if 'Insert failed due to ID conflict' in response.json():
continue
except (requests.JSONDecodeError, json.decoder.JSONDecodeError):
if response.status_code == 413:
# Possibly split in two and retry?
logger.error("Shard upload failed: you are sending meta-data that is too large.")
pbar.update(1)
response.close()
failed += shard_size
elif response.status_code == 504:
errors_504 += shard_size
start_point = futures[future]
logger.debug(
f"{self.name}: Connection failed for records {start_point}-{start_point + shard_size}, retrying."
)
failure_fraction = errors_504 / (failed + succeeded + errors_504)
if failure_fraction > 0.5 and errors_504 > shard_size * 3:
raise RuntimeError(
f"{self.name}: Atlas is under high load and cannot ingest datums at this time. Please try again later."
)
new_submission = executor.submit(send_request, start_point)
futures[new_submission] = start_point
response.close()
else:
logger.error(f"{self.name}: Shard upload failed: {response}")
failed += shard_size
pbar.update(1)
response.close()
else:
# A successful upload.
succeeded += shard_size
pbar.update(1)
response.close()
# remove the now completed future
del futures[future]
# close the progress bar if this method was called with no external progresbar
if close_pbar:
pbar.close()
if failed:
logger.warning(f"Failed to upload {failed} datums")
if close_pbar:
if failed:
logger.warning("Upload partially succeeded.")
else:
logger.info("Upload succeeded.")
def update_maps(
self, data: List[Dict], embeddings: Optional[np.array] = None, shard_size: int = 1000, num_workers: int = 10
):
'''
Utility method to update a projects maps by adding the given data.
Args:
data: An [N,] element list of dictionaries containing metadata for each embedding.
embeddings: An [N, d] matrix of embeddings for updating embedding projects. Leave as None to update text projects.
shard_size: Data is uploaded in parallel by many threads. Adjust the number of datums to upload by each worker.
num_workers: The number of workers to use when sending data.
'''
# Validate data
if self.modality == 'embedding' and embeddings is None:
msg = 'Please specify embeddings for updating an embedding project'
raise ValueError(msg)
if self.modality == 'text' and embeddings is not None:
msg = 'Please dont specify embeddings for updating a text project'
raise ValueError(msg)
if embeddings is not None and len(data) != embeddings.shape[0]:
msg = (
'Expected data and embeddings to be the same length but found lengths {} and {} respectively.'.format()
)
raise ValueError(msg)
# Add new data
logger.info("Uploading data to Nomic's neural database Atlas.")
with tqdm(total=len(data) // shard_size) as pbar:
for i in range(0, len(data), MAX_MEMORY_CHUNK):
if self.modality == 'embedding':
self.add_embeddings(
embeddings=embeddings[i : i + MAX_MEMORY_CHUNK, :],
data=data[i : i + MAX_MEMORY_CHUNK],
shard_size=shard_size,
num_workers=num_workers,
pbar=pbar,
)
else:
self.add_text(
data=data[i : i + MAX_MEMORY_CHUNK],
shard_size=shard_size,
num_workers=num_workers,
pbar=pbar,
)
logger.info("Upload succeeded.")
# Update maps
# finally, update all the indices
return self.rebuild_maps()
def rebuild_maps(self, rebuild_topic_models: bool = False):
'''
Rebuilds all maps in a project with the latest state project data state. Maps will not be rebuilt to
reflect the additions, deletions or updates you have made to your data until this method is called.
Args:
rebuild_topic_models: (Default False) - If true, will create new topic models when updating these indices
'''
response = requests.post(
self.atlas_api_path + "/v1/project/update_indices",
headers=self.header,
json={'project_id': self.id, 'rebuild_topic_models': rebuild_topic_models},
)
logger.info(f"Updating maps in project `{self.name}`")