Skip to content

Commit

Permalink
Add infrastructure for offline pretraining of Bandits (#138)
Browse files Browse the repository at this point in the history
* fix

* fix#2

* minor

* Implement helper class for offline training

* Rename OfflineTrainer to HistoryCollector

* WIP minor fixes

* Renames and moves

* minor type enhancements

* Add static method into ExperienceBuvver

* WIP add validation and training of OperatorAgent

* Add handling of trajectoreis into ExperienceBuffer

* WIP add validation methods into AgentLearner

* WIP change interface of ExperienceBuffer slightly

* WIP implement agent validation

* WIP Complete agent_training.py fitting & add ExperienceBuffer.split

* Separate history_collector.py on 2 classes

* Minor refactorings

* Move experience_buffer.py

* Move common_types; minor tweaks

* WIP Fixups of minor errors

* Fix erros precluding agent training

Some errors are because of serialization issues with molecules

* TMP experiment

* Few renames

* Extend docs

* Fix error with `content` loading of molecules

* Drop pretrain line from experiment.py

* Minor fix for fitness.valid check

* Revert "Fix error with `content` loading of molecules"

* Minor modification for node when they have name but no params

* remove HistoryCollector

* minors

* fix example

* fixes after rebase

* fix nxid

* minor

* stable reward minors

* minor

---------

Co-authored-by: Pinchuk Maya <[email protected]>
  • Loading branch information
gkirgizov and maypink authored Oct 25, 2023
1 parent afa11e6 commit 29cd362
Show file tree
Hide file tree
Showing 19 changed files with 490 additions and 125 deletions.
28 changes: 23 additions & 5 deletions examples/molecule_search/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
normalized_logp, CLScorer
from golem.core.dag.verification_rules import has_no_self_cycled_nodes, has_no_isolated_components, \
has_no_isolated_nodes
from golem.core.optimisers.adaptive.agent_trainer import AgentTrainer
from golem.core.optimisers.adaptive.history_collector import HistoryReader
from golem.core.optimisers.adaptive.operator_agent import MutationAgentTypeEnum
from golem.core.optimisers.genetic.gp_optimizer import EvoGraphOptimizer
from golem.core.optimisers.genetic.gp_params import GPAlgorithmParameters
Expand All @@ -25,6 +27,7 @@
from golem.core.optimisers.objective import Objective
from golem.core.optimisers.opt_history_objects.opt_history import OptHistory
from golem.core.optimisers.optimizer import GraphGenerationParams, GraphOptimizer
from golem.core.paths import project_root
from golem.visualisation.opt_history.multiple_fitness_line import MultipleFitnessLines
from golem.visualisation.opt_viz_extra import visualise_pareto

Expand Down Expand Up @@ -129,6 +132,16 @@ def visualize_results(molecules: Iterable[MolGraph],
image.show()


def pretrain_agent(optimizer: EvoGraphOptimizer, objective: Objective, results_dir: str) -> AgentTrainer:
agent = optimizer.mutation.agent
trainer = AgentTrainer(objective, optimizer.mutation, agent)
# load histories
history_reader = HistoryReader(Path(results_dir))
# train agent
trainer.fit(histories=history_reader.load_histories(), validate_each=1)
return trainer


def run_experiment(optimizer_setup: Callable,
optimizer_cls: Type[GraphOptimizer] = EvoGraphOptimizer,
adaptive_kind: MutationAgentTypeEnum = MutationAgentTypeEnum.random,
Expand All @@ -143,13 +156,14 @@ def run_experiment(optimizer_setup: Callable,
trial_iterations: Optional[int] = None,
visualize: bool = False,
save_history: bool = True,
pretrain_dir: Optional[str] = None,
):
metrics = metrics or ['qed_score']
optimizer_id = optimizer_cls.__name__.lower()[:3]
experiment_id = f'Experiment [optimizer={optimizer_id} metrics={", ".join(metrics)} pop_size={pop_size}]'
exp_name = f'{optimizer_id}_{adaptive_kind.value}_popsize{pop_size}_min{trial_timeout}_{"_".join(metrics)}'

atom_types = atom_types or ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br']
metrics = metrics or ['qed_score']
trial_results = []
trial_histories = []
trial_timedelta = timedelta(minutes=trial_timeout) if trial_timeout else None
Expand All @@ -165,6 +179,9 @@ def run_experiment(optimizer_setup: Callable,
pop_size,
metrics,
initial_molecules)
if pretrain_dir:
pretrain_agent(optimizer, objective, pretrain_dir)

found_graphs = optimizer.optimise(objective)
history = optimizer.history

Expand Down Expand Up @@ -208,10 +225,11 @@ def plot_experiment_comparison(experiment_ids: Sequence[str], metric_id: int = 0

if __name__ == '__main__':
run_experiment(molecule_search_setup,
adaptive_kind=MutationAgentTypeEnum.random,
adaptive_kind=MutationAgentTypeEnum.bandit,
max_heavy_atoms=38,
trial_timeout=15,
trial_timeout=6,
pop_size=50,
metrics=['qed_score', 'cl_score'],
visualize=True,
num_trials=5)
num_trials=5,
pretrain_dir=os.path.join(project_root(), 'examples', 'molecule_search', 'histories')
)
14 changes: 10 additions & 4 deletions examples/molecule_search/mol_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ def __init__(self):

def _restore(self, opt_graph: OptGraph, metadata: Optional[Dict[str, Any]] = None) -> MolGraph:
digraph = self.nx_adapter.restore(opt_graph)
# return to previous node indexing
digraph = nx.relabel_nodes(digraph, dict(digraph.nodes(data='nxid')))
# to ensure backward compatibility with old individuals without 'nxid' field in nodes
if not any(x is None for x in list(dict(digraph.nodes(data='nxid')).values())):
# return to previous node indexing
digraph = nx.relabel_nodes(digraph, dict(digraph.nodes(data='nxid')))
digraph = restore_edges_params_from_nodes(digraph)
nx_graph = digraph.to_undirected()
mol_graph = MolGraph.from_nx_graph(nx_graph)
Expand Down Expand Up @@ -50,7 +52,11 @@ def restore_edges_params_from_nodes(graph: nx.DiGraph) -> nx.DiGraph:
all_edges_params = {}
for node in graph.nodes():
for predecessor in graph.predecessors(node):
edge_params = edge_params_by_node[node][predecessor]
all_edges_params.update({(predecessor, node): edge_params})
node_params = edge_params_by_node[node]
# sometimes by unknown reason some nodes are encoded as int, some as str.
# maybe that's deserialization messing up somewhere.
edge_params = node_params.get(predecessor) or node_params.get(str(predecessor))
if edge_params:
all_edges_params[(predecessor, node)] = edge_params
nx.set_edge_attributes(graph, all_edges_params)
return graph
8 changes: 4 additions & 4 deletions examples/molecule_search/mol_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rdkit import Chem
from rdkit.Chem import MolFromSmiles, MolToSmiles, SanitizeMol, Kekulize, MolToInchi
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem.rdchem import Atom, BondType, RWMol, GetPeriodicTable
from rdkit.Chem.rdchem import Atom, BondType, RWMol, GetPeriodicTable, ChiralType, HybridizationType


class MolGraph:
Expand All @@ -32,10 +32,10 @@ def from_nx_graph(graph: nx.Graph):
node_to_idx = {}
for node in graph.nodes():
a = Chem.Atom(atomic_nums[node])
a.SetChiralTag(chiral_tags[node])
a.SetChiralTag(ChiralType(chiral_tags[node]))
a.SetFormalCharge(formal_charges[node])
a.SetIsAromatic(node_is_aromatics[node])
a.SetHybridization(node_hybridizations[node])
a.SetHybridization(HybridizationType(node_hybridizations[node]))
a.SetNumExplicitHs(num_explicit_hss[node])
idx = mol.AddAtom(a)
node_to_idx[node] = idx
Expand All @@ -45,7 +45,7 @@ def from_nx_graph(graph: nx.Graph):
first, second = edge
ifirst = node_to_idx[first]
isecond = node_to_idx[second]
bond_type = bond_types[first, second]
bond_type = BondType(bond_types[first, second])
mol.AddBond(ifirst, isecond, bond_type)

SanitizeMol(mol)
Expand Down
13 changes: 7 additions & 6 deletions golem/core/adapter/nx_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ def __init__(self):
def _node_restore(self, node: GraphNode) -> Dict:
"""Transforms GraphNode to dict of NetworkX node attributes.
Override for custom behavior."""
parameters = {}
if hasattr(node, 'parameters'):
parameters = node.parameters
if node.name:
parameters['name'] = node.name
return deepcopy(parameters)
else:
return {}
parameters = deepcopy(node.parameters)

if node.name:
parameters['name'] = node.name

return parameters

def _node_adapt(self, data: Dict) -> OptNode:
"""Transforms a dict of NetworkX node attributes to GraphNode.
Expand Down
198 changes: 198 additions & 0 deletions golem/core/optimisers/adaptive/agent_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import operator
from copy import deepcopy
from functools import reduce
from typing import Sequence, Optional, Any, Tuple, List, Iterable

import numpy as np

from golem.core.dag.graph import Graph
from golem.core.log import default_log
from golem.core.optimisers.adaptive.common_types import TrajectoryStep, GraphTrajectory
from golem.core.optimisers.adaptive.experience_buffer import ExperienceBuffer
from golem.core.optimisers.adaptive.operator_agent import OperatorAgent
from golem.core.optimisers.fitness import Fitness
from golem.core.optimisers.genetic.operators.mutation import Mutation
from golem.core.optimisers.objective import Objective
from golem.core.optimisers.opt_history_objects.individual import Individual
from golem.core.optimisers.opt_history_objects.opt_history import OptHistory
from golem.core.optimisers.opt_history_objects.parent_operator import ParentOperator
from golem.core.utilities.data_structures import unzip


class AgentTrainer:
"""Utility class providing fit/validate logic for adaptive Mutation agents.
Works in tandem with `HistoryReader`.
How to use offline training:
1. Collect histories to some directory using `ExperimentLauncher`
2. Create optimizer & Pretrain mutation agent on these histories using `HistoryReader` and `AgentTrainer`
3. Optionally, validate the Agent on validation set of histories
4. Run optimization with pretrained agent
"""

def __init__(self,
objective: Objective,
mutation_operator: Mutation,
agent: Optional[OperatorAgent] = None,
):
self._log = default_log(self)
self.agent = agent if agent is not None else mutation_operator.agent
self.mutation = mutation_operator
self.objective = objective
self._adapter = self.mutation.graph_generation_params.adapter

def fit(self, histories: Iterable[OptHistory], validate_each: int = -1) -> OperatorAgent:
"""
Method to fit trainer on collected histories.
param histories: histories to use in training.
param validate_each: validate agent once in validate_each generation.
"""
# Set mutation probabilities to 1.0
initial_req = deepcopy(self.mutation.requirements)
self.mutation.requirements.mutation_prob = 1.0

for i, history in enumerate(histories):
# Preliminary validity check
# This allows to filter out histories with different objectives automatically
if history.objective.metric_names != self.objective.metric_names:
self._log.warning(f'History #{i+1} has different objective! '
f'Expected {self.objective}, got {history.objective}.')
continue

# Build datasets
experience = ExperienceBuffer.from_history(history)
val_experience = None
if validate_each > 0 and i % validate_each == 0:
experience, val_experience = experience.split(ratio=0.8, shuffle=True)

# Train
self._log.info(f'Training on history #{i+1} with {len(history.generations)} generations')
self.agent.partial_fit(experience)

# Validate
if val_experience:
reward_loss, reward_target = self.validate_agent(experience=val_experience)
self._log.info(f'Agent validation for history #{i+1} & {experience}: '
f'Reward target={reward_target:.3f}, loss={reward_loss:.3f}')

# Reset mutation probabilities to default
self.mutation.update_requirements(requirements=initial_req)
return self.agent

def validate_on_rollouts(self, histories: Sequence[OptHistory]) -> float:
"""Validates rollouts of agent vs. historic trajectories, comparing
their mean total rewards (i.e. total fitness gain over the trajectory)."""

# Collect all trajectories from all histories; and their rewards
trajectories = concat_lists(map(ExperienceBuffer.unroll_trajectories, histories))

mean_traj_len = int(np.mean([len(tr) for tr in trajectories]))
traj_rewards = [sum(reward for _, reward, _ in traj) for traj in trajectories]
mean_baseline_reward = np.mean(traj_rewards)

# Collect same number of trajectories of the same length; and their rewards
agent_trajectories = [self._sample_trajectory(initial=tr[0][0], length=mean_traj_len)
for tr in trajectories]
agent_traj_rewards = [sum(reward for _, reward, _ in traj) for traj in agent_trajectories]
mean_agent_reward = np.mean(agent_traj_rewards)

# Compute improvement score of agent over baseline histories
improvement = mean_agent_reward - mean_baseline_reward
return improvement

def validate_history(self, history: OptHistory) -> Tuple[float, float]:
"""Validates history of mutated individuals against optimal policy."""
history_trajectories = ExperienceBuffer.unroll_trajectories(history)
return self._validate_against_optimal(history_trajectories)

def validate_agent(self,
graphs: Optional[Sequence[Graph]] = None,
experience: Optional[ExperienceBuffer] = None) -> Tuple[float, float]:
"""Validates agent policy against optimal policy on given graphs."""
if experience:
agent_steps = experience.retrieve_trajectories()
elif graphs:
agent_steps = [self._make_action_step(Individual(g)) for g in graphs]
else:
self._log.warning('Either graphs or history must not be None for validation!')
return 0., 0.
return self._validate_against_optimal(trajectories=[agent_steps])

def _validate_against_optimal(self, trajectories: Sequence[GraphTrajectory]) -> Tuple[float, float]:
"""Validates a policy trajectories against optimal policy
that at each step always chooses the best action with max reward."""
reward_losses = []
reward_targets = []
for trajectory in trajectories:
inds, actions, rewards = unzip(trajectory)
_, best_actions, best_rewards = self._apply_best_action(inds)
reward_loss = self._compute_reward_loss(rewards, best_rewards)
reward_losses.append(reward_loss)
reward_targets.append(np.mean(best_rewards))
reward_loss = float(np.mean(reward_losses))
reward_target = float(np.mean(reward_targets))
return reward_loss, reward_target

@staticmethod
def _compute_reward_loss(rewards, optimal_rewards, normalized=False) -> float:
"""Returns difference (or deviation) from optimal reward.
When normalized, 0. means actual rewards match optimal rewards completely,
0.5 means they on average deviate by 50% from optimal rewards,
and 2.2 means they on average deviate by more than 2 times from optimal reward."""
reward_losses = np.subtract(optimal_rewards, rewards) # always positive
if normalized:
reward_losses = reward_losses / np.abs(optimal_rewards) \
if np.count_nonzero(optimal_rewards) == optimal_rewards.size else reward_losses
means = np.mean(reward_losses)
return float(means)

def _apply_best_action(self, inds: Sequence[Individual]) -> TrajectoryStep:
"""Returns greedily optimal mutation for given graph and associated reward."""
candidates = []
for ind in inds:
for mutation_id in self.agent.available_actions:
try:
values = self._apply_action(mutation_id, ind)
candidates.append(values)
except Exception as e:
self._log.warning(f'Eval error for mutation <{mutation_id}> '
f'on graph: {ind.graph.descriptive_id}:\n{e}')
continue
best_step = max(candidates, key=lambda step: step[-1])
return best_step

def _apply_action(self, action: Any, ind: Individual) -> TrajectoryStep:
new_graph, applied = self.mutation._adapt_and_apply_mutation(ind.graph, action)
fitness = self._eval_objective(new_graph) if applied else None
parent_op = ParentOperator(type_='mutation', operators=applied, parent_individuals=ind)
new_ind = Individual(new_graph, fitness=fitness, parent_operator=parent_op)

prev_fitness = ind.fitness or self._eval_objective(ind.graph)
if prev_fitness and fitness:
reward = prev_fitness.value - fitness.value
elif prev_fitness and not fitness:
reward = -1.
else:
reward = 0.
return new_ind, action, reward

def _eval_objective(self, graph: Graph) -> Fitness:
return self._adapter.adapt_func(self.objective)(graph)

def _make_action_step(self, ind: Individual) -> TrajectoryStep:
action = self.agent.choose_action(ind.graph)
return self._apply_action(action, ind)

def _sample_trajectory(self, initial: Individual, length: int) -> GraphTrajectory:
trajectory = []
past_ind = initial
for i in range(length):
next_ind, action, reward = self._make_action_step(past_ind)
trajectory.append((next_ind, action, reward))
past_ind = next_ind
return trajectory


def concat_lists(lists: Iterable[List]) -> List:
return reduce(operator.add, lists, [])
11 changes: 11 additions & 0 deletions golem/core/optimisers/adaptive/common_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Union, Hashable, Tuple, Sequence

from golem.core.dag.graph import Graph
from golem.core.optimisers.opt_history_objects.individual import Individual

ObsType = Union[Individual, Graph]
ActType = Hashable
# Trajectory step includes (past observation, action, reward)
TrajectoryStep = Tuple[Individual, ActType, float]
# Trajectory is a sequence of applied mutations and received rewards
GraphTrajectory = Sequence[TrajectoryStep]
Loading

0 comments on commit 29cd362

Please sign in to comment.