Skip to content

Commit

Permalink
Load patches from store
Browse files Browse the repository at this point in the history
  • Loading branch information
Olivier Michaud committed Oct 17, 2024
1 parent 7c8fbc0 commit c60bec8
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 12 deletions.
7 changes: 7 additions & 0 deletions src/saturn_engine/stores/topologies_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,10 @@ def patch(*, session: AnySession, patch: BaseObject) -> TopologyPatch:
)
session.execute(stmt) # type: ignore
return topology_patch


def get_patches(*, session: AnySession) -> list[TopologyPatch]:
"""_summary_
Return all the patches
"""
return session.query(TopologyPatch).all()
14 changes: 14 additions & 0 deletions src/saturn_engine/utils/dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import typing as t


def deep_merge(a: dict[str, t.Any], b: dict[str, t.Any]) -> dict[str, t.Any]:
"""
Merge b into a
"""
result = a.copy()
for key, value in b.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = deep_merge(result[key], value)
else:
result[key] = value
return result
37 changes: 35 additions & 2 deletions src/saturn_engine/worker_manager/config/declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import re
from collections import defaultdict

from saturn_engine.models.topology_patches import TopologyPatch
from saturn_engine.utils import dict as dict_utils
from saturn_engine.utils.declarative_config import UncompiledObject
from saturn_engine.utils.declarative_config import load_uncompiled_objects_from_path
from saturn_engine.utils.declarative_config import load_uncompiled_objects_from_str
Expand All @@ -31,8 +33,15 @@

def compile_static_definitions(
uncompiled_objects: list[UncompiledObject],
patches: list[TopologyPatch] | None = None,
) -> StaticDefinitions:
objects_by_kind: DefaultDict[str, dict[str, UncompiledObject]] = defaultdict(dict)

if patches:
uncompiled_objects = merge_with_patches(
uncompiled_objects=uncompiled_objects, patches=patches
)

