From 8167d7d88ee63f802afe80e13406e3324d05d8ea Mon Sep 17 00:00:00 2001 From: YektaY Date: Thu, 22 Jun 2023 16:35:52 -0700 Subject: [PATCH 1/3] added a way to save generator states and rebuild the old states --- xopt/base.py | 20 ++++++++++++++++---- xopt/pydantic.py | 19 +++++++++++++++++-- xopt/utils.py | 16 +++++++++++++++- 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/xopt/base.py b/xopt/base.py index 248f6af7..da472356 100644 --- a/xopt/base.py +++ b/xopt/base.py @@ -8,8 +8,10 @@ from xopt.evaluator import Evaluator, validate_outputs from xopt.generator import Generator from xopt.generators import get_generator + # from xopt.generators import get_generator_and_defaults from xopt.pydantic import XoptBaseModel +from xopt.utils import build_generator_from_saved_state from xopt.vocs import VOCS __version__ = _version.get_versions()["version"] @@ -65,7 +67,8 @@ def __init__( Initialize Xopt object using either a config dictionary or explicitly Args: - config: dict, or YAML or JSON str or file. This overrides all other arguments. + config: dict, or YAML or JSON str or file. + This overrides all other arguments. generator: Generator object evaluator: Evaluator object @@ -311,7 +314,7 @@ def check_components(self): def dump_state(self): """dump data to file""" if self.options.dump_file is not None: - output = state_to_dict(self) + output = state_to_dict(self, include_history=True) with open(self.options.dump_file, "w") as f: yaml.dump(output, f) logger.debug(f"Dumped state to YAML file: {self.options.dump_file}") @@ -392,7 +395,8 @@ def yaml(self, filename=None, *, include_data=False): def __repr__(self): """ - Returns infor about the Xopt object, including the YAML representation without data. + Returns infor about the Xopt object, + including the YAML representation without data. """ return f""" Xopt @@ -489,7 +493,7 @@ def xopt_kwargs_from_dict(config: dict) -> dict: } -def state_to_dict(X, include_data=True): +def state_to_dict(X, include_data=True, include_history=False): # dump data to dict with config metadata output = { "xopt": json.loads(X.options.json()), @@ -503,4 +507,12 @@ def state_to_dict(X, include_data=True): if include_data: output["data"] = json.loads(X.data.to_json()) + if include_history: + output["history"] = json.loads(X.generator.to_json()) + return output + + +def rebuild_from_previous_state(self, index): + """rebuild generator from saved history""" + return build_generator_from_saved_state(index, self.options.dump_file) diff --git a/xopt/pydantic.py b/xopt/pydantic.py index 607740b1..08876325 100644 --- a/xopt/pydantic.py +++ b/xopt/pydantic.py @@ -11,7 +11,15 @@ import numpy as np import orjson import torch.nn -from pydantic import BaseModel, create_model, Extra, Field, root_validator, validator +from pydantic import ( + BaseModel, + create_model, + Extra, + Field, + parse_obj_as, + root_validator, + validator, +) from pydantic.generics import GenericModel ObjType = TypeVar("ObjType") @@ -80,6 +88,12 @@ def recursive_deserialize(v: dict): return v +def rebuild_from_json(model, json_data): + """Method to rebuild a generator from a json file.""" + rebuilt_generator = parse_obj_as(model, json_data) + return rebuilt_generator + + # define custom json_dumps using orjson def orjson_dumps(v, *, default, base_key=""): v = recursive_serialize(v, base_key=base_key) @@ -414,7 +428,8 @@ def map(self, fn, *iter: Iterable, **kwargs) -> Iterable[Future]: def get_callable_from_string(callable: str, bind: Any = None) -> Callable: - """Get callable from a string. In the case that the callable points to a bound method, + """Get callable from a string. + In the case that the callable points to a bound method, the function returns a callable taking the bind instance as the first arg. Args: diff --git a/xopt/utils.py b/xopt/utils.py index 31d0b6b8..aea36c80 100644 --- a/xopt/utils.py +++ b/xopt/utils.py @@ -1,6 +1,7 @@ import datetime import importlib import inspect +import json import sys import time import traceback @@ -9,7 +10,7 @@ import torch import yaml -from .pydantic import get_descriptions_defaults +from .pydantic import get_descriptions_defaults, rebuild_from_json from .vocs import VOCS @@ -172,6 +173,19 @@ def read_xopt_csv(*files): return pd.concat(dfs) +def build_generator_from_saved_state(self, index, dump_file): + """rebuild generator from saved history""" + with open(dump_file, "r") as file: + data = json.load(file) + + model = data["generator"]["name"] + desired_state = data["history"][index] + + generator = rebuild_from_json(model, desired_state) + + return generator + + def visualize_model(generator, data, axes=None): test_x = torch.linspace(*torch.tensor(generator.vocs.bounds.flatten()), 100) generator.add_data(data) From 378d633d4e53d90a4e8f034288bf8dad21732f7b Mon Sep 17 00:00:00 2001 From: YektaY Date: Fri, 23 Jun 2023 11:13:43 -0700 Subject: [PATCH 2/3] moved rebuild method --- xopt/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xopt/base.py b/xopt/base.py index 6ecb9db4..a3ac1ed3 100644 --- a/xopt/base.py +++ b/xopt/base.py @@ -312,6 +312,11 @@ def check_components(self): if self.vocs is None: raise XoptError("Xopt VOCS is not specified") + def rebuild_from_previous_state(self, index): + """rebuild generator from saved history""" + if self.options.dump_file is not None: + return build_generator_from_saved_state(index, self.options.dump_file) + def dump_state(self): """dump data to file""" if self.options.dump_file is not None: @@ -512,8 +517,3 @@ def state_to_dict(X, include_data=True, include_history=False): output["history"] = json.loads(X.generator.to_json()) return output - - -def rebuild_from_previous_state(self, index): - """rebuild generator from saved history""" - return build_generator_from_saved_state(index, self.options.dump_file) From 3b0ffa3beeb373fe72661cce23a6f1af8707d6d6 Mon Sep 17 00:00:00 2001 From: YektaY Date: Fri, 30 Jun 2023 14:13:15 -0700 Subject: [PATCH 3/3] building out the generator save and rebuilt feature --- xopt/base.py | 6 ++++-- xopt/generator.py | 4 ++++ xopt/pydantic.py | 16 +--------------- xopt/utils.py | 20 ++++++++++++-------- 4 files changed, 21 insertions(+), 25 deletions(-) diff --git a/xopt/base.py b/xopt/base.py index a3ac1ed3..c2d24221 100644 --- a/xopt/base.py +++ b/xopt/base.py @@ -315,7 +315,9 @@ def check_components(self): def rebuild_from_previous_state(self, index): """rebuild generator from saved history""" if self.options.dump_file is not None: - return build_generator_from_saved_state(index, self.options.dump_file) + return build_generator_from_saved_state( + index=index, dump_file=self.options.dump_file + ) def dump_state(self): """dump data to file""" @@ -514,6 +516,6 @@ def state_to_dict(X, include_data=True, include_history=False): output["data"] = json.loads(X.data.to_json()) if include_history: - output["history"] = json.loads(X.generator.to_json()) + output["history"] = json.loads(X.generator.json()) return output diff --git a/xopt/generator.py b/xopt/generator.py index 3374af5d..8fa6fb64 100644 --- a/xopt/generator.py +++ b/xopt/generator.py @@ -46,6 +46,10 @@ def __init__(self, **kwargs): vocs: The vocs to use. options: The options to use. """ + print("start") + for key, value in kwargs.items(): + print(key, value) + print("end") super().__init__(**kwargs) _check_vocs(self.vocs, self.supports_multi_objective) logger.info(f"Initialized generator {self.name}") diff --git a/xopt/pydantic.py b/xopt/pydantic.py index a8347a44..5d9cdaa0 100644 --- a/xopt/pydantic.py +++ b/xopt/pydantic.py @@ -11,15 +11,7 @@ import numpy as np import orjson import torch.nn -from pydantic import ( - BaseModel, - create_model, - Extra, - Field, - parse_obj_as, - root_validator, - validator, -) +from pydantic import BaseModel, create_model, Extra, Field, root_validator, validator from pydantic.generics import GenericModel ObjType = TypeVar("ObjType") @@ -88,12 +80,6 @@ def recursive_deserialize(v: dict): return v -def rebuild_from_json(model, json_data): - """Method to rebuild a generator from a json file.""" - rebuilt_generator = parse_obj_as(model, json_data) - return rebuilt_generator - - # define custom json_dumps using orjson def orjson_dumps(v, *, default, base_key=""): v = recursive_serialize(v, base_key=base_key) diff --git a/xopt/utils.py b/xopt/utils.py index aea36c80..73d24d94 100644 --- a/xopt/utils.py +++ b/xopt/utils.py @@ -1,7 +1,6 @@ import datetime import importlib import inspect -import json import sys import time import traceback @@ -9,8 +8,11 @@ import pandas as pd import torch import yaml +from pydantic import parse_obj_as -from .pydantic import get_descriptions_defaults, rebuild_from_json +from xopt.generators import get_generator + +from .pydantic import get_descriptions_defaults from .vocs import VOCS @@ -173,17 +175,19 @@ def read_xopt_csv(*files): return pd.concat(dfs) -def build_generator_from_saved_state(self, index, dump_file): +def build_generator_from_saved_state(index, dump_file): """rebuild generator from saved history""" with open(dump_file, "r") as file: - data = json.load(file) + data = yaml.safe_load(file) - model = data["generator"]["name"] - desired_state = data["history"][index] + list_of_saved_generators = data["history"] + desired_state = list_of_saved_generators[index] - generator = rebuild_from_json(model, desired_state) + # desired_state['vocs'] = data['vocs'] + generator_class = get_generator(data["generator"].pop("name")) + rebuilt_generator = parse_obj_as(generator_class, desired_state) - return generator + return rebuilt_generator def visualize_model(generator, data, axes=None):