[PERF] Batch SQLite Embeddings Queue (#959)

## Description of changes
Adds a batch submit_embeddings() to the producer API. Implements this
behavior in sqlite based embeddings queue. Refactors tests to add a
fixture for both single and batch submit and substitutes all use of
produce in tests with the fixture so both methods are uniformly tested.

This PR is stacked on #958, please read that before reading this.

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- This PR makes the sqlite based embeddings queue batch calls to sqlite
based on the batch size of calls to DML operations.
	 
### Quick Benchmarks
N = 100k
D = 128
metadata = {simple key: simple value}
Document = random 100 character string

**Before - Overall time: 102s**
<img width="587" alt="Screenshot 2023-08-09 at 5 43 12 PM"
src="https://github.com/chroma-core/chroma/assets/5598697/752809a2-78a0-4d5b-a238-fc41ca2635cc">

**After - Overall time: 36s**
<img width="583" alt="Screenshot 2023-08-09 at 7 37 41 PM"
src="https://github.com/chroma-core/chroma/assets/5598697/eb1759c8-8f62-4ed7-ad88-450b3f72692b">

**_Overall, when compared to main this and #958 decrease the time of
this benchmark from 469s -> 36s._**

	
Todo:
- [x] Clean up code
- [x] Rethink assumptions about submit_embeddings ordering guarantees
- [x] Add tests
- [x] Add batching to other DML - upsert(), update(), delete()

## Test plan
Existing tests cover a bulk of the functionality.
New tests TODO:

- [x] Producer/Consumer tests for submit_embeddings
- [x] Segment tests with batch embeddings

## Documentation Changes
None required. But a misc feature idea is hyperparameter sweep for the
optimal batch_size for your platform
This commit is contained in:
Hammad Bashir
2023-08-10 16:11:11 -07:00
committed by GitHub
parent cdb588b7cc
commit 6739259ef8
7 changed files with 353 additions and 166 deletions

View File

@@ -242,9 +242,11 @@ class SegmentAPI(API):
coll = self._get_collection(collection_id) coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.ADD) self._manager.hint_use_collection(collection_id, t.Operation.ADD)
records_to_submit = []
for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents): for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents):
self._validate_embedding_record(coll, r) self._validate_embedding_record(coll, r)
self._producer.submit_embedding(coll["topic"], r) records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids))) self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids)))
return True return True
@@ -261,9 +263,11 @@ class SegmentAPI(API):
coll = self._get_collection(collection_id) coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
records_to_submit = []
for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents): for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents):
self._validate_embedding_record(coll, r) self._validate_embedding_record(coll, r)
self._producer.submit_embedding(coll["topic"], r) records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
return True return True
@@ -279,9 +283,11 @@ class SegmentAPI(API):
coll = self._get_collection(collection_id) coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
records_to_submit = []
for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents): for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents):
self._validate_embedding_record(coll, r) self._validate_embedding_record(coll, r)
self._producer.submit_embedding(coll["topic"], r) records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
return True return True
@@ -376,9 +382,11 @@ class SegmentAPI(API):
else: else:
ids_to_delete = ids ids_to_delete = ids
records_to_submit = []
for r in _records(t.Operation.DELETE, ids_to_delete): for r in _records(t.Operation.DELETE, ids_to_delete):
self._validate_embedding_record(coll, r) self._validate_embedding_record(coll, r)
self._producer.submit_embedding(coll["topic"], r) records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
self._telemetry_client.capture( self._telemetry_client.capture(
CollectionDeleteEvent(str(collection_id), len(ids_to_delete)) CollectionDeleteEvent(str(collection_id), len(ids_to_delete))

View File

@@ -16,7 +16,7 @@ from chromadb.types import (
from chromadb.config import System from chromadb.config import System
from overrides import override from overrides import override
from collections import defaultdict from collections import defaultdict
from typing import Tuple, Optional, Dict, Set, cast from typing import Sequence, Tuple, Optional, Dict, Set, cast
from uuid import UUID from uuid import UUID
from pypika import Table, functions from pypika import Table, functions
import uuid import uuid
@@ -103,22 +103,32 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
if not self._running: if not self._running:
raise RuntimeError("Component not running") raise RuntimeError("Component not running")
if embedding["embedding"]: return self.submit_embeddings(topic_name, [embedding])[0]
encoding_type = cast(ScalarEncoding, embedding["encoding"])
encoding = encoding_type.value
embedding_bytes = encode_vector(embedding["embedding"], encoding_type)
else: @override
embedding_bytes = None def submit_embeddings(
encoding = None self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord]
metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None ) -> Sequence[SeqId]:
if not self._running:
raise RuntimeError("Component not running")
if len(embeddings) == 0:
return []
t = Table("embeddings_queue") t = Table("embeddings_queue")
insert = ( insert = (
self.querybuilder() self.querybuilder()
.into(t) .into(t)
.columns(t.operation, t.topic, t.id, t.vector, t.encoding, t.metadata) .columns(t.operation, t.topic, t.id, t.vector, t.encoding, t.metadata)
.insert( )
id_to_idx: Dict[str, int] = {}
for embedding in embeddings:
(
embedding_bytes,
encoding,
metadata,
) = self._prepare_vector_encoding_metadata(embedding)
insert = insert.insert(
ParameterValue(_operation_codes[embedding["operation"]]), ParameterValue(_operation_codes[embedding["operation"]]),
ParameterValue(topic_name), ParameterValue(topic_name),
ParameterValue(embedding["id"]), ParameterValue(embedding["id"]),
@@ -126,21 +136,34 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
ParameterValue(encoding), ParameterValue(encoding),
ParameterValue(metadata), ParameterValue(metadata),
) )
) id_to_idx[embedding["id"]] = len(id_to_idx)
with self.tx() as cur: with self.tx() as cur:
sql, params = get_sql(insert, self.parameter_format()) sql, params = get_sql(insert, self.parameter_format())
sql = f"{sql} RETURNING seq_id" # Pypika doesn't support RETURNING # The returning clause does not guarantee order, so we need to do reorder
seq_id = int(cur.execute(sql, params).fetchone()[0]) # the results. https://www.sqlite.org/lang_returning.html
embedding_record = EmbeddingRecord( sql = f"{sql} RETURNING seq_id, id" # Pypika doesn't support RETURNING
id=embedding["id"], results = cur.execute(sql, params).fetchall()
seq_id=seq_id, # Reorder the results
embedding=embedding["embedding"], seq_ids = [cast(SeqId, None)] * len(
encoding=embedding["encoding"], results
metadata=embedding["metadata"], ) # Lie to mypy: https://stackoverflow.com/questions/76694215/python-type-casting-when-preallocating-list
operation=embedding["operation"], embedding_records = []
) for seq_id, id in results:
self._notify_all(topic_name, embedding_record) seq_ids[id_to_idx[id]] = seq_id
return seq_id submit_embedding_record = embeddings[id_to_idx[id]]
# We allow notifying consumers out of order relative to one call to
# submit_embeddings so we do not reorder the records before submitting them
embedding_record = EmbeddingRecord(
id=id,
seq_id=seq_id,
embedding=submit_embedding_record["embedding"],
encoding=submit_embedding_record["encoding"],
metadata=submit_embedding_record["metadata"],
operation=submit_embedding_record["operation"],
)
embedding_records.append(embedding_record)
self._notify_all(topic_name, embedding_records)
return seq_ids
@override @override
def subscribe( def subscribe(
@@ -185,6 +208,19 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
def max_seqid(self) -> SeqId: def max_seqid(self) -> SeqId:
return 2**63 - 1 return 2**63 - 1
def _prepare_vector_encoding_metadata(
self, embedding: SubmitEmbeddingRecord
) -> Tuple[Optional[bytes], Optional[str], Optional[str]]:
if embedding["embedding"]:
encoding_type = cast(ScalarEncoding, embedding["encoding"])
encoding = encoding_type.value
embedding_bytes = encode_vector(embedding["embedding"], encoding_type)
else:
embedding_bytes = None
encoding = None
metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None
return embedding_bytes, encoding, metadata
def _backfill(self, subscription: Subscription) -> None: def _backfill(self, subscription: Subscription) -> None:
"""Backfill the given subscription with any currently matching records in the """Backfill the given subscription with any currently matching records in the
DB""" DB"""
@@ -211,14 +247,16 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
vector = None vector = None
self._notify_one( self._notify_one(
subscription, subscription,
EmbeddingRecord( [
seq_id=row[0], EmbeddingRecord(
operation=_operation_codes_inv[row[1]], seq_id=row[0],
id=row[2], operation=_operation_codes_inv[row[1]],
embedding=vector, id=row[2],
encoding=encoding, embedding=vector,
metadata=json.loads(row[5]) if row[5] else None, encoding=encoding,
), metadata=json.loads(row[5]) if row[5] else None,
)
],
) )
def _validate_range( def _validate_range(
@@ -242,29 +280,37 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
cur.execute(q.get_sql()) cur.execute(q.get_sql())
return int(cur.fetchone()[0]) + 1 return int(cur.fetchone()[0]) + 1
def _notify_all(self, topic: str, embedding: EmbeddingRecord) -> None: def _notify_all(self, topic: str, embeddings: Sequence[EmbeddingRecord]) -> None:
"""Send a notification to each subscriber of the given topic.""" """Send a notification to each subscriber of the given topic."""
if self._running: if self._running:
for sub in self._subscriptions[topic]: for sub in self._subscriptions[topic]:
self._notify_one(sub, embedding) self._notify_one(sub, embeddings)
def _notify_one(self, sub: Subscription, embedding: EmbeddingRecord) -> None: def _notify_one(
self, sub: Subscription, embeddings: Sequence[EmbeddingRecord]
) -> None:
"""Send a notification to a single subscriber.""" """Send a notification to a single subscriber."""
if embedding["seq_id"] > sub.end: # Filter out any embeddings that are not in the subscription range
self.unsubscribe(sub.id) should_unsubscribe = False
return filtered_embeddings = []
for embedding in embeddings:
if embedding["seq_id"] <= sub.start: if embedding["seq_id"] <= sub.start:
return continue
if embedding["seq_id"] > sub.end:
should_unsubscribe = True
break
filtered_embeddings.append(embedding)
# Log errors instead of throwing them to preserve async semantics # Log errors instead of throwing them to preserve async semantics
# for consistency between local and distributed configurations # for consistency between local and distributed configurations
try: try:
sub.callback([embedding]) if len(filtered_embeddings) > 0:
sub.callback(filtered_embeddings)
if should_unsubscribe:
self.unsubscribe(sub.id)
except BaseException as e: except BaseException as e:
id = embedding.get("id", embedding.get("delete_id"))
logger.error( logger.error(
f"Exception occurred invoking consumer for subscription {sub.id}" f"Exception occurred invoking consumer for subscription {sub.id}"
+ f"to topic {sub.topic_name} for embedding id {id} ", + f"to topic {sub.topic_name}",
e, e,
) )

View File

@@ -52,6 +52,16 @@ class Producer(Component):
"""Add an embedding record to the given topic. Returns the SeqID of the record.""" """Add an embedding record to the given topic. Returns the SeqID of the record."""
pass pass
@abstractmethod
def submit_embeddings(
self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord]
) -> Sequence[SeqId]:
"""Add a batch of embedding records to the given topic. Returns the SeqIDs of
the records. The returned SeqIDs will be in the same order as the given
SubmitEmbeddingRecords. However, it is not guaranteed that the SeqIDs will be
processed in the same order as the given SubmitEmbeddingRecords."""
pass
ConsumerCallbackFn = Callable[[Sequence[EmbeddingRecord]], None] ConsumerCallbackFn = Callable[[Sequence[EmbeddingRecord]], None]

View File

@@ -1,5 +1,6 @@
from chromadb.config import Settings, System from chromadb.config import Settings, System
from chromadb.api import API from chromadb.api import API
from chromadb.ingest import Producer
import chromadb.server.fastapi import chromadb.server.fastapi
from requests.exceptions import ConnectionError from requests.exceptions import ConnectionError
import hypothesis import hypothesis
@@ -8,12 +9,23 @@ import os
import uvicorn import uvicorn
import time import time
import pytest import pytest
from typing import Generator, List, Callable, Optional, Tuple from typing import (
Generator,
Iterator,
List,
Optional,
Sequence,
Tuple,
Callable,
)
from typing_extensions import Protocol
import shutil import shutil
import logging import logging
import socket import socket
import multiprocessing import multiprocessing
from chromadb.types import SeqId, SubmitEmbeddingRecord
root_logger = logging.getLogger() root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG) # This will only run when testing root_logger.setLevel(logging.DEBUG) # This will only run when testing
@@ -189,3 +201,59 @@ def api(system: System) -> Generator[API, None, None]:
system.reset_state() system.reset_state()
api = system.instance(API) api = system.instance(API)
yield api yield api
# Producer / Consumer fixtures #
class ProducerFn(Protocol):
def __call__(
self,
producer: Producer,
topic: str,
embeddings: Iterator[SubmitEmbeddingRecord],
n: int,
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
...
def produce_n_single(
producer: Producer,
topic: str,
embeddings: Iterator[SubmitEmbeddingRecord],
n: int,
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
submitted_embeddings = []
seq_ids = []
for _ in range(n):
e = next(embeddings)
seq_id = producer.submit_embedding(topic, e)
submitted_embeddings.append(e)
seq_ids.append(seq_id)
return submitted_embeddings, seq_ids
def produce_n_batch(
producer: Producer,
topic: str,
embeddings: Iterator[SubmitEmbeddingRecord],
n: int,
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
submitted_embeddings = []
seq_ids: Sequence[SeqId] = []
for _ in range(n):
e = next(embeddings)
submitted_embeddings.append(e)
seq_ids = producer.submit_embeddings(topic, submitted_embeddings)
return submitted_embeddings, seq_ids
def produce_fn_fixtures() -> List[ProducerFn]:
return [produce_n_single, produce_n_batch]
@pytest.fixture(scope="module", params=produce_fn_fixtures())
def produce_fns(
request: pytest.FixtureRequest,
) -> Generator[ProducerFn, None, None]:
yield request.param

View File

@@ -16,6 +16,7 @@ from typing import (
) )
from chromadb.ingest import Producer, Consumer from chromadb.ingest import Producer, Consumer
from chromadb.db.impl.sqlite import SqliteDB from chromadb.db.impl.sqlite import SqliteDB
from chromadb.test.conftest import ProducerFn
from chromadb.types import ( from chromadb.types import (
SubmitEmbeddingRecord, SubmitEmbeddingRecord,
Operation, Operation,
@@ -135,15 +136,13 @@ def assert_records_match(
async def test_backfill( async def test_backfill(
producer_consumer: Tuple[Producer, Consumer], producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord], sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer, consumer = producer_consumer producer, consumer = producer_consumer
producer.reset_state() producer.reset_state()
embeddings = [next(sample_embeddings) for _ in range(3)]
producer.create_topic("test_topic") producer.create_topic("test_topic")
for e in embeddings: embeddings = produce_fns(producer, "test_topic", sample_embeddings, 3)[0]
producer.submit_embedding("test_topic", e)
consume_fn = CapturingConsumeFn() consume_fn = CapturingConsumeFn()
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid()) consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
@@ -212,6 +211,7 @@ async def test_multiple_topics(
async def test_start_seq_id( async def test_start_seq_id(
producer_consumer: Tuple[Producer, Consumer], producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord], sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer, consumer = producer_consumer producer, consumer = producer_consumer
producer.reset_state() producer.reset_state()
@@ -222,22 +222,16 @@ async def test_start_seq_id(
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid()) consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
embeddings = [] embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
for _ in range(5):
e = next(sample_embeddings)
embeddings.append(e)
producer.submit_embedding("test_topic", e)
results_1 = await consume_fn_1.get(5) results_1 = await consume_fn_1.get(5)
assert_records_match(embeddings, results_1) assert_records_match(embeddings, results_1)
start = consume_fn_1.embeddings[-1]["seq_id"] start = consume_fn_1.embeddings[-1]["seq_id"]
consumer.subscribe("test_topic", consume_fn_2, start=start) consumer.subscribe("test_topic", consume_fn_2, start=start)
for _ in range(5): second_embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
e = next(sample_embeddings) assert isinstance(embeddings, list)
embeddings.append(e) embeddings.extend(second_embeddings)
producer.submit_embedding("test_topic", e)
results_2 = await consume_fn_2.get(5) results_2 = await consume_fn_2.get(5)
assert_records_match(embeddings[-5:], results_2) assert_records_match(embeddings[-5:], results_2)
@@ -246,6 +240,7 @@ async def test_start_seq_id(
async def test_end_seq_id( async def test_end_seq_id(
producer_consumer: Tuple[Producer, Consumer], producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord], sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer, consumer = producer_consumer producer, consumer = producer_consumer
producer.reset_state() producer.reset_state()
@@ -256,11 +251,7 @@ async def test_end_seq_id(
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid()) consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
embeddings = [] embeddings = produce_fns(producer, "test_topic", sample_embeddings, 10)[0]
for _ in range(10):
e = next(sample_embeddings)
embeddings.append(e)
producer.submit_embedding("test_topic", e)
results_1 = await consume_fn_1.get(10) results_1 = await consume_fn_1.get(10)
assert_records_match(embeddings, results_1) assert_records_match(embeddings, results_1)
@@ -274,3 +265,60 @@ async def test_end_seq_id(
# Should never produce a 7th # Should never produce a 7th
with pytest.raises(TimeoutError): with pytest.raises(TimeoutError):
_ = await wait_for(consume_fn_2.get(7), timeout=1) _ = await wait_for(consume_fn_2.get(7), timeout=1)
@pytest.mark.asyncio
async def test_submit_batch(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
embeddings = [next(sample_embeddings) for _ in range(100)]
producer.create_topic("test_topic")
producer.submit_embeddings("test_topic", embeddings=embeddings)
consume_fn = CapturingConsumeFn()
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
recieved = await consume_fn.get(100)
assert_records_match(embeddings, recieved)
@pytest.mark.asyncio
async def test_multiple_topics_batch(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
N_TOPICS = 100
consume_fns = [CapturingConsumeFn() for _ in range(N_TOPICS)]
for i in range(N_TOPICS):
producer.create_topic(f"test_topic_{i}")
consumer.subscribe(
f"test_topic_{i}", consume_fns[i], start=consumer.min_seqid()
)
embeddings_n: List[List[SubmitEmbeddingRecord]] = [[] for _ in range(N_TOPICS)]
PRODUCE_BATCH_SIZE = 10
N_TO_PRODUCE = 100
total_produced = 0
for i in range(N_TO_PRODUCE // PRODUCE_BATCH_SIZE):
for i in range(N_TOPICS):
embeddings_n[i].extend(
produce_fns(
producer,
f"test_topic_{i}",
sample_embeddings,
PRODUCE_BATCH_SIZE,
)[0]
)
recieved = await consume_fns[i].get(total_produced + PRODUCE_BATCH_SIZE)
assert_records_match(embeddings_n[i], recieved)
total_produced += PRODUCE_BATCH_SIZE

View File

@@ -4,6 +4,7 @@ import tempfile
import pytest import pytest
from typing import Generator, List, Callable, Iterator, Dict, Optional, Union, Sequence from typing import Generator, List, Callable, Iterator, Dict, Optional, Union, Sequence
from chromadb.config import System, Settings from chromadb.config import System, Settings
from chromadb.test.conftest import ProducerFn
from chromadb.types import ( from chromadb.types import (
SubmitEmbeddingRecord, SubmitEmbeddingRecord,
MetadataEmbeddingRecord, MetadataEmbeddingRecord,
@@ -128,16 +129,16 @@ def sync(segment: MetadataReader, seq_id: SeqId) -> None:
def test_insert_and_count( def test_insert_and_count(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer = system.instance(Producer) producer = system.instance(Producer)
system.reset_state() system.reset_state()
topic = str(segment_definition["topic"]) topic = str(segment_definition["topic"])
max_id = 0 max_id = produce_fns(producer, topic, sample_embeddings, 3)[1][-1]
for i in range(3):
max_id = producer.submit_embedding(topic, next(sample_embeddings))
segment = SqliteMetadataSegment(system, segment_definition) segment = SqliteMetadataSegment(system, segment_definition)
segment.start() segment.start()
@@ -166,17 +167,15 @@ def assert_equiv_records(
def test_get( def test_get(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer = system.instance(Producer) producer = system.instance(Producer)
system.reset_state() system.reset_state()
topic = str(segment_definition["topic"]) topic = str(segment_definition["topic"])
embeddings = [next(sample_embeddings) for i in range(10)] embeddings, seq_ids = produce_fns(producer, topic, sample_embeddings, 10)
seq_ids = []
for e in embeddings:
seq_ids.append(producer.submit_embedding(topic, e))
segment = SqliteMetadataSegment(system, segment_definition) segment = SqliteMetadataSegment(system, segment_definition)
segment.start() segment.start()
@@ -270,7 +269,9 @@ def test_get(
def test_fulltext( def test_fulltext(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer = system.instance(Producer) producer = system.instance(Producer)
system.reset_state() system.reset_state()
@@ -279,9 +280,7 @@ def test_fulltext(
segment = SqliteMetadataSegment(system, segment_definition) segment = SqliteMetadataSegment(system, segment_definition)
segment.start() segment.start()
max_id = 0 max_id = produce_fns(producer, topic, sample_embeddings, 100)[1][-1]
for i in range(100):
max_id = producer.submit_embedding(topic, next(sample_embeddings))
sync(segment, max_id) sync(segment, max_id)
@@ -331,7 +330,9 @@ def test_fulltext(
def test_delete( def test_delete(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer = system.instance(Producer) producer = system.instance(Producer)
system.reset_state() system.reset_state()
@@ -340,11 +341,8 @@ def test_delete(
segment = SqliteMetadataSegment(system, segment_definition) segment = SqliteMetadataSegment(system, segment_definition)
segment.start() segment.start()
embeddings = [next(sample_embeddings) for i in range(10)] embeddings, seq_ids = produce_fns(producer, topic, sample_embeddings, 10)
max_id = seq_ids[-1]
max_id = 0
for e in embeddings:
max_id = producer.submit_embedding(topic, e)
sync(segment, max_id) sync(segment, max_id)
@@ -353,16 +351,16 @@ def test_delete(
assert_equiv_records(embeddings[:1], results) assert_equiv_records(embeddings[:1], results)
# Delete by ID # Delete by ID
max_id = producer.submit_embedding( delete_embedding = SubmitEmbeddingRecord(
topic, id="embedding_0",
SubmitEmbeddingRecord( embedding=None,
id="embedding_0", encoding=None,
embedding=None, metadata=None,
encoding=None, operation=Operation.DELETE,
metadata=None,
operation=Operation.DELETE,
),
) )
max_id = produce_fns(producer, topic, (delete_embedding for _ in range(1)), 1)[1][
-1
]
sync(segment, max_id) sync(segment, max_id)
@@ -370,16 +368,9 @@ def test_delete(
assert segment.get_metadata(ids=["embedding_0"]) == [] assert segment.get_metadata(ids=["embedding_0"]) == []
# Delete is idempotent # Delete is idempotent
max_id = producer.submit_embedding( max_id = produce_fns(producer, topic, (delete_embedding for _ in range(1)), 1)[1][
topic, -1
SubmitEmbeddingRecord( ]
id="embedding_0",
embedding=None,
encoding=None,
metadata=None,
operation=Operation.DELETE,
),
)
sync(segment, max_id) sync(segment, max_id)
assert segment.count() == 9 assert segment.count() == 9
@@ -420,7 +411,9 @@ def test_update(
def test_upsert( def test_upsert(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord] system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer = system.instance(Producer) producer = system.instance(Producer)
system.reset_state() system.reset_state()
@@ -439,7 +432,12 @@ def test_upsert(
encoding=None, encoding=None,
operation=Operation.UPSERT, operation=Operation.UPSERT,
) )
max_id = producer.submit_embedding(topic, update_record) max_id = produce_fns(
producer=producer,
topic=topic,
embeddings=(update_record for _ in range(1)),
n=1,
)[1][-1]
sync(segment, max_id) sync(segment, max_id)
results = segment.get_metadata(ids=["no_such_id"]) results = segment.get_metadata(ids=["no_such_id"])
assert results[0]["metadata"] == {"foo": "bar"} assert results[0]["metadata"] == {"foo": "bar"}

View File

@@ -1,6 +1,7 @@
import pytest import pytest
from typing import Generator, List, Callable, Iterator, Type, cast from typing import Generator, List, Callable, Iterator, Type, cast
from chromadb.config import System, Settings from chromadb.config import System, Settings
from chromadb.test.conftest import ProducerFn
from chromadb.types import ( from chromadb.types import (
SubmitEmbeddingRecord, SubmitEmbeddingRecord,
VectorQuery, VectorQuery,
@@ -129,6 +130,7 @@ def test_insert_and_count(
system: System, system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord], sample_embeddings: Iterator[SubmitEmbeddingRecord],
vector_reader: Type[VectorReader], vector_reader: Type[VectorReader],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer = system.instance(Producer) producer = system.instance(Producer)
@@ -136,9 +138,9 @@ def test_insert_and_count(
segment_definition = create_random_segment_definition() segment_definition = create_random_segment_definition()
topic = str(segment_definition["topic"]) topic = str(segment_definition["topic"])
max_id = 0 max_id = produce_fns(
for i in range(3): producer=producer, topic=topic, n=3, embeddings=sample_embeddings
max_id = producer.submit_embedding(topic, next(sample_embeddings)) )[1][-1]
segment = vector_reader(system, segment_definition) segment = vector_reader(system, segment_definition)
segment.start() segment.start()
@@ -146,8 +148,10 @@ def test_insert_and_count(
sync(segment, max_id) sync(segment, max_id)
assert segment.count() == 3 assert segment.count() == 3
for i in range(3):
max_id = producer.submit_embedding(topic, next(sample_embeddings)) max_id = produce_fns(
producer=producer, topic=topic, n=3, embeddings=sample_embeddings
)[1][-1]
sync(segment, max_id) sync(segment, max_id)
assert segment.count() == 6 assert segment.count() == 6
@@ -165,6 +169,7 @@ def test_get_vectors(
system: System, system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord], sample_embeddings: Iterator[SubmitEmbeddingRecord],
vector_reader: Type[VectorReader], vector_reader: Type[VectorReader],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer = system.instance(Producer) producer = system.instance(Producer)
system.reset_state() system.reset_state()
@@ -174,11 +179,9 @@ def test_get_vectors(
segment = vector_reader(system, segment_definition) segment = vector_reader(system, segment_definition)
segment.start() segment.start()
embeddings = [next(sample_embeddings) for i in range(10)] embeddings, seq_ids = produce_fns(
producer=producer, topic=topic, embeddings=sample_embeddings, n=10
seq_ids: List[SeqId] = [] )
for e in embeddings:
seq_ids.append(producer.submit_embedding(topic, e))
sync(segment, seq_ids[-1]) sync(segment, seq_ids[-1])
@@ -210,6 +213,7 @@ def test_ann_query(
system: System, system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord], sample_embeddings: Iterator[SubmitEmbeddingRecord],
vector_reader: Type[VectorReader], vector_reader: Type[VectorReader],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer = system.instance(Producer) producer = system.instance(Producer)
system.reset_state() system.reset_state()
@@ -219,11 +223,9 @@ def test_ann_query(
segment = vector_reader(system, segment_definition) segment = vector_reader(system, segment_definition)
segment.start() segment.start()
embeddings = [next(sample_embeddings) for i in range(100)] embeddings, seq_ids = produce_fns(
producer=producer, topic=topic, embeddings=sample_embeddings, n=100
seq_ids: List[SeqId] = [] )
for e in embeddings:
seq_ids.append(producer.submit_embedding(topic, e))
sync(segment, seq_ids[-1]) sync(segment, seq_ids[-1])
@@ -275,6 +277,7 @@ def test_delete(
system: System, system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord], sample_embeddings: Iterator[SubmitEmbeddingRecord],
vector_reader: Type[VectorReader], vector_reader: Type[VectorReader],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer = system.instance(Producer) producer = system.instance(Producer)
system.reset_state() system.reset_state()
@@ -284,26 +287,28 @@ def test_delete(
segment = vector_reader(system, segment_definition) segment = vector_reader(system, segment_definition)
segment.start() segment.start()
embeddings = [next(sample_embeddings) for i in range(5)] embeddings, seq_ids = produce_fns(
producer=producer, topic=topic, embeddings=sample_embeddings, n=5
seq_ids: List[SeqId] = [] )
for e in embeddings:
seq_ids.append(producer.submit_embedding(topic, e))
sync(segment, seq_ids[-1]) sync(segment, seq_ids[-1])
assert segment.count() == 5 assert segment.count() == 5
delete_record = SubmitEmbeddingRecord(
id=embeddings[0]["id"],
embedding=None,
encoding=None,
metadata=None,
operation=Operation.DELETE,
)
assert isinstance(seq_ids, List)
seq_ids.append( seq_ids.append(
producer.submit_embedding( produce_fns(
topic, producer=producer,
SubmitEmbeddingRecord( topic=topic,
id=embeddings[0]["id"], n=1,
embedding=None, embeddings=(delete_record for _ in range(1)),
encoding=None, )[1][0]
metadata=None,
operation=Operation.DELETE,
),
)
) )
sync(segment, seq_ids[-1]) sync(segment, seq_ids[-1])
@@ -334,16 +339,12 @@ def test_delete(
# Delete is idempotent # Delete is idempotent
seq_ids.append( seq_ids.append(
producer.submit_embedding( produce_fns(
topic, producer=producer,
SubmitEmbeddingRecord( topic=topic,
id=embeddings[0]["id"], n=1,
embedding=None, embeddings=(delete_record for _ in range(1)),
encoding=None, )[1][0]
metadata=None,
operation=Operation.DELETE,
),
)
) )
sync(segment, seq_ids[-1]) sync(segment, seq_ids[-1])
@@ -416,6 +417,7 @@ def test_update(
system: System, system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord], sample_embeddings: Iterator[SubmitEmbeddingRecord],
vector_reader: Type[VectorReader], vector_reader: Type[VectorReader],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer = system.instance(Producer) producer = system.instance(Producer)
system.reset_state() system.reset_state()
@@ -428,16 +430,19 @@ def test_update(
_test_update(producer, topic, segment, sample_embeddings, Operation.UPDATE) _test_update(producer, topic, segment, sample_embeddings, Operation.UPDATE)
# test updating a nonexistent record # test updating a nonexistent record
seq_id = producer.submit_embedding( update_record = SubmitEmbeddingRecord(
topic, id="no_such_record",
SubmitEmbeddingRecord( embedding=[10.0, 10.0],
id="no_such_record", encoding=ScalarEncoding.FLOAT32,
embedding=[10.0, 10.0], metadata=None,
encoding=ScalarEncoding.FLOAT32, operation=Operation.UPDATE,
metadata=None,
operation=Operation.UPDATE,
),
) )
seq_id = produce_fns(
producer=producer,
topic=topic,
n=1,
embeddings=(update_record for _ in range(1)),
)[1][0]
sync(segment, seq_id) sync(segment, seq_id)
@@ -449,6 +454,7 @@ def test_upsert(
system: System, system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord], sample_embeddings: Iterator[SubmitEmbeddingRecord],
vector_reader: Type[VectorReader], vector_reader: Type[VectorReader],
produce_fns: ProducerFn,
) -> None: ) -> None:
producer = system.instance(Producer) producer = system.instance(Producer)
system.reset_state() system.reset_state()
@@ -461,16 +467,19 @@ def test_upsert(
_test_update(producer, topic, segment, sample_embeddings, Operation.UPSERT) _test_update(producer, topic, segment, sample_embeddings, Operation.UPSERT)
# test updating a nonexistent record # test updating a nonexistent record
seq_id = producer.submit_embedding( upsert_record = SubmitEmbeddingRecord(
topic, id="no_such_record",
SubmitEmbeddingRecord( embedding=[42, 42],
id="no_such_record", encoding=ScalarEncoding.FLOAT32,
embedding=[42, 42], metadata=None,
encoding=ScalarEncoding.FLOAT32, operation=Operation.UPSERT,
metadata=None,
operation=Operation.UPSERT,
),
) )
seq_id = produce_fns(
producer=producer,
topic=topic,
n=1,
embeddings=(upsert_record for _ in range(1)),
)[1][0]
sync(segment, seq_id) sync(segment, seq_id)