Skip to content

Commit

Permalink
Add ESM-GearNet (#16)
Browse files Browse the repository at this point in the history
* init esm_gearnet & pre-training

* fix a bug in loading scheduler
  • Loading branch information
Oxer11 authored Mar 12, 2023
1 parent 38c9ca5 commit e4d1bc7
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 1 deletion.
74 changes: 74 additions & 0 deletions config/downstream/EC/ESM_gearnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
output_dir: ~/scratch/protein_output

dataset:
class: EnzymeCommission
path: ~/scratch/protein-datasets/
test_cutoff: 0.95
atom_feature: null
bond_feature: null
transform:
class: Compose
transforms:
- class: ProteinView
view: residue
- class: TruncateProtein
max_length: 550

task:
class: MultipleBinaryClassification
model:
class: FusionNetwork
sequence_model:
class: ESM
path: ~/scratch/protein-model-weights/esm-model-weights/
model: ESM-1b
structure_model:
class: GearNet
input_dim: 1280
hidden_dims: [512, 512, 512, 512, 512, 512]
batch_norm: True
concat_hidden: True
short_cut: True
readout: 'sum'
num_relation: 7
graph_construction_model:
class: GraphConstruction
node_layers:
- class: AlphaCarbonNode
edge_layers:
- class: SequentialEdge
max_distance: 2
- class: SpatialEdge
radius: 10.0
min_distance: 5
- class: KNNEdge
k: 10
min_distance: 5
edge_feature: gearnet
criterion: bce
num_mlp_layer: 3
metric: ['auprc@micro', 'f1_max']

optimizer:
class: AdamW
lr: 1.0e-4
weight_decay: 0

scheduler:
class: ReduceLROnPlateau
factor: 0.6
patience: 5

engine:
gpus: {{ gpus }}
batch_size: 2
log_interval: 1000

model_checkpoint: {{ ckpt }}

sequence_model_lr_ratio: 0.1

metric: f1_max

train:
num_epoch: 50
70 changes: 70 additions & 0 deletions config/pretrain/mc_esm_gearnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
output_dir: ~/scratch/protein_output

dataset:
class: AlphaFoldDB
path: ~/scratch/protein-datasets/alphafold
species_start: 0
species_end: 22
# species_id: 3
# split_id: 1
atom_feature: null
bond_feature: null
transform:
class: ProteinView
view: residue

task:
class: Unsupervised
model:
class: MultiviewContrast
crop_funcs:
- class: SubsequenceNode
max_length: 50
noise_funcs:
- class: IdentityNode
- class: RandomEdgeMask
mask_rate: 0.15
model:
class: FusionNetwork
sequence_model:
class: ESM
path: ~/scratch/protein-model-weights/esm-model-weights/
model: ESM-1b
structure_model:
class: GearNet
input_dim: 1280
hidden_dims: [512, 512, 512, 512, 512, 512]
batch_norm: True
concat_hidden: True
short_cut: True
readout: 'sum'
num_relation: 7
graph_construction_model:
class: GraphConstruction
node_layers:
- class: AlphaCarbonNode
edge_layers:
- class: SequentialEdge
max_distance: 2
- class: SpatialEdge
radius: 10.0
min_distance: 5
- class: KNNEdge
k: 10
min_distance: 5
edge_feature: gearnet

optimizer:
class: Adam
lr: 2.0e-4

engine:
gpus: {{ gpus }}
batch_size: 48
log_interval: 100

save_interval: 5
fix_sequence_model: True

train:
num_epoch: 50
27 changes: 27 additions & 0 deletions gearnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,30 @@ def forward(self, graph, input, all_loss=None, metric=None):
"graph_feature": graph_feature,
"node_feature": node_feature
}


@R.register("models.FusionNetwork")
class FusionNetwork(nn.Module, core.Configurable):

def __init__(self, sequence_model, structure_model):
super(FusionNetwork, self).__init__()
self.sequence_model = sequence_model
self.structure_model = structure_model
self.output_dim = sequence_model.output_dim + structure_model.output_dim

def forward(self, graph, input, all_loss=None, metric=None):
output1 = self.sequence_model(graph, input, all_loss, metric)
node_output1 = output1.get("node_feature", output1.get("residue_feature"))
output2 = self.structure_model(graph, node_output1, all_loss, metric)
node_output2 = output2.get("node_feature", output2.get("residue_feature"))

node_feature = torch.cat([node_output1, node_output2], dim=-1)
graph_feature = torch.cat([
output1['graph_feature'],
output2['graph_feature']
], dim=-1)

return {
"graph_feature": graph_feature,
"node_feature": node_feature
}
26 changes: 25 additions & 1 deletion util.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ def build_downstream_solver(cfg, dataset):
]
optimizer = core.Configurable.load_config_dict(cfg.optimizer)
solver.optimizer = optimizer
elif "sequence_model_lr_ratio" in cfg:
assert cfg.task.model["class"] == "FusionNetwork"
cfg.optimizer.params = [
{'params': solver.model.model.sequence_model.parameters(), 'lr': cfg.optimizer.lr * cfg.sequence_model_lr_ratio},
{'params': solver.model.model.structure_model.parameters(), 'lr': cfg.optimizer.lr},
{'params': solver.model.mlp.parameters(), 'lr': cfg.optimizer.lr}
]
optimizer = core.Configurable.load_config_dict(cfg.optimizer)
solver.optimizer = optimizer

if isinstance(scheduler, lr_scheduler.ReduceLROnPlateau):
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, **cfg.scheduler)
elif scheduler is not None:
cfg.scheduler.optimizer = optimizer
scheduler = core.Configurable.load_config_dict(cfg.scheduler)
solver.scheduler = scheduler

if cfg.get("checkpoint") is not None:
solver.load(cfg.checkpoint)
Expand All @@ -149,7 +165,15 @@ def build_pretrain_solver(cfg, dataset):
logger.warning("#dataset: %d" % (len(dataset)))

task = core.Configurable.load_config_dict(cfg.task)
cfg.optimizer.params = task.parameters()
if "fix_sequence_model" in cfg:
if cfg.task["class"] == "Unsupervised":
model_dict = cfg.task.model.model
else:
model_dict = cfg.task.model
assert model_dict["class"] == "FusionNetwork"
for p in task.model.model.sequence_model.parameters():
p.requires_grad = False
cfg.optimizer.params = [p for p in task.parameters() if p.requires_grad]
optimizer = core.Configurable.load_config_dict(cfg.optimizer)
solver = core.Engine(task, dataset, None, None, optimizer, **cfg.engine)

Expand Down

0 comments on commit e4d1bc7

Please sign in to comment.