Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support partial list update (append) in state manager #1134

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 80 additions & 14 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
InspectionStatuses,
)
from beanie.odm.operators.find.comparison import In
from beanie.odm.operators.update.array import Push
from beanie.odm.operators.update.general import (
CurrentDate,
Inc,
Expand Down Expand Up @@ -625,6 +626,65 @@ async def save(
**kwargs,
)

@saved_state_needed
def _build_update_operations_from_changes(
self,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Analyze document changes and build appropriate MongoDB update operations.
Separates list append operations (for $push) from regular field updates (for $set).

Returns:
Tuple[Dict[str, Any], Dict[str, Any]]: A tuple of (set_operations, push_operations)
"""
changes = self.get_changes()

set_operations = {}
push_operations = {}

# Analyze changes to determine operation type
for field_name, field_value in changes.items():
saved_value = None
field_path = field_name.split(".")
current_dict = self._saved_state

# Navigate to the correct nested field in saved state
for path_part in field_path:
if current_dict is None:
break
if (
isinstance(current_dict, dict)
and path_part in current_dict
):
saved_value = current_dict[path_part]
current_dict = saved_value
else:
saved_value = None
break

# Check if this is a list append operation
if (
isinstance(saved_value, list)
and isinstance(field_value, list)
and not self.state_management_replace_objects()
):
# If the new list is longer and contains all elements from the old list at the beginning
if len(field_value) > len(saved_value) and all(
field_value[i] == saved_value[i]
for i in range(len(saved_value))
):
# we can partially update list
appended_items = field_value[len(saved_value) :]
push_operations[field_name] = {"$each": appended_items}
else:
# This is a more complex list modification (element changed or removed)
set_operations[field_name] = field_value
else:
# Regular field update
set_operations[field_name] = field_value

return set_operations, push_operations

@saved_state_needed
@wrap_with_actions(EventTypes.SAVE_CHANGES)
@validate_self_before
Expand All @@ -638,30 +698,38 @@ async def save_changes(
"""
Save changes.
State management usage must be turned on

:param ignore_revision: bool - ignore revision id, if revision is turned on
:param bulk_writer: "BulkWriter" - Beanie bulk writer
:return: Optional[self]
"""
if not self.is_changed:
return None
changes = self.get_changes()

set_operations, push_operations = (
self._build_update_operations_from_changes()
)

update_operations: List[Union[SetOperator, Push, Unset]] = []

if set_operations:
update_operations.append(SetOperator(set_operations))

if push_operations:
update_operations.append(Push(push_operations))

if self.get_settings().keep_nulls is False:
update_operations.append(Unset(get_top_level_nones(self)))

if update_operations:
return await self.update(
SetOperator(changes),
Unset(get_top_level_nones(self)),
ignore_revision=ignore_revision,
session=session,
bulk_writer=bulk_writer,
)
else:
return await self.set(
changes,
*update_operations,
ignore_revision=ignore_revision,
session=session,
bulk_writer=bulk_writer,
)

return self

@classmethod
async def replace_many(
cls: Type[DocType],
Expand Down Expand Up @@ -708,7 +776,7 @@ async def update(
:param pymongo_kwargs: pymongo native parameters for update operation
:return: self
"""
arguments: list[Any] = list(args)
arguments: List[Any] = list(args)

if skip_sync is not None:
raise DeprecationWarning(
Expand Down Expand Up @@ -1031,9 +1099,7 @@ def _collect_updates(
Args:
old_dict: dict1
new_dict: dict2

Returns: dictionary with updates

"""
updates = {}
if old_dict.keys() - new_dict.keys():
Expand Down
24 changes: 21 additions & 3 deletions docs/tutorial/state_management.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# State Management

## Why Use State Management?

- **Prevents Data Loss**: When queries update the same document simultaneously, changes to different fields won't overwrite each other
- **Improved Performance**: Only changed fields are sent to MongoDB, reducing network traffic and database load
- **Change Control**: Track modifications and roll back changes if needed

## Configuration

Beanie can keep the document state synced with the database in order to find local changes and save only them.

This feature must be explicitly turned on in the `Settings` inner class:
Expand Down Expand Up @@ -39,6 +47,18 @@ await s.save_changes()

The `save_changes()` method can only be used with already inserted documents.

## Array Operations

When using state management with arrays, only top-level append operations (`append()`) generate `$push` MongoDB operations, as they are the only array modification that is _mostly_ safe in concurrent scenarios.

For arrays containing Pydantic models, the behavior is as follows:

- Appending a new model object uses `$push` with the model's dictionary representation
- Modifying a model's fields in an existing array element replaces that entire model since we can't atomically update nested model fields
- With `state_management_replace_objects = True`, any change to a nested model replaces the entire array
- Changing elements or lists inside lists updates whole values, because update by index is unsafe

All other array modifications use standard update operations since they can create race conditions in concurrent updates.

## Interacting with changes

Expand Down Expand Up @@ -74,7 +94,6 @@ s.get_previous_changes() == {"num": 200}
s.get_changes() == {}
```


## Options

By default, state management will merge the changes made to nested objects,
Expand Down Expand Up @@ -125,7 +144,7 @@ i = Item(name="Test", attributes={"attribute_1": 1.0, "attribute_2": 2.0})
await i.insert()
i.attributes.attribute_1 = 1.0
await i.save_changes()
# Changes will consist of: {"attributes.attribute_1": 1.0, "attributes.attribute_2": 2.0}
# Changes will consist of: {"attributes": {"attribute_1": 1.0, "attribute_2": 2.0}}
# Keeping attribute_2
```

Expand All @@ -138,4 +157,3 @@ i.attributes = {"attribute_1": 1.0}
await i.save_changes()
# Changes will consist of: {"attributes": {"attribute_1": 1.0}}
# Removing attribute_2
```
11 changes: 9 additions & 2 deletions docs/tutorial/update.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@ except (ValueError, beanie.exceptions.DocumentNotFound):
print("Can't replace a non existing document")
```

Note that these methods require multiple queries to the database and replace the entire document with the new version.
A more tailored solution can often be created by applying update queries directly on the database level.
> ⚠️ **Important**: Both `save()` and `replace()` methods update the **entire document** in the database, even if you only changed one field. This can lead to:
>
> - **Performance issues**: Sending and processing the entire document is less efficient than updating just changed fields
> - **Data loss**: If another process modified different fields of the same document since you retrieved it, those changes will be overwritten

To avoid these issues, two alternatives are available:

1. **State Management**: Use [State Management](state_management.md) with `save_changes()` method to track and update only modified fields
2. **Update Queries**: Use database-level update operations as described below for direct field updates

## Update queries

Expand Down
47 changes: 46 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import List, Optional, Set

import motor.motor_asyncio
import pymongo.monitoring
import pytest
from pymongo.monitoring import CommandListener

from beanie.odm.utils.pydantic import IS_PYDANTIC_V2

Expand All @@ -19,8 +23,49 @@ def settings():
return Settings()


# Command monitor to track MongoDB operations
class CommandLogger(CommandListener):
def __init__(self, command_names_to_track: Optional[Set[str]] = None):
self.commands: List[tuple] = []
self.command_names_to_track: Set[str] = command_names_to_track or {
"findAndModify",
}

def started(self, event):
if event.command_name in self.command_names_to_track:
self.commands.append((event.command_name, event.command))

def succeeded(self, event):
pass

def failed(self, event):
pass

def clear(self):
self.commands = []

def get_commands_by_name(self, command_name: str):
return [
command for command in self.commands if command[0] == command_name
]


@pytest.fixture
def command_logger():
"""
Fixture that provides a pre-configured CommandLogger for tracking MongoDB commands.
The logger tracks "findAndModify", "update", "insert" commands.

Returns:
A configured CommandLogger instance
"""
logger = CommandLogger({"findAndModify", "update", "insert"})
pymongo.monitoring.register(logger)
yield logger


@pytest.fixture()
def cli(settings):
def cli(settings, command_logger):
return motor.motor_asyncio.AsyncIOMotorClient(settings.mongodb_dsn)


Expand Down
Loading
Loading