mirror of
https://github.com/placeholder-soft/chroma.git
synced 2026-01-12 08:44:18 +08:00
[CHORE] Add support for pydantic v2 (#1174)
## Description of changes Closes #893 *Summarize the changes made by this PR.* - Improvements & Bug fixes - Adds support for pydantic v2.0 by changing how Collection model inits - this simple change fixes pydantic v2 - Fixes the cross version tests to handle pydantic specifically - Conditionally imports pydantic-settings based on what is available. In v2 BaseSettings was moved to a new package. - New functionality - N/A ## Test plan Existing tests were run with the following configs 1. Fastapi < 0.100, Pydantic >= 2.0 - Unsupported as the fastapi dependencies will not allow it. They likely should, as pydantic.v1 imports would support this, but this is a downstream issue. 2. Fastapi >= 0.100, Pydantic >= 2.0, Supported via normal imports ✅ (Tested with fastapi==0.103.1, pydantic==2.3.0) 3. Fastapi < 0.100 Pydantic < 2.0, Supported via normal imports ✅ (Tested with fastapi==0.95.2, pydantic==1.9.2) 4. Fastapi >= 0.100, Pydantic < 2.0, Supported via normal imports ✅ (Tested with latest fastapi, pydantic==1.9.2) - [x] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes None required.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, cast, List
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from uuid import UUID
|
||||
import chromadb.utils.embedding_functions as ef
|
||||
|
||||
@@ -50,9 +51,9 @@ class Collection(BaseModel):
|
||||
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
|
||||
metadata: Optional[CollectionMetadata] = None,
|
||||
):
|
||||
super().__init__(name=name, metadata=metadata, id=id)
|
||||
self._client = client
|
||||
self._embedding_function = embedding_function
|
||||
super().__init__(name=name, metadata=metadata, id=id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Collection(name={self.name})"
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import cast, Dict, TypeVar, Any
|
||||
import requests
|
||||
from overrides import override
|
||||
from pydantic import SecretStr
|
||||
|
||||
from chromadb.auth import (
|
||||
ServerAuthCredentialsProvider,
|
||||
AbstractCredentials,
|
||||
|
||||
@@ -9,9 +9,20 @@ from typing import Type, TypeVar, cast
|
||||
|
||||
from overrides import EnforceOverrides
|
||||
from overrides import override
|
||||
from pydantic import BaseSettings, validator
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
in_pydantic_v2 = False
|
||||
try:
|
||||
from pydantic import BaseSettings
|
||||
except ImportError:
|
||||
in_pydantic_v2 = True
|
||||
from pydantic.v1 import BaseSettings
|
||||
from pydantic.v1 import validator
|
||||
|
||||
if not in_pydantic_v2:
|
||||
from pydantic import validator # type: ignore # noqa
|
||||
|
||||
# The thin client will have a flag to control which implementations to use
|
||||
is_thin_client = False
|
||||
try:
|
||||
|
||||
@@ -24,6 +24,9 @@ from chromadb.config import Settings
|
||||
MINIMUM_VERSION = "0.4.1"
|
||||
version_re = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+$")
|
||||
|
||||
# Some modules do not work across versions, since we upgrade our support for them, and should be explicitly reimported in the subprocess
|
||||
VERSIONED_MODULES = ["pydantic"]
|
||||
|
||||
|
||||
def versions() -> List[str]:
|
||||
"""Returns the pinned minimum version and the latest version of chromadb."""
|
||||
@@ -49,7 +52,7 @@ def _patch_boolean_metadata(
|
||||
# boolean value metadata to int
|
||||
collection_metadata = collection.metadata
|
||||
if collection_metadata is not None:
|
||||
_bool_to_int(collection_metadata)
|
||||
_bool_to_int(collection_metadata) # type: ignore
|
||||
|
||||
if embeddings["metadatas"] is not None:
|
||||
if isinstance(embeddings["metadatas"], list):
|
||||
@@ -162,7 +165,10 @@ def switch_to_version(version: str) -> ModuleType:
|
||||
old_modules = {
|
||||
n: m
|
||||
for n, m in sys.modules.items()
|
||||
if n == module_name or (n.startswith(module_name + "."))
|
||||
if n == module_name
|
||||
or (n.startswith(module_name + "."))
|
||||
or n in VERSIONED_MODULES
|
||||
or (any(n.startswith(m + ".") for m in VERSIONED_MODULES))
|
||||
}
|
||||
for n in old_modules:
|
||||
del sys.modules[n]
|
||||
@@ -197,7 +203,7 @@ def persist_generated_data_with_old_version(
|
||||
api.reset()
|
||||
coll = api.create_collection(
|
||||
name=collection_strategy.name,
|
||||
metadata=collection_strategy.metadata,
|
||||
metadata=collection_strategy.metadata, # type: ignore
|
||||
# In order to test old versions, we can't rely on the not_implemented function
|
||||
embedding_function=not_implemented_ef(),
|
||||
)
|
||||
|
||||
@@ -16,9 +16,9 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
'requests >= 2.28',
|
||||
'pydantic>=1.9,<2.0',
|
||||
'pydantic >= 1.9',
|
||||
'chroma-hnswlib==0.7.3',
|
||||
'fastapi>=0.95.2, <0.100.0',
|
||||
'fastapi >= 0.95.2',
|
||||
'uvicorn[standard] >= 0.18.3',
|
||||
'numpy == 1.21.6; python_version < "3.8"',
|
||||
'numpy >= 1.22.5; python_version >= "3.8"',
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
bcrypt==4.0.1
|
||||
chroma-hnswlib==0.7.3
|
||||
fastapi>=0.95.2, <0.100.0
|
||||
fastapi>=0.95.2
|
||||
graphlib_backport==1.0.3; python_version < '3.9'
|
||||
importlib-resources
|
||||
numpy==1.21.6; python_version < '3.8'
|
||||
@@ -9,11 +9,11 @@ onnxruntime==1.14.1
|
||||
overrides==7.3.1
|
||||
posthog==2.4.0
|
||||
pulsar-client==3.1.0
|
||||
pydantic>=1.9,<2.0
|
||||
pydantic>=1.9
|
||||
pypika==0.48.9
|
||||
requests==2.28.1
|
||||
tokenizers==0.13.2
|
||||
tqdm==4.65.0
|
||||
typer>=0.9.0
|
||||
typing_extensions==4.5.0
|
||||
typing_extensions>=4.5.0
|
||||
uvicorn[standard]==0.18.3
|
||||
|
||||
Reference in New Issue
Block a user