-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1667 from pyiron/decorator
Implement decorators in `pyiron_base`
- Loading branch information
Showing
3 changed files
with
181 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import inspect | ||
from typing import Optional | ||
|
||
from pyiron_base.jobs.job.extension.server.generic import Server | ||
from pyiron_base.project.generic import Project | ||
|
||
|
||
# The combined decorator | ||
def pyiron_job( | ||
funct: Optional[callable] = None, | ||
*, | ||
host: Optional[str] = None, | ||
queue: Optional[str] = None, | ||
cores: int = 1, | ||
threads: int = 1, | ||
gpus: Optional[int] = None, | ||
run_mode: str = "modal", | ||
new_hdf: bool = True, | ||
output_file_lst: list = [], | ||
output_key_lst: list = [], | ||
): | ||
""" | ||
Decorator to create a pyiron job object from any python function | ||
Args: | ||
funct (callable): python function to create a job object from | ||
host (str): the hostname of the current system. | ||
queue (str): the queue selected for a current simulation. | ||
cores (int): the number of cores selected for the current simulation. | ||
threads (int): the number of threads selected for the current simulation. | ||
gpus (int): the number of gpus selected for the current simulation. | ||
run_mode (str): the run mode of the job ['modal', 'non_modal', 'queue', 'manual'] | ||
new_hdf (bool): defines whether a subjob should be stored in the same HDF5 file or in a new one. | ||
output_file_lst (list): | ||
output_key_lst (list): | ||
Returns: | ||
callable: The decorated functions | ||
Example: | ||
>>> from pyiron_base import pyiron_job, Project | ||
>>> | ||
>>> @pyiron_job | ||
>>> def my_function_a(a, b=8): | ||
>>> return a + b | ||
>>> | ||
>>> @pyiron_job(cores=2) | ||
>>> def my_function_b(a, b=8): | ||
>>> return a + b | ||
>>> | ||
>>> pr = Project("test") | ||
>>> c = my_function_a(a=1, b=2, pyiron_project=pr) | ||
>>> d = my_function_b(a=c, b=3, pyiron_project=pr) | ||
>>> print(d.pull()) | ||
Output: 6 | ||
""" | ||
|
||
def get_delayed_object( | ||
*args, | ||
pyiron_project: Project, | ||
python_function: callable, | ||
pyiron_resource_dict: dict, | ||
resource_default_dict: dict, | ||
**kwargs, | ||
): | ||
for k, v in resource_default_dict.items(): | ||
if k not in pyiron_resource_dict: | ||
pyiron_resource_dict[k] = v | ||
delayed_job_object = pyiron_project.wrap_python_function( | ||
python_function=python_function, | ||
*args, | ||
job_name=None, | ||
automatically_rename=True, | ||
execute_job=False, | ||
delayed=True, | ||
output_file_lst=output_file_lst, | ||
output_key_lst=output_key_lst, | ||
**kwargs, | ||
) | ||
delayed_job_object._server = Server(**pyiron_resource_dict) | ||
return delayed_job_object | ||
|
||
# This is the actual decorator function that applies to the decorated function | ||
def pyiron_job_function(f) -> callable: | ||
def function( | ||
*args, pyiron_project: Project, pyiron_resource_dict: dict = {}, **kwargs | ||
): | ||
resource_default_dict = { | ||
"host": None, | ||
"queue": None, | ||
"cores": 1, | ||
"threads": 1, | ||
"gpus": None, | ||
"run_mode": "modal", | ||
"new_hdf": True, | ||
} | ||
return get_delayed_object( | ||
*args, | ||
python_function=f, | ||
pyiron_project=pyiron_project, | ||
pyiron_resource_dict=pyiron_resource_dict, | ||
resource_default_dict=resource_default_dict, | ||
**kwargs, | ||
) | ||
|
||
return function | ||
|
||
# If funct is None, it means the decorator is called with arguments (like @pyiron_job(...)) | ||
if funct is None: | ||
return pyiron_job_function | ||
|
||
# If funct is not None, it means the decorator is called without parentheses (like @pyiron_job) | ||
else: | ||
# Assume this usage and handle the decorator like `pyiron_job_simple` | ||
def function( | ||
*args, | ||
pyiron_project: Project, | ||
pyiron_resource_dict: dict = {}, | ||
**kwargs, | ||
): | ||
resource_default_dict = { | ||
k: v.default for k, v in inspect.signature(Server).parameters.items() | ||
} | ||
return get_delayed_object( | ||
*args, | ||
python_function=funct, | ||
pyiron_project=pyiron_project, | ||
pyiron_resource_dict=pyiron_resource_dict, | ||
resource_default_dict=resource_default_dict, | ||
**kwargs, | ||
) | ||
|
||
return function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from pyiron_base._tests import TestWithProject | ||
from pyiron_base import pyiron_job | ||
import unittest | ||
|
||
|
||
class TestPythonFunctionDecorator(TestWithProject): | ||
def tearDown(self): | ||
self.project.remove_jobs(recursive=True, silently=True) | ||
|
||
def test_delayed(self): | ||
@pyiron_job() | ||
def my_function_a(a, b=8): | ||
return a + b | ||
|
||
@pyiron_job(cores=2) | ||
def my_function_b(a, b=8): | ||
return a + b | ||
|
||
c = my_function_a(a=1, b=2, pyiron_project=self.project) | ||
d = my_function_b(a=c, b=3, pyiron_project=self.project) | ||
self.assertEqual(d.pull(), 6) | ||
nodes_dict, edges_lst = d.get_graph() | ||
self.assertEqual(len(nodes_dict), 6) | ||
self.assertEqual(len(edges_lst), 6) | ||
|
||
def test_delayed_simple(self): | ||
@pyiron_job | ||
def my_function_a(a, b=8): | ||
return a + b | ||
|
||
@pyiron_job | ||
def my_function_b(a, b=8): | ||
return a + b | ||
|
||
c = my_function_a(a=1, b=2, pyiron_project=self.project) | ||
d = my_function_b( | ||
a=c, b=3, pyiron_project=self.project, pyiron_resource_dict={"cores": 2} | ||
) | ||
self.assertEqual(d.pull(), 6) | ||
nodes_dict, edges_lst = d.get_graph() | ||
self.assertEqual(len(nodes_dict), 6) | ||
self.assertEqual(len(edges_lst), 6) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |