major refactorings WIP

This commit is contained in:
Luke VanderHart
2022-11-24 00:17:24 -05:00
parent bbfde39cd3
commit d9dfd089fb
77 changed files with 1074 additions and 1524 deletions

4
.gitignore vendored
View File

@@ -5,7 +5,7 @@
__pycache__
**/__pycache__
chroma-server/chroma_logs.log
*.log
**/data__nogit
@@ -14,4 +14,4 @@ chroma-server/chroma_logs.log
index_data
/index_data
chroma_logs.log
venv

View File

@@ -1,3 +0,0 @@
### Using core code/modules outside Docker
Adding `127.0.0.1 clickhouse` to /etc/hosts is only necessary if you want to use the app's core code outside of Docker, while still using Clickhouse inside Docker. This is because inside Docker Clickhouse is mapped to `clickhouse` as the network name. Outside, Docker containers are simply mapped to ports on 127.0.0.1. There are two ways to use the core code outside of Docker then: (1) update the url from `clickhouse` to `127.0.01` inside clickhouse.py (FYI chroma-server inside Docker *will* break), or map requests locally on your host machine to `clickhouse` over to Docker's `127.0.0.1`. Adding the line to /etc/hosts fulfills #2. This is outside the usage pattern of normal use, but it's useful to know how and why it can work.

View File

@@ -1,5 +0,0 @@
black:
black --fast chroma-server chroma-client
check_black:
black --check --fast chroma-server chroma-client

View File

@@ -1,54 +1,47 @@
# Chroma
# Chroma Server
This repository is a monorepo containing all the core components of
the Chroma product.
## Development
Contents:
- `/doc` - Project documentation
- `/chroma-client` - Python client for Chroma
- `/chroma-server` - FastAPI server used as the backend for Chroma client
### Get up and running on Linux
No requirements
```
/bin/bash -c "$(curl -fsSL https://gist.githubusercontent.com/jeffchuber/effcbac05021e863bbd634f4b7d0283d/raw/4d38b150809d6ccbc379f88433cadd86c81d32cd/chroma_setup.sh)"
python3 chroma/bin/test.py
```
### Get up and running on Mac
Requirements
- git
- Docker & `docker-compose`
- pip
Set up a virtual environment and install the project's requirements
and dev requirements:
```
/bin/bash -c "$(curl -fsSL https://gist.githubusercontent.com/jeffchuber/27a3cbb28e6521c811da6398346cd35f/raw/55c2d82870436431120a9446b47f19b72d88fa31/chroma_setup_mac.sh)"
python3 chroma/bin/test.py
python3 -m venv venv # Only need to do this once
source venv/bin/activate # Do this each time you use a new shell for the project
pip install -r requirements.txt
pip install -r requirements_dev.txt
```
* These urls will be swapped out for the link in the repo once it is live
To run tests, run `bin/test`. This will run the test suite inside a
docker compose cluster, with the database available, and clean up when complete.
To run the server locally, in development mode, run `uvicorn chroma_server:app --reload`
### You should see something like
## Docker
To build the docker image locally, run `bin/build`.
The version tag of the build is generated by the `bin/version` script,
which uses the `setuptools_scm` library. For full documentation, see
the
[documentation for setuptools_scm](https://github.com/pypa/setuptools_scm/).
In brief, version numbers are generated as follows:
- If the current git head is tagged, the version number is exactly the
tag (e.g, `0.0.1`).
- If the the current git head is a clean checkout, but is not tagged,
the version number is a patch version increment of the most recent
tag, plus `devN` where N is the number of commits since the most
recent tag. For example, if there have been 5 commits since the
`0.0.1` tag, the generated version will be `0.0.2-dev5`.
- If the current head is not a clean checkout, a `-dirty` local
version will be appended to the version number. For example,
`0.0.2-dev5-dirty`.
To run use `docker images` to see what containers and tags you have available:
```
Getting heartbeat to verify the server is up
{'nanosecond heartbeat': 1667865642509760965000}
Logging embeddings into the database
Generating the index
True
Running a nearest neighbor search
{'ids': ['11540ca6-ebbc-4c81-8299-108d8c47c88c'], 'embeddings': [['sample_space', '11540ca6-ebbc-4c81-8299-108d8c47c88c', [1.0, 2.0, 3.0, 4.0, 5.0], '/images/1', 'training', None, 'spoon']], 'distances': [0.0]}
Success! Everything worked!
docker run -p 8000:8000 ghcr.io/chroma-core/chroma-server:<tag name -- eg 0.0.2-dirty>>
```
### Run in-memory Chroma
```
cd chroma-server
CHROMA_MODE="in-memory" uvicorn chroma_server.api:app --reload --log-level=debug
```
This will expose the internal app at `localhost:8000`

View File

@@ -4,16 +4,16 @@
from chroma_client import Chroma
chroma = Chroma()
chroma.set_model_space('sample_space')
chroma.set_model_space("sample_space")
print("Getting heartbeat to verify the server is up")
print(chroma.heartbeat())
print("Logging embeddings into the database")
chroma.add(
[[1,2,3,4,5], [5,4,3,2,1], [10,9,8,7,6]],
["/images/1", "/images/2", "/images/3"],
["training", "training", "training"],
['spoon', 'knife', 'fork']
[[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [10, 9, 8, 7, 6]],
["/images/1", "/images/2", "/images/3"],
["training", "training", "training"],
["spoon", "knife", "fork"],
)
print("count")
@@ -24,6 +24,6 @@ print("Generating the index")
print(chroma.create_index())
print("Running a nearest neighbor search")
print(chroma.get_nearest_neighbors([1,2,3,4,5], 1))
print(chroma.get_nearest_neighbors([1, 2, 3, 4, 5], 1))
print("Success! Everything worked!")

View File

@@ -1,14 +0,0 @@
# general things to ignore
build/
dist/
*.egg-info/
*.egg
*.py[cod]
__pycache__/
*.so
*~
venv
# due to using tox and pytest
.tox
.cache

View File

@@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1 +0,0 @@
# Chroma Client

View File

@@ -1,7 +0,0 @@
# Depenences to test, build and release the project
build
pytest
setuptools_scm
httpx
-e . # Include transitive dependencies and code from the current project

View File

@@ -1,37 +0,0 @@
[build-system]
requires = ["setuptools>=61.0", "setuptools_scm[toml]>=6.2"]
build-backend = "setuptools.build_meta"
[project]
name = "chroma_client"
dynamic = ["version"]
dependencies = [
'pyarrow ~= 9.0',
'requests ~= 2.28',
]
authors = [
{ name="Jeff Huber", email="jeff@trychroma.com" },
{ name="Anton Troynikov", email="anton@trychroma.com" }
]
description = "Chroma."
readme = "README.md"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache-2.0",
"Operating System :: OS Independent",
]
[project.urls]
"Homepage" = "https://github.com/chroma-core/chroma"
"Bug Tracker" = "https://github.com/chroma-core/chroma/issues"
[tool.pytest.ini_options]
pythonpath = [
"src"
]
[tool.setuptools_scm]
root=".."
local_scheme="dirty-tag"

View File

@@ -1 +0,0 @@
from .client import *

View File

@@ -1,18 +0,0 @@
# use pytest to test chroma_client
from chroma_client import Chroma
import pytest
import time
from httpx import AsyncClient
# from ..api import app # this wont work because i moved the file
@pytest.fixture
def anyio_backend():
return "asyncio"
def test_init():
chroma = Chroma()
assert chroma._api_url == "http://localhost:8000/api/v1"

View File

@@ -1,13 +0,0 @@
# general things to ignore
build/
dist/
*.egg-info/
*.egg
*.py[cod]
__pycache__/
*.so
*~
venv
# due to using tox and pytest
.cache

View File

@@ -1,3 +0,0 @@
disable_anonymized_telemetry=False
environment=development
telemetry_anonymized_uuid=f80b11fc-1c5a-4a90-ba35-8c3a3c5371cc

View File

@@ -1,13 +0,0 @@
# general things to ignore
build/
dist/
*.egg-info/
*.egg
*.py[cod]
__pycache__/
*.so
*~
venv
# due to using tox and pytest
.cache

View File

@@ -1,11 +0,0 @@
# pull official base image
FROM --platform=linux/amd64 python:3.10
WORKDIR /chroma-server
COPY ./requirements.txt requirements.txt
RUN pip install --no-cache-dir --upgrade -r requirements.txt
# copy project
COPY . .

View File

@@ -1,47 +0,0 @@
# Chroma Server
## Development
Set up a virtual environment and install the project's requirements
and dev requirements:
```
python3 -m venv venv # Only need to do this once
source venv/bin/activate # Do this each time you use a new shell for the project
pip install -r requirements.txt
pip install -r requirements_dev.txt
```
To run tests, run `bin/test`. This will run the test suite inside a
docker compose cluster, with the database available, and clean up when complete.
To run the server locally, in development mode, run `uvicorn chroma_server:app --reload`
## Docker
To build the docker image locally, run `bin/build`.
The version tag of the build is generated by the `bin/version` script,
which uses the `setuptools_scm` library. For full documentation, see
the
[documentation for setuptools_scm](https://github.com/pypa/setuptools_scm/).
In brief, version numbers are generated as follows:
- If the current git head is tagged, the version number is exactly the
tag (e.g, `0.0.1`).
- If the the current git head is a clean checkout, but is not tagged,
the version number is a patch version increment of the most recent
tag, plus `devN` where N is the number of commits since the most
recent tag. For example, if there have been 5 commits since the
`0.0.1` tag, the generated version will be `0.0.2-dev5`.
- If the current head is not a clean checkout, a `-dirty` local
version will be appended to the version number. For example,
`0.0.2-dev5-dirty`.
To run use `docker images` to see what containers and tags you have available:
```
docker run -p 8000:8000 ghcr.io/chroma-core/chroma-server:<tag name -- eg 0.0.2-dirty>>
```
This will expose the internal app at `localhost:8000`

View File

@@ -1,3 +0,0 @@
from chroma_server.api import app
app = app

View File

@@ -1,8 +0,0 @@
import random
def rand_bisectional_subsample(data):
"""
Randomly bisectionally subsample a list of data to size.
"""
return data.sample(frac=0.5, replace=True, random_state=1)

View File

@@ -1,89 +0,0 @@
import numpy as np
import json
import ast
def class_distances(data):
"""'
This is all very subject to change, so essentially just copy and paste from what we had before
"""
return False
# def unpack_annotations(embeddings):
# annotations = [json.loads(embedding['infer'])["annotations"]for embedding in embeddings]
# annotations = [annotation for annotation_list in annotations for annotation in annotation_list]
# # Unpack embedding data
# embeddings = [embedding["embedding"] for embedding in embeddings]
# embedding_vectors_by_category = {}
# for embedding_annotation_pair in zip(embeddings, annotations):
# data = np.array(embedding_annotation_pair[0])
# category = embedding_annotation_pair[1]['category_id']
# if category in embedding_vectors_by_category.keys():
# embedding_vectors_by_category[category] = np.append(
# embedding_vectors_by_category[category], data[np.newaxis, :], axis=0
# )
# else:
# embedding_vectors_by_category[category] = data[np.newaxis, :]
# return embedding_vectors_by_category
# # Get the training embeddings. This is the set of embeddings belonging to datapoints of the training dataset, and to the first created embedding set.
# object_embedding_vectors_by_category = unpack_annotations(data.to_dict('records'))
# inv_covs_by_category = {}
# means_by_category = {}
# for category, embeddings in object_embedding_vectors_by_category.items():
# print(f"Computing mean and covariance for label category {category}")
# # Compute the mean and inverse covariance for computing MHB distance
# print(f"category: {category} samples: {embeddings.shape[0]}")
# if embeddings.shape[0] < (embeddings.shape[1] + 1):
# print(f"not enough samples for stable covariance in category {category}")
# continue
# cov = np.cov(embeddings.transpose())
# try:
# inv_cov = np.linalg.inv(cov)
# except np.linalg.LinAlgError as err:
# print(f"covariance for category {category} is singular")
# continue
# mean = np.mean(embeddings, axis=0)
# inv_covs_by_category[category] = inv_cov
# means_by_category[category] = mean
# target_datapoints = data.to_dict('records') #+ panda_train_table.to_dict('records')
# output_distances = []
# # Process each datapoint's inferences individually. This is going to be very slow.
# # This is because there is no way to grab the corresponding metadata off the datapoint.
# # We could instead put it on the embedding directly ?
# inference_metadata = {}
# quality_scores = []
# for idx, datapoint in enumerate(target_datapoints):
# inferences = json.loads(datapoint['infer'])["annotations"]
# embeddings = [datapoint["embedding"]]
# for i in range(len(inferences)):
# emb_data = embeddings[i]
# category = inferences[i]["category_id"]
# if not category in inv_covs_by_category.keys():
# output_distances.append({"distance": None, "id": datapoint["id"]})
# continue
# mean = means_by_category[category]
# inv_cov = inv_covs_by_category[category]
# delta = np.array(emb_data) - mean
# squared_mhb = np.sum((delta * np.matmul(inv_cov, delta)), axis=0)
# if squared_mhb < 0:
# print(f"squared distance for category {category} is negative")
# output_distances.append({"distance": None, "id": datapoint["id"]})
# continue
# distance = np.sqrt(squared_mhb)
# quality_scores.append([distance, datapoint])
# inference_metadata[datapoint["input_uri"]] = distance
# output_distances.append({"id": datapoint["id"], "distance": distance})
# if (len(inferences) == 0):
# raise Exception("No inferences found for datapoint")
# return output_distances

View File

@@ -1,55 +0,0 @@
import time
import os
from chroma_server.db.clickhouse import Clickhouse
from chroma_server.db.duckdb import DuckDB
from chroma_server.index.hnswlib import Hnswlib
from chroma_server.utils.error_reporting import init_error_reporting
from chroma_server.utils.telemetry.capture import Capture
from fastapi import FastAPI
chroma_telemetry = Capture()
chroma_telemetry.capture('server-start')
init_error_reporting()
from chroma_server.routes import ChromaRouter
# current valid modes are 'in-memory' and 'docker', it defaults to docker
chroma_mode = os.getenv('CHROMA_MODE', 'docker')
if chroma_mode == 'in-memory':
db = DuckDB
else:
db = Clickhouse
ann_index = Hnswlib
app = FastAPI(debug=True)
# init db and index
app._db = db()
app._ann_index = ann_index()
if chroma_mode == 'in-memory':
filesystem_location = os.getcwd()
# create a dir
if not os.path.exists(filesystem_location + '/.chroma'):
os.makedirs(filesystem_location + '/.chroma')
if not os.path.exists(filesystem_location + '/.chroma/index_data'):
os.makedirs(filesystem_location + '/.chroma/index_data')
# specify where to save and load data from
app._db.set_save_folder(filesystem_location + '/.chroma')
app._ann_index.set_save_folder(filesystem_location + '/.chroma/index_data')
print("Initializing Chroma...")
print("Data will be saved to: " + filesystem_location + '/.chroma')
# if the db exists, load it
if os.path.exists(filesystem_location + '/.chroma/chroma.parquet'):
print(f"Existing database found at {filesystem_location + '/.chroma/chroma.parquet'}. Loading...")
app._db.load()
router = ChromaRouter(app=app, db=db, ann_index=ann_index)
app.include_router(router.router)

View File

@@ -1,60 +0,0 @@
import os
from fastapi import FastAPI
from chroma_server.routes import ChromaRouter
from chroma_server.index.hnswlib import Hnswlib
from chroma_server.db.duckdb import DuckDB
# we import types here so that the user can import them from here
from chroma_server.types import (
ProcessEmbedding, AddEmbedding, FetchEmbedding,
QueryEmbedding, CountEmbedding, DeleteEmbedding,
RawSql, Results, SpaceKeyInput)
core = FastAPI(debug=True)
core._db = DuckDB()
core._ann_index = Hnswlib()
router = ChromaRouter(app=core, db=DuckDB, ann_index=Hnswlib)
core.include_router(router.router)
def init(filesystem_location: str = None):
if filesystem_location is None:
filesystem_location = os.getcwd()
# create a dir
if not os.path.exists(filesystem_location + '/.chroma'):
os.makedirs(filesystem_location + '/.chroma')
if not os.path.exists(filesystem_location + '/.chroma/index_data'):
os.makedirs(filesystem_location + '/.chroma/index_data')
# specify where to save and load data from
core._db.set_save_folder(filesystem_location + '/.chroma')
core._ann_index.set_save_folder(filesystem_location + '/.chroma/index_data')
print("Initializing Chroma...")
print("Data will be saved to: " + filesystem_location + '/.chroma')
# if the db exists, load it
if os.path.exists(filesystem_location + '/.chroma/chroma.parquet'):
print(f"Existing database found at {filesystem_location + '/.chroma/chroma.parquet'}. Loading...")
core._db.load()
core.init = init
# headless mode
core.heartbeat = router.root
core.add = router.add
core.count = router.count
core.fetch = router.fetch
core.reset = router.reset
core.delete = router.delete
core.get_nearest_neighbors = router.get_nearest_neighbors
core.raw_sql = router.raw_sql
core.create_index = router.create_index
# these as currently constructed require celery
# chroma_core.process = process
# chroma_core.get_status = get_status
# chroma_core.get_results = get_results

View File

@@ -1,27 +0,0 @@
from abc import abstractmethod
class Database:
@abstractmethod
def __init__(self):
pass
@abstractmethod
def add(self, model_space, embedding, input_uri, dataset=None, custom_quality_score=None, inference_class=None, label_class=None):
pass
@abstractmethod
def count(self, model_space=None):
pass
@abstractmethod
def fetch(self, where={}, sort=None, limit=None):
pass
@abstractmethod
def get_by_ids(self, ids):
pass
@abstractmethod
def reset(self):
pass

View File

@@ -1,118 +0,0 @@
import duckdb
import uuid
import time
from chroma_server.db.clickhouse import Clickhouse, db_array_schema_to_clickhouse_schema, EMBEDDING_TABLE_SCHEMA, RESULTS_TABLE_SCHEMA, db_schema_to_keys
import pandas as pd
import numpy as np
def clickhouse_to_duckdb_schema(table_schema):
for item in table_schema:
if 'embedding' in item:
item['embedding'] = 'REAL[]'
# capitalize the key
item[list(item.keys())[0]] = item[list(item.keys())[0]].upper()
if 'NULLABLE' in item[list(item.keys())[0]]:
item[list(item.keys())[0]] = item[list(item.keys())[0]].replace('NULLABLE(', '').replace(')', '')
if 'UUID' in item[list(item.keys())[0]]:
item[list(item.keys())[0]] = 'STRING'
if 'FLOAT64' in item[list(item.keys())[0]]:
item[list(item.keys())[0]] = 'REAL'
return table_schema
class DuckDB(Clickhouse):
_save_folder = None
# duckdb has different types, so we want to convert the clickhouse schema to duckdb schema
def _create_table_embeddings(self):
self._conn.execute(f'''CREATE TABLE embeddings (
{db_array_schema_to_clickhouse_schema(clickhouse_to_duckdb_schema(EMBEDDING_TABLE_SCHEMA))}
) ''')
def _create_table_results(self):
self._conn.execute(f'''CREATE TABLE results (
{db_array_schema_to_clickhouse_schema(clickhouse_to_duckdb_schema(RESULTS_TABLE_SCHEMA))}
) ''')
# duckdb has a different way of connecting to the database
def __init__(self):
self._conn = duckdb.connect()
self._create_table_embeddings()
self._create_table_results()
def set_save_folder(self, path):
self._save_folder = path
def get_save_folder(self):
return self._save_folder
# the execute many syntax is different than clickhouse, the (?,?) syntax is different than clickhouse
def add(self, model_space, embedding, input_uri, dataset=None, inference_class=None, label_class=None):
data_to_insert = []
for i in range(len(embedding)):
data_to_insert.append([model_space[i], str(uuid.uuid4()), embedding[i], input_uri[i], dataset[i], inference_class[i], (label_class[i] if label_class is not None else None)])
insert_string = "model_space, uuid, embedding, input_uri, dataset, inference_class, label_class"
self._conn.executemany(f'''
INSERT INTO embeddings ({insert_string}) VALUES (?,?,?,?,?,?,?)''', data_to_insert)
def count(self, model_space=None):
return self._count(model_space=model_space).fetchall()[0][0]
def _fetch(self, where={}, columnar=False):
val = self._conn.execute(f'''SELECT {db_schema_to_keys()} FROM embeddings {where}''').fetchall()
if columnar:
val = list(zip(*val))
return val
def _delete(self, where={}):
uuids_deleted = self._conn.execute(f'''SELECT uuid FROM embeddings {where}''').fetchall()
self._conn.execute(f'''
DELETE FROM
embeddings
{where}
''').fetchall()[0]
return uuids_deleted
def get_by_ids(self, ids=list):
# select from duckdb table where ids are in the list
if not isinstance(ids, list):
raise Exception("ids must be a list")
if not ids:
# create an empty pandas dataframe
return pd.DataFrame()
return self._conn.execute(f'''
SELECT
{db_schema_to_keys()}
FROM
embeddings
WHERE
uuid IN ({','.join([("'" + str(x) + "'") for x in ids])})
''').fetchall()
def persist(self):
'''
Persist the database to disk
'''
if self._conn is None:
return
self._conn.execute(f'''
COPY
(SELECT * FROM embeddings)
TO '{self._save_folder}/chroma.parquet'
(FORMAT PARQUET);
''')
def load(self):
'''
Load the database from disk
'''
path = self._save_folder + "/chroma.parquet"
self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');")
def __del__(self):
self.persist()

View File

@@ -1,27 +0,0 @@
from abc import abstractmethod
class Index:
@abstractmethod
def __init__(self):
pass
@abstractmethod
def run(self, batch):
pass
@abstractmethod
def fetch(self, query):
pass
@abstractmethod
def delete_batch(self, batch):
pass
@abstractmethod
def persist(self):
pass
@abstractmethod
def load(self):
pass

View File

@@ -1,12 +0,0 @@
import logging
def setup_logging():
logging.basicConfig(filename="chroma_logs.log")
logger = logging.getLogger("Chroma")
logger.setLevel(logging.DEBUG)
logger.debug("Logger created")
return logger
logger = setup_logging()

View File

@@ -1,204 +0,0 @@
from celery.result import AsyncResult
import time
import os
from chroma_server.db.clickhouse import Clickhouse, get_col_pos
from chroma_server.db.duckdb import DuckDB
from chroma_server.index.hnswlib import Hnswlib
from chroma_server.types import (AddEmbedding, CountEmbedding, DeleteEmbedding,
FetchEmbedding, ProcessEmbedding,
QueryEmbedding, RawSql, Results,
SpaceKeyInput)
from chroma_server.utils.error_reporting import init_error_reporting
from chroma_server.utils.telemetry.capture import Capture
from chroma_server.worker import heavy_offline_analysis
from fastapi import FastAPI, status, APIRouter
from fastapi.responses import JSONResponse
class ChromaRouter():
_app = None
_db = None
_ann_index = None
_celery_enabed = True
def __init__(self, app: FastAPI, db, ann_index: Hnswlib):
self._app = app
self._db = db
self._ann_index = ann_index
self.router = APIRouter()
self.router.add_api_route("/api/v1", self.root, methods=["GET"])
self.router.add_api_route("/api/v1/add", self.add, methods=["POST"], status_code=status.HTTP_201_CREATED)
self.router.add_api_route("/api/v1/fetch", self.fetch, methods=["POST"])
self.router.add_api_route("/api/v1/delete", self.delete, methods=["POST"])
self.router.add_api_route("/api/v1/count", self.count, methods=["GET"])
self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"])
self.router.add_api_route("/api/v1/raw_sql", self.raw_sql, methods=["POST"])
self.router.add_api_route("/api/v1/get_nearest_neighbors", self.get_nearest_neighbors, methods=["POST"])
self.router.add_api_route("/api/v1/create_index", self.create_index, methods=["POST"])
self.router.add_api_route("/api/v1/process", self.process, methods=["POST"])
self.router.add_api_route("/api/v1/get_status", self.get_status, methods=["POST"])
self.router.add_api_route("/api/v1/get_results", self.get_results, methods=["POST"])
# if the type of the db is not duckdb, then disable celery
if not isinstance(self._db, Clickhouse):
self._celery_enabed = False
# API Endpoints
def root(self):
'''Heartbeat endpoint'''
return {"nanosecond heartbeat": int(1000 * time.time_ns())}
def add(self, new_embedding: AddEmbedding):
'''Save batched embeddings to database'''
number_of_embeddings = len(new_embedding.embedding)
if isinstance(new_embedding.model_space, str):
model_space = [new_embedding.model_space] * number_of_embeddings
elif len(new_embedding.model_space) == 1:
model_space = [new_embedding.model_space[0]] * number_of_embeddings
else:
model_space = new_embedding.model_space
if isinstance(new_embedding.dataset, str):
dataset = [new_embedding.dataset] * number_of_embeddings
elif len(new_embedding.dataset) == 1:
dataset = [new_embedding.dataset[0]] * number_of_embeddings
else:
dataset = new_embedding.dataset
self._app._db.add(
model_space,
new_embedding.embedding,
new_embedding.input_uri,
dataset,
new_embedding.inference_class,
new_embedding.label_class
)
return {"response": "Added records to database"}
def fetch(self, embedding: FetchEmbedding):
'''
Fetches embeddings from the database
- enables filtering by where, sorting by key, and limiting the number of results
'''
return self._app._db.fetch(embedding.where, embedding.sort, embedding.limit, embedding.offset)
def delete(self, embedding: DeleteEmbedding):
'''
Deletes embeddings from the database
- enables filtering by where
'''
deleted_uuids = self._app._db.delete(embedding.where)
if len(embedding.where) == 1:
if 'model_space' in embedding.where:
self._app._ann_index.delete(embedding.where['model_space'])
deleted_uuids = [uuid[0] for uuid in deleted_uuids] # de-tuple
self._app._ann_index.delete_from_index(embedding.where['model_space'], deleted_uuids)
return deleted_uuids
# @app.get("/api/v1/count")
def count(self, model_space: str = None):
'''
Returns the number of records in the database
'''
return {"count": self._app._db.count(model_space=model_space)}
def reset(self):
'''
Reset the database and index - WARNING: Destructive!
'''
index_save_folder = self._app._ann_index.get_save_folder()
db_save_folder = self._app._db.get_save_folder()
self._app._db = self._db()
self._app._db.reset()
self._app._ann_index.reset() # this has to come first I think
self._app._ann_index = self._ann_index()
self._app._ann_index.set_save_folder(index_save_folder)
self._app._db.set_save_folder(db_save_folder)
return True
def get_nearest_neighbors(self,embedding: QueryEmbedding):
'''
return the distance, database ids, and embedding themselves for the input embedding
'''
if embedding.where['model_space'] is None:
return {"error": "model_space is required"}
results = self._app._db.fetch(embedding.where)
ids = [str(item[get_col_pos('uuid')]) for item in results]
uuids, distances = self._app._ann_index.get_nearest_neighbors(embedding.where['model_space'], embedding.embedding, embedding.n_results, ids)
return {
"ids": uuids,
"embeddings": self._app._db.get_by_ids(uuids),
"distances": distances.tolist()[0]
}
def raw_sql(self, raw_sql: RawSql):
return self._app._db.raw_sql(raw_sql.raw_sql)
def create_index(self, process_embedding: ProcessEmbedding):
'''
Currently generates an index for the embedding db
'''
fetch = self._app._db.fetch({"model_space": process_embedding.model_space}, columnar=True)
# chroma_telemetry.capture('created-index-run-process', {'n': len(fetch[2])})
self._app._ann_index.run(process_embedding.model_space, fetch[1], fetch[2]) # more magic number, ugh
def process(self, process_embedding: ProcessEmbedding):
'''
Currently generates an index for the embedding db
'''
if not self._celery_enabed:
raise Exception("in-memory mode does not process because it relies on celery and redis")
fetch = self._app._db.fetch({"model_space": process_embedding.model_space}, columnar=True)
# chroma_telemetry.capture('created-index-run-process', {'n': len(fetch[2])})
self._app._ann_index.run(process_embedding.model_space, fetch[1], fetch[2]) # more magic number, ugh
task = heavy_offline_analysis.delay(process_embedding.model_space)
# chroma_telemetry.capture('heavy-offline-analysis')
return JSONResponse({"task_id": task.id})
def get_status(self, task_id):
if not self._celery_enabed:
raise Exception("in-memory mode does not process because it relies on celery and redis")
task_result = AsyncResult(task_id)
result = {
"task_id": task_id,
"task_status": task_result.status,
"task_result": task_result.result
}
return JSONResponse(result)
def get_results(self, results: Results):
if not self._celery_enabed:
raise Exception("in-memory mode does not process because it relies on celery and redis")
# if there is no index, generate one
if not self._app._ann_index.has_index(results.model_space):
fetch = self._app._db.fetch({"model_space": results.model_space}, columnar=True)
# chroma_telemetry.capture('run-process', {'n': len(fetch[2])})
print("Generating index for model space: ", results.model_space, " with ", len(fetch[2]), " embeddings")
self._app._ann_index.run(results.model_space, fetch[1], fetch[2]) # more magic number, ugh
print("Done generating index for model space: ", results.model_space)
# if there are no results, generate them
print("self._app._db.count_results(results.model_space): ", self._app._db.count_results(results.model_space))
if self._app._db.count_results(results.model_space) == 0:
print("starting heavy offline analysis")
task = heavy_offline_analysis(results.model_space)
print("ending heavy offline analysis")
return self._app._db.return_results(results.model_space, results.n_results)
else:
return self._app._db.return_results(results.model_space, results.n_results)

View File

@@ -1,181 +0,0 @@
import pytest
import time
from httpx import AsyncClient
from ..api import app
@pytest.fixture
def anyio_backend():
return "asyncio"
@pytest.mark.anyio
async def test_root():
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.get("/api/v1")
assert response.status_code == 200
assert isinstance(response.json()["nanosecond heartbeat"], int)
async def post_batch_records(ac):
return await ac.post(
"/api/v1/add",
json={
"embedding": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"input_uri": ["https://example.com", "https://example.com"],
"dataset": ["training", "training"],
"inference_class": ["knife", "person"],
"model_space": ["test_space", "test_space"],
"label_class": ["person", "person"],
},
)
async def post_batch_records_minimal(ac):
return await ac.post(
"/api/v1/add",
json={
"embedding": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"input_uri": ["https://example.com", "https://example.com"],
"dataset": "training",
"inference_class": ["person", "person"],
"model_space": "test_space"
}, #label_class left off on purpose
)
@pytest.mark.anyio
async def test_add_batch():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
response = await post_batch_records(ac)
assert response.status_code == 201
assert response.json() == {"response": "Added records to database"}
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.json() == {"count": 2}
@pytest.mark.anyio
async def test_add_batch_minimal():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
response = await post_batch_records_minimal(ac)
assert response.status_code == 201
assert response.json() == {"response": "Added records to database"}
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.json() == {"count": 2}
@pytest.mark.anyio
async def test_fetch_from_db():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
await post_batch_records(ac)
params = {"where": {"model_space": "test_space"}}
response = await ac.post("/api/v1/fetch", json=params)
assert response.status_code == 200
assert len(response.json()) == 2
@pytest.mark.anyio
async def test_count_from_db():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset") # reset db
await post_batch_records(ac)
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.status_code == 200
assert response.json() == {"count": 2}
@pytest.mark.anyio
async def test_reset_db():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
await post_batch_records(ac)
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.json() == {"count": 2}
response = await ac.post("/api/v1/reset")
assert response.json() == True
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.json() == {"count": 0}
@pytest.mark.anyio
async def test_get_nearest_neighbors():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
await post_batch_records(ac)
await ac.post("/api/v1/create_index", json={"model_space": "test_space"})
response = await ac.post(
"/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "where":{"model_space": "test_space"}}
)
assert response.status_code == 200
assert len(response.json()["ids"]) == 1
@pytest.mark.anyio
async def test_get_nearest_neighbors_filter():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
await post_batch_records(ac)
await ac.post("/api/v1/create_index", json={"model_space": "test_space"})
response = await ac.post(
"/api/v1/get_nearest_neighbors",
json={
"embedding": [1.1, 2.3, 3.2],
"n_results": 1,
"where":{
"dataset": "training",
"inference_class": "monkey",
"model_space": "test_space",
}
},
)
assert response.status_code == 200
assert len(response.json()["ids"]) == 0
@pytest.mark.anyio
async def test_process():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
await post_batch_records(ac)
response = await ac.post("/api/v1/create_index", json={"model_space": "test_space"})
assert response.status_code == 200
# test delete
@pytest.mark.anyio
async def test_delete():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
await post_batch_records(ac)
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.json() == {"count": 2}
response = await ac.post("/api/v1/delete", json={"where": {"model_space": "test_space"}})
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.json() == {"count": 0}
@pytest.mark.anyio
async def test_delete_with_index():
async with AsyncClient(app=app, base_url="http://test") as ac:
await ac.post("/api/v1/reset")
await post_batch_records(ac)
response = await ac.get("/api/v1/count", params={"model_space": "test_space"})
assert response.json() == {"count": 2}
await ac.post("/api/v1/create_index", json={"model_space": "test_space"})
response = await ac.post(
"/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "where":{"model_space": "test_space"}}
)
assert response.json()['embeddings'][0][5] == 'knife'
response = await ac.post("/api/v1/delete", json={"where": {"model_space": "test_space", "inference_class": "knife"}})
response = await ac.post(
"/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "where":{"model_space": "test_space"}}
)
assert response.json()['embeddings'][0][5] == 'person'
# test calculate results
# @pytest.mark.anyio
# async def test_calculate_results():
# async with AsyncClient(app=app, base_url="http://test") as ac:
# await ac.post("/api/v1/reset")
# await post_batch_records(ac)
# await ac.post("/api/v1/process", json={"model_space": "test_space"})
# response = await ac.post(
# "/api/v1/calculate_results",
# json={
# "model_space": "test_space",
# },
# )
# assert response.status_code == 200
# assert response.json() == {"ids": [], "distances": []}

View File

@@ -1,15 +0,0 @@
from functools import lru_cache
from typing import Union
from pydantic import BaseSettings
class Settings(BaseSettings):
disable_anonymized_telemetry: bool = False
telemetry_anonymized_uuid: str = ''
environment: str = 'development'
class Config:
env_file = ".env"
@lru_cache()
def get_settings():
return Settings()

View File

@@ -1,39 +0,0 @@
version: '3.9'
networks:
chroma-net:
driver: bridge
services:
server_test:
build:
context: .
dockerfile: Dockerfile
target: chroma_server_test
volumes:
- ./:/chroma-server/
- index_data:/index_data
depends_on:
- clickhouse
networks:
- chroma-net
environment:
- CLICKHOUSE_TCP_PORT=9001
- environment=test
clickhouse:
image: docker.io/bitnami/clickhouse:22.9
environment:
- ALLOW_EMPTY_PASSWORD=yes
- CLICKHOUSE_TCP_PORT=9001
- CLICKHOUSE_HTTP_PORT=8124
ports:
- '8124:8124'
- '9001:9001'
networks:
- chroma-net
volumes:
index_data:
driver: local

View File

@@ -1,20 +0,0 @@
[project]
name = "chroma_server"
dynamic = ["version"]
authors = [
{ name="Jeff Huber", email="jeff@trychroma.com" },
{ name="Anton Troynikov", email="anton@trychroma.com" }
]
description = "Chroma."
readme = "README.md"
requires-python = ">=3.7"
[project.urls]
"Homepage" = "https://github.com/chroma-core/chroma"
"Bug Tracker" = "https://github.com/chroma-core/chroma/issues"
[tool.setuptools_scm]
root=".."
local_scheme="no-local-version"

View File

@@ -1,2 +0,0 @@
python -m pytest
CHROMA_MODE=in-memory python -m pytest

48
chroma/__init__.py Normal file
View File

@@ -0,0 +1,48 @@
import chroma.config
__settings = chroma.config.Settings()
def configure(**kwags):
"""Override Chroma's default settings, environment variables or .env files"""
__settings = chroma.config.Settings(**kwargs)
def get_settings():
return __settings
def get_db(settings=__settings):
"""Return a chroma.DB instance based on the provided or environmental settings."""
if settings.clickhouse_host:
print("Using Clickhouse for database")
import chroma.db.clickhouse
return chroma.db.clickhouse.Clickhouse(settings)
elif settings.chroma_cache_dir:
print("Using DuckDB with local filesystem persistence for database")
import chroma.db.duckdb
return chroma.db.duckdb.PersistentDuckDB(settings)
else:
print("Using DuckDB in-memory for database. Data will be transient.")
import chroma.db.duckdb
return chroma.db.duckdb.DuckDB(settings)
def get_api(settings=__settings):
"""Return a chroma.API instance based on the provided or environmental
settings, optionally overriding the DB instance."""
if settings.chroma_server_host and settings.chroma_server_grpc_port:
print("Running Chroma in client/server mode using ArrowFlight protocol.")
import chroma.api.arrowflight
return chroma.api.arrowflight.ArrowFlightAPI(settings)
elif settings.chroma_server_host and settings.chroma_server_http_port:
print("Running Chroma in client/server mode using REST protocol.")
import chroma.api.fastapi
return chroma.api.fastapi.FastAPI(settings)
elif settings.celery_broker_url:
print("Running Chroma in server mode with Celery jobs enabled.")
import chroma.api.celery
return chroma.api.celery.CeleryAPI(settings, get_db(settings))
else:
print("Running Chroma using direct local API.")
import chroma.api.local
return chroma.api.local.LocalAPI(settings, get_db(settings))

View File

@@ -1,67 +1,55 @@
import requests
import json
from abc import ABC, abstractmethod
from typing import Union
class Chroma:
class API(ABC):
_api_url = "http://localhost:8000/api/v1"
_model_space = "default_scope"
_model_space = 'default_scope'
def __init__(self, url=None, model_space=None):
"""Initialize Chroma client"""
@abstractmethod
def __init__(self):
pass
if isinstance(url, str) and url.startswith("http"):
self._api_url = url
if isinstance(model_space, str) and model_space:
self._model_space = model_space
def set_model_space(self, model_space):
'''Sets the space key for the client, enables overriding the string concat'''
self._model_space = model_space
def get_model_space(self):
'''Returns the model_space key'''
return self._model_space
@abstractmethod
def heartbeat(self):
'''Returns the current server time in nanoseconds to check if the server is alive'''
return requests.get(self._api_url).json()
pass
@abstractmethod
def add(self,
embedding: list,
input_uri: list,
dataset: list = None,
inference_class: list = None,
label_class: list = None,
model_spaces: list = None):
"""Add embeddings to the data store"""
pass
@abstractmethod
def count(self, model_space=None):
'''Returns the number of embeddings in the database'''
params = {"model_space": model_space or self._model_space}
x = requests.get(self._api_url + "/count", params=params)
return x.json()
pass
@abstractmethod
def fetch(self, where={}, sort=None, limit=None, offset=None, page=None, page_size=None):
'''Fetches embeddings from the database'''
if self._model_space:
where["model_space"] = self._model_space
pass
if page and page_size:
offset = (page - 1) * page_size
limit = page_size
return requests.post(self._api_url + "/fetch", data=json.dumps({
"where":where,
"sort":sort,
"limit":limit,
"offset":offset
})).json()
@abstractmethod
def delete(self, where={}):
'''Deletes embeddings from the database'''
if self._model_space:
where["model_space"] = self._model_space
pass
return requests.post(self._api_url + "/delete", data=json.dumps({
"where":where,
})).json()
def add(self,
embedding: list,
input_uri: list,
@abstractmethod
def add(self,
embedding: list,
input_uri: list,
dataset: list = None,
inference_class: list = None,
label_class: list = None,
@@ -70,107 +58,113 @@ class Chroma:
Addss a batch of embeddings to the database
- pass in column oriented data lists
'''
pass
if not model_spaces:
model_spaces = self._model_space
x = requests.post(self._api_url + "/add", data = json.dumps({
"model_space": model_spaces,
"embedding": embedding,
"input_uri": input_uri,
"dataset": dataset,
"inference_class": inference_class,
"label_class": label_class
}) )
@abstractmethod
def get_nearest_neighbors(self, embedding, n_results=10, where={}):
'''Gets the nearest neighbors of a single embedding'''
pass
@abstractmethod
def process(self, model_space=None):
'''
Processes embeddings in the database
- currently this only runs hnswlib, doesnt return anything
'''
pass
@abstractmethod
def reset(self):
'''Resets the database'''
pass
@abstractmethod
def raw_sql(self, sql):
'''Runs a raw SQL query against the database'''
pass
@abstractmethod
def get_results(self, model_space=None, n_results = 100):
'''Gets the results for the given space key'''
pass
@abstractmethod
def get_task_status(self, task_id):
'''Gets the status of a task'''
pass
@abstractmethod
def create_index(self, model_space=None):
'''Creates an index for the given space key'''
pass
def set_model_space(self, model_space):
'''Sets the space key for the client, enables overriding the string concat'''
self._model_space = model_space
def get_model_space(self):
'''Returns the model_space key'''
return self._model_space
def where_with_model_space(self, where_clause):
'''Returns a where clause that specifies the default model space iff it wasn't already specified'''
if self._model_space and "model_space" not in where_clause:
where_clause["model_space"] = self._model_space
return where_clause
if x.status_code == 201:
return True
else:
return False
def add_training(self, embedding: list, input_uri: list, inference_class: list, label_class: list = None, model_spaces: list = None):
'''
Small wrapper around add() to add a batch of training embedding - sets dataset to "training"
'''
datasets = ["training"] * len(input_uri)
return self.add(
embedding=embedding,
input_uri=input_uri,
embedding=embedding,
input_uri=input_uri,
dataset=datasets,
inference_class=inference_class,
model_spaces=model_spaces,
label_class=label_class
)
def add_production(self, embedding: list, input_uri: list, inference_class: list, label_class: list = None, model_spaces: list = None):
'''
Small wrapper around add() to add a batch of production embedding - sets dataset to "production"
'''
datasets = ["production"] * len(input_uri)
return self.add(
embedding=embedding,
input_uri=input_uri,
embedding=embedding,
input_uri=input_uri,
dataset=datasets,
inference_class=inference_class,
model_spaces=model_spaces,
label_class=label_class
)
def add_triage(self, embedding: list, input_uri: list, inference_class: list, label_class: list = None, model_spaces: list = None):
'''
Small wrapper around add() to add a batch of triage embedding - sets dataset to "triage"
'''
datasets = ["triage"] * len(input_uri)
return self.add(
embedding=embedding,
input_uri=input_uri,
embedding=embedding,
input_uri=input_uri,
dataset=datasets,
inference_class=inference_class,
model_spaces=model_spaces,
label_class=label_class
)
def get_nearest_neighbors(self, embedding, n_results=10, where={}):
'''Gets the nearest neighbors of a single embedding'''
if "model_space" not in where:
where["model_space"] = self._model_space
x = requests.post(self._api_url + "/get_nearest_neighbors", data = json.dumps({
"embedding": embedding,
"n_results": n_results,
"where": where
}) )
if x.status_code == 200:
return x.json()
else:
return False
def process(self, model_space=None):
'''
Processes embeddings in the database
- currently this only runs hnswlib, doesnt return anything
'''
x = requests.post(self._api_url + "/process", data = json.dumps({"model_space": model_space or self._model_space}))
return x.json()
def reset(self):
'''Resets the database'''
return requests.post(self._api_url + "/reset")
def raw_sql(self, sql):
'''Runs a raw SQL query against the database'''
return requests.post(self._api_url + "/raw_sql", data = json.dumps({"raw_sql": sql})).json()
def get_results(self, model_space=None, n_results = 100):
'''Gets the results for the given space key'''
return requests.post(self._api_url + "/get_results", data = json.dumps({"model_space": model_space or self._model_space, "n_results": n_results})).json()
def get_task_status(self, task_id):
'''Gets the status of a task'''
return requests.post(self._api_url + f"/tasks/{task_id}").json()
def create_index(self, model_space=None):
'''Creates an index for the given space key'''
return requests.post(self._api_url + "/create_index", data = json.dumps({"model_space": model_space or self._model_space})).json()

View File

@@ -0,0 +1,9 @@
from chroma.api import API
class ArrowFlightAPI(API):
def __init__(self, settings):
print("Constructing Local instance")
# TODO: Implement

47
chroma/api/celery.py Normal file
View File

@@ -0,0 +1,47 @@
from chroma.api.local import LocalAPI
from chroma.worker import heavy_offline_analysis
from celery.result import AsyncResult
class CeleryAPI(LocalAPI):
def __init__(self, settings):
pass
def process(self, model_space=None):
self.create_index(model_space)
task = heavy_offline_analysis.delay(model_space)
# chroma_telemetry.capture('heavy-offline-analysis')
return task.id
def get_status(self, task_id):
task_result = AsyncResult(task_id)
result = {
"task_id": task_id,
"task_status": task_result.status,
"task_result": task_result.result
}
return result
def get_results(self, model_space=None, n_results=100):
model_space = model_space or self._model_space
if not self._db.has_index(model_space):
self._db.create_index(model_space)
results_count = self._db.count_results(model_space)
if results_count == 0:
heavy_offline_analysis(model_space)
return self._db.return_results(model_space, n_results)

117
chroma/api/fastapi.py Normal file
View File

@@ -0,0 +1,117 @@
from chroma.api import API
import requests
class FastAPI(API):
def __init__(self, settings):
self._api_url = f'http://{settings.chroma_server_host}:{settings.chroma_server_http_port}/api/v1'
def heartbeat(self):
'''Returns the current server time in nanoseconds to check if the server is alive'''
return int(requests.get(self._api_url).json()['nanosecond heartbeat'])
def count(self, model_space=None):
'''Returns the number of embeddings in the database'''
params = {"model_space": model_space or self._model_space}
x = requests.get(self._api_url + "/count", params=params)
return x.json()['count']
def fetch(self, where={}, sort=None, limit=None, offset=None, page=None, page_size=None):
'''Fetches embeddings from the database'''
where = self.where_with_model_space(where)
if page and page_size:
offset = (page - 1) * page_size
limit = page_size
return requests.post(self._api_url + "/fetch", data=json.dumps({
"where":where,
"sort":sort,
"limit":limit,
"offset":offset
})).json()
def delete(self, where={}):
'''Deletes embeddings from the database'''
where = self.where_with_model_space(where)
return requests.post(self._api_url + "/delete", data=json.dumps({
"where":where,
})).json()
def add(self,
embedding: list,
input_uri: list,
dataset: list = None,
inference_class: list = None,
label_class: list = None,
model_spaces: list = None):
'''
Addss a batch of embeddings to the database
- pass in column oriented data lists
'''
if not model_spaces:
model_spaces = self._model_space
x = requests.post(self._api_url + "/add", data = json.dumps({
"model_space": model_spaces,
"embedding": embedding,
"input_uri": input_uri,
"dataset": dataset,
"inference_class": inference_class,
"label_class": label_class
}) )
if x.status_code == 201:
return True
else:
return False
def get_nearest_neighbors(self, embedding, n_results=10, where={}):
'''Gets the nearest neighbors of a single embedding'''
where = self.where_with_model_space(where)
x = requests.post(self._api_url + "/get_nearest_neighbors", data = json.dumps({
"embedding": embedding,
"n_results": n_results,
"where": where
}) )
if x.status_code == 200:
return x.json()
else:
return False
def process(self, model_space=None):
'''
Processes embeddings in the database
- currently this only runs hnswlib, doesnt return anything
'''
x = requests.post(self._api_url + "/process", data = json.dumps({"model_space": model_space or self._model_space}))
return x.json()
def reset(self):
'''Resets the database'''
return requests.post(self._api_url + "/reset")
def raw_sql(self, sql):
'''Runs a raw SQL query against the database'''
return requests.post(self._api_url + "/raw_sql", data = json.dumps({"raw_sql": sql})).json()
def get_results(self, model_space=None, n_results = 100):
'''Gets the results for the given space key'''
return requests.post(self._api_url + "/get_results",
data = json.dumps({"model_space": model_space or self._model_space, "n_results": n_results})).json()
def get_task_status(self, task_id):
'''Gets the status of a task'''
return requests.post(self._api_url + f"/tasks/{task_id}").json()
def create_index(self, model_space=None):
'''Creates an index for the given space key'''
return requests.post(self._api_url + "/create_index",
data = json.dumps({"model_space": model_space or self._model_space})).json()

115
chroma/api/local.py Normal file
View File

@@ -0,0 +1,115 @@
import time
from chroma.api import API
class LocalAPI(API):
def __init__(self, settings, db):
self._db = db
def heartbeat(self):
return int(1000 * time.time_ns())
def add(self,
embedding: list,
input_uri: list,
dataset: list = None,
inference_class: list = None,
label_class: list = None,
model_spaces: list = None):
number_of_embeddings = len(embedding)
if isinstance(model_spaces, str):
model_space = [model_spaces] * number_of_embeddings
elif len(model_spaces) == 1:
model_space = [model_spaces[0]] * number_of_embeddings
else:
model_space = model_spaces
if isinstance(dataset, str):
ds = [dataset] * number_of_embeddings
elif len(dataset) == 1:
ds = [dataset[0]] * number_of_embeddings
else:
ds = dataset
self._db.add(
model_space,
embedding,
input_uri,
ds,
inference_class,
label_class
)
def fetch(self, where={}, sort=None, limit=None, offset=None, page=None, page_size=None):
if page and page_size:
offset = (page - 1) * page_size
limit = page_size
return self._db.fetch(where, sort, limit, offset)
def delete(self, where={}):
where = self.where_with_model_space(where)
deleted_uuids = self._db.delete(where)
return deleted_uuids
def count(self, model_space=None):
model_space = model_space or self._model_space
return {"count": self._db.count(model_space=model_space)}
def reset(self):
self._db.reset()
return True
def get_nearest_neighbors(self, embedding, n_results, where):
where = self.where_with_model_space(where)
results = self._db.fetch(where)
ids = [str(item[get_col_pos('uuid')]) for item in results]
uuids, distances = self._db.get_nearest_neighbors(where['model_space'], embedding, n_results, ids)
return {
"ids": uuids,
"embeddings": self._db.get_by_ids(uuids),
"distances": distances.tolist()[0]
}
def raw_sql(self, raw_sql):
return self._db.raw_sql(raw_sql)
def create_index(self, model_space=None):
self._db.create_index(model_space or self._model_space)
return True
def process(self, model_space=None):
raise NotImplementedError("Cannot launch job: Celery is not configured")
def get_status(self, task_id):
raise NotImplementedError("Cannot get status of job: Celery is not configured")
def get_results(self, model_space=None, n_results=100):
raise NotImplementedError("Cannot get job results: Celery is not configured")

24
chroma/config.py Normal file
View File

@@ -0,0 +1,24 @@
from pydantic import BaseSettings, Field
class Settings(BaseSettings):
disable_anonymized_telemetry: bool = False
telemetry_anonymized_uuid: str = ""
environment: str = ""
clickhouse_host: str = None
clickhouse_port: str = None
celery_broker_url: str = None
celery_result_backend: str = None
chroma_cache_dir: str = None
chroma_server_host: str = None
chroma_server_http_port: str = None
chroma_server_grpc_port: str = None
class Config:
env_file = '.env'
env_file_encoding = 'utf-8'

75
chroma/db/__init__.py Normal file
View File

@@ -0,0 +1,75 @@
from abc import ABC, abstractmethod
class DB(ABC):
@abstractmethod
def __init__(self):
pass
@abstractmethod
def add(self, model_space, embedding, input_uri, dataset=None, custom_quality_score=None, inference_class=None, label_class=None):
pass
@abstractmethod
def fetch(self, where, sort, limit, offset, columnar):
pass
@abstractmethod
def delete(self, where):
pass
@abstractmethod
def reset(self):
pass
@abstractmethod
def get_nearest_neighbors(self, where, embedding, n_results, ids):
pass
@abstractmethod
def get_by_ids(self, uuids):
pass
@abstractmethod
def raw_sql(self, raw_sql):
pass
@abstractmethod
def create_index(self, model_space):
pass
@abstractmethod
def has_index(self, model_space):
pass
@abstractmethod
def count_results(self, model_space):
pass
@abstractmethod
def return_results(self, model_space, n_results):
pass
@abstractmethod
def delete_results(self, model_space):
pass
@abstractmethod
def add_results(self, model_space, uuids, quality_scores):
pass
@abstractmethod
def get_col_pos(self, col_name):
pass

View File

@@ -1,4 +1,5 @@
from chroma_server.db.abstract import Database
from chroma.db import DB
from chroma.db.index.hnswlib import Hnswlib
import uuid
import time
import os
@@ -37,14 +38,12 @@ def db_schema_to_keys():
return_str += f"{list(element.keys())[0]}, "
return return_str
def get_col_pos(col_name):
for i, col in enumerate(EMBEDDING_TABLE_SCHEMA):
if col_name in col:
return i
class Clickhouse(Database):
class Clickhouse(DB):
_conn = None
def _create_table_embeddings(self):
self._conn.execute(f'''CREATE TABLE IF NOT EXISTS embeddings (
{db_array_schema_to_clickhouse_schema(EMBEDDING_TABLE_SCHEMA)}
@@ -52,17 +51,20 @@ class Clickhouse(Database):
self._conn.execute(f'''SET allow_experimental_lightweight_delete = true''')
self._conn.execute(f'''SET mutations_sync = 1''') # https://clickhouse.com/docs/en/operations/settings/settings/#mutations_sync
def _create_table_results(self):
self._conn.execute(f'''CREATE TABLE IF NOT EXISTS results (
{db_array_schema_to_clickhouse_schema(RESULTS_TABLE_SCHEMA)}
) ENGINE = MergeTree() ORDER BY model_space''')
def __init__(self):
client = Client(host='clickhouse', port=os.getenv('CLICKHOUSE_TCP_PORT', '9000'))
self._conn = client
def __init__(self, settings):
self._conn = Client(host=settings.clickhouse_host, port=clickhouse_port)
self._create_table_embeddings()
self._create_table_results()
self._idx = Hnswlib()
def add(self, model_space, embedding, input_uri, dataset=None, inference_class=None, label_class=None):
data_to_insert = []
@@ -74,18 +76,11 @@ class Clickhouse(Database):
self._conn.execute(f'''
INSERT INTO embeddings ({insert_string}) VALUES''', data_to_insert)
def _count(self, model_space=None):
where_string = ""
if model_space is not None:
where_string = f"WHERE model_space = '{model_space}'"
return self._conn.execute(f"SELECT COUNT() FROM embeddings {where_string}")
def count(self, model_space=None):
return self._count(model_space=model_space)[0][0]
def _fetch(self, where={}, columnar=False):
return self._conn.execute(f'''SELECT {db_schema_to_keys()} FROM embeddings {where}''', columnar=columnar)
def fetch(self, where={}, sort=None, limit=None, offset=None, columnar=False):
if where["model_space"] is None:
return {"error": "model_space is required"}
@@ -95,12 +90,12 @@ class Clickhouse(Database):
if where is not None:
if not isinstance(where, dict):
raise Exception("Invalid where: " + str(where))
# ensure where is a flat dict
for key in where:
if isinstance(where[key], dict):
raise Exception("Invalid where: " + str(where))
where = " AND ".join([f"{key} = '{value}'" for key, value in where.items()])
if where:
@@ -113,7 +108,7 @@ class Clickhouse(Database):
if limit is not None or isinstance(limit, int):
where += f" LIMIT {limit}"
if offset is not None or isinstance(offset, int):
where += f" OFFSET {offset}"
@@ -125,7 +120,7 @@ class Clickhouse(Database):
def _delete(self, where={}):
uuids_deleted = self._conn.execute(f'''SELECT toString(uuid) FROM embeddings {where}''')
self._conn.execute(f'''
DELETE FROM
DELETE FROM
embeddings
{where}
''')
@@ -140,35 +135,56 @@ class Clickhouse(Database):
if where is not None:
if not isinstance(where, dict):
raise Exception("Invalid where: " + str(where))
# ensure where is a flat dict
for key in where:
if isinstance(where[key], dict):
raise Exception("Invalid where: " + str(where))
where = " AND ".join([f"{key} = '{value}'" for key, value in where.items()])
if where:
where = f"WHERE {where}"
val = self._delete(where=where)
deleted_uuids = self._delete(where=where)
print(f"time to fetch {len(val)} embeddings: ", time.time() - s3)
return val
if len(where) == 1:
self._idx.delete(where['model_space'])
self._idx.delete_from_index(where['model_space'], [uuid[0] for uuid in deleted_uuids])
return deleted_uuids
def get_by_ids(self, ids=list):
return self._conn.execute(f'''
SELECT {db_schema_to_keys()} FROM embeddings WHERE uuid IN ({ids})''')
def create_index(self, model_space):
fetch = self.fetch({"model_space": model_space}, columnar=True)
self._idx.run(model_space, fetch[1], fetch[2]) # more magic number, ugh
def has_index(self, model_space):
return self._idx.has_index(self, model_space)
def reset(self):
self._conn.execute('DROP TABLE embeddings')
self._conn.execute('DROP TABLE results')
self._create_table_embeddings()
self._create_table_results()
self._idx.reset()
self._idx = Hnswlib()
def raw_sql(self, sql):
return self._conn.execute(sql)
def add_results(self, model_spaces, uuids, custom_quality_score):
data_to_insert = []
for i in range(len(model_spaces)):
@@ -176,16 +192,19 @@ class Clickhouse(Database):
self._conn.execute('''
INSERT INTO results (model_space, uuid, custom_quality_score) VALUES''', data_to_insert)
def delete_results(self, model_space):
self._conn.execute(f"DELETE FROM results WHERE model_space = '{model_space}'")
def count_results(self, model_space=None):
where_string = ""
if model_space is not None:
where_string = f"WHERE model_space = '{model_space}'"
return self._conn.execute(f"SELECT COUNT() FROM results {where_string}")[0][0]
def return_results(self, model_space, n_results = 100):
return self._conn.execute(f'''
SELECT
@@ -205,8 +224,9 @@ class Clickhouse(Database):
LIMIT {n_results}
''')
def set_save_folder(self, path):
pass
def get_save_folder(self):
pass
def get_col_pos(col_name):
for i, col in enumerate(EMBEDDING_TABLE_SCHEMA):
if col_name in col:
return i

12
chroma/db/duckdb.py Normal file
View File

@@ -0,0 +1,12 @@
from chroma.db import DB
class DuckDB(DB):
def __init__(self, settings):
pass
class PersistentDuckDB(DuckDB):
def __init__(self, settings):
pass

View File

@@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
class Index(ABC):
@abstractmethod
def __init__(self, settings):
pass
@abstractmethod
def delete(self, model_space):
pass
@abstractmethod
def delete_from_index(self, model_space, uuids):
pass
@abstractmethod
def reset(self):
pass
@abstractmethod
def run(self, model_space, uuids, embeddings):
pass
@abstractmethod
def has_index(self, model_space):
pass

View File

@@ -4,13 +4,11 @@ import time
import hnswlib
import numpy as np
from chroma_server.index.abstract import Index
from chroma_server.logger import logger
from chroma.db.index import Index
class Hnswlib(Index):
_save_folder = '/index_data'
_model_space = None
_index = None
_index_metadata = {
@@ -22,28 +20,28 @@ class Hnswlib(Index):
_id_to_uuid = {}
_uuid_to_id = {}
def __init__(self):
pass
# set the save folder
def set_save_folder(self, save_folder):
self._save_folder = save_folder
def __init__(self, settings):
self._save_folder = settings.chroma_cache_dir + "/index"
def get_save_folder(self):
return self._save_folder
def run(self, model_space, uuids, embeddings, space='l2', ef=10, num_threads=4):
def run(self, model_space, uuids, embeddings):
space = 'l2'
ef=10
num_threads='4'
# more comments available at the source: https://github.com/nmslib/hnswlib
dimensionality = len(embeddings[0])
for uuid, i in zip(uuids, range(len(uuids))):
self._id_to_uuid[i] = str(uuid)
self._uuid_to_id[str(uuid)] = i
index = hnswlib.Index(space=space, dim=dimensionality) # possible options are l2, cosine or ip
index.init_index(max_elements=len(embeddings), ef_construction=100, M=16)
index.set_ef(ef)
index.set_num_threads(num_threads)
index.init_index(max_elements=len(embeddings), ef_construction=100, M=16)
index.set_ef(ef)
index.set_num_threads(num_threads)
index.add_items(embeddings, range(len(uuids)))
self._index = index
@@ -53,7 +51,8 @@ class Hnswlib(Index):
'elements': len(embeddings) ,
'time_created': time.time(),
}
self.save()
self._save()
def delete(self, model_space):
# delete files, dont throw error if they dont exist
@@ -72,9 +71,10 @@ class Hnswlib(Index):
self._id_to_uuid = {}
self._uuid_to_id = {}
def delete_from_index(self, model_space, uuids):
if self._model_space != model_space:
self.load(model_space)
self._load(model_space)
if self._index is not None:
for uuid in uuids:
@@ -82,9 +82,11 @@ class Hnswlib(Index):
del self._id_to_uuid[self._uuid_to_id[uuid]]
del self._uuid_to_id[uuid]
self.save()
def save(self):
self._save()
def _save():
# create the directory if it doesn't exist
if not os.path.exists(f'{self._save_folder}'):
os.makedirs(f'{self._save_folder}')
@@ -103,7 +105,9 @@ class Hnswlib(Index):
logger.debug('Index saved to {self._save_folder}/index.bin')
def load(self, model_space):
def _load(self, model_space):
# unpickle the mappers
try:
with open(f"{self._save_folder}/id_to_uuid_{model_space}.pkl", 'rb') as f:
@@ -121,14 +125,15 @@ class Hnswlib(Index):
except:
logger.debug('Index not found')
def has_index(self, model_space):
return os.path.isfile(f"{self._save_folder}/index_{model_space}.bin")
# do knn_query on hnswlib to get nearest neighbors
def get_nearest_neighbors(self, model_space, query, k, uuids=None):
if self._model_space != model_space:
self.load(model_space)
self._load(model_space)
s2= time.time()
# get ids from uuids as a set, if they are available
@@ -137,7 +142,7 @@ class Hnswlib(Index):
ids = {self._uuid_to_id[uuid] for uuid in uuids}
if len(ids) < k :
k = len(ids)
filter_function = None
if len(ids) != 0:
filter_function = lambda id: id in ids
@@ -149,13 +154,16 @@ class Hnswlib(Index):
logger.debug(f'time to run knn query: {time.time() - s3}')
uuids = [self._id_to_uuid[id] for id in database_ids[0]]
return uuids, distances
def reset(self):
if os.path.exists(f'{self._save_folder}'):
for f in os.listdir(f'{self._save_folder}'):
os.remove(os.path.join(f'{self._save_folder}', f))
# recreate the directory
if not os.path.exists(f'{self._save_folder}'):
os.makedirs(f'{self._save_folder}')
os.makedirs(f'{self._save_folder}')

11
chroma/server/__init__.py Normal file
View File

@@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
from chroma.utils.error_reporting import init_error_reporting
from chroma.server.utils.telemetry.capture import Capture
class Server(ABC):
def __init__(self):
self._chroma_telemetry = Capture()
self._chroma_telemetry.capture('server-start')
init_error_reporting()

View File

@@ -0,0 +1,9 @@
import chroma.server
class ArrowFlight(chroma.server.Server):
def __init__(self):
super().__init__()
pass
#TODO: Implement

83
chroma/server/fastapi.py Normal file
View File

@@ -0,0 +1,83 @@
import fastapi
from fastapi.responses import JSONResponse
import chroma
import chroma.server
from chroma.server.fastapi.types import (AddEmbedding, CountEmbedding, DeleteEmbedding,
FetchEmbedding, ProcessEmbedding,
QueryEmbedding, RawSql, Results,
SpaceKeyInput)
class FastAPI(chroma.server.Server):
def __init__(self):
super().__init__()
self._app = fastapi.FastAPI(debug=True)
self._api = chroma.get_api()
self.router = fastpi.APIRouter()
self.router.add_api_route("/api/v1", self.root, methods=["GET"])
self.router.add_api_route("/api/v1/add", self.add, methods=["POST"], status_code=status.HTTP_201_CREATED)
self.router.add_api_route("/api/v1/fetch", self.fetch, methods=["POST"])
self.router.add_api_route("/api/v1/delete", self.delete, methods=["POST"])
self.router.add_api_route("/api/v1/count", self.count, methods=["GET"])
self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"])
self.router.add_api_route("/api/v1/raw_sql", self.raw_sql, methods=["POST"])
self.router.add_api_route("/api/v1/get_nearest_neighbors", self.get_nearest_neighbors, methods=["POST"])
self.router.add_api_route("/api/v1/create_index", self.create_index, methods=["POST"])
self.router.add_api_route("/api/v1/process", self.process, methods=["POST"])
self.router.add_api_route("/api/v1/get_status", self.get_status, methods=["POST"])
self.router.add_api_route("/api/v1/get_results", self.get_results, methods=["POST"])
self._app.include_router(router)
def root(self):
return {"nanosecond heartbeat": self._api.heartbeat()}
def add(self, new_embedding: AddEmbedding):
if self._api.add(**new_embedding):
return {"response": "Added records to database"}
else:
raise Exception("api.add returned false")
def fetch(self, embedding: FetchEmbedding):
return self._api.fetch(embedding)
def delete(self, embedding: DeleteEmbedding):
return self._api.delete(embedding)
def count(self, model_space: str = None):
return self._api.count(model_space)
def reset(self):
return self._api.reset()
def get_nearest_neighbors(self, embedding: QueryEmbedding):
return self._api.get_nearest_neighbors(**embedding)
def raw_sql(self, raw_sql: RawSql):
return self._api.raw_sql(raw_sql.raw_sql)
def create_index(self, process_embedding: ProcessEmbedding):
return self._api.create_index(process_embedding.model_space)
def process(self, process_embedding: ProcessEmbedding):
task_id = self._api.process(process_embedding.model_space)
return JSONResponse({"task_id": task_id})
def get_status(self, task_id):
return JSONResponse(self._api.get_task_status(task_id))
def get_results(self, results: Results):
return self._api.get_results(results.model_space, results.n_results)

View File

@@ -38,4 +38,4 @@ class SpaceKeyInput(BaseModel):
model_space: str
class DeleteEmbedding(BaseModel):
where: dict = {}
where: dict = {}

View File

@@ -1,4 +1,4 @@
from chroma_server.utils.config.settings import get_settings
from chroma import get_settings
import sentry_sdk
from sentry_sdk.client import Client
@@ -27,4 +27,4 @@ def init_error_reporting():
before_send=strip_sensitive_data,
)
with configure_scope() as scope:
scope.set_tag('posthog_distinct_id', get_settings().telemetry_anonymized_uuid)
scope.set_tag('posthog_distinct_id', get_settings().telemetry_anonymized_uuid)

View File

@@ -1,10 +1,11 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
class Telemetry(ABC):
class Telemetry():
@abstractmethod
def __init__(self):
pass
@abstractmethod
def capture(self, event, properties=None):
pass
pass

View File

@@ -2,8 +2,8 @@ import posthog
import uuid
import sys
import os
from chroma_server.utils.telemetry.abstract import Telemetry
from chroma_server.utils.config.settings import get_settings
from chroma.server.utils.telemetry import Telemetry
from chroma import get_settings
class Capture(Telemetry):
_conn = None
@@ -37,4 +37,3 @@ class Capture(Telemetry):
properties['environment'] = os.getenv('environment', 'development')
self._conn.capture(self._telemetry_anonymized_uuid, event, properties)

4
chroma/test/test_api.py Normal file
View File

@@ -0,0 +1,4 @@
import pytest
def test_init():
assert(1==1)

View File

@@ -0,0 +1,56 @@
import pytest
import unittest
from unittest.mock import patch
import chroma
import chroma.config
class GetDBTest(unittest.TestCase):
@patch('chroma.db.duckdb.DuckDB', autospec=True)
def test_default_db(self, mock):
db = chroma.get_db(chroma.config.Settings())
assert mock.called
@patch('chroma.db.duckdb.PersistentDuckDB', autospec=True)
def test_persistent_duckdb(self, mock):
db = chroma.get_db(chroma.config.Settings(chroma_cache_dir="./foo"))
assert mock.called
@patch('chroma.db.clickhouse.Clickhouse', autospec=True)
def test_clickhouse(self, mock):
db = chroma.get_db(chroma.config.Settings(clickhouse_host="foo"))
assert mock.called
class GetAPITest(unittest.TestCase):
@patch('chroma.db.duckdb.DuckDB', autospec=True)
@patch('chroma.api.local.LocalAPI', autospec=True)
def test_local(self, mock_api, mock_db):
api = chroma.get_api(chroma.config.Settings())
assert mock_api.called
assert mock_db.called
@patch('chroma.db.duckdb.DuckDB', autospec=True)
@patch('chroma.api.celery.CeleryAPI', autospec=True)
def test_celery(self, mock_api, mock_db):
api = chroma.get_api(chroma.config.Settings(celery_broker_url='foo'))
assert mock_api.called
assert mock_db.called
@patch('chroma.api.fastapi.FastAPI', autospec=True)
def test_fastapi(self, mock):
api = chroma.get_api(chroma.config.Settings(chroma_server_host='foo',
chroma_server_http_port='80'))
assert mock.called
@patch('chroma.api.arrowflight.ArrowFlightAPI', autospec=True)
def test_arrowflight(self, mock):
api = chroma.get_api(chroma.config.Settings(chroma_server_host='foo',
chroma_server_http_port='80',
chroma_server_grpc_port='9999'))
assert mock.called

View File

@@ -1,8 +1,8 @@
import os
import time
import random
import chroma
from celery import Celery
from chroma_server.db.clickhouse import Clickhouse, get_col_pos
celery = Celery(__name__)
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379")
@@ -15,19 +15,20 @@ def create_task(task_type):
@celery.task(name="heavy_offline_analysis")
def heavy_offline_analysis(model_space):
task_db_conn = Clickhouse()
embedding_rows = task_db_conn.fetch({"model_space": model_space})
db = chroma.get_db()
embedding_rows = db.fetch({"model_space": model_space})
uuids = []
custom_quality_scores = []
for row in embedding_rows:
uuids.append(row[get_col_pos("uuid")])
custom_quality_scores.append(random.random())
spaces = [model_space] * len(uuids)
task_db_conn.delete_results(model_space)
task_db_conn.add_results(spaces, uuids, custom_quality_scores)
db.delete_results(model_space)
db.add_results(spaces, uuids, custom_quality_scores)
return "Wrote custom quality scores to database"

View File

@@ -7,9 +7,7 @@ chroma.reset()
# add
for i in range(10):
chroma.add(
embedding=[1,2,3,4,5,6,7,8,9,10],
input_uri="https://www.google.com",
dataset=None
embedding=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], input_uri="https://www.google.com", dataset=None
)
# fetch all
@@ -17,7 +15,7 @@ allres = chroma.get_all()
print(allres)
# count
print("count is", chroma.count()['count'])
print("count is", chroma.count()["count"])
# persist
chroma.persist()
@@ -33,4 +31,4 @@ chroma.process()
# reset
chroma.reset()
print("count is", chroma.count()['count'])
print("count is", chroma.count()["count"])

View File

@@ -4,6 +4,6 @@ import chroma_client
app = Flask(__name__)
@app.route('/')
@app.route("/")
def hello():
return(str(chroma_client.fetch_new_labels()))
return str(chroma_client.fetch_new_labels())

View File

@@ -7,22 +7,23 @@ chroma.set_model_space("sample_1_1")
print(chroma.heartbeat())
chroma.add([[1,2,3,4,5]], ["/images/1"], ["training"], ['spoon'])
chroma.add([[1,2,3,4,5]], ["/images/2"], ["training"], ['spoon'])
chroma.add([[1,2,3,4,5]], ["/images/3"], ["training"], ['spoon'])
chroma.add([[1,2,3,4,5]], ["/images/1"], ["training"], ['knife'])
chroma.add([[1,2,3,4,5]], ["/images/4"], ["training"], ['knife'])
chroma.add([[1,2,3,4,5]], ["/prod/2"], ["test"], ['knife'])
chroma.add([[1, 2, 3, 4, 5]], ["/images/1"], ["training"], ["spoon"])
chroma.add([[1, 2, 3, 4, 5]], ["/images/2"], ["training"], ["spoon"])
chroma.add([[1, 2, 3, 4, 5]], ["/images/3"], ["training"], ["spoon"])
chroma.add([[1, 2, 3, 4, 5]], ["/images/1"], ["training"], ["knife"])
chroma.add([[1, 2, 3, 4, 5]], ["/images/4"], ["training"], ["knife"])
chroma.add([[1, 2, 3, 4, 5]], ["/prod/2"], ["test"], ["knife"])
process_task = chroma.process()
print(process_task)
print(chroma.get_task_status(process_task['task_id']))
print(chroma.get_task_status(process_task["task_id"]))
print("sleeping for 10s to wait for task to complete")
import time
time.sleep(10)
print(chroma.get_task_status(process_task['task_id']))
print(chroma.get_task_status(process_task["task_id"]))
print(chroma.get_results())
# print(chroma.raw_sql("SELECT * FROM results WHERE space_key = 'yolov3_1_1'"))
@@ -64,4 +65,4 @@ print(chroma.get_results())
# print(chroma.process())
# print(chroma.get_nearest_neighbors([1,2,3,4,5], 2))
# print(chroma.get_nearest_neighbors([1,2,3,4,5], 2, space_key="yolov3_5_1"))
# print(chroma.get_nearest_neighbors([1,2,3,4,5], 2, space_key="yolov3_5_1"))

View File

@@ -2,4 +2,4 @@ from chroma_client import Chroma
chroma = Chroma()
print(chroma.get_results('yolov3_1_1'))
print(chroma.get_results("yolov3_1_1"))

View File

@@ -10,7 +10,7 @@ import pandas as pd
if __name__ == "__main__":
file = 'data__nogit/yolov3_objects_large_5k.parquet'
file = "data__nogit/yolov3_objects_large_5k.parquet"
print("Loading parquet file: ", file)
py = pq.read_table(file)
@@ -20,14 +20,14 @@ if __name__ == "__main__":
data_length = len(df)
chroma = Chroma(model_space="yolov3")
chroma.reset() #make sure we are using a fresh db
chroma.reset() # make sure we are using a fresh db
allstart = time.time()
start = time.time()
dataset = "training"
BATCH_SIZE = 1_000
print("Loading in records with a batch size of: " , data_length)
print("Loading in records with a batch size of: ", data_length)
for i in range(0, data_length, BATCH_SIZE):
if i >= 20_000:
@@ -35,64 +35,157 @@ if __name__ == "__main__":
end = time.time()
page = i * BATCH_SIZE
print("Time to process BATCH_SIZE rows: " + '{0:.2f}'.format((end - start)) + "s, records loaded: " + str(i))
print(
"Time to process BATCH_SIZE rows: "
+ "{0:.2f}".format((end - start))
+ "s, records loaded: "
+ str(i)
)
start = time.time()
batch = df[i:i+BATCH_SIZE]
batch = df[i : i + BATCH_SIZE]
for index, row in batch.iterrows():
for idx, annotation in enumerate(row['infer']['annotations']):
annotation["bbox"] = annotation['bbox'].tolist()
row['infer']['annotations'] = row['infer']['annotations'].tolist()
for idx, annotation in enumerate(row["infer"]["annotations"]):
annotation["bbox"] = annotation["bbox"].tolist()
row["infer"]["annotations"] = row["infer"]["annotations"].tolist()
row['embedding_data'] = row['embedding_data'].tolist()
row["embedding_data"] = row["embedding_data"].tolist()
embedding = batch['embedding_data'].tolist()
input_uri = batch['resource_uri'].tolist()
embedding = batch["embedding_data"].tolist()
input_uri = batch["resource_uri"].tolist()
inference_classes = []
for index, row in batch.iterrows():
for idx, annotation in enumerate(row['infer']['annotations']):
inference_classes.append(annotation['category_name'])
for idx, annotation in enumerate(row["infer"]["annotations"]):
inference_classes.append(annotation["category_name"])
datasets = dataset
chroma.add(
embedding=embedding,
input_uri=input_uri,
embedding=embedding,
input_uri=input_uri,
dataset=dataset,
inference_class=inference_classes
inference_class=inference_classes,
)
allend = time.time()
print("time to add all: ", "{:.2f}".format(allend - allstart) + 's')
print("time to add all: ", "{:.2f}".format(allend - allstart) + "s")
fetched = chroma.count()
print("Records loaded into the database: ", fetched)
print("Records loaded into the database: ", fetched)
start = time.time()
chroma.create_index()
end = time.time()
print("Time to process: " +'{0:.2f}'.format((end - start)) + 's')
print("Time to process: " + "{0:.2f}".format((end - start)) + "s")
knife_embedding = [0.2310010939836502, -0.3462161719799042, 0.29164767265319824, -0.09828940033912659, 1.814868450164795, -10.517369270324707, -13.531850814819336, -12.730537414550781, -13.011675834655762, -10.257010459899902, -13.779699325561523, -11.963963508605957, -13.948140144348145, -12.46799087524414, -14.569470405578613, -16.388280868530273, -13.76762580871582, -12.192169189453125, -12.204055786132812, -12.259000778198242, -13.696036338806152, -14.609177589416504, -16.951879501342773, -17.096384048461914, -14.355693817138672, -16.643482208251953, -14.270745277404785, -14.375198364257812, -14.381218910217285, -13.475995063781738, -12.694938659667969, -10.011992454528809, -9.770626068115234, -13.155019760131836, -16.136341094970703, -6.552414417266846, -11.243837356567383, -16.678457260131836, -14.629229545593262, -10.052337646484375, -15.451828956604004, -12.561151504516602, -11.68396282196045, -11.975972175598145, -11.09926986694336, -13.060500144958496, -12.075592994689941, -1.0808746814727783, 1.7046797275543213, -3.8080708980560303, -11.401922225952148, -12.184720039367676, -13.262567520141602, -11.299583435058594, -13.654638290405273, -10.767330169677734, -9.012763977050781, -10.202326774597168, -10.088111877441406, -13.247991561889648, -9.651527404785156, -11.903244972229004, -13.922954559326172, -17.37179946899414, -12.51513385772705, -7.8046746253967285, -14.406414985656738, -13.172696113586426, -11.194984436035156, -12.029500961303711, -10.996524810791016, -10.828441619873047, -8.673471450805664, -13.800869941711426, -9.680946350097656, -12.964024543762207, -9.694372177124023, -13.132003784179688, -9.38864803314209, -14.305071830749512, -14.4693603515625, -5.0566205978393555, -15.685358047485352, -12.493011474609375, -8.424881935119629]
knife_embedding = [
0.2310010939836502,
-0.3462161719799042,
0.29164767265319824,
-0.09828940033912659,
1.814868450164795,
-10.517369270324707,
-13.531850814819336,
-12.730537414550781,
-13.011675834655762,
-10.257010459899902,
-13.779699325561523,
-11.963963508605957,
-13.948140144348145,
-12.46799087524414,
-14.569470405578613,
-16.388280868530273,
-13.76762580871582,
-12.192169189453125,
-12.204055786132812,
-12.259000778198242,
-13.696036338806152,
-14.609177589416504,
-16.951879501342773,
-17.096384048461914,
-14.355693817138672,
-16.643482208251953,
-14.270745277404785,
-14.375198364257812,
-14.381218910217285,
-13.475995063781738,
-12.694938659667969,
-10.011992454528809,
-9.770626068115234,
-13.155019760131836,
-16.136341094970703,
-6.552414417266846,
-11.243837356567383,
-16.678457260131836,
-14.629229545593262,
-10.052337646484375,
-15.451828956604004,
-12.561151504516602,
-11.68396282196045,
-11.975972175598145,
-11.09926986694336,
-13.060500144958496,
-12.075592994689941,
-1.0808746814727783,
1.7046797275543213,
-3.8080708980560303,
-11.401922225952148,
-12.184720039367676,
-13.262567520141602,
-11.299583435058594,
-13.654638290405273,
-10.767330169677734,
-9.012763977050781,
-10.202326774597168,
-10.088111877441406,
-13.247991561889648,
-9.651527404785156,
-11.903244972229004,
-13.922954559326172,
-17.37179946899414,
-12.51513385772705,
-7.8046746253967285,
-14.406414985656738,
-13.172696113586426,
-11.194984436035156,
-12.029500961303711,
-10.996524810791016,
-10.828441619873047,
-8.673471450805664,
-13.800869941711426,
-9.680946350097656,
-12.964024543762207,
-9.694372177124023,
-13.132003784179688,
-9.38864803314209,
-14.305071830749512,
-14.4693603515625,
-5.0566205978393555,
-15.685358047485352,
-12.493011474609375,
-8.424881935119629,
]
start = time.time()
get_nearest_neighbors = chroma.get_nearest_neighbors(knife_embedding, 4, where= {"inference_class": "knife","dataset": "training"})
get_nearest_neighbors = chroma.get_nearest_neighbors(
knife_embedding, 4, where={"inference_class": "knife", "dataset": "training"}
)
print("get_nearest_neighbors: ", get_nearest_neighbors)
res_df = pd.DataFrame(get_nearest_neighbors['embeddings'])
res_df = pd.DataFrame(get_nearest_neighbors["embeddings"])
print(res_df.head())
print("Distances to nearest neighbors: ", get_nearest_neighbors['distances'])
print("Internal ids of nearest neighbors: ", get_nearest_neighbors['ids'])
print("Distances to nearest neighbors: ", get_nearest_neighbors["distances"])
print("Internal ids of nearest neighbors: ", get_nearest_neighbors["ids"])
end = time.time()
print("Time to get nearest neighbors: " +'{0:.2f}'.format((end - start)) + 's')
print("Time to get nearest neighbors: " + "{0:.2f}".format((end - start)) + "s")
task = chroma.calculate_results()
print(task)
print(chroma.get_task_status(task['task_id']))
print(chroma.get_task_status(task["task_id"]))
fetched = chroma.count()
print("Records loaded into the database: ", fetched)
print("Records loaded into the database: ", fetched)
del chroma

View File

@@ -1,8 +1,40 @@
[project]
name = "chroma"
dynamic = ["version"]
authors = [
{ name="Jeff Huber", email="jeff@trychroma.com" },
{ name="Anton Troynikov", email="anton@trychroma.com" }
]
description = "Chroma."
readme = "README.md"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache-2.0",
"Operating System :: OS Independent",
]
dependencies = [
'pyarrow ~= 9.0',
'requests ~= 2.28',
]
[tool.black]
line-length = 100
required-version = "22.10.0" # Black will refuse to run if it's not this version.
target-version = ['py36', 'py37', 'py38', 'py39', 'py310']
# Black will refuse to run if it's not this version.
required-version = "22.6.0"
[tool.pytest.ini_options]
pythonpath = ["."]
[project.urls]
"Homepage" = "https://github.com/chroma-core/chroma"
"Bug Tracker" = "https://github.com/chroma-core/chroma/issues"
[build-system]
requires = ["setuptools>=61.0", "setuptools_scm[toml]>=6.2"]
build-backend = "setuptools.build_meta"
[tool.setuptools_scm]
local_scheme="no-local-version"
# Ensure black's output will be compatible with all listed versions.
target-version = ['py36', 'py37', 'py38', 'py39', 'py310']

View File

@@ -1,11 +1,11 @@
fastapi==0.85.1
uvicorn[standard]==0.18.3
pyarrow==9.0.0
numpy
pandas
requests==2.28.1
numpy==1.23.5
pandas==1.5.1
duckdb==0.5.1
hnswlib @ git+https://oauth2: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@github.com/chroma-core/hnswlib.git
redis==3.5.3
celery==4.4.7
clickhouse_driver==0.2.4

View File

@@ -1,3 +1,5 @@
httpx
build
pytest
setuptools_scm
httpx
black