From d9dfd089fb9177fe4ec2cf7087bf28987b5e9d48 Mon Sep 17 00:00:00 2001 From: Luke VanderHart Date: Thu, 24 Nov 2022 00:17:24 -0500 Subject: [PATCH] major refactorings WIP --- .gitignore | 4 +- chroma-client/DEVELOP.md => DEVELOP.md | 0 DEV_README.md | 3 - chroma-server/Dockerfile => Dockerfile | 0 Makefile | 5 - README.md | 77 +++--- {chroma-server/bin => bin}/build | 0 {chroma-server/bin => bin}/test | 0 bin/test.py | 12 +- {chroma-server/bin => bin}/version | 0 chroma-client/.gitignore | 14 -- chroma-client/LICENSE | 201 ---------------- chroma-client/README.md | 1 - chroma-client/dev_requirements.txt | 7 - chroma-client/pyproject.toml | 37 --- chroma-client/src/chroma_client/__init__.py | 1 - chroma-client/tests/test_client.py | 18 -- chroma-server/.dockerignore | 13 -- chroma-server/.env | 3 - chroma-server/.gitignore | 13 -- chroma-server/DockerfileCelery | 11 - chroma-server/README.md | 47 ---- chroma-server/app.py | 3 - .../chroma_server/algorithms/__init__.py | 0 .../algorithms/rand_subsample.py | 8 - .../algorithms/stub_distances.py | 89 ------- chroma-server/chroma_server/api.py | 55 ----- chroma-server/chroma_server/core.py | 60 ----- chroma-server/chroma_server/db/__init__.py | 0 chroma-server/chroma_server/db/abstract.py | 27 --- chroma-server/chroma_server/db/duckdb.py | 118 ---------- chroma-server/chroma_server/index/__init__.py | 0 chroma-server/chroma_server/index/abstract.py | 27 --- chroma-server/chroma_server/logger.py | 12 - chroma-server/chroma_server/routes.py | 204 ---------------- chroma-server/chroma_server/test/__init__.py | 0 chroma-server/chroma_server/test/test_api.py | 181 -------------- chroma-server/chroma_server/utils/__init__.py | 0 .../chroma_server/utils/config/__init__.py | 0 .../chroma_server/utils/config/settings.py | 15 -- .../chroma_server/utils/telemetry/__init__.py | 0 chroma-server/docker-compose.test.yml | 39 ---- chroma-server/pyproject.toml | 20 -- chroma-server/run_tests.sh | 2 - chroma/__init__.py | 48 ++++ .../client.py => chroma/api/__init__.py | 220 +++++++++--------- chroma/api/arrowflight.py | 9 + chroma/api/celery.py | 47 ++++ chroma/api/fastapi.py | 117 ++++++++++ chroma/api/local.py | 115 +++++++++ chroma/config.py | 24 ++ chroma/db/__init__.py | 75 ++++++ .../chroma_server => chroma}/db/clickhouse.py | 86 ++++--- chroma/db/duckdb.py | 12 + chroma/db/index/__init__.py | 32 +++ .../db}/index/hnswlib.py | 60 ++--- chroma/server/__init__.py | 11 + chroma/server/arrowflight.py | 9 + chroma/server/fastapi.py | 83 +++++++ .../server/fastapi}/__init__.py | 0 .../server/fastapi/types.py | 2 +- .../server}/utils/error_reporting.py | 4 +- .../server/utils/telemetry/__init__.py | 7 +- .../server}/utils/telemetry/capture.py | 5 +- chroma/test/test_api.py | 4 + chroma/test/test_chroma.py | 56 +++++ .../chroma_server => chroma}/worker.py | 15 +- .../chroma-in-notebook.ipynb | 0 .../in-memory_demo.ipynb | 0 examples/misc/play.py | 8 +- examples/sample-app/app.py | 4 +- examples/sample-script/sample_script.py | 19 +- examples/yolov3/results.py | 2 +- examples/yolov3/yolov3.py | 147 +++++++++--- pyproject.toml | 40 +++- .../requirements.txt => requirements.txt | 6 +- ...quirements_dev.txt => requirements_dev.txt | 4 +- 77 files changed, 1074 insertions(+), 1524 deletions(-) rename chroma-client/DEVELOP.md => DEVELOP.md (100%) delete mode 100644 DEV_README.md rename chroma-server/Dockerfile => Dockerfile (100%) delete mode 100644 Makefile rename {chroma-server/bin => bin}/build (100%) rename {chroma-server/bin => bin}/test (100%) rename {chroma-server/bin => bin}/version (100%) delete mode 100644 chroma-client/.gitignore delete mode 100644 chroma-client/LICENSE delete mode 100644 chroma-client/README.md delete mode 100644 chroma-client/dev_requirements.txt delete mode 100644 chroma-client/pyproject.toml delete mode 100644 chroma-client/src/chroma_client/__init__.py delete mode 100644 chroma-client/tests/test_client.py delete mode 100644 chroma-server/.dockerignore delete mode 100644 chroma-server/.env delete mode 100644 chroma-server/.gitignore delete mode 100644 chroma-server/DockerfileCelery delete mode 100644 chroma-server/README.md delete mode 100644 chroma-server/app.py delete mode 100644 chroma-server/chroma_server/algorithms/__init__.py delete mode 100644 chroma-server/chroma_server/algorithms/rand_subsample.py delete mode 100644 chroma-server/chroma_server/algorithms/stub_distances.py delete mode 100644 chroma-server/chroma_server/api.py delete mode 100644 chroma-server/chroma_server/core.py delete mode 100644 chroma-server/chroma_server/db/__init__.py delete mode 100644 chroma-server/chroma_server/db/abstract.py delete mode 100644 chroma-server/chroma_server/db/duckdb.py delete mode 100644 chroma-server/chroma_server/index/__init__.py delete mode 100644 chroma-server/chroma_server/index/abstract.py delete mode 100644 chroma-server/chroma_server/logger.py delete mode 100644 chroma-server/chroma_server/routes.py delete mode 100644 chroma-server/chroma_server/test/__init__.py delete mode 100644 chroma-server/chroma_server/test/test_api.py delete mode 100644 chroma-server/chroma_server/utils/__init__.py delete mode 100644 chroma-server/chroma_server/utils/config/__init__.py delete mode 100644 chroma-server/chroma_server/utils/config/settings.py delete mode 100644 chroma-server/chroma_server/utils/telemetry/__init__.py delete mode 100644 chroma-server/docker-compose.test.yml delete mode 100644 chroma-server/pyproject.toml delete mode 100644 chroma-server/run_tests.sh create mode 100644 chroma/__init__.py rename chroma-client/src/chroma_client/client.py => chroma/api/__init__.py (51%) create mode 100644 chroma/api/arrowflight.py create mode 100644 chroma/api/celery.py create mode 100644 chroma/api/fastapi.py create mode 100644 chroma/api/local.py create mode 100644 chroma/config.py create mode 100644 chroma/db/__init__.py rename {chroma-server/chroma_server => chroma}/db/clickhouse.py (86%) create mode 100644 chroma/db/duckdb.py create mode 100644 chroma/db/index/__init__.py rename {chroma-server/chroma_server => chroma/db}/index/hnswlib.py (86%) create mode 100644 chroma/server/__init__.py create mode 100644 chroma/server/arrowflight.py create mode 100644 chroma/server/fastapi.py rename {chroma-server/chroma_server => chroma/server/fastapi}/__init__.py (100%) rename chroma-server/chroma_server/types/__init__.py => chroma/server/fastapi/types.py (97%) rename {chroma-server/chroma_server => chroma/server}/utils/error_reporting.py (91%) rename chroma-server/chroma_server/utils/telemetry/abstract.py => chroma/server/utils/telemetry/__init__.py (63%) rename {chroma-server/chroma_server => chroma/server}/utils/telemetry/capture.py (90%) create mode 100644 chroma/test/test_api.py create mode 100644 chroma/test/test_chroma.py rename {chroma-server/chroma_server => chroma}/worker.py (72%) rename {chroma-server => examples}/chroma-in-notebook.ipynb (100%) rename {chroma-server => examples}/in-memory_demo.ipynb (100%) rename chroma-server/requirements.txt => requirements.txt (88%) rename chroma-server/requirements_dev.txt => requirements_dev.txt (70%) diff --git a/.gitignore b/.gitignore index bbaf92e..126f00e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ __pycache__ **/__pycache__ -chroma-server/chroma_logs.log +*.log **/data__nogit @@ -14,4 +14,4 @@ chroma-server/chroma_logs.log index_data /index_data -chroma_logs.log \ No newline at end of file +venv \ No newline at end of file diff --git a/chroma-client/DEVELOP.md b/DEVELOP.md similarity index 100% rename from chroma-client/DEVELOP.md rename to DEVELOP.md diff --git a/DEV_README.md b/DEV_README.md deleted file mode 100644 index 58160c6..0000000 --- a/DEV_README.md +++ /dev/null @@ -1,3 +0,0 @@ -### Using core code/modules outside Docker - -Adding `127.0.0.1 clickhouse` to /etc/hosts is only necessary if you want to use the app's core code outside of Docker, while still using Clickhouse inside Docker. This is because inside Docker Clickhouse is mapped to `clickhouse` as the network name. Outside, Docker containers are simply mapped to ports on 127.0.0.1. There are two ways to use the core code outside of Docker then: (1) update the url from `clickhouse` to `127.0.01` inside clickhouse.py (FYI chroma-server inside Docker *will* break), or map requests locally on your host machine to `clickhouse` over to Docker's `127.0.0.1`. Adding the line to /etc/hosts fulfills #2. This is outside the usage pattern of normal use, but it's useful to know how and why it can work. \ No newline at end of file diff --git a/chroma-server/Dockerfile b/Dockerfile similarity index 100% rename from chroma-server/Dockerfile rename to Dockerfile diff --git a/Makefile b/Makefile deleted file mode 100644 index 75e62a2..0000000 --- a/Makefile +++ /dev/null @@ -1,5 +0,0 @@ -black: - black --fast chroma-server chroma-client - -check_black: - black --check --fast chroma-server chroma-client \ No newline at end of file diff --git a/README.md b/README.md index b38bca1..cabe66b 100644 --- a/README.md +++ b/README.md @@ -1,54 +1,47 @@ -# Chroma +# Chroma Server -This repository is a monorepo containing all the core components of -the Chroma product. +## Development -Contents: - -- `/doc` - Project documentation -- `/chroma-client` - Python client for Chroma -- `/chroma-server` - FastAPI server used as the backend for Chroma client - - - -### Get up and running on Linux -No requirements -``` -/bin/bash -c "$(curl -fsSL https://gist.githubusercontent.com/jeffchuber/effcbac05021e863bbd634f4b7d0283d/raw/4d38b150809d6ccbc379f88433cadd86c81d32cd/chroma_setup.sh)" -python3 chroma/bin/test.py -``` - -### Get up and running on Mac -Requirements -- git -- Docker & `docker-compose` -- pip +Set up a virtual environment and install the project's requirements +and dev requirements: ``` -/bin/bash -c "$(curl -fsSL https://gist.githubusercontent.com/jeffchuber/27a3cbb28e6521c811da6398346cd35f/raw/55c2d82870436431120a9446b47f19b72d88fa31/chroma_setup_mac.sh)" -python3 chroma/bin/test.py +python3 -m venv venv # Only need to do this once +source venv/bin/activate # Do this each time you use a new shell for the project +pip install -r requirements.txt +pip install -r requirements_dev.txt ``` -* These urls will be swapped out for the link in the repo once it is live +To run tests, run `bin/test`. This will run the test suite inside a +docker compose cluster, with the database available, and clean up when complete. +To run the server locally, in development mode, run `uvicorn chroma_server:app --reload` -### You should see something like +## Docker +To build the docker image locally, run `bin/build`. + +The version tag of the build is generated by the `bin/version` script, +which uses the `setuptools_scm` library. For full documentation, see +the +[documentation for setuptools_scm](https://github.com/pypa/setuptools_scm/). + +In brief, version numbers are generated as follows: + +- If the current git head is tagged, the version number is exactly the + tag (e.g, `0.0.1`). +- If the the current git head is a clean checkout, but is not tagged, + the version number is a patch version increment of the most recent + tag, plus `devN` where N is the number of commits since the most + recent tag. For example, if there have been 5 commits since the + `0.0.1` tag, the generated version will be `0.0.2-dev5`. +- If the current head is not a clean checkout, a `-dirty` local + version will be appended to the version number. For example, + `0.0.2-dev5-dirty`. + +To run use `docker images` to see what containers and tags you have available: ``` -Getting heartbeat to verify the server is up -{'nanosecond heartbeat': 1667865642509760965000} -Logging embeddings into the database -Generating the index -True -Running a nearest neighbor search -{'ids': ['11540ca6-ebbc-4c81-8299-108d8c47c88c'], 'embeddings': [['sample_space', '11540ca6-ebbc-4c81-8299-108d8c47c88c', [1.0, 2.0, 3.0, 4.0, 5.0], '/images/1', 'training', None, 'spoon']], 'distances': [0.0]} -Success! Everything worked! +docker run -p 8000:8000 ghcr.io/chroma-core/chroma-server:> ``` - -### Run in-memory Chroma - -``` -cd chroma-server -CHROMA_MODE="in-memory" uvicorn chroma_server.api:app --reload --log-level=debug -``` \ No newline at end of file +This will expose the internal app at `localhost:8000` diff --git a/chroma-server/bin/build b/bin/build similarity index 100% rename from chroma-server/bin/build rename to bin/build diff --git a/chroma-server/bin/test b/bin/test similarity index 100% rename from chroma-server/bin/test rename to bin/test diff --git a/bin/test.py b/bin/test.py index f628c33..4ab2a73 100644 --- a/bin/test.py +++ b/bin/test.py @@ -4,16 +4,16 @@ from chroma_client import Chroma chroma = Chroma() -chroma.set_model_space('sample_space') +chroma.set_model_space("sample_space") print("Getting heartbeat to verify the server is up") print(chroma.heartbeat()) print("Logging embeddings into the database") chroma.add( - [[1,2,3,4,5], [5,4,3,2,1], [10,9,8,7,6]], - ["/images/1", "/images/2", "/images/3"], - ["training", "training", "training"], - ['spoon', 'knife', 'fork'] + [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [10, 9, 8, 7, 6]], + ["/images/1", "/images/2", "/images/3"], + ["training", "training", "training"], + ["spoon", "knife", "fork"], ) print("count") @@ -24,6 +24,6 @@ print("Generating the index") print(chroma.create_index()) print("Running a nearest neighbor search") -print(chroma.get_nearest_neighbors([1,2,3,4,5], 1)) +print(chroma.get_nearest_neighbors([1, 2, 3, 4, 5], 1)) print("Success! Everything worked!") diff --git a/chroma-server/bin/version b/bin/version similarity index 100% rename from chroma-server/bin/version rename to bin/version diff --git a/chroma-client/.gitignore b/chroma-client/.gitignore deleted file mode 100644 index fa37994..0000000 --- a/chroma-client/.gitignore +++ /dev/null @@ -1,14 +0,0 @@ -# general things to ignore -build/ -dist/ -*.egg-info/ -*.egg -*.py[cod] -__pycache__/ -*.so -*~ -venv - -# due to using tox and pytest -.tox -.cache diff --git a/chroma-client/LICENSE b/chroma-client/LICENSE deleted file mode 100644 index 261eeb9..0000000 --- a/chroma-client/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/chroma-client/README.md b/chroma-client/README.md deleted file mode 100644 index c354c4c..0000000 --- a/chroma-client/README.md +++ /dev/null @@ -1 +0,0 @@ -# Chroma Client diff --git a/chroma-client/dev_requirements.txt b/chroma-client/dev_requirements.txt deleted file mode 100644 index 1d80bb9..0000000 --- a/chroma-client/dev_requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -# Depenences to test, build and release the project -build -pytest -setuptools_scm -httpx --e . # Include transitive dependencies and code from the current project - diff --git a/chroma-client/pyproject.toml b/chroma-client/pyproject.toml deleted file mode 100644 index 0b15166..0000000 --- a/chroma-client/pyproject.toml +++ /dev/null @@ -1,37 +0,0 @@ -[build-system] -requires = ["setuptools>=61.0", "setuptools_scm[toml]>=6.2"] -build-backend = "setuptools.build_meta" - -[project] -name = "chroma_client" -dynamic = ["version"] -dependencies = [ - 'pyarrow ~= 9.0', - 'requests ~= 2.28', -] - -authors = [ - { name="Jeff Huber", email="jeff@trychroma.com" }, - { name="Anton Troynikov", email="anton@trychroma.com" } -] -description = "Chroma." -readme = "README.md" -requires-python = ">=3.7" -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache-2.0", - "Operating System :: OS Independent", -] - -[project.urls] -"Homepage" = "https://github.com/chroma-core/chroma" -"Bug Tracker" = "https://github.com/chroma-core/chroma/issues" - -[tool.pytest.ini_options] -pythonpath = [ - "src" -] - -[tool.setuptools_scm] -root=".." -local_scheme="dirty-tag" diff --git a/chroma-client/src/chroma_client/__init__.py b/chroma-client/src/chroma_client/__init__.py deleted file mode 100644 index 421945b..0000000 --- a/chroma-client/src/chroma_client/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .client import * diff --git a/chroma-client/tests/test_client.py b/chroma-client/tests/test_client.py deleted file mode 100644 index b576fd9..0000000 --- a/chroma-client/tests/test_client.py +++ /dev/null @@ -1,18 +0,0 @@ -# use pytest to test chroma_client - -from chroma_client import Chroma -import pytest -import time -from httpx import AsyncClient - -# from ..api import app # this wont work because i moved the file - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - - -def test_init(): - chroma = Chroma() - assert chroma._api_url == "http://localhost:8000/api/v1" diff --git a/chroma-server/.dockerignore b/chroma-server/.dockerignore deleted file mode 100644 index a692159..0000000 --- a/chroma-server/.dockerignore +++ /dev/null @@ -1,13 +0,0 @@ -# general things to ignore -build/ -dist/ -*.egg-info/ -*.egg -*.py[cod] -__pycache__/ -*.so -*~ -venv - -# due to using tox and pytest -.cache diff --git a/chroma-server/.env b/chroma-server/.env deleted file mode 100644 index a75cfd9..0000000 --- a/chroma-server/.env +++ /dev/null @@ -1,3 +0,0 @@ -disable_anonymized_telemetry=False -environment=development -telemetry_anonymized_uuid=f80b11fc-1c5a-4a90-ba35-8c3a3c5371cc diff --git a/chroma-server/.gitignore b/chroma-server/.gitignore deleted file mode 100644 index a692159..0000000 --- a/chroma-server/.gitignore +++ /dev/null @@ -1,13 +0,0 @@ -# general things to ignore -build/ -dist/ -*.egg-info/ -*.egg -*.py[cod] -__pycache__/ -*.so -*~ -venv - -# due to using tox and pytest -.cache diff --git a/chroma-server/DockerfileCelery b/chroma-server/DockerfileCelery deleted file mode 100644 index dbc14c2..0000000 --- a/chroma-server/DockerfileCelery +++ /dev/null @@ -1,11 +0,0 @@ -# pull official base image -FROM --platform=linux/amd64 python:3.10 - -WORKDIR /chroma-server - -COPY ./requirements.txt requirements.txt - -RUN pip install --no-cache-dir --upgrade -r requirements.txt - -# copy project -COPY . . \ No newline at end of file diff --git a/chroma-server/README.md b/chroma-server/README.md deleted file mode 100644 index cabe66b..0000000 --- a/chroma-server/README.md +++ /dev/null @@ -1,47 +0,0 @@ -# Chroma Server - -## Development - -Set up a virtual environment and install the project's requirements -and dev requirements: - -``` -python3 -m venv venv # Only need to do this once -source venv/bin/activate # Do this each time you use a new shell for the project -pip install -r requirements.txt -pip install -r requirements_dev.txt -``` - -To run tests, run `bin/test`. This will run the test suite inside a -docker compose cluster, with the database available, and clean up when complete. - -To run the server locally, in development mode, run `uvicorn chroma_server:app --reload` - -## Docker - -To build the docker image locally, run `bin/build`. - -The version tag of the build is generated by the `bin/version` script, -which uses the `setuptools_scm` library. For full documentation, see -the -[documentation for setuptools_scm](https://github.com/pypa/setuptools_scm/). - -In brief, version numbers are generated as follows: - -- If the current git head is tagged, the version number is exactly the - tag (e.g, `0.0.1`). -- If the the current git head is a clean checkout, but is not tagged, - the version number is a patch version increment of the most recent - tag, plus `devN` where N is the number of commits since the most - recent tag. For example, if there have been 5 commits since the - `0.0.1` tag, the generated version will be `0.0.2-dev5`. -- If the current head is not a clean checkout, a `-dirty` local - version will be appended to the version number. For example, - `0.0.2-dev5-dirty`. - -To run use `docker images` to see what containers and tags you have available: -``` -docker run -p 8000:8000 ghcr.io/chroma-core/chroma-server:> -``` - -This will expose the internal app at `localhost:8000` diff --git a/chroma-server/app.py b/chroma-server/app.py deleted file mode 100644 index b77e070..0000000 --- a/chroma-server/app.py +++ /dev/null @@ -1,3 +0,0 @@ -from chroma_server.api import app - -app = app \ No newline at end of file diff --git a/chroma-server/chroma_server/algorithms/__init__.py b/chroma-server/chroma_server/algorithms/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chroma-server/chroma_server/algorithms/rand_subsample.py b/chroma-server/chroma_server/algorithms/rand_subsample.py deleted file mode 100644 index 4952d28..0000000 --- a/chroma-server/chroma_server/algorithms/rand_subsample.py +++ /dev/null @@ -1,8 +0,0 @@ -import random - - -def rand_bisectional_subsample(data): - """ - Randomly bisectionally subsample a list of data to size. - """ - return data.sample(frac=0.5, replace=True, random_state=1) diff --git a/chroma-server/chroma_server/algorithms/stub_distances.py b/chroma-server/chroma_server/algorithms/stub_distances.py deleted file mode 100644 index 0d79677..0000000 --- a/chroma-server/chroma_server/algorithms/stub_distances.py +++ /dev/null @@ -1,89 +0,0 @@ -import numpy as np -import json -import ast - - -def class_distances(data): - """' - This is all very subject to change, so essentially just copy and paste from what we had before - """ - - return False - - # def unpack_annotations(embeddings): - # annotations = [json.loads(embedding['infer'])["annotations"]for embedding in embeddings] - # annotations = [annotation for annotation_list in annotations for annotation in annotation_list] - # # Unpack embedding data - # embeddings = [embedding["embedding"] for embedding in embeddings] - # embedding_vectors_by_category = {} - # for embedding_annotation_pair in zip(embeddings, annotations): - # data = np.array(embedding_annotation_pair[0]) - # category = embedding_annotation_pair[1]['category_id'] - # if category in embedding_vectors_by_category.keys(): - # embedding_vectors_by_category[category] = np.append( - # embedding_vectors_by_category[category], data[np.newaxis, :], axis=0 - # ) - # else: - # embedding_vectors_by_category[category] = data[np.newaxis, :] - - # return embedding_vectors_by_category - - # # Get the training embeddings. This is the set of embeddings belonging to datapoints of the training dataset, and to the first created embedding set. - # object_embedding_vectors_by_category = unpack_annotations(data.to_dict('records')) - - # inv_covs_by_category = {} - # means_by_category = {} - # for category, embeddings in object_embedding_vectors_by_category.items(): - # print(f"Computing mean and covariance for label category {category}") - - # # Compute the mean and inverse covariance for computing MHB distance - # print(f"category: {category} samples: {embeddings.shape[0]}") - # if embeddings.shape[0] < (embeddings.shape[1] + 1): - # print(f"not enough samples for stable covariance in category {category}") - # continue - # cov = np.cov(embeddings.transpose()) - # try: - # inv_cov = np.linalg.inv(cov) - # except np.linalg.LinAlgError as err: - # print(f"covariance for category {category} is singular") - # continue - # mean = np.mean(embeddings, axis=0) - # inv_covs_by_category[category] = inv_cov - # means_by_category[category] = mean - - # target_datapoints = data.to_dict('records') #+ panda_train_table.to_dict('records') - - # output_distances = [] - - # # Process each datapoint's inferences individually. This is going to be very slow. - # # This is because there is no way to grab the corresponding metadata off the datapoint. - # # We could instead put it on the embedding directly ? - # inference_metadata = {} - # quality_scores = [] - # for idx, datapoint in enumerate(target_datapoints): - # inferences = json.loads(datapoint['infer'])["annotations"] - # embeddings = [datapoint["embedding"]] - - # for i in range(len(inferences)): - # emb_data = embeddings[i] - # category = inferences[i]["category_id"] - # if not category in inv_covs_by_category.keys(): - # output_distances.append({"distance": None, "id": datapoint["id"]}) - # continue - # mean = means_by_category[category] - # inv_cov = inv_covs_by_category[category] - # delta = np.array(emb_data) - mean - # squared_mhb = np.sum((delta * np.matmul(inv_cov, delta)), axis=0) - # if squared_mhb < 0: - # print(f"squared distance for category {category} is negative") - # output_distances.append({"distance": None, "id": datapoint["id"]}) - # continue - # distance = np.sqrt(squared_mhb) - # quality_scores.append([distance, datapoint]) - # inference_metadata[datapoint["input_uri"]] = distance - # output_distances.append({"id": datapoint["id"], "distance": distance}) - - # if (len(inferences) == 0): - # raise Exception("No inferences found for datapoint") - - # return output_distances diff --git a/chroma-server/chroma_server/api.py b/chroma-server/chroma_server/api.py deleted file mode 100644 index fba31f3..0000000 --- a/chroma-server/chroma_server/api.py +++ /dev/null @@ -1,55 +0,0 @@ -import time -import os - -from chroma_server.db.clickhouse import Clickhouse -from chroma_server.db.duckdb import DuckDB -from chroma_server.index.hnswlib import Hnswlib -from chroma_server.utils.error_reporting import init_error_reporting -from chroma_server.utils.telemetry.capture import Capture -from fastapi import FastAPI - -chroma_telemetry = Capture() -chroma_telemetry.capture('server-start') -init_error_reporting() - -from chroma_server.routes import ChromaRouter - -# current valid modes are 'in-memory' and 'docker', it defaults to docker -chroma_mode = os.getenv('CHROMA_MODE', 'docker') -if chroma_mode == 'in-memory': - db = DuckDB -else: - db = Clickhouse - -ann_index = Hnswlib - -app = FastAPI(debug=True) - -# init db and index -app._db = db() -app._ann_index = ann_index() - -if chroma_mode == 'in-memory': - filesystem_location = os.getcwd() - - # create a dir - if not os.path.exists(filesystem_location + '/.chroma'): - os.makedirs(filesystem_location + '/.chroma') - - if not os.path.exists(filesystem_location + '/.chroma/index_data'): - os.makedirs(filesystem_location + '/.chroma/index_data') - - # specify where to save and load data from - app._db.set_save_folder(filesystem_location + '/.chroma') - app._ann_index.set_save_folder(filesystem_location + '/.chroma/index_data') - - print("Initializing Chroma...") - print("Data will be saved to: " + filesystem_location + '/.chroma') - - # if the db exists, load it - if os.path.exists(filesystem_location + '/.chroma/chroma.parquet'): - print(f"Existing database found at {filesystem_location + '/.chroma/chroma.parquet'}. Loading...") - app._db.load() - -router = ChromaRouter(app=app, db=db, ann_index=ann_index) -app.include_router(router.router) diff --git a/chroma-server/chroma_server/core.py b/chroma-server/chroma_server/core.py deleted file mode 100644 index cb8dd7c..0000000 --- a/chroma-server/chroma_server/core.py +++ /dev/null @@ -1,60 +0,0 @@ -import os -from fastapi import FastAPI - -from chroma_server.routes import ChromaRouter -from chroma_server.index.hnswlib import Hnswlib -from chroma_server.db.duckdb import DuckDB - -# we import types here so that the user can import them from here -from chroma_server.types import ( - ProcessEmbedding, AddEmbedding, FetchEmbedding, - QueryEmbedding, CountEmbedding, DeleteEmbedding, - RawSql, Results, SpaceKeyInput) - -core = FastAPI(debug=True) -core._db = DuckDB() -core._ann_index = Hnswlib() - -router = ChromaRouter(app=core, db=DuckDB, ann_index=Hnswlib) -core.include_router(router.router) - -def init(filesystem_location: str = None): - if filesystem_location is None: - filesystem_location = os.getcwd() - - # create a dir - if not os.path.exists(filesystem_location + '/.chroma'): - os.makedirs(filesystem_location + '/.chroma') - - if not os.path.exists(filesystem_location + '/.chroma/index_data'): - os.makedirs(filesystem_location + '/.chroma/index_data') - - # specify where to save and load data from - core._db.set_save_folder(filesystem_location + '/.chroma') - core._ann_index.set_save_folder(filesystem_location + '/.chroma/index_data') - - print("Initializing Chroma...") - print("Data will be saved to: " + filesystem_location + '/.chroma') - - # if the db exists, load it - if os.path.exists(filesystem_location + '/.chroma/chroma.parquet'): - print(f"Existing database found at {filesystem_location + '/.chroma/chroma.parquet'}. Loading...") - core._db.load() - -core.init = init - -# headless mode -core.heartbeat = router.root -core.add = router.add -core.count = router.count -core.fetch = router.fetch -core.reset = router.reset -core.delete = router.delete -core.get_nearest_neighbors = router.get_nearest_neighbors -core.raw_sql = router.raw_sql -core.create_index = router.create_index - -# these as currently constructed require celery -# chroma_core.process = process -# chroma_core.get_status = get_status -# chroma_core.get_results = get_results diff --git a/chroma-server/chroma_server/db/__init__.py b/chroma-server/chroma_server/db/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chroma-server/chroma_server/db/abstract.py b/chroma-server/chroma_server/db/abstract.py deleted file mode 100644 index 602d984..0000000 --- a/chroma-server/chroma_server/db/abstract.py +++ /dev/null @@ -1,27 +0,0 @@ -from abc import abstractmethod - - -class Database: - @abstractmethod - def __init__(self): - pass - - @abstractmethod - def add(self, model_space, embedding, input_uri, dataset=None, custom_quality_score=None, inference_class=None, label_class=None): - pass - - @abstractmethod - def count(self, model_space=None): - pass - - @abstractmethod - def fetch(self, where={}, sort=None, limit=None): - pass - - @abstractmethod - def get_by_ids(self, ids): - pass - - @abstractmethod - def reset(self): - pass diff --git a/chroma-server/chroma_server/db/duckdb.py b/chroma-server/chroma_server/db/duckdb.py deleted file mode 100644 index 1de58ff..0000000 --- a/chroma-server/chroma_server/db/duckdb.py +++ /dev/null @@ -1,118 +0,0 @@ -import duckdb -import uuid -import time -from chroma_server.db.clickhouse import Clickhouse, db_array_schema_to_clickhouse_schema, EMBEDDING_TABLE_SCHEMA, RESULTS_TABLE_SCHEMA, db_schema_to_keys -import pandas as pd -import numpy as np - -def clickhouse_to_duckdb_schema(table_schema): - for item in table_schema: - if 'embedding' in item: - item['embedding'] = 'REAL[]' - # capitalize the key - item[list(item.keys())[0]] = item[list(item.keys())[0]].upper() - if 'NULLABLE' in item[list(item.keys())[0]]: - item[list(item.keys())[0]] = item[list(item.keys())[0]].replace('NULLABLE(', '').replace(')', '') - if 'UUID' in item[list(item.keys())[0]]: - item[list(item.keys())[0]] = 'STRING' - if 'FLOAT64' in item[list(item.keys())[0]]: - item[list(item.keys())[0]] = 'REAL' - - return table_schema - -class DuckDB(Clickhouse): - - _save_folder = None - - # duckdb has different types, so we want to convert the clickhouse schema to duckdb schema - def _create_table_embeddings(self): - self._conn.execute(f'''CREATE TABLE embeddings ( - {db_array_schema_to_clickhouse_schema(clickhouse_to_duckdb_schema(EMBEDDING_TABLE_SCHEMA))} - ) ''') - - def _create_table_results(self): - self._conn.execute(f'''CREATE TABLE results ( - {db_array_schema_to_clickhouse_schema(clickhouse_to_duckdb_schema(RESULTS_TABLE_SCHEMA))} - ) ''') - - # duckdb has a different way of connecting to the database - def __init__(self): - self._conn = duckdb.connect() - self._create_table_embeddings() - self._create_table_results() - - def set_save_folder(self, path): - self._save_folder = path - - def get_save_folder(self): - return self._save_folder - - # the execute many syntax is different than clickhouse, the (?,?) syntax is different than clickhouse - def add(self, model_space, embedding, input_uri, dataset=None, inference_class=None, label_class=None): - data_to_insert = [] - for i in range(len(embedding)): - data_to_insert.append([model_space[i], str(uuid.uuid4()), embedding[i], input_uri[i], dataset[i], inference_class[i], (label_class[i] if label_class is not None else None)]) - - insert_string = "model_space, uuid, embedding, input_uri, dataset, inference_class, label_class" - self._conn.executemany(f''' - INSERT INTO embeddings ({insert_string}) VALUES (?,?,?,?,?,?,?)''', data_to_insert) - - def count(self, model_space=None): - return self._count(model_space=model_space).fetchall()[0][0] - - def _fetch(self, where={}, columnar=False): - val = self._conn.execute(f'''SELECT {db_schema_to_keys()} FROM embeddings {where}''').fetchall() - if columnar: - val = list(zip(*val)) - return val - - def _delete(self, where={}): - uuids_deleted = self._conn.execute(f'''SELECT uuid FROM embeddings {where}''').fetchall() - self._conn.execute(f''' - DELETE FROM - embeddings - {where} - ''').fetchall()[0] - return uuids_deleted - - def get_by_ids(self, ids=list): - # select from duckdb table where ids are in the list - if not isinstance(ids, list): - raise Exception("ids must be a list") - - if not ids: - # create an empty pandas dataframe - return pd.DataFrame() - - return self._conn.execute(f''' - SELECT - {db_schema_to_keys()} - FROM - embeddings - WHERE - uuid IN ({','.join([("'" + str(x) + "'") for x in ids])}) - ''').fetchall() - - def persist(self): - ''' - Persist the database to disk - ''' - if self._conn is None: - return - - self._conn.execute(f''' - COPY - (SELECT * FROM embeddings) - TO '{self._save_folder}/chroma.parquet' - (FORMAT PARQUET); - ''') - - def load(self): - ''' - Load the database from disk - ''' - path = self._save_folder + "/chroma.parquet" - self._conn.execute(f"INSERT INTO embeddings SELECT * FROM read_parquet('{path}');") - - def __del__(self): - self.persist() \ No newline at end of file diff --git a/chroma-server/chroma_server/index/__init__.py b/chroma-server/chroma_server/index/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chroma-server/chroma_server/index/abstract.py b/chroma-server/chroma_server/index/abstract.py deleted file mode 100644 index 86df060..0000000 --- a/chroma-server/chroma_server/index/abstract.py +++ /dev/null @@ -1,27 +0,0 @@ -from abc import abstractmethod - - -class Index: - @abstractmethod - def __init__(self): - pass - - @abstractmethod - def run(self, batch): - pass - - @abstractmethod - def fetch(self, query): - pass - - @abstractmethod - def delete_batch(self, batch): - pass - - @abstractmethod - def persist(self): - pass - - @abstractmethod - def load(self): - pass diff --git a/chroma-server/chroma_server/logger.py b/chroma-server/chroma_server/logger.py deleted file mode 100644 index 1976460..0000000 --- a/chroma-server/chroma_server/logger.py +++ /dev/null @@ -1,12 +0,0 @@ -import logging - - -def setup_logging(): - logging.basicConfig(filename="chroma_logs.log") - logger = logging.getLogger("Chroma") - logger.setLevel(logging.DEBUG) - logger.debug("Logger created") - return logger - - -logger = setup_logging() diff --git a/chroma-server/chroma_server/routes.py b/chroma-server/chroma_server/routes.py deleted file mode 100644 index 6eb647d..0000000 --- a/chroma-server/chroma_server/routes.py +++ /dev/null @@ -1,204 +0,0 @@ -from celery.result import AsyncResult -import time -import os - -from chroma_server.db.clickhouse import Clickhouse, get_col_pos -from chroma_server.db.duckdb import DuckDB -from chroma_server.index.hnswlib import Hnswlib -from chroma_server.types import (AddEmbedding, CountEmbedding, DeleteEmbedding, - FetchEmbedding, ProcessEmbedding, - QueryEmbedding, RawSql, Results, - SpaceKeyInput) -from chroma_server.utils.error_reporting import init_error_reporting -from chroma_server.utils.telemetry.capture import Capture -from chroma_server.worker import heavy_offline_analysis -from fastapi import FastAPI, status, APIRouter -from fastapi.responses import JSONResponse - -class ChromaRouter(): - - _app = None - _db = None - _ann_index = None - _celery_enabed = True - - def __init__(self, app: FastAPI, db, ann_index: Hnswlib): - self._app = app - self._db = db - self._ann_index = ann_index - - self.router = APIRouter() - self.router.add_api_route("/api/v1", self.root, methods=["GET"]) - self.router.add_api_route("/api/v1/add", self.add, methods=["POST"], status_code=status.HTTP_201_CREATED) - self.router.add_api_route("/api/v1/fetch", self.fetch, methods=["POST"]) - self.router.add_api_route("/api/v1/delete", self.delete, methods=["POST"]) - self.router.add_api_route("/api/v1/count", self.count, methods=["GET"]) - self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"]) - self.router.add_api_route("/api/v1/raw_sql", self.raw_sql, methods=["POST"]) - self.router.add_api_route("/api/v1/get_nearest_neighbors", self.get_nearest_neighbors, methods=["POST"]) - self.router.add_api_route("/api/v1/create_index", self.create_index, methods=["POST"]) - self.router.add_api_route("/api/v1/process", self.process, methods=["POST"]) - self.router.add_api_route("/api/v1/get_status", self.get_status, methods=["POST"]) - self.router.add_api_route("/api/v1/get_results", self.get_results, methods=["POST"]) - - # if the type of the db is not duckdb, then disable celery - if not isinstance(self._db, Clickhouse): - self._celery_enabed = False - - # API Endpoints - def root(self): - '''Heartbeat endpoint''' - return {"nanosecond heartbeat": int(1000 * time.time_ns())} - - def add(self, new_embedding: AddEmbedding): - '''Save batched embeddings to database''' - - number_of_embeddings = len(new_embedding.embedding) - - if isinstance(new_embedding.model_space, str): - model_space = [new_embedding.model_space] * number_of_embeddings - elif len(new_embedding.model_space) == 1: - model_space = [new_embedding.model_space[0]] * number_of_embeddings - else: - model_space = new_embedding.model_space - - if isinstance(new_embedding.dataset, str): - dataset = [new_embedding.dataset] * number_of_embeddings - elif len(new_embedding.dataset) == 1: - dataset = [new_embedding.dataset[0]] * number_of_embeddings - else: - dataset = new_embedding.dataset - - self._app._db.add( - model_space, - new_embedding.embedding, - new_embedding.input_uri, - dataset, - new_embedding.inference_class, - new_embedding.label_class - ) - - return {"response": "Added records to database"} - - def fetch(self, embedding: FetchEmbedding): - ''' - Fetches embeddings from the database - - enables filtering by where, sorting by key, and limiting the number of results - ''' - return self._app._db.fetch(embedding.where, embedding.sort, embedding.limit, embedding.offset) - - def delete(self, embedding: DeleteEmbedding): - ''' - Deletes embeddings from the database - - enables filtering by where - ''' - deleted_uuids = self._app._db.delete(embedding.where) - if len(embedding.where) == 1: - if 'model_space' in embedding.where: - self._app._ann_index.delete(embedding.where['model_space']) - - deleted_uuids = [uuid[0] for uuid in deleted_uuids] # de-tuple - self._app._ann_index.delete_from_index(embedding.where['model_space'], deleted_uuids) - return deleted_uuids - - # @app.get("/api/v1/count") - def count(self, model_space: str = None): - ''' - Returns the number of records in the database - ''' - return {"count": self._app._db.count(model_space=model_space)} - - def reset(self): - ''' - Reset the database and index - WARNING: Destructive! - ''' - index_save_folder = self._app._ann_index.get_save_folder() - db_save_folder = self._app._db.get_save_folder() - - self._app._db = self._db() - self._app._db.reset() - self._app._ann_index.reset() # this has to come first I think - self._app._ann_index = self._ann_index() - - self._app._ann_index.set_save_folder(index_save_folder) - self._app._db.set_save_folder(db_save_folder) - - return True - - def get_nearest_neighbors(self,embedding: QueryEmbedding): - ''' - return the distance, database ids, and embedding themselves for the input embedding - ''' - if embedding.where['model_space'] is None: - return {"error": "model_space is required"} - - results = self._app._db.fetch(embedding.where) - ids = [str(item[get_col_pos('uuid')]) for item in results] - - uuids, distances = self._app._ann_index.get_nearest_neighbors(embedding.where['model_space'], embedding.embedding, embedding.n_results, ids) - return { - "ids": uuids, - "embeddings": self._app._db.get_by_ids(uuids), - "distances": distances.tolist()[0] - } - - def raw_sql(self, raw_sql: RawSql): - return self._app._db.raw_sql(raw_sql.raw_sql) - - def create_index(self, process_embedding: ProcessEmbedding): - ''' - Currently generates an index for the embedding db - ''' - fetch = self._app._db.fetch({"model_space": process_embedding.model_space}, columnar=True) - # chroma_telemetry.capture('created-index-run-process', {'n': len(fetch[2])}) - self._app._ann_index.run(process_embedding.model_space, fetch[1], fetch[2]) # more magic number, ugh - - def process(self, process_embedding: ProcessEmbedding): - ''' - Currently generates an index for the embedding db - ''' - if not self._celery_enabed: - raise Exception("in-memory mode does not process because it relies on celery and redis") - - fetch = self._app._db.fetch({"model_space": process_embedding.model_space}, columnar=True) - # chroma_telemetry.capture('created-index-run-process', {'n': len(fetch[2])}) - self._app._ann_index.run(process_embedding.model_space, fetch[1], fetch[2]) # more magic number, ugh - - task = heavy_offline_analysis.delay(process_embedding.model_space) - # chroma_telemetry.capture('heavy-offline-analysis') - return JSONResponse({"task_id": task.id}) - - def get_status(self, task_id): - if not self._celery_enabed: - raise Exception("in-memory mode does not process because it relies on celery and redis") - - task_result = AsyncResult(task_id) - result = { - "task_id": task_id, - "task_status": task_result.status, - "task_result": task_result.result - } - return JSONResponse(result) - - def get_results(self, results: Results): - if not self._celery_enabed: - raise Exception("in-memory mode does not process because it relies on celery and redis") - - # if there is no index, generate one - if not self._app._ann_index.has_index(results.model_space): - fetch = self._app._db.fetch({"model_space": results.model_space}, columnar=True) - # chroma_telemetry.capture('run-process', {'n': len(fetch[2])}) - print("Generating index for model space: ", results.model_space, " with ", len(fetch[2]), " embeddings") - self._app._ann_index.run(results.model_space, fetch[1], fetch[2]) # more magic number, ugh - print("Done generating index for model space: ", results.model_space) - - # if there are no results, generate them - print("self._app._db.count_results(results.model_space): ", self._app._db.count_results(results.model_space)) - if self._app._db.count_results(results.model_space) == 0: - print("starting heavy offline analysis") - task = heavy_offline_analysis(results.model_space) - print("ending heavy offline analysis") - return self._app._db.return_results(results.model_space, results.n_results) - - else: - return self._app._db.return_results(results.model_space, results.n_results) \ No newline at end of file diff --git a/chroma-server/chroma_server/test/__init__.py b/chroma-server/chroma_server/test/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chroma-server/chroma_server/test/test_api.py b/chroma-server/chroma_server/test/test_api.py deleted file mode 100644 index 7b15590..0000000 --- a/chroma-server/chroma_server/test/test_api.py +++ /dev/null @@ -1,181 +0,0 @@ -import pytest -import time -from httpx import AsyncClient -from ..api import app - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - -@pytest.mark.anyio -async def test_root(): - async with AsyncClient(app=app, base_url="http://test") as ac: - response = await ac.get("/api/v1") - assert response.status_code == 200 - assert isinstance(response.json()["nanosecond heartbeat"], int) - -async def post_batch_records(ac): - return await ac.post( - "/api/v1/add", - json={ - "embedding": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], - "input_uri": ["https://example.com", "https://example.com"], - "dataset": ["training", "training"], - "inference_class": ["knife", "person"], - "model_space": ["test_space", "test_space"], - "label_class": ["person", "person"], - }, - ) - -async def post_batch_records_minimal(ac): - return await ac.post( - "/api/v1/add", - json={ - "embedding": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], - "input_uri": ["https://example.com", "https://example.com"], - "dataset": "training", - "inference_class": ["person", "person"], - "model_space": "test_space" - }, #label_class left off on purpose - ) - - -@pytest.mark.anyio -async def test_add_batch(): - async with AsyncClient(app=app, base_url="http://test") as ac: - await ac.post("/api/v1/reset") - response = await post_batch_records(ac) - assert response.status_code == 201 - assert response.json() == {"response": "Added records to database"} - response = await ac.get("/api/v1/count", params={"model_space": "test_space"}) - assert response.json() == {"count": 2} - - -@pytest.mark.anyio -async def test_add_batch_minimal(): - async with AsyncClient(app=app, base_url="http://test") as ac: - await ac.post("/api/v1/reset") - response = await post_batch_records_minimal(ac) - assert response.status_code == 201 - assert response.json() == {"response": "Added records to database"} - response = await ac.get("/api/v1/count", params={"model_space": "test_space"}) - assert response.json() == {"count": 2} - -@pytest.mark.anyio -async def test_fetch_from_db(): - async with AsyncClient(app=app, base_url="http://test") as ac: - await ac.post("/api/v1/reset") - await post_batch_records(ac) - params = {"where": {"model_space": "test_space"}} - response = await ac.post("/api/v1/fetch", json=params) - assert response.status_code == 200 - assert len(response.json()) == 2 - -@pytest.mark.anyio -async def test_count_from_db(): - async with AsyncClient(app=app, base_url="http://test") as ac: - await ac.post("/api/v1/reset") # reset db - await post_batch_records(ac) - response = await ac.get("/api/v1/count", params={"model_space": "test_space"}) - assert response.status_code == 200 - assert response.json() == {"count": 2} - -@pytest.mark.anyio -async def test_reset_db(): - async with AsyncClient(app=app, base_url="http://test") as ac: - await ac.post("/api/v1/reset") - await post_batch_records(ac) - response = await ac.get("/api/v1/count", params={"model_space": "test_space"}) - assert response.json() == {"count": 2} - response = await ac.post("/api/v1/reset") - assert response.json() == True - response = await ac.get("/api/v1/count", params={"model_space": "test_space"}) - assert response.json() == {"count": 0} - -@pytest.mark.anyio -async def test_get_nearest_neighbors(): - async with AsyncClient(app=app, base_url="http://test") as ac: - await ac.post("/api/v1/reset") - await post_batch_records(ac) - await ac.post("/api/v1/create_index", json={"model_space": "test_space"}) - response = await ac.post( - "/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "where":{"model_space": "test_space"}} - ) - assert response.status_code == 200 - assert len(response.json()["ids"]) == 1 - -@pytest.mark.anyio -async def test_get_nearest_neighbors_filter(): - async with AsyncClient(app=app, base_url="http://test") as ac: - await ac.post("/api/v1/reset") - await post_batch_records(ac) - await ac.post("/api/v1/create_index", json={"model_space": "test_space"}) - response = await ac.post( - "/api/v1/get_nearest_neighbors", - json={ - "embedding": [1.1, 2.3, 3.2], - "n_results": 1, - "where":{ - "dataset": "training", - "inference_class": "monkey", - "model_space": "test_space", - } - }, - ) - assert response.status_code == 200 - assert len(response.json()["ids"]) == 0 - -@pytest.mark.anyio -async def test_process(): - async with AsyncClient(app=app, base_url="http://test") as ac: - await ac.post("/api/v1/reset") - await post_batch_records(ac) - response = await ac.post("/api/v1/create_index", json={"model_space": "test_space"}) - assert response.status_code == 200 - -# test delete -@pytest.mark.anyio -async def test_delete(): - async with AsyncClient(app=app, base_url="http://test") as ac: - await ac.post("/api/v1/reset") - await post_batch_records(ac) - response = await ac.get("/api/v1/count", params={"model_space": "test_space"}) - assert response.json() == {"count": 2} - response = await ac.post("/api/v1/delete", json={"where": {"model_space": "test_space"}}) - response = await ac.get("/api/v1/count", params={"model_space": "test_space"}) - assert response.json() == {"count": 0} - -@pytest.mark.anyio -async def test_delete_with_index(): - async with AsyncClient(app=app, base_url="http://test") as ac: - await ac.post("/api/v1/reset") - await post_batch_records(ac) - response = await ac.get("/api/v1/count", params={"model_space": "test_space"}) - assert response.json() == {"count": 2} - await ac.post("/api/v1/create_index", json={"model_space": "test_space"}) - response = await ac.post( - "/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "where":{"model_space": "test_space"}} - ) - assert response.json()['embeddings'][0][5] == 'knife' - response = await ac.post("/api/v1/delete", json={"where": {"model_space": "test_space", "inference_class": "knife"}}) - response = await ac.post( - "/api/v1/get_nearest_neighbors", json={"embedding": [1.1, 2.3, 3.2], "n_results": 1, "where":{"model_space": "test_space"}} - ) - assert response.json()['embeddings'][0][5] == 'person' - -# test calculate results -# @pytest.mark.anyio -# async def test_calculate_results(): -# async with AsyncClient(app=app, base_url="http://test") as ac: -# await ac.post("/api/v1/reset") -# await post_batch_records(ac) -# await ac.post("/api/v1/process", json={"model_space": "test_space"}) -# response = await ac.post( -# "/api/v1/calculate_results", -# json={ -# "model_space": "test_space", -# }, -# ) -# assert response.status_code == 200 -# assert response.json() == {"ids": [], "distances": []} \ No newline at end of file diff --git a/chroma-server/chroma_server/utils/__init__.py b/chroma-server/chroma_server/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chroma-server/chroma_server/utils/config/__init__.py b/chroma-server/chroma_server/utils/config/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chroma-server/chroma_server/utils/config/settings.py b/chroma-server/chroma_server/utils/config/settings.py deleted file mode 100644 index e04d0dc..0000000 --- a/chroma-server/chroma_server/utils/config/settings.py +++ /dev/null @@ -1,15 +0,0 @@ -from functools import lru_cache -from typing import Union -from pydantic import BaseSettings - -class Settings(BaseSettings): - disable_anonymized_telemetry: bool = False - telemetry_anonymized_uuid: str = '' - environment: str = 'development' - - class Config: - env_file = ".env" - -@lru_cache() -def get_settings(): - return Settings() \ No newline at end of file diff --git a/chroma-server/chroma_server/utils/telemetry/__init__.py b/chroma-server/chroma_server/utils/telemetry/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chroma-server/docker-compose.test.yml b/chroma-server/docker-compose.test.yml deleted file mode 100644 index ff8d92e..0000000 --- a/chroma-server/docker-compose.test.yml +++ /dev/null @@ -1,39 +0,0 @@ -version: '3.9' - -networks: - chroma-net: - driver: bridge - -services: - server_test: - build: - context: . - dockerfile: Dockerfile - target: chroma_server_test - volumes: - - ./:/chroma-server/ - - index_data:/index_data - depends_on: - - clickhouse - networks: - - chroma-net - environment: - - CLICKHOUSE_TCP_PORT=9001 - - environment=test - - clickhouse: - image: docker.io/bitnami/clickhouse:22.9 - environment: - - ALLOW_EMPTY_PASSWORD=yes - - CLICKHOUSE_TCP_PORT=9001 - - CLICKHOUSE_HTTP_PORT=8124 - ports: - - '8124:8124' - - '9001:9001' - networks: - - chroma-net - -volumes: - index_data: - driver: local - diff --git a/chroma-server/pyproject.toml b/chroma-server/pyproject.toml deleted file mode 100644 index b06b173..0000000 --- a/chroma-server/pyproject.toml +++ /dev/null @@ -1,20 +0,0 @@ -[project] -name = "chroma_server" -dynamic = ["version"] - -authors = [ - { name="Jeff Huber", email="jeff@trychroma.com" }, - { name="Anton Troynikov", email="anton@trychroma.com" } -] -description = "Chroma." -readme = "README.md" -requires-python = ">=3.7" - -[project.urls] -"Homepage" = "https://github.com/chroma-core/chroma" -"Bug Tracker" = "https://github.com/chroma-core/chroma/issues" - -[tool.setuptools_scm] -root=".." -local_scheme="no-local-version" - diff --git a/chroma-server/run_tests.sh b/chroma-server/run_tests.sh deleted file mode 100644 index ccba284..0000000 --- a/chroma-server/run_tests.sh +++ /dev/null @@ -1,2 +0,0 @@ -python -m pytest -CHROMA_MODE=in-memory python -m pytest \ No newline at end of file diff --git a/chroma/__init__.py b/chroma/__init__.py new file mode 100644 index 0000000..ba6b8d0 --- /dev/null +++ b/chroma/__init__.py @@ -0,0 +1,48 @@ +import chroma.config + +__settings = chroma.config.Settings() + +def configure(**kwags): + """Override Chroma's default settings, environment variables or .env files""" + __settings = chroma.config.Settings(**kwargs) + +def get_settings(): + return __settings + +def get_db(settings=__settings): + """Return a chroma.DB instance based on the provided or environmental settings.""" + + if settings.clickhouse_host: + print("Using Clickhouse for database") + import chroma.db.clickhouse + return chroma.db.clickhouse.Clickhouse(settings) + elif settings.chroma_cache_dir: + print("Using DuckDB with local filesystem persistence for database") + import chroma.db.duckdb + return chroma.db.duckdb.PersistentDuckDB(settings) + else: + print("Using DuckDB in-memory for database. Data will be transient.") + import chroma.db.duckdb + return chroma.db.duckdb.DuckDB(settings) + + +def get_api(settings=__settings): + """Return a chroma.API instance based on the provided or environmental + settings, optionally overriding the DB instance.""" + + if settings.chroma_server_host and settings.chroma_server_grpc_port: + print("Running Chroma in client/server mode using ArrowFlight protocol.") + import chroma.api.arrowflight + return chroma.api.arrowflight.ArrowFlightAPI(settings) + elif settings.chroma_server_host and settings.chroma_server_http_port: + print("Running Chroma in client/server mode using REST protocol.") + import chroma.api.fastapi + return chroma.api.fastapi.FastAPI(settings) + elif settings.celery_broker_url: + print("Running Chroma in server mode with Celery jobs enabled.") + import chroma.api.celery + return chroma.api.celery.CeleryAPI(settings, get_db(settings)) + else: + print("Running Chroma using direct local API.") + import chroma.api.local + return chroma.api.local.LocalAPI(settings, get_db(settings)) diff --git a/chroma-client/src/chroma_client/client.py b/chroma/api/__init__.py similarity index 51% rename from chroma-client/src/chroma_client/client.py rename to chroma/api/__init__.py index 7b3b18c..43b8396 100644 --- a/chroma-client/src/chroma_client/client.py +++ b/chroma/api/__init__.py @@ -1,67 +1,55 @@ -import requests -import json +from abc import ABC, abstractmethod from typing import Union -class Chroma: +class API(ABC): - _api_url = "http://localhost:8000/api/v1" - _model_space = "default_scope" + _model_space = 'default_scope' - def __init__(self, url=None, model_space=None): - """Initialize Chroma client""" + @abstractmethod + def __init__(self): + pass - if isinstance(url, str) and url.startswith("http"): - self._api_url = url - - if isinstance(model_space, str) and model_space: - self._model_space = model_space - - def set_model_space(self, model_space): - '''Sets the space key for the client, enables overriding the string concat''' - self._model_space = model_space - - def get_model_space(self): - '''Returns the model_space key''' - return self._model_space + @abstractmethod def heartbeat(self): '''Returns the current server time in nanoseconds to check if the server is alive''' - return requests.get(self._api_url).json() + pass + + @abstractmethod + def add(self, + embedding: list, + input_uri: list, + dataset: list = None, + inference_class: list = None, + label_class: list = None, + model_spaces: list = None): + """Add embeddings to the data store""" + pass + + + @abstractmethod def count(self, model_space=None): '''Returns the number of embeddings in the database''' - params = {"model_space": model_space or self._model_space} - x = requests.get(self._api_url + "/count", params=params) - return x.json() + pass + + @abstractmethod def fetch(self, where={}, sort=None, limit=None, offset=None, page=None, page_size=None): '''Fetches embeddings from the database''' - if self._model_space: - where["model_space"] = self._model_space + pass - if page and page_size: - offset = (page - 1) * page_size - limit = page_size - - return requests.post(self._api_url + "/fetch", data=json.dumps({ - "where":where, - "sort":sort, - "limit":limit, - "offset":offset - })).json() + @abstractmethod def delete(self, where={}): '''Deletes embeddings from the database''' - if self._model_space: - where["model_space"] = self._model_space + pass - return requests.post(self._api_url + "/delete", data=json.dumps({ - "where":where, - })).json() - - def add(self, - embedding: list, - input_uri: list, + + @abstractmethod + def add(self, + embedding: list, + input_uri: list, dataset: list = None, inference_class: list = None, label_class: list = None, @@ -70,107 +58,113 @@ class Chroma: Addss a batch of embeddings to the database - pass in column oriented data lists ''' + pass - if not model_spaces: - model_spaces = self._model_space - x = requests.post(self._api_url + "/add", data = json.dumps({ - "model_space": model_spaces, - "embedding": embedding, - "input_uri": input_uri, - "dataset": dataset, - "inference_class": inference_class, - "label_class": label_class - }) ) + @abstractmethod + def get_nearest_neighbors(self, embedding, n_results=10, where={}): + '''Gets the nearest neighbors of a single embedding''' + pass + + + @abstractmethod + def process(self, model_space=None): + ''' + Processes embeddings in the database + - currently this only runs hnswlib, doesnt return anything + ''' + pass + + + @abstractmethod + def reset(self): + '''Resets the database''' + pass + + + @abstractmethod + def raw_sql(self, sql): + '''Runs a raw SQL query against the database''' + pass + + + @abstractmethod + def get_results(self, model_space=None, n_results = 100): + '''Gets the results for the given space key''' + pass + + + @abstractmethod + def get_task_status(self, task_id): + '''Gets the status of a task''' + pass + + + @abstractmethod + def create_index(self, model_space=None): + '''Creates an index for the given space key''' + pass + + + def set_model_space(self, model_space): + '''Sets the space key for the client, enables overriding the string concat''' + self._model_space = model_space + + + def get_model_space(self): + '''Returns the model_space key''' + return self._model_space + + + def where_with_model_space(self, where_clause): + '''Returns a where clause that specifies the default model space iff it wasn't already specified''' + + if self._model_space and "model_space" not in where_clause: + where_clause["model_space"] = self._model_space + + return where_clause + - if x.status_code == 201: - return True - else: - return False - def add_training(self, embedding: list, input_uri: list, inference_class: list, label_class: list = None, model_spaces: list = None): ''' Small wrapper around add() to add a batch of training embedding - sets dataset to "training" ''' datasets = ["training"] * len(input_uri) return self.add( - embedding=embedding, - input_uri=input_uri, + embedding=embedding, + input_uri=input_uri, dataset=datasets, inference_class=inference_class, model_spaces=model_spaces, label_class=label_class ) - + + def add_production(self, embedding: list, input_uri: list, inference_class: list, label_class: list = None, model_spaces: list = None): ''' Small wrapper around add() to add a batch of production embedding - sets dataset to "production" ''' datasets = ["production"] * len(input_uri) return self.add( - embedding=embedding, - input_uri=input_uri, + embedding=embedding, + input_uri=input_uri, dataset=datasets, inference_class=inference_class, model_spaces=model_spaces, label_class=label_class ) - + + def add_triage(self, embedding: list, input_uri: list, inference_class: list, label_class: list = None, model_spaces: list = None): ''' Small wrapper around add() to add a batch of triage embedding - sets dataset to "triage" ''' datasets = ["triage"] * len(input_uri) return self.add( - embedding=embedding, - input_uri=input_uri, + embedding=embedding, + input_uri=input_uri, dataset=datasets, inference_class=inference_class, model_spaces=model_spaces, label_class=label_class ) - - def get_nearest_neighbors(self, embedding, n_results=10, where={}): - '''Gets the nearest neighbors of a single embedding''' - - if "model_space" not in where: - where["model_space"] = self._model_space - - x = requests.post(self._api_url + "/get_nearest_neighbors", data = json.dumps({ - "embedding": embedding, - "n_results": n_results, - "where": where - }) ) - - if x.status_code == 200: - return x.json() - else: - return False - - def process(self, model_space=None): - ''' - Processes embeddings in the database - - currently this only runs hnswlib, doesnt return anything - ''' - x = requests.post(self._api_url + "/process", data = json.dumps({"model_space": model_space or self._model_space})) - return x.json() - - def reset(self): - '''Resets the database''' - return requests.post(self._api_url + "/reset") - - def raw_sql(self, sql): - '''Runs a raw SQL query against the database''' - return requests.post(self._api_url + "/raw_sql", data = json.dumps({"raw_sql": sql})).json() - - def get_results(self, model_space=None, n_results = 100): - '''Gets the results for the given space key''' - return requests.post(self._api_url + "/get_results", data = json.dumps({"model_space": model_space or self._model_space, "n_results": n_results})).json() - - def get_task_status(self, task_id): - '''Gets the status of a task''' - return requests.post(self._api_url + f"/tasks/{task_id}").json() - - def create_index(self, model_space=None): - '''Creates an index for the given space key''' - return requests.post(self._api_url + "/create_index", data = json.dumps({"model_space": model_space or self._model_space})).json() \ No newline at end of file diff --git a/chroma/api/arrowflight.py b/chroma/api/arrowflight.py new file mode 100644 index 0000000..0829db5 --- /dev/null +++ b/chroma/api/arrowflight.py @@ -0,0 +1,9 @@ +from chroma.api import API + +class ArrowFlightAPI(API): + + def __init__(self, settings): + print("Constructing Local instance") + + # TODO: Implement + diff --git a/chroma/api/celery.py b/chroma/api/celery.py new file mode 100644 index 0000000..ab2d1fc --- /dev/null +++ b/chroma/api/celery.py @@ -0,0 +1,47 @@ +from chroma.api.local import LocalAPI +from chroma.worker import heavy_offline_analysis +from celery.result import AsyncResult + +class CeleryAPI(LocalAPI): + + def __init__(self, settings): + pass + + + def process(self, model_space=None): + + self.create_index(model_space) + + task = heavy_offline_analysis.delay(model_space) + # chroma_telemetry.capture('heavy-offline-analysis') + + return task.id + + + def get_status(self, task_id): + + task_result = AsyncResult(task_id) + result = { + "task_id": task_id, + "task_status": task_result.status, + "task_result": task_result.result + } + + return result + + + def get_results(self, model_space=None, n_results=100): + + model_space = model_space or self._model_space + + if not self._db.has_index(model_space): + self._db.create_index(model_space) + + results_count = self._db.count_results(model_space) + + if results_count == 0: + heavy_offline_analysis(model_space) + + return self._db.return_results(model_space, n_results) + + diff --git a/chroma/api/fastapi.py b/chroma/api/fastapi.py new file mode 100644 index 0000000..4d0be1f --- /dev/null +++ b/chroma/api/fastapi.py @@ -0,0 +1,117 @@ +from chroma.api import API +import requests + +class FastAPI(API): + + def __init__(self, settings): + self._api_url = f'http://{settings.chroma_server_host}:{settings.chroma_server_http_port}/api/v1' + + def heartbeat(self): + '''Returns the current server time in nanoseconds to check if the server is alive''' + return int(requests.get(self._api_url).json()['nanosecond heartbeat']) + + def count(self, model_space=None): + '''Returns the number of embeddings in the database''' + params = {"model_space": model_space or self._model_space} + x = requests.get(self._api_url + "/count", params=params) + return x.json()['count'] + + def fetch(self, where={}, sort=None, limit=None, offset=None, page=None, page_size=None): + '''Fetches embeddings from the database''' + + where = self.where_with_model_space(where) + + if page and page_size: + offset = (page - 1) * page_size + limit = page_size + + return requests.post(self._api_url + "/fetch", data=json.dumps({ + "where":where, + "sort":sort, + "limit":limit, + "offset":offset + })).json() + + def delete(self, where={}): + '''Deletes embeddings from the database''' + + where = self.where_with_model_space(where) + + return requests.post(self._api_url + "/delete", data=json.dumps({ + "where":where, + })).json() + + def add(self, + embedding: list, + input_uri: list, + dataset: list = None, + inference_class: list = None, + label_class: list = None, + model_spaces: list = None): + ''' + Addss a batch of embeddings to the database + - pass in column oriented data lists + ''' + + if not model_spaces: + model_spaces = self._model_space + + x = requests.post(self._api_url + "/add", data = json.dumps({ + "model_space": model_spaces, + "embedding": embedding, + "input_uri": input_uri, + "dataset": dataset, + "inference_class": inference_class, + "label_class": label_class + }) ) + + if x.status_code == 201: + return True + else: + return False + + def get_nearest_neighbors(self, embedding, n_results=10, where={}): + '''Gets the nearest neighbors of a single embedding''' + + where = self.where_with_model_space(where) + + x = requests.post(self._api_url + "/get_nearest_neighbors", data = json.dumps({ + "embedding": embedding, + "n_results": n_results, + "where": where + }) ) + + if x.status_code == 200: + return x.json() + else: + return False + + def process(self, model_space=None): + ''' + Processes embeddings in the database + - currently this only runs hnswlib, doesnt return anything + ''' + x = requests.post(self._api_url + "/process", data = json.dumps({"model_space": model_space or self._model_space})) + return x.json() + + def reset(self): + '''Resets the database''' + return requests.post(self._api_url + "/reset") + + def raw_sql(self, sql): + '''Runs a raw SQL query against the database''' + return requests.post(self._api_url + "/raw_sql", data = json.dumps({"raw_sql": sql})).json() + + def get_results(self, model_space=None, n_results = 100): + '''Gets the results for the given space key''' + return requests.post(self._api_url + "/get_results", + data = json.dumps({"model_space": model_space or self._model_space, "n_results": n_results})).json() + + def get_task_status(self, task_id): + '''Gets the status of a task''' + return requests.post(self._api_url + f"/tasks/{task_id}").json() + + def create_index(self, model_space=None): + '''Creates an index for the given space key''' + return requests.post(self._api_url + "/create_index", + data = json.dumps({"model_space": model_space or self._model_space})).json() diff --git a/chroma/api/local.py b/chroma/api/local.py new file mode 100644 index 0000000..ef5a9d2 --- /dev/null +++ b/chroma/api/local.py @@ -0,0 +1,115 @@ +import time +from chroma.api import API + +class LocalAPI(API): + + def __init__(self, settings, db): + self._db = db + + + def heartbeat(self): + return int(1000 * time.time_ns()) + + + def add(self, + embedding: list, + input_uri: list, + dataset: list = None, + inference_class: list = None, + label_class: list = None, + model_spaces: list = None): + + number_of_embeddings = len(embedding) + + if isinstance(model_spaces, str): + model_space = [model_spaces] * number_of_embeddings + elif len(model_spaces) == 1: + model_space = [model_spaces[0]] * number_of_embeddings + else: + model_space = model_spaces + + if isinstance(dataset, str): + ds = [dataset] * number_of_embeddings + elif len(dataset) == 1: + ds = [dataset[0]] * number_of_embeddings + else: + ds = dataset + + self._db.add( + model_space, + embedding, + input_uri, + ds, + inference_class, + label_class + ) + + + def fetch(self, where={}, sort=None, limit=None, offset=None, page=None, page_size=None): + + if page and page_size: + offset = (page - 1) * page_size + limit = page_size + + return self._db.fetch(where, sort, limit, offset) + + + def delete(self, where={}): + + where = self.where_with_model_space(where) + deleted_uuids = self._db.delete(where) + return deleted_uuids + + + def count(self, model_space=None): + + model_space = model_space or self._model_space + return {"count": self._db.count(model_space=model_space)} + + + def reset(self): + + self._db.reset() + return True + + + def get_nearest_neighbors(self, embedding, n_results, where): + + where = self.where_with_model_space(where) + + results = self._db.fetch(where) + ids = [str(item[get_col_pos('uuid')]) for item in results] + + uuids, distances = self._db.get_nearest_neighbors(where['model_space'], embedding, n_results, ids) + + return { + "ids": uuids, + "embeddings": self._db.get_by_ids(uuids), + "distances": distances.tolist()[0] + } + + + def raw_sql(self, raw_sql): + + return self._db.raw_sql(raw_sql) + + + def create_index(self, model_space=None): + + self._db.create_index(model_space or self._model_space) + return True + + + def process(self, model_space=None): + + raise NotImplementedError("Cannot launch job: Celery is not configured") + + + def get_status(self, task_id): + + raise NotImplementedError("Cannot get status of job: Celery is not configured") + + + def get_results(self, model_space=None, n_results=100): + + raise NotImplementedError("Cannot get job results: Celery is not configured") diff --git a/chroma/config.py b/chroma/config.py new file mode 100644 index 0000000..4f6d731 --- /dev/null +++ b/chroma/config.py @@ -0,0 +1,24 @@ +from pydantic import BaseSettings, Field + + +class Settings(BaseSettings): + + disable_anonymized_telemetry: bool = False + telemetry_anonymized_uuid: str = "" + environment: str = "" + + clickhouse_host: str = None + clickhouse_port: str = None + + celery_broker_url: str = None + celery_result_backend: str = None + + chroma_cache_dir: str = None + + chroma_server_host: str = None + chroma_server_http_port: str = None + chroma_server_grpc_port: str = None + + class Config: + env_file = '.env' + env_file_encoding = 'utf-8' diff --git a/chroma/db/__init__.py b/chroma/db/__init__.py new file mode 100644 index 0000000..cfe4e65 --- /dev/null +++ b/chroma/db/__init__.py @@ -0,0 +1,75 @@ +from abc import ABC, abstractmethod + +class DB(ABC): + + @abstractmethod + def __init__(self): + pass + + + @abstractmethod + def add(self, model_space, embedding, input_uri, dataset=None, custom_quality_score=None, inference_class=None, label_class=None): + pass + + + @abstractmethod + def fetch(self, where, sort, limit, offset, columnar): + pass + + @abstractmethod + def delete(self, where): + pass + + @abstractmethod + def reset(self): + pass + + + @abstractmethod + def get_nearest_neighbors(self, where, embedding, n_results, ids): + pass + + + @abstractmethod + def get_by_ids(self, uuids): + pass + + + @abstractmethod + def raw_sql(self, raw_sql): + pass + + + @abstractmethod + def create_index(self, model_space): + pass + + + @abstractmethod + def has_index(self, model_space): + pass + + + @abstractmethod + def count_results(self, model_space): + pass + + + @abstractmethod + def return_results(self, model_space, n_results): + pass + + + @abstractmethod + def delete_results(self, model_space): + pass + + + @abstractmethod + def add_results(self, model_space, uuids, quality_scores): + pass + + + @abstractmethod + def get_col_pos(self, col_name): + pass diff --git a/chroma-server/chroma_server/db/clickhouse.py b/chroma/db/clickhouse.py similarity index 86% rename from chroma-server/chroma_server/db/clickhouse.py rename to chroma/db/clickhouse.py index 708aa7a..f96109c 100644 --- a/chroma-server/chroma_server/db/clickhouse.py +++ b/chroma/db/clickhouse.py @@ -1,4 +1,5 @@ -from chroma_server.db.abstract import Database +from chroma.db import DB +from chroma.db.index.hnswlib import Hnswlib import uuid import time import os @@ -37,14 +38,12 @@ def db_schema_to_keys(): return_str += f"{list(element.keys())[0]}, " return return_str -def get_col_pos(col_name): - for i, col in enumerate(EMBEDDING_TABLE_SCHEMA): - if col_name in col: - return i -class Clickhouse(Database): +class Clickhouse(DB): + _conn = None + def _create_table_embeddings(self): self._conn.execute(f'''CREATE TABLE IF NOT EXISTS embeddings ( {db_array_schema_to_clickhouse_schema(EMBEDDING_TABLE_SCHEMA)} @@ -52,17 +51,20 @@ class Clickhouse(Database): self._conn.execute(f'''SET allow_experimental_lightweight_delete = true''') self._conn.execute(f'''SET mutations_sync = 1''') # https://clickhouse.com/docs/en/operations/settings/settings/#mutations_sync - + + def _create_table_results(self): self._conn.execute(f'''CREATE TABLE IF NOT EXISTS results ( {db_array_schema_to_clickhouse_schema(RESULTS_TABLE_SCHEMA)} ) ENGINE = MergeTree() ORDER BY model_space''') - def __init__(self): - client = Client(host='clickhouse', port=os.getenv('CLICKHOUSE_TCP_PORT', '9000')) - self._conn = client + + def __init__(self, settings): + self._conn = Client(host=settings.clickhouse_host, port=clickhouse_port) self._create_table_embeddings() self._create_table_results() + self._idx = Hnswlib() + def add(self, model_space, embedding, input_uri, dataset=None, inference_class=None, label_class=None): data_to_insert = [] @@ -74,18 +76,11 @@ class Clickhouse(Database): self._conn.execute(f''' INSERT INTO embeddings ({insert_string}) VALUES''', data_to_insert) - def _count(self, model_space=None): - where_string = "" - if model_space is not None: - where_string = f"WHERE model_space = '{model_space}'" - return self._conn.execute(f"SELECT COUNT() FROM embeddings {where_string}") - - def count(self, model_space=None): - return self._count(model_space=model_space)[0][0] def _fetch(self, where={}, columnar=False): return self._conn.execute(f'''SELECT {db_schema_to_keys()} FROM embeddings {where}''', columnar=columnar) + def fetch(self, where={}, sort=None, limit=None, offset=None, columnar=False): if where["model_space"] is None: return {"error": "model_space is required"} @@ -95,12 +90,12 @@ class Clickhouse(Database): if where is not None: if not isinstance(where, dict): raise Exception("Invalid where: " + str(where)) - + # ensure where is a flat dict for key in where: if isinstance(where[key], dict): raise Exception("Invalid where: " + str(where)) - + where = " AND ".join([f"{key} = '{value}'" for key, value in where.items()]) if where: @@ -113,7 +108,7 @@ class Clickhouse(Database): if limit is not None or isinstance(limit, int): where += f" LIMIT {limit}" - + if offset is not None or isinstance(offset, int): where += f" OFFSET {offset}" @@ -125,7 +120,7 @@ class Clickhouse(Database): def _delete(self, where={}): uuids_deleted = self._conn.execute(f'''SELECT toString(uuid) FROM embeddings {where}''') self._conn.execute(f''' - DELETE FROM + DELETE FROM embeddings {where} ''') @@ -140,35 +135,56 @@ class Clickhouse(Database): if where is not None: if not isinstance(where, dict): raise Exception("Invalid where: " + str(where)) - + # ensure where is a flat dict for key in where: if isinstance(where[key], dict): raise Exception("Invalid where: " + str(where)) - + where = " AND ".join([f"{key} = '{value}'" for key, value in where.items()]) if where: where = f"WHERE {where}" - val = self._delete(where=where) + deleted_uuids = self._delete(where=where) print(f"time to fetch {len(val)} embeddings: ", time.time() - s3) - return val + if len(where) == 1: + self._idx.delete(where['model_space']) + + self._idx.delete_from_index(where['model_space'], [uuid[0] for uuid in deleted_uuids]) + + return deleted_uuids + def get_by_ids(self, ids=list): return self._conn.execute(f''' SELECT {db_schema_to_keys()} FROM embeddings WHERE uuid IN ({ids})''') + + def create_index(self, model_space): + fetch = self.fetch({"model_space": model_space}, columnar=True) + self._idx.run(model_space, fetch[1], fetch[2]) # more magic number, ugh + + + def has_index(self, model_space): + return self._idx.has_index(self, model_space) + + def reset(self): self._conn.execute('DROP TABLE embeddings') self._conn.execute('DROP TABLE results') self._create_table_embeddings() self._create_table_results() + self._idx.reset() + self._idx = Hnswlib() + + def raw_sql(self, sql): return self._conn.execute(sql) + def add_results(self, model_spaces, uuids, custom_quality_score): data_to_insert = [] for i in range(len(model_spaces)): @@ -176,16 +192,19 @@ class Clickhouse(Database): self._conn.execute(''' INSERT INTO results (model_space, uuid, custom_quality_score) VALUES''', data_to_insert) - + + def delete_results(self, model_space): self._conn.execute(f"DELETE FROM results WHERE model_space = '{model_space}'") + def count_results(self, model_space=None): where_string = "" if model_space is not None: where_string = f"WHERE model_space = '{model_space}'" return self._conn.execute(f"SELECT COUNT() FROM results {where_string}")[0][0] - + + def return_results(self, model_space, n_results = 100): return self._conn.execute(f''' SELECT @@ -205,8 +224,9 @@ class Clickhouse(Database): LIMIT {n_results} ''') - def set_save_folder(self, path): - pass - - def get_save_folder(self): - pass \ No newline at end of file + + def get_col_pos(col_name): + for i, col in enumerate(EMBEDDING_TABLE_SCHEMA): + if col_name in col: + return i + diff --git a/chroma/db/duckdb.py b/chroma/db/duckdb.py new file mode 100644 index 0000000..1ffbbf4 --- /dev/null +++ b/chroma/db/duckdb.py @@ -0,0 +1,12 @@ +from chroma.db import DB + +class DuckDB(DB): + + def __init__(self, settings): + pass + +class PersistentDuckDB(DuckDB): + + def __init__(self, settings): + pass + diff --git a/chroma/db/index/__init__.py b/chroma/db/index/__init__.py new file mode 100644 index 0000000..8c1caa7 --- /dev/null +++ b/chroma/db/index/__init__.py @@ -0,0 +1,32 @@ +from abc import ABC, abstractmethod + +class Index(ABC): + + @abstractmethod + def __init__(self, settings): + pass + + + @abstractmethod + def delete(self, model_space): + pass + + + @abstractmethod + def delete_from_index(self, model_space, uuids): + pass + + + @abstractmethod + def reset(self): + pass + + + @abstractmethod + def run(self, model_space, uuids, embeddings): + pass + + + @abstractmethod + def has_index(self, model_space): + pass diff --git a/chroma-server/chroma_server/index/hnswlib.py b/chroma/db/index/hnswlib.py similarity index 86% rename from chroma-server/chroma_server/index/hnswlib.py rename to chroma/db/index/hnswlib.py index de7f36f..ad1ca14 100644 --- a/chroma-server/chroma_server/index/hnswlib.py +++ b/chroma/db/index/hnswlib.py @@ -4,13 +4,11 @@ import time import hnswlib import numpy as np -from chroma_server.index.abstract import Index -from chroma_server.logger import logger +from chroma.db.index import Index class Hnswlib(Index): - _save_folder = '/index_data' _model_space = None _index = None _index_metadata = { @@ -22,28 +20,28 @@ class Hnswlib(Index): _id_to_uuid = {} _uuid_to_id = {} - def __init__(self): - pass - # set the save folder - def set_save_folder(self, save_folder): - self._save_folder = save_folder + def __init__(self, settings): + self._save_folder = settings.chroma_cache_dir + "/index" - def get_save_folder(self): - return self._save_folder - def run(self, model_space, uuids, embeddings, space='l2', ef=10, num_threads=4): + def run(self, model_space, uuids, embeddings): + + space = 'l2' + ef=10 + num_threads='4' + # more comments available at the source: https://github.com/nmslib/hnswlib dimensionality = len(embeddings[0]) - + for uuid, i in zip(uuids, range(len(uuids))): self._id_to_uuid[i] = str(uuid) self._uuid_to_id[str(uuid)] = i index = hnswlib.Index(space=space, dim=dimensionality) # possible options are l2, cosine or ip - index.init_index(max_elements=len(embeddings), ef_construction=100, M=16) - index.set_ef(ef) - index.set_num_threads(num_threads) + index.init_index(max_elements=len(embeddings), ef_construction=100, M=16) + index.set_ef(ef) + index.set_num_threads(num_threads) index.add_items(embeddings, range(len(uuids))) self._index = index @@ -53,7 +51,8 @@ class Hnswlib(Index): 'elements': len(embeddings) , 'time_created': time.time(), } - self.save() + self._save() + def delete(self, model_space): # delete files, dont throw error if they dont exist @@ -72,9 +71,10 @@ class Hnswlib(Index): self._id_to_uuid = {} self._uuid_to_id = {} + def delete_from_index(self, model_space, uuids): if self._model_space != model_space: - self.load(model_space) + self._load(model_space) if self._index is not None: for uuid in uuids: @@ -82,9 +82,11 @@ class Hnswlib(Index): del self._id_to_uuid[self._uuid_to_id[uuid]] del self._uuid_to_id[uuid] - self.save() - - def save(self): + self._save() + + + def _save(): + # create the directory if it doesn't exist if not os.path.exists(f'{self._save_folder}'): os.makedirs(f'{self._save_folder}') @@ -103,7 +105,9 @@ class Hnswlib(Index): logger.debug('Index saved to {self._save_folder}/index.bin') - def load(self, model_space): + + def _load(self, model_space): + # unpickle the mappers try: with open(f"{self._save_folder}/id_to_uuid_{model_space}.pkl", 'rb') as f: @@ -121,14 +125,15 @@ class Hnswlib(Index): except: logger.debug('Index not found') + def has_index(self, model_space): return os.path.isfile(f"{self._save_folder}/index_{model_space}.bin") - # do knn_query on hnswlib to get nearest neighbors + def get_nearest_neighbors(self, model_space, query, k, uuids=None): if self._model_space != model_space: - self.load(model_space) + self._load(model_space) s2= time.time() # get ids from uuids as a set, if they are available @@ -137,7 +142,7 @@ class Hnswlib(Index): ids = {self._uuid_to_id[uuid] for uuid in uuids} if len(ids) < k : k = len(ids) - + filter_function = None if len(ids) != 0: filter_function = lambda id: id in ids @@ -149,13 +154,16 @@ class Hnswlib(Index): logger.debug(f'time to run knn query: {time.time() - s3}') uuids = [self._id_to_uuid[id] for id in database_ids[0]] - + return uuids, distances + def reset(self): if os.path.exists(f'{self._save_folder}'): for f in os.listdir(f'{self._save_folder}'): os.remove(os.path.join(f'{self._save_folder}', f)) # recreate the directory if not os.path.exists(f'{self._save_folder}'): - os.makedirs(f'{self._save_folder}') \ No newline at end of file + os.makedirs(f'{self._save_folder}') + + diff --git a/chroma/server/__init__.py b/chroma/server/__init__.py new file mode 100644 index 0000000..24bf10e --- /dev/null +++ b/chroma/server/__init__.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod + +from chroma.utils.error_reporting import init_error_reporting +from chroma.server.utils.telemetry.capture import Capture + +class Server(ABC): + + def __init__(self): + self._chroma_telemetry = Capture() + self._chroma_telemetry.capture('server-start') + init_error_reporting() diff --git a/chroma/server/arrowflight.py b/chroma/server/arrowflight.py new file mode 100644 index 0000000..8061aa5 --- /dev/null +++ b/chroma/server/arrowflight.py @@ -0,0 +1,9 @@ +import chroma.server + +class ArrowFlight(chroma.server.Server): + + def __init__(self): + super().__init__() + pass + + #TODO: Implement diff --git a/chroma/server/fastapi.py b/chroma/server/fastapi.py new file mode 100644 index 0000000..a0aba47 --- /dev/null +++ b/chroma/server/fastapi.py @@ -0,0 +1,83 @@ +import fastapi +from fastapi.responses import JSONResponse +import chroma +import chroma.server +from chroma.server.fastapi.types import (AddEmbedding, CountEmbedding, DeleteEmbedding, + FetchEmbedding, ProcessEmbedding, + QueryEmbedding, RawSql, Results, + SpaceKeyInput) + +class FastAPI(chroma.server.Server): + + def __init__(self): + super().__init__() + self._app = fastapi.FastAPI(debug=True) + self._api = chroma.get_api() + + self.router = fastpi.APIRouter() + self.router.add_api_route("/api/v1", self.root, methods=["GET"]) + self.router.add_api_route("/api/v1/add", self.add, methods=["POST"], status_code=status.HTTP_201_CREATED) + self.router.add_api_route("/api/v1/fetch", self.fetch, methods=["POST"]) + self.router.add_api_route("/api/v1/delete", self.delete, methods=["POST"]) + self.router.add_api_route("/api/v1/count", self.count, methods=["GET"]) + self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"]) + self.router.add_api_route("/api/v1/raw_sql", self.raw_sql, methods=["POST"]) + self.router.add_api_route("/api/v1/get_nearest_neighbors", self.get_nearest_neighbors, methods=["POST"]) + self.router.add_api_route("/api/v1/create_index", self.create_index, methods=["POST"]) + self.router.add_api_route("/api/v1/process", self.process, methods=["POST"]) + self.router.add_api_route("/api/v1/get_status", self.get_status, methods=["POST"]) + self.router.add_api_route("/api/v1/get_results", self.get_results, methods=["POST"]) + + self._app.include_router(router) + + + def root(self): + return {"nanosecond heartbeat": self._api.heartbeat()} + + + def add(self, new_embedding: AddEmbedding): + if self._api.add(**new_embedding): + return {"response": "Added records to database"} + else: + raise Exception("api.add returned false") + + + def fetch(self, embedding: FetchEmbedding): + return self._api.fetch(embedding) + + + def delete(self, embedding: DeleteEmbedding): + return self._api.delete(embedding) + + + def count(self, model_space: str = None): + return self._api.count(model_space) + + + def reset(self): + return self._api.reset() + + + def get_nearest_neighbors(self, embedding: QueryEmbedding): + return self._api.get_nearest_neighbors(**embedding) + + + def raw_sql(self, raw_sql: RawSql): + return self._api.raw_sql(raw_sql.raw_sql) + + + def create_index(self, process_embedding: ProcessEmbedding): + return self._api.create_index(process_embedding.model_space) + + + def process(self, process_embedding: ProcessEmbedding): + task_id = self._api.process(process_embedding.model_space) + return JSONResponse({"task_id": task_id}) + + + def get_status(self, task_id): + return JSONResponse(self._api.get_task_status(task_id)) + + + def get_results(self, results: Results): + return self._api.get_results(results.model_space, results.n_results) diff --git a/chroma-server/chroma_server/__init__.py b/chroma/server/fastapi/__init__.py similarity index 100% rename from chroma-server/chroma_server/__init__.py rename to chroma/server/fastapi/__init__.py diff --git a/chroma-server/chroma_server/types/__init__.py b/chroma/server/fastapi/types.py similarity index 97% rename from chroma-server/chroma_server/types/__init__.py rename to chroma/server/fastapi/types.py index f3c37a4..c770384 100644 --- a/chroma-server/chroma_server/types/__init__.py +++ b/chroma/server/fastapi/types.py @@ -38,4 +38,4 @@ class SpaceKeyInput(BaseModel): model_space: str class DeleteEmbedding(BaseModel): - where: dict = {} \ No newline at end of file + where: dict = {} diff --git a/chroma-server/chroma_server/utils/error_reporting.py b/chroma/server/utils/error_reporting.py similarity index 91% rename from chroma-server/chroma_server/utils/error_reporting.py rename to chroma/server/utils/error_reporting.py index d207c51..e67650d 100644 --- a/chroma-server/chroma_server/utils/error_reporting.py +++ b/chroma/server/utils/error_reporting.py @@ -1,4 +1,4 @@ -from chroma_server.utils.config.settings import get_settings +from chroma import get_settings import sentry_sdk from sentry_sdk.client import Client @@ -27,4 +27,4 @@ def init_error_reporting(): before_send=strip_sensitive_data, ) with configure_scope() as scope: - scope.set_tag('posthog_distinct_id', get_settings().telemetry_anonymized_uuid) \ No newline at end of file + scope.set_tag('posthog_distinct_id', get_settings().telemetry_anonymized_uuid) diff --git a/chroma-server/chroma_server/utils/telemetry/abstract.py b/chroma/server/utils/telemetry/__init__.py similarity index 63% rename from chroma-server/chroma_server/utils/telemetry/abstract.py rename to chroma/server/utils/telemetry/__init__.py index 921c399..ce95a1d 100644 --- a/chroma-server/chroma_server/utils/telemetry/abstract.py +++ b/chroma/server/utils/telemetry/__init__.py @@ -1,10 +1,11 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod + +class Telemetry(ABC): -class Telemetry(): @abstractmethod def __init__(self): pass @abstractmethod def capture(self, event, properties=None): - pass \ No newline at end of file + pass diff --git a/chroma-server/chroma_server/utils/telemetry/capture.py b/chroma/server/utils/telemetry/capture.py similarity index 90% rename from chroma-server/chroma_server/utils/telemetry/capture.py rename to chroma/server/utils/telemetry/capture.py index 56eb39f..8d1964d 100644 --- a/chroma-server/chroma_server/utils/telemetry/capture.py +++ b/chroma/server/utils/telemetry/capture.py @@ -2,8 +2,8 @@ import posthog import uuid import sys import os -from chroma_server.utils.telemetry.abstract import Telemetry -from chroma_server.utils.config.settings import get_settings +from chroma.server.utils.telemetry import Telemetry +from chroma import get_settings class Capture(Telemetry): _conn = None @@ -37,4 +37,3 @@ class Capture(Telemetry): properties['environment'] = os.getenv('environment', 'development') self._conn.capture(self._telemetry_anonymized_uuid, event, properties) - diff --git a/chroma/test/test_api.py b/chroma/test/test_api.py new file mode 100644 index 0000000..a231dbb --- /dev/null +++ b/chroma/test/test_api.py @@ -0,0 +1,4 @@ +import pytest + +def test_init(): + assert(1==1) diff --git a/chroma/test/test_chroma.py b/chroma/test/test_chroma.py new file mode 100644 index 0000000..f4d7173 --- /dev/null +++ b/chroma/test/test_chroma.py @@ -0,0 +1,56 @@ +import pytest +import unittest +from unittest.mock import patch + +import chroma +import chroma.config + +class GetDBTest(unittest.TestCase): + + @patch('chroma.db.duckdb.DuckDB', autospec=True) + def test_default_db(self, mock): + db = chroma.get_db(chroma.config.Settings()) + assert mock.called + + @patch('chroma.db.duckdb.PersistentDuckDB', autospec=True) + def test_persistent_duckdb(self, mock): + db = chroma.get_db(chroma.config.Settings(chroma_cache_dir="./foo")) + assert mock.called + + + @patch('chroma.db.clickhouse.Clickhouse', autospec=True) + def test_clickhouse(self, mock): + db = chroma.get_db(chroma.config.Settings(clickhouse_host="foo")) + assert mock.called + +class GetAPITest(unittest.TestCase): + + @patch('chroma.db.duckdb.DuckDB', autospec=True) + @patch('chroma.api.local.LocalAPI', autospec=True) + def test_local(self, mock_api, mock_db): + api = chroma.get_api(chroma.config.Settings()) + assert mock_api.called + assert mock_db.called + + @patch('chroma.db.duckdb.DuckDB', autospec=True) + @patch('chroma.api.celery.CeleryAPI', autospec=True) + def test_celery(self, mock_api, mock_db): + api = chroma.get_api(chroma.config.Settings(celery_broker_url='foo')) + assert mock_api.called + assert mock_db.called + + + @patch('chroma.api.fastapi.FastAPI', autospec=True) + def test_fastapi(self, mock): + api = chroma.get_api(chroma.config.Settings(chroma_server_host='foo', + chroma_server_http_port='80')) + assert mock.called + + + @patch('chroma.api.arrowflight.ArrowFlightAPI', autospec=True) + def test_arrowflight(self, mock): + api = chroma.get_api(chroma.config.Settings(chroma_server_host='foo', + chroma_server_http_port='80', + chroma_server_grpc_port='9999')) + assert mock.called + diff --git a/chroma-server/chroma_server/worker.py b/chroma/worker.py similarity index 72% rename from chroma-server/chroma_server/worker.py rename to chroma/worker.py index 6937774..f6f7009 100644 --- a/chroma-server/chroma_server/worker.py +++ b/chroma/worker.py @@ -1,8 +1,8 @@ import os import time import random +import chroma from celery import Celery -from chroma_server.db.clickhouse import Clickhouse, get_col_pos celery = Celery(__name__) celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379") @@ -15,19 +15,20 @@ def create_task(task_type): @celery.task(name="heavy_offline_analysis") def heavy_offline_analysis(model_space): - task_db_conn = Clickhouse() - embedding_rows = task_db_conn.fetch({"model_space": model_space}) + db = chroma.get_db() + + embedding_rows = db.fetch({"model_space": model_space}) uuids = [] custom_quality_scores = [] - + for row in embedding_rows: uuids.append(row[get_col_pos("uuid")]) custom_quality_scores.append(random.random()) spaces = [model_space] * len(uuids) - task_db_conn.delete_results(model_space) - task_db_conn.add_results(spaces, uuids, custom_quality_scores) - + db.delete_results(model_space) + db.add_results(spaces, uuids, custom_quality_scores) + return "Wrote custom quality scores to database" diff --git a/chroma-server/chroma-in-notebook.ipynb b/examples/chroma-in-notebook.ipynb similarity index 100% rename from chroma-server/chroma-in-notebook.ipynb rename to examples/chroma-in-notebook.ipynb diff --git a/chroma-server/in-memory_demo.ipynb b/examples/in-memory_demo.ipynb similarity index 100% rename from chroma-server/in-memory_demo.ipynb rename to examples/in-memory_demo.ipynb diff --git a/examples/misc/play.py b/examples/misc/play.py index 9774089..34b860b 100644 --- a/examples/misc/play.py +++ b/examples/misc/play.py @@ -7,9 +7,7 @@ chroma.reset() # add for i in range(10): chroma.add( - embedding=[1,2,3,4,5,6,7,8,9,10], - input_uri="https://www.google.com", - dataset=None + embedding=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], input_uri="https://www.google.com", dataset=None ) # fetch all @@ -17,7 +15,7 @@ allres = chroma.get_all() print(allres) # count -print("count is", chroma.count()['count']) +print("count is", chroma.count()["count"]) # persist chroma.persist() @@ -33,4 +31,4 @@ chroma.process() # reset chroma.reset() -print("count is", chroma.count()['count']) \ No newline at end of file +print("count is", chroma.count()["count"]) diff --git a/examples/sample-app/app.py b/examples/sample-app/app.py index 91e7264..25f094b 100644 --- a/examples/sample-app/app.py +++ b/examples/sample-app/app.py @@ -4,6 +4,6 @@ import chroma_client app = Flask(__name__) -@app.route('/') +@app.route("/") def hello(): - return(str(chroma_client.fetch_new_labels())) \ No newline at end of file + return str(chroma_client.fetch_new_labels()) diff --git a/examples/sample-script/sample_script.py b/examples/sample-script/sample_script.py index 6f0da05..c2c0fa4 100644 --- a/examples/sample-script/sample_script.py +++ b/examples/sample-script/sample_script.py @@ -7,22 +7,23 @@ chroma.set_model_space("sample_1_1") print(chroma.heartbeat()) -chroma.add([[1,2,3,4,5]], ["/images/1"], ["training"], ['spoon']) -chroma.add([[1,2,3,4,5]], ["/images/2"], ["training"], ['spoon']) -chroma.add([[1,2,3,4,5]], ["/images/3"], ["training"], ['spoon']) -chroma.add([[1,2,3,4,5]], ["/images/1"], ["training"], ['knife']) -chroma.add([[1,2,3,4,5]], ["/images/4"], ["training"], ['knife']) -chroma.add([[1,2,3,4,5]], ["/prod/2"], ["test"], ['knife']) +chroma.add([[1, 2, 3, 4, 5]], ["/images/1"], ["training"], ["spoon"]) +chroma.add([[1, 2, 3, 4, 5]], ["/images/2"], ["training"], ["spoon"]) +chroma.add([[1, 2, 3, 4, 5]], ["/images/3"], ["training"], ["spoon"]) +chroma.add([[1, 2, 3, 4, 5]], ["/images/1"], ["training"], ["knife"]) +chroma.add([[1, 2, 3, 4, 5]], ["/images/4"], ["training"], ["knife"]) +chroma.add([[1, 2, 3, 4, 5]], ["/prod/2"], ["test"], ["knife"]) process_task = chroma.process() print(process_task) -print(chroma.get_task_status(process_task['task_id'])) +print(chroma.get_task_status(process_task["task_id"])) print("sleeping for 10s to wait for task to complete") import time + time.sleep(10) -print(chroma.get_task_status(process_task['task_id'])) +print(chroma.get_task_status(process_task["task_id"])) print(chroma.get_results()) # print(chroma.raw_sql("SELECT * FROM results WHERE space_key = 'yolov3_1_1'")) @@ -64,4 +65,4 @@ print(chroma.get_results()) # print(chroma.process()) # print(chroma.get_nearest_neighbors([1,2,3,4,5], 2)) -# print(chroma.get_nearest_neighbors([1,2,3,4,5], 2, space_key="yolov3_5_1")) \ No newline at end of file +# print(chroma.get_nearest_neighbors([1,2,3,4,5], 2, space_key="yolov3_5_1")) diff --git a/examples/yolov3/results.py b/examples/yolov3/results.py index 921d8a4..b2bb55d 100644 --- a/examples/yolov3/results.py +++ b/examples/yolov3/results.py @@ -2,4 +2,4 @@ from chroma_client import Chroma chroma = Chroma() -print(chroma.get_results('yolov3_1_1')) +print(chroma.get_results("yolov3_1_1")) diff --git a/examples/yolov3/yolov3.py b/examples/yolov3/yolov3.py index 7261c8e..ed19733 100644 --- a/examples/yolov3/yolov3.py +++ b/examples/yolov3/yolov3.py @@ -10,7 +10,7 @@ import pandas as pd if __name__ == "__main__": - file = 'data__nogit/yolov3_objects_large_5k.parquet' + file = "data__nogit/yolov3_objects_large_5k.parquet" print("Loading parquet file: ", file) py = pq.read_table(file) @@ -20,14 +20,14 @@ if __name__ == "__main__": data_length = len(df) chroma = Chroma(model_space="yolov3") - chroma.reset() #make sure we are using a fresh db + chroma.reset() # make sure we are using a fresh db allstart = time.time() start = time.time() dataset = "training" BATCH_SIZE = 1_000 - print("Loading in records with a batch size of: " , data_length) + print("Loading in records with a batch size of: ", data_length) for i in range(0, data_length, BATCH_SIZE): if i >= 20_000: @@ -35,64 +35,157 @@ if __name__ == "__main__": end = time.time() page = i * BATCH_SIZE - print("Time to process BATCH_SIZE rows: " + '{0:.2f}'.format((end - start)) + "s, records loaded: " + str(i)) + print( + "Time to process BATCH_SIZE rows: " + + "{0:.2f}".format((end - start)) + + "s, records loaded: " + + str(i) + ) start = time.time() - batch = df[i:i+BATCH_SIZE] + batch = df[i : i + BATCH_SIZE] for index, row in batch.iterrows(): - for idx, annotation in enumerate(row['infer']['annotations']): - annotation["bbox"] = annotation['bbox'].tolist() - row['infer']['annotations'] = row['infer']['annotations'].tolist() + for idx, annotation in enumerate(row["infer"]["annotations"]): + annotation["bbox"] = annotation["bbox"].tolist() + row["infer"]["annotations"] = row["infer"]["annotations"].tolist() - row['embedding_data'] = row['embedding_data'].tolist() + row["embedding_data"] = row["embedding_data"].tolist() - embedding = batch['embedding_data'].tolist() - input_uri = batch['resource_uri'].tolist() + embedding = batch["embedding_data"].tolist() + input_uri = batch["resource_uri"].tolist() inference_classes = [] for index, row in batch.iterrows(): - for idx, annotation in enumerate(row['infer']['annotations']): - inference_classes.append(annotation['category_name']) + for idx, annotation in enumerate(row["infer"]["annotations"]): + inference_classes.append(annotation["category_name"]) datasets = dataset chroma.add( - embedding=embedding, - input_uri=input_uri, + embedding=embedding, + input_uri=input_uri, dataset=dataset, - inference_class=inference_classes + inference_class=inference_classes, ) allend = time.time() - print("time to add all: ", "{:.2f}".format(allend - allstart) + 's') + print("time to add all: ", "{:.2f}".format(allend - allstart) + "s") fetched = chroma.count() - print("Records loaded into the database: ", fetched) + print("Records loaded into the database: ", fetched) start = time.time() chroma.create_index() end = time.time() - print("Time to process: " +'{0:.2f}'.format((end - start)) + 's') + print("Time to process: " + "{0:.2f}".format((end - start)) + "s") - knife_embedding = [0.2310010939836502, -0.3462161719799042, 0.29164767265319824, -0.09828940033912659, 1.814868450164795, -10.517369270324707, -13.531850814819336, -12.730537414550781, -13.011675834655762, -10.257010459899902, -13.779699325561523, -11.963963508605957, -13.948140144348145, -12.46799087524414, -14.569470405578613, -16.388280868530273, -13.76762580871582, -12.192169189453125, -12.204055786132812, -12.259000778198242, -13.696036338806152, -14.609177589416504, -16.951879501342773, -17.096384048461914, -14.355693817138672, -16.643482208251953, -14.270745277404785, -14.375198364257812, -14.381218910217285, -13.475995063781738, -12.694938659667969, -10.011992454528809, -9.770626068115234, -13.155019760131836, -16.136341094970703, -6.552414417266846, -11.243837356567383, -16.678457260131836, -14.629229545593262, -10.052337646484375, -15.451828956604004, -12.561151504516602, -11.68396282196045, -11.975972175598145, -11.09926986694336, -13.060500144958496, -12.075592994689941, -1.0808746814727783, 1.7046797275543213, -3.8080708980560303, -11.401922225952148, -12.184720039367676, -13.262567520141602, -11.299583435058594, -13.654638290405273, -10.767330169677734, -9.012763977050781, -10.202326774597168, -10.088111877441406, -13.247991561889648, -9.651527404785156, -11.903244972229004, -13.922954559326172, -17.37179946899414, -12.51513385772705, -7.8046746253967285, -14.406414985656738, -13.172696113586426, -11.194984436035156, -12.029500961303711, -10.996524810791016, -10.828441619873047, -8.673471450805664, -13.800869941711426, -9.680946350097656, -12.964024543762207, -9.694372177124023, -13.132003784179688, -9.38864803314209, -14.305071830749512, -14.4693603515625, -5.0566205978393555, -15.685358047485352, -12.493011474609375, -8.424881935119629] + knife_embedding = [ + 0.2310010939836502, + -0.3462161719799042, + 0.29164767265319824, + -0.09828940033912659, + 1.814868450164795, + -10.517369270324707, + -13.531850814819336, + -12.730537414550781, + -13.011675834655762, + -10.257010459899902, + -13.779699325561523, + -11.963963508605957, + -13.948140144348145, + -12.46799087524414, + -14.569470405578613, + -16.388280868530273, + -13.76762580871582, + -12.192169189453125, + -12.204055786132812, + -12.259000778198242, + -13.696036338806152, + -14.609177589416504, + -16.951879501342773, + -17.096384048461914, + -14.355693817138672, + -16.643482208251953, + -14.270745277404785, + -14.375198364257812, + -14.381218910217285, + -13.475995063781738, + -12.694938659667969, + -10.011992454528809, + -9.770626068115234, + -13.155019760131836, + -16.136341094970703, + -6.552414417266846, + -11.243837356567383, + -16.678457260131836, + -14.629229545593262, + -10.052337646484375, + -15.451828956604004, + -12.561151504516602, + -11.68396282196045, + -11.975972175598145, + -11.09926986694336, + -13.060500144958496, + -12.075592994689941, + -1.0808746814727783, + 1.7046797275543213, + -3.8080708980560303, + -11.401922225952148, + -12.184720039367676, + -13.262567520141602, + -11.299583435058594, + -13.654638290405273, + -10.767330169677734, + -9.012763977050781, + -10.202326774597168, + -10.088111877441406, + -13.247991561889648, + -9.651527404785156, + -11.903244972229004, + -13.922954559326172, + -17.37179946899414, + -12.51513385772705, + -7.8046746253967285, + -14.406414985656738, + -13.172696113586426, + -11.194984436035156, + -12.029500961303711, + -10.996524810791016, + -10.828441619873047, + -8.673471450805664, + -13.800869941711426, + -9.680946350097656, + -12.964024543762207, + -9.694372177124023, + -13.132003784179688, + -9.38864803314209, + -14.305071830749512, + -14.4693603515625, + -5.0566205978393555, + -15.685358047485352, + -12.493011474609375, + -8.424881935119629, + ] start = time.time() - get_nearest_neighbors = chroma.get_nearest_neighbors(knife_embedding, 4, where= {"inference_class": "knife","dataset": "training"}) + get_nearest_neighbors = chroma.get_nearest_neighbors( + knife_embedding, 4, where={"inference_class": "knife", "dataset": "training"} + ) print("get_nearest_neighbors: ", get_nearest_neighbors) - res_df = pd.DataFrame(get_nearest_neighbors['embeddings']) + res_df = pd.DataFrame(get_nearest_neighbors["embeddings"]) print(res_df.head()) - print("Distances to nearest neighbors: ", get_nearest_neighbors['distances']) - print("Internal ids of nearest neighbors: ", get_nearest_neighbors['ids']) + print("Distances to nearest neighbors: ", get_nearest_neighbors["distances"]) + print("Internal ids of nearest neighbors: ", get_nearest_neighbors["ids"]) end = time.time() - print("Time to get nearest neighbors: " +'{0:.2f}'.format((end - start)) + 's') + print("Time to get nearest neighbors: " + "{0:.2f}".format((end - start)) + "s") task = chroma.calculate_results() print(task) - print(chroma.get_task_status(task['task_id'])) + print(chroma.get_task_status(task["task_id"])) fetched = chroma.count() - print("Records loaded into the database: ", fetched) + print("Records loaded into the database: ", fetched) del chroma diff --git a/pyproject.toml b/pyproject.toml index ac1b975..9b793c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,40 @@ +[project] +name = "chroma" +dynamic = ["version"] + +authors = [ + { name="Jeff Huber", email="jeff@trychroma.com" }, + { name="Anton Troynikov", email="anton@trychroma.com" } +] +description = "Chroma." +readme = "README.md" +requires-python = ">=3.7" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache-2.0", + "Operating System :: OS Independent", +] +dependencies = [ + 'pyarrow ~= 9.0', + 'requests ~= 2.28', +] + [tool.black] line-length = 100 +required-version = "22.10.0" # Black will refuse to run if it's not this version. +target-version = ['py36', 'py37', 'py38', 'py39', 'py310'] -# Black will refuse to run if it's not this version. -required-version = "22.6.0" +[tool.pytest.ini_options] +pythonpath = ["."] + +[project.urls] +"Homepage" = "https://github.com/chroma-core/chroma" +"Bug Tracker" = "https://github.com/chroma-core/chroma/issues" + +[build-system] +requires = ["setuptools>=61.0", "setuptools_scm[toml]>=6.2"] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] +local_scheme="no-local-version" -# Ensure black's output will be compatible with all listed versions. -target-version = ['py36', 'py37', 'py38', 'py39', 'py310'] \ No newline at end of file diff --git a/chroma-server/requirements.txt b/requirements.txt similarity index 88% rename from chroma-server/requirements.txt rename to requirements.txt index dd55aa4..60b9d1f 100644 --- a/chroma-server/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ fastapi==0.85.1 uvicorn[standard]==0.18.3 pyarrow==9.0.0 -numpy -pandas +requests==2.28.1 +numpy==1.23.5 +pandas==1.5.1 duckdb==0.5.1 hnswlib @ git+https://oauth2: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@github.com/chroma-core/hnswlib.git - redis==3.5.3 celery==4.4.7 clickhouse_driver==0.2.4 diff --git a/chroma-server/requirements_dev.txt b/requirements_dev.txt similarity index 70% rename from chroma-server/requirements_dev.txt rename to requirements_dev.txt index be3caac..7c96455 100644 --- a/chroma-server/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,3 +1,5 @@ -httpx +build pytest setuptools_scm +httpx +black