Skip to content

Commit

Permalink
Diffusion control with cross attention (#575)
Browse files Browse the repository at this point in the history
* Control diffusion models
* brain Tmaps
* kronecker merge layers
* add av and lmdb to docker
  • Loading branch information
lucidtronix authored Oct 9, 2024
1 parent 230466d commit 01691eb
Show file tree
Hide file tree
Showing 36 changed files with 1,598 additions and 1,153 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/publish-to-pypi.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
name: Publish Python 🐍 distribution 📦 to PyPI

on:
on:
push:
tags:
- '*' # Push events to every tag not containing /
tags:
- '*' # Push events to every tag not containing /

jobs:
build:
Expand Down Expand Up @@ -56,4 +56,4 @@ jobs:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@release/v1
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include docker/vm_boot_images/config/tensorflow-requirements.txt
4 changes: 3 additions & 1 deletion docker/vm_boot_images/config/tensorflow-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ google-cloud-storage
umap-learn[plot]
neurite
voxelmorph
pystrum
pystrum
av
lmdb
44 changes: 29 additions & 15 deletions ml4h/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,18 @@
import importlib
import numpy as np
import multiprocessing
from typing import Set, Dict, List, Optional, Tuple
from collections import defaultdict
from typing import Set, Dict, List, Optional, Tuple

from ml4h.logger import load_config
from ml4h.TensorMap import TensorMap, TimeSeriesOrder
from ml4h.defines import IMPUTATION_RANDOM, IMPUTATION_MEAN
from ml4h.tensormap.mgb.dynamic import make_mgb_dynamic_tensor_maps
from ml4h.tensormap.tensor_map_maker import generate_categorical_tensor_map_from_file
from ml4h.models.legacy_models import parent_sort, BottleneckType, check_no_bottleneck
from ml4h.models.legacy_models import parent_sort, check_no_bottleneck
from ml4h.tensormap.tensor_map_maker import make_test_tensor_maps, generate_random_pixel_as_text_tensor_maps
from ml4h.models.legacy_models import NORMALIZATION_CLASSES, CONV_REGULARIZATION_CLASSES, DENSE_REGULARIZATION_CLASSES
from ml4h.tensormap.tensor_map_maker import generate_continuous_tensor_map_from_file, generate_random_text_tensor_maps

BOTTLENECK_STR_TO_ENUM = {
'flatten_restructure': BottleneckType.FlattenRestructure,
'global_average_pool': BottleneckType.GlobalAveragePoolStructured,
'variational': BottleneckType.Variational,
'no_bottleneck': BottleneckType.NoBottleNeck,
}
from ml4h.tensormap.tensor_map_maker import generate_categorical_tensor_map_from_file, generate_latent_tensor_map_from_file


def parse_args():
Expand Down Expand Up @@ -104,7 +97,18 @@ def parse_args():
'--categorical_file_columns', nargs='*', default=[],
help='Column headers in file from which categorical TensorMap(s) will be made.',
)