for uncompiled_object in uncompiled_objects:
if uncompiled_object.name in objects_by_kind[uncompiled_object.kind]:
raise Exception(
Expand Down Expand Up @@ -122,12 +131,14 @@ def load_definitions_from_str(definitions: str) -> StaticDefinitions:
return compile_static_definitions(load_uncompiled_objects_from_str(definitions))


def load_definitions_from_paths(config_dirs: list[str]) -> StaticDefinitions:
def load_definitions_from_paths(
config_dirs: list[str], patches: list[TopologyPatch] | None = None
) -> StaticDefinitions:
uncompiled_objects = []
for config_dir in config_dirs:
uncompiled_objects.extend(load_uncompiled_objects_from_path(config_dir))

return compile_static_definitions(uncompiled_objects)
return compile_static_definitions(uncompiled_objects, patches=patches)


def filter_with_jobs_selector(
Expand All @@ -141,3 +152,25 @@ def filter_with_jobs_selector(
if pattern.search(name)
}
return dataclasses.replace(definitions, jobs=jobs, job_definitions=job_definitions)


def merge_with_patches(
uncompiled_objects: list[UncompiledObject], patches: list[TopologyPatch]
) -> list[UncompiledObject]:
uncompiled_object_by_kind_and_name = {
(u.kind, u.name): u for u in uncompiled_objects
}
for patch in patches:
uncompiled_object = uncompiled_object_by_kind_and_name.get(
(patch.kind, patch.name)
)
if not uncompiled_object:
logging.warning(
f"Can't find an uncompiled objects to use with patch {patch=}"
)
continue

uncompiled_object.data = dict_utils.deep_merge(
a=uncompiled_object.data, b=patch.data
)
return uncompiled_objects
7 changes: 6 additions & 1 deletion src/saturn_engine/worker_manager/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from saturn_engine.config import WorkerManagerConfig
from saturn_engine.stores import topologies_store
from saturn_engine.utils.sqlalchemy import AnySession
from saturn_engine.worker_manager.config.declarative import filter_with_jobs_selector
from saturn_engine.worker_manager.config.declarative import load_definitions_from_paths
Expand Down Expand Up @@ -38,7 +39,11 @@ def _load_static_definition(
- Jobs
- JobDefinitions
"""
definitions = load_definitions_from_paths(config.static_definitions_directories)
patches = topologies_store.get_patches(session=session)
definitions = load_definitions_from_paths(
config.static_definitions_directories, patches=patches
)

if config.static_definitions_jobs_selector:
definitions = filter_with_jobs_selector(
definitions=definitions,
Expand Down
3 changes: 0 additions & 3 deletions src/saturn_engine/worker_manager/services/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def lock_jobs(
session: AnySyncSession,
) -> LockResponse:
logger = logging.getLogger(f"{__name__}.lock_jobs")

# Note:
# - Leftover items remain unassigned.
assignation_expiration_cutoff: datetime = datetime.now() - timedelta(minutes=15)
Expand Down Expand Up @@ -61,7 +60,6 @@ def lock_jobs(
selector=lock_input.selector,
)
)

# Join definitions and filtered out by executors
for item in assigned_items.copy():
try:
Expand Down Expand Up @@ -129,7 +127,6 @@ def lock_jobs(
continue

executors.setdefault(executor.name, executor)

# Refresh assignments
new_assigned_at = datetime.now()
for assigned_item in assigned_items:
Expand Down
152 changes: 152 additions & 0 deletions tests/worker_manager/api/test_topologies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from unittest import mock

from flask.testing import FlaskClient
from sqlalchemy.orm import Session

from saturn_engine.worker_manager.app import SaturnApp


def test_put_topology_patch(client: FlaskClient) -> None:
Expand Down Expand Up @@ -36,3 +41,150 @@ def test_put_topology_patch(client: FlaskClient) -> None:
"metadata": {"name": "test-topic", "labels": {}},
"spec": {"type": "RabbitMQTopic", "options": {"queue_name": "queue_2"}},
}


def test_put_topology_patch_ensure_topology_changed(
tmp_path: str, app: SaturnApp, client: FlaskClient, session: Session
) -> None:
topology = """
apiVersion: saturn.flared.io/v1alpha1
kind: SaturnExecutor
metadata:
name: default
spec:
type: ARQExecutor
options:
redis_url: "redis://redis"
queue_name: "arq:saturn-default"
redis_pool_args:
max_connections: 10000
concurrency: 108
---
apiVersion: saturn.flared.io/v1alpha1
kind: SaturnInventory
metadata:
name: test-inventory
spec:
type: testtype
---
apiVersion: saturn.flared.io/v1alpha1
kind: SaturnJobDefinition
metadata:
name: job_1
labels:
owner: team-saturn
spec:
minimalInterval: "@weekly"
template:
input:
inventory: test-inventory
pipeline:
name: something.saturn.pipelines.aa.bb
---
"""
with open(f"{tmp_path}/topology.yaml", "+w") as f:
f.write(topology)

app.saturn.config.static_definitions_directories = [tmp_path]
app.saturn.load_static_definition(session=session)

resp = client.post("/api/jobs/sync")
assert resp.status_code == 200
assert resp.json == {}
resp = client.post("/api/lock", json={"worker_id": "worker-1"})
assert resp.json == {
"executors": [
{
"name": "default",
"options": {
"concurrency": 108,
"queue_name": "arq:saturn-default",
"redis_pool_args": {"max_connections": 10000},
"redis_url": "redis://redis",
},
"type": "ARQExecutor",
}
],
"items": [
{
"config": {},
"executor": "default",
"input": {"name": "test-inventory", "options": {}, "type": "testtype"},
"labels": {"owner": "team-saturn"},
"name": mock.ANY,
"output": {},
"pipeline": {
"args": {},
"info": {
"name": "something.saturn.pipelines.aa.bb",
"resources": {},
},
},
"state": {
"cursor": None,
"started_at": mock.ANY,
},
}
],
"resources": [],
"resources_providers": [],
}

# Let's change the pipeline name
resp = client.put(
"/api/topologies/patch",
json={
"apiVersion": "saturn.flared.io/v1alpha1",
"kind": "SaturnJobDefinition",
"metadata": {"name": "job_1"},
"spec": {
"template": {
"pipeline": {"name": "something.else.saturn.pipelines.aa.bb"},
},
},
},
)

# And reset the static definition
session.commit()
app.saturn.load_static_definition(session=session)

# Make sure we have the new topology version
resp = client.post("/api/lock", json={"worker_id": "worker-1"})
assert resp.json == {
"executors": [
{
"name": "default",
"options": {
"concurrency": 108,
"queue_name": "arq:saturn-default",
"redis_pool_args": {"max_connections": 10000},
"redis_url": "redis://redis",
},
"type": "ARQExecutor",
}
],
"items": [
{
"config": {},
"executor": "default",
"input": {"name": "test-inventory", "options": {}, "type": "testtype"},
"labels": {"owner": "team-saturn"},
"name": mock.ANY,
"output": {},
"pipeline": {
"args": {},
"info": {
"name": "something.else.saturn.pipelines.aa.bb",
"resources": {},
},
},
"state": {
"cursor": None,
"started_at": mock.ANY,
},
}
],
"resources": [],
"resources_providers": [],
}
74 changes: 74 additions & 0 deletions tests/worker_manager/config/test_declarative.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import os

import pytest
from flask.testing import FlaskClient
from sqlalchemy.orm import Session

from saturn_engine.core.api import ComponentDefinition
from saturn_engine.core.api import JobDefinition
from saturn_engine.core.api import ResourceItem
from saturn_engine.stores import topologies_store
from saturn_engine.utils.declarative_config import BaseObject
from saturn_engine.utils.declarative_config import ObjectMetadata
from saturn_engine.utils.declarative_config import load_uncompiled_objects_from_str
from saturn_engine.worker_manager.config.declarative import compile_static_definitions
from saturn_engine.worker_manager.config.declarative import filter_with_jobs_selector
from saturn_engine.worker_manager.config.declarative import load_definitions_from_paths
from saturn_engine.worker_manager.config.declarative import load_definitions_from_str
Expand Down Expand Up @@ -606,3 +613,70 @@ def test_dynamic_definition() -> None:
static_definitions = load_definitions_from_str(resources_provider_str)
assert "test-inventory" in static_definitions.inventories
assert static_definitions.inventories["test-inventory"].name == "test-inventory"


def test_compile_static_definitions_with_patches(
client: FlaskClient, session: Session
) -> None:
concurrency_definition_str = """
apiVersion: saturn.flared.io/v1alpha1
kind: SaturnResource
metadata:
name: test-resource
labels:
owner: team-saturn
spec:
type: TestApiKey
data:
key: "qwe"
default_delay: 10
concurrency: 2
"""

uncompiled_objects = load_uncompiled_objects_from_str(concurrency_definition_str)

compileed_static_definitions_without_patch = compile_static_definitions(
uncompiled_objects=uncompiled_objects
)

assert compileed_static_definitions_without_patch.resources == {
"test-resource-1": ResourceItem(
name="test-resource-1",
type="TestApiKey",
data={"key": "qwe"},
default_delay=10.0,
rate_limit=None,
),
"test-resource-2": ResourceItem(
name="test-resource-2",
type="TestApiKey",
data={"key": "qwe"},
default_delay=10.0,
rate_limit=None,
),
}

# Now we create a patch to change the resource concurrency
patch = topologies_store.patch(
session=session,
patch=BaseObject(
kind="SaturnResource",
apiVersion="saturn.flared.io/v1alpha1",
metadata=ObjectMetadata(name="test-resource"),
spec={"concurrency": 1},
),
)

compileed_static_definitions_without_patch = compile_static_definitions(
uncompiled_objects=uncompiled_objects, patches=[patch]
)

assert compileed_static_definitions_without_patch.resources == {
"test-resource": ResourceItem(
name="test-resource",
type="TestApiKey",
data={"key": "qwe"},
default_delay=10.0,
rate_limit=None,
)
}
Loading

0 comments on commit c60bec8

Please sign in to comment.