-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BoTorch kernel preset, which uses dimensions-scaled prior #483
Draft
Hrovatin
wants to merge
21
commits into
main
Choose a base branch
from
feature/botorch_kernel_preset
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+12,246
−2
Draft
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
39ad74c
Add direct arylation benchmark for TL with temperature as a task
Hrovatin 863541f
Update changelog
Hrovatin 520eab8
remove random seed that was set in the paper as it is redundant with …
Hrovatin fd990b6
Benchmark for transfer learning on arylhalides with dissimilar susbst…
Hrovatin 4219451
Transfer learning benchmark with inverted Hartmann functions as tasks
Hrovatin 64475fd
Add non-transfer learning campaign and transfer learning campaign wit…
Hrovatin d47b295
Transfer learning benchmark with noisy Michalewicz functions as tasks
Hrovatin 8d20313
Transfer learning benchmark with noisy Easom functions as tasks.
Hrovatin ad8cbe1
Move data to benchmark folder
Hrovatin a4469a2
restructure benchmark data access
Hrovatin efcf7af
Make data paths general
Hrovatin ac8d371
Use csv instead of xlsx
Hrovatin deeda82
Add BoTorch kernel preset, which uses dimensions-scaled prior
Hrovatin 66a63cf
pre-commit fixes
Hrovatin 6d929e5
add to changelog
Hrovatin afcd803
Add a few botorch kernel preset benchmarks and adapt scripts for a te…
Hrovatin 88044e4
Set N repeat iterations
Hrovatin 0b44ff7
Added more benchmarks
Hrovatin eec73a7
Add benchmark
Hrovatin c1bb99e
Define benchmarks to run
Hrovatin 1b46aa6
Reduce number of replicates to speed up benchmark time
Hrovatin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Presets adapted from BoTorch.""" | ||
|
||
from __future__ import annotations | ||
|
||
from math import log, sqrt | ||
from typing import TYPE_CHECKING | ||
|
||
from attrs import define | ||
from gpytorch.constraints import GreaterThan | ||
from typing_extensions import override | ||
|
||
from baybe.kernels.basic import RBFKernel | ||
from baybe.parameters import TaskParameter | ||
from baybe.priors.basic import LogNormalPrior | ||
from baybe.searchspace import SearchSpace | ||
from baybe.surrogates.gaussian_process.kernel_factory import KernelFactory | ||
|
||
if TYPE_CHECKING: | ||
from torch import Tensor | ||
|
||
from baybe.kernels.base import Kernel | ||
|
||
|
||
@define | ||
class BotorchKernelFactory(KernelFactory): | ||
"""A kernel factory for Gaussian process surrogates adapted from BoTorch. | ||
|
||
References: | ||
* https://github.com/pytorch/botorch/blob/a018a5ffbcbface6229d6c39f7ac6ef9baf5765e/botorch/models/multitask.py#L220 | ||
* https://github.com/pytorch/botorch/blob/a018a5ffbcbface6229d6c39f7ac6ef9baf5765e/botorch/models/utils/gpytorch_modules.py#L100 | ||
|
||
""" | ||
|
||
@override | ||
def __call__( | ||
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor | ||
) -> Kernel: | ||
ard_num_dims = train_x.shape[-1] - len( | ||
[ | ||
param | ||
for param in searchspace.discrete.parameters | ||
if isinstance(param, TaskParameter) | ||
] | ||
) | ||
lengthscale_prior = LogNormalPrior( | ||
loc=sqrt(2) + log(ard_num_dims) * 0.5, scale=sqrt(3) | ||
) | ||
|
||
return RBFKernel( | ||
lengthscale_prior=lengthscale_prior, | ||
lengthscale_constraint=GreaterThan( | ||
2.5e-2, | ||
transform=None, | ||
initial_value=lengthscale_prior.to_gpytorch().mode, | ||
), | ||
) |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Utils for reading data.""" | ||
|
||
import os | ||
|
||
DATA_PATH = os.sep.join(__file__.split(os.sep)[:-1]) + os.sep |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,26 @@ | ||
"""Benchmark domains.""" | ||
|
||
from benchmarks.definition.base import Benchmark | ||
from benchmarks.domains.synthetic_2C1D_1C import synthetic_2C1D_1C_benchmark | ||
from benchmarks.domains.kernel_presets.arylhalides_tl_substance import ( | ||
arylhalides_tl_substance_benchmark, | ||
) | ||
from benchmarks.domains.kernel_presets.direct_arylation_tl_temp import ( | ||
direct_arylation_tl_temp_benchmark, | ||
) | ||
from benchmarks.domains.kernel_presets.easom_tl_noise import easom_tl_noise_benchmark | ||
from benchmarks.domains.kernel_presets.hartmann_tl_inverted_noise import ( | ||
hartmann_tl_inverted_noise_benchmark, | ||
) | ||
from benchmarks.domains.kernel_presets.michalewicz_tl_noise import ( | ||
michalewicz_tl_noise_benchmark, | ||
) | ||
|
||
BENCHMARKS: list[Benchmark] = [ | ||
synthetic_2C1D_1C_benchmark, | ||
hartmann_tl_inverted_noise_benchmark, | ||
easom_tl_noise_benchmark, | ||
michalewicz_tl_noise_benchmark, | ||
arylhalides_tl_substance_benchmark, | ||
direct_arylation_tl_temp_benchmark, | ||
] | ||
|
||
__all__ = ["BENCHMARKS"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
"""Benchmark on ArylHalides data with two distinct arylhalides as TL tasks.""" | ||
|
||
from __future__ import annotations | ||
|
||
import os | ||
|
||
import pandas as pd | ||
|
||
from baybe.campaign import Campaign | ||
from baybe.objectives import SingleTargetObjective | ||
from baybe.parameters import SubstanceParameter, TaskParameter | ||
from baybe.searchspace import SearchSpace | ||
from baybe.simulation import simulate_scenarios | ||
from baybe.targets import NumericalTarget | ||
from benchmarks.data.utils import DATA_PATH | ||
from benchmarks.definition import ( | ||
ConvergenceBenchmark, | ||
ConvergenceBenchmarkSettings, | ||
) | ||
|
||
|
||
def get_data() -> pd.DataFrame: | ||
"""Load data for benchmark. | ||
|
||
Returns: | ||
Data for benchmark. | ||
""" | ||
data_path = DATA_PATH + "ArylHalides" + os.sep | ||
data = pd.read_table(data_path + "data.csv", sep=",") | ||
data_raw = pd.read_table(data_path + "data_raw.csv", sep=",") | ||
for substance in ["base", "ligand", "additive"]: | ||
data[substance + "_smiles"] = data[substance].map( | ||
dict(zip(data_raw[substance], data_raw[substance + "_smiles"])) | ||
) | ||
return data | ||
|
||
|
||
data = get_data() | ||
|
||
test_task = "1-iodo-4-methoxybenzene" | ||
source_task = [ | ||
# Dissimilar source task | ||
"1-chloro-4-(trifluoromethyl)benzene" | ||
] | ||
|
||
|
||
def space_data() -> ( | ||
SingleTargetObjective, | ||
SearchSpace, | ||
SearchSpace, | ||
pd.DataFrame, | ||
pd.DataFrame, | ||
): | ||
"""Definition of search space, objective, and data. | ||
|
||
Returns: | ||
Objective, TL search space, non-TL search space, | ||
pre-measured task data (source task), | ||
and lookup for the active (target) task. | ||
""" | ||
data_params = [ | ||
SubstanceParameter( | ||
name=substance, | ||
data=dict(zip(data[substance], data[f"{substance}_smiles"])), | ||
encoding="MORDRED", | ||
) | ||
for substance in ["base", "ligand", "additive"] | ||
] | ||
|
||
task_param = TaskParameter( | ||
name="aryl_halide", | ||
values=[test_task] + source_task, | ||
active_values=[test_task], | ||
) | ||
|
||
objective = SingleTargetObjective(NumericalTarget(name="yield", mode="MAX")) | ||
searchspace = SearchSpace.from_product(parameters=[*data_params, task_param]) | ||
searchspace_nontl = SearchSpace.from_product(parameters=data_params) | ||
|
||
lookup = data.query(f'aryl_halide=="{test_task}"').copy(deep=True) | ||
initial_data = data.query("aryl_halide.isin(@source_task)", engine="python").copy( | ||
deep=True | ||
) | ||
|
||
return objective, searchspace, searchspace_nontl, initial_data, lookup | ||
|
||
|
||
def arylhalides_tl_substance(settings: ConvergenceBenchmarkSettings) -> pd.DataFrame: | ||
"""Benchmark function comparing TL and non-TL campaigns. | ||
|
||
Inputs: | ||
base Discrete substance with numerical encoding | ||
ligand Discrete substance with numerical encoding | ||
additive Discrete substance with numerical encoding | ||
Concentration Continuous | ||
aryl_halide Discrete task parameter | ||
Output: continuous | ||
Objective: Maximization | ||
Optimal Inputs: [ | ||
{ | ||
base MTBD | ||
ligand AdBrettPhos | ||
additive N,N-dibenzylisoxazol-3-amine | ||
} | ||
] | ||
Optimal Output: 68.24812709999999 | ||
""" | ||
objective, searchspace, searchspace_nontl, initial_data, lookup = space_data() | ||
|
||
campaign = Campaign( | ||
searchspace=searchspace, | ||
objective=objective, | ||
) | ||
|
||
results = [] | ||
for p in [0.01, 0.02, 0.05, 0.1, 0.2]: | ||
results.append( | ||
simulate_scenarios( | ||
{f"{int(100 * p)}": campaign}, | ||
lookup, | ||
initial_data=[ | ||
initial_data.sample(frac=p) for _ in range(settings.n_mc_iterations) | ||
], | ||
batch_size=settings.batch_size, | ||
n_doe_iterations=settings.n_doe_iterations, | ||
impute_mode="error", | ||
) | ||
) | ||
# No training data | ||
results.append( | ||
simulate_scenarios( | ||
{"0": campaign}, | ||
lookup, | ||
batch_size=settings.batch_size, | ||
n_doe_iterations=settings.n_doe_iterations, | ||
n_mc_iterations=settings.n_mc_iterations, | ||
impute_mode="error", | ||
) | ||
) | ||
# Non-TL campaign | ||
results.append( | ||
simulate_scenarios( | ||
{"non-TL": Campaign(searchspace=searchspace_nontl, objective=objective)}, | ||
lookup, | ||
batch_size=settings.batch_size, | ||
n_doe_iterations=settings.n_doe_iterations, | ||
n_mc_iterations=settings.n_mc_iterations, | ||
impute_mode="error", | ||
) | ||
) | ||
results = pd.concat(results) | ||
return results | ||
|
||
|
||
benchmark_config = ConvergenceBenchmarkSettings( | ||
batch_size=2, | ||
n_doe_iterations=10, | ||
n_mc_iterations=100, | ||
) | ||
|
||
arylhalides_tl_substance_benchmark = ConvergenceBenchmark( | ||
function=arylhalides_tl_substance, | ||
optimal_target_values={"yield": 68.24812709999999}, | ||
settings=benchmark_config, | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Hrovatin, just two comments upfront, to save us some work later:
Cheers 🙃
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah then I strongly misunderstood the aim - we should then align during the meeting next week