[ENH] Pulsar Producer & Consumer (#921)

## Description of changes

*Summarize the changes made by this PR.*
 - New functionality
- Adds a basic pulsar producer, consumer and associated tests. As well
as a docker compose for the distributed version of chroma.

## Test plan
We added bin/cluster-test.sh, which starts pulsar and allows
test_producer_consumer to run the pulsar fixture.

## Documentation Changes
None required.
This commit is contained in:
Hammad Bashir
2023-09-20 02:03:07 -07:00
committed by GitHub
parent 020950470c
commit 896822231e
17 changed files with 1105 additions and 43 deletions

View File

@@ -0,0 +1,31 @@
name: Chroma Cluster Tests
on:
push:
branches:
- main
pull_request:
branches:
- main
- '**'
workflow_dispatch:
jobs:
test:
strategy:
matrix:
python: ['3.7']
platform: [ubuntu-latest]
testfile: ["chromadb/test/ingest/test_producer_consumer.py"] # Just this one test for now
runs-on: ${{ matrix.platform }}
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
- name: Install test dependencies
run: python -m pip install -r requirements.txt && python -m pip install -r requirements_dev.txt
- name: Integration Test
run: bin/cluster-test.sh ${{ matrix.testfile }}

View File

@@ -32,4 +32,4 @@ repos:
hooks:
- id: mypy
args: [--strict, --ignore-missing-imports, --follow-imports=silent, --disable-error-code=type-abstract]
additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy"]
additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy", "types-protobuf"]

16
bin/cluster-test.sh Executable file
View File

@@ -0,0 +1,16 @@
#!/usr/bin/env bash
set -e
function cleanup {
docker compose -f docker-compose.cluster.yml down --rmi local --volumes
}
trap cleanup EXIT
docker compose -f docker-compose.cluster.yml up -d --wait pulsar
export CHROMA_CLUSTER_TEST_ONLY=1
echo testing: python -m pytest "$@"
python -m pytest "$@"

View File

@@ -1,6 +1,7 @@
from chromadb.api import API
from chromadb.config import Settings, System
from chromadb.db.system import SysDB
from chromadb.ingest.impl.utils import create_topic_name
from chromadb.segment import SegmentManager, MetadataReader, VectorReader
from chromadb.telemetry import Telemetry
from chromadb.ingest import Producer
@@ -130,6 +131,9 @@ class SegmentAPI(API):
coll = t.Collection(
id=id, name=name, metadata=metadata, topic=self._topic(id), dimension=None
)
# TODO: Topic creation right now lives in the producer but it should be moved to the coordinator,
# and the producer should just be responsible for publishing messages. Coordinator should
# be responsible for all management of topics.
self._producer.create_topic(coll["topic"])
segments = self._manager.create_segments(coll)
self._sysdb.create_collection(coll)
@@ -559,7 +563,7 @@ class SegmentAPI(API):
return self._producer.max_batch_size
def _topic(self, collection_id: UUID) -> str:
return f"persistent://{self._tenant_id}/{self._topic_ns}/{collection_id}"
return create_topic_name(self._tenant_id, self._topic_ns, str(collection_id))
# TODO: This could potentially cause race conditions in a distributed version of the
# system, since the cache is only local.

View File

@@ -92,6 +92,10 @@ class Settings(BaseSettings): # type: ignore
chroma_server_grpc_port: Optional[str] = None
chroma_server_cors_allow_origins: List[str] = [] # eg ["http://localhost:3000"]
pulsar_broker_url: Optional[str] = None
pulsar_admin_port: Optional[str] = None
pulsar_broker_port: Optional[str] = None
chroma_server_auth_provider: Optional[str] = None
@validator("chroma_server_auth_provider", pre=True, always=True, allow_reuse=True)

View File

@@ -0,0 +1,304 @@
from __future__ import annotations
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
import uuid
from chromadb.config import Settings, System
from chromadb.ingest import Consumer, ConsumerCallbackFn, Producer
from overrides import overrides, EnforceOverrides
from uuid import UUID
from chromadb.ingest.impl.pulsar_admin import PulsarAdmin
from chromadb.ingest.impl.utils import create_pulsar_connection_str
from chromadb.proto.convert import from_proto_submit, to_proto_submit
import chromadb.proto.chroma_pb2 as proto
from chromadb.types import SeqId, SubmitEmbeddingRecord
import pulsar
from concurrent.futures import wait, Future
from chromadb.utils.messageid import int_to_pulsar, pulsar_to_int
class PulsarProducer(Producer, EnforceOverrides):
_connection_str: str
_topic_to_producer: Dict[str, pulsar.Producer]
_client: pulsar.Client
_admin: PulsarAdmin
_settings: Settings
def __init__(self, system: System) -> None:
pulsar_host = system.settings.require("pulsar_broker_url")
pulsar_port = system.settings.require("pulsar_broker_port")
self._connection_str = create_pulsar_connection_str(pulsar_host, pulsar_port)
self._topic_to_producer = {}
self._settings = system.settings
self._admin = PulsarAdmin(system)
super().__init__(system)
@overrides
def start(self) -> None:
self._client = pulsar.Client(self._connection_str)
super().start()
@overrides
def stop(self) -> None:
self._client.close()
super().stop()
@overrides
def create_topic(self, topic_name: str) -> None:
self._admin.create_topic(topic_name)
@overrides
def delete_topic(self, topic_name: str) -> None:
self._admin.delete_topic(topic_name)
@overrides
def submit_embedding(
self, topic_name: str, embedding: SubmitEmbeddingRecord
) -> SeqId:
"""Add an embedding record to the given topic. Returns the SeqID of the record."""
producer = self._get_or_create_producer(topic_name)
proto_submit: proto.SubmitEmbeddingRecord = to_proto_submit(embedding)
# TODO: batch performance / async
msg_id: pulsar.MessageId = producer.send(proto_submit.SerializeToString())
return pulsar_to_int(msg_id)
@overrides
def submit_embeddings(
self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord]
) -> Sequence[SeqId]:
if not self._running:
raise RuntimeError("Component not running")
if len(embeddings) == 0:
return []
if len(embeddings) > self.max_batch_size:
raise ValueError(
f"""
Cannot submit more than {self.max_batch_size:,} embeddings at once.
Please submit your embeddings in batches of size
{self.max_batch_size:,} or less.
"""
)
producer = self._get_or_create_producer(topic_name)
protos_to_submit = [to_proto_submit(embedding) for embedding in embeddings]
def create_producer_callback(
future: Future[int],
) -> Callable[[Any, pulsar.MessageId], None]:
def producer_callback(res: Any, msg_id: pulsar.MessageId) -> None:
if msg_id:
future.set_result(pulsar_to_int(msg_id))
else:
future.set_exception(
Exception(
"Unknown error while submitting embedding in producer_callback"
)
)
return producer_callback
futures = []
for proto_to_submit in protos_to_submit:
future: Future[int] = Future()
producer.send_async(
proto_to_submit.SerializeToString(),
callback=create_producer_callback(future),
)
futures.append(future)
wait(futures)
results: List[SeqId] = []
for future in futures:
exception = future.exception()
if exception is not None:
raise exception
results.append(future.result())
return results
@property
@overrides
def max_batch_size(self) -> int:
# For now, we use 1,000
# TODO: tune this to a reasonable value by default
return 1000
def _get_or_create_producer(self, topic_name: str) -> pulsar.Producer:
if topic_name not in self._topic_to_producer:
producer = self._client.create_producer(topic_name)
self._topic_to_producer[topic_name] = producer
return self._topic_to_producer[topic_name]
@overrides
def reset_state(self) -> None:
if not self._settings.require("allow_reset"):
raise ValueError(
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
)
for topic_name in self._topic_to_producer:
self._admin.delete_topic(topic_name)
self._topic_to_producer = {}
super().reset_state()
class PulsarConsumer(Consumer, EnforceOverrides):
class PulsarSubscription:
id: UUID
topic_name: str
start: int
end: int
callback: ConsumerCallbackFn
consumer: pulsar.Consumer
def __init__(
self,
id: UUID,
topic_name: str,
start: int,
end: int,
callback: ConsumerCallbackFn,
consumer: pulsar.Consumer,
):
self.id = id
self.topic_name = topic_name
self.start = start
self.end = end
self.callback = callback
self.consumer = consumer
_connection_str: str
_client: pulsar.Client
_subscriptions: Dict[str, Set[PulsarSubscription]]
_settings: Settings
def __init__(self, system: System) -> None:
pulsar_host = system.settings.require("pulsar_broker_url")
pulsar_port = system.settings.require("pulsar_broker_port")
self._connection_str = create_pulsar_connection_str(pulsar_host, pulsar_port)
self._subscriptions = defaultdict(set)
self._settings = system.settings
super().__init__(system)
@overrides
def start(self) -> None:
self._client = pulsar.Client(self._connection_str)
super().start()
@overrides
def stop(self) -> None:
self._client.close()
super().stop()
@overrides
def subscribe(
self,
topic_name: str,
consume_fn: ConsumerCallbackFn,
start: Optional[SeqId] = None,
end: Optional[SeqId] = None,
id: Optional[UUID] = None,
) -> UUID:
"""Register a function that will be called to recieve embeddings for a given
topic. The given function may be called any number of times, with any number of
records, and may be called concurrently.
Only records between start (exclusive) and end (inclusive) SeqIDs will be
returned. If start is None, the first record returned will be the next record
generated, not including those generated before creating the subscription. If
end is None, the consumer will consume indefinitely, otherwise it will
automatically be unsubscribed when the end SeqID is reached.
If the function throws an exception, the function may be called again with the
same or different records.
Takes an optional UUID as a unique subscription ID. If no ID is provided, a new
ID will be generated and returned."""
if not self._running:
raise RuntimeError("Consumer must be started before subscribing")
subscription_id = (
id or uuid.uuid4()
) # TODO: this should really be created by the coordinator and stored in sysdb
start, end = self._validate_range(start, end)
def wrap_callback(consumer: pulsar.Consumer, message: pulsar.Message) -> None:
msg_data = message.data()
msg_id = pulsar_to_int(message.message_id())
submit_embedding_record = proto.SubmitEmbeddingRecord()
proto.SubmitEmbeddingRecord.ParseFromString(
submit_embedding_record, msg_data
)
embedding_record = from_proto_submit(submit_embedding_record, msg_id)
consume_fn([embedding_record])
consumer.acknowledge(message)
if msg_id == end:
self.unsubscribe(subscription_id)
consumer = self._client.subscribe(
topic_name,
subscription_id.hex,
message_listener=wrap_callback,
)
subscription = self.PulsarSubscription(
subscription_id, topic_name, start, end, consume_fn, consumer
)
self._subscriptions[topic_name].add(subscription)
# NOTE: For some reason the seek() method expects a shadowed MessageId type
# which resides in _msg_id.
consumer.seek(int_to_pulsar(start)._msg_id)
return subscription_id
def _validate_range(
self, start: Optional[SeqId], end: Optional[SeqId]
) -> Tuple[int, int]:
"""Validate and normalize the start and end SeqIDs for a subscription using this
impl."""
start = start or pulsar_to_int(pulsar.MessageId.latest)
end = end or self.max_seqid()
if not isinstance(start, int) or not isinstance(end, int):
raise ValueError("SeqIDs must be integers")
if start >= end:
raise ValueError(f"Invalid SeqID range: {start} to {end}")
return start, end
@overrides
def unsubscribe(self, subscription_id: UUID) -> None:
"""Unregister a subscription. The consume function will no longer be invoked,
and resources associated with the subscription will be released."""
for topic_name, subscriptions in self._subscriptions.items():
for subscription in subscriptions:
if subscription.id == subscription_id:
subscription.consumer.close()
subscriptions.remove(subscription)
if len(subscriptions) == 0:
del self._subscriptions[topic_name]
return
@overrides
def min_seqid(self) -> SeqId:
"""Return the minimum possible SeqID in this implementation."""
return pulsar_to_int(pulsar.MessageId.earliest)
@overrides
def max_seqid(self) -> SeqId:
"""Return the maximum possible SeqID in this implementation."""
return 2**192 - 1
@overrides
def reset_state(self) -> None:
if not self._settings.require("allow_reset"):
raise ValueError(
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
)
for topic_name, subscriptions in self._subscriptions.items():
for subscription in subscriptions:
subscription.consumer.close()
self._subscriptions = defaultdict(set)
super().reset_state()

View File

@@ -0,0 +1,81 @@
# A thin wrapper around the pulsar admin api
import requests
from chromadb.config import System
from chromadb.ingest.impl.utils import parse_topic_name
class PulsarAdmin:
"""A thin wrapper around the pulsar admin api, only used for interim development towards distributed chroma.
This functionality will be moved to the chroma coordinator."""
_connection_str: str
def __init__(self, system: System):
pulsar_host = system.settings.require("pulsar_broker_url")
pulsar_port = system.settings.require("pulsar_admin_port")
self._connection_str = f"http://{pulsar_host}:{pulsar_port}"
# Create the default tenant and namespace
# This is a temporary workaround until we have a proper tenant/namespace management system
self.create_tenant("default")
self.create_namespace("default", "default")
def create_tenant(self, tenant: str) -> None:
"""Make a PUT request to the admin api to create the tenant"""
path = f"/admin/v2/tenants/{tenant}"
url = self._connection_str + path
response = requests.put(
url, json={"allowedClusters": ["standalone"], "adminRoles": []}
) # TODO: how to manage clusters?
if response.status_code != 204 and response.status_code != 409:
raise RuntimeError(f"Failed to create tenant {tenant}")
def create_namespace(self, tenant: str, namespace: str) -> None:
"""Make a PUT request to the admin api to create the namespace"""
path = f"/admin/v2/namespaces/{tenant}/{namespace}"
url = self._connection_str + path
response = requests.put(url)
if response.status_code != 204 and response.status_code != 409:
raise RuntimeError(f"Failed to create namespace {namespace}")
def create_topic(self, topic: str) -> None:
# TODO: support non-persistent topics?
tenant, namespace, topic_name = parse_topic_name(topic)
if tenant != "default":
raise ValueError(f"Only the default tenant is supported, got {tenant}")
if namespace != "default":
raise ValueError(
f"Only the default namespace is supported, got {namespace}"
)
# Make a PUT request to the admin api to create the topic
path = f"/admin/v2/persistent/{tenant}/{namespace}/{topic_name}"
url = self._connection_str + path
response = requests.put(url)
if response.status_code != 204 and response.status_code != 409:
raise RuntimeError(f"Failed to create topic {topic_name}")
def delete_topic(self, topic: str) -> None:
tenant, namespace, topic_name = parse_topic_name(topic)
if tenant != "default":
raise ValueError(f"Only the default tenant is supported, got {tenant}")
if namespace != "default":
raise ValueError(
f"Only the default namespace is supported, got {namespace}"
)
# Make a PUT request to the admin api to delete the topic
path = f"/admin/v2/persistent/{tenant}/{namespace}/{topic_name}"
# Force delete the topic
path += "?force=true"
url = self._connection_str + path
response = requests.delete(url)
if response.status_code != 204 and response.status_code != 409:
raise RuntimeError(f"Failed to delete topic {topic_name}")

View File

@@ -0,0 +1,20 @@
import re
from typing import Tuple
topic_regex = r"persistent:\/\/(?P<tenant>.+)\/(?P<namespace>.+)\/(?P<topic>.+)"
def parse_topic_name(topic_name: str) -> Tuple[str, str, str]:
"""Parse the topic name into the tenant, namespace and topic name"""
match = re.match(topic_regex, topic_name)
if not match:
raise ValueError(f"Invalid topic name: {topic_name}")
return match.group("tenant"), match.group("namespace"), match.group("topic")
def create_pulsar_connection_str(host: str, port: str) -> str:
return f"pulsar://{host}:{port}"
def create_topic_name(tenant: str, namespace: str, topic: str) -> str:
return f"persistent://{tenant}/{namespace}/{topic}"

View File

@@ -0,0 +1,40 @@
syntax = "proto3";
package chroma;
enum Operation {
ADD = 0;
UPDATE = 1;
UPSERT = 2;
DELETE = 3;
}
enum ScalarEncoding {
FLOAT32 = 0;
INT32 = 1;
}
message Vector {
int32 dimension = 1;
bytes vector = 2;
ScalarEncoding encoding = 3;
}
message UpdateMetadataValue {
oneof value {
string string_value = 1;
int64 int_value = 2;
double float_value = 3;
}
}
message UpdateMetadata {
map<string, UpdateMetadataValue> metadata = 1;
}
message SubmitEmbeddingRecord {
string id = 1;
optional Vector vector = 2;
optional UpdateMetadata metadata = 3;
Operation operation = 4;
}

View File

@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: chromadb/proto/chroma.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x1b\x63hromadb/proto/chroma.proto\x12\x06\x63hroma"U\n\x06Vector\x12\x11\n\tdimension\x18\x01 \x01(\x05\x12\x0e\n\x06vector\x18\x02 \x01(\x0c\x12(\n\x08\x65ncoding\x18\x03 \x01(\x0e\x32\x16.chroma.ScalarEncoding"b\n\x13UpdateMetadataValue\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value"\x96\x01\n\x0eUpdateMetadata\x12\x36\n\x08metadata\x18\x01 \x03(\x0b\x32$.chroma.UpdateMetadata.MetadataEntry\x1aL\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0b\x32\x1b.chroma.UpdateMetadataValue:\x02\x38\x01"\xb5\x01\n\x15SubmitEmbeddingRecord\x12\n\n\x02id\x18\x01 \x01(\t\x12#\n\x06vector\x18\x02 \x01(\x0b\x32\x0e.chroma.VectorH\x00\x88\x01\x01\x12-\n\x08metadata\x18\x03 \x01(\x0b\x32\x16.chroma.UpdateMetadataH\x01\x88\x01\x01\x12$\n\toperation\x18\x04 \x01(\x0e\x32\x11.chroma.OperationB\t\n\x07_vectorB\x0b\n\t_metadata*8\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\n\n\x06UPDATE\x10\x01\x12\n\n\x06UPSERT\x10\x02\x12\n\n\x06\x44\x45LETE\x10\x03*(\n\x0eScalarEncoding\x12\x0b\n\x07\x46LOAT32\x10\x00\x12\t\n\x05INT32\x10\x01\x62\x06proto3'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, "chromadb.proto.chroma_pb2", _globals
)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_UPDATEMETADATA_METADATAENTRY._options = None
_UPDATEMETADATA_METADATAENTRY._serialized_options = b"8\001"
_globals["_OPERATION"]._serialized_start = 563
_globals["_OPERATION"]._serialized_end = 619
_globals["_SCALARENCODING"]._serialized_start = 621
_globals["_SCALARENCODING"]._serialized_end = 661
_globals["_VECTOR"]._serialized_start = 39
_globals["_VECTOR"]._serialized_end = 124
_globals["_UPDATEMETADATAVALUE"]._serialized_start = 126
_globals["_UPDATEMETADATAVALUE"]._serialized_end = 224
_globals["_UPDATEMETADATA"]._serialized_start = 227
_globals["_UPDATEMETADATA"]._serialized_end = 377
_globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_start = 301
_globals["_UPDATEMETADATA_METADATAENTRY"]._serialized_end = 377
_globals["_SUBMITEMBEDDINGRECORD"]._serialized_start = 380
_globals["_SUBMITEMBEDDINGRECORD"]._serialized_end = 561
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,247 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
class _Operation:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _OperationEnumTypeWrapper(
google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_Operation.ValueType],
builtins.type,
):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
ADD: _Operation.ValueType # 0
UPDATE: _Operation.ValueType # 1
UPSERT: _Operation.ValueType # 2
DELETE: _Operation.ValueType # 3
class Operation(_Operation, metaclass=_OperationEnumTypeWrapper): ...
ADD: Operation.ValueType # 0
UPDATE: Operation.ValueType # 1
UPSERT: Operation.ValueType # 2
DELETE: Operation.ValueType # 3
global___Operation = Operation
class _ScalarEncoding:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _ScalarEncodingEnumTypeWrapper(
google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
_ScalarEncoding.ValueType
],
builtins.type,
):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
FLOAT32: _ScalarEncoding.ValueType # 0
INT32: _ScalarEncoding.ValueType # 1
class ScalarEncoding(_ScalarEncoding, metaclass=_ScalarEncodingEnumTypeWrapper): ...
FLOAT32: ScalarEncoding.ValueType # 0
INT32: ScalarEncoding.ValueType # 1
global___ScalarEncoding = ScalarEncoding
@typing_extensions.final
class Vector(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DIMENSION_FIELD_NUMBER: builtins.int
VECTOR_FIELD_NUMBER: builtins.int
ENCODING_FIELD_NUMBER: builtins.int
dimension: builtins.int
vector: builtins.bytes
encoding: global___ScalarEncoding.ValueType
def __init__(
self,
*,
dimension: builtins.int = ...,
vector: builtins.bytes = ...,
encoding: global___ScalarEncoding.ValueType = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"dimension", b"dimension", "encoding", b"encoding", "vector", b"vector"
],
) -> None: ...
global___Vector = Vector
@typing_extensions.final
class UpdateMetadataValue(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
STRING_VALUE_FIELD_NUMBER: builtins.int
INT_VALUE_FIELD_NUMBER: builtins.int
FLOAT_VALUE_FIELD_NUMBER: builtins.int
string_value: builtins.str
int_value: builtins.int
float_value: builtins.float
def __init__(
self,
*,
string_value: builtins.str = ...,
int_value: builtins.int = ...,
float_value: builtins.float = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"float_value",
b"float_value",
"int_value",
b"int_value",
"string_value",
b"string_value",
"value",
b"value",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"float_value",
b"float_value",
"int_value",
b"int_value",
"string_value",
b"string_value",
"value",
b"value",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["value", b"value"]
) -> (
typing_extensions.Literal["string_value", "int_value", "float_value"] | None
): ...
global___UpdateMetadataValue = UpdateMetadataValue
@typing_extensions.final
class UpdateMetadata(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@typing_extensions.final
class MetadataEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.str
@property
def value(self) -> global___UpdateMetadataValue: ...
def __init__(
self,
*,
key: builtins.str = ...,
value: global___UpdateMetadataValue | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["value", b"value"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal["key", b"key", "value", b"value"],
) -> None: ...
METADATA_FIELD_NUMBER: builtins.int
@property
def metadata(
self,
) -> google.protobuf.internal.containers.MessageMap[
builtins.str, global___UpdateMetadataValue
]: ...
def __init__(
self,
*,
metadata: collections.abc.Mapping[builtins.str, global___UpdateMetadataValue]
| None = ...,
) -> None: ...
def ClearField(
self, field_name: typing_extensions.Literal["metadata", b"metadata"]
) -> None: ...
global___UpdateMetadata = UpdateMetadata
@typing_extensions.final
class SubmitEmbeddingRecord(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ID_FIELD_NUMBER: builtins.int
VECTOR_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
OPERATION_FIELD_NUMBER: builtins.int
id: builtins.str
@property
def vector(self) -> global___Vector: ...
@property
def metadata(self) -> global___UpdateMetadata: ...
operation: global___Operation.ValueType
def __init__(
self,
*,
id: builtins.str = ...,
vector: global___Vector | None = ...,
metadata: global___UpdateMetadata | None = ...,
operation: global___Operation.ValueType = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_metadata",
b"_metadata",
"_vector",
b"_vector",
"metadata",
b"metadata",
"vector",
b"vector",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_metadata",
b"_metadata",
"_vector",
b"_vector",
"id",
b"id",
"metadata",
b"metadata",
"operation",
b"operation",
"vector",
b"vector",
],
) -> None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_metadata", b"_metadata"]
) -> typing_extensions.Literal["metadata"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_vector", b"_vector"]
) -> typing_extensions.Literal["vector"] | None: ...
global___SubmitEmbeddingRecord = SubmitEmbeddingRecord

150
chromadb/proto/convert.py Normal file
View File

@@ -0,0 +1,150 @@
import array
from typing import Optional, Tuple, Union
from chromadb.api.types import Embedding
import chromadb.proto.chroma_pb2 as proto
from chromadb.types import (
EmbeddingRecord,
Metadata,
Operation,
ScalarEncoding,
SeqId,
SubmitEmbeddingRecord,
Vector,
)
def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> proto.Vector:
if encoding == ScalarEncoding.FLOAT32:
as_bytes = array.array("f", vector).tobytes()
proto_encoding = proto.ScalarEncoding.FLOAT32
elif encoding == ScalarEncoding.INT32:
as_bytes = array.array("i", vector).tobytes()
proto_encoding = proto.ScalarEncoding.INT32
else:
raise ValueError(
f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \
or {ScalarEncoding.INT32}"
)
return proto.Vector(dimension=len(vector), vector=as_bytes, encoding=proto_encoding)
def from_proto_vector(vector: proto.Vector) -> Tuple[Embedding, ScalarEncoding]:
encoding = vector.encoding
as_array: array.array[float] | array.array[int]
if encoding == proto.ScalarEncoding.FLOAT32:
as_array = array.array("f")
out_encoding = ScalarEncoding.FLOAT32
elif encoding == proto.ScalarEncoding.INT32:
as_array = array.array("i")
out_encoding = ScalarEncoding.INT32
else:
raise ValueError(
f"Unknown encoding {encoding}, expected one of \
{proto.ScalarEncoding.FLOAT32} or {proto.ScalarEncoding.INT32}"
)
as_array.frombytes(vector.vector)
return (as_array.tolist(), out_encoding)
def from_proto_operation(operation: proto.Operation.ValueType) -> Operation:
if operation == proto.Operation.ADD:
return Operation.ADD
elif operation == proto.Operation.UPDATE:
return Operation.UPDATE
elif operation == proto.Operation.UPSERT:
return Operation.UPSERT
elif operation == proto.Operation.DELETE:
return Operation.DELETE
else:
raise RuntimeError(f"Unknown operation {operation}") # TODO: full error
def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]:
if not metadata.metadata:
return None
out_metadata = {}
for key, value in metadata.metadata.items():
if value.HasField("string_value"):
out_metadata[key] = value.string_value
elif value.HasField("int_value"):
out_metadata[key] = value.int_value
elif value.HasField("float_value"):
out_metadata[key] = value.float_value
else:
raise RuntimeError(f"Unknown metadata value type {value}")
return out_metadata
def from_proto_submit(
submit_embedding_record: proto.SubmitEmbeddingRecord, seq_id: SeqId
) -> EmbeddingRecord:
embedding, encoding = from_proto_vector(submit_embedding_record.vector)
record = EmbeddingRecord(
id=submit_embedding_record.id,
seq_id=seq_id,
embedding=embedding,
encoding=encoding,
metadata=from_proto_metadata(submit_embedding_record.metadata),
operation=from_proto_operation(submit_embedding_record.operation),
)
return record
def to_proto_metadata_update_value(
value: Union[str, int, float, None]
) -> proto.UpdateMetadataValue:
if isinstance(value, str):
return proto.UpdateMetadataValue(string_value=value)
elif isinstance(value, int):
return proto.UpdateMetadataValue(int_value=value)
elif isinstance(value, float):
return proto.UpdateMetadataValue(float_value=value)
elif value is None:
return proto.UpdateMetadataValue()
else:
raise ValueError(
f"Unknown metadata value type {type(value)}, expected one of str, int, \
float, or None"
)
def to_proto_operation(operation: Operation) -> proto.Operation.ValueType:
if operation == Operation.ADD:
return proto.Operation.ADD
elif operation == Operation.UPDATE:
return proto.Operation.UPDATE
elif operation == Operation.UPSERT:
return proto.Operation.UPSERT
elif operation == Operation.DELETE:
return proto.Operation.DELETE
else:
raise ValueError(
f"Unknown operation {operation}, expected one of {Operation.ADD}, \
{Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}"
)
def to_proto_submit(
submit_record: SubmitEmbeddingRecord,
) -> proto.SubmitEmbeddingRecord:
vector = None
if submit_record["embedding"] is not None and submit_record["encoding"] is not None:
vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"])
metadata = None
if submit_record["metadata"] is not None:
metadata = {
k: to_proto_metadata_update_value(v)
for k, v in submit_record["metadata"].items()
}
return proto.SubmitEmbeddingRecord(
id=submit_record["id"],
vector=vector,
metadata=proto.UpdateMetadata(metadata=metadata)
if metadata is not None
else None,
operation=to_proto_operation(submit_record["operation"]),
)

View File

@@ -469,9 +469,9 @@ class SqliteMetadataSegment(MetadataReader):
def _encode_seq_id(seq_id: SeqId) -> bytes:
"""Encode a SeqID into a byte array"""
if seq_id.bit_length() < 64:
if seq_id.bit_length() <= 64:
return int.to_bytes(seq_id, 8, "big")
elif seq_id.bit_length() < 192:
elif seq_id.bit_length() <= 192:
return int.to_bytes(seq_id, 24, "big")
else:
raise ValueError(f"Unsupported SeqID: {seq_id}")

View File

@@ -207,7 +207,6 @@ class PersistentLocalHnswSegment(LocalHnswSegment):
"""Add a batch of embeddings to the index"""
if not self._running:
raise RuntimeError("Cannot add embeddings to stopped component")
with WriteRWLock(self._lock):
for record in records:
if record["embedding"] is not None:

View File

@@ -1,3 +1,4 @@
import asyncio
import os
import shutil
import tempfile
@@ -16,6 +17,7 @@ from typing import (
)
from chromadb.ingest import Producer, Consumer
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.ingest.impl.utils import create_topic_name
from chromadb.test.conftest import ProducerFn
from chromadb.types import (
SubmitEmbeddingRecord,
@@ -51,8 +53,33 @@ def sqlite_persistent() -> Generator[Tuple[Producer, Consumer], None, None]:
shutil.rmtree(save_path)
def pulsar() -> Generator[Tuple[Producer, Consumer], None, None]:
"""Fixture generator for pulsar Producer + Consumer. This fixture requires a running
pulsar cluster. You can use bin/cluster-test.sh to start a standalone pulsar and run this test
"""
system = System(
Settings(
allow_reset=True,
chroma_producer_impl="chromadb.ingest.impl.pulsar.PulsarProducer",
chroma_consumer_impl="chromadb.ingest.impl.pulsar.PulsarConsumer",
pulsar_broker_url="localhost",
pulsar_admin_port="8080",
pulsar_broker_port="6650",
)
)
producer = system.require(Producer)
consumer = system.require(Consumer)
system.start()
yield producer, consumer
system.stop()
def fixtures() -> List[Callable[[], Generator[Tuple[Producer, Consumer], None, None]]]:
return [sqlite, sqlite_persistent]
fixtures = [sqlite, sqlite_persistent]
if "CHROMA_CLUSTER_TEST_ONLY" in os.environ:
fixtures = [pulsar]
return fixtures
@pytest.fixture(scope="module", params=fixtures())
@@ -89,14 +116,20 @@ class CapturingConsumeFn:
waiters: List[Tuple[int, Event]]
def __init__(self) -> None:
"""A function that captures embeddings and allows you to wait for a certain
number of embeddings to be available. It must be constructed in the thread with
the main event loop
"""
self.embeddings = []
self.waiters = []
self._loop = asyncio.get_event_loop()
def __call__(self, embeddings: Sequence[EmbeddingRecord]) -> None:
self.embeddings.extend(embeddings)
for n, event in self.waiters:
if len(self.embeddings) >= n:
event.set()
# event.set() is not thread safe, so we need to call it in the main event loop
self._loop.call_soon_threadsafe(event.set)
async def get(self, n: int, timeout_secs: int = 10) -> Sequence[EmbeddingRecord]:
"Wait until at least N embeddings are available, then return all embeddings"
@@ -132,6 +165,10 @@ def assert_records_match(
assert_approx_equal(inserted["embedding"], consumed["embedding"])
def full_topic_name(topic_name: str) -> str:
return create_topic_name("default", "default", topic_name)
@pytest.mark.asyncio
async def test_backfill(
producer_consumer: Tuple[Producer, Consumer],
@@ -140,12 +177,14 @@ async def test_backfill(
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
consumer.reset_state()
producer.create_topic("test_topic")
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 3)[0]
topic_name = full_topic_name("test_topic")
producer.create_topic(topic_name)
embeddings = produce_fns(producer, topic_name, sample_embeddings, 3)[0]
consume_fn = CapturingConsumeFn()
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
recieved = await consume_fn.get(3)
assert_records_match(embeddings, recieved)
@@ -158,18 +197,21 @@ async def test_notifications(
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
producer.create_topic("test_topic")
consumer.reset_state()
topic_name = full_topic_name("test_topic")
producer.create_topic(topic_name)
embeddings: List[SubmitEmbeddingRecord] = []
consume_fn = CapturingConsumeFn()
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
for i in range(10):
e = next(sample_embeddings)
embeddings.append(e)
producer.submit_embedding("test_topic", e)
producer.submit_embedding(topic_name, e)
received = await consume_fn.get(i + 1)
assert_records_match(embeddings, received)
@@ -181,8 +223,11 @@ async def test_multiple_topics(
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
producer.create_topic("test_topic_1")
producer.create_topic("test_topic_2")
consumer.reset_state()
topic_name_1 = full_topic_name("test_topic_1")
topic_name_2 = full_topic_name("test_topic_2")
producer.create_topic(topic_name_1)
producer.create_topic(topic_name_2)
embeddings_1: List[SubmitEmbeddingRecord] = []
embeddings_2: List[SubmitEmbeddingRecord] = []
@@ -190,19 +235,19 @@ async def test_multiple_topics(
consume_fn_1 = CapturingConsumeFn()
consume_fn_2 = CapturingConsumeFn()
consumer.subscribe("test_topic_1", consume_fn_1, start=consumer.min_seqid())
consumer.subscribe("test_topic_2", consume_fn_2, start=consumer.min_seqid())
consumer.subscribe(topic_name_1, consume_fn_1, start=consumer.min_seqid())
consumer.subscribe(topic_name_2, consume_fn_2, start=consumer.min_seqid())
for i in range(10):
e_1 = next(sample_embeddings)
embeddings_1.append(e_1)
producer.submit_embedding("test_topic_1", e_1)
producer.submit_embedding(topic_name_1, e_1)
results_2 = await consume_fn_1.get(i + 1)
assert_records_match(embeddings_1, results_2)
e_2 = next(sample_embeddings)
embeddings_2.append(e_2)
producer.submit_embedding("test_topic_2", e_2)
producer.submit_embedding(topic_name_2, e_2)
results_2 = await consume_fn_2.get(i + 1)
assert_records_match(embeddings_2, results_2)
@@ -215,21 +260,23 @@ async def test_start_seq_id(
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
producer.create_topic("test_topic")
consumer.reset_state()
topic_name = full_topic_name("test_topic")
producer.create_topic(topic_name)
consume_fn_1 = CapturingConsumeFn()
consume_fn_2 = CapturingConsumeFn()
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
consumer.subscribe(topic_name, consume_fn_1, start=consumer.min_seqid())
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
embeddings = produce_fns(producer, topic_name, sample_embeddings, 5)[0]
results_1 = await consume_fn_1.get(5)
assert_records_match(embeddings, results_1)
start = consume_fn_1.embeddings[-1]["seq_id"]
consumer.subscribe("test_topic", consume_fn_2, start=start)
second_embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
consumer.subscribe(topic_name, consume_fn_2, start=start)
second_embeddings = produce_fns(producer, topic_name, sample_embeddings, 5)[0]
assert isinstance(embeddings, list)
embeddings.extend(second_embeddings)
results_2 = await consume_fn_2.get(5)
@@ -244,20 +291,22 @@ async def test_end_seq_id(
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
producer.create_topic("test_topic")
consumer.reset_state()
topic_name = full_topic_name("test_topic")
producer.create_topic(topic_name)
consume_fn_1 = CapturingConsumeFn()
consume_fn_2 = CapturingConsumeFn()
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
consumer.subscribe(topic_name, consume_fn_1, start=consumer.min_seqid())
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 10)[0]
embeddings = produce_fns(producer, topic_name, sample_embeddings, 10)[0]
results_1 = await consume_fn_1.get(10)
assert_records_match(embeddings, results_1)
end = consume_fn_1.embeddings[-5]["seq_id"]
consumer.subscribe("test_topic", consume_fn_2, start=consumer.min_seqid(), end=end)
consumer.subscribe(topic_name, consume_fn_2, start=consumer.min_seqid(), end=end)
results_2 = await consume_fn_2.get(6)
assert_records_match(embeddings[:6], results_2)
@@ -274,14 +323,16 @@ async def test_submit_batch(
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
consumer.reset_state()
topic_name = full_topic_name("test_topic")
embeddings = [next(sample_embeddings) for _ in range(100)]
producer.create_topic("test_topic")
producer.submit_embeddings("test_topic", embeddings=embeddings)
producer.create_topic(topic_name)
producer.submit_embeddings(topic_name, embeddings=embeddings)
consume_fn = CapturingConsumeFn()
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
recieved = await consume_fn.get(100)
assert_records_match(embeddings, recieved)
@@ -295,13 +346,16 @@ async def test_multiple_topics_batch(
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
consumer.reset_state()
N_TOPICS = 100
N_TOPICS = 2
consume_fns = [CapturingConsumeFn() for _ in range(N_TOPICS)]
for i in range(N_TOPICS):
producer.create_topic(f"test_topic_{i}")
producer.create_topic(full_topic_name(f"test_topic_{i}"))
consumer.subscribe(
f"test_topic_{i}", consume_fns[i], start=consumer.min_seqid()
full_topic_name(f"test_topic_{i}"),
consume_fns[i],
start=consumer.min_seqid(),
)
embeddings_n: List[List[SubmitEmbeddingRecord]] = [[] for _ in range(N_TOPICS)]
@@ -310,17 +364,17 @@ async def test_multiple_topics_batch(
N_TO_PRODUCE = 100
total_produced = 0
for i in range(N_TO_PRODUCE // PRODUCE_BATCH_SIZE):
for i in range(N_TOPICS):
embeddings_n[i].extend(
for n in range(N_TOPICS):
embeddings_n[n].extend(
produce_fns(
producer,
f"test_topic_{i}",
full_topic_name(f"test_topic_{n}"),
sample_embeddings,
PRODUCE_BATCH_SIZE,
)[0]
)
recieved = await consume_fns[i].get(total_produced + PRODUCE_BATCH_SIZE)
assert_records_match(embeddings_n[i], recieved)
recieved = await consume_fns[n].get(total_produced + PRODUCE_BATCH_SIZE)
assert_records_match(embeddings_n[n], recieved)
total_produced += PRODUCE_BATCH_SIZE
@@ -331,19 +385,21 @@ async def test_max_batch_size(
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
max_batch_size = producer_consumer[0].max_batch_size
consumer.reset_state()
topic_name = full_topic_name("test_topic")
max_batch_size = producer.max_batch_size
assert max_batch_size > 0
# Make sure that we can produce a batch of size max_batch_size
embeddings = [next(sample_embeddings) for _ in range(max_batch_size)]
consume_fn = CapturingConsumeFn()
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
producer.submit_embeddings("test_topic", embeddings=embeddings)
consumer.subscribe(topic_name, consume_fn, start=consumer.min_seqid())
producer.submit_embeddings(topic_name, embeddings=embeddings)
received = await consume_fn.get(max_batch_size, timeout_secs=120)
assert_records_match(embeddings, received)
embeddings = [next(sample_embeddings) for _ in range(max_batch_size + 1)]
# Make sure that we can't produce a batch of size > max_batch_size
with pytest.raises(ValueError) as e:
producer.submit_embeddings("test_topic", embeddings=embeddings)
producer.submit_embeddings(topic_name, embeddings=embeddings)
assert "Cannot submit more than" in str(e.value)

View File

@@ -0,0 +1,66 @@
# This docker compose file is not meant to be used. It is a work in progress
# for the distributed version of Chroma. It is not yet functional.
version: '3.9'
networks:
net:
driver: bridge
services:
server:
image: server
build:
context: .
dockerfile: Dockerfile
volumes:
- ./:/chroma
- index_data:/index_data
command: uvicorn chromadb.app:app --reload --workers 1 --host 0.0.0.0 --port 8000 --log-config log_config.yml
environment:
- IS_PERSISTENT=TRUE
- CHROMA_PRODUCER_IMPL=chromadb.ingest.impl.pulsar.PulsarProducer
- CHROMA_CONSUMER_IMPL=chromadb.ingest.impl.pulsar.PulsarConsumer
- PULSAR_BROKER_URL=pulsar
- PULSAR_BROKER_PORT=6650
- PULSAR_ADMIN_PORT=8080
ports:
- 8000:8000
depends_on:
pulsar:
condition: service_healthy
networks:
- net
pulsar:
image: apachepulsar/pulsar
volumes:
- pulsardata:/pulsar/data
- pulsarconf:/pulsar/conf
command: bin/pulsar standalone
ports:
- 6650:6650
- 8080:8080
networks:
- net
healthcheck:
test:
[
"CMD",
"curl",
"-f",
"localhost:8080/admin/v2/brokers/health"
]
interval: 3s
timeout: 1m
retries: 10
volumes:
index_data:
driver: local
backups:
driver: local
pulsardata:
driver: local
pulsarconf:
driver: local

View File

@@ -3,8 +3,10 @@ build
httpx
hypothesis
hypothesis[numpy]
mypy-protobuf
pre-commit
pytest
pytest-asyncio
setuptools_scm
types-protobuf
types-requests==2.30.0.0