parser.add_argument(
'--latent_input_file', default=None, help=
'Path to a file containing latent space values from which an input TensorMap will be made.'
'Note that setting this argument has the effect of linking the first input_tensors'
'argument to the TensorMap made from this file.',
)
parser.add_argument(
'--latent_output_file', default=None, help=
'Path to a file containing latent space values from which an input TensorMap will be made.'
'Note that setting this argument has the effect of linking the first output_tensors'
'argument to the TensorMap made from this file.',
)
parser.add_argument(
'--categorical_field_ids', nargs='*', default=[], type=int,
help='List of field ids from which input features will be collected.',
Expand Down Expand Up @@ -212,13 +216,17 @@ def parse_args():
'--max_parameters', default=50000000, type=int,
help='Maximum number of trainable parameters in a model during hyperparameter optimization.',
)
parser.add_argument('--bottleneck_type', type=str, default=list(BOTTLENECK_STR_TO_ENUM)[0], choices=list(BOTTLENECK_STR_TO_ENUM))
parser.add_argument('--hidden_layer', default='embed', help='Name of a hidden layer for inspections.')
parser.add_argument('--language_layer', default='ecg_rest_text', help='Name of TensorMap for learning language models (eg train_char_model).')
parser.add_argument('--language_prefix', default='ukb_ecg_rest', help='Path prefix for a TensorMap to learn language models (eg train_char_model)')
parser.add_argument('--text_window', default=32, type=int, help='Size of text window in number of tokens.')
parser.add_argument('--hd5_as_text', default=None, help='Path prefix for a TensorMap to learn language models from flattened HD5 arrays.')
parser.add_argument('--attention_heads', default=4, type=int, help='Number of attention heads in Multi-headed attention layers')
parser.add_argument(
'--attention_window', default=4, type=int,
help='For diffusion models, when U-Net representation size is smaller than attention_window '
'Cross-Attention is applied',
)
parser.add_argument(
'--transformer_size', default=32, type=int,
help='Number of output neurons in Transformer encoders and decoders, '
Expand Down Expand Up @@ -506,6 +514,10 @@ def _process_args(args):
args.tensor_maps_in.append(input_map)
args.tensor_maps_out.append(output_map)

if args.latent_input_file is not None:
args.tensor_maps_in.append(
generate_latent_tensor_map_from_file(args.latent_input_file, args.input_tensors.pop(0)),
)
args.tensor_maps_in.extend([tensormap_lookup(it, args.tensormap_prefix) for it in args.input_tensors])

if args.continuous_file is not None:
Expand All @@ -530,13 +542,15 @@ def _process_args(args):
args.output_tensors.pop(0),
),
)
if args.latent_output_file is not None:
args.tensor_maps_out.append(
generate_latent_tensor_map_from_file(args.latent_output_file, args.output_tensors.pop(0)),
)
args.tensor_maps_out.extend([tensormap_lookup(ot, args.tensormap_prefix) for ot in args.output_tensors])
args.tensor_maps_out = parent_sort(args.tensor_maps_out)
args.tensor_maps_protected = [tensormap_lookup(it, args.tensormap_prefix) for it in args.protected_tensors]

args.bottleneck_type = BOTTLENECK_STR_TO_ENUM[args.bottleneck_type]
if args.bottleneck_type == BottleneckType.NoBottleNeck:
check_no_bottleneck(args.u_connect, args.tensor_maps_out)
check_no_bottleneck(args.u_connect, args.tensor_maps_out)

if args.learning_rate_schedule is not None and args.patience < args.epochs:
raise ValueError(f'learning_rate_schedule is not compatible with ReduceLROnPlateau. Set patience > epochs.')
Expand Down
2 changes: 1 addition & 1 deletion ml4h/ml4ht_integration/tensor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,4 @@ def infer_from_dataloader(dataloader, model, tensor_maps_out, max_batches=125000
except StopIteration:
print('loaded all batches')
break
return pd.DataFrame.from_dict(space_dict)
return pd.DataFrame.from_dict(space_dict)
39 changes: 35 additions & 4 deletions ml4h/models/basic_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self,
*,
tensor_map: TensorMap,
dense_layers: List[int] = [32],
dense_layers: List[int] = [256],
activation: str = 'swish',
dense_normalize: str = None,
dense_regularize: str = None,
Expand All @@ -62,9 +62,9 @@ def __init__(
)

def can_apply(self):
return self.tensor_map.axes() == 1 and not self.tensor_map.is_embedding()
return self.tensor_map.axes() == 1# and not self.tensor_map.is_embedding()

def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]]) -> Tensor:
def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = None) -> Tensor:
if not self.can_apply():
return x
y = self.fully_connected(x, intermediates)
Expand Down Expand Up @@ -156,6 +156,7 @@ def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]]) -> T
intermediates[self.tensor_map].append(x)
return x


class ModelAsBlock(Block):
"""Takes a serialized model and applies it, can be used to encode or decode Tensors"""
def __init__(
Expand Down Expand Up @@ -231,7 +232,6 @@ def __init__(
self.dense = Dense(units, activation=activation)
self.final_layer = Dense(units=tensor_map.shape[0], name=tensor_map.output_name(), activation=tensor_map.activation)


def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = None) -> Tensor:
x = self.dense(self.drop(x))
if self.tensor_map.is_continuous():
Expand Down Expand Up @@ -259,3 +259,34 @@ def get_config(self):
"scalar": self.scalar,
})
return config


class IdentityEncoderBlock(Block):
def __init__(
self,
tensor_map: TensorMap,
**kwargs,
):
self.tensor_map = tensor_map

def can_apply(self):
return True

def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = None) -> Tensor:
intermediates[self.tensor_map].append(x)
return x


class IdentityDecoderBlock(Block):
def __init__(
self,
tensor_map: TensorMap,
**kwargs,
):
self.tensor_map = tensor_map

def can_apply(self):
return True

def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = None) -> Tensor:
return intermediates[self.tensor_map][-1]
Loading

0 comments on commit 01691eb

Please sign in to comment.