diff --git a/.gitignore b/.gitignore index 391819bfa9c97..0e85fbb786409 100644 --- a/.gitignore +++ b/.gitignore @@ -1,13 +1,22 @@ +# Build .pants.d/ dist/ migration_scripts/ -venv/ + +# IDEs .idea +.vscode +.zed + +# Local development +venv/ .venv/ .ipynb_checkpoints .__pycache__ __pycache__ dev_notebooks/ + +# Other llamaindex_registry.txt packages_to_bump_deduped.txt .env diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8a65f36a59744..aa807ea83b51b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,13 +22,14 @@ repos: exclude: llama-index-core/llama_index/core/_static - id: trailing-whitespace exclude: llama-index-core/llama_index/core/_static + - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.1.5 - hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] exclude: ".*poetry.lock|.*_static" + - repo: https://github.com/psf/black-pre-commit-mirror rev: 23.10.1 hooks: @@ -36,6 +37,7 @@ repos: name: black-src alias: black exclude: "^docs|.*poetry.lock|.*_static" + - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.0.1 hooks: @@ -56,9 +58,10 @@ repos: --explicit-package-bases, --disallow-untyped-defs, --ignore-missing-imports, - --python-version=3.8, + --python-version=3.9, ] entry: bash -c "export MYPYPATH=llama_index" + - repo: https://github.com/psf/black-pre-commit-mirror rev: 23.10.1 hooks: @@ -68,6 +71,7 @@ repos: files: ^(docs/|examples/) # Using PEP 8's line length in docs prevents excess left/right scrolling args: [--line-length=79] + - repo: https://github.com/adamchainz/blacken-docs rev: 1.16.0 hooks: @@ -78,11 +82,13 @@ repos: additional_dependencies: [black==23.10.1] # Using PEP 8's line length in docs prevents excess left/right scrolling args: [--line-length=79] + - repo: https://github.com/pre-commit/mirrors-prettier rev: v3.0.3 hooks: - id: prettier exclude: llama-index-core/llama_index/core/_static|poetry.lock|llama-index-legacy/llama_index/legacy/_static|docs/docs + - repo: https://github.com/codespell-project/codespell rev: v2.2.6 hooks: @@ -98,13 +104,15 @@ repos: [ "--skip=*/algolia.js", "--ignore-words-list", - "astroid,gallary,momento,narl,ot,rouge,nin,gere,asend", + "astroid,gallary,momento,narl,ot,rouge,nin,gere,asend,seperator", ] + - repo: https://github.com/srstevenson/nb-clean rev: 3.1.0 hooks: - id: nb-clean args: [--preserve-cell-outputs, --remove-empty-cells] + - repo: https://github.com/pappasam/toml-sort rev: v0.23.1 hooks: diff --git a/llama-index-core/llama_index/core/__init__.py b/llama-index-core/llama_index/core/__init__.py index 40fea381cd953..7534635d199cf 100644 --- a/llama-index-core/llama_index/core/__init__.py +++ b/llama-index-core/llama_index/core/__init__.py @@ -6,6 +6,12 @@ from logging import NullHandler from typing import Callable, Optional +try: + # Force pants to install eval_type_backport on 3.9 + import eval_type_backport # noqa # type: ignore +except ImportError: + pass + # response from llama_index.core.base.response.schema import Response @@ -28,8 +34,8 @@ GPTVectorStoreIndex, KeywordTableIndex, KnowledgeGraphIndex, - PropertyGraphIndex, ListIndex, + PropertyGraphIndex, RAKEKeywordTableIndex, SimpleKeywordTableIndex, SummaryIndex, @@ -67,6 +73,9 @@ set_global_service_context, ) +# global settings +from llama_index.core.settings import Settings + # storage from llama_index.core.storage.storage_context import StorageContext @@ -76,9 +85,6 @@ # global tokenizer from llama_index.core.utils import get_tokenizer, set_global_tokenizer -# global settings -from llama_index.core.settings import Settings - # best practices for library logging: # https://docs.python.org/3/howto/logging.html#configuring-logging-for-a-library logging.getLogger(__name__).addHandler(NullHandler()) diff --git a/llama-index-core/llama_index/core/bridge/pydantic.py b/llama-index-core/llama_index/core/bridge/pydantic.py index b0c4078e3f0ce..ecec5751a0cd6 100644 --- a/llama-index-core/llama_index/core/bridge/pydantic.py +++ b/llama-index-core/llama_index/core/bridge/pydantic.py @@ -1,29 +1,30 @@ import pydantic from pydantic import ( - ConfigDict, + AnyUrl, BaseModel, - GetJsonSchemaHandler, - GetCoreSchemaHandler, + BeforeValidator, + ConfigDict, Field, + GetCoreSchemaHandler, + GetJsonSchemaHandler, PlainSerializer, PrivateAttr, + Secret, + SecretStr, + SerializeAsAny, StrictFloat, StrictInt, StrictStr, - create_model, - model_validator, - field_validator, - ValidationInfo, - ValidationError, TypeAdapter, + ValidationError, + ValidationInfo, WithJsonSchema, - BeforeValidator, - SerializeAsAny, WrapSerializer, + create_model, field_serializer, - Secret, - SecretStr, + field_validator, model_serializer, + model_validator, ) from pydantic.fields import FieldInfo from pydantic.json_schema import JsonSchemaValue @@ -58,4 +59,5 @@ "Secret", "SecretStr", "model_serializer", + "AnyUrl", ] diff --git a/llama-index-core/llama_index/core/schema.py b/llama-index-core/llama_index/core/schema.py index 9b5d2b7e87052..198c985ff3779 100644 --- a/llama-index-core/llama_index/core/schema.py +++ b/llama-index-core/llama_index/core/schema.py @@ -1,5 +1,8 @@ """Base schema for data structures.""" +from __future__ import annotations + +import base64 import json import logging import pickle @@ -10,28 +13,45 @@ from enum import Enum, auto from hashlib import sha256 from io import BytesIO -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Dict, + List, + Literal, + Optional, + Sequence, + Union, +) +import filetype from dataclasses_json import DataClassJsonMixin +from typing_extensions import Self + from llama_index.core.bridge.pydantic import ( + AnyUrl, BaseModel, + ConfigDict, Field, GetJsonSchemaHandler, - SerializeAsAny, JsonSchemaValue, - ConfigDict, + PlainSerializer, + SerializeAsAny, model_serializer, + model_validator, ) from llama_index.core.bridge.pydantic_core import CoreSchema from llama_index.core.instrumentation import DispatcherSpanMixin from llama_index.core.utils import SAMPLE_TEXT, truncate_text -from typing_extensions import Self -if TYPE_CHECKING: - from haystack.schema import Document as HaystackDocument - from llama_index.core.bridge.langchain import Document as LCDocument - from semantic_kernel.memory.memory_record import MemoryRecord +if TYPE_CHECKING: # pragma: no cover + from haystack.schema import Document as HaystackDocument # type: ignore from llama_cloud.types.cloud_document import CloudDocument + from semantic_kernel.memory.memory_record import MemoryRecord # type: ignore + + from llama_index.core.bridge.langchain import Document as LCDocument # type: ignore DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}" @@ -44,6 +64,10 @@ logger = logging.getLogger(__name__) +EnumNameSerializer = PlainSerializer( + lambda e: e.value, return_type="str", when_used="always" +) + class BaseComponent(BaseModel): """Base component object to capture class names.""" @@ -156,14 +180,12 @@ class TransformComponent(BaseComponent, DispatcherSpanMixin): model_config = ConfigDict(arbitrary_types_allowed=True) @abstractmethod - def __call__( - self, nodes: Sequence["BaseNode"], **kwargs: Any - ) -> Sequence["BaseNode"]: + def __call__(self, nodes: Sequence[BaseNode], **kwargs: Any) -> Sequence[BaseNode]: """Transform nodes.""" async def acall( - self, nodes: Sequence["BaseNode"], **kwargs: Any - ) -> Sequence["BaseNode"]: + self, nodes: Sequence[BaseNode], **kwargs: Any + ) -> Sequence[BaseNode]: """Async transform nodes.""" return self.__call__(nodes, **kwargs) @@ -192,6 +214,14 @@ class ObjectType(str, Enum): IMAGE = auto() INDEX = auto() DOCUMENT = auto() + MULTIMODAL = auto() + + +class Modality(str, Enum): + TEXT = auto() + IMAGE = auto() + AUDIO = auto() + VIDEO = auto() class MetadataMode(str, Enum): @@ -203,7 +233,7 @@ class MetadataMode(str, Enum): class RelatedNodeInfo(BaseComponent): node_id: str - node_type: Optional[ObjectType] = None + node_type: Annotated[ObjectType, EnumNameSerializer] | str | None = None metadata: Dict[str, Any] = Field(default_factory=dict) hash: Optional[str] = None @@ -253,10 +283,24 @@ class BaseNode(BaseComponent): default_factory=list, description="Metadata keys that are excluded from text for the LLM.", ) - relationships: Dict[NodeRelationship, RelatedNodeType] = Field( + relationships: Dict[ + Annotated[NodeRelationship, EnumNameSerializer], + RelatedNodeType, + ] = Field( default_factory=dict, description="A mapping of relationships to other node information.", ) + metadata_template: str = Field( + default=DEFAULT_METADATA_TMPL, + description=( + "Template for how metadata is formatted, with {key} and " + "{value} placeholders." + ), + ) + metadata_separator: str = Field( + default="\n", + description="Separator between metadata fields when converting to string.", + ) @classmethod @abstractmethod @@ -267,9 +311,28 @@ def get_type(cls) -> str: def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str: """Get object content.""" - @abstractmethod def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: - """Metadata string.""" + """Metadata info string.""" + if mode == MetadataMode.NONE: + return "" + + usable_metadata_keys = set(self.metadata.keys()) + if mode == MetadataMode.LLM: + for key in self.excluded_llm_metadata_keys: + if key in usable_metadata_keys: + usable_metadata_keys.remove(key) + elif mode == MetadataMode.EMBED: + for key in self.excluded_embed_metadata_keys: + if key in usable_metadata_keys: + usable_metadata_keys.remove(key) + + return self.metadata_separator.join( + [ + self.metadata_template.format(key=key, value=str(value)) + for key, value in self.metadata.items() + if key in usable_metadata_keys + ] + ) @abstractmethod def set_content(self, value: Any) -> None: @@ -348,7 +411,7 @@ def child_nodes(self) -> Optional[List[RelatedNodeInfo]]: return relation @property - def ref_doc_id(self) -> Optional[str]: + def ref_doc_id(self) -> Optional[str]: # pragma: no cover """Deprecated: Get ref doc id.""" source_node = self.source_node if source_node is None: @@ -356,7 +419,7 @@ def ref_doc_id(self) -> Optional[str]: return source_node.node_id @property - def extra_info(self) -> Dict[str, Any]: + def extra_info(self) -> Dict[str, Any]: # pragma: no cover """TODO: DEPRECATED: Extra info.""" return self.metadata @@ -389,7 +452,156 @@ def as_related_node_info(self) -> RelatedNodeInfo: ) +EmbeddingKind = Literal["sparse", "dense"] + + +class MediaResource(BaseModel): + """A container class for media content. + + This class represents a generic media resource that can be stored and accessed + in multiple ways - as raw bytes, on the filesystem, or via URL. It also supports + storing vector embeddings for the media content. + + Attributes: + embeddings: Multi-vector dict representation of this resource for embedding-based search/retrieval + text: Plain text representation of this resource + data: Raw binary data of the media content + mimetype: The MIME type indicating the format/type of the media content + path: Local filesystem path where the media content can be accessed + url: URL where the media content can be accessed remotely + """ + + embeddings: dict[EmbeddingKind, list[float]] | None = Field( + default=None, description="Vector representation of this resource." + ) + data: bytes | None = Field( + default=None, + exclude=True, + description="base64 binary representation of this resource.", + ) + text: str | None = Field( + default=None, description="Text representation of this resource." + ) + mimetype: str | None = Field( + default=None, description="MIME type of this resource." + ) + path: Path | None = Field( + default=None, description="Filesystem path of this resource." + ) + url: AnyUrl | None = Field(default=None, description="URL to reach this resource.") + + @model_validator(mode="after") + def guess_mimetype(self) -> Self: + """Guess the mimetype when possible. + + In case the model was built passing its content but without a mimetype, + we try to guess it using the filetype library. To avoid resource-intense + operations, we won't load the path or the URL to guess the mimetype. + """ + if not self.data or self.mimetype: + return self + + try: + decoded_data = base64.b64decode(self.data) + guess = filetype.guess(decoded_data) + self.mimetype = guess.mime if guess else None + except Exception as e: + logging.debug("Data is not base64 encoded, cannot guess mimetype") + finally: + return self + + @property + def hash(self) -> str: + """Generate a hash to uniquely identify the media resource. + + The hash is generated based on the available content (data, path, text or url). + Returns an empty string if no content is available. + """ + bits: list[str] = [] + if self.text is not None: + bits.append(self.text) + if self.data is not None: + # Hash the binary data if available + bits.append(str(sha256(self.data).hexdigest())) + if self.path is not None: + # Hash the file path if provided + bits.append(str(sha256(str(self.path).encode("utf-8")).hexdigest())) + if self.url is not None: + # Use the URL string as basis for hash + bits.append(str(sha256(str(self.url).encode("utf-8")).hexdigest())) + + doc_identity = "".join(bits) + if not doc_identity: + return "" + return str(sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()) + + +class Node(BaseNode): + text: MediaResource | None = Field( + default=None, description="Text content of the node." + ) + image: MediaResource | None = Field( + default=None, description="Image content of the node." + ) + audio: MediaResource | None = Field( + default=None, description="Audio content of the node." + ) + video: MediaResource | None = Field( + default=None, description="Video content of the node." + ) + + @classmethod + def class_name(cls) -> str: + return "Node" + + @classmethod + def get_type(cls) -> str: + """Get Object type.""" + return ObjectType.MULTIMODAL + + def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str: + """Get the text content for the node if available. + + Provided for backward compatibility, use self.text directly instead. + """ + if self.text: + return self.text.text or "" + return "" + + def set_content(self, value: str) -> None: + """Set the text content of the node. + + Provided for backward compatibility, set self.text instead. + """ + self.text = MediaResource(text=value) + + @property + def hash(self) -> str: + doc_identities = [] + if self.audio is not None: + doc_identities.append(self.audio.hash) + if self.image is not None: + doc_identities.append(self.image.hash) + if self.text is not None: + doc_identities.append(self.text.hash) + if self.video is not None: + doc_identities.append(self.video.hash) + + doc_identity = "-".join(doc_identities) + return str(sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()) + + class TextNode(BaseNode): + """Provided for backward compatibility. + + Note: we keep the field with the typo "seperator" to maintain backward compatibility for + serialized objects. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """This is needed to help static checkers with inherited fields.""" + super().__init__(*args, **kwargs) + text: str = Field(default="", description="Text content of the node.") mimetype: str = Field( default="text/plain", description="MIME type of the node content." @@ -400,6 +612,10 @@ class TextNode(BaseNode): end_char_idx: Optional[int] = Field( default=None, description="End char index of the node." ) + metadata_seperator: str = Field( + default="\n", + description="Separator between metadata fields when converting to string.", + ) text_template: str = Field( default=DEFAULT_TEXT_NODE_TMPL, description=( @@ -407,17 +623,6 @@ class TextNode(BaseNode): "{metadata_str} placeholders." ), ) - metadata_template: str = Field( - default=DEFAULT_METADATA_TMPL, - description=( - "Template for how metadata is formatted, with {key} and " - "{value} placeholders." - ), - ) - metadata_seperator: str = Field( - default="\n", - description="Separator between metadata fields when converting to string.", - ) @classmethod def class_name(cls) -> str: @@ -483,10 +688,6 @@ def node_info(self) -> Dict[str, Any]: return self.get_node_info() -# TODO: legacy backport of old Node class -Node = TextNode - - class ImageNode(TextNode): """Node with image.""" @@ -575,7 +776,7 @@ def from_text_node( cls, node: TextNode, index_id: str, - ) -> "IndexNode": + ) -> IndexNode: """Create index node from text node.""" # copy all attributes from text node, add index id return cls( @@ -588,7 +789,7 @@ def from_text_node( def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore output = super().from_dict(data, **kwargs) - obj = data.get("obj", None) + obj = data.get("obj") parsed_obj = None if isinstance(obj, str): @@ -721,7 +922,7 @@ def __setattr__(self, name: str, value: object) -> None: name = self._compat_fields[name] super().__setattr__(name, value) - def to_langchain_format(self) -> "LCDocument": + def to_langchain_format(self) -> LCDocument: """Convert struct to LangChain document format.""" from llama_index.core.bridge.langchain import Document as LCDocument @@ -729,13 +930,13 @@ def to_langchain_format(self) -> "LCDocument": return LCDocument(page_content=self.text, metadata=metadata, id=self.id_) @classmethod - def from_langchain_format(cls, doc: "LCDocument") -> "Document": + def from_langchain_format(cls, doc: LCDocument) -> Document: """Convert struct from LangChain document format.""" if doc.id: return cls(text=doc.page_content, metadata=doc.metadata, id_=doc.id) return cls(text=doc.page_content, metadata=doc.metadata) - def to_haystack_format(self) -> "HaystackDocument": + def to_haystack_format(self) -> HaystackDocument: """Convert struct to Haystack document format.""" from haystack.schema import Document as HaystackDocument @@ -744,7 +945,7 @@ def to_haystack_format(self) -> "HaystackDocument": ) @classmethod - def from_haystack_format(cls, doc: "HaystackDocument") -> "Document": + def from_haystack_format(cls, doc: HaystackDocument) -> Document: """Convert struct from Haystack document format.""" return cls( text=doc.content, metadata=doc.meta, embedding=doc.embedding, id_=doc.id @@ -758,7 +959,7 @@ def to_embedchain_format(self) -> Dict[str, Any]: } @classmethod - def from_embedchain_format(cls, doc: Dict[str, Any]) -> "Document": + def from_embedchain_format(cls, doc: Dict[str, Any]) -> Document: """Convert struct from EmbedChain document format.""" return cls( text=doc["data"]["content"], @@ -766,7 +967,7 @@ def from_embedchain_format(cls, doc: Dict[str, Any]) -> "Document": id_=doc["doc_id"], ) - def to_semantic_kernel_format(self) -> "MemoryRecord": + def to_semantic_kernel_format(self) -> MemoryRecord: """Convert struct to Semantic Kernel document format.""" import numpy as np from semantic_kernel.memory.memory_record import MemoryRecord @@ -779,7 +980,7 @@ def to_semantic_kernel_format(self) -> "MemoryRecord": ) @classmethod - def from_semantic_kernel_format(cls, doc: "MemoryRecord") -> "Document": + def from_semantic_kernel_format(cls, doc: MemoryRecord) -> Document: """Convert struct from Semantic Kernel document format.""" return cls( text=doc._text, @@ -799,7 +1000,7 @@ def to_vectorflow(self, client: Any) -> None: client.embed(f.name) @classmethod - def example(cls) -> "Document": + def example(cls) -> Document: return Document( text=SAMPLE_TEXT, metadata={"filename": "README.md", "category": "codebase"}, @@ -809,7 +1010,7 @@ def example(cls) -> "Document": def class_name(cls) -> str: return "Document" - def to_cloud_document(self) -> "CloudDocument": + def to_cloud_document(self) -> CloudDocument: """Convert to LlamaCloud document type.""" from llama_cloud.types.cloud_document import CloudDocument @@ -824,8 +1025,8 @@ def to_cloud_document(self) -> "CloudDocument": @classmethod def from_cloud_document( cls, - doc: "CloudDocument", - ) -> "Document": + doc: CloudDocument, + ) -> Document: """Convert from LlamaCloud document type.""" return Document( text=doc.text, diff --git a/llama-index-core/pyproject.toml b/llama-index-core/pyproject.toml index 654c2a3fb39e4..1e3f55d800fef 100644 --- a/llama-index-core/pyproject.toml +++ b/llama-index-core/pyproject.toml @@ -227,6 +227,9 @@ unfixable = [ [tool.ruff.flake8-annotations] mypy-init-return = true +[tool.ruff.lint.flake8-pytest-style] +fixture-parentheses = true + [tool.ruff.pydocstyle] convention = "google" diff --git a/llama-index-core/tests/BUILD b/llama-index-core/tests/BUILD index 3718580df954d..1d1d2df6dda7a 100644 --- a/llama-index-core/tests/BUILD +++ b/llama-index-core/tests/BUILD @@ -8,6 +8,7 @@ python_test_utils( "llama-index-core/tests/mock_utils/mock_text_splitter.py", "llama-index-core/tests/mock_utils/mock_prompts.py", "llama-index-core/tests/mock_utils/mock_utils.py", + "llama-index-core:poetry#eval-type-backport" ], ) diff --git a/llama-index-core/tests/schema/BUILD b/llama-index-core/tests/schema/BUILD new file mode 100644 index 0000000000000..57341b1358b56 --- /dev/null +++ b/llama-index-core/tests/schema/BUILD @@ -0,0 +1,3 @@ +python_tests( + name="tests", +) diff --git a/llama-index-core/tests/schema/__init__.py b/llama-index-core/tests/schema/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-core/tests/schema/data/data.txt b/llama-index-core/tests/schema/data/data.txt new file mode 100644 index 0000000000000..67b808feb3620 --- /dev/null +++ b/llama-index-core/tests/schema/data/data.txt @@ -0,0 +1 @@ +Test data diff --git a/llama-index-core/tests/schema/test_base_component.py b/llama-index-core/tests/schema/test_base_component.py new file mode 100644 index 0000000000000..19d8b3a2cdef3 --- /dev/null +++ b/llama-index-core/tests/schema/test_base_component.py @@ -0,0 +1,63 @@ +from typing import Callable + +import pytest +from llama_index.core.schema import BaseComponent +from pydantic.fields import PrivateAttr + + +@pytest.fixture() +def my_component(): + class MyComponent(BaseComponent): + foo: str = "bar" + + return MyComponent + + +def test_identifiers(): + assert BaseComponent.class_name() == "base_component" + + +def test_schema(): + assert ( + BaseComponent.schema_json() + == '{"description": "Base component object to capture class names.", "properties": {"class_name": {"default": "base_component", "title": "Class Name", "type": "string"}}, "title": "BaseComponent", "type": "object"}' + ) + + +def test_json(): + assert BaseComponent().json() == '{"class_name": "base_component"}' + + +def test__getstate__(): + class MyComponent(BaseComponent): + _text: str = PrivateAttr(default="test private attr") + _fn: Callable = PrivateAttr(default=lambda x: x) + + mc = MyComponent() + # add an unpickable field + mc._unpickable = lambda x: x # type: ignore + assert mc.__getstate__() == { + "__dict__": {}, + "__pydantic_extra__": None, + "__pydantic_fields_set__": set(), + "__pydantic_private__": {"_text": "test private attr"}, + } + + +def test__setstate__(): + c = BaseComponent() + c.__setstate__({}) + + +def test_from_dict(my_component): + mc = my_component.from_dict( + {"class_name": "to_be_popped_out", "foo": "test string"} + ) + assert mc.foo == "test string" + + +def test_from_json(my_component): + mc = my_component.from_json( + '{"class_name": "to_be_popped_out", "foo": "test string"}' + ) + assert mc.foo == "test string" diff --git a/llama-index-core/tests/schema/test_base_node.py b/llama-index-core/tests/schema/test_base_node.py new file mode 100644 index 0000000000000..d1d2e1697d6f2 --- /dev/null +++ b/llama-index-core/tests/schema/test_base_node.py @@ -0,0 +1,165 @@ +from typing import Any + +import pytest +from llama_index.core.schema import ( + BaseNode, + MetadataMode, + NodeRelationship, + ObjectType, + RelatedNodeInfo, +) + + +@pytest.fixture() +def MyNode(): + class MyNode(BaseNode): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + @classmethod + def get_type(cls): + return ObjectType.TEXT + + def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str: + return "Test content" + + def set_content(self, value: Any) -> None: + return super().set_content(value) + + @property + def hash(self) -> str: + return super().hash + + return MyNode + + +def test_get_metadata_str(MyNode): + metadata = { + "key": "value", + "forbidden": "true", + } + excluded = ["forbidden"] + node = MyNode( + metadata=metadata, + excluded_llm_metadata_keys=excluded, + excluded_embed_metadata_keys=excluded, + ) + assert node.get_metadata_str(MetadataMode.NONE) == "" + assert node.get_metadata_str(MetadataMode.LLM) == "key: value" + assert node.get_metadata_str(MetadataMode.EMBED) == "key: value" + + +def test_node_id(MyNode): + n = MyNode() + n.node_id = "this" + assert n.node_id == "this" + + +def test_source_node(MyNode): + n1 = MyNode() + n2 = MyNode( + relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id=n1.node_id)} + ) + assert n2.source_node.hash == n1.hash + assert n1.source_node is None + + with pytest.raises( + ValueError, match="Source object must be a single RelatedNodeInfo object" + ): + n3 = MyNode( + relationships={ + NodeRelationship.SOURCE: [RelatedNodeInfo(node_id=n1.node_id)] + } + ) + n3.source_node + + +def test_prev_node(MyNode): + n1 = MyNode() + n2 = MyNode( + relationships={NodeRelationship.PREVIOUS: RelatedNodeInfo(node_id=n1.node_id)} + ) + assert n2.prev_node.hash == n1.hash + assert n1.prev_node is None + + with pytest.raises( + ValueError, match="Previous object must be a single RelatedNodeInfo object" + ): + n3 = MyNode( + relationships={ + NodeRelationship.PREVIOUS: [RelatedNodeInfo(node_id=n1.node_id)] + } + ) + n3.prev_node + + +def test_next_node(MyNode): + n1 = MyNode() + n2 = MyNode( + relationships={NodeRelationship.NEXT: RelatedNodeInfo(node_id=n1.node_id)} + ) + assert n2.next_node.hash == n1.hash + assert n1.next_node is None + + with pytest.raises( + ValueError, match="Next object must be a single RelatedNodeInfo object" + ): + n3 = MyNode( + relationships={NodeRelationship.NEXT: [RelatedNodeInfo(node_id=n1.node_id)]} + ) + n3.next_node + + +def test_parent_node(MyNode): + n1 = MyNode() + n2 = MyNode( + relationships={NodeRelationship.PARENT: RelatedNodeInfo(node_id=n1.node_id)} + ) + assert n2.parent_node.hash == n1.hash + assert n1.parent_node is None + + with pytest.raises( + ValueError, match="Parent object must be a single RelatedNodeInfo object" + ): + n3 = MyNode( + relationships={ + NodeRelationship.PARENT: [RelatedNodeInfo(node_id=n1.node_id)] + } + ) + n3.parent_node + + +def test_child_node(MyNode): + n1 = MyNode() + n2 = MyNode( + relationships={NodeRelationship.CHILD: [RelatedNodeInfo(node_id=n1.node_id)]} + ) + assert n2.child_nodes[0].hash == n1.hash + assert n1.child_nodes is None + + with pytest.raises( + ValueError, match="Child objects must be a list of RelatedNodeInfo objects" + ): + n3 = MyNode( + relationships={NodeRelationship.CHILD: RelatedNodeInfo(node_id=n1.node_id)} + ) + n3.child_nodes + + +def test___str__(MyNode): + n = MyNode() + n.node_id = "test_node" + assert str(n) == "Node ID: test_node\nText: Test content" + + +def test_get_embedding(MyNode): + n = MyNode() + with pytest.raises(ValueError, match="embedding not set."): + n.get_embedding() + n.embedding = [0.0, 0.0] + assert n.get_embedding() == [0.0, 0.0] + + +def test_as_related_node_info(MyNode): + n = MyNode(id_="test_node") + assert n.as_related_node_info().node_id == "test_node" diff --git a/llama-index-core/tests/schema/test_media_resource.py b/llama-index-core/tests/schema/test_media_resource.py new file mode 100644 index 0000000000000..beb0594fec9cd --- /dev/null +++ b/llama-index-core/tests/schema/test_media_resource.py @@ -0,0 +1,30 @@ +from llama_index.core.bridge.pydantic import AnyUrl +from llama_index.core.schema import MediaResource + + +def test_defaults(): + m = MediaResource() + assert m.data is None + assert m.embeddings is None + assert m.mimetype is None + assert m.path is None + assert m.url is None + + +def test_mimetype(): + png_1px = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" + m = MediaResource(data=png_1px.encode("utf-8"), mimetype=None) + assert m.mimetype == "image/png" + + +def test_hash(): + assert ( + MediaResource( + data=b"test bytes", + path="foo/bar/baz", + url=AnyUrl("http://example.com"), + text="some text", + ).hash + == "7ac964db7843a9ffb37cda7b5b9822b0f84111d6a271b4991dd26d1fc68490d3" + ) + assert MediaResource().hash == "" diff --git a/llama-index-core/tests/schema/test_node.py b/llama-index-core/tests/schema/test_node.py new file mode 100644 index 0000000000000..7c34f04c1c14c --- /dev/null +++ b/llama-index-core/tests/schema/test_node.py @@ -0,0 +1,21 @@ +from llama_index.core.schema import MediaResource, Node, ObjectType + + +def test_identifiers(): + assert Node.class_name() == "Node" + assert Node.get_type() == ObjectType.MULTIMODAL + + +def test_get_content(): + assert Node().get_content() == "" + + +def test_hash(): + node = Node() + node.audio = MediaResource(data=b"test audio", mimetype="audio/aac") + node.image = MediaResource(data=b"test image", mimetype="image/png") + node.text = MediaResource(text="some text", mimetype="text/plain") + node.video = MediaResource(data=b"some video", mimetype="video/mpeg") + assert ( + node.hash == "ee411edd3dffb27470eef165ccf4df9fabaa02e7c7c39415950d3ac4d7e35e61" + ) diff --git a/llama-index-core/tests/test_schema.py b/llama-index-core/tests/schema/test_schema.py similarity index 92% rename from llama-index-core/tests/test_schema.py rename to llama-index-core/tests/schema/test_schema.py index a32c3a2e42561..eb556494ea57b 100644 --- a/llama-index-core/tests/test_schema.py +++ b/llama-index-core/tests/schema/test_schema.py @@ -1,5 +1,5 @@ import pytest -from llama_index.core.schema import NodeWithScore, TextNode, ImageNode +from llama_index.core.schema import ImageNode, Node, NodeWithScore, TextNode @pytest.fixture() @@ -60,3 +60,7 @@ def test_image_node_hash() -> None: node3 = ImageNode(image_url="base64", id_="id") node4 = ImageNode(image_url="base64", id_="id2") assert node3.hash == node4.hash + + +def test_node() -> None: + node = Node(id_="test_node") diff --git a/llama-index-integrations/node_parser/llama-index-node-parser-docling/tests/test_node_parser_docling.py b/llama-index-integrations/node_parser/llama-index-node-parser-docling/tests/test_node_parser_docling.py index be6e3b75d487a..64b9f6fe488a6 100644 --- a/llama-index-integrations/node_parser/llama-index-node-parser-docling/tests/test_node_parser_docling.py +++ b/llama-index-integrations/node_parser/llama-index-node-parser-docling/tests/test_node_parser_docling.py @@ -1,9 +1,8 @@ import json +from llama_index.core.schema import BaseNode from llama_index.core.schema import Document as LIDocument - from llama_index.node_parser.docling import DoclingNodeParser -from llama_index.core.schema import BaseNode in_json_str = json.dumps( { @@ -20,6 +19,7 @@ "text_template": "{metadata_str}\n\n{content}", "metadata_template": "{key}: {value}", "metadata_seperator": "\n", + "metadata_separator": "\n", "class_name": "Document", } ) @@ -99,6 +99,7 @@ "text_template": "{metadata_str}\n\n{content}", "metadata_template": "{key}: {value}", "metadata_seperator": "\n", + "metadata_separator": "\n", "class_name": "TextNode", }, { @@ -174,6 +175,7 @@ "text_template": "{metadata_str}\n\n{content}", "metadata_template": "{key}: {value}", "metadata_seperator": "\n", + "metadata_separator": "\n", "class_name": "TextNode", }, ] @@ -231,6 +233,7 @@ "text_template": "{metadata_str}\n\n{content}", "metadata_template": "{key}: {value}", "metadata_seperator": "\n", + "metadata_separator": "\n", "class_name": "TextNode", }, { @@ -282,6 +285,7 @@ "text_template": "{metadata_str}\n\n{content}", "metadata_template": "{key}: {value}", "metadata_seperator": "\n", + "metadata_separator": "\n", "class_name": "TextNode", }, ] diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-colpali-rerank/tests/test_postprocessor_colpali_rerank.py b/llama-index-integrations/postprocessor/llama-index-postprocessor-colpali-rerank/tests/test_postprocessor_colpali_rerank.py index bf1783ae617e8..472ac17f6a237 100644 --- a/llama-index-integrations/postprocessor/llama-index-postprocessor-colpali-rerank/tests/test_postprocessor_colpali_rerank.py +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-colpali-rerank/tests/test_postprocessor_colpali_rerank.py @@ -1,14 +1,13 @@ import os import tempfile -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import numpy as np import torch -from PIL import Image - from llama_index.core.postprocessor.types import BaseNodePostprocessor -from llama_index.core.schema import Node, NodeWithScore, QueryBundle +from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode from llama_index.postprocessor.colpali_rerank import ColPaliRerank +from PIL import Image def test_class(): @@ -45,10 +44,10 @@ def test_postprocess(mock_colpali): # Create mock nodes node1 = NodeWithScore( - node=Node(text="test1", metadata={"file_path": image1_path}), score=0.8 + node=TextNode(text="test1", metadata={"file_path": image1_path}), score=0.8 ) node2 = NodeWithScore( - node=Node(text="test2", metadata={"file_path": image2_path}), score=0.6 + node=TextNode(text="test2", metadata={"file_path": image2_path}), score=0.6 ) nodes = [node1, node2] diff --git a/llama-index-integrations/readers/llama-index-readers-docling/tests/test_readers_docling.py b/llama-index-integrations/readers/llama-index-readers-docling/tests/test_readers_docling.py index 8a8e4a3e51375..762d6e297686d 100644 --- a/llama-index-integrations/readers/llama-index-readers-docling/tests/test_readers_docling.py +++ b/llama-index-integrations/readers/llama-index-readers-docling/tests/test_readers_docling.py @@ -1,8 +1,9 @@ import json from pathlib import Path from unittest.mock import MagicMock -from llama_index.readers.docling.base import DoclingReader + from docling_core.types import DoclingDocument as DLDocument +from llama_index.readers.docling.base import DoclingReader in_json_str = json.dumps( { @@ -71,6 +72,7 @@ "text_template": "{metadata_str}\n\n{content}", "metadata_template": "{key}: {value}", "metadata_seperator": "\n", + "metadata_separator": "\n", "class_name": "Document", } ] @@ -92,6 +94,7 @@ "text_template": "{metadata_str}\n\n{content}", "metadata_template": "{key}: {value}", "metadata_seperator": "\n", + "metadata_separator": "\n", "class_name": "Document", } ]