[ENH]: CIP-5: Large Batch Handling Improvements Proposal (#1077)

- Including only CIP for review.

Refs: #1049

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - New proposal to handle large batches of embeddings gracefully

## Test plan
*How are these changes tested?*

- [ ] Tests pass locally with `pytest` for python, `yarn test` for js

## Documentation Changes
TBD

---------

Signed-off-by: sunilkumardash9 <sunilkumardash9@gmail.com>
Co-authored-by: Sunil Kumar Dash <47926185+sunilkumardash9@users.noreply.github.com>
This commit is contained in:
Trayan Azarov
2023-09-18 23:00:57 +03:00
committed by GitHub
parent 2b434b8266
commit 82b9c830f7
9 changed files with 302 additions and 44 deletions

View File

@@ -378,3 +378,10 @@ class API(Component, ABC):
"""
pass
@property
@abstractmethod
def max_batch_size(self) -> int:
"""Return the maximum number of records that can be submitted in a single call
to submit_embeddings."""
pass

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Optional, cast
from typing import Optional, cast, Tuple
from typing import Sequence
from uuid import UUID
@@ -23,6 +23,7 @@ from chromadb.api.types import (
GetResult,
QueryResult,
CollectionMetadata,
validate_batch,
)
from chromadb.auth import (
ClientAuthProvider,
@@ -38,6 +39,7 @@ logger = logging.getLogger(__name__)
class FastAPI(API):
_settings: Settings
_max_batch_size: int = -1
@staticmethod
def _validate_host(host: str) -> None:
@@ -296,6 +298,29 @@ class FastAPI(API):
raise_chroma_error(resp)
return cast(IDs, resp.json())
def _submit_batch(
self,
batch: Tuple[
IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents]
],
url: str,
) -> requests.Response:
"""
Submits a batch of embeddings to the database
"""
resp = self._session.post(
self._api_url + url,
data=json.dumps(
{
"ids": batch[0],
"embeddings": batch[1],
"metadatas": batch[2],
"documents": batch[3],
}
),
)
return resp
@override
def _add(
self,
@@ -309,18 +334,9 @@ class FastAPI(API):
Adds a batch of embeddings to the database
- pass in column oriented data lists
"""
resp = self._session.post(
self._api_url + "/collections/" + str(collection_id) + "/add",
data=json.dumps(
{
"ids": ids,
"embeddings": embeddings,
"metadatas": metadatas,
"documents": documents,
}
),
)
batch = (ids, embeddings, metadatas, documents)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
raise_chroma_error(resp)
return True
@@ -337,18 +353,11 @@ class FastAPI(API):
Updates a batch of embeddings in the database
- pass in column oriented data lists
"""
resp = self._session.post(
self._api_url + "/collections/" + str(collection_id) + "/update",
data=json.dumps(
{
"ids": ids,
"embeddings": embeddings,
"metadatas": metadatas,
"documents": documents,
}
),
batch = (ids, embeddings, metadatas, documents)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(
batch, "/collections/" + str(collection_id) + "/update"
)
resp.raise_for_status()
return True
@@ -365,18 +374,11 @@ class FastAPI(API):
Upserts a batch of embeddings in the database
- pass in column oriented data lists
"""
resp = self._session.post(
self._api_url + "/collections/" + str(collection_id) + "/upsert",
data=json.dumps(
{
"ids": ids,
"embeddings": embeddings,
"metadatas": metadatas,
"documents": documents,
}
),
batch = (ids, embeddings, metadatas, documents)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(
batch, "/collections/" + str(collection_id) + "/upsert"
)
resp.raise_for_status()
return True
@@ -434,6 +436,15 @@ class FastAPI(API):
"""Returns the settings of the client"""
return self._settings
@property
@override
def max_batch_size(self) -> int:
if self._max_batch_size == -1:
resp = self._session.get(self._api_url + "/pre-flight-checks")
raise_chroma_error(resp)
self._max_batch_size = cast(int, resp.json()["max_batch_size"])
return self._max_batch_size
def raise_chroma_error(resp: requests.Response) -> None:
"""Raises an error if the response is not ok, using a ChromaError if possible"""

View File

@@ -26,6 +26,7 @@ from chromadb.api.types import (
validate_update_metadata,
validate_where,
validate_where_document,
validate_batch,
)
from chromadb.telemetry.events import CollectionAddEvent, CollectionDeleteEvent
@@ -38,6 +39,7 @@ import time
import logging
import re
logger = logging.getLogger(__name__)
@@ -241,9 +243,18 @@ class SegmentAPI(API):
) -> bool:
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.ADD)
validate_batch(
(ids, embeddings, metadatas, documents),
{"max_batch_size": self.max_batch_size},
)
records_to_submit = []
for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents):
for r in _records(
t.Operation.ADD,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
):
self._validate_embedding_record(coll, r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
@@ -262,9 +273,18 @@ class SegmentAPI(API):
) -> bool:
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
validate_batch(
(ids, embeddings, metadatas, documents),
{"max_batch_size": self.max_batch_size},
)
records_to_submit = []
for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents):
for r in _records(
t.Operation.UPDATE,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
):
self._validate_embedding_record(coll, r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
@@ -282,9 +302,18 @@ class SegmentAPI(API):
) -> bool:
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
validate_batch(
(ids, embeddings, metadatas, documents),
{"max_batch_size": self.max_batch_size},
)
records_to_submit = []
for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents):
for r in _records(
t.Operation.UPSERT,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
):
self._validate_embedding_record(coll, r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
@@ -524,6 +553,11 @@ class SegmentAPI(API):
def get_settings(self) -> Settings:
return self._settings
@property
@override
def max_batch_size(self) -> int:
return self._producer.max_batch_size
def _topic(self, collection_id: UUID) -> str:
return f"persistent://{self._tenant_id}/{self._topic_ns}/{collection_id}"

View File

@@ -1,4 +1,4 @@
from typing import Optional, Union, Sequence, TypeVar, List, Dict, Any
from typing import Optional, Union, Sequence, TypeVar, List, Dict, Any, Tuple
from typing_extensions import Literal, TypedDict, Protocol
import chromadb.errors as errors
from chromadb.types import (
@@ -367,3 +367,13 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings:
f"Expected each value in the embedding to be a int or float, got {embeddings}"
)
return embeddings
def validate_batch(
batch: Tuple[IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents]],
limits: Dict[str, Any],
) -> None:
if len(batch[0]) > limits["max_batch_size"]:
raise ValueError(
f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}"
)

