From 126b25865a8f78326523784e86888d12ba90e4ed Mon Sep 17 00:00:00 2001 From: bushig Date: Thu, 27 Feb 2025 18:10:55 +0300 Subject: [PATCH] support partial list update (append) in state manager --- beanie/odm/documents.py | 94 ++++++- docs/tutorial/state_management.md | 24 +- docs/tutorial/update.md | 11 +- tests/conftest.py | 47 +++- tests/odm/test_list_operations.py | 432 ++++++++++++++++++++++++++++++ 5 files changed, 588 insertions(+), 20 deletions(-) create mode 100644 tests/odm/test_list_operations.py diff --git a/beanie/odm/documents.py b/beanie/odm/documents.py index 9e670dad..f484eda0 100644 --- a/beanie/odm/documents.py +++ b/beanie/odm/documents.py @@ -77,6 +77,7 @@ InspectionStatuses, ) from beanie.odm.operators.find.comparison import In +from beanie.odm.operators.update.array import Push from beanie.odm.operators.update.general import ( CurrentDate, Inc, @@ -625,6 +626,65 @@ async def save( **kwargs, ) + @saved_state_needed + def _build_update_operations_from_changes( + self, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Analyze document changes and build appropriate MongoDB update operations. + Separates list append operations (for $push) from regular field updates (for $set). + + Returns: + Tuple[Dict[str, Any], Dict[str, Any]]: A tuple of (set_operations, push_operations) + """ + changes = self.get_changes() + + set_operations = {} + push_operations = {} + + # Analyze changes to determine operation type + for field_name, field_value in changes.items(): + saved_value = None + field_path = field_name.split(".") + current_dict = self._saved_state + + # Navigate to the correct nested field in saved state + for path_part in field_path: + if current_dict is None: + break + if ( + isinstance(current_dict, dict) + and path_part in current_dict + ): + saved_value = current_dict[path_part] + current_dict = saved_value + else: + saved_value = None + break + + # Check if this is a list append operation + if ( + isinstance(saved_value, list) + and isinstance(field_value, list) + and not self.state_management_replace_objects() + ): + # If the new list is longer and contains all elements from the old list at the beginning + if len(field_value) > len(saved_value) and all( + field_value[i] == saved_value[i] + for i in range(len(saved_value)) + ): + # we can partially update list + appended_items = field_value[len(saved_value) :] + push_operations[field_name] = {"$each": appended_items} + else: + # This is a more complex list modification (element changed or removed) + set_operations[field_name] = field_value + else: + # Regular field update + set_operations[field_name] = field_value + + return set_operations, push_operations + @saved_state_needed @wrap_with_actions(EventTypes.SAVE_CHANGES) @validate_self_before @@ -638,30 +698,38 @@ async def save_changes( """ Save changes. State management usage must be turned on - :param ignore_revision: bool - ignore revision id, if revision is turned on :param bulk_writer: "BulkWriter" - Beanie bulk writer :return: Optional[self] """ if not self.is_changed: return None - changes = self.get_changes() + + set_operations, push_operations = ( + self._build_update_operations_from_changes() + ) + + update_operations: List[Union[SetOperator, Push, Unset]] = [] + + if set_operations: + update_operations.append(SetOperator(set_operations)) + + if push_operations: + update_operations.append(Push(push_operations)) + if self.get_settings().keep_nulls is False: + update_operations.append(Unset(get_top_level_nones(self))) + + if update_operations: return await self.update( - SetOperator(changes), - Unset(get_top_level_nones(self)), - ignore_revision=ignore_revision, - session=session, - bulk_writer=bulk_writer, - ) - else: - return await self.set( - changes, + *update_operations, ignore_revision=ignore_revision, session=session, bulk_writer=bulk_writer, ) + return self + @classmethod async def replace_many( cls: Type[DocType], @@ -708,7 +776,7 @@ async def update( :param pymongo_kwargs: pymongo native parameters for update operation :return: self """ - arguments: list[Any] = list(args) + arguments: List[Any] = list(args) if skip_sync is not None: raise DeprecationWarning( @@ -1031,9 +1099,7 @@ def _collect_updates( Args: old_dict: dict1 new_dict: dict2 - Returns: dictionary with updates - """ updates = {} if old_dict.keys() - new_dict.keys(): diff --git a/docs/tutorial/state_management.md b/docs/tutorial/state_management.md index 658d9cae..7ab07332 100644 --- a/docs/tutorial/state_management.md +++ b/docs/tutorial/state_management.md @@ -1,5 +1,13 @@ # State Management +## Why Use State Management? + +- **Prevents Data Loss**: When queries update the same document simultaneously, changes to different fields won't overwrite each other +- **Improved Performance**: Only changed fields are sent to MongoDB, reducing network traffic and database load +- **Change Control**: Track modifications and roll back changes if needed + +## Configuration + Beanie can keep the document state synced with the database in order to find local changes and save only them. This feature must be explicitly turned on in the `Settings` inner class: @@ -39,6 +47,18 @@ await s.save_changes() The `save_changes()` method can only be used with already inserted documents. +## Array Operations + +When using state management with arrays, only top-level append operations (`append()`) generate `$push` MongoDB operations, as they are the only array modification that is _mostly_ safe in concurrent scenarios. + +For arrays containing Pydantic models, the behavior is as follows: + +- Appending a new model object uses `$push` with the model's dictionary representation +- Modifying a model's fields in an existing array element replaces that entire model since we can't atomically update nested model fields +- With `state_management_replace_objects = True`, any change to a nested model replaces the entire array +- Changing elements or lists inside lists updates whole values, because update by index is unsafe + +All other array modifications use standard update operations since they can create race conditions in concurrent updates. ## Interacting with changes @@ -74,7 +94,6 @@ s.get_previous_changes() == {"num": 200} s.get_changes() == {} ``` - ## Options By default, state management will merge the changes made to nested objects, @@ -125,7 +144,7 @@ i = Item(name="Test", attributes={"attribute_1": 1.0, "attribute_2": 2.0}) await i.insert() i.attributes.attribute_1 = 1.0 await i.save_changes() -# Changes will consist of: {"attributes.attribute_1": 1.0, "attributes.attribute_2": 2.0} +# Changes will consist of: {"attributes": {"attribute_1": 1.0, "attribute_2": 2.0}} # Keeping attribute_2 ``` @@ -138,4 +157,3 @@ i.attributes = {"attribute_1": 1.0} await i.save_changes() # Changes will consist of: {"attributes": {"attribute_1": 1.0}} # Removing attribute_2 -``` diff --git a/docs/tutorial/update.md b/docs/tutorial/update.md index 9362fa43..ab6d74ca 100644 --- a/docs/tutorial/update.md +++ b/docs/tutorial/update.md @@ -28,8 +28,15 @@ except (ValueError, beanie.exceptions.DocumentNotFound): print("Can't replace a non existing document") ``` -Note that these methods require multiple queries to the database and replace the entire document with the new version. -A more tailored solution can often be created by applying update queries directly on the database level. +> ⚠️ **Important**: Both `save()` and `replace()` methods update the **entire document** in the database, even if you only changed one field. This can lead to: +> +> - **Performance issues**: Sending and processing the entire document is less efficient than updating just changed fields +> - **Data loss**: If another process modified different fields of the same document since you retrieved it, those changes will be overwritten + +To avoid these issues, two alternatives are available: + +1. **State Management**: Use [State Management](state_management.md) with `save_changes()` method to track and update only modified fields +2. **Update Queries**: Use database-level update operations as described below for direct field updates ## Update queries diff --git a/tests/conftest.py b/tests/conftest.py index ed5a91d5..cccb720f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,9 @@ +from typing import List, Optional, Set + import motor.motor_asyncio +import pymongo.monitoring import pytest +from pymongo.monitoring import CommandListener from beanie.odm.utils.pydantic import IS_PYDANTIC_V2 @@ -19,8 +23,49 @@ def settings(): return Settings() +# Command monitor to track MongoDB operations +class CommandLogger(CommandListener): + def __init__(self, command_names_to_track: Optional[Set[str]] = None): + self.commands: List[tuple] = [] + self.command_names_to_track: Set[str] = command_names_to_track or { + "findAndModify", + } + + def started(self, event): + if event.command_name in self.command_names_to_track: + self.commands.append((event.command_name, event.command)) + + def succeeded(self, event): + pass + + def failed(self, event): + pass + + def clear(self): + self.commands = [] + + def get_commands_by_name(self, command_name: str): + return [ + command for command in self.commands if command[0] == command_name + ] + + +@pytest.fixture +def command_logger(): + """ + Fixture that provides a pre-configured CommandLogger for tracking MongoDB commands. + The logger tracks "findAndModify", "update", "insert" commands. + + Returns: + A configured CommandLogger instance + """ + logger = CommandLogger({"findAndModify", "update", "insert"}) + pymongo.monitoring.register(logger) + yield logger + + @pytest.fixture() -def cli(settings): +def cli(settings, command_logger): return motor.motor_asyncio.AsyncIOMotorClient(settings.mongodb_dsn) diff --git a/tests/odm/test_list_operations.py b/tests/odm/test_list_operations.py new file mode 100644 index 00000000..73a3ead8 --- /dev/null +++ b/tests/odm/test_list_operations.py @@ -0,0 +1,432 @@ +import asyncio +from typing import Any, Dict, List, Optional + +import pytest +from pydantic import BaseModel, Field + +from beanie import Document, init_beanie + +pytestmark = pytest.mark.asyncio + + +class Item(BaseModel): + name: str + quantity: int + + +class ComplexListOperationsModel(Document): + str_list: List[str] = Field(default_factory=list) + int_list: Optional[List[int]] = None + nested_dict_list: List[Dict[str, Any]] = Field(default_factory=list) + matrix: List[List[int]] = Field(default_factory=list) + items: List[Item] = Field(default_factory=list) + + class Settings: + name = "complex_list_operations" + use_state_management = True + + +class ReplaceObjectsDocument(Document): + int_list: List[int] = Field(default_factory=list) + + class Settings: + name = "replace_objects_document" + use_state_management = True + state_management_replace_objects = True + + +async def test_append_operation_generates_push_with_exact_values( + db, command_logger +): + """Test that appending to a list generates $push operation with exact values in MongoDB""" + await init_beanie( + database=db, document_models=[ComplexListOperationsModel] + ) + + # Create document with initial list + doc = ComplexListOperationsModel( + str_list=["a", "b"], + int_list=[1, 2], + nested_dict_list=[{"key": "value1"}, {"key": "value2"}], + matrix=[[1, 2], [3, 4]], + items=[Item(name="item1", quantity=5)], + ) + await doc.insert() + + # Clear any previous commands + command_logger.clear() + + # Append to lists + doc.str_list.append("c") + doc.int_list.append(3) + doc.nested_dict_list.append({"key": "value3"}) + doc.matrix.append([5, 6]) + doc.items.append(Item(name="item2", quantity=10)) + await doc.save_changes() + + # Check operations + has_push = False + push_values = {} + + for cmd_name, command in command_logger.get_commands_by_name( + "findAndModify" + ): + update = command["update"] + if "$push" in update: + has_push = True + push_values = update["$push"] + break + + assert has_push, "$push operation was not found in commands" + + # Verify the exact values in the push operation + assert "str_list" in push_values, "str_list not found in $push operation" + assert push_values["str_list"]["$each"] == [ + "c" + ], f"Expected push value ['c'], got {push_values['str_list']['$each']}" + + assert "int_list" in push_values, "int_list not found in $push operation" + assert push_values["int_list"]["$each"] == [ + 3 + ], f"Expected push value [3], got {push_values['int_list']['$each']}" + + # Fetch from DB and verify the values were correctly stored + updated_doc = await ComplexListOperationsModel.get(doc.id) + assert updated_doc.str_list == ["a", "b", "c"] + assert updated_doc.int_list == [1, 2, 3] + assert updated_doc.nested_dict_list == [ + {"key": "value1"}, + {"key": "value2"}, + {"key": "value3"}, + ] + assert updated_doc.matrix == [[1, 2], [3, 4], [5, 6]] + assert len(updated_doc.items) == 2 + assert updated_doc.items[0].name == "item1" + assert updated_doc.items[1].name == "item2" + + await ComplexListOperationsModel.delete_all() + + +async def test_modify_element_updates_whole_list_with_exact_values( + db, command_logger +): + """Test that modifying an element in a list updates the entire list with correct values""" + await init_beanie( + database=db, document_models=[ComplexListOperationsModel] + ) + + # Create document with initial list + doc = ComplexListOperationsModel( + str_list=["a", "b", "c"], + int_list=[1, 2, 3], + nested_dict_list=[{"key": "value1"}, {"key": "value2"}], + matrix=[[1, 2], [3, 4]], + items=[ + Item(name="item1", quantity=5), + Item(name="item2", quantity=10), + ], + ) + await doc.insert() + + # Clear any previous commands + command_logger.clear() + + # Modify elements in lists + doc.str_list[1] = "MODIFIED" + doc.int_list[0] = 100 + doc.nested_dict_list[0]["key"] = "MODIFIED_VALUE" + doc.matrix[1][0] = 99 + doc.items[0].quantity = 50 + + await doc.save_changes() + + # Check operations + has_set = False + set_values = {} + + for cmd_name, command in command_logger.get_commands_by_name( + "findAndModify" + ): + update = command["update"] + if "$set" in update: + has_set = True + set_values = update["$set"] + break + + assert has_set, "$set operation was not found in commands" + + # Verify the exact values in the set operations + assert "str_list" in set_values, "str_list not found in $set operation" + assert ( + set_values["str_list"] == ["a", "MODIFIED", "c"] + ), f"Expected set value ['a', 'MODIFIED', 'c'], got {set_values['str_list']}" + + assert "int_list" in set_values, "int_list not found in $set operation" + assert set_values["int_list"] == [ + 100, + 2, + 3, + ], f"Expected set value [100, 2, 3], got {set_values['int_list']}" + + # Fetch from DB and verify the values were correctly stored + updated_doc = await ComplexListOperationsModel.get(doc.id) + assert updated_doc.str_list == ["a", "MODIFIED", "c"] + assert updated_doc.int_list == [100, 2, 3] + assert updated_doc.nested_dict_list[0]["key"] == "MODIFIED_VALUE" + assert updated_doc.matrix[1][0] == 99 + assert updated_doc.items[0].quantity == 50 + + await ComplexListOperationsModel.delete_all() + + +async def test_multiple_list_operations_with_exact_values(db, command_logger): + """Test handling multiple list operations in a single save_changes call with verification of exact values""" + await init_beanie( + database=db, document_models=[ComplexListOperationsModel] + ) + + # Create document with initial list + doc = ComplexListOperationsModel( + str_list=["a", "b"], + int_list=[1, 2], + nested_dict_list=[{"key": "value1"}], + matrix=[[1, 2]], + items=[Item(name="item1", quantity=5)], + ) + await doc.insert() + + # Clear any previous commands + command_logger.clear() + + # Multiple operations: append and modify + doc.str_list.append("c") # append operation + doc.int_list[0] = 100 # modify operation + await doc.save_changes() + + # Check operations + has_push = False + has_set = False + push_values = {} + set_values = {} + + for cmd_name, command in command_logger.get_commands_by_name( + "findAndModify" + ): + update = command["update"] + if "$push" in update: + has_push = True + push_values = update["$push"] + if "$set" in update: + has_set = True + set_values = update["$set"] + + assert has_push, "$push operation was not found in commands" + assert has_set, "$set operation was not found in commands" + + # Verify exact values + assert "str_list" in push_values, "str_list not found in $push operation" + assert push_values["str_list"]["$each"] == [ + "c" + ], f"Expected push value ['c'], got {push_values['str_list']['$each']}" + + assert "int_list" in set_values, "int_list not found in $set operation" + assert set_values["int_list"] == [ + 100, + 2, + ], f"Expected set value [100, 2], got {set_values['int_list']}" + + # Fetch from DB and verify + updated_doc = await ComplexListOperationsModel.get(doc.id) + assert updated_doc.str_list == ["a", "b", "c"] + assert updated_doc.int_list == [100, 2] + + await ComplexListOperationsModel.delete_all() + + +async def test_nested_list_updates_use_full_update(db, command_logger): + """Test that updating lists nested in lists will always generate a full update of the outer list""" + await init_beanie( + database=db, document_models=[ComplexListOperationsModel] + ) + + # Create document with initial nested lists + doc = ComplexListOperationsModel( + str_list=[], + int_list=[], + nested_dict_list=[{"key": "value1", "nested_list": [1, 2, 3]}], + matrix=[[1, 2], [3, 4], [5, 6]], + items=[], + ) + await doc.insert() + + # Clear any previous commands + command_logger.clear() + + # Modify a nested list element + doc.matrix[1][0] = 99 # Modify element in nested list + await doc.save_changes() + + # Check operations + has_set_on_outer_list = False + set_values = {} + + for cmd_name, command in command_logger.get_commands_by_name( + "findAndModify" + ): + update = command["update"] + if "$set" in update and "matrix" in update["$set"]: + has_set_on_outer_list = True + set_values = update["$set"] + break + + assert has_set_on_outer_list, "Full update on the outer list was not found" + assert "matrix" in set_values, "Matrix not found in $set operation" + + # Verify the entire matrix was updated, not just the nested element + assert set_values["matrix"] == [ + [1, 2], + [99, 4], + [5, 6], + ], f"Expected set value for matrix, got {set_values['matrix']}" + + # Now test updating a list in a nested dictionary + command_logger.clear() # Clear previous commands + + doc.nested_dict_list[0]["nested_list"].append( + 4 + ) # Append to a list inside a dictionary + await doc.save_changes() + + # Verify that the entire outer list (nested_dict_list) was updated + has_set_on_outer_dict_list = False + for cmd_name, command in command_logger.get_commands_by_name( + "findAndModify" + ): + update = command["update"] + if "$set" in update and "nested_dict_list" in update["$set"]: + has_set_on_outer_dict_list = True + set_values = update["$set"] + break + + assert ( + has_set_on_outer_dict_list + ), "Full update on the nested_dict_list was not found" + assert set_values["nested_dict_list"][0]["nested_list"] == [ + 1, + 2, + 3, + 4, + ], "Nested list not updated correctly" + + # Fetch from DB and verify the values + updated_doc = await ComplexListOperationsModel.get(doc.id) + assert updated_doc.matrix == [[1, 2], [99, 4], [5, 6]] + assert updated_doc.nested_dict_list[0]["nested_list"] == [1, 2, 3, 4] + + await ComplexListOperationsModel.delete_all() + + +async def test_concurrent_updates_to_list(db): + """Test that concurrent updates to lists are correctly saved""" + + await init_beanie( + database=db, document_models=[ComplexListOperationsModel] + ) + + # Create initial document with empty lists + doc = ComplexListOperationsModel(str_list=[], int_list=[]) + await doc.insert() + doc_id = doc.id + + # Define concurrent update functions + async def append_to_str_list(): + doc1 = await ComplexListOperationsModel.get(doc_id) + doc1.str_list.append("concurrent1") + await doc1.save_changes() + + async def append_to_int_list(): + doc2 = await ComplexListOperationsModel.get(doc_id) + doc2.int_list.append(42) + await doc2.save_changes() + + async def update_both_lists(): + doc3 = await ComplexListOperationsModel.get(doc_id) + doc3.str_list.append("concurrent2") + doc3.int_list.append(99) + await doc3.save_changes() + + # Execute concurrent updates + await asyncio.gather( + append_to_str_list(), append_to_int_list(), update_both_lists() + ) + + # Verify final state reflects all updates + final_doc = await ComplexListOperationsModel.get(doc_id) + assert ( + "concurrent1" in final_doc.str_list + ), "First concurrent str_list update missing" + assert ( + "concurrent2" in final_doc.str_list + ), "Second concurrent str_list update missing" + assert 42 in final_doc.int_list, "First concurrent int_list update missing" + assert ( + 99 in final_doc.int_list + ), "Second concurrent int_list update missing" + + # All items should be preserved with no duplicates + assert ( + len(final_doc.str_list) == 2 + ), "str_list contains wrong number of items" + assert ( + len(final_doc.int_list) == 2 + ), "int_list contains wrong number of items" + + await ComplexListOperationsModel.delete_all() + + +async def test_state_management_replace_objects_behavior(db, command_logger): + """Test that when state_management_replace_objects is True, lists are always fully updated""" + await init_beanie(database=db, document_models=[ReplaceObjectsDocument]) + + # Create document with initial list + doc = ReplaceObjectsDocument(int_list=[1, 2, 3]) + await doc.insert() + + # Clear any previous commands + command_logger.clear() + + # Append to the list (which would normally use $push if replace_objects were False) + doc.int_list.append(4) + await doc.save_changes() + + # Check operations - should use $set even for append operations + has_push = False + has_set = False + set_values = {} + + for cmd_name, command in command_logger.get_commands_by_name( + "findAndModify" + ): + update = command["update"] + if "$push" in update: + has_push = True + if "$set" in update and "int_list" in update["$set"]: + has_set = True + set_values = update["$set"] + break + + # With replace_objects=True, we should always use $set, not $push + assert not has_push, "$push operation was found but shouldn't be used with state_management_replace_objects=True" + assert has_set, "$set operation for int_list was not found" + assert set_values["int_list"] == [ + 1, + 2, + 3, + 4, + ], f"Expected set value [1, 2, 3, 4], got {set_values['int_list']}" + + # Verify final state + final_doc = await ReplaceObjectsDocument.get(doc.id) + assert final_doc.int_list == [1, 2, 3, 4] + + await ReplaceObjectsDocument.delete_all()