Skip to content

Commit

Permalink
added admin methods
Browse files Browse the repository at this point in the history
  • Loading branch information
teo-milea committed Oct 11, 2024
1 parent 99e6659 commit 4f8e4c9
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 2 deletions.
35 changes: 33 additions & 2 deletions packages/syft/src/syft/service/sync/sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,22 @@
from ..service import TYPE_TO_SERVICE
from ..service import service_method
from ..user.user_roles import ADMIN_ROLE_LEVEL
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
from .sync_stash import SyncStash
from .sync_state import SyncState

logger = logging.getLogger(__name__)


def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> ObjectStash:
if isinstance(item, ActionObject):
return get_store_by_type(context=context, obj_type=type(item))


def get_store_by_type(context: AuthedServiceContext, obj_type: type) -> ObjectStash:
if issubclass(obj_type, ActionObject):
service = context.server.services.action # type: ignore
return service.stash # type: ignore
service = context.server.get_service(TYPE_TO_SERVICE[type(item)]) # type: ignore
service = context.server.get_service(TYPE_TO_SERVICE[obj_type]) # type: ignore
return service.stash


Expand Down Expand Up @@ -450,3 +455,29 @@ def build_current_state(
)
def _get_state(self, context: AuthedServiceContext) -> SyncState:
return self.build_current_state(context).unwrap()

@service_method(
path="sync._get_object",
name="_get_object",
roles=DATA_SCIENTIST_ROLE_LEVEL,
)
def _get_object(
self, context: AuthedServiceContext, uid: UID, object_type: type
) -> Any:
return (
get_store_by_type(context, object_type)
.get_by_uid(credentials=context.credentials, uid=uid)
.unwrap()
)

@service_method(
path="sync._update_object",
name="_update_object",
roles=ADMIN_ROLE_LEVEL,
)
def _update_object(self, context: AuthedServiceContext, object: Any) -> Any:
return (
get_store(context, object)
.update(credentials=context.credentials, obj=object)
.unwrap()
)
70 changes: 70 additions & 0 deletions packages/syft/tests/syft/service/sync/get_set_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# third party
import numpy as np
import pytest

# syft absolute
import syft
import syft as sy
from syft.client.datasite_client import DatasiteClient
from syft.client.sync_decision import SyncDecision
from syft.client.syncing import compare_clients
from syft.client.syncing import resolve
from syft.server.worker import Worker
from syft.service.action.action_object import ActionObject
from syft.service.code.user_code import ApprovalDecision
from syft.service.code.user_code import UserCodeStatus
from syft.service.dataset.dataset import Dataset
from syft.service.job.job_stash import Job
from syft.service.request.request import RequestStatus
from syft.service.response import SyftSuccess
from syft.service.sync.resolve_widget import ResolveWidget
from syft.service.user.user import User, UserView
from syft.types.errors import SyftException


def get_ds_client(client: DatasiteClient) -> DatasiteClient:
client.register(
name="a",
email="[email protected]",
password="asdf",
password_verify="asdf",
)
return client.login(email="[email protected]", password="asdf")


def test_get_set_object(high_worker):
high_client: DatasiteClient = high_worker.root_client
_ = get_ds_client(high_client)
root_datasite_client = high_worker.root_client
dataset = sy.Dataset(
name="local_test",
asset_list=[
sy.Asset(
name="local_test",
data=[1, 2, 3],
mock=[1, 1, 1],
)
],
)
root_datasite_client.upload_dataset(dataset)
dataset = root_datasite_client.datasets[0]

other_dataset = high_client.api.services.sync._get_object(uid=dataset.id, object_type=Dataset)
other_dataset.server_uid = dataset.server_uid
assert dataset == other_dataset
other_dataset.name = "new_name"
updated_dataset = high_client.api.services.sync._update_object(
object=other_dataset
)
assert updated_dataset.name == "new_name"

asset = root_datasite_client.datasets[0].assets[0]
source_ao = high_client.api.services.action.get(uid=asset.action_id)
ao = high_client.api.services.sync._get_object(
uid=asset.action_id, object_type=ActionObject
)
ao._set_obj_location_(
high_worker.id,
root_datasite_client.credentials,
)
assert source_ao == ao

0 comments on commit 4f8e4c9

Please sign in to comment.