[CHORE] Add support for pydantic v2 (#1174)

## Description of changes
Closes #893 

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Adds support for pydantic v2.0 by changing how Collection model inits
- this simple change fixes pydantic v2
	 - Fixes the cross version tests to handle pydantic specifically
- Conditionally imports pydantic-settings based on what is available. In
v2 BaseSettings was moved to a new package.
 - New functionality
	 - N/A

## Test plan
Existing tests were run with the following configs
1. Fastapi < 0.100, Pydantic >= 2.0 - Unsupported as the fastapi
dependencies will not allow it. They likely should, as pydantic.v1
imports would support this, but this is a downstream issue.
2. Fastapi >= 0.100, Pydantic >= 2.0, Supported via normal imports 
(Tested with fastapi==0.103.1, pydantic==2.3.0)
3. Fastapi < 0.100 Pydantic < 2.0, Supported via normal imports 
(Tested with fastapi==0.95.2, pydantic==1.9.2)
4. Fastapi >= 0.100, Pydantic < 2.0, Supported via normal imports 
(Tested with latest fastapi, pydantic==1.9.2)

- [x] Tests pass locally with `pytest` for python, `yarn test` for js

## Documentation Changes
None required.
This commit is contained in:
Hammad Bashir
2023-09-25 09:25:39 -07:00
committed by GitHub
parent c7a0414ea7
commit 8a6ad07127
6 changed files with 28 additions and 11 deletions

View File

@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Optional, Tuple, cast, List from typing import TYPE_CHECKING, Optional, Tuple, cast, List
from pydantic import BaseModel, PrivateAttr from pydantic import BaseModel, PrivateAttr
from uuid import UUID from uuid import UUID
import chromadb.utils.embedding_functions as ef import chromadb.utils.embedding_functions as ef
@@ -50,9 +51,9 @@ class Collection(BaseModel):
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
metadata: Optional[CollectionMetadata] = None, metadata: Optional[CollectionMetadata] = None,
): ):
super().__init__(name=name, metadata=metadata, id=id)
self._client = client self._client = client
self._embedding_function = embedding_function self._embedding_function = embedding_function
super().__init__(name=name, metadata=metadata, id=id)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Collection(name={self.name})" return f"Collection(name={self.name})"

View File

@@ -5,7 +5,6 @@ from typing import cast, Dict, TypeVar, Any
import requests import requests
from overrides import override from overrides import override
from pydantic import SecretStr from pydantic import SecretStr
from chromadb.auth import ( from chromadb.auth import (
ServerAuthCredentialsProvider, ServerAuthCredentialsProvider,
AbstractCredentials, AbstractCredentials,

View File

@@ -9,9 +9,20 @@ from typing import Type, TypeVar, cast
from overrides import EnforceOverrides from overrides import EnforceOverrides
from overrides import override from overrides import override
from pydantic import BaseSettings, validator
from typing_extensions import Literal from typing_extensions import Literal
in_pydantic_v2 = False
try:
from pydantic import BaseSettings
except ImportError:
in_pydantic_v2 = True
from pydantic.v1 import BaseSettings
from pydantic.v1 import validator
if not in_pydantic_v2:
from pydantic import validator # type: ignore # noqa
# The thin client will have a flag to control which implementations to use # The thin client will have a flag to control which implementations to use
is_thin_client = False is_thin_client = False
try: try:

View File

@@ -24,6 +24,9 @@ from chromadb.config import Settings
MINIMUM_VERSION = "0.4.1" MINIMUM_VERSION = "0.4.1"
version_re = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+$") version_re = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+$")
# Some modules do not work across versions, since we upgrade our support for them, and should be explicitly reimported in the subprocess
VERSIONED_MODULES = ["pydantic"]
def versions() -> List[str]: def versions() -> List[str]:
"""Returns the pinned minimum version and the latest version of chromadb.""" """Returns the pinned minimum version and the latest version of chromadb."""
@@ -49,7 +52,7 @@ def _patch_boolean_metadata(
# boolean value metadata to int # boolean value metadata to int
collection_metadata = collection.metadata collection_metadata = collection.metadata
if collection_metadata is not None: if collection_metadata is not None:
_bool_to_int(collection_metadata) _bool_to_int(collection_metadata) # type: ignore
if embeddings["metadatas"] is not None: if embeddings["metadatas"] is not None:
if isinstance(embeddings["metadatas"], list): if isinstance(embeddings["metadatas"], list):
@@ -162,7 +165,10 @@ def switch_to_version(version: str) -> ModuleType:
old_modules = { old_modules = {
n: m n: m
for n, m in sys.modules.items() for n, m in sys.modules.items()
if n == module_name or (n.startswith(module_name + ".")) if n == module_name
or (n.startswith(module_name + "."))
or n in VERSIONED_MODULES
or (any(n.startswith(m + ".") for m in VERSIONED_MODULES))
} }
for n in old_modules: for n in old_modules:
del sys.modules[n] del sys.modules[n]
@@ -197,7 +203,7 @@ def persist_generated_data_with_old_version(
api.reset() api.reset()
coll = api.create_collection( coll = api.create_collection(
name=collection_strategy.name, name=collection_strategy.name,
metadata=collection_strategy.metadata, metadata=collection_strategy.metadata, # type: ignore
# In order to test old versions, we can't rely on the not_implemented function # In order to test old versions, we can't rely on the not_implemented function
embedding_function=not_implemented_ef(), embedding_function=not_implemented_ef(),
) )

View File

@@ -16,9 +16,9 @@ classifiers = [
] ]
dependencies = [ dependencies = [
'requests >= 2.28', 'requests >= 2.28',
'pydantic>=1.9,<2.0', 'pydantic >= 1.9',
'chroma-hnswlib==0.7.3', 'chroma-hnswlib==0.7.3',
'fastapi>=0.95.2, <0.100.0', 'fastapi >= 0.95.2',
'uvicorn[standard] >= 0.18.3', 'uvicorn[standard] >= 0.18.3',
'numpy == 1.21.6; python_version < "3.8"', 'numpy == 1.21.6; python_version < "3.8"',
'numpy >= 1.22.5; python_version >= "3.8"', 'numpy >= 1.22.5; python_version >= "3.8"',

View File

@@ -1,6 +1,6 @@
bcrypt==4.0.1 bcrypt==4.0.1
chroma-hnswlib==0.7.3 chroma-hnswlib==0.7.3
fastapi>=0.95.2, <0.100.0 fastapi>=0.95.2
graphlib_backport==1.0.3; python_version < '3.9' graphlib_backport==1.0.3; python_version < '3.9'
importlib-resources importlib-resources
numpy==1.21.6; python_version < '3.8' numpy==1.21.6; python_version < '3.8'
@@ -9,11 +9,11 @@ onnxruntime==1.14.1
overrides==7.3.1 overrides==7.3.1
posthog==2.4.0 posthog==2.4.0
pulsar-client==3.1.0 pulsar-client==3.1.0
pydantic>=1.9,<2.0 pydantic>=1.9
pypika==0.48.9 pypika==0.48.9
requests==2.28.1 requests==2.28.1
tokenizers==0.13.2 tokenizers==0.13.2
tqdm==4.65.0 tqdm==4.65.0
typer>=0.9.0 typer>=0.9.0
typing_extensions==4.5.0 typing_extensions>=4.5.0
uvicorn[standard]==0.18.3 uvicorn[standard]==0.18.3