[PERF] Batch SQLite Embeddings Queue (#959)

## Description of changes
Adds a batch submit_embeddings() to the producer API. Implements this
behavior in sqlite based embeddings queue. Refactors tests to add a
fixture for both single and batch submit and substitutes all use of
produce in tests with the fixture so both methods are uniformly tested.

This PR is stacked on #958, please read that before reading this.

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- This PR makes the sqlite based embeddings queue batch calls to sqlite
based on the batch size of calls to DML operations.
	 
### Quick Benchmarks
N = 100k
D = 128
metadata = {simple key: simple value}
Document = random 100 character string

**Before - Overall time: 102s**
<img width="587" alt="Screenshot 2023-08-09 at 5 43 12 PM"
src="https://github.com/chroma-core/chroma/assets/5598697/752809a2-78a0-4d5b-a238-fc41ca2635cc">

**After - Overall time: 36s**
<img width="583" alt="Screenshot 2023-08-09 at 7 37 41 PM"
src="https://github.com/chroma-core/chroma/assets/5598697/eb1759c8-8f62-4ed7-ad88-450b3f72692b">

**_Overall, when compared to main this and #958 decrease the time of
this benchmark from 469s -> 36s._**

	
Todo:
- [x] Clean up code
- [x] Rethink assumptions about submit_embeddings ordering guarantees
- [x] Add tests
- [x] Add batching to other DML - upsert(), update(), delete()

## Test plan
Existing tests cover a bulk of the functionality.
New tests TODO:

- [x] Producer/Consumer tests for submit_embeddings
- [x] Segment tests with batch embeddings

## Documentation Changes
None required. But a misc feature idea is hyperparameter sweep for the
optimal batch_size for your platform
This commit is contained in:
Hammad Bashir
2023-08-10 16:11:11 -07:00
committed by GitHub
parent cdb588b7cc
commit 6739259ef8
7 changed files with 353 additions and 166 deletions

View File

@@ -242,9 +242,11 @@ class SegmentAPI(API):
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.ADD)
records_to_submit = []
for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents):
self._validate_embedding_record(coll, r)
self._producer.submit_embedding(coll["topic"], r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids)))
return True
@@ -261,9 +263,11 @@ class SegmentAPI(API):
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
records_to_submit = []
for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents):
self._validate_embedding_record(coll, r)
self._producer.submit_embedding(coll["topic"], r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
return True
@@ -279,9 +283,11 @@ class SegmentAPI(API):
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
records_to_submit = []
for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents):
self._validate_embedding_record(coll, r)
self._producer.submit_embedding(coll["topic"], r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
return True
@@ -376,9 +382,11 @@ class SegmentAPI(API):
else:
ids_to_delete = ids
records_to_submit = []
for r in _records(t.Operation.DELETE, ids_to_delete):
self._validate_embedding_record(coll, r)
self._producer.submit_embedding(coll["topic"], r)
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)
self._telemetry_client.capture(
CollectionDeleteEvent(str(collection_id), len(ids_to_delete))

View File

@@ -16,7 +16,7 @@ from chromadb.types import (
from chromadb.config import System
from overrides import override
from collections import defaultdict
from typing import Tuple, Optional, Dict, Set, cast
from typing import Sequence, Tuple, Optional, Dict, Set, cast
from uuid import UUID
from pypika import Table, functions
import uuid
@@ -103,22 +103,32 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
if not self._running:
raise RuntimeError("Component not running")
if embedding["embedding"]:
encoding_type = cast(ScalarEncoding, embedding["encoding"])
encoding = encoding_type.value
embedding_bytes = encode_vector(embedding["embedding"], encoding_type)
return self.submit_embeddings(topic_name, [embedding])[0]
else:
embedding_bytes = None
encoding = None
metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None
@override
def submit_embeddings(
self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord]
) -> Sequence[SeqId]:
if not self._running:
raise RuntimeError("Component not running")
if len(embeddings) == 0:
return []
t = Table("embeddings_queue")
insert = (
self.querybuilder()
.into(t)
.columns(t.operation, t.topic, t.id, t.vector, t.encoding, t.metadata)
.insert(
)
id_to_idx: Dict[str, int] = {}
for embedding in embeddings:
(
embedding_bytes,
encoding,
metadata,
) = self._prepare_vector_encoding_metadata(embedding)
insert = insert.insert(
ParameterValue(_operation_codes[embedding["operation"]]),
ParameterValue(topic_name),
ParameterValue(embedding["id"]),
@@ -126,21 +136,34 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
ParameterValue(encoding),
ParameterValue(metadata),
)
)
id_to_idx[embedding["id"]] = len(id_to_idx)
with self.tx() as cur:
sql, params = get_sql(insert, self.parameter_format())
sql = f"{sql} RETURNING seq_id" # Pypika doesn't support RETURNING
seq_id = int(cur.execute(sql, params).fetchone()[0])
embedding_record = EmbeddingRecord(
id=embedding["id"],
seq_id=seq_id,
embedding=embedding["embedding"],
encoding=embedding["encoding"],
metadata=embedding["metadata"],
operation=embedding["operation"],
)
self._notify_all(topic_name, embedding_record)
return seq_id
# The returning clause does not guarantee order, so we need to do reorder
# the results. https://www.sqlite.org/lang_returning.html
sql = f"{sql} RETURNING seq_id, id" # Pypika doesn't support RETURNING
results = cur.execute(sql, params).fetchall()
# Reorder the results
seq_ids = [cast(SeqId, None)] * len(
results
) # Lie to mypy: https://stackoverflow.com/questions/76694215/python-type-casting-when-preallocating-list
embedding_records = []
for seq_id, id in results:
seq_ids[id_to_idx[id]] = seq_id
submit_embedding_record = embeddings[id_to_idx[id]]
# We allow notifying consumers out of order relative to one call to
# submit_embeddings so we do not reorder the records before submitting them
embedding_record = EmbeddingRecord(
id=id,
seq_id=seq_id,
embedding=submit_embedding_record["embedding"],
encoding=submit_embedding_record["encoding"],
metadata=submit_embedding_record["metadata"],
operation=submit_embedding_record["operation"],
)
embedding_records.append(embedding_record)
self._notify_all(topic_name, embedding_records)
return seq_ids
@override
def subscribe(
@@ -185,6 +208,19 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
def max_seqid(self) -> SeqId:
return 2**63 - 1
def _prepare_vector_encoding_metadata(
self, embedding: SubmitEmbeddingRecord
) -> Tuple[Optional[bytes], Optional[str], Optional[str]]:
if embedding["embedding"]:
encoding_type = cast(ScalarEncoding, embedding["encoding"])
encoding = encoding_type.value
embedding_bytes = encode_vector(embedding["embedding"], encoding_type)
else:
embedding_bytes = None
encoding = None
metadata = json.dumps(embedding["metadata"]) if embedding["metadata"] else None
return embedding_bytes, encoding, metadata
def _backfill(self, subscription: Subscription) -> None:
"""Backfill the given subscription with any currently matching records in the
DB"""
@@ -211,14 +247,16 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
vector = None
self._notify_one(
subscription,
EmbeddingRecord(
seq_id=row[0],
operation=_operation_codes_inv[row[1]],
id=row[2],
embedding=vector,
encoding=encoding,
metadata=json.loads(row[5]) if row[5] else None,
),
[
EmbeddingRecord(
seq_id=row[0],
operation=_operation_codes_inv[row[1]],
id=row[2],
embedding=vector,
encoding=encoding,
metadata=json.loads(row[5]) if row[5] else None,
)
],
)
def _validate_range(
@@ -242,29 +280,37 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
cur.execute(q.get_sql())
return int(cur.fetchone()[0]) + 1
def _notify_all(self, topic: str, embedding: EmbeddingRecord) -> None:
def _notify_all(self, topic: str, embeddings: Sequence[EmbeddingRecord]) -> None:
"""Send a notification to each subscriber of the given topic."""
if self._running:
for sub in self._subscriptions[topic]:
self._notify_one(sub, embedding)
self._notify_one(sub, embeddings)
def _notify_one(self, sub: Subscription, embedding: EmbeddingRecord) -> None:
def _notify_one(
self, sub: Subscription, embeddings: Sequence[EmbeddingRecord]
) -> None:
"""Send a notification to a single subscriber."""
if embedding["seq_id"] > sub.end:
self.unsubscribe(sub.id)
return
if embedding["seq_id"] <= sub.start:
return
# Filter out any embeddings that are not in the subscription range
should_unsubscribe = False
filtered_embeddings = []
for embedding in embeddings:
if embedding["seq_id"] <= sub.start:
continue
if embedding["seq_id"] > sub.end:
should_unsubscribe = True
break
filtered_embeddings.append(embedding)
# Log errors instead of throwing them to preserve async semantics
# for consistency between local and distributed configurations
try:
sub.callback([embedding])
if len(filtered_embeddings) > 0:
sub.callback(filtered_embeddings)
if should_unsubscribe:
self.unsubscribe(sub.id)
except BaseException as e:
id = embedding.get("id", embedding.get("delete_id"))
logger.error(
f"Exception occurred invoking consumer for subscription {sub.id}"
+ f"to topic {sub.topic_name} for embedding id {id} ",
+ f"to topic {sub.topic_name}",
e,
)

View File

@@ -52,6 +52,16 @@ class Producer(Component):
"""Add an embedding record to the given topic. Returns the SeqID of the record."""
pass
@abstractmethod
def submit_embeddings(
self, topic_name: str, embeddings: Sequence[SubmitEmbeddingRecord]
) -> Sequence[SeqId]:
"""Add a batch of embedding records to the given topic. Returns the SeqIDs of
the records. The returned SeqIDs will be in the same order as the given
SubmitEmbeddingRecords. However, it is not guaranteed that the SeqIDs will be
processed in the same order as the given SubmitEmbeddingRecords."""
pass
ConsumerCallbackFn = Callable[[Sequence[EmbeddingRecord]], None]

View File

@@ -1,5 +1,6 @@
from chromadb.config import Settings, System
from chromadb.api import API
from chromadb.ingest import Producer
import chromadb.server.fastapi
from requests.exceptions import ConnectionError
import hypothesis
@@ -8,12 +9,23 @@ import os
import uvicorn
import time
import pytest
from typing import Generator, List, Callable, Optional, Tuple
from typing import (
Generator,
Iterator,
List,
Optional,
Sequence,
Tuple,
Callable,
)
from typing_extensions import Protocol
import shutil
import logging
import socket
import multiprocessing
from chromadb.types import SeqId, SubmitEmbeddingRecord
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG) # This will only run when testing
@@ -189,3 +201,59 @@ def api(system: System) -> Generator[API, None, None]:
system.reset_state()
api = system.instance(API)
yield api
# Producer / Consumer fixtures #
class ProducerFn(Protocol):
def __call__(
self,
producer: Producer,
topic: str,
embeddings: Iterator[SubmitEmbeddingRecord],
n: int,
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
...
def produce_n_single(
producer: Producer,
topic: str,
embeddings: Iterator[SubmitEmbeddingRecord],
n: int,
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
submitted_embeddings = []
seq_ids = []
for _ in range(n):
e = next(embeddings)
seq_id = producer.submit_embedding(topic, e)
submitted_embeddings.append(e)
seq_ids.append(seq_id)
return submitted_embeddings, seq_ids
def produce_n_batch(
producer: Producer,
topic: str,
embeddings: Iterator[SubmitEmbeddingRecord],
n: int,
) -> Tuple[Sequence[SubmitEmbeddingRecord], Sequence[SeqId]]:
submitted_embeddings = []
seq_ids: Sequence[SeqId] = []
for _ in range(n):
e = next(embeddings)
submitted_embeddings.append(e)
seq_ids = producer.submit_embeddings(topic, submitted_embeddings)
return submitted_embeddings, seq_ids
def produce_fn_fixtures() -> List[ProducerFn]:
return [produce_n_single, produce_n_batch]
@pytest.fixture(scope="module", params=produce_fn_fixtures())
def produce_fns(
request: pytest.FixtureRequest,
) -> Generator[ProducerFn, None, None]:
yield request.param

View File

@@ -16,6 +16,7 @@ from typing import (
)
from chromadb.ingest import Producer, Consumer
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.test.conftest import ProducerFn
from chromadb.types import (
SubmitEmbeddingRecord,
Operation,
@@ -135,15 +136,13 @@ def assert_records_match(
async def test_backfill(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
embeddings = [next(sample_embeddings) for _ in range(3)]
producer.create_topic("test_topic")
for e in embeddings:
producer.submit_embedding("test_topic", e)
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 3)[0]
consume_fn = CapturingConsumeFn()
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
@@ -212,6 +211,7 @@ async def test_multiple_topics(
async def test_start_seq_id(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
@@ -222,22 +222,16 @@ async def test_start_seq_id(
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
embeddings = []
for _ in range(5):
e = next(sample_embeddings)
embeddings.append(e)
producer.submit_embedding("test_topic", e)
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
results_1 = await consume_fn_1.get(5)
assert_records_match(embeddings, results_1)
start = consume_fn_1.embeddings[-1]["seq_id"]
consumer.subscribe("test_topic", consume_fn_2, start=start)
for _ in range(5):
e = next(sample_embeddings)
embeddings.append(e)
producer.submit_embedding("test_topic", e)
second_embeddings = produce_fns(producer, "test_topic", sample_embeddings, 5)[0]
assert isinstance(embeddings, list)
embeddings.extend(second_embeddings)
results_2 = await consume_fn_2.get(5)
assert_records_match(embeddings[-5:], results_2)
@@ -246,6 +240,7 @@ async def test_start_seq_id(
async def test_end_seq_id(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
@@ -256,11 +251,7 @@ async def test_end_seq_id(
consumer.subscribe("test_topic", consume_fn_1, start=consumer.min_seqid())
embeddings = []
for _ in range(10):
e = next(sample_embeddings)
embeddings.append(e)
producer.submit_embedding("test_topic", e)
embeddings = produce_fns(producer, "test_topic", sample_embeddings, 10)[0]
results_1 = await consume_fn_1.get(10)
assert_records_match(embeddings, results_1)
@@ -274,3 +265,60 @@ async def test_end_seq_id(
# Should never produce a 7th
with pytest.raises(TimeoutError):
_ = await wait_for(consume_fn_2.get(7), timeout=1)
@pytest.mark.asyncio
async def test_submit_batch(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
embeddings = [next(sample_embeddings) for _ in range(100)]
producer.create_topic("test_topic")
producer.submit_embeddings("test_topic", embeddings=embeddings)
consume_fn = CapturingConsumeFn()
consumer.subscribe("test_topic", consume_fn, start=consumer.min_seqid())
recieved = await consume_fn.get(100)
assert_records_match(embeddings, recieved)
@pytest.mark.asyncio
async def test_multiple_topics_batch(
producer_consumer: Tuple[Producer, Consumer],
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer, consumer = producer_consumer
producer.reset_state()
N_TOPICS = 100
consume_fns = [CapturingConsumeFn() for _ in range(N_TOPICS)]
for i in range(N_TOPICS):
producer.create_topic(f"test_topic_{i}")
consumer.subscribe(
f"test_topic_{i}", consume_fns[i], start=consumer.min_seqid()
)
embeddings_n: List[List[SubmitEmbeddingRecord]] = [[] for _ in range(N_TOPICS)]
PRODUCE_BATCH_SIZE = 10
N_TO_PRODUCE = 100
total_produced = 0
for i in range(N_TO_PRODUCE // PRODUCE_BATCH_SIZE):
for i in range(N_TOPICS):
embeddings_n[i].extend(
produce_fns(
producer,
f"test_topic_{i}",
sample_embeddings,
PRODUCE_BATCH_SIZE,
)[0]
)
recieved = await consume_fns[i].get(total_produced + PRODUCE_BATCH_SIZE)
assert_records_match(embeddings_n[i], recieved)
total_produced += PRODUCE_BATCH_SIZE

View File

@@ -4,6 +4,7 @@ import tempfile
import pytest
from typing import Generator, List, Callable, Iterator, Dict, Optional, Union, Sequence
from chromadb.config import System, Settings
from chromadb.test.conftest import ProducerFn
from chromadb.types import (
SubmitEmbeddingRecord,
MetadataEmbeddingRecord,
@@ -128,16 +129,16 @@ def sync(segment: MetadataReader, seq_id: SeqId) -> None:
def test_insert_and_count(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
topic = str(segment_definition["topic"])
max_id = 0
for i in range(3):
max_id = producer.submit_embedding(topic, next(sample_embeddings))
max_id = produce_fns(producer, topic, sample_embeddings, 3)[1][-1]
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
@@ -166,17 +167,15 @@ def assert_equiv_records(
def test_get(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
topic = str(segment_definition["topic"])
embeddings = [next(sample_embeddings) for i in range(10)]
seq_ids = []
for e in embeddings:
seq_ids.append(producer.submit_embedding(topic, e))
embeddings, seq_ids = produce_fns(producer, topic, sample_embeddings, 10)
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
@@ -270,7 +269,9 @@ def test_get(
def test_fulltext(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
@@ -279,9 +280,7 @@ def test_fulltext(
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
max_id = 0
for i in range(100):
max_id = producer.submit_embedding(topic, next(sample_embeddings))
max_id = produce_fns(producer, topic, sample_embeddings, 100)[1][-1]
sync(segment, max_id)
@@ -331,7 +330,9 @@ def test_fulltext(
def test_delete(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
@@ -340,11 +341,8 @@ def test_delete(
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
embeddings = [next(sample_embeddings) for i in range(10)]
max_id = 0
for e in embeddings:
max_id = producer.submit_embedding(topic, e)
embeddings, seq_ids = produce_fns(producer, topic, sample_embeddings, 10)
max_id = seq_ids[-1]
sync(segment, max_id)
@@ -353,16 +351,16 @@ def test_delete(
assert_equiv_records(embeddings[:1], results)
# Delete by ID
max_id = producer.submit_embedding(
topic,
SubmitEmbeddingRecord(
id="embedding_0",
embedding=None,
encoding=None,
metadata=None,
operation=Operation.DELETE,
),
delete_embedding = SubmitEmbeddingRecord(
id="embedding_0",
embedding=None,
encoding=None,
metadata=None,
operation=Operation.DELETE,
)
max_id = produce_fns(producer, topic, (delete_embedding for _ in range(1)), 1)[1][
-1
]
sync(segment, max_id)
@@ -370,16 +368,9 @@ def test_delete(
assert segment.get_metadata(ids=["embedding_0"]) == []
# Delete is idempotent
max_id = producer.submit_embedding(
topic,
SubmitEmbeddingRecord(
id="embedding_0",
embedding=None,
encoding=None,
metadata=None,
operation=Operation.DELETE,
),
)
max_id = produce_fns(producer, topic, (delete_embedding for _ in range(1)), 1)[1][
-1
]
sync(segment, max_id)
assert segment.count() == 9
@@ -420,7 +411,9 @@ def test_update(
def test_upsert(
system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord]
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
@@ -439,7 +432,12 @@ def test_upsert(
encoding=None,
operation=Operation.UPSERT,
)
max_id = producer.submit_embedding(topic, update_record)
max_id = produce_fns(
producer=producer,
topic=topic,
embeddings=(update_record for _ in range(1)),
n=1,
)[1][-1]
sync(segment, max_id)
results = segment.get_metadata(ids=["no_such_id"])
assert results[0]["metadata"] == {"foo": "bar"}

View File

@@ -1,6 +1,7 @@
import pytest
from typing import Generator, List, Callable, Iterator, Type, cast
from chromadb.config import System, Settings
from chromadb.test.conftest import ProducerFn
from chromadb.types import (
SubmitEmbeddingRecord,
VectorQuery,
@@ -129,6 +130,7 @@ def test_insert_and_count(
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
vector_reader: Type[VectorReader],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
@@ -136,9 +138,9 @@ def test_insert_and_count(
segment_definition = create_random_segment_definition()
topic = str(segment_definition["topic"])
max_id = 0
for i in range(3):
max_id = producer.submit_embedding(topic, next(sample_embeddings))
max_id = produce_fns(
producer=producer, topic=topic, n=3, embeddings=sample_embeddings
)[1][-1]
segment = vector_reader(system, segment_definition)
segment.start()
@@ -146,8 +148,10 @@ def test_insert_and_count(
sync(segment, max_id)
assert segment.count() == 3
for i in range(3):
max_id = producer.submit_embedding(topic, next(sample_embeddings))
max_id = produce_fns(
producer=producer, topic=topic, n=3, embeddings=sample_embeddings
)[1][-1]
sync(segment, max_id)
assert segment.count() == 6
@@ -165,6 +169,7 @@ def test_get_vectors(
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
vector_reader: Type[VectorReader],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
@@ -174,11 +179,9 @@ def test_get_vectors(
segment = vector_reader(system, segment_definition)
segment.start()
embeddings = [next(sample_embeddings) for i in range(10)]
seq_ids: List[SeqId] = []
for e in embeddings:
seq_ids.append(producer.submit_embedding(topic, e))
embeddings, seq_ids = produce_fns(
producer=producer, topic=topic, embeddings=sample_embeddings, n=10
)
sync(segment, seq_ids[-1])
@@ -210,6 +213,7 @@ def test_ann_query(
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
vector_reader: Type[VectorReader],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
@@ -219,11 +223,9 @@ def test_ann_query(
segment = vector_reader(system, segment_definition)
segment.start()
embeddings = [next(sample_embeddings) for i in range(100)]
seq_ids: List[SeqId] = []
for e in embeddings:
seq_ids.append(producer.submit_embedding(topic, e))
embeddings, seq_ids = produce_fns(
producer=producer, topic=topic, embeddings=sample_embeddings, n=100
)
sync(segment, seq_ids[-1])
@@ -275,6 +277,7 @@ def test_delete(
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
vector_reader: Type[VectorReader],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
@@ -284,26 +287,28 @@ def test_delete(
segment = vector_reader(system, segment_definition)
segment.start()
embeddings = [next(sample_embeddings) for i in range(5)]
seq_ids: List[SeqId] = []
for e in embeddings:
seq_ids.append(producer.submit_embedding(topic, e))
embeddings, seq_ids = produce_fns(
producer=producer, topic=topic, embeddings=sample_embeddings, n=5
)
sync(segment, seq_ids[-1])
assert segment.count() == 5
delete_record = SubmitEmbeddingRecord(
id=embeddings[0]["id"],
embedding=None,
encoding=None,
metadata=None,
operation=Operation.DELETE,
)
assert isinstance(seq_ids, List)
seq_ids.append(
producer.submit_embedding(
topic,
SubmitEmbeddingRecord(
id=embeddings[0]["id"],
embedding=None,
encoding=None,
metadata=None,
operation=Operation.DELETE,
),
)
produce_fns(
producer=producer,
topic=topic,
n=1,
embeddings=(delete_record for _ in range(1)),
)[1][0]
)
sync(segment, seq_ids[-1])
@@ -334,16 +339,12 @@ def test_delete(
# Delete is idempotent
seq_ids.append(
producer.submit_embedding(
topic,
SubmitEmbeddingRecord(
id=embeddings[0]["id"],
embedding=None,
encoding=None,
metadata=None,
operation=Operation.DELETE,
),
)
produce_fns(
producer=producer,
topic=topic,
n=1,
embeddings=(delete_record for _ in range(1)),
)[1][0]
)
sync(segment, seq_ids[-1])
@@ -416,6 +417,7 @@ def test_update(
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
vector_reader: Type[VectorReader],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
@@ -428,16 +430,19 @@ def test_update(
_test_update(producer, topic, segment, sample_embeddings, Operation.UPDATE)
# test updating a nonexistent record
seq_id = producer.submit_embedding(
topic,
SubmitEmbeddingRecord(
id="no_such_record",
embedding=[10.0, 10.0],
encoding=ScalarEncoding.FLOAT32,
metadata=None,
operation=Operation.UPDATE,
),
update_record = SubmitEmbeddingRecord(
id="no_such_record",
embedding=[10.0, 10.0],
encoding=ScalarEncoding.FLOAT32,
metadata=None,
operation=Operation.UPDATE,
)
seq_id = produce_fns(
producer=producer,
topic=topic,
n=1,
embeddings=(update_record for _ in range(1)),
)[1][0]
sync(segment, seq_id)
@@ -449,6 +454,7 @@ def test_upsert(
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
vector_reader: Type[VectorReader],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
@@ -461,16 +467,19 @@ def test_upsert(
_test_update(producer, topic, segment, sample_embeddings, Operation.UPSERT)
# test updating a nonexistent record
seq_id = producer.submit_embedding(
topic,
SubmitEmbeddingRecord(
id="no_such_record",
embedding=[42, 42],
encoding=ScalarEncoding.FLOAT32,
metadata=None,
operation=Operation.UPSERT,
),
upsert_record = SubmitEmbeddingRecord(
id="no_such_record",
embedding=[42, 42],
encoding=ScalarEncoding.FLOAT32,
metadata=None,
operation=Operation.UPSERT,
)
seq_id = produce_fns(
producer=producer,
topic=topic,
n=1,
embeddings=(upsert_record for _ in range(1)),
)[1][0]
sync(segment, seq_id)