From 6739259ef8e35ece4166fff10e2c3010d50647c4 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Thu, 10 Aug 2023 16:11:11 -0700 Subject: [PATCH] [PERF] Batch SQLite Embeddings Queue (#959) ## Description of changes Adds a batch submit_embeddings() to the producer API. Implements this behavior in sqlite based embeddings queue. Refactors tests to add a fixture for both single and batch submit and substitutes all use of produce in tests with the fixture so both methods are uniformly tested. This PR is stacked on #958, please read that before reading this. *Summarize the changes made by this PR.* - Improvements & Bug fixes - This PR makes the sqlite based embeddings queue batch calls to sqlite based on the batch size of calls to DML operations. ### Quick Benchmarks N = 100k D = 128 metadata = {simple key: simple value} Document = random 100 character string **Before - Overall time: 102s** Screenshot 2023-08-09 at 5 43 12 PM **After - Overall time: 36s** Screenshot 2023-08-09 at 7 37 41 PM **_Overall, when compared to main this and #958 decrease the time of this benchmark from 469s -> 36s._** Todo: - [x] Clean up code - [x] Rethink assumptions about submit_embeddings ordering guarantees - [x] Add tests - [x] Add batching to other DML - upsert(), update(), delete() ## Test plan Existing tests cover a bulk of the functionality. New tests TODO: - [x] Producer/Consumer tests for submit_embeddings - [x] Segment tests with batch embeddings ## Documentation Changes None required. But a misc feature idea is hyperparameter sweep for the optimal batch_size for your platform --- chromadb/api/segment.py | 16 ++- chromadb/db/mixins/embeddings_queue.py | 132 ++++++++++++------ chromadb/ingest/__init__.py | 10 ++ chromadb/test/conftest.py | 70 +++++++++- .../test/ingest/test_producer_consumer.py | 86 +++++++++--- chromadb/test/segment/test_metadata.py | 80 ++++++----- chromadb/test/segment/test_vector.py | 125 +++++++++-------- 7 files changed, 353 insertions(+), 166 deletions(-) diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index e7f911b..e14a506 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -242,9 +242,11 @@ class SegmentAPI(API): coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.ADD) + records_to_submit = [] for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents): self._validate_embedding_record(coll, r) - self._producer.submit_embedding(coll["topic"], r) + records_to_submit.append(r) + self._producer.submit_embeddings(coll["topic"], records_to_submit) self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids))) return True @@ -261,9 +263,11 @@ class SegmentAPI(API): coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) + records_to_submit = [] for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents): self._validate_embedding_record(coll, r) - self._producer.submit_embedding(coll["topic"], r) + records_to_submit.append(r) + self._producer.submit_embeddings(coll["topic"], records_to_submit) return True @@ -279,9 +283,11 @@ class SegmentAPI(API): coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) + records_to_submit = [] for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents): self._validate_embedding_record(coll, r) - self._producer.submit_embedding(coll["topic"], r) + records_to_submit.append(r) + self._producer.submit_embeddings(coll["topic"], records_to_submit) return True @@ -376,9 +382,11 @@ class SegmentAPI(API): else: ids_to_delete = ids + records_to_submit = [] for r in _records(t.Operation.DELETE, ids_to_delete): self._validate_embedding_record(coll, r) - self._producer.submit_embedding(coll["topic"], r) + records_to_submit.append(r) + self._producer.submit_embeddings(coll["topic"], records_to_submit) self._telemetry_client.capture( CollectionDeleteEvent(str(collection_id), len(ids_to_delete)) diff --git a/chromadb/db/mixins/embeddings_queue.py b/chromadb/db/mixins/embeddings_queue.py index c5f3d18..225de6b 100644 --- a/chromadb/db/mixins/embeddings_queue.py +++ b/chromadb/db/mixins/embeddings_queue.py @@ -16,7 +16,7 @@ from chromadb.types import ( from chromadb.config import System from overrides import override from collections import defaultdict -from typing import Tuple, Optional, Dict, Set, cast +from typing import Sequence, Tuple, Optional, Dict, Set, cast from uuid import UUID from pypika import Table, functions import uuid @@ -103,22 +103,32 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer): if not self._running: raise RuntimeError("Component not running") - if embedding["embedding"]: - encoding_type = cast(ScalarEncoding, embedding["encoding"]) - encoding = encoding_type.value - embedding_bytes = encode_vector(embedding["embedding"], encoding_type) + return self.submit_embeddings(topic_name, [embedding])[0] - else: - embedding_bytes = None - encoding = None - metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None + @override + def submit_embeddings( + self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord] + ) -> Sequence[SeqId]: + if not self._running: + raise RuntimeError("Component not running") + + if len(embeddings) == 0: + return [] t = Table("embeddings_queue") insert = ( self.querybuilder() .into(t) .columns(t.operation, t.topic, t.id, t.vector, t.encoding, t.metadata) - .insert( + ) + id_to_idx: Dict[str, int] = {} + for embedding in embeddings: + ( + embedding_bytes, + encoding, + metadata, + ) = self._prepare_vector_encoding_metadata(embedding) + insert = insert.insert( ParameterValue(_operation_codes[embedding["operation"]]), ParameterValue(topic_name), ParameterValue(embedding["id"]), @@ -126,21 +136,34 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer): ParameterValue(encoding), ParameterValue(metadata), ) - ) + id_to_idx[embedding["id"]] = len(id_to_idx) with self.tx() as cur: sql, params = get_sql(insert, self.parameter_format()) - sql = f"{sql} RETURNING seq_id" # Pypika doesn't support RETURNING - seq_id = int(cur.execute(sql, params).fetchone()[0]) - embedding_record = EmbeddingRecord( - id=embedding["id"], - seq_id=seq_id, - embedding=embedding["embedding"], - encoding=embedding["encoding"], - metadata=embedding["metadata"], - operation=embedding["operation"], - ) - self._notify_all(topic_name, embedding_record) - return seq_id + # The returning clause does not guarantee order, so we need to do reorder + # the results. https://www.sqlite.org/lang_returning.html + sql = f"{sql} RETURNING seq_id, id" # Pypika doesn't support RETURNING + results = cur.execute(sql, params).fetchall() + # Reorder the results + seq_ids = [cast(SeqId, None)] * len( + results + ) # Lie to mypy: https://stackoverflow.com/questions/76694215/python-type-casting-when-preallocating-list + embedding_records = [] + for seq_id, id in results: + seq_ids[id_to_idx[id]] = seq_id + submit_embedding_record = embeddings[id_to_idx[id]] + # We allow notifying consumers out of order relative to one call to + # submit_embeddings so we do not reorder the records before submitting them + embedding_record = EmbeddingRecord( + id=id, + seq_id=seq_id, + embedding=submit_embedding_record["embedding"], + encoding=submit_embedding_record["encoding"], + metadata=submit_embedding_record["metadata"], + operation=submit_embedding_record["operation"], + ) + embedding_records.append(embedding_record) + self._notify_all(topic_name, embedding_records) + return seq_ids @override def subscribe( @@ -185,6 +208,19 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer): def max_seqid(self) -> SeqId: return 2**63 - 1 + def _prepare_vector_encoding_metadata( + self, embedding: SubmitEmbeddingRecord + ) -> Tuple[Optional[bytes], Optional[str], Optional[str]]: + if embedding["embedding"]: + encoding_type = cast(ScalarEncoding, embedding["encoding"]) + encoding = encoding_type.value + embedding_bytes = encode_vector(embedding["embedding"], encoding_type) + else: + embedding_bytes = None + encoding = None + metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None + return embedding_bytes, encoding, metadata + def _backfill(self, subscription: Subscription) -> None: """Backfill the given subscription with any currently matching records in the DB""" @@ -211,14 +247,16 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer): vector = None self._notify_one( subscription, - EmbeddingRecord( - seq_id=row[0], - operation=_operation_codes_inv[row[1]], - id=row[2], - embedding=vector, - encoding=encoding, - metadata=json.loads(row[5]) if row[5] else None, - ), + [ + EmbeddingRecord( + seq_id=row[0], + operation=_operation_codes_inv[row[1]], + id=row[2], + embedding=vector, + encoding=encoding, + metadata=json.loads(row[5]) if row[5] else None, + ) + ], ) def _validate_range( @@ -242,29 +280,37 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer): cur.execute(q.get_sql()) return int(cur.fetchone()[0]) + 1 - def _notify_all(self, topic: str, embedding: EmbeddingRecord) -> None: + def _notify_all(self, topic: str, embeddings: Sequence[EmbeddingRecord]) -> None: """Send a notification to each subscriber of the given topic.""" if self._running: for sub in self._subscriptions[topic]: - self._notify_one(sub, embedding) + self._notify_one(sub, embeddings) - def _notify_one(self, sub: Subscription, embedding: EmbeddingRecord) -> None: + def _notify_one( + self, sub: Subscription, embeddings: Sequence[EmbeddingRecord] + ) -> None: """Send a notification to a single subscriber.""" - if embedding["seq_id"] > sub.end: - self.unsubscribe(sub.id) - return - - if embedding["seq_id"] <= sub.start: - return + # Filter out any embeddings that are not in the subscription range + should_unsubscribe = False + filtered_embeddings = [] + for embedding in embeddings: + if embedding["seq_id"] <= sub.start: + continue + if embedding["seq_id"] > sub.end: + should_unsubscribe = True + break + filtered_embeddings.append(embedding) # Log errors instead of throwing them to preserve async semantics # for consistency between local and distributed configurations try: - sub.callback([embedding]) + if len(filtered_embeddings) > 0: + sub.callback(filtered_embeddings) + if should_unsubscribe: + self.unsubscribe(sub.id) except BaseException as e: - id = embedding.get("id", embedding.get("delete_id")) logger.error( f"Exception occurred invoking consumer for subscription {sub.id}" - + f"to topic {sub.topic_name} for embedding id {id} ", + + f"to topic {sub.topic_name}", e, ) diff --git a/chromadb/ingest/__init__.py b/chromadb/ingest/__init__.py index 38d6a4f..6aad15e 100644 --- a/chromadb/ingest/__init__.py +++ b/chromadb/ingest/__init__.py @@ -52,6 +52,16 @@ class Producer(Component): """Add an embedding record to the given topic. Returns the SeqID of the record.""" pass + @abstractmethod + def submit_embeddings( + self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord] + ) -> Sequence[SeqId]: + """Add a batch of embedding records to the given topic. Returns the SeqIDs of + the records. The returned SeqIDs will be in the same order as the given + SubmitEmbeddingRecords. However, it is not guaranteed that the SeqIDs will be + processed in the same order as the given SubmitEmbeddingRecords.""" + pass + ConsumerCallbackFn = Callable[[Sequence[EmbeddingRecord]], None] diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index 526e1a0..acd2f84 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -1,5 +1,6 @@ from chromadb.config import Settings, System from chromadb.api import API +from chromadb.ingest import Producer import chromadb.server.fastapi from requests.exceptions import ConnectionError import hypothesis @@ -8,12 +9,23 @@ import os import uvicorn import time import pytest -from typing import Generator, List, Callable, Optional, Tuple +from typing import ( + Generator, + Iterator, + List, + Optional, + Sequence, + Tuple, + Callable, +) +from typing_extensions import Protocol import shutil import logging import socket import multiprocessing +from chromadb.types import SeqId, SubmitEmbeddingRecord + root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) # This will only run when testing @@ -189,3 +201,59 @@ def api(system: System) -> Generator[API, None, None]: system.reset_state() api = system.instance(API) yield api + + +# Producer / Consumer fixtures # + + +class ProducerFn(Protocol): + def __call__( + self, + producer: Producer, + topic: str, + embeddings: Iterator[SubmitEmbeddingRecord], + n: int, + ) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]: + ... + + +def produce_n_single( + producer: Producer, + topic: str, + embeddings: Iterator[SubmitEmbeddingRecord], + n: int, +) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]: + submitted_embeddings = [] + seq_ids = [] + for _ in range(n): + e = next(embeddings) + seq_id = producer.submit_embedding(topic, e) + submitted_embeddings.append(e) + seq_ids.append(seq_id) + return submitted_embeddings, seq_ids + + +def produce_n_batch( + producer: Producer, + topic: str, + embeddings: Iterator[SubmitEmbeddingRecord], + n: int, +) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]: + submitted_embeddings = [] + seq_ids: Sequence[SeqId] = [] + for _ in range(n): + e = next(embeddings) + submitted_embeddings.append(e) + seq_ids = producer.submit_embeddings(topic, submitted_embeddings) + return submitted_embeddings, seq_ids + + +def produce_fn_fixtures() -> List[ProducerFn]: + return [produce_n_single, produce_n_batch] + + +@pytest.fixture(scope="module", params=produce_fn_fixtures()) +def produce_fns( + request: pytest.FixtureRequest, +) -> Generator[ProducerFn, None, None]: + yield request.param diff --git a/chromadb/test/ingest/test_producer_consumer.py b/chromadb/test/ingest/test_producer_consumer.py index 02808aa..22f3695 100644 --- a/chromadb/test/ingest/test_producer_consumer.py +++ b/chromadb/test/ingest/test_producer_consumer.py @@ -16,6 +16,7 @@ from typing import ( ) from chromadb.ingest import Producer, Consumer from chromadb.db.impl.sqlite import SqliteDB +from chromadb.test.conftest import ProducerFn from chromadb.types import ( SubmitEmbeddingRecord, Operation, @@ -135,15 +136,13 @@ def assert_records_match( async def test_backfill( producer_consumer: Tuple[Producer, Consumer], sample_embeddings: Iterator[SubmitEmbeddingRecord], + produce_fns: ProducerFn, ) -> None: producer, consumer = producer_consumer producer.reset_state() - embeddings = [next(sample_embeddings) for _ in range(3)] - producer.create_topic("test_topic") - for e in embeddings: - producer.submit_embedding("test_topic", e) + embeddings = produce_fns(producer, "test_topic", sample_embeddings, 3)[0] consume_fn = CapturingConsumeFn() consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid()) @@ -212,6 +211,7 @@ async def test_multiple_topics( async def test_start_seq_id( producer_consumer: Tuple[Producer, Consumer], sample_embeddings: Iterator[SubmitEmbeddingRecord], + produce_fns: ProducerFn, ) -> None: producer, consumer = producer_consumer producer.reset_state() @@ -222,22 +222,16 @@ async def test_start_seq_id( consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid()) - embeddings = [] - for _ in range(5): - e = next(sample_embeddings) - embeddings.append(e) - producer.submit_embedding("test_topic", e) + embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0] results_1 = await consume_fn_1.get(5) assert_records_match(embeddings, results_1) start = consume_fn_1.embeddings[-1]["seq_id"] consumer.subscribe("test_topic", consume_fn_2, start=start) - for _ in range(5): - e = next(sample_embeddings) - embeddings.append(e) - producer.submit_embedding("test_topic", e) - + second_embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0] + assert isinstance(embeddings, list) + embeddings.extend(second_embeddings) results_2 = await consume_fn_2.get(5) assert_records_match(embeddings[-5:], results_2) @@ -246,6 +240,7 @@ async def test_start_seq_id( async def test_end_seq_id( producer_consumer: Tuple[Producer, Consumer], sample_embeddings: Iterator[SubmitEmbeddingRecord], + produce_fns: ProducerFn, ) -> None: producer, consumer = producer_consumer producer.reset_state() @@ -256,11 +251,7 @@ async def test_end_seq_id( consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid()) - embeddings = [] - for _ in range(10): - e = next(sample_embeddings) - embeddings.append(e) - producer.submit_embedding("test_topic", e) + embeddings = produce_fns(producer, "test_topic", sample_embeddings, 10)[0] results_1 = await consume_fn_1.get(10) assert_records_match(embeddings, results_1) @@ -274,3 +265,60 @@ async def test_end_seq_id( # Should never produce a 7th with pytest.raises(TimeoutError): _ = await wait_for(consume_fn_2.get(7), timeout=1) + + +@pytest.mark.asyncio +async def test_submit_batch( + producer_consumer: Tuple[Producer, Consumer], + sample_embeddings: Iterator[SubmitEmbeddingRecord], +) -> None: + producer, consumer = producer_consumer + producer.reset_state() + + embeddings = [next(sample_embeddings) for _ in range(100)] + + producer.create_topic("test_topic") + producer.submit_embeddings("test_topic", embeddings=embeddings) + + consume_fn = CapturingConsumeFn() + consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid()) + + recieved = await consume_fn.get(100) + assert_records_match(embeddings, recieved) + + +@pytest.mark.asyncio +async def test_multiple_topics_batch( + producer_consumer: Tuple[Producer, Consumer], + sample_embeddings: Iterator[SubmitEmbeddingRecord], + produce_fns: ProducerFn, +) -> None: + producer, consumer = producer_consumer + producer.reset_state() + + N_TOPICS = 100 + consume_fns = [CapturingConsumeFn() for _ in range(N_TOPICS)] + for i in range(N_TOPICS): + producer.create_topic(f"test_topic_{i}") + consumer.subscribe( + f"test_topic_{i}", consume_fns[i], start=consumer.min_seqid() + ) + + embeddings_n: List[List[SubmitEmbeddingRecord]] = [[] for _ in range(N_TOPICS)] + + PRODUCE_BATCH_SIZE = 10 + N_TO_PRODUCE = 100 + total_produced = 0 + for i in range(N_TO_PRODUCE // PRODUCE_BATCH_SIZE): + for i in range(N_TOPICS): + embeddings_n[i].extend( + produce_fns( + producer, + f"test_topic_{i}", + sample_embeddings, + PRODUCE_BATCH_SIZE, + )[0] + ) + recieved = await consume_fns[i].get(total_produced + PRODUCE_BATCH_SIZE) + assert_records_match(embeddings_n[i], recieved) + total_produced += PRODUCE_BATCH_SIZE diff --git a/chromadb/test/segment/test_metadata.py b/chromadb/test/segment/test_metadata.py index 9e83271..772502b 100644 --- a/chromadb/test/segment/test_metadata.py +++ b/chromadb/test/segment/test_metadata.py @@ -4,6 +4,7 @@ import tempfile import pytest from typing import Generator, List, Callable, Iterator, Dict, Optional, Union, Sequence from chromadb.config import System, Settings +from chromadb.test.conftest import ProducerFn from chromadb.types import ( SubmitEmbeddingRecord, MetadataEmbeddingRecord, @@ -128,16 +129,16 @@ def sync(segment: MetadataReader, seq_id: SeqId) -> None: def test_insert_and_count( - system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] + system: System, + sample_embeddings: Iterator[SubmitEmbeddingRecord], + produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() topic = str(segment_definition["topic"]) - max_id = 0 - for i in range(3): - max_id = producer.submit_embedding(topic, next(sample_embeddings)) + max_id = produce_fns(producer, topic, sample_embeddings, 3)[1][-1] segment = SqliteMetadataSegment(system, segment_definition) segment.start() @@ -166,17 +167,15 @@ def assert_equiv_records( def test_get( - system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] + system: System, + sample_embeddings: Iterator[SubmitEmbeddingRecord], + produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() topic = str(segment_definition["topic"]) - embeddings = [next(sample_embeddings) for i in range(10)] - - seq_ids = [] - for e in embeddings: - seq_ids.append(producer.submit_embedding(topic, e)) + embeddings, seq_ids = produce_fns(producer, topic, sample_embeddings, 10) segment = SqliteMetadataSegment(system, segment_definition) segment.start() @@ -270,7 +269,9 @@ def test_get( def test_fulltext( - system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] + system: System, + sample_embeddings: Iterator[SubmitEmbeddingRecord], + produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() @@ -279,9 +280,7 @@ def test_fulltext( segment = SqliteMetadataSegment(system, segment_definition) segment.start() - max_id = 0 - for i in range(100): - max_id = producer.submit_embedding(topic, next(sample_embeddings)) + max_id = produce_fns(producer, topic, sample_embeddings, 100)[1][-1] sync(segment, max_id) @@ -331,7 +330,9 @@ def test_fulltext( def test_delete( - system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] + system: System, + sample_embeddings: Iterator[SubmitEmbeddingRecord], + produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() @@ -340,11 +341,8 @@ def test_delete( segment = SqliteMetadataSegment(system, segment_definition) segment.start() - embeddings = [next(sample_embeddings) for i in range(10)] - - max_id = 0 - for e in embeddings: - max_id = producer.submit_embedding(topic, e) + embeddings, seq_ids = produce_fns(producer, topic, sample_embeddings, 10) + max_id = seq_ids[-1] sync(segment, max_id) @@ -353,16 +351,16 @@ def test_delete( assert_equiv_records(embeddings[:1], results) # Delete by ID - max_id = producer.submit_embedding( - topic, - SubmitEmbeddingRecord( - id="embedding_0", - embedding=None, - encoding=None, - metadata=None, - operation=Operation.DELETE, - ), + delete_embedding = SubmitEmbeddingRecord( + id="embedding_0", + embedding=None, + encoding=None, + metadata=None, + operation=Operation.DELETE, ) + max_id = produce_fns(producer, topic, (delete_embedding for _ in range(1)), 1)[1][ + -1 + ] sync(segment, max_id) @@ -370,16 +368,9 @@ def test_delete( assert segment.get_metadata(ids=["embedding_0"]) == [] # Delete is idempotent - max_id = producer.submit_embedding( - topic, - SubmitEmbeddingRecord( - id="embedding_0", - embedding=None, - encoding=None, - metadata=None, - operation=Operation.DELETE, - ), - ) + max_id = produce_fns(producer, topic, (delete_embedding for _ in range(1)), 1)[1][ + -1 + ] sync(segment, max_id) assert segment.count() == 9 @@ -420,7 +411,9 @@ def test_update( def test_upsert( - system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] + system: System, + sample_embeddings: Iterator[SubmitEmbeddingRecord], + produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() @@ -439,7 +432,12 @@ def test_upsert( encoding=None, operation=Operation.UPSERT, ) - max_id = producer.submit_embedding(topic, update_record) + max_id = produce_fns( + producer=producer, + topic=topic, + embeddings=(update_record for _ in range(1)), + n=1, + )[1][-1] sync(segment, max_id) results = segment.get_metadata(ids=["no_such_id"]) assert results[0]["metadata"] == {"foo": "bar"} diff --git a/chromadb/test/segment/test_vector.py b/chromadb/test/segment/test_vector.py index de142d7..111d1db 100644 --- a/chromadb/test/segment/test_vector.py +++ b/chromadb/test/segment/test_vector.py @@ -1,6 +1,7 @@ import pytest from typing import Generator, List, Callable, Iterator, Type, cast from chromadb.config import System, Settings +from chromadb.test.conftest import ProducerFn from chromadb.types import ( SubmitEmbeddingRecord, VectorQuery, @@ -129,6 +130,7 @@ def test_insert_and_count( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], + produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) @@ -136,9 +138,9 @@ def test_insert_and_count( segment_definition = create_random_segment_definition() topic = str(segment_definition["topic"]) - max_id = 0 - for i in range(3): - max_id = producer.submit_embedding(topic, next(sample_embeddings)) + max_id = produce_fns( + producer=producer, topic=topic, n=3, embeddings=sample_embeddings + )[1][-1] segment = vector_reader(system, segment_definition) segment.start() @@ -146,8 +148,10 @@ def test_insert_and_count( sync(segment, max_id) assert segment.count() == 3 - for i in range(3): - max_id = producer.submit_embedding(topic, next(sample_embeddings)) + + max_id = produce_fns( + producer=producer, topic=topic, n=3, embeddings=sample_embeddings + )[1][-1] sync(segment, max_id) assert segment.count() == 6 @@ -165,6 +169,7 @@ def test_get_vectors( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], + produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() @@ -174,11 +179,9 @@ def test_get_vectors( segment = vector_reader(system, segment_definition) segment.start() - embeddings = [next(sample_embeddings) for i in range(10)] - - seq_ids: List[SeqId] = [] - for e in embeddings: - seq_ids.append(producer.submit_embedding(topic, e)) + embeddings, seq_ids = produce_fns( + producer=producer, topic=topic, embeddings=sample_embeddings, n=10 + ) sync(segment, seq_ids[-1]) @@ -210,6 +213,7 @@ def test_ann_query( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], + produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() @@ -219,11 +223,9 @@ def test_ann_query( segment = vector_reader(system, segment_definition) segment.start() - embeddings = [next(sample_embeddings) for i in range(100)] - - seq_ids: List[SeqId] = [] - for e in embeddings: - seq_ids.append(producer.submit_embedding(topic, e)) + embeddings, seq_ids = produce_fns( + producer=producer, topic=topic, embeddings=sample_embeddings, n=100 + ) sync(segment, seq_ids[-1]) @@ -275,6 +277,7 @@ def test_delete( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], + produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() @@ -284,26 +287,28 @@ def test_delete( segment = vector_reader(system, segment_definition) segment.start() - embeddings = [next(sample_embeddings) for i in range(5)] - - seq_ids: List[SeqId] = [] - for e in embeddings: - seq_ids.append(producer.submit_embedding(topic, e)) + embeddings, seq_ids = produce_fns( + producer=producer, topic=topic, embeddings=sample_embeddings, n=5 + ) sync(segment, seq_ids[-1]) assert segment.count() == 5 + delete_record = SubmitEmbeddingRecord( + id=embeddings[0]["id"], + embedding=None, + encoding=None, + metadata=None, + operation=Operation.DELETE, + ) + assert isinstance(seq_ids, List) seq_ids.append( - producer.submit_embedding( - topic, - SubmitEmbeddingRecord( - id=embeddings[0]["id"], - embedding=None, - encoding=None, - metadata=None, - operation=Operation.DELETE, - ), - ) + produce_fns( + producer=producer, + topic=topic, + n=1, + embeddings=(delete_record for _ in range(1)), + )[1][0] ) sync(segment, seq_ids[-1]) @@ -334,16 +339,12 @@ def test_delete( # Delete is idempotent seq_ids.append( - producer.submit_embedding( - topic, - SubmitEmbeddingRecord( - id=embeddings[0]["id"], - embedding=None, - encoding=None, - metadata=None, - operation=Operation.DELETE, - ), - ) + produce_fns( + producer=producer, + topic=topic, + n=1, + embeddings=(delete_record for _ in range(1)), + )[1][0] ) sync(segment, seq_ids[-1]) @@ -416,6 +417,7 @@ def test_update( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], + produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() @@ -428,16 +430,19 @@ def test_update( _test_update(producer, topic, segment, sample_embeddings, Operation.UPDATE) # test updating a nonexistent record - seq_id = producer.submit_embedding( - topic, - SubmitEmbeddingRecord( - id="no_such_record", - embedding=[10.0, 10.0], - encoding=ScalarEncoding.FLOAT32, - metadata=None, - operation=Operation.UPDATE, - ), + update_record = SubmitEmbeddingRecord( + id="no_such_record", + embedding=[10.0, 10.0], + encoding=ScalarEncoding.FLOAT32, + metadata=None, + operation=Operation.UPDATE, ) + seq_id = produce_fns( + producer=producer, + topic=topic, + n=1, + embeddings=(update_record for _ in range(1)), + )[1][0] sync(segment, seq_id) @@ -449,6 +454,7 @@ def test_upsert( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], vector_reader: Type[VectorReader], + produce_fns: ProducerFn, ) -> None: producer = system.instance(Producer) system.reset_state() @@ -461,16 +467,19 @@ def test_upsert( _test_update(producer, topic, segment, sample_embeddings, Operation.UPSERT) # test updating a nonexistent record - seq_id = producer.submit_embedding( - topic, - SubmitEmbeddingRecord( - id="no_such_record", - embedding=[42, 42], - encoding=ScalarEncoding.FLOAT32, - metadata=None, - operation=Operation.UPSERT, - ), + upsert_record = SubmitEmbeddingRecord( + id="no_such_record", + embedding=[42, 42], + encoding=ScalarEncoding.FLOAT32, + metadata=None, + operation=Operation.UPSERT, ) + seq_id = produce_fns( + producer=producer, + topic=topic, + n=1, + embeddings=(upsert_record for _ in range(1)), + )[1][0] sync(segment, seq_id)