2025-05-16 18:00:22 +04:00

764 lines
28 KiB
Python

import base64
import concurrent
import concurrent.futures
import io
import json
import os
from collections import defaultdict
from datetime import datetime
from typing import Dict, Iterable, List, Tuple
import numpy as np
import pandas
import pandas as pd
import pyarrow as pa
import requests
from loguru import logger
from pyarrow import compute as pc
from pyarrow import ipc
from tqdm import tqdm
from .settings import EMBEDDING_PAGINATION_LIMIT
class AtlasMapDuplicates:
"""
Atlas Duplicate Clusters State
Atlas automatically groups embeddings that are sufficiently close into semantic clusters.
You can use these clusters for semantic duplicate detection allowing you to quickly deduplicate
your data.
=== "Accessing Duplicates Example"
``` py
from nomic import AtlasProject
project = AtlasProject(name='My Project')
map = project.maps[0]
print(map.duplicates)
```
=== "Output"
```
460 deletion candidates in 9540 clusters
id_ duplicate_class cluster_id
0 0A singleton 5178
1 0g retention candidate 271
2 0Q singleton 6672
3 0w singleton 7529
4 1A singleton 1587
... ... ... ...
9999 JZU singleton 6346
```
"""
def __init__(self, projection: "AtlasProjection"):
self.projection = projection
self.id_field = self.projection.project.id_field
try:
self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, '_duplicate_class', '_cluster_id'])
except pa.lib.ArrowInvalid as e:
raise ValueError("Duplicate detection has not yet been run on this map.")
self._tb = self._tb.rename_columns([self.id_field, 'duplicate_class', 'cluster_id'])
@property
def df(self) -> pd.DataFrame:
"""
Pandas dataframe mapping each data point to its cluster of semantically similar points
=== "Accessing Duplicates Example"
``` py
from nomic import AtlasProject
project = AtlasProject(name='My Project')
map = project.maps[0]
print(map.duplicates.df)
```
=== "Output"
```
id_ _duplicate_class _cluster_id
0 0A singleton 5178
1 0g retention candidate 271
2 0Q singleton 6672
3 0w singleton 7529
4 1A singleton 1587
... ... ... ...
9999 JZU singleton 6346
```
"""
return self.tb.to_pandas()
@property
def tb(self) -> pa.Table:
"""
Pyarrow table with information about duplicate clusters and candidates.
This table is memmapped from the underlying files and is the most efficient way to
access duplicate information.
"""
return self._tb
def deletion_candidates(self) -> List[str]:
"""
Returns:
The ids for all data points which are semantic duplicates and are candidates for being deleted from the dataset. If you remove these data points from your dataset, your dataset will be semantically deduplicated.
"""
dupes = self.tb[self.id_field].filter(pc.equal(self.tb['duplicate_class'], 'deletion candidate'))
return dupes.to_pylist()
def __repr__(self) -> str:
repr = f"===Atlas Duplicates for ({self.projection})===\n"
duplicate_count = len(
self.tb[self.id_field].filter(pc.equal(self.tb['duplicate_class'], 'deletion candidate'))
)
cluster_count = len(self.tb['cluster_id'].value_counts())
repr += f"{duplicate_count} deletion candidates in {cluster_count} clusters\n"
return repr + self.df.__repr__()
class AtlasMapTopics:
"""
Atlas Topics State
=== "Accessing Topics Example"
``` py
from nomic import AtlasProject
project = AtlasProject(name='My Project')
map = project.maps[0]
print(map.topics)
```
=== "Output"
```
id_ topic_depth_1 topic_depth_2 topic_depth_3
0 000262a5-2811 Space exploration Hurricane Jeanne Spacecraft Cassini
1 000c453d-ee97 English football Athens 2004 Olympics bobby rathore
...
9999 fffcc65c-38dc Space exploration Presidential elections Blood
```
"""
def __init__(self, projection: "AtlasProjection"):
self.projection = projection
self.project = projection.project
self.id_field = self.projection.project.id_field
try:
self._tb: pa.Table = projection._fetch_tiles().select(
[self.id_field, '_topic_depth_1', '_topic_depth_2', '_topic_depth_3']
).rename_columns([self.id_field, 'topic_depth_1', 'topic_depth_2', 'topic_depth_3'])
except pa.lib.ArrowInvalid as e:
raise ValueError("Topic modeling has not yet been run on this map.")
self._metadata = None
self._hierarchy = None
@property
def df(self) -> pandas.DataFrame:
"""
A pandas dataframe associating each datapoint on your map to their topics as each topic depth.
"""
return self.tb.to_pandas()
@property
def tb(self) -> pa.Table:
"""
Pyarrow table associating each datapoint on the map to their Atlas assigned topics.
This table is memmapped from the underlying files and is the most efficient way to
access topic information.
"""
return self._tb
@property
def metadata(self) -> pandas.DataFrame:
"""
Pandas dataframe where each row gives metadata all map topics including:
- topic id
- a human readable topic description
- identifying keywords that differentiate the topic from other topics
"""
if self._metadata is not None:
return self._metadata
response = requests.get(
self.projection.project.atlas_api_path
+ "/v1/project/{}/index/projection/{}".format(
self.projection.project.meta['id'], self.projection.projection_id
),
headers=self.projection.project.header,
)
topics = json.loads(response.text)['topic_models'][0]['features']
topic_data = [e['properties'] for e in topics]
topic_data = pd.DataFrame(topic_data)
topic_data = topic_data.rename(columns={"topic": "topic_id",
'_topic_depth_1': 'topic_depth_1',
'_topic_depth_2': 'topic_depth_2',
'_topic_depth_3': 'topic_depth_3'})
self._metadata = topic_data
return topic_data
@property
def hierarchy(self) -> Dict:
"""
A dictionary that allows iteration of the topic hierarchy. Each key is a topic mapping to its sub-topics.
If topic is not a key in the hierarchy, it is leaf in the topic hierarchy.
"""
if self._hierarchy is not None:
return self._hierarchy
topic_df = self.metadata
topic_hierarchy = defaultdict(list)
cols = ["topic_id", "topic_depth_1", "topic_depth_2", "topic_depth_3"]
for i, row in topic_df[cols].iterrows():
# Only consider the non-null values for each row
topics = [topic for topic in row if pd.notna(topic)]
# Iterate over the topics in each row, adding each topic to the
# list of subtopics for the topic at the previous depth
for i in range(1, len(topics) - 1):
if topics[i + 1] not in topic_hierarchy[topics[i]]:
topic_hierarchy[topics[i]].append(topics[i + 1])
self._heirarchy = dict(topic_hierarchy)
return self._heirarchy
def group_by_topic(self, topic_depth: int = 1) -> List[Dict]:
"""
Associates topics at a given depth in the topic hierarchy to the identifiers of their contained datapoints.
Args:
topic_depth: Topic depth to group datums by. Acceptable values
currently are (1, 2, 3).
Returns:
List of dictionaries where each dictionary contains next depth
subtopics, subtopic ids, topic_id, topic_short_description,
topic_long_description, and list of datum_ids.
"""
topic_cols = []
# TODO: This will need to be changed once topic depths becomes dynamic and not hard-coded
if topic_depth not in (1, 2, 3):
raise ValueError("Topic depth out of range.")
# Unique datum id column to aggregate
datum_id_col = self.project.meta["unique_id_field"]
df = self.df
topic_datum_dict = df.groupby(f"topic_depth_{topic_depth}")[datum_id_col].apply(set).to_dict()
topic_df = self.metadata
hierarchy = self.hierarchy
result = []
for topic, datum_ids in topic_datum_dict.items():
# Encountered topic with zero datums
if len(datum_ids) == 0:
continue
result_dict = {}
topic_metadata = topic_df[topic_df["topic_short_description"] == topic]
subtopics = hierarchy[topic]
result_dict["subtopics"] = subtopics
result_dict["subtopic_ids"] = topic_df[topic_df["topic_short_description"].isin(subtopics)][
"topic_id"
].tolist()
result_dict["topic_id"] = topic_metadata["topic_id"].item()
result_dict["topic_short_description"] = topic_metadata["topic_short_description"].item()
result_dict["topic_long_description"] = topic_metadata["topic_description"].item()
result_dict["datum_ids"] = datum_ids
result.append(result_dict)
return result
def get_topic_density(self, time_field: str, start: datetime, end: datetime):
'''
Computes the density/frequency of topics in a given interval of a timestamp field.
Useful for answering questions such as:
- What topics increased in prevalence between December and January?
Args:
time_field: Your metadata field containing isoformat timestamps
start: A datetime object for the window start
end: A datetime object for the window end
Returns:
List[{topic: str, count: int}] - A list of {topic, count} dictionaries, sorted from largest count to smallest count
'''
response = requests.post(
self.project.atlas_api_path + "/v1/project/{}/topic_density".format(self.projection.atlas_index_id),
headers=self.project.header,
json={'start': start.isoformat(), 'end': end.isoformat(), 'time_field': time_field},
)
if response.status_code != 200:
raise Exception(response.text)
return response.json()
def vector_search_topics(self, queries: np.array, k: int = 32, depth: int = 3) -> Dict:
'''
Given an embedding, returns a normalized distribution over topics.
Useful for answering the questions such as:
- What topic does my new datapoint belong to?
- Does by datapoint belong to the "Dog" topic or the "Cat" topic.
Args:
queries: a 2d numpy array where each row corresponds to a query vector
k: (Default 32) the number of neighbors to use when estimating the posterior
depth: (Default 3) the topic depth at which you want to search
Returns:
A dict mapping {topic: posterior probability} for each query
'''
if queries.ndim != 2:
raise ValueError(
'Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d).'
)
bytesio = io.BytesIO()
np.save(bytesio, queries)
response = requests.post(
self.project.atlas_api_path + "/v1/project/data/get/embedding/topic",
headers=self.project.header,
json={
'atlas_index_id': self.projection.atlas_index_id,
'queries': base64.b64encode(bytesio.getvalue()).decode('utf-8'),
'k': k,
'depth': depth,
},
)
if response.status_code != 200:
raise Exception(response.text)
return response.json()
def __repr__(self) -> str:
return str(self.df)
class AtlasMapEmbeddings:
"""
Atlas Embeddings State
Access latent (high-dimensional) and projected (two-dimensional) embeddings of your datapoints. High dimensional embeddings
are not immediately downloaded when you access the embeddings attribute.
=== "Accessing Embeddings Example"
``` py
from nomic import AtlasProject
project = AtlasProject(name='My Project')
map = project.maps[0]
print(map.embeddings)
```
=== "Output"
```
id_ x y
0 0A -6.164423 21.517719
1 0g -6.606402 -5.601104
2 0Q -9.206946 7.448542
... ... ... ...
9998 JZQ 2.110881 -12.937058
9999 JZU 7.865006 -6.876243
```
"""
def __init__(self, projection: "AtlasProjection"):
self.projection = projection
self.id_field = self.projection.project.id_field
self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, 'x', 'y'])
self.project = projection.project
@property
def df(self):
"""
Pandas dataframe containing information about embeddings of your datapoints.
Includes only the two-dimensional embeddings
"""
return self.tb.to_pandas()
@property
def tb(self) -> pa.Table:
"""
Pyarrow table containing two-dimensional embeddings of each of your data points.
This table is memmapped from the underlying files and is the most efficient way to
access embedding information.
Does not include high-dimensional embeddings.
"""
return self._tb
@property
def projected(self) -> pd.DataFrame:
"""
Two-dimensional embeddings.
These are the points you see in your web browser.
Returns:
Pandas dataframe mapping your datapoints to their two-dimensional embeddings.
"""
return self.df
@property
def latent(self):
# """
# #TODO
# 1. download embeddings and store it in a fixed location on disk (e.g. .nomic directory)
# 2. make sure the embeddings align with the arrow table download order.
# """
raise NotImplementedError(
"Accessing latent embeddings is not yet implemented. You must use map.download_embeddings() method for now."
)
def vector_search(self, queries: np.array = None, ids: List[str] = None, k: int = 5) -> Dict[str, List]:
'''
Performs semantic vector search over data points on your map.
If ids is specified, receive back the most similar data ids in vector space to your input ids.
If queries is specified, receive back the data ids with representations most similar to the query vectors.
You should not specify both queries and ids.
Args:
queries: a 2d numpy array where each row corresponds to a query vector
ids: a list of ids
k: the number of closest data points (neighbors) to return for each input query/data id
Returns:
A tuple with two elements containing the following information:
neighbors: A set of ids corresponding to the nearest neighbors of each query
distances: A set of distances between each query and its neighbors
'''
if queries is None and ids is None:
raise ValueError('You must specify either a list of datum `ids` or numpy array of `queries` but not both.')
max_k = 128
max_queries = 256
if k > max_k:
raise Exception(f"Cannot query for more than {max_k} nearest neighbors. Set `k` to {max_k} or lower")
if ids is not None:
if len(ids) > max_queries:
raise Exception(f"Max ids per query is {max_queries}. You sent {len(ids)}.")
if queries is not None:
if not isinstance(queries, np.ndarray):
raise Exception("`queries` must be an instance of np.array.")
if queries.shape[0] > max_queries:
raise Exception(f"Max vectors per query is {max_queries}. You sent {queries.shape[0]}.")
if queries is not None:
if queries.ndim != 2:
raise ValueError(
'Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d).'
)
bytesio = io.BytesIO()
np.save(bytesio, queries)
if queries is not None:
response = requests.post(
self.projection.project.atlas_api_path + "/v1/project/data/get/nearest_neighbors/by_embedding",
headers=self.projection.project.header,
json={
'atlas_index_id': self.projection.atlas_index_id,
'queries': base64.b64encode(bytesio.getvalue()).decode('utf-8'),
'k': k,
},
)
else:
response = requests.post(
self.projection.project.atlas_api_path + "/v1/project/data/get/nearest_neighbors/by_id",
headers=self.projection.project.header,
json={'atlas_index_id': self.projection.atlas_index_id, 'datum_ids': ids, 'k': k},
)
if response.status_code == 500:
raise Exception('Cannot perform vector search on your map at this time. Try again later.')
if response.status_code != 200:
raise Exception(response.text)
response = response.json()
return response['neighbors'], response['distances']
def get_embedding_iterator(self) -> Iterable[Tuple[str, str]]:
'''
Iterate through embeddings of your datums.
Returns:
A iterable mapping datum ids to their embeddings.
'''
if self.project.is_locked:
raise Exception('Project is locked! Please wait until the project is unlocked to download embeddings')
offset = 0
limit = EMBEDDING_PAGINATION_LIMIT
while True:
response = requests.get(
self.atlas_api_path
+ f"/v1/project/data/get/embedding/{self.project.id}/{self.projection.atlas_index_id}/{offset}/{limit}",
headers=self.header,
)
if response.status_code != 200:
raise Exception(response.text)
content = response.json()
if len(content['datum_ids']) == 0:
break
offset += len(content['datum_ids'])
yield content['datum_ids'], content['embeddings']
def download_embeddings(self, save_directory: str, num_workers: int = 10) -> bool:
'''
Downloads embeddings to the specified save_directory.
Args:
save_directory: The directory to save your embeddings.
Returns:
True on success
'''
self.project._latest_project_state()
total_datums = self.project.total_datums
if self.project.is_locked:
raise Exception('Project is locked! Please wait until the project is unlocked to download embeddings')
offset = 0
limit = EMBEDDING_PAGINATION_LIMIT
def download_shard(offset, check_access=False):
response = requests.get(
self.project.atlas_api_path
+ f"/v1/project/data/get/embedding/{self.project.id}/{self.projection.atlas_index_id}/{offset}/{limit}",
headers=self.project.header,
)
if response.status_code != 200:
raise Exception(response.text)
if check_access:
return
shard_name = '{}_{}_{}.feather'.format(self.projection.atlas_index_id, offset, offset + limit)
shard_path = os.path.join(save_directory, shard_name)
try:
content = response.content
is_arrow_format = content[:6] == b"ARROW1" and content[-6:] == b"ARROW1"
if not is_arrow_format:
raise Exception('Expected response to be in Arrow IPC format')
with open(shard_path, 'wb') as f:
f.write(content)
except Exception as e:
logger.error('Shard {} download failed with error: {}'.format(shard_name, e))
download_shard(0, check_access=True)
with tqdm(total=total_datums // limit) as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {
executor.submit(download_shard, cur_offset): cur_offset
for cur_offset in range(0, total_datums, limit)
}
for future in concurrent.futures.as_completed(futures):
_ = future.result()
pbar.update(1)
return True
def __repr__(self) -> str:
return str(self.df)
class AtlasMapTags:
"""
Atlas Map Tag State
Tags are shared across all maps in your AtlasProject. You can manipulate tags by filtering over
the associated pandas dataframe
=== "Accessing Tags Example"
``` py
from nomic import AtlasProject
project = AtlasProject(name='My Project')
map = project.maps[0]
print(map.tags)
```
=== "Output"
```
id_ oil search_engines
0 0A 0 0
1 0g 0 0
2 0Q 0 0
3 0w 0 0
4 1A 1 0
... ... ... ...
9998 JZQ 0 0
9999 JZU 0 0
```
"""
def __init__(self, projection: "AtlasProjection"):
self.projection = projection
self.project = projection.project
self.id_field = self.projection.project.id_field
self._tb: pa.Table = projection._fetch_tiles().select([self.id_field])
@property
def df(self) -> pd.DataFrame:
"""
Pandas dataframe mapping each data point to its tags.
=== "Accessing Tags Example"
``` py
from nomic import AtlasProject
project = AtlasProject(name='My Project')
map = project.maps[0]
print(map.tags.df)
```
=== "Output"
```
id_ oil search_engines
0 0A 0 0
1 0g 0 0
2 0Q 0 0
3 0w 0 0
4 1A 1 0
... ... ... ...
9998 JZQ 0 0
9999 JZU 0 0
```
"""
id_frame = self._tb.to_pandas()
tag_to_datums = self.get_tags()
# encoded contains a multi-hot vector withs 1 for all rows that contain that tag
encoded = {key: [] for key in list(tag_to_datums.keys())}
for id in id_frame[self.id_field]:
for key in encoded:
if id in tag_to_datums[key]:
encoded[key].append(1)
else:
encoded[key].append(0)
tag_frame = pandas.DataFrame(encoded)
return pd.concat([id_frame, tag_frame], axis=1)
def get_tags(self) -> Dict[str, List[str]]:
'''
Retrieves back all tags made in the web browser for a specific map
Returns:
A dictionary mapping data points to tags.
'''
# now get the tags
datums_and_tags = requests.post(
self.project.atlas_api_path + '/v1/project/tag/read/all_by_datum',
headers=self.project.header,
json={
'project_id': self.project.id,
},
).json()['results']
label_to_datums = {}
for item in datums_and_tags:
for label in item['labels']:
if label not in label_to_datums:
label_to_datums[label] = set()
label_to_datums[label].add(item['datum_id'])
return label_to_datums
def add(self, ids: List[str], tags: List[str]):
'''
Adds tags to datapoints.
Args:
ids: The datum ids you want to tag
tags: A list containing the tags you want to apply to these data points.
'''
assert isinstance(ids, list), 'ids must be a list of strings'
assert isinstance(tags, list), 'tags must be a list of strings'
colname = json.dumps(
{
'project_id': self.project.id,
'atlas_index_id': self.projection.atlas_index_id,
'type': 'datum_id',
'tags': tags,
}
)
payload_table = pa.table([pa.array(ids, type=pa.string())], [colname])
buffer = io.BytesIO()
writer = ipc.new_file(buffer, payload_table.schema, options=ipc.IpcWriteOptions(compression='zstd'))
writer.write_table(payload_table)
writer.close()
payload = buffer.getvalue()
headers = self.project.header.copy()
headers['Content-Type'] = 'application/octet-stream'
response = requests.post(self.project.atlas_api_path + "/v1/project/tag/add", headers=headers, data=payload)
if response.status_code != 200:
raise Exception("Failed to add tags")
def remove(self, ids: List[str], tags: List[str], delete_all: bool = False) -> bool:
'''
Deletes the specified tags from the given data points.
Args:
ids: The datum_ids to delete tags from.
tags: The list of tags to delete from the data points. Each tag will be applied to all data points in `ids`.
delete_all: If true, ignores ids parameter and deletes all specified tags from all data points.
Returns:
True on success
'''
assert isinstance(ids, list), 'datum_ids must be a list of strings'
assert isinstance(tags, list), 'tags must be a list of strings'
colname = json.dumps(
{
'project_id': self.project.id,
'atlas_index_id': self.projection.atlas_index_id,
'type': 'datum_id',
'tags': tags,
'delete_all': delete_all,
}
)
payload_table = pa.table([pa.array(ids, type=pa.string())], [colname])
buffer = io.BytesIO()
writer = ipc.new_file(buffer, payload_table.schema, options=ipc.IpcWriteOptions(compression='zstd'))
writer.write_table(payload_table)
writer.close()
payload = buffer.getvalue()
headers = self.project.header.copy()
headers['Content-Type'] = 'application/octet-stream'
response = requests.post(self.project.atlas_api_path + "/v1/project/tag/delete", headers=headers, data=payload)
if response.status_code != 200:
raise Exception("Failed to delete tags")
def __repr__(self) -> str:
return str(self.df)