View File

@@ -126,6 +126,9 @@ class FastAPI(chromadb.server.Server):
self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"])
self.router.add_api_route("/api/v1/version", self.version, methods=["GET"])
self.router.add_api_route("/api/v1/heartbeat", self.heartbeat, methods=["GET"])
self.router.add_api_route(
"/api/v1/pre-flight-checks", self.pre_flight_checks, methods=["GET"]
)
self.router.add_api_route(
"/api/v1/collections",
@@ -312,3 +315,8 @@ class FastAPI(chromadb.server.Server):
include=query.include,
)
return nnresult
def pre_flight_checks(self) -> Dict[str, Any]:
return {
"max_batch_size": self._api.max_batch_size,
}

View File

@@ -1,11 +1,15 @@
from typing import cast
import random
import uuid
from random import randint
from typing import cast, List, Any, Dict
import pytest
import hypothesis.strategies as st
from hypothesis import given, settings
from chromadb.api import API
from chromadb.api.types import Embeddings
from chromadb.api.types import Embeddings, Metadatas
import chromadb.test.property.strategies as strategies
import chromadb.test.property.invariants as invariants
from chromadb.utils.batch_utils import create_batches
collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll")
@@ -44,6 +48,79 @@ def test_add(
)
def create_large_recordset(
min_size: int = 45000,
max_size: int = 50000,
) -> strategies.RecordSet:
size = randint(min_size, max_size)
ids = [str(uuid.uuid4()) for _ in range(size)]
metadatas = [{"some_key": f"{i}"} for i in range(size)]
documents = [f"Document {i}" for i in range(size)]
embeddings = [[1, 2, 3] for _ in range(size)]
record_set: Dict[str, List[Any]] = {
"ids": ids,
"embeddings": cast(Embeddings, embeddings),
"metadatas": metadatas,
"documents": documents,
}
return record_set
@given(collection=collection_st)
@settings(deadline=None, max_examples=1)
def test_add_large(api: API, collection: strategies.Collection) -> None:
api.reset()
record_set = create_large_recordset(
min_size=api.max_batch_size,
max_size=api.max_batch_size + int(api.max_batch_size * random.random()),
)
coll = api.create_collection(
name=collection.name,
metadata=collection.metadata,
embedding_function=collection.embedding_function,
)
normalized_record_set = invariants.wrap_all(record_set)
if not invariants.is_metadata_valid(normalized_record_set):
with pytest.raises(Exception):
coll.add(**normalized_record_set)
return
for batch in create_batches(
api=api,
ids=cast(List[str], record_set["ids"]),
embeddings=cast(Embeddings, record_set["embeddings"]),
metadatas=cast(Metadatas, record_set["metadatas"]),
documents=cast(List[str], record_set["documents"]),
):
coll.add(*batch)
invariants.count(coll, cast(strategies.RecordSet, normalized_record_set))
@given(collection=collection_st)
@settings(deadline=None, max_examples=1)
def test_add_large_exceeding(api: API, collection: strategies.Collection) -> None:
api.reset()
record_set = create_large_recordset(
min_size=api.max_batch_size,
max_size=api.max_batch_size + int(api.max_batch_size * random.random()),
)
coll = api.create_collection(
name=collection.name,
metadata=collection.metadata,
embedding_function=collection.embedding_function,
)
normalized_record_set = invariants.wrap_all(record_set)
if not invariants.is_metadata_valid(normalized_record_set):
with pytest.raises(Exception):
coll.add(**normalized_record_set)
return
with pytest.raises(Exception) as e:
coll.add(**record_set)
assert "exceeds maximum batch size" in str(e.value)
# TODO: This test fails right now because the ids are not sorted by the input order
@pytest.mark.xfail(
reason="This is expected to fail right now. We should change the API to sort the \

View File

@@ -1,6 +1,8 @@
# type: ignore
import requests
import chromadb
from chromadb.api.fastapi import FastAPI
from chromadb.api.types import QueryResult
from chromadb.config import Settings
import chromadb.server.fastapi
@@ -164,6 +166,22 @@ def test_heartbeat(api):
assert heartbeat > datetime.now() - timedelta(seconds=10)
def test_max_batch_size(api):
print(api)
batch_size = api.max_batch_size
assert batch_size > 0
def test_pre_flight_checks(api):
if not isinstance(api, FastAPI):
pytest.skip("Not a FastAPI instance")
resp = requests.get(f"{api._api_url}/pre-flight-checks")
assert resp.status_code == 200
assert resp.json() is not None
assert "max_batch_size" in resp.json().keys()
batch_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"ids": ["https://example.com/1", "https://example.com/2"],

View File

@@ -0,0 +1,34 @@
from typing import Optional, Tuple, List
from chromadb.api import API
from chromadb.api.types import (
Documents,
Embeddings,
IDs,
Metadatas,
)
def create_batches(
api: API,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
) -> List[Tuple[IDs, Embeddings, Optional[Metadatas], Optional[Documents]]]:
_batches: List[
Tuple[IDs, Embeddings, Optional[Metadatas], Optional[Documents]]
] = []
if len(ids) > api.max_batch_size:
# create split batches
for i in range(0, len(ids), api.max_batch_size):
_batches.append(
( # type: ignore
ids[i : i + api.max_batch_size],
embeddings[i : i + api.max_batch_size] if embeddings else None,
metadatas[i : i + api.max_batch_size] if metadatas else None,
documents[i : i + api.max_batch_size] if documents else None,
)
)
else:
_batches.append((ids, embeddings, metadatas, documents)) # type: ignore
return _batches

View File

@@ -0,0 +1,59 @@
# CIP-5: Large Batch Handling Improvements Proposal
## Status
Current Status: `Under Discussion`
## **Motivation**
As users start putting Chroma in its paces and storing ever-increasing datasets, we must ensure that errors
related to significant and potentially expensive batches are handled gracefully. This CIP proposes to add a new
setting, `max_batch_size` API, on the local segment API and use it to split large batches into smaller ones.
## **Public Interfaces**
The following interfaces are impacted:
- New Server API endpoint - `/pre-flight-checks`
- New `max_batch_size` property on the `API` interface
- Updated `_add`, `_update` and `_upsert` methods on `chromadb.api.segment.SegmentAPI`
- Updated `_add`, `_update` and `_upsert` methods on `chromadb.api.fastapi.FastAPI`
- New utility library `batch_utils.py`
- New exception raised when batch size exceeds `max_batch_size`
## **Proposed Changes**
We propose the following changes:
- The new `max_batch_size` property is now available in the `API` interface. The property relies on the
underlying `Producer` class
to fetch the actual value. The property will be implemented by both `chromadb.api.segment.SegmentAPI`
and `chromadb.api.fastapi.FastAPI`
- `chromadb.api.segment.SegmentAPI` will implement the `max_batch_size` property by fetching the value from the
`Producer` class.
- `chromadb.api.fastapi.FastAPI` will implement the `max_batch_size` by fetching it from a new `/pre-flight-checks`
endpoint on the Server.
- New `/pre-flight-checks` endpoint on the Server will return a dictionary with pre-flight checks the client must
fulfil to integrate with the server side. For now, we propose using this only for `max_batch_size`, but we can
add more checks in the future. The pre-flight checks will be only fetched once per client and cached for the duration
of the client's lifetime.
- Updated `_add`, `_update` and `_upsert` method on `chromadb.api.segment.SegmentAPI` to validate batch size.
- Updated `_add`, `_update` and `_upsert` method on `chromadb.api.fastapi.FastAPI` to validate batch size (client-side
validation)
- New utility library `batch_utils.py` will contain the logic for splitting batches into smaller ones.
## **Compatibility, Deprecation, and Migration Plan**
The change will be fully compatible with existing implementations. The changes will be transparent to the user.
## **Test Plan**
New tests:
- Batch splitting tests for `chromadb.api.segment.SegmentAPI`
- Batch splitting tests for `chromadb.api.fastapi.FastAPI`
- Tests for `/pre-flight-checks` endpoint
## **Rejected Alternatives**
N/A