Skip to content

Commit

Permalink
Add langchain dep to ML tests. (#33607)
Browse files Browse the repository at this point in the history
Co-authored-by: Claude <[email protected]>
  • Loading branch information
claudevdm and Claude authored Jan 16, 2025
1 parent 4797d75 commit 46699a0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 20 deletions.
54 changes: 34 additions & 20 deletions sdks/python/apache_beam/ml/rag/chunking/langchain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

"""Tests for apache_beam.ml.rag.chunking.langchain."""

import functools
import unittest

import apache_beam as beam
from apache_beam.ml.rag.types import Chunk
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import BeamAssertException
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.testing.util import is_not_empty

try:
from apache_beam.ml.rag.chunking.langchain import LangChainChunker
Expand All @@ -41,13 +43,10 @@
TRANSFORMERS_AVAILABLE = False


def chunk_equals(expected, actual):
"""Custom equality function for Chunk objects."""
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
return False
return (
expected.content == actual.content and expected.index == actual.index and
expected.metadata == actual.metadata)
def assert_true(elements, assert_fn, error_message_fn):
if not assert_fn(elements):
raise BeamAssertException(error_message_fn(elements))
return True


@unittest.skipIf(not LANGCHAIN_AVAILABLE, 'langchain is not installed.')
Expand Down Expand Up @@ -83,9 +82,15 @@ def test_no_metadata_fields(self):
| provider.get_ptransform_for_processing())
chunks_count = chunks | beam.combiners.Count.Globally()

assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks')
assert_that(chunks_count, is_not_empty(), 'Has chunks')

assert_that(chunks, lambda x: all(c.metadata == {} for c in x))
assert_that(
chunks,
functools.partial(
assert_true,
assert_fn=lambda x: (all(c.metadata == {} for c in x)),
error_message_fn=lambda x: f"Expected empty metadata, actual {x}")
)

def test_multiple_metadata_fields(self):
"""Test chunking with multiple metadata fields."""
Expand All @@ -94,6 +99,7 @@ def test_multiple_metadata_fields(self):
document_field='content',
metadata_fields=['source', 'language'],
text_splitter=splitter)
expected_metadata = {'source': 'simple.txt', 'language': 'en'}

with TestPipeline() as p:
chunks = (
Expand All @@ -102,18 +108,20 @@ def test_multiple_metadata_fields(self):
| provider.get_ptransform_for_processing())
chunks_count = chunks | beam.combiners.Count.Globally()

assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks')
assert_that(chunks_count, is_not_empty(), 'Has chunks')
assert_that(
chunks,
lambda x: all(
c.metadata == {
'source': 'simple.txt', 'language': 'en'
} for c in x))
functools.partial(
assert_true,
assert_fn=lambda x: all(
c.metadata == expected_metadata for c in x),
error_message_fn=lambda x:
f"Expected metadata {expected_metadata}, actual {x}"))

def test_recursive_splitter_no_overlap(self):
"""Test RecursiveCharacterTextSplitter with no overlap."""
splitter = RecursiveCharacterTextSplitter(
chunk_size=30, chunk_overlap=0, separators=[". "])
chunk_size=30, chunk_overlap=0, separators=[".", " "])
provider = LangChainChunker(
document_field='content',
metadata_fields=['source'],
Expand All @@ -126,8 +134,14 @@ def test_recursive_splitter_no_overlap(self):
| provider.get_ptransform_for_processing())
chunks_count = chunks | beam.combiners.Count.Globally()

assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks')
assert_that(chunks, lambda x: all(len(c.content.text) <= 30 for c in x))
assert_that(chunks_count, is_not_empty(), 'Has chunks')
assert_that(
chunks,
functools.partial(
assert_true,
assert_fn=lambda x: all(len(c.content.text) <= 30 for c in x),
error_message_fn=lambda x: f"Expected len(chunk) <= 30, \
actual {[len(c.content.text) for c in x]}"))

@unittest.skipIf(not TRANSFORMERS_AVAILABLE, "transformers not available")
def test_huggingface_tokenizer_splitter(self):
Expand Down Expand Up @@ -155,13 +169,13 @@ def check_token_lengths(chunks):
# Verify each chunk's token length is within limits
num_tokens = len(tokenizer.encode(chunk.content.text))
if not num_tokens <= 10:
raise AssertionError(
raise BeamAssertException(
f"Chunk has {num_tokens} tokens, expected <= 10")
return True

chunks_count = chunks | beam.combiners.Count.Globally()

assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks')
assert_that(chunks_count, is_not_empty(), 'Has chunks')
assert_that(chunks, check_token_lengths)

def test_invalid_document_field(self):
Expand Down
2 changes: 2 additions & 0 deletions sdks/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def get_portability_package_data():
'ml_test': [
'datatable',
'embeddings',
'langchain',
'onnxruntime',
'sentence-transformers',
'skl2onnx',
Expand All @@ -505,6 +506,7 @@ def get_portability_package_data():
'datatable',
'embeddings',
'onnxruntime',
'langchain',
'sentence-transformers',
'skl2onnx',
'pillow',
Expand Down

0 comments on commit 46699a0

Please sign in to comment.