From a8e9e4314812f46f1d3c872e210c459d6d4a18e5 Mon Sep 17 00:00:00 2001 From: Johnnykoch02 Date: Fri, 12 Apr 2024 03:27:45 -0400 Subject: [PATCH 1/4] initial commit --- .gitignore | 8 + README.md | 419 +----- app/main.py | 4 +- app/main_distributed.py | 2 +- app/vjepa/train.py | 16 +- app/vjepa/transforms.py | 4 +- app/vjepa/utils.py | 10 +- {src => build/lib}/datasets/data_manager.py | 4 +- {src => build/lib}/datasets/image_dataset.py | 0 .../lib}/datasets/utils/video/functional.py | 0 .../lib}/datasets/utils/video/randaugment.py | 0 .../lib}/datasets/utils/video/randerase.py | 0 .../lib}/datasets/utils/video/transforms.py | 4 +- .../datasets/utils/video/volume_transforms.py | 0 .../lib}/datasets/utils/weighted_sampler.py | 0 {src => build/lib}/datasets/video_dataset.py | 8 +- build/lib/jepa_src/__init__.py | 0 build/lib/jepa_src/datasets/__init__.py | 0 build/lib/jepa_src/datasets/data_manager.py | 91 ++ build/lib/jepa_src/datasets/image_dataset.py | 79 ++ build/lib/jepa_src/datasets/utils/__init__.py | 0 .../jepa_src/datasets/utils/video/__init__.py | 0 .../datasets/utils/video/functional.py | 96 ++ .../datasets/utils/video/randaugment.py | 518 ++++++++ .../datasets/utils/video/randerase.py | 180 +++ .../datasets/utils/video/transforms.py | 1184 +++++++++++++++++ .../datasets/utils/video/volume_transforms.py | 151 +++ .../datasets/utils/weighted_sampler.py | 97 ++ build/lib/jepa_src/datasets/video_dataset.py | 272 ++++ build/lib/jepa_src/masks/__init__.py | 0 {src => build/lib/jepa_src}/masks/default.py | 0 .../lib/jepa_src}/masks/multiblock3d.py | 0 .../lib/jepa_src}/masks/random_tube.py | 0 {src => build/lib/jepa_src}/masks/utils.py | 0 build/lib/jepa_src/models/__init__.py | 0 .../lib/jepa_src}/models/attentive_pooler.py | 4 +- .../lib/jepa_src}/models/predictor.py | 8 +- build/lib/jepa_src/models/utils/__init__.py | 0 .../lib/jepa_src}/models/utils/modules.py | 0 .../lib/jepa_src}/models/utils/multimask.py | 0 .../lib/jepa_src}/models/utils/patch_embed.py | 0 .../lib/jepa_src}/models/utils/pos_embs.py | 0 .../jepa_src}/models/vision_transformer.py | 10 +- build/lib/jepa_src/utils/__init__.py | 0 .../lib/jepa_src}/utils/distributed.py | 0 {src => build/lib/jepa_src}/utils/logging.py | 0 .../lib/jepa_src}/utils/monitoring.py | 0 .../lib/jepa_src}/utils/schedulers.py | 0 {src => build/lib/jepa_src}/utils/tensors.py | 0 build/lib/masks/default.py | 20 + build/lib/masks/multiblock3d.py | 203 +++ build/lib/masks/random_tube.py | 117 ++ build/lib/masks/utils.py | 23 + build/lib/models/attentive_pooler.py | 136 ++ build/lib/models/predictor.py | 246 ++++ build/lib/models/utils/modules.py | 183 +++ build/lib/models/utils/multimask.py | 48 + build/lib/models/utils/patch_embed.py | 57 + build/lib/models/utils/pos_embs.py | 99 ++ build/lib/models/vision_transformer.py | 307 +++++ build/lib/utils/distributed.py | 113 ++ build/lib/utils/logging.py | 118 ++ build/lib/utils/monitoring.py | 175 +++ build/lib/utils/schedulers.py | 76 ++ build/lib/utils/tensors.py | 71 + build/lib/vjepa_encoder/__init__.py | 0 build/lib/vjepa_encoder/vision_encoder.py | 327 +++++ build/lib/vjepa_encoder/vjepa/__init__.py | 0 build/lib/vjepa_encoder/vjepa/train.py | 586 ++++++++ build/lib/vjepa_encoder/vjepa/transforms.py | 153 +++ build/lib/vjepa_encoder/vjepa/utils.py | 210 +++ evals/image_classification_frozen/eval.py | 12 +- evals/main.py | 2 +- evals/video_classification_frozen/eval.py | 12 +- evals/video_classification_frozen/utils.py | 10 +- fair_documentation.md | 407 ++++++ huggingface/README.md | 78 ++ huggingface/demo_jepa_encoder.py | 14 + huggingface/params-encoder.yaml | 89 ++ jepa_encoder.egg-info/PKG-INFO | 17 + jepa_encoder.egg-info/SOURCES.txt | 10 + jepa_encoder.egg-info/dependency_links.txt | 1 + jepa_encoder.egg-info/requires.txt | 11 + jepa_encoder.egg-info/top_level.txt | 1 + jepa_src/__init__.py | 0 jepa_src/datasets/__init__.py | 0 jepa_src/datasets/data_manager.py | 91 ++ jepa_src/datasets/image_dataset.py | 79 ++ jepa_src/datasets/utils/__init__.py | 0 jepa_src/datasets/utils/video/__init__.py | 0 jepa_src/datasets/utils/video/functional.py | 96 ++ jepa_src/datasets/utils/video/randaugment.py | 518 ++++++++ jepa_src/datasets/utils/video/randerase.py | 180 +++ jepa_src/datasets/utils/video/transforms.py | 1184 +++++++++++++++++ .../datasets/utils/video/volume_transforms.py | 151 +++ jepa_src/datasets/utils/weighted_sampler.py | 97 ++ jepa_src/datasets/video_dataset.py | 272 ++++ jepa_src/masks/__init__.py | 0 jepa_src/masks/default.py | 20 + jepa_src/masks/multiblock3d.py | 203 +++ jepa_src/masks/random_tube.py | 117 ++ jepa_src/masks/utils.py | 23 + jepa_src/models/__init__.py | 0 jepa_src/models/attentive_pooler.py | 136 ++ jepa_src/models/predictor.py | 246 ++++ jepa_src/models/utils/__init__.py | 0 jepa_src/models/utils/modules.py | 183 +++ jepa_src/models/utils/multimask.py | 48 + jepa_src/models/utils/patch_embed.py | 57 + jepa_src/models/utils/pos_embs.py | 99 ++ jepa_src/models/vision_transformer.py | 307 +++++ jepa_src/utils/__init__.py | 0 jepa_src/utils/distributed.py | 113 ++ jepa_src/utils/logging.py | 118 ++ jepa_src/utils/monitoring.py | 175 +++ jepa_src/utils/schedulers.py | 76 ++ jepa_src/utils/tensors.py | 71 + requirements.txt | 2 - setup.py | 11 +- vjepa_encoder.egg-info/PKG-INFO | 19 + vjepa_encoder.egg-info/SOURCES.txt | 47 + vjepa_encoder.egg-info/dependency_links.txt | 1 + vjepa_encoder.egg-info/requires.txt | 11 + vjepa_encoder.egg-info/top_level.txt | 2 + vjepa_encoder/__init__.py | 0 vjepa_encoder/vision_encoder.py | 327 +++++ vjepa_encoder/vjepa/__init__.py | 0 vjepa_encoder/vjepa/train.py | 586 ++++++++ vjepa_encoder/vjepa/transforms.py | 153 +++ vjepa_encoder/vjepa/utils.py | 210 +++ 130 files changed, 12697 insertions(+), 433 deletions(-) rename {src => build/lib}/datasets/data_manager.py (94%) rename {src => build/lib}/datasets/image_dataset.py (100%) rename {src => build/lib}/datasets/utils/video/functional.py (100%) rename {src => build/lib}/datasets/utils/video/randaugment.py (100%) rename {src => build/lib}/datasets/utils/video/randerase.py (100%) rename {src => build/lib}/datasets/utils/video/transforms.py (99%) rename {src => build/lib}/datasets/utils/video/volume_transforms.py (100%) rename {src => build/lib}/datasets/utils/weighted_sampler.py (100%) rename {src => build/lib}/datasets/video_dataset.py (97%) create mode 100644 build/lib/jepa_src/__init__.py create mode 100644 build/lib/jepa_src/datasets/__init__.py create mode 100644 build/lib/jepa_src/datasets/data_manager.py create mode 100644 build/lib/jepa_src/datasets/image_dataset.py create mode 100644 build/lib/jepa_src/datasets/utils/__init__.py create mode 100644 build/lib/jepa_src/datasets/utils/video/__init__.py create mode 100644 build/lib/jepa_src/datasets/utils/video/functional.py create mode 100644 build/lib/jepa_src/datasets/utils/video/randaugment.py create mode 100644 build/lib/jepa_src/datasets/utils/video/randerase.py create mode 100644 build/lib/jepa_src/datasets/utils/video/transforms.py create mode 100644 build/lib/jepa_src/datasets/utils/video/volume_transforms.py create mode 100644 build/lib/jepa_src/datasets/utils/weighted_sampler.py create mode 100644 build/lib/jepa_src/datasets/video_dataset.py create mode 100644 build/lib/jepa_src/masks/__init__.py rename {src => build/lib/jepa_src}/masks/default.py (100%) rename {src => build/lib/jepa_src}/masks/multiblock3d.py (100%) rename {src => build/lib/jepa_src}/masks/random_tube.py (100%) rename {src => build/lib/jepa_src}/masks/utils.py (100%) create mode 100644 build/lib/jepa_src/models/__init__.py rename {src => build/lib/jepa_src}/models/attentive_pooler.py (97%) rename {src => build/lib/jepa_src}/models/predictor.py (97%) create mode 100644 build/lib/jepa_src/models/utils/__init__.py rename {src => build/lib/jepa_src}/models/utils/modules.py (100%) rename {src => build/lib/jepa_src}/models/utils/multimask.py (100%) rename {src => build/lib/jepa_src}/models/utils/patch_embed.py (100%) rename {src => build/lib/jepa_src}/models/utils/pos_embs.py (100%) rename {src => build/lib/jepa_src}/models/vision_transformer.py (96%) create mode 100644 build/lib/jepa_src/utils/__init__.py rename {src => build/lib/jepa_src}/utils/distributed.py (100%) rename {src => build/lib/jepa_src}/utils/logging.py (100%) rename {src => build/lib/jepa_src}/utils/monitoring.py (100%) rename {src => build/lib/jepa_src}/utils/schedulers.py (100%) rename {src => build/lib/jepa_src}/utils/tensors.py (100%) create mode 100644 build/lib/masks/default.py create mode 100644 build/lib/masks/multiblock3d.py create mode 100644 build/lib/masks/random_tube.py create mode 100644 build/lib/masks/utils.py create mode 100644 build/lib/models/attentive_pooler.py create mode 100644 build/lib/models/predictor.py create mode 100644 build/lib/models/utils/modules.py create mode 100644 build/lib/models/utils/multimask.py create mode 100644 build/lib/models/utils/patch_embed.py create mode 100644 build/lib/models/utils/pos_embs.py create mode 100644 build/lib/models/vision_transformer.py create mode 100644 build/lib/utils/distributed.py create mode 100644 build/lib/utils/logging.py create mode 100644 build/lib/utils/monitoring.py create mode 100644 build/lib/utils/schedulers.py create mode 100644 build/lib/utils/tensors.py create mode 100644 build/lib/vjepa_encoder/__init__.py create mode 100644 build/lib/vjepa_encoder/vision_encoder.py create mode 100644 build/lib/vjepa_encoder/vjepa/__init__.py create mode 100644 build/lib/vjepa_encoder/vjepa/train.py create mode 100644 build/lib/vjepa_encoder/vjepa/transforms.py create mode 100644 build/lib/vjepa_encoder/vjepa/utils.py create mode 100644 fair_documentation.md create mode 100644 huggingface/README.md create mode 100644 huggingface/demo_jepa_encoder.py create mode 100644 huggingface/params-encoder.yaml create mode 100644 jepa_encoder.egg-info/PKG-INFO create mode 100644 jepa_encoder.egg-info/SOURCES.txt create mode 100644 jepa_encoder.egg-info/dependency_links.txt create mode 100644 jepa_encoder.egg-info/requires.txt create mode 100644 jepa_encoder.egg-info/top_level.txt create mode 100644 jepa_src/__init__.py create mode 100644 jepa_src/datasets/__init__.py create mode 100644 jepa_src/datasets/data_manager.py create mode 100644 jepa_src/datasets/image_dataset.py create mode 100644 jepa_src/datasets/utils/__init__.py create mode 100644 jepa_src/datasets/utils/video/__init__.py create mode 100644 jepa_src/datasets/utils/video/functional.py create mode 100644 jepa_src/datasets/utils/video/randaugment.py create mode 100644 jepa_src/datasets/utils/video/randerase.py create mode 100644 jepa_src/datasets/utils/video/transforms.py create mode 100644 jepa_src/datasets/utils/video/volume_transforms.py create mode 100644 jepa_src/datasets/utils/weighted_sampler.py create mode 100644 jepa_src/datasets/video_dataset.py create mode 100644 jepa_src/masks/__init__.py create mode 100644 jepa_src/masks/default.py create mode 100644 jepa_src/masks/multiblock3d.py create mode 100644 jepa_src/masks/random_tube.py create mode 100644 jepa_src/masks/utils.py create mode 100644 jepa_src/models/__init__.py create mode 100644 jepa_src/models/attentive_pooler.py create mode 100644 jepa_src/models/predictor.py create mode 100644 jepa_src/models/utils/__init__.py create mode 100644 jepa_src/models/utils/modules.py create mode 100644 jepa_src/models/utils/multimask.py create mode 100644 jepa_src/models/utils/patch_embed.py create mode 100644 jepa_src/models/utils/pos_embs.py create mode 100644 jepa_src/models/vision_transformer.py create mode 100644 jepa_src/utils/__init__.py create mode 100644 jepa_src/utils/distributed.py create mode 100644 jepa_src/utils/logging.py create mode 100644 jepa_src/utils/monitoring.py create mode 100644 jepa_src/utils/schedulers.py create mode 100644 jepa_src/utils/tensors.py create mode 100644 vjepa_encoder.egg-info/PKG-INFO create mode 100644 vjepa_encoder.egg-info/SOURCES.txt create mode 100644 vjepa_encoder.egg-info/dependency_links.txt create mode 100644 vjepa_encoder.egg-info/requires.txt create mode 100644 vjepa_encoder.egg-info/top_level.txt create mode 100644 vjepa_encoder/__init__.py create mode 100644 vjepa_encoder/vision_encoder.py create mode 100644 vjepa_encoder/vjepa/__init__.py create mode 100644 vjepa_encoder/vjepa/train.py create mode 100644 vjepa_encoder/vjepa/transforms.py create mode 100644 vjepa_encoder/vjepa/utils.py diff --git a/.gitignore b/.gitignore index 3bb2efd..5e71834 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,10 @@ .*.swp *.pyc +*.tar + +bin/ +dist/ +.vscode/ +logs/ + +jepa_src/jepa.egg-info/ \ No newline at end of file diff --git a/README.md b/README.md index a3579e1..1fc98b8 100644 --- a/README.md +++ b/README.md @@ -1,407 +1,80 @@ -# V-JEPA: Video Joint Embedding Predictive Architecture + VJEPA Encoder -Official PyTorch codebase for the _video joint-embedding predictive architecture_, V-JEPA, a method for self-supervised learning of visual representations from video. +The VJEPA Encoder is a Python package that provides an implementation of the encoder component from the JEPA (Joint Encoding for Prediction and Alignment) architecture proposed by Facebook AI Research. The encoder is designed to extract meaningful representations from visual data. I do not own the rights or lay claim to the copyright of this software. This package is an adaptation to `facebookresearch/jepa` to enable ease of use of the Jepa Architecture built with Vision Transformers. -**[Meta AI Research, FAIR](https://ai.facebook.com/research/)** +## Installation -Adrien Bardes, Quentin Garrido, Jean Ponce, Xinlei Chen, Michael Rabbat, Yann LeCun, Mahmoud Assran*, Nicolas Ballas* +To install the VJEPA Encoder package, you can use pip: -[\[Blog\]](https://ai.meta.com/blog/v-jepa-yann-lecun-ai-model-video-joint-embedding-predictive-architecture/) -[\[Paper\]](https://ai.meta.com/research/publications/revisiting-feature-prediction-for-learning-visual-representations-from-video/) -[\[Yannic Kilcher's Video\]](https://www.youtube.com/watch?v=7UkJPwz_N_0) - -V-JEPA models are trained by passively watching video pixels from the VideoMix2M dataset, and produce versatile visual representations that perform well on downstream video and image tasks, without adaption of the model’s parameters; e.g., using a frozen backbone and only a light-weight task-specific attentive probe. - -## Method -V-JEPA pretraining is based solely on an unsupervised feature prediction objective, and does not utilize pretrained image encoders, text, negative examples, human annotations, or pixel-level reconstruction. - - - -      - - - - -## Visualizations -As opposed to generative methods that have a pixel decoder, V-JEPA has a predictor that makes predictions in latent space. -We train a conditional diffusion model to decode the V-JEPA feature-space predictions to interpretable pixels; the pretrained V-JEPA encoder and predictor networks are kept frozen in this process. -The decoder is only fed the representations predicted for the missing regions of the video, and does not have access to the unmasked regions of the video. - -The V-JEPA feature predictions are indeed grounded, and exhibit spatio-temporal consistency with the unmasked regions of the video. - - -
- - - - -
- -## MODEL ZOO - -#### Pretrained models - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
modelpatch sizeresolutioniterationsbatch sizedatadownload
ViT-L2x16x16224x22490K3072VideoMix2Mcheckpointconfigs
ViT-H2x16x16224x22490K3072VideoMix2Mcheckpointconfigs
ViT-H2x16x16384x38490K2400VideoMix2Mcheckpointconfigs
- -#### K400 Attentive probes - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
modelresolutionaccuracy (16x8x3)download
ViT-L/16224x22480.8attentive probe checkpointconfigs
ViT-H/16224x22482.0attentive probe checkpointconfigs
ViT-H/16384x38481.9attentive probe checkpointconfigs
- -#### SSv2 Attentive probes - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
modelresolutionaccuracy (16x2x3)download
ViT-L/16224x22469.5attentive probe checkpointconfigs
ViT-H/16224x22471.4attentive probe checkpointconfigs
ViT-H/16384x38472.2attentive probe checkpointconfigs
- -#### ImageNet1K Attentive probes - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
modelresolutionaccuracydownload
ViT-L/16224x22474.8attentive probe checkpointconfigs
ViT-H/16224x22475.9attentive probe checkpointconfigs
ViT-H/16384x38477.4attentive probe checkpointconfigs
- -#### Places205 Attentive probes - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
modelresolutionaccuracydownload
ViT-L/16224x22460.3attentive probe checkpointconfigs
ViT-H/16224x22461.7attentive probe checkpointconfigs
ViT-H/16384x38462.8attentive probe checkpointconfigs
- -#### iNat21 Attentive probes +``` +pip install vjepa_encoder +``` - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
modelresolutionaccuracydownload
ViT-L/16224x22467.8attentive probe checkpointconfigs
ViT-H/16224x22467.9attentive probe checkpointconfigs
ViT-H/16384x38472.6attentive probe checkpointconfigs
+## Usage -## Code Structure +To use the VJEPA Encoder in your Python code, you can import it as follows: -**Config files:** -All experiment parameters are specified in config files (as opposed to command-line arguments). See the [configs/](configs/) directory for example config files. Note, before launching an experiment, you must update the paths in the config file to point to your own directories, indicating where to save the logs and checkpoints and where to find the training data. +```python +from vjepa_encoder.vision_encoder import JepaEncoder +``` +### Loading the Encoder -``` -. -├── app # the only place where training loops are allowed -│ ├── vjepa # Video JEPA pre-training -│ ├── main_distributed.py # entrypoint for launching app on slurm cluster -│ └── main.py # entrypoint for launching app locally on your machine for debugging -├── evals # the only place where evaluation of 'apps' are allowed -│ ├── image_classification # training an attentive probe for image classification with frozen backbone -│ ├── video_classification # training an attentive probe for video classification with frozen backbone -│ ├── main_distributed.py # entrypoint for launching distributed evaluations on slurm cluster -│ └── main.py # entrypoint for launching evaluations locally on your machine for debugging -├── src # the package -│ ├── datasets # datasets, data loaders, ... -│ ├── models # model definitions -│ ├── masks # mask collators, masking utilities, ... -│ └── utils # shared utilities -└── configs # the only place where config files are allowed (specify experiment params for app/eval runs) - ├── evals # configs for launching vjepa frozen evaluations - └── pretrain # configs for launching vjepa pretraining +To load the pre-trained encoder, you can use the `load_model` function: +```python +config_file_path = "./params-encoder.yaml" +devices = ["cuda:0"] +encoder = JepaEncoder.load_model(config_file_path, devices) ``` -## Data preparation +- `config_file_path`: Path to the configuration file (YAML) containing the model settings. +- `devices`: List of devices (e.g., `['cuda:0']`) to use for distributed training. If not provided, the model will be loaded on the CPU. -### Video Datasets -V-JEPA pretraining and evaluations work with many standard video formats. -To make a video dataset compatible with the V-JEPA codebase, you simply need to create a `.csv` file with the following format and then specify the path to this CSV file in your config. -``` -/absolute_file_path.[mp4, webvid, etc.] $integer_class_label -/absolute_file_path.[mp4, webvid, etc.] $integer_class_label -/absolute_file_path.[mp4, webvid, etc.] $integer_class_label -... -``` -Since V-JEPA is entirely unsupervised, the pretraining code will disregard the `$integer_class_label` in the CSV file. -Thus, feel free to put a random value in this column. -However, if you wish to run a supervised video classification evaluation on your video dataset, you must replace ```$integer_class_label``` with the ground truth label for each video. +### Preprocessing Data -### Image Datasets -We use the standard PyTorch ```ImageFolder``` class in our image classification evals. -Thus, to set up an image dataset for the image classification evaluation, first create a directory to store your image datasets ```$your_directory_containing_image_datasets```. -Next, download your image datasets into this directory in a format compatible with [PyTorch ImageFolder](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html). +The VJEPA Encoder provides a `preprocess_data` function to preprocess input data before feeding it to the encoder: -For example, suppose we have a directory called ``my_image_datasets``. We would then download our image datasets into this directory so that we end up with the following file tree -``` -. -└── /my_image_datasets/ # where we store image datasets - ├── places205/121517/pytorch/ # Places205 - │ └── [...] - ├── iNaturalist-2021/110421/ # iNaturalist21 - │ └── [...] - ├── [...] # Other Image Datasets - │ └── [...] - └── imagenet_full_size/061417/ # ImageNet1k - └── train - │ ├── $class_1 - │ │ ├── xxx.[png, jpeg, etc.] - │ │ ├── [...] - │ │ └── xxz.[png, jpeg, etc.] - │ ├── [...] - │ └── $class_n - │ ├── abc.[png, jpeg, etc.] - │ ├── [...] - │ └── abz.[png, jpeg, etc.] - └── val - ├── $class_1 - │ ├── xxx.[png, jpeg, etc.] - │ ├── [...] - │ └── xxz.[png, jpeg, etc.] - ├── [...] - └── $class_n - ├── abc.[png, jpeg, etc.] - ├── [...] - └── abz.[png, jpeg, etc.] +```python +preprocessed_data = encoder.preprocess_data(input_data) ``` +- `input_data`: Input data, which can be an image path, image array, PIL Image, or PyTorch tensor. -## Launching V-JEPA pretraining +### Embedding Images -### Local training -If you wish to debug your code or setup before launching a distributed training run, we provide the functionality to do so by running the pretraining script locally on a multi-GPU (or single-GPU) machine, however, reproducing our results requires launching distributed training. +To obtain the embeddings for an image, you can use the `embed_image` function: -The single-machine implementation starts from the [app/main.py](appmain.py), which parses the experiment config file and runs the pretraining locally on a multi-GPU (or single-GPU) machine. -For example, to run V-JEPA pretraining on GPUs "0", "1", and "2" on a local machine using the config [configs/pretrain/vitl16.yaml](configs/pretrain/vitl16.yaml), type the command: -```bash -python -m app.main \ - --fname configs/pretrain/vitl16.yaml \ - --devices cuda:0 cuda:1 cuda:2 +```python +embeddings = encoder.embed_image(input_data) ``` -### Distributed training -To launch a distributed training run, the implementation starts from [app/main_distributed.py](app/main_distributed.py), which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source [submitit](https://github.com/facebookincubator/submitit) tool and provide examples for a SLURM cluster. +- `input_data`: Input data, which can be an image path, image array, PIL Image, or PyTorch tensor. -For example, to launch a distributed pre-training experiment using the config [configs/pretrain/vitl16.yaml](configs/pretrain/vitl16.yaml), type the command: -```bash -python -m app.main_distributed \ - --fname configs/pretrain/vitl16.yaml \ - --folder $path_to_save_stderr_and_stdout \ - --partition $slurm_partition -``` +The function returns the embeddings generated by the encoder. -## Launching Evaluations +## Configuration -### Local training -If you wish to debug your eval code or setup before launching a distributed training run, we provide the functionality to do so by running the evaluation script locally on a multi-GPU (or single-GPU) machine, however, reproducing the full eval would require launching distributed training. -The single-machine implementation starts from the [eval/main.py](eval/main.py), which parses the experiment config file and runs the eval locally on a multi-GPU (or single-GPU) machine. +The VJEPA Encoder requires a configuration file in YAML format to specify the model settings. The configuration file should include the following sections: -For example, to run ImageNet image classification on GPUs "0", "1", and "2" on a local machine using the config [configs/eval/vitl16_in1k.yaml](configs/eval/vitl16_in1k.yaml), type the command: -```bash -python -m evals.main \ - --fname configs/eval/vitl16_in1k.yaml \ - --devices cuda:0 cuda:1 cuda:2 -``` +- `meta`: General settings such as the checkpoint file path, random seed, etc. +- `mask`: Settings related to masking. +- `model`: Model architecture settings. +- `data`: Data-related settings such as crop size, patch size, etc. +- `logging`: Logging settings. +Please refer to the provided configuration file template for more details. -### Distributed training -To launch a distributed evaluation run, the implementation starts from [eval/main_distributed.py](eval/main_distributed.py), which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source [submitit](https://github.com/facebookincubator/submitit) tool and provide examples for a SLURM cluster. - -For example, to launch a distributed ImageNet image classification experiment using the config [configs/eval/vitl16_in1k.yaml](configs/eval/vitl16_in1k.yaml), type the command: -```bash -python -m evals.main_distributed \ - --fname configs/eval/vitl16_in1k.yaml \ - --folder $path_to_save_stderr_and_stdout \ - --partition $slurm_partition -``` +## License -Similarly, to launch a distributed K400 video classification experiment using the config [configs/eval/vitl16_k400.yaml](configs/eval/vitl16_k400.yaml), type the command: -```bash -python -m evals.main_distributed \ - --fname configs/eval/vitl16_k400.yaml \ - --folder $path_to_save_stderr_and_stdout \ - --partition $slurm_partition -``` +The VJEPA Encoder is released under the [MIT License](LICENSE). ---- +## Acknowledgments -### Setup +The VJEPA Encoder is based on the research work conducted by Facebook AI Research. We would like to acknowledge their contributions to the field of computer vision and representation learning. -Run: -```bash -conda create -n jepa python=3.9 pip -conda activate jepa -python setup.py install -``` +## Contact -## License -See the [LICENSE](./LICENSE) file for details about the license under which this code is made available. +If you have any questions or suggestions regarding the VJEPA Encoder, please feel free to contact us at johnnykoch02@gmail.com. -## Citation -If you find this repository useful in your research, please consider giving a star :star: and a citation -```bibtex -@article{bardes2024revisiting, - title={Revisiting Feature Prediction for Learning Visual Representations from Video}, - author={Bardes, Adrien and Garrido, Quentin and Ponce, Jean and Rabbat, Michael, and LeCun, Yann and Assran, Mahmoud and Ballas, Nicolas}, - journal={arXiv preprint}, - year={2024} -} +--- \ No newline at end of file diff --git a/app/main.py b/app/main.py index 52e1596..9f66229 100644 --- a/app/main.py +++ b/app/main.py @@ -13,7 +13,7 @@ import yaml from app.scaffold import main as app_main -from src.utils.distributed import init_distributed +from jepa_src.utils.distributed import init_distributed parser = argparse.ArgumentParser() parser.add_argument( @@ -30,7 +30,7 @@ def process_main(rank, fname, world_size, devices): os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) import logging - from src.utils.logging import get_logger + from jepa_src.utils.logging import get_logger logger = get_logger(force=True) if rank == 0: logger.setLevel(logging.INFO) diff --git a/app/main_distributed.py b/app/main_distributed.py index 11ac3a2..fe2e160 100644 --- a/app/main_distributed.py +++ b/app/main_distributed.py @@ -13,7 +13,7 @@ import submitit from app.scaffold import main as app_main -from src.utils.logging import get_logger +from jepa_src.utils.logging import get_logger logger = get_logger(force=True) diff --git a/app/vjepa/train.py b/app/vjepa/train.py index 2b55616..ccb2e75 100644 --- a/app/vjepa/train.py +++ b/app/vjepa/train.py @@ -26,19 +26,19 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel -from src.datasets.data_manager import init_data -from src.masks.random_tube import MaskCollator as TubeMaskCollator -from src.masks.multiblock3d import MaskCollator as MB3DMaskCollator -from src.masks.utils import apply_masks -from src.utils.distributed import init_distributed, AllReduce -from src.utils.logging import ( +from jepa_src.datasets.data_manager import init_data +from jepa_src.masks.random_tube import MaskCollator as TubeMaskCollator +from jepa_src.masks.multiblock3d import MaskCollator as MB3DMaskCollator +from jepa_src.masks.utils import apply_masks +from jepa_src.utils.distributed import init_distributed, AllReduce +from jepa_src.utils.logging import ( CSVLogger, gpu_timer, get_logger, grad_logger, adamw_logger, AverageMeter) -from src.utils.tensors import repeat_interleave_batch +from jepa_src.utils.tensors import repeat_interleave_batch from app.vjepa.utils import ( load_checkpoint, @@ -77,7 +77,7 @@ def main(args, resume_preempt=False): skip_batches = cfgs_meta.get('skip_batches', -1) use_sdpa = cfgs_meta.get('use_sdpa', False) which_dtype = cfgs_meta.get('dtype') - logger.info(f'{which_dtype=}') + logger.info(f'{which_dtype}') if which_dtype.lower() == 'bfloat16': dtype = torch.bfloat16 mixed_precision = True diff --git a/app/vjepa/transforms.py b/app/vjepa/transforms.py index 0854dd9..ba62555 100644 --- a/app/vjepa/transforms.py +++ b/app/vjepa/transforms.py @@ -8,8 +8,8 @@ import torch import torchvision.transforms as transforms -import src.datasets.utils.video.transforms as video_transforms -from src.datasets.utils.video.randerase import RandomErasing +import jepa_src.datasets.utils.video.transforms as video_transforms +from jepa_src.datasets.utils.video.randerase import RandomErasing def make_transforms( diff --git a/app/vjepa/utils.py b/app/vjepa/utils.py index dc8668d..2636ed7 100644 --- a/app/vjepa/utils.py +++ b/app/vjepa/utils.py @@ -13,13 +13,13 @@ import torch -import src.models.vision_transformer as video_vit -import src.models.predictor as vit_pred -from src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper -from src.utils.schedulers import ( +import jepa_src.models.vision_transformer as video_vit +import jepa_src.models.predictor as vit_pred +from jepa_src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper +from jepa_src.utils.schedulers import ( WarmupCosineSchedule, CosineWDSchedule) -from src.utils.tensors import trunc_normal_ +from jepa_src.utils.tensors import trunc_normal_ logging.basicConfig(stream=sys.stdout, level=logging.INFO) logger = logging.getLogger() diff --git a/src/datasets/data_manager.py b/build/lib/datasets/data_manager.py similarity index 94% rename from src/datasets/data_manager.py rename to build/lib/datasets/data_manager.py index cdb7ade..cf53940 100644 --- a/src/datasets/data_manager.py +++ b/build/lib/datasets/data_manager.py @@ -48,7 +48,7 @@ def init_data( if (data.lower() == 'imagenet') \ or (data.lower() == 'inat21') \ or (data.lower() == 'places205'): - from src.datasets.image_dataset import make_imagedataset + from jepa_src.datasets.image_dataset import make_imagedataset dataset, data_loader, dist_sampler = make_imagedataset( transform=transform, batch_size=batch_size, @@ -66,7 +66,7 @@ def init_data( subset_file=subset_file) elif data.lower() == 'videodataset': - from src.datasets.video_dataset import make_videodataset + from jepa_src.datasets.video_dataset import make_videodataset dataset, data_loader, dist_sampler = make_videodataset( data_paths=root_path, batch_size=batch_size, diff --git a/src/datasets/image_dataset.py b/build/lib/datasets/image_dataset.py similarity index 100% rename from src/datasets/image_dataset.py rename to build/lib/datasets/image_dataset.py diff --git a/src/datasets/utils/video/functional.py b/build/lib/datasets/utils/video/functional.py similarity index 100% rename from src/datasets/utils/video/functional.py rename to build/lib/datasets/utils/video/functional.py diff --git a/src/datasets/utils/video/randaugment.py b/build/lib/datasets/utils/video/randaugment.py similarity index 100% rename from src/datasets/utils/video/randaugment.py rename to build/lib/datasets/utils/video/randaugment.py diff --git a/src/datasets/utils/video/randerase.py b/build/lib/datasets/utils/video/randerase.py similarity index 100% rename from src/datasets/utils/video/randerase.py rename to build/lib/datasets/utils/video/randerase.py diff --git a/src/datasets/utils/video/transforms.py b/build/lib/datasets/utils/video/transforms.py similarity index 99% rename from src/datasets/utils/video/transforms.py rename to build/lib/datasets/utils/video/transforms.py index ffa8e61..979985d 100644 --- a/src/datasets/utils/video/transforms.py +++ b/build/lib/datasets/utils/video/transforms.py @@ -17,8 +17,8 @@ import torchvision.transforms.functional as F from torchvision import transforms -import src.datasets.utils.video.functional as FF -from src.datasets.utils.video.randaugment import rand_augment_transform +import jepa_src.datasets.utils.video.functional as FF +from jepa_src.datasets.utils.video.randaugment import rand_augment_transform _pil_interpolation_to_str = { diff --git a/src/datasets/utils/video/volume_transforms.py b/build/lib/datasets/utils/video/volume_transforms.py similarity index 100% rename from src/datasets/utils/video/volume_transforms.py rename to build/lib/datasets/utils/video/volume_transforms.py diff --git a/src/datasets/utils/weighted_sampler.py b/build/lib/datasets/utils/weighted_sampler.py similarity index 100% rename from src/datasets/utils/weighted_sampler.py rename to build/lib/datasets/utils/weighted_sampler.py diff --git a/src/datasets/video_dataset.py b/build/lib/datasets/video_dataset.py similarity index 97% rename from src/datasets/video_dataset.py rename to build/lib/datasets/video_dataset.py index b05cc70..82cee52 100644 --- a/src/datasets/video_dataset.py +++ b/build/lib/datasets/video_dataset.py @@ -18,7 +18,7 @@ import torch -from src.datasets.utils.weighted_sampler import DistributedWeightedSampler +from jepa_src.datasets.utils.weighted_sampler import DistributedWeightedSampler _GLOBAL_SEED = 0 logger = getLogger() @@ -188,15 +188,15 @@ def loadvideo_decord(self, sample): fname = sample if not os.path.exists(fname): - warnings.warn(f'video path not found {fname=}') + warnings.warn(f'video path not found {fname}') return [], None _fsize = os.path.getsize(fname) if _fsize < 1 * 1024: # avoid hanging issue - warnings.warn(f'video too short {fname=}') + warnings.warn(f'video too short {fname}') return [], None if _fsize > self.filter_long_videos: - warnings.warn(f'skipping long video of size {_fsize=} (bytes)') + warnings.warn(f'skipping long video of size {_fsize} (bytes)') return [], None try: diff --git a/build/lib/jepa_src/__init__.py b/build/lib/jepa_src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/jepa_src/datasets/__init__.py b/build/lib/jepa_src/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/jepa_src/datasets/data_manager.py b/build/lib/jepa_src/datasets/data_manager.py new file mode 100644 index 0000000..cf53940 --- /dev/null +++ b/build/lib/jepa_src/datasets/data_manager.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from logging import getLogger + + +_GLOBAL_SEED = 0 +logger = getLogger() + + +def init_data( + batch_size, + transform=None, + shared_transform=None, + data='ImageNet', + collator=None, + pin_mem=True, + num_workers=8, + world_size=1, + rank=0, + root_path=None, + image_folder=None, + training=True, + copy_data=False, + drop_last=True, + tokenize_txt=True, + subset_file=None, + clip_len=8, + frame_sample_rate=2, + duration=None, + num_clips=1, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(1e9), + decode_one_clip=True, + datasets_weights=None, + persistent_workers=False, + repeat_wds=False, + ipe=300, + log_dir=None, +): + + if (data.lower() == 'imagenet') \ + or (data.lower() == 'inat21') \ + or (data.lower() == 'places205'): + from jepa_src.datasets.image_dataset import make_imagedataset + dataset, data_loader, dist_sampler = make_imagedataset( + transform=transform, + batch_size=batch_size, + collator=collator, + pin_mem=pin_mem, + training=training, + num_workers=num_workers, + world_size=world_size, + rank=rank, + root_path=root_path, + image_folder=image_folder, + persistent_workers=persistent_workers, + copy_data=copy_data, + drop_last=drop_last, + subset_file=subset_file) + + elif data.lower() == 'videodataset': + from jepa_src.datasets.video_dataset import make_videodataset + dataset, data_loader, dist_sampler = make_videodataset( + data_paths=root_path, + batch_size=batch_size, + frames_per_clip=clip_len, + frame_step=frame_sample_rate, + duration=duration, + num_clips=num_clips, + random_clip_sampling=random_clip_sampling, + allow_clip_overlap=allow_clip_overlap, + filter_short_videos=filter_short_videos, + filter_long_videos=filter_long_videos, + shared_transform=shared_transform, + transform=transform, + datasets_weights=datasets_weights, + collator=collator, + num_workers=num_workers, + world_size=world_size, + rank=rank, + drop_last=drop_last, + log_dir=log_dir) + + return (data_loader, dist_sampler) diff --git a/build/lib/jepa_src/datasets/image_dataset.py b/build/lib/jepa_src/datasets/image_dataset.py new file mode 100644 index 0000000..84e9b08 --- /dev/null +++ b/build/lib/jepa_src/datasets/image_dataset.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +from logging import getLogger + +import torch +import torchvision + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class ImageFolder(torchvision.datasets.ImageFolder): + + def __init__( + self, + root, + image_folder='imagenet_full_size/061417/', + transform=None, + train=True, + ): + """ + ImageFolder + :param root: root network directory for ImageFolder data + :param image_folder: path to images inside root network directory + :param train: whether to load train data (or validation) + """ + + suffix = 'train/' if train else 'val/' + data_path = os.path.join(root, image_folder, suffix) + logger.info(f'data-path {data_path}') + super(ImageFolder, self).__init__(root=data_path, transform=transform) + logger.info('Initialized ImageFolder') + + +def make_imagedataset( + transform, + batch_size, + collator=None, + pin_mem=True, + num_workers=8, + world_size=1, + rank=0, + root_path=None, + image_folder=None, + training=True, + copy_data=False, + drop_last=True, + persistent_workers=False, + subset_file=None +): + dataset = ImageFolder( + root=root_path, + image_folder=image_folder, + transform=transform, + train=training) + logger.info('ImageFolder dataset created') + dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset=dataset, + num_replicas=world_size, + rank=rank) + data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=collator, + sampler=dist_sampler, + batch_size=batch_size, + drop_last=drop_last, + pin_memory=pin_mem, + num_workers=num_workers, + persistent_workers=persistent_workers) + logger.info('ImageFolder unsupervised data loader created') + + return dataset, data_loader, dist_sampler diff --git a/build/lib/jepa_src/datasets/utils/__init__.py b/build/lib/jepa_src/datasets/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/jepa_src/datasets/utils/video/__init__.py b/build/lib/jepa_src/datasets/utils/video/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/jepa_src/datasets/utils/video/functional.py b/build/lib/jepa_src/datasets/utils/video/functional.py new file mode 100644 index 0000000..a91d15d --- /dev/null +++ b/build/lib/jepa_src/datasets/utils/video/functional.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numbers +import cv2 +import numpy as np +import PIL +import torch + + +def _is_tensor_clip(clip): + return torch.is_tensor(clip) and clip.ndimension() == 4 + + +def crop_clip(clip, min_h, min_w, h, w): + if isinstance(clip[0], np.ndarray): + cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] + + elif isinstance(clip[0], PIL.Image.Image): + cropped = [ + img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return cropped + + +def resize_clip(clip, size, interpolation='bilinear'): + if isinstance(clip[0], np.ndarray): + if isinstance(size, numbers.Number): + im_h, im_w, im_c = clip[0].shape + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[0], size[1] + if interpolation == 'bilinear': + np_inter = cv2.INTER_LINEAR + else: + np_inter = cv2.INTER_NEAREST + scaled = [ + cv2.resize(img, size, interpolation=np_inter) for img in clip + ] + elif isinstance(clip[0], PIL.Image.Image): + if isinstance(size, numbers.Number): + im_w, im_h = clip[0].size + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[1], size[0] + if interpolation == 'bilinear': + pil_inter = PIL.Image.BILINEAR + else: + pil_inter = PIL.Image.NEAREST + scaled = [img.resize(size, pil_inter) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return scaled + + +def get_resize_sizes(im_h, im_w, size): + if im_w < im_h: + ow = size + oh = int(size * im_h / im_w) + else: + oh = size + ow = int(size * im_w / im_h) + return oh, ow + + +def normalize(clip, mean, std, inplace=False): + if not _is_tensor_clip(clip): + raise TypeError('tensor is not a torch clip.') + + if not inplace: + clip = clip.clone() + + dtype = clip.dtype + mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) + std = torch.as_tensor(std, dtype=dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + + return clip diff --git a/build/lib/jepa_src/datasets/utils/video/randaugment.py b/build/lib/jepa_src/datasets/utils/video/randaugment.py new file mode 100644 index 0000000..4c80a99 --- /dev/null +++ b/build/lib/jepa_src/datasets/utils/video/randaugment.py @@ -0,0 +1,518 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py +pulished under an Apache License 2.0. +""" + +import math +import numpy as np +import random +import re +import PIL +from PIL import Image, ImageEnhance, ImageOps + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10.0 + +_HPARAMS_DEFAULT = { + "translate_const": 250, + "img_mean": _FILL, +} + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop("resample", Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if "fillcolor" in kwargs and _PIL_VER < (5, 0): + kwargs.pop("fillcolor") + kwargs["resample"] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs + ) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs + ) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs + ) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs + ) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs + ) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs + ) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], + -rotn_center[1] - post_trans[1], + matrix, + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs["resample"]) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30.0 + level = _randomly_negate(level) + return (level,) + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + + +def _enhance_increasing_level_to_arg(level, _hparams): + # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend + # range [0.1, 1.9] + level = (level / _MAX_LEVEL) * 0.9 + level = 1.0 + _randomly_negate(level) + return (level,) + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams["translate_const"] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return (level,) + + +def _translate_rel_level_to_arg(level, hparams): + # default range [-0.45, 0.45] + translate_pct = hparams.get("translate_pct", 0.45) + level = (level / _MAX_LEVEL) * translate_pct + level = _randomly_negate(level) + return (level,) + + +def _posterize_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4),) + + +def _posterize_increasing_level_to_arg(level, hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image', + # intensity/severity of augmentation increases with level + return (4 - _posterize_level_to_arg(level, hparams)[0],) + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4) + 4,) + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 256),) + + +def _solarize_increasing_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation increases with level + return (256 - _solarize_level_to_arg(level, _hparams)[0],) + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return (int((level / _MAX_LEVEL) * 110),) + + +LEVEL_TO_ARG = { + "AutoContrast": None, + "Equalize": None, + "Invert": None, + "Rotate": _rotate_level_to_arg, + # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + "Posterize": _posterize_level_to_arg, + "PosterizeIncreasing": _posterize_increasing_level_to_arg, + "PosterizeOriginal": _posterize_original_level_to_arg, + "Solarize": _solarize_level_to_arg, + "SolarizeIncreasing": _solarize_increasing_level_to_arg, + "SolarizeAdd": _solarize_add_level_to_arg, + "Color": _enhance_level_to_arg, + "ColorIncreasing": _enhance_increasing_level_to_arg, + "Contrast": _enhance_level_to_arg, + "ContrastIncreasing": _enhance_increasing_level_to_arg, + "Brightness": _enhance_level_to_arg, + "BrightnessIncreasing": _enhance_increasing_level_to_arg, + "Sharpness": _enhance_level_to_arg, + "SharpnessIncreasing": _enhance_increasing_level_to_arg, + "ShearX": _shear_level_to_arg, + "ShearY": _shear_level_to_arg, + "TranslateX": _translate_abs_level_to_arg, + "TranslateY": _translate_abs_level_to_arg, + "TranslateXRel": _translate_rel_level_to_arg, + "TranslateYRel": _translate_rel_level_to_arg, +} + + +NAME_TO_OP = { + "AutoContrast": auto_contrast, + "Equalize": equalize, + "Invert": invert, + "Rotate": rotate, + "Posterize": posterize, + "PosterizeIncreasing": posterize, + "PosterizeOriginal": posterize, + "Solarize": solarize, + "SolarizeIncreasing": solarize, + "SolarizeAdd": solarize_add, + "Color": color, + "ColorIncreasing": color, + "Contrast": contrast, + "ContrastIncreasing": contrast, + "Brightness": brightness, + "BrightnessIncreasing": brightness, + "Sharpness": sharpness, + "SharpnessIncreasing": sharpness, + "ShearX": shear_x, + "ShearY": shear_y, + "TranslateX": translate_x_abs, + "TranslateY": translate_y_abs, + "TranslateXRel": translate_x_rel, + "TranslateYRel": translate_y_rel, +} + + +class AugmentOp: + """ + Apply for video. + """ + + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = { + "fillcolor": hparams["img_mean"] + if "img_mean" in hparams + else _FILL, + "resample": hparams["interpolation"] + if "interpolation" in hparams + else _RANDOM_INTERPOLATION, + } + + # If magnitude_std is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from a normal distribution + # with mean `magnitude` and std-dev of `magnitude_std`. + # NOTE This is my own hack, being tested, not in papers or reference impls. + self.magnitude_std = self.hparams.get("magnitude_std", 0) + + def __call__(self, img_list): + if self.prob < 1.0 and random.random() > self.prob: + return img_list + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + level_args = ( + self.level_fn(magnitude, self.hparams) + if self.level_fn is not None + else () + ) + + if isinstance(img_list, list): + return [ + self.aug_fn(img, *level_args, **self.kwargs) for img in img_list + ] + else: + return self.aug_fn(img_list, *level_args, **self.kwargs) + + +_RAND_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "Posterize", + "Solarize", + "SolarizeAdd", + "Color", + "Contrast", + "Brightness", + "Sharpness", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + + +_RAND_INCREASING_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "PosterizeIncreasing", + "SolarizeIncreasing", + "SolarizeAdd", + "ColorIncreasing", + "ContrastIncreasing", + "BrightnessIncreasing", + "SharpnessIncreasing", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + + +# These experimental weights are based loosely on the relative improvements mentioned in paper. +# They may not result in increased performance, but could likely be tuned to so. +_RAND_CHOICE_WEIGHTS_0 = { + "Rotate": 0.3, + "ShearX": 0.2, + "ShearY": 0.2, + "TranslateXRel": 0.1, + "TranslateYRel": 0.1, + "Color": 0.025, + "Sharpness": 0.025, + "AutoContrast": 0.025, + "Solarize": 0.005, + "SolarizeAdd": 0.005, + "Contrast": 0.005, + "Brightness": 0.005, + "Equalize": 0.005, + "Posterize": 0, + "Invert": 0, +} + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [ + AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) + for name in transforms + ] + + +class RandAugment: + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + # no replacement when using weighted choice + ops = np.random.choice( + self.ops, + self.num_layers, + replace=self.choice_weights is None, + p=self.choice_weights, + ) + for op in ops: + img = op(img) + return img + + +def rand_augment_transform(config_str, hparams): + """ + RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 + + Create a RandAugment transform + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) + 'mstd' - float std deviation of magnitude noise applied + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 + :param hparams: Other hparams (kwargs) for the RandAugmentation scheme + :return: A PyTorch compatible Transform + """ + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + transforms = _RAND_TRANSFORMS + config = config_str.split("-") + assert config[0] == "rand" + config = config[1:] + for c in config: + cs = re.split(r"(\d.*)", c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == "mstd": + # noise param injected via hparams for now + hparams.setdefault("magnitude_std", float(val)) + elif key == "inc": + if bool(val): + transforms = _RAND_INCREASING_TRANSFORMS + elif key == "m": + magnitude = int(val) + elif key == "n": + num_layers = int(val) + elif key == "w": + weight_idx = int(val) + else: + assert NotImplementedError + ra_ops = rand_augment_ops( + magnitude=magnitude, hparams=hparams, transforms=transforms + ) + choice_weights = ( + None if weight_idx is None else _select_rand_weights(weight_idx) + ) + return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) diff --git a/build/lib/jepa_src/datasets/utils/video/randerase.py b/build/lib/jepa_src/datasets/utils/video/randerase.py new file mode 100644 index 0000000..d1f185c --- /dev/null +++ b/build/lib/jepa_src/datasets/utils/video/randerase.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py +pulished under an Apache License 2.0. +""" +import math +import random +import torch + + +def _get_pixels( + per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" +): + # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() + # paths, flip the order so normal is run on CPU if this becomes a problem + # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 + if per_pixel: + return torch.empty(patch_size, dtype=dtype, device=device).normal_() + elif rand_color: + return torch.empty( + (patch_size[0], 1, 1), dtype=dtype, device=device + ).normal_() + else: + return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) + + +class RandomErasing: + """Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + This variant of RandomErasing is intended to be applied to either a batch + or single image tensor after it has been normalized by dataset mean and std. + Args: + probability: Probability that the Random Erasing operation will be performed. + min_area: Minimum percentage of erased area wrt input image area. + max_area: Maximum percentage of erased area wrt input image area. + min_aspect: Minimum aspect ratio of erased area. + mode: pixel color mode, one of 'const', 'rand', or 'pixel' + 'const' - erase block is constant color of 0 for all channels + 'rand' - erase block is same per-channel random (normal) color + 'pixel' - erase block is per-pixel random (normal) color + max_count: maximum number of erasing blocks per image, area per box is scaled by count. + per-image count is randomly chosen between 1 and this value. + """ + + def __init__( + self, + probability=0.5, + min_area=0.02, + max_area=1 / 3, + min_aspect=0.3, + max_aspect=None, + mode="const", + min_count=1, + max_count=None, + num_splits=0, + device="cuda", + cube=True, + ): + self.probability = probability + self.min_area = min_area + self.max_area = max_area + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + self.min_count = min_count + self.max_count = max_count or min_count + self.num_splits = num_splits + mode = mode.lower() + self.rand_color = False + self.per_pixel = False + self.cube = cube + if mode == "rand": + self.rand_color = True # per block random normal + elif mode == "pixel": + self.per_pixel = True # per pixel random normal + else: + assert not mode or mode == "const" + self.device = device + + def _erase(self, img, chan, img_h, img_w, dtype): + if random.random() > self.probability: + return + area = img_h * img_w + count = ( + self.min_count + if self.min_count == self.max_count + else random.randint(self.min_count, self.max_count) + ) + for _ in range(count): + for _ in range(10): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / count + ) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + img[:, top:top + h, left:left + w] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) + break + + def _erase_cube( + self, + img, + batch_start, + batch_size, + chan, + img_h, + img_w, + dtype, + ): + if random.random() > self.probability: + return + area = img_h * img_w + count = ( + self.min_count + if self.min_count == self.max_count + else random.randint(self.min_count, self.max_count) + ) + for _ in range(count): + for _ in range(100): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / count + ) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + for i in range(batch_start, batch_size): + img_instance = img[i] + img_instance[ + :, top:top + h, left:left + w + ] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) + break + + def __call__(self, input): + if len(input.size()) == 3: + self._erase(input, *input.size(), input.dtype) + else: + batch_size, chan, img_h, img_w = input.size() + # skip first slice of batch if num_splits is set (for clean portion of samples) + batch_start = ( + batch_size // self.num_splits if self.num_splits > 1 else 0 + ) + if self.cube: + self._erase_cube( + input, + batch_start, + batch_size, + chan, + img_h, + img_w, + input.dtype, + ) + else: + for i in range(batch_start, batch_size): + self._erase(input[i], chan, img_h, img_w, input.dtype) + return input diff --git a/build/lib/jepa_src/datasets/utils/video/transforms.py b/build/lib/jepa_src/datasets/utils/video/transforms.py new file mode 100644 index 0000000..979985d --- /dev/null +++ b/build/lib/jepa_src/datasets/utils/video/transforms.py @@ -0,0 +1,1184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +import numpy as np +import random +import numbers +import PIL +from PIL import Image + +import torch +import torchvision +import torchvision.transforms.functional as F +from torchvision import transforms + +import jepa_src.datasets.utils.video.functional as FF +from jepa_src.datasets.utils.video.randaugment import rand_augment_transform + + +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _pil_interp(method): + if method == 'bicubic': + return Image.BICUBIC + elif method == 'lanczos': + return Image.LANCZOS + elif method == 'hamming': + return Image.HAMMING + else: + return Image.BILINEAR + + +def random_short_side_scale_jitter( + images, min_size, max_size, boxes=None, inverse_uniform_sampling=False +): + """ + Perform a spatial short scale jittering on the given images and + corresponding boxes. + Args: + images (tensor): images to perform scale jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + min_size (int): the minimal size to scale the frames. + max_size (int): the maximal size to scale the frames. + boxes (ndarray): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + inverse_uniform_sampling (bool): if True, sample uniformly in + [1 / max_scale, 1 / min_scale] and take a reciprocal to get the + scale. If False, take a uniform sample from [min_scale, max_scale]. + Returns: + (tensor): the scaled images with dimension of + `num frames` x `channel` x `new height` x `new width`. + (ndarray or None): the scaled boxes with dimension of + `num boxes` x 4. + """ + if inverse_uniform_sampling: + size = int( + round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)) + ) + else: + size = int(round(np.random.uniform(min_size, max_size))) + + height = images.shape[2] + width = images.shape[3] + if (width <= height and width == size) or ( + height <= width and height == size + ): + return images, boxes + new_width = size + new_height = size + if width < height: + new_height = int(math.floor((float(height) / width) * size)) + if boxes is not None: + boxes = boxes * float(new_height) / height + else: + new_width = int(math.floor((float(width) / height) * size)) + if boxes is not None: + boxes = boxes * float(new_width) / width + + return ( + torch.nn.functional.interpolate( + images, + size=(new_height, new_width), + mode='bilinear', + align_corners=False, + ), + boxes, + ) + + +def crop_boxes(boxes, x_offset, y_offset): + """ + Peform crop on the bounding boxes given the offsets. + Args: + boxes (ndarray or None): bounding boxes to peform crop. The dimension + is `num boxes` x 4. + x_offset (int): cropping offset in the x axis. + y_offset (int): cropping offset in the y axis. + Returns: + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + cropped_boxes = boxes.copy() + cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset + cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset + + return cropped_boxes + + +def random_crop(images, size, boxes=None): + """ + Perform random spatial crop on the given images and corresponding boxes. + Args: + images (tensor): images to perform random crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): the size of height and width to crop on the image. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + Returns: + cropped (tensor): cropped images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + if images.shape[2] == size and images.shape[3] == size: + return images + height = images.shape[2] + width = images.shape[3] + y_offset = 0 + if height > size: + y_offset = int(np.random.randint(0, height - size)) + x_offset = 0 + if width > size: + x_offset = int(np.random.randint(0, width - size)) + cropped = images[ + :, :, y_offset:y_offset + size, x_offset:x_offset + size + ] + + cropped_boxes = ( + crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + ) + + return cropped, cropped_boxes + + +def horizontal_flip(prob, images, boxes=None): + """ + Perform horizontal flip on the given images and corresponding boxes. + Args: + prob (float): probility to flip the images. + images (tensor): images to perform horizontal flip, the dimension is + `num frames` x `channel` x `height` x `width`. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + Returns: + images (tensor): images with dimension of + `num frames` x `channel` x `height` x `width`. + flipped_boxes (ndarray or None): the flipped boxes with dimension of + `num boxes` x 4. + """ + if boxes is None: + flipped_boxes = None + else: + flipped_boxes = boxes.copy() + + if np.random.uniform() < prob: + images = images.flip((-1)) + + if len(images.shape) == 3: + width = images.shape[2] + elif len(images.shape) == 4: + width = images.shape[3] + else: + raise NotImplementedError("Dimension does not supported") + if boxes is not None: + flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1 + + return images, flipped_boxes + + +def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): + """ + Perform uniform spatial sampling on the images and corresponding boxes. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode='bilinear', + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[ + :, :, y_offset:y_offset + size, x_offset:x_offset + size + ] + cropped_boxes = ( + crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + ) + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped, cropped_boxes + + +def clip_boxes_to_image(boxes, height, width): + """ + Clip an array of boxes to an image with the given height and width. + Args: + boxes (ndarray): bounding boxes to perform clipping. + Dimension is `num boxes` x 4. + height (int): given image height. + width (int): given image width. + Returns: + clipped_boxes (ndarray): the clipped boxes with dimension of + `num boxes` x 4. + """ + clipped_boxes = boxes.copy() + clipped_boxes[:, [0, 2]] = np.minimum( + width - 1.0, np.maximum(0.0, boxes[:, [0, 2]]) + ) + clipped_boxes[:, [1, 3]] = np.minimum( + height - 1.0, np.maximum(0.0, boxes[:, [1, 3]]) + ) + return clipped_boxes + + +def blend(images1, images2, alpha): + """ + Blend two images with a given weight alpha. + Args: + images1 (tensor): the first images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + images2 (tensor): the second images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + alpha (float): the blending weight. + Returns: + (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + return images1 * alpha + images2 * (1 - alpha) + + +def grayscale(images): + """ + Get the grayscale for the input images. The channels of images should be + in order BGR. + Args: + images (tensor): the input images for getting grayscale. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + img_gray (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + # R -> 0.299, G -> 0.587, B -> 0.114. + img_gray = torch.tensor(images) + gray_channel = ( + 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] + ) + img_gray[:, 0] = gray_channel + img_gray[:, 1] = gray_channel + img_gray[:, 2] = gray_channel + return img_gray + + +def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): + """ + Perfrom a color jittering on the input images. The channels of images + should be in order BGR. + Args: + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + img_brightness (float): jitter ratio for brightness. + img_contrast (float): jitter ratio for contrast. + img_saturation (float): jitter ratio for saturation. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + + jitter = [] + if img_brightness != 0: + jitter.append('brightness') + if img_contrast != 0: + jitter.append('contrast') + if img_saturation != 0: + jitter.append('saturation') + + if len(jitter) > 0: + order = np.random.permutation(np.arange(len(jitter))) + for idx in range(0, len(jitter)): + if jitter[order[idx]] == 'brightness': + images = brightness_jitter(img_brightness, images) + elif jitter[order[idx]] == 'contrast': + images = contrast_jitter(img_contrast, images) + elif jitter[order[idx]] == 'saturation': + images = saturation_jitter(img_saturation, images) + return images + + +def brightness_jitter(var, images): + """ + Perfrom brightness jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for brightness. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_bright = torch.zeros(images.shape) + images = blend(images, img_bright, alpha) + return images + + +def contrast_jitter(var, images): + """ + Perfrom contrast jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for contrast. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_gray = grayscale(images) + img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True) + images = blend(images, img_gray, alpha) + return images + + +def saturation_jitter(var, images): + """ + Perfrom saturation jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for saturation. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + img_gray = grayscale(images) + images = blend(images, img_gray, alpha) + + return images + + +def lighting_jitter(images, alphastd, eigval, eigvec): + """ + Perform AlexNet-style PCA jitter on the given images. + Args: + images (tensor): images to perform lighting jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + alphastd (float): jitter ratio for PCA jitter. + eigval (list): eigenvalues for PCA jitter. + eigvec (list[list]): eigenvectors for PCA jitter. + Returns: + out_images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if alphastd == 0: + return images + # generate alpha1, alpha2, alpha3. + alpha = np.random.normal(0, alphastd, size=(1, 3)) + eig_vec = np.array(eigvec) + eig_val = np.reshape(eigval, (1, 3)) + rgb = np.sum( + eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), + axis=1, + ) + out_images = torch.zeros_like(images) + if len(images.shape) == 3: + # C H W + channel_dim = 0 + elif len(images.shape) == 4: + # T C H W + channel_dim = 1 + else: + raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + + for idx in range(images.shape[channel_dim]): + # C H W + if len(images.shape) == 3: + out_images[idx] = images[idx] + rgb[2 - idx] + # T C H W + elif len(images.shape) == 4: + out_images[:, idx] = images[:, idx] + rgb[2 - idx] + else: + raise NotImplementedError( + f'Unsupported dimension {len(images.shape)}' + ) + + return out_images + + +def color_normalization(images, mean, stddev): + """ + Perform color nomration on the given images. + Args: + images (tensor): images to perform color normalization. Dimension is + `num frames` x `channel` x `height` x `width`. + mean (list): mean values for normalization. + stddev (list): standard deviations for normalization. + + Returns: + out_images (tensor): the noramlized images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if len(images.shape) == 3: + assert ( + len(mean) == images.shape[0] + ), 'channel mean not computed properly' + assert ( + len(stddev) == images.shape[0] + ), 'channel stddev not computed properly' + elif len(images.shape) == 4: + assert ( + len(mean) == images.shape[1] + ), 'channel mean not computed properly' + assert ( + len(stddev) == images.shape[1] + ), 'channel stddev not computed properly' + else: + raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + + out_images = torch.zeros_like(images) + for idx in range(len(mean)): + # C H W + if len(images.shape) == 3: + out_images[idx] = (images[idx] - mean[idx]) / stddev[idx] + elif len(images.shape) == 4: + out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] + else: + raise NotImplementedError( + f'Unsupported dimension {len(images.shape)}' + ) + return out_images + + +def _get_param_spatial_crop( + scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False +): + """ + Given scale, ratio, height and width, return sampled coordinates of the videos. + """ + for _ in range(num_repeat): + area = height * width + target_area = random.uniform(*scale) * area + if log_scale: + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + else: + aspect_ratio = random.uniform(*ratio) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if np.random.uniform() < 0.5 and switch_hw: + w, h = h, w + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(ratio): + w = width + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + +def random_resized_crop( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + Crop the given images to random size and aspect ratio. A crop of random + size (default: of 0.08 to 1.0) of the original size and a random aspect + ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This + crop is finally resized to given size. This is popularly used to train the + Inception networks. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + cropped = images[:, :, i:i + h, j:j + w] + return torch.nn.functional.interpolate( + cropped, + size=(target_height, target_width), + mode='bilinear', + align_corners=False, + ) + + +def random_resized_crop_with_shift( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + This is similar to random_resized_crop. However, it samples two different + boxes (for cropping) for the first and last frame. It then linearly + interpolates the two boxes for other frames. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + t = images.shape[1] + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width) + i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()] + j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()] + h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()] + w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] + out = torch.zeros((3, t, target_height, target_width)) + for ind in range(t): + out[:, ind:ind + 1, :, :] = torch.nn.functional.interpolate( + images[ + :, + ind:ind + 1, + i_s[ind]:i_s[ind] + h_s[ind], + j_s[ind]:j_s[ind] + w_s[ind], + ], + size=(target_height, target_width), + mode='bilinear', + align_corners=False, + ) + return out + + +def create_random_augment( + input_size, + auto_augment=None, + interpolation='bilinear', +): + """ + Get video randaug transform. + + Args: + input_size: The size of the input video in tuple. + auto_augment: Parameters for randaug. An example: + "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number + of operations to apply). + interpolation: Interpolation method. + """ + if isinstance(input_size, tuple): + img_size = input_size[-2:] + else: + img_size = input_size + + if auto_augment: + assert isinstance(auto_augment, str) + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size + aa_params = {'translate_const': int(img_size_min * 0.45)} + if interpolation and interpolation != 'random': + aa_params['interpolation'] = _pil_interp(interpolation) + if auto_augment.startswith('rand'): + return transforms.Compose( + [rand_augment_transform(auto_augment, aa_params)] + ) + raise NotImplementedError + + +def random_sized_crop_img( + im, + size, + jitter_scale=(0.08, 1.0), + jitter_aspect=(3.0 / 4.0, 4.0 / 3.0), + max_iter=10, +): + """ + Performs Inception-style cropping (used for training). + """ + assert ( + len(im.shape) == 3 + ), 'Currently only support image for random_sized_crop' + h, w = im.shape[1:3] + i, j, h, w = _get_param_spatial_crop( + scale=jitter_scale, + ratio=jitter_aspect, + height=h, + width=w, + num_repeat=max_iter, + log_scale=False, + switch_hw=True, + ) + cropped = im[:, i:i + h, j:j + w] + return torch.nn.functional.interpolate( + cropped.unsqueeze(0), + size=(size, size), + mode='bilinear', + align_corners=False, + ).squeeze(0) + + +# The following code are modified based on timm lib, we will replace the following +# contents with dependency from PyTorchVideo. +# https://github.com/facebookresearch/pytorchvideo +class RandomResizedCropAndInterpolation: + """Crop the given PIL Image to random size and aspect ratio with random interpolation. + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation='bilinear', + ): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + print('range should be of kind (min, max)') + + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = _pil_interp(interpolation) + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for _ in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + return F.resized_crop(img, i, j, h, w, self.size, interpolation) + + def __repr__(self): + if isinstance(self.interpolation, (tuple, list)): + interpolate_str = ' '.join( + [_pil_interpolation_to_str[x] for x in self.interpolation] + ) + else: + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format( + tuple(round(s, 4) for s in self.scale) + ) + format_string += ', ratio={0}'.format( + tuple(round(r, 4) for r in self.ratio) + ) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + +class Compose(object): + """Composes several transforms + Args: + transforms (list of ``Transform`` objects): list of transforms + to compose + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, clip): + for t in self.transforms: + clip = t(clip) + return clip + + +class RandomHorizontalFlip(object): + """Horizontally flip the list of given images randomly + with a probability 0.5 + """ + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Randomly flipped clip + """ + if random.random() < 0.5: + if isinstance(clip[0], np.ndarray): + return [np.fliplr(img) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + return [ + img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + ' but got list of {0}'.format(type(clip[0]))) + return clip + + +class RandomResize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): + self.ratio = ratio + self.interpolation = interpolation + + def __call__(self, clip): + scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) + + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + + new_w = int(im_w * scaling_factor) + new_h = int(im_h * scaling_factor) + new_size = (new_w, new_h) + resized = FF.resize_clip( + clip, new_size, interpolation=self.interpolation) + return resized + + +class Resize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, size, interpolation='nearest'): + self.size = size + self.interpolation = interpolation + + def __call__(self, clip): + resized = FF.resize_clip( + clip, self.size, interpolation=self.interpolation) + return resized + + +class RandomCrop(object): + """Extract random crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w > im_w or h > im_h: + error_msg = ( + 'Initial image size should be larger then ' + 'cropped size but got cropped sizes : ({w}, {h}) while ' + 'initial image is ({im_w}, {im_h})'.format( + im_w=im_w, im_h=im_h, w=w, h=h)) + raise ValueError(error_msg) + + x1 = random.randint(0, im_w - w) + y1 = random.randint(0, im_h - h) + cropped = FF.crop_clip(clip, y1, x1, h, w) + + return cropped + + +class ThreeCrop(object): + """Extract random crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w != im_w and h != im_h: + clip = FF.resize_clip(clip, self.size, interpolation="bilinear") + im_h, im_w, im_c = clip[0].shape + + step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0) + cropped = [] + for i in range(3): + if (im_h > self.size[0]): + x1 = 0 + y1 = i * step + cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) + else: + x1 = i * step + y1 = 0 + cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) + return cropped + + +class RandomRotation(object): + """Rotate entire clip randomly by a random angle within + given bounds + Args: + degrees (sequence or int): Range of degrees to select from + If degrees is a number instead of sequence like (min, max), + the range of degrees, will be (-degrees, +degrees). + """ + + def __init__(self, degrees): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError('If degrees is a single number,' + 'must be positive') + degrees = (-degrees, degrees) + else: + if len(degrees) != 2: + raise ValueError('If degrees is a sequence,' + 'it must be of len 2.') + + self.degrees = degrees + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + import skimage + angle = random.uniform(self.degrees[0], self.degrees[1]) + if isinstance(clip[0], np.ndarray): + rotated = [skimage.transform.rotate(img, angle) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + rotated = [img.rotate(angle) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + + return rotated + + +class CenterCrop(object): + """Extract center crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w > im_w or h > im_h: + error_msg = ( + 'Initial image size should be larger then ' + 'cropped size but got cropped sizes : ({w}, {h}) while ' + 'initial image is ({im_w}, {im_h})'.format( + im_w=im_w, im_h=im_h, w=w, h=h)) + raise ValueError(error_msg) + + x1 = int(round((im_w - w) / 2.)) + y1 = int(round((im_h - h) / 2.)) + cropped = FF.crop_clip(clip, y1, x1, h, w) + + return cropped + + +class ColorJitter(object): + """ + Randomly change the brightness, contrast and saturation and hue of the clip + + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + def get_params(self, brightness, contrast, saturation, hue): + if brightness > 0: + brightness_factor = random.uniform( + max(0, 1 - brightness), 1 + brightness) + else: + brightness_factor = None + + if contrast > 0: + contrast_factor = random.uniform( + max(0, 1 - contrast), 1 + contrast) + else: + contrast_factor = None + + if saturation > 0: + saturation_factor = random.uniform( + max(0, 1 - saturation), 1 + saturation) + else: + saturation_factor = None + + if hue > 0: + hue_factor = random.uniform(-hue, hue) + else: + hue_factor = None + return brightness_factor, contrast_factor, saturation_factor, hue_factor + + def __call__(self, clip): + """ + Args: + clip (list): list of PIL.Image + Returns: + list PIL.Image : list of transformed PIL.Image + """ + if isinstance(clip[0], np.ndarray): + raise TypeError( + 'Color jitter not yet implemented for numpy arrays') + elif isinstance(clip[0], PIL.Image.Image): + brightness, contrast, saturation, hue = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + # Create img transform function sequence + img_transforms = [] + if brightness is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) + if saturation is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) + if hue is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) + if contrast is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) + random.shuffle(img_transforms) + + # Apply to all images + jittered_clip = [] + for img in clip: + for func in img_transforms: + jittered_img = func(img) + jittered_clip.append(jittered_img) + + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return jittered_clip + + +class Normalize(object): + """Normalize a clip with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + .. note:: + This transform acts out of place, i.e., it does not mutates the input tensor. + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, clip): + """ + Args: + clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized. + Returns: + Tensor: Normalized Tensor clip. + """ + return FF.normalize(clip, self.mean, self.std) + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) diff --git a/build/lib/jepa_src/datasets/utils/video/volume_transforms.py b/build/lib/jepa_src/datasets/utils/video/volume_transforms.py new file mode 100644 index 0000000..0a01bb3 --- /dev/null +++ b/build/lib/jepa_src/datasets/utils/video/volume_transforms.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numpy as np +from PIL import Image + +import torch + + +def convert_img(img): + """Converts (H, W, C) numpy.ndarray to (C, W, H) format""" + if len(img.shape) == 3: + img = img.transpose(2, 0, 1) + if len(img.shape) == 2: + img = np.expand_dims(img, 0) + return img + + +class ClipToTensor(object): + """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] + to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] + """ + + def __init__(self, channel_nb=3, div_255=True, numpy=False): + self.channel_nb = channel_nb + self.div_255 = div_255 + self.numpy = numpy + + def __call__(self, clip): + """ + Args: clip (list of numpy.ndarray): clip (list of images) + to be converted to tensor. + """ + # Retrieve shape + if isinstance(clip[0], np.ndarray): + h, w, ch = clip[0].shape + assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) + elif isinstance(clip[0], Image.Image): + w, h = clip[0].size + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + + np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) + + # Convert + for img_idx, img in enumerate(clip): + if isinstance(img, np.ndarray): + pass + elif isinstance(img, Image.Image): + img = np.array(img, copy=False) + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + img = convert_img(img) + np_clip[:, img_idx, :, :] = img + if self.numpy: + if self.div_255: + np_clip = np_clip / 255.0 + return np_clip + + else: + tensor_clip = torch.from_numpy(np_clip) + + if not isinstance(tensor_clip, torch.FloatTensor): + tensor_clip = tensor_clip.float() + if self.div_255: + tensor_clip = torch.div(tensor_clip, 255) + return tensor_clip + + +# Note this norms data to -1/1 +class ClipToTensor_K(object): + """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] + to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] + """ + + def __init__(self, channel_nb=3, div_255=True, numpy=False): + self.channel_nb = channel_nb + self.div_255 = div_255 + self.numpy = numpy + + def __call__(self, clip): + """ + Args: clip (list of numpy.ndarray): clip (list of images) + to be converted to tensor. + """ + # Retrieve shape + if isinstance(clip[0], np.ndarray): + h, w, ch = clip[0].shape + assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) + elif isinstance(clip[0], Image.Image): + w, h = clip[0].size + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + + np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) + + # Convert + for img_idx, img in enumerate(clip): + if isinstance(img, np.ndarray): + pass + elif isinstance(img, Image.Image): + img = np.array(img, copy=False) + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + img = convert_img(img) + np_clip[:, img_idx, :, :] = img + if self.numpy: + if self.div_255: + np_clip = (np_clip - 127.5) / 127.5 + return np_clip + + else: + tensor_clip = torch.from_numpy(np_clip) + + if not isinstance(tensor_clip, torch.FloatTensor): + tensor_clip = tensor_clip.float() + if self.div_255: + tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) + return tensor_clip + + +class ToTensor(object): + """Converts numpy array to tensor""" + + def __call__(self, array): + tensor = torch.from_numpy(array) + return tensor diff --git a/build/lib/jepa_src/datasets/utils/weighted_sampler.py b/build/lib/jepa_src/datasets/utils/weighted_sampler.py new file mode 100644 index 0000000..fd40825 --- /dev/null +++ b/build/lib/jepa_src/datasets/utils/weighted_sampler.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from typing import Iterator, Optional +from operator import itemgetter +import numpy as np + +import torch +from torch.utils.data import ( + Dataset, + Sampler, + DistributedSampler, + WeightedRandomSampler +) + + +class DatasetFromSampler(Dataset): + + def __init__(self, sampler: Sampler): + self.sampler = sampler + self.sampler_list = None + + def __getitem__(self, index: int): + if self.sampler_list is None: + self.sampler_list = list(self.sampler) + return self.sampler_list[index] + + def __len__(self) -> int: + return len(self.sampler) + + +class DistributedSamplerWrapper(DistributedSampler): + """ Convert any Pytorch Sampler to a DistributedSampler """ + + def __init__( + self, + sampler, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + super(DistributedSamplerWrapper, self).__init__( + DatasetFromSampler(sampler), + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) + self.sampler = sampler + + def __iter__(self) -> Iterator[int]: + self.dataset = DatasetFromSampler(self.sampler) + indexes_of_indexes = super().__iter__() + subsampler_indexes = self.dataset + return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) + + +class CustomWeightedRandomSampler(WeightedRandomSampler): + """ Generalized WeightedRandomSampler to allow for more than 2^24 samples """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __iter__(self): + rand_tensor = np.random.choice( + range(0, len(self.weights)), + size=self.num_samples, + p=self.weights.numpy() / torch.sum(self.weights).numpy(), + replace=self.replacement + ) + rand_tensor = torch.from_numpy(rand_tensor) + return iter(rand_tensor.tolist()) + + +class DistributedWeightedSampler(DistributedSamplerWrapper): + + def __init__( + self, + weights, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + weighted_sampler = CustomWeightedRandomSampler( + weights=weights, + num_samples=len(weights), + replacement=False) + + super(DistributedWeightedSampler, self).__init__( + sampler=weighted_sampler, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) diff --git a/build/lib/jepa_src/datasets/video_dataset.py b/build/lib/jepa_src/datasets/video_dataset.py new file mode 100644 index 0000000..82cee52 --- /dev/null +++ b/build/lib/jepa_src/datasets/video_dataset.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import pathlib +import warnings + +from logging import getLogger + +import numpy as np +import pandas as pd + +from decord import VideoReader, cpu + +import torch + +from jepa_src.datasets.utils.weighted_sampler import DistributedWeightedSampler + +_GLOBAL_SEED = 0 +logger = getLogger() + + +def make_videodataset( + data_paths, + batch_size, + frames_per_clip=8, + frame_step=4, + num_clips=1, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(10**9), + transform=None, + shared_transform=None, + rank=0, + world_size=1, + datasets_weights=None, + collator=None, + drop_last=True, + num_workers=10, + pin_mem=True, + duration=None, + log_dir=None, +): + dataset = VideoDataset( + data_paths=data_paths, + datasets_weights=datasets_weights, + frames_per_clip=frames_per_clip, + frame_step=frame_step, + num_clips=num_clips, + random_clip_sampling=random_clip_sampling, + allow_clip_overlap=allow_clip_overlap, + filter_short_videos=filter_short_videos, + filter_long_videos=filter_long_videos, + duration=duration, + shared_transform=shared_transform, + transform=transform) + + logger.info('VideoDataset dataset created') + if datasets_weights is not None: + dist_sampler = DistributedWeightedSampler( + dataset.sample_weights, + num_replicas=world_size, + rank=rank, + shuffle=True) + else: + dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=True) + + data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=collator, + sampler=dist_sampler, + batch_size=batch_size, + drop_last=drop_last, + pin_memory=pin_mem, + num_workers=num_workers, + persistent_workers=num_workers > 0) + logger.info('VideoDataset unsupervised data loader created') + + return dataset, data_loader, dist_sampler + + +class VideoDataset(torch.utils.data.Dataset): + """ Video classification dataset. """ + + def __init__( + self, + data_paths, + datasets_weights=None, + frames_per_clip=16, + frame_step=4, + num_clips=1, + transform=None, + shared_transform=None, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(10**9), + duration=None, # duration in seconds + ): + self.data_paths = data_paths + self.datasets_weights = datasets_weights + self.frames_per_clip = frames_per_clip + self.frame_step = frame_step + self.num_clips = num_clips + self.transform = transform + self.shared_transform = shared_transform + self.random_clip_sampling = random_clip_sampling + self.allow_clip_overlap = allow_clip_overlap + self.filter_short_videos = filter_short_videos + self.filter_long_videos = filter_long_videos + self.duration = duration + + if VideoReader is None: + raise ImportError('Unable to import "decord" which is required to read videos.') + + # Load video paths and labels + samples, labels = [], [] + self.num_samples_per_dataset = [] + for data_path in self.data_paths: + + if data_path[-4:] == '.csv': + data = pd.read_csv(data_path, header=None, delimiter=" ") + samples += list(data.values[:, 0]) + labels += list(data.values[:, 1]) + num_samples = len(data) + self.num_samples_per_dataset.append(num_samples) + + elif data_path[-4:] == '.npy': + data = np.load(data_path, allow_pickle=True) + data = list(map(lambda x: repr(x)[1:-1], data)) + samples += data + labels += [0] * len(data) + num_samples = len(data) + self.num_samples_per_dataset.append(len(data)) + + # [Optional] Weights for each sample to be used by downstream + # weighted video sampler + self.sample_weights = None + if self.datasets_weights is not None: + self.sample_weights = [] + for dw, ns in zip(self.datasets_weights, self.num_samples_per_dataset): + self.sample_weights += [dw / ns] * ns + + self.samples = samples + self.labels = labels + + def __getitem__(self, index): + sample = self.samples[index] + + # Keep trying to load videos until you find a valid sample + loaded_video = False + while not loaded_video: + buffer, clip_indices = self.loadvideo_decord(sample) # [T H W 3] + loaded_video = len(buffer) > 0 + if not loaded_video: + index = np.random.randint(self.__len__()) + sample = self.samples[index] + + # Label/annotations for video + label = self.labels[index] + + def split_into_clips(video): + """ Split video into a list of clips """ + fpc = self.frames_per_clip + nc = self.num_clips + return [video[i*fpc:(i+1)*fpc] for i in range(nc)] + + # Parse video into frames & apply data augmentations + if self.shared_transform is not None: + buffer = self.shared_transform(buffer) + buffer = split_into_clips(buffer) + if self.transform is not None: + buffer = [self.transform(clip) for clip in buffer] + + return buffer, label, clip_indices + + def loadvideo_decord(self, sample): + """ Load video content using Decord """ + + fname = sample + if not os.path.exists(fname): + warnings.warn(f'video path not found {fname}') + return [], None + + _fsize = os.path.getsize(fname) + if _fsize < 1 * 1024: # avoid hanging issue + warnings.warn(f'video too short {fname}') + return [], None + if _fsize > self.filter_long_videos: + warnings.warn(f'skipping long video of size {_fsize} (bytes)') + return [], None + + try: + vr = VideoReader(fname, num_threads=-1, ctx=cpu(0)) + except Exception: + return [], None + + fpc = self.frames_per_clip + fstp = self.frame_step + if self.duration is not None: + try: + fps = vr.get_avg_fps() + fstp = int(self.duration * fps / fpc) + except Exception as e: + warnings.warn(e) + clip_len = int(fpc * fstp) + + if self.filter_short_videos and len(vr) < clip_len: + warnings.warn(f'skipping video of length {len(vr)}') + return [], None + + vr.seek(0) # Go to start of video before sampling frames + + # Partition video into equal sized segments and sample each clip + # from a different segment + partition_len = len(vr) // self.num_clips + + all_indices, clip_indices = [], [] + for i in range(self.num_clips): + + if partition_len > clip_len: + # If partition_len > clip len, then sample a random window of + # clip_len frames within the segment + end_indx = clip_len + if self.random_clip_sampling: + end_indx = np.random.randint(clip_len, partition_len) + start_indx = end_indx - clip_len + indices = np.linspace(start_indx, end_indx, num=fpc) + indices = np.clip(indices, start_indx, end_indx-1).astype(np.int64) + # -- + indices = indices + i * partition_len + else: + # If partition overlap not allowed and partition_len < clip_len + # then repeatedly append the last frame in the segment until + # we reach the desired clip length + if not self.allow_clip_overlap: + indices = np.linspace(0, partition_len, num=partition_len // fstp) + indices = np.concatenate((indices, np.ones(fpc - partition_len // fstp) * partition_len,)) + indices = np.clip(indices, 0, partition_len-1).astype(np.int64) + # -- + indices = indices + i * partition_len + + # If partition overlap is allowed and partition_len < clip_len + # then start_indx of segment i+1 will lie within segment i + else: + sample_len = min(clip_len, len(vr)) - 1 + indices = np.linspace(0, sample_len, num=sample_len // fstp) + indices = np.concatenate((indices, np.ones(fpc - sample_len // fstp) * sample_len,)) + indices = np.clip(indices, 0, sample_len-1).astype(np.int64) + # -- + clip_step = 0 + if len(vr) > clip_len: + clip_step = (len(vr) - clip_len) // (self.num_clips - 1) + indices = indices + i * clip_step + + clip_indices.append(indices) + all_indices.extend(list(indices)) + + buffer = vr.get_batch(all_indices).asnumpy() + return buffer, clip_indices + + def __len__(self): + return len(self.samples) diff --git a/build/lib/jepa_src/masks/__init__.py b/build/lib/jepa_src/masks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/masks/default.py b/build/lib/jepa_src/masks/default.py similarity index 100% rename from src/masks/default.py rename to build/lib/jepa_src/masks/default.py diff --git a/src/masks/multiblock3d.py b/build/lib/jepa_src/masks/multiblock3d.py similarity index 100% rename from src/masks/multiblock3d.py rename to build/lib/jepa_src/masks/multiblock3d.py diff --git a/src/masks/random_tube.py b/build/lib/jepa_src/masks/random_tube.py similarity index 100% rename from src/masks/random_tube.py rename to build/lib/jepa_src/masks/random_tube.py diff --git a/src/masks/utils.py b/build/lib/jepa_src/masks/utils.py similarity index 100% rename from src/masks/utils.py rename to build/lib/jepa_src/masks/utils.py diff --git a/build/lib/jepa_src/models/__init__.py b/build/lib/jepa_src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/attentive_pooler.py b/build/lib/jepa_src/models/attentive_pooler.py similarity index 97% rename from src/models/attentive_pooler.py rename to build/lib/jepa_src/models/attentive_pooler.py index ecd9986..26b0e0e 100644 --- a/src/models/attentive_pooler.py +++ b/build/lib/jepa_src/models/attentive_pooler.py @@ -10,12 +10,12 @@ import torch import torch.nn as nn -from src.models.utils.modules import ( +from jepa_src.models.utils.modules import ( Block, CrossAttention, CrossAttentionBlock ) -from src.utils.tensors import trunc_normal_ +from jepa_src.utils.tensors import trunc_normal_ class AttentivePooler(nn.Module): diff --git a/src/models/predictor.py b/build/lib/jepa_src/models/predictor.py similarity index 97% rename from src/models/predictor.py rename to build/lib/jepa_src/models/predictor.py index 2dd9a38..95f6bc0 100644 --- a/src/models/predictor.py +++ b/build/lib/jepa_src/models/predictor.py @@ -11,13 +11,13 @@ import torch import torch.nn as nn -from src.models.utils.modules import Block -from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed -from src.utils.tensors import ( +from jepa_src.models.utils.modules import Block +from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from jepa_src.utils.tensors import ( trunc_normal_, repeat_interleave_batch ) -from src.masks.utils import apply_masks +from jepa_src.masks.utils import apply_masks class VisionTransformerPredictor(nn.Module): diff --git a/build/lib/jepa_src/models/utils/__init__.py b/build/lib/jepa_src/models/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/utils/modules.py b/build/lib/jepa_src/models/utils/modules.py similarity index 100% rename from src/models/utils/modules.py rename to build/lib/jepa_src/models/utils/modules.py diff --git a/src/models/utils/multimask.py b/build/lib/jepa_src/models/utils/multimask.py similarity index 100% rename from src/models/utils/multimask.py rename to build/lib/jepa_src/models/utils/multimask.py diff --git a/src/models/utils/patch_embed.py b/build/lib/jepa_src/models/utils/patch_embed.py similarity index 100% rename from src/models/utils/patch_embed.py rename to build/lib/jepa_src/models/utils/patch_embed.py diff --git a/src/models/utils/pos_embs.py b/build/lib/jepa_src/models/utils/pos_embs.py similarity index 100% rename from src/models/utils/pos_embs.py rename to build/lib/jepa_src/models/utils/pos_embs.py diff --git a/src/models/vision_transformer.py b/build/lib/jepa_src/models/vision_transformer.py similarity index 96% rename from src/models/vision_transformer.py rename to build/lib/jepa_src/models/vision_transformer.py index a8748df..946246e 100644 --- a/src/models/vision_transformer.py +++ b/build/lib/jepa_src/models/vision_transformer.py @@ -11,11 +11,11 @@ import torch import torch.nn as nn -from src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D -from src.models.utils.modules import Block -from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed -from src.utils.tensors import trunc_normal_ -from src.masks.utils import apply_masks +from jepa_src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D +from jepa_src.models.utils.modules import Block +from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from jepa_src.utils.tensors import trunc_normal_ +from jepa_src.masks.utils import apply_masks class VisionTransformer(nn.Module): diff --git a/build/lib/jepa_src/utils/__init__.py b/build/lib/jepa_src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/distributed.py b/build/lib/jepa_src/utils/distributed.py similarity index 100% rename from src/utils/distributed.py rename to build/lib/jepa_src/utils/distributed.py diff --git a/src/utils/logging.py b/build/lib/jepa_src/utils/logging.py similarity index 100% rename from src/utils/logging.py rename to build/lib/jepa_src/utils/logging.py diff --git a/src/utils/monitoring.py b/build/lib/jepa_src/utils/monitoring.py similarity index 100% rename from src/utils/monitoring.py rename to build/lib/jepa_src/utils/monitoring.py diff --git a/src/utils/schedulers.py b/build/lib/jepa_src/utils/schedulers.py similarity index 100% rename from src/utils/schedulers.py rename to build/lib/jepa_src/utils/schedulers.py diff --git a/src/utils/tensors.py b/build/lib/jepa_src/utils/tensors.py similarity index 100% rename from src/utils/tensors.py rename to build/lib/jepa_src/utils/tensors.py diff --git a/build/lib/masks/default.py b/build/lib/masks/default.py new file mode 100644 index 0000000..2810c0a --- /dev/null +++ b/build/lib/masks/default.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from logging import getLogger + +import torch + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class DefaultCollator(object): + + def __call__(self, batch): + collated_batch = torch.utils.data.default_collate(batch) + return collated_batch, None, None diff --git a/build/lib/masks/multiblock3d.py b/build/lib/masks/multiblock3d.py new file mode 100644 index 0000000..a7bbc3e --- /dev/null +++ b/build/lib/masks/multiblock3d.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +from multiprocessing import Value + +from logging import getLogger + +import torch + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class MaskCollator(object): + + def __init__( + self, + cfgs_mask, + crop_size=(224, 224), + num_frames=16, + patch_size=(16, 16), + tubelet_size=2, + ): + super(MaskCollator, self).__init__() + + self.mask_generators = [] + for m in cfgs_mask: + mask_generator = _MaskGenerator( + crop_size=crop_size, + num_frames=num_frames, + spatial_patch_size=patch_size, + temporal_patch_size=tubelet_size, + spatial_pred_mask_scale=m.get('spatial_scale'), + temporal_pred_mask_scale=m.get('temporal_scale'), + aspect_ratio=m.get('aspect_ratio'), + npred=m.get('num_blocks'), + max_context_frames_ratio=m.get('max_temporal_keep', 1.0), + max_keep=m.get('max_keep', None), + ) + self.mask_generators.append(mask_generator) + + def step(self): + for mask_generator in self.mask_generators: + mask_generator.step() + + def __call__(self, batch): + + batch_size = len(batch) + collated_batch = torch.utils.data.default_collate(batch) + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(batch_size) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + + return collated_batch, collated_masks_enc, collated_masks_pred + + +class _MaskGenerator(object): + + def __init__( + self, + crop_size=(224, 224), + num_frames=16, + spatial_patch_size=(16, 16), + temporal_patch_size=2, + spatial_pred_mask_scale=(0.2, 0.8), + temporal_pred_mask_scale=(1.0, 1.0), + aspect_ratio=(0.3, 3.0), + npred=1, + max_context_frames_ratio=1.0, + max_keep=None, + ): + super(_MaskGenerator, self).__init__() + if not isinstance(crop_size, tuple): + crop_size = (crop_size, ) * 2 + self.crop_size = crop_size + self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.duration = num_frames // temporal_patch_size + + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.aspect_ratio = aspect_ratio + self.spatial_pred_mask_scale = spatial_pred_mask_scale + self.temporal_pred_mask_scale = temporal_pred_mask_scale + self.npred = npred + self.max_context_duration = max(1, int(self.duration * max_context_frames_ratio)) # maximum number of time-steps (frames) spanned by context mask + self.max_keep = max_keep # maximum number of patches to keep in context + self._itr_counter = Value('i', -1) # collator is shared across worker processes + + def step(self): + i = self._itr_counter + with i.get_lock(): + i.value += 1 + v = i.value + return v + + def _sample_block_size( + self, + generator, + temporal_scale, + spatial_scale, + aspect_ratio_scale + ): + # -- Sample temporal block mask scale + _rand = torch.rand(1, generator=generator).item() + min_t, max_t = temporal_scale + temporal_mask_scale = min_t + _rand * (max_t - min_t) + t = max(1, int(self.duration * temporal_mask_scale)) + + # -- Sample spatial block mask scale + _rand = torch.rand(1, generator=generator).item() + min_s, max_s = spatial_scale + spatial_mask_scale = min_s + _rand * (max_s - min_s) + spatial_num_keep = int(self.height * self.width * spatial_mask_scale) + + # -- Sample block aspect-ratio + _rand = torch.rand(1, generator=generator).item() + min_ar, max_ar = aspect_ratio_scale + aspect_ratio = min_ar + _rand * (max_ar - min_ar) + + # -- Compute block height and width (given scale and aspect-ratio) + h = int(round(math.sqrt(spatial_num_keep * aspect_ratio))) + w = int(round(math.sqrt(spatial_num_keep / aspect_ratio))) + h = min(h, self.height) + w = min(w, self.width) + + return (t, h, w) + + def _sample_block_mask(self, b_size): + t, h, w = b_size + top = torch.randint(0, self.height - h + 1, (1,)) + left = torch.randint(0, self.width - w + 1, (1,)) + start = torch.randint(0, self.duration - t + 1, (1,)) + + mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) + mask[start:start+t, top:top+h, left:left+w] = 0 + + # Context mask will only span the first X frames + # (X=self.max_context_frames) + if self.max_context_duration < self.duration: + mask[self.max_context_duration:, :, :] = 0 + + # -- + return mask + + def __call__(self, batch_size): + """ + Create encoder and predictor masks when collating imgs into a batch + # 1. sample pred block size using seed + # 2. sample several pred block locations for each image (w/o seed) + # 3. return pred masks and complement (enc mask) + """ + seed = self.step() + g = torch.Generator() + g.manual_seed(seed) + p_size = self._sample_block_size( + generator=g, + temporal_scale=self.temporal_pred_mask_scale, + spatial_scale=self.spatial_pred_mask_scale, + aspect_ratio_scale=self.aspect_ratio, + ) + + collated_masks_pred, collated_masks_enc = [], [] + min_keep_enc = min_keep_pred = self.duration * self.height * self.width + for _ in range(batch_size): + + empty_context = True + while empty_context: + + mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) + for _ in range(self.npred): + mask_e *= self._sample_block_mask(p_size) + mask_e = mask_e.flatten() + + mask_p = torch.argwhere(mask_e == 0).squeeze() + mask_e = torch.nonzero(mask_e).squeeze() + + empty_context = len(mask_e) == 0 + if not empty_context: + min_keep_pred = min(min_keep_pred, len(mask_p)) + min_keep_enc = min(min_keep_enc, len(mask_e)) + collated_masks_pred.append(mask_p) + collated_masks_enc.append(mask_e) + + if self.max_keep is not None: + min_keep_enc = min(min_keep_enc, self.max_keep) + + collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred] + collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) + # -- + collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc] + collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) + + return collated_masks_enc, collated_masks_pred diff --git a/build/lib/masks/random_tube.py b/build/lib/masks/random_tube.py new file mode 100644 index 0000000..84c0640 --- /dev/null +++ b/build/lib/masks/random_tube.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from multiprocessing import Value + +from logging import getLogger + +import torch +import numpy as np + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class MaskCollator(object): + + def __init__( + self, + cfgs_mask, + crop_size=(224, 224), + num_frames=16, + patch_size=(16, 16), + tubelet_size=2, + ): + super(MaskCollator, self).__init__() + + self.mask_generators = [] + for m in cfgs_mask: + mask_generator = _MaskGenerator( + crop_size=crop_size, + num_frames=num_frames, + spatial_patch_size=patch_size, + temporal_patch_size=tubelet_size, + ratio=m.get('ratio'), + ) + self.mask_generators.append(mask_generator) + + def step(self): + for mask_generator in self.mask_generators: + mask_generator.step() + + def __call__(self, batch): + + batch_size = len(batch) + collated_batch = torch.utils.data.default_collate(batch) + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(batch_size) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + + return collated_batch, collated_masks_enc, collated_masks_pred + + +class _MaskGenerator(object): + + def __init__( + self, + crop_size=(224, 224), + num_frames=16, + spatial_patch_size=(16, 16), + temporal_patch_size=2, + ratio=0.9, + ): + super(_MaskGenerator, self).__init__() + if not isinstance(crop_size, tuple): + crop_size = (crop_size, ) * 2 + self.crop_size = crop_size + self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.duration = num_frames // temporal_patch_size + + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + self.num_patches_spatial = self.height*self.width + + self.ratio = ratio + + self.num_keep_spatial = int(self.num_patches_spatial*(1.-self.ratio)) + self.num_keep = self.num_keep_spatial * self.duration + + self._itr_counter = Value('i', -1) # collator is shared across worker processes + + def step(self): + i = self._itr_counter + with i.get_lock(): + i.value += 1 + v = i.value + return v + + def __call__(self, batch_size): + def sample_mask(): + mask = np.hstack([ + np.zeros(self.num_patches_spatial - self.num_keep_spatial), + np.ones(self.num_keep_spatial), + ]) + np.random.shuffle(mask) + mask = torch.tensor(np.tile(mask, (self.duration, 1))) + mask = mask.flatten() + mask_p = torch.argwhere(mask == 0).squeeze() + mask_e = torch.nonzero(mask).squeeze() + return mask_e, mask_p + + collated_masks_pred, collated_masks_enc = [], [] + for _ in range(batch_size): + mask_e, mask_p = sample_mask() + collated_masks_enc.append(mask_e) + collated_masks_pred.append(mask_p) + + collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) + collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) + + return collated_masks_enc, collated_masks_pred diff --git a/build/lib/masks/utils.py b/build/lib/masks/utils.py new file mode 100644 index 0000000..ca04af1 --- /dev/null +++ b/build/lib/masks/utils.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch + + +def apply_masks(x, masks, concat=True): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + if not concat: + return all_x + + return torch.cat(all_x, dim=0) diff --git a/build/lib/models/attentive_pooler.py b/build/lib/models/attentive_pooler.py new file mode 100644 index 0000000..26b0e0e --- /dev/null +++ b/build/lib/models/attentive_pooler.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +import torch +import torch.nn as nn + +from jepa_src.models.utils.modules import ( + Block, + CrossAttention, + CrossAttentionBlock +) +from jepa_src.utils.tensors import trunc_normal_ + + +class AttentivePooler(nn.Module): + """ Attentive Pooler """ + def __init__( + self, + num_queries=1, + embed_dim=768, + num_heads=12, + mlp_ratio=4.0, + depth=1, + norm_layer=nn.LayerNorm, + init_std=0.02, + qkv_bias=True, + complete_block=True + ): + super().__init__() + self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) + + self.complete_block = complete_block + if complete_block: + self.cross_attention_block = CrossAttentionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer) + else: + self.cross_attention_block = CrossAttention( + dim=embed_dim, + num_heads=num_heads, + qkv_bias=qkv_bias) + + self.blocks = None + if depth > 1: + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=False, + norm_layer=norm_layer) + for i in range(depth-1)]) + + self.init_std = init_std + trunc_normal_(self.query_tokens, std=self.init_std) + self.apply(self._init_weights) + self._rescale_blocks() + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + if self.complete_block: + rescale(self.cross_attention_block.xattn.proj.weight.data, 1) + rescale(self.cross_attention_block.mlp.fc2.weight.data, 1) + else: + rescale(self.cross_attention_block.proj.weight.data, 1) + if self.blocks is not None: + for layer_id, layer in enumerate(self.blocks, 1): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + q = self.query_tokens.repeat(len(x), 1, 1) + q = self.cross_attention_block(q, x) + if self.blocks is not None: + for blk in self.blocks: + q = blk(q) + return q + + +class AttentiveClassifier(nn.Module): + """ Attentive Classifier """ + def __init__( + self, + embed_dim=768, + num_heads=12, + mlp_ratio=4.0, + depth=1, + norm_layer=nn.LayerNorm, + init_std=0.02, + qkv_bias=True, + num_classes=1000, + complete_block=True, + ): + super().__init__() + self.pooler = AttentivePooler( + num_queries=1, + embed_dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + depth=depth, + norm_layer=norm_layer, + init_std=init_std, + qkv_bias=qkv_bias, + complete_block=complete_block, + ) + self.linear = nn.Linear(embed_dim, num_classes, bias=True) + + def forward(self, x): + x = self.pooler(x).squeeze(1) + x = self.linear(x) + return x diff --git a/build/lib/models/predictor.py b/build/lib/models/predictor.py new file mode 100644 index 0000000..95f6bc0 --- /dev/null +++ b/build/lib/models/predictor.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +from functools import partial + +import torch +import torch.nn as nn + +from jepa_src.models.utils.modules import Block +from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from jepa_src.utils.tensors import ( + trunc_normal_, + repeat_interleave_batch +) +from jepa_src.masks.utils import apply_masks + + +class VisionTransformerPredictor(nn.Module): + """ Vision Transformer """ + def __init__( + self, + img_size=224, + patch_size=16, + num_frames=1, + tubelet_size=2, + embed_dim=768, + predictor_embed_dim=384, + depth=6, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + uniform_power=False, + use_mask_tokens=False, + num_mask_tokens=2, + zero_init_mask_tokens=True, + **kwargs + ): + super().__init__() + # Map input to predictor dimension + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + + # Mask tokens + self.mask_tokens = None + self.num_mask_tokens = 0 + if use_mask_tokens: + self.num_mask_tokens = num_mask_tokens + self.mask_tokens = nn.ParameterList([ + nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + for i in range(num_mask_tokens) + ]) + + # Determine positional embedding + self.input_size = img_size + self.patch_size = patch_size + # -- + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + + grid_size = self.input_size // self.patch_size + grid_depth = self.num_frames // self.tubelet_size + + if self.is_video: + self.num_patches = num_patches = ( + (num_frames // tubelet_size) + * (img_size // patch_size) + * (img_size // patch_size) + ) + else: + self.num_patches = num_patches = ( + (img_size // patch_size) + * (img_size // patch_size) + ) + # Position embedding + self.uniform_power = uniform_power + self.predictor_pos_embed = None + self.predictor_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, predictor_embed_dim), + requires_grad=False) + + # Attention Blocks + self.predictor_blocks = nn.ModuleList([ + Block( + dim=predictor_embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + attn_drop=attn_drop_rate, + grid_size=grid_size, + grid_depth=grid_depth, + norm_layer=norm_layer) + for i in range(depth)]) + + # Normalize & project back to input dimension + self.predictor_norm = norm_layer(predictor_embed_dim) + self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) + + # ------ initialize weights + if self.predictor_pos_embed is not None: + self._init_pos_embed(self.predictor_pos_embed.data) # sincos pos-embed + self.init_std = init_std + if not zero_init_mask_tokens: + for mt in self.mask_tokens: + trunc_normal_(mt, std=init_std) + self.apply(self._init_weights) + self._rescale_blocks() + + def _init_pos_embed(self, pos_embed): + embed_dim = pos_embed.size(-1) + grid_size = self.input_size // self.patch_size + if self.is_video: + grid_depth = self.num_frames // self.tubelet_size + sincos = get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=self.uniform_power + ) + else: + sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) + pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.predictor_blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def diffusion(self, x, noise_beta=(0.5, 1.0), steps=1000): + + # Prepare diffusion noise schedule + b1, b2 = noise_beta + beta_scheduler = (b1 + i*(b2-b1)/steps for i in range(steps)) + alpha_scheduler = [] + _alpha = 1.0 + for _beta in beta_scheduler: + _alpha *= 1.-_beta + alpha_scheduler += [_alpha] + + # Sample diffusion time step + T = torch.randint(0, steps, (len(x),)) + alpha = torch.tensor(alpha_scheduler, device=x.device)[T].unsqueeze(-1).unsqueeze(-1) + + # Normalize features and apply noise + x = torch.nn.functional.layer_norm(x, (x.size(-1),)) + x = alpha**0.5 * x + (1.-alpha)**0.5 * torch.randn(x.shape, device=x.device) + return x + + def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): + """ + :param ctxt: context tokens + :param tgt: target tokens + :param masks_ctxt: indices of context tokens in input + :params masks_tgt: indices of target tokens in input + """ + + assert (masks_ctxt is not None) and (masks_tgt is not None), 'Cannot run predictor without mask indices' + + if not isinstance(masks_ctxt, list): + masks_ctxt = [masks_ctxt] + + if not isinstance(masks_tgt, list): + masks_tgt = [masks_tgt] + + # Batch Size + B = len(ctxt) // len(masks_ctxt) + + # Map context tokens to pedictor dimensions + x = self.predictor_embed(ctxt) + _, N_ctxt, D = x.shape + + # Add positional embedding to ctxt tokens + if self.predictor_pos_embed is not None: + ctxt_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + x += apply_masks(ctxt_pos_embed, masks_ctxt) + + # Map target tokens to predictor dimensions & add noise (fwd diffusion) + if self.mask_tokens is None: + pred_tokens = self.predictor_embed(tgt) + pred_tokens = self.diffusion(pred_tokens) + else: + mask_index = mask_index % self.num_mask_tokens + pred_tokens = self.mask_tokens[mask_index] + pred_tokens = pred_tokens.repeat(B, self.num_patches, 1) + pred_tokens = apply_masks(pred_tokens, masks_tgt) + + # Add positional embedding to target tokens + if self.predictor_pos_embed is not None: + pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) + pos_embs = apply_masks(pos_embs, masks_tgt) + pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_ctxt)) + pred_tokens += pos_embs + + # Concatenate context & target tokens + x = x.repeat(len(masks_tgt), 1, 1) + x = torch.cat([x, pred_tokens], dim=1) + + # FIXME: this implementation currently assumes masks_ctxt and masks_tgt + # are alligned 1:1 (ok with MultiMask wrapper on predictor but + # otherwise will break) + masks_ctxt = torch.cat(masks_ctxt, dim=0) + masks_tgt = torch.cat(masks_tgt, dim=0) + masks = torch.cat([masks_ctxt, masks_tgt], dim=1) + + # Fwd prop + for blk in self.predictor_blocks: + x = blk(x, mask=masks) + x = self.predictor_norm(x) + + # Return output corresponding to target tokens + x = x[:, N_ctxt:] + x = self.predictor_proj(x) + + return x + + +def vit_predictor(**kwargs): + model = VisionTransformerPredictor( + mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + return model diff --git a/build/lib/models/utils/modules.py b/build/lib/models/utils/modules.py new file mode 100644 index 0000000..dc470d9 --- /dev/null +++ b/build/lib/models/utils/modules.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0. + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + use_sdpa=True + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop_prob = proj_drop + self.proj_drop = nn.Dropout(proj_drop) + self.use_sdpa = use_sdpa + + def forward(self, x, mask=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D] + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob) + attn = None + else: + attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D] + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + grid_size=None, + grid_depth=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x, return_attention=False, mask=None): + y, attn = self.attn(self.norm1(x), mask=mask) + if return_attention: + return attn + x = x + y + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, + dim, + num_heads=12, + qkv_bias=False, + use_sdpa=True + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, int(dim*2), bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + self.use_sdpa = use_sdpa + + def forward(self, q, x): + B, n, C = q.shape + q = self.q(q).reshape(B, n, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + B, N, C = x.shape + kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # (batch_size, num_heads, seq_len, feature_dim_per_head) + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + q = F.scaled_dot_product_attention(q, k, v) + else: + xattn = (q @ k.transpose(-2, -1)) * self.scale + xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len) + q = (xattn @ v) + + q = q.transpose(1, 2).reshape(B, n, C) + q = self.proj(q) + + return q + + +class CrossAttentionBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + + def forward(self, q, x): + y = self.xattn(q, self.norm1(x)) + q = q + y + q = q + self.mlp(self.norm2(q)) + return q diff --git a/build/lib/models/utils/multimask.py b/build/lib/models/utils/multimask.py new file mode 100644 index 0000000..d480086 --- /dev/null +++ b/build/lib/models/utils/multimask.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch.nn as nn + + +class MultiMaskWrapper(nn.Module): + + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, x, masks=None): + if masks is None: + return self.backbone(x) + + if (masks is not None) and not isinstance(masks, list): + masks = [masks] + outs = [] + for m in masks: + outs += [self.backbone(x, masks=m)] + return outs + + +class PredictorMultiMaskWrapper(nn.Module): + + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, ctxt, tgt, masks_ctxt, masks_tgt): + if type(ctxt) is not list: + ctxt = [ctxt] + if type(tgt) is not list: + tgt = [tgt] + if type(masks_ctxt) is not list: + masks_ctxt = [masks_ctxt] + if type(masks_tgt) is not list: + masks_tgt = [masks_tgt] + + outs = [] + for i, (zi, hi, mc, mt) in enumerate(zip(ctxt, tgt, masks_ctxt, masks_tgt)): + outs += [self.backbone(zi, hi, mc, mt, mask_index=i)] + return outs diff --git a/build/lib/models/utils/patch_embed.py b/build/lib/models/utils/patch_embed.py new file mode 100644 index 0000000..4ff4de5 --- /dev/null +++ b/build/lib/models/utils/patch_embed.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch.nn as nn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding + """ + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768 + ): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed3D(nn.Module): + """ + Image to Patch Embedding + """ + + def __init__( + self, + patch_size=16, + tubelet_size=2, + in_chans=3, + embed_dim=768, + ): + super().__init__() + self.patch_size = patch_size + self.tubelet_size = tubelet_size + + self.proj = nn.Conv3d( + in_channels=in_chans, + out_channels=embed_dim, + kernel_size=(tubelet_size, patch_size, patch_size), + stride=(tubelet_size, patch_size, patch_size), + ) + + def forward(self, x, **kwargs): + B, C, T, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x diff --git a/build/lib/models/utils/pos_embs.py b/build/lib/models/utils/pos_embs.py new file mode 100644 index 0000000..d1d82e2 --- /dev/null +++ b/build/lib/models/utils/pos_embs.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numpy as np + + +def get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=False +): + """ + grid_size: int of the grid height and width + grid_depth: int of the grid depth + returns: + pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token) + or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token) + """ + grid_d = np.arange(grid_depth, dtype=float) + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid_h, grid_d, grid_w = np.meshgrid(grid_h, grid_d, grid_w) # order of meshgrid is very important for indexing as [d,h,w] + + if not uniform_power: + h_embed_dim = embed_dim // 4 + w_embed_dim = embed_dim // 4 + d_embed_dim = embed_dim // 2 + else: + h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim/6)*2) + + emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1) + emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2) + emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3) + pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1) + pos_embed = pos_embed[:, :embed_dim] + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + returns: + pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token) + or [1+grid_size*grid_size, embed_dim] (w/ cls_token) + """ + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w] + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2) + pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + embed_dim: output dimension for each position + grid_size: int of the grid length + returns: + pos_embed: [grid_size, embed_dim] (w/o cls_token) + or [1+grid_size, embed_dim] (w/ cls_token) + """ + grid = np.arange(grid_size, dtype=float) + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + returns: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/build/lib/models/vision_transformer.py b/build/lib/models/vision_transformer.py new file mode 100644 index 0000000..946246e --- /dev/null +++ b/build/lib/models/vision_transformer.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +from functools import partial + +import torch +import torch.nn as nn + +from jepa_src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D +from jepa_src.models.utils.modules import Block +from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from jepa_src.utils.tensors import trunc_normal_ +from jepa_src.masks.utils import apply_masks + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__( + self, + img_size=224, + patch_size=16, + num_frames=1, + tubelet_size=2, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + out_layers=None, + uniform_power=False, + **kwargs + ): + super().__init__() + self.num_features = self.embed_dim = embed_dim + self.num_heads = num_heads + self.out_layers = out_layers + + self.input_size = img_size + self.patch_size = patch_size + + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + + grid_size = self.input_size // self.patch_size + grid_depth = self.num_frames // self.tubelet_size + + # Tokenize pixels with convolution + if self.is_video: + self.patch_embed = PatchEmbed3D( + patch_size=patch_size, + tubelet_size=tubelet_size, + in_chans=in_chans, + embed_dim=embed_dim) + self.num_patches = ( + (num_frames // tubelet_size) + * (img_size // patch_size) + * (img_size // patch_size) + ) + else: + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + self.num_patches = ( + (img_size // patch_size) + * (img_size // patch_size) + ) + + # Position embedding + self.uniform_power = uniform_power + self.pos_embed = None + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches, embed_dim), + requires_grad=False) + + # Attention Blocks + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + grid_size=grid_size, + grid_depth=grid_depth, + attn_drop=attn_drop_rate, + norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # ------ initialize weights + if self.pos_embed is not None: + self._init_pos_embed(self.pos_embed.data) # sincos pos-embed + self.init_std = init_std + self.apply(self._init_weights) + self._rescale_blocks() + + def _init_pos_embed(self, pos_embed): + embed_dim = pos_embed.size(-1) + grid_size = self.input_size // self.patch_size + if self.is_video: + grid_depth = self.num_frames // self.tubelet_size + sincos = get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=self.uniform_power + ) + else: + sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) + pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv3d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def get_num_layers(self): + return len(self.blocks) + + def no_weight_decay(self): + return {} + + def forward(self, x, masks=None): + """ + :param x: input image/video + :param masks: indices of patch tokens to mask (remove) + """ + + if masks is not None and not isinstance(masks, list): + masks = [masks] + + # Tokenize input + pos_embed = self.pos_embed + if pos_embed is not None: + pos_embed = self.interpolate_pos_encoding(x, pos_embed) + x = self.patch_embed(x) + if pos_embed is not None: + x += pos_embed + B, N, D = x.shape + + # Mask away unwanted tokens (if masks provided) + if masks is not None: + x = apply_masks(x, masks) + masks = torch.cat(masks, dim=0) + + # Fwd prop + outs = [] + for i, blk in enumerate(self.blocks): + x = blk(x, mask=masks) + if self.out_layers is not None and i in self.out_layers: + outs.append(self.norm(x)) + + if self.out_layers is not None: + return outs + + if self.norm is not None: + x = self.norm(x) + + return x + + def interpolate_pos_encoding(self, x, pos_embed): + + _, N, dim = pos_embed.shape + + if self.is_video: + + # If pos_embed already corret size, just return + _, _, T, H, W = x.shape + if H == self.input_size and W == self.input_size and T == self.num_frames: + return pos_embed + + # Convert depth, height, width of input to be measured in patches + # instead of pixels/frames + T = T // self.tubelet_size + H = H // self.patch_size + W = W // self.patch_size + + # Compute the initialized shape of the positional embedding measured + # in patches + N_t = self.num_frames // self.tubelet_size + N_h = N_w = self.input_size // self.patch_size + assert N_h * N_w * N_t == N, 'Positional embedding initialized incorrectly' + + # Compute scale factor for spatio-temporal interpolation + scale_factor = (T/N_t, H/N_h, W/N_w) + + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3), + scale_factor=scale_factor, + mode='trilinear') + pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim) + return pos_embed + + else: + + # If pos_embed already corret size, just return + _, _, H, W = x.shape + if H == self.input_size and W == self.input_size: + return pos_embed + + # Compute scale factor for spatial interpolation + npatch = (H // self.patch_size) * (W // self.patch_size) + scale_factor = math.sqrt(npatch / N) + + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=scale_factor, + mode='bicubic') + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return pos_embed + + +def vit_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_small(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_large(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_huge(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_giant(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_gigantic(patch_size=14, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1664, depth=48, num_heads=16, mpl_ratio=64/13, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs + ) + return model + + +VIT_EMBED_DIMS = { + 'vit_tiny': 192, + 'vit_small': 384, + 'vit_base': 768, + 'vit_large': 1024, + 'vit_huge': 1280, + 'vit_giant': 1408, + 'vit_gigantic': 1664, +} diff --git a/build/lib/utils/distributed.py b/build/lib/utils/distributed.py new file mode 100644 index 0000000..cfba444 --- /dev/null +++ b/build/lib/utils/distributed.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +import torch +import torch.distributed as dist + +from logging import getLogger + +logger = getLogger() + + +def init_distributed(port=37123, rank_and_world_size=(None, None)): + + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size(), dist.get_rank() + + rank, world_size = rank_and_world_size + os.environ['MASTER_ADDR'] = 'localhost' + + if (rank is None) or (world_size is None): + try: + world_size = int(os.environ['SLURM_NTASKS']) + rank = int(os.environ['SLURM_PROCID']) + os.environ['MASTER_ADDR'] = os.environ['HOSTNAME'] + except Exception: + logger.info('SLURM vars not set (distributed training not available)') + world_size, rank = 1, 0 + return world_size, rank + + try: + os.environ['MASTER_PORT'] = str(port) + torch.distributed.init_process_group( + backend='nccl', + world_size=world_size, + rank=rank + ) + except Exception as e: + world_size, rank = 1, 0 + logger.info(f'Rank: {rank}. Distributed training not available {e}') + + return world_size, rank + + +class AllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() + outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())] + dist.all_gather(outputs, x) + return torch.cat(outputs, 0) + return x + + @staticmethod + def backward(ctx, grads): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank() + e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1) + grads = grads.contiguous() + dist.all_reduce(grads) + return grads[s:e] + return grads + + +class AllReduceSum(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() + dist.all_reduce(x) + return x + + @staticmethod + def backward(ctx, grads): + return grads + + +class AllReduce(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() / dist.get_world_size() + dist.all_reduce(x) + return x + + @staticmethod + def backward(ctx, grads): + return grads diff --git a/build/lib/utils/logging.py b/build/lib/utils/logging.py new file mode 100644 index 0000000..fcdd3fa --- /dev/null +++ b/build/lib/utils/logging.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging +import sys + +import torch + + +def gpu_timer(closure, log_timings=True): + """ Helper to time gpu-time to execute closure() """ + log_timings = log_timings and torch.cuda.is_available() + + elapsed_time = -1. + if log_timings: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + + result = closure() + + if log_timings: + end.record() + torch.cuda.synchronize() + elapsed_time = start.elapsed_time(end) + + return result, elapsed_time + + +LOG_FORMAT = "[%(levelname)-8s][%(asctime)s][%(funcName)-25s] %(message)s" +DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + + +def get_logger(name=None, force=False): + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=LOG_FORMAT, datefmt=DATE_FORMAT, force=force) + return logging.getLogger(name=name) + + +class CSVLogger(object): + + def __init__(self, fname, *argv): + self.fname = fname + self.types = [] + # -- print headers + with open(self.fname, '+a') as f: + for i, v in enumerate(argv, 1): + self.types.append(v[0]) + if i < len(argv): + print(v[1], end=',', file=f) + else: + print(v[1], end='\n', file=f) + + def log(self, *argv): + with open(self.fname, '+a') as f: + for i, tv in enumerate(zip(self.types, argv), 1): + end = ',' if i < len(argv) else '\n' + print(tv[0] % tv[1], end=end, file=f) + + +class AverageMeter(object): + """computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.max = float('-inf') + self.min = float('inf') + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + try: + self.max = max(val, self.max) + self.min = min(val, self.min) + except Exception: + pass + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def grad_logger(named_params): + stats = AverageMeter() + stats.first_layer = None + stats.last_layer = None + for n, p in named_params: + if (p.grad is not None) and not (n.endswith('.bias') or len(p.shape) == 1): + grad_norm = float(torch.norm(p.grad.data)) + stats.update(grad_norm) + if 'qkv' in n: + stats.last_layer = grad_norm + if stats.first_layer is None: + stats.first_layer = grad_norm + if stats.first_layer is None or stats.last_layer is None: + stats.first_layer = stats.last_layer = 0. + return stats + + +def adamw_logger(optimizer): + """ logging magnitude of first and second momentum buffers in adamw """ + # TODO: assert that optimizer is instance of torch.optim.AdamW + state = optimizer.state_dict().get('state') + exp_avg_stats = AverageMeter() + exp_avg_sq_stats = AverageMeter() + for key in state: + s = state.get(key) + exp_avg_stats.update(float(s.get('exp_avg').abs().mean())) + exp_avg_sq_stats.update(float(s.get('exp_avg_sq').abs().mean())) + return {'exp_avg': exp_avg_stats, 'exp_avg_sq': exp_avg_sq_stats} diff --git a/build/lib/utils/monitoring.py b/build/lib/utils/monitoring.py new file mode 100644 index 0000000..95a7845 --- /dev/null +++ b/build/lib/utils/monitoring.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import dataclasses +import threading +from typing import Dict, Tuple + +import psutil + + +@dataclasses.dataclass +class ResourceStatsSample: + timestamp: float + cpu_percent: float + read_count: int + write_count: int + read_bytes: int + write_bytes: int + read_chars: int + write_chars: int + cpu_times_user: float + cpu_times_system: float + cpu_times_children_user: float + cpu_times_children_system: float + cpu_times_iowait: float + cpu_affinity: str + cpu_num: int + num_threads: int + num_voluntary_ctx_switches: int + num_involuntary_ctx_switches: int + + def as_tuple(self) -> Dict: + """Return values mirroring fields.""" + return dataclasses.astuple(self) + + def fields(self) -> Tuple[dataclasses.Field, ...]: + """Return fields in this dataclass.""" + return dataclasses.fields(self.__class__) + + +class ResourceMonitoringThread(threading.Thread): + def __init__(self, pid=None, refresh_interval=None, stats_callback_fn=None): + """Starts a thread to monitor pid every refresh_interval seconds. + + Passes a ResourceStatsSample object to the callback.""" + super(ResourceMonitoringThread, self).__init__() + if refresh_interval is None: + refresh_interval = 5 + self.is_running_event = threading.Event() + self.p = psutil.Process(pid) + self.refresh_interval = refresh_interval + if stats_callback_fn is None: + # Default callback + def stats_callback_fn(resource_sample: ResourceStatsSample): + print( + f"PID {self.p.pid} Stats: {resource_sample.resource_stats}") + elif not callable(stats_callback_fn): + raise ValueError("Callback needs to be callable, got {}".format( + type(stats_callback_fn))) + self.stats_callback_fn = stats_callback_fn + + def stop(self) -> None: + self.is_running_event.set() + + def run(self) -> None: + while not self.is_running_event.is_set(): + self.sample_counters() + self.is_running_event.wait(self.refresh_interval) + + def log_sample(self, resource_sample: ResourceStatsSample) -> None: + self.stats_callback_fn(resource_sample) + + def sample_counters(self) -> None: + if not self.p.is_running(): + self.stop() + return + + with self.p.oneshot(): + cpu_percent = self.p.cpu_percent() + cpu_times = self.p.cpu_times() + io_counters = self.p.io_counters() + cpu_affinity = self.p.cpu_affinity() + cpu_num = self.p.cpu_num() + num_threads = self.p.num_threads() + num_ctx_switches = self.p.num_ctx_switches() + timestamp = time.time() + + read_count = io_counters.read_count + write_count = io_counters.write_count + read_bytes = io_counters.read_bytes + write_bytes = io_counters.write_bytes + read_chars = io_counters.read_chars + write_chars = io_counters.write_chars + + def compress_cpu_affinity(cpu_affinity): + """Change list representation to interval/range representation.""" + if not cpu_affinity: + return "" + cpu_affinity_compressed = [] + min_x = None + max_x = None + last_x = None + + # Find contiguous ranges + for x in cpu_affinity: + if last_x is None: + # Start interval + min_x = x + max_x = x + last_x = x + continue + elif x == (last_x + 1): + # Move interval up + max_x = x + elif max_x is not None: + # Interval ended, start again + if min_x == max_x: + cpu_affinity_compressed.append("{}".format(min_x)) + else: + cpu_affinity_compressed.append( + "{}-{}".format(min_x, max_x)) + min_x = x + max_x = x + last_x = x + # Terminate last range + if max_x is not None: + if min_x == max_x: + cpu_affinity_compressed.append("{}".format(min_x)) + else: + cpu_affinity_compressed.append( + "{}-{}".format(min_x, max_x)) + + # Concat + cpu_affinity_compressed = ",".join(cpu_affinity_compressed) + + return cpu_affinity_compressed + + cpu_affinity = compress_cpu_affinity(cpu_affinity) + + resource_sample = ResourceStatsSample( + timestamp=timestamp, + cpu_percent=cpu_percent, + read_count=read_count, + write_count=write_count, + read_bytes=read_bytes, + write_bytes=write_bytes, + read_chars=read_chars, + write_chars=write_chars, + cpu_times_user=cpu_times.user, + cpu_times_system=cpu_times.system, + cpu_times_children_user=cpu_times.children_user, + cpu_times_children_system=cpu_times.children_system, + cpu_times_iowait=cpu_times.iowait, + cpu_affinity=cpu_affinity, + cpu_num=cpu_num, + num_threads=num_threads, + num_voluntary_ctx_switches=num_ctx_switches.voluntary, + num_involuntary_ctx_switches=num_ctx_switches.involuntary, + ) + self.log_sample(resource_sample) + + +if __name__ == "__main__": + import multiprocessing + import time + pid = multiprocessing.current_process().pid + monitor_thread = ResourceMonitoringThread(pid, 1) + monitor_thread.start() + time.sleep(5) + print("Shutdown") + monitor_thread.stop() diff --git a/build/lib/utils/schedulers.py b/build/lib/utils/schedulers.py new file mode 100644 index 0000000..df02e2b --- /dev/null +++ b/build/lib/utils/schedulers.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + + +class WarmupCosineSchedule(object): + + def __init__( + self, + optimizer, + warmup_steps, + start_lr, + ref_lr, + T_max, + last_epoch=-1, + final_lr=0. + ): + self.optimizer = optimizer + self.start_lr = start_lr + self.ref_lr = ref_lr + self.final_lr = final_lr + self.warmup_steps = warmup_steps + self.T_max = T_max - warmup_steps + self._step = 0. + + def step(self): + self._step += 1 + if self._step < self.warmup_steps: + progress = float(self._step) / float(max(1, self.warmup_steps)) + new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) + else: + # -- progress after warmup + progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) + new_lr = max(self.final_lr, + self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1. + math.cos(math.pi * progress))) + + for group in self.optimizer.param_groups: + group['lr'] = new_lr + + return new_lr + + +class CosineWDSchedule(object): + + def __init__( + self, + optimizer, + ref_wd, + T_max, + final_wd=0. + ): + self.optimizer = optimizer + self.ref_wd = ref_wd + self.final_wd = final_wd + self.T_max = T_max + self._step = 0. + + def step(self): + self._step += 1 + progress = self._step / self.T_max + new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1. + math.cos(math.pi * progress)) + + if self.final_wd <= self.ref_wd: + new_wd = max(self.final_wd, new_wd) + else: + new_wd = min(self.final_wd, new_wd) + + for group in self.optimizer.param_groups: + if ('WD_exclude' not in group) or not group['WD_exclude']: + group['weight_decay'] = new_wd + return new_wd diff --git a/build/lib/utils/tensors.py b/build/lib/utils/tensors.py new file mode 100644 index 0000000..6ae2850 --- /dev/null +++ b/build/lib/utils/tensors.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +import torch + +from logging import getLogger + +logger = getLogger() + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def apply_masks(x, masks): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors containing indices of patches [0,N) to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + return torch.cat(all_x, dim=0) + + +def repeat_interleave_batch(x, B, repeat): + N = len(x) // B + x = torch.cat([ + torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0) + for i in range(N) + ], dim=0) + return x diff --git a/build/lib/vjepa_encoder/__init__.py b/build/lib/vjepa_encoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/vjepa_encoder/vision_encoder.py b/build/lib/vjepa_encoder/vision_encoder.py new file mode 100644 index 0000000..7d74393 --- /dev/null +++ b/build/lib/vjepa_encoder/vision_encoder.py @@ -0,0 +1,327 @@ +# Extension of Jepa by Robot Perception and Action Laboratory, USF +# +# Non-Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from typing import List, Optional, Any +import multiprocessing as mp + +import pprint +import yaml +import os + +import torch + +from jepa_src.utils.distributed import init_distributed + +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Tuple + +from vjepa_encoder.vjepa.utils import init_video_model +import numpy as np + +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +# from torch.nn.parallel import DistributedDataParallel +from jepa_src.utils.distributed import init_distributed, AllReduce +from jepa_src.utils.logging import get_logger + +from vjepa_encoder.vjepa.utils import init_video_model + +import torch +from torchvision import transforms +from PIL import Image +import numpy as np + +_GLOBAL_SEED = 0 +np.random.seed(_GLOBAL_SEED) +torch.manual_seed(_GLOBAL_SEED) +torch.backends.cudnn.benchmark = True + +import logging +from jepa_src.utils.logging import get_logger +logger = get_logger(force=True) +logger.setLevel(logging.INFO) + +class JepaEncoder(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + self.encoder, self.predictor = None, None + + def preprocess_image(self, input_data: Any): + """ + Preprocess the input image data. + + Args: + input_data (Any): Input data in various formats. + - str: Path to the image file. + - list: List of image data (numpy arrays, PIL Images, or tensors). + - numpy.ndarray: Image data as a numpy array. + - If the array has shape (batch_size, height, width, channels), it will be treated as a batch of images. + - If the array has shape (height, width, channels), it will be treated as a single image. + - PIL.Image.Image: Image data as a PIL Image object. + - torch.Tensor: Image data as a PyTorch tensor. + + Returns: + torch.Tensor: Preprocessed image data as a tensor. + - If the input is a batch of images, the output will have shape (batch_size, channels, height, width). + - If the input is a single image, the output will have shape (1, channels, height, width). + + Raises: + ValueError: If the input type is not supported. + """ + if isinstance(input_data, str): + img = Image.open(input_data).convert('RGB') + + elif isinstance(input_data, list): + imgs = [ + self.preprocess_image(i).squeeze() for i in input_data + ] + preprocessed_input = torch.stack(imgs) + return preprocessed_input + + elif isinstance(input_data, np.ndarray): + if len(input_data.shape) == 4: + input_data = input_data.transpose(0, 3, 1, 2) + preprocessed_input = torch.from_numpy(input_data).float() + preprocess = transforms.Compose([ + transforms.Resize(self.args['data']['crop_size']), + transforms.CenterCrop(self.args['data']['crop_size']), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + preprocessed_input = preprocess(preprocessed_input) + return preprocessed_input + + img = Image.fromarray(input_data.astype(np.uint8)) + + elif isinstance(input_data, Image.Image): + img = input_data + + elif isinstance(input_data, torch.Tensor): + preprocessed_input = input_data + preprocess = transforms.Compose([ + transforms.Resize(self.args['data']['crop_size']), + transforms.CenterCrop(self.args['data']['crop_size']), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + preprocessed_input = preprocess(preprocessed_input) + return preprocessed_input + + else: + raise ValueError("Unsupported input type. Expected image path, image array, or PIL Image.") + + # Define the preprocessing transforms + preprocess = transforms.Compose([ + transforms.Resize(self.args['data']['crop_size']), + transforms.CenterCrop(self.args['data']['crop_size']), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # Apply preprocessing transforms + preprocessed_input = preprocess(img) + + preprocessed_input = preprocessed_input.unsqueeze(0) # Add batch dimension + return preprocessed_input + + def embed_image(self, x): + """ + Generate embeddings for the input image data. + + Args: + x (Any): Input image data in various formats. + - str: Path to the image file. + - list: List of image data (numpy arrays, PIL Images, or tensors). + - numpy.ndarray: Image data as a numpy array. + - If the array has shape (batch_size, height, width, channels), it will be treated as a batch of images. + - If the array has shape (height, width, channels), it will be treated as a single image. + - PIL.Image.Image: Image data as a PIL Image object. + - torch.Tensor: Image data as a PyTorch tensor. + + Returns: + torch.Tensor: Embeddings for the input image data. + - If the input is a batch of images, the output will have shape (batch_size, num_patches, embedding_size). + - If the input is a single image, the output will have shape (1, num_patches, embedding_size). + + Notes: + - The input image data is preprocessed using the `preprocess_image` method before generating embeddings. + - If the preprocessed input has fewer than 5 dimensions, an additional dimension is added to represent the time dimension. + - The embeddings are generated using the forward pass of the model. + - The computation is performed on the available device (GPU if available, otherwise CPU). + """ + x = self.preprocess_image(x) + + # Unsqueeze along the time Dimension + if len(x.shape) < 5: + x = x.unsqueeze(2) + + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + + x = x.to(device) + + with torch.no_grad(): + embeddings = self.forward(x) + + return embeddings + + def load_encoder_checkpoint( + self, + r_path, + encoder, + ): + try: + checkpoint = torch.load(r_path, map_location=torch.device('cpu')) + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + + try: + + # -- loading encoder + pretrained_dict = checkpoint['encoder'] + msg = encoder.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') + + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + epoch = 0 + + return encoder + + + def forward(self, clips: torch.Tensor, masks_enc: List[torch.Tensor], masks_pred: List[torch.Tensor]) -> List[torch.Tensor]: + z = self.encoder(clips, masks_enc) + h = self._forward_target(clips, masks_pred) + z = self.predictor(z, h, masks_enc, masks_pred) + return z + + def freeze_encoder(self): + for p in self.encoder.parameters(): + p.requires_grad = False + + def forward(self, x): + return self.encoder(x) + + @classmethod + def load_model(cls, config_file_path: str, device: Optional[List[str]] = None) -> "JepaEncoder": + # TODO: Fix this so it works properly + # os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) + + args = None + with open(config_file_path, 'r') as y_file: + args = yaml.load(y_file, Loader=yaml.FullLoader) + logger.info('loaded params...') + + pprint.PrettyPrinter(indent=4).pprint(args) + dump = os.path.join(args['logging']['folder'], 'params-encoder.yaml') + with open(dump, 'w') as f: + yaml.dump(args, f) + + + model = cls(args) + + world_size, rank = init_distributed() + + # -- META + cfgs_meta = args.get('meta') + load_model = cfgs_meta.get('load_checkpoint') + assert load_model, "Cannot load model without checkpoint file specified" + r_file = cfgs_meta.get('read_checkpoint', None) + seed = cfgs_meta.get('seed', _GLOBAL_SEED) + save_every_freq = cfgs_meta.get('save_every_freq', -1) + skip_batches = cfgs_meta.get('skip_batches', -1) + use_sdpa = cfgs_meta.get('use_sdpa', False) + which_dtype = cfgs_meta.get('dtype') + logger.info(f'{which_dtype}') + if which_dtype.lower() == 'bfloat16': + dtype = torch.bfloat16 + mixed_precision = True + elif which_dtype.lower() == 'float16': + dtype = torch.float16 + mixed_precision = True + else: + dtype = torch.float32 + mixed_precision = False + + # -- MASK + cfgs_mask = args.get('mask') + + # -- MODEL + cfgs_model = args.get('model') + model_name = cfgs_model.get('model_name') + pred_depth = cfgs_model.get('pred_depth') + pred_embed_dim = cfgs_model.get('pred_embed_dim') + uniform_power = cfgs_model.get('uniform_power', True) + use_mask_tokens = cfgs_model.get('use_mask_tokens', True) + zero_init_mask_tokens = cfgs_model.get('zero_init_mask_tokens', True) + + # -- DATA + cfgs_data = args.get('data') + num_clips = cfgs_data.get('num_clips') + num_frames = cfgs_data.get('num_frames') + tubelet_size = cfgs_data.get('tubelet_size') + sampling_rate = cfgs_data.get('sampling_rate') + duration = cfgs_data.get('clip_duration', None) + crop_size = cfgs_data.get('crop_size', 224) + patch_size = cfgs_data.get('patch_size') + + # -- LOGGING + cfgs_logging = args.get('logging') + folder = cfgs_logging.get('folder') + tag = cfgs_logging.get('write_tag') + + # -- set device + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + torch.cuda.set_device(device) + + # -- log/checkpointing paths + latest_file = f'{tag}-latest.pth.tar' + latest_path = os.path.join(folder, latest_file) + load_path = None + if load_model: + load_path = os.path.join(folder, r_file) if r_file is not None else latest_path + if not os.path.exists(load_path): + load_path = r_file + if not os.path.exists(load_path): + raise RuntimeError("Cannot load model. Ensure you specify the path to the model .tar file in the input config.") + + # -- Attempt to initialize model + model.encoder, model.predictor = init_video_model( + uniform_power=uniform_power, + use_mask_tokens=use_mask_tokens, + num_mask_tokens=len(cfgs_mask), + zero_init_mask_tokens=zero_init_mask_tokens, + device=device, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + model_name=model_name, + crop_size=crop_size, + pred_depth=pred_depth, + pred_embed_dim=pred_embed_dim, + use_sdpa=use_sdpa, + ) + + # model.encoder = DistributedDataParallel(model.encoder, static_graph=True) + + # -- load training checkpoint + model.encoder = model.load_encoder_checkpoint( + load_path, model.encoder + ) + + return model + + diff --git a/build/lib/vjepa_encoder/vjepa/__init__.py b/build/lib/vjepa_encoder/vjepa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/vjepa_encoder/vjepa/train.py b/build/lib/vjepa_encoder/vjepa/train.py new file mode 100644 index 0000000..ccb2e75 --- /dev/null +++ b/build/lib/vjepa_encoder/vjepa/train.py @@ -0,0 +1,586 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS +try: + # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE + # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE + # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE + # -- TO EACH PROCESS + os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID'] +except Exception: + pass + +import copy +import time +import numpy as np + +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel + +from jepa_src.datasets.data_manager import init_data +from jepa_src.masks.random_tube import MaskCollator as TubeMaskCollator +from jepa_src.masks.multiblock3d import MaskCollator as MB3DMaskCollator +from jepa_src.masks.utils import apply_masks +from jepa_src.utils.distributed import init_distributed, AllReduce +from jepa_src.utils.logging import ( + CSVLogger, + gpu_timer, + get_logger, + grad_logger, + adamw_logger, + AverageMeter) +from jepa_src.utils.tensors import repeat_interleave_batch + +from app.vjepa.utils import ( + load_checkpoint, + init_video_model, + init_opt, +) +from app.vjepa.transforms import make_transforms + + +# -- +log_timings = True +log_freq = 10 +checkpoint_freq = 1 +# -- + +_GLOBAL_SEED = 0 +np.random.seed(_GLOBAL_SEED) +torch.manual_seed(_GLOBAL_SEED) +torch.backends.cudnn.benchmark = True + + +logger = get_logger(__name__) + + +def main(args, resume_preempt=False): + # ----------------------------------------------------------------------- # + # PASSED IN PARAMS FROM CONFIG FILE + # ----------------------------------------------------------------------- # + + # -- META + cfgs_meta = args.get('meta') + load_model = cfgs_meta.get('load_checkpoint') or resume_preempt + r_file = cfgs_meta.get('read_checkpoint', None) + seed = cfgs_meta.get('seed', _GLOBAL_SEED) + save_every_freq = cfgs_meta.get('save_every_freq', -1) + skip_batches = cfgs_meta.get('skip_batches', -1) + use_sdpa = cfgs_meta.get('use_sdpa', False) + which_dtype = cfgs_meta.get('dtype') + logger.info(f'{which_dtype}') + if which_dtype.lower() == 'bfloat16': + dtype = torch.bfloat16 + mixed_precision = True + elif which_dtype.lower() == 'float16': + dtype = torch.float16 + mixed_precision = True + else: + dtype = torch.float32 + mixed_precision = False + + # -- MASK + cfgs_mask = args.get('mask') + + # -- MODEL + cfgs_model = args.get('model') + model_name = cfgs_model.get('model_name') + pred_depth = cfgs_model.get('pred_depth') + pred_embed_dim = cfgs_model.get('pred_embed_dim') + uniform_power = cfgs_model.get('uniform_power', True) + use_mask_tokens = cfgs_model.get('use_mask_tokens', True) + zero_init_mask_tokens = cfgs_model.get('zero_init_mask_tokens', True) + + # -- DATA + cfgs_data = args.get('data') + dataset_type = cfgs_data.get('dataset_type', 'videodataset') + mask_type = cfgs_data.get('mask_type', 'multiblock3d') + dataset_paths = cfgs_data.get('datasets', []) + datasets_weights = cfgs_data.get('datasets_weights', None) + if datasets_weights is not None: + assert len(datasets_weights) == len(dataset_paths), 'Must have one sampling weight specified for each dataset' + batch_size = cfgs_data.get('batch_size') + num_clips = cfgs_data.get('num_clips') + num_frames = cfgs_data.get('num_frames') + tubelet_size = cfgs_data.get('tubelet_size') + sampling_rate = cfgs_data.get('sampling_rate') + duration = cfgs_data.get('clip_duration', None) + crop_size = cfgs_data.get('crop_size', 224) + patch_size = cfgs_data.get('patch_size') + pin_mem = cfgs_data.get('pin_mem', False) + num_workers = cfgs_data.get('num_workers', 1) + filter_short_videos = cfgs_data.get('filter_short_videos', False) + decode_one_clip = cfgs_data.get('decode_one_clip', True) + log_resource_util_data = cfgs_data.get('log_resource_utilization', False) + + # -- DATA AUGS + cfgs_data_aug = args.get('data_aug') + ar_range = cfgs_data_aug.get('random_resize_aspect_ratio', [3/4, 4/3]) + rr_scale = cfgs_data_aug.get('random_resize_scale', [0.3, 1.0]) + motion_shift = cfgs_data_aug.get('motion_shift', False) + reprob = cfgs_data_aug.get('reprob', 0.) + use_aa = cfgs_data_aug.get('auto_augment', False) + + # -- LOSS + cfgs_loss = args.get('loss') + loss_exp = cfgs_loss.get('loss_exp') + reg_coeff = cfgs_loss.get('reg_coeff') + + # -- OPTIMIZATION + cfgs_opt = args.get('optimization') + ipe = cfgs_opt.get('ipe', None) + ipe_scale = cfgs_opt.get('ipe_scale', 1.0) + clip_grad = cfgs_opt.get('clip_grad', None) + wd = float(cfgs_opt.get('weight_decay')) + final_wd = float(cfgs_opt.get('final_weight_decay')) + num_epochs = cfgs_opt.get('epochs') + warmup = cfgs_opt.get('warmup') + start_lr = cfgs_opt.get('start_lr') + lr = cfgs_opt.get('lr') + final_lr = cfgs_opt.get('final_lr') + ema = cfgs_opt.get('ema') + betas = cfgs_opt.get('betas', (0.9, 0.999)) + eps = cfgs_opt.get('eps', 1.e-8) + + # -- LOGGING + cfgs_logging = args.get('logging') + folder = cfgs_logging.get('folder') + tag = cfgs_logging.get('write_tag') + + # ----------------------------------------------------------------------- # + # ----------------------------------------------------------------------- # + + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.benchmark = True + try: + mp.set_start_method('spawn') + except Exception: + pass + + # -- init torch distributed backend + world_size, rank = init_distributed() + logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') + + # -- set device + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + torch.cuda.set_device(device) + + # -- log/checkpointing paths + log_file = os.path.join(folder, f'{tag}_r{rank}.csv') + latest_file = f'{tag}-latest.pth.tar' + latest_path = os.path.join(folder, latest_file) + load_path = None + if load_model: + load_path = os.path.join(folder, r_file) if r_file is not None else latest_path + if not os.path.exists(load_path): + load_path = None + load_model = False + + # -- make csv_logger + csv_logger = CSVLogger( + log_file, + ('%d', 'epoch'), + ('%d', 'itr'), + ('%.5f', 'loss'), + ('%.5f', 'loss-jepa'), + ('%.5f', 'reg-loss'), + ('%.5f', 'enc-grad-norm'), + ('%.5f', 'pred-grad-norm'), + ('%d', 'gpu-time(ms)'), + ('%d', 'wall-time(ms)'), + ) + + # -- init model + encoder, predictor = init_video_model( + uniform_power=uniform_power, + use_mask_tokens=use_mask_tokens, + num_mask_tokens=len(cfgs_mask), + zero_init_mask_tokens=zero_init_mask_tokens, + device=device, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + model_name=model_name, + crop_size=crop_size, + pred_depth=pred_depth, + pred_embed_dim=pred_embed_dim, + use_sdpa=use_sdpa, + ) + target_encoder = copy.deepcopy(encoder) + + # -- make data transforms + if mask_type == 'multiblock3d': + logger.info('Initializing basic multi-block mask') + mask_collator = MB3DMaskCollator( + crop_size=crop_size, + num_frames=num_frames, + patch_size=patch_size, + tubelet_size=tubelet_size, + cfgs_mask=cfgs_mask) + else: + logger.info('Initializing random tube mask') + mask_collator = TubeMaskCollator( + crop_size=crop_size, + num_frames=num_frames, + patch_size=patch_size, + tubelet_size=tubelet_size, + cfgs_mask=cfgs_mask) + transform = make_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=ar_range, + random_resize_scale=rr_scale, + reprob=reprob, + auto_augment=use_aa, + motion_shift=motion_shift, + crop_size=crop_size) + + # -- init data-loaders/samplers + (unsupervised_loader, + unsupervised_sampler) = init_data( + data=dataset_type, + root_path=dataset_paths, + batch_size=batch_size, + training=True, + clip_len=num_frames, + frame_sample_rate=sampling_rate, + filter_short_videos=filter_short_videos, + decode_one_clip=decode_one_clip, + duration=duration, + num_clips=num_clips, + transform=transform, + datasets_weights=datasets_weights, + collator=mask_collator, + num_workers=num_workers, + world_size=world_size, + pin_mem=pin_mem, + rank=rank, + log_dir=folder if log_resource_util_data else None) + try: + _dlen = len(unsupervised_loader) + except Exception: # Different interface for webdataset + _dlen = unsupervised_loader.num_batches + if ipe is None: + ipe = _dlen + logger.info(f'iterations per epoch/dataest length: {ipe}/{_dlen}') + + # -- init optimizer and scheduler + optimizer, scaler, scheduler, wd_scheduler = init_opt( + encoder=encoder, + predictor=predictor, + wd=wd, + final_wd=final_wd, + start_lr=start_lr, + ref_lr=lr, + final_lr=final_lr, + iterations_per_epoch=ipe, + warmup=warmup, + num_epochs=num_epochs, + ipe_scale=ipe_scale, + mixed_precision=mixed_precision, + betas=betas, + eps=eps) + encoder = DistributedDataParallel(encoder, static_graph=True) + predictor = DistributedDataParallel(predictor, static_graph=True) + target_encoder = DistributedDataParallel(target_encoder) + for p in target_encoder.parameters(): + p.requires_grad = False + + # -- momentum schedule + momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale) + for i in range(int(ipe*num_epochs*ipe_scale)+1)) + + start_epoch = 0 + # -- load training checkpoint + if load_model or os.path.exists(latest_path): + ( + encoder, + predictor, + target_encoder, + optimizer, + scaler, + start_epoch, + ) = load_checkpoint( + r_path=load_path, + encoder=encoder, + predictor=predictor, + target_encoder=target_encoder, + opt=optimizer, + scaler=scaler) + for _ in range(start_epoch * ipe): + scheduler.step() + wd_scheduler.step() + next(momentum_scheduler) + mask_collator.step() + + def save_checkpoint(epoch, path): + if rank != 0: + return + save_dict = { + 'encoder': encoder.state_dict(), + 'predictor': predictor.state_dict(), + 'opt': optimizer.state_dict(), + 'scaler': None if scaler is None else scaler.state_dict(), + 'target_encoder': target_encoder.state_dict(), + 'epoch': epoch, + 'loss': loss_meter.avg, + 'batch_size': batch_size, + 'world_size': world_size, + 'lr': lr, + } + try: + torch.save(save_dict, path) + except Exception as e: + logger.info(f'Encountered exception when saving checkpoint: {e}') + + logger.info('Initializing loader...') + loader = iter(unsupervised_loader) + + if skip_batches > 0: + logger.info(f'Skip {skip_batches} batches') + unsupervised_sampler.set_epoch(start_epoch) + for itr in range(skip_batches): + if itr % 10 == 0: + logger.info(f'Skip {itr}/{skip_batches} batches') + try: + udata = next(loader) + except Exception: + loader = iter(unsupervised_loader) + udata = next(loader) + + # -- TRAINING LOOP + for epoch in range(start_epoch, num_epochs): + logger.info('Epoch %d' % (epoch + 1)) + + # -- update distributed-data-loader epoch + unsupervised_sampler.set_epoch(epoch) + + loss_meter = AverageMeter() + input_var_meter = AverageMeter() + input_var_min_meter = AverageMeter() + jepa_loss_meter = AverageMeter() + reg_loss_meter = AverageMeter() + mask_meters = [AverageMeter() for _ in range(len(cfgs_mask))] + gpu_time_meter = AverageMeter() + wall_time_meter = AverageMeter() + + for itr in range(ipe): + itr_start_time = time.time() + + try: + udata, masks_enc, masks_pred = next(loader) + except Exception: + logger.info('Exhausted data loaders. Refreshing...') + loader = iter(unsupervised_loader) + udata, masks_enc, masks_pred = next(loader) + assert len(masks_enc) == len(masks_pred), \ + 'Currently require num encoder masks = num predictor masks' + + def load_clips(): + # -- unsupervised video clips + # Put each clip on the GPU and concatenate along batch + # dimension + clips = torch.cat([u.to(device, non_blocking=True) for u in udata[0]], dim=0) + + # Put each mask-enc/mask-pred pair on the GPU and reuse the + # same mask pair for each clip + _masks_enc, _masks_pred = [], [] + for _me, _mp in zip(masks_enc, masks_pred): + _me = _me.to(device, non_blocking=True) + _mp = _mp.to(device, non_blocking=True) + _me = repeat_interleave_batch(_me, batch_size, repeat=num_clips) + _mp = repeat_interleave_batch(_mp, batch_size, repeat=num_clips) + _masks_enc.append(_me) + _masks_pred.append(_mp) + + return (clips, _masks_enc, _masks_pred) + clips, masks_enc, masks_pred = load_clips() + + for _i, m in enumerate(mask_meters): + m.update(masks_enc[_i][0].size(-1)) + + def train_step(): + _new_lr = scheduler.step() + _new_wd = wd_scheduler.step() + # -- + + def forward_target(c): + """ + Returns list of tensors of shape [B, N, D], one for each + mask-pred. + """ + with torch.no_grad(): + h = target_encoder(c) + h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim [B, N, D] + # -- create targets (masked regions of h) + h = apply_masks(h, masks_pred, concat=False) + return h + + def forward_context(c, h): + """ + Returns list of tensors of shape [B, N, D], one for each + mask-pred. + """ + z = encoder(c, masks_enc) + z = predictor(z, h, masks_enc, masks_pred) + return z + + def loss_fn(z, h): + loss = 0. + # Compute loss and accumulate for each mask-enc/mask-pred pair + for zi, hi in zip(z, h): + loss += torch.mean(torch.abs(zi - hi)**loss_exp) / loss_exp + loss /= len(masks_pred) + return loss + + def reg_fn(z): + return sum([torch.sqrt(zi.var(dim=1) + 0.0001) for zi in z]) / len(z) + + # Step 1. Forward + loss_jepa, loss_reg = 0., 0. + with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision): + h = forward_target(clips) + z = forward_context(clips, h) + loss_jepa = loss_fn(z, h) # jepa prediction loss + pstd_z = reg_fn(z) # predictor variance across patches + loss_reg += torch.mean(F.relu(1.-pstd_z)) + loss = loss_jepa + reg_coeff * loss_reg + + # Step 2. Backward & step + _enc_norm, _pred_norm = 0., 0. + if mixed_precision: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + else: + loss.backward() + if (epoch > warmup) and (clip_grad is not None): + _enc_norm = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip_grad) + _pred_norm = torch.nn.utils.clip_grad_norm_(predictor.parameters(), clip_grad) + if mixed_precision: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + grad_stats = grad_logger(encoder.named_parameters()) + grad_stats.global_norm = float(_enc_norm) + grad_stats_pred = grad_logger(predictor.named_parameters()) + grad_stats_pred.global_norm = float(_pred_norm) + optimizer.zero_grad() + optim_stats = adamw_logger(optimizer) + + # Step 3. momentum update of target encoder + m = next(momentum_scheduler) + with torch.no_grad(): + for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()): + param_k.data.mul_(m).add_((1.-m) * param_q.detach().data) + + return ( + float(loss), + float(loss_jepa), + float(loss_reg), + _new_lr, + _new_wd, + grad_stats, + grad_stats_pred, + optim_stats, + ) + (loss, loss_jepa, loss_reg, _new_lr, _new_wd, grad_stats, grad_stats_pred, optim_stats,), gpu_etime_ms = gpu_timer(train_step) + iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000. + loss_meter.update(loss) + input_var = float(AllReduce.apply(clips.view(clips.shape[0], -1).var(dim=1).mean(dim=0))) + input_var_min = float(AllReduce.apply(torch.min(clips.view(clips.shape[0], -1).var(dim=1)))) + input_var_meter.update(input_var) + input_var_min_meter.update(input_var_min) + jepa_loss_meter.update(loss_jepa) + reg_loss_meter.update(loss_reg) + gpu_time_meter.update(gpu_etime_ms) + wall_time_meter.update(iter_elapsed_time_ms) + + # -- Logging + def log_stats(): + csv_logger.log( + epoch + 1, + itr, + loss, + loss_jepa, + loss_reg, + grad_stats.global_norm, + grad_stats_pred.global_norm, + gpu_etime_ms, + iter_elapsed_time_ms) + if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss): + logger.info( + '[%d, %5d] loss: %.3f | p%.3f r%.3f | ' + 'input_var: %.3f %.3f | ' + 'masks: %s ' + '[wd: %.2e] [lr: %.2e] ' + '[mem: %.2e] ' + '[gpu: %.1f ms]' + '[wall: %.1f ms]' + % (epoch + 1, itr, + loss_meter.avg, + jepa_loss_meter.avg, + reg_loss_meter.avg, + input_var_meter.avg, + input_var_min_meter.avg, + '[' + ', '.join(['%.1f' % m.avg for m in mask_meters]) + ']', + _new_wd, + _new_lr, + torch.cuda.max_memory_allocated() / 1024.0**2, + gpu_time_meter.avg, + wall_time_meter.avg)) + + if optim_stats is not None: + logger.info( + '[%d, %5d] first moment: %.2e [%.2e %.2e] second moment: %.2e [%.2e %.2e]' + % (epoch + 1, itr, + optim_stats.get('exp_avg').avg, + optim_stats.get('exp_avg').min, + optim_stats.get('exp_avg').max, + optim_stats.get('exp_avg_sq').avg, + optim_stats.get('exp_avg_sq').min, + optim_stats.get('exp_avg_sq').max)) + + if grad_stats is not None: + logger.info( + '[%d, %5d] enc_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e' + % (epoch + 1, itr, + grad_stats.first_layer, + grad_stats.last_layer, + grad_stats.min, + grad_stats.max, + grad_stats.global_norm)) + + if grad_stats_pred is not None: + logger.info( + '[%d, %5d] pred_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e' + % (epoch + 1, itr, + grad_stats_pred.first_layer, + grad_stats_pred.last_layer, + grad_stats_pred.min, + grad_stats_pred.max, + grad_stats_pred.global_norm)) + log_stats() + assert not np.isnan(loss), 'loss is nan' + + # -- Save Checkpoint + logger.info('avg. loss %.3f' % loss_meter.avg) + # -- Save Last + if epoch % checkpoint_freq == 0 or epoch == (num_epochs - 1): + save_checkpoint(epoch + 1, latest_path) + if save_every_freq > 0 and epoch % save_every_freq == 0: + save_every_file = f'{tag}-e{epoch}.pth.tar' + save_every_path = os.path.join(folder, save_every_file) + save_checkpoint(epoch + 1, save_every_path) diff --git a/build/lib/vjepa_encoder/vjepa/transforms.py b/build/lib/vjepa_encoder/vjepa/transforms.py new file mode 100644 index 0000000..ba62555 --- /dev/null +++ b/build/lib/vjepa_encoder/vjepa/transforms.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torchvision.transforms as transforms + +import jepa_src.datasets.utils.video.transforms as video_transforms +from jepa_src.datasets.utils.video.randerase import RandomErasing + + +def make_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=(3/4, 4/3), + random_resize_scale=(0.3, 1.0), + reprob=0.0, + auto_augment=False, + motion_shift=False, + crop_size=224, + normalize=((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) +): + + _frames_augmentation = VideoTransform( + random_horizontal_flip=random_horizontal_flip, + random_resize_aspect_ratio=random_resize_aspect_ratio, + random_resize_scale=random_resize_scale, + reprob=reprob, + auto_augment=auto_augment, + motion_shift=motion_shift, + crop_size=crop_size, + normalize=normalize, + ) + return _frames_augmentation + + +class VideoTransform(object): + + def __init__( + self, + random_horizontal_flip=True, + random_resize_aspect_ratio=(3/4, 4/3), + random_resize_scale=(0.3, 1.0), + reprob=0.0, + auto_augment=False, + motion_shift=False, + crop_size=224, + normalize=((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + ): + + self.random_horizontal_flip = random_horizontal_flip + self.random_resize_aspect_ratio = random_resize_aspect_ratio + self.random_resize_scale = random_resize_scale + self.auto_augment = auto_augment + self.motion_shift = motion_shift + self.crop_size = crop_size + self.mean = torch.tensor(normalize[0], dtype=torch.float32) + self.std = torch.tensor(normalize[1], dtype=torch.float32) + if not self.auto_augment: + # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255. + self.mean *= 255. + self.std *= 255. + + self.autoaug_transform = video_transforms.create_random_augment( + input_size=(crop_size, crop_size), + auto_augment='rand-m7-n4-mstd0.5-inc1', + interpolation='bicubic', + ) + + self.spatial_transform = video_transforms.random_resized_crop_with_shift \ + if motion_shift else video_transforms.random_resized_crop + + self.reprob = reprob + self.erase_transform = RandomErasing( + reprob, + mode='pixel', + max_count=1, + num_splits=1, + device='cpu', + ) + + def __call__(self, buffer): + + if self.auto_augment: + buffer = [transforms.ToPILImage()(frame) for frame in buffer] + buffer = self.autoaug_transform(buffer) + buffer = [transforms.ToTensor()(img) for img in buffer] + buffer = torch.stack(buffer) # T C H W + buffer = buffer.permute(0, 2, 3, 1) # T H W C + else: + buffer = torch.tensor(buffer, dtype=torch.float32) + + buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W + + buffer = self.spatial_transform( + images=buffer, + target_height=self.crop_size, + target_width=self.crop_size, + scale=self.random_resize_scale, + ratio=self.random_resize_aspect_ratio, + ) + if self.random_horizontal_flip: + buffer, _ = video_transforms.horizontal_flip(0.5, buffer) + + buffer = _tensor_normalize_inplace(buffer, self.mean, self.std) + if self.reprob > 0: + buffer = buffer.permute(1, 0, 2, 3) + buffer = self.erase_transform(buffer) + buffer = buffer.permute(1, 0, 2, 3) + + return buffer + + +def tensor_normalize(tensor, mean, std): + """ + Normalize a given tensor by subtracting the mean and dividing the std. + Args: + tensor (tensor): tensor to normalize. + mean (tensor or list): mean value to subtract. + std (tensor or list): std to divide. + """ + if tensor.dtype == torch.uint8: + tensor = tensor.float() + tensor = tensor / 255.0 + if type(mean) == list: + mean = torch.tensor(mean) + if type(std) == list: + std = torch.tensor(std) + tensor = tensor - mean + tensor = tensor / std + return tensor + + +def _tensor_normalize_inplace(tensor, mean, std): + """ + Normalize a given tensor by subtracting the mean and dividing the std. + Args: + tensor (tensor): tensor to normalize (with dimensions C, T, H, W). + mean (tensor): mean value to subtract (in 0 to 255 floats). + std (tensor): std to divide (in 0 to 255 floats). + """ + if tensor.dtype == torch.uint8: + tensor = tensor.float() + + C, T, H, W = tensor.shape + tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension + tensor.sub_(mean).div_(std) + tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front + return tensor diff --git a/build/lib/vjepa_encoder/vjepa/utils.py b/build/lib/vjepa_encoder/vjepa/utils.py new file mode 100644 index 0000000..2636ed7 --- /dev/null +++ b/build/lib/vjepa_encoder/vjepa/utils.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging +import sys +import warnings +import yaml + + +import torch + +import jepa_src.models.vision_transformer as video_vit +import jepa_src.models.predictor as vit_pred +from jepa_src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper +from jepa_src.utils.schedulers import ( + WarmupCosineSchedule, + CosineWDSchedule) +from jepa_src.utils.tensors import trunc_normal_ + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging.getLogger() + + +def load_checkpoint( + r_path, + encoder, + predictor, + target_encoder, + opt, + scaler, +): + try: + checkpoint = torch.load(r_path, map_location=torch.device('cpu')) + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + + epoch = 0 + try: + epoch = checkpoint['epoch'] + + # -- loading encoder + pretrained_dict = checkpoint['encoder'] + msg = encoder.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') + + # -- loading predictor + pretrained_dict = checkpoint['predictor'] + msg = predictor.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained predictor from epoch {epoch} with msg: {msg}') + + # -- loading target_encoder + if target_encoder is not None: + print(list(checkpoint.keys())) + pretrained_dict = checkpoint['target_encoder'] + msg = target_encoder.load_state_dict(pretrained_dict) + logger.info( + f'loaded pretrained target encoder from epoch {epoch} with msg: {msg}' + ) + + # -- loading optimizer + opt.load_state_dict(checkpoint['opt']) + if scaler is not None: + scaler.load_state_dict(checkpoint['scaler']) + logger.info(f'loaded optimizers from epoch {epoch}') + logger.info(f'read-path: {r_path}') + del checkpoint + + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + epoch = 0 + + return ( + encoder, + predictor, + target_encoder, + opt, + scaler, + epoch, + ) + + +def init_video_model( + device, + patch_size=16, + num_frames=16, + tubelet_size=2, + model_name='vit_base', + crop_size=224, + pred_depth=6, + pred_embed_dim=384, + uniform_power=False, + use_mask_tokens=False, + num_mask_tokens=2, + zero_init_mask_tokens=True, + use_sdpa=False, +): + encoder = video_vit.__dict__[model_name]( + img_size=crop_size, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + uniform_power=uniform_power, + use_sdpa=use_sdpa, + ) + encoder = MultiMaskWrapper(encoder) + predictor = vit_pred.__dict__['vit_predictor']( + img_size=crop_size, + use_mask_tokens=use_mask_tokens, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + embed_dim=encoder.backbone.embed_dim, + predictor_embed_dim=pred_embed_dim, + depth=pred_depth, + num_heads=encoder.backbone.num_heads, + uniform_power=uniform_power, + num_mask_tokens=num_mask_tokens, + zero_init_mask_tokens=zero_init_mask_tokens, + use_sdpa=use_sdpa, + ) + predictor = PredictorMultiMaskWrapper(predictor) + + def init_weights(m): + if isinstance(m, torch.nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + elif isinstance(m, torch.nn.LayerNorm): + torch.nn.init.constant_(m.bias, 0) + torch.nn.init.constant_(m.weight, 1.0) + + for m in encoder.modules(): + init_weights(m) + + for m in predictor.modules(): + init_weights(m) + + encoder.to(device) + predictor.to(device) + logger.info(encoder) + logger.info(predictor) + + def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.info(f'Encoder number of parameters: {count_parameters(encoder)}') + logger.info(f'Predictor number of parameters: {count_parameters(predictor)}') + + return encoder, predictor + + +def init_opt( + encoder, + predictor, + iterations_per_epoch, + start_lr, + ref_lr, + warmup, + num_epochs, + wd=1e-6, + final_wd=1e-6, + final_lr=0.0, + mixed_precision=False, + ipe_scale=1.25, + betas=(0.9, 0.999), + eps=1e-8, + zero_init_bias_wd=True, +): + param_groups = [ + { + 'params': (p for n, p in encoder.named_parameters() + if ('bias' not in n) and (len(p.shape) != 1)) + }, { + 'params': (p for n, p in predictor.named_parameters() + if ('bias' not in n) and (len(p.shape) != 1)) + }, { + 'params': (p for n, p in encoder.named_parameters() + if ('bias' in n) or (len(p.shape) == 1)), + 'WD_exclude': zero_init_bias_wd, + 'weight_decay': 0, + }, { + 'params': (p for n, p in predictor.named_parameters() + if ('bias' in n) or (len(p.shape) == 1)), + 'WD_exclude': zero_init_bias_wd, + 'weight_decay': 0, + }, + ] + + logger.info('Using AdamW') + optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps) + scheduler = WarmupCosineSchedule( + optimizer, + warmup_steps=int(warmup * iterations_per_epoch), + start_lr=start_lr, + ref_lr=ref_lr, + final_lr=final_lr, + T_max=int(ipe_scale * num_epochs * iterations_per_epoch), + ) + wd_scheduler = CosineWDSchedule( + optimizer, + ref_wd=wd, + final_wd=final_wd, + T_max=int(ipe_scale * num_epochs * iterations_per_epoch), + ) + scaler = torch.cuda.amp.GradScaler() if mixed_precision else None + return optimizer, scaler, scheduler, wd_scheduler diff --git a/evals/image_classification_frozen/eval.py b/evals/image_classification_frozen/eval.py index 56d2f28..248d6aa 100644 --- a/evals/image_classification_frozen/eval.py +++ b/evals/image_classification_frozen/eval.py @@ -30,20 +30,20 @@ from timm.data import create_transform as timm_make_transforms -import src.models.vision_transformer as vit -from src.models.attentive_pooler import AttentiveClassifier -from src.datasets.data_manager import ( +import jepa_src.models.vision_transformer as vit +from jepa_src.models.attentive_pooler import AttentiveClassifier +from jepa_src.datasets.data_manager import ( init_data, ) -from src.utils.distributed import ( +from jepa_src.utils.distributed import ( init_distributed, AllReduce ) -from src.utils.schedulers import ( +from jepa_src.utils.schedulers import ( WarmupCosineSchedule, CosineWDSchedule, ) -from src.utils.logging import ( +from jepa_src.utils.logging import ( AverageMeter, CSVLogger ) diff --git a/evals/main.py b/evals/main.py index c614edb..2efa2a0 100644 --- a/evals/main.py +++ b/evals/main.py @@ -12,7 +12,7 @@ import pprint import yaml -from src.utils.distributed import init_distributed +from jepa_src.utils.distributed import init_distributed from evals.scaffold import main as eval_main diff --git a/evals/video_classification_frozen/eval.py b/evals/video_classification_frozen/eval.py index f81f526..91af6e7 100644 --- a/evals/video_classification_frozen/eval.py +++ b/evals/video_classification_frozen/eval.py @@ -28,20 +28,20 @@ from torch.nn.parallel import DistributedDataParallel -import src.models.vision_transformer as vit -from src.models.attentive_pooler import AttentiveClassifier -from src.datasets.data_manager import ( +import jepa_src.models.vision_transformer as vit +from jepa_src.models.attentive_pooler import AttentiveClassifier +from jepa_src.datasets.data_manager import ( init_data, ) -from src.utils.distributed import ( +from jepa_src.utils.distributed import ( init_distributed, AllReduce ) -from src.utils.schedulers import ( +from jepa_src.utils.schedulers import ( WarmupCosineSchedule, CosineWDSchedule, ) -from src.utils.logging import ( +from jepa_src.utils.logging import ( AverageMeter, CSVLogger ) diff --git a/evals/video_classification_frozen/utils.py b/evals/video_classification_frozen/utils.py index 450f799..6853588 100644 --- a/evals/video_classification_frozen/utils.py +++ b/evals/video_classification_frozen/utils.py @@ -11,13 +11,13 @@ import torch.nn as nn import torchvision.transforms as transforms -import src.datasets.utils.video.transforms as video_transforms -import src.datasets.utils.video.volume_transforms as volume_transforms +import jepa_src.datasets.utils.video.transforms as video_transforms +import jepa_src.datasets.utils.video.volume_transforms as volume_transforms -from src.datasets.utils.video.randerase import RandomErasing +from jepa_src.datasets.utils.video.randerase import RandomErasing -from src.models.utils.pos_embs import get_1d_sincos_pos_embed -from src.masks.utils import apply_masks +from jepa_src.models.utils.pos_embs import get_1d_sincos_pos_embed +from jepa_src.masks.utils import apply_masks class FrameAggregation(nn.Module): diff --git a/fair_documentation.md b/fair_documentation.md new file mode 100644 index 0000000..a3579e1 --- /dev/null +++ b/fair_documentation.md @@ -0,0 +1,407 @@ +# V-JEPA: Video Joint Embedding Predictive Architecture + +Official PyTorch codebase for the _video joint-embedding predictive architecture_, V-JEPA, a method for self-supervised learning of visual representations from video. + +**[Meta AI Research, FAIR](https://ai.facebook.com/research/)** + +Adrien Bardes, Quentin Garrido, Jean Ponce, Xinlei Chen, Michael Rabbat, Yann LeCun, Mahmoud Assran*, Nicolas Ballas* + +[\[Blog\]](https://ai.meta.com/blog/v-jepa-yann-lecun-ai-model-video-joint-embedding-predictive-architecture/) +[\[Paper\]](https://ai.meta.com/research/publications/revisiting-feature-prediction-for-learning-visual-representations-from-video/) +[\[Yannic Kilcher's Video\]](https://www.youtube.com/watch?v=7UkJPwz_N_0) + +V-JEPA models are trained by passively watching video pixels from the VideoMix2M dataset, and produce versatile visual representations that perform well on downstream video and image tasks, without adaption of the model’s parameters; e.g., using a frozen backbone and only a light-weight task-specific attentive probe. + +## Method +V-JEPA pretraining is based solely on an unsupervised feature prediction objective, and does not utilize pretrained image encoders, text, negative examples, human annotations, or pixel-level reconstruction. + + + +      + + + + +## Visualizations +As opposed to generative methods that have a pixel decoder, V-JEPA has a predictor that makes predictions in latent space. +We train a conditional diffusion model to decode the V-JEPA feature-space predictions to interpretable pixels; the pretrained V-JEPA encoder and predictor networks are kept frozen in this process. +The decoder is only fed the representations predicted for the missing regions of the video, and does not have access to the unmasked regions of the video. + +The V-JEPA feature predictions are indeed grounded, and exhibit spatio-temporal consistency with the unmasked regions of the video. + + +
+ + + + +
+ +## MODEL ZOO + +#### Pretrained models + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelpatch sizeresolutioniterationsbatch sizedatadownload
ViT-L2x16x16224x22490K3072VideoMix2Mcheckpointconfigs
ViT-H2x16x16224x22490K3072VideoMix2Mcheckpointconfigs
ViT-H2x16x16384x38490K2400VideoMix2Mcheckpointconfigs
+ +#### K400 Attentive probes + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelresolutionaccuracy (16x8x3)download
ViT-L/16224x22480.8attentive probe checkpointconfigs
ViT-H/16224x22482.0attentive probe checkpointconfigs
ViT-H/16384x38481.9attentive probe checkpointconfigs
+ +#### SSv2 Attentive probes + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelresolutionaccuracy (16x2x3)download
ViT-L/16224x22469.5attentive probe checkpointconfigs
ViT-H/16224x22471.4attentive probe checkpointconfigs
ViT-H/16384x38472.2attentive probe checkpointconfigs
+ +#### ImageNet1K Attentive probes + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelresolutionaccuracydownload
ViT-L/16224x22474.8attentive probe checkpointconfigs
ViT-H/16224x22475.9attentive probe checkpointconfigs
ViT-H/16384x38477.4attentive probe checkpointconfigs
+ +#### Places205 Attentive probes + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelresolutionaccuracydownload
ViT-L/16224x22460.3attentive probe checkpointconfigs
ViT-H/16224x22461.7attentive probe checkpointconfigs
ViT-H/16384x38462.8attentive probe checkpointconfigs
+ +#### iNat21 Attentive probes + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelresolutionaccuracydownload
ViT-L/16224x22467.8attentive probe checkpointconfigs
ViT-H/16224x22467.9attentive probe checkpointconfigs
ViT-H/16384x38472.6attentive probe checkpointconfigs
+ +## Code Structure + +**Config files:** +All experiment parameters are specified in config files (as opposed to command-line arguments). See the [configs/](configs/) directory for example config files. Note, before launching an experiment, you must update the paths in the config file to point to your own directories, indicating where to save the logs and checkpoints and where to find the training data. + + +``` +. +├── app # the only place where training loops are allowed +│ ├── vjepa # Video JEPA pre-training +│ ├── main_distributed.py # entrypoint for launching app on slurm cluster +│ └── main.py # entrypoint for launching app locally on your machine for debugging +├── evals # the only place where evaluation of 'apps' are allowed +│ ├── image_classification # training an attentive probe for image classification with frozen backbone +│ ├── video_classification # training an attentive probe for video classification with frozen backbone +│ ├── main_distributed.py # entrypoint for launching distributed evaluations on slurm cluster +│ └── main.py # entrypoint for launching evaluations locally on your machine for debugging +├── src # the package +│ ├── datasets # datasets, data loaders, ... +│ ├── models # model definitions +│ ├── masks # mask collators, masking utilities, ... +│ └── utils # shared utilities +└── configs # the only place where config files are allowed (specify experiment params for app/eval runs) + ├── evals # configs for launching vjepa frozen evaluations + └── pretrain # configs for launching vjepa pretraining + +``` + +## Data preparation + +### Video Datasets +V-JEPA pretraining and evaluations work with many standard video formats. +To make a video dataset compatible with the V-JEPA codebase, you simply need to create a `.csv` file with the following format and then specify the path to this CSV file in your config. +``` +/absolute_file_path.[mp4, webvid, etc.] $integer_class_label +/absolute_file_path.[mp4, webvid, etc.] $integer_class_label +/absolute_file_path.[mp4, webvid, etc.] $integer_class_label +... +``` +Since V-JEPA is entirely unsupervised, the pretraining code will disregard the `$integer_class_label` in the CSV file. +Thus, feel free to put a random value in this column. +However, if you wish to run a supervised video classification evaluation on your video dataset, you must replace ```$integer_class_label``` with the ground truth label for each video. + +### Image Datasets +We use the standard PyTorch ```ImageFolder``` class in our image classification evals. +Thus, to set up an image dataset for the image classification evaluation, first create a directory to store your image datasets ```$your_directory_containing_image_datasets```. +Next, download your image datasets into this directory in a format compatible with [PyTorch ImageFolder](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html). + +For example, suppose we have a directory called ``my_image_datasets``. We would then download our image datasets into this directory so that we end up with the following file tree +``` +. +└── /my_image_datasets/ # where we store image datasets + ├── places205/121517/pytorch/ # Places205 + │ └── [...] + ├── iNaturalist-2021/110421/ # iNaturalist21 + │ └── [...] + ├── [...] # Other Image Datasets + │ └── [...] + └── imagenet_full_size/061417/ # ImageNet1k + └── train + │ ├── $class_1 + │ │ ├── xxx.[png, jpeg, etc.] + │ │ ├── [...] + │ │ └── xxz.[png, jpeg, etc.] + │ ├── [...] + │ └── $class_n + │ ├── abc.[png, jpeg, etc.] + │ ├── [...] + │ └── abz.[png, jpeg, etc.] + └── val + ├── $class_1 + │ ├── xxx.[png, jpeg, etc.] + │ ├── [...] + │ └── xxz.[png, jpeg, etc.] + ├── [...] + └── $class_n + ├── abc.[png, jpeg, etc.] + ├── [...] + └── abz.[png, jpeg, etc.] +``` + + +## Launching V-JEPA pretraining + +### Local training +If you wish to debug your code or setup before launching a distributed training run, we provide the functionality to do so by running the pretraining script locally on a multi-GPU (or single-GPU) machine, however, reproducing our results requires launching distributed training. + +The single-machine implementation starts from the [app/main.py](appmain.py), which parses the experiment config file and runs the pretraining locally on a multi-GPU (or single-GPU) machine. +For example, to run V-JEPA pretraining on GPUs "0", "1", and "2" on a local machine using the config [configs/pretrain/vitl16.yaml](configs/pretrain/vitl16.yaml), type the command: +```bash +python -m app.main \ + --fname configs/pretrain/vitl16.yaml \ + --devices cuda:0 cuda:1 cuda:2 +``` + +### Distributed training +To launch a distributed training run, the implementation starts from [app/main_distributed.py](app/main_distributed.py), which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source [submitit](https://github.com/facebookincubator/submitit) tool and provide examples for a SLURM cluster. + +For example, to launch a distributed pre-training experiment using the config [configs/pretrain/vitl16.yaml](configs/pretrain/vitl16.yaml), type the command: +```bash +python -m app.main_distributed \ + --fname configs/pretrain/vitl16.yaml \ + --folder $path_to_save_stderr_and_stdout \ + --partition $slurm_partition +``` + +## Launching Evaluations + +### Local training +If you wish to debug your eval code or setup before launching a distributed training run, we provide the functionality to do so by running the evaluation script locally on a multi-GPU (or single-GPU) machine, however, reproducing the full eval would require launching distributed training. +The single-machine implementation starts from the [eval/main.py](eval/main.py), which parses the experiment config file and runs the eval locally on a multi-GPU (or single-GPU) machine. + +For example, to run ImageNet image classification on GPUs "0", "1", and "2" on a local machine using the config [configs/eval/vitl16_in1k.yaml](configs/eval/vitl16_in1k.yaml), type the command: +```bash +python -m evals.main \ + --fname configs/eval/vitl16_in1k.yaml \ + --devices cuda:0 cuda:1 cuda:2 +``` + + +### Distributed training +To launch a distributed evaluation run, the implementation starts from [eval/main_distributed.py](eval/main_distributed.py), which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source [submitit](https://github.com/facebookincubator/submitit) tool and provide examples for a SLURM cluster. + +For example, to launch a distributed ImageNet image classification experiment using the config [configs/eval/vitl16_in1k.yaml](configs/eval/vitl16_in1k.yaml), type the command: +```bash +python -m evals.main_distributed \ + --fname configs/eval/vitl16_in1k.yaml \ + --folder $path_to_save_stderr_and_stdout \ + --partition $slurm_partition +``` + +Similarly, to launch a distributed K400 video classification experiment using the config [configs/eval/vitl16_k400.yaml](configs/eval/vitl16_k400.yaml), type the command: +```bash +python -m evals.main_distributed \ + --fname configs/eval/vitl16_k400.yaml \ + --folder $path_to_save_stderr_and_stdout \ + --partition $slurm_partition +``` + +--- + +### Setup + +Run: +```bash +conda create -n jepa python=3.9 pip +conda activate jepa +python setup.py install +``` + +## License +See the [LICENSE](./LICENSE) file for details about the license under which this code is made available. + +## Citation +If you find this repository useful in your research, please consider giving a star :star: and a citation +```bibtex +@article{bardes2024revisiting, + title={Revisiting Feature Prediction for Learning Visual Representations from Video}, + author={Bardes, Adrien and Garrido, Quentin and Ponce, Jean and Rabbat, Michael, and LeCun, Yann and Assran, Mahmoud and Ballas, Nicolas}, + journal={arXiv preprint}, + year={2024} +} diff --git a/huggingface/README.md b/huggingface/README.md new file mode 100644 index 0000000..263b08f --- /dev/null +++ b/huggingface/README.md @@ -0,0 +1,78 @@ + VJEPA Encoder + +The VJEPA Encoder finetuned JEPA model trained on [High Speed and High Dynamic Range Video with an Event Camera IEEE Transactions on Pattern Analysis and Machine Intelligence, 2019](https://rpg.ifi.uzh.ch/event_driving_datasets.html). This package is an adaptation to `facebookresearch/jepa` to enable ease of use of the Jepa Architecture built with Vision Transformers. + +## Installation + +To install the VJEPA Encoder package, you can use pip: + +``` +pip install vjepa_encoder +``` + +## Usage + +To use the VJEPA Encoder in your Python code, you can import it as follows: + +```python +from vjepa_encoder.vision_encoder import JepaEncoder +``` + +### Loading the Encoder + +To load the pre-trained encoder, you can use the `load_model` function: + +```python +encoder = JepaEncoder.load_model(config_file_path, devices) +``` + +- `config_file_path`: Path to the configuration file (YAML) containing the model settings. +- `devices`: List of devices (e.g., `['cuda:0']`) to use for distributed training. If not provided, the model will be loaded on the CPU. + +### Preprocessing Data + +The VJEPA Encoder provides a `preprocess_data` function to preprocess input data before feeding it to the encoder: + +```python +preprocessed_data = encoder.preprocess_data(input_data) +``` + +- `input_data`: Input data, which can be an image path, image array, PIL Image, or PyTorch tensor. + +### Embedding Images + +To obtain the embeddings for an image, you can use the `embed_image` function: + +```python +embeddings = encoder.embed_image(input_data) +``` + +- `input_data`: Input data, which can be an image path, image array, PIL Image, or PyTorch tensor. + +The function returns the embeddings generated by the encoder. + +## Configuration + +The VJEPA Encoder requires a configuration file in YAML format to specify the model settings. The configuration file should include the following sections: + +- `meta`: General settings such as the checkpoint file path, random seed, etc. +- `mask`: Settings related to masking. +- `model`: Model architecture settings. +- `data`: Data-related settings such as crop size, patch size, etc. +- `logging`: Logging settings. + +Please refer to the provided configuration file template for more details. + +## License + +The VJEPA Encoder is released under the [MIT License](LICENSE). + +## Acknowledgments + +The VJEPA Encoder is based on the research work conducted by Facebook AI Research. We would like to acknowledge their contributions to the field of computer vision and representation learning. + +## Contact + +If you have any questions or suggestions regarding the VJEPA Encoder, please feel free to contact me at johnnykoch02@gmail.com. + +--- \ No newline at end of file diff --git a/huggingface/demo_jepa_encoder.py b/huggingface/demo_jepa_encoder.py new file mode 100644 index 0000000..878bc38 --- /dev/null +++ b/huggingface/demo_jepa_encoder.py @@ -0,0 +1,14 @@ +from vjepa_encoder.vision_encoder import JepaEncoder + +encoder = JepaEncoder.load_model( + "logs/params-encoder.yaml" +) + +import numpy +img = numpy.random.random(size=(360, 480, 3)) + +print("Input Img:", img.shape) +embedding = encoder.embed_image(img) + +print(embedding) +print(embedding.shape) \ No newline at end of file diff --git a/huggingface/params-encoder.yaml b/huggingface/params-encoder.yaml new file mode 100644 index 0000000..d6e4f64 --- /dev/null +++ b/huggingface/params-encoder.yaml @@ -0,0 +1,89 @@ +app: vjepa +data: + batch_size: 8 + clip_duration: null + crop_size: 224 + dataset_type: VideoDataset + datasets: + - /path/to/dataset.csv + decode_one_clip: true + filter_short_videos: false + num_clips: 1 + num_frames: 16 + num_workers: 4 + patch_size: 16 + pin_mem: true + sampling_rate: 4 + tubelet_size: 2 +data_aug: + auto_augment: false + motion_shift: false + random_resize_aspect_ratio: + - 0.75 + - 1.35 + random_resize_scale: + - 0.3 + - 1.0 + reprob: 0.0 +logging: + folder: /path/to/logs + write_tag: jepa +loss: + loss_exp: 1.0 + reg_coeff: 0.0 +mask: +- aspect_ratio: + - 0.75 + - 1.5 + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 8 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 +- aspect_ratio: + - 0.75 + - 1.5 + max_keep: null + max_temporal_keep: 1.0 + num_blocks: 2 + spatial_scale: + - 0.7 + - 0.7 + temporal_scale: + - 1.0 + - 1.0 +meta: + dtype: bfloat16 + eval_freq: 100 + load_checkpoint: true + read_checkpoint: /path/to/vitl16.pth.tar + save_every_freq: 5 + seed: 234 + use_sdpa: true +model: + model_name: vit_large + pred_depth: 12 + pred_embed_dim: 384 + uniform_power: true + use_mask_tokens: true + zero_init_mask_tokens: true +nodes: 16 +optimization: + clip_grad: 10.0 + ema: + - 0.998 + - 1.0 + epochs: 25 + final_lr: 1.0e-06 + final_weight_decay: 0.4 + ipe: 300 + ipe_scale: 1.25 + lr: 0.000625 + start_lr: 0.0002 + warmup: 40 + weight_decay: 0.04 +tasks_per_node: 8 diff --git a/jepa_encoder.egg-info/PKG-INFO b/jepa_encoder.egg-info/PKG-INFO new file mode 100644 index 0000000..6a3951f --- /dev/null +++ b/jepa_encoder.egg-info/PKG-INFO @@ -0,0 +1,17 @@ +Metadata-Version: 2.1 +Name: jepa-encoder +Version: 0.0.1 +Summary: JEPA research code. +Requires-Python: >=3.9 +License-File: LICENSE +Requires-Dist: pyyaml +Requires-Dist: numpy +Requires-Dist: opencv-python +Requires-Dist: submitit +Requires-Dist: braceexpand +Requires-Dist: webdataset +Requires-Dist: timm +Requires-Dist: decord +Requires-Dist: pandas +Requires-Dist: einops +Requires-Dist: beartype diff --git a/jepa_encoder.egg-info/SOURCES.txt b/jepa_encoder.egg-info/SOURCES.txt new file mode 100644 index 0000000..00be8b0 --- /dev/null +++ b/jepa_encoder.egg-info/SOURCES.txt @@ -0,0 +1,10 @@ +LICENSE +README.md +setup.py +jepa_encoder.egg-info/PKG-INFO +jepa_encoder.egg-info/SOURCES.txt +jepa_encoder.egg-info/dependency_links.txt +jepa_encoder.egg-info/requires.txt +jepa_encoder.egg-info/top_level.txt +vjepa_encoder/__init__.py +vjepa_encoder/vision_encoder.py \ No newline at end of file diff --git a/jepa_encoder.egg-info/dependency_links.txt b/jepa_encoder.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/jepa_encoder.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/jepa_encoder.egg-info/requires.txt b/jepa_encoder.egg-info/requires.txt new file mode 100644 index 0000000..386919b --- /dev/null +++ b/jepa_encoder.egg-info/requires.txt @@ -0,0 +1,11 @@ +pyyaml +numpy +opencv-python +submitit +braceexpand +webdataset +timm +decord +pandas +einops +beartype diff --git a/jepa_encoder.egg-info/top_level.txt b/jepa_encoder.egg-info/top_level.txt new file mode 100644 index 0000000..cca3137 --- /dev/null +++ b/jepa_encoder.egg-info/top_level.txt @@ -0,0 +1 @@ +vjepa_encoder diff --git a/jepa_src/__init__.py b/jepa_src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/datasets/__init__.py b/jepa_src/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/datasets/data_manager.py b/jepa_src/datasets/data_manager.py new file mode 100644 index 0000000..cf53940 --- /dev/null +++ b/jepa_src/datasets/data_manager.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from logging import getLogger + + +_GLOBAL_SEED = 0 +logger = getLogger() + + +def init_data( + batch_size, + transform=None, + shared_transform=None, + data='ImageNet', + collator=None, + pin_mem=True, + num_workers=8, + world_size=1, + rank=0, + root_path=None, + image_folder=None, + training=True, + copy_data=False, + drop_last=True, + tokenize_txt=True, + subset_file=None, + clip_len=8, + frame_sample_rate=2, + duration=None, + num_clips=1, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(1e9), + decode_one_clip=True, + datasets_weights=None, + persistent_workers=False, + repeat_wds=False, + ipe=300, + log_dir=None, +): + + if (data.lower() == 'imagenet') \ + or (data.lower() == 'inat21') \ + or (data.lower() == 'places205'): + from jepa_src.datasets.image_dataset import make_imagedataset + dataset, data_loader, dist_sampler = make_imagedataset( + transform=transform, + batch_size=batch_size, + collator=collator, + pin_mem=pin_mem, + training=training, + num_workers=num_workers, + world_size=world_size, + rank=rank, + root_path=root_path, + image_folder=image_folder, + persistent_workers=persistent_workers, + copy_data=copy_data, + drop_last=drop_last, + subset_file=subset_file) + + elif data.lower() == 'videodataset': + from jepa_src.datasets.video_dataset import make_videodataset + dataset, data_loader, dist_sampler = make_videodataset( + data_paths=root_path, + batch_size=batch_size, + frames_per_clip=clip_len, + frame_step=frame_sample_rate, + duration=duration, + num_clips=num_clips, + random_clip_sampling=random_clip_sampling, + allow_clip_overlap=allow_clip_overlap, + filter_short_videos=filter_short_videos, + filter_long_videos=filter_long_videos, + shared_transform=shared_transform, + transform=transform, + datasets_weights=datasets_weights, + collator=collator, + num_workers=num_workers, + world_size=world_size, + rank=rank, + drop_last=drop_last, + log_dir=log_dir) + + return (data_loader, dist_sampler) diff --git a/jepa_src/datasets/image_dataset.py b/jepa_src/datasets/image_dataset.py new file mode 100644 index 0000000..84e9b08 --- /dev/null +++ b/jepa_src/datasets/image_dataset.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +from logging import getLogger + +import torch +import torchvision + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class ImageFolder(torchvision.datasets.ImageFolder): + + def __init__( + self, + root, + image_folder='imagenet_full_size/061417/', + transform=None, + train=True, + ): + """ + ImageFolder + :param root: root network directory for ImageFolder data + :param image_folder: path to images inside root network directory + :param train: whether to load train data (or validation) + """ + + suffix = 'train/' if train else 'val/' + data_path = os.path.join(root, image_folder, suffix) + logger.info(f'data-path {data_path}') + super(ImageFolder, self).__init__(root=data_path, transform=transform) + logger.info('Initialized ImageFolder') + + +def make_imagedataset( + transform, + batch_size, + collator=None, + pin_mem=True, + num_workers=8, + world_size=1, + rank=0, + root_path=None, + image_folder=None, + training=True, + copy_data=False, + drop_last=True, + persistent_workers=False, + subset_file=None +): + dataset = ImageFolder( + root=root_path, + image_folder=image_folder, + transform=transform, + train=training) + logger.info('ImageFolder dataset created') + dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset=dataset, + num_replicas=world_size, + rank=rank) + data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=collator, + sampler=dist_sampler, + batch_size=batch_size, + drop_last=drop_last, + pin_memory=pin_mem, + num_workers=num_workers, + persistent_workers=persistent_workers) + logger.info('ImageFolder unsupervised data loader created') + + return dataset, data_loader, dist_sampler diff --git a/jepa_src/datasets/utils/__init__.py b/jepa_src/datasets/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/datasets/utils/video/__init__.py b/jepa_src/datasets/utils/video/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/datasets/utils/video/functional.py b/jepa_src/datasets/utils/video/functional.py new file mode 100644 index 0000000..a91d15d --- /dev/null +++ b/jepa_src/datasets/utils/video/functional.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numbers +import cv2 +import numpy as np +import PIL +import torch + + +def _is_tensor_clip(clip): + return torch.is_tensor(clip) and clip.ndimension() == 4 + + +def crop_clip(clip, min_h, min_w, h, w): + if isinstance(clip[0], np.ndarray): + cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] + + elif isinstance(clip[0], PIL.Image.Image): + cropped = [ + img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return cropped + + +def resize_clip(clip, size, interpolation='bilinear'): + if isinstance(clip[0], np.ndarray): + if isinstance(size, numbers.Number): + im_h, im_w, im_c = clip[0].shape + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[0], size[1] + if interpolation == 'bilinear': + np_inter = cv2.INTER_LINEAR + else: + np_inter = cv2.INTER_NEAREST + scaled = [ + cv2.resize(img, size, interpolation=np_inter) for img in clip + ] + elif isinstance(clip[0], PIL.Image.Image): + if isinstance(size, numbers.Number): + im_w, im_h = clip[0].size + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[1], size[0] + if interpolation == 'bilinear': + pil_inter = PIL.Image.BILINEAR + else: + pil_inter = PIL.Image.NEAREST + scaled = [img.resize(size, pil_inter) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return scaled + + +def get_resize_sizes(im_h, im_w, size): + if im_w < im_h: + ow = size + oh = int(size * im_h / im_w) + else: + oh = size + ow = int(size * im_w / im_h) + return oh, ow + + +def normalize(clip, mean, std, inplace=False): + if not _is_tensor_clip(clip): + raise TypeError('tensor is not a torch clip.') + + if not inplace: + clip = clip.clone() + + dtype = clip.dtype + mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) + std = torch.as_tensor(std, dtype=dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + + return clip diff --git a/jepa_src/datasets/utils/video/randaugment.py b/jepa_src/datasets/utils/video/randaugment.py new file mode 100644 index 0000000..4c80a99 --- /dev/null +++ b/jepa_src/datasets/utils/video/randaugment.py @@ -0,0 +1,518 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py +pulished under an Apache License 2.0. +""" + +import math +import numpy as np +import random +import re +import PIL +from PIL import Image, ImageEnhance, ImageOps + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10.0 + +_HPARAMS_DEFAULT = { + "translate_const": 250, + "img_mean": _FILL, +} + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop("resample", Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if "fillcolor" in kwargs and _PIL_VER < (5, 0): + kwargs.pop("fillcolor") + kwargs["resample"] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs + ) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs + ) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs + ) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs + ) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs + ) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs + ) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], + -rotn_center[1] - post_trans[1], + matrix, + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs["resample"]) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30.0 + level = _randomly_negate(level) + return (level,) + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + + +def _enhance_increasing_level_to_arg(level, _hparams): + # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend + # range [0.1, 1.9] + level = (level / _MAX_LEVEL) * 0.9 + level = 1.0 + _randomly_negate(level) + return (level,) + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams["translate_const"] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return (level,) + + +def _translate_rel_level_to_arg(level, hparams): + # default range [-0.45, 0.45] + translate_pct = hparams.get("translate_pct", 0.45) + level = (level / _MAX_LEVEL) * translate_pct + level = _randomly_negate(level) + return (level,) + + +def _posterize_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4),) + + +def _posterize_increasing_level_to_arg(level, hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image', + # intensity/severity of augmentation increases with level + return (4 - _posterize_level_to_arg(level, hparams)[0],) + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4) + 4,) + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 256),) + + +def _solarize_increasing_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation increases with level + return (256 - _solarize_level_to_arg(level, _hparams)[0],) + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return (int((level / _MAX_LEVEL) * 110),) + + +LEVEL_TO_ARG = { + "AutoContrast": None, + "Equalize": None, + "Invert": None, + "Rotate": _rotate_level_to_arg, + # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + "Posterize": _posterize_level_to_arg, + "PosterizeIncreasing": _posterize_increasing_level_to_arg, + "PosterizeOriginal": _posterize_original_level_to_arg, + "Solarize": _solarize_level_to_arg, + "SolarizeIncreasing": _solarize_increasing_level_to_arg, + "SolarizeAdd": _solarize_add_level_to_arg, + "Color": _enhance_level_to_arg, + "ColorIncreasing": _enhance_increasing_level_to_arg, + "Contrast": _enhance_level_to_arg, + "ContrastIncreasing": _enhance_increasing_level_to_arg, + "Brightness": _enhance_level_to_arg, + "BrightnessIncreasing": _enhance_increasing_level_to_arg, + "Sharpness": _enhance_level_to_arg, + "SharpnessIncreasing": _enhance_increasing_level_to_arg, + "ShearX": _shear_level_to_arg, + "ShearY": _shear_level_to_arg, + "TranslateX": _translate_abs_level_to_arg, + "TranslateY": _translate_abs_level_to_arg, + "TranslateXRel": _translate_rel_level_to_arg, + "TranslateYRel": _translate_rel_level_to_arg, +} + + +NAME_TO_OP = { + "AutoContrast": auto_contrast, + "Equalize": equalize, + "Invert": invert, + "Rotate": rotate, + "Posterize": posterize, + "PosterizeIncreasing": posterize, + "PosterizeOriginal": posterize, + "Solarize": solarize, + "SolarizeIncreasing": solarize, + "SolarizeAdd": solarize_add, + "Color": color, + "ColorIncreasing": color, + "Contrast": contrast, + "ContrastIncreasing": contrast, + "Brightness": brightness, + "BrightnessIncreasing": brightness, + "Sharpness": sharpness, + "SharpnessIncreasing": sharpness, + "ShearX": shear_x, + "ShearY": shear_y, + "TranslateX": translate_x_abs, + "TranslateY": translate_y_abs, + "TranslateXRel": translate_x_rel, + "TranslateYRel": translate_y_rel, +} + + +class AugmentOp: + """ + Apply for video. + """ + + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = { + "fillcolor": hparams["img_mean"] + if "img_mean" in hparams + else _FILL, + "resample": hparams["interpolation"] + if "interpolation" in hparams + else _RANDOM_INTERPOLATION, + } + + # If magnitude_std is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from a normal distribution + # with mean `magnitude` and std-dev of `magnitude_std`. + # NOTE This is my own hack, being tested, not in papers or reference impls. + self.magnitude_std = self.hparams.get("magnitude_std", 0) + + def __call__(self, img_list): + if self.prob < 1.0 and random.random() > self.prob: + return img_list + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + level_args = ( + self.level_fn(magnitude, self.hparams) + if self.level_fn is not None + else () + ) + + if isinstance(img_list, list): + return [ + self.aug_fn(img, *level_args, **self.kwargs) for img in img_list + ] + else: + return self.aug_fn(img_list, *level_args, **self.kwargs) + + +_RAND_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "Posterize", + "Solarize", + "SolarizeAdd", + "Color", + "Contrast", + "Brightness", + "Sharpness", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + + +_RAND_INCREASING_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "PosterizeIncreasing", + "SolarizeIncreasing", + "SolarizeAdd", + "ColorIncreasing", + "ContrastIncreasing", + "BrightnessIncreasing", + "SharpnessIncreasing", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + + +# These experimental weights are based loosely on the relative improvements mentioned in paper. +# They may not result in increased performance, but could likely be tuned to so. +_RAND_CHOICE_WEIGHTS_0 = { + "Rotate": 0.3, + "ShearX": 0.2, + "ShearY": 0.2, + "TranslateXRel": 0.1, + "TranslateYRel": 0.1, + "Color": 0.025, + "Sharpness": 0.025, + "AutoContrast": 0.025, + "Solarize": 0.005, + "SolarizeAdd": 0.005, + "Contrast": 0.005, + "Brightness": 0.005, + "Equalize": 0.005, + "Posterize": 0, + "Invert": 0, +} + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [ + AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) + for name in transforms + ] + + +class RandAugment: + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + # no replacement when using weighted choice + ops = np.random.choice( + self.ops, + self.num_layers, + replace=self.choice_weights is None, + p=self.choice_weights, + ) + for op in ops: + img = op(img) + return img + + +def rand_augment_transform(config_str, hparams): + """ + RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 + + Create a RandAugment transform + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) + 'mstd' - float std deviation of magnitude noise applied + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 + :param hparams: Other hparams (kwargs) for the RandAugmentation scheme + :return: A PyTorch compatible Transform + """ + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + transforms = _RAND_TRANSFORMS + config = config_str.split("-") + assert config[0] == "rand" + config = config[1:] + for c in config: + cs = re.split(r"(\d.*)", c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == "mstd": + # noise param injected via hparams for now + hparams.setdefault("magnitude_std", float(val)) + elif key == "inc": + if bool(val): + transforms = _RAND_INCREASING_TRANSFORMS + elif key == "m": + magnitude = int(val) + elif key == "n": + num_layers = int(val) + elif key == "w": + weight_idx = int(val) + else: + assert NotImplementedError + ra_ops = rand_augment_ops( + magnitude=magnitude, hparams=hparams, transforms=transforms + ) + choice_weights = ( + None if weight_idx is None else _select_rand_weights(weight_idx) + ) + return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) diff --git a/jepa_src/datasets/utils/video/randerase.py b/jepa_src/datasets/utils/video/randerase.py new file mode 100644 index 0000000..d1f185c --- /dev/null +++ b/jepa_src/datasets/utils/video/randerase.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py +pulished under an Apache License 2.0. +""" +import math +import random +import torch + + +def _get_pixels( + per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" +): + # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() + # paths, flip the order so normal is run on CPU if this becomes a problem + # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 + if per_pixel: + return torch.empty(patch_size, dtype=dtype, device=device).normal_() + elif rand_color: + return torch.empty( + (patch_size[0], 1, 1), dtype=dtype, device=device + ).normal_() + else: + return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) + + +class RandomErasing: + """Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + This variant of RandomErasing is intended to be applied to either a batch + or single image tensor after it has been normalized by dataset mean and std. + Args: + probability: Probability that the Random Erasing operation will be performed. + min_area: Minimum percentage of erased area wrt input image area. + max_area: Maximum percentage of erased area wrt input image area. + min_aspect: Minimum aspect ratio of erased area. + mode: pixel color mode, one of 'const', 'rand', or 'pixel' + 'const' - erase block is constant color of 0 for all channels + 'rand' - erase block is same per-channel random (normal) color + 'pixel' - erase block is per-pixel random (normal) color + max_count: maximum number of erasing blocks per image, area per box is scaled by count. + per-image count is randomly chosen between 1 and this value. + """ + + def __init__( + self, + probability=0.5, + min_area=0.02, + max_area=1 / 3, + min_aspect=0.3, + max_aspect=None, + mode="const", + min_count=1, + max_count=None, + num_splits=0, + device="cuda", + cube=True, + ): + self.probability = probability + self.min_area = min_area + self.max_area = max_area + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + self.min_count = min_count + self.max_count = max_count or min_count + self.num_splits = num_splits + mode = mode.lower() + self.rand_color = False + self.per_pixel = False + self.cube = cube + if mode == "rand": + self.rand_color = True # per block random normal + elif mode == "pixel": + self.per_pixel = True # per pixel random normal + else: + assert not mode or mode == "const" + self.device = device + + def _erase(self, img, chan, img_h, img_w, dtype): + if random.random() > self.probability: + return + area = img_h * img_w + count = ( + self.min_count + if self.min_count == self.max_count + else random.randint(self.min_count, self.max_count) + ) + for _ in range(count): + for _ in range(10): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / count + ) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + img[:, top:top + h, left:left + w] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) + break + + def _erase_cube( + self, + img, + batch_start, + batch_size, + chan, + img_h, + img_w, + dtype, + ): + if random.random() > self.probability: + return + area = img_h * img_w + count = ( + self.min_count + if self.min_count == self.max_count + else random.randint(self.min_count, self.max_count) + ) + for _ in range(count): + for _ in range(100): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / count + ) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + for i in range(batch_start, batch_size): + img_instance = img[i] + img_instance[ + :, top:top + h, left:left + w + ] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) + break + + def __call__(self, input): + if len(input.size()) == 3: + self._erase(input, *input.size(), input.dtype) + else: + batch_size, chan, img_h, img_w = input.size() + # skip first slice of batch if num_splits is set (for clean portion of samples) + batch_start = ( + batch_size // self.num_splits if self.num_splits > 1 else 0 + ) + if self.cube: + self._erase_cube( + input, + batch_start, + batch_size, + chan, + img_h, + img_w, + input.dtype, + ) + else: + for i in range(batch_start, batch_size): + self._erase(input[i], chan, img_h, img_w, input.dtype) + return input diff --git a/jepa_src/datasets/utils/video/transforms.py b/jepa_src/datasets/utils/video/transforms.py new file mode 100644 index 0000000..979985d --- /dev/null +++ b/jepa_src/datasets/utils/video/transforms.py @@ -0,0 +1,1184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +import numpy as np +import random +import numbers +import PIL +from PIL import Image + +import torch +import torchvision +import torchvision.transforms.functional as F +from torchvision import transforms + +import jepa_src.datasets.utils.video.functional as FF +from jepa_src.datasets.utils.video.randaugment import rand_augment_transform + + +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _pil_interp(method): + if method == 'bicubic': + return Image.BICUBIC + elif method == 'lanczos': + return Image.LANCZOS + elif method == 'hamming': + return Image.HAMMING + else: + return Image.BILINEAR + + +def random_short_side_scale_jitter( + images, min_size, max_size, boxes=None, inverse_uniform_sampling=False +): + """ + Perform a spatial short scale jittering on the given images and + corresponding boxes. + Args: + images (tensor): images to perform scale jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + min_size (int): the minimal size to scale the frames. + max_size (int): the maximal size to scale the frames. + boxes (ndarray): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + inverse_uniform_sampling (bool): if True, sample uniformly in + [1 / max_scale, 1 / min_scale] and take a reciprocal to get the + scale. If False, take a uniform sample from [min_scale, max_scale]. + Returns: + (tensor): the scaled images with dimension of + `num frames` x `channel` x `new height` x `new width`. + (ndarray or None): the scaled boxes with dimension of + `num boxes` x 4. + """ + if inverse_uniform_sampling: + size = int( + round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)) + ) + else: + size = int(round(np.random.uniform(min_size, max_size))) + + height = images.shape[2] + width = images.shape[3] + if (width <= height and width == size) or ( + height <= width and height == size + ): + return images, boxes + new_width = size + new_height = size + if width < height: + new_height = int(math.floor((float(height) / width) * size)) + if boxes is not None: + boxes = boxes * float(new_height) / height + else: + new_width = int(math.floor((float(width) / height) * size)) + if boxes is not None: + boxes = boxes * float(new_width) / width + + return ( + torch.nn.functional.interpolate( + images, + size=(new_height, new_width), + mode='bilinear', + align_corners=False, + ), + boxes, + ) + + +def crop_boxes(boxes, x_offset, y_offset): + """ + Peform crop on the bounding boxes given the offsets. + Args: + boxes (ndarray or None): bounding boxes to peform crop. The dimension + is `num boxes` x 4. + x_offset (int): cropping offset in the x axis. + y_offset (int): cropping offset in the y axis. + Returns: + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + cropped_boxes = boxes.copy() + cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset + cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset + + return cropped_boxes + + +def random_crop(images, size, boxes=None): + """ + Perform random spatial crop on the given images and corresponding boxes. + Args: + images (tensor): images to perform random crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): the size of height and width to crop on the image. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + Returns: + cropped (tensor): cropped images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + if images.shape[2] == size and images.shape[3] == size: + return images + height = images.shape[2] + width = images.shape[3] + y_offset = 0 + if height > size: + y_offset = int(np.random.randint(0, height - size)) + x_offset = 0 + if width > size: + x_offset = int(np.random.randint(0, width - size)) + cropped = images[ + :, :, y_offset:y_offset + size, x_offset:x_offset + size + ] + + cropped_boxes = ( + crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + ) + + return cropped, cropped_boxes + + +def horizontal_flip(prob, images, boxes=None): + """ + Perform horizontal flip on the given images and corresponding boxes. + Args: + prob (float): probility to flip the images. + images (tensor): images to perform horizontal flip, the dimension is + `num frames` x `channel` x `height` x `width`. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + Returns: + images (tensor): images with dimension of + `num frames` x `channel` x `height` x `width`. + flipped_boxes (ndarray or None): the flipped boxes with dimension of + `num boxes` x 4. + """ + if boxes is None: + flipped_boxes = None + else: + flipped_boxes = boxes.copy() + + if np.random.uniform() < prob: + images = images.flip((-1)) + + if len(images.shape) == 3: + width = images.shape[2] + elif len(images.shape) == 4: + width = images.shape[3] + else: + raise NotImplementedError("Dimension does not supported") + if boxes is not None: + flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1 + + return images, flipped_boxes + + +def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): + """ + Perform uniform spatial sampling on the images and corresponding boxes. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode='bilinear', + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[ + :, :, y_offset:y_offset + size, x_offset:x_offset + size + ] + cropped_boxes = ( + crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + ) + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped, cropped_boxes + + +def clip_boxes_to_image(boxes, height, width): + """ + Clip an array of boxes to an image with the given height and width. + Args: + boxes (ndarray): bounding boxes to perform clipping. + Dimension is `num boxes` x 4. + height (int): given image height. + width (int): given image width. + Returns: + clipped_boxes (ndarray): the clipped boxes with dimension of + `num boxes` x 4. + """ + clipped_boxes = boxes.copy() + clipped_boxes[:, [0, 2]] = np.minimum( + width - 1.0, np.maximum(0.0, boxes[:, [0, 2]]) + ) + clipped_boxes[:, [1, 3]] = np.minimum( + height - 1.0, np.maximum(0.0, boxes[:, [1, 3]]) + ) + return clipped_boxes + + +def blend(images1, images2, alpha): + """ + Blend two images with a given weight alpha. + Args: + images1 (tensor): the first images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + images2 (tensor): the second images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + alpha (float): the blending weight. + Returns: + (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + return images1 * alpha + images2 * (1 - alpha) + + +def grayscale(images): + """ + Get the grayscale for the input images. The channels of images should be + in order BGR. + Args: + images (tensor): the input images for getting grayscale. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + img_gray (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + # R -> 0.299, G -> 0.587, B -> 0.114. + img_gray = torch.tensor(images) + gray_channel = ( + 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] + ) + img_gray[:, 0] = gray_channel + img_gray[:, 1] = gray_channel + img_gray[:, 2] = gray_channel + return img_gray + + +def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): + """ + Perfrom a color jittering on the input images. The channels of images + should be in order BGR. + Args: + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + img_brightness (float): jitter ratio for brightness. + img_contrast (float): jitter ratio for contrast. + img_saturation (float): jitter ratio for saturation. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + + jitter = [] + if img_brightness != 0: + jitter.append('brightness') + if img_contrast != 0: + jitter.append('contrast') + if img_saturation != 0: + jitter.append('saturation') + + if len(jitter) > 0: + order = np.random.permutation(np.arange(len(jitter))) + for idx in range(0, len(jitter)): + if jitter[order[idx]] == 'brightness': + images = brightness_jitter(img_brightness, images) + elif jitter[order[idx]] == 'contrast': + images = contrast_jitter(img_contrast, images) + elif jitter[order[idx]] == 'saturation': + images = saturation_jitter(img_saturation, images) + return images + + +def brightness_jitter(var, images): + """ + Perfrom brightness jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for brightness. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_bright = torch.zeros(images.shape) + images = blend(images, img_bright, alpha) + return images + + +def contrast_jitter(var, images): + """ + Perfrom contrast jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for contrast. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_gray = grayscale(images) + img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True) + images = blend(images, img_gray, alpha) + return images + + +def saturation_jitter(var, images): + """ + Perfrom saturation jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for saturation. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + img_gray = grayscale(images) + images = blend(images, img_gray, alpha) + + return images + + +def lighting_jitter(images, alphastd, eigval, eigvec): + """ + Perform AlexNet-style PCA jitter on the given images. + Args: + images (tensor): images to perform lighting jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + alphastd (float): jitter ratio for PCA jitter. + eigval (list): eigenvalues for PCA jitter. + eigvec (list[list]): eigenvectors for PCA jitter. + Returns: + out_images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if alphastd == 0: + return images + # generate alpha1, alpha2, alpha3. + alpha = np.random.normal(0, alphastd, size=(1, 3)) + eig_vec = np.array(eigvec) + eig_val = np.reshape(eigval, (1, 3)) + rgb = np.sum( + eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), + axis=1, + ) + out_images = torch.zeros_like(images) + if len(images.shape) == 3: + # C H W + channel_dim = 0 + elif len(images.shape) == 4: + # T C H W + channel_dim = 1 + else: + raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + + for idx in range(images.shape[channel_dim]): + # C H W + if len(images.shape) == 3: + out_images[idx] = images[idx] + rgb[2 - idx] + # T C H W + elif len(images.shape) == 4: + out_images[:, idx] = images[:, idx] + rgb[2 - idx] + else: + raise NotImplementedError( + f'Unsupported dimension {len(images.shape)}' + ) + + return out_images + + +def color_normalization(images, mean, stddev): + """ + Perform color nomration on the given images. + Args: + images (tensor): images to perform color normalization. Dimension is + `num frames` x `channel` x `height` x `width`. + mean (list): mean values for normalization. + stddev (list): standard deviations for normalization. + + Returns: + out_images (tensor): the noramlized images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if len(images.shape) == 3: + assert ( + len(mean) == images.shape[0] + ), 'channel mean not computed properly' + assert ( + len(stddev) == images.shape[0] + ), 'channel stddev not computed properly' + elif len(images.shape) == 4: + assert ( + len(mean) == images.shape[1] + ), 'channel mean not computed properly' + assert ( + len(stddev) == images.shape[1] + ), 'channel stddev not computed properly' + else: + raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + + out_images = torch.zeros_like(images) + for idx in range(len(mean)): + # C H W + if len(images.shape) == 3: + out_images[idx] = (images[idx] - mean[idx]) / stddev[idx] + elif len(images.shape) == 4: + out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] + else: + raise NotImplementedError( + f'Unsupported dimension {len(images.shape)}' + ) + return out_images + + +def _get_param_spatial_crop( + scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False +): + """ + Given scale, ratio, height and width, return sampled coordinates of the videos. + """ + for _ in range(num_repeat): + area = height * width + target_area = random.uniform(*scale) * area + if log_scale: + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + else: + aspect_ratio = random.uniform(*ratio) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if np.random.uniform() < 0.5 and switch_hw: + w, h = h, w + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(ratio): + w = width + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + +def random_resized_crop( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + Crop the given images to random size and aspect ratio. A crop of random + size (default: of 0.08 to 1.0) of the original size and a random aspect + ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This + crop is finally resized to given size. This is popularly used to train the + Inception networks. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + cropped = images[:, :, i:i + h, j:j + w] + return torch.nn.functional.interpolate( + cropped, + size=(target_height, target_width), + mode='bilinear', + align_corners=False, + ) + + +def random_resized_crop_with_shift( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + This is similar to random_resized_crop. However, it samples two different + boxes (for cropping) for the first and last frame. It then linearly + interpolates the two boxes for other frames. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + t = images.shape[1] + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width) + i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()] + j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()] + h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()] + w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] + out = torch.zeros((3, t, target_height, target_width)) + for ind in range(t): + out[:, ind:ind + 1, :, :] = torch.nn.functional.interpolate( + images[ + :, + ind:ind + 1, + i_s[ind]:i_s[ind] + h_s[ind], + j_s[ind]:j_s[ind] + w_s[ind], + ], + size=(target_height, target_width), + mode='bilinear', + align_corners=False, + ) + return out + + +def create_random_augment( + input_size, + auto_augment=None, + interpolation='bilinear', +): + """ + Get video randaug transform. + + Args: + input_size: The size of the input video in tuple. + auto_augment: Parameters for randaug. An example: + "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number + of operations to apply). + interpolation: Interpolation method. + """ + if isinstance(input_size, tuple): + img_size = input_size[-2:] + else: + img_size = input_size + + if auto_augment: + assert isinstance(auto_augment, str) + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size + aa_params = {'translate_const': int(img_size_min * 0.45)} + if interpolation and interpolation != 'random': + aa_params['interpolation'] = _pil_interp(interpolation) + if auto_augment.startswith('rand'): + return transforms.Compose( + [rand_augment_transform(auto_augment, aa_params)] + ) + raise NotImplementedError + + +def random_sized_crop_img( + im, + size, + jitter_scale=(0.08, 1.0), + jitter_aspect=(3.0 / 4.0, 4.0 / 3.0), + max_iter=10, +): + """ + Performs Inception-style cropping (used for training). + """ + assert ( + len(im.shape) == 3 + ), 'Currently only support image for random_sized_crop' + h, w = im.shape[1:3] + i, j, h, w = _get_param_spatial_crop( + scale=jitter_scale, + ratio=jitter_aspect, + height=h, + width=w, + num_repeat=max_iter, + log_scale=False, + switch_hw=True, + ) + cropped = im[:, i:i + h, j:j + w] + return torch.nn.functional.interpolate( + cropped.unsqueeze(0), + size=(size, size), + mode='bilinear', + align_corners=False, + ).squeeze(0) + + +# The following code are modified based on timm lib, we will replace the following +# contents with dependency from PyTorchVideo. +# https://github.com/facebookresearch/pytorchvideo +class RandomResizedCropAndInterpolation: + """Crop the given PIL Image to random size and aspect ratio with random interpolation. + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation='bilinear', + ): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + print('range should be of kind (min, max)') + + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = _pil_interp(interpolation) + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for _ in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + return F.resized_crop(img, i, j, h, w, self.size, interpolation) + + def __repr__(self): + if isinstance(self.interpolation, (tuple, list)): + interpolate_str = ' '.join( + [_pil_interpolation_to_str[x] for x in self.interpolation] + ) + else: + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format( + tuple(round(s, 4) for s in self.scale) + ) + format_string += ', ratio={0}'.format( + tuple(round(r, 4) for r in self.ratio) + ) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + +class Compose(object): + """Composes several transforms + Args: + transforms (list of ``Transform`` objects): list of transforms + to compose + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, clip): + for t in self.transforms: + clip = t(clip) + return clip + + +class RandomHorizontalFlip(object): + """Horizontally flip the list of given images randomly + with a probability 0.5 + """ + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Randomly flipped clip + """ + if random.random() < 0.5: + if isinstance(clip[0], np.ndarray): + return [np.fliplr(img) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + return [ + img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + ' but got list of {0}'.format(type(clip[0]))) + return clip + + +class RandomResize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): + self.ratio = ratio + self.interpolation = interpolation + + def __call__(self, clip): + scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) + + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + + new_w = int(im_w * scaling_factor) + new_h = int(im_h * scaling_factor) + new_size = (new_w, new_h) + resized = FF.resize_clip( + clip, new_size, interpolation=self.interpolation) + return resized + + +class Resize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, size, interpolation='nearest'): + self.size = size + self.interpolation = interpolation + + def __call__(self, clip): + resized = FF.resize_clip( + clip, self.size, interpolation=self.interpolation) + return resized + + +class RandomCrop(object): + """Extract random crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w > im_w or h > im_h: + error_msg = ( + 'Initial image size should be larger then ' + 'cropped size but got cropped sizes : ({w}, {h}) while ' + 'initial image is ({im_w}, {im_h})'.format( + im_w=im_w, im_h=im_h, w=w, h=h)) + raise ValueError(error_msg) + + x1 = random.randint(0, im_w - w) + y1 = random.randint(0, im_h - h) + cropped = FF.crop_clip(clip, y1, x1, h, w) + + return cropped + + +class ThreeCrop(object): + """Extract random crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w != im_w and h != im_h: + clip = FF.resize_clip(clip, self.size, interpolation="bilinear") + im_h, im_w, im_c = clip[0].shape + + step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0) + cropped = [] + for i in range(3): + if (im_h > self.size[0]): + x1 = 0 + y1 = i * step + cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) + else: + x1 = i * step + y1 = 0 + cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) + return cropped + + +class RandomRotation(object): + """Rotate entire clip randomly by a random angle within + given bounds + Args: + degrees (sequence or int): Range of degrees to select from + If degrees is a number instead of sequence like (min, max), + the range of degrees, will be (-degrees, +degrees). + """ + + def __init__(self, degrees): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError('If degrees is a single number,' + 'must be positive') + degrees = (-degrees, degrees) + else: + if len(degrees) != 2: + raise ValueError('If degrees is a sequence,' + 'it must be of len 2.') + + self.degrees = degrees + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + import skimage + angle = random.uniform(self.degrees[0], self.degrees[1]) + if isinstance(clip[0], np.ndarray): + rotated = [skimage.transform.rotate(img, angle) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + rotated = [img.rotate(angle) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + + return rotated + + +class CenterCrop(object): + """Extract center crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w > im_w or h > im_h: + error_msg = ( + 'Initial image size should be larger then ' + 'cropped size but got cropped sizes : ({w}, {h}) while ' + 'initial image is ({im_w}, {im_h})'.format( + im_w=im_w, im_h=im_h, w=w, h=h)) + raise ValueError(error_msg) + + x1 = int(round((im_w - w) / 2.)) + y1 = int(round((im_h - h) / 2.)) + cropped = FF.crop_clip(clip, y1, x1, h, w) + + return cropped + + +class ColorJitter(object): + """ + Randomly change the brightness, contrast and saturation and hue of the clip + + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + def get_params(self, brightness, contrast, saturation, hue): + if brightness > 0: + brightness_factor = random.uniform( + max(0, 1 - brightness), 1 + brightness) + else: + brightness_factor = None + + if contrast > 0: + contrast_factor = random.uniform( + max(0, 1 - contrast), 1 + contrast) + else: + contrast_factor = None + + if saturation > 0: + saturation_factor = random.uniform( + max(0, 1 - saturation), 1 + saturation) + else: + saturation_factor = None + + if hue > 0: + hue_factor = random.uniform(-hue, hue) + else: + hue_factor = None + return brightness_factor, contrast_factor, saturation_factor, hue_factor + + def __call__(self, clip): + """ + Args: + clip (list): list of PIL.Image + Returns: + list PIL.Image : list of transformed PIL.Image + """ + if isinstance(clip[0], np.ndarray): + raise TypeError( + 'Color jitter not yet implemented for numpy arrays') + elif isinstance(clip[0], PIL.Image.Image): + brightness, contrast, saturation, hue = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + # Create img transform function sequence + img_transforms = [] + if brightness is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) + if saturation is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) + if hue is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) + if contrast is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) + random.shuffle(img_transforms) + + # Apply to all images + jittered_clip = [] + for img in clip: + for func in img_transforms: + jittered_img = func(img) + jittered_clip.append(jittered_img) + + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return jittered_clip + + +class Normalize(object): + """Normalize a clip with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + .. note:: + This transform acts out of place, i.e., it does not mutates the input tensor. + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, clip): + """ + Args: + clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized. + Returns: + Tensor: Normalized Tensor clip. + """ + return FF.normalize(clip, self.mean, self.std) + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) diff --git a/jepa_src/datasets/utils/video/volume_transforms.py b/jepa_src/datasets/utils/video/volume_transforms.py new file mode 100644 index 0000000..0a01bb3 --- /dev/null +++ b/jepa_src/datasets/utils/video/volume_transforms.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numpy as np +from PIL import Image + +import torch + + +def convert_img(img): + """Converts (H, W, C) numpy.ndarray to (C, W, H) format""" + if len(img.shape) == 3: + img = img.transpose(2, 0, 1) + if len(img.shape) == 2: + img = np.expand_dims(img, 0) + return img + + +class ClipToTensor(object): + """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] + to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] + """ + + def __init__(self, channel_nb=3, div_255=True, numpy=False): + self.channel_nb = channel_nb + self.div_255 = div_255 + self.numpy = numpy + + def __call__(self, clip): + """ + Args: clip (list of numpy.ndarray): clip (list of images) + to be converted to tensor. + """ + # Retrieve shape + if isinstance(clip[0], np.ndarray): + h, w, ch = clip[0].shape + assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) + elif isinstance(clip[0], Image.Image): + w, h = clip[0].size + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + + np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) + + # Convert + for img_idx, img in enumerate(clip): + if isinstance(img, np.ndarray): + pass + elif isinstance(img, Image.Image): + img = np.array(img, copy=False) + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + img = convert_img(img) + np_clip[:, img_idx, :, :] = img + if self.numpy: + if self.div_255: + np_clip = np_clip / 255.0 + return np_clip + + else: + tensor_clip = torch.from_numpy(np_clip) + + if not isinstance(tensor_clip, torch.FloatTensor): + tensor_clip = tensor_clip.float() + if self.div_255: + tensor_clip = torch.div(tensor_clip, 255) + return tensor_clip + + +# Note this norms data to -1/1 +class ClipToTensor_K(object): + """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] + to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] + """ + + def __init__(self, channel_nb=3, div_255=True, numpy=False): + self.channel_nb = channel_nb + self.div_255 = div_255 + self.numpy = numpy + + def __call__(self, clip): + """ + Args: clip (list of numpy.ndarray): clip (list of images) + to be converted to tensor. + """ + # Retrieve shape + if isinstance(clip[0], np.ndarray): + h, w, ch = clip[0].shape + assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) + elif isinstance(clip[0], Image.Image): + w, h = clip[0].size + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + + np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) + + # Convert + for img_idx, img in enumerate(clip): + if isinstance(img, np.ndarray): + pass + elif isinstance(img, Image.Image): + img = np.array(img, copy=False) + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + img = convert_img(img) + np_clip[:, img_idx, :, :] = img + if self.numpy: + if self.div_255: + np_clip = (np_clip - 127.5) / 127.5 + return np_clip + + else: + tensor_clip = torch.from_numpy(np_clip) + + if not isinstance(tensor_clip, torch.FloatTensor): + tensor_clip = tensor_clip.float() + if self.div_255: + tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) + return tensor_clip + + +class ToTensor(object): + """Converts numpy array to tensor""" + + def __call__(self, array): + tensor = torch.from_numpy(array) + return tensor diff --git a/jepa_src/datasets/utils/weighted_sampler.py b/jepa_src/datasets/utils/weighted_sampler.py new file mode 100644 index 0000000..fd40825 --- /dev/null +++ b/jepa_src/datasets/utils/weighted_sampler.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from typing import Iterator, Optional +from operator import itemgetter +import numpy as np + +import torch +from torch.utils.data import ( + Dataset, + Sampler, + DistributedSampler, + WeightedRandomSampler +) + + +class DatasetFromSampler(Dataset): + + def __init__(self, sampler: Sampler): + self.sampler = sampler + self.sampler_list = None + + def __getitem__(self, index: int): + if self.sampler_list is None: + self.sampler_list = list(self.sampler) + return self.sampler_list[index] + + def __len__(self) -> int: + return len(self.sampler) + + +class DistributedSamplerWrapper(DistributedSampler): + """ Convert any Pytorch Sampler to a DistributedSampler """ + + def __init__( + self, + sampler, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + super(DistributedSamplerWrapper, self).__init__( + DatasetFromSampler(sampler), + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) + self.sampler = sampler + + def __iter__(self) -> Iterator[int]: + self.dataset = DatasetFromSampler(self.sampler) + indexes_of_indexes = super().__iter__() + subsampler_indexes = self.dataset + return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) + + +class CustomWeightedRandomSampler(WeightedRandomSampler): + """ Generalized WeightedRandomSampler to allow for more than 2^24 samples """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __iter__(self): + rand_tensor = np.random.choice( + range(0, len(self.weights)), + size=self.num_samples, + p=self.weights.numpy() / torch.sum(self.weights).numpy(), + replace=self.replacement + ) + rand_tensor = torch.from_numpy(rand_tensor) + return iter(rand_tensor.tolist()) + + +class DistributedWeightedSampler(DistributedSamplerWrapper): + + def __init__( + self, + weights, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + weighted_sampler = CustomWeightedRandomSampler( + weights=weights, + num_samples=len(weights), + replacement=False) + + super(DistributedWeightedSampler, self).__init__( + sampler=weighted_sampler, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) diff --git a/jepa_src/datasets/video_dataset.py b/jepa_src/datasets/video_dataset.py new file mode 100644 index 0000000..82cee52 --- /dev/null +++ b/jepa_src/datasets/video_dataset.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import pathlib +import warnings + +from logging import getLogger + +import numpy as np +import pandas as pd + +from decord import VideoReader, cpu + +import torch + +from jepa_src.datasets.utils.weighted_sampler import DistributedWeightedSampler + +_GLOBAL_SEED = 0 +logger = getLogger() + + +def make_videodataset( + data_paths, + batch_size, + frames_per_clip=8, + frame_step=4, + num_clips=1, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(10**9), + transform=None, + shared_transform=None, + rank=0, + world_size=1, + datasets_weights=None, + collator=None, + drop_last=True, + num_workers=10, + pin_mem=True, + duration=None, + log_dir=None, +): + dataset = VideoDataset( + data_paths=data_paths, + datasets_weights=datasets_weights, + frames_per_clip=frames_per_clip, + frame_step=frame_step, + num_clips=num_clips, + random_clip_sampling=random_clip_sampling, + allow_clip_overlap=allow_clip_overlap, + filter_short_videos=filter_short_videos, + filter_long_videos=filter_long_videos, + duration=duration, + shared_transform=shared_transform, + transform=transform) + + logger.info('VideoDataset dataset created') + if datasets_weights is not None: + dist_sampler = DistributedWeightedSampler( + dataset.sample_weights, + num_replicas=world_size, + rank=rank, + shuffle=True) + else: + dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=True) + + data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=collator, + sampler=dist_sampler, + batch_size=batch_size, + drop_last=drop_last, + pin_memory=pin_mem, + num_workers=num_workers, + persistent_workers=num_workers > 0) + logger.info('VideoDataset unsupervised data loader created') + + return dataset, data_loader, dist_sampler + + +class VideoDataset(torch.utils.data.Dataset): + """ Video classification dataset. """ + + def __init__( + self, + data_paths, + datasets_weights=None, + frames_per_clip=16, + frame_step=4, + num_clips=1, + transform=None, + shared_transform=None, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(10**9), + duration=None, # duration in seconds + ): + self.data_paths = data_paths + self.datasets_weights = datasets_weights + self.frames_per_clip = frames_per_clip + self.frame_step = frame_step + self.num_clips = num_clips + self.transform = transform + self.shared_transform = shared_transform + self.random_clip_sampling = random_clip_sampling + self.allow_clip_overlap = allow_clip_overlap + self.filter_short_videos = filter_short_videos + self.filter_long_videos = filter_long_videos + self.duration = duration + + if VideoReader is None: + raise ImportError('Unable to import "decord" which is required to read videos.') + + # Load video paths and labels + samples, labels = [], [] + self.num_samples_per_dataset = [] + for data_path in self.data_paths: + + if data_path[-4:] == '.csv': + data = pd.read_csv(data_path, header=None, delimiter=" ") + samples += list(data.values[:, 0]) + labels += list(data.values[:, 1]) + num_samples = len(data) + self.num_samples_per_dataset.append(num_samples) + + elif data_path[-4:] == '.npy': + data = np.load(data_path, allow_pickle=True) + data = list(map(lambda x: repr(x)[1:-1], data)) + samples += data + labels += [0] * len(data) + num_samples = len(data) + self.num_samples_per_dataset.append(len(data)) + + # [Optional] Weights for each sample to be used by downstream + # weighted video sampler + self.sample_weights = None + if self.datasets_weights is not None: + self.sample_weights = [] + for dw, ns in zip(self.datasets_weights, self.num_samples_per_dataset): + self.sample_weights += [dw / ns] * ns + + self.samples = samples + self.labels = labels + + def __getitem__(self, index): + sample = self.samples[index] + + # Keep trying to load videos until you find a valid sample + loaded_video = False + while not loaded_video: + buffer, clip_indices = self.loadvideo_decord(sample) # [T H W 3] + loaded_video = len(buffer) > 0 + if not loaded_video: + index = np.random.randint(self.__len__()) + sample = self.samples[index] + + # Label/annotations for video + label = self.labels[index] + + def split_into_clips(video): + """ Split video into a list of clips """ + fpc = self.frames_per_clip + nc = self.num_clips + return [video[i*fpc:(i+1)*fpc] for i in range(nc)] + + # Parse video into frames & apply data augmentations + if self.shared_transform is not None: + buffer = self.shared_transform(buffer) + buffer = split_into_clips(buffer) + if self.transform is not None: + buffer = [self.transform(clip) for clip in buffer] + + return buffer, label, clip_indices + + def loadvideo_decord(self, sample): + """ Load video content using Decord """ + + fname = sample + if not os.path.exists(fname): + warnings.warn(f'video path not found {fname}') + return [], None + + _fsize = os.path.getsize(fname) + if _fsize < 1 * 1024: # avoid hanging issue + warnings.warn(f'video too short {fname}') + return [], None + if _fsize > self.filter_long_videos: + warnings.warn(f'skipping long video of size {_fsize} (bytes)') + return [], None + + try: + vr = VideoReader(fname, num_threads=-1, ctx=cpu(0)) + except Exception: + return [], None + + fpc = self.frames_per_clip + fstp = self.frame_step + if self.duration is not None: + try: + fps = vr.get_avg_fps() + fstp = int(self.duration * fps / fpc) + except Exception as e: + warnings.warn(e) + clip_len = int(fpc * fstp) + + if self.filter_short_videos and len(vr) < clip_len: + warnings.warn(f'skipping video of length {len(vr)}') + return [], None + + vr.seek(0) # Go to start of video before sampling frames + + # Partition video into equal sized segments and sample each clip + # from a different segment + partition_len = len(vr) // self.num_clips + + all_indices, clip_indices = [], [] + for i in range(self.num_clips): + + if partition_len > clip_len: + # If partition_len > clip len, then sample a random window of + # clip_len frames within the segment + end_indx = clip_len + if self.random_clip_sampling: + end_indx = np.random.randint(clip_len, partition_len) + start_indx = end_indx - clip_len + indices = np.linspace(start_indx, end_indx, num=fpc) + indices = np.clip(indices, start_indx, end_indx-1).astype(np.int64) + # -- + indices = indices + i * partition_len + else: + # If partition overlap not allowed and partition_len < clip_len + # then repeatedly append the last frame in the segment until + # we reach the desired clip length + if not self.allow_clip_overlap: + indices = np.linspace(0, partition_len, num=partition_len // fstp) + indices = np.concatenate((indices, np.ones(fpc - partition_len // fstp) * partition_len,)) + indices = np.clip(indices, 0, partition_len-1).astype(np.int64) + # -- + indices = indices + i * partition_len + + # If partition overlap is allowed and partition_len < clip_len + # then start_indx of segment i+1 will lie within segment i + else: + sample_len = min(clip_len, len(vr)) - 1 + indices = np.linspace(0, sample_len, num=sample_len // fstp) + indices = np.concatenate((indices, np.ones(fpc - sample_len // fstp) * sample_len,)) + indices = np.clip(indices, 0, sample_len-1).astype(np.int64) + # -- + clip_step = 0 + if len(vr) > clip_len: + clip_step = (len(vr) - clip_len) // (self.num_clips - 1) + indices = indices + i * clip_step + + clip_indices.append(indices) + all_indices.extend(list(indices)) + + buffer = vr.get_batch(all_indices).asnumpy() + return buffer, clip_indices + + def __len__(self): + return len(self.samples) diff --git a/jepa_src/masks/__init__.py b/jepa_src/masks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/masks/default.py b/jepa_src/masks/default.py new file mode 100644 index 0000000..2810c0a --- /dev/null +++ b/jepa_src/masks/default.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from logging import getLogger + +import torch + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class DefaultCollator(object): + + def __call__(self, batch): + collated_batch = torch.utils.data.default_collate(batch) + return collated_batch, None, None diff --git a/jepa_src/masks/multiblock3d.py b/jepa_src/masks/multiblock3d.py new file mode 100644 index 0000000..a7bbc3e --- /dev/null +++ b/jepa_src/masks/multiblock3d.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +from multiprocessing import Value + +from logging import getLogger + +import torch + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class MaskCollator(object): + + def __init__( + self, + cfgs_mask, + crop_size=(224, 224), + num_frames=16, + patch_size=(16, 16), + tubelet_size=2, + ): + super(MaskCollator, self).__init__() + + self.mask_generators = [] + for m in cfgs_mask: + mask_generator = _MaskGenerator( + crop_size=crop_size, + num_frames=num_frames, + spatial_patch_size=patch_size, + temporal_patch_size=tubelet_size, + spatial_pred_mask_scale=m.get('spatial_scale'), + temporal_pred_mask_scale=m.get('temporal_scale'), + aspect_ratio=m.get('aspect_ratio'), + npred=m.get('num_blocks'), + max_context_frames_ratio=m.get('max_temporal_keep', 1.0), + max_keep=m.get('max_keep', None), + ) + self.mask_generators.append(mask_generator) + + def step(self): + for mask_generator in self.mask_generators: + mask_generator.step() + + def __call__(self, batch): + + batch_size = len(batch) + collated_batch = torch.utils.data.default_collate(batch) + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(batch_size) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + + return collated_batch, collated_masks_enc, collated_masks_pred + + +class _MaskGenerator(object): + + def __init__( + self, + crop_size=(224, 224), + num_frames=16, + spatial_patch_size=(16, 16), + temporal_patch_size=2, + spatial_pred_mask_scale=(0.2, 0.8), + temporal_pred_mask_scale=(1.0, 1.0), + aspect_ratio=(0.3, 3.0), + npred=1, + max_context_frames_ratio=1.0, + max_keep=None, + ): + super(_MaskGenerator, self).__init__() + if not isinstance(crop_size, tuple): + crop_size = (crop_size, ) * 2 + self.crop_size = crop_size + self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.duration = num_frames // temporal_patch_size + + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.aspect_ratio = aspect_ratio + self.spatial_pred_mask_scale = spatial_pred_mask_scale + self.temporal_pred_mask_scale = temporal_pred_mask_scale + self.npred = npred + self.max_context_duration = max(1, int(self.duration * max_context_frames_ratio)) # maximum number of time-steps (frames) spanned by context mask + self.max_keep = max_keep # maximum number of patches to keep in context + self._itr_counter = Value('i', -1) # collator is shared across worker processes + + def step(self): + i = self._itr_counter + with i.get_lock(): + i.value += 1 + v = i.value + return v + + def _sample_block_size( + self, + generator, + temporal_scale, + spatial_scale, + aspect_ratio_scale + ): + # -- Sample temporal block mask scale + _rand = torch.rand(1, generator=generator).item() + min_t, max_t = temporal_scale + temporal_mask_scale = min_t + _rand * (max_t - min_t) + t = max(1, int(self.duration * temporal_mask_scale)) + + # -- Sample spatial block mask scale + _rand = torch.rand(1, generator=generator).item() + min_s, max_s = spatial_scale + spatial_mask_scale = min_s + _rand * (max_s - min_s) + spatial_num_keep = int(self.height * self.width * spatial_mask_scale) + + # -- Sample block aspect-ratio + _rand = torch.rand(1, generator=generator).item() + min_ar, max_ar = aspect_ratio_scale + aspect_ratio = min_ar + _rand * (max_ar - min_ar) + + # -- Compute block height and width (given scale and aspect-ratio) + h = int(round(math.sqrt(spatial_num_keep * aspect_ratio))) + w = int(round(math.sqrt(spatial_num_keep / aspect_ratio))) + h = min(h, self.height) + w = min(w, self.width) + + return (t, h, w) + + def _sample_block_mask(self, b_size): + t, h, w = b_size + top = torch.randint(0, self.height - h + 1, (1,)) + left = torch.randint(0, self.width - w + 1, (1,)) + start = torch.randint(0, self.duration - t + 1, (1,)) + + mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) + mask[start:start+t, top:top+h, left:left+w] = 0 + + # Context mask will only span the first X frames + # (X=self.max_context_frames) + if self.max_context_duration < self.duration: + mask[self.max_context_duration:, :, :] = 0 + + # -- + return mask + + def __call__(self, batch_size): + """ + Create encoder and predictor masks when collating imgs into a batch + # 1. sample pred block size using seed + # 2. sample several pred block locations for each image (w/o seed) + # 3. return pred masks and complement (enc mask) + """ + seed = self.step() + g = torch.Generator() + g.manual_seed(seed) + p_size = self._sample_block_size( + generator=g, + temporal_scale=self.temporal_pred_mask_scale, + spatial_scale=self.spatial_pred_mask_scale, + aspect_ratio_scale=self.aspect_ratio, + ) + + collated_masks_pred, collated_masks_enc = [], [] + min_keep_enc = min_keep_pred = self.duration * self.height * self.width + for _ in range(batch_size): + + empty_context = True + while empty_context: + + mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) + for _ in range(self.npred): + mask_e *= self._sample_block_mask(p_size) + mask_e = mask_e.flatten() + + mask_p = torch.argwhere(mask_e == 0).squeeze() + mask_e = torch.nonzero(mask_e).squeeze() + + empty_context = len(mask_e) == 0 + if not empty_context: + min_keep_pred = min(min_keep_pred, len(mask_p)) + min_keep_enc = min(min_keep_enc, len(mask_e)) + collated_masks_pred.append(mask_p) + collated_masks_enc.append(mask_e) + + if self.max_keep is not None: + min_keep_enc = min(min_keep_enc, self.max_keep) + + collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred] + collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) + # -- + collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc] + collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) + + return collated_masks_enc, collated_masks_pred diff --git a/jepa_src/masks/random_tube.py b/jepa_src/masks/random_tube.py new file mode 100644 index 0000000..84c0640 --- /dev/null +++ b/jepa_src/masks/random_tube.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from multiprocessing import Value + +from logging import getLogger + +import torch +import numpy as np + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class MaskCollator(object): + + def __init__( + self, + cfgs_mask, + crop_size=(224, 224), + num_frames=16, + patch_size=(16, 16), + tubelet_size=2, + ): + super(MaskCollator, self).__init__() + + self.mask_generators = [] + for m in cfgs_mask: + mask_generator = _MaskGenerator( + crop_size=crop_size, + num_frames=num_frames, + spatial_patch_size=patch_size, + temporal_patch_size=tubelet_size, + ratio=m.get('ratio'), + ) + self.mask_generators.append(mask_generator) + + def step(self): + for mask_generator in self.mask_generators: + mask_generator.step() + + def __call__(self, batch): + + batch_size = len(batch) + collated_batch = torch.utils.data.default_collate(batch) + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(batch_size) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + + return collated_batch, collated_masks_enc, collated_masks_pred + + +class _MaskGenerator(object): + + def __init__( + self, + crop_size=(224, 224), + num_frames=16, + spatial_patch_size=(16, 16), + temporal_patch_size=2, + ratio=0.9, + ): + super(_MaskGenerator, self).__init__() + if not isinstance(crop_size, tuple): + crop_size = (crop_size, ) * 2 + self.crop_size = crop_size + self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.duration = num_frames // temporal_patch_size + + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + self.num_patches_spatial = self.height*self.width + + self.ratio = ratio + + self.num_keep_spatial = int(self.num_patches_spatial*(1.-self.ratio)) + self.num_keep = self.num_keep_spatial * self.duration + + self._itr_counter = Value('i', -1) # collator is shared across worker processes + + def step(self): + i = self._itr_counter + with i.get_lock(): + i.value += 1 + v = i.value + return v + + def __call__(self, batch_size): + def sample_mask(): + mask = np.hstack([ + np.zeros(self.num_patches_spatial - self.num_keep_spatial), + np.ones(self.num_keep_spatial), + ]) + np.random.shuffle(mask) + mask = torch.tensor(np.tile(mask, (self.duration, 1))) + mask = mask.flatten() + mask_p = torch.argwhere(mask == 0).squeeze() + mask_e = torch.nonzero(mask).squeeze() + return mask_e, mask_p + + collated_masks_pred, collated_masks_enc = [], [] + for _ in range(batch_size): + mask_e, mask_p = sample_mask() + collated_masks_enc.append(mask_e) + collated_masks_pred.append(mask_p) + + collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) + collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) + + return collated_masks_enc, collated_masks_pred diff --git a/jepa_src/masks/utils.py b/jepa_src/masks/utils.py new file mode 100644 index 0000000..ca04af1 --- /dev/null +++ b/jepa_src/masks/utils.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch + + +def apply_masks(x, masks, concat=True): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + if not concat: + return all_x + + return torch.cat(all_x, dim=0) diff --git a/jepa_src/models/__init__.py b/jepa_src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/models/attentive_pooler.py b/jepa_src/models/attentive_pooler.py new file mode 100644 index 0000000..26b0e0e --- /dev/null +++ b/jepa_src/models/attentive_pooler.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +import torch +import torch.nn as nn + +from jepa_src.models.utils.modules import ( + Block, + CrossAttention, + CrossAttentionBlock +) +from jepa_src.utils.tensors import trunc_normal_ + + +class AttentivePooler(nn.Module): + """ Attentive Pooler """ + def __init__( + self, + num_queries=1, + embed_dim=768, + num_heads=12, + mlp_ratio=4.0, + depth=1, + norm_layer=nn.LayerNorm, + init_std=0.02, + qkv_bias=True, + complete_block=True + ): + super().__init__() + self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) + + self.complete_block = complete_block + if complete_block: + self.cross_attention_block = CrossAttentionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer) + else: + self.cross_attention_block = CrossAttention( + dim=embed_dim, + num_heads=num_heads, + qkv_bias=qkv_bias) + + self.blocks = None + if depth > 1: + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=False, + norm_layer=norm_layer) + for i in range(depth-1)]) + + self.init_std = init_std + trunc_normal_(self.query_tokens, std=self.init_std) + self.apply(self._init_weights) + self._rescale_blocks() + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + if self.complete_block: + rescale(self.cross_attention_block.xattn.proj.weight.data, 1) + rescale(self.cross_attention_block.mlp.fc2.weight.data, 1) + else: + rescale(self.cross_attention_block.proj.weight.data, 1) + if self.blocks is not None: + for layer_id, layer in enumerate(self.blocks, 1): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + q = self.query_tokens.repeat(len(x), 1, 1) + q = self.cross_attention_block(q, x) + if self.blocks is not None: + for blk in self.blocks: + q = blk(q) + return q + + +class AttentiveClassifier(nn.Module): + """ Attentive Classifier """ + def __init__( + self, + embed_dim=768, + num_heads=12, + mlp_ratio=4.0, + depth=1, + norm_layer=nn.LayerNorm, + init_std=0.02, + qkv_bias=True, + num_classes=1000, + complete_block=True, + ): + super().__init__() + self.pooler = AttentivePooler( + num_queries=1, + embed_dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + depth=depth, + norm_layer=norm_layer, + init_std=init_std, + qkv_bias=qkv_bias, + complete_block=complete_block, + ) + self.linear = nn.Linear(embed_dim, num_classes, bias=True) + + def forward(self, x): + x = self.pooler(x).squeeze(1) + x = self.linear(x) + return x diff --git a/jepa_src/models/predictor.py b/jepa_src/models/predictor.py new file mode 100644 index 0000000..95f6bc0 --- /dev/null +++ b/jepa_src/models/predictor.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +from functools import partial + +import torch +import torch.nn as nn + +from jepa_src.models.utils.modules import Block +from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from jepa_src.utils.tensors import ( + trunc_normal_, + repeat_interleave_batch +) +from jepa_src.masks.utils import apply_masks + + +class VisionTransformerPredictor(nn.Module): + """ Vision Transformer """ + def __init__( + self, + img_size=224, + patch_size=16, + num_frames=1, + tubelet_size=2, + embed_dim=768, + predictor_embed_dim=384, + depth=6, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + uniform_power=False, + use_mask_tokens=False, + num_mask_tokens=2, + zero_init_mask_tokens=True, + **kwargs + ): + super().__init__() + # Map input to predictor dimension + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + + # Mask tokens + self.mask_tokens = None + self.num_mask_tokens = 0 + if use_mask_tokens: + self.num_mask_tokens = num_mask_tokens + self.mask_tokens = nn.ParameterList([ + nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + for i in range(num_mask_tokens) + ]) + + # Determine positional embedding + self.input_size = img_size + self.patch_size = patch_size + # -- + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + + grid_size = self.input_size // self.patch_size + grid_depth = self.num_frames // self.tubelet_size + + if self.is_video: + self.num_patches = num_patches = ( + (num_frames // tubelet_size) + * (img_size // patch_size) + * (img_size // patch_size) + ) + else: + self.num_patches = num_patches = ( + (img_size // patch_size) + * (img_size // patch_size) + ) + # Position embedding + self.uniform_power = uniform_power + self.predictor_pos_embed = None + self.predictor_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, predictor_embed_dim), + requires_grad=False) + + # Attention Blocks + self.predictor_blocks = nn.ModuleList([ + Block( + dim=predictor_embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + attn_drop=attn_drop_rate, + grid_size=grid_size, + grid_depth=grid_depth, + norm_layer=norm_layer) + for i in range(depth)]) + + # Normalize & project back to input dimension + self.predictor_norm = norm_layer(predictor_embed_dim) + self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) + + # ------ initialize weights + if self.predictor_pos_embed is not None: + self._init_pos_embed(self.predictor_pos_embed.data) # sincos pos-embed + self.init_std = init_std + if not zero_init_mask_tokens: + for mt in self.mask_tokens: + trunc_normal_(mt, std=init_std) + self.apply(self._init_weights) + self._rescale_blocks() + + def _init_pos_embed(self, pos_embed): + embed_dim = pos_embed.size(-1) + grid_size = self.input_size // self.patch_size + if self.is_video: + grid_depth = self.num_frames // self.tubelet_size + sincos = get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=self.uniform_power + ) + else: + sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) + pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.predictor_blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def diffusion(self, x, noise_beta=(0.5, 1.0), steps=1000): + + # Prepare diffusion noise schedule + b1, b2 = noise_beta + beta_scheduler = (b1 + i*(b2-b1)/steps for i in range(steps)) + alpha_scheduler = [] + _alpha = 1.0 + for _beta in beta_scheduler: + _alpha *= 1.-_beta + alpha_scheduler += [_alpha] + + # Sample diffusion time step + T = torch.randint(0, steps, (len(x),)) + alpha = torch.tensor(alpha_scheduler, device=x.device)[T].unsqueeze(-1).unsqueeze(-1) + + # Normalize features and apply noise + x = torch.nn.functional.layer_norm(x, (x.size(-1),)) + x = alpha**0.5 * x + (1.-alpha)**0.5 * torch.randn(x.shape, device=x.device) + return x + + def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): + """ + :param ctxt: context tokens + :param tgt: target tokens + :param masks_ctxt: indices of context tokens in input + :params masks_tgt: indices of target tokens in input + """ + + assert (masks_ctxt is not None) and (masks_tgt is not None), 'Cannot run predictor without mask indices' + + if not isinstance(masks_ctxt, list): + masks_ctxt = [masks_ctxt] + + if not isinstance(masks_tgt, list): + masks_tgt = [masks_tgt] + + # Batch Size + B = len(ctxt) // len(masks_ctxt) + + # Map context tokens to pedictor dimensions + x = self.predictor_embed(ctxt) + _, N_ctxt, D = x.shape + + # Add positional embedding to ctxt tokens + if self.predictor_pos_embed is not None: + ctxt_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + x += apply_masks(ctxt_pos_embed, masks_ctxt) + + # Map target tokens to predictor dimensions & add noise (fwd diffusion) + if self.mask_tokens is None: + pred_tokens = self.predictor_embed(tgt) + pred_tokens = self.diffusion(pred_tokens) + else: + mask_index = mask_index % self.num_mask_tokens + pred_tokens = self.mask_tokens[mask_index] + pred_tokens = pred_tokens.repeat(B, self.num_patches, 1) + pred_tokens = apply_masks(pred_tokens, masks_tgt) + + # Add positional embedding to target tokens + if self.predictor_pos_embed is not None: + pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) + pos_embs = apply_masks(pos_embs, masks_tgt) + pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_ctxt)) + pred_tokens += pos_embs + + # Concatenate context & target tokens + x = x.repeat(len(masks_tgt), 1, 1) + x = torch.cat([x, pred_tokens], dim=1) + + # FIXME: this implementation currently assumes masks_ctxt and masks_tgt + # are alligned 1:1 (ok with MultiMask wrapper on predictor but + # otherwise will break) + masks_ctxt = torch.cat(masks_ctxt, dim=0) + masks_tgt = torch.cat(masks_tgt, dim=0) + masks = torch.cat([masks_ctxt, masks_tgt], dim=1) + + # Fwd prop + for blk in self.predictor_blocks: + x = blk(x, mask=masks) + x = self.predictor_norm(x) + + # Return output corresponding to target tokens + x = x[:, N_ctxt:] + x = self.predictor_proj(x) + + return x + + +def vit_predictor(**kwargs): + model = VisionTransformerPredictor( + mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + return model diff --git a/jepa_src/models/utils/__init__.py b/jepa_src/models/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/models/utils/modules.py b/jepa_src/models/utils/modules.py new file mode 100644 index 0000000..dc470d9 --- /dev/null +++ b/jepa_src/models/utils/modules.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0. + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + use_sdpa=True + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop_prob = proj_drop + self.proj_drop = nn.Dropout(proj_drop) + self.use_sdpa = use_sdpa + + def forward(self, x, mask=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D] + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob) + attn = None + else: + attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D] + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + grid_size=None, + grid_depth=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x, return_attention=False, mask=None): + y, attn = self.attn(self.norm1(x), mask=mask) + if return_attention: + return attn + x = x + y + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, + dim, + num_heads=12, + qkv_bias=False, + use_sdpa=True + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, int(dim*2), bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + self.use_sdpa = use_sdpa + + def forward(self, q, x): + B, n, C = q.shape + q = self.q(q).reshape(B, n, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + B, N, C = x.shape + kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # (batch_size, num_heads, seq_len, feature_dim_per_head) + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + q = F.scaled_dot_product_attention(q, k, v) + else: + xattn = (q @ k.transpose(-2, -1)) * self.scale + xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len) + q = (xattn @ v) + + q = q.transpose(1, 2).reshape(B, n, C) + q = self.proj(q) + + return q + + +class CrossAttentionBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + + def forward(self, q, x): + y = self.xattn(q, self.norm1(x)) + q = q + y + q = q + self.mlp(self.norm2(q)) + return q diff --git a/jepa_src/models/utils/multimask.py b/jepa_src/models/utils/multimask.py new file mode 100644 index 0000000..d480086 --- /dev/null +++ b/jepa_src/models/utils/multimask.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch.nn as nn + + +class MultiMaskWrapper(nn.Module): + + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, x, masks=None): + if masks is None: + return self.backbone(x) + + if (masks is not None) and not isinstance(masks, list): + masks = [masks] + outs = [] + for m in masks: + outs += [self.backbone(x, masks=m)] + return outs + + +class PredictorMultiMaskWrapper(nn.Module): + + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, ctxt, tgt, masks_ctxt, masks_tgt): + if type(ctxt) is not list: + ctxt = [ctxt] + if type(tgt) is not list: + tgt = [tgt] + if type(masks_ctxt) is not list: + masks_ctxt = [masks_ctxt] + if type(masks_tgt) is not list: + masks_tgt = [masks_tgt] + + outs = [] + for i, (zi, hi, mc, mt) in enumerate(zip(ctxt, tgt, masks_ctxt, masks_tgt)): + outs += [self.backbone(zi, hi, mc, mt, mask_index=i)] + return outs diff --git a/jepa_src/models/utils/patch_embed.py b/jepa_src/models/utils/patch_embed.py new file mode 100644 index 0000000..4ff4de5 --- /dev/null +++ b/jepa_src/models/utils/patch_embed.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch.nn as nn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding + """ + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768 + ): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed3D(nn.Module): + """ + Image to Patch Embedding + """ + + def __init__( + self, + patch_size=16, + tubelet_size=2, + in_chans=3, + embed_dim=768, + ): + super().__init__() + self.patch_size = patch_size + self.tubelet_size = tubelet_size + + self.proj = nn.Conv3d( + in_channels=in_chans, + out_channels=embed_dim, + kernel_size=(tubelet_size, patch_size, patch_size), + stride=(tubelet_size, patch_size, patch_size), + ) + + def forward(self, x, **kwargs): + B, C, T, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x diff --git a/jepa_src/models/utils/pos_embs.py b/jepa_src/models/utils/pos_embs.py new file mode 100644 index 0000000..d1d82e2 --- /dev/null +++ b/jepa_src/models/utils/pos_embs.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numpy as np + + +def get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=False +): + """ + grid_size: int of the grid height and width + grid_depth: int of the grid depth + returns: + pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token) + or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token) + """ + grid_d = np.arange(grid_depth, dtype=float) + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid_h, grid_d, grid_w = np.meshgrid(grid_h, grid_d, grid_w) # order of meshgrid is very important for indexing as [d,h,w] + + if not uniform_power: + h_embed_dim = embed_dim // 4 + w_embed_dim = embed_dim // 4 + d_embed_dim = embed_dim // 2 + else: + h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim/6)*2) + + emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1) + emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2) + emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3) + pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1) + pos_embed = pos_embed[:, :embed_dim] + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + returns: + pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token) + or [1+grid_size*grid_size, embed_dim] (w/ cls_token) + """ + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w] + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2) + pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + embed_dim: output dimension for each position + grid_size: int of the grid length + returns: + pos_embed: [grid_size, embed_dim] (w/o cls_token) + or [1+grid_size, embed_dim] (w/ cls_token) + """ + grid = np.arange(grid_size, dtype=float) + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + returns: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/jepa_src/models/vision_transformer.py b/jepa_src/models/vision_transformer.py new file mode 100644 index 0000000..946246e --- /dev/null +++ b/jepa_src/models/vision_transformer.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +from functools import partial + +import torch +import torch.nn as nn + +from jepa_src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D +from jepa_src.models.utils.modules import Block +from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from jepa_src.utils.tensors import trunc_normal_ +from jepa_src.masks.utils import apply_masks + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__( + self, + img_size=224, + patch_size=16, + num_frames=1, + tubelet_size=2, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + out_layers=None, + uniform_power=False, + **kwargs + ): + super().__init__() + self.num_features = self.embed_dim = embed_dim + self.num_heads = num_heads + self.out_layers = out_layers + + self.input_size = img_size + self.patch_size = patch_size + + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + + grid_size = self.input_size // self.patch_size + grid_depth = self.num_frames // self.tubelet_size + + # Tokenize pixels with convolution + if self.is_video: + self.patch_embed = PatchEmbed3D( + patch_size=patch_size, + tubelet_size=tubelet_size, + in_chans=in_chans, + embed_dim=embed_dim) + self.num_patches = ( + (num_frames // tubelet_size) + * (img_size // patch_size) + * (img_size // patch_size) + ) + else: + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + self.num_patches = ( + (img_size // patch_size) + * (img_size // patch_size) + ) + + # Position embedding + self.uniform_power = uniform_power + self.pos_embed = None + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches, embed_dim), + requires_grad=False) + + # Attention Blocks + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + grid_size=grid_size, + grid_depth=grid_depth, + attn_drop=attn_drop_rate, + norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # ------ initialize weights + if self.pos_embed is not None: + self._init_pos_embed(self.pos_embed.data) # sincos pos-embed + self.init_std = init_std + self.apply(self._init_weights) + self._rescale_blocks() + + def _init_pos_embed(self, pos_embed): + embed_dim = pos_embed.size(-1) + grid_size = self.input_size // self.patch_size + if self.is_video: + grid_depth = self.num_frames // self.tubelet_size + sincos = get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=self.uniform_power + ) + else: + sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) + pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv3d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def get_num_layers(self): + return len(self.blocks) + + def no_weight_decay(self): + return {} + + def forward(self, x, masks=None): + """ + :param x: input image/video + :param masks: indices of patch tokens to mask (remove) + """ + + if masks is not None and not isinstance(masks, list): + masks = [masks] + + # Tokenize input + pos_embed = self.pos_embed + if pos_embed is not None: + pos_embed = self.interpolate_pos_encoding(x, pos_embed) + x = self.patch_embed(x) + if pos_embed is not None: + x += pos_embed + B, N, D = x.shape + + # Mask away unwanted tokens (if masks provided) + if masks is not None: + x = apply_masks(x, masks) + masks = torch.cat(masks, dim=0) + + # Fwd prop + outs = [] + for i, blk in enumerate(self.blocks): + x = blk(x, mask=masks) + if self.out_layers is not None and i in self.out_layers: + outs.append(self.norm(x)) + + if self.out_layers is not None: + return outs + + if self.norm is not None: + x = self.norm(x) + + return x + + def interpolate_pos_encoding(self, x, pos_embed): + + _, N, dim = pos_embed.shape + + if self.is_video: + + # If pos_embed already corret size, just return + _, _, T, H, W = x.shape + if H == self.input_size and W == self.input_size and T == self.num_frames: + return pos_embed + + # Convert depth, height, width of input to be measured in patches + # instead of pixels/frames + T = T // self.tubelet_size + H = H // self.patch_size + W = W // self.patch_size + + # Compute the initialized shape of the positional embedding measured + # in patches + N_t = self.num_frames // self.tubelet_size + N_h = N_w = self.input_size // self.patch_size + assert N_h * N_w * N_t == N, 'Positional embedding initialized incorrectly' + + # Compute scale factor for spatio-temporal interpolation + scale_factor = (T/N_t, H/N_h, W/N_w) + + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3), + scale_factor=scale_factor, + mode='trilinear') + pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim) + return pos_embed + + else: + + # If pos_embed already corret size, just return + _, _, H, W = x.shape + if H == self.input_size and W == self.input_size: + return pos_embed + + # Compute scale factor for spatial interpolation + npatch = (H // self.patch_size) * (W // self.patch_size) + scale_factor = math.sqrt(npatch / N) + + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=scale_factor, + mode='bicubic') + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return pos_embed + + +def vit_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_small(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_large(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_huge(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_giant(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_gigantic(patch_size=14, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1664, depth=48, num_heads=16, mpl_ratio=64/13, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs + ) + return model + + +VIT_EMBED_DIMS = { + 'vit_tiny': 192, + 'vit_small': 384, + 'vit_base': 768, + 'vit_large': 1024, + 'vit_huge': 1280, + 'vit_giant': 1408, + 'vit_gigantic': 1664, +} diff --git a/jepa_src/utils/__init__.py b/jepa_src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/utils/distributed.py b/jepa_src/utils/distributed.py new file mode 100644 index 0000000..cfba444 --- /dev/null +++ b/jepa_src/utils/distributed.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +import torch +import torch.distributed as dist + +from logging import getLogger + +logger = getLogger() + + +def init_distributed(port=37123, rank_and_world_size=(None, None)): + + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size(), dist.get_rank() + + rank, world_size = rank_and_world_size + os.environ['MASTER_ADDR'] = 'localhost' + + if (rank is None) or (world_size is None): + try: + world_size = int(os.environ['SLURM_NTASKS']) + rank = int(os.environ['SLURM_PROCID']) + os.environ['MASTER_ADDR'] = os.environ['HOSTNAME'] + except Exception: + logger.info('SLURM vars not set (distributed training not available)') + world_size, rank = 1, 0 + return world_size, rank + + try: + os.environ['MASTER_PORT'] = str(port) + torch.distributed.init_process_group( + backend='nccl', + world_size=world_size, + rank=rank + ) + except Exception as e: + world_size, rank = 1, 0 + logger.info(f'Rank: {rank}. Distributed training not available {e}') + + return world_size, rank + + +class AllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() + outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())] + dist.all_gather(outputs, x) + return torch.cat(outputs, 0) + return x + + @staticmethod + def backward(ctx, grads): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank() + e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1) + grads = grads.contiguous() + dist.all_reduce(grads) + return grads[s:e] + return grads + + +class AllReduceSum(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() + dist.all_reduce(x) + return x + + @staticmethod + def backward(ctx, grads): + return grads + + +class AllReduce(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() / dist.get_world_size() + dist.all_reduce(x) + return x + + @staticmethod + def backward(ctx, grads): + return grads diff --git a/jepa_src/utils/logging.py b/jepa_src/utils/logging.py new file mode 100644 index 0000000..fcdd3fa --- /dev/null +++ b/jepa_src/utils/logging.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging +import sys + +import torch + + +def gpu_timer(closure, log_timings=True): + """ Helper to time gpu-time to execute closure() """ + log_timings = log_timings and torch.cuda.is_available() + + elapsed_time = -1. + if log_timings: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + + result = closure() + + if log_timings: + end.record() + torch.cuda.synchronize() + elapsed_time = start.elapsed_time(end) + + return result, elapsed_time + + +LOG_FORMAT = "[%(levelname)-8s][%(asctime)s][%(funcName)-25s] %(message)s" +DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + + +def get_logger(name=None, force=False): + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=LOG_FORMAT, datefmt=DATE_FORMAT, force=force) + return logging.getLogger(name=name) + + +class CSVLogger(object): + + def __init__(self, fname, *argv): + self.fname = fname + self.types = [] + # -- print headers + with open(self.fname, '+a') as f: + for i, v in enumerate(argv, 1): + self.types.append(v[0]) + if i < len(argv): + print(v[1], end=',', file=f) + else: + print(v[1], end='\n', file=f) + + def log(self, *argv): + with open(self.fname, '+a') as f: + for i, tv in enumerate(zip(self.types, argv), 1): + end = ',' if i < len(argv) else '\n' + print(tv[0] % tv[1], end=end, file=f) + + +class AverageMeter(object): + """computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.max = float('-inf') + self.min = float('inf') + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + try: + self.max = max(val, self.max) + self.min = min(val, self.min) + except Exception: + pass + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def grad_logger(named_params): + stats = AverageMeter() + stats.first_layer = None + stats.last_layer = None + for n, p in named_params: + if (p.grad is not None) and not (n.endswith('.bias') or len(p.shape) == 1): + grad_norm = float(torch.norm(p.grad.data)) + stats.update(grad_norm) + if 'qkv' in n: + stats.last_layer = grad_norm + if stats.first_layer is None: + stats.first_layer = grad_norm + if stats.first_layer is None or stats.last_layer is None: + stats.first_layer = stats.last_layer = 0. + return stats + + +def adamw_logger(optimizer): + """ logging magnitude of first and second momentum buffers in adamw """ + # TODO: assert that optimizer is instance of torch.optim.AdamW + state = optimizer.state_dict().get('state') + exp_avg_stats = AverageMeter() + exp_avg_sq_stats = AverageMeter() + for key in state: + s = state.get(key) + exp_avg_stats.update(float(s.get('exp_avg').abs().mean())) + exp_avg_sq_stats.update(float(s.get('exp_avg_sq').abs().mean())) + return {'exp_avg': exp_avg_stats, 'exp_avg_sq': exp_avg_sq_stats} diff --git a/jepa_src/utils/monitoring.py b/jepa_src/utils/monitoring.py new file mode 100644 index 0000000..95a7845 --- /dev/null +++ b/jepa_src/utils/monitoring.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import dataclasses +import threading +from typing import Dict, Tuple + +import psutil + + +@dataclasses.dataclass +class ResourceStatsSample: + timestamp: float + cpu_percent: float + read_count: int + write_count: int + read_bytes: int + write_bytes: int + read_chars: int + write_chars: int + cpu_times_user: float + cpu_times_system: float + cpu_times_children_user: float + cpu_times_children_system: float + cpu_times_iowait: float + cpu_affinity: str + cpu_num: int + num_threads: int + num_voluntary_ctx_switches: int + num_involuntary_ctx_switches: int + + def as_tuple(self) -> Dict: + """Return values mirroring fields.""" + return dataclasses.astuple(self) + + def fields(self) -> Tuple[dataclasses.Field, ...]: + """Return fields in this dataclass.""" + return dataclasses.fields(self.__class__) + + +class ResourceMonitoringThread(threading.Thread): + def __init__(self, pid=None, refresh_interval=None, stats_callback_fn=None): + """Starts a thread to monitor pid every refresh_interval seconds. + + Passes a ResourceStatsSample object to the callback.""" + super(ResourceMonitoringThread, self).__init__() + if refresh_interval is None: + refresh_interval = 5 + self.is_running_event = threading.Event() + self.p = psutil.Process(pid) + self.refresh_interval = refresh_interval + if stats_callback_fn is None: + # Default callback + def stats_callback_fn(resource_sample: ResourceStatsSample): + print( + f"PID {self.p.pid} Stats: {resource_sample.resource_stats}") + elif not callable(stats_callback_fn): + raise ValueError("Callback needs to be callable, got {}".format( + type(stats_callback_fn))) + self.stats_callback_fn = stats_callback_fn + + def stop(self) -> None: + self.is_running_event.set() + + def run(self) -> None: + while not self.is_running_event.is_set(): + self.sample_counters() + self.is_running_event.wait(self.refresh_interval) + + def log_sample(self, resource_sample: ResourceStatsSample) -> None: + self.stats_callback_fn(resource_sample) + + def sample_counters(self) -> None: + if not self.p.is_running(): + self.stop() + return + + with self.p.oneshot(): + cpu_percent = self.p.cpu_percent() + cpu_times = self.p.cpu_times() + io_counters = self.p.io_counters() + cpu_affinity = self.p.cpu_affinity() + cpu_num = self.p.cpu_num() + num_threads = self.p.num_threads() + num_ctx_switches = self.p.num_ctx_switches() + timestamp = time.time() + + read_count = io_counters.read_count + write_count = io_counters.write_count + read_bytes = io_counters.read_bytes + write_bytes = io_counters.write_bytes + read_chars = io_counters.read_chars + write_chars = io_counters.write_chars + + def compress_cpu_affinity(cpu_affinity): + """Change list representation to interval/range representation.""" + if not cpu_affinity: + return "" + cpu_affinity_compressed = [] + min_x = None + max_x = None + last_x = None + + # Find contiguous ranges + for x in cpu_affinity: + if last_x is None: + # Start interval + min_x = x + max_x = x + last_x = x + continue + elif x == (last_x + 1): + # Move interval up + max_x = x + elif max_x is not None: + # Interval ended, start again + if min_x == max_x: + cpu_affinity_compressed.append("{}".format(min_x)) + else: + cpu_affinity_compressed.append( + "{}-{}".format(min_x, max_x)) + min_x = x + max_x = x + last_x = x + # Terminate last range + if max_x is not None: + if min_x == max_x: + cpu_affinity_compressed.append("{}".format(min_x)) + else: + cpu_affinity_compressed.append( + "{}-{}".format(min_x, max_x)) + + # Concat + cpu_affinity_compressed = ",".join(cpu_affinity_compressed) + + return cpu_affinity_compressed + + cpu_affinity = compress_cpu_affinity(cpu_affinity) + + resource_sample = ResourceStatsSample( + timestamp=timestamp, + cpu_percent=cpu_percent, + read_count=read_count, + write_count=write_count, + read_bytes=read_bytes, + write_bytes=write_bytes, + read_chars=read_chars, + write_chars=write_chars, + cpu_times_user=cpu_times.user, + cpu_times_system=cpu_times.system, + cpu_times_children_user=cpu_times.children_user, + cpu_times_children_system=cpu_times.children_system, + cpu_times_iowait=cpu_times.iowait, + cpu_affinity=cpu_affinity, + cpu_num=cpu_num, + num_threads=num_threads, + num_voluntary_ctx_switches=num_ctx_switches.voluntary, + num_involuntary_ctx_switches=num_ctx_switches.involuntary, + ) + self.log_sample(resource_sample) + + +if __name__ == "__main__": + import multiprocessing + import time + pid = multiprocessing.current_process().pid + monitor_thread = ResourceMonitoringThread(pid, 1) + monitor_thread.start() + time.sleep(5) + print("Shutdown") + monitor_thread.stop() diff --git a/jepa_src/utils/schedulers.py b/jepa_src/utils/schedulers.py new file mode 100644 index 0000000..df02e2b --- /dev/null +++ b/jepa_src/utils/schedulers.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + + +class WarmupCosineSchedule(object): + + def __init__( + self, + optimizer, + warmup_steps, + start_lr, + ref_lr, + T_max, + last_epoch=-1, + final_lr=0. + ): + self.optimizer = optimizer + self.start_lr = start_lr + self.ref_lr = ref_lr + self.final_lr = final_lr + self.warmup_steps = warmup_steps + self.T_max = T_max - warmup_steps + self._step = 0. + + def step(self): + self._step += 1 + if self._step < self.warmup_steps: + progress = float(self._step) / float(max(1, self.warmup_steps)) + new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) + else: + # -- progress after warmup + progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) + new_lr = max(self.final_lr, + self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1. + math.cos(math.pi * progress))) + + for group in self.optimizer.param_groups: + group['lr'] = new_lr + + return new_lr + + +class CosineWDSchedule(object): + + def __init__( + self, + optimizer, + ref_wd, + T_max, + final_wd=0. + ): + self.optimizer = optimizer + self.ref_wd = ref_wd + self.final_wd = final_wd + self.T_max = T_max + self._step = 0. + + def step(self): + self._step += 1 + progress = self._step / self.T_max + new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1. + math.cos(math.pi * progress)) + + if self.final_wd <= self.ref_wd: + new_wd = max(self.final_wd, new_wd) + else: + new_wd = min(self.final_wd, new_wd) + + for group in self.optimizer.param_groups: + if ('WD_exclude' not in group) or not group['WD_exclude']: + group['weight_decay'] = new_wd + return new_wd diff --git a/jepa_src/utils/tensors.py b/jepa_src/utils/tensors.py new file mode 100644 index 0000000..6ae2850 --- /dev/null +++ b/jepa_src/utils/tensors.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +import torch + +from logging import getLogger + +logger = getLogger() + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def apply_masks(x, masks): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors containing indices of patches [0,N) to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + return torch.cat(all_x, dim=0) + + +def repeat_interleave_batch(x, B, repeat): + N = len(x) // B + x = torch.cat([ + torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0) + for i in range(N) + ], dim=0) + return x diff --git a/requirements.txt b/requirements.txt index d297071..386919b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ -torch>=2 -torchvision pyyaml numpy opencv-python diff --git a/setup.py b/setup.py index 82de1e0..a852a6c 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,9 @@ import os from setuptools import setup -VERSION = "0.0.1" +VERSION = "0.0.2" + +from setuptools import setup, find_packages def get_requirements(): with open("./requirements.txt") as reqsf: @@ -17,9 +19,12 @@ def get_requirements(): if __name__ == "__main__": setup( - name="jepa", + name="vjepa_encoder", version=VERSION, description="JEPA research code.", + author="Jonathan Koch", + author_email="johnnykoch02@gmail.com", python_requires=">=3.9", + packages=find_packages(), install_requires=get_requirements(), - ) + ) \ No newline at end of file diff --git a/vjepa_encoder.egg-info/PKG-INFO b/vjepa_encoder.egg-info/PKG-INFO new file mode 100644 index 0000000..2fa15e3 --- /dev/null +++ b/vjepa_encoder.egg-info/PKG-INFO @@ -0,0 +1,19 @@ +Metadata-Version: 2.1 +Name: vjepa-encoder +Version: 0.0.2 +Summary: JEPA research code. +Author: Jonathan Koch +Author-email: johnnykoch02@gmail.com +Requires-Python: >=3.9 +License-File: LICENSE +Requires-Dist: pyyaml +Requires-Dist: numpy +Requires-Dist: opencv-python +Requires-Dist: submitit +Requires-Dist: braceexpand +Requires-Dist: webdataset +Requires-Dist: timm +Requires-Dist: decord +Requires-Dist: pandas +Requires-Dist: einops +Requires-Dist: beartype diff --git a/vjepa_encoder.egg-info/SOURCES.txt b/vjepa_encoder.egg-info/SOURCES.txt new file mode 100644 index 0000000..8b7b93f --- /dev/null +++ b/vjepa_encoder.egg-info/SOURCES.txt @@ -0,0 +1,47 @@ +LICENSE +README.md +setup.py +jepa_src/__init__.py +jepa_src/datasets/__init__.py +jepa_src/datasets/data_manager.py +jepa_src/datasets/image_dataset.py +jepa_src/datasets/video_dataset.py +jepa_src/datasets/utils/__init__.py +jepa_src/datasets/utils/weighted_sampler.py +jepa_src/datasets/utils/video/__init__.py +jepa_src/datasets/utils/video/functional.py +jepa_src/datasets/utils/video/randaugment.py +jepa_src/datasets/utils/video/randerase.py +jepa_src/datasets/utils/video/transforms.py +jepa_src/datasets/utils/video/volume_transforms.py +jepa_src/masks/__init__.py +jepa_src/masks/default.py +jepa_src/masks/multiblock3d.py +jepa_src/masks/random_tube.py +jepa_src/masks/utils.py +jepa_src/models/__init__.py +jepa_src/models/attentive_pooler.py +jepa_src/models/predictor.py +jepa_src/models/vision_transformer.py +jepa_src/models/utils/__init__.py +jepa_src/models/utils/modules.py +jepa_src/models/utils/multimask.py +jepa_src/models/utils/patch_embed.py +jepa_src/models/utils/pos_embs.py +jepa_src/utils/__init__.py +jepa_src/utils/distributed.py +jepa_src/utils/logging.py +jepa_src/utils/monitoring.py +jepa_src/utils/schedulers.py +jepa_src/utils/tensors.py +vjepa_encoder/__init__.py +vjepa_encoder/vision_encoder.py +vjepa_encoder.egg-info/PKG-INFO +vjepa_encoder.egg-info/SOURCES.txt +vjepa_encoder.egg-info/dependency_links.txt +vjepa_encoder.egg-info/requires.txt +vjepa_encoder.egg-info/top_level.txt +vjepa_encoder/vjepa/__init__.py +vjepa_encoder/vjepa/train.py +vjepa_encoder/vjepa/transforms.py +vjepa_encoder/vjepa/utils.py \ No newline at end of file diff --git a/vjepa_encoder.egg-info/dependency_links.txt b/vjepa_encoder.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/vjepa_encoder.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/vjepa_encoder.egg-info/requires.txt b/vjepa_encoder.egg-info/requires.txt new file mode 100644 index 0000000..386919b --- /dev/null +++ b/vjepa_encoder.egg-info/requires.txt @@ -0,0 +1,11 @@ +pyyaml +numpy +opencv-python +submitit +braceexpand +webdataset +timm +decord +pandas +einops +beartype diff --git a/vjepa_encoder.egg-info/top_level.txt b/vjepa_encoder.egg-info/top_level.txt new file mode 100644 index 0000000..b7a0b20 --- /dev/null +++ b/vjepa_encoder.egg-info/top_level.txt @@ -0,0 +1,2 @@ +jepa_src +vjepa_encoder diff --git a/vjepa_encoder/__init__.py b/vjepa_encoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vjepa_encoder/vision_encoder.py b/vjepa_encoder/vision_encoder.py new file mode 100644 index 0000000..7d74393 --- /dev/null +++ b/vjepa_encoder/vision_encoder.py @@ -0,0 +1,327 @@ +# Extension of Jepa by Robot Perception and Action Laboratory, USF +# +# Non-Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from typing import List, Optional, Any +import multiprocessing as mp + +import pprint +import yaml +import os + +import torch + +from jepa_src.utils.distributed import init_distributed + +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Tuple + +from vjepa_encoder.vjepa.utils import init_video_model +import numpy as np + +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +# from torch.nn.parallel import DistributedDataParallel +from jepa_src.utils.distributed import init_distributed, AllReduce +from jepa_src.utils.logging import get_logger + +from vjepa_encoder.vjepa.utils import init_video_model + +import torch +from torchvision import transforms +from PIL import Image +import numpy as np + +_GLOBAL_SEED = 0 +np.random.seed(_GLOBAL_SEED) +torch.manual_seed(_GLOBAL_SEED) +torch.backends.cudnn.benchmark = True + +import logging +from jepa_src.utils.logging import get_logger +logger = get_logger(force=True) +logger.setLevel(logging.INFO) + +class JepaEncoder(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + self.encoder, self.predictor = None, None + + def preprocess_image(self, input_data: Any): + """ + Preprocess the input image data. + + Args: + input_data (Any): Input data in various formats. + - str: Path to the image file. + - list: List of image data (numpy arrays, PIL Images, or tensors). + - numpy.ndarray: Image data as a numpy array. + - If the array has shape (batch_size, height, width, channels), it will be treated as a batch of images. + - If the array has shape (height, width, channels), it will be treated as a single image. + - PIL.Image.Image: Image data as a PIL Image object. + - torch.Tensor: Image data as a PyTorch tensor. + + Returns: + torch.Tensor: Preprocessed image data as a tensor. + - If the input is a batch of images, the output will have shape (batch_size, channels, height, width). + - If the input is a single image, the output will have shape (1, channels, height, width). + + Raises: + ValueError: If the input type is not supported. + """ + if isinstance(input_data, str): + img = Image.open(input_data).convert('RGB') + + elif isinstance(input_data, list): + imgs = [ + self.preprocess_image(i).squeeze() for i in input_data + ] + preprocessed_input = torch.stack(imgs) + return preprocessed_input + + elif isinstance(input_data, np.ndarray): + if len(input_data.shape) == 4: + input_data = input_data.transpose(0, 3, 1, 2) + preprocessed_input = torch.from_numpy(input_data).float() + preprocess = transforms.Compose([ + transforms.Resize(self.args['data']['crop_size']), + transforms.CenterCrop(self.args['data']['crop_size']), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + preprocessed_input = preprocess(preprocessed_input) + return preprocessed_input + + img = Image.fromarray(input_data.astype(np.uint8)) + + elif isinstance(input_data, Image.Image): + img = input_data + + elif isinstance(input_data, torch.Tensor): + preprocessed_input = input_data + preprocess = transforms.Compose([ + transforms.Resize(self.args['data']['crop_size']), + transforms.CenterCrop(self.args['data']['crop_size']), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + preprocessed_input = preprocess(preprocessed_input) + return preprocessed_input + + else: + raise ValueError("Unsupported input type. Expected image path, image array, or PIL Image.") + + # Define the preprocessing transforms + preprocess = transforms.Compose([ + transforms.Resize(self.args['data']['crop_size']), + transforms.CenterCrop(self.args['data']['crop_size']), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # Apply preprocessing transforms + preprocessed_input = preprocess(img) + + preprocessed_input = preprocessed_input.unsqueeze(0) # Add batch dimension + return preprocessed_input + + def embed_image(self, x): + """ + Generate embeddings for the input image data. + + Args: + x (Any): Input image data in various formats. + - str: Path to the image file. + - list: List of image data (numpy arrays, PIL Images, or tensors). + - numpy.ndarray: Image data as a numpy array. + - If the array has shape (batch_size, height, width, channels), it will be treated as a batch of images. + - If the array has shape (height, width, channels), it will be treated as a single image. + - PIL.Image.Image: Image data as a PIL Image object. + - torch.Tensor: Image data as a PyTorch tensor. + + Returns: + torch.Tensor: Embeddings for the input image data. + - If the input is a batch of images, the output will have shape (batch_size, num_patches, embedding_size). + - If the input is a single image, the output will have shape (1, num_patches, embedding_size). + + Notes: + - The input image data is preprocessed using the `preprocess_image` method before generating embeddings. + - If the preprocessed input has fewer than 5 dimensions, an additional dimension is added to represent the time dimension. + - The embeddings are generated using the forward pass of the model. + - The computation is performed on the available device (GPU if available, otherwise CPU). + """ + x = self.preprocess_image(x) + + # Unsqueeze along the time Dimension + if len(x.shape) < 5: + x = x.unsqueeze(2) + + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + + x = x.to(device) + + with torch.no_grad(): + embeddings = self.forward(x) + + return embeddings + + def load_encoder_checkpoint( + self, + r_path, + encoder, + ): + try: + checkpoint = torch.load(r_path, map_location=torch.device('cpu')) + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + + try: + + # -- loading encoder + pretrained_dict = checkpoint['encoder'] + msg = encoder.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') + + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + epoch = 0 + + return encoder + + + def forward(self, clips: torch.Tensor, masks_enc: List[torch.Tensor], masks_pred: List[torch.Tensor]) -> List[torch.Tensor]: + z = self.encoder(clips, masks_enc) + h = self._forward_target(clips, masks_pred) + z = self.predictor(z, h, masks_enc, masks_pred) + return z + + def freeze_encoder(self): + for p in self.encoder.parameters(): + p.requires_grad = False + + def forward(self, x): + return self.encoder(x) + + @classmethod + def load_model(cls, config_file_path: str, device: Optional[List[str]] = None) -> "JepaEncoder": + # TODO: Fix this so it works properly + # os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) + + args = None + with open(config_file_path, 'r') as y_file: + args = yaml.load(y_file, Loader=yaml.FullLoader) + logger.info('loaded params...') + + pprint.PrettyPrinter(indent=4).pprint(args) + dump = os.path.join(args['logging']['folder'], 'params-encoder.yaml') + with open(dump, 'w') as f: + yaml.dump(args, f) + + + model = cls(args) + + world_size, rank = init_distributed() + + # -- META + cfgs_meta = args.get('meta') + load_model = cfgs_meta.get('load_checkpoint') + assert load_model, "Cannot load model without checkpoint file specified" + r_file = cfgs_meta.get('read_checkpoint', None) + seed = cfgs_meta.get('seed', _GLOBAL_SEED) + save_every_freq = cfgs_meta.get('save_every_freq', -1) + skip_batches = cfgs_meta.get('skip_batches', -1) + use_sdpa = cfgs_meta.get('use_sdpa', False) + which_dtype = cfgs_meta.get('dtype') + logger.info(f'{which_dtype}') + if which_dtype.lower() == 'bfloat16': + dtype = torch.bfloat16 + mixed_precision = True + elif which_dtype.lower() == 'float16': + dtype = torch.float16 + mixed_precision = True + else: + dtype = torch.float32 + mixed_precision = False + + # -- MASK + cfgs_mask = args.get('mask') + + # -- MODEL + cfgs_model = args.get('model') + model_name = cfgs_model.get('model_name') + pred_depth = cfgs_model.get('pred_depth') + pred_embed_dim = cfgs_model.get('pred_embed_dim') + uniform_power = cfgs_model.get('uniform_power', True) + use_mask_tokens = cfgs_model.get('use_mask_tokens', True) + zero_init_mask_tokens = cfgs_model.get('zero_init_mask_tokens', True) + + # -- DATA + cfgs_data = args.get('data') + num_clips = cfgs_data.get('num_clips') + num_frames = cfgs_data.get('num_frames') + tubelet_size = cfgs_data.get('tubelet_size') + sampling_rate = cfgs_data.get('sampling_rate') + duration = cfgs_data.get('clip_duration', None) + crop_size = cfgs_data.get('crop_size', 224) + patch_size = cfgs_data.get('patch_size') + + # -- LOGGING + cfgs_logging = args.get('logging') + folder = cfgs_logging.get('folder') + tag = cfgs_logging.get('write_tag') + + # -- set device + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + torch.cuda.set_device(device) + + # -- log/checkpointing paths + latest_file = f'{tag}-latest.pth.tar' + latest_path = os.path.join(folder, latest_file) + load_path = None + if load_model: + load_path = os.path.join(folder, r_file) if r_file is not None else latest_path + if not os.path.exists(load_path): + load_path = r_file + if not os.path.exists(load_path): + raise RuntimeError("Cannot load model. Ensure you specify the path to the model .tar file in the input config.") + + # -- Attempt to initialize model + model.encoder, model.predictor = init_video_model( + uniform_power=uniform_power, + use_mask_tokens=use_mask_tokens, + num_mask_tokens=len(cfgs_mask), + zero_init_mask_tokens=zero_init_mask_tokens, + device=device, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + model_name=model_name, + crop_size=crop_size, + pred_depth=pred_depth, + pred_embed_dim=pred_embed_dim, + use_sdpa=use_sdpa, + ) + + # model.encoder = DistributedDataParallel(model.encoder, static_graph=True) + + # -- load training checkpoint + model.encoder = model.load_encoder_checkpoint( + load_path, model.encoder + ) + + return model + + diff --git a/vjepa_encoder/vjepa/__init__.py b/vjepa_encoder/vjepa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vjepa_encoder/vjepa/train.py b/vjepa_encoder/vjepa/train.py new file mode 100644 index 0000000..ccb2e75 --- /dev/null +++ b/vjepa_encoder/vjepa/train.py @@ -0,0 +1,586 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS +try: + # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE + # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE + # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE + # -- TO EACH PROCESS + os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID'] +except Exception: + pass + +import copy +import time +import numpy as np + +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel + +from jepa_src.datasets.data_manager import init_data +from jepa_src.masks.random_tube import MaskCollator as TubeMaskCollator +from jepa_src.masks.multiblock3d import MaskCollator as MB3DMaskCollator +from jepa_src.masks.utils import apply_masks +from jepa_src.utils.distributed import init_distributed, AllReduce +from jepa_src.utils.logging import ( + CSVLogger, + gpu_timer, + get_logger, + grad_logger, + adamw_logger, + AverageMeter) +from jepa_src.utils.tensors import repeat_interleave_batch + +from app.vjepa.utils import ( + load_checkpoint, + init_video_model, + init_opt, +) +from app.vjepa.transforms import make_transforms + + +# -- +log_timings = True +log_freq = 10 +checkpoint_freq = 1 +# -- + +_GLOBAL_SEED = 0 +np.random.seed(_GLOBAL_SEED) +torch.manual_seed(_GLOBAL_SEED) +torch.backends.cudnn.benchmark = True + + +logger = get_logger(__name__) + + +def main(args, resume_preempt=False): + # ----------------------------------------------------------------------- # + # PASSED IN PARAMS FROM CONFIG FILE + # ----------------------------------------------------------------------- # + + # -- META + cfgs_meta = args.get('meta') + load_model = cfgs_meta.get('load_checkpoint') or resume_preempt + r_file = cfgs_meta.get('read_checkpoint', None) + seed = cfgs_meta.get('seed', _GLOBAL_SEED) + save_every_freq = cfgs_meta.get('save_every_freq', -1) + skip_batches = cfgs_meta.get('skip_batches', -1) + use_sdpa = cfgs_meta.get('use_sdpa', False) + which_dtype = cfgs_meta.get('dtype') + logger.info(f'{which_dtype}') + if which_dtype.lower() == 'bfloat16': + dtype = torch.bfloat16 + mixed_precision = True + elif which_dtype.lower() == 'float16': + dtype = torch.float16 + mixed_precision = True + else: + dtype = torch.float32 + mixed_precision = False + + # -- MASK + cfgs_mask = args.get('mask') + + # -- MODEL + cfgs_model = args.get('model') + model_name = cfgs_model.get('model_name') + pred_depth = cfgs_model.get('pred_depth') + pred_embed_dim = cfgs_model.get('pred_embed_dim') + uniform_power = cfgs_model.get('uniform_power', True) + use_mask_tokens = cfgs_model.get('use_mask_tokens', True) + zero_init_mask_tokens = cfgs_model.get('zero_init_mask_tokens', True) + + # -- DATA + cfgs_data = args.get('data') + dataset_type = cfgs_data.get('dataset_type', 'videodataset') + mask_type = cfgs_data.get('mask_type', 'multiblock3d') + dataset_paths = cfgs_data.get('datasets', []) + datasets_weights = cfgs_data.get('datasets_weights', None) + if datasets_weights is not None: + assert len(datasets_weights) == len(dataset_paths), 'Must have one sampling weight specified for each dataset' + batch_size = cfgs_data.get('batch_size') + num_clips = cfgs_data.get('num_clips') + num_frames = cfgs_data.get('num_frames') + tubelet_size = cfgs_data.get('tubelet_size') + sampling_rate = cfgs_data.get('sampling_rate') + duration = cfgs_data.get('clip_duration', None) + crop_size = cfgs_data.get('crop_size', 224) + patch_size = cfgs_data.get('patch_size') + pin_mem = cfgs_data.get('pin_mem', False) + num_workers = cfgs_data.get('num_workers', 1) + filter_short_videos = cfgs_data.get('filter_short_videos', False) + decode_one_clip = cfgs_data.get('decode_one_clip', True) + log_resource_util_data = cfgs_data.get('log_resource_utilization', False) + + # -- DATA AUGS + cfgs_data_aug = args.get('data_aug') + ar_range = cfgs_data_aug.get('random_resize_aspect_ratio', [3/4, 4/3]) + rr_scale = cfgs_data_aug.get('random_resize_scale', [0.3, 1.0]) + motion_shift = cfgs_data_aug.get('motion_shift', False) + reprob = cfgs_data_aug.get('reprob', 0.) + use_aa = cfgs_data_aug.get('auto_augment', False) + + # -- LOSS + cfgs_loss = args.get('loss') + loss_exp = cfgs_loss.get('loss_exp') + reg_coeff = cfgs_loss.get('reg_coeff') + + # -- OPTIMIZATION + cfgs_opt = args.get('optimization') + ipe = cfgs_opt.get('ipe', None) + ipe_scale = cfgs_opt.get('ipe_scale', 1.0) + clip_grad = cfgs_opt.get('clip_grad', None) + wd = float(cfgs_opt.get('weight_decay')) + final_wd = float(cfgs_opt.get('final_weight_decay')) + num_epochs = cfgs_opt.get('epochs') + warmup = cfgs_opt.get('warmup') + start_lr = cfgs_opt.get('start_lr') + lr = cfgs_opt.get('lr') + final_lr = cfgs_opt.get('final_lr') + ema = cfgs_opt.get('ema') + betas = cfgs_opt.get('betas', (0.9, 0.999)) + eps = cfgs_opt.get('eps', 1.e-8) + + # -- LOGGING + cfgs_logging = args.get('logging') + folder = cfgs_logging.get('folder') + tag = cfgs_logging.get('write_tag') + + # ----------------------------------------------------------------------- # + # ----------------------------------------------------------------------- # + + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.benchmark = True + try: + mp.set_start_method('spawn') + except Exception: + pass + + # -- init torch distributed backend + world_size, rank = init_distributed() + logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') + + # -- set device + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + torch.cuda.set_device(device) + + # -- log/checkpointing paths + log_file = os.path.join(folder, f'{tag}_r{rank}.csv') + latest_file = f'{tag}-latest.pth.tar' + latest_path = os.path.join(folder, latest_file) + load_path = None + if load_model: + load_path = os.path.join(folder, r_file) if r_file is not None else latest_path + if not os.path.exists(load_path): + load_path = None + load_model = False + + # -- make csv_logger + csv_logger = CSVLogger( + log_file, + ('%d', 'epoch'), + ('%d', 'itr'), + ('%.5f', 'loss'), + ('%.5f', 'loss-jepa'), + ('%.5f', 'reg-loss'), + ('%.5f', 'enc-grad-norm'), + ('%.5f', 'pred-grad-norm'), + ('%d', 'gpu-time(ms)'), + ('%d', 'wall-time(ms)'), + ) + + # -- init model + encoder, predictor = init_video_model( + uniform_power=uniform_power, + use_mask_tokens=use_mask_tokens, + num_mask_tokens=len(cfgs_mask), + zero_init_mask_tokens=zero_init_mask_tokens, + device=device, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + model_name=model_name, + crop_size=crop_size, + pred_depth=pred_depth, + pred_embed_dim=pred_embed_dim, + use_sdpa=use_sdpa, + ) + target_encoder = copy.deepcopy(encoder) + + # -- make data transforms + if mask_type == 'multiblock3d': + logger.info('Initializing basic multi-block mask') + mask_collator = MB3DMaskCollator( + crop_size=crop_size, + num_frames=num_frames, + patch_size=patch_size, + tubelet_size=tubelet_size, + cfgs_mask=cfgs_mask) + else: + logger.info('Initializing random tube mask') + mask_collator = TubeMaskCollator( + crop_size=crop_size, + num_frames=num_frames, + patch_size=patch_size, + tubelet_size=tubelet_size, + cfgs_mask=cfgs_mask) + transform = make_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=ar_range, + random_resize_scale=rr_scale, + reprob=reprob, + auto_augment=use_aa, + motion_shift=motion_shift, + crop_size=crop_size) + + # -- init data-loaders/samplers + (unsupervised_loader, + unsupervised_sampler) = init_data( + data=dataset_type, + root_path=dataset_paths, + batch_size=batch_size, + training=True, + clip_len=num_frames, + frame_sample_rate=sampling_rate, + filter_short_videos=filter_short_videos, + decode_one_clip=decode_one_clip, + duration=duration, + num_clips=num_clips, + transform=transform, + datasets_weights=datasets_weights, + collator=mask_collator, + num_workers=num_workers, + world_size=world_size, + pin_mem=pin_mem, + rank=rank, + log_dir=folder if log_resource_util_data else None) + try: + _dlen = len(unsupervised_loader) + except Exception: # Different interface for webdataset + _dlen = unsupervised_loader.num_batches + if ipe is None: + ipe = _dlen + logger.info(f'iterations per epoch/dataest length: {ipe}/{_dlen}') + + # -- init optimizer and scheduler + optimizer, scaler, scheduler, wd_scheduler = init_opt( + encoder=encoder, + predictor=predictor, + wd=wd, + final_wd=final_wd, + start_lr=start_lr, + ref_lr=lr, + final_lr=final_lr, + iterations_per_epoch=ipe, + warmup=warmup, + num_epochs=num_epochs, + ipe_scale=ipe_scale, + mixed_precision=mixed_precision, + betas=betas, + eps=eps) + encoder = DistributedDataParallel(encoder, static_graph=True) + predictor = DistributedDataParallel(predictor, static_graph=True) + target_encoder = DistributedDataParallel(target_encoder) + for p in target_encoder.parameters(): + p.requires_grad = False + + # -- momentum schedule + momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale) + for i in range(int(ipe*num_epochs*ipe_scale)+1)) + + start_epoch = 0 + # -- load training checkpoint + if load_model or os.path.exists(latest_path): + ( + encoder, + predictor, + target_encoder, + optimizer, + scaler, + start_epoch, + ) = load_checkpoint( + r_path=load_path, + encoder=encoder, + predictor=predictor, + target_encoder=target_encoder, + opt=optimizer, + scaler=scaler) + for _ in range(start_epoch * ipe): + scheduler.step() + wd_scheduler.step() + next(momentum_scheduler) + mask_collator.step() + + def save_checkpoint(epoch, path): + if rank != 0: + return + save_dict = { + 'encoder': encoder.state_dict(), + 'predictor': predictor.state_dict(), + 'opt': optimizer.state_dict(), + 'scaler': None if scaler is None else scaler.state_dict(), + 'target_encoder': target_encoder.state_dict(), + 'epoch': epoch, + 'loss': loss_meter.avg, + 'batch_size': batch_size, + 'world_size': world_size, + 'lr': lr, + } + try: + torch.save(save_dict, path) + except Exception as e: + logger.info(f'Encountered exception when saving checkpoint: {e}') + + logger.info('Initializing loader...') + loader = iter(unsupervised_loader) + + if skip_batches > 0: + logger.info(f'Skip {skip_batches} batches') + unsupervised_sampler.set_epoch(start_epoch) + for itr in range(skip_batches): + if itr % 10 == 0: + logger.info(f'Skip {itr}/{skip_batches} batches') + try: + udata = next(loader) + except Exception: + loader = iter(unsupervised_loader) + udata = next(loader) + + # -- TRAINING LOOP + for epoch in range(start_epoch, num_epochs): + logger.info('Epoch %d' % (epoch + 1)) + + # -- update distributed-data-loader epoch + unsupervised_sampler.set_epoch(epoch) + + loss_meter = AverageMeter() + input_var_meter = AverageMeter() + input_var_min_meter = AverageMeter() + jepa_loss_meter = AverageMeter() + reg_loss_meter = AverageMeter() + mask_meters = [AverageMeter() for _ in range(len(cfgs_mask))] + gpu_time_meter = AverageMeter() + wall_time_meter = AverageMeter() + + for itr in range(ipe): + itr_start_time = time.time() + + try: + udata, masks_enc, masks_pred = next(loader) + except Exception: + logger.info('Exhausted data loaders. Refreshing...') + loader = iter(unsupervised_loader) + udata, masks_enc, masks_pred = next(loader) + assert len(masks_enc) == len(masks_pred), \ + 'Currently require num encoder masks = num predictor masks' + + def load_clips(): + # -- unsupervised video clips + # Put each clip on the GPU and concatenate along batch + # dimension + clips = torch.cat([u.to(device, non_blocking=True) for u in udata[0]], dim=0) + + # Put each mask-enc/mask-pred pair on the GPU and reuse the + # same mask pair for each clip + _masks_enc, _masks_pred = [], [] + for _me, _mp in zip(masks_enc, masks_pred): + _me = _me.to(device, non_blocking=True) + _mp = _mp.to(device, non_blocking=True) + _me = repeat_interleave_batch(_me, batch_size, repeat=num_clips) + _mp = repeat_interleave_batch(_mp, batch_size, repeat=num_clips) + _masks_enc.append(_me) + _masks_pred.append(_mp) + + return (clips, _masks_enc, _masks_pred) + clips, masks_enc, masks_pred = load_clips() + + for _i, m in enumerate(mask_meters): + m.update(masks_enc[_i][0].size(-1)) + + def train_step(): + _new_lr = scheduler.step() + _new_wd = wd_scheduler.step() + # -- + + def forward_target(c): + """ + Returns list of tensors of shape [B, N, D], one for each + mask-pred. + """ + with torch.no_grad(): + h = target_encoder(c) + h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim [B, N, D] + # -- create targets (masked regions of h) + h = apply_masks(h, masks_pred, concat=False) + return h + + def forward_context(c, h): + """ + Returns list of tensors of shape [B, N, D], one for each + mask-pred. + """ + z = encoder(c, masks_enc) + z = predictor(z, h, masks_enc, masks_pred) + return z + + def loss_fn(z, h): + loss = 0. + # Compute loss and accumulate for each mask-enc/mask-pred pair + for zi, hi in zip(z, h): + loss += torch.mean(torch.abs(zi - hi)**loss_exp) / loss_exp + loss /= len(masks_pred) + return loss + + def reg_fn(z): + return sum([torch.sqrt(zi.var(dim=1) + 0.0001) for zi in z]) / len(z) + + # Step 1. Forward + loss_jepa, loss_reg = 0., 0. + with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision): + h = forward_target(clips) + z = forward_context(clips, h) + loss_jepa = loss_fn(z, h) # jepa prediction loss + pstd_z = reg_fn(z) # predictor variance across patches + loss_reg += torch.mean(F.relu(1.-pstd_z)) + loss = loss_jepa + reg_coeff * loss_reg + + # Step 2. Backward & step + _enc_norm, _pred_norm = 0., 0. + if mixed_precision: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + else: + loss.backward() + if (epoch > warmup) and (clip_grad is not None): + _enc_norm = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip_grad) + _pred_norm = torch.nn.utils.clip_grad_norm_(predictor.parameters(), clip_grad) + if mixed_precision: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + grad_stats = grad_logger(encoder.named_parameters()) + grad_stats.global_norm = float(_enc_norm) + grad_stats_pred = grad_logger(predictor.named_parameters()) + grad_stats_pred.global_norm = float(_pred_norm) + optimizer.zero_grad() + optim_stats = adamw_logger(optimizer) + + # Step 3. momentum update of target encoder + m = next(momentum_scheduler) + with torch.no_grad(): + for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()): + param_k.data.mul_(m).add_((1.-m) * param_q.detach().data) + + return ( + float(loss), + float(loss_jepa), + float(loss_reg), + _new_lr, + _new_wd, + grad_stats, + grad_stats_pred, + optim_stats, + ) + (loss, loss_jepa, loss_reg, _new_lr, _new_wd, grad_stats, grad_stats_pred, optim_stats,), gpu_etime_ms = gpu_timer(train_step) + iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000. + loss_meter.update(loss) + input_var = float(AllReduce.apply(clips.view(clips.shape[0], -1).var(dim=1).mean(dim=0))) + input_var_min = float(AllReduce.apply(torch.min(clips.view(clips.shape[0], -1).var(dim=1)))) + input_var_meter.update(input_var) + input_var_min_meter.update(input_var_min) + jepa_loss_meter.update(loss_jepa) + reg_loss_meter.update(loss_reg) + gpu_time_meter.update(gpu_etime_ms) + wall_time_meter.update(iter_elapsed_time_ms) + + # -- Logging + def log_stats(): + csv_logger.log( + epoch + 1, + itr, + loss, + loss_jepa, + loss_reg, + grad_stats.global_norm, + grad_stats_pred.global_norm, + gpu_etime_ms, + iter_elapsed_time_ms) + if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss): + logger.info( + '[%d, %5d] loss: %.3f | p%.3f r%.3f | ' + 'input_var: %.3f %.3f | ' + 'masks: %s ' + '[wd: %.2e] [lr: %.2e] ' + '[mem: %.2e] ' + '[gpu: %.1f ms]' + '[wall: %.1f ms]' + % (epoch + 1, itr, + loss_meter.avg, + jepa_loss_meter.avg, + reg_loss_meter.avg, + input_var_meter.avg, + input_var_min_meter.avg, + '[' + ', '.join(['%.1f' % m.avg for m in mask_meters]) + ']', + _new_wd, + _new_lr, + torch.cuda.max_memory_allocated() / 1024.0**2, + gpu_time_meter.avg, + wall_time_meter.avg)) + + if optim_stats is not None: + logger.info( + '[%d, %5d] first moment: %.2e [%.2e %.2e] second moment: %.2e [%.2e %.2e]' + % (epoch + 1, itr, + optim_stats.get('exp_avg').avg, + optim_stats.get('exp_avg').min, + optim_stats.get('exp_avg').max, + optim_stats.get('exp_avg_sq').avg, + optim_stats.get('exp_avg_sq').min, + optim_stats.get('exp_avg_sq').max)) + + if grad_stats is not None: + logger.info( + '[%d, %5d] enc_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e' + % (epoch + 1, itr, + grad_stats.first_layer, + grad_stats.last_layer, + grad_stats.min, + grad_stats.max, + grad_stats.global_norm)) + + if grad_stats_pred is not None: + logger.info( + '[%d, %5d] pred_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e' + % (epoch + 1, itr, + grad_stats_pred.first_layer, + grad_stats_pred.last_layer, + grad_stats_pred.min, + grad_stats_pred.max, + grad_stats_pred.global_norm)) + log_stats() + assert not np.isnan(loss), 'loss is nan' + + # -- Save Checkpoint + logger.info('avg. loss %.3f' % loss_meter.avg) + # -- Save Last + if epoch % checkpoint_freq == 0 or epoch == (num_epochs - 1): + save_checkpoint(epoch + 1, latest_path) + if save_every_freq > 0 and epoch % save_every_freq == 0: + save_every_file = f'{tag}-e{epoch}.pth.tar' + save_every_path = os.path.join(folder, save_every_file) + save_checkpoint(epoch + 1, save_every_path) diff --git a/vjepa_encoder/vjepa/transforms.py b/vjepa_encoder/vjepa/transforms.py new file mode 100644 index 0000000..ba62555 --- /dev/null +++ b/vjepa_encoder/vjepa/transforms.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torchvision.transforms as transforms + +import jepa_src.datasets.utils.video.transforms as video_transforms +from jepa_src.datasets.utils.video.randerase import RandomErasing + + +def make_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=(3/4, 4/3), + random_resize_scale=(0.3, 1.0), + reprob=0.0, + auto_augment=False, + motion_shift=False, + crop_size=224, + normalize=((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) +): + + _frames_augmentation = VideoTransform( + random_horizontal_flip=random_horizontal_flip, + random_resize_aspect_ratio=random_resize_aspect_ratio, + random_resize_scale=random_resize_scale, + reprob=reprob, + auto_augment=auto_augment, + motion_shift=motion_shift, + crop_size=crop_size, + normalize=normalize, + ) + return _frames_augmentation + + +class VideoTransform(object): + + def __init__( + self, + random_horizontal_flip=True, + random_resize_aspect_ratio=(3/4, 4/3), + random_resize_scale=(0.3, 1.0), + reprob=0.0, + auto_augment=False, + motion_shift=False, + crop_size=224, + normalize=((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + ): + + self.random_horizontal_flip = random_horizontal_flip + self.random_resize_aspect_ratio = random_resize_aspect_ratio + self.random_resize_scale = random_resize_scale + self.auto_augment = auto_augment + self.motion_shift = motion_shift + self.crop_size = crop_size + self.mean = torch.tensor(normalize[0], dtype=torch.float32) + self.std = torch.tensor(normalize[1], dtype=torch.float32) + if not self.auto_augment: + # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255. + self.mean *= 255. + self.std *= 255. + + self.autoaug_transform = video_transforms.create_random_augment( + input_size=(crop_size, crop_size), + auto_augment='rand-m7-n4-mstd0.5-inc1', + interpolation='bicubic', + ) + + self.spatial_transform = video_transforms.random_resized_crop_with_shift \ + if motion_shift else video_transforms.random_resized_crop + + self.reprob = reprob + self.erase_transform = RandomErasing( + reprob, + mode='pixel', + max_count=1, + num_splits=1, + device='cpu', + ) + + def __call__(self, buffer): + + if self.auto_augment: + buffer = [transforms.ToPILImage()(frame) for frame in buffer] + buffer = self.autoaug_transform(buffer) + buffer = [transforms.ToTensor()(img) for img in buffer] + buffer = torch.stack(buffer) # T C H W + buffer = buffer.permute(0, 2, 3, 1) # T H W C + else: + buffer = torch.tensor(buffer, dtype=torch.float32) + + buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W + + buffer = self.spatial_transform( + images=buffer, + target_height=self.crop_size, + target_width=self.crop_size, + scale=self.random_resize_scale, + ratio=self.random_resize_aspect_ratio, + ) + if self.random_horizontal_flip: + buffer, _ = video_transforms.horizontal_flip(0.5, buffer) + + buffer = _tensor_normalize_inplace(buffer, self.mean, self.std) + if self.reprob > 0: + buffer = buffer.permute(1, 0, 2, 3) + buffer = self.erase_transform(buffer) + buffer = buffer.permute(1, 0, 2, 3) + + return buffer + + +def tensor_normalize(tensor, mean, std): + """ + Normalize a given tensor by subtracting the mean and dividing the std. + Args: + tensor (tensor): tensor to normalize. + mean (tensor or list): mean value to subtract. + std (tensor or list): std to divide. + """ + if tensor.dtype == torch.uint8: + tensor = tensor.float() + tensor = tensor / 255.0 + if type(mean) == list: + mean = torch.tensor(mean) + if type(std) == list: + std = torch.tensor(std) + tensor = tensor - mean + tensor = tensor / std + return tensor + + +def _tensor_normalize_inplace(tensor, mean, std): + """ + Normalize a given tensor by subtracting the mean and dividing the std. + Args: + tensor (tensor): tensor to normalize (with dimensions C, T, H, W). + mean (tensor): mean value to subtract (in 0 to 255 floats). + std (tensor): std to divide (in 0 to 255 floats). + """ + if tensor.dtype == torch.uint8: + tensor = tensor.float() + + C, T, H, W = tensor.shape + tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension + tensor.sub_(mean).div_(std) + tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front + return tensor diff --git a/vjepa_encoder/vjepa/utils.py b/vjepa_encoder/vjepa/utils.py new file mode 100644 index 0000000..2636ed7 --- /dev/null +++ b/vjepa_encoder/vjepa/utils.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging +import sys +import warnings +import yaml + + +import torch + +import jepa_src.models.vision_transformer as video_vit +import jepa_src.models.predictor as vit_pred +from jepa_src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper +from jepa_src.utils.schedulers import ( + WarmupCosineSchedule, + CosineWDSchedule) +from jepa_src.utils.tensors import trunc_normal_ + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging.getLogger() + + +def load_checkpoint( + r_path, + encoder, + predictor, + target_encoder, + opt, + scaler, +): + try: + checkpoint = torch.load(r_path, map_location=torch.device('cpu')) + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + + epoch = 0 + try: + epoch = checkpoint['epoch'] + + # -- loading encoder + pretrained_dict = checkpoint['encoder'] + msg = encoder.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') + + # -- loading predictor + pretrained_dict = checkpoint['predictor'] + msg = predictor.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained predictor from epoch {epoch} with msg: {msg}') + + # -- loading target_encoder + if target_encoder is not None: + print(list(checkpoint.keys())) + pretrained_dict = checkpoint['target_encoder'] + msg = target_encoder.load_state_dict(pretrained_dict) + logger.info( + f'loaded pretrained target encoder from epoch {epoch} with msg: {msg}' + ) + + # -- loading optimizer + opt.load_state_dict(checkpoint['opt']) + if scaler is not None: + scaler.load_state_dict(checkpoint['scaler']) + logger.info(f'loaded optimizers from epoch {epoch}') + logger.info(f'read-path: {r_path}') + del checkpoint + + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + epoch = 0 + + return ( + encoder, + predictor, + target_encoder, + opt, + scaler, + epoch, + ) + + +def init_video_model( + device, + patch_size=16, + num_frames=16, + tubelet_size=2, + model_name='vit_base', + crop_size=224, + pred_depth=6, + pred_embed_dim=384, + uniform_power=False, + use_mask_tokens=False, + num_mask_tokens=2, + zero_init_mask_tokens=True, + use_sdpa=False, +): + encoder = video_vit.__dict__[model_name]( + img_size=crop_size, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + uniform_power=uniform_power, + use_sdpa=use_sdpa, + ) + encoder = MultiMaskWrapper(encoder) + predictor = vit_pred.__dict__['vit_predictor']( + img_size=crop_size, + use_mask_tokens=use_mask_tokens, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + embed_dim=encoder.backbone.embed_dim, + predictor_embed_dim=pred_embed_dim, + depth=pred_depth, + num_heads=encoder.backbone.num_heads, + uniform_power=uniform_power, + num_mask_tokens=num_mask_tokens, + zero_init_mask_tokens=zero_init_mask_tokens, + use_sdpa=use_sdpa, + ) + predictor = PredictorMultiMaskWrapper(predictor) + + def init_weights(m): + if isinstance(m, torch.nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + elif isinstance(m, torch.nn.LayerNorm): + torch.nn.init.constant_(m.bias, 0) + torch.nn.init.constant_(m.weight, 1.0) + + for m in encoder.modules(): + init_weights(m) + + for m in predictor.modules(): + init_weights(m) + + encoder.to(device) + predictor.to(device) + logger.info(encoder) + logger.info(predictor) + + def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.info(f'Encoder number of parameters: {count_parameters(encoder)}') + logger.info(f'Predictor number of parameters: {count_parameters(predictor)}') + + return encoder, predictor + + +def init_opt( + encoder, + predictor, + iterations_per_epoch, + start_lr, + ref_lr, + warmup, + num_epochs, + wd=1e-6, + final_wd=1e-6, + final_lr=0.0, + mixed_precision=False, + ipe_scale=1.25, + betas=(0.9, 0.999), + eps=1e-8, + zero_init_bias_wd=True, +): + param_groups = [ + { + 'params': (p for n, p in encoder.named_parameters() + if ('bias' not in n) and (len(p.shape) != 1)) + }, { + 'params': (p for n, p in predictor.named_parameters() + if ('bias' not in n) and (len(p.shape) != 1)) + }, { + 'params': (p for n, p in encoder.named_parameters() + if ('bias' in n) or (len(p.shape) == 1)), + 'WD_exclude': zero_init_bias_wd, + 'weight_decay': 0, + }, { + 'params': (p for n, p in predictor.named_parameters() + if ('bias' in n) or (len(p.shape) == 1)), + 'WD_exclude': zero_init_bias_wd, + 'weight_decay': 0, + }, + ] + + logger.info('Using AdamW') + optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps) + scheduler = WarmupCosineSchedule( + optimizer, + warmup_steps=int(warmup * iterations_per_epoch), + start_lr=start_lr, + ref_lr=ref_lr, + final_lr=final_lr, + T_max=int(ipe_scale * num_epochs * iterations_per_epoch), + ) + wd_scheduler = CosineWDSchedule( + optimizer, + ref_wd=wd, + final_wd=final_wd, + T_max=int(ipe_scale * num_epochs * iterations_per_epoch), + ) + scaler = torch.cuda.amp.GradScaler() if mixed_precision else None + return optimizer, scaler, scheduler, wd_scheduler From 01522ee322282e0341e8aa01e18b9236da46acc5 Mon Sep 17 00:00:00 2001 From: Johnnykoch02 Date: Tue, 16 Apr 2024 21:17:37 -0400 Subject: [PATCH 2/4] Update python compatibility: 3.7>= --- .gitignore | 8 + README.md | 434 +----- app/main.py | 4 +- app/main_distributed.py | 2 +- app/vjepa/train.py | 16 +- app/vjepa/transforms.py | 4 +- app/vjepa/utils.py | 10 +- {src => build/lib}/datasets/data_manager.py | 4 +- {src => build/lib}/datasets/image_dataset.py | 0 .../lib}/datasets/utils/video/functional.py | 0 .../lib}/datasets/utils/video/randaugment.py | 0 .../lib}/datasets/utils/video/randerase.py | 0 .../lib}/datasets/utils/video/transforms.py | 4 +- .../datasets/utils/video/volume_transforms.py | 0 .../lib}/datasets/utils/weighted_sampler.py | 0 {src => build/lib}/datasets/video_dataset.py | 8 +- build/lib/jepa_src/__init__.py | 0 build/lib/jepa_src/datasets/__init__.py | 0 build/lib/jepa_src/datasets/data_manager.py | 91 ++ build/lib/jepa_src/datasets/image_dataset.py | 79 ++ build/lib/jepa_src/datasets/utils/__init__.py | 0 .../jepa_src/datasets/utils/video/__init__.py | 0 .../datasets/utils/video/functional.py | 96 ++ .../datasets/utils/video/randaugment.py | 518 ++++++++ .../datasets/utils/video/randerase.py | 180 +++ .../datasets/utils/video/transforms.py | 1184 +++++++++++++++++ .../datasets/utils/video/volume_transforms.py | 151 +++ .../datasets/utils/weighted_sampler.py | 97 ++ build/lib/jepa_src/datasets/video_dataset.py | 272 ++++ build/lib/jepa_src/masks/__init__.py | 0 {src => build/lib/jepa_src}/masks/default.py | 0 .../lib/jepa_src}/masks/multiblock3d.py | 0 .../lib/jepa_src}/masks/random_tube.py | 0 {src => build/lib/jepa_src}/masks/utils.py | 0 build/lib/jepa_src/models/__init__.py | 0 .../lib/jepa_src}/models/attentive_pooler.py | 4 +- .../lib/jepa_src}/models/predictor.py | 8 +- build/lib/jepa_src/models/utils/__init__.py | 0 build/lib/jepa_src/models/utils/functional.py | 30 + .../lib/jepa_src}/models/utils/modules.py | 5 +- .../lib/jepa_src}/models/utils/multimask.py | 0 .../lib/jepa_src}/models/utils/patch_embed.py | 0 .../lib/jepa_src}/models/utils/pos_embs.py | 0 .../jepa_src}/models/vision_transformer.py | 10 +- build/lib/jepa_src/utils/__init__.py | 0 .../lib/jepa_src}/utils/distributed.py | 0 build/lib/jepa_src/utils/functional.py | 30 + {src => build/lib/jepa_src}/utils/logging.py | 0 .../lib/jepa_src}/utils/monitoring.py | 0 .../lib/jepa_src}/utils/schedulers.py | 0 {src => build/lib/jepa_src}/utils/tensors.py | 0 build/lib/masks/default.py | 20 + build/lib/masks/multiblock3d.py | 203 +++ build/lib/masks/random_tube.py | 117 ++ build/lib/masks/utils.py | 23 + build/lib/models/attentive_pooler.py | 136 ++ build/lib/models/predictor.py | 246 ++++ build/lib/models/utils/modules.py | 185 +++ build/lib/models/utils/multimask.py | 48 + build/lib/models/utils/patch_embed.py | 57 + build/lib/models/utils/pos_embs.py | 99 ++ build/lib/models/vision_transformer.py | 307 +++++ build/lib/utils/distributed.py | 113 ++ build/lib/utils/logging.py | 118 ++ build/lib/utils/monitoring.py | 175 +++ build/lib/utils/schedulers.py | 76 ++ build/lib/utils/tensors.py | 71 + build/lib/vjepa_encoder/__init__.py | 0 build/lib/vjepa_encoder/vision_encoder.py | 329 +++++ build/lib/vjepa_encoder/vjepa/__init__.py | 0 build/lib/vjepa_encoder/vjepa/train.py | 586 ++++++++ build/lib/vjepa_encoder/vjepa/transforms.py | 153 +++ build/lib/vjepa_encoder/vjepa/utils.py | 210 +++ demo_jepa_encoder.py | 22 + evals/image_classification_frozen/eval.py | 12 +- evals/main.py | 2 +- evals/video_classification_frozen/eval.py | 12 +- evals/video_classification_frozen/utils.py | 10 +- fair_documentation.md | 407 ++++++ jepa_encoder.egg-info/PKG-INFO | 17 + jepa_encoder.egg-info/SOURCES.txt | 10 + jepa_encoder.egg-info/dependency_links.txt | 1 + jepa_encoder.egg-info/requires.txt | 11 + jepa_encoder.egg-info/top_level.txt | 1 + jepa_src/__init__.py | 0 jepa_src/datasets/__init__.py | 0 jepa_src/datasets/data_manager.py | 91 ++ jepa_src/datasets/image_dataset.py | 79 ++ jepa_src/datasets/utils/__init__.py | 0 jepa_src/datasets/utils/video/__init__.py | 0 jepa_src/datasets/utils/video/functional.py | 96 ++ jepa_src/datasets/utils/video/randaugment.py | 518 ++++++++ jepa_src/datasets/utils/video/randerase.py | 180 +++ jepa_src/datasets/utils/video/transforms.py | 1184 +++++++++++++++++ .../datasets/utils/video/volume_transforms.py | 151 +++ jepa_src/datasets/utils/weighted_sampler.py | 97 ++ jepa_src/datasets/video_dataset.py | 272 ++++ jepa_src/masks/__init__.py | 0 jepa_src/masks/default.py | 20 + jepa_src/masks/multiblock3d.py | 203 +++ jepa_src/masks/random_tube.py | 117 ++ jepa_src/masks/utils.py | 23 + jepa_src/models/__init__.py | 0 jepa_src/models/attentive_pooler.py | 136 ++ jepa_src/models/predictor.py | 246 ++++ jepa_src/models/utils/__init__.py | 0 jepa_src/models/utils/modules.py | 184 +++ jepa_src/models/utils/multimask.py | 48 + jepa_src/models/utils/patch_embed.py | 57 + jepa_src/models/utils/pos_embs.py | 99 ++ jepa_src/models/vision_transformer.py | 307 +++++ jepa_src/utils/__init__.py | 0 jepa_src/utils/distributed.py | 113 ++ jepa_src/utils/functional.py | 30 + jepa_src/utils/logging.py | 118 ++ jepa_src/utils/monitoring.py | 175 +++ jepa_src/utils/schedulers.py | 76 ++ jepa_src/utils/tensors.py | 71 + requirements.txt | 2 - setup.py | 13 +- vjepa_encoder.egg-info/PKG-INFO | 11 + vjepa_encoder.egg-info/SOURCES.txt | 48 + vjepa_encoder.egg-info/dependency_links.txt | 1 + vjepa_encoder.egg-info/requires.txt | 11 + vjepa_encoder.egg-info/top_level.txt | 2 + vjepa_encoder/__init__.py | 0 vjepa_encoder/vision_encoder.py | 327 +++++ vjepa_encoder/vjepa/__init__.py | 0 vjepa_encoder/vjepa/train.py | 586 ++++++++ vjepa_encoder/vjepa/transforms.py | 153 +++ vjepa_encoder/vjepa/utils.py | 210 +++ 131 files changed, 12640 insertions(+), 441 deletions(-) rename {src => build/lib}/datasets/data_manager.py (94%) rename {src => build/lib}/datasets/image_dataset.py (100%) rename {src => build/lib}/datasets/utils/video/functional.py (100%) rename {src => build/lib}/datasets/utils/video/randaugment.py (100%) rename {src => build/lib}/datasets/utils/video/randerase.py (100%) rename {src => build/lib}/datasets/utils/video/transforms.py (99%) rename {src => build/lib}/datasets/utils/video/volume_transforms.py (100%) rename {src => build/lib}/datasets/utils/weighted_sampler.py (100%) rename {src => build/lib}/datasets/video_dataset.py (97%) create mode 100644 build/lib/jepa_src/__init__.py create mode 100644 build/lib/jepa_src/datasets/__init__.py create mode 100644 build/lib/jepa_src/datasets/data_manager.py create mode 100644 build/lib/jepa_src/datasets/image_dataset.py create mode 100644 build/lib/jepa_src/datasets/utils/__init__.py create mode 100644 build/lib/jepa_src/datasets/utils/video/__init__.py create mode 100644 build/lib/jepa_src/datasets/utils/video/functional.py create mode 100644 build/lib/jepa_src/datasets/utils/video/randaugment.py create mode 100644 build/lib/jepa_src/datasets/utils/video/randerase.py create mode 100644 build/lib/jepa_src/datasets/utils/video/transforms.py create mode 100644 build/lib/jepa_src/datasets/utils/video/volume_transforms.py create mode 100644 build/lib/jepa_src/datasets/utils/weighted_sampler.py create mode 100644 build/lib/jepa_src/datasets/video_dataset.py create mode 100644 build/lib/jepa_src/masks/__init__.py rename {src => build/lib/jepa_src}/masks/default.py (100%) rename {src => build/lib/jepa_src}/masks/multiblock3d.py (100%) rename {src => build/lib/jepa_src}/masks/random_tube.py (100%) rename {src => build/lib/jepa_src}/masks/utils.py (100%) create mode 100644 build/lib/jepa_src/models/__init__.py rename {src => build/lib/jepa_src}/models/attentive_pooler.py (97%) rename {src => build/lib/jepa_src}/models/predictor.py (97%) create mode 100644 build/lib/jepa_src/models/utils/__init__.py create mode 100644 build/lib/jepa_src/models/utils/functional.py rename {src => build/lib/jepa_src}/models/utils/modules.py (96%) rename {src => build/lib/jepa_src}/models/utils/multimask.py (100%) rename {src => build/lib/jepa_src}/models/utils/patch_embed.py (100%) rename {src => build/lib/jepa_src}/models/utils/pos_embs.py (100%) rename {src => build/lib/jepa_src}/models/vision_transformer.py (96%) create mode 100644 build/lib/jepa_src/utils/__init__.py rename {src => build/lib/jepa_src}/utils/distributed.py (100%) create mode 100644 build/lib/jepa_src/utils/functional.py rename {src => build/lib/jepa_src}/utils/logging.py (100%) rename {src => build/lib/jepa_src}/utils/monitoring.py (100%) rename {src => build/lib/jepa_src}/utils/schedulers.py (100%) rename {src => build/lib/jepa_src}/utils/tensors.py (100%) create mode 100644 build/lib/masks/default.py create mode 100644 build/lib/masks/multiblock3d.py create mode 100644 build/lib/masks/random_tube.py create mode 100644 build/lib/masks/utils.py create mode 100644 build/lib/models/attentive_pooler.py create mode 100644 build/lib/models/predictor.py create mode 100644 build/lib/models/utils/modules.py create mode 100644 build/lib/models/utils/multimask.py create mode 100644 build/lib/models/utils/patch_embed.py create mode 100644 build/lib/models/utils/pos_embs.py create mode 100644 build/lib/models/vision_transformer.py create mode 100644 build/lib/utils/distributed.py create mode 100644 build/lib/utils/logging.py create mode 100644 build/lib/utils/monitoring.py create mode 100644 build/lib/utils/schedulers.py create mode 100644 build/lib/utils/tensors.py create mode 100644 build/lib/vjepa_encoder/__init__.py create mode 100644 build/lib/vjepa_encoder/vision_encoder.py create mode 100644 build/lib/vjepa_encoder/vjepa/__init__.py create mode 100644 build/lib/vjepa_encoder/vjepa/train.py create mode 100644 build/lib/vjepa_encoder/vjepa/transforms.py create mode 100644 build/lib/vjepa_encoder/vjepa/utils.py create mode 100644 demo_jepa_encoder.py create mode 100644 fair_documentation.md create mode 100644 jepa_encoder.egg-info/PKG-INFO create mode 100644 jepa_encoder.egg-info/SOURCES.txt create mode 100644 jepa_encoder.egg-info/dependency_links.txt create mode 100644 jepa_encoder.egg-info/requires.txt create mode 100644 jepa_encoder.egg-info/top_level.txt create mode 100644 jepa_src/__init__.py create mode 100644 jepa_src/datasets/__init__.py create mode 100644 jepa_src/datasets/data_manager.py create mode 100644 jepa_src/datasets/image_dataset.py create mode 100644 jepa_src/datasets/utils/__init__.py create mode 100644 jepa_src/datasets/utils/video/__init__.py create mode 100644 jepa_src/datasets/utils/video/functional.py create mode 100644 jepa_src/datasets/utils/video/randaugment.py create mode 100644 jepa_src/datasets/utils/video/randerase.py create mode 100644 jepa_src/datasets/utils/video/transforms.py create mode 100644 jepa_src/datasets/utils/video/volume_transforms.py create mode 100644 jepa_src/datasets/utils/weighted_sampler.py create mode 100644 jepa_src/datasets/video_dataset.py create mode 100644 jepa_src/masks/__init__.py create mode 100644 jepa_src/masks/default.py create mode 100644 jepa_src/masks/multiblock3d.py create mode 100644 jepa_src/masks/random_tube.py create mode 100644 jepa_src/masks/utils.py create mode 100644 jepa_src/models/__init__.py create mode 100644 jepa_src/models/attentive_pooler.py create mode 100644 jepa_src/models/predictor.py create mode 100644 jepa_src/models/utils/__init__.py create mode 100644 jepa_src/models/utils/modules.py create mode 100644 jepa_src/models/utils/multimask.py create mode 100644 jepa_src/models/utils/patch_embed.py create mode 100644 jepa_src/models/utils/pos_embs.py create mode 100644 jepa_src/models/vision_transformer.py create mode 100644 jepa_src/utils/__init__.py create mode 100644 jepa_src/utils/distributed.py create mode 100644 jepa_src/utils/functional.py create mode 100644 jepa_src/utils/logging.py create mode 100644 jepa_src/utils/monitoring.py create mode 100644 jepa_src/utils/schedulers.py create mode 100644 jepa_src/utils/tensors.py create mode 100644 vjepa_encoder.egg-info/PKG-INFO create mode 100644 vjepa_encoder.egg-info/SOURCES.txt create mode 100644 vjepa_encoder.egg-info/dependency_links.txt create mode 100644 vjepa_encoder.egg-info/requires.txt create mode 100644 vjepa_encoder.egg-info/top_level.txt create mode 100644 vjepa_encoder/__init__.py create mode 100644 vjepa_encoder/vision_encoder.py create mode 100644 vjepa_encoder/vjepa/__init__.py create mode 100644 vjepa_encoder/vjepa/train.py create mode 100644 vjepa_encoder/vjepa/transforms.py create mode 100644 vjepa_encoder/vjepa/utils.py diff --git a/.gitignore b/.gitignore index 3bb2efd..bbe5ec0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,10 @@ .*.swp *.pyc +# *.tar + +bin/ +dist/ +.vscode/ +logs/ + +jepa_src/jepa.egg-info/ \ No newline at end of file diff --git a/README.md b/README.md index a3579e1..5643126 100644 --- a/README.md +++ b/README.md @@ -1,407 +1,85 @@ -# V-JEPA: Video Joint Embedding Predictive Architecture - -Official PyTorch codebase for the _video joint-embedding predictive architecture_, V-JEPA, a method for self-supervised learning of visual representations from video. - -**[Meta AI Research, FAIR](https://ai.facebook.com/research/)** - -Adrien Bardes, Quentin Garrido, Jean Ponce, Xinlei Chen, Michael Rabbat, Yann LeCun, Mahmoud Assran*, Nicolas Ballas* - -[\[Blog\]](https://ai.meta.com/blog/v-jepa-yann-lecun-ai-model-video-joint-embedding-predictive-architecture/) -[\[Paper\]](https://ai.meta.com/research/publications/revisiting-feature-prediction-for-learning-visual-representations-from-video/) -[\[Yannic Kilcher's Video\]](https://www.youtube.com/watch?v=7UkJPwz_N_0) - -V-JEPA models are trained by passively watching video pixels from the VideoMix2M dataset, and produce versatile visual representations that perform well on downstream video and image tasks, without adaption of the model’s parameters; e.g., using a frozen backbone and only a light-weight task-specific attentive probe. - -## Method -V-JEPA pretraining is based solely on an unsupervised feature prediction objective, and does not utilize pretrained image encoders, text, negative examples, human annotations, or pixel-level reconstruction. - - - -      - - - - -## Visualizations -As opposed to generative methods that have a pixel decoder, V-JEPA has a predictor that makes predictions in latent space. -We train a conditional diffusion model to decode the V-JEPA feature-space predictions to interpretable pixels; the pretrained V-JEPA encoder and predictor networks are kept frozen in this process. -The decoder is only fed the representations predicted for the missing regions of the video, and does not have access to the unmasked regions of the video. - -The V-JEPA feature predictions are indeed grounded, and exhibit spatio-temporal consistency with the unmasked regions of the video. - - -
- - - - -
- -## MODEL ZOO - -#### Pretrained models - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
modelpatch sizeresolutioniterationsbatch sizedatadownload
ViT-L2x16x16224x22490K3072VideoMix2Mcheckpointconfigs
ViT-H2x16x16224x22490K3072VideoMix2Mcheckpointconfigs
ViT-H2x16x16384x38490K2400VideoMix2Mcheckpointconfigs
- -#### K400 Attentive probes - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
modelresolutionaccuracy (16x8x3)download
ViT-L/16224x22480.8attentive probe checkpointconfigs
ViT-H/16224x22482.0attentive probe checkpointconfigs
ViT-H/16384x38481.9attentive probe checkpointconfigs
- -#### SSv2 Attentive probes - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
modelresolutionaccuracy (16x2x3)download
ViT-L/16224x22469.5attentive probe checkpointconfigs
ViT-H/16224x22471.4attentive probe checkpointconfigs
ViT-H/16384x38472.2attentive probe checkpointconfigs
- -#### ImageNet1K Attentive probes - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
modelresolutionaccuracydownload
ViT-L/16224x22474.8attentive probe checkpointconfigs
ViT-H/16224x22475.9attentive probe checkpointconfigs
ViT-H/16384x38477.4attentive probe checkpointconfigs
- -#### Places205 Attentive probes - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
modelresolutionaccuracydownload
ViT-L/16224x22460.3attentive probe checkpointconfigs
ViT-H/16224x22461.7attentive probe checkpointconfigs
ViT-H/16384x38462.8attentive probe checkpointconfigs
- -#### iNat21 Attentive probes - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
modelresolutionaccuracydownload
ViT-L/16224x22467.8attentive probe checkpointconfigs
ViT-H/16224x22467.9attentive probe checkpointconfigs
ViT-H/16384x38472.6attentive probe checkpointconfigs
- -## Code Structure - -**Config files:** -All experiment parameters are specified in config files (as opposed to command-line arguments). See the [configs/](configs/) directory for example config files. Note, before launching an experiment, you must update the paths in the config file to point to your own directories, indicating where to save the logs and checkpoints and where to find the training data. + VJEPA Encoder +The VJEPA Encoder is a Python package that provides an implementation of the encoder component from the JEPA (Joint Encoding for Prediction and Alignment) architecture proposed by Facebook AI Research. The encoder is designed to extract meaningful representations from visual data. I do not own the rights or lay claim to the copyright of this software. This package is an adaptation to `facebookresearch/jepa` to enable ease of use of the Jepa Architecture built with Vision Transformers. -``` -. -├── app # the only place where training loops are allowed -│ ├── vjepa # Video JEPA pre-training -│ ├── main_distributed.py # entrypoint for launching app on slurm cluster -│ └── main.py # entrypoint for launching app locally on your machine for debugging -├── evals # the only place where evaluation of 'apps' are allowed -│ ├── image_classification # training an attentive probe for image classification with frozen backbone -│ ├── video_classification # training an attentive probe for video classification with frozen backbone -│ ├── main_distributed.py # entrypoint for launching distributed evaluations on slurm cluster -│ └── main.py # entrypoint for launching evaluations locally on your machine for debugging -├── src # the package -│ ├── datasets # datasets, data loaders, ... -│ ├── models # model definitions -│ ├── masks # mask collators, masking utilities, ... -│ └── utils # shared utilities -└── configs # the only place where config files are allowed (specify experiment params for app/eval runs) - ├── evals # configs for launching vjepa frozen evaluations - └── pretrain # configs for launching vjepa pretraining - -``` +## Installation -## Data preparation +To install the VJEPA Encoder package, you can use pip: -### Video Datasets -V-JEPA pretraining and evaluations work with many standard video formats. -To make a video dataset compatible with the V-JEPA codebase, you simply need to create a `.csv` file with the following format and then specify the path to this CSV file in your config. ``` -/absolute_file_path.[mp4, webvid, etc.] $integer_class_label -/absolute_file_path.[mp4, webvid, etc.] $integer_class_label -/absolute_file_path.[mp4, webvid, etc.] $integer_class_label -... +pip install vjepa_encoder ``` -Since V-JEPA is entirely unsupervised, the pretraining code will disregard the `$integer_class_label` in the CSV file. -Thus, feel free to put a random value in this column. -However, if you wish to run a supervised video classification evaluation on your video dataset, you must replace ```$integer_class_label``` with the ground truth label for each video. -### Image Datasets -We use the standard PyTorch ```ImageFolder``` class in our image classification evals. -Thus, to set up an image dataset for the image classification evaluation, first create a directory to store your image datasets ```$your_directory_containing_image_datasets```. -Next, download your image datasets into this directory in a format compatible with [PyTorch ImageFolder](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html). +## Usage -For example, suppose we have a directory called ``my_image_datasets``. We would then download our image datasets into this directory so that we end up with the following file tree -``` -. -└── /my_image_datasets/ # where we store image datasets - ├── places205/121517/pytorch/ # Places205 - │ └── [...] - ├── iNaturalist-2021/110421/ # iNaturalist21 - │ └── [...] - ├── [...] # Other Image Datasets - │ └── [...] - └── imagenet_full_size/061417/ # ImageNet1k - └── train - │ ├── $class_1 - │ │ ├── xxx.[png, jpeg, etc.] - │ │ ├── [...] - │ │ └── xxz.[png, jpeg, etc.] - │ ├── [...] - │ └── $class_n - │ ├── abc.[png, jpeg, etc.] - │ ├── [...] - │ └── abz.[png, jpeg, etc.] - └── val - ├── $class_1 - │ ├── xxx.[png, jpeg, etc.] - │ ├── [...] - │ └── xxz.[png, jpeg, etc.] - ├── [...] - └── $class_n - ├── abc.[png, jpeg, etc.] - ├── [...] - └── abz.[png, jpeg, etc.] -``` +To use the VJEPA Encoder in your Python code, you can import it as follows: +```python +from vjepa_encoder.vision_encoder import JepaEncoder +``` -## Launching V-JEPA pretraining +### Loading the Encoder -### Local training -If you wish to debug your code or setup before launching a distributed training run, we provide the functionality to do so by running the pretraining script locally on a multi-GPU (or single-GPU) machine, however, reproducing our results requires launching distributed training. +To load the pre-trained encoder, you can use the `load_model` function: -The single-machine implementation starts from the [app/main.py](appmain.py), which parses the experiment config file and runs the pretraining locally on a multi-GPU (or single-GPU) machine. -For example, to run V-JEPA pretraining on GPUs "0", "1", and "2" on a local machine using the config [configs/pretrain/vitl16.yaml](configs/pretrain/vitl16.yaml), type the command: -```bash -python -m app.main \ - --fname configs/pretrain/vitl16.yaml \ - --devices cuda:0 cuda:1 cuda:2 +```python +config_file_path = "./params-encoder.yaml" +devices = ["cuda:0"] +encoder = JepaEncoder.load_model(config_file_path, devices) ``` -### Distributed training -To launch a distributed training run, the implementation starts from [app/main_distributed.py](app/main_distributed.py), which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source [submitit](https://github.com/facebookincubator/submitit) tool and provide examples for a SLURM cluster. +- `config_file_path`: Path to the configuration file (YAML) containing the model settings. +- `devices`: List of devices (e.g., `['cuda:0']`) to use for distributed training. If not provided, the model will be loaded on the CPU. -For example, to launch a distributed pre-training experiment using the config [configs/pretrain/vitl16.yaml](configs/pretrain/vitl16.yaml), type the command: -```bash -python -m app.main_distributed \ - --fname configs/pretrain/vitl16.yaml \ - --folder $path_to_save_stderr_and_stdout \ - --partition $slurm_partition -``` -## Launching Evaluations +#### Important Notes about the Config File: + +- the config file provided in this repo provides the basics for loading and using the encoder. The most important things to note in this file are the `r_checkpoint`: points at the `.tar` file for the JEPA checkpoint, and the `tabulet_size`: this is used in some temporal calculation and if you plan on embedding images you should set this to `1`; set this to `N` if you plan on using a temporal dimension for your data, where N corresponds to however many temporal inputs you have. -### Local training -If you wish to debug your eval code or setup before launching a distributed training run, we provide the functionality to do so by running the evaluation script locally on a multi-GPU (or single-GPU) machine, however, reproducing the full eval would require launching distributed training. -The single-machine implementation starts from the [eval/main.py](eval/main.py), which parses the experiment config file and runs the eval locally on a multi-GPU (or single-GPU) machine. +### Preprocessing Data -For example, to run ImageNet image classification on GPUs "0", "1", and "2" on a local machine using the config [configs/eval/vitl16_in1k.yaml](configs/eval/vitl16_in1k.yaml), type the command: -```bash -python -m evals.main \ - --fname configs/eval/vitl16_in1k.yaml \ - --devices cuda:0 cuda:1 cuda:2 +The VJEPA Encoder provides a `preprocess_data` function to preprocess input data before feeding it to the encoder: + +```python +preprocessed_data = encoder.preprocess_data(input_data) ``` +- `input_data`: Input data, which can be an image path, image array, PIL Image, or PyTorch tensor. -### Distributed training -To launch a distributed evaluation run, the implementation starts from [eval/main_distributed.py](eval/main_distributed.py), which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source [submitit](https://github.com/facebookincubator/submitit) tool and provide examples for a SLURM cluster. +### Embedding Images -For example, to launch a distributed ImageNet image classification experiment using the config [configs/eval/vitl16_in1k.yaml](configs/eval/vitl16_in1k.yaml), type the command: -```bash -python -m evals.main_distributed \ - --fname configs/eval/vitl16_in1k.yaml \ - --folder $path_to_save_stderr_and_stdout \ - --partition $slurm_partition -``` +To obtain the embeddings for an image, you can use the `embed_image` function: -Similarly, to launch a distributed K400 video classification experiment using the config [configs/eval/vitl16_k400.yaml](configs/eval/vitl16_k400.yaml), type the command: -```bash -python -m evals.main_distributed \ - --fname configs/eval/vitl16_k400.yaml \ - --folder $path_to_save_stderr_and_stdout \ - --partition $slurm_partition +```python +embeddings = encoder.embed_image(input_data) ``` ---- +- `input_data`: Input data, which can be an image path, image array, PIL Image, or PyTorch tensor. -### Setup +The function returns the embeddings generated by the encoder. -Run: -```bash -conda create -n jepa python=3.9 pip -conda activate jepa -python setup.py install -``` +## Configuration + +The VJEPA Encoder requires a configuration file in YAML format to specify the model settings. The configuration file should include the following sections: + +- `meta`: General settings such as the checkpoint file path, random seed, etc. +- `mask`: Settings related to masking. +- `model`: Model architecture settings. +- `data`: Data-related settings such as crop size, patch size, etc. +- `logging`: Logging settings. + +Please refer to the provided configuration file template for more details. ## License -See the [LICENSE](./LICENSE) file for details about the license under which this code is made available. - -## Citation -If you find this repository useful in your research, please consider giving a star :star: and a citation -```bibtex -@article{bardes2024revisiting, - title={Revisiting Feature Prediction for Learning Visual Representations from Video}, - author={Bardes, Adrien and Garrido, Quentin and Ponce, Jean and Rabbat, Michael, and LeCun, Yann and Assran, Mahmoud and Ballas, Nicolas}, - journal={arXiv preprint}, - year={2024} -} + +The VJEPA Encoder is released under the [MIT License](LICENSE). + +## Acknowledgments + +The VJEPA Encoder is based on the research work conducted by Facebook AI Research. We would like to acknowledge their contributions to the field of computer vision and representation learning. + +## Contact + +If you have any questions or suggestions regarding the VJEPA Encoder, please feel free to contact us at johnnykoch02@gmail.com. + +--- \ No newline at end of file diff --git a/app/main.py b/app/main.py index 52e1596..9f66229 100644 --- a/app/main.py +++ b/app/main.py @@ -13,7 +13,7 @@ import yaml from app.scaffold import main as app_main -from src.utils.distributed import init_distributed +from jepa_src.utils.distributed import init_distributed parser = argparse.ArgumentParser() parser.add_argument( @@ -30,7 +30,7 @@ def process_main(rank, fname, world_size, devices): os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) import logging - from src.utils.logging import get_logger + from jepa_src.utils.logging import get_logger logger = get_logger(force=True) if rank == 0: logger.setLevel(logging.INFO) diff --git a/app/main_distributed.py b/app/main_distributed.py index 11ac3a2..fe2e160 100644 --- a/app/main_distributed.py +++ b/app/main_distributed.py @@ -13,7 +13,7 @@ import submitit from app.scaffold import main as app_main -from src.utils.logging import get_logger +from jepa_src.utils.logging import get_logger logger = get_logger(force=True) diff --git a/app/vjepa/train.py b/app/vjepa/train.py index 2b55616..ccb2e75 100644 --- a/app/vjepa/train.py +++ b/app/vjepa/train.py @@ -26,19 +26,19 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel -from src.datasets.data_manager import init_data -from src.masks.random_tube import MaskCollator as TubeMaskCollator -from src.masks.multiblock3d import MaskCollator as MB3DMaskCollator -from src.masks.utils import apply_masks -from src.utils.distributed import init_distributed, AllReduce -from src.utils.logging import ( +from jepa_src.datasets.data_manager import init_data +from jepa_src.masks.random_tube import MaskCollator as TubeMaskCollator +from jepa_src.masks.multiblock3d import MaskCollator as MB3DMaskCollator +from jepa_src.masks.utils import apply_masks +from jepa_src.utils.distributed import init_distributed, AllReduce +from jepa_src.utils.logging import ( CSVLogger, gpu_timer, get_logger, grad_logger, adamw_logger, AverageMeter) -from src.utils.tensors import repeat_interleave_batch +from jepa_src.utils.tensors import repeat_interleave_batch from app.vjepa.utils import ( load_checkpoint, @@ -77,7 +77,7 @@ def main(args, resume_preempt=False): skip_batches = cfgs_meta.get('skip_batches', -1) use_sdpa = cfgs_meta.get('use_sdpa', False) which_dtype = cfgs_meta.get('dtype') - logger.info(f'{which_dtype=}') + logger.info(f'{which_dtype}') if which_dtype.lower() == 'bfloat16': dtype = torch.bfloat16 mixed_precision = True diff --git a/app/vjepa/transforms.py b/app/vjepa/transforms.py index 0854dd9..ba62555 100644 --- a/app/vjepa/transforms.py +++ b/app/vjepa/transforms.py @@ -8,8 +8,8 @@ import torch import torchvision.transforms as transforms -import src.datasets.utils.video.transforms as video_transforms -from src.datasets.utils.video.randerase import RandomErasing +import jepa_src.datasets.utils.video.transforms as video_transforms +from jepa_src.datasets.utils.video.randerase import RandomErasing def make_transforms( diff --git a/app/vjepa/utils.py b/app/vjepa/utils.py index dc8668d..2636ed7 100644 --- a/app/vjepa/utils.py +++ b/app/vjepa/utils.py @@ -13,13 +13,13 @@ import torch -import src.models.vision_transformer as video_vit -import src.models.predictor as vit_pred -from src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper -from src.utils.schedulers import ( +import jepa_src.models.vision_transformer as video_vit +import jepa_src.models.predictor as vit_pred +from jepa_src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper +from jepa_src.utils.schedulers import ( WarmupCosineSchedule, CosineWDSchedule) -from src.utils.tensors import trunc_normal_ +from jepa_src.utils.tensors import trunc_normal_ logging.basicConfig(stream=sys.stdout, level=logging.INFO) logger = logging.getLogger() diff --git a/src/datasets/data_manager.py b/build/lib/datasets/data_manager.py similarity index 94% rename from src/datasets/data_manager.py rename to build/lib/datasets/data_manager.py index cdb7ade..cf53940 100644 --- a/src/datasets/data_manager.py +++ b/build/lib/datasets/data_manager.py @@ -48,7 +48,7 @@ def init_data( if (data.lower() == 'imagenet') \ or (data.lower() == 'inat21') \ or (data.lower() == 'places205'): - from src.datasets.image_dataset import make_imagedataset + from jepa_src.datasets.image_dataset import make_imagedataset dataset, data_loader, dist_sampler = make_imagedataset( transform=transform, batch_size=batch_size, @@ -66,7 +66,7 @@ def init_data( subset_file=subset_file) elif data.lower() == 'videodataset': - from src.datasets.video_dataset import make_videodataset + from jepa_src.datasets.video_dataset import make_videodataset dataset, data_loader, dist_sampler = make_videodataset( data_paths=root_path, batch_size=batch_size, diff --git a/src/datasets/image_dataset.py b/build/lib/datasets/image_dataset.py similarity index 100% rename from src/datasets/image_dataset.py rename to build/lib/datasets/image_dataset.py diff --git a/src/datasets/utils/video/functional.py b/build/lib/datasets/utils/video/functional.py similarity index 100% rename from src/datasets/utils/video/functional.py rename to build/lib/datasets/utils/video/functional.py diff --git a/src/datasets/utils/video/randaugment.py b/build/lib/datasets/utils/video/randaugment.py similarity index 100% rename from src/datasets/utils/video/randaugment.py rename to build/lib/datasets/utils/video/randaugment.py diff --git a/src/datasets/utils/video/randerase.py b/build/lib/datasets/utils/video/randerase.py similarity index 100% rename from src/datasets/utils/video/randerase.py rename to build/lib/datasets/utils/video/randerase.py diff --git a/src/datasets/utils/video/transforms.py b/build/lib/datasets/utils/video/transforms.py similarity index 99% rename from src/datasets/utils/video/transforms.py rename to build/lib/datasets/utils/video/transforms.py index ffa8e61..979985d 100644 --- a/src/datasets/utils/video/transforms.py +++ b/build/lib/datasets/utils/video/transforms.py @@ -17,8 +17,8 @@ import torchvision.transforms.functional as F from torchvision import transforms -import src.datasets.utils.video.functional as FF -from src.datasets.utils.video.randaugment import rand_augment_transform +import jepa_src.datasets.utils.video.functional as FF +from jepa_src.datasets.utils.video.randaugment import rand_augment_transform _pil_interpolation_to_str = { diff --git a/src/datasets/utils/video/volume_transforms.py b/build/lib/datasets/utils/video/volume_transforms.py similarity index 100% rename from src/datasets/utils/video/volume_transforms.py rename to build/lib/datasets/utils/video/volume_transforms.py diff --git a/src/datasets/utils/weighted_sampler.py b/build/lib/datasets/utils/weighted_sampler.py similarity index 100% rename from src/datasets/utils/weighted_sampler.py rename to build/lib/datasets/utils/weighted_sampler.py diff --git a/src/datasets/video_dataset.py b/build/lib/datasets/video_dataset.py similarity index 97% rename from src/datasets/video_dataset.py rename to build/lib/datasets/video_dataset.py index b05cc70..82cee52 100644 --- a/src/datasets/video_dataset.py +++ b/build/lib/datasets/video_dataset.py @@ -18,7 +18,7 @@ import torch -from src.datasets.utils.weighted_sampler import DistributedWeightedSampler +from jepa_src.datasets.utils.weighted_sampler import DistributedWeightedSampler _GLOBAL_SEED = 0 logger = getLogger() @@ -188,15 +188,15 @@ def loadvideo_decord(self, sample): fname = sample if not os.path.exists(fname): - warnings.warn(f'video path not found {fname=}') + warnings.warn(f'video path not found {fname}') return [], None _fsize = os.path.getsize(fname) if _fsize < 1 * 1024: # avoid hanging issue - warnings.warn(f'video too short {fname=}') + warnings.warn(f'video too short {fname}') return [], None if _fsize > self.filter_long_videos: - warnings.warn(f'skipping long video of size {_fsize=} (bytes)') + warnings.warn(f'skipping long video of size {_fsize} (bytes)') return [], None try: diff --git a/build/lib/jepa_src/__init__.py b/build/lib/jepa_src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/jepa_src/datasets/__init__.py b/build/lib/jepa_src/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/jepa_src/datasets/data_manager.py b/build/lib/jepa_src/datasets/data_manager.py new file mode 100644 index 0000000..cf53940 --- /dev/null +++ b/build/lib/jepa_src/datasets/data_manager.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from logging import getLogger + + +_GLOBAL_SEED = 0 +logger = getLogger() + + +def init_data( + batch_size, + transform=None, + shared_transform=None, + data='ImageNet', + collator=None, + pin_mem=True, + num_workers=8, + world_size=1, + rank=0, + root_path=None, + image_folder=None, + training=True, + copy_data=False, + drop_last=True, + tokenize_txt=True, + subset_file=None, + clip_len=8, + frame_sample_rate=2, + duration=None, + num_clips=1, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(1e9), + decode_one_clip=True, + datasets_weights=None, + persistent_workers=False, + repeat_wds=False, + ipe=300, + log_dir=None, +): + + if (data.lower() == 'imagenet') \ + or (data.lower() == 'inat21') \ + or (data.lower() == 'places205'): + from jepa_src.datasets.image_dataset import make_imagedataset + dataset, data_loader, dist_sampler = make_imagedataset( + transform=transform, + batch_size=batch_size, + collator=collator, + pin_mem=pin_mem, + training=training, + num_workers=num_workers, + world_size=world_size, + rank=rank, + root_path=root_path, + image_folder=image_folder, + persistent_workers=persistent_workers, + copy_data=copy_data, + drop_last=drop_last, + subset_file=subset_file) + + elif data.lower() == 'videodataset': + from jepa_src.datasets.video_dataset import make_videodataset + dataset, data_loader, dist_sampler = make_videodataset( + data_paths=root_path, + batch_size=batch_size, + frames_per_clip=clip_len, + frame_step=frame_sample_rate, + duration=duration, + num_clips=num_clips, + random_clip_sampling=random_clip_sampling, + allow_clip_overlap=allow_clip_overlap, + filter_short_videos=filter_short_videos, + filter_long_videos=filter_long_videos, + shared_transform=shared_transform, + transform=transform, + datasets_weights=datasets_weights, + collator=collator, + num_workers=num_workers, + world_size=world_size, + rank=rank, + drop_last=drop_last, + log_dir=log_dir) + + return (data_loader, dist_sampler) diff --git a/build/lib/jepa_src/datasets/image_dataset.py b/build/lib/jepa_src/datasets/image_dataset.py new file mode 100644 index 0000000..84e9b08 --- /dev/null +++ b/build/lib/jepa_src/datasets/image_dataset.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +from logging import getLogger + +import torch +import torchvision + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class ImageFolder(torchvision.datasets.ImageFolder): + + def __init__( + self, + root, + image_folder='imagenet_full_size/061417/', + transform=None, + train=True, + ): + """ + ImageFolder + :param root: root network directory for ImageFolder data + :param image_folder: path to images inside root network directory + :param train: whether to load train data (or validation) + """ + + suffix = 'train/' if train else 'val/' + data_path = os.path.join(root, image_folder, suffix) + logger.info(f'data-path {data_path}') + super(ImageFolder, self).__init__(root=data_path, transform=transform) + logger.info('Initialized ImageFolder') + + +def make_imagedataset( + transform, + batch_size, + collator=None, + pin_mem=True, + num_workers=8, + world_size=1, + rank=0, + root_path=None, + image_folder=None, + training=True, + copy_data=False, + drop_last=True, + persistent_workers=False, + subset_file=None +): + dataset = ImageFolder( + root=root_path, + image_folder=image_folder, + transform=transform, + train=training) + logger.info('ImageFolder dataset created') + dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset=dataset, + num_replicas=world_size, + rank=rank) + data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=collator, + sampler=dist_sampler, + batch_size=batch_size, + drop_last=drop_last, + pin_memory=pin_mem, + num_workers=num_workers, + persistent_workers=persistent_workers) + logger.info('ImageFolder unsupervised data loader created') + + return dataset, data_loader, dist_sampler diff --git a/build/lib/jepa_src/datasets/utils/__init__.py b/build/lib/jepa_src/datasets/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/jepa_src/datasets/utils/video/__init__.py b/build/lib/jepa_src/datasets/utils/video/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/jepa_src/datasets/utils/video/functional.py b/build/lib/jepa_src/datasets/utils/video/functional.py new file mode 100644 index 0000000..a91d15d --- /dev/null +++ b/build/lib/jepa_src/datasets/utils/video/functional.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numbers +import cv2 +import numpy as np +import PIL +import torch + + +def _is_tensor_clip(clip): + return torch.is_tensor(clip) and clip.ndimension() == 4 + + +def crop_clip(clip, min_h, min_w, h, w): + if isinstance(clip[0], np.ndarray): + cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] + + elif isinstance(clip[0], PIL.Image.Image): + cropped = [ + img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return cropped + + +def resize_clip(clip, size, interpolation='bilinear'): + if isinstance(clip[0], np.ndarray): + if isinstance(size, numbers.Number): + im_h, im_w, im_c = clip[0].shape + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[0], size[1] + if interpolation == 'bilinear': + np_inter = cv2.INTER_LINEAR + else: + np_inter = cv2.INTER_NEAREST + scaled = [ + cv2.resize(img, size, interpolation=np_inter) for img in clip + ] + elif isinstance(clip[0], PIL.Image.Image): + if isinstance(size, numbers.Number): + im_w, im_h = clip[0].size + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[1], size[0] + if interpolation == 'bilinear': + pil_inter = PIL.Image.BILINEAR + else: + pil_inter = PIL.Image.NEAREST + scaled = [img.resize(size, pil_inter) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return scaled + + +def get_resize_sizes(im_h, im_w, size): + if im_w < im_h: + ow = size + oh = int(size * im_h / im_w) + else: + oh = size + ow = int(size * im_w / im_h) + return oh, ow + + +def normalize(clip, mean, std, inplace=False): + if not _is_tensor_clip(clip): + raise TypeError('tensor is not a torch clip.') + + if not inplace: + clip = clip.clone() + + dtype = clip.dtype + mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) + std = torch.as_tensor(std, dtype=dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + + return clip diff --git a/build/lib/jepa_src/datasets/utils/video/randaugment.py b/build/lib/jepa_src/datasets/utils/video/randaugment.py new file mode 100644 index 0000000..4c80a99 --- /dev/null +++ b/build/lib/jepa_src/datasets/utils/video/randaugment.py @@ -0,0 +1,518 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py +pulished under an Apache License 2.0. +""" + +import math +import numpy as np +import random +import re +import PIL +from PIL import Image, ImageEnhance, ImageOps + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10.0 + +_HPARAMS_DEFAULT = { + "translate_const": 250, + "img_mean": _FILL, +} + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop("resample", Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if "fillcolor" in kwargs and _PIL_VER < (5, 0): + kwargs.pop("fillcolor") + kwargs["resample"] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs + ) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs + ) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs + ) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs + ) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs + ) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs + ) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], + -rotn_center[1] - post_trans[1], + matrix, + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs["resample"]) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30.0 + level = _randomly_negate(level) + return (level,) + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + + +def _enhance_increasing_level_to_arg(level, _hparams): + # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend + # range [0.1, 1.9] + level = (level / _MAX_LEVEL) * 0.9 + level = 1.0 + _randomly_negate(level) + return (level,) + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams["translate_const"] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return (level,) + + +def _translate_rel_level_to_arg(level, hparams): + # default range [-0.45, 0.45] + translate_pct = hparams.get("translate_pct", 0.45) + level = (level / _MAX_LEVEL) * translate_pct + level = _randomly_negate(level) + return (level,) + + +def _posterize_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4),) + + +def _posterize_increasing_level_to_arg(level, hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image', + # intensity/severity of augmentation increases with level + return (4 - _posterize_level_to_arg(level, hparams)[0],) + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4) + 4,) + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 256),) + + +def _solarize_increasing_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation increases with level + return (256 - _solarize_level_to_arg(level, _hparams)[0],) + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return (int((level / _MAX_LEVEL) * 110),) + + +LEVEL_TO_ARG = { + "AutoContrast": None, + "Equalize": None, + "Invert": None, + "Rotate": _rotate_level_to_arg, + # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + "Posterize": _posterize_level_to_arg, + "PosterizeIncreasing": _posterize_increasing_level_to_arg, + "PosterizeOriginal": _posterize_original_level_to_arg, + "Solarize": _solarize_level_to_arg, + "SolarizeIncreasing": _solarize_increasing_level_to_arg, + "SolarizeAdd": _solarize_add_level_to_arg, + "Color": _enhance_level_to_arg, + "ColorIncreasing": _enhance_increasing_level_to_arg, + "Contrast": _enhance_level_to_arg, + "ContrastIncreasing": _enhance_increasing_level_to_arg, + "Brightness": _enhance_level_to_arg, + "BrightnessIncreasing": _enhance_increasing_level_to_arg, + "Sharpness": _enhance_level_to_arg, + "SharpnessIncreasing": _enhance_increasing_level_to_arg, + "ShearX": _shear_level_to_arg, + "ShearY": _shear_level_to_arg, + "TranslateX": _translate_abs_level_to_arg, + "TranslateY": _translate_abs_level_to_arg, + "TranslateXRel": _translate_rel_level_to_arg, + "TranslateYRel": _translate_rel_level_to_arg, +} + + +NAME_TO_OP = { + "AutoContrast": auto_contrast, + "Equalize": equalize, + "Invert": invert, + "Rotate": rotate, + "Posterize": posterize, + "PosterizeIncreasing": posterize, + "PosterizeOriginal": posterize, + "Solarize": solarize, + "SolarizeIncreasing": solarize, + "SolarizeAdd": solarize_add, + "Color": color, + "ColorIncreasing": color, + "Contrast": contrast, + "ContrastIncreasing": contrast, + "Brightness": brightness, + "BrightnessIncreasing": brightness, + "Sharpness": sharpness, + "SharpnessIncreasing": sharpness, + "ShearX": shear_x, + "ShearY": shear_y, + "TranslateX": translate_x_abs, + "TranslateY": translate_y_abs, + "TranslateXRel": translate_x_rel, + "TranslateYRel": translate_y_rel, +} + + +class AugmentOp: + """ + Apply for video. + """ + + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = { + "fillcolor": hparams["img_mean"] + if "img_mean" in hparams + else _FILL, + "resample": hparams["interpolation"] + if "interpolation" in hparams + else _RANDOM_INTERPOLATION, + } + + # If magnitude_std is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from a normal distribution + # with mean `magnitude` and std-dev of `magnitude_std`. + # NOTE This is my own hack, being tested, not in papers or reference impls. + self.magnitude_std = self.hparams.get("magnitude_std", 0) + + def __call__(self, img_list): + if self.prob < 1.0 and random.random() > self.prob: + return img_list + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + level_args = ( + self.level_fn(magnitude, self.hparams) + if self.level_fn is not None + else () + ) + + if isinstance(img_list, list): + return [ + self.aug_fn(img, *level_args, **self.kwargs) for img in img_list + ] + else: + return self.aug_fn(img_list, *level_args, **self.kwargs) + + +_RAND_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "Posterize", + "Solarize", + "SolarizeAdd", + "Color", + "Contrast", + "Brightness", + "Sharpness", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + + +_RAND_INCREASING_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "PosterizeIncreasing", + "SolarizeIncreasing", + "SolarizeAdd", + "ColorIncreasing", + "ContrastIncreasing", + "BrightnessIncreasing", + "SharpnessIncreasing", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + + +# These experimental weights are based loosely on the relative improvements mentioned in paper. +# They may not result in increased performance, but could likely be tuned to so. +_RAND_CHOICE_WEIGHTS_0 = { + "Rotate": 0.3, + "ShearX": 0.2, + "ShearY": 0.2, + "TranslateXRel": 0.1, + "TranslateYRel": 0.1, + "Color": 0.025, + "Sharpness": 0.025, + "AutoContrast": 0.025, + "Solarize": 0.005, + "SolarizeAdd": 0.005, + "Contrast": 0.005, + "Brightness": 0.005, + "Equalize": 0.005, + "Posterize": 0, + "Invert": 0, +} + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [ + AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) + for name in transforms + ] + + +class RandAugment: + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + # no replacement when using weighted choice + ops = np.random.choice( + self.ops, + self.num_layers, + replace=self.choice_weights is None, + p=self.choice_weights, + ) + for op in ops: + img = op(img) + return img + + +def rand_augment_transform(config_str, hparams): + """ + RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 + + Create a RandAugment transform + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) + 'mstd' - float std deviation of magnitude noise applied + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 + :param hparams: Other hparams (kwargs) for the RandAugmentation scheme + :return: A PyTorch compatible Transform + """ + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + transforms = _RAND_TRANSFORMS + config = config_str.split("-") + assert config[0] == "rand" + config = config[1:] + for c in config: + cs = re.split(r"(\d.*)", c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == "mstd": + # noise param injected via hparams for now + hparams.setdefault("magnitude_std", float(val)) + elif key == "inc": + if bool(val): + transforms = _RAND_INCREASING_TRANSFORMS + elif key == "m": + magnitude = int(val) + elif key == "n": + num_layers = int(val) + elif key == "w": + weight_idx = int(val) + else: + assert NotImplementedError + ra_ops = rand_augment_ops( + magnitude=magnitude, hparams=hparams, transforms=transforms + ) + choice_weights = ( + None if weight_idx is None else _select_rand_weights(weight_idx) + ) + return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) diff --git a/build/lib/jepa_src/datasets/utils/video/randerase.py b/build/lib/jepa_src/datasets/utils/video/randerase.py new file mode 100644 index 0000000..d1f185c --- /dev/null +++ b/build/lib/jepa_src/datasets/utils/video/randerase.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py +pulished under an Apache License 2.0. +""" +import math +import random +import torch + + +def _get_pixels( + per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" +): + # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() + # paths, flip the order so normal is run on CPU if this becomes a problem + # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 + if per_pixel: + return torch.empty(patch_size, dtype=dtype, device=device).normal_() + elif rand_color: + return torch.empty( + (patch_size[0], 1, 1), dtype=dtype, device=device + ).normal_() + else: + return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) + + +class RandomErasing: + """Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + This variant of RandomErasing is intended to be applied to either a batch + or single image tensor after it has been normalized by dataset mean and std. + Args: + probability: Probability that the Random Erasing operation will be performed. + min_area: Minimum percentage of erased area wrt input image area. + max_area: Maximum percentage of erased area wrt input image area. + min_aspect: Minimum aspect ratio of erased area. + mode: pixel color mode, one of 'const', 'rand', or 'pixel' + 'const' - erase block is constant color of 0 for all channels + 'rand' - erase block is same per-channel random (normal) color + 'pixel' - erase block is per-pixel random (normal) color + max_count: maximum number of erasing blocks per image, area per box is scaled by count. + per-image count is randomly chosen between 1 and this value. + """ + + def __init__( + self, + probability=0.5, + min_area=0.02, + max_area=1 / 3, + min_aspect=0.3, + max_aspect=None, + mode="const", + min_count=1, + max_count=None, + num_splits=0, + device="cuda", + cube=True, + ): + self.probability = probability + self.min_area = min_area + self.max_area = max_area + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + self.min_count = min_count + self.max_count = max_count or min_count + self.num_splits = num_splits + mode = mode.lower() + self.rand_color = False + self.per_pixel = False + self.cube = cube + if mode == "rand": + self.rand_color = True # per block random normal + elif mode == "pixel": + self.per_pixel = True # per pixel random normal + else: + assert not mode or mode == "const" + self.device = device + + def _erase(self, img, chan, img_h, img_w, dtype): + if random.random() > self.probability: + return + area = img_h * img_w + count = ( + self.min_count + if self.min_count == self.max_count + else random.randint(self.min_count, self.max_count) + ) + for _ in range(count): + for _ in range(10): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / count + ) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + img[:, top:top + h, left:left + w] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) + break + + def _erase_cube( + self, + img, + batch_start, + batch_size, + chan, + img_h, + img_w, + dtype, + ): + if random.random() > self.probability: + return + area = img_h * img_w + count = ( + self.min_count + if self.min_count == self.max_count + else random.randint(self.min_count, self.max_count) + ) + for _ in range(count): + for _ in range(100): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / count + ) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + for i in range(batch_start, batch_size): + img_instance = img[i] + img_instance[ + :, top:top + h, left:left + w + ] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) + break + + def __call__(self, input): + if len(input.size()) == 3: + self._erase(input, *input.size(), input.dtype) + else: + batch_size, chan, img_h, img_w = input.size() + # skip first slice of batch if num_splits is set (for clean portion of samples) + batch_start = ( + batch_size // self.num_splits if self.num_splits > 1 else 0 + ) + if self.cube: + self._erase_cube( + input, + batch_start, + batch_size, + chan, + img_h, + img_w, + input.dtype, + ) + else: + for i in range(batch_start, batch_size): + self._erase(input[i], chan, img_h, img_w, input.dtype) + return input diff --git a/build/lib/jepa_src/datasets/utils/video/transforms.py b/build/lib/jepa_src/datasets/utils/video/transforms.py new file mode 100644 index 0000000..979985d --- /dev/null +++ b/build/lib/jepa_src/datasets/utils/video/transforms.py @@ -0,0 +1,1184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +import numpy as np +import random +import numbers +import PIL +from PIL import Image + +import torch +import torchvision +import torchvision.transforms.functional as F +from torchvision import transforms + +import jepa_src.datasets.utils.video.functional as FF +from jepa_src.datasets.utils.video.randaugment import rand_augment_transform + + +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _pil_interp(method): + if method == 'bicubic': + return Image.BICUBIC + elif method == 'lanczos': + return Image.LANCZOS + elif method == 'hamming': + return Image.HAMMING + else: + return Image.BILINEAR + + +def random_short_side_scale_jitter( + images, min_size, max_size, boxes=None, inverse_uniform_sampling=False +): + """ + Perform a spatial short scale jittering on the given images and + corresponding boxes. + Args: + images (tensor): images to perform scale jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + min_size (int): the minimal size to scale the frames. + max_size (int): the maximal size to scale the frames. + boxes (ndarray): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + inverse_uniform_sampling (bool): if True, sample uniformly in + [1 / max_scale, 1 / min_scale] and take a reciprocal to get the + scale. If False, take a uniform sample from [min_scale, max_scale]. + Returns: + (tensor): the scaled images with dimension of + `num frames` x `channel` x `new height` x `new width`. + (ndarray or None): the scaled boxes with dimension of + `num boxes` x 4. + """ + if inverse_uniform_sampling: + size = int( + round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)) + ) + else: + size = int(round(np.random.uniform(min_size, max_size))) + + height = images.shape[2] + width = images.shape[3] + if (width <= height and width == size) or ( + height <= width and height == size + ): + return images, boxes + new_width = size + new_height = size + if width < height: + new_height = int(math.floor((float(height) / width) * size)) + if boxes is not None: + boxes = boxes * float(new_height) / height + else: + new_width = int(math.floor((float(width) / height) * size)) + if boxes is not None: + boxes = boxes * float(new_width) / width + + return ( + torch.nn.functional.interpolate( + images, + size=(new_height, new_width), + mode='bilinear', + align_corners=False, + ), + boxes, + ) + + +def crop_boxes(boxes, x_offset, y_offset): + """ + Peform crop on the bounding boxes given the offsets. + Args: + boxes (ndarray or None): bounding boxes to peform crop. The dimension + is `num boxes` x 4. + x_offset (int): cropping offset in the x axis. + y_offset (int): cropping offset in the y axis. + Returns: + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + cropped_boxes = boxes.copy() + cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset + cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset + + return cropped_boxes + + +def random_crop(images, size, boxes=None): + """ + Perform random spatial crop on the given images and corresponding boxes. + Args: + images (tensor): images to perform random crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): the size of height and width to crop on the image. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + Returns: + cropped (tensor): cropped images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + if images.shape[2] == size and images.shape[3] == size: + return images + height = images.shape[2] + width = images.shape[3] + y_offset = 0 + if height > size: + y_offset = int(np.random.randint(0, height - size)) + x_offset = 0 + if width > size: + x_offset = int(np.random.randint(0, width - size)) + cropped = images[ + :, :, y_offset:y_offset + size, x_offset:x_offset + size + ] + + cropped_boxes = ( + crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + ) + + return cropped, cropped_boxes + + +def horizontal_flip(prob, images, boxes=None): + """ + Perform horizontal flip on the given images and corresponding boxes. + Args: + prob (float): probility to flip the images. + images (tensor): images to perform horizontal flip, the dimension is + `num frames` x `channel` x `height` x `width`. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + Returns: + images (tensor): images with dimension of + `num frames` x `channel` x `height` x `width`. + flipped_boxes (ndarray or None): the flipped boxes with dimension of + `num boxes` x 4. + """ + if boxes is None: + flipped_boxes = None + else: + flipped_boxes = boxes.copy() + + if np.random.uniform() < prob: + images = images.flip((-1)) + + if len(images.shape) == 3: + width = images.shape[2] + elif len(images.shape) == 4: + width = images.shape[3] + else: + raise NotImplementedError("Dimension does not supported") + if boxes is not None: + flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1 + + return images, flipped_boxes + + +def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): + """ + Perform uniform spatial sampling on the images and corresponding boxes. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode='bilinear', + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[ + :, :, y_offset:y_offset + size, x_offset:x_offset + size + ] + cropped_boxes = ( + crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + ) + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped, cropped_boxes + + +def clip_boxes_to_image(boxes, height, width): + """ + Clip an array of boxes to an image with the given height and width. + Args: + boxes (ndarray): bounding boxes to perform clipping. + Dimension is `num boxes` x 4. + height (int): given image height. + width (int): given image width. + Returns: + clipped_boxes (ndarray): the clipped boxes with dimension of + `num boxes` x 4. + """ + clipped_boxes = boxes.copy() + clipped_boxes[:, [0, 2]] = np.minimum( + width - 1.0, np.maximum(0.0, boxes[:, [0, 2]]) + ) + clipped_boxes[:, [1, 3]] = np.minimum( + height - 1.0, np.maximum(0.0, boxes[:, [1, 3]]) + ) + return clipped_boxes + + +def blend(images1, images2, alpha): + """ + Blend two images with a given weight alpha. + Args: + images1 (tensor): the first images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + images2 (tensor): the second images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + alpha (float): the blending weight. + Returns: + (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + return images1 * alpha + images2 * (1 - alpha) + + +def grayscale(images): + """ + Get the grayscale for the input images. The channels of images should be + in order BGR. + Args: + images (tensor): the input images for getting grayscale. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + img_gray (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + # R -> 0.299, G -> 0.587, B -> 0.114. + img_gray = torch.tensor(images) + gray_channel = ( + 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] + ) + img_gray[:, 0] = gray_channel + img_gray[:, 1] = gray_channel + img_gray[:, 2] = gray_channel + return img_gray + + +def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): + """ + Perfrom a color jittering on the input images. The channels of images + should be in order BGR. + Args: + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + img_brightness (float): jitter ratio for brightness. + img_contrast (float): jitter ratio for contrast. + img_saturation (float): jitter ratio for saturation. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + + jitter = [] + if img_brightness != 0: + jitter.append('brightness') + if img_contrast != 0: + jitter.append('contrast') + if img_saturation != 0: + jitter.append('saturation') + + if len(jitter) > 0: + order = np.random.permutation(np.arange(len(jitter))) + for idx in range(0, len(jitter)): + if jitter[order[idx]] == 'brightness': + images = brightness_jitter(img_brightness, images) + elif jitter[order[idx]] == 'contrast': + images = contrast_jitter(img_contrast, images) + elif jitter[order[idx]] == 'saturation': + images = saturation_jitter(img_saturation, images) + return images + + +def brightness_jitter(var, images): + """ + Perfrom brightness jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for brightness. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_bright = torch.zeros(images.shape) + images = blend(images, img_bright, alpha) + return images + + +def contrast_jitter(var, images): + """ + Perfrom contrast jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for contrast. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_gray = grayscale(images) + img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True) + images = blend(images, img_gray, alpha) + return images + + +def saturation_jitter(var, images): + """ + Perfrom saturation jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for saturation. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + img_gray = grayscale(images) + images = blend(images, img_gray, alpha) + + return images + + +def lighting_jitter(images, alphastd, eigval, eigvec): + """ + Perform AlexNet-style PCA jitter on the given images. + Args: + images (tensor): images to perform lighting jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + alphastd (float): jitter ratio for PCA jitter. + eigval (list): eigenvalues for PCA jitter. + eigvec (list[list]): eigenvectors for PCA jitter. + Returns: + out_images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if alphastd == 0: + return images + # generate alpha1, alpha2, alpha3. + alpha = np.random.normal(0, alphastd, size=(1, 3)) + eig_vec = np.array(eigvec) + eig_val = np.reshape(eigval, (1, 3)) + rgb = np.sum( + eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), + axis=1, + ) + out_images = torch.zeros_like(images) + if len(images.shape) == 3: + # C H W + channel_dim = 0 + elif len(images.shape) == 4: + # T C H W + channel_dim = 1 + else: + raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + + for idx in range(images.shape[channel_dim]): + # C H W + if len(images.shape) == 3: + out_images[idx] = images[idx] + rgb[2 - idx] + # T C H W + elif len(images.shape) == 4: + out_images[:, idx] = images[:, idx] + rgb[2 - idx] + else: + raise NotImplementedError( + f'Unsupported dimension {len(images.shape)}' + ) + + return out_images + + +def color_normalization(images, mean, stddev): + """ + Perform color nomration on the given images. + Args: + images (tensor): images to perform color normalization. Dimension is + `num frames` x `channel` x `height` x `width`. + mean (list): mean values for normalization. + stddev (list): standard deviations for normalization. + + Returns: + out_images (tensor): the noramlized images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if len(images.shape) == 3: + assert ( + len(mean) == images.shape[0] + ), 'channel mean not computed properly' + assert ( + len(stddev) == images.shape[0] + ), 'channel stddev not computed properly' + elif len(images.shape) == 4: + assert ( + len(mean) == images.shape[1] + ), 'channel mean not computed properly' + assert ( + len(stddev) == images.shape[1] + ), 'channel stddev not computed properly' + else: + raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + + out_images = torch.zeros_like(images) + for idx in range(len(mean)): + # C H W + if len(images.shape) == 3: + out_images[idx] = (images[idx] - mean[idx]) / stddev[idx] + elif len(images.shape) == 4: + out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] + else: + raise NotImplementedError( + f'Unsupported dimension {len(images.shape)}' + ) + return out_images + + +def _get_param_spatial_crop( + scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False +): + """ + Given scale, ratio, height and width, return sampled coordinates of the videos. + """ + for _ in range(num_repeat): + area = height * width + target_area = random.uniform(*scale) * area + if log_scale: + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + else: + aspect_ratio = random.uniform(*ratio) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if np.random.uniform() < 0.5 and switch_hw: + w, h = h, w + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(ratio): + w = width + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + +def random_resized_crop( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + Crop the given images to random size and aspect ratio. A crop of random + size (default: of 0.08 to 1.0) of the original size and a random aspect + ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This + crop is finally resized to given size. This is popularly used to train the + Inception networks. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + cropped = images[:, :, i:i + h, j:j + w] + return torch.nn.functional.interpolate( + cropped, + size=(target_height, target_width), + mode='bilinear', + align_corners=False, + ) + + +def random_resized_crop_with_shift( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + This is similar to random_resized_crop. However, it samples two different + boxes (for cropping) for the first and last frame. It then linearly + interpolates the two boxes for other frames. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + t = images.shape[1] + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width) + i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()] + j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()] + h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()] + w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] + out = torch.zeros((3, t, target_height, target_width)) + for ind in range(t): + out[:, ind:ind + 1, :, :] = torch.nn.functional.interpolate( + images[ + :, + ind:ind + 1, + i_s[ind]:i_s[ind] + h_s[ind], + j_s[ind]:j_s[ind] + w_s[ind], + ], + size=(target_height, target_width), + mode='bilinear', + align_corners=False, + ) + return out + + +def create_random_augment( + input_size, + auto_augment=None, + interpolation='bilinear', +): + """ + Get video randaug transform. + + Args: + input_size: The size of the input video in tuple. + auto_augment: Parameters for randaug. An example: + "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number + of operations to apply). + interpolation: Interpolation method. + """ + if isinstance(input_size, tuple): + img_size = input_size[-2:] + else: + img_size = input_size + + if auto_augment: + assert isinstance(auto_augment, str) + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size + aa_params = {'translate_const': int(img_size_min * 0.45)} + if interpolation and interpolation != 'random': + aa_params['interpolation'] = _pil_interp(interpolation) + if auto_augment.startswith('rand'): + return transforms.Compose( + [rand_augment_transform(auto_augment, aa_params)] + ) + raise NotImplementedError + + +def random_sized_crop_img( + im, + size, + jitter_scale=(0.08, 1.0), + jitter_aspect=(3.0 / 4.0, 4.0 / 3.0), + max_iter=10, +): + """ + Performs Inception-style cropping (used for training). + """ + assert ( + len(im.shape) == 3 + ), 'Currently only support image for random_sized_crop' + h, w = im.shape[1:3] + i, j, h, w = _get_param_spatial_crop( + scale=jitter_scale, + ratio=jitter_aspect, + height=h, + width=w, + num_repeat=max_iter, + log_scale=False, + switch_hw=True, + ) + cropped = im[:, i:i + h, j:j + w] + return torch.nn.functional.interpolate( + cropped.unsqueeze(0), + size=(size, size), + mode='bilinear', + align_corners=False, + ).squeeze(0) + + +# The following code are modified based on timm lib, we will replace the following +# contents with dependency from PyTorchVideo. +# https://github.com/facebookresearch/pytorchvideo +class RandomResizedCropAndInterpolation: + """Crop the given PIL Image to random size and aspect ratio with random interpolation. + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation='bilinear', + ): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + print('range should be of kind (min, max)') + + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = _pil_interp(interpolation) + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for _ in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + return F.resized_crop(img, i, j, h, w, self.size, interpolation) + + def __repr__(self): + if isinstance(self.interpolation, (tuple, list)): + interpolate_str = ' '.join( + [_pil_interpolation_to_str[x] for x in self.interpolation] + ) + else: + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format( + tuple(round(s, 4) for s in self.scale) + ) + format_string += ', ratio={0}'.format( + tuple(round(r, 4) for r in self.ratio) + ) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + +class Compose(object): + """Composes several transforms + Args: + transforms (list of ``Transform`` objects): list of transforms + to compose + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, clip): + for t in self.transforms: + clip = t(clip) + return clip + + +class RandomHorizontalFlip(object): + """Horizontally flip the list of given images randomly + with a probability 0.5 + """ + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Randomly flipped clip + """ + if random.random() < 0.5: + if isinstance(clip[0], np.ndarray): + return [np.fliplr(img) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + return [ + img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + ' but got list of {0}'.format(type(clip[0]))) + return clip + + +class RandomResize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): + self.ratio = ratio + self.interpolation = interpolation + + def __call__(self, clip): + scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) + + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + + new_w = int(im_w * scaling_factor) + new_h = int(im_h * scaling_factor) + new_size = (new_w, new_h) + resized = FF.resize_clip( + clip, new_size, interpolation=self.interpolation) + return resized + + +class Resize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, size, interpolation='nearest'): + self.size = size + self.interpolation = interpolation + + def __call__(self, clip): + resized = FF.resize_clip( + clip, self.size, interpolation=self.interpolation) + return resized + + +class RandomCrop(object): + """Extract random crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w > im_w or h > im_h: + error_msg = ( + 'Initial image size should be larger then ' + 'cropped size but got cropped sizes : ({w}, {h}) while ' + 'initial image is ({im_w}, {im_h})'.format( + im_w=im_w, im_h=im_h, w=w, h=h)) + raise ValueError(error_msg) + + x1 = random.randint(0, im_w - w) + y1 = random.randint(0, im_h - h) + cropped = FF.crop_clip(clip, y1, x1, h, w) + + return cropped + + +class ThreeCrop(object): + """Extract random crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w != im_w and h != im_h: + clip = FF.resize_clip(clip, self.size, interpolation="bilinear") + im_h, im_w, im_c = clip[0].shape + + step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0) + cropped = [] + for i in range(3): + if (im_h > self.size[0]): + x1 = 0 + y1 = i * step + cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) + else: + x1 = i * step + y1 = 0 + cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) + return cropped + + +class RandomRotation(object): + """Rotate entire clip randomly by a random angle within + given bounds + Args: + degrees (sequence or int): Range of degrees to select from + If degrees is a number instead of sequence like (min, max), + the range of degrees, will be (-degrees, +degrees). + """ + + def __init__(self, degrees): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError('If degrees is a single number,' + 'must be positive') + degrees = (-degrees, degrees) + else: + if len(degrees) != 2: + raise ValueError('If degrees is a sequence,' + 'it must be of len 2.') + + self.degrees = degrees + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + import skimage + angle = random.uniform(self.degrees[0], self.degrees[1]) + if isinstance(clip[0], np.ndarray): + rotated = [skimage.transform.rotate(img, angle) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + rotated = [img.rotate(angle) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + + return rotated + + +class CenterCrop(object): + """Extract center crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w > im_w or h > im_h: + error_msg = ( + 'Initial image size should be larger then ' + 'cropped size but got cropped sizes : ({w}, {h}) while ' + 'initial image is ({im_w}, {im_h})'.format( + im_w=im_w, im_h=im_h, w=w, h=h)) + raise ValueError(error_msg) + + x1 = int(round((im_w - w) / 2.)) + y1 = int(round((im_h - h) / 2.)) + cropped = FF.crop_clip(clip, y1, x1, h, w) + + return cropped + + +class ColorJitter(object): + """ + Randomly change the brightness, contrast and saturation and hue of the clip + + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + def get_params(self, brightness, contrast, saturation, hue): + if brightness > 0: + brightness_factor = random.uniform( + max(0, 1 - brightness), 1 + brightness) + else: + brightness_factor = None + + if contrast > 0: + contrast_factor = random.uniform( + max(0, 1 - contrast), 1 + contrast) + else: + contrast_factor = None + + if saturation > 0: + saturation_factor = random.uniform( + max(0, 1 - saturation), 1 + saturation) + else: + saturation_factor = None + + if hue > 0: + hue_factor = random.uniform(-hue, hue) + else: + hue_factor = None + return brightness_factor, contrast_factor, saturation_factor, hue_factor + + def __call__(self, clip): + """ + Args: + clip (list): list of PIL.Image + Returns: + list PIL.Image : list of transformed PIL.Image + """ + if isinstance(clip[0], np.ndarray): + raise TypeError( + 'Color jitter not yet implemented for numpy arrays') + elif isinstance(clip[0], PIL.Image.Image): + brightness, contrast, saturation, hue = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + # Create img transform function sequence + img_transforms = [] + if brightness is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) + if saturation is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) + if hue is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) + if contrast is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) + random.shuffle(img_transforms) + + # Apply to all images + jittered_clip = [] + for img in clip: + for func in img_transforms: + jittered_img = func(img) + jittered_clip.append(jittered_img) + + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return jittered_clip + + +class Normalize(object): + """Normalize a clip with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + .. note:: + This transform acts out of place, i.e., it does not mutates the input tensor. + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, clip): + """ + Args: + clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized. + Returns: + Tensor: Normalized Tensor clip. + """ + return FF.normalize(clip, self.mean, self.std) + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) diff --git a/build/lib/jepa_src/datasets/utils/video/volume_transforms.py b/build/lib/jepa_src/datasets/utils/video/volume_transforms.py new file mode 100644 index 0000000..0a01bb3 --- /dev/null +++ b/build/lib/jepa_src/datasets/utils/video/volume_transforms.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numpy as np +from PIL import Image + +import torch + + +def convert_img(img): + """Converts (H, W, C) numpy.ndarray to (C, W, H) format""" + if len(img.shape) == 3: + img = img.transpose(2, 0, 1) + if len(img.shape) == 2: + img = np.expand_dims(img, 0) + return img + + +class ClipToTensor(object): + """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] + to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] + """ + + def __init__(self, channel_nb=3, div_255=True, numpy=False): + self.channel_nb = channel_nb + self.div_255 = div_255 + self.numpy = numpy + + def __call__(self, clip): + """ + Args: clip (list of numpy.ndarray): clip (list of images) + to be converted to tensor. + """ + # Retrieve shape + if isinstance(clip[0], np.ndarray): + h, w, ch = clip[0].shape + assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) + elif isinstance(clip[0], Image.Image): + w, h = clip[0].size + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + + np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) + + # Convert + for img_idx, img in enumerate(clip): + if isinstance(img, np.ndarray): + pass + elif isinstance(img, Image.Image): + img = np.array(img, copy=False) + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + img = convert_img(img) + np_clip[:, img_idx, :, :] = img + if self.numpy: + if self.div_255: + np_clip = np_clip / 255.0 + return np_clip + + else: + tensor_clip = torch.from_numpy(np_clip) + + if not isinstance(tensor_clip, torch.FloatTensor): + tensor_clip = tensor_clip.float() + if self.div_255: + tensor_clip = torch.div(tensor_clip, 255) + return tensor_clip + + +# Note this norms data to -1/1 +class ClipToTensor_K(object): + """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] + to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] + """ + + def __init__(self, channel_nb=3, div_255=True, numpy=False): + self.channel_nb = channel_nb + self.div_255 = div_255 + self.numpy = numpy + + def __call__(self, clip): + """ + Args: clip (list of numpy.ndarray): clip (list of images) + to be converted to tensor. + """ + # Retrieve shape + if isinstance(clip[0], np.ndarray): + h, w, ch = clip[0].shape + assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) + elif isinstance(clip[0], Image.Image): + w, h = clip[0].size + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + + np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) + + # Convert + for img_idx, img in enumerate(clip): + if isinstance(img, np.ndarray): + pass + elif isinstance(img, Image.Image): + img = np.array(img, copy=False) + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + img = convert_img(img) + np_clip[:, img_idx, :, :] = img + if self.numpy: + if self.div_255: + np_clip = (np_clip - 127.5) / 127.5 + return np_clip + + else: + tensor_clip = torch.from_numpy(np_clip) + + if not isinstance(tensor_clip, torch.FloatTensor): + tensor_clip = tensor_clip.float() + if self.div_255: + tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) + return tensor_clip + + +class ToTensor(object): + """Converts numpy array to tensor""" + + def __call__(self, array): + tensor = torch.from_numpy(array) + return tensor diff --git a/build/lib/jepa_src/datasets/utils/weighted_sampler.py b/build/lib/jepa_src/datasets/utils/weighted_sampler.py new file mode 100644 index 0000000..fd40825 --- /dev/null +++ b/build/lib/jepa_src/datasets/utils/weighted_sampler.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from typing import Iterator, Optional +from operator import itemgetter +import numpy as np + +import torch +from torch.utils.data import ( + Dataset, + Sampler, + DistributedSampler, + WeightedRandomSampler +) + + +class DatasetFromSampler(Dataset): + + def __init__(self, sampler: Sampler): + self.sampler = sampler + self.sampler_list = None + + def __getitem__(self, index: int): + if self.sampler_list is None: + self.sampler_list = list(self.sampler) + return self.sampler_list[index] + + def __len__(self) -> int: + return len(self.sampler) + + +class DistributedSamplerWrapper(DistributedSampler): + """ Convert any Pytorch Sampler to a DistributedSampler """ + + def __init__( + self, + sampler, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + super(DistributedSamplerWrapper, self).__init__( + DatasetFromSampler(sampler), + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) + self.sampler = sampler + + def __iter__(self) -> Iterator[int]: + self.dataset = DatasetFromSampler(self.sampler) + indexes_of_indexes = super().__iter__() + subsampler_indexes = self.dataset + return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) + + +class CustomWeightedRandomSampler(WeightedRandomSampler): + """ Generalized WeightedRandomSampler to allow for more than 2^24 samples """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __iter__(self): + rand_tensor = np.random.choice( + range(0, len(self.weights)), + size=self.num_samples, + p=self.weights.numpy() / torch.sum(self.weights).numpy(), + replace=self.replacement + ) + rand_tensor = torch.from_numpy(rand_tensor) + return iter(rand_tensor.tolist()) + + +class DistributedWeightedSampler(DistributedSamplerWrapper): + + def __init__( + self, + weights, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + weighted_sampler = CustomWeightedRandomSampler( + weights=weights, + num_samples=len(weights), + replacement=False) + + super(DistributedWeightedSampler, self).__init__( + sampler=weighted_sampler, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) diff --git a/build/lib/jepa_src/datasets/video_dataset.py b/build/lib/jepa_src/datasets/video_dataset.py new file mode 100644 index 0000000..82cee52 --- /dev/null +++ b/build/lib/jepa_src/datasets/video_dataset.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import pathlib +import warnings + +from logging import getLogger + +import numpy as np +import pandas as pd + +from decord import VideoReader, cpu + +import torch + +from jepa_src.datasets.utils.weighted_sampler import DistributedWeightedSampler + +_GLOBAL_SEED = 0 +logger = getLogger() + + +def make_videodataset( + data_paths, + batch_size, + frames_per_clip=8, + frame_step=4, + num_clips=1, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(10**9), + transform=None, + shared_transform=None, + rank=0, + world_size=1, + datasets_weights=None, + collator=None, + drop_last=True, + num_workers=10, + pin_mem=True, + duration=None, + log_dir=None, +): + dataset = VideoDataset( + data_paths=data_paths, + datasets_weights=datasets_weights, + frames_per_clip=frames_per_clip, + frame_step=frame_step, + num_clips=num_clips, + random_clip_sampling=random_clip_sampling, + allow_clip_overlap=allow_clip_overlap, + filter_short_videos=filter_short_videos, + filter_long_videos=filter_long_videos, + duration=duration, + shared_transform=shared_transform, + transform=transform) + + logger.info('VideoDataset dataset created') + if datasets_weights is not None: + dist_sampler = DistributedWeightedSampler( + dataset.sample_weights, + num_replicas=world_size, + rank=rank, + shuffle=True) + else: + dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=True) + + data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=collator, + sampler=dist_sampler, + batch_size=batch_size, + drop_last=drop_last, + pin_memory=pin_mem, + num_workers=num_workers, + persistent_workers=num_workers > 0) + logger.info('VideoDataset unsupervised data loader created') + + return dataset, data_loader, dist_sampler + + +class VideoDataset(torch.utils.data.Dataset): + """ Video classification dataset. """ + + def __init__( + self, + data_paths, + datasets_weights=None, + frames_per_clip=16, + frame_step=4, + num_clips=1, + transform=None, + shared_transform=None, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(10**9), + duration=None, # duration in seconds + ): + self.data_paths = data_paths + self.datasets_weights = datasets_weights + self.frames_per_clip = frames_per_clip + self.frame_step = frame_step + self.num_clips = num_clips + self.transform = transform + self.shared_transform = shared_transform + self.random_clip_sampling = random_clip_sampling + self.allow_clip_overlap = allow_clip_overlap + self.filter_short_videos = filter_short_videos + self.filter_long_videos = filter_long_videos + self.duration = duration + + if VideoReader is None: + raise ImportError('Unable to import "decord" which is required to read videos.') + + # Load video paths and labels + samples, labels = [], [] + self.num_samples_per_dataset = [] + for data_path in self.data_paths: + + if data_path[-4:] == '.csv': + data = pd.read_csv(data_path, header=None, delimiter=" ") + samples += list(data.values[:, 0]) + labels += list(data.values[:, 1]) + num_samples = len(data) + self.num_samples_per_dataset.append(num_samples) + + elif data_path[-4:] == '.npy': + data = np.load(data_path, allow_pickle=True) + data = list(map(lambda x: repr(x)[1:-1], data)) + samples += data + labels += [0] * len(data) + num_samples = len(data) + self.num_samples_per_dataset.append(len(data)) + + # [Optional] Weights for each sample to be used by downstream + # weighted video sampler + self.sample_weights = None + if self.datasets_weights is not None: + self.sample_weights = [] + for dw, ns in zip(self.datasets_weights, self.num_samples_per_dataset): + self.sample_weights += [dw / ns] * ns + + self.samples = samples + self.labels = labels + + def __getitem__(self, index): + sample = self.samples[index] + + # Keep trying to load videos until you find a valid sample + loaded_video = False + while not loaded_video: + buffer, clip_indices = self.loadvideo_decord(sample) # [T H W 3] + loaded_video = len(buffer) > 0 + if not loaded_video: + index = np.random.randint(self.__len__()) + sample = self.samples[index] + + # Label/annotations for video + label = self.labels[index] + + def split_into_clips(video): + """ Split video into a list of clips """ + fpc = self.frames_per_clip + nc = self.num_clips + return [video[i*fpc:(i+1)*fpc] for i in range(nc)] + + # Parse video into frames & apply data augmentations + if self.shared_transform is not None: + buffer = self.shared_transform(buffer) + buffer = split_into_clips(buffer) + if self.transform is not None: + buffer = [self.transform(clip) for clip in buffer] + + return buffer, label, clip_indices + + def loadvideo_decord(self, sample): + """ Load video content using Decord """ + + fname = sample + if not os.path.exists(fname): + warnings.warn(f'video path not found {fname}') + return [], None + + _fsize = os.path.getsize(fname) + if _fsize < 1 * 1024: # avoid hanging issue + warnings.warn(f'video too short {fname}') + return [], None + if _fsize > self.filter_long_videos: + warnings.warn(f'skipping long video of size {_fsize} (bytes)') + return [], None + + try: + vr = VideoReader(fname, num_threads=-1, ctx=cpu(0)) + except Exception: + return [], None + + fpc = self.frames_per_clip + fstp = self.frame_step + if self.duration is not None: + try: + fps = vr.get_avg_fps() + fstp = int(self.duration * fps / fpc) + except Exception as e: + warnings.warn(e) + clip_len = int(fpc * fstp) + + if self.filter_short_videos and len(vr) < clip_len: + warnings.warn(f'skipping video of length {len(vr)}') + return [], None + + vr.seek(0) # Go to start of video before sampling frames + + # Partition video into equal sized segments and sample each clip + # from a different segment + partition_len = len(vr) // self.num_clips + + all_indices, clip_indices = [], [] + for i in range(self.num_clips): + + if partition_len > clip_len: + # If partition_len > clip len, then sample a random window of + # clip_len frames within the segment + end_indx = clip_len + if self.random_clip_sampling: + end_indx = np.random.randint(clip_len, partition_len) + start_indx = end_indx - clip_len + indices = np.linspace(start_indx, end_indx, num=fpc) + indices = np.clip(indices, start_indx, end_indx-1).astype(np.int64) + # -- + indices = indices + i * partition_len + else: + # If partition overlap not allowed and partition_len < clip_len + # then repeatedly append the last frame in the segment until + # we reach the desired clip length + if not self.allow_clip_overlap: + indices = np.linspace(0, partition_len, num=partition_len // fstp) + indices = np.concatenate((indices, np.ones(fpc - partition_len // fstp) * partition_len,)) + indices = np.clip(indices, 0, partition_len-1).astype(np.int64) + # -- + indices = indices + i * partition_len + + # If partition overlap is allowed and partition_len < clip_len + # then start_indx of segment i+1 will lie within segment i + else: + sample_len = min(clip_len, len(vr)) - 1 + indices = np.linspace(0, sample_len, num=sample_len // fstp) + indices = np.concatenate((indices, np.ones(fpc - sample_len // fstp) * sample_len,)) + indices = np.clip(indices, 0, sample_len-1).astype(np.int64) + # -- + clip_step = 0 + if len(vr) > clip_len: + clip_step = (len(vr) - clip_len) // (self.num_clips - 1) + indices = indices + i * clip_step + + clip_indices.append(indices) + all_indices.extend(list(indices)) + + buffer = vr.get_batch(all_indices).asnumpy() + return buffer, clip_indices + + def __len__(self): + return len(self.samples) diff --git a/build/lib/jepa_src/masks/__init__.py b/build/lib/jepa_src/masks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/masks/default.py b/build/lib/jepa_src/masks/default.py similarity index 100% rename from src/masks/default.py rename to build/lib/jepa_src/masks/default.py diff --git a/src/masks/multiblock3d.py b/build/lib/jepa_src/masks/multiblock3d.py similarity index 100% rename from src/masks/multiblock3d.py rename to build/lib/jepa_src/masks/multiblock3d.py diff --git a/src/masks/random_tube.py b/build/lib/jepa_src/masks/random_tube.py similarity index 100% rename from src/masks/random_tube.py rename to build/lib/jepa_src/masks/random_tube.py diff --git a/src/masks/utils.py b/build/lib/jepa_src/masks/utils.py similarity index 100% rename from src/masks/utils.py rename to build/lib/jepa_src/masks/utils.py diff --git a/build/lib/jepa_src/models/__init__.py b/build/lib/jepa_src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/attentive_pooler.py b/build/lib/jepa_src/models/attentive_pooler.py similarity index 97% rename from src/models/attentive_pooler.py rename to build/lib/jepa_src/models/attentive_pooler.py index ecd9986..26b0e0e 100644 --- a/src/models/attentive_pooler.py +++ b/build/lib/jepa_src/models/attentive_pooler.py @@ -10,12 +10,12 @@ import torch import torch.nn as nn -from src.models.utils.modules import ( +from jepa_src.models.utils.modules import ( Block, CrossAttention, CrossAttentionBlock ) -from src.utils.tensors import trunc_normal_ +from jepa_src.utils.tensors import trunc_normal_ class AttentivePooler(nn.Module): diff --git a/src/models/predictor.py b/build/lib/jepa_src/models/predictor.py similarity index 97% rename from src/models/predictor.py rename to build/lib/jepa_src/models/predictor.py index 2dd9a38..95f6bc0 100644 --- a/src/models/predictor.py +++ b/build/lib/jepa_src/models/predictor.py @@ -11,13 +11,13 @@ import torch import torch.nn as nn -from src.models.utils.modules import Block -from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed -from src.utils.tensors import ( +from jepa_src.models.utils.modules import Block +from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from jepa_src.utils.tensors import ( trunc_normal_, repeat_interleave_batch ) -from src.masks.utils import apply_masks +from jepa_src.masks.utils import apply_masks class VisionTransformerPredictor(nn.Module): diff --git a/build/lib/jepa_src/models/utils/__init__.py b/build/lib/jepa_src/models/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/jepa_src/models/utils/functional.py b/build/lib/jepa_src/models/utils/functional.py new file mode 100644 index 0000000..27d1b42 --- /dev/null +++ b/build/lib/jepa_src/models/utils/functional.py @@ -0,0 +1,30 @@ +import torch +import torch.nn.functional as F + +def scaled_dot_product_attention(q, k, v, dropout_p=0.0): + """ + Computes scaled dot product attention. + + Args: + q (torch.Tensor): Query tensor of shape (batch_size, num_heads, seq_len_q, head_dim). + k (torch.Tensor): Key tensor of shape (batch_size, num_heads, seq_len_k, head_dim). + v (torch.Tensor): Value tensor of shape (batch_size, num_heads, seq_len_v, head_dim). + dropout_p (float, optional): Dropout probability. Default is 0.0. + + Returns: + torch.Tensor: Output tensor of shape (batch_size, num_heads, seq_len_q, head_dim). + """ + # Compute attention scores + attn_scores = torch.matmul(q, k.transpose(-2, -1)) + attn_scores = attn_scores / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32)) + + # Apply softmax to attention scores + attn_probs = F.softmax(attn_scores, dim=-1) + + # Apply dropout to attention probabilities + attn_probs = F.dropout(attn_probs, p=dropout_p) + + # Compute attention output + attn_output = torch.matmul(attn_probs, v) + + return attn_output \ No newline at end of file diff --git a/src/models/utils/modules.py b/build/lib/jepa_src/models/utils/modules.py similarity index 96% rename from src/models/utils/modules.py rename to build/lib/jepa_src/models/utils/modules.py index dc470d9..2412b7a 100644 --- a/src/models/utils/modules.py +++ b/build/lib/jepa_src/models/utils/modules.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F +import jepa_src.utils.functional as JF class MLP(nn.Module): def __init__( @@ -65,7 +66,7 @@ def forward(self, x, mask=None): if self.use_sdpa: with torch.backends.cuda.sdp_kernel(): - x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob) + x = JF.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob) attn = None else: attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D] @@ -147,7 +148,7 @@ def forward(self, q, x): if self.use_sdpa: with torch.backends.cuda.sdp_kernel(): - q = F.scaled_dot_product_attention(q, k, v) + q = JF.scaled_dot_product_attention(q, k, v) else: xattn = (q @ k.transpose(-2, -1)) * self.scale xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len) diff --git a/src/models/utils/multimask.py b/build/lib/jepa_src/models/utils/multimask.py similarity index 100% rename from src/models/utils/multimask.py rename to build/lib/jepa_src/models/utils/multimask.py diff --git a/src/models/utils/patch_embed.py b/build/lib/jepa_src/models/utils/patch_embed.py similarity index 100% rename from src/models/utils/patch_embed.py rename to build/lib/jepa_src/models/utils/patch_embed.py diff --git a/src/models/utils/pos_embs.py b/build/lib/jepa_src/models/utils/pos_embs.py similarity index 100% rename from src/models/utils/pos_embs.py rename to build/lib/jepa_src/models/utils/pos_embs.py diff --git a/src/models/vision_transformer.py b/build/lib/jepa_src/models/vision_transformer.py similarity index 96% rename from src/models/vision_transformer.py rename to build/lib/jepa_src/models/vision_transformer.py index a8748df..946246e 100644 --- a/src/models/vision_transformer.py +++ b/build/lib/jepa_src/models/vision_transformer.py @@ -11,11 +11,11 @@ import torch import torch.nn as nn -from src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D -from src.models.utils.modules import Block -from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed -from src.utils.tensors import trunc_normal_ -from src.masks.utils import apply_masks +from jepa_src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D +from jepa_src.models.utils.modules import Block +from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from jepa_src.utils.tensors import trunc_normal_ +from jepa_src.masks.utils import apply_masks class VisionTransformer(nn.Module): diff --git a/build/lib/jepa_src/utils/__init__.py b/build/lib/jepa_src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/distributed.py b/build/lib/jepa_src/utils/distributed.py similarity index 100% rename from src/utils/distributed.py rename to build/lib/jepa_src/utils/distributed.py diff --git a/build/lib/jepa_src/utils/functional.py b/build/lib/jepa_src/utils/functional.py new file mode 100644 index 0000000..27d1b42 --- /dev/null +++ b/build/lib/jepa_src/utils/functional.py @@ -0,0 +1,30 @@ +import torch +import torch.nn.functional as F + +def scaled_dot_product_attention(q, k, v, dropout_p=0.0): + """ + Computes scaled dot product attention. + + Args: + q (torch.Tensor): Query tensor of shape (batch_size, num_heads, seq_len_q, head_dim). + k (torch.Tensor): Key tensor of shape (batch_size, num_heads, seq_len_k, head_dim). + v (torch.Tensor): Value tensor of shape (batch_size, num_heads, seq_len_v, head_dim). + dropout_p (float, optional): Dropout probability. Default is 0.0. + + Returns: + torch.Tensor: Output tensor of shape (batch_size, num_heads, seq_len_q, head_dim). + """ + # Compute attention scores + attn_scores = torch.matmul(q, k.transpose(-2, -1)) + attn_scores = attn_scores / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32)) + + # Apply softmax to attention scores + attn_probs = F.softmax(attn_scores, dim=-1) + + # Apply dropout to attention probabilities + attn_probs = F.dropout(attn_probs, p=dropout_p) + + # Compute attention output + attn_output = torch.matmul(attn_probs, v) + + return attn_output \ No newline at end of file diff --git a/src/utils/logging.py b/build/lib/jepa_src/utils/logging.py similarity index 100% rename from src/utils/logging.py rename to build/lib/jepa_src/utils/logging.py diff --git a/src/utils/monitoring.py b/build/lib/jepa_src/utils/monitoring.py similarity index 100% rename from src/utils/monitoring.py rename to build/lib/jepa_src/utils/monitoring.py diff --git a/src/utils/schedulers.py b/build/lib/jepa_src/utils/schedulers.py similarity index 100% rename from src/utils/schedulers.py rename to build/lib/jepa_src/utils/schedulers.py diff --git a/src/utils/tensors.py b/build/lib/jepa_src/utils/tensors.py similarity index 100% rename from src/utils/tensors.py rename to build/lib/jepa_src/utils/tensors.py diff --git a/build/lib/masks/default.py b/build/lib/masks/default.py new file mode 100644 index 0000000..2810c0a --- /dev/null +++ b/build/lib/masks/default.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from logging import getLogger + +import torch + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class DefaultCollator(object): + + def __call__(self, batch): + collated_batch = torch.utils.data.default_collate(batch) + return collated_batch, None, None diff --git a/build/lib/masks/multiblock3d.py b/build/lib/masks/multiblock3d.py new file mode 100644 index 0000000..a7bbc3e --- /dev/null +++ b/build/lib/masks/multiblock3d.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +from multiprocessing import Value + +from logging import getLogger + +import torch + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class MaskCollator(object): + + def __init__( + self, + cfgs_mask, + crop_size=(224, 224), + num_frames=16, + patch_size=(16, 16), + tubelet_size=2, + ): + super(MaskCollator, self).__init__() + + self.mask_generators = [] + for m in cfgs_mask: + mask_generator = _MaskGenerator( + crop_size=crop_size, + num_frames=num_frames, + spatial_patch_size=patch_size, + temporal_patch_size=tubelet_size, + spatial_pred_mask_scale=m.get('spatial_scale'), + temporal_pred_mask_scale=m.get('temporal_scale'), + aspect_ratio=m.get('aspect_ratio'), + npred=m.get('num_blocks'), + max_context_frames_ratio=m.get('max_temporal_keep', 1.0), + max_keep=m.get('max_keep', None), + ) + self.mask_generators.append(mask_generator) + + def step(self): + for mask_generator in self.mask_generators: + mask_generator.step() + + def __call__(self, batch): + + batch_size = len(batch) + collated_batch = torch.utils.data.default_collate(batch) + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(batch_size) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + + return collated_batch, collated_masks_enc, collated_masks_pred + + +class _MaskGenerator(object): + + def __init__( + self, + crop_size=(224, 224), + num_frames=16, + spatial_patch_size=(16, 16), + temporal_patch_size=2, + spatial_pred_mask_scale=(0.2, 0.8), + temporal_pred_mask_scale=(1.0, 1.0), + aspect_ratio=(0.3, 3.0), + npred=1, + max_context_frames_ratio=1.0, + max_keep=None, + ): + super(_MaskGenerator, self).__init__() + if not isinstance(crop_size, tuple): + crop_size = (crop_size, ) * 2 + self.crop_size = crop_size + self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.duration = num_frames // temporal_patch_size + + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.aspect_ratio = aspect_ratio + self.spatial_pred_mask_scale = spatial_pred_mask_scale + self.temporal_pred_mask_scale = temporal_pred_mask_scale + self.npred = npred + self.max_context_duration = max(1, int(self.duration * max_context_frames_ratio)) # maximum number of time-steps (frames) spanned by context mask + self.max_keep = max_keep # maximum number of patches to keep in context + self._itr_counter = Value('i', -1) # collator is shared across worker processes + + def step(self): + i = self._itr_counter + with i.get_lock(): + i.value += 1 + v = i.value + return v + + def _sample_block_size( + self, + generator, + temporal_scale, + spatial_scale, + aspect_ratio_scale + ): + # -- Sample temporal block mask scale + _rand = torch.rand(1, generator=generator).item() + min_t, max_t = temporal_scale + temporal_mask_scale = min_t + _rand * (max_t - min_t) + t = max(1, int(self.duration * temporal_mask_scale)) + + # -- Sample spatial block mask scale + _rand = torch.rand(1, generator=generator).item() + min_s, max_s = spatial_scale + spatial_mask_scale = min_s + _rand * (max_s - min_s) + spatial_num_keep = int(self.height * self.width * spatial_mask_scale) + + # -- Sample block aspect-ratio + _rand = torch.rand(1, generator=generator).item() + min_ar, max_ar = aspect_ratio_scale + aspect_ratio = min_ar + _rand * (max_ar - min_ar) + + # -- Compute block height and width (given scale and aspect-ratio) + h = int(round(math.sqrt(spatial_num_keep * aspect_ratio))) + w = int(round(math.sqrt(spatial_num_keep / aspect_ratio))) + h = min(h, self.height) + w = min(w, self.width) + + return (t, h, w) + + def _sample_block_mask(self, b_size): + t, h, w = b_size + top = torch.randint(0, self.height - h + 1, (1,)) + left = torch.randint(0, self.width - w + 1, (1,)) + start = torch.randint(0, self.duration - t + 1, (1,)) + + mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) + mask[start:start+t, top:top+h, left:left+w] = 0 + + # Context mask will only span the first X frames + # (X=self.max_context_frames) + if self.max_context_duration < self.duration: + mask[self.max_context_duration:, :, :] = 0 + + # -- + return mask + + def __call__(self, batch_size): + """ + Create encoder and predictor masks when collating imgs into a batch + # 1. sample pred block size using seed + # 2. sample several pred block locations for each image (w/o seed) + # 3. return pred masks and complement (enc mask) + """ + seed = self.step() + g = torch.Generator() + g.manual_seed(seed) + p_size = self._sample_block_size( + generator=g, + temporal_scale=self.temporal_pred_mask_scale, + spatial_scale=self.spatial_pred_mask_scale, + aspect_ratio_scale=self.aspect_ratio, + ) + + collated_masks_pred, collated_masks_enc = [], [] + min_keep_enc = min_keep_pred = self.duration * self.height * self.width + for _ in range(batch_size): + + empty_context = True + while empty_context: + + mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) + for _ in range(self.npred): + mask_e *= self._sample_block_mask(p_size) + mask_e = mask_e.flatten() + + mask_p = torch.argwhere(mask_e == 0).squeeze() + mask_e = torch.nonzero(mask_e).squeeze() + + empty_context = len(mask_e) == 0 + if not empty_context: + min_keep_pred = min(min_keep_pred, len(mask_p)) + min_keep_enc = min(min_keep_enc, len(mask_e)) + collated_masks_pred.append(mask_p) + collated_masks_enc.append(mask_e) + + if self.max_keep is not None: + min_keep_enc = min(min_keep_enc, self.max_keep) + + collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred] + collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) + # -- + collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc] + collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) + + return collated_masks_enc, collated_masks_pred diff --git a/build/lib/masks/random_tube.py b/build/lib/masks/random_tube.py new file mode 100644 index 0000000..84c0640 --- /dev/null +++ b/build/lib/masks/random_tube.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from multiprocessing import Value + +from logging import getLogger + +import torch +import numpy as np + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class MaskCollator(object): + + def __init__( + self, + cfgs_mask, + crop_size=(224, 224), + num_frames=16, + patch_size=(16, 16), + tubelet_size=2, + ): + super(MaskCollator, self).__init__() + + self.mask_generators = [] + for m in cfgs_mask: + mask_generator = _MaskGenerator( + crop_size=crop_size, + num_frames=num_frames, + spatial_patch_size=patch_size, + temporal_patch_size=tubelet_size, + ratio=m.get('ratio'), + ) + self.mask_generators.append(mask_generator) + + def step(self): + for mask_generator in self.mask_generators: + mask_generator.step() + + def __call__(self, batch): + + batch_size = len(batch) + collated_batch = torch.utils.data.default_collate(batch) + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(batch_size) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + + return collated_batch, collated_masks_enc, collated_masks_pred + + +class _MaskGenerator(object): + + def __init__( + self, + crop_size=(224, 224), + num_frames=16, + spatial_patch_size=(16, 16), + temporal_patch_size=2, + ratio=0.9, + ): + super(_MaskGenerator, self).__init__() + if not isinstance(crop_size, tuple): + crop_size = (crop_size, ) * 2 + self.crop_size = crop_size + self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.duration = num_frames // temporal_patch_size + + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + self.num_patches_spatial = self.height*self.width + + self.ratio = ratio + + self.num_keep_spatial = int(self.num_patches_spatial*(1.-self.ratio)) + self.num_keep = self.num_keep_spatial * self.duration + + self._itr_counter = Value('i', -1) # collator is shared across worker processes + + def step(self): + i = self._itr_counter + with i.get_lock(): + i.value += 1 + v = i.value + return v + + def __call__(self, batch_size): + def sample_mask(): + mask = np.hstack([ + np.zeros(self.num_patches_spatial - self.num_keep_spatial), + np.ones(self.num_keep_spatial), + ]) + np.random.shuffle(mask) + mask = torch.tensor(np.tile(mask, (self.duration, 1))) + mask = mask.flatten() + mask_p = torch.argwhere(mask == 0).squeeze() + mask_e = torch.nonzero(mask).squeeze() + return mask_e, mask_p + + collated_masks_pred, collated_masks_enc = [], [] + for _ in range(batch_size): + mask_e, mask_p = sample_mask() + collated_masks_enc.append(mask_e) + collated_masks_pred.append(mask_p) + + collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) + collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) + + return collated_masks_enc, collated_masks_pred diff --git a/build/lib/masks/utils.py b/build/lib/masks/utils.py new file mode 100644 index 0000000..ca04af1 --- /dev/null +++ b/build/lib/masks/utils.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch + + +def apply_masks(x, masks, concat=True): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + if not concat: + return all_x + + return torch.cat(all_x, dim=0) diff --git a/build/lib/models/attentive_pooler.py b/build/lib/models/attentive_pooler.py new file mode 100644 index 0000000..26b0e0e --- /dev/null +++ b/build/lib/models/attentive_pooler.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +import torch +import torch.nn as nn + +from jepa_src.models.utils.modules import ( + Block, + CrossAttention, + CrossAttentionBlock +) +from jepa_src.utils.tensors import trunc_normal_ + + +class AttentivePooler(nn.Module): + """ Attentive Pooler """ + def __init__( + self, + num_queries=1, + embed_dim=768, + num_heads=12, + mlp_ratio=4.0, + depth=1, + norm_layer=nn.LayerNorm, + init_std=0.02, + qkv_bias=True, + complete_block=True + ): + super().__init__() + self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) + + self.complete_block = complete_block + if complete_block: + self.cross_attention_block = CrossAttentionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer) + else: + self.cross_attention_block = CrossAttention( + dim=embed_dim, + num_heads=num_heads, + qkv_bias=qkv_bias) + + self.blocks = None + if depth > 1: + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=False, + norm_layer=norm_layer) + for i in range(depth-1)]) + + self.init_std = init_std + trunc_normal_(self.query_tokens, std=self.init_std) + self.apply(self._init_weights) + self._rescale_blocks() + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + if self.complete_block: + rescale(self.cross_attention_block.xattn.proj.weight.data, 1) + rescale(self.cross_attention_block.mlp.fc2.weight.data, 1) + else: + rescale(self.cross_attention_block.proj.weight.data, 1) + if self.blocks is not None: + for layer_id, layer in enumerate(self.blocks, 1): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + q = self.query_tokens.repeat(len(x), 1, 1) + q = self.cross_attention_block(q, x) + if self.blocks is not None: + for blk in self.blocks: + q = blk(q) + return q + + +class AttentiveClassifier(nn.Module): + """ Attentive Classifier """ + def __init__( + self, + embed_dim=768, + num_heads=12, + mlp_ratio=4.0, + depth=1, + norm_layer=nn.LayerNorm, + init_std=0.02, + qkv_bias=True, + num_classes=1000, + complete_block=True, + ): + super().__init__() + self.pooler = AttentivePooler( + num_queries=1, + embed_dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + depth=depth, + norm_layer=norm_layer, + init_std=init_std, + qkv_bias=qkv_bias, + complete_block=complete_block, + ) + self.linear = nn.Linear(embed_dim, num_classes, bias=True) + + def forward(self, x): + x = self.pooler(x).squeeze(1) + x = self.linear(x) + return x diff --git a/build/lib/models/predictor.py b/build/lib/models/predictor.py new file mode 100644 index 0000000..95f6bc0 --- /dev/null +++ b/build/lib/models/predictor.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +from functools import partial + +import torch +import torch.nn as nn + +from jepa_src.models.utils.modules import Block +from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from jepa_src.utils.tensors import ( + trunc_normal_, + repeat_interleave_batch +) +from jepa_src.masks.utils import apply_masks + + +class VisionTransformerPredictor(nn.Module): + """ Vision Transformer """ + def __init__( + self, + img_size=224, + patch_size=16, + num_frames=1, + tubelet_size=2, + embed_dim=768, + predictor_embed_dim=384, + depth=6, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + uniform_power=False, + use_mask_tokens=False, + num_mask_tokens=2, + zero_init_mask_tokens=True, + **kwargs + ): + super().__init__() + # Map input to predictor dimension + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + + # Mask tokens + self.mask_tokens = None + self.num_mask_tokens = 0 + if use_mask_tokens: + self.num_mask_tokens = num_mask_tokens + self.mask_tokens = nn.ParameterList([ + nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + for i in range(num_mask_tokens) + ]) + + # Determine positional embedding + self.input_size = img_size + self.patch_size = patch_size + # -- + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + + grid_size = self.input_size // self.patch_size + grid_depth = self.num_frames // self.tubelet_size + + if self.is_video: + self.num_patches = num_patches = ( + (num_frames // tubelet_size) + * (img_size // patch_size) + * (img_size // patch_size) + ) + else: + self.num_patches = num_patches = ( + (img_size // patch_size) + * (img_size // patch_size) + ) + # Position embedding + self.uniform_power = uniform_power + self.predictor_pos_embed = None + self.predictor_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, predictor_embed_dim), + requires_grad=False) + + # Attention Blocks + self.predictor_blocks = nn.ModuleList([ + Block( + dim=predictor_embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + attn_drop=attn_drop_rate, + grid_size=grid_size, + grid_depth=grid_depth, + norm_layer=norm_layer) + for i in range(depth)]) + + # Normalize & project back to input dimension + self.predictor_norm = norm_layer(predictor_embed_dim) + self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) + + # ------ initialize weights + if self.predictor_pos_embed is not None: + self._init_pos_embed(self.predictor_pos_embed.data) # sincos pos-embed + self.init_std = init_std + if not zero_init_mask_tokens: + for mt in self.mask_tokens: + trunc_normal_(mt, std=init_std) + self.apply(self._init_weights) + self._rescale_blocks() + + def _init_pos_embed(self, pos_embed): + embed_dim = pos_embed.size(-1) + grid_size = self.input_size // self.patch_size + if self.is_video: + grid_depth = self.num_frames // self.tubelet_size + sincos = get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=self.uniform_power + ) + else: + sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) + pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.predictor_blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def diffusion(self, x, noise_beta=(0.5, 1.0), steps=1000): + + # Prepare diffusion noise schedule + b1, b2 = noise_beta + beta_scheduler = (b1 + i*(b2-b1)/steps for i in range(steps)) + alpha_scheduler = [] + _alpha = 1.0 + for _beta in beta_scheduler: + _alpha *= 1.-_beta + alpha_scheduler += [_alpha] + + # Sample diffusion time step + T = torch.randint(0, steps, (len(x),)) + alpha = torch.tensor(alpha_scheduler, device=x.device)[T].unsqueeze(-1).unsqueeze(-1) + + # Normalize features and apply noise + x = torch.nn.functional.layer_norm(x, (x.size(-1),)) + x = alpha**0.5 * x + (1.-alpha)**0.5 * torch.randn(x.shape, device=x.device) + return x + + def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): + """ + :param ctxt: context tokens + :param tgt: target tokens + :param masks_ctxt: indices of context tokens in input + :params masks_tgt: indices of target tokens in input + """ + + assert (masks_ctxt is not None) and (masks_tgt is not None), 'Cannot run predictor without mask indices' + + if not isinstance(masks_ctxt, list): + masks_ctxt = [masks_ctxt] + + if not isinstance(masks_tgt, list): + masks_tgt = [masks_tgt] + + # Batch Size + B = len(ctxt) // len(masks_ctxt) + + # Map context tokens to pedictor dimensions + x = self.predictor_embed(ctxt) + _, N_ctxt, D = x.shape + + # Add positional embedding to ctxt tokens + if self.predictor_pos_embed is not None: + ctxt_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + x += apply_masks(ctxt_pos_embed, masks_ctxt) + + # Map target tokens to predictor dimensions & add noise (fwd diffusion) + if self.mask_tokens is None: + pred_tokens = self.predictor_embed(tgt) + pred_tokens = self.diffusion(pred_tokens) + else: + mask_index = mask_index % self.num_mask_tokens + pred_tokens = self.mask_tokens[mask_index] + pred_tokens = pred_tokens.repeat(B, self.num_patches, 1) + pred_tokens = apply_masks(pred_tokens, masks_tgt) + + # Add positional embedding to target tokens + if self.predictor_pos_embed is not None: + pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) + pos_embs = apply_masks(pos_embs, masks_tgt) + pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_ctxt)) + pred_tokens += pos_embs + + # Concatenate context & target tokens + x = x.repeat(len(masks_tgt), 1, 1) + x = torch.cat([x, pred_tokens], dim=1) + + # FIXME: this implementation currently assumes masks_ctxt and masks_tgt + # are alligned 1:1 (ok with MultiMask wrapper on predictor but + # otherwise will break) + masks_ctxt = torch.cat(masks_ctxt, dim=0) + masks_tgt = torch.cat(masks_tgt, dim=0) + masks = torch.cat([masks_ctxt, masks_tgt], dim=1) + + # Fwd prop + for blk in self.predictor_blocks: + x = blk(x, mask=masks) + x = self.predictor_norm(x) + + # Return output corresponding to target tokens + x = x[:, N_ctxt:] + x = self.predictor_proj(x) + + return x + + +def vit_predictor(**kwargs): + model = VisionTransformerPredictor( + mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + return model diff --git a/build/lib/models/utils/modules.py b/build/lib/models/utils/modules.py new file mode 100644 index 0000000..c78ffc0 --- /dev/null +++ b/build/lib/models/utils/modules.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import jepa_src.utils.functional as JF + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0. + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + use_sdpa=True + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop_prob = proj_drop + self.proj_drop = nn.Dropout(proj_drop) + self.use_sdpa = use_sdpa + + def forward(self, x, mask=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D] + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + x = JF.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob) + attn = None + else: + attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D] + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + grid_size=None, + grid_depth=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x, return_attention=False, mask=None): + y, attn = self.attn(self.norm1(x), mask=mask) + if return_attention: + return attn + x = x + y + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, + dim, + num_heads=12, + qkv_bias=False, + use_sdpa=True + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, int(dim*2), bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + self.use_sdpa = use_sdpa + + def forward(self, q, x): + B, n, C = q.shape + q = self.q(q).reshape(B, n, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + B, N, C = x.shape + kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # (batch_size, num_heads, seq_len, feature_dim_per_head) + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + q = JF.scaled_dot_product_attention(q, k, v) + else: + xattn = (q @ k.transpose(-2, -1)) * self.scale + xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len) + q = (xattn @ v) + + q = q.transpose(1, 2).reshape(B, n, C) + q = self.proj(q) + + return q + + +class CrossAttentionBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + + def forward(self, q, x): + y = self.xattn(q, self.norm1(x)) + q = q + y + q = q + self.mlp(self.norm2(q)) + return q diff --git a/build/lib/models/utils/multimask.py b/build/lib/models/utils/multimask.py new file mode 100644 index 0000000..d480086 --- /dev/null +++ b/build/lib/models/utils/multimask.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch.nn as nn + + +class MultiMaskWrapper(nn.Module): + + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, x, masks=None): + if masks is None: + return self.backbone(x) + + if (masks is not None) and not isinstance(masks, list): + masks = [masks] + outs = [] + for m in masks: + outs += [self.backbone(x, masks=m)] + return outs + + +class PredictorMultiMaskWrapper(nn.Module): + + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, ctxt, tgt, masks_ctxt, masks_tgt): + if type(ctxt) is not list: + ctxt = [ctxt] + if type(tgt) is not list: + tgt = [tgt] + if type(masks_ctxt) is not list: + masks_ctxt = [masks_ctxt] + if type(masks_tgt) is not list: + masks_tgt = [masks_tgt] + + outs = [] + for i, (zi, hi, mc, mt) in enumerate(zip(ctxt, tgt, masks_ctxt, masks_tgt)): + outs += [self.backbone(zi, hi, mc, mt, mask_index=i)] + return outs diff --git a/build/lib/models/utils/patch_embed.py b/build/lib/models/utils/patch_embed.py new file mode 100644 index 0000000..4ff4de5 --- /dev/null +++ b/build/lib/models/utils/patch_embed.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch.nn as nn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding + """ + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768 + ): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed3D(nn.Module): + """ + Image to Patch Embedding + """ + + def __init__( + self, + patch_size=16, + tubelet_size=2, + in_chans=3, + embed_dim=768, + ): + super().__init__() + self.patch_size = patch_size + self.tubelet_size = tubelet_size + + self.proj = nn.Conv3d( + in_channels=in_chans, + out_channels=embed_dim, + kernel_size=(tubelet_size, patch_size, patch_size), + stride=(tubelet_size, patch_size, patch_size), + ) + + def forward(self, x, **kwargs): + B, C, T, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x diff --git a/build/lib/models/utils/pos_embs.py b/build/lib/models/utils/pos_embs.py new file mode 100644 index 0000000..d1d82e2 --- /dev/null +++ b/build/lib/models/utils/pos_embs.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numpy as np + + +def get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=False +): + """ + grid_size: int of the grid height and width + grid_depth: int of the grid depth + returns: + pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token) + or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token) + """ + grid_d = np.arange(grid_depth, dtype=float) + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid_h, grid_d, grid_w = np.meshgrid(grid_h, grid_d, grid_w) # order of meshgrid is very important for indexing as [d,h,w] + + if not uniform_power: + h_embed_dim = embed_dim // 4 + w_embed_dim = embed_dim // 4 + d_embed_dim = embed_dim // 2 + else: + h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim/6)*2) + + emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1) + emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2) + emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3) + pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1) + pos_embed = pos_embed[:, :embed_dim] + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + returns: + pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token) + or [1+grid_size*grid_size, embed_dim] (w/ cls_token) + """ + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w] + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2) + pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + embed_dim: output dimension for each position + grid_size: int of the grid length + returns: + pos_embed: [grid_size, embed_dim] (w/o cls_token) + or [1+grid_size, embed_dim] (w/ cls_token) + """ + grid = np.arange(grid_size, dtype=float) + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + returns: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/build/lib/models/vision_transformer.py b/build/lib/models/vision_transformer.py new file mode 100644 index 0000000..946246e --- /dev/null +++ b/build/lib/models/vision_transformer.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +from functools import partial + +import torch +import torch.nn as nn + +from jepa_src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D +from jepa_src.models.utils.modules import Block +from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from jepa_src.utils.tensors import trunc_normal_ +from jepa_src.masks.utils import apply_masks + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__( + self, + img_size=224, + patch_size=16, + num_frames=1, + tubelet_size=2, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + out_layers=None, + uniform_power=False, + **kwargs + ): + super().__init__() + self.num_features = self.embed_dim = embed_dim + self.num_heads = num_heads + self.out_layers = out_layers + + self.input_size = img_size + self.patch_size = patch_size + + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + + grid_size = self.input_size // self.patch_size + grid_depth = self.num_frames // self.tubelet_size + + # Tokenize pixels with convolution + if self.is_video: + self.patch_embed = PatchEmbed3D( + patch_size=patch_size, + tubelet_size=tubelet_size, + in_chans=in_chans, + embed_dim=embed_dim) + self.num_patches = ( + (num_frames // tubelet_size) + * (img_size // patch_size) + * (img_size // patch_size) + ) + else: + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + self.num_patches = ( + (img_size // patch_size) + * (img_size // patch_size) + ) + + # Position embedding + self.uniform_power = uniform_power + self.pos_embed = None + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches, embed_dim), + requires_grad=False) + + # Attention Blocks + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + grid_size=grid_size, + grid_depth=grid_depth, + attn_drop=attn_drop_rate, + norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # ------ initialize weights + if self.pos_embed is not None: + self._init_pos_embed(self.pos_embed.data) # sincos pos-embed + self.init_std = init_std + self.apply(self._init_weights) + self._rescale_blocks() + + def _init_pos_embed(self, pos_embed): + embed_dim = pos_embed.size(-1) + grid_size = self.input_size // self.patch_size + if self.is_video: + grid_depth = self.num_frames // self.tubelet_size + sincos = get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=self.uniform_power + ) + else: + sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) + pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv3d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def get_num_layers(self): + return len(self.blocks) + + def no_weight_decay(self): + return {} + + def forward(self, x, masks=None): + """ + :param x: input image/video + :param masks: indices of patch tokens to mask (remove) + """ + + if masks is not None and not isinstance(masks, list): + masks = [masks] + + # Tokenize input + pos_embed = self.pos_embed + if pos_embed is not None: + pos_embed = self.interpolate_pos_encoding(x, pos_embed) + x = self.patch_embed(x) + if pos_embed is not None: + x += pos_embed + B, N, D = x.shape + + # Mask away unwanted tokens (if masks provided) + if masks is not None: + x = apply_masks(x, masks) + masks = torch.cat(masks, dim=0) + + # Fwd prop + outs = [] + for i, blk in enumerate(self.blocks): + x = blk(x, mask=masks) + if self.out_layers is not None and i in self.out_layers: + outs.append(self.norm(x)) + + if self.out_layers is not None: + return outs + + if self.norm is not None: + x = self.norm(x) + + return x + + def interpolate_pos_encoding(self, x, pos_embed): + + _, N, dim = pos_embed.shape + + if self.is_video: + + # If pos_embed already corret size, just return + _, _, T, H, W = x.shape + if H == self.input_size and W == self.input_size and T == self.num_frames: + return pos_embed + + # Convert depth, height, width of input to be measured in patches + # instead of pixels/frames + T = T // self.tubelet_size + H = H // self.patch_size + W = W // self.patch_size + + # Compute the initialized shape of the positional embedding measured + # in patches + N_t = self.num_frames // self.tubelet_size + N_h = N_w = self.input_size // self.patch_size + assert N_h * N_w * N_t == N, 'Positional embedding initialized incorrectly' + + # Compute scale factor for spatio-temporal interpolation + scale_factor = (T/N_t, H/N_h, W/N_w) + + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3), + scale_factor=scale_factor, + mode='trilinear') + pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim) + return pos_embed + + else: + + # If pos_embed already corret size, just return + _, _, H, W = x.shape + if H == self.input_size and W == self.input_size: + return pos_embed + + # Compute scale factor for spatial interpolation + npatch = (H // self.patch_size) * (W // self.patch_size) + scale_factor = math.sqrt(npatch / N) + + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=scale_factor, + mode='bicubic') + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return pos_embed + + +def vit_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_small(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_large(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_huge(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_giant(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_gigantic(patch_size=14, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1664, depth=48, num_heads=16, mpl_ratio=64/13, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs + ) + return model + + +VIT_EMBED_DIMS = { + 'vit_tiny': 192, + 'vit_small': 384, + 'vit_base': 768, + 'vit_large': 1024, + 'vit_huge': 1280, + 'vit_giant': 1408, + 'vit_gigantic': 1664, +} diff --git a/build/lib/utils/distributed.py b/build/lib/utils/distributed.py new file mode 100644 index 0000000..cfba444 --- /dev/null +++ b/build/lib/utils/distributed.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +import torch +import torch.distributed as dist + +from logging import getLogger + +logger = getLogger() + + +def init_distributed(port=37123, rank_and_world_size=(None, None)): + + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size(), dist.get_rank() + + rank, world_size = rank_and_world_size + os.environ['MASTER_ADDR'] = 'localhost' + + if (rank is None) or (world_size is None): + try: + world_size = int(os.environ['SLURM_NTASKS']) + rank = int(os.environ['SLURM_PROCID']) + os.environ['MASTER_ADDR'] = os.environ['HOSTNAME'] + except Exception: + logger.info('SLURM vars not set (distributed training not available)') + world_size, rank = 1, 0 + return world_size, rank + + try: + os.environ['MASTER_PORT'] = str(port) + torch.distributed.init_process_group( + backend='nccl', + world_size=world_size, + rank=rank + ) + except Exception as e: + world_size, rank = 1, 0 + logger.info(f'Rank: {rank}. Distributed training not available {e}') + + return world_size, rank + + +class AllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() + outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())] + dist.all_gather(outputs, x) + return torch.cat(outputs, 0) + return x + + @staticmethod + def backward(ctx, grads): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank() + e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1) + grads = grads.contiguous() + dist.all_reduce(grads) + return grads[s:e] + return grads + + +class AllReduceSum(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() + dist.all_reduce(x) + return x + + @staticmethod + def backward(ctx, grads): + return grads + + +class AllReduce(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() / dist.get_world_size() + dist.all_reduce(x) + return x + + @staticmethod + def backward(ctx, grads): + return grads diff --git a/build/lib/utils/logging.py b/build/lib/utils/logging.py new file mode 100644 index 0000000..fcdd3fa --- /dev/null +++ b/build/lib/utils/logging.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging +import sys + +import torch + + +def gpu_timer(closure, log_timings=True): + """ Helper to time gpu-time to execute closure() """ + log_timings = log_timings and torch.cuda.is_available() + + elapsed_time = -1. + if log_timings: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + + result = closure() + + if log_timings: + end.record() + torch.cuda.synchronize() + elapsed_time = start.elapsed_time(end) + + return result, elapsed_time + + +LOG_FORMAT = "[%(levelname)-8s][%(asctime)s][%(funcName)-25s] %(message)s" +DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + + +def get_logger(name=None, force=False): + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=LOG_FORMAT, datefmt=DATE_FORMAT, force=force) + return logging.getLogger(name=name) + + +class CSVLogger(object): + + def __init__(self, fname, *argv): + self.fname = fname + self.types = [] + # -- print headers + with open(self.fname, '+a') as f: + for i, v in enumerate(argv, 1): + self.types.append(v[0]) + if i < len(argv): + print(v[1], end=',', file=f) + else: + print(v[1], end='\n', file=f) + + def log(self, *argv): + with open(self.fname, '+a') as f: + for i, tv in enumerate(zip(self.types, argv), 1): + end = ',' if i < len(argv) else '\n' + print(tv[0] % tv[1], end=end, file=f) + + +class AverageMeter(object): + """computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.max = float('-inf') + self.min = float('inf') + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + try: + self.max = max(val, self.max) + self.min = min(val, self.min) + except Exception: + pass + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def grad_logger(named_params): + stats = AverageMeter() + stats.first_layer = None + stats.last_layer = None + for n, p in named_params: + if (p.grad is not None) and not (n.endswith('.bias') or len(p.shape) == 1): + grad_norm = float(torch.norm(p.grad.data)) + stats.update(grad_norm) + if 'qkv' in n: + stats.last_layer = grad_norm + if stats.first_layer is None: + stats.first_layer = grad_norm + if stats.first_layer is None or stats.last_layer is None: + stats.first_layer = stats.last_layer = 0. + return stats + + +def adamw_logger(optimizer): + """ logging magnitude of first and second momentum buffers in adamw """ + # TODO: assert that optimizer is instance of torch.optim.AdamW + state = optimizer.state_dict().get('state') + exp_avg_stats = AverageMeter() + exp_avg_sq_stats = AverageMeter() + for key in state: + s = state.get(key) + exp_avg_stats.update(float(s.get('exp_avg').abs().mean())) + exp_avg_sq_stats.update(float(s.get('exp_avg_sq').abs().mean())) + return {'exp_avg': exp_avg_stats, 'exp_avg_sq': exp_avg_sq_stats} diff --git a/build/lib/utils/monitoring.py b/build/lib/utils/monitoring.py new file mode 100644 index 0000000..95a7845 --- /dev/null +++ b/build/lib/utils/monitoring.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import dataclasses +import threading +from typing import Dict, Tuple + +import psutil + + +@dataclasses.dataclass +class ResourceStatsSample: + timestamp: float + cpu_percent: float + read_count: int + write_count: int + read_bytes: int + write_bytes: int + read_chars: int + write_chars: int + cpu_times_user: float + cpu_times_system: float + cpu_times_children_user: float + cpu_times_children_system: float + cpu_times_iowait: float + cpu_affinity: str + cpu_num: int + num_threads: int + num_voluntary_ctx_switches: int + num_involuntary_ctx_switches: int + + def as_tuple(self) -> Dict: + """Return values mirroring fields.""" + return dataclasses.astuple(self) + + def fields(self) -> Tuple[dataclasses.Field, ...]: + """Return fields in this dataclass.""" + return dataclasses.fields(self.__class__) + + +class ResourceMonitoringThread(threading.Thread): + def __init__(self, pid=None, refresh_interval=None, stats_callback_fn=None): + """Starts a thread to monitor pid every refresh_interval seconds. + + Passes a ResourceStatsSample object to the callback.""" + super(ResourceMonitoringThread, self).__init__() + if refresh_interval is None: + refresh_interval = 5 + self.is_running_event = threading.Event() + self.p = psutil.Process(pid) + self.refresh_interval = refresh_interval + if stats_callback_fn is None: + # Default callback + def stats_callback_fn(resource_sample: ResourceStatsSample): + print( + f"PID {self.p.pid} Stats: {resource_sample.resource_stats}") + elif not callable(stats_callback_fn): + raise ValueError("Callback needs to be callable, got {}".format( + type(stats_callback_fn))) + self.stats_callback_fn = stats_callback_fn + + def stop(self) -> None: + self.is_running_event.set() + + def run(self) -> None: + while not self.is_running_event.is_set(): + self.sample_counters() + self.is_running_event.wait(self.refresh_interval) + + def log_sample(self, resource_sample: ResourceStatsSample) -> None: + self.stats_callback_fn(resource_sample) + + def sample_counters(self) -> None: + if not self.p.is_running(): + self.stop() + return + + with self.p.oneshot(): + cpu_percent = self.p.cpu_percent() + cpu_times = self.p.cpu_times() + io_counters = self.p.io_counters() + cpu_affinity = self.p.cpu_affinity() + cpu_num = self.p.cpu_num() + num_threads = self.p.num_threads() + num_ctx_switches = self.p.num_ctx_switches() + timestamp = time.time() + + read_count = io_counters.read_count + write_count = io_counters.write_count + read_bytes = io_counters.read_bytes + write_bytes = io_counters.write_bytes + read_chars = io_counters.read_chars + write_chars = io_counters.write_chars + + def compress_cpu_affinity(cpu_affinity): + """Change list representation to interval/range representation.""" + if not cpu_affinity: + return "" + cpu_affinity_compressed = [] + min_x = None + max_x = None + last_x = None + + # Find contiguous ranges + for x in cpu_affinity: + if last_x is None: + # Start interval + min_x = x + max_x = x + last_x = x + continue + elif x == (last_x + 1): + # Move interval up + max_x = x + elif max_x is not None: + # Interval ended, start again + if min_x == max_x: + cpu_affinity_compressed.append("{}".format(min_x)) + else: + cpu_affinity_compressed.append( + "{}-{}".format(min_x, max_x)) + min_x = x + max_x = x + last_x = x + # Terminate last range + if max_x is not None: + if min_x == max_x: + cpu_affinity_compressed.append("{}".format(min_x)) + else: + cpu_affinity_compressed.append( + "{}-{}".format(min_x, max_x)) + + # Concat + cpu_affinity_compressed = ",".join(cpu_affinity_compressed) + + return cpu_affinity_compressed + + cpu_affinity = compress_cpu_affinity(cpu_affinity) + + resource_sample = ResourceStatsSample( + timestamp=timestamp, + cpu_percent=cpu_percent, + read_count=read_count, + write_count=write_count, + read_bytes=read_bytes, + write_bytes=write_bytes, + read_chars=read_chars, + write_chars=write_chars, + cpu_times_user=cpu_times.user, + cpu_times_system=cpu_times.system, + cpu_times_children_user=cpu_times.children_user, + cpu_times_children_system=cpu_times.children_system, + cpu_times_iowait=cpu_times.iowait, + cpu_affinity=cpu_affinity, + cpu_num=cpu_num, + num_threads=num_threads, + num_voluntary_ctx_switches=num_ctx_switches.voluntary, + num_involuntary_ctx_switches=num_ctx_switches.involuntary, + ) + self.log_sample(resource_sample) + + +if __name__ == "__main__": + import multiprocessing + import time + pid = multiprocessing.current_process().pid + monitor_thread = ResourceMonitoringThread(pid, 1) + monitor_thread.start() + time.sleep(5) + print("Shutdown") + monitor_thread.stop() diff --git a/build/lib/utils/schedulers.py b/build/lib/utils/schedulers.py new file mode 100644 index 0000000..df02e2b --- /dev/null +++ b/build/lib/utils/schedulers.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + + +class WarmupCosineSchedule(object): + + def __init__( + self, + optimizer, + warmup_steps, + start_lr, + ref_lr, + T_max, + last_epoch=-1, + final_lr=0. + ): + self.optimizer = optimizer + self.start_lr = start_lr + self.ref_lr = ref_lr + self.final_lr = final_lr + self.warmup_steps = warmup_steps + self.T_max = T_max - warmup_steps + self._step = 0. + + def step(self): + self._step += 1 + if self._step < self.warmup_steps: + progress = float(self._step) / float(max(1, self.warmup_steps)) + new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) + else: + # -- progress after warmup + progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) + new_lr = max(self.final_lr, + self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1. + math.cos(math.pi * progress))) + + for group in self.optimizer.param_groups: + group['lr'] = new_lr + + return new_lr + + +class CosineWDSchedule(object): + + def __init__( + self, + optimizer, + ref_wd, + T_max, + final_wd=0. + ): + self.optimizer = optimizer + self.ref_wd = ref_wd + self.final_wd = final_wd + self.T_max = T_max + self._step = 0. + + def step(self): + self._step += 1 + progress = self._step / self.T_max + new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1. + math.cos(math.pi * progress)) + + if self.final_wd <= self.ref_wd: + new_wd = max(self.final_wd, new_wd) + else: + new_wd = min(self.final_wd, new_wd) + + for group in self.optimizer.param_groups: + if ('WD_exclude' not in group) or not group['WD_exclude']: + group['weight_decay'] = new_wd + return new_wd diff --git a/build/lib/utils/tensors.py b/build/lib/utils/tensors.py new file mode 100644 index 0000000..6ae2850 --- /dev/null +++ b/build/lib/utils/tensors.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +import torch + +from logging import getLogger + +logger = getLogger() + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def apply_masks(x, masks): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors containing indices of patches [0,N) to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + return torch.cat(all_x, dim=0) + + +def repeat_interleave_batch(x, B, repeat): + N = len(x) // B + x = torch.cat([ + torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0) + for i in range(N) + ], dim=0) + return x diff --git a/build/lib/vjepa_encoder/__init__.py b/build/lib/vjepa_encoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/vjepa_encoder/vision_encoder.py b/build/lib/vjepa_encoder/vision_encoder.py new file mode 100644 index 0000000..1f473eb --- /dev/null +++ b/build/lib/vjepa_encoder/vision_encoder.py @@ -0,0 +1,329 @@ +# Extension of Jepa by Robot Perception and Action Laboratory, USF +# +# Non-Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from typing import List, Optional, Any +import multiprocessing as mp + +import pprint +import yaml +import os + +import torch + +from jepa_src.utils.distributed import init_distributed + +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Tuple + +from vjepa_encoder.vjepa.utils import init_video_model +import numpy as np + +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +# from torch.nn.parallel import DistributedDataParallel +from jepa_src.utils.distributed import init_distributed, AllReduce +from jepa_src.utils.logging import get_logger + +from vjepa_encoder.vjepa.utils import init_video_model + +import torch +from torchvision import transforms +from PIL import Image +import numpy as np + +_GLOBAL_SEED = 0 +np.random.seed(_GLOBAL_SEED) +torch.manual_seed(_GLOBAL_SEED) +torch.backends.cudnn.benchmark = True + +from jepa_src.models.vision_transformer import VIT_EMBED_DIMS as JEPA_DIM_SIZE + +import logging +from jepa_src.utils.logging import get_logger +logger = get_logger(force=True) +logger.setLevel(logging.INFO) + +class JepaEncoder(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + self.encoder, self.predictor = None, None + + def preprocess_image(self, input_data: Any): + """ + Preprocess the input image data. + + Args: + input_data (Any): Input data in various formats. + - str: Path to the image file. + - list: List of image data (numpy arrays, PIL Images, or tensors). + - numpy.ndarray: Image data as a numpy array. + - If the array has shape (batch_size, height, width, channels), it will be treated as a batch of images. + - If the array has shape (height, width, channels), it will be treated as a single image. + - PIL.Image.Image: Image data as a PIL Image object. + - torch.Tensor: Image data as a PyTorch tensor. + + Returns: + torch.Tensor: Preprocessed image data as a tensor. + - If the input is a batch of images, the output will have shape (batch_size, channels, height, width). + - If the input is a single image, the output will have shape (1, channels, height, width). + + Raises: + ValueError: If the input type is not supported. + """ + if isinstance(input_data, str): + img = Image.open(input_data).convert('RGB') + + elif isinstance(input_data, list): + imgs = [ + self.preprocess_image(i).squeeze() for i in input_data + ] + preprocessed_input = torch.stack(imgs) + return preprocessed_input + + elif isinstance(input_data, np.ndarray): + if len(input_data.shape) == 4: + input_data = input_data.transpose(0, 3, 1, 2) + preprocessed_input = torch.from_numpy(input_data).float() + preprocess = transforms.Compose([ + transforms.Resize(self.args['data']['crop_size']), + transforms.CenterCrop(self.args['data']['crop_size']), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + preprocessed_input = preprocess(preprocessed_input) + return preprocessed_input + + img = Image.fromarray(input_data.astype(np.uint8)) + + elif isinstance(input_data, Image.Image): + img = input_data + + elif isinstance(input_data, torch.Tensor): + preprocessed_input = input_data + preprocess = transforms.Compose([ + transforms.Resize(self.args['data']['crop_size']), + transforms.CenterCrop(self.args['data']['crop_size']), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + preprocessed_input = preprocess(preprocessed_input) + return preprocessed_input + + else: + raise ValueError("Unsupported input type. Expected image path, image array, or PIL Image.") + + # Define the preprocessing transforms + preprocess = transforms.Compose([ + transforms.Resize(self.args['data']['crop_size']), + transforms.CenterCrop(self.args['data']['crop_size']), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # Apply preprocessing transforms + preprocessed_input = preprocess(img) + + preprocessed_input = preprocessed_input.unsqueeze(0) # Add batch dimension + return preprocessed_input + + def embed_image(self, x): + """ + Generate embeddings for the input image data. + + Args: + x (Any): Input image data in various formats. + - str: Path to the image file. + - list: List of image data (numpy arrays, PIL Images, or tensors). + - numpy.ndarray: Image data as a numpy array. + - If the array has shape (batch_size, height, width, channels), it will be treated as a batch of images. + - If the array has shape (height, width, channels), it will be treated as a single image. + - PIL.Image.Image: Image data as a PIL Image object. + - torch.Tensor: Image data as a PyTorch tensor. + + Returns: + torch.Tensor: Embeddings for the input image data. + - If the input is a batch of images, the output will have shape (batch_size, num_patches, embedding_size). + - If the input is a single image, the output will have shape (1, num_patches, embedding_size). + + Notes: + - The input image data is preprocessed using the `preprocess_image` method before generating embeddings. + - If the preprocessed input has fewer than 5 dimensions, an additional dimension is added to represent the time dimension. + - The embeddings are generated using the forward pass of the model. + - The computation is performed on the available device (GPU if available, otherwise CPU). + """ + x = self.preprocess_image(x) + + # Unsqueeze along the time Dimension + if len(x.shape) < 5: + x = x.unsqueeze(2) + + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + + x = x.to(device) + + with torch.no_grad(): + embeddings = self.forward(x) + + return embeddings + + def load_encoder_checkpoint( + self, + r_path, + encoder, + ): + try: + checkpoint = torch.load(r_path, map_location=torch.device('cpu')) + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + + try: + + # -- loading encoder + pretrained_dict = checkpoint['encoder'] + msg = encoder.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') + + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + epoch = 0 + + return encoder + + + def forward(self, clips: torch.Tensor, masks_enc: List[torch.Tensor], masks_pred: List[torch.Tensor]) -> List[torch.Tensor]: + z = self.encoder(clips, masks_enc) + h = self._forward_target(clips, masks_pred) + z = self.predictor(z, h, masks_enc, masks_pred) + return z + + def freeze_encoder(self): + for p in self.encoder.parameters(): + p.requires_grad = False + + def forward(self, x): + return self.encoder(x) + + @classmethod + def load_model(cls, config_file_path: str, device: Optional[List[str]] = None) -> "JepaEncoder": + # TODO: Fix this so it works properly + # os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) + + args = None + with open(config_file_path, 'r') as y_file: + args = yaml.load(y_file, Loader=yaml.FullLoader) + logger.info('loaded params...') + + pprint.PrettyPrinter(indent=4).pprint(args) + dump = os.path.join(args['logging']['folder'], 'params-encoder.yaml') + with open(dump, 'w') as f: + yaml.dump(args, f) + + + model = cls(args) + + world_size, rank = init_distributed() + + # -- META + cfgs_meta = args.get('meta') + load_model = cfgs_meta.get('load_checkpoint') + assert load_model, "Cannot load model without checkpoint file specified" + r_file = cfgs_meta.get('read_checkpoint', None) + seed = cfgs_meta.get('seed', _GLOBAL_SEED) + save_every_freq = cfgs_meta.get('save_every_freq', -1) + skip_batches = cfgs_meta.get('skip_batches', -1) + use_sdpa = cfgs_meta.get('use_sdpa', False) + which_dtype = cfgs_meta.get('dtype') + logger.info(f'{which_dtype}') + if which_dtype.lower() == 'bfloat16': + dtype = torch.bfloat16 + mixed_precision = True + elif which_dtype.lower() == 'float16': + dtype = torch.float16 + mixed_precision = True + else: + dtype = torch.float32 + mixed_precision = False + + # -- MASK + cfgs_mask = args.get('mask') + + # -- MODEL + cfgs_model = args.get('model') + model_name = cfgs_model.get('model_name') + pred_depth = cfgs_model.get('pred_depth') + pred_embed_dim = cfgs_model.get('pred_embed_dim') + uniform_power = cfgs_model.get('uniform_power', True) + use_mask_tokens = cfgs_model.get('use_mask_tokens', True) + zero_init_mask_tokens = cfgs_model.get('zero_init_mask_tokens', True) + + # -- DATA + cfgs_data = args.get('data') + num_clips = cfgs_data.get('num_clips') + num_frames = cfgs_data.get('num_frames') + tubelet_size = cfgs_data.get('tubelet_size') + sampling_rate = cfgs_data.get('sampling_rate') + duration = cfgs_data.get('clip_duration', None) + crop_size = cfgs_data.get('crop_size', 224) + patch_size = cfgs_data.get('patch_size') + + # -- LOGGING + cfgs_logging = args.get('logging') + folder = cfgs_logging.get('folder') + tag = cfgs_logging.get('write_tag') + + # -- set device + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + torch.cuda.set_device(device) + + # -- log/checkpointing paths + latest_file = f'{tag}-latest.pth.tar' + latest_path = os.path.join(folder, latest_file) + load_path = None + if load_model: + load_path = os.path.join(folder, r_file) if r_file is not None else latest_path + if not os.path.exists(load_path): + load_path = r_file + if not os.path.exists(load_path): + raise RuntimeError("Cannot load model. Ensure you specify the path to the model .tar file in the input config.") + + # -- Attempt to initialize model + model.encoder, model.predictor = init_video_model( + uniform_power=uniform_power, + use_mask_tokens=use_mask_tokens, + num_mask_tokens=len(cfgs_mask), + zero_init_mask_tokens=zero_init_mask_tokens, + device=device, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + model_name=model_name, + crop_size=crop_size, + pred_depth=pred_depth, + pred_embed_dim=pred_embed_dim, + use_sdpa=use_sdpa, + ) + + # model.encoder = DistributedDataParallel(model.encoder, static_graph=True) + + # -- load training checkpoint + model.encoder = model.load_encoder_checkpoint( + load_path, model.encoder + ) + + return model + + diff --git a/build/lib/vjepa_encoder/vjepa/__init__.py b/build/lib/vjepa_encoder/vjepa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/vjepa_encoder/vjepa/train.py b/build/lib/vjepa_encoder/vjepa/train.py new file mode 100644 index 0000000..ccb2e75 --- /dev/null +++ b/build/lib/vjepa_encoder/vjepa/train.py @@ -0,0 +1,586 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS +try: + # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE + # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE + # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE + # -- TO EACH PROCESS + os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID'] +except Exception: + pass + +import copy +import time +import numpy as np + +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel + +from jepa_src.datasets.data_manager import init_data +from jepa_src.masks.random_tube import MaskCollator as TubeMaskCollator +from jepa_src.masks.multiblock3d import MaskCollator as MB3DMaskCollator +from jepa_src.masks.utils import apply_masks +from jepa_src.utils.distributed import init_distributed, AllReduce +from jepa_src.utils.logging import ( + CSVLogger, + gpu_timer, + get_logger, + grad_logger, + adamw_logger, + AverageMeter) +from jepa_src.utils.tensors import repeat_interleave_batch + +from app.vjepa.utils import ( + load_checkpoint, + init_video_model, + init_opt, +) +from app.vjepa.transforms import make_transforms + + +# -- +log_timings = True +log_freq = 10 +checkpoint_freq = 1 +# -- + +_GLOBAL_SEED = 0 +np.random.seed(_GLOBAL_SEED) +torch.manual_seed(_GLOBAL_SEED) +torch.backends.cudnn.benchmark = True + + +logger = get_logger(__name__) + + +def main(args, resume_preempt=False): + # ----------------------------------------------------------------------- # + # PASSED IN PARAMS FROM CONFIG FILE + # ----------------------------------------------------------------------- # + + # -- META + cfgs_meta = args.get('meta') + load_model = cfgs_meta.get('load_checkpoint') or resume_preempt + r_file = cfgs_meta.get('read_checkpoint', None) + seed = cfgs_meta.get('seed', _GLOBAL_SEED) + save_every_freq = cfgs_meta.get('save_every_freq', -1) + skip_batches = cfgs_meta.get('skip_batches', -1) + use_sdpa = cfgs_meta.get('use_sdpa', False) + which_dtype = cfgs_meta.get('dtype') + logger.info(f'{which_dtype}') + if which_dtype.lower() == 'bfloat16': + dtype = torch.bfloat16 + mixed_precision = True + elif which_dtype.lower() == 'float16': + dtype = torch.float16 + mixed_precision = True + else: + dtype = torch.float32 + mixed_precision = False + + # -- MASK + cfgs_mask = args.get('mask') + + # -- MODEL + cfgs_model = args.get('model') + model_name = cfgs_model.get('model_name') + pred_depth = cfgs_model.get('pred_depth') + pred_embed_dim = cfgs_model.get('pred_embed_dim') + uniform_power = cfgs_model.get('uniform_power', True) + use_mask_tokens = cfgs_model.get('use_mask_tokens', True) + zero_init_mask_tokens = cfgs_model.get('zero_init_mask_tokens', True) + + # -- DATA + cfgs_data = args.get('data') + dataset_type = cfgs_data.get('dataset_type', 'videodataset') + mask_type = cfgs_data.get('mask_type', 'multiblock3d') + dataset_paths = cfgs_data.get('datasets', []) + datasets_weights = cfgs_data.get('datasets_weights', None) + if datasets_weights is not None: + assert len(datasets_weights) == len(dataset_paths), 'Must have one sampling weight specified for each dataset' + batch_size = cfgs_data.get('batch_size') + num_clips = cfgs_data.get('num_clips') + num_frames = cfgs_data.get('num_frames') + tubelet_size = cfgs_data.get('tubelet_size') + sampling_rate = cfgs_data.get('sampling_rate') + duration = cfgs_data.get('clip_duration', None) + crop_size = cfgs_data.get('crop_size', 224) + patch_size = cfgs_data.get('patch_size') + pin_mem = cfgs_data.get('pin_mem', False) + num_workers = cfgs_data.get('num_workers', 1) + filter_short_videos = cfgs_data.get('filter_short_videos', False) + decode_one_clip = cfgs_data.get('decode_one_clip', True) + log_resource_util_data = cfgs_data.get('log_resource_utilization', False) + + # -- DATA AUGS + cfgs_data_aug = args.get('data_aug') + ar_range = cfgs_data_aug.get('random_resize_aspect_ratio', [3/4, 4/3]) + rr_scale = cfgs_data_aug.get('random_resize_scale', [0.3, 1.0]) + motion_shift = cfgs_data_aug.get('motion_shift', False) + reprob = cfgs_data_aug.get('reprob', 0.) + use_aa = cfgs_data_aug.get('auto_augment', False) + + # -- LOSS + cfgs_loss = args.get('loss') + loss_exp = cfgs_loss.get('loss_exp') + reg_coeff = cfgs_loss.get('reg_coeff') + + # -- OPTIMIZATION + cfgs_opt = args.get('optimization') + ipe = cfgs_opt.get('ipe', None) + ipe_scale = cfgs_opt.get('ipe_scale', 1.0) + clip_grad = cfgs_opt.get('clip_grad', None) + wd = float(cfgs_opt.get('weight_decay')) + final_wd = float(cfgs_opt.get('final_weight_decay')) + num_epochs = cfgs_opt.get('epochs') + warmup = cfgs_opt.get('warmup') + start_lr = cfgs_opt.get('start_lr') + lr = cfgs_opt.get('lr') + final_lr = cfgs_opt.get('final_lr') + ema = cfgs_opt.get('ema') + betas = cfgs_opt.get('betas', (0.9, 0.999)) + eps = cfgs_opt.get('eps', 1.e-8) + + # -- LOGGING + cfgs_logging = args.get('logging') + folder = cfgs_logging.get('folder') + tag = cfgs_logging.get('write_tag') + + # ----------------------------------------------------------------------- # + # ----------------------------------------------------------------------- # + + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.benchmark = True + try: + mp.set_start_method('spawn') + except Exception: + pass + + # -- init torch distributed backend + world_size, rank = init_distributed() + logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') + + # -- set device + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + torch.cuda.set_device(device) + + # -- log/checkpointing paths + log_file = os.path.join(folder, f'{tag}_r{rank}.csv') + latest_file = f'{tag}-latest.pth.tar' + latest_path = os.path.join(folder, latest_file) + load_path = None + if load_model: + load_path = os.path.join(folder, r_file) if r_file is not None else latest_path + if not os.path.exists(load_path): + load_path = None + load_model = False + + # -- make csv_logger + csv_logger = CSVLogger( + log_file, + ('%d', 'epoch'), + ('%d', 'itr'), + ('%.5f', 'loss'), + ('%.5f', 'loss-jepa'), + ('%.5f', 'reg-loss'), + ('%.5f', 'enc-grad-norm'), + ('%.5f', 'pred-grad-norm'), + ('%d', 'gpu-time(ms)'), + ('%d', 'wall-time(ms)'), + ) + + # -- init model + encoder, predictor = init_video_model( + uniform_power=uniform_power, + use_mask_tokens=use_mask_tokens, + num_mask_tokens=len(cfgs_mask), + zero_init_mask_tokens=zero_init_mask_tokens, + device=device, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + model_name=model_name, + crop_size=crop_size, + pred_depth=pred_depth, + pred_embed_dim=pred_embed_dim, + use_sdpa=use_sdpa, + ) + target_encoder = copy.deepcopy(encoder) + + # -- make data transforms + if mask_type == 'multiblock3d': + logger.info('Initializing basic multi-block mask') + mask_collator = MB3DMaskCollator( + crop_size=crop_size, + num_frames=num_frames, + patch_size=patch_size, + tubelet_size=tubelet_size, + cfgs_mask=cfgs_mask) + else: + logger.info('Initializing random tube mask') + mask_collator = TubeMaskCollator( + crop_size=crop_size, + num_frames=num_frames, + patch_size=patch_size, + tubelet_size=tubelet_size, + cfgs_mask=cfgs_mask) + transform = make_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=ar_range, + random_resize_scale=rr_scale, + reprob=reprob, + auto_augment=use_aa, + motion_shift=motion_shift, + crop_size=crop_size) + + # -- init data-loaders/samplers + (unsupervised_loader, + unsupervised_sampler) = init_data( + data=dataset_type, + root_path=dataset_paths, + batch_size=batch_size, + training=True, + clip_len=num_frames, + frame_sample_rate=sampling_rate, + filter_short_videos=filter_short_videos, + decode_one_clip=decode_one_clip, + duration=duration, + num_clips=num_clips, + transform=transform, + datasets_weights=datasets_weights, + collator=mask_collator, + num_workers=num_workers, + world_size=world_size, + pin_mem=pin_mem, + rank=rank, + log_dir=folder if log_resource_util_data else None) + try: + _dlen = len(unsupervised_loader) + except Exception: # Different interface for webdataset + _dlen = unsupervised_loader.num_batches + if ipe is None: + ipe = _dlen + logger.info(f'iterations per epoch/dataest length: {ipe}/{_dlen}') + + # -- init optimizer and scheduler + optimizer, scaler, scheduler, wd_scheduler = init_opt( + encoder=encoder, + predictor=predictor, + wd=wd, + final_wd=final_wd, + start_lr=start_lr, + ref_lr=lr, + final_lr=final_lr, + iterations_per_epoch=ipe, + warmup=warmup, + num_epochs=num_epochs, + ipe_scale=ipe_scale, + mixed_precision=mixed_precision, + betas=betas, + eps=eps) + encoder = DistributedDataParallel(encoder, static_graph=True) + predictor = DistributedDataParallel(predictor, static_graph=True) + target_encoder = DistributedDataParallel(target_encoder) + for p in target_encoder.parameters(): + p.requires_grad = False + + # -- momentum schedule + momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale) + for i in range(int(ipe*num_epochs*ipe_scale)+1)) + + start_epoch = 0 + # -- load training checkpoint + if load_model or os.path.exists(latest_path): + ( + encoder, + predictor, + target_encoder, + optimizer, + scaler, + start_epoch, + ) = load_checkpoint( + r_path=load_path, + encoder=encoder, + predictor=predictor, + target_encoder=target_encoder, + opt=optimizer, + scaler=scaler) + for _ in range(start_epoch * ipe): + scheduler.step() + wd_scheduler.step() + next(momentum_scheduler) + mask_collator.step() + + def save_checkpoint(epoch, path): + if rank != 0: + return + save_dict = { + 'encoder': encoder.state_dict(), + 'predictor': predictor.state_dict(), + 'opt': optimizer.state_dict(), + 'scaler': None if scaler is None else scaler.state_dict(), + 'target_encoder': target_encoder.state_dict(), + 'epoch': epoch, + 'loss': loss_meter.avg, + 'batch_size': batch_size, + 'world_size': world_size, + 'lr': lr, + } + try: + torch.save(save_dict, path) + except Exception as e: + logger.info(f'Encountered exception when saving checkpoint: {e}') + + logger.info('Initializing loader...') + loader = iter(unsupervised_loader) + + if skip_batches > 0: + logger.info(f'Skip {skip_batches} batches') + unsupervised_sampler.set_epoch(start_epoch) + for itr in range(skip_batches): + if itr % 10 == 0: + logger.info(f'Skip {itr}/{skip_batches} batches') + try: + udata = next(loader) + except Exception: + loader = iter(unsupervised_loader) + udata = next(loader) + + # -- TRAINING LOOP + for epoch in range(start_epoch, num_epochs): + logger.info('Epoch %d' % (epoch + 1)) + + # -- update distributed-data-loader epoch + unsupervised_sampler.set_epoch(epoch) + + loss_meter = AverageMeter() + input_var_meter = AverageMeter() + input_var_min_meter = AverageMeter() + jepa_loss_meter = AverageMeter() + reg_loss_meter = AverageMeter() + mask_meters = [AverageMeter() for _ in range(len(cfgs_mask))] + gpu_time_meter = AverageMeter() + wall_time_meter = AverageMeter() + + for itr in range(ipe): + itr_start_time = time.time() + + try: + udata, masks_enc, masks_pred = next(loader) + except Exception: + logger.info('Exhausted data loaders. Refreshing...') + loader = iter(unsupervised_loader) + udata, masks_enc, masks_pred = next(loader) + assert len(masks_enc) == len(masks_pred), \ + 'Currently require num encoder masks = num predictor masks' + + def load_clips(): + # -- unsupervised video clips + # Put each clip on the GPU and concatenate along batch + # dimension + clips = torch.cat([u.to(device, non_blocking=True) for u in udata[0]], dim=0) + + # Put each mask-enc/mask-pred pair on the GPU and reuse the + # same mask pair for each clip + _masks_enc, _masks_pred = [], [] + for _me, _mp in zip(masks_enc, masks_pred): + _me = _me.to(device, non_blocking=True) + _mp = _mp.to(device, non_blocking=True) + _me = repeat_interleave_batch(_me, batch_size, repeat=num_clips) + _mp = repeat_interleave_batch(_mp, batch_size, repeat=num_clips) + _masks_enc.append(_me) + _masks_pred.append(_mp) + + return (clips, _masks_enc, _masks_pred) + clips, masks_enc, masks_pred = load_clips() + + for _i, m in enumerate(mask_meters): + m.update(masks_enc[_i][0].size(-1)) + + def train_step(): + _new_lr = scheduler.step() + _new_wd = wd_scheduler.step() + # -- + + def forward_target(c): + """ + Returns list of tensors of shape [B, N, D], one for each + mask-pred. + """ + with torch.no_grad(): + h = target_encoder(c) + h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim [B, N, D] + # -- create targets (masked regions of h) + h = apply_masks(h, masks_pred, concat=False) + return h + + def forward_context(c, h): + """ + Returns list of tensors of shape [B, N, D], one for each + mask-pred. + """ + z = encoder(c, masks_enc) + z = predictor(z, h, masks_enc, masks_pred) + return z + + def loss_fn(z, h): + loss = 0. + # Compute loss and accumulate for each mask-enc/mask-pred pair + for zi, hi in zip(z, h): + loss += torch.mean(torch.abs(zi - hi)**loss_exp) / loss_exp + loss /= len(masks_pred) + return loss + + def reg_fn(z): + return sum([torch.sqrt(zi.var(dim=1) + 0.0001) for zi in z]) / len(z) + + # Step 1. Forward + loss_jepa, loss_reg = 0., 0. + with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision): + h = forward_target(clips) + z = forward_context(clips, h) + loss_jepa = loss_fn(z, h) # jepa prediction loss + pstd_z = reg_fn(z) # predictor variance across patches + loss_reg += torch.mean(F.relu(1.-pstd_z)) + loss = loss_jepa + reg_coeff * loss_reg + + # Step 2. Backward & step + _enc_norm, _pred_norm = 0., 0. + if mixed_precision: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + else: + loss.backward() + if (epoch > warmup) and (clip_grad is not None): + _enc_norm = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip_grad) + _pred_norm = torch.nn.utils.clip_grad_norm_(predictor.parameters(), clip_grad) + if mixed_precision: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + grad_stats = grad_logger(encoder.named_parameters()) + grad_stats.global_norm = float(_enc_norm) + grad_stats_pred = grad_logger(predictor.named_parameters()) + grad_stats_pred.global_norm = float(_pred_norm) + optimizer.zero_grad() + optim_stats = adamw_logger(optimizer) + + # Step 3. momentum update of target encoder + m = next(momentum_scheduler) + with torch.no_grad(): + for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()): + param_k.data.mul_(m).add_((1.-m) * param_q.detach().data) + + return ( + float(loss), + float(loss_jepa), + float(loss_reg), + _new_lr, + _new_wd, + grad_stats, + grad_stats_pred, + optim_stats, + ) + (loss, loss_jepa, loss_reg, _new_lr, _new_wd, grad_stats, grad_stats_pred, optim_stats,), gpu_etime_ms = gpu_timer(train_step) + iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000. + loss_meter.update(loss) + input_var = float(AllReduce.apply(clips.view(clips.shape[0], -1).var(dim=1).mean(dim=0))) + input_var_min = float(AllReduce.apply(torch.min(clips.view(clips.shape[0], -1).var(dim=1)))) + input_var_meter.update(input_var) + input_var_min_meter.update(input_var_min) + jepa_loss_meter.update(loss_jepa) + reg_loss_meter.update(loss_reg) + gpu_time_meter.update(gpu_etime_ms) + wall_time_meter.update(iter_elapsed_time_ms) + + # -- Logging + def log_stats(): + csv_logger.log( + epoch + 1, + itr, + loss, + loss_jepa, + loss_reg, + grad_stats.global_norm, + grad_stats_pred.global_norm, + gpu_etime_ms, + iter_elapsed_time_ms) + if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss): + logger.info( + '[%d, %5d] loss: %.3f | p%.3f r%.3f | ' + 'input_var: %.3f %.3f | ' + 'masks: %s ' + '[wd: %.2e] [lr: %.2e] ' + '[mem: %.2e] ' + '[gpu: %.1f ms]' + '[wall: %.1f ms]' + % (epoch + 1, itr, + loss_meter.avg, + jepa_loss_meter.avg, + reg_loss_meter.avg, + input_var_meter.avg, + input_var_min_meter.avg, + '[' + ', '.join(['%.1f' % m.avg for m in mask_meters]) + ']', + _new_wd, + _new_lr, + torch.cuda.max_memory_allocated() / 1024.0**2, + gpu_time_meter.avg, + wall_time_meter.avg)) + + if optim_stats is not None: + logger.info( + '[%d, %5d] first moment: %.2e [%.2e %.2e] second moment: %.2e [%.2e %.2e]' + % (epoch + 1, itr, + optim_stats.get('exp_avg').avg, + optim_stats.get('exp_avg').min, + optim_stats.get('exp_avg').max, + optim_stats.get('exp_avg_sq').avg, + optim_stats.get('exp_avg_sq').min, + optim_stats.get('exp_avg_sq').max)) + + if grad_stats is not None: + logger.info( + '[%d, %5d] enc_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e' + % (epoch + 1, itr, + grad_stats.first_layer, + grad_stats.last_layer, + grad_stats.min, + grad_stats.max, + grad_stats.global_norm)) + + if grad_stats_pred is not None: + logger.info( + '[%d, %5d] pred_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e' + % (epoch + 1, itr, + grad_stats_pred.first_layer, + grad_stats_pred.last_layer, + grad_stats_pred.min, + grad_stats_pred.max, + grad_stats_pred.global_norm)) + log_stats() + assert not np.isnan(loss), 'loss is nan' + + # -- Save Checkpoint + logger.info('avg. loss %.3f' % loss_meter.avg) + # -- Save Last + if epoch % checkpoint_freq == 0 or epoch == (num_epochs - 1): + save_checkpoint(epoch + 1, latest_path) + if save_every_freq > 0 and epoch % save_every_freq == 0: + save_every_file = f'{tag}-e{epoch}.pth.tar' + save_every_path = os.path.join(folder, save_every_file) + save_checkpoint(epoch + 1, save_every_path) diff --git a/build/lib/vjepa_encoder/vjepa/transforms.py b/build/lib/vjepa_encoder/vjepa/transforms.py new file mode 100644 index 0000000..ba62555 --- /dev/null +++ b/build/lib/vjepa_encoder/vjepa/transforms.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torchvision.transforms as transforms + +import jepa_src.datasets.utils.video.transforms as video_transforms +from jepa_src.datasets.utils.video.randerase import RandomErasing + + +def make_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=(3/4, 4/3), + random_resize_scale=(0.3, 1.0), + reprob=0.0, + auto_augment=False, + motion_shift=False, + crop_size=224, + normalize=((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) +): + + _frames_augmentation = VideoTransform( + random_horizontal_flip=random_horizontal_flip, + random_resize_aspect_ratio=random_resize_aspect_ratio, + random_resize_scale=random_resize_scale, + reprob=reprob, + auto_augment=auto_augment, + motion_shift=motion_shift, + crop_size=crop_size, + normalize=normalize, + ) + return _frames_augmentation + + +class VideoTransform(object): + + def __init__( + self, + random_horizontal_flip=True, + random_resize_aspect_ratio=(3/4, 4/3), + random_resize_scale=(0.3, 1.0), + reprob=0.0, + auto_augment=False, + motion_shift=False, + crop_size=224, + normalize=((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + ): + + self.random_horizontal_flip = random_horizontal_flip + self.random_resize_aspect_ratio = random_resize_aspect_ratio + self.random_resize_scale = random_resize_scale + self.auto_augment = auto_augment + self.motion_shift = motion_shift + self.crop_size = crop_size + self.mean = torch.tensor(normalize[0], dtype=torch.float32) + self.std = torch.tensor(normalize[1], dtype=torch.float32) + if not self.auto_augment: + # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255. + self.mean *= 255. + self.std *= 255. + + self.autoaug_transform = video_transforms.create_random_augment( + input_size=(crop_size, crop_size), + auto_augment='rand-m7-n4-mstd0.5-inc1', + interpolation='bicubic', + ) + + self.spatial_transform = video_transforms.random_resized_crop_with_shift \ + if motion_shift else video_transforms.random_resized_crop + + self.reprob = reprob + self.erase_transform = RandomErasing( + reprob, + mode='pixel', + max_count=1, + num_splits=1, + device='cpu', + ) + + def __call__(self, buffer): + + if self.auto_augment: + buffer = [transforms.ToPILImage()(frame) for frame in buffer] + buffer = self.autoaug_transform(buffer) + buffer = [transforms.ToTensor()(img) for img in buffer] + buffer = torch.stack(buffer) # T C H W + buffer = buffer.permute(0, 2, 3, 1) # T H W C + else: + buffer = torch.tensor(buffer, dtype=torch.float32) + + buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W + + buffer = self.spatial_transform( + images=buffer, + target_height=self.crop_size, + target_width=self.crop_size, + scale=self.random_resize_scale, + ratio=self.random_resize_aspect_ratio, + ) + if self.random_horizontal_flip: + buffer, _ = video_transforms.horizontal_flip(0.5, buffer) + + buffer = _tensor_normalize_inplace(buffer, self.mean, self.std) + if self.reprob > 0: + buffer = buffer.permute(1, 0, 2, 3) + buffer = self.erase_transform(buffer) + buffer = buffer.permute(1, 0, 2, 3) + + return buffer + + +def tensor_normalize(tensor, mean, std): + """ + Normalize a given tensor by subtracting the mean and dividing the std. + Args: + tensor (tensor): tensor to normalize. + mean (tensor or list): mean value to subtract. + std (tensor or list): std to divide. + """ + if tensor.dtype == torch.uint8: + tensor = tensor.float() + tensor = tensor / 255.0 + if type(mean) == list: + mean = torch.tensor(mean) + if type(std) == list: + std = torch.tensor(std) + tensor = tensor - mean + tensor = tensor / std + return tensor + + +def _tensor_normalize_inplace(tensor, mean, std): + """ + Normalize a given tensor by subtracting the mean and dividing the std. + Args: + tensor (tensor): tensor to normalize (with dimensions C, T, H, W). + mean (tensor): mean value to subtract (in 0 to 255 floats). + std (tensor): std to divide (in 0 to 255 floats). + """ + if tensor.dtype == torch.uint8: + tensor = tensor.float() + + C, T, H, W = tensor.shape + tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension + tensor.sub_(mean).div_(std) + tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front + return tensor diff --git a/build/lib/vjepa_encoder/vjepa/utils.py b/build/lib/vjepa_encoder/vjepa/utils.py new file mode 100644 index 0000000..2636ed7 --- /dev/null +++ b/build/lib/vjepa_encoder/vjepa/utils.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging +import sys +import warnings +import yaml + + +import torch + +import jepa_src.models.vision_transformer as video_vit +import jepa_src.models.predictor as vit_pred +from jepa_src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper +from jepa_src.utils.schedulers import ( + WarmupCosineSchedule, + CosineWDSchedule) +from jepa_src.utils.tensors import trunc_normal_ + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging.getLogger() + + +def load_checkpoint( + r_path, + encoder, + predictor, + target_encoder, + opt, + scaler, +): + try: + checkpoint = torch.load(r_path, map_location=torch.device('cpu')) + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + + epoch = 0 + try: + epoch = checkpoint['epoch'] + + # -- loading encoder + pretrained_dict = checkpoint['encoder'] + msg = encoder.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') + + # -- loading predictor + pretrained_dict = checkpoint['predictor'] + msg = predictor.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained predictor from epoch {epoch} with msg: {msg}') + + # -- loading target_encoder + if target_encoder is not None: + print(list(checkpoint.keys())) + pretrained_dict = checkpoint['target_encoder'] + msg = target_encoder.load_state_dict(pretrained_dict) + logger.info( + f'loaded pretrained target encoder from epoch {epoch} with msg: {msg}' + ) + + # -- loading optimizer + opt.load_state_dict(checkpoint['opt']) + if scaler is not None: + scaler.load_state_dict(checkpoint['scaler']) + logger.info(f'loaded optimizers from epoch {epoch}') + logger.info(f'read-path: {r_path}') + del checkpoint + + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + epoch = 0 + + return ( + encoder, + predictor, + target_encoder, + opt, + scaler, + epoch, + ) + + +def init_video_model( + device, + patch_size=16, + num_frames=16, + tubelet_size=2, + model_name='vit_base', + crop_size=224, + pred_depth=6, + pred_embed_dim=384, + uniform_power=False, + use_mask_tokens=False, + num_mask_tokens=2, + zero_init_mask_tokens=True, + use_sdpa=False, +): + encoder = video_vit.__dict__[model_name]( + img_size=crop_size, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + uniform_power=uniform_power, + use_sdpa=use_sdpa, + ) + encoder = MultiMaskWrapper(encoder) + predictor = vit_pred.__dict__['vit_predictor']( + img_size=crop_size, + use_mask_tokens=use_mask_tokens, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + embed_dim=encoder.backbone.embed_dim, + predictor_embed_dim=pred_embed_dim, + depth=pred_depth, + num_heads=encoder.backbone.num_heads, + uniform_power=uniform_power, + num_mask_tokens=num_mask_tokens, + zero_init_mask_tokens=zero_init_mask_tokens, + use_sdpa=use_sdpa, + ) + predictor = PredictorMultiMaskWrapper(predictor) + + def init_weights(m): + if isinstance(m, torch.nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + elif isinstance(m, torch.nn.LayerNorm): + torch.nn.init.constant_(m.bias, 0) + torch.nn.init.constant_(m.weight, 1.0) + + for m in encoder.modules(): + init_weights(m) + + for m in predictor.modules(): + init_weights(m) + + encoder.to(device) + predictor.to(device) + logger.info(encoder) + logger.info(predictor) + + def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.info(f'Encoder number of parameters: {count_parameters(encoder)}') + logger.info(f'Predictor number of parameters: {count_parameters(predictor)}') + + return encoder, predictor + + +def init_opt( + encoder, + predictor, + iterations_per_epoch, + start_lr, + ref_lr, + warmup, + num_epochs, + wd=1e-6, + final_wd=1e-6, + final_lr=0.0, + mixed_precision=False, + ipe_scale=1.25, + betas=(0.9, 0.999), + eps=1e-8, + zero_init_bias_wd=True, +): + param_groups = [ + { + 'params': (p for n, p in encoder.named_parameters() + if ('bias' not in n) and (len(p.shape) != 1)) + }, { + 'params': (p for n, p in predictor.named_parameters() + if ('bias' not in n) and (len(p.shape) != 1)) + }, { + 'params': (p for n, p in encoder.named_parameters() + if ('bias' in n) or (len(p.shape) == 1)), + 'WD_exclude': zero_init_bias_wd, + 'weight_decay': 0, + }, { + 'params': (p for n, p in predictor.named_parameters() + if ('bias' in n) or (len(p.shape) == 1)), + 'WD_exclude': zero_init_bias_wd, + 'weight_decay': 0, + }, + ] + + logger.info('Using AdamW') + optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps) + scheduler = WarmupCosineSchedule( + optimizer, + warmup_steps=int(warmup * iterations_per_epoch), + start_lr=start_lr, + ref_lr=ref_lr, + final_lr=final_lr, + T_max=int(ipe_scale * num_epochs * iterations_per_epoch), + ) + wd_scheduler = CosineWDSchedule( + optimizer, + ref_wd=wd, + final_wd=final_wd, + T_max=int(ipe_scale * num_epochs * iterations_per_epoch), + ) + scaler = torch.cuda.amp.GradScaler() if mixed_precision else None + return optimizer, scaler, scheduler, wd_scheduler diff --git a/demo_jepa_encoder.py b/demo_jepa_encoder.py new file mode 100644 index 0000000..9a8842f --- /dev/null +++ b/demo_jepa_encoder.py @@ -0,0 +1,22 @@ +from vjepa_encoder.vision_encoder import JepaEncoder + +encoder = JepaEncoder.load_model( + "logs/params-encoder.yaml" +) + +import numpy +import torch +img = numpy.random.random(size=(360, 480, 3)) + +x = torch.rand((32, 3, 256, 900)) + +print("Input Img:", img.shape) +embedding = encoder.embed_image(img) + +print(embedding) +print(embedding.shape) + + +embedding = encoder.embed_image(x) +print(embedding) +print(embedding.shape) \ No newline at end of file diff --git a/evals/image_classification_frozen/eval.py b/evals/image_classification_frozen/eval.py index 56d2f28..248d6aa 100644 --- a/evals/image_classification_frozen/eval.py +++ b/evals/image_classification_frozen/eval.py @@ -30,20 +30,20 @@ from timm.data import create_transform as timm_make_transforms -import src.models.vision_transformer as vit -from src.models.attentive_pooler import AttentiveClassifier -from src.datasets.data_manager import ( +import jepa_src.models.vision_transformer as vit +from jepa_src.models.attentive_pooler import AttentiveClassifier +from jepa_src.datasets.data_manager import ( init_data, ) -from src.utils.distributed import ( +from jepa_src.utils.distributed import ( init_distributed, AllReduce ) -from src.utils.schedulers import ( +from jepa_src.utils.schedulers import ( WarmupCosineSchedule, CosineWDSchedule, ) -from src.utils.logging import ( +from jepa_src.utils.logging import ( AverageMeter, CSVLogger ) diff --git a/evals/main.py b/evals/main.py index c614edb..2efa2a0 100644 --- a/evals/main.py +++ b/evals/main.py @@ -12,7 +12,7 @@ import pprint import yaml -from src.utils.distributed import init_distributed +from jepa_src.utils.distributed import init_distributed from evals.scaffold import main as eval_main diff --git a/evals/video_classification_frozen/eval.py b/evals/video_classification_frozen/eval.py index f81f526..91af6e7 100644 --- a/evals/video_classification_frozen/eval.py +++ b/evals/video_classification_frozen/eval.py @@ -28,20 +28,20 @@ from torch.nn.parallel import DistributedDataParallel -import src.models.vision_transformer as vit -from src.models.attentive_pooler import AttentiveClassifier -from src.datasets.data_manager import ( +import jepa_src.models.vision_transformer as vit +from jepa_src.models.attentive_pooler import AttentiveClassifier +from jepa_src.datasets.data_manager import ( init_data, ) -from src.utils.distributed import ( +from jepa_src.utils.distributed import ( init_distributed, AllReduce ) -from src.utils.schedulers import ( +from jepa_src.utils.schedulers import ( WarmupCosineSchedule, CosineWDSchedule, ) -from src.utils.logging import ( +from jepa_src.utils.logging import ( AverageMeter, CSVLogger ) diff --git a/evals/video_classification_frozen/utils.py b/evals/video_classification_frozen/utils.py index 450f799..6853588 100644 --- a/evals/video_classification_frozen/utils.py +++ b/evals/video_classification_frozen/utils.py @@ -11,13 +11,13 @@ import torch.nn as nn import torchvision.transforms as transforms -import src.datasets.utils.video.transforms as video_transforms -import src.datasets.utils.video.volume_transforms as volume_transforms +import jepa_src.datasets.utils.video.transforms as video_transforms +import jepa_src.datasets.utils.video.volume_transforms as volume_transforms -from src.datasets.utils.video.randerase import RandomErasing +from jepa_src.datasets.utils.video.randerase import RandomErasing -from src.models.utils.pos_embs import get_1d_sincos_pos_embed -from src.masks.utils import apply_masks +from jepa_src.models.utils.pos_embs import get_1d_sincos_pos_embed +from jepa_src.masks.utils import apply_masks class FrameAggregation(nn.Module): diff --git a/fair_documentation.md b/fair_documentation.md new file mode 100644 index 0000000..a3579e1 --- /dev/null +++ b/fair_documentation.md @@ -0,0 +1,407 @@ +# V-JEPA: Video Joint Embedding Predictive Architecture + +Official PyTorch codebase for the _video joint-embedding predictive architecture_, V-JEPA, a method for self-supervised learning of visual representations from video. + +**[Meta AI Research, FAIR](https://ai.facebook.com/research/)** + +Adrien Bardes, Quentin Garrido, Jean Ponce, Xinlei Chen, Michael Rabbat, Yann LeCun, Mahmoud Assran*, Nicolas Ballas* + +[\[Blog\]](https://ai.meta.com/blog/v-jepa-yann-lecun-ai-model-video-joint-embedding-predictive-architecture/) +[\[Paper\]](https://ai.meta.com/research/publications/revisiting-feature-prediction-for-learning-visual-representations-from-video/) +[\[Yannic Kilcher's Video\]](https://www.youtube.com/watch?v=7UkJPwz_N_0) + +V-JEPA models are trained by passively watching video pixels from the VideoMix2M dataset, and produce versatile visual representations that perform well on downstream video and image tasks, without adaption of the model’s parameters; e.g., using a frozen backbone and only a light-weight task-specific attentive probe. + +## Method +V-JEPA pretraining is based solely on an unsupervised feature prediction objective, and does not utilize pretrained image encoders, text, negative examples, human annotations, or pixel-level reconstruction. + + + +      + + + + +## Visualizations +As opposed to generative methods that have a pixel decoder, V-JEPA has a predictor that makes predictions in latent space. +We train a conditional diffusion model to decode the V-JEPA feature-space predictions to interpretable pixels; the pretrained V-JEPA encoder and predictor networks are kept frozen in this process. +The decoder is only fed the representations predicted for the missing regions of the video, and does not have access to the unmasked regions of the video. + +The V-JEPA feature predictions are indeed grounded, and exhibit spatio-temporal consistency with the unmasked regions of the video. + + +
+ + + + +
+ +## MODEL ZOO + +#### Pretrained models + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelpatch sizeresolutioniterationsbatch sizedatadownload
ViT-L2x16x16224x22490K3072VideoMix2Mcheckpointconfigs
ViT-H2x16x16224x22490K3072VideoMix2Mcheckpointconfigs
ViT-H2x16x16384x38490K2400VideoMix2Mcheckpointconfigs
+ +#### K400 Attentive probes + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelresolutionaccuracy (16x8x3)download
ViT-L/16224x22480.8attentive probe checkpointconfigs
ViT-H/16224x22482.0attentive probe checkpointconfigs
ViT-H/16384x38481.9attentive probe checkpointconfigs
+ +#### SSv2 Attentive probes + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelresolutionaccuracy (16x2x3)download
ViT-L/16224x22469.5attentive probe checkpointconfigs
ViT-H/16224x22471.4attentive probe checkpointconfigs
ViT-H/16384x38472.2attentive probe checkpointconfigs
+ +#### ImageNet1K Attentive probes + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelresolutionaccuracydownload
ViT-L/16224x22474.8attentive probe checkpointconfigs
ViT-H/16224x22475.9attentive probe checkpointconfigs
ViT-H/16384x38477.4attentive probe checkpointconfigs
+ +#### Places205 Attentive probes + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelresolutionaccuracydownload
ViT-L/16224x22460.3attentive probe checkpointconfigs
ViT-H/16224x22461.7attentive probe checkpointconfigs
ViT-H/16384x38462.8attentive probe checkpointconfigs
+ +#### iNat21 Attentive probes + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelresolutionaccuracydownload
ViT-L/16224x22467.8attentive probe checkpointconfigs
ViT-H/16224x22467.9attentive probe checkpointconfigs
ViT-H/16384x38472.6attentive probe checkpointconfigs
+ +## Code Structure + +**Config files:** +All experiment parameters are specified in config files (as opposed to command-line arguments). See the [configs/](configs/) directory for example config files. Note, before launching an experiment, you must update the paths in the config file to point to your own directories, indicating where to save the logs and checkpoints and where to find the training data. + + +``` +. +├── app # the only place where training loops are allowed +│ ├── vjepa # Video JEPA pre-training +│ ├── main_distributed.py # entrypoint for launching app on slurm cluster +│ └── main.py # entrypoint for launching app locally on your machine for debugging +├── evals # the only place where evaluation of 'apps' are allowed +│ ├── image_classification # training an attentive probe for image classification with frozen backbone +│ ├── video_classification # training an attentive probe for video classification with frozen backbone +│ ├── main_distributed.py # entrypoint for launching distributed evaluations on slurm cluster +│ └── main.py # entrypoint for launching evaluations locally on your machine for debugging +├── src # the package +│ ├── datasets # datasets, data loaders, ... +│ ├── models # model definitions +│ ├── masks # mask collators, masking utilities, ... +│ └── utils # shared utilities +└── configs # the only place where config files are allowed (specify experiment params for app/eval runs) + ├── evals # configs for launching vjepa frozen evaluations + └── pretrain # configs for launching vjepa pretraining + +``` + +## Data preparation + +### Video Datasets +V-JEPA pretraining and evaluations work with many standard video formats. +To make a video dataset compatible with the V-JEPA codebase, you simply need to create a `.csv` file with the following format and then specify the path to this CSV file in your config. +``` +/absolute_file_path.[mp4, webvid, etc.] $integer_class_label +/absolute_file_path.[mp4, webvid, etc.] $integer_class_label +/absolute_file_path.[mp4, webvid, etc.] $integer_class_label +... +``` +Since V-JEPA is entirely unsupervised, the pretraining code will disregard the `$integer_class_label` in the CSV file. +Thus, feel free to put a random value in this column. +However, if you wish to run a supervised video classification evaluation on your video dataset, you must replace ```$integer_class_label``` with the ground truth label for each video. + +### Image Datasets +We use the standard PyTorch ```ImageFolder``` class in our image classification evals. +Thus, to set up an image dataset for the image classification evaluation, first create a directory to store your image datasets ```$your_directory_containing_image_datasets```. +Next, download your image datasets into this directory in a format compatible with [PyTorch ImageFolder](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html). + +For example, suppose we have a directory called ``my_image_datasets``. We would then download our image datasets into this directory so that we end up with the following file tree +``` +. +└── /my_image_datasets/ # where we store image datasets + ├── places205/121517/pytorch/ # Places205 + │ └── [...] + ├── iNaturalist-2021/110421/ # iNaturalist21 + │ └── [...] + ├── [...] # Other Image Datasets + │ └── [...] + └── imagenet_full_size/061417/ # ImageNet1k + └── train + │ ├── $class_1 + │ │ ├── xxx.[png, jpeg, etc.] + │ │ ├── [...] + │ │ └── xxz.[png, jpeg, etc.] + │ ├── [...] + │ └── $class_n + │ ├── abc.[png, jpeg, etc.] + │ ├── [...] + │ └── abz.[png, jpeg, etc.] + └── val + ├── $class_1 + │ ├── xxx.[png, jpeg, etc.] + │ ├── [...] + │ └── xxz.[png, jpeg, etc.] + ├── [...] + └── $class_n + ├── abc.[png, jpeg, etc.] + ├── [...] + └── abz.[png, jpeg, etc.] +``` + + +## Launching V-JEPA pretraining + +### Local training +If you wish to debug your code or setup before launching a distributed training run, we provide the functionality to do so by running the pretraining script locally on a multi-GPU (or single-GPU) machine, however, reproducing our results requires launching distributed training. + +The single-machine implementation starts from the [app/main.py](appmain.py), which parses the experiment config file and runs the pretraining locally on a multi-GPU (or single-GPU) machine. +For example, to run V-JEPA pretraining on GPUs "0", "1", and "2" on a local machine using the config [configs/pretrain/vitl16.yaml](configs/pretrain/vitl16.yaml), type the command: +```bash +python -m app.main \ + --fname configs/pretrain/vitl16.yaml \ + --devices cuda:0 cuda:1 cuda:2 +``` + +### Distributed training +To launch a distributed training run, the implementation starts from [app/main_distributed.py](app/main_distributed.py), which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source [submitit](https://github.com/facebookincubator/submitit) tool and provide examples for a SLURM cluster. + +For example, to launch a distributed pre-training experiment using the config [configs/pretrain/vitl16.yaml](configs/pretrain/vitl16.yaml), type the command: +```bash +python -m app.main_distributed \ + --fname configs/pretrain/vitl16.yaml \ + --folder $path_to_save_stderr_and_stdout \ + --partition $slurm_partition +``` + +## Launching Evaluations + +### Local training +If you wish to debug your eval code or setup before launching a distributed training run, we provide the functionality to do so by running the evaluation script locally on a multi-GPU (or single-GPU) machine, however, reproducing the full eval would require launching distributed training. +The single-machine implementation starts from the [eval/main.py](eval/main.py), which parses the experiment config file and runs the eval locally on a multi-GPU (or single-GPU) machine. + +For example, to run ImageNet image classification on GPUs "0", "1", and "2" on a local machine using the config [configs/eval/vitl16_in1k.yaml](configs/eval/vitl16_in1k.yaml), type the command: +```bash +python -m evals.main \ + --fname configs/eval/vitl16_in1k.yaml \ + --devices cuda:0 cuda:1 cuda:2 +``` + + +### Distributed training +To launch a distributed evaluation run, the implementation starts from [eval/main_distributed.py](eval/main_distributed.py), which, in addition to parsing the config file, also allows for specifying details about distributed training. For distributed training, we use the popular open-source [submitit](https://github.com/facebookincubator/submitit) tool and provide examples for a SLURM cluster. + +For example, to launch a distributed ImageNet image classification experiment using the config [configs/eval/vitl16_in1k.yaml](configs/eval/vitl16_in1k.yaml), type the command: +```bash +python -m evals.main_distributed \ + --fname configs/eval/vitl16_in1k.yaml \ + --folder $path_to_save_stderr_and_stdout \ + --partition $slurm_partition +``` + +Similarly, to launch a distributed K400 video classification experiment using the config [configs/eval/vitl16_k400.yaml](configs/eval/vitl16_k400.yaml), type the command: +```bash +python -m evals.main_distributed \ + --fname configs/eval/vitl16_k400.yaml \ + --folder $path_to_save_stderr_and_stdout \ + --partition $slurm_partition +``` + +--- + +### Setup + +Run: +```bash +conda create -n jepa python=3.9 pip +conda activate jepa +python setup.py install +``` + +## License +See the [LICENSE](./LICENSE) file for details about the license under which this code is made available. + +## Citation +If you find this repository useful in your research, please consider giving a star :star: and a citation +```bibtex +@article{bardes2024revisiting, + title={Revisiting Feature Prediction for Learning Visual Representations from Video}, + author={Bardes, Adrien and Garrido, Quentin and Ponce, Jean and Rabbat, Michael, and LeCun, Yann and Assran, Mahmoud and Ballas, Nicolas}, + journal={arXiv preprint}, + year={2024} +} diff --git a/jepa_encoder.egg-info/PKG-INFO b/jepa_encoder.egg-info/PKG-INFO new file mode 100644 index 0000000..6a3951f --- /dev/null +++ b/jepa_encoder.egg-info/PKG-INFO @@ -0,0 +1,17 @@ +Metadata-Version: 2.1 +Name: jepa-encoder +Version: 0.0.1 +Summary: JEPA research code. +Requires-Python: >=3.9 +License-File: LICENSE +Requires-Dist: pyyaml +Requires-Dist: numpy +Requires-Dist: opencv-python +Requires-Dist: submitit +Requires-Dist: braceexpand +Requires-Dist: webdataset +Requires-Dist: timm +Requires-Dist: decord +Requires-Dist: pandas +Requires-Dist: einops +Requires-Dist: beartype diff --git a/jepa_encoder.egg-info/SOURCES.txt b/jepa_encoder.egg-info/SOURCES.txt new file mode 100644 index 0000000..00be8b0 --- /dev/null +++ b/jepa_encoder.egg-info/SOURCES.txt @@ -0,0 +1,10 @@ +LICENSE +README.md +setup.py +jepa_encoder.egg-info/PKG-INFO +jepa_encoder.egg-info/SOURCES.txt +jepa_encoder.egg-info/dependency_links.txt +jepa_encoder.egg-info/requires.txt +jepa_encoder.egg-info/top_level.txt +vjepa_encoder/__init__.py +vjepa_encoder/vision_encoder.py \ No newline at end of file diff --git a/jepa_encoder.egg-info/dependency_links.txt b/jepa_encoder.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/jepa_encoder.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/jepa_encoder.egg-info/requires.txt b/jepa_encoder.egg-info/requires.txt new file mode 100644 index 0000000..386919b --- /dev/null +++ b/jepa_encoder.egg-info/requires.txt @@ -0,0 +1,11 @@ +pyyaml +numpy +opencv-python +submitit +braceexpand +webdataset +timm +decord +pandas +einops +beartype diff --git a/jepa_encoder.egg-info/top_level.txt b/jepa_encoder.egg-info/top_level.txt new file mode 100644 index 0000000..cca3137 --- /dev/null +++ b/jepa_encoder.egg-info/top_level.txt @@ -0,0 +1 @@ +vjepa_encoder diff --git a/jepa_src/__init__.py b/jepa_src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/datasets/__init__.py b/jepa_src/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/datasets/data_manager.py b/jepa_src/datasets/data_manager.py new file mode 100644 index 0000000..cf53940 --- /dev/null +++ b/jepa_src/datasets/data_manager.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from logging import getLogger + + +_GLOBAL_SEED = 0 +logger = getLogger() + + +def init_data( + batch_size, + transform=None, + shared_transform=None, + data='ImageNet', + collator=None, + pin_mem=True, + num_workers=8, + world_size=1, + rank=0, + root_path=None, + image_folder=None, + training=True, + copy_data=False, + drop_last=True, + tokenize_txt=True, + subset_file=None, + clip_len=8, + frame_sample_rate=2, + duration=None, + num_clips=1, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(1e9), + decode_one_clip=True, + datasets_weights=None, + persistent_workers=False, + repeat_wds=False, + ipe=300, + log_dir=None, +): + + if (data.lower() == 'imagenet') \ + or (data.lower() == 'inat21') \ + or (data.lower() == 'places205'): + from jepa_src.datasets.image_dataset import make_imagedataset + dataset, data_loader, dist_sampler = make_imagedataset( + transform=transform, + batch_size=batch_size, + collator=collator, + pin_mem=pin_mem, + training=training, + num_workers=num_workers, + world_size=world_size, + rank=rank, + root_path=root_path, + image_folder=image_folder, + persistent_workers=persistent_workers, + copy_data=copy_data, + drop_last=drop_last, + subset_file=subset_file) + + elif data.lower() == 'videodataset': + from jepa_src.datasets.video_dataset import make_videodataset + dataset, data_loader, dist_sampler = make_videodataset( + data_paths=root_path, + batch_size=batch_size, + frames_per_clip=clip_len, + frame_step=frame_sample_rate, + duration=duration, + num_clips=num_clips, + random_clip_sampling=random_clip_sampling, + allow_clip_overlap=allow_clip_overlap, + filter_short_videos=filter_short_videos, + filter_long_videos=filter_long_videos, + shared_transform=shared_transform, + transform=transform, + datasets_weights=datasets_weights, + collator=collator, + num_workers=num_workers, + world_size=world_size, + rank=rank, + drop_last=drop_last, + log_dir=log_dir) + + return (data_loader, dist_sampler) diff --git a/jepa_src/datasets/image_dataset.py b/jepa_src/datasets/image_dataset.py new file mode 100644 index 0000000..84e9b08 --- /dev/null +++ b/jepa_src/datasets/image_dataset.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +from logging import getLogger + +import torch +import torchvision + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class ImageFolder(torchvision.datasets.ImageFolder): + + def __init__( + self, + root, + image_folder='imagenet_full_size/061417/', + transform=None, + train=True, + ): + """ + ImageFolder + :param root: root network directory for ImageFolder data + :param image_folder: path to images inside root network directory + :param train: whether to load train data (or validation) + """ + + suffix = 'train/' if train else 'val/' + data_path = os.path.join(root, image_folder, suffix) + logger.info(f'data-path {data_path}') + super(ImageFolder, self).__init__(root=data_path, transform=transform) + logger.info('Initialized ImageFolder') + + +def make_imagedataset( + transform, + batch_size, + collator=None, + pin_mem=True, + num_workers=8, + world_size=1, + rank=0, + root_path=None, + image_folder=None, + training=True, + copy_data=False, + drop_last=True, + persistent_workers=False, + subset_file=None +): + dataset = ImageFolder( + root=root_path, + image_folder=image_folder, + transform=transform, + train=training) + logger.info('ImageFolder dataset created') + dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset=dataset, + num_replicas=world_size, + rank=rank) + data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=collator, + sampler=dist_sampler, + batch_size=batch_size, + drop_last=drop_last, + pin_memory=pin_mem, + num_workers=num_workers, + persistent_workers=persistent_workers) + logger.info('ImageFolder unsupervised data loader created') + + return dataset, data_loader, dist_sampler diff --git a/jepa_src/datasets/utils/__init__.py b/jepa_src/datasets/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/datasets/utils/video/__init__.py b/jepa_src/datasets/utils/video/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/datasets/utils/video/functional.py b/jepa_src/datasets/utils/video/functional.py new file mode 100644 index 0000000..a91d15d --- /dev/null +++ b/jepa_src/datasets/utils/video/functional.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numbers +import cv2 +import numpy as np +import PIL +import torch + + +def _is_tensor_clip(clip): + return torch.is_tensor(clip) and clip.ndimension() == 4 + + +def crop_clip(clip, min_h, min_w, h, w): + if isinstance(clip[0], np.ndarray): + cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] + + elif isinstance(clip[0], PIL.Image.Image): + cropped = [ + img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return cropped + + +def resize_clip(clip, size, interpolation='bilinear'): + if isinstance(clip[0], np.ndarray): + if isinstance(size, numbers.Number): + im_h, im_w, im_c = clip[0].shape + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[0], size[1] + if interpolation == 'bilinear': + np_inter = cv2.INTER_LINEAR + else: + np_inter = cv2.INTER_NEAREST + scaled = [ + cv2.resize(img, size, interpolation=np_inter) for img in clip + ] + elif isinstance(clip[0], PIL.Image.Image): + if isinstance(size, numbers.Number): + im_w, im_h = clip[0].size + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[1], size[0] + if interpolation == 'bilinear': + pil_inter = PIL.Image.BILINEAR + else: + pil_inter = PIL.Image.NEAREST + scaled = [img.resize(size, pil_inter) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return scaled + + +def get_resize_sizes(im_h, im_w, size): + if im_w < im_h: + ow = size + oh = int(size * im_h / im_w) + else: + oh = size + ow = int(size * im_w / im_h) + return oh, ow + + +def normalize(clip, mean, std, inplace=False): + if not _is_tensor_clip(clip): + raise TypeError('tensor is not a torch clip.') + + if not inplace: + clip = clip.clone() + + dtype = clip.dtype + mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) + std = torch.as_tensor(std, dtype=dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + + return clip diff --git a/jepa_src/datasets/utils/video/randaugment.py b/jepa_src/datasets/utils/video/randaugment.py new file mode 100644 index 0000000..4c80a99 --- /dev/null +++ b/jepa_src/datasets/utils/video/randaugment.py @@ -0,0 +1,518 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py +pulished under an Apache License 2.0. +""" + +import math +import numpy as np +import random +import re +import PIL +from PIL import Image, ImageEnhance, ImageOps + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10.0 + +_HPARAMS_DEFAULT = { + "translate_const": 250, + "img_mean": _FILL, +} + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop("resample", Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if "fillcolor" in kwargs and _PIL_VER < (5, 0): + kwargs.pop("fillcolor") + kwargs["resample"] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs + ) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs + ) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs + ) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs + ) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs + ) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs + ) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], + -rotn_center[1] - post_trans[1], + matrix, + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs["resample"]) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30.0 + level = _randomly_negate(level) + return (level,) + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + + +def _enhance_increasing_level_to_arg(level, _hparams): + # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend + # range [0.1, 1.9] + level = (level / _MAX_LEVEL) * 0.9 + level = 1.0 + _randomly_negate(level) + return (level,) + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams["translate_const"] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return (level,) + + +def _translate_rel_level_to_arg(level, hparams): + # default range [-0.45, 0.45] + translate_pct = hparams.get("translate_pct", 0.45) + level = (level / _MAX_LEVEL) * translate_pct + level = _randomly_negate(level) + return (level,) + + +def _posterize_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4),) + + +def _posterize_increasing_level_to_arg(level, hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image', + # intensity/severity of augmentation increases with level + return (4 - _posterize_level_to_arg(level, hparams)[0],) + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4) + 4,) + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 256),) + + +def _solarize_increasing_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation increases with level + return (256 - _solarize_level_to_arg(level, _hparams)[0],) + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return (int((level / _MAX_LEVEL) * 110),) + + +LEVEL_TO_ARG = { + "AutoContrast": None, + "Equalize": None, + "Invert": None, + "Rotate": _rotate_level_to_arg, + # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + "Posterize": _posterize_level_to_arg, + "PosterizeIncreasing": _posterize_increasing_level_to_arg, + "PosterizeOriginal": _posterize_original_level_to_arg, + "Solarize": _solarize_level_to_arg, + "SolarizeIncreasing": _solarize_increasing_level_to_arg, + "SolarizeAdd": _solarize_add_level_to_arg, + "Color": _enhance_level_to_arg, + "ColorIncreasing": _enhance_increasing_level_to_arg, + "Contrast": _enhance_level_to_arg, + "ContrastIncreasing": _enhance_increasing_level_to_arg, + "Brightness": _enhance_level_to_arg, + "BrightnessIncreasing": _enhance_increasing_level_to_arg, + "Sharpness": _enhance_level_to_arg, + "SharpnessIncreasing": _enhance_increasing_level_to_arg, + "ShearX": _shear_level_to_arg, + "ShearY": _shear_level_to_arg, + "TranslateX": _translate_abs_level_to_arg, + "TranslateY": _translate_abs_level_to_arg, + "TranslateXRel": _translate_rel_level_to_arg, + "TranslateYRel": _translate_rel_level_to_arg, +} + + +NAME_TO_OP = { + "AutoContrast": auto_contrast, + "Equalize": equalize, + "Invert": invert, + "Rotate": rotate, + "Posterize": posterize, + "PosterizeIncreasing": posterize, + "PosterizeOriginal": posterize, + "Solarize": solarize, + "SolarizeIncreasing": solarize, + "SolarizeAdd": solarize_add, + "Color": color, + "ColorIncreasing": color, + "Contrast": contrast, + "ContrastIncreasing": contrast, + "Brightness": brightness, + "BrightnessIncreasing": brightness, + "Sharpness": sharpness, + "SharpnessIncreasing": sharpness, + "ShearX": shear_x, + "ShearY": shear_y, + "TranslateX": translate_x_abs, + "TranslateY": translate_y_abs, + "TranslateXRel": translate_x_rel, + "TranslateYRel": translate_y_rel, +} + + +class AugmentOp: + """ + Apply for video. + """ + + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = { + "fillcolor": hparams["img_mean"] + if "img_mean" in hparams + else _FILL, + "resample": hparams["interpolation"] + if "interpolation" in hparams + else _RANDOM_INTERPOLATION, + } + + # If magnitude_std is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from a normal distribution + # with mean `magnitude` and std-dev of `magnitude_std`. + # NOTE This is my own hack, being tested, not in papers or reference impls. + self.magnitude_std = self.hparams.get("magnitude_std", 0) + + def __call__(self, img_list): + if self.prob < 1.0 and random.random() > self.prob: + return img_list + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + level_args = ( + self.level_fn(magnitude, self.hparams) + if self.level_fn is not None + else () + ) + + if isinstance(img_list, list): + return [ + self.aug_fn(img, *level_args, **self.kwargs) for img in img_list + ] + else: + return self.aug_fn(img_list, *level_args, **self.kwargs) + + +_RAND_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "Posterize", + "Solarize", + "SolarizeAdd", + "Color", + "Contrast", + "Brightness", + "Sharpness", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + + +_RAND_INCREASING_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "PosterizeIncreasing", + "SolarizeIncreasing", + "SolarizeAdd", + "ColorIncreasing", + "ContrastIncreasing", + "BrightnessIncreasing", + "SharpnessIncreasing", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + + +# These experimental weights are based loosely on the relative improvements mentioned in paper. +# They may not result in increased performance, but could likely be tuned to so. +_RAND_CHOICE_WEIGHTS_0 = { + "Rotate": 0.3, + "ShearX": 0.2, + "ShearY": 0.2, + "TranslateXRel": 0.1, + "TranslateYRel": 0.1, + "Color": 0.025, + "Sharpness": 0.025, + "AutoContrast": 0.025, + "Solarize": 0.005, + "SolarizeAdd": 0.005, + "Contrast": 0.005, + "Brightness": 0.005, + "Equalize": 0.005, + "Posterize": 0, + "Invert": 0, +} + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [ + AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) + for name in transforms + ] + + +class RandAugment: + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + # no replacement when using weighted choice + ops = np.random.choice( + self.ops, + self.num_layers, + replace=self.choice_weights is None, + p=self.choice_weights, + ) + for op in ops: + img = op(img) + return img + + +def rand_augment_transform(config_str, hparams): + """ + RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 + + Create a RandAugment transform + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) + 'mstd' - float std deviation of magnitude noise applied + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 + :param hparams: Other hparams (kwargs) for the RandAugmentation scheme + :return: A PyTorch compatible Transform + """ + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + transforms = _RAND_TRANSFORMS + config = config_str.split("-") + assert config[0] == "rand" + config = config[1:] + for c in config: + cs = re.split(r"(\d.*)", c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == "mstd": + # noise param injected via hparams for now + hparams.setdefault("magnitude_std", float(val)) + elif key == "inc": + if bool(val): + transforms = _RAND_INCREASING_TRANSFORMS + elif key == "m": + magnitude = int(val) + elif key == "n": + num_layers = int(val) + elif key == "w": + weight_idx = int(val) + else: + assert NotImplementedError + ra_ops = rand_augment_ops( + magnitude=magnitude, hparams=hparams, transforms=transforms + ) + choice_weights = ( + None if weight_idx is None else _select_rand_weights(weight_idx) + ) + return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) diff --git a/jepa_src/datasets/utils/video/randerase.py b/jepa_src/datasets/utils/video/randerase.py new file mode 100644 index 0000000..d1f185c --- /dev/null +++ b/jepa_src/datasets/utils/video/randerase.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py +pulished under an Apache License 2.0. +""" +import math +import random +import torch + + +def _get_pixels( + per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" +): + # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() + # paths, flip the order so normal is run on CPU if this becomes a problem + # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 + if per_pixel: + return torch.empty(patch_size, dtype=dtype, device=device).normal_() + elif rand_color: + return torch.empty( + (patch_size[0], 1, 1), dtype=dtype, device=device + ).normal_() + else: + return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) + + +class RandomErasing: + """Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + This variant of RandomErasing is intended to be applied to either a batch + or single image tensor after it has been normalized by dataset mean and std. + Args: + probability: Probability that the Random Erasing operation will be performed. + min_area: Minimum percentage of erased area wrt input image area. + max_area: Maximum percentage of erased area wrt input image area. + min_aspect: Minimum aspect ratio of erased area. + mode: pixel color mode, one of 'const', 'rand', or 'pixel' + 'const' - erase block is constant color of 0 for all channels + 'rand' - erase block is same per-channel random (normal) color + 'pixel' - erase block is per-pixel random (normal) color + max_count: maximum number of erasing blocks per image, area per box is scaled by count. + per-image count is randomly chosen between 1 and this value. + """ + + def __init__( + self, + probability=0.5, + min_area=0.02, + max_area=1 / 3, + min_aspect=0.3, + max_aspect=None, + mode="const", + min_count=1, + max_count=None, + num_splits=0, + device="cuda", + cube=True, + ): + self.probability = probability + self.min_area = min_area + self.max_area = max_area + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + self.min_count = min_count + self.max_count = max_count or min_count + self.num_splits = num_splits + mode = mode.lower() + self.rand_color = False + self.per_pixel = False + self.cube = cube + if mode == "rand": + self.rand_color = True # per block random normal + elif mode == "pixel": + self.per_pixel = True # per pixel random normal + else: + assert not mode or mode == "const" + self.device = device + + def _erase(self, img, chan, img_h, img_w, dtype): + if random.random() > self.probability: + return + area = img_h * img_w + count = ( + self.min_count + if self.min_count == self.max_count + else random.randint(self.min_count, self.max_count) + ) + for _ in range(count): + for _ in range(10): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / count + ) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + img[:, top:top + h, left:left + w] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) + break + + def _erase_cube( + self, + img, + batch_start, + batch_size, + chan, + img_h, + img_w, + dtype, + ): + if random.random() > self.probability: + return + area = img_h * img_w + count = ( + self.min_count + if self.min_count == self.max_count + else random.randint(self.min_count, self.max_count) + ) + for _ in range(count): + for _ in range(100): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / count + ) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + for i in range(batch_start, batch_size): + img_instance = img[i] + img_instance[ + :, top:top + h, left:left + w + ] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) + break + + def __call__(self, input): + if len(input.size()) == 3: + self._erase(input, *input.size(), input.dtype) + else: + batch_size, chan, img_h, img_w = input.size() + # skip first slice of batch if num_splits is set (for clean portion of samples) + batch_start = ( + batch_size // self.num_splits if self.num_splits > 1 else 0 + ) + if self.cube: + self._erase_cube( + input, + batch_start, + batch_size, + chan, + img_h, + img_w, + input.dtype, + ) + else: + for i in range(batch_start, batch_size): + self._erase(input[i], chan, img_h, img_w, input.dtype) + return input diff --git a/jepa_src/datasets/utils/video/transforms.py b/jepa_src/datasets/utils/video/transforms.py new file mode 100644 index 0000000..979985d --- /dev/null +++ b/jepa_src/datasets/utils/video/transforms.py @@ -0,0 +1,1184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +import numpy as np +import random +import numbers +import PIL +from PIL import Image + +import torch +import torchvision +import torchvision.transforms.functional as F +from torchvision import transforms + +import jepa_src.datasets.utils.video.functional as FF +from jepa_src.datasets.utils.video.randaugment import rand_augment_transform + + +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _pil_interp(method): + if method == 'bicubic': + return Image.BICUBIC + elif method == 'lanczos': + return Image.LANCZOS + elif method == 'hamming': + return Image.HAMMING + else: + return Image.BILINEAR + + +def random_short_side_scale_jitter( + images, min_size, max_size, boxes=None, inverse_uniform_sampling=False +): + """ + Perform a spatial short scale jittering on the given images and + corresponding boxes. + Args: + images (tensor): images to perform scale jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + min_size (int): the minimal size to scale the frames. + max_size (int): the maximal size to scale the frames. + boxes (ndarray): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + inverse_uniform_sampling (bool): if True, sample uniformly in + [1 / max_scale, 1 / min_scale] and take a reciprocal to get the + scale. If False, take a uniform sample from [min_scale, max_scale]. + Returns: + (tensor): the scaled images with dimension of + `num frames` x `channel` x `new height` x `new width`. + (ndarray or None): the scaled boxes with dimension of + `num boxes` x 4. + """ + if inverse_uniform_sampling: + size = int( + round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)) + ) + else: + size = int(round(np.random.uniform(min_size, max_size))) + + height = images.shape[2] + width = images.shape[3] + if (width <= height and width == size) or ( + height <= width and height == size + ): + return images, boxes + new_width = size + new_height = size + if width < height: + new_height = int(math.floor((float(height) / width) * size)) + if boxes is not None: + boxes = boxes * float(new_height) / height + else: + new_width = int(math.floor((float(width) / height) * size)) + if boxes is not None: + boxes = boxes * float(new_width) / width + + return ( + torch.nn.functional.interpolate( + images, + size=(new_height, new_width), + mode='bilinear', + align_corners=False, + ), + boxes, + ) + + +def crop_boxes(boxes, x_offset, y_offset): + """ + Peform crop on the bounding boxes given the offsets. + Args: + boxes (ndarray or None): bounding boxes to peform crop. The dimension + is `num boxes` x 4. + x_offset (int): cropping offset in the x axis. + y_offset (int): cropping offset in the y axis. + Returns: + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + cropped_boxes = boxes.copy() + cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset + cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset + + return cropped_boxes + + +def random_crop(images, size, boxes=None): + """ + Perform random spatial crop on the given images and corresponding boxes. + Args: + images (tensor): images to perform random crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): the size of height and width to crop on the image. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + Returns: + cropped (tensor): cropped images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + if images.shape[2] == size and images.shape[3] == size: + return images + height = images.shape[2] + width = images.shape[3] + y_offset = 0 + if height > size: + y_offset = int(np.random.randint(0, height - size)) + x_offset = 0 + if width > size: + x_offset = int(np.random.randint(0, width - size)) + cropped = images[ + :, :, y_offset:y_offset + size, x_offset:x_offset + size + ] + + cropped_boxes = ( + crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + ) + + return cropped, cropped_boxes + + +def horizontal_flip(prob, images, boxes=None): + """ + Perform horizontal flip on the given images and corresponding boxes. + Args: + prob (float): probility to flip the images. + images (tensor): images to perform horizontal flip, the dimension is + `num frames` x `channel` x `height` x `width`. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + Returns: + images (tensor): images with dimension of + `num frames` x `channel` x `height` x `width`. + flipped_boxes (ndarray or None): the flipped boxes with dimension of + `num boxes` x 4. + """ + if boxes is None: + flipped_boxes = None + else: + flipped_boxes = boxes.copy() + + if np.random.uniform() < prob: + images = images.flip((-1)) + + if len(images.shape) == 3: + width = images.shape[2] + elif len(images.shape) == 4: + width = images.shape[3] + else: + raise NotImplementedError("Dimension does not supported") + if boxes is not None: + flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1 + + return images, flipped_boxes + + +def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): + """ + Perform uniform spatial sampling on the images and corresponding boxes. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode='bilinear', + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[ + :, :, y_offset:y_offset + size, x_offset:x_offset + size + ] + cropped_boxes = ( + crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + ) + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped, cropped_boxes + + +def clip_boxes_to_image(boxes, height, width): + """ + Clip an array of boxes to an image with the given height and width. + Args: + boxes (ndarray): bounding boxes to perform clipping. + Dimension is `num boxes` x 4. + height (int): given image height. + width (int): given image width. + Returns: + clipped_boxes (ndarray): the clipped boxes with dimension of + `num boxes` x 4. + """ + clipped_boxes = boxes.copy() + clipped_boxes[:, [0, 2]] = np.minimum( + width - 1.0, np.maximum(0.0, boxes[:, [0, 2]]) + ) + clipped_boxes[:, [1, 3]] = np.minimum( + height - 1.0, np.maximum(0.0, boxes[:, [1, 3]]) + ) + return clipped_boxes + + +def blend(images1, images2, alpha): + """ + Blend two images with a given weight alpha. + Args: + images1 (tensor): the first images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + images2 (tensor): the second images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + alpha (float): the blending weight. + Returns: + (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + return images1 * alpha + images2 * (1 - alpha) + + +def grayscale(images): + """ + Get the grayscale for the input images. The channels of images should be + in order BGR. + Args: + images (tensor): the input images for getting grayscale. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + img_gray (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + # R -> 0.299, G -> 0.587, B -> 0.114. + img_gray = torch.tensor(images) + gray_channel = ( + 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] + ) + img_gray[:, 0] = gray_channel + img_gray[:, 1] = gray_channel + img_gray[:, 2] = gray_channel + return img_gray + + +def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): + """ + Perfrom a color jittering on the input images. The channels of images + should be in order BGR. + Args: + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + img_brightness (float): jitter ratio for brightness. + img_contrast (float): jitter ratio for contrast. + img_saturation (float): jitter ratio for saturation. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + + jitter = [] + if img_brightness != 0: + jitter.append('brightness') + if img_contrast != 0: + jitter.append('contrast') + if img_saturation != 0: + jitter.append('saturation') + + if len(jitter) > 0: + order = np.random.permutation(np.arange(len(jitter))) + for idx in range(0, len(jitter)): + if jitter[order[idx]] == 'brightness': + images = brightness_jitter(img_brightness, images) + elif jitter[order[idx]] == 'contrast': + images = contrast_jitter(img_contrast, images) + elif jitter[order[idx]] == 'saturation': + images = saturation_jitter(img_saturation, images) + return images + + +def brightness_jitter(var, images): + """ + Perfrom brightness jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for brightness. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_bright = torch.zeros(images.shape) + images = blend(images, img_bright, alpha) + return images + + +def contrast_jitter(var, images): + """ + Perfrom contrast jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for contrast. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_gray = grayscale(images) + img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True) + images = blend(images, img_gray, alpha) + return images + + +def saturation_jitter(var, images): + """ + Perfrom saturation jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for saturation. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + img_gray = grayscale(images) + images = blend(images, img_gray, alpha) + + return images + + +def lighting_jitter(images, alphastd, eigval, eigvec): + """ + Perform AlexNet-style PCA jitter on the given images. + Args: + images (tensor): images to perform lighting jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + alphastd (float): jitter ratio for PCA jitter. + eigval (list): eigenvalues for PCA jitter. + eigvec (list[list]): eigenvectors for PCA jitter. + Returns: + out_images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if alphastd == 0: + return images + # generate alpha1, alpha2, alpha3. + alpha = np.random.normal(0, alphastd, size=(1, 3)) + eig_vec = np.array(eigvec) + eig_val = np.reshape(eigval, (1, 3)) + rgb = np.sum( + eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), + axis=1, + ) + out_images = torch.zeros_like(images) + if len(images.shape) == 3: + # C H W + channel_dim = 0 + elif len(images.shape) == 4: + # T C H W + channel_dim = 1 + else: + raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + + for idx in range(images.shape[channel_dim]): + # C H W + if len(images.shape) == 3: + out_images[idx] = images[idx] + rgb[2 - idx] + # T C H W + elif len(images.shape) == 4: + out_images[:, idx] = images[:, idx] + rgb[2 - idx] + else: + raise NotImplementedError( + f'Unsupported dimension {len(images.shape)}' + ) + + return out_images + + +def color_normalization(images, mean, stddev): + """ + Perform color nomration on the given images. + Args: + images (tensor): images to perform color normalization. Dimension is + `num frames` x `channel` x `height` x `width`. + mean (list): mean values for normalization. + stddev (list): standard deviations for normalization. + + Returns: + out_images (tensor): the noramlized images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if len(images.shape) == 3: + assert ( + len(mean) == images.shape[0] + ), 'channel mean not computed properly' + assert ( + len(stddev) == images.shape[0] + ), 'channel stddev not computed properly' + elif len(images.shape) == 4: + assert ( + len(mean) == images.shape[1] + ), 'channel mean not computed properly' + assert ( + len(stddev) == images.shape[1] + ), 'channel stddev not computed properly' + else: + raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + + out_images = torch.zeros_like(images) + for idx in range(len(mean)): + # C H W + if len(images.shape) == 3: + out_images[idx] = (images[idx] - mean[idx]) / stddev[idx] + elif len(images.shape) == 4: + out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] + else: + raise NotImplementedError( + f'Unsupported dimension {len(images.shape)}' + ) + return out_images + + +def _get_param_spatial_crop( + scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False +): + """ + Given scale, ratio, height and width, return sampled coordinates of the videos. + """ + for _ in range(num_repeat): + area = height * width + target_area = random.uniform(*scale) * area + if log_scale: + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + else: + aspect_ratio = random.uniform(*ratio) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if np.random.uniform() < 0.5 and switch_hw: + w, h = h, w + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(ratio): + w = width + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + +def random_resized_crop( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + Crop the given images to random size and aspect ratio. A crop of random + size (default: of 0.08 to 1.0) of the original size and a random aspect + ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This + crop is finally resized to given size. This is popularly used to train the + Inception networks. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + cropped = images[:, :, i:i + h, j:j + w] + return torch.nn.functional.interpolate( + cropped, + size=(target_height, target_width), + mode='bilinear', + align_corners=False, + ) + + +def random_resized_crop_with_shift( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + This is similar to random_resized_crop. However, it samples two different + boxes (for cropping) for the first and last frame. It then linearly + interpolates the two boxes for other frames. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + t = images.shape[1] + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width) + i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()] + j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()] + h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()] + w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] + out = torch.zeros((3, t, target_height, target_width)) + for ind in range(t): + out[:, ind:ind + 1, :, :] = torch.nn.functional.interpolate( + images[ + :, + ind:ind + 1, + i_s[ind]:i_s[ind] + h_s[ind], + j_s[ind]:j_s[ind] + w_s[ind], + ], + size=(target_height, target_width), + mode='bilinear', + align_corners=False, + ) + return out + + +def create_random_augment( + input_size, + auto_augment=None, + interpolation='bilinear', +): + """ + Get video randaug transform. + + Args: + input_size: The size of the input video in tuple. + auto_augment: Parameters for randaug. An example: + "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number + of operations to apply). + interpolation: Interpolation method. + """ + if isinstance(input_size, tuple): + img_size = input_size[-2:] + else: + img_size = input_size + + if auto_augment: + assert isinstance(auto_augment, str) + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size + aa_params = {'translate_const': int(img_size_min * 0.45)} + if interpolation and interpolation != 'random': + aa_params['interpolation'] = _pil_interp(interpolation) + if auto_augment.startswith('rand'): + return transforms.Compose( + [rand_augment_transform(auto_augment, aa_params)] + ) + raise NotImplementedError + + +def random_sized_crop_img( + im, + size, + jitter_scale=(0.08, 1.0), + jitter_aspect=(3.0 / 4.0, 4.0 / 3.0), + max_iter=10, +): + """ + Performs Inception-style cropping (used for training). + """ + assert ( + len(im.shape) == 3 + ), 'Currently only support image for random_sized_crop' + h, w = im.shape[1:3] + i, j, h, w = _get_param_spatial_crop( + scale=jitter_scale, + ratio=jitter_aspect, + height=h, + width=w, + num_repeat=max_iter, + log_scale=False, + switch_hw=True, + ) + cropped = im[:, i:i + h, j:j + w] + return torch.nn.functional.interpolate( + cropped.unsqueeze(0), + size=(size, size), + mode='bilinear', + align_corners=False, + ).squeeze(0) + + +# The following code are modified based on timm lib, we will replace the following +# contents with dependency from PyTorchVideo. +# https://github.com/facebookresearch/pytorchvideo +class RandomResizedCropAndInterpolation: + """Crop the given PIL Image to random size and aspect ratio with random interpolation. + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation='bilinear', + ): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + print('range should be of kind (min, max)') + + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = _pil_interp(interpolation) + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for _ in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + return F.resized_crop(img, i, j, h, w, self.size, interpolation) + + def __repr__(self): + if isinstance(self.interpolation, (tuple, list)): + interpolate_str = ' '.join( + [_pil_interpolation_to_str[x] for x in self.interpolation] + ) + else: + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format( + tuple(round(s, 4) for s in self.scale) + ) + format_string += ', ratio={0}'.format( + tuple(round(r, 4) for r in self.ratio) + ) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + +class Compose(object): + """Composes several transforms + Args: + transforms (list of ``Transform`` objects): list of transforms + to compose + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, clip): + for t in self.transforms: + clip = t(clip) + return clip + + +class RandomHorizontalFlip(object): + """Horizontally flip the list of given images randomly + with a probability 0.5 + """ + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Randomly flipped clip + """ + if random.random() < 0.5: + if isinstance(clip[0], np.ndarray): + return [np.fliplr(img) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + return [ + img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + ' but got list of {0}'.format(type(clip[0]))) + return clip + + +class RandomResize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): + self.ratio = ratio + self.interpolation = interpolation + + def __call__(self, clip): + scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) + + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + + new_w = int(im_w * scaling_factor) + new_h = int(im_h * scaling_factor) + new_size = (new_w, new_h) + resized = FF.resize_clip( + clip, new_size, interpolation=self.interpolation) + return resized + + +class Resize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, size, interpolation='nearest'): + self.size = size + self.interpolation = interpolation + + def __call__(self, clip): + resized = FF.resize_clip( + clip, self.size, interpolation=self.interpolation) + return resized + + +class RandomCrop(object): + """Extract random crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w > im_w or h > im_h: + error_msg = ( + 'Initial image size should be larger then ' + 'cropped size but got cropped sizes : ({w}, {h}) while ' + 'initial image is ({im_w}, {im_h})'.format( + im_w=im_w, im_h=im_h, w=w, h=h)) + raise ValueError(error_msg) + + x1 = random.randint(0, im_w - w) + y1 = random.randint(0, im_h - h) + cropped = FF.crop_clip(clip, y1, x1, h, w) + + return cropped + + +class ThreeCrop(object): + """Extract random crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w != im_w and h != im_h: + clip = FF.resize_clip(clip, self.size, interpolation="bilinear") + im_h, im_w, im_c = clip[0].shape + + step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0) + cropped = [] + for i in range(3): + if (im_h > self.size[0]): + x1 = 0 + y1 = i * step + cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) + else: + x1 = i * step + y1 = 0 + cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) + return cropped + + +class RandomRotation(object): + """Rotate entire clip randomly by a random angle within + given bounds + Args: + degrees (sequence or int): Range of degrees to select from + If degrees is a number instead of sequence like (min, max), + the range of degrees, will be (-degrees, +degrees). + """ + + def __init__(self, degrees): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError('If degrees is a single number,' + 'must be positive') + degrees = (-degrees, degrees) + else: + if len(degrees) != 2: + raise ValueError('If degrees is a sequence,' + 'it must be of len 2.') + + self.degrees = degrees + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + import skimage + angle = random.uniform(self.degrees[0], self.degrees[1]) + if isinstance(clip[0], np.ndarray): + rotated = [skimage.transform.rotate(img, angle) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + rotated = [img.rotate(angle) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + + return rotated + + +class CenterCrop(object): + """Extract center crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w > im_w or h > im_h: + error_msg = ( + 'Initial image size should be larger then ' + 'cropped size but got cropped sizes : ({w}, {h}) while ' + 'initial image is ({im_w}, {im_h})'.format( + im_w=im_w, im_h=im_h, w=w, h=h)) + raise ValueError(error_msg) + + x1 = int(round((im_w - w) / 2.)) + y1 = int(round((im_h - h) / 2.)) + cropped = FF.crop_clip(clip, y1, x1, h, w) + + return cropped + + +class ColorJitter(object): + """ + Randomly change the brightness, contrast and saturation and hue of the clip + + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + def get_params(self, brightness, contrast, saturation, hue): + if brightness > 0: + brightness_factor = random.uniform( + max(0, 1 - brightness), 1 + brightness) + else: + brightness_factor = None + + if contrast > 0: + contrast_factor = random.uniform( + max(0, 1 - contrast), 1 + contrast) + else: + contrast_factor = None + + if saturation > 0: + saturation_factor = random.uniform( + max(0, 1 - saturation), 1 + saturation) + else: + saturation_factor = None + + if hue > 0: + hue_factor = random.uniform(-hue, hue) + else: + hue_factor = None + return brightness_factor, contrast_factor, saturation_factor, hue_factor + + def __call__(self, clip): + """ + Args: + clip (list): list of PIL.Image + Returns: + list PIL.Image : list of transformed PIL.Image + """ + if isinstance(clip[0], np.ndarray): + raise TypeError( + 'Color jitter not yet implemented for numpy arrays') + elif isinstance(clip[0], PIL.Image.Image): + brightness, contrast, saturation, hue = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + # Create img transform function sequence + img_transforms = [] + if brightness is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) + if saturation is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) + if hue is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) + if contrast is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) + random.shuffle(img_transforms) + + # Apply to all images + jittered_clip = [] + for img in clip: + for func in img_transforms: + jittered_img = func(img) + jittered_clip.append(jittered_img) + + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return jittered_clip + + +class Normalize(object): + """Normalize a clip with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + .. note:: + This transform acts out of place, i.e., it does not mutates the input tensor. + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, clip): + """ + Args: + clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized. + Returns: + Tensor: Normalized Tensor clip. + """ + return FF.normalize(clip, self.mean, self.std) + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) diff --git a/jepa_src/datasets/utils/video/volume_transforms.py b/jepa_src/datasets/utils/video/volume_transforms.py new file mode 100644 index 0000000..0a01bb3 --- /dev/null +++ b/jepa_src/datasets/utils/video/volume_transforms.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numpy as np +from PIL import Image + +import torch + + +def convert_img(img): + """Converts (H, W, C) numpy.ndarray to (C, W, H) format""" + if len(img.shape) == 3: + img = img.transpose(2, 0, 1) + if len(img.shape) == 2: + img = np.expand_dims(img, 0) + return img + + +class ClipToTensor(object): + """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] + to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] + """ + + def __init__(self, channel_nb=3, div_255=True, numpy=False): + self.channel_nb = channel_nb + self.div_255 = div_255 + self.numpy = numpy + + def __call__(self, clip): + """ + Args: clip (list of numpy.ndarray): clip (list of images) + to be converted to tensor. + """ + # Retrieve shape + if isinstance(clip[0], np.ndarray): + h, w, ch = clip[0].shape + assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) + elif isinstance(clip[0], Image.Image): + w, h = clip[0].size + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + + np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) + + # Convert + for img_idx, img in enumerate(clip): + if isinstance(img, np.ndarray): + pass + elif isinstance(img, Image.Image): + img = np.array(img, copy=False) + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + img = convert_img(img) + np_clip[:, img_idx, :, :] = img + if self.numpy: + if self.div_255: + np_clip = np_clip / 255.0 + return np_clip + + else: + tensor_clip = torch.from_numpy(np_clip) + + if not isinstance(tensor_clip, torch.FloatTensor): + tensor_clip = tensor_clip.float() + if self.div_255: + tensor_clip = torch.div(tensor_clip, 255) + return tensor_clip + + +# Note this norms data to -1/1 +class ClipToTensor_K(object): + """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] + to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] + """ + + def __init__(self, channel_nb=3, div_255=True, numpy=False): + self.channel_nb = channel_nb + self.div_255 = div_255 + self.numpy = numpy + + def __call__(self, clip): + """ + Args: clip (list of numpy.ndarray): clip (list of images) + to be converted to tensor. + """ + # Retrieve shape + if isinstance(clip[0], np.ndarray): + h, w, ch = clip[0].shape + assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) + elif isinstance(clip[0], Image.Image): + w, h = clip[0].size + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + + np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) + + # Convert + for img_idx, img in enumerate(clip): + if isinstance(img, np.ndarray): + pass + elif isinstance(img, Image.Image): + img = np.array(img, copy=False) + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + img = convert_img(img) + np_clip[:, img_idx, :, :] = img + if self.numpy: + if self.div_255: + np_clip = (np_clip - 127.5) / 127.5 + return np_clip + + else: + tensor_clip = torch.from_numpy(np_clip) + + if not isinstance(tensor_clip, torch.FloatTensor): + tensor_clip = tensor_clip.float() + if self.div_255: + tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) + return tensor_clip + + +class ToTensor(object): + """Converts numpy array to tensor""" + + def __call__(self, array): + tensor = torch.from_numpy(array) + return tensor diff --git a/jepa_src/datasets/utils/weighted_sampler.py b/jepa_src/datasets/utils/weighted_sampler.py new file mode 100644 index 0000000..fd40825 --- /dev/null +++ b/jepa_src/datasets/utils/weighted_sampler.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from typing import Iterator, Optional +from operator import itemgetter +import numpy as np + +import torch +from torch.utils.data import ( + Dataset, + Sampler, + DistributedSampler, + WeightedRandomSampler +) + + +class DatasetFromSampler(Dataset): + + def __init__(self, sampler: Sampler): + self.sampler = sampler + self.sampler_list = None + + def __getitem__(self, index: int): + if self.sampler_list is None: + self.sampler_list = list(self.sampler) + return self.sampler_list[index] + + def __len__(self) -> int: + return len(self.sampler) + + +class DistributedSamplerWrapper(DistributedSampler): + """ Convert any Pytorch Sampler to a DistributedSampler """ + + def __init__( + self, + sampler, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + super(DistributedSamplerWrapper, self).__init__( + DatasetFromSampler(sampler), + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) + self.sampler = sampler + + def __iter__(self) -> Iterator[int]: + self.dataset = DatasetFromSampler(self.sampler) + indexes_of_indexes = super().__iter__() + subsampler_indexes = self.dataset + return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) + + +class CustomWeightedRandomSampler(WeightedRandomSampler): + """ Generalized WeightedRandomSampler to allow for more than 2^24 samples """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __iter__(self): + rand_tensor = np.random.choice( + range(0, len(self.weights)), + size=self.num_samples, + p=self.weights.numpy() / torch.sum(self.weights).numpy(), + replace=self.replacement + ) + rand_tensor = torch.from_numpy(rand_tensor) + return iter(rand_tensor.tolist()) + + +class DistributedWeightedSampler(DistributedSamplerWrapper): + + def __init__( + self, + weights, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + weighted_sampler = CustomWeightedRandomSampler( + weights=weights, + num_samples=len(weights), + replacement=False) + + super(DistributedWeightedSampler, self).__init__( + sampler=weighted_sampler, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) diff --git a/jepa_src/datasets/video_dataset.py b/jepa_src/datasets/video_dataset.py new file mode 100644 index 0000000..82cee52 --- /dev/null +++ b/jepa_src/datasets/video_dataset.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import pathlib +import warnings + +from logging import getLogger + +import numpy as np +import pandas as pd + +from decord import VideoReader, cpu + +import torch + +from jepa_src.datasets.utils.weighted_sampler import DistributedWeightedSampler + +_GLOBAL_SEED = 0 +logger = getLogger() + + +def make_videodataset( + data_paths, + batch_size, + frames_per_clip=8, + frame_step=4, + num_clips=1, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(10**9), + transform=None, + shared_transform=None, + rank=0, + world_size=1, + datasets_weights=None, + collator=None, + drop_last=True, + num_workers=10, + pin_mem=True, + duration=None, + log_dir=None, +): + dataset = VideoDataset( + data_paths=data_paths, + datasets_weights=datasets_weights, + frames_per_clip=frames_per_clip, + frame_step=frame_step, + num_clips=num_clips, + random_clip_sampling=random_clip_sampling, + allow_clip_overlap=allow_clip_overlap, + filter_short_videos=filter_short_videos, + filter_long_videos=filter_long_videos, + duration=duration, + shared_transform=shared_transform, + transform=transform) + + logger.info('VideoDataset dataset created') + if datasets_weights is not None: + dist_sampler = DistributedWeightedSampler( + dataset.sample_weights, + num_replicas=world_size, + rank=rank, + shuffle=True) + else: + dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=True) + + data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=collator, + sampler=dist_sampler, + batch_size=batch_size, + drop_last=drop_last, + pin_memory=pin_mem, + num_workers=num_workers, + persistent_workers=num_workers > 0) + logger.info('VideoDataset unsupervised data loader created') + + return dataset, data_loader, dist_sampler + + +class VideoDataset(torch.utils.data.Dataset): + """ Video classification dataset. """ + + def __init__( + self, + data_paths, + datasets_weights=None, + frames_per_clip=16, + frame_step=4, + num_clips=1, + transform=None, + shared_transform=None, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(10**9), + duration=None, # duration in seconds + ): + self.data_paths = data_paths + self.datasets_weights = datasets_weights + self.frames_per_clip = frames_per_clip + self.frame_step = frame_step + self.num_clips = num_clips + self.transform = transform + self.shared_transform = shared_transform + self.random_clip_sampling = random_clip_sampling + self.allow_clip_overlap = allow_clip_overlap + self.filter_short_videos = filter_short_videos + self.filter_long_videos = filter_long_videos + self.duration = duration + + if VideoReader is None: + raise ImportError('Unable to import "decord" which is required to read videos.') + + # Load video paths and labels + samples, labels = [], [] + self.num_samples_per_dataset = [] + for data_path in self.data_paths: + + if data_path[-4:] == '.csv': + data = pd.read_csv(data_path, header=None, delimiter=" ") + samples += list(data.values[:, 0]) + labels += list(data.values[:, 1]) + num_samples = len(data) + self.num_samples_per_dataset.append(num_samples) + + elif data_path[-4:] == '.npy': + data = np.load(data_path, allow_pickle=True) + data = list(map(lambda x: repr(x)[1:-1], data)) + samples += data + labels += [0] * len(data) + num_samples = len(data) + self.num_samples_per_dataset.append(len(data)) + + # [Optional] Weights for each sample to be used by downstream + # weighted video sampler + self.sample_weights = None + if self.datasets_weights is not None: + self.sample_weights = [] + for dw, ns in zip(self.datasets_weights, self.num_samples_per_dataset): + self.sample_weights += [dw / ns] * ns + + self.samples = samples + self.labels = labels + + def __getitem__(self, index): + sample = self.samples[index] + + # Keep trying to load videos until you find a valid sample + loaded_video = False + while not loaded_video: + buffer, clip_indices = self.loadvideo_decord(sample) # [T H W 3] + loaded_video = len(buffer) > 0 + if not loaded_video: + index = np.random.randint(self.__len__()) + sample = self.samples[index] + + # Label/annotations for video + label = self.labels[index] + + def split_into_clips(video): + """ Split video into a list of clips """ + fpc = self.frames_per_clip + nc = self.num_clips + return [video[i*fpc:(i+1)*fpc] for i in range(nc)] + + # Parse video into frames & apply data augmentations + if self.shared_transform is not None: + buffer = self.shared_transform(buffer) + buffer = split_into_clips(buffer) + if self.transform is not None: + buffer = [self.transform(clip) for clip in buffer] + + return buffer, label, clip_indices + + def loadvideo_decord(self, sample): + """ Load video content using Decord """ + + fname = sample + if not os.path.exists(fname): + warnings.warn(f'video path not found {fname}') + return [], None + + _fsize = os.path.getsize(fname) + if _fsize < 1 * 1024: # avoid hanging issue + warnings.warn(f'video too short {fname}') + return [], None + if _fsize > self.filter_long_videos: + warnings.warn(f'skipping long video of size {_fsize} (bytes)') + return [], None + + try: + vr = VideoReader(fname, num_threads=-1, ctx=cpu(0)) + except Exception: + return [], None + + fpc = self.frames_per_clip + fstp = self.frame_step + if self.duration is not None: + try: + fps = vr.get_avg_fps() + fstp = int(self.duration * fps / fpc) + except Exception as e: + warnings.warn(e) + clip_len = int(fpc * fstp) + + if self.filter_short_videos and len(vr) < clip_len: + warnings.warn(f'skipping video of length {len(vr)}') + return [], None + + vr.seek(0) # Go to start of video before sampling frames + + # Partition video into equal sized segments and sample each clip + # from a different segment + partition_len = len(vr) // self.num_clips + + all_indices, clip_indices = [], [] + for i in range(self.num_clips): + + if partition_len > clip_len: + # If partition_len > clip len, then sample a random window of + # clip_len frames within the segment + end_indx = clip_len + if self.random_clip_sampling: + end_indx = np.random.randint(clip_len, partition_len) + start_indx = end_indx - clip_len + indices = np.linspace(start_indx, end_indx, num=fpc) + indices = np.clip(indices, start_indx, end_indx-1).astype(np.int64) + # -- + indices = indices + i * partition_len + else: + # If partition overlap not allowed and partition_len < clip_len + # then repeatedly append the last frame in the segment until + # we reach the desired clip length + if not self.allow_clip_overlap: + indices = np.linspace(0, partition_len, num=partition_len // fstp) + indices = np.concatenate((indices, np.ones(fpc - partition_len // fstp) * partition_len,)) + indices = np.clip(indices, 0, partition_len-1).astype(np.int64) + # -- + indices = indices + i * partition_len + + # If partition overlap is allowed and partition_len < clip_len + # then start_indx of segment i+1 will lie within segment i + else: + sample_len = min(clip_len, len(vr)) - 1 + indices = np.linspace(0, sample_len, num=sample_len // fstp) + indices = np.concatenate((indices, np.ones(fpc - sample_len // fstp) * sample_len,)) + indices = np.clip(indices, 0, sample_len-1).astype(np.int64) + # -- + clip_step = 0 + if len(vr) > clip_len: + clip_step = (len(vr) - clip_len) // (self.num_clips - 1) + indices = indices + i * clip_step + + clip_indices.append(indices) + all_indices.extend(list(indices)) + + buffer = vr.get_batch(all_indices).asnumpy() + return buffer, clip_indices + + def __len__(self): + return len(self.samples) diff --git a/jepa_src/masks/__init__.py b/jepa_src/masks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/masks/default.py b/jepa_src/masks/default.py new file mode 100644 index 0000000..2810c0a --- /dev/null +++ b/jepa_src/masks/default.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from logging import getLogger + +import torch + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class DefaultCollator(object): + + def __call__(self, batch): + collated_batch = torch.utils.data.default_collate(batch) + return collated_batch, None, None diff --git a/jepa_src/masks/multiblock3d.py b/jepa_src/masks/multiblock3d.py new file mode 100644 index 0000000..a7bbc3e --- /dev/null +++ b/jepa_src/masks/multiblock3d.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +from multiprocessing import Value + +from logging import getLogger + +import torch + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class MaskCollator(object): + + def __init__( + self, + cfgs_mask, + crop_size=(224, 224), + num_frames=16, + patch_size=(16, 16), + tubelet_size=2, + ): + super(MaskCollator, self).__init__() + + self.mask_generators = [] + for m in cfgs_mask: + mask_generator = _MaskGenerator( + crop_size=crop_size, + num_frames=num_frames, + spatial_patch_size=patch_size, + temporal_patch_size=tubelet_size, + spatial_pred_mask_scale=m.get('spatial_scale'), + temporal_pred_mask_scale=m.get('temporal_scale'), + aspect_ratio=m.get('aspect_ratio'), + npred=m.get('num_blocks'), + max_context_frames_ratio=m.get('max_temporal_keep', 1.0), + max_keep=m.get('max_keep', None), + ) + self.mask_generators.append(mask_generator) + + def step(self): + for mask_generator in self.mask_generators: + mask_generator.step() + + def __call__(self, batch): + + batch_size = len(batch) + collated_batch = torch.utils.data.default_collate(batch) + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(batch_size) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + + return collated_batch, collated_masks_enc, collated_masks_pred + + +class _MaskGenerator(object): + + def __init__( + self, + crop_size=(224, 224), + num_frames=16, + spatial_patch_size=(16, 16), + temporal_patch_size=2, + spatial_pred_mask_scale=(0.2, 0.8), + temporal_pred_mask_scale=(1.0, 1.0), + aspect_ratio=(0.3, 3.0), + npred=1, + max_context_frames_ratio=1.0, + max_keep=None, + ): + super(_MaskGenerator, self).__init__() + if not isinstance(crop_size, tuple): + crop_size = (crop_size, ) * 2 + self.crop_size = crop_size + self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.duration = num_frames // temporal_patch_size + + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.aspect_ratio = aspect_ratio + self.spatial_pred_mask_scale = spatial_pred_mask_scale + self.temporal_pred_mask_scale = temporal_pred_mask_scale + self.npred = npred + self.max_context_duration = max(1, int(self.duration * max_context_frames_ratio)) # maximum number of time-steps (frames) spanned by context mask + self.max_keep = max_keep # maximum number of patches to keep in context + self._itr_counter = Value('i', -1) # collator is shared across worker processes + + def step(self): + i = self._itr_counter + with i.get_lock(): + i.value += 1 + v = i.value + return v + + def _sample_block_size( + self, + generator, + temporal_scale, + spatial_scale, + aspect_ratio_scale + ): + # -- Sample temporal block mask scale + _rand = torch.rand(1, generator=generator).item() + min_t, max_t = temporal_scale + temporal_mask_scale = min_t + _rand * (max_t - min_t) + t = max(1, int(self.duration * temporal_mask_scale)) + + # -- Sample spatial block mask scale + _rand = torch.rand(1, generator=generator).item() + min_s, max_s = spatial_scale + spatial_mask_scale = min_s + _rand * (max_s - min_s) + spatial_num_keep = int(self.height * self.width * spatial_mask_scale) + + # -- Sample block aspect-ratio + _rand = torch.rand(1, generator=generator).item() + min_ar, max_ar = aspect_ratio_scale + aspect_ratio = min_ar + _rand * (max_ar - min_ar) + + # -- Compute block height and width (given scale and aspect-ratio) + h = int(round(math.sqrt(spatial_num_keep * aspect_ratio))) + w = int(round(math.sqrt(spatial_num_keep / aspect_ratio))) + h = min(h, self.height) + w = min(w, self.width) + + return (t, h, w) + + def _sample_block_mask(self, b_size): + t, h, w = b_size + top = torch.randint(0, self.height - h + 1, (1,)) + left = torch.randint(0, self.width - w + 1, (1,)) + start = torch.randint(0, self.duration - t + 1, (1,)) + + mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) + mask[start:start+t, top:top+h, left:left+w] = 0 + + # Context mask will only span the first X frames + # (X=self.max_context_frames) + if self.max_context_duration < self.duration: + mask[self.max_context_duration:, :, :] = 0 + + # -- + return mask + + def __call__(self, batch_size): + """ + Create encoder and predictor masks when collating imgs into a batch + # 1. sample pred block size using seed + # 2. sample several pred block locations for each image (w/o seed) + # 3. return pred masks and complement (enc mask) + """ + seed = self.step() + g = torch.Generator() + g.manual_seed(seed) + p_size = self._sample_block_size( + generator=g, + temporal_scale=self.temporal_pred_mask_scale, + spatial_scale=self.spatial_pred_mask_scale, + aspect_ratio_scale=self.aspect_ratio, + ) + + collated_masks_pred, collated_masks_enc = [], [] + min_keep_enc = min_keep_pred = self.duration * self.height * self.width + for _ in range(batch_size): + + empty_context = True + while empty_context: + + mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) + for _ in range(self.npred): + mask_e *= self._sample_block_mask(p_size) + mask_e = mask_e.flatten() + + mask_p = torch.argwhere(mask_e == 0).squeeze() + mask_e = torch.nonzero(mask_e).squeeze() + + empty_context = len(mask_e) == 0 + if not empty_context: + min_keep_pred = min(min_keep_pred, len(mask_p)) + min_keep_enc = min(min_keep_enc, len(mask_e)) + collated_masks_pred.append(mask_p) + collated_masks_enc.append(mask_e) + + if self.max_keep is not None: + min_keep_enc = min(min_keep_enc, self.max_keep) + + collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred] + collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) + # -- + collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc] + collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) + + return collated_masks_enc, collated_masks_pred diff --git a/jepa_src/masks/random_tube.py b/jepa_src/masks/random_tube.py new file mode 100644 index 0000000..84c0640 --- /dev/null +++ b/jepa_src/masks/random_tube.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from multiprocessing import Value + +from logging import getLogger + +import torch +import numpy as np + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class MaskCollator(object): + + def __init__( + self, + cfgs_mask, + crop_size=(224, 224), + num_frames=16, + patch_size=(16, 16), + tubelet_size=2, + ): + super(MaskCollator, self).__init__() + + self.mask_generators = [] + for m in cfgs_mask: + mask_generator = _MaskGenerator( + crop_size=crop_size, + num_frames=num_frames, + spatial_patch_size=patch_size, + temporal_patch_size=tubelet_size, + ratio=m.get('ratio'), + ) + self.mask_generators.append(mask_generator) + + def step(self): + for mask_generator in self.mask_generators: + mask_generator.step() + + def __call__(self, batch): + + batch_size = len(batch) + collated_batch = torch.utils.data.default_collate(batch) + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(batch_size) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + + return collated_batch, collated_masks_enc, collated_masks_pred + + +class _MaskGenerator(object): + + def __init__( + self, + crop_size=(224, 224), + num_frames=16, + spatial_patch_size=(16, 16), + temporal_patch_size=2, + ratio=0.9, + ): + super(_MaskGenerator, self).__init__() + if not isinstance(crop_size, tuple): + crop_size = (crop_size, ) * 2 + self.crop_size = crop_size + self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.duration = num_frames // temporal_patch_size + + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + self.num_patches_spatial = self.height*self.width + + self.ratio = ratio + + self.num_keep_spatial = int(self.num_patches_spatial*(1.-self.ratio)) + self.num_keep = self.num_keep_spatial * self.duration + + self._itr_counter = Value('i', -1) # collator is shared across worker processes + + def step(self): + i = self._itr_counter + with i.get_lock(): + i.value += 1 + v = i.value + return v + + def __call__(self, batch_size): + def sample_mask(): + mask = np.hstack([ + np.zeros(self.num_patches_spatial - self.num_keep_spatial), + np.ones(self.num_keep_spatial), + ]) + np.random.shuffle(mask) + mask = torch.tensor(np.tile(mask, (self.duration, 1))) + mask = mask.flatten() + mask_p = torch.argwhere(mask == 0).squeeze() + mask_e = torch.nonzero(mask).squeeze() + return mask_e, mask_p + + collated_masks_pred, collated_masks_enc = [], [] + for _ in range(batch_size): + mask_e, mask_p = sample_mask() + collated_masks_enc.append(mask_e) + collated_masks_pred.append(mask_p) + + collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) + collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) + + return collated_masks_enc, collated_masks_pred diff --git a/jepa_src/masks/utils.py b/jepa_src/masks/utils.py new file mode 100644 index 0000000..ca04af1 --- /dev/null +++ b/jepa_src/masks/utils.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch + + +def apply_masks(x, masks, concat=True): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + if not concat: + return all_x + + return torch.cat(all_x, dim=0) diff --git a/jepa_src/models/__init__.py b/jepa_src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/models/attentive_pooler.py b/jepa_src/models/attentive_pooler.py new file mode 100644 index 0000000..26b0e0e --- /dev/null +++ b/jepa_src/models/attentive_pooler.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +import torch +import torch.nn as nn + +from jepa_src.models.utils.modules import ( + Block, + CrossAttention, + CrossAttentionBlock +) +from jepa_src.utils.tensors import trunc_normal_ + + +class AttentivePooler(nn.Module): + """ Attentive Pooler """ + def __init__( + self, + num_queries=1, + embed_dim=768, + num_heads=12, + mlp_ratio=4.0, + depth=1, + norm_layer=nn.LayerNorm, + init_std=0.02, + qkv_bias=True, + complete_block=True + ): + super().__init__() + self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) + + self.complete_block = complete_block + if complete_block: + self.cross_attention_block = CrossAttentionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer) + else: + self.cross_attention_block = CrossAttention( + dim=embed_dim, + num_heads=num_heads, + qkv_bias=qkv_bias) + + self.blocks = None + if depth > 1: + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=False, + norm_layer=norm_layer) + for i in range(depth-1)]) + + self.init_std = init_std + trunc_normal_(self.query_tokens, std=self.init_std) + self.apply(self._init_weights) + self._rescale_blocks() + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + if self.complete_block: + rescale(self.cross_attention_block.xattn.proj.weight.data, 1) + rescale(self.cross_attention_block.mlp.fc2.weight.data, 1) + else: + rescale(self.cross_attention_block.proj.weight.data, 1) + if self.blocks is not None: + for layer_id, layer in enumerate(self.blocks, 1): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + q = self.query_tokens.repeat(len(x), 1, 1) + q = self.cross_attention_block(q, x) + if self.blocks is not None: + for blk in self.blocks: + q = blk(q) + return q + + +class AttentiveClassifier(nn.Module): + """ Attentive Classifier """ + def __init__( + self, + embed_dim=768, + num_heads=12, + mlp_ratio=4.0, + depth=1, + norm_layer=nn.LayerNorm, + init_std=0.02, + qkv_bias=True, + num_classes=1000, + complete_block=True, + ): + super().__init__() + self.pooler = AttentivePooler( + num_queries=1, + embed_dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + depth=depth, + norm_layer=norm_layer, + init_std=init_std, + qkv_bias=qkv_bias, + complete_block=complete_block, + ) + self.linear = nn.Linear(embed_dim, num_classes, bias=True) + + def forward(self, x): + x = self.pooler(x).squeeze(1) + x = self.linear(x) + return x diff --git a/jepa_src/models/predictor.py b/jepa_src/models/predictor.py new file mode 100644 index 0000000..95f6bc0 --- /dev/null +++ b/jepa_src/models/predictor.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +from functools import partial + +import torch +import torch.nn as nn + +from jepa_src.models.utils.modules import Block +from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from jepa_src.utils.tensors import ( + trunc_normal_, + repeat_interleave_batch +) +from jepa_src.masks.utils import apply_masks + + +class VisionTransformerPredictor(nn.Module): + """ Vision Transformer """ + def __init__( + self, + img_size=224, + patch_size=16, + num_frames=1, + tubelet_size=2, + embed_dim=768, + predictor_embed_dim=384, + depth=6, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + uniform_power=False, + use_mask_tokens=False, + num_mask_tokens=2, + zero_init_mask_tokens=True, + **kwargs + ): + super().__init__() + # Map input to predictor dimension + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + + # Mask tokens + self.mask_tokens = None + self.num_mask_tokens = 0 + if use_mask_tokens: + self.num_mask_tokens = num_mask_tokens + self.mask_tokens = nn.ParameterList([ + nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + for i in range(num_mask_tokens) + ]) + + # Determine positional embedding + self.input_size = img_size + self.patch_size = patch_size + # -- + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + + grid_size = self.input_size // self.patch_size + grid_depth = self.num_frames // self.tubelet_size + + if self.is_video: + self.num_patches = num_patches = ( + (num_frames // tubelet_size) + * (img_size // patch_size) + * (img_size // patch_size) + ) + else: + self.num_patches = num_patches = ( + (img_size // patch_size) + * (img_size // patch_size) + ) + # Position embedding + self.uniform_power = uniform_power + self.predictor_pos_embed = None + self.predictor_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, predictor_embed_dim), + requires_grad=False) + + # Attention Blocks + self.predictor_blocks = nn.ModuleList([ + Block( + dim=predictor_embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + attn_drop=attn_drop_rate, + grid_size=grid_size, + grid_depth=grid_depth, + norm_layer=norm_layer) + for i in range(depth)]) + + # Normalize & project back to input dimension + self.predictor_norm = norm_layer(predictor_embed_dim) + self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) + + # ------ initialize weights + if self.predictor_pos_embed is not None: + self._init_pos_embed(self.predictor_pos_embed.data) # sincos pos-embed + self.init_std = init_std + if not zero_init_mask_tokens: + for mt in self.mask_tokens: + trunc_normal_(mt, std=init_std) + self.apply(self._init_weights) + self._rescale_blocks() + + def _init_pos_embed(self, pos_embed): + embed_dim = pos_embed.size(-1) + grid_size = self.input_size // self.patch_size + if self.is_video: + grid_depth = self.num_frames // self.tubelet_size + sincos = get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=self.uniform_power + ) + else: + sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) + pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.predictor_blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def diffusion(self, x, noise_beta=(0.5, 1.0), steps=1000): + + # Prepare diffusion noise schedule + b1, b2 = noise_beta + beta_scheduler = (b1 + i*(b2-b1)/steps for i in range(steps)) + alpha_scheduler = [] + _alpha = 1.0 + for _beta in beta_scheduler: + _alpha *= 1.-_beta + alpha_scheduler += [_alpha] + + # Sample diffusion time step + T = torch.randint(0, steps, (len(x),)) + alpha = torch.tensor(alpha_scheduler, device=x.device)[T].unsqueeze(-1).unsqueeze(-1) + + # Normalize features and apply noise + x = torch.nn.functional.layer_norm(x, (x.size(-1),)) + x = alpha**0.5 * x + (1.-alpha)**0.5 * torch.randn(x.shape, device=x.device) + return x + + def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): + """ + :param ctxt: context tokens + :param tgt: target tokens + :param masks_ctxt: indices of context tokens in input + :params masks_tgt: indices of target tokens in input + """ + + assert (masks_ctxt is not None) and (masks_tgt is not None), 'Cannot run predictor without mask indices' + + if not isinstance(masks_ctxt, list): + masks_ctxt = [masks_ctxt] + + if not isinstance(masks_tgt, list): + masks_tgt = [masks_tgt] + + # Batch Size + B = len(ctxt) // len(masks_ctxt) + + # Map context tokens to pedictor dimensions + x = self.predictor_embed(ctxt) + _, N_ctxt, D = x.shape + + # Add positional embedding to ctxt tokens + if self.predictor_pos_embed is not None: + ctxt_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + x += apply_masks(ctxt_pos_embed, masks_ctxt) + + # Map target tokens to predictor dimensions & add noise (fwd diffusion) + if self.mask_tokens is None: + pred_tokens = self.predictor_embed(tgt) + pred_tokens = self.diffusion(pred_tokens) + else: + mask_index = mask_index % self.num_mask_tokens + pred_tokens = self.mask_tokens[mask_index] + pred_tokens = pred_tokens.repeat(B, self.num_patches, 1) + pred_tokens = apply_masks(pred_tokens, masks_tgt) + + # Add positional embedding to target tokens + if self.predictor_pos_embed is not None: + pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) + pos_embs = apply_masks(pos_embs, masks_tgt) + pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_ctxt)) + pred_tokens += pos_embs + + # Concatenate context & target tokens + x = x.repeat(len(masks_tgt), 1, 1) + x = torch.cat([x, pred_tokens], dim=1) + + # FIXME: this implementation currently assumes masks_ctxt and masks_tgt + # are alligned 1:1 (ok with MultiMask wrapper on predictor but + # otherwise will break) + masks_ctxt = torch.cat(masks_ctxt, dim=0) + masks_tgt = torch.cat(masks_tgt, dim=0) + masks = torch.cat([masks_ctxt, masks_tgt], dim=1) + + # Fwd prop + for blk in self.predictor_blocks: + x = blk(x, mask=masks) + x = self.predictor_norm(x) + + # Return output corresponding to target tokens + x = x[:, N_ctxt:] + x = self.predictor_proj(x) + + return x + + +def vit_predictor(**kwargs): + model = VisionTransformerPredictor( + mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + return model diff --git a/jepa_src/models/utils/__init__.py b/jepa_src/models/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/models/utils/modules.py b/jepa_src/models/utils/modules.py new file mode 100644 index 0000000..2412b7a --- /dev/null +++ b/jepa_src/models/utils/modules.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import jepa_src.utils.functional as JF + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0. + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + use_sdpa=True + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop_prob = proj_drop + self.proj_drop = nn.Dropout(proj_drop) + self.use_sdpa = use_sdpa + + def forward(self, x, mask=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D] + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + x = JF.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob) + attn = None + else: + attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D] + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + grid_size=None, + grid_depth=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x, return_attention=False, mask=None): + y, attn = self.attn(self.norm1(x), mask=mask) + if return_attention: + return attn + x = x + y + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, + dim, + num_heads=12, + qkv_bias=False, + use_sdpa=True + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, int(dim*2), bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + self.use_sdpa = use_sdpa + + def forward(self, q, x): + B, n, C = q.shape + q = self.q(q).reshape(B, n, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + B, N, C = x.shape + kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # (batch_size, num_heads, seq_len, feature_dim_per_head) + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + q = JF.scaled_dot_product_attention(q, k, v) + else: + xattn = (q @ k.transpose(-2, -1)) * self.scale + xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len) + q = (xattn @ v) + + q = q.transpose(1, 2).reshape(B, n, C) + q = self.proj(q) + + return q + + +class CrossAttentionBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + + def forward(self, q, x): + y = self.xattn(q, self.norm1(x)) + q = q + y + q = q + self.mlp(self.norm2(q)) + return q diff --git a/jepa_src/models/utils/multimask.py b/jepa_src/models/utils/multimask.py new file mode 100644 index 0000000..d480086 --- /dev/null +++ b/jepa_src/models/utils/multimask.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch.nn as nn + + +class MultiMaskWrapper(nn.Module): + + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, x, masks=None): + if masks is None: + return self.backbone(x) + + if (masks is not None) and not isinstance(masks, list): + masks = [masks] + outs = [] + for m in masks: + outs += [self.backbone(x, masks=m)] + return outs + + +class PredictorMultiMaskWrapper(nn.Module): + + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, ctxt, tgt, masks_ctxt, masks_tgt): + if type(ctxt) is not list: + ctxt = [ctxt] + if type(tgt) is not list: + tgt = [tgt] + if type(masks_ctxt) is not list: + masks_ctxt = [masks_ctxt] + if type(masks_tgt) is not list: + masks_tgt = [masks_tgt] + + outs = [] + for i, (zi, hi, mc, mt) in enumerate(zip(ctxt, tgt, masks_ctxt, masks_tgt)): + outs += [self.backbone(zi, hi, mc, mt, mask_index=i)] + return outs diff --git a/jepa_src/models/utils/patch_embed.py b/jepa_src/models/utils/patch_embed.py new file mode 100644 index 0000000..4ff4de5 --- /dev/null +++ b/jepa_src/models/utils/patch_embed.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch.nn as nn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding + """ + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768 + ): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed3D(nn.Module): + """ + Image to Patch Embedding + """ + + def __init__( + self, + patch_size=16, + tubelet_size=2, + in_chans=3, + embed_dim=768, + ): + super().__init__() + self.patch_size = patch_size + self.tubelet_size = tubelet_size + + self.proj = nn.Conv3d( + in_channels=in_chans, + out_channels=embed_dim, + kernel_size=(tubelet_size, patch_size, patch_size), + stride=(tubelet_size, patch_size, patch_size), + ) + + def forward(self, x, **kwargs): + B, C, T, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x diff --git a/jepa_src/models/utils/pos_embs.py b/jepa_src/models/utils/pos_embs.py new file mode 100644 index 0000000..d1d82e2 --- /dev/null +++ b/jepa_src/models/utils/pos_embs.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numpy as np + + +def get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=False +): + """ + grid_size: int of the grid height and width + grid_depth: int of the grid depth + returns: + pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token) + or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token) + """ + grid_d = np.arange(grid_depth, dtype=float) + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid_h, grid_d, grid_w = np.meshgrid(grid_h, grid_d, grid_w) # order of meshgrid is very important for indexing as [d,h,w] + + if not uniform_power: + h_embed_dim = embed_dim // 4 + w_embed_dim = embed_dim // 4 + d_embed_dim = embed_dim // 2 + else: + h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim/6)*2) + + emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1) + emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2) + emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3) + pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1) + pos_embed = pos_embed[:, :embed_dim] + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + returns: + pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token) + or [1+grid_size*grid_size, embed_dim] (w/ cls_token) + """ + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w] + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2) + pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + embed_dim: output dimension for each position + grid_size: int of the grid length + returns: + pos_embed: [grid_size, embed_dim] (w/o cls_token) + or [1+grid_size, embed_dim] (w/ cls_token) + """ + grid = np.arange(grid_size, dtype=float) + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + returns: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/jepa_src/models/vision_transformer.py b/jepa_src/models/vision_transformer.py new file mode 100644 index 0000000..946246e --- /dev/null +++ b/jepa_src/models/vision_transformer.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +from functools import partial + +import torch +import torch.nn as nn + +from jepa_src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D +from jepa_src.models.utils.modules import Block +from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from jepa_src.utils.tensors import trunc_normal_ +from jepa_src.masks.utils import apply_masks + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__( + self, + img_size=224, + patch_size=16, + num_frames=1, + tubelet_size=2, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + out_layers=None, + uniform_power=False, + **kwargs + ): + super().__init__() + self.num_features = self.embed_dim = embed_dim + self.num_heads = num_heads + self.out_layers = out_layers + + self.input_size = img_size + self.patch_size = patch_size + + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + + grid_size = self.input_size // self.patch_size + grid_depth = self.num_frames // self.tubelet_size + + # Tokenize pixels with convolution + if self.is_video: + self.patch_embed = PatchEmbed3D( + patch_size=patch_size, + tubelet_size=tubelet_size, + in_chans=in_chans, + embed_dim=embed_dim) + self.num_patches = ( + (num_frames // tubelet_size) + * (img_size // patch_size) + * (img_size // patch_size) + ) + else: + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + self.num_patches = ( + (img_size // patch_size) + * (img_size // patch_size) + ) + + # Position embedding + self.uniform_power = uniform_power + self.pos_embed = None + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches, embed_dim), + requires_grad=False) + + # Attention Blocks + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + grid_size=grid_size, + grid_depth=grid_depth, + attn_drop=attn_drop_rate, + norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # ------ initialize weights + if self.pos_embed is not None: + self._init_pos_embed(self.pos_embed.data) # sincos pos-embed + self.init_std = init_std + self.apply(self._init_weights) + self._rescale_blocks() + + def _init_pos_embed(self, pos_embed): + embed_dim = pos_embed.size(-1) + grid_size = self.input_size // self.patch_size + if self.is_video: + grid_depth = self.num_frames // self.tubelet_size + sincos = get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=self.uniform_power + ) + else: + sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) + pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv3d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def get_num_layers(self): + return len(self.blocks) + + def no_weight_decay(self): + return {} + + def forward(self, x, masks=None): + """ + :param x: input image/video + :param masks: indices of patch tokens to mask (remove) + """ + + if masks is not None and not isinstance(masks, list): + masks = [masks] + + # Tokenize input + pos_embed = self.pos_embed + if pos_embed is not None: + pos_embed = self.interpolate_pos_encoding(x, pos_embed) + x = self.patch_embed(x) + if pos_embed is not None: + x += pos_embed + B, N, D = x.shape + + # Mask away unwanted tokens (if masks provided) + if masks is not None: + x = apply_masks(x, masks) + masks = torch.cat(masks, dim=0) + + # Fwd prop + outs = [] + for i, blk in enumerate(self.blocks): + x = blk(x, mask=masks) + if self.out_layers is not None and i in self.out_layers: + outs.append(self.norm(x)) + + if self.out_layers is not None: + return outs + + if self.norm is not None: + x = self.norm(x) + + return x + + def interpolate_pos_encoding(self, x, pos_embed): + + _, N, dim = pos_embed.shape + + if self.is_video: + + # If pos_embed already corret size, just return + _, _, T, H, W = x.shape + if H == self.input_size and W == self.input_size and T == self.num_frames: + return pos_embed + + # Convert depth, height, width of input to be measured in patches + # instead of pixels/frames + T = T // self.tubelet_size + H = H // self.patch_size + W = W // self.patch_size + + # Compute the initialized shape of the positional embedding measured + # in patches + N_t = self.num_frames // self.tubelet_size + N_h = N_w = self.input_size // self.patch_size + assert N_h * N_w * N_t == N, 'Positional embedding initialized incorrectly' + + # Compute scale factor for spatio-temporal interpolation + scale_factor = (T/N_t, H/N_h, W/N_w) + + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3), + scale_factor=scale_factor, + mode='trilinear') + pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim) + return pos_embed + + else: + + # If pos_embed already corret size, just return + _, _, H, W = x.shape + if H == self.input_size and W == self.input_size: + return pos_embed + + # Compute scale factor for spatial interpolation + npatch = (H // self.patch_size) * (W // self.patch_size) + scale_factor = math.sqrt(npatch / N) + + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=scale_factor, + mode='bicubic') + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return pos_embed + + +def vit_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_small(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_large(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_huge(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_giant(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_gigantic(patch_size=14, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1664, depth=48, num_heads=16, mpl_ratio=64/13, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs + ) + return model + + +VIT_EMBED_DIMS = { + 'vit_tiny': 192, + 'vit_small': 384, + 'vit_base': 768, + 'vit_large': 1024, + 'vit_huge': 1280, + 'vit_giant': 1408, + 'vit_gigantic': 1664, +} diff --git a/jepa_src/utils/__init__.py b/jepa_src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jepa_src/utils/distributed.py b/jepa_src/utils/distributed.py new file mode 100644 index 0000000..cfba444 --- /dev/null +++ b/jepa_src/utils/distributed.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +import torch +import torch.distributed as dist + +from logging import getLogger + +logger = getLogger() + + +def init_distributed(port=37123, rank_and_world_size=(None, None)): + + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size(), dist.get_rank() + + rank, world_size = rank_and_world_size + os.environ['MASTER_ADDR'] = 'localhost' + + if (rank is None) or (world_size is None): + try: + world_size = int(os.environ['SLURM_NTASKS']) + rank = int(os.environ['SLURM_PROCID']) + os.environ['MASTER_ADDR'] = os.environ['HOSTNAME'] + except Exception: + logger.info('SLURM vars not set (distributed training not available)') + world_size, rank = 1, 0 + return world_size, rank + + try: + os.environ['MASTER_PORT'] = str(port) + torch.distributed.init_process_group( + backend='nccl', + world_size=world_size, + rank=rank + ) + except Exception as e: + world_size, rank = 1, 0 + logger.info(f'Rank: {rank}. Distributed training not available {e}') + + return world_size, rank + + +class AllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() + outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())] + dist.all_gather(outputs, x) + return torch.cat(outputs, 0) + return x + + @staticmethod + def backward(ctx, grads): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank() + e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1) + grads = grads.contiguous() + dist.all_reduce(grads) + return grads[s:e] + return grads + + +class AllReduceSum(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() + dist.all_reduce(x) + return x + + @staticmethod + def backward(ctx, grads): + return grads + + +class AllReduce(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() / dist.get_world_size() + dist.all_reduce(x) + return x + + @staticmethod + def backward(ctx, grads): + return grads diff --git a/jepa_src/utils/functional.py b/jepa_src/utils/functional.py new file mode 100644 index 0000000..27d1b42 --- /dev/null +++ b/jepa_src/utils/functional.py @@ -0,0 +1,30 @@ +import torch +import torch.nn.functional as F + +def scaled_dot_product_attention(q, k, v, dropout_p=0.0): + """ + Computes scaled dot product attention. + + Args: + q (torch.Tensor): Query tensor of shape (batch_size, num_heads, seq_len_q, head_dim). + k (torch.Tensor): Key tensor of shape (batch_size, num_heads, seq_len_k, head_dim). + v (torch.Tensor): Value tensor of shape (batch_size, num_heads, seq_len_v, head_dim). + dropout_p (float, optional): Dropout probability. Default is 0.0. + + Returns: + torch.Tensor: Output tensor of shape (batch_size, num_heads, seq_len_q, head_dim). + """ + # Compute attention scores + attn_scores = torch.matmul(q, k.transpose(-2, -1)) + attn_scores = attn_scores / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32)) + + # Apply softmax to attention scores + attn_probs = F.softmax(attn_scores, dim=-1) + + # Apply dropout to attention probabilities + attn_probs = F.dropout(attn_probs, p=dropout_p) + + # Compute attention output + attn_output = torch.matmul(attn_probs, v) + + return attn_output \ No newline at end of file diff --git a/jepa_src/utils/logging.py b/jepa_src/utils/logging.py new file mode 100644 index 0000000..fcdd3fa --- /dev/null +++ b/jepa_src/utils/logging.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging +import sys + +import torch + + +def gpu_timer(closure, log_timings=True): + """ Helper to time gpu-time to execute closure() """ + log_timings = log_timings and torch.cuda.is_available() + + elapsed_time = -1. + if log_timings: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + + result = closure() + + if log_timings: + end.record() + torch.cuda.synchronize() + elapsed_time = start.elapsed_time(end) + + return result, elapsed_time + + +LOG_FORMAT = "[%(levelname)-8s][%(asctime)s][%(funcName)-25s] %(message)s" +DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + + +def get_logger(name=None, force=False): + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=LOG_FORMAT, datefmt=DATE_FORMAT, force=force) + return logging.getLogger(name=name) + + +class CSVLogger(object): + + def __init__(self, fname, *argv): + self.fname = fname + self.types = [] + # -- print headers + with open(self.fname, '+a') as f: + for i, v in enumerate(argv, 1): + self.types.append(v[0]) + if i < len(argv): + print(v[1], end=',', file=f) + else: + print(v[1], end='\n', file=f) + + def log(self, *argv): + with open(self.fname, '+a') as f: + for i, tv in enumerate(zip(self.types, argv), 1): + end = ',' if i < len(argv) else '\n' + print(tv[0] % tv[1], end=end, file=f) + + +class AverageMeter(object): + """computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.max = float('-inf') + self.min = float('inf') + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + try: + self.max = max(val, self.max) + self.min = min(val, self.min) + except Exception: + pass + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def grad_logger(named_params): + stats = AverageMeter() + stats.first_layer = None + stats.last_layer = None + for n, p in named_params: + if (p.grad is not None) and not (n.endswith('.bias') or len(p.shape) == 1): + grad_norm = float(torch.norm(p.grad.data)) + stats.update(grad_norm) + if 'qkv' in n: + stats.last_layer = grad_norm + if stats.first_layer is None: + stats.first_layer = grad_norm + if stats.first_layer is None or stats.last_layer is None: + stats.first_layer = stats.last_layer = 0. + return stats + + +def adamw_logger(optimizer): + """ logging magnitude of first and second momentum buffers in adamw """ + # TODO: assert that optimizer is instance of torch.optim.AdamW + state = optimizer.state_dict().get('state') + exp_avg_stats = AverageMeter() + exp_avg_sq_stats = AverageMeter() + for key in state: + s = state.get(key) + exp_avg_stats.update(float(s.get('exp_avg').abs().mean())) + exp_avg_sq_stats.update(float(s.get('exp_avg_sq').abs().mean())) + return {'exp_avg': exp_avg_stats, 'exp_avg_sq': exp_avg_sq_stats} diff --git a/jepa_src/utils/monitoring.py b/jepa_src/utils/monitoring.py new file mode 100644 index 0000000..95a7845 --- /dev/null +++ b/jepa_src/utils/monitoring.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import dataclasses +import threading +from typing import Dict, Tuple + +import psutil + + +@dataclasses.dataclass +class ResourceStatsSample: + timestamp: float + cpu_percent: float + read_count: int + write_count: int + read_bytes: int + write_bytes: int + read_chars: int + write_chars: int + cpu_times_user: float + cpu_times_system: float + cpu_times_children_user: float + cpu_times_children_system: float + cpu_times_iowait: float + cpu_affinity: str + cpu_num: int + num_threads: int + num_voluntary_ctx_switches: int + num_involuntary_ctx_switches: int + + def as_tuple(self) -> Dict: + """Return values mirroring fields.""" + return dataclasses.astuple(self) + + def fields(self) -> Tuple[dataclasses.Field, ...]: + """Return fields in this dataclass.""" + return dataclasses.fields(self.__class__) + + +class ResourceMonitoringThread(threading.Thread): + def __init__(self, pid=None, refresh_interval=None, stats_callback_fn=None): + """Starts a thread to monitor pid every refresh_interval seconds. + + Passes a ResourceStatsSample object to the callback.""" + super(ResourceMonitoringThread, self).__init__() + if refresh_interval is None: + refresh_interval = 5 + self.is_running_event = threading.Event() + self.p = psutil.Process(pid) + self.refresh_interval = refresh_interval + if stats_callback_fn is None: + # Default callback + def stats_callback_fn(resource_sample: ResourceStatsSample): + print( + f"PID {self.p.pid} Stats: {resource_sample.resource_stats}") + elif not callable(stats_callback_fn): + raise ValueError("Callback needs to be callable, got {}".format( + type(stats_callback_fn))) + self.stats_callback_fn = stats_callback_fn + + def stop(self) -> None: + self.is_running_event.set() + + def run(self) -> None: + while not self.is_running_event.is_set(): + self.sample_counters() + self.is_running_event.wait(self.refresh_interval) + + def log_sample(self, resource_sample: ResourceStatsSample) -> None: + self.stats_callback_fn(resource_sample) + + def sample_counters(self) -> None: + if not self.p.is_running(): + self.stop() + return + + with self.p.oneshot(): + cpu_percent = self.p.cpu_percent() + cpu_times = self.p.cpu_times() + io_counters = self.p.io_counters() + cpu_affinity = self.p.cpu_affinity() + cpu_num = self.p.cpu_num() + num_threads = self.p.num_threads() + num_ctx_switches = self.p.num_ctx_switches() + timestamp = time.time() + + read_count = io_counters.read_count + write_count = io_counters.write_count + read_bytes = io_counters.read_bytes + write_bytes = io_counters.write_bytes + read_chars = io_counters.read_chars + write_chars = io_counters.write_chars + + def compress_cpu_affinity(cpu_affinity): + """Change list representation to interval/range representation.""" + if not cpu_affinity: + return "" + cpu_affinity_compressed = [] + min_x = None + max_x = None + last_x = None + + # Find contiguous ranges + for x in cpu_affinity: + if last_x is None: + # Start interval + min_x = x + max_x = x + last_x = x + continue + elif x == (last_x + 1): + # Move interval up + max_x = x + elif max_x is not None: + # Interval ended, start again + if min_x == max_x: + cpu_affinity_compressed.append("{}".format(min_x)) + else: + cpu_affinity_compressed.append( + "{}-{}".format(min_x, max_x)) + min_x = x + max_x = x + last_x = x + # Terminate last range + if max_x is not None: + if min_x == max_x: + cpu_affinity_compressed.append("{}".format(min_x)) + else: + cpu_affinity_compressed.append( + "{}-{}".format(min_x, max_x)) + + # Concat + cpu_affinity_compressed = ",".join(cpu_affinity_compressed) + + return cpu_affinity_compressed + + cpu_affinity = compress_cpu_affinity(cpu_affinity) + + resource_sample = ResourceStatsSample( + timestamp=timestamp, + cpu_percent=cpu_percent, + read_count=read_count, + write_count=write_count, + read_bytes=read_bytes, + write_bytes=write_bytes, + read_chars=read_chars, + write_chars=write_chars, + cpu_times_user=cpu_times.user, + cpu_times_system=cpu_times.system, + cpu_times_children_user=cpu_times.children_user, + cpu_times_children_system=cpu_times.children_system, + cpu_times_iowait=cpu_times.iowait, + cpu_affinity=cpu_affinity, + cpu_num=cpu_num, + num_threads=num_threads, + num_voluntary_ctx_switches=num_ctx_switches.voluntary, + num_involuntary_ctx_switches=num_ctx_switches.involuntary, + ) + self.log_sample(resource_sample) + + +if __name__ == "__main__": + import multiprocessing + import time + pid = multiprocessing.current_process().pid + monitor_thread = ResourceMonitoringThread(pid, 1) + monitor_thread.start() + time.sleep(5) + print("Shutdown") + monitor_thread.stop() diff --git a/jepa_src/utils/schedulers.py b/jepa_src/utils/schedulers.py new file mode 100644 index 0000000..df02e2b --- /dev/null +++ b/jepa_src/utils/schedulers.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + + +class WarmupCosineSchedule(object): + + def __init__( + self, + optimizer, + warmup_steps, + start_lr, + ref_lr, + T_max, + last_epoch=-1, + final_lr=0. + ): + self.optimizer = optimizer + self.start_lr = start_lr + self.ref_lr = ref_lr + self.final_lr = final_lr + self.warmup_steps = warmup_steps + self.T_max = T_max - warmup_steps + self._step = 0. + + def step(self): + self._step += 1 + if self._step < self.warmup_steps: + progress = float(self._step) / float(max(1, self.warmup_steps)) + new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) + else: + # -- progress after warmup + progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) + new_lr = max(self.final_lr, + self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1. + math.cos(math.pi * progress))) + + for group in self.optimizer.param_groups: + group['lr'] = new_lr + + return new_lr + + +class CosineWDSchedule(object): + + def __init__( + self, + optimizer, + ref_wd, + T_max, + final_wd=0. + ): + self.optimizer = optimizer + self.ref_wd = ref_wd + self.final_wd = final_wd + self.T_max = T_max + self._step = 0. + + def step(self): + self._step += 1 + progress = self._step / self.T_max + new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1. + math.cos(math.pi * progress)) + + if self.final_wd <= self.ref_wd: + new_wd = max(self.final_wd, new_wd) + else: + new_wd = min(self.final_wd, new_wd) + + for group in self.optimizer.param_groups: + if ('WD_exclude' not in group) or not group['WD_exclude']: + group['weight_decay'] = new_wd + return new_wd diff --git a/jepa_src/utils/tensors.py b/jepa_src/utils/tensors.py new file mode 100644 index 0000000..6ae2850 --- /dev/null +++ b/jepa_src/utils/tensors.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +import torch + +from logging import getLogger + +logger = getLogger() + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def apply_masks(x, masks): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors containing indices of patches [0,N) to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + return torch.cat(all_x, dim=0) + + +def repeat_interleave_batch(x, B, repeat): + N = len(x) // B + x = torch.cat([ + torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0) + for i in range(N) + ], dim=0) + return x diff --git a/requirements.txt b/requirements.txt index d297071..386919b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ -torch>=2 -torchvision pyyaml numpy opencv-python diff --git a/setup.py b/setup.py index 82de1e0..5865e1a 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,9 @@ import os from setuptools import setup -VERSION = "0.0.1" +VERSION = "0.0.4" + +from setuptools import setup, find_packages def get_requirements(): with open("./requirements.txt") as reqsf: @@ -17,9 +19,12 @@ def get_requirements(): if __name__ == "__main__": setup( - name="jepa", + name="vjepa_encoder", version=VERSION, description="JEPA research code.", - python_requires=">=3.9", + author="Jonathan Koch", + author_email="johnnykoch02@gmail.com", + python_requires=">=3.7", + packages=find_packages(), install_requires=get_requirements(), - ) + ) \ No newline at end of file diff --git a/vjepa_encoder.egg-info/PKG-INFO b/vjepa_encoder.egg-info/PKG-INFO new file mode 100644 index 0000000..ee525c7 --- /dev/null +++ b/vjepa_encoder.egg-info/PKG-INFO @@ -0,0 +1,11 @@ +Metadata-Version: 1.2 +Name: vjepa-encoder +Version: 0.0.4 +Summary: JEPA research code. +Home-page: UNKNOWN +Author: Jonathan Koch +Author-email: johnnykoch02@gmail.com +License: UNKNOWN +Description: UNKNOWN +Platform: UNKNOWN +Requires-Python: >=3.7 diff --git a/vjepa_encoder.egg-info/SOURCES.txt b/vjepa_encoder.egg-info/SOURCES.txt new file mode 100644 index 0000000..33d81e9 --- /dev/null +++ b/vjepa_encoder.egg-info/SOURCES.txt @@ -0,0 +1,48 @@ +LICENSE +README.md +setup.py +jepa_src/__init__.py +jepa_src/datasets/__init__.py +jepa_src/datasets/data_manager.py +jepa_src/datasets/image_dataset.py +jepa_src/datasets/video_dataset.py +jepa_src/datasets/utils/__init__.py +jepa_src/datasets/utils/weighted_sampler.py +jepa_src/datasets/utils/video/__init__.py +jepa_src/datasets/utils/video/functional.py +jepa_src/datasets/utils/video/randaugment.py +jepa_src/datasets/utils/video/randerase.py +jepa_src/datasets/utils/video/transforms.py +jepa_src/datasets/utils/video/volume_transforms.py +jepa_src/masks/__init__.py +jepa_src/masks/default.py +jepa_src/masks/multiblock3d.py +jepa_src/masks/random_tube.py +jepa_src/masks/utils.py +jepa_src/models/__init__.py +jepa_src/models/attentive_pooler.py +jepa_src/models/predictor.py +jepa_src/models/vision_transformer.py +jepa_src/models/utils/__init__.py +jepa_src/models/utils/modules.py +jepa_src/models/utils/multimask.py +jepa_src/models/utils/patch_embed.py +jepa_src/models/utils/pos_embs.py +jepa_src/utils/__init__.py +jepa_src/utils/distributed.py +jepa_src/utils/functional.py +jepa_src/utils/logging.py +jepa_src/utils/monitoring.py +jepa_src/utils/schedulers.py +jepa_src/utils/tensors.py +vjepa_encoder/__init__.py +vjepa_encoder/vision_encoder.py +vjepa_encoder.egg-info/PKG-INFO +vjepa_encoder.egg-info/SOURCES.txt +vjepa_encoder.egg-info/dependency_links.txt +vjepa_encoder.egg-info/requires.txt +vjepa_encoder.egg-info/top_level.txt +vjepa_encoder/vjepa/__init__.py +vjepa_encoder/vjepa/train.py +vjepa_encoder/vjepa/transforms.py +vjepa_encoder/vjepa/utils.py \ No newline at end of file diff --git a/vjepa_encoder.egg-info/dependency_links.txt b/vjepa_encoder.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/vjepa_encoder.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/vjepa_encoder.egg-info/requires.txt b/vjepa_encoder.egg-info/requires.txt new file mode 100644 index 0000000..386919b --- /dev/null +++ b/vjepa_encoder.egg-info/requires.txt @@ -0,0 +1,11 @@ +pyyaml +numpy +opencv-python +submitit +braceexpand +webdataset +timm +decord +pandas +einops +beartype diff --git a/vjepa_encoder.egg-info/top_level.txt b/vjepa_encoder.egg-info/top_level.txt new file mode 100644 index 0000000..b7a0b20 --- /dev/null +++ b/vjepa_encoder.egg-info/top_level.txt @@ -0,0 +1,2 @@ +jepa_src +vjepa_encoder diff --git a/vjepa_encoder/__init__.py b/vjepa_encoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vjepa_encoder/vision_encoder.py b/vjepa_encoder/vision_encoder.py new file mode 100644 index 0000000..7d74393 --- /dev/null +++ b/vjepa_encoder/vision_encoder.py @@ -0,0 +1,327 @@ +# Extension of Jepa by Robot Perception and Action Laboratory, USF +# +# Non-Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from typing import List, Optional, Any +import multiprocessing as mp + +import pprint +import yaml +import os + +import torch + +from jepa_src.utils.distributed import init_distributed + +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Tuple + +from vjepa_encoder.vjepa.utils import init_video_model +import numpy as np + +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +# from torch.nn.parallel import DistributedDataParallel +from jepa_src.utils.distributed import init_distributed, AllReduce +from jepa_src.utils.logging import get_logger + +from vjepa_encoder.vjepa.utils import init_video_model + +import torch +from torchvision import transforms +from PIL import Image +import numpy as np + +_GLOBAL_SEED = 0 +np.random.seed(_GLOBAL_SEED) +torch.manual_seed(_GLOBAL_SEED) +torch.backends.cudnn.benchmark = True + +import logging +from jepa_src.utils.logging import get_logger +logger = get_logger(force=True) +logger.setLevel(logging.INFO) + +class JepaEncoder(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + self.encoder, self.predictor = None, None + + def preprocess_image(self, input_data: Any): + """ + Preprocess the input image data. + + Args: + input_data (Any): Input data in various formats. + - str: Path to the image file. + - list: List of image data (numpy arrays, PIL Images, or tensors). + - numpy.ndarray: Image data as a numpy array. + - If the array has shape (batch_size, height, width, channels), it will be treated as a batch of images. + - If the array has shape (height, width, channels), it will be treated as a single image. + - PIL.Image.Image: Image data as a PIL Image object. + - torch.Tensor: Image data as a PyTorch tensor. + + Returns: + torch.Tensor: Preprocessed image data as a tensor. + - If the input is a batch of images, the output will have shape (batch_size, channels, height, width). + - If the input is a single image, the output will have shape (1, channels, height, width). + + Raises: + ValueError: If the input type is not supported. + """ + if isinstance(input_data, str): + img = Image.open(input_data).convert('RGB') + + elif isinstance(input_data, list): + imgs = [ + self.preprocess_image(i).squeeze() for i in input_data + ] + preprocessed_input = torch.stack(imgs) + return preprocessed_input + + elif isinstance(input_data, np.ndarray): + if len(input_data.shape) == 4: + input_data = input_data.transpose(0, 3, 1, 2) + preprocessed_input = torch.from_numpy(input_data).float() + preprocess = transforms.Compose([ + transforms.Resize(self.args['data']['crop_size']), + transforms.CenterCrop(self.args['data']['crop_size']), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + preprocessed_input = preprocess(preprocessed_input) + return preprocessed_input + + img = Image.fromarray(input_data.astype(np.uint8)) + + elif isinstance(input_data, Image.Image): + img = input_data + + elif isinstance(input_data, torch.Tensor): + preprocessed_input = input_data + preprocess = transforms.Compose([ + transforms.Resize(self.args['data']['crop_size']), + transforms.CenterCrop(self.args['data']['crop_size']), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + preprocessed_input = preprocess(preprocessed_input) + return preprocessed_input + + else: + raise ValueError("Unsupported input type. Expected image path, image array, or PIL Image.") + + # Define the preprocessing transforms + preprocess = transforms.Compose([ + transforms.Resize(self.args['data']['crop_size']), + transforms.CenterCrop(self.args['data']['crop_size']), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # Apply preprocessing transforms + preprocessed_input = preprocess(img) + + preprocessed_input = preprocessed_input.unsqueeze(0) # Add batch dimension + return preprocessed_input + + def embed_image(self, x): + """ + Generate embeddings for the input image data. + + Args: + x (Any): Input image data in various formats. + - str: Path to the image file. + - list: List of image data (numpy arrays, PIL Images, or tensors). + - numpy.ndarray: Image data as a numpy array. + - If the array has shape (batch_size, height, width, channels), it will be treated as a batch of images. + - If the array has shape (height, width, channels), it will be treated as a single image. + - PIL.Image.Image: Image data as a PIL Image object. + - torch.Tensor: Image data as a PyTorch tensor. + + Returns: + torch.Tensor: Embeddings for the input image data. + - If the input is a batch of images, the output will have shape (batch_size, num_patches, embedding_size). + - If the input is a single image, the output will have shape (1, num_patches, embedding_size). + + Notes: + - The input image data is preprocessed using the `preprocess_image` method before generating embeddings. + - If the preprocessed input has fewer than 5 dimensions, an additional dimension is added to represent the time dimension. + - The embeddings are generated using the forward pass of the model. + - The computation is performed on the available device (GPU if available, otherwise CPU). + """ + x = self.preprocess_image(x) + + # Unsqueeze along the time Dimension + if len(x.shape) < 5: + x = x.unsqueeze(2) + + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + + x = x.to(device) + + with torch.no_grad(): + embeddings = self.forward(x) + + return embeddings + + def load_encoder_checkpoint( + self, + r_path, + encoder, + ): + try: + checkpoint = torch.load(r_path, map_location=torch.device('cpu')) + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + + try: + + # -- loading encoder + pretrained_dict = checkpoint['encoder'] + msg = encoder.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') + + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + epoch = 0 + + return encoder + + + def forward(self, clips: torch.Tensor, masks_enc: List[torch.Tensor], masks_pred: List[torch.Tensor]) -> List[torch.Tensor]: + z = self.encoder(clips, masks_enc) + h = self._forward_target(clips, masks_pred) + z = self.predictor(z, h, masks_enc, masks_pred) + return z + + def freeze_encoder(self): + for p in self.encoder.parameters(): + p.requires_grad = False + + def forward(self, x): + return self.encoder(x) + + @classmethod + def load_model(cls, config_file_path: str, device: Optional[List[str]] = None) -> "JepaEncoder": + # TODO: Fix this so it works properly + # os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) + + args = None + with open(config_file_path, 'r') as y_file: + args = yaml.load(y_file, Loader=yaml.FullLoader) + logger.info('loaded params...') + + pprint.PrettyPrinter(indent=4).pprint(args) + dump = os.path.join(args['logging']['folder'], 'params-encoder.yaml') + with open(dump, 'w') as f: + yaml.dump(args, f) + + + model = cls(args) + + world_size, rank = init_distributed() + + # -- META + cfgs_meta = args.get('meta') + load_model = cfgs_meta.get('load_checkpoint') + assert load_model, "Cannot load model without checkpoint file specified" + r_file = cfgs_meta.get('read_checkpoint', None) + seed = cfgs_meta.get('seed', _GLOBAL_SEED) + save_every_freq = cfgs_meta.get('save_every_freq', -1) + skip_batches = cfgs_meta.get('skip_batches', -1) + use_sdpa = cfgs_meta.get('use_sdpa', False) + which_dtype = cfgs_meta.get('dtype') + logger.info(f'{which_dtype}') + if which_dtype.lower() == 'bfloat16': + dtype = torch.bfloat16 + mixed_precision = True + elif which_dtype.lower() == 'float16': + dtype = torch.float16 + mixed_precision = True + else: + dtype = torch.float32 + mixed_precision = False + + # -- MASK + cfgs_mask = args.get('mask') + + # -- MODEL + cfgs_model = args.get('model') + model_name = cfgs_model.get('model_name') + pred_depth = cfgs_model.get('pred_depth') + pred_embed_dim = cfgs_model.get('pred_embed_dim') + uniform_power = cfgs_model.get('uniform_power', True) + use_mask_tokens = cfgs_model.get('use_mask_tokens', True) + zero_init_mask_tokens = cfgs_model.get('zero_init_mask_tokens', True) + + # -- DATA + cfgs_data = args.get('data') + num_clips = cfgs_data.get('num_clips') + num_frames = cfgs_data.get('num_frames') + tubelet_size = cfgs_data.get('tubelet_size') + sampling_rate = cfgs_data.get('sampling_rate') + duration = cfgs_data.get('clip_duration', None) + crop_size = cfgs_data.get('crop_size', 224) + patch_size = cfgs_data.get('patch_size') + + # -- LOGGING + cfgs_logging = args.get('logging') + folder = cfgs_logging.get('folder') + tag = cfgs_logging.get('write_tag') + + # -- set device + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + torch.cuda.set_device(device) + + # -- log/checkpointing paths + latest_file = f'{tag}-latest.pth.tar' + latest_path = os.path.join(folder, latest_file) + load_path = None + if load_model: + load_path = os.path.join(folder, r_file) if r_file is not None else latest_path + if not os.path.exists(load_path): + load_path = r_file + if not os.path.exists(load_path): + raise RuntimeError("Cannot load model. Ensure you specify the path to the model .tar file in the input config.") + + # -- Attempt to initialize model + model.encoder, model.predictor = init_video_model( + uniform_power=uniform_power, + use_mask_tokens=use_mask_tokens, + num_mask_tokens=len(cfgs_mask), + zero_init_mask_tokens=zero_init_mask_tokens, + device=device, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + model_name=model_name, + crop_size=crop_size, + pred_depth=pred_depth, + pred_embed_dim=pred_embed_dim, + use_sdpa=use_sdpa, + ) + + # model.encoder = DistributedDataParallel(model.encoder, static_graph=True) + + # -- load training checkpoint + model.encoder = model.load_encoder_checkpoint( + load_path, model.encoder + ) + + return model + + diff --git a/vjepa_encoder/vjepa/__init__.py b/vjepa_encoder/vjepa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vjepa_encoder/vjepa/train.py b/vjepa_encoder/vjepa/train.py new file mode 100644 index 0000000..ccb2e75 --- /dev/null +++ b/vjepa_encoder/vjepa/train.py @@ -0,0 +1,586 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS +try: + # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE + # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE + # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE + # -- TO EACH PROCESS + os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID'] +except Exception: + pass + +import copy +import time +import numpy as np + +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel + +from jepa_src.datasets.data_manager import init_data +from jepa_src.masks.random_tube import MaskCollator as TubeMaskCollator +from jepa_src.masks.multiblock3d import MaskCollator as MB3DMaskCollator +from jepa_src.masks.utils import apply_masks +from jepa_src.utils.distributed import init_distributed, AllReduce +from jepa_src.utils.logging import ( + CSVLogger, + gpu_timer, + get_logger, + grad_logger, + adamw_logger, + AverageMeter) +from jepa_src.utils.tensors import repeat_interleave_batch + +from app.vjepa.utils import ( + load_checkpoint, + init_video_model, + init_opt, +) +from app.vjepa.transforms import make_transforms + + +# -- +log_timings = True +log_freq = 10 +checkpoint_freq = 1 +# -- + +_GLOBAL_SEED = 0 +np.random.seed(_GLOBAL_SEED) +torch.manual_seed(_GLOBAL_SEED) +torch.backends.cudnn.benchmark = True + + +logger = get_logger(__name__) + + +def main(args, resume_preempt=False): + # ----------------------------------------------------------------------- # + # PASSED IN PARAMS FROM CONFIG FILE + # ----------------------------------------------------------------------- # + + # -- META + cfgs_meta = args.get('meta') + load_model = cfgs_meta.get('load_checkpoint') or resume_preempt + r_file = cfgs_meta.get('read_checkpoint', None) + seed = cfgs_meta.get('seed', _GLOBAL_SEED) + save_every_freq = cfgs_meta.get('save_every_freq', -1) + skip_batches = cfgs_meta.get('skip_batches', -1) + use_sdpa = cfgs_meta.get('use_sdpa', False) + which_dtype = cfgs_meta.get('dtype') + logger.info(f'{which_dtype}') + if which_dtype.lower() == 'bfloat16': + dtype = torch.bfloat16 + mixed_precision = True + elif which_dtype.lower() == 'float16': + dtype = torch.float16 + mixed_precision = True + else: + dtype = torch.float32 + mixed_precision = False + + # -- MASK + cfgs_mask = args.get('mask') + + # -- MODEL + cfgs_model = args.get('model') + model_name = cfgs_model.get('model_name') + pred_depth = cfgs_model.get('pred_depth') + pred_embed_dim = cfgs_model.get('pred_embed_dim') + uniform_power = cfgs_model.get('uniform_power', True) + use_mask_tokens = cfgs_model.get('use_mask_tokens', True) + zero_init_mask_tokens = cfgs_model.get('zero_init_mask_tokens', True) + + # -- DATA + cfgs_data = args.get('data') + dataset_type = cfgs_data.get('dataset_type', 'videodataset') + mask_type = cfgs_data.get('mask_type', 'multiblock3d') + dataset_paths = cfgs_data.get('datasets', []) + datasets_weights = cfgs_data.get('datasets_weights', None) + if datasets_weights is not None: + assert len(datasets_weights) == len(dataset_paths), 'Must have one sampling weight specified for each dataset' + batch_size = cfgs_data.get('batch_size') + num_clips = cfgs_data.get('num_clips') + num_frames = cfgs_data.get('num_frames') + tubelet_size = cfgs_data.get('tubelet_size') + sampling_rate = cfgs_data.get('sampling_rate') + duration = cfgs_data.get('clip_duration', None) + crop_size = cfgs_data.get('crop_size', 224) + patch_size = cfgs_data.get('patch_size') + pin_mem = cfgs_data.get('pin_mem', False) + num_workers = cfgs_data.get('num_workers', 1) + filter_short_videos = cfgs_data.get('filter_short_videos', False) + decode_one_clip = cfgs_data.get('decode_one_clip', True) + log_resource_util_data = cfgs_data.get('log_resource_utilization', False) + + # -- DATA AUGS + cfgs_data_aug = args.get('data_aug') + ar_range = cfgs_data_aug.get('random_resize_aspect_ratio', [3/4, 4/3]) + rr_scale = cfgs_data_aug.get('random_resize_scale', [0.3, 1.0]) + motion_shift = cfgs_data_aug.get('motion_shift', False) + reprob = cfgs_data_aug.get('reprob', 0.) + use_aa = cfgs_data_aug.get('auto_augment', False) + + # -- LOSS + cfgs_loss = args.get('loss') + loss_exp = cfgs_loss.get('loss_exp') + reg_coeff = cfgs_loss.get('reg_coeff') + + # -- OPTIMIZATION + cfgs_opt = args.get('optimization') + ipe = cfgs_opt.get('ipe', None) + ipe_scale = cfgs_opt.get('ipe_scale', 1.0) + clip_grad = cfgs_opt.get('clip_grad', None) + wd = float(cfgs_opt.get('weight_decay')) + final_wd = float(cfgs_opt.get('final_weight_decay')) + num_epochs = cfgs_opt.get('epochs') + warmup = cfgs_opt.get('warmup') + start_lr = cfgs_opt.get('start_lr') + lr = cfgs_opt.get('lr') + final_lr = cfgs_opt.get('final_lr') + ema = cfgs_opt.get('ema') + betas = cfgs_opt.get('betas', (0.9, 0.999)) + eps = cfgs_opt.get('eps', 1.e-8) + + # -- LOGGING + cfgs_logging = args.get('logging') + folder = cfgs_logging.get('folder') + tag = cfgs_logging.get('write_tag') + + # ----------------------------------------------------------------------- # + # ----------------------------------------------------------------------- # + + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.benchmark = True + try: + mp.set_start_method('spawn') + except Exception: + pass + + # -- init torch distributed backend + world_size, rank = init_distributed() + logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') + + # -- set device + if not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda:0') + torch.cuda.set_device(device) + + # -- log/checkpointing paths + log_file = os.path.join(folder, f'{tag}_r{rank}.csv') + latest_file = f'{tag}-latest.pth.tar' + latest_path = os.path.join(folder, latest_file) + load_path = None + if load_model: + load_path = os.path.join(folder, r_file) if r_file is not None else latest_path + if not os.path.exists(load_path): + load_path = None + load_model = False + + # -- make csv_logger + csv_logger = CSVLogger( + log_file, + ('%d', 'epoch'), + ('%d', 'itr'), + ('%.5f', 'loss'), + ('%.5f', 'loss-jepa'), + ('%.5f', 'reg-loss'), + ('%.5f', 'enc-grad-norm'), + ('%.5f', 'pred-grad-norm'), + ('%d', 'gpu-time(ms)'), + ('%d', 'wall-time(ms)'), + ) + + # -- init model + encoder, predictor = init_video_model( + uniform_power=uniform_power, + use_mask_tokens=use_mask_tokens, + num_mask_tokens=len(cfgs_mask), + zero_init_mask_tokens=zero_init_mask_tokens, + device=device, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + model_name=model_name, + crop_size=crop_size, + pred_depth=pred_depth, + pred_embed_dim=pred_embed_dim, + use_sdpa=use_sdpa, + ) + target_encoder = copy.deepcopy(encoder) + + # -- make data transforms + if mask_type == 'multiblock3d': + logger.info('Initializing basic multi-block mask') + mask_collator = MB3DMaskCollator( + crop_size=crop_size, + num_frames=num_frames, + patch_size=patch_size, + tubelet_size=tubelet_size, + cfgs_mask=cfgs_mask) + else: + logger.info('Initializing random tube mask') + mask_collator = TubeMaskCollator( + crop_size=crop_size, + num_frames=num_frames, + patch_size=patch_size, + tubelet_size=tubelet_size, + cfgs_mask=cfgs_mask) + transform = make_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=ar_range, + random_resize_scale=rr_scale, + reprob=reprob, + auto_augment=use_aa, + motion_shift=motion_shift, + crop_size=crop_size) + + # -- init data-loaders/samplers + (unsupervised_loader, + unsupervised_sampler) = init_data( + data=dataset_type, + root_path=dataset_paths, + batch_size=batch_size, + training=True, + clip_len=num_frames, + frame_sample_rate=sampling_rate, + filter_short_videos=filter_short_videos, + decode_one_clip=decode_one_clip, + duration=duration, + num_clips=num_clips, + transform=transform, + datasets_weights=datasets_weights, + collator=mask_collator, + num_workers=num_workers, + world_size=world_size, + pin_mem=pin_mem, + rank=rank, + log_dir=folder if log_resource_util_data else None) + try: + _dlen = len(unsupervised_loader) + except Exception: # Different interface for webdataset + _dlen = unsupervised_loader.num_batches + if ipe is None: + ipe = _dlen + logger.info(f'iterations per epoch/dataest length: {ipe}/{_dlen}') + + # -- init optimizer and scheduler + optimizer, scaler, scheduler, wd_scheduler = init_opt( + encoder=encoder, + predictor=predictor, + wd=wd, + final_wd=final_wd, + start_lr=start_lr, + ref_lr=lr, + final_lr=final_lr, + iterations_per_epoch=ipe, + warmup=warmup, + num_epochs=num_epochs, + ipe_scale=ipe_scale, + mixed_precision=mixed_precision, + betas=betas, + eps=eps) + encoder = DistributedDataParallel(encoder, static_graph=True) + predictor = DistributedDataParallel(predictor, static_graph=True) + target_encoder = DistributedDataParallel(target_encoder) + for p in target_encoder.parameters(): + p.requires_grad = False + + # -- momentum schedule + momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale) + for i in range(int(ipe*num_epochs*ipe_scale)+1)) + + start_epoch = 0 + # -- load training checkpoint + if load_model or os.path.exists(latest_path): + ( + encoder, + predictor, + target_encoder, + optimizer, + scaler, + start_epoch, + ) = load_checkpoint( + r_path=load_path, + encoder=encoder, + predictor=predictor, + target_encoder=target_encoder, + opt=optimizer, + scaler=scaler) + for _ in range(start_epoch * ipe): + scheduler.step() + wd_scheduler.step() + next(momentum_scheduler) + mask_collator.step() + + def save_checkpoint(epoch, path): + if rank != 0: + return + save_dict = { + 'encoder': encoder.state_dict(), + 'predictor': predictor.state_dict(), + 'opt': optimizer.state_dict(), + 'scaler': None if scaler is None else scaler.state_dict(), + 'target_encoder': target_encoder.state_dict(), + 'epoch': epoch, + 'loss': loss_meter.avg, + 'batch_size': batch_size, + 'world_size': world_size, + 'lr': lr, + } + try: + torch.save(save_dict, path) + except Exception as e: + logger.info(f'Encountered exception when saving checkpoint: {e}') + + logger.info('Initializing loader...') + loader = iter(unsupervised_loader) + + if skip_batches > 0: + logger.info(f'Skip {skip_batches} batches') + unsupervised_sampler.set_epoch(start_epoch) + for itr in range(skip_batches): + if itr % 10 == 0: + logger.info(f'Skip {itr}/{skip_batches} batches') + try: + udata = next(loader) + except Exception: + loader = iter(unsupervised_loader) + udata = next(loader) + + # -- TRAINING LOOP + for epoch in range(start_epoch, num_epochs): + logger.info('Epoch %d' % (epoch + 1)) + + # -- update distributed-data-loader epoch + unsupervised_sampler.set_epoch(epoch) + + loss_meter = AverageMeter() + input_var_meter = AverageMeter() + input_var_min_meter = AverageMeter() + jepa_loss_meter = AverageMeter() + reg_loss_meter = AverageMeter() + mask_meters = [AverageMeter() for _ in range(len(cfgs_mask))] + gpu_time_meter = AverageMeter() + wall_time_meter = AverageMeter() + + for itr in range(ipe): + itr_start_time = time.time() + + try: + udata, masks_enc, masks_pred = next(loader) + except Exception: + logger.info('Exhausted data loaders. Refreshing...') + loader = iter(unsupervised_loader) + udata, masks_enc, masks_pred = next(loader) + assert len(masks_enc) == len(masks_pred), \ + 'Currently require num encoder masks = num predictor masks' + + def load_clips(): + # -- unsupervised video clips + # Put each clip on the GPU and concatenate along batch + # dimension + clips = torch.cat([u.to(device, non_blocking=True) for u in udata[0]], dim=0) + + # Put each mask-enc/mask-pred pair on the GPU and reuse the + # same mask pair for each clip + _masks_enc, _masks_pred = [], [] + for _me, _mp in zip(masks_enc, masks_pred): + _me = _me.to(device, non_blocking=True) + _mp = _mp.to(device, non_blocking=True) + _me = repeat_interleave_batch(_me, batch_size, repeat=num_clips) + _mp = repeat_interleave_batch(_mp, batch_size, repeat=num_clips) + _masks_enc.append(_me) + _masks_pred.append(_mp) + + return (clips, _masks_enc, _masks_pred) + clips, masks_enc, masks_pred = load_clips() + + for _i, m in enumerate(mask_meters): + m.update(masks_enc[_i][0].size(-1)) + + def train_step(): + _new_lr = scheduler.step() + _new_wd = wd_scheduler.step() + # -- + + def forward_target(c): + """ + Returns list of tensors of shape [B, N, D], one for each + mask-pred. + """ + with torch.no_grad(): + h = target_encoder(c) + h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim [B, N, D] + # -- create targets (masked regions of h) + h = apply_masks(h, masks_pred, concat=False) + return h + + def forward_context(c, h): + """ + Returns list of tensors of shape [B, N, D], one for each + mask-pred. + """ + z = encoder(c, masks_enc) + z = predictor(z, h, masks_enc, masks_pred) + return z + + def loss_fn(z, h): + loss = 0. + # Compute loss and accumulate for each mask-enc/mask-pred pair + for zi, hi in zip(z, h): + loss += torch.mean(torch.abs(zi - hi)**loss_exp) / loss_exp + loss /= len(masks_pred) + return loss + + def reg_fn(z): + return sum([torch.sqrt(zi.var(dim=1) + 0.0001) for zi in z]) / len(z) + + # Step 1. Forward + loss_jepa, loss_reg = 0., 0. + with torch.cuda.amp.autocast(dtype=dtype, enabled=mixed_precision): + h = forward_target(clips) + z = forward_context(clips, h) + loss_jepa = loss_fn(z, h) # jepa prediction loss + pstd_z = reg_fn(z) # predictor variance across patches + loss_reg += torch.mean(F.relu(1.-pstd_z)) + loss = loss_jepa + reg_coeff * loss_reg + + # Step 2. Backward & step + _enc_norm, _pred_norm = 0., 0. + if mixed_precision: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + else: + loss.backward() + if (epoch > warmup) and (clip_grad is not None): + _enc_norm = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip_grad) + _pred_norm = torch.nn.utils.clip_grad_norm_(predictor.parameters(), clip_grad) + if mixed_precision: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + grad_stats = grad_logger(encoder.named_parameters()) + grad_stats.global_norm = float(_enc_norm) + grad_stats_pred = grad_logger(predictor.named_parameters()) + grad_stats_pred.global_norm = float(_pred_norm) + optimizer.zero_grad() + optim_stats = adamw_logger(optimizer) + + # Step 3. momentum update of target encoder + m = next(momentum_scheduler) + with torch.no_grad(): + for param_q, param_k in zip(encoder.parameters(), target_encoder.parameters()): + param_k.data.mul_(m).add_((1.-m) * param_q.detach().data) + + return ( + float(loss), + float(loss_jepa), + float(loss_reg), + _new_lr, + _new_wd, + grad_stats, + grad_stats_pred, + optim_stats, + ) + (loss, loss_jepa, loss_reg, _new_lr, _new_wd, grad_stats, grad_stats_pred, optim_stats,), gpu_etime_ms = gpu_timer(train_step) + iter_elapsed_time_ms = (time.time() - itr_start_time) * 1000. + loss_meter.update(loss) + input_var = float(AllReduce.apply(clips.view(clips.shape[0], -1).var(dim=1).mean(dim=0))) + input_var_min = float(AllReduce.apply(torch.min(clips.view(clips.shape[0], -1).var(dim=1)))) + input_var_meter.update(input_var) + input_var_min_meter.update(input_var_min) + jepa_loss_meter.update(loss_jepa) + reg_loss_meter.update(loss_reg) + gpu_time_meter.update(gpu_etime_ms) + wall_time_meter.update(iter_elapsed_time_ms) + + # -- Logging + def log_stats(): + csv_logger.log( + epoch + 1, + itr, + loss, + loss_jepa, + loss_reg, + grad_stats.global_norm, + grad_stats_pred.global_norm, + gpu_etime_ms, + iter_elapsed_time_ms) + if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss): + logger.info( + '[%d, %5d] loss: %.3f | p%.3f r%.3f | ' + 'input_var: %.3f %.3f | ' + 'masks: %s ' + '[wd: %.2e] [lr: %.2e] ' + '[mem: %.2e] ' + '[gpu: %.1f ms]' + '[wall: %.1f ms]' + % (epoch + 1, itr, + loss_meter.avg, + jepa_loss_meter.avg, + reg_loss_meter.avg, + input_var_meter.avg, + input_var_min_meter.avg, + '[' + ', '.join(['%.1f' % m.avg for m in mask_meters]) + ']', + _new_wd, + _new_lr, + torch.cuda.max_memory_allocated() / 1024.0**2, + gpu_time_meter.avg, + wall_time_meter.avg)) + + if optim_stats is not None: + logger.info( + '[%d, %5d] first moment: %.2e [%.2e %.2e] second moment: %.2e [%.2e %.2e]' + % (epoch + 1, itr, + optim_stats.get('exp_avg').avg, + optim_stats.get('exp_avg').min, + optim_stats.get('exp_avg').max, + optim_stats.get('exp_avg_sq').avg, + optim_stats.get('exp_avg_sq').min, + optim_stats.get('exp_avg_sq').max)) + + if grad_stats is not None: + logger.info( + '[%d, %5d] enc_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e' + % (epoch + 1, itr, + grad_stats.first_layer, + grad_stats.last_layer, + grad_stats.min, + grad_stats.max, + grad_stats.global_norm)) + + if grad_stats_pred is not None: + logger.info( + '[%d, %5d] pred_grad_stats: f/l[%.2e %.2e] mn/mx(%.2e, %.2e) %.2e' + % (epoch + 1, itr, + grad_stats_pred.first_layer, + grad_stats_pred.last_layer, + grad_stats_pred.min, + grad_stats_pred.max, + grad_stats_pred.global_norm)) + log_stats() + assert not np.isnan(loss), 'loss is nan' + + # -- Save Checkpoint + logger.info('avg. loss %.3f' % loss_meter.avg) + # -- Save Last + if epoch % checkpoint_freq == 0 or epoch == (num_epochs - 1): + save_checkpoint(epoch + 1, latest_path) + if save_every_freq > 0 and epoch % save_every_freq == 0: + save_every_file = f'{tag}-e{epoch}.pth.tar' + save_every_path = os.path.join(folder, save_every_file) + save_checkpoint(epoch + 1, save_every_path) diff --git a/vjepa_encoder/vjepa/transforms.py b/vjepa_encoder/vjepa/transforms.py new file mode 100644 index 0000000..ba62555 --- /dev/null +++ b/vjepa_encoder/vjepa/transforms.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torchvision.transforms as transforms + +import jepa_src.datasets.utils.video.transforms as video_transforms +from jepa_src.datasets.utils.video.randerase import RandomErasing + + +def make_transforms( + random_horizontal_flip=True, + random_resize_aspect_ratio=(3/4, 4/3), + random_resize_scale=(0.3, 1.0), + reprob=0.0, + auto_augment=False, + motion_shift=False, + crop_size=224, + normalize=((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) +): + + _frames_augmentation = VideoTransform( + random_horizontal_flip=random_horizontal_flip, + random_resize_aspect_ratio=random_resize_aspect_ratio, + random_resize_scale=random_resize_scale, + reprob=reprob, + auto_augment=auto_augment, + motion_shift=motion_shift, + crop_size=crop_size, + normalize=normalize, + ) + return _frames_augmentation + + +class VideoTransform(object): + + def __init__( + self, + random_horizontal_flip=True, + random_resize_aspect_ratio=(3/4, 4/3), + random_resize_scale=(0.3, 1.0), + reprob=0.0, + auto_augment=False, + motion_shift=False, + crop_size=224, + normalize=((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + ): + + self.random_horizontal_flip = random_horizontal_flip + self.random_resize_aspect_ratio = random_resize_aspect_ratio + self.random_resize_scale = random_resize_scale + self.auto_augment = auto_augment + self.motion_shift = motion_shift + self.crop_size = crop_size + self.mean = torch.tensor(normalize[0], dtype=torch.float32) + self.std = torch.tensor(normalize[1], dtype=torch.float32) + if not self.auto_augment: + # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255. + self.mean *= 255. + self.std *= 255. + + self.autoaug_transform = video_transforms.create_random_augment( + input_size=(crop_size, crop_size), + auto_augment='rand-m7-n4-mstd0.5-inc1', + interpolation='bicubic', + ) + + self.spatial_transform = video_transforms.random_resized_crop_with_shift \ + if motion_shift else video_transforms.random_resized_crop + + self.reprob = reprob + self.erase_transform = RandomErasing( + reprob, + mode='pixel', + max_count=1, + num_splits=1, + device='cpu', + ) + + def __call__(self, buffer): + + if self.auto_augment: + buffer = [transforms.ToPILImage()(frame) for frame in buffer] + buffer = self.autoaug_transform(buffer) + buffer = [transforms.ToTensor()(img) for img in buffer] + buffer = torch.stack(buffer) # T C H W + buffer = buffer.permute(0, 2, 3, 1) # T H W C + else: + buffer = torch.tensor(buffer, dtype=torch.float32) + + buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W + + buffer = self.spatial_transform( + images=buffer, + target_height=self.crop_size, + target_width=self.crop_size, + scale=self.random_resize_scale, + ratio=self.random_resize_aspect_ratio, + ) + if self.random_horizontal_flip: + buffer, _ = video_transforms.horizontal_flip(0.5, buffer) + + buffer = _tensor_normalize_inplace(buffer, self.mean, self.std) + if self.reprob > 0: + buffer = buffer.permute(1, 0, 2, 3) + buffer = self.erase_transform(buffer) + buffer = buffer.permute(1, 0, 2, 3) + + return buffer + + +def tensor_normalize(tensor, mean, std): + """ + Normalize a given tensor by subtracting the mean and dividing the std. + Args: + tensor (tensor): tensor to normalize. + mean (tensor or list): mean value to subtract. + std (tensor or list): std to divide. + """ + if tensor.dtype == torch.uint8: + tensor = tensor.float() + tensor = tensor / 255.0 + if type(mean) == list: + mean = torch.tensor(mean) + if type(std) == list: + std = torch.tensor(std) + tensor = tensor - mean + tensor = tensor / std + return tensor + + +def _tensor_normalize_inplace(tensor, mean, std): + """ + Normalize a given tensor by subtracting the mean and dividing the std. + Args: + tensor (tensor): tensor to normalize (with dimensions C, T, H, W). + mean (tensor): mean value to subtract (in 0 to 255 floats). + std (tensor): std to divide (in 0 to 255 floats). + """ + if tensor.dtype == torch.uint8: + tensor = tensor.float() + + C, T, H, W = tensor.shape + tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension + tensor.sub_(mean).div_(std) + tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front + return tensor diff --git a/vjepa_encoder/vjepa/utils.py b/vjepa_encoder/vjepa/utils.py new file mode 100644 index 0000000..2636ed7 --- /dev/null +++ b/vjepa_encoder/vjepa/utils.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging +import sys +import warnings +import yaml + + +import torch + +import jepa_src.models.vision_transformer as video_vit +import jepa_src.models.predictor as vit_pred +from jepa_src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper +from jepa_src.utils.schedulers import ( + WarmupCosineSchedule, + CosineWDSchedule) +from jepa_src.utils.tensors import trunc_normal_ + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging.getLogger() + + +def load_checkpoint( + r_path, + encoder, + predictor, + target_encoder, + opt, + scaler, +): + try: + checkpoint = torch.load(r_path, map_location=torch.device('cpu')) + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + + epoch = 0 + try: + epoch = checkpoint['epoch'] + + # -- loading encoder + pretrained_dict = checkpoint['encoder'] + msg = encoder.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') + + # -- loading predictor + pretrained_dict = checkpoint['predictor'] + msg = predictor.load_state_dict(pretrained_dict) + logger.info(f'loaded pretrained predictor from epoch {epoch} with msg: {msg}') + + # -- loading target_encoder + if target_encoder is not None: + print(list(checkpoint.keys())) + pretrained_dict = checkpoint['target_encoder'] + msg = target_encoder.load_state_dict(pretrained_dict) + logger.info( + f'loaded pretrained target encoder from epoch {epoch} with msg: {msg}' + ) + + # -- loading optimizer + opt.load_state_dict(checkpoint['opt']) + if scaler is not None: + scaler.load_state_dict(checkpoint['scaler']) + logger.info(f'loaded optimizers from epoch {epoch}') + logger.info(f'read-path: {r_path}') + del checkpoint + + except Exception as e: + logger.info(f'Encountered exception when loading checkpoint {e}') + epoch = 0 + + return ( + encoder, + predictor, + target_encoder, + opt, + scaler, + epoch, + ) + + +def init_video_model( + device, + patch_size=16, + num_frames=16, + tubelet_size=2, + model_name='vit_base', + crop_size=224, + pred_depth=6, + pred_embed_dim=384, + uniform_power=False, + use_mask_tokens=False, + num_mask_tokens=2, + zero_init_mask_tokens=True, + use_sdpa=False, +): + encoder = video_vit.__dict__[model_name]( + img_size=crop_size, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + uniform_power=uniform_power, + use_sdpa=use_sdpa, + ) + encoder = MultiMaskWrapper(encoder) + predictor = vit_pred.__dict__['vit_predictor']( + img_size=crop_size, + use_mask_tokens=use_mask_tokens, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=tubelet_size, + embed_dim=encoder.backbone.embed_dim, + predictor_embed_dim=pred_embed_dim, + depth=pred_depth, + num_heads=encoder.backbone.num_heads, + uniform_power=uniform_power, + num_mask_tokens=num_mask_tokens, + zero_init_mask_tokens=zero_init_mask_tokens, + use_sdpa=use_sdpa, + ) + predictor = PredictorMultiMaskWrapper(predictor) + + def init_weights(m): + if isinstance(m, torch.nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + elif isinstance(m, torch.nn.LayerNorm): + torch.nn.init.constant_(m.bias, 0) + torch.nn.init.constant_(m.weight, 1.0) + + for m in encoder.modules(): + init_weights(m) + + for m in predictor.modules(): + init_weights(m) + + encoder.to(device) + predictor.to(device) + logger.info(encoder) + logger.info(predictor) + + def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.info(f'Encoder number of parameters: {count_parameters(encoder)}') + logger.info(f'Predictor number of parameters: {count_parameters(predictor)}') + + return encoder, predictor + + +def init_opt( + encoder, + predictor, + iterations_per_epoch, + start_lr, + ref_lr, + warmup, + num_epochs, + wd=1e-6, + final_wd=1e-6, + final_lr=0.0, + mixed_precision=False, + ipe_scale=1.25, + betas=(0.9, 0.999), + eps=1e-8, + zero_init_bias_wd=True, +): + param_groups = [ + { + 'params': (p for n, p in encoder.named_parameters() + if ('bias' not in n) and (len(p.shape) != 1)) + }, { + 'params': (p for n, p in predictor.named_parameters() + if ('bias' not in n) and (len(p.shape) != 1)) + }, { + 'params': (p for n, p in encoder.named_parameters() + if ('bias' in n) or (len(p.shape) == 1)), + 'WD_exclude': zero_init_bias_wd, + 'weight_decay': 0, + }, { + 'params': (p for n, p in predictor.named_parameters() + if ('bias' in n) or (len(p.shape) == 1)), + 'WD_exclude': zero_init_bias_wd, + 'weight_decay': 0, + }, + ] + + logger.info('Using AdamW') + optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps) + scheduler = WarmupCosineSchedule( + optimizer, + warmup_steps=int(warmup * iterations_per_epoch), + start_lr=start_lr, + ref_lr=ref_lr, + final_lr=final_lr, + T_max=int(ipe_scale * num_epochs * iterations_per_epoch), + ) + wd_scheduler = CosineWDSchedule( + optimizer, + ref_wd=wd, + final_wd=final_wd, + T_max=int(ipe_scale * num_epochs * iterations_per_epoch), + ) + scaler = torch.cuda.amp.GradScaler() if mixed_precision else None + return optimizer, scaler, scheduler, wd_scheduler From 9c2caadc77b3f1f9a509f05c2395f2b8bce5b5a8 Mon Sep 17 00:00:00 2001 From: Johnnykoch02 Date: Tue, 16 Apr 2024 21:23:07 -0400 Subject: [PATCH 3/4] Update docs --- README.md | 2 +- huggingface | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 160000 huggingface diff --git a/README.md b/README.md index 5643126..c3bc93d 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ The VJEPA Encoder is a Python package that provides an implementation of the enc To install the VJEPA Encoder package, you can use pip: ``` -pip install vjepa_encoder +pip install vjepa-encoder ``` ## Usage diff --git a/huggingface b/huggingface new file mode 160000 index 0000000..f4a9049 --- /dev/null +++ b/huggingface @@ -0,0 +1 @@ +Subproject commit f4a9049954dc6e9eb389e857ac5d9f3e52f2bdd3 From dd9c81c223efd2cf1d51c88895894d5441c8a9ef Mon Sep 17 00:00:00 2001 From: Johnnykoch02 Date: Thu, 18 Apr 2024 03:00:19 -0400 Subject: [PATCH 4/4] new jepa functionality --- build/lib/datasets/data_manager.py | 91 -- build/lib/datasets/image_dataset.py | 79 -- build/lib/datasets/utils/video/functional.py | 96 -- build/lib/datasets/utils/video/randaugment.py | 518 -------- build/lib/datasets/utils/video/randerase.py | 180 --- build/lib/datasets/utils/video/transforms.py | 1184 ----------------- .../datasets/utils/video/volume_transforms.py | 151 --- build/lib/datasets/utils/weighted_sampler.py | 97 -- build/lib/datasets/video_dataset.py | 272 ---- build/lib/jepa_src/models/utils/functional.py | 30 - build/lib/masks/default.py | 20 - build/lib/masks/multiblock3d.py | 203 --- build/lib/masks/random_tube.py | 117 -- build/lib/masks/utils.py | 23 - build/lib/models/attentive_pooler.py | 136 -- build/lib/models/predictor.py | 246 ---- build/lib/models/utils/modules.py | 185 --- build/lib/models/utils/multimask.py | 48 - build/lib/models/utils/patch_embed.py | 57 - build/lib/models/utils/pos_embs.py | 99 -- build/lib/models/vision_transformer.py | 307 ----- build/lib/utils/distributed.py | 113 -- build/lib/utils/logging.py | 118 -- build/lib/utils/monitoring.py | 175 --- build/lib/utils/schedulers.py | 76 -- build/lib/utils/tensors.py | 71 - build/lib/vjepa_encoder/__init__.py | 6 + build/lib/vjepa_encoder/vision_encoder.py | 17 +- demo_jepa_encoder.py | 4 +- jepa_encoder.egg-info/PKG-INFO | 17 - jepa_encoder.egg-info/SOURCES.txt | 10 - jepa_encoder.egg-info/dependency_links.txt | 1 - jepa_encoder.egg-info/requires.txt | 11 - jepa_encoder.egg-info/top_level.txt | 1 - setup.py | 2 +- vjepa_encoder.egg-info/PKG-INFO | 9 +- vjepa_encoder/__init__.py | 6 + vjepa_encoder/vision_encoder.py | 15 +- 38 files changed, 45 insertions(+), 4746 deletions(-) delete mode 100644 build/lib/datasets/data_manager.py delete mode 100644 build/lib/datasets/image_dataset.py delete mode 100644 build/lib/datasets/utils/video/functional.py delete mode 100644 build/lib/datasets/utils/video/randaugment.py delete mode 100644 build/lib/datasets/utils/video/randerase.py delete mode 100644 build/lib/datasets/utils/video/transforms.py delete mode 100644 build/lib/datasets/utils/video/volume_transforms.py delete mode 100644 build/lib/datasets/utils/weighted_sampler.py delete mode 100644 build/lib/datasets/video_dataset.py delete mode 100644 build/lib/jepa_src/models/utils/functional.py delete mode 100644 build/lib/masks/default.py delete mode 100644 build/lib/masks/multiblock3d.py delete mode 100644 build/lib/masks/random_tube.py delete mode 100644 build/lib/masks/utils.py delete mode 100644 build/lib/models/attentive_pooler.py delete mode 100644 build/lib/models/predictor.py delete mode 100644 build/lib/models/utils/modules.py delete mode 100644 build/lib/models/utils/multimask.py delete mode 100644 build/lib/models/utils/patch_embed.py delete mode 100644 build/lib/models/utils/pos_embs.py delete mode 100644 build/lib/models/vision_transformer.py delete mode 100644 build/lib/utils/distributed.py delete mode 100644 build/lib/utils/logging.py delete mode 100644 build/lib/utils/monitoring.py delete mode 100644 build/lib/utils/schedulers.py delete mode 100644 build/lib/utils/tensors.py delete mode 100644 jepa_encoder.egg-info/PKG-INFO delete mode 100644 jepa_encoder.egg-info/SOURCES.txt delete mode 100644 jepa_encoder.egg-info/dependency_links.txt delete mode 100644 jepa_encoder.egg-info/requires.txt delete mode 100644 jepa_encoder.egg-info/top_level.txt diff --git a/build/lib/datasets/data_manager.py b/build/lib/datasets/data_manager.py deleted file mode 100644 index cf53940..0000000 --- a/build/lib/datasets/data_manager.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -from logging import getLogger - - -_GLOBAL_SEED = 0 -logger = getLogger() - - -def init_data( - batch_size, - transform=None, - shared_transform=None, - data='ImageNet', - collator=None, - pin_mem=True, - num_workers=8, - world_size=1, - rank=0, - root_path=None, - image_folder=None, - training=True, - copy_data=False, - drop_last=True, - tokenize_txt=True, - subset_file=None, - clip_len=8, - frame_sample_rate=2, - duration=None, - num_clips=1, - random_clip_sampling=True, - allow_clip_overlap=False, - filter_short_videos=False, - filter_long_videos=int(1e9), - decode_one_clip=True, - datasets_weights=None, - persistent_workers=False, - repeat_wds=False, - ipe=300, - log_dir=None, -): - - if (data.lower() == 'imagenet') \ - or (data.lower() == 'inat21') \ - or (data.lower() == 'places205'): - from jepa_src.datasets.image_dataset import make_imagedataset - dataset, data_loader, dist_sampler = make_imagedataset( - transform=transform, - batch_size=batch_size, - collator=collator, - pin_mem=pin_mem, - training=training, - num_workers=num_workers, - world_size=world_size, - rank=rank, - root_path=root_path, - image_folder=image_folder, - persistent_workers=persistent_workers, - copy_data=copy_data, - drop_last=drop_last, - subset_file=subset_file) - - elif data.lower() == 'videodataset': - from jepa_src.datasets.video_dataset import make_videodataset - dataset, data_loader, dist_sampler = make_videodataset( - data_paths=root_path, - batch_size=batch_size, - frames_per_clip=clip_len, - frame_step=frame_sample_rate, - duration=duration, - num_clips=num_clips, - random_clip_sampling=random_clip_sampling, - allow_clip_overlap=allow_clip_overlap, - filter_short_videos=filter_short_videos, - filter_long_videos=filter_long_videos, - shared_transform=shared_transform, - transform=transform, - datasets_weights=datasets_weights, - collator=collator, - num_workers=num_workers, - world_size=world_size, - rank=rank, - drop_last=drop_last, - log_dir=log_dir) - - return (data_loader, dist_sampler) diff --git a/build/lib/datasets/image_dataset.py b/build/lib/datasets/image_dataset.py deleted file mode 100644 index 84e9b08..0000000 --- a/build/lib/datasets/image_dataset.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import os - -from logging import getLogger - -import torch -import torchvision - -_GLOBAL_SEED = 0 -logger = getLogger() - - -class ImageFolder(torchvision.datasets.ImageFolder): - - def __init__( - self, - root, - image_folder='imagenet_full_size/061417/', - transform=None, - train=True, - ): - """ - ImageFolder - :param root: root network directory for ImageFolder data - :param image_folder: path to images inside root network directory - :param train: whether to load train data (or validation) - """ - - suffix = 'train/' if train else 'val/' - data_path = os.path.join(root, image_folder, suffix) - logger.info(f'data-path {data_path}') - super(ImageFolder, self).__init__(root=data_path, transform=transform) - logger.info('Initialized ImageFolder') - - -def make_imagedataset( - transform, - batch_size, - collator=None, - pin_mem=True, - num_workers=8, - world_size=1, - rank=0, - root_path=None, - image_folder=None, - training=True, - copy_data=False, - drop_last=True, - persistent_workers=False, - subset_file=None -): - dataset = ImageFolder( - root=root_path, - image_folder=image_folder, - transform=transform, - train=training) - logger.info('ImageFolder dataset created') - dist_sampler = torch.utils.data.distributed.DistributedSampler( - dataset=dataset, - num_replicas=world_size, - rank=rank) - data_loader = torch.utils.data.DataLoader( - dataset, - collate_fn=collator, - sampler=dist_sampler, - batch_size=batch_size, - drop_last=drop_last, - pin_memory=pin_mem, - num_workers=num_workers, - persistent_workers=persistent_workers) - logger.info('ImageFolder unsupervised data loader created') - - return dataset, data_loader, dist_sampler diff --git a/build/lib/datasets/utils/video/functional.py b/build/lib/datasets/utils/video/functional.py deleted file mode 100644 index a91d15d..0000000 --- a/build/lib/datasets/utils/video/functional.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import numbers -import cv2 -import numpy as np -import PIL -import torch - - -def _is_tensor_clip(clip): - return torch.is_tensor(clip) and clip.ndimension() == 4 - - -def crop_clip(clip, min_h, min_w, h, w): - if isinstance(clip[0], np.ndarray): - cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] - - elif isinstance(clip[0], PIL.Image.Image): - cropped = [ - img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip - ] - else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) - return cropped - - -def resize_clip(clip, size, interpolation='bilinear'): - if isinstance(clip[0], np.ndarray): - if isinstance(size, numbers.Number): - im_h, im_w, im_c = clip[0].shape - # Min spatial dim already matches minimal size - if (im_w <= im_h and im_w == size) or (im_h <= im_w - and im_h == size): - return clip - new_h, new_w = get_resize_sizes(im_h, im_w, size) - size = (new_w, new_h) - else: - size = size[0], size[1] - if interpolation == 'bilinear': - np_inter = cv2.INTER_LINEAR - else: - np_inter = cv2.INTER_NEAREST - scaled = [ - cv2.resize(img, size, interpolation=np_inter) for img in clip - ] - elif isinstance(clip[0], PIL.Image.Image): - if isinstance(size, numbers.Number): - im_w, im_h = clip[0].size - # Min spatial dim already matches minimal size - if (im_w <= im_h and im_w == size) or (im_h <= im_w - and im_h == size): - return clip - new_h, new_w = get_resize_sizes(im_h, im_w, size) - size = (new_w, new_h) - else: - size = size[1], size[0] - if interpolation == 'bilinear': - pil_inter = PIL.Image.BILINEAR - else: - pil_inter = PIL.Image.NEAREST - scaled = [img.resize(size, pil_inter) for img in clip] - else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) - return scaled - - -def get_resize_sizes(im_h, im_w, size): - if im_w < im_h: - ow = size - oh = int(size * im_h / im_w) - else: - oh = size - ow = int(size * im_w / im_h) - return oh, ow - - -def normalize(clip, mean, std, inplace=False): - if not _is_tensor_clip(clip): - raise TypeError('tensor is not a torch clip.') - - if not inplace: - clip = clip.clone() - - dtype = clip.dtype - mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) - std = torch.as_tensor(std, dtype=dtype, device=clip.device) - clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) - - return clip diff --git a/build/lib/datasets/utils/video/randaugment.py b/build/lib/datasets/utils/video/randaugment.py deleted file mode 100644 index 4c80a99..0000000 --- a/build/lib/datasets/utils/video/randaugment.py +++ /dev/null @@ -1,518 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -""" -This implementation is based on -https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py -pulished under an Apache License 2.0. -""" - -import math -import numpy as np -import random -import re -import PIL -from PIL import Image, ImageEnhance, ImageOps - -_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) - -_FILL = (128, 128, 128) - -# This signifies the max integer that the controller RNN could predict for the -# augmentation scheme. -_MAX_LEVEL = 10.0 - -_HPARAMS_DEFAULT = { - "translate_const": 250, - "img_mean": _FILL, -} - -_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) - - -def _interpolation(kwargs): - interpolation = kwargs.pop("resample", Image.BILINEAR) - if isinstance(interpolation, (list, tuple)): - return random.choice(interpolation) - else: - return interpolation - - -def _check_args_tf(kwargs): - if "fillcolor" in kwargs and _PIL_VER < (5, 0): - kwargs.pop("fillcolor") - kwargs["resample"] = _interpolation(kwargs) - - -def shear_x(img, factor, **kwargs): - _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs - ) - - -def shear_y(img, factor, **kwargs): - _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs - ) - - -def translate_x_rel(img, pct, **kwargs): - pixels = pct * img.size[0] - _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs - ) - - -def translate_y_rel(img, pct, **kwargs): - pixels = pct * img.size[1] - _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs - ) - - -def translate_x_abs(img, pixels, **kwargs): - _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs - ) - - -def translate_y_abs(img, pixels, **kwargs): - _check_args_tf(kwargs) - return img.transform( - img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs - ) - - -def rotate(img, degrees, **kwargs): - _check_args_tf(kwargs) - if _PIL_VER >= (5, 2): - return img.rotate(degrees, **kwargs) - elif _PIL_VER >= (5, 0): - w, h = img.size - post_trans = (0, 0) - rotn_center = (w / 2.0, h / 2.0) - angle = -math.radians(degrees) - matrix = [ - round(math.cos(angle), 15), - round(math.sin(angle), 15), - 0.0, - round(-math.sin(angle), 15), - round(math.cos(angle), 15), - 0.0, - ] - - def transform(x, y, matrix): - (a, b, c, d, e, f) = matrix - return a * x + b * y + c, d * x + e * y + f - - matrix[2], matrix[5] = transform( - -rotn_center[0] - post_trans[0], - -rotn_center[1] - post_trans[1], - matrix, - ) - matrix[2] += rotn_center[0] - matrix[5] += rotn_center[1] - return img.transform(img.size, Image.AFFINE, matrix, **kwargs) - else: - return img.rotate(degrees, resample=kwargs["resample"]) - - -def auto_contrast(img, **__): - return ImageOps.autocontrast(img) - - -def invert(img, **__): - return ImageOps.invert(img) - - -def equalize(img, **__): - return ImageOps.equalize(img) - - -def solarize(img, thresh, **__): - return ImageOps.solarize(img, thresh) - - -def solarize_add(img, add, thresh=128, **__): - lut = [] - for i in range(256): - if i < thresh: - lut.append(min(255, i + add)) - else: - lut.append(i) - if img.mode in ("L", "RGB"): - if img.mode == "RGB" and len(lut) == 256: - lut = lut + lut + lut - return img.point(lut) - else: - return img - - -def posterize(img, bits_to_keep, **__): - if bits_to_keep >= 8: - return img - return ImageOps.posterize(img, bits_to_keep) - - -def contrast(img, factor, **__): - return ImageEnhance.Contrast(img).enhance(factor) - - -def color(img, factor, **__): - return ImageEnhance.Color(img).enhance(factor) - - -def brightness(img, factor, **__): - return ImageEnhance.Brightness(img).enhance(factor) - - -def sharpness(img, factor, **__): - return ImageEnhance.Sharpness(img).enhance(factor) - - -def _randomly_negate(v): - """With 50% prob, negate the value""" - return -v if random.random() > 0.5 else v - - -def _rotate_level_to_arg(level, _hparams): - # range [-30, 30] - level = (level / _MAX_LEVEL) * 30.0 - level = _randomly_negate(level) - return (level,) - - -def _enhance_level_to_arg(level, _hparams): - # range [0.1, 1.9] - return ((level / _MAX_LEVEL) * 1.8 + 0.1,) - - -def _enhance_increasing_level_to_arg(level, _hparams): - # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend - # range [0.1, 1.9] - level = (level / _MAX_LEVEL) * 0.9 - level = 1.0 + _randomly_negate(level) - return (level,) - - -def _shear_level_to_arg(level, _hparams): - # range [-0.3, 0.3] - level = (level / _MAX_LEVEL) * 0.3 - level = _randomly_negate(level) - return (level,) - - -def _translate_abs_level_to_arg(level, hparams): - translate_const = hparams["translate_const"] - level = (level / _MAX_LEVEL) * float(translate_const) - level = _randomly_negate(level) - return (level,) - - -def _translate_rel_level_to_arg(level, hparams): - # default range [-0.45, 0.45] - translate_pct = hparams.get("translate_pct", 0.45) - level = (level / _MAX_LEVEL) * translate_pct - level = _randomly_negate(level) - return (level,) - - -def _posterize_level_to_arg(level, _hparams): - # As per Tensorflow TPU EfficientNet impl - # range [0, 4], 'keep 0 up to 4 MSB of original image' - # intensity/severity of augmentation decreases with level - return (int((level / _MAX_LEVEL) * 4),) - - -def _posterize_increasing_level_to_arg(level, hparams): - # As per Tensorflow models research and UDA impl - # range [4, 0], 'keep 4 down to 0 MSB of original image', - # intensity/severity of augmentation increases with level - return (4 - _posterize_level_to_arg(level, hparams)[0],) - - -def _posterize_original_level_to_arg(level, _hparams): - # As per original AutoAugment paper description - # range [4, 8], 'keep 4 up to 8 MSB of image' - # intensity/severity of augmentation decreases with level - return (int((level / _MAX_LEVEL) * 4) + 4,) - - -def _solarize_level_to_arg(level, _hparams): - # range [0, 256] - # intensity/severity of augmentation decreases with level - return (int((level / _MAX_LEVEL) * 256),) - - -def _solarize_increasing_level_to_arg(level, _hparams): - # range [0, 256] - # intensity/severity of augmentation increases with level - return (256 - _solarize_level_to_arg(level, _hparams)[0],) - - -def _solarize_add_level_to_arg(level, _hparams): - # range [0, 110] - return (int((level / _MAX_LEVEL) * 110),) - - -LEVEL_TO_ARG = { - "AutoContrast": None, - "Equalize": None, - "Invert": None, - "Rotate": _rotate_level_to_arg, - # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers - "Posterize": _posterize_level_to_arg, - "PosterizeIncreasing": _posterize_increasing_level_to_arg, - "PosterizeOriginal": _posterize_original_level_to_arg, - "Solarize": _solarize_level_to_arg, - "SolarizeIncreasing": _solarize_increasing_level_to_arg, - "SolarizeAdd": _solarize_add_level_to_arg, - "Color": _enhance_level_to_arg, - "ColorIncreasing": _enhance_increasing_level_to_arg, - "Contrast": _enhance_level_to_arg, - "ContrastIncreasing": _enhance_increasing_level_to_arg, - "Brightness": _enhance_level_to_arg, - "BrightnessIncreasing": _enhance_increasing_level_to_arg, - "Sharpness": _enhance_level_to_arg, - "SharpnessIncreasing": _enhance_increasing_level_to_arg, - "ShearX": _shear_level_to_arg, - "ShearY": _shear_level_to_arg, - "TranslateX": _translate_abs_level_to_arg, - "TranslateY": _translate_abs_level_to_arg, - "TranslateXRel": _translate_rel_level_to_arg, - "TranslateYRel": _translate_rel_level_to_arg, -} - - -NAME_TO_OP = { - "AutoContrast": auto_contrast, - "Equalize": equalize, - "Invert": invert, - "Rotate": rotate, - "Posterize": posterize, - "PosterizeIncreasing": posterize, - "PosterizeOriginal": posterize, - "Solarize": solarize, - "SolarizeIncreasing": solarize, - "SolarizeAdd": solarize_add, - "Color": color, - "ColorIncreasing": color, - "Contrast": contrast, - "ContrastIncreasing": contrast, - "Brightness": brightness, - "BrightnessIncreasing": brightness, - "Sharpness": sharpness, - "SharpnessIncreasing": sharpness, - "ShearX": shear_x, - "ShearY": shear_y, - "TranslateX": translate_x_abs, - "TranslateY": translate_y_abs, - "TranslateXRel": translate_x_rel, - "TranslateYRel": translate_y_rel, -} - - -class AugmentOp: - """ - Apply for video. - """ - - def __init__(self, name, prob=0.5, magnitude=10, hparams=None): - hparams = hparams or _HPARAMS_DEFAULT - self.aug_fn = NAME_TO_OP[name] - self.level_fn = LEVEL_TO_ARG[name] - self.prob = prob - self.magnitude = magnitude - self.hparams = hparams.copy() - self.kwargs = { - "fillcolor": hparams["img_mean"] - if "img_mean" in hparams - else _FILL, - "resample": hparams["interpolation"] - if "interpolation" in hparams - else _RANDOM_INTERPOLATION, - } - - # If magnitude_std is > 0, we introduce some randomness - # in the usually fixed policy and sample magnitude from a normal distribution - # with mean `magnitude` and std-dev of `magnitude_std`. - # NOTE This is my own hack, being tested, not in papers or reference impls. - self.magnitude_std = self.hparams.get("magnitude_std", 0) - - def __call__(self, img_list): - if self.prob < 1.0 and random.random() > self.prob: - return img_list - magnitude = self.magnitude - if self.magnitude_std and self.magnitude_std > 0: - magnitude = random.gauss(magnitude, self.magnitude_std) - magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range - level_args = ( - self.level_fn(magnitude, self.hparams) - if self.level_fn is not None - else () - ) - - if isinstance(img_list, list): - return [ - self.aug_fn(img, *level_args, **self.kwargs) for img in img_list - ] - else: - return self.aug_fn(img_list, *level_args, **self.kwargs) - - -_RAND_TRANSFORMS = [ - "AutoContrast", - "Equalize", - "Invert", - "Rotate", - "Posterize", - "Solarize", - "SolarizeAdd", - "Color", - "Contrast", - "Brightness", - "Sharpness", - "ShearX", - "ShearY", - "TranslateXRel", - "TranslateYRel", -] - - -_RAND_INCREASING_TRANSFORMS = [ - "AutoContrast", - "Equalize", - "Invert", - "Rotate", - "PosterizeIncreasing", - "SolarizeIncreasing", - "SolarizeAdd", - "ColorIncreasing", - "ContrastIncreasing", - "BrightnessIncreasing", - "SharpnessIncreasing", - "ShearX", - "ShearY", - "TranslateXRel", - "TranslateYRel", -] - - -# These experimental weights are based loosely on the relative improvements mentioned in paper. -# They may not result in increased performance, but could likely be tuned to so. -_RAND_CHOICE_WEIGHTS_0 = { - "Rotate": 0.3, - "ShearX": 0.2, - "ShearY": 0.2, - "TranslateXRel": 0.1, - "TranslateYRel": 0.1, - "Color": 0.025, - "Sharpness": 0.025, - "AutoContrast": 0.025, - "Solarize": 0.005, - "SolarizeAdd": 0.005, - "Contrast": 0.005, - "Brightness": 0.005, - "Equalize": 0.005, - "Posterize": 0, - "Invert": 0, -} - - -def _select_rand_weights(weight_idx=0, transforms=None): - transforms = transforms or _RAND_TRANSFORMS - assert weight_idx == 0 # only one set of weights currently - rand_weights = _RAND_CHOICE_WEIGHTS_0 - probs = [rand_weights[k] for k in transforms] - probs /= np.sum(probs) - return probs - - -def rand_augment_ops(magnitude=10, hparams=None, transforms=None): - hparams = hparams or _HPARAMS_DEFAULT - transforms = transforms or _RAND_TRANSFORMS - return [ - AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) - for name in transforms - ] - - -class RandAugment: - def __init__(self, ops, num_layers=2, choice_weights=None): - self.ops = ops - self.num_layers = num_layers - self.choice_weights = choice_weights - - def __call__(self, img): - # no replacement when using weighted choice - ops = np.random.choice( - self.ops, - self.num_layers, - replace=self.choice_weights is None, - p=self.choice_weights, - ) - for op in ops: - img = op(img) - return img - - -def rand_augment_transform(config_str, hparams): - """ - RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 - - Create a RandAugment transform - :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by - dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining - sections, not order sepecific determine - 'm' - integer magnitude of rand augment - 'n' - integer num layers (number of transform ops selected per image) - 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) - 'mstd' - float std deviation of magnitude noise applied - 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) - Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 - 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 - :param hparams: Other hparams (kwargs) for the RandAugmentation scheme - :return: A PyTorch compatible Transform - """ - magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) - num_layers = 2 # default to 2 ops per image - weight_idx = None # default to no probability weights for op choice - transforms = _RAND_TRANSFORMS - config = config_str.split("-") - assert config[0] == "rand" - config = config[1:] - for c in config: - cs = re.split(r"(\d.*)", c) - if len(cs) < 2: - continue - key, val = cs[:2] - if key == "mstd": - # noise param injected via hparams for now - hparams.setdefault("magnitude_std", float(val)) - elif key == "inc": - if bool(val): - transforms = _RAND_INCREASING_TRANSFORMS - elif key == "m": - magnitude = int(val) - elif key == "n": - num_layers = int(val) - elif key == "w": - weight_idx = int(val) - else: - assert NotImplementedError - ra_ops = rand_augment_ops( - magnitude=magnitude, hparams=hparams, transforms=transforms - ) - choice_weights = ( - None if weight_idx is None else _select_rand_weights(weight_idx) - ) - return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) diff --git a/build/lib/datasets/utils/video/randerase.py b/build/lib/datasets/utils/video/randerase.py deleted file mode 100644 index d1f185c..0000000 --- a/build/lib/datasets/utils/video/randerase.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -""" -This implementation is based on -https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py -pulished under an Apache License 2.0. -""" -import math -import random -import torch - - -def _get_pixels( - per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" -): - # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() - # paths, flip the order so normal is run on CPU if this becomes a problem - # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 - if per_pixel: - return torch.empty(patch_size, dtype=dtype, device=device).normal_() - elif rand_color: - return torch.empty( - (patch_size[0], 1, 1), dtype=dtype, device=device - ).normal_() - else: - return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) - - -class RandomErasing: - """Randomly selects a rectangle region in an image and erases its pixels. - 'Random Erasing Data Augmentation' by Zhong et al. - See https://arxiv.org/pdf/1708.04896.pdf - This variant of RandomErasing is intended to be applied to either a batch - or single image tensor after it has been normalized by dataset mean and std. - Args: - probability: Probability that the Random Erasing operation will be performed. - min_area: Minimum percentage of erased area wrt input image area. - max_area: Maximum percentage of erased area wrt input image area. - min_aspect: Minimum aspect ratio of erased area. - mode: pixel color mode, one of 'const', 'rand', or 'pixel' - 'const' - erase block is constant color of 0 for all channels - 'rand' - erase block is same per-channel random (normal) color - 'pixel' - erase block is per-pixel random (normal) color - max_count: maximum number of erasing blocks per image, area per box is scaled by count. - per-image count is randomly chosen between 1 and this value. - """ - - def __init__( - self, - probability=0.5, - min_area=0.02, - max_area=1 / 3, - min_aspect=0.3, - max_aspect=None, - mode="const", - min_count=1, - max_count=None, - num_splits=0, - device="cuda", - cube=True, - ): - self.probability = probability - self.min_area = min_area - self.max_area = max_area - max_aspect = max_aspect or 1 / min_aspect - self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) - self.min_count = min_count - self.max_count = max_count or min_count - self.num_splits = num_splits - mode = mode.lower() - self.rand_color = False - self.per_pixel = False - self.cube = cube - if mode == "rand": - self.rand_color = True # per block random normal - elif mode == "pixel": - self.per_pixel = True # per pixel random normal - else: - assert not mode or mode == "const" - self.device = device - - def _erase(self, img, chan, img_h, img_w, dtype): - if random.random() > self.probability: - return - area = img_h * img_w - count = ( - self.min_count - if self.min_count == self.max_count - else random.randint(self.min_count, self.max_count) - ) - for _ in range(count): - for _ in range(10): - target_area = ( - random.uniform(self.min_area, self.max_area) * area / count - ) - aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) - h = int(round(math.sqrt(target_area * aspect_ratio))) - w = int(round(math.sqrt(target_area / aspect_ratio))) - if w < img_w and h < img_h: - top = random.randint(0, img_h - h) - left = random.randint(0, img_w - w) - img[:, top:top + h, left:left + w] = _get_pixels( - self.per_pixel, - self.rand_color, - (chan, h, w), - dtype=dtype, - device=self.device, - ) - break - - def _erase_cube( - self, - img, - batch_start, - batch_size, - chan, - img_h, - img_w, - dtype, - ): - if random.random() > self.probability: - return - area = img_h * img_w - count = ( - self.min_count - if self.min_count == self.max_count - else random.randint(self.min_count, self.max_count) - ) - for _ in range(count): - for _ in range(100): - target_area = ( - random.uniform(self.min_area, self.max_area) * area / count - ) - aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) - h = int(round(math.sqrt(target_area * aspect_ratio))) - w = int(round(math.sqrt(target_area / aspect_ratio))) - if w < img_w and h < img_h: - top = random.randint(0, img_h - h) - left = random.randint(0, img_w - w) - for i in range(batch_start, batch_size): - img_instance = img[i] - img_instance[ - :, top:top + h, left:left + w - ] = _get_pixels( - self.per_pixel, - self.rand_color, - (chan, h, w), - dtype=dtype, - device=self.device, - ) - break - - def __call__(self, input): - if len(input.size()) == 3: - self._erase(input, *input.size(), input.dtype) - else: - batch_size, chan, img_h, img_w = input.size() - # skip first slice of batch if num_splits is set (for clean portion of samples) - batch_start = ( - batch_size // self.num_splits if self.num_splits > 1 else 0 - ) - if self.cube: - self._erase_cube( - input, - batch_start, - batch_size, - chan, - img_h, - img_w, - input.dtype, - ) - else: - for i in range(batch_start, batch_size): - self._erase(input[i], chan, img_h, img_w, input.dtype) - return input diff --git a/build/lib/datasets/utils/video/transforms.py b/build/lib/datasets/utils/video/transforms.py deleted file mode 100644 index 979985d..0000000 --- a/build/lib/datasets/utils/video/transforms.py +++ /dev/null @@ -1,1184 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import math -import numpy as np -import random -import numbers -import PIL -from PIL import Image - -import torch -import torchvision -import torchvision.transforms.functional as F -from torchvision import transforms - -import jepa_src.datasets.utils.video.functional as FF -from jepa_src.datasets.utils.video.randaugment import rand_augment_transform - - -_pil_interpolation_to_str = { - Image.NEAREST: 'PIL.Image.NEAREST', - Image.BILINEAR: 'PIL.Image.BILINEAR', - Image.BICUBIC: 'PIL.Image.BICUBIC', - Image.LANCZOS: 'PIL.Image.LANCZOS', - Image.HAMMING: 'PIL.Image.HAMMING', - Image.BOX: 'PIL.Image.BOX', -} - - -_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) - - -def _pil_interp(method): - if method == 'bicubic': - return Image.BICUBIC - elif method == 'lanczos': - return Image.LANCZOS - elif method == 'hamming': - return Image.HAMMING - else: - return Image.BILINEAR - - -def random_short_side_scale_jitter( - images, min_size, max_size, boxes=None, inverse_uniform_sampling=False -): - """ - Perform a spatial short scale jittering on the given images and - corresponding boxes. - Args: - images (tensor): images to perform scale jitter. Dimension is - `num frames` x `channel` x `height` x `width`. - min_size (int): the minimal size to scale the frames. - max_size (int): the maximal size to scale the frames. - boxes (ndarray): optional. Corresponding boxes to images. - Dimension is `num boxes` x 4. - inverse_uniform_sampling (bool): if True, sample uniformly in - [1 / max_scale, 1 / min_scale] and take a reciprocal to get the - scale. If False, take a uniform sample from [min_scale, max_scale]. - Returns: - (tensor): the scaled images with dimension of - `num frames` x `channel` x `new height` x `new width`. - (ndarray or None): the scaled boxes with dimension of - `num boxes` x 4. - """ - if inverse_uniform_sampling: - size = int( - round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)) - ) - else: - size = int(round(np.random.uniform(min_size, max_size))) - - height = images.shape[2] - width = images.shape[3] - if (width <= height and width == size) or ( - height <= width and height == size - ): - return images, boxes - new_width = size - new_height = size - if width < height: - new_height = int(math.floor((float(height) / width) * size)) - if boxes is not None: - boxes = boxes * float(new_height) / height - else: - new_width = int(math.floor((float(width) / height) * size)) - if boxes is not None: - boxes = boxes * float(new_width) / width - - return ( - torch.nn.functional.interpolate( - images, - size=(new_height, new_width), - mode='bilinear', - align_corners=False, - ), - boxes, - ) - - -def crop_boxes(boxes, x_offset, y_offset): - """ - Peform crop on the bounding boxes given the offsets. - Args: - boxes (ndarray or None): bounding boxes to peform crop. The dimension - is `num boxes` x 4. - x_offset (int): cropping offset in the x axis. - y_offset (int): cropping offset in the y axis. - Returns: - cropped_boxes (ndarray or None): the cropped boxes with dimension of - `num boxes` x 4. - """ - cropped_boxes = boxes.copy() - cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset - cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset - - return cropped_boxes - - -def random_crop(images, size, boxes=None): - """ - Perform random spatial crop on the given images and corresponding boxes. - Args: - images (tensor): images to perform random crop. The dimension is - `num frames` x `channel` x `height` x `width`. - size (int): the size of height and width to crop on the image. - boxes (ndarray or None): optional. Corresponding boxes to images. - Dimension is `num boxes` x 4. - Returns: - cropped (tensor): cropped images with dimension of - `num frames` x `channel` x `size` x `size`. - cropped_boxes (ndarray or None): the cropped boxes with dimension of - `num boxes` x 4. - """ - if images.shape[2] == size and images.shape[3] == size: - return images - height = images.shape[2] - width = images.shape[3] - y_offset = 0 - if height > size: - y_offset = int(np.random.randint(0, height - size)) - x_offset = 0 - if width > size: - x_offset = int(np.random.randint(0, width - size)) - cropped = images[ - :, :, y_offset:y_offset + size, x_offset:x_offset + size - ] - - cropped_boxes = ( - crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None - ) - - return cropped, cropped_boxes - - -def horizontal_flip(prob, images, boxes=None): - """ - Perform horizontal flip on the given images and corresponding boxes. - Args: - prob (float): probility to flip the images. - images (tensor): images to perform horizontal flip, the dimension is - `num frames` x `channel` x `height` x `width`. - boxes (ndarray or None): optional. Corresponding boxes to images. - Dimension is `num boxes` x 4. - Returns: - images (tensor): images with dimension of - `num frames` x `channel` x `height` x `width`. - flipped_boxes (ndarray or None): the flipped boxes with dimension of - `num boxes` x 4. - """ - if boxes is None: - flipped_boxes = None - else: - flipped_boxes = boxes.copy() - - if np.random.uniform() < prob: - images = images.flip((-1)) - - if len(images.shape) == 3: - width = images.shape[2] - elif len(images.shape) == 4: - width = images.shape[3] - else: - raise NotImplementedError("Dimension does not supported") - if boxes is not None: - flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1 - - return images, flipped_boxes - - -def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): - """ - Perform uniform spatial sampling on the images and corresponding boxes. - Args: - images (tensor): images to perform uniform crop. The dimension is - `num frames` x `channel` x `height` x `width`. - size (int): size of height and weight to crop the images. - spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width - is larger than height. Or 0, 1, or 2 for top, center, and bottom - crop if height is larger than width. - boxes (ndarray or None): optional. Corresponding boxes to images. - Dimension is `num boxes` x 4. - scale_size (int): optinal. If not None, resize the images to scale_size before - performing any crop. - Returns: - cropped (tensor): images with dimension of - `num frames` x `channel` x `size` x `size`. - cropped_boxes (ndarray or None): the cropped boxes with dimension of - `num boxes` x 4. - """ - assert spatial_idx in [0, 1, 2] - ndim = len(images.shape) - if ndim == 3: - images = images.unsqueeze(0) - height = images.shape[2] - width = images.shape[3] - - if scale_size is not None: - if width <= height: - width, height = scale_size, int(height / width * scale_size) - else: - width, height = int(width / height * scale_size), scale_size - images = torch.nn.functional.interpolate( - images, - size=(height, width), - mode='bilinear', - align_corners=False, - ) - - y_offset = int(math.ceil((height - size) / 2)) - x_offset = int(math.ceil((width - size) / 2)) - - if height > width: - if spatial_idx == 0: - y_offset = 0 - elif spatial_idx == 2: - y_offset = height - size - else: - if spatial_idx == 0: - x_offset = 0 - elif spatial_idx == 2: - x_offset = width - size - cropped = images[ - :, :, y_offset:y_offset + size, x_offset:x_offset + size - ] - cropped_boxes = ( - crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None - ) - if ndim == 3: - cropped = cropped.squeeze(0) - return cropped, cropped_boxes - - -def clip_boxes_to_image(boxes, height, width): - """ - Clip an array of boxes to an image with the given height and width. - Args: - boxes (ndarray): bounding boxes to perform clipping. - Dimension is `num boxes` x 4. - height (int): given image height. - width (int): given image width. - Returns: - clipped_boxes (ndarray): the clipped boxes with dimension of - `num boxes` x 4. - """ - clipped_boxes = boxes.copy() - clipped_boxes[:, [0, 2]] = np.minimum( - width - 1.0, np.maximum(0.0, boxes[:, [0, 2]]) - ) - clipped_boxes[:, [1, 3]] = np.minimum( - height - 1.0, np.maximum(0.0, boxes[:, [1, 3]]) - ) - return clipped_boxes - - -def blend(images1, images2, alpha): - """ - Blend two images with a given weight alpha. - Args: - images1 (tensor): the first images to be blended, the dimension is - `num frames` x `channel` x `height` x `width`. - images2 (tensor): the second images to be blended, the dimension is - `num frames` x `channel` x `height` x `width`. - alpha (float): the blending weight. - Returns: - (tensor): blended images, the dimension is - `num frames` x `channel` x `height` x `width`. - """ - return images1 * alpha + images2 * (1 - alpha) - - -def grayscale(images): - """ - Get the grayscale for the input images. The channels of images should be - in order BGR. - Args: - images (tensor): the input images for getting grayscale. Dimension is - `num frames` x `channel` x `height` x `width`. - Returns: - img_gray (tensor): blended images, the dimension is - `num frames` x `channel` x `height` x `width`. - """ - # R -> 0.299, G -> 0.587, B -> 0.114. - img_gray = torch.tensor(images) - gray_channel = ( - 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] - ) - img_gray[:, 0] = gray_channel - img_gray[:, 1] = gray_channel - img_gray[:, 2] = gray_channel - return img_gray - - -def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): - """ - Perfrom a color jittering on the input images. The channels of images - should be in order BGR. - Args: - images (tensor): images to perform color jitter. Dimension is - `num frames` x `channel` x `height` x `width`. - img_brightness (float): jitter ratio for brightness. - img_contrast (float): jitter ratio for contrast. - img_saturation (float): jitter ratio for saturation. - Returns: - images (tensor): the jittered images, the dimension is - `num frames` x `channel` x `height` x `width`. - """ - - jitter = [] - if img_brightness != 0: - jitter.append('brightness') - if img_contrast != 0: - jitter.append('contrast') - if img_saturation != 0: - jitter.append('saturation') - - if len(jitter) > 0: - order = np.random.permutation(np.arange(len(jitter))) - for idx in range(0, len(jitter)): - if jitter[order[idx]] == 'brightness': - images = brightness_jitter(img_brightness, images) - elif jitter[order[idx]] == 'contrast': - images = contrast_jitter(img_contrast, images) - elif jitter[order[idx]] == 'saturation': - images = saturation_jitter(img_saturation, images) - return images - - -def brightness_jitter(var, images): - """ - Perfrom brightness jittering on the input images. The channels of images - should be in order BGR. - Args: - var (float): jitter ratio for brightness. - images (tensor): images to perform color jitter. Dimension is - `num frames` x `channel` x `height` x `width`. - Returns: - images (tensor): the jittered images, the dimension is - `num frames` x `channel` x `height` x `width`. - """ - alpha = 1.0 + np.random.uniform(-var, var) - - img_bright = torch.zeros(images.shape) - images = blend(images, img_bright, alpha) - return images - - -def contrast_jitter(var, images): - """ - Perfrom contrast jittering on the input images. The channels of images - should be in order BGR. - Args: - var (float): jitter ratio for contrast. - images (tensor): images to perform color jitter. Dimension is - `num frames` x `channel` x `height` x `width`. - Returns: - images (tensor): the jittered images, the dimension is - `num frames` x `channel` x `height` x `width`. - """ - alpha = 1.0 + np.random.uniform(-var, var) - - img_gray = grayscale(images) - img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True) - images = blend(images, img_gray, alpha) - return images - - -def saturation_jitter(var, images): - """ - Perfrom saturation jittering on the input images. The channels of images - should be in order BGR. - Args: - var (float): jitter ratio for saturation. - images (tensor): images to perform color jitter. Dimension is - `num frames` x `channel` x `height` x `width`. - Returns: - images (tensor): the jittered images, the dimension is - `num frames` x `channel` x `height` x `width`. - """ - alpha = 1.0 + np.random.uniform(-var, var) - img_gray = grayscale(images) - images = blend(images, img_gray, alpha) - - return images - - -def lighting_jitter(images, alphastd, eigval, eigvec): - """ - Perform AlexNet-style PCA jitter on the given images. - Args: - images (tensor): images to perform lighting jitter. Dimension is - `num frames` x `channel` x `height` x `width`. - alphastd (float): jitter ratio for PCA jitter. - eigval (list): eigenvalues for PCA jitter. - eigvec (list[list]): eigenvectors for PCA jitter. - Returns: - out_images (tensor): the jittered images, the dimension is - `num frames` x `channel` x `height` x `width`. - """ - if alphastd == 0: - return images - # generate alpha1, alpha2, alpha3. - alpha = np.random.normal(0, alphastd, size=(1, 3)) - eig_vec = np.array(eigvec) - eig_val = np.reshape(eigval, (1, 3)) - rgb = np.sum( - eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), - axis=1, - ) - out_images = torch.zeros_like(images) - if len(images.shape) == 3: - # C H W - channel_dim = 0 - elif len(images.shape) == 4: - # T C H W - channel_dim = 1 - else: - raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') - - for idx in range(images.shape[channel_dim]): - # C H W - if len(images.shape) == 3: - out_images[idx] = images[idx] + rgb[2 - idx] - # T C H W - elif len(images.shape) == 4: - out_images[:, idx] = images[:, idx] + rgb[2 - idx] - else: - raise NotImplementedError( - f'Unsupported dimension {len(images.shape)}' - ) - - return out_images - - -def color_normalization(images, mean, stddev): - """ - Perform color nomration on the given images. - Args: - images (tensor): images to perform color normalization. Dimension is - `num frames` x `channel` x `height` x `width`. - mean (list): mean values for normalization. - stddev (list): standard deviations for normalization. - - Returns: - out_images (tensor): the noramlized images, the dimension is - `num frames` x `channel` x `height` x `width`. - """ - if len(images.shape) == 3: - assert ( - len(mean) == images.shape[0] - ), 'channel mean not computed properly' - assert ( - len(stddev) == images.shape[0] - ), 'channel stddev not computed properly' - elif len(images.shape) == 4: - assert ( - len(mean) == images.shape[1] - ), 'channel mean not computed properly' - assert ( - len(stddev) == images.shape[1] - ), 'channel stddev not computed properly' - else: - raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') - - out_images = torch.zeros_like(images) - for idx in range(len(mean)): - # C H W - if len(images.shape) == 3: - out_images[idx] = (images[idx] - mean[idx]) / stddev[idx] - elif len(images.shape) == 4: - out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] - else: - raise NotImplementedError( - f'Unsupported dimension {len(images.shape)}' - ) - return out_images - - -def _get_param_spatial_crop( - scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False -): - """ - Given scale, ratio, height and width, return sampled coordinates of the videos. - """ - for _ in range(num_repeat): - area = height * width - target_area = random.uniform(*scale) * area - if log_scale: - log_ratio = (math.log(ratio[0]), math.log(ratio[1])) - aspect_ratio = math.exp(random.uniform(*log_ratio)) - else: - aspect_ratio = random.uniform(*ratio) - - w = int(round(math.sqrt(target_area * aspect_ratio))) - h = int(round(math.sqrt(target_area / aspect_ratio))) - - if np.random.uniform() < 0.5 and switch_hw: - w, h = h, w - - if 0 < w <= width and 0 < h <= height: - i = random.randint(0, height - h) - j = random.randint(0, width - w) - return i, j, h, w - - # Fallback to central crop - in_ratio = float(width) / float(height) - if in_ratio < min(ratio): - w = width - h = int(round(w / min(ratio))) - elif in_ratio > max(ratio): - h = height - w = int(round(h * max(ratio))) - else: # whole image - w = width - h = height - i = (height - h) // 2 - j = (width - w) // 2 - return i, j, h, w - - -def random_resized_crop( - images, - target_height, - target_width, - scale=(0.8, 1.0), - ratio=(3.0 / 4.0, 4.0 / 3.0), -): - """ - Crop the given images to random size and aspect ratio. A crop of random - size (default: of 0.08 to 1.0) of the original size and a random aspect - ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This - crop is finally resized to given size. This is popularly used to train the - Inception networks. - - Args: - images: Images to perform resizing and cropping. - target_height: Desired height after cropping. - target_width: Desired width after cropping. - scale: Scale range of Inception-style area based random resizing. - ratio: Aspect ratio range of Inception-style area based random resizing. - """ - - height = images.shape[2] - width = images.shape[3] - - i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) - cropped = images[:, :, i:i + h, j:j + w] - return torch.nn.functional.interpolate( - cropped, - size=(target_height, target_width), - mode='bilinear', - align_corners=False, - ) - - -def random_resized_crop_with_shift( - images, - target_height, - target_width, - scale=(0.8, 1.0), - ratio=(3.0 / 4.0, 4.0 / 3.0), -): - """ - This is similar to random_resized_crop. However, it samples two different - boxes (for cropping) for the first and last frame. It then linearly - interpolates the two boxes for other frames. - - Args: - images: Images to perform resizing and cropping. - target_height: Desired height after cropping. - target_width: Desired width after cropping. - scale: Scale range of Inception-style area based random resizing. - ratio: Aspect ratio range of Inception-style area based random resizing. - """ - t = images.shape[1] - height = images.shape[2] - width = images.shape[3] - - i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) - i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width) - i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()] - j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()] - h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()] - w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] - out = torch.zeros((3, t, target_height, target_width)) - for ind in range(t): - out[:, ind:ind + 1, :, :] = torch.nn.functional.interpolate( - images[ - :, - ind:ind + 1, - i_s[ind]:i_s[ind] + h_s[ind], - j_s[ind]:j_s[ind] + w_s[ind], - ], - size=(target_height, target_width), - mode='bilinear', - align_corners=False, - ) - return out - - -def create_random_augment( - input_size, - auto_augment=None, - interpolation='bilinear', -): - """ - Get video randaug transform. - - Args: - input_size: The size of the input video in tuple. - auto_augment: Parameters for randaug. An example: - "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number - of operations to apply). - interpolation: Interpolation method. - """ - if isinstance(input_size, tuple): - img_size = input_size[-2:] - else: - img_size = input_size - - if auto_augment: - assert isinstance(auto_augment, str) - if isinstance(img_size, tuple): - img_size_min = min(img_size) - else: - img_size_min = img_size - aa_params = {'translate_const': int(img_size_min * 0.45)} - if interpolation and interpolation != 'random': - aa_params['interpolation'] = _pil_interp(interpolation) - if auto_augment.startswith('rand'): - return transforms.Compose( - [rand_augment_transform(auto_augment, aa_params)] - ) - raise NotImplementedError - - -def random_sized_crop_img( - im, - size, - jitter_scale=(0.08, 1.0), - jitter_aspect=(3.0 / 4.0, 4.0 / 3.0), - max_iter=10, -): - """ - Performs Inception-style cropping (used for training). - """ - assert ( - len(im.shape) == 3 - ), 'Currently only support image for random_sized_crop' - h, w = im.shape[1:3] - i, j, h, w = _get_param_spatial_crop( - scale=jitter_scale, - ratio=jitter_aspect, - height=h, - width=w, - num_repeat=max_iter, - log_scale=False, - switch_hw=True, - ) - cropped = im[:, i:i + h, j:j + w] - return torch.nn.functional.interpolate( - cropped.unsqueeze(0), - size=(size, size), - mode='bilinear', - align_corners=False, - ).squeeze(0) - - -# The following code are modified based on timm lib, we will replace the following -# contents with dependency from PyTorchVideo. -# https://github.com/facebookresearch/pytorchvideo -class RandomResizedCropAndInterpolation: - """Crop the given PIL Image to random size and aspect ratio with random interpolation. - A crop of random size (default: of 0.08 to 1.0) of the original size and a random - aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop - is finally resized to given size. - This is popularly used to train the Inception networks. - Args: - size: expected output size of each edge - scale: range of size of the origin size cropped - ratio: range of aspect ratio of the origin aspect ratio cropped - interpolation: Default: PIL.Image.BILINEAR - """ - - def __init__( - self, - size, - scale=(0.08, 1.0), - ratio=(3.0 / 4.0, 4.0 / 3.0), - interpolation='bilinear', - ): - if isinstance(size, tuple): - self.size = size - else: - self.size = (size, size) - if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): - print('range should be of kind (min, max)') - - if interpolation == 'random': - self.interpolation = _RANDOM_INTERPOLATION - else: - self.interpolation = _pil_interp(interpolation) - self.scale = scale - self.ratio = ratio - - @staticmethod - def get_params(img, scale, ratio): - """Get parameters for ``crop`` for a random sized crop. - Args: - img (PIL Image): Image to be cropped. - scale (tuple): range of size of the origin size cropped - ratio (tuple): range of aspect ratio of the origin aspect ratio cropped - Returns: - tuple: params (i, j, h, w) to be passed to ``crop`` for a random - sized crop. - """ - area = img.size[0] * img.size[1] - - for _ in range(10): - target_area = random.uniform(*scale) * area - log_ratio = (math.log(ratio[0]), math.log(ratio[1])) - aspect_ratio = math.exp(random.uniform(*log_ratio)) - - w = int(round(math.sqrt(target_area * aspect_ratio))) - h = int(round(math.sqrt(target_area / aspect_ratio))) - - if w <= img.size[0] and h <= img.size[1]: - i = random.randint(0, img.size[1] - h) - j = random.randint(0, img.size[0] - w) - return i, j, h, w - - # Fallback to central crop - in_ratio = img.size[0] / img.size[1] - if in_ratio < min(ratio): - w = img.size[0] - h = int(round(w / min(ratio))) - elif in_ratio > max(ratio): - h = img.size[1] - w = int(round(h * max(ratio))) - else: # whole image - w = img.size[0] - h = img.size[1] - i = (img.size[1] - h) // 2 - j = (img.size[0] - w) // 2 - return i, j, h, w - - def __call__(self, img): - """ - Args: - img (PIL Image): Image to be cropped and resized. - Returns: - PIL Image: Randomly cropped and resized image. - """ - i, j, h, w = self.get_params(img, self.scale, self.ratio) - if isinstance(self.interpolation, (tuple, list)): - interpolation = random.choice(self.interpolation) - else: - interpolation = self.interpolation - return F.resized_crop(img, i, j, h, w, self.size, interpolation) - - def __repr__(self): - if isinstance(self.interpolation, (tuple, list)): - interpolate_str = ' '.join( - [_pil_interpolation_to_str[x] for x in self.interpolation] - ) - else: - interpolate_str = _pil_interpolation_to_str[self.interpolation] - format_string = self.__class__.__name__ + '(size={0}'.format(self.size) - format_string += ', scale={0}'.format( - tuple(round(s, 4) for s in self.scale) - ) - format_string += ', ratio={0}'.format( - tuple(round(r, 4) for r in self.ratio) - ) - format_string += ', interpolation={0})'.format(interpolate_str) - return format_string - - -class Compose(object): - """Composes several transforms - Args: - transforms (list of ``Transform`` objects): list of transforms - to compose - """ - - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, clip): - for t in self.transforms: - clip = t(clip) - return clip - - -class RandomHorizontalFlip(object): - """Horizontally flip the list of given images randomly - with a probability 0.5 - """ - - def __call__(self, clip): - """ - Args: - img (PIL.Image or numpy.ndarray): List of images to be cropped - in format (h, w, c) in numpy.ndarray - Returns: - PIL.Image or numpy.ndarray: Randomly flipped clip - """ - if random.random() < 0.5: - if isinstance(clip[0], np.ndarray): - return [np.fliplr(img) for img in clip] - elif isinstance(clip[0], PIL.Image.Image): - return [ - img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip - ] - else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - ' but got list of {0}'.format(type(clip[0]))) - return clip - - -class RandomResize(object): - """Resizes a list of (H x W x C) numpy.ndarray to the final size - The larger the original image is, the more times it takes to - interpolate - Args: - interpolation (str): Can be one of 'nearest', 'bilinear' - defaults to nearest - size (tuple): (widht, height) - """ - - def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): - self.ratio = ratio - self.interpolation = interpolation - - def __call__(self, clip): - scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) - - if isinstance(clip[0], np.ndarray): - im_h, im_w, im_c = clip[0].shape - elif isinstance(clip[0], PIL.Image.Image): - im_w, im_h = clip[0].size - - new_w = int(im_w * scaling_factor) - new_h = int(im_h * scaling_factor) - new_size = (new_w, new_h) - resized = FF.resize_clip( - clip, new_size, interpolation=self.interpolation) - return resized - - -class Resize(object): - """Resizes a list of (H x W x C) numpy.ndarray to the final size - The larger the original image is, the more times it takes to - interpolate - Args: - interpolation (str): Can be one of 'nearest', 'bilinear' - defaults to nearest - size (tuple): (widht, height) - """ - - def __init__(self, size, interpolation='nearest'): - self.size = size - self.interpolation = interpolation - - def __call__(self, clip): - resized = FF.resize_clip( - clip, self.size, interpolation=self.interpolation) - return resized - - -class RandomCrop(object): - """Extract random crop at the same location for a list of images - Args: - size (sequence or int): Desired output size for the - crop in format (h, w) - """ - - def __init__(self, size): - if isinstance(size, numbers.Number): - size = (size, size) - - self.size = size - - def __call__(self, clip): - """ - Args: - img (PIL.Image or numpy.ndarray): List of images to be cropped - in format (h, w, c) in numpy.ndarray - Returns: - PIL.Image or numpy.ndarray: Cropped list of images - """ - h, w = self.size - if isinstance(clip[0], np.ndarray): - im_h, im_w, im_c = clip[0].shape - elif isinstance(clip[0], PIL.Image.Image): - im_w, im_h = clip[0].size - else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) - if w > im_w or h > im_h: - error_msg = ( - 'Initial image size should be larger then ' - 'cropped size but got cropped sizes : ({w}, {h}) while ' - 'initial image is ({im_w}, {im_h})'.format( - im_w=im_w, im_h=im_h, w=w, h=h)) - raise ValueError(error_msg) - - x1 = random.randint(0, im_w - w) - y1 = random.randint(0, im_h - h) - cropped = FF.crop_clip(clip, y1, x1, h, w) - - return cropped - - -class ThreeCrop(object): - """Extract random crop at the same location for a list of images - Args: - size (sequence or int): Desired output size for the - crop in format (h, w) - """ - - def __init__(self, size): - if isinstance(size, numbers.Number): - size = (size, size) - - self.size = size - - def __call__(self, clip): - """ - Args: - img (PIL.Image or numpy.ndarray): List of images to be cropped - in format (h, w, c) in numpy.ndarray - Returns: - PIL.Image or numpy.ndarray: Cropped list of images - """ - h, w = self.size - if isinstance(clip[0], np.ndarray): - im_h, im_w, im_c = clip[0].shape - elif isinstance(clip[0], PIL.Image.Image): - im_w, im_h = clip[0].size - else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) - if w != im_w and h != im_h: - clip = FF.resize_clip(clip, self.size, interpolation="bilinear") - im_h, im_w, im_c = clip[0].shape - - step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0) - cropped = [] - for i in range(3): - if (im_h > self.size[0]): - x1 = 0 - y1 = i * step - cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) - else: - x1 = i * step - y1 = 0 - cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) - return cropped - - -class RandomRotation(object): - """Rotate entire clip randomly by a random angle within - given bounds - Args: - degrees (sequence or int): Range of degrees to select from - If degrees is a number instead of sequence like (min, max), - the range of degrees, will be (-degrees, +degrees). - """ - - def __init__(self, degrees): - if isinstance(degrees, numbers.Number): - if degrees < 0: - raise ValueError('If degrees is a single number,' - 'must be positive') - degrees = (-degrees, degrees) - else: - if len(degrees) != 2: - raise ValueError('If degrees is a sequence,' - 'it must be of len 2.') - - self.degrees = degrees - - def __call__(self, clip): - """ - Args: - img (PIL.Image or numpy.ndarray): List of images to be cropped - in format (h, w, c) in numpy.ndarray - Returns: - PIL.Image or numpy.ndarray: Cropped list of images - """ - import skimage - angle = random.uniform(self.degrees[0], self.degrees[1]) - if isinstance(clip[0], np.ndarray): - rotated = [skimage.transform.rotate(img, angle) for img in clip] - elif isinstance(clip[0], PIL.Image.Image): - rotated = [img.rotate(angle) for img in clip] - else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) - - return rotated - - -class CenterCrop(object): - """Extract center crop at the same location for a list of images - Args: - size (sequence or int): Desired output size for the - crop in format (h, w) - """ - - def __init__(self, size): - if isinstance(size, numbers.Number): - size = (size, size) - - self.size = size - - def __call__(self, clip): - """ - Args: - img (PIL.Image or numpy.ndarray): List of images to be cropped - in format (h, w, c) in numpy.ndarray - Returns: - PIL.Image or numpy.ndarray: Cropped list of images - """ - h, w = self.size - if isinstance(clip[0], np.ndarray): - im_h, im_w, im_c = clip[0].shape - elif isinstance(clip[0], PIL.Image.Image): - im_w, im_h = clip[0].size - else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) - if w > im_w or h > im_h: - error_msg = ( - 'Initial image size should be larger then ' - 'cropped size but got cropped sizes : ({w}, {h}) while ' - 'initial image is ({im_w}, {im_h})'.format( - im_w=im_w, im_h=im_h, w=w, h=h)) - raise ValueError(error_msg) - - x1 = int(round((im_w - w) / 2.)) - y1 = int(round((im_h - h) / 2.)) - cropped = FF.crop_clip(clip, y1, x1, h, w) - - return cropped - - -class ColorJitter(object): - """ - Randomly change the brightness, contrast and saturation and hue of the clip - - Args: - brightness (float): How much to jitter brightness. brightness_factor - is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. - contrast (float): How much to jitter contrast. contrast_factor - is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. - saturation (float): How much to jitter saturation. saturation_factor - is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. - hue(float): How much to jitter hue. hue_factor is chosen uniformly from - [-hue, hue]. Should be >=0 and <= 0.5. - """ - - def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): - self.brightness = brightness - self.contrast = contrast - self.saturation = saturation - self.hue = hue - - def get_params(self, brightness, contrast, saturation, hue): - if brightness > 0: - brightness_factor = random.uniform( - max(0, 1 - brightness), 1 + brightness) - else: - brightness_factor = None - - if contrast > 0: - contrast_factor = random.uniform( - max(0, 1 - contrast), 1 + contrast) - else: - contrast_factor = None - - if saturation > 0: - saturation_factor = random.uniform( - max(0, 1 - saturation), 1 + saturation) - else: - saturation_factor = None - - if hue > 0: - hue_factor = random.uniform(-hue, hue) - else: - hue_factor = None - return brightness_factor, contrast_factor, saturation_factor, hue_factor - - def __call__(self, clip): - """ - Args: - clip (list): list of PIL.Image - Returns: - list PIL.Image : list of transformed PIL.Image - """ - if isinstance(clip[0], np.ndarray): - raise TypeError( - 'Color jitter not yet implemented for numpy arrays') - elif isinstance(clip[0], PIL.Image.Image): - brightness, contrast, saturation, hue = self.get_params( - self.brightness, self.contrast, self.saturation, self.hue) - - # Create img transform function sequence - img_transforms = [] - if brightness is not None: - img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) - if saturation is not None: - img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) - if hue is not None: - img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) - if contrast is not None: - img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) - random.shuffle(img_transforms) - - # Apply to all images - jittered_clip = [] - for img in clip: - for func in img_transforms: - jittered_img = func(img) - jittered_clip.append(jittered_img) - - else: - raise TypeError('Expected numpy.ndarray or PIL.Image' + - 'but got list of {0}'.format(type(clip[0]))) - return jittered_clip - - -class Normalize(object): - """Normalize a clip with mean and standard deviation. - Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform - will normalize each channel of the input ``torch.*Tensor`` i.e. - ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` - .. note:: - This transform acts out of place, i.e., it does not mutates the input tensor. - Args: - mean (sequence): Sequence of means for each channel. - std (sequence): Sequence of standard deviations for each channel. - """ - - def __init__(self, mean, std): - self.mean = mean - self.std = std - - def __call__(self, clip): - """ - Args: - clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized. - Returns: - Tensor: Normalized Tensor clip. - """ - return FF.normalize(clip, self.mean, self.std) - - def __repr__(self): - return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) diff --git a/build/lib/datasets/utils/video/volume_transforms.py b/build/lib/datasets/utils/video/volume_transforms.py deleted file mode 100644 index 0a01bb3..0000000 --- a/build/lib/datasets/utils/video/volume_transforms.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import numpy as np -from PIL import Image - -import torch - - -def convert_img(img): - """Converts (H, W, C) numpy.ndarray to (C, W, H) format""" - if len(img.shape) == 3: - img = img.transpose(2, 0, 1) - if len(img.shape) == 2: - img = np.expand_dims(img, 0) - return img - - -class ClipToTensor(object): - """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] - to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] - """ - - def __init__(self, channel_nb=3, div_255=True, numpy=False): - self.channel_nb = channel_nb - self.div_255 = div_255 - self.numpy = numpy - - def __call__(self, clip): - """ - Args: clip (list of numpy.ndarray): clip (list of images) - to be converted to tensor. - """ - # Retrieve shape - if isinstance(clip[0], np.ndarray): - h, w, ch = clip[0].shape - assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) - elif isinstance(clip[0], Image.Image): - w, h = clip[0].size - else: - raise TypeError( - "Expected numpy.ndarray or PIL.Image\ - but got list of {0}".format( - type(clip[0]) - ) - ) - - np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) - - # Convert - for img_idx, img in enumerate(clip): - if isinstance(img, np.ndarray): - pass - elif isinstance(img, Image.Image): - img = np.array(img, copy=False) - else: - raise TypeError( - "Expected numpy.ndarray or PIL.Image\ - but got list of {0}".format( - type(clip[0]) - ) - ) - img = convert_img(img) - np_clip[:, img_idx, :, :] = img - if self.numpy: - if self.div_255: - np_clip = np_clip / 255.0 - return np_clip - - else: - tensor_clip = torch.from_numpy(np_clip) - - if not isinstance(tensor_clip, torch.FloatTensor): - tensor_clip = tensor_clip.float() - if self.div_255: - tensor_clip = torch.div(tensor_clip, 255) - return tensor_clip - - -# Note this norms data to -1/1 -class ClipToTensor_K(object): - """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] - to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] - """ - - def __init__(self, channel_nb=3, div_255=True, numpy=False): - self.channel_nb = channel_nb - self.div_255 = div_255 - self.numpy = numpy - - def __call__(self, clip): - """ - Args: clip (list of numpy.ndarray): clip (list of images) - to be converted to tensor. - """ - # Retrieve shape - if isinstance(clip[0], np.ndarray): - h, w, ch = clip[0].shape - assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) - elif isinstance(clip[0], Image.Image): - w, h = clip[0].size - else: - raise TypeError( - "Expected numpy.ndarray or PIL.Image\ - but got list of {0}".format( - type(clip[0]) - ) - ) - - np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) - - # Convert - for img_idx, img in enumerate(clip): - if isinstance(img, np.ndarray): - pass - elif isinstance(img, Image.Image): - img = np.array(img, copy=False) - else: - raise TypeError( - "Expected numpy.ndarray or PIL.Image\ - but got list of {0}".format( - type(clip[0]) - ) - ) - img = convert_img(img) - np_clip[:, img_idx, :, :] = img - if self.numpy: - if self.div_255: - np_clip = (np_clip - 127.5) / 127.5 - return np_clip - - else: - tensor_clip = torch.from_numpy(np_clip) - - if not isinstance(tensor_clip, torch.FloatTensor): - tensor_clip = tensor_clip.float() - if self.div_255: - tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) - return tensor_clip - - -class ToTensor(object): - """Converts numpy array to tensor""" - - def __call__(self, array): - tensor = torch.from_numpy(array) - return tensor diff --git a/build/lib/datasets/utils/weighted_sampler.py b/build/lib/datasets/utils/weighted_sampler.py deleted file mode 100644 index fd40825..0000000 --- a/build/lib/datasets/utils/weighted_sampler.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -from typing import Iterator, Optional -from operator import itemgetter -import numpy as np - -import torch -from torch.utils.data import ( - Dataset, - Sampler, - DistributedSampler, - WeightedRandomSampler -) - - -class DatasetFromSampler(Dataset): - - def __init__(self, sampler: Sampler): - self.sampler = sampler - self.sampler_list = None - - def __getitem__(self, index: int): - if self.sampler_list is None: - self.sampler_list = list(self.sampler) - return self.sampler_list[index] - - def __len__(self) -> int: - return len(self.sampler) - - -class DistributedSamplerWrapper(DistributedSampler): - """ Convert any Pytorch Sampler to a DistributedSampler """ - - def __init__( - self, - sampler, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, - shuffle: bool = True, - ): - super(DistributedSamplerWrapper, self).__init__( - DatasetFromSampler(sampler), - num_replicas=num_replicas, - rank=rank, - shuffle=shuffle, - ) - self.sampler = sampler - - def __iter__(self) -> Iterator[int]: - self.dataset = DatasetFromSampler(self.sampler) - indexes_of_indexes = super().__iter__() - subsampler_indexes = self.dataset - return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) - - -class CustomWeightedRandomSampler(WeightedRandomSampler): - """ Generalized WeightedRandomSampler to allow for more than 2^24 samples """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def __iter__(self): - rand_tensor = np.random.choice( - range(0, len(self.weights)), - size=self.num_samples, - p=self.weights.numpy() / torch.sum(self.weights).numpy(), - replace=self.replacement - ) - rand_tensor = torch.from_numpy(rand_tensor) - return iter(rand_tensor.tolist()) - - -class DistributedWeightedSampler(DistributedSamplerWrapper): - - def __init__( - self, - weights, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, - shuffle: bool = True, - ): - weighted_sampler = CustomWeightedRandomSampler( - weights=weights, - num_samples=len(weights), - replacement=False) - - super(DistributedWeightedSampler, self).__init__( - sampler=weighted_sampler, - num_replicas=num_replicas, - rank=rank, - shuffle=shuffle, - ) diff --git a/build/lib/datasets/video_dataset.py b/build/lib/datasets/video_dataset.py deleted file mode 100644 index 82cee52..0000000 --- a/build/lib/datasets/video_dataset.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import os -import pathlib -import warnings - -from logging import getLogger - -import numpy as np -import pandas as pd - -from decord import VideoReader, cpu - -import torch - -from jepa_src.datasets.utils.weighted_sampler import DistributedWeightedSampler - -_GLOBAL_SEED = 0 -logger = getLogger() - - -def make_videodataset( - data_paths, - batch_size, - frames_per_clip=8, - frame_step=4, - num_clips=1, - random_clip_sampling=True, - allow_clip_overlap=False, - filter_short_videos=False, - filter_long_videos=int(10**9), - transform=None, - shared_transform=None, - rank=0, - world_size=1, - datasets_weights=None, - collator=None, - drop_last=True, - num_workers=10, - pin_mem=True, - duration=None, - log_dir=None, -): - dataset = VideoDataset( - data_paths=data_paths, - datasets_weights=datasets_weights, - frames_per_clip=frames_per_clip, - frame_step=frame_step, - num_clips=num_clips, - random_clip_sampling=random_clip_sampling, - allow_clip_overlap=allow_clip_overlap, - filter_short_videos=filter_short_videos, - filter_long_videos=filter_long_videos, - duration=duration, - shared_transform=shared_transform, - transform=transform) - - logger.info('VideoDataset dataset created') - if datasets_weights is not None: - dist_sampler = DistributedWeightedSampler( - dataset.sample_weights, - num_replicas=world_size, - rank=rank, - shuffle=True) - else: - dist_sampler = torch.utils.data.distributed.DistributedSampler( - dataset, - num_replicas=world_size, - rank=rank, - shuffle=True) - - data_loader = torch.utils.data.DataLoader( - dataset, - collate_fn=collator, - sampler=dist_sampler, - batch_size=batch_size, - drop_last=drop_last, - pin_memory=pin_mem, - num_workers=num_workers, - persistent_workers=num_workers > 0) - logger.info('VideoDataset unsupervised data loader created') - - return dataset, data_loader, dist_sampler - - -class VideoDataset(torch.utils.data.Dataset): - """ Video classification dataset. """ - - def __init__( - self, - data_paths, - datasets_weights=None, - frames_per_clip=16, - frame_step=4, - num_clips=1, - transform=None, - shared_transform=None, - random_clip_sampling=True, - allow_clip_overlap=False, - filter_short_videos=False, - filter_long_videos=int(10**9), - duration=None, # duration in seconds - ): - self.data_paths = data_paths - self.datasets_weights = datasets_weights - self.frames_per_clip = frames_per_clip - self.frame_step = frame_step - self.num_clips = num_clips - self.transform = transform - self.shared_transform = shared_transform - self.random_clip_sampling = random_clip_sampling - self.allow_clip_overlap = allow_clip_overlap - self.filter_short_videos = filter_short_videos - self.filter_long_videos = filter_long_videos - self.duration = duration - - if VideoReader is None: - raise ImportError('Unable to import "decord" which is required to read videos.') - - # Load video paths and labels - samples, labels = [], [] - self.num_samples_per_dataset = [] - for data_path in self.data_paths: - - if data_path[-4:] == '.csv': - data = pd.read_csv(data_path, header=None, delimiter=" ") - samples += list(data.values[:, 0]) - labels += list(data.values[:, 1]) - num_samples = len(data) - self.num_samples_per_dataset.append(num_samples) - - elif data_path[-4:] == '.npy': - data = np.load(data_path, allow_pickle=True) - data = list(map(lambda x: repr(x)[1:-1], data)) - samples += data - labels += [0] * len(data) - num_samples = len(data) - self.num_samples_per_dataset.append(len(data)) - - # [Optional] Weights for each sample to be used by downstream - # weighted video sampler - self.sample_weights = None - if self.datasets_weights is not None: - self.sample_weights = [] - for dw, ns in zip(self.datasets_weights, self.num_samples_per_dataset): - self.sample_weights += [dw / ns] * ns - - self.samples = samples - self.labels = labels - - def __getitem__(self, index): - sample = self.samples[index] - - # Keep trying to load videos until you find a valid sample - loaded_video = False - while not loaded_video: - buffer, clip_indices = self.loadvideo_decord(sample) # [T H W 3] - loaded_video = len(buffer) > 0 - if not loaded_video: - index = np.random.randint(self.__len__()) - sample = self.samples[index] - - # Label/annotations for video - label = self.labels[index] - - def split_into_clips(video): - """ Split video into a list of clips """ - fpc = self.frames_per_clip - nc = self.num_clips - return [video[i*fpc:(i+1)*fpc] for i in range(nc)] - - # Parse video into frames & apply data augmentations - if self.shared_transform is not None: - buffer = self.shared_transform(buffer) - buffer = split_into_clips(buffer) - if self.transform is not None: - buffer = [self.transform(clip) for clip in buffer] - - return buffer, label, clip_indices - - def loadvideo_decord(self, sample): - """ Load video content using Decord """ - - fname = sample - if not os.path.exists(fname): - warnings.warn(f'video path not found {fname}') - return [], None - - _fsize = os.path.getsize(fname) - if _fsize < 1 * 1024: # avoid hanging issue - warnings.warn(f'video too short {fname}') - return [], None - if _fsize > self.filter_long_videos: - warnings.warn(f'skipping long video of size {_fsize} (bytes)') - return [], None - - try: - vr = VideoReader(fname, num_threads=-1, ctx=cpu(0)) - except Exception: - return [], None - - fpc = self.frames_per_clip - fstp = self.frame_step - if self.duration is not None: - try: - fps = vr.get_avg_fps() - fstp = int(self.duration * fps / fpc) - except Exception as e: - warnings.warn(e) - clip_len = int(fpc * fstp) - - if self.filter_short_videos and len(vr) < clip_len: - warnings.warn(f'skipping video of length {len(vr)}') - return [], None - - vr.seek(0) # Go to start of video before sampling frames - - # Partition video into equal sized segments and sample each clip - # from a different segment - partition_len = len(vr) // self.num_clips - - all_indices, clip_indices = [], [] - for i in range(self.num_clips): - - if partition_len > clip_len: - # If partition_len > clip len, then sample a random window of - # clip_len frames within the segment - end_indx = clip_len - if self.random_clip_sampling: - end_indx = np.random.randint(clip_len, partition_len) - start_indx = end_indx - clip_len - indices = np.linspace(start_indx, end_indx, num=fpc) - indices = np.clip(indices, start_indx, end_indx-1).astype(np.int64) - # -- - indices = indices + i * partition_len - else: - # If partition overlap not allowed and partition_len < clip_len - # then repeatedly append the last frame in the segment until - # we reach the desired clip length - if not self.allow_clip_overlap: - indices = np.linspace(0, partition_len, num=partition_len // fstp) - indices = np.concatenate((indices, np.ones(fpc - partition_len // fstp) * partition_len,)) - indices = np.clip(indices, 0, partition_len-1).astype(np.int64) - # -- - indices = indices + i * partition_len - - # If partition overlap is allowed and partition_len < clip_len - # then start_indx of segment i+1 will lie within segment i - else: - sample_len = min(clip_len, len(vr)) - 1 - indices = np.linspace(0, sample_len, num=sample_len // fstp) - indices = np.concatenate((indices, np.ones(fpc - sample_len // fstp) * sample_len,)) - indices = np.clip(indices, 0, sample_len-1).astype(np.int64) - # -- - clip_step = 0 - if len(vr) > clip_len: - clip_step = (len(vr) - clip_len) // (self.num_clips - 1) - indices = indices + i * clip_step - - clip_indices.append(indices) - all_indices.extend(list(indices)) - - buffer = vr.get_batch(all_indices).asnumpy() - return buffer, clip_indices - - def __len__(self): - return len(self.samples) diff --git a/build/lib/jepa_src/models/utils/functional.py b/build/lib/jepa_src/models/utils/functional.py deleted file mode 100644 index 27d1b42..0000000 --- a/build/lib/jepa_src/models/utils/functional.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch -import torch.nn.functional as F - -def scaled_dot_product_attention(q, k, v, dropout_p=0.0): - """ - Computes scaled dot product attention. - - Args: - q (torch.Tensor): Query tensor of shape (batch_size, num_heads, seq_len_q, head_dim). - k (torch.Tensor): Key tensor of shape (batch_size, num_heads, seq_len_k, head_dim). - v (torch.Tensor): Value tensor of shape (batch_size, num_heads, seq_len_v, head_dim). - dropout_p (float, optional): Dropout probability. Default is 0.0. - - Returns: - torch.Tensor: Output tensor of shape (batch_size, num_heads, seq_len_q, head_dim). - """ - # Compute attention scores - attn_scores = torch.matmul(q, k.transpose(-2, -1)) - attn_scores = attn_scores / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32)) - - # Apply softmax to attention scores - attn_probs = F.softmax(attn_scores, dim=-1) - - # Apply dropout to attention probabilities - attn_probs = F.dropout(attn_probs, p=dropout_p) - - # Compute attention output - attn_output = torch.matmul(attn_probs, v) - - return attn_output \ No newline at end of file diff --git a/build/lib/masks/default.py b/build/lib/masks/default.py deleted file mode 100644 index 2810c0a..0000000 --- a/build/lib/masks/default.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -from logging import getLogger - -import torch - -_GLOBAL_SEED = 0 -logger = getLogger() - - -class DefaultCollator(object): - - def __call__(self, batch): - collated_batch = torch.utils.data.default_collate(batch) - return collated_batch, None, None diff --git a/build/lib/masks/multiblock3d.py b/build/lib/masks/multiblock3d.py deleted file mode 100644 index a7bbc3e..0000000 --- a/build/lib/masks/multiblock3d.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import math - -from multiprocessing import Value - -from logging import getLogger - -import torch - -_GLOBAL_SEED = 0 -logger = getLogger() - - -class MaskCollator(object): - - def __init__( - self, - cfgs_mask, - crop_size=(224, 224), - num_frames=16, - patch_size=(16, 16), - tubelet_size=2, - ): - super(MaskCollator, self).__init__() - - self.mask_generators = [] - for m in cfgs_mask: - mask_generator = _MaskGenerator( - crop_size=crop_size, - num_frames=num_frames, - spatial_patch_size=patch_size, - temporal_patch_size=tubelet_size, - spatial_pred_mask_scale=m.get('spatial_scale'), - temporal_pred_mask_scale=m.get('temporal_scale'), - aspect_ratio=m.get('aspect_ratio'), - npred=m.get('num_blocks'), - max_context_frames_ratio=m.get('max_temporal_keep', 1.0), - max_keep=m.get('max_keep', None), - ) - self.mask_generators.append(mask_generator) - - def step(self): - for mask_generator in self.mask_generators: - mask_generator.step() - - def __call__(self, batch): - - batch_size = len(batch) - collated_batch = torch.utils.data.default_collate(batch) - - collated_masks_pred, collated_masks_enc = [], [] - for i, mask_generator in enumerate(self.mask_generators): - masks_enc, masks_pred = mask_generator(batch_size) - collated_masks_enc.append(masks_enc) - collated_masks_pred.append(masks_pred) - - return collated_batch, collated_masks_enc, collated_masks_pred - - -class _MaskGenerator(object): - - def __init__( - self, - crop_size=(224, 224), - num_frames=16, - spatial_patch_size=(16, 16), - temporal_patch_size=2, - spatial_pred_mask_scale=(0.2, 0.8), - temporal_pred_mask_scale=(1.0, 1.0), - aspect_ratio=(0.3, 3.0), - npred=1, - max_context_frames_ratio=1.0, - max_keep=None, - ): - super(_MaskGenerator, self).__init__() - if not isinstance(crop_size, tuple): - crop_size = (crop_size, ) * 2 - self.crop_size = crop_size - self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size - self.duration = num_frames // temporal_patch_size - - self.spatial_patch_size = spatial_patch_size - self.temporal_patch_size = temporal_patch_size - - self.aspect_ratio = aspect_ratio - self.spatial_pred_mask_scale = spatial_pred_mask_scale - self.temporal_pred_mask_scale = temporal_pred_mask_scale - self.npred = npred - self.max_context_duration = max(1, int(self.duration * max_context_frames_ratio)) # maximum number of time-steps (frames) spanned by context mask - self.max_keep = max_keep # maximum number of patches to keep in context - self._itr_counter = Value('i', -1) # collator is shared across worker processes - - def step(self): - i = self._itr_counter - with i.get_lock(): - i.value += 1 - v = i.value - return v - - def _sample_block_size( - self, - generator, - temporal_scale, - spatial_scale, - aspect_ratio_scale - ): - # -- Sample temporal block mask scale - _rand = torch.rand(1, generator=generator).item() - min_t, max_t = temporal_scale - temporal_mask_scale = min_t + _rand * (max_t - min_t) - t = max(1, int(self.duration * temporal_mask_scale)) - - # -- Sample spatial block mask scale - _rand = torch.rand(1, generator=generator).item() - min_s, max_s = spatial_scale - spatial_mask_scale = min_s + _rand * (max_s - min_s) - spatial_num_keep = int(self.height * self.width * spatial_mask_scale) - - # -- Sample block aspect-ratio - _rand = torch.rand(1, generator=generator).item() - min_ar, max_ar = aspect_ratio_scale - aspect_ratio = min_ar + _rand * (max_ar - min_ar) - - # -- Compute block height and width (given scale and aspect-ratio) - h = int(round(math.sqrt(spatial_num_keep * aspect_ratio))) - w = int(round(math.sqrt(spatial_num_keep / aspect_ratio))) - h = min(h, self.height) - w = min(w, self.width) - - return (t, h, w) - - def _sample_block_mask(self, b_size): - t, h, w = b_size - top = torch.randint(0, self.height - h + 1, (1,)) - left = torch.randint(0, self.width - w + 1, (1,)) - start = torch.randint(0, self.duration - t + 1, (1,)) - - mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) - mask[start:start+t, top:top+h, left:left+w] = 0 - - # Context mask will only span the first X frames - # (X=self.max_context_frames) - if self.max_context_duration < self.duration: - mask[self.max_context_duration:, :, :] = 0 - - # -- - return mask - - def __call__(self, batch_size): - """ - Create encoder and predictor masks when collating imgs into a batch - # 1. sample pred block size using seed - # 2. sample several pred block locations for each image (w/o seed) - # 3. return pred masks and complement (enc mask) - """ - seed = self.step() - g = torch.Generator() - g.manual_seed(seed) - p_size = self._sample_block_size( - generator=g, - temporal_scale=self.temporal_pred_mask_scale, - spatial_scale=self.spatial_pred_mask_scale, - aspect_ratio_scale=self.aspect_ratio, - ) - - collated_masks_pred, collated_masks_enc = [], [] - min_keep_enc = min_keep_pred = self.duration * self.height * self.width - for _ in range(batch_size): - - empty_context = True - while empty_context: - - mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) - for _ in range(self.npred): - mask_e *= self._sample_block_mask(p_size) - mask_e = mask_e.flatten() - - mask_p = torch.argwhere(mask_e == 0).squeeze() - mask_e = torch.nonzero(mask_e).squeeze() - - empty_context = len(mask_e) == 0 - if not empty_context: - min_keep_pred = min(min_keep_pred, len(mask_p)) - min_keep_enc = min(min_keep_enc, len(mask_e)) - collated_masks_pred.append(mask_p) - collated_masks_enc.append(mask_e) - - if self.max_keep is not None: - min_keep_enc = min(min_keep_enc, self.max_keep) - - collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred] - collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) - # -- - collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc] - collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) - - return collated_masks_enc, collated_masks_pred diff --git a/build/lib/masks/random_tube.py b/build/lib/masks/random_tube.py deleted file mode 100644 index 84c0640..0000000 --- a/build/lib/masks/random_tube.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -from multiprocessing import Value - -from logging import getLogger - -import torch -import numpy as np - -_GLOBAL_SEED = 0 -logger = getLogger() - - -class MaskCollator(object): - - def __init__( - self, - cfgs_mask, - crop_size=(224, 224), - num_frames=16, - patch_size=(16, 16), - tubelet_size=2, - ): - super(MaskCollator, self).__init__() - - self.mask_generators = [] - for m in cfgs_mask: - mask_generator = _MaskGenerator( - crop_size=crop_size, - num_frames=num_frames, - spatial_patch_size=patch_size, - temporal_patch_size=tubelet_size, - ratio=m.get('ratio'), - ) - self.mask_generators.append(mask_generator) - - def step(self): - for mask_generator in self.mask_generators: - mask_generator.step() - - def __call__(self, batch): - - batch_size = len(batch) - collated_batch = torch.utils.data.default_collate(batch) - - collated_masks_pred, collated_masks_enc = [], [] - for i, mask_generator in enumerate(self.mask_generators): - masks_enc, masks_pred = mask_generator(batch_size) - collated_masks_enc.append(masks_enc) - collated_masks_pred.append(masks_pred) - - return collated_batch, collated_masks_enc, collated_masks_pred - - -class _MaskGenerator(object): - - def __init__( - self, - crop_size=(224, 224), - num_frames=16, - spatial_patch_size=(16, 16), - temporal_patch_size=2, - ratio=0.9, - ): - super(_MaskGenerator, self).__init__() - if not isinstance(crop_size, tuple): - crop_size = (crop_size, ) * 2 - self.crop_size = crop_size - self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size - self.duration = num_frames // temporal_patch_size - - self.spatial_patch_size = spatial_patch_size - self.temporal_patch_size = temporal_patch_size - self.num_patches_spatial = self.height*self.width - - self.ratio = ratio - - self.num_keep_spatial = int(self.num_patches_spatial*(1.-self.ratio)) - self.num_keep = self.num_keep_spatial * self.duration - - self._itr_counter = Value('i', -1) # collator is shared across worker processes - - def step(self): - i = self._itr_counter - with i.get_lock(): - i.value += 1 - v = i.value - return v - - def __call__(self, batch_size): - def sample_mask(): - mask = np.hstack([ - np.zeros(self.num_patches_spatial - self.num_keep_spatial), - np.ones(self.num_keep_spatial), - ]) - np.random.shuffle(mask) - mask = torch.tensor(np.tile(mask, (self.duration, 1))) - mask = mask.flatten() - mask_p = torch.argwhere(mask == 0).squeeze() - mask_e = torch.nonzero(mask).squeeze() - return mask_e, mask_p - - collated_masks_pred, collated_masks_enc = [], [] - for _ in range(batch_size): - mask_e, mask_p = sample_mask() - collated_masks_enc.append(mask_e) - collated_masks_pred.append(mask_p) - - collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) - collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) - - return collated_masks_enc, collated_masks_pred diff --git a/build/lib/masks/utils.py b/build/lib/masks/utils.py deleted file mode 100644 index ca04af1..0000000 --- a/build/lib/masks/utils.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import torch - - -def apply_masks(x, masks, concat=True): - """ - :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] - :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep - """ - all_x = [] - for m in masks: - mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) - all_x += [torch.gather(x, dim=1, index=mask_keep)] - if not concat: - return all_x - - return torch.cat(all_x, dim=0) diff --git a/build/lib/models/attentive_pooler.py b/build/lib/models/attentive_pooler.py deleted file mode 100644 index 26b0e0e..0000000 --- a/build/lib/models/attentive_pooler.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import math - -import torch -import torch.nn as nn - -from jepa_src.models.utils.modules import ( - Block, - CrossAttention, - CrossAttentionBlock -) -from jepa_src.utils.tensors import trunc_normal_ - - -class AttentivePooler(nn.Module): - """ Attentive Pooler """ - def __init__( - self, - num_queries=1, - embed_dim=768, - num_heads=12, - mlp_ratio=4.0, - depth=1, - norm_layer=nn.LayerNorm, - init_std=0.02, - qkv_bias=True, - complete_block=True - ): - super().__init__() - self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) - - self.complete_block = complete_block - if complete_block: - self.cross_attention_block = CrossAttentionBlock( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - norm_layer=norm_layer) - else: - self.cross_attention_block = CrossAttention( - dim=embed_dim, - num_heads=num_heads, - qkv_bias=qkv_bias) - - self.blocks = None - if depth > 1: - self.blocks = nn.ModuleList([ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=False, - norm_layer=norm_layer) - for i in range(depth-1)]) - - self.init_std = init_std - trunc_normal_(self.query_tokens, std=self.init_std) - self.apply(self._init_weights) - self._rescale_blocks() - - def _rescale_blocks(self): - def rescale(param, layer_id): - param.div_(math.sqrt(2.0 * layer_id)) - - if self.complete_block: - rescale(self.cross_attention_block.xattn.proj.weight.data, 1) - rescale(self.cross_attention_block.mlp.fc2.weight.data, 1) - else: - rescale(self.cross_attention_block.proj.weight.data, 1) - if self.blocks is not None: - for layer_id, layer in enumerate(self.blocks, 1): - rescale(layer.attn.proj.weight.data, layer_id + 1) - rescale(layer.mlp.fc2.weight.data, layer_id + 1) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=self.init_std) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - trunc_normal_(m.weight, std=self.init_std) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def forward(self, x): - q = self.query_tokens.repeat(len(x), 1, 1) - q = self.cross_attention_block(q, x) - if self.blocks is not None: - for blk in self.blocks: - q = blk(q) - return q - - -class AttentiveClassifier(nn.Module): - """ Attentive Classifier """ - def __init__( - self, - embed_dim=768, - num_heads=12, - mlp_ratio=4.0, - depth=1, - norm_layer=nn.LayerNorm, - init_std=0.02, - qkv_bias=True, - num_classes=1000, - complete_block=True, - ): - super().__init__() - self.pooler = AttentivePooler( - num_queries=1, - embed_dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - depth=depth, - norm_layer=norm_layer, - init_std=init_std, - qkv_bias=qkv_bias, - complete_block=complete_block, - ) - self.linear = nn.Linear(embed_dim, num_classes, bias=True) - - def forward(self, x): - x = self.pooler(x).squeeze(1) - x = self.linear(x) - return x diff --git a/build/lib/models/predictor.py b/build/lib/models/predictor.py deleted file mode 100644 index 95f6bc0..0000000 --- a/build/lib/models/predictor.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import math -from functools import partial - -import torch -import torch.nn as nn - -from jepa_src.models.utils.modules import Block -from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed -from jepa_src.utils.tensors import ( - trunc_normal_, - repeat_interleave_batch -) -from jepa_src.masks.utils import apply_masks - - -class VisionTransformerPredictor(nn.Module): - """ Vision Transformer """ - def __init__( - self, - img_size=224, - patch_size=16, - num_frames=1, - tubelet_size=2, - embed_dim=768, - predictor_embed_dim=384, - depth=6, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop_rate=0.0, - attn_drop_rate=0.0, - norm_layer=nn.LayerNorm, - init_std=0.02, - uniform_power=False, - use_mask_tokens=False, - num_mask_tokens=2, - zero_init_mask_tokens=True, - **kwargs - ): - super().__init__() - # Map input to predictor dimension - self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) - - # Mask tokens - self.mask_tokens = None - self.num_mask_tokens = 0 - if use_mask_tokens: - self.num_mask_tokens = num_mask_tokens - self.mask_tokens = nn.ParameterList([ - nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) - for i in range(num_mask_tokens) - ]) - - # Determine positional embedding - self.input_size = img_size - self.patch_size = patch_size - # -- - self.num_frames = num_frames - self.tubelet_size = tubelet_size - self.is_video = num_frames > 1 - - grid_size = self.input_size // self.patch_size - grid_depth = self.num_frames // self.tubelet_size - - if self.is_video: - self.num_patches = num_patches = ( - (num_frames // tubelet_size) - * (img_size // patch_size) - * (img_size // patch_size) - ) - else: - self.num_patches = num_patches = ( - (img_size // patch_size) - * (img_size // patch_size) - ) - # Position embedding - self.uniform_power = uniform_power - self.predictor_pos_embed = None - self.predictor_pos_embed = nn.Parameter( - torch.zeros(1, num_patches, predictor_embed_dim), - requires_grad=False) - - # Attention Blocks - self.predictor_blocks = nn.ModuleList([ - Block( - dim=predictor_embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer=nn.GELU, - attn_drop=attn_drop_rate, - grid_size=grid_size, - grid_depth=grid_depth, - norm_layer=norm_layer) - for i in range(depth)]) - - # Normalize & project back to input dimension - self.predictor_norm = norm_layer(predictor_embed_dim) - self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) - - # ------ initialize weights - if self.predictor_pos_embed is not None: - self._init_pos_embed(self.predictor_pos_embed.data) # sincos pos-embed - self.init_std = init_std - if not zero_init_mask_tokens: - for mt in self.mask_tokens: - trunc_normal_(mt, std=init_std) - self.apply(self._init_weights) - self._rescale_blocks() - - def _init_pos_embed(self, pos_embed): - embed_dim = pos_embed.size(-1) - grid_size = self.input_size // self.patch_size - if self.is_video: - grid_depth = self.num_frames // self.tubelet_size - sincos = get_3d_sincos_pos_embed( - embed_dim, - grid_size, - grid_depth, - cls_token=False, - uniform_power=self.uniform_power - ) - else: - sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) - pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=self.init_std) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def _rescale_blocks(self): - def rescale(param, layer_id): - param.div_(math.sqrt(2.0 * layer_id)) - - for layer_id, layer in enumerate(self.predictor_blocks): - rescale(layer.attn.proj.weight.data, layer_id + 1) - rescale(layer.mlp.fc2.weight.data, layer_id + 1) - - def diffusion(self, x, noise_beta=(0.5, 1.0), steps=1000): - - # Prepare diffusion noise schedule - b1, b2 = noise_beta - beta_scheduler = (b1 + i*(b2-b1)/steps for i in range(steps)) - alpha_scheduler = [] - _alpha = 1.0 - for _beta in beta_scheduler: - _alpha *= 1.-_beta - alpha_scheduler += [_alpha] - - # Sample diffusion time step - T = torch.randint(0, steps, (len(x),)) - alpha = torch.tensor(alpha_scheduler, device=x.device)[T].unsqueeze(-1).unsqueeze(-1) - - # Normalize features and apply noise - x = torch.nn.functional.layer_norm(x, (x.size(-1),)) - x = alpha**0.5 * x + (1.-alpha)**0.5 * torch.randn(x.shape, device=x.device) - return x - - def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): - """ - :param ctxt: context tokens - :param tgt: target tokens - :param masks_ctxt: indices of context tokens in input - :params masks_tgt: indices of target tokens in input - """ - - assert (masks_ctxt is not None) and (masks_tgt is not None), 'Cannot run predictor without mask indices' - - if not isinstance(masks_ctxt, list): - masks_ctxt = [masks_ctxt] - - if not isinstance(masks_tgt, list): - masks_tgt = [masks_tgt] - - # Batch Size - B = len(ctxt) // len(masks_ctxt) - - # Map context tokens to pedictor dimensions - x = self.predictor_embed(ctxt) - _, N_ctxt, D = x.shape - - # Add positional embedding to ctxt tokens - if self.predictor_pos_embed is not None: - ctxt_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) - x += apply_masks(ctxt_pos_embed, masks_ctxt) - - # Map target tokens to predictor dimensions & add noise (fwd diffusion) - if self.mask_tokens is None: - pred_tokens = self.predictor_embed(tgt) - pred_tokens = self.diffusion(pred_tokens) - else: - mask_index = mask_index % self.num_mask_tokens - pred_tokens = self.mask_tokens[mask_index] - pred_tokens = pred_tokens.repeat(B, self.num_patches, 1) - pred_tokens = apply_masks(pred_tokens, masks_tgt) - - # Add positional embedding to target tokens - if self.predictor_pos_embed is not None: - pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) - pos_embs = apply_masks(pos_embs, masks_tgt) - pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_ctxt)) - pred_tokens += pos_embs - - # Concatenate context & target tokens - x = x.repeat(len(masks_tgt), 1, 1) - x = torch.cat([x, pred_tokens], dim=1) - - # FIXME: this implementation currently assumes masks_ctxt and masks_tgt - # are alligned 1:1 (ok with MultiMask wrapper on predictor but - # otherwise will break) - masks_ctxt = torch.cat(masks_ctxt, dim=0) - masks_tgt = torch.cat(masks_tgt, dim=0) - masks = torch.cat([masks_ctxt, masks_tgt], dim=1) - - # Fwd prop - for blk in self.predictor_blocks: - x = blk(x, mask=masks) - x = self.predictor_norm(x) - - # Return output corresponding to target tokens - x = x[:, N_ctxt:] - x = self.predictor_proj(x) - - return x - - -def vit_predictor(**kwargs): - model = VisionTransformerPredictor( - mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), - **kwargs) - return model diff --git a/build/lib/models/utils/modules.py b/build/lib/models/utils/modules.py deleted file mode 100644 index c78ffc0..0000000 --- a/build/lib/models/utils/modules.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import torch -import torch.nn as nn -import torch.nn.functional as F - -import jepa_src.utils.functional as JF - - -class MLP(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0. - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class Attention(nn.Module): - def __init__( - self, - dim, - num_heads=8, - qkv_bias=False, - qk_scale=None, - attn_drop=0., - proj_drop=0., - use_sdpa=True - ): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop_prob = proj_drop - self.proj_drop = nn.Dropout(proj_drop) - self.use_sdpa = use_sdpa - - def forward(self, x, mask=None): - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D] - - if self.use_sdpa: - with torch.backends.cuda.sdp_kernel(): - x = JF.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob) - attn = None - else: - attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D] - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = (attn @ v) - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x, attn - - -class Block(nn.Module): - def __init__( - self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - qk_scale=None, - drop=0., - attn_drop=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - grid_size=None, - grid_depth=None, - ): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop) - - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MLP( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop) - - def forward(self, x, return_attention=False, mask=None): - y, attn = self.attn(self.norm1(x), mask=mask) - if return_attention: - return attn - x = x + y - x = x + self.mlp(self.norm2(x)) - return x - - -class CrossAttention(nn.Module): - def __init__( - self, - dim, - num_heads=12, - qkv_bias=False, - use_sdpa=True - ): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim ** -0.5 - self.q = nn.Linear(dim, dim, bias=qkv_bias) - self.kv = nn.Linear(dim, int(dim*2), bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - self.use_sdpa = use_sdpa - - def forward(self, q, x): - B, n, C = q.shape - q = self.q(q).reshape(B, n, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - - B, N, C = x.shape - kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - k, v = kv[0], kv[1] # (batch_size, num_heads, seq_len, feature_dim_per_head) - - if self.use_sdpa: - with torch.backends.cuda.sdp_kernel(): - q = JF.scaled_dot_product_attention(q, k, v) - else: - xattn = (q @ k.transpose(-2, -1)) * self.scale - xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len) - q = (xattn @ v) - - q = q.transpose(1, 2).reshape(B, n, C) - q = self.proj(q) - - return q - - -class CrossAttentionBlock(nn.Module): - def __init__( - self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm - ): - super().__init__() - self.norm1 = norm_layer(dim) - self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) - - def forward(self, q, x): - y = self.xattn(q, self.norm1(x)) - q = q + y - q = q + self.mlp(self.norm2(q)) - return q diff --git a/build/lib/models/utils/multimask.py b/build/lib/models/utils/multimask.py deleted file mode 100644 index d480086..0000000 --- a/build/lib/models/utils/multimask.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import torch.nn as nn - - -class MultiMaskWrapper(nn.Module): - - def __init__(self, backbone): - super().__init__() - self.backbone = backbone - - def forward(self, x, masks=None): - if masks is None: - return self.backbone(x) - - if (masks is not None) and not isinstance(masks, list): - masks = [masks] - outs = [] - for m in masks: - outs += [self.backbone(x, masks=m)] - return outs - - -class PredictorMultiMaskWrapper(nn.Module): - - def __init__(self, backbone): - super().__init__() - self.backbone = backbone - - def forward(self, ctxt, tgt, masks_ctxt, masks_tgt): - if type(ctxt) is not list: - ctxt = [ctxt] - if type(tgt) is not list: - tgt = [tgt] - if type(masks_ctxt) is not list: - masks_ctxt = [masks_ctxt] - if type(masks_tgt) is not list: - masks_tgt = [masks_tgt] - - outs = [] - for i, (zi, hi, mc, mt) in enumerate(zip(ctxt, tgt, masks_ctxt, masks_tgt)): - outs += [self.backbone(zi, hi, mc, mt, mask_index=i)] - return outs diff --git a/build/lib/models/utils/patch_embed.py b/build/lib/models/utils/patch_embed.py deleted file mode 100644 index 4ff4de5..0000000 --- a/build/lib/models/utils/patch_embed.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import torch.nn as nn - - -class PatchEmbed(nn.Module): - """ - Image to Patch Embedding - """ - def __init__( - self, - patch_size=16, - in_chans=3, - embed_dim=768 - ): - super().__init__() - self.patch_size = patch_size - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - - def forward(self, x): - B, C, H, W = x.shape - x = self.proj(x).flatten(2).transpose(1, 2) - return x - - -class PatchEmbed3D(nn.Module): - """ - Image to Patch Embedding - """ - - def __init__( - self, - patch_size=16, - tubelet_size=2, - in_chans=3, - embed_dim=768, - ): - super().__init__() - self.patch_size = patch_size - self.tubelet_size = tubelet_size - - self.proj = nn.Conv3d( - in_channels=in_chans, - out_channels=embed_dim, - kernel_size=(tubelet_size, patch_size, patch_size), - stride=(tubelet_size, patch_size, patch_size), - ) - - def forward(self, x, **kwargs): - B, C, T, H, W = x.shape - x = self.proj(x).flatten(2).transpose(1, 2) - return x diff --git a/build/lib/models/utils/pos_embs.py b/build/lib/models/utils/pos_embs.py deleted file mode 100644 index d1d82e2..0000000 --- a/build/lib/models/utils/pos_embs.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import numpy as np - - -def get_3d_sincos_pos_embed( - embed_dim, - grid_size, - grid_depth, - cls_token=False, - uniform_power=False -): - """ - grid_size: int of the grid height and width - grid_depth: int of the grid depth - returns: - pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token) - or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token) - """ - grid_d = np.arange(grid_depth, dtype=float) - grid_h = np.arange(grid_size, dtype=float) - grid_w = np.arange(grid_size, dtype=float) - grid_h, grid_d, grid_w = np.meshgrid(grid_h, grid_d, grid_w) # order of meshgrid is very important for indexing as [d,h,w] - - if not uniform_power: - h_embed_dim = embed_dim // 4 - w_embed_dim = embed_dim // 4 - d_embed_dim = embed_dim // 2 - else: - h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim/6)*2) - - emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1) - emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2) - emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3) - pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1) - pos_embed = pos_embed[:, :embed_dim] - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): - """ - grid_size: int of the grid height and width - returns: - pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token) - or [1+grid_size*grid_size, embed_dim] (w/ cls_token) - """ - grid_h = np.arange(grid_size, dtype=float) - grid_w = np.arange(grid_size, dtype=float) - grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w] - - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2) - pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): - """ - embed_dim: output dimension for each position - grid_size: int of the grid length - returns: - pos_embed: [grid_size, embed_dim] (w/o cls_token) - or [1+grid_size, embed_dim] (w/ cls_token) - """ - grid = np.arange(grid_size, dtype=float) - pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) - returns: (M, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=float) - omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb diff --git a/build/lib/models/vision_transformer.py b/build/lib/models/vision_transformer.py deleted file mode 100644 index 946246e..0000000 --- a/build/lib/models/vision_transformer.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import math -from functools import partial - -import torch -import torch.nn as nn - -from jepa_src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D -from jepa_src.models.utils.modules import Block -from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed -from jepa_src.utils.tensors import trunc_normal_ -from jepa_src.masks.utils import apply_masks - - -class VisionTransformer(nn.Module): - """ Vision Transformer """ - def __init__( - self, - img_size=224, - patch_size=16, - num_frames=1, - tubelet_size=2, - in_chans=3, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop_rate=0.0, - attn_drop_rate=0.0, - norm_layer=nn.LayerNorm, - init_std=0.02, - out_layers=None, - uniform_power=False, - **kwargs - ): - super().__init__() - self.num_features = self.embed_dim = embed_dim - self.num_heads = num_heads - self.out_layers = out_layers - - self.input_size = img_size - self.patch_size = patch_size - - self.num_frames = num_frames - self.tubelet_size = tubelet_size - self.is_video = num_frames > 1 - - grid_size = self.input_size // self.patch_size - grid_depth = self.num_frames // self.tubelet_size - - # Tokenize pixels with convolution - if self.is_video: - self.patch_embed = PatchEmbed3D( - patch_size=patch_size, - tubelet_size=tubelet_size, - in_chans=in_chans, - embed_dim=embed_dim) - self.num_patches = ( - (num_frames // tubelet_size) - * (img_size // patch_size) - * (img_size // patch_size) - ) - else: - self.patch_embed = PatchEmbed( - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim) - self.num_patches = ( - (img_size // patch_size) - * (img_size // patch_size) - ) - - # Position embedding - self.uniform_power = uniform_power - self.pos_embed = None - self.pos_embed = nn.Parameter( - torch.zeros(1, self.num_patches, embed_dim), - requires_grad=False) - - # Attention Blocks - self.blocks = nn.ModuleList([ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer=nn.GELU, - grid_size=grid_size, - grid_depth=grid_depth, - attn_drop=attn_drop_rate, - norm_layer=norm_layer) - for i in range(depth)]) - self.norm = norm_layer(embed_dim) - - # ------ initialize weights - if self.pos_embed is not None: - self._init_pos_embed(self.pos_embed.data) # sincos pos-embed - self.init_std = init_std - self.apply(self._init_weights) - self._rescale_blocks() - - def _init_pos_embed(self, pos_embed): - embed_dim = pos_embed.size(-1) - grid_size = self.input_size // self.patch_size - if self.is_video: - grid_depth = self.num_frames // self.tubelet_size - sincos = get_3d_sincos_pos_embed( - embed_dim, - grid_size, - grid_depth, - cls_token=False, - uniform_power=self.uniform_power - ) - else: - sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) - pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=self.init_std) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - elif isinstance(m, nn.Conv2d): - trunc_normal_(m.weight, std=self.init_std) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.Conv3d): - trunc_normal_(m.weight, std=self.init_std) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def _rescale_blocks(self): - def rescale(param, layer_id): - param.div_(math.sqrt(2.0 * layer_id)) - - for layer_id, layer in enumerate(self.blocks): - rescale(layer.attn.proj.weight.data, layer_id + 1) - rescale(layer.mlp.fc2.weight.data, layer_id + 1) - - def get_num_layers(self): - return len(self.blocks) - - def no_weight_decay(self): - return {} - - def forward(self, x, masks=None): - """ - :param x: input image/video - :param masks: indices of patch tokens to mask (remove) - """ - - if masks is not None and not isinstance(masks, list): - masks = [masks] - - # Tokenize input - pos_embed = self.pos_embed - if pos_embed is not None: - pos_embed = self.interpolate_pos_encoding(x, pos_embed) - x = self.patch_embed(x) - if pos_embed is not None: - x += pos_embed - B, N, D = x.shape - - # Mask away unwanted tokens (if masks provided) - if masks is not None: - x = apply_masks(x, masks) - masks = torch.cat(masks, dim=0) - - # Fwd prop - outs = [] - for i, blk in enumerate(self.blocks): - x = blk(x, mask=masks) - if self.out_layers is not None and i in self.out_layers: - outs.append(self.norm(x)) - - if self.out_layers is not None: - return outs - - if self.norm is not None: - x = self.norm(x) - - return x - - def interpolate_pos_encoding(self, x, pos_embed): - - _, N, dim = pos_embed.shape - - if self.is_video: - - # If pos_embed already corret size, just return - _, _, T, H, W = x.shape - if H == self.input_size and W == self.input_size and T == self.num_frames: - return pos_embed - - # Convert depth, height, width of input to be measured in patches - # instead of pixels/frames - T = T // self.tubelet_size - H = H // self.patch_size - W = W // self.patch_size - - # Compute the initialized shape of the positional embedding measured - # in patches - N_t = self.num_frames // self.tubelet_size - N_h = N_w = self.input_size // self.patch_size - assert N_h * N_w * N_t == N, 'Positional embedding initialized incorrectly' - - # Compute scale factor for spatio-temporal interpolation - scale_factor = (T/N_t, H/N_h, W/N_w) - - pos_embed = nn.functional.interpolate( - pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3), - scale_factor=scale_factor, - mode='trilinear') - pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim) - return pos_embed - - else: - - # If pos_embed already corret size, just return - _, _, H, W = x.shape - if H == self.input_size and W == self.input_size: - return pos_embed - - # Compute scale factor for spatial interpolation - npatch = (H // self.patch_size) * (W // self.patch_size) - scale_factor = math.sqrt(npatch / N) - - pos_embed = nn.functional.interpolate( - pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), - scale_factor=scale_factor, - mode='bicubic') - pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return pos_embed - - -def vit_tiny(patch_size=16, **kwargs): - model = VisionTransformer( - patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - return model - - -def vit_small(patch_size=16, **kwargs): - model = VisionTransformer( - patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - return model - - -def vit_base(patch_size=16, **kwargs): - model = VisionTransformer( - patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - return model - - -def vit_large(patch_size=16, **kwargs): - model = VisionTransformer( - patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - return model - - -def vit_huge(patch_size=16, **kwargs): - model = VisionTransformer( - patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - return model - - -def vit_giant(patch_size=16, **kwargs): - model = VisionTransformer( - patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - return model - - -def vit_gigantic(patch_size=14, **kwargs): - model = VisionTransformer( - patch_size=patch_size, embed_dim=1664, depth=48, num_heads=16, mpl_ratio=64/13, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs - ) - return model - - -VIT_EMBED_DIMS = { - 'vit_tiny': 192, - 'vit_small': 384, - 'vit_base': 768, - 'vit_large': 1024, - 'vit_huge': 1280, - 'vit_giant': 1408, - 'vit_gigantic': 1664, -} diff --git a/build/lib/utils/distributed.py b/build/lib/utils/distributed.py deleted file mode 100644 index cfba444..0000000 --- a/build/lib/utils/distributed.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import os - -import torch -import torch.distributed as dist - -from logging import getLogger - -logger = getLogger() - - -def init_distributed(port=37123, rank_and_world_size=(None, None)): - - if dist.is_available() and dist.is_initialized(): - return dist.get_world_size(), dist.get_rank() - - rank, world_size = rank_and_world_size - os.environ['MASTER_ADDR'] = 'localhost' - - if (rank is None) or (world_size is None): - try: - world_size = int(os.environ['SLURM_NTASKS']) - rank = int(os.environ['SLURM_PROCID']) - os.environ['MASTER_ADDR'] = os.environ['HOSTNAME'] - except Exception: - logger.info('SLURM vars not set (distributed training not available)') - world_size, rank = 1, 0 - return world_size, rank - - try: - os.environ['MASTER_PORT'] = str(port) - torch.distributed.init_process_group( - backend='nccl', - world_size=world_size, - rank=rank - ) - except Exception as e: - world_size, rank = 1, 0 - logger.info(f'Rank: {rank}. Distributed training not available {e}') - - return world_size, rank - - -class AllGather(torch.autograd.Function): - - @staticmethod - def forward(ctx, x): - if ( - dist.is_available() - and dist.is_initialized() - and (dist.get_world_size() > 1) - ): - x = x.contiguous() - outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())] - dist.all_gather(outputs, x) - return torch.cat(outputs, 0) - return x - - @staticmethod - def backward(ctx, grads): - if ( - dist.is_available() - and dist.is_initialized() - and (dist.get_world_size() > 1) - ): - s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank() - e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1) - grads = grads.contiguous() - dist.all_reduce(grads) - return grads[s:e] - return grads - - -class AllReduceSum(torch.autograd.Function): - - @staticmethod - def forward(ctx, x): - if ( - dist.is_available() - and dist.is_initialized() - and (dist.get_world_size() > 1) - ): - x = x.contiguous() - dist.all_reduce(x) - return x - - @staticmethod - def backward(ctx, grads): - return grads - - -class AllReduce(torch.autograd.Function): - - @staticmethod - def forward(ctx, x): - if ( - dist.is_available() - and dist.is_initialized() - and (dist.get_world_size() > 1) - ): - x = x.contiguous() / dist.get_world_size() - dist.all_reduce(x) - return x - - @staticmethod - def backward(ctx, grads): - return grads diff --git a/build/lib/utils/logging.py b/build/lib/utils/logging.py deleted file mode 100644 index fcdd3fa..0000000 --- a/build/lib/utils/logging.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import logging -import sys - -import torch - - -def gpu_timer(closure, log_timings=True): - """ Helper to time gpu-time to execute closure() """ - log_timings = log_timings and torch.cuda.is_available() - - elapsed_time = -1. - if log_timings: - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - - result = closure() - - if log_timings: - end.record() - torch.cuda.synchronize() - elapsed_time = start.elapsed_time(end) - - return result, elapsed_time - - -LOG_FORMAT = "[%(levelname)-8s][%(asctime)s][%(funcName)-25s] %(message)s" -DATE_FORMAT = "%Y-%m-%d %H:%M:%S" - - -def get_logger(name=None, force=False): - logging.basicConfig(stream=sys.stdout, level=logging.INFO, - format=LOG_FORMAT, datefmt=DATE_FORMAT, force=force) - return logging.getLogger(name=name) - - -class CSVLogger(object): - - def __init__(self, fname, *argv): - self.fname = fname - self.types = [] - # -- print headers - with open(self.fname, '+a') as f: - for i, v in enumerate(argv, 1): - self.types.append(v[0]) - if i < len(argv): - print(v[1], end=',', file=f) - else: - print(v[1], end='\n', file=f) - - def log(self, *argv): - with open(self.fname, '+a') as f: - for i, tv in enumerate(zip(self.types, argv), 1): - end = ',' if i < len(argv) else '\n' - print(tv[0] % tv[1], end=end, file=f) - - -class AverageMeter(object): - """computes and stores the average and current value""" - - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.max = float('-inf') - self.min = float('inf') - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - try: - self.max = max(val, self.max) - self.min = min(val, self.min) - except Exception: - pass - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -def grad_logger(named_params): - stats = AverageMeter() - stats.first_layer = None - stats.last_layer = None - for n, p in named_params: - if (p.grad is not None) and not (n.endswith('.bias') or len(p.shape) == 1): - grad_norm = float(torch.norm(p.grad.data)) - stats.update(grad_norm) - if 'qkv' in n: - stats.last_layer = grad_norm - if stats.first_layer is None: - stats.first_layer = grad_norm - if stats.first_layer is None or stats.last_layer is None: - stats.first_layer = stats.last_layer = 0. - return stats - - -def adamw_logger(optimizer): - """ logging magnitude of first and second momentum buffers in adamw """ - # TODO: assert that optimizer is instance of torch.optim.AdamW - state = optimizer.state_dict().get('state') - exp_avg_stats = AverageMeter() - exp_avg_sq_stats = AverageMeter() - for key in state: - s = state.get(key) - exp_avg_stats.update(float(s.get('exp_avg').abs().mean())) - exp_avg_sq_stats.update(float(s.get('exp_avg_sq').abs().mean())) - return {'exp_avg': exp_avg_stats, 'exp_avg_sq': exp_avg_sq_stats} diff --git a/build/lib/utils/monitoring.py b/build/lib/utils/monitoring.py deleted file mode 100644 index 95a7845..0000000 --- a/build/lib/utils/monitoring.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import dataclasses -import threading -from typing import Dict, Tuple - -import psutil - - -@dataclasses.dataclass -class ResourceStatsSample: - timestamp: float - cpu_percent: float - read_count: int - write_count: int - read_bytes: int - write_bytes: int - read_chars: int - write_chars: int - cpu_times_user: float - cpu_times_system: float - cpu_times_children_user: float - cpu_times_children_system: float - cpu_times_iowait: float - cpu_affinity: str - cpu_num: int - num_threads: int - num_voluntary_ctx_switches: int - num_involuntary_ctx_switches: int - - def as_tuple(self) -> Dict: - """Return values mirroring fields.""" - return dataclasses.astuple(self) - - def fields(self) -> Tuple[dataclasses.Field, ...]: - """Return fields in this dataclass.""" - return dataclasses.fields(self.__class__) - - -class ResourceMonitoringThread(threading.Thread): - def __init__(self, pid=None, refresh_interval=None, stats_callback_fn=None): - """Starts a thread to monitor pid every refresh_interval seconds. - - Passes a ResourceStatsSample object to the callback.""" - super(ResourceMonitoringThread, self).__init__() - if refresh_interval is None: - refresh_interval = 5 - self.is_running_event = threading.Event() - self.p = psutil.Process(pid) - self.refresh_interval = refresh_interval - if stats_callback_fn is None: - # Default callback - def stats_callback_fn(resource_sample: ResourceStatsSample): - print( - f"PID {self.p.pid} Stats: {resource_sample.resource_stats}") - elif not callable(stats_callback_fn): - raise ValueError("Callback needs to be callable, got {}".format( - type(stats_callback_fn))) - self.stats_callback_fn = stats_callback_fn - - def stop(self) -> None: - self.is_running_event.set() - - def run(self) -> None: - while not self.is_running_event.is_set(): - self.sample_counters() - self.is_running_event.wait(self.refresh_interval) - - def log_sample(self, resource_sample: ResourceStatsSample) -> None: - self.stats_callback_fn(resource_sample) - - def sample_counters(self) -> None: - if not self.p.is_running(): - self.stop() - return - - with self.p.oneshot(): - cpu_percent = self.p.cpu_percent() - cpu_times = self.p.cpu_times() - io_counters = self.p.io_counters() - cpu_affinity = self.p.cpu_affinity() - cpu_num = self.p.cpu_num() - num_threads = self.p.num_threads() - num_ctx_switches = self.p.num_ctx_switches() - timestamp = time.time() - - read_count = io_counters.read_count - write_count = io_counters.write_count - read_bytes = io_counters.read_bytes - write_bytes = io_counters.write_bytes - read_chars = io_counters.read_chars - write_chars = io_counters.write_chars - - def compress_cpu_affinity(cpu_affinity): - """Change list representation to interval/range representation.""" - if not cpu_affinity: - return "" - cpu_affinity_compressed = [] - min_x = None - max_x = None - last_x = None - - # Find contiguous ranges - for x in cpu_affinity: - if last_x is None: - # Start interval - min_x = x - max_x = x - last_x = x - continue - elif x == (last_x + 1): - # Move interval up - max_x = x - elif max_x is not None: - # Interval ended, start again - if min_x == max_x: - cpu_affinity_compressed.append("{}".format(min_x)) - else: - cpu_affinity_compressed.append( - "{}-{}".format(min_x, max_x)) - min_x = x - max_x = x - last_x = x - # Terminate last range - if max_x is not None: - if min_x == max_x: - cpu_affinity_compressed.append("{}".format(min_x)) - else: - cpu_affinity_compressed.append( - "{}-{}".format(min_x, max_x)) - - # Concat - cpu_affinity_compressed = ",".join(cpu_affinity_compressed) - - return cpu_affinity_compressed - - cpu_affinity = compress_cpu_affinity(cpu_affinity) - - resource_sample = ResourceStatsSample( - timestamp=timestamp, - cpu_percent=cpu_percent, - read_count=read_count, - write_count=write_count, - read_bytes=read_bytes, - write_bytes=write_bytes, - read_chars=read_chars, - write_chars=write_chars, - cpu_times_user=cpu_times.user, - cpu_times_system=cpu_times.system, - cpu_times_children_user=cpu_times.children_user, - cpu_times_children_system=cpu_times.children_system, - cpu_times_iowait=cpu_times.iowait, - cpu_affinity=cpu_affinity, - cpu_num=cpu_num, - num_threads=num_threads, - num_voluntary_ctx_switches=num_ctx_switches.voluntary, - num_involuntary_ctx_switches=num_ctx_switches.involuntary, - ) - self.log_sample(resource_sample) - - -if __name__ == "__main__": - import multiprocessing - import time - pid = multiprocessing.current_process().pid - monitor_thread = ResourceMonitoringThread(pid, 1) - monitor_thread.start() - time.sleep(5) - print("Shutdown") - monitor_thread.stop() diff --git a/build/lib/utils/schedulers.py b/build/lib/utils/schedulers.py deleted file mode 100644 index df02e2b..0000000 --- a/build/lib/utils/schedulers.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import math - - -class WarmupCosineSchedule(object): - - def __init__( - self, - optimizer, - warmup_steps, - start_lr, - ref_lr, - T_max, - last_epoch=-1, - final_lr=0. - ): - self.optimizer = optimizer - self.start_lr = start_lr - self.ref_lr = ref_lr - self.final_lr = final_lr - self.warmup_steps = warmup_steps - self.T_max = T_max - warmup_steps - self._step = 0. - - def step(self): - self._step += 1 - if self._step < self.warmup_steps: - progress = float(self._step) / float(max(1, self.warmup_steps)) - new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) - else: - # -- progress after warmup - progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) - new_lr = max(self.final_lr, - self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1. + math.cos(math.pi * progress))) - - for group in self.optimizer.param_groups: - group['lr'] = new_lr - - return new_lr - - -class CosineWDSchedule(object): - - def __init__( - self, - optimizer, - ref_wd, - T_max, - final_wd=0. - ): - self.optimizer = optimizer - self.ref_wd = ref_wd - self.final_wd = final_wd - self.T_max = T_max - self._step = 0. - - def step(self): - self._step += 1 - progress = self._step / self.T_max - new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1. + math.cos(math.pi * progress)) - - if self.final_wd <= self.ref_wd: - new_wd = max(self.final_wd, new_wd) - else: - new_wd = min(self.final_wd, new_wd) - - for group in self.optimizer.param_groups: - if ('WD_exclude' not in group) or not group['WD_exclude']: - group['weight_decay'] = new_wd - return new_wd diff --git a/build/lib/utils/tensors.py b/build/lib/utils/tensors.py deleted file mode 100644 index 6ae2850..0000000 --- a/build/lib/utils/tensors.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# - -import math - -import torch - -from logging import getLogger - -logger = getLogger() - - -def _no_grad_trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1. + math.erf(x / math.sqrt(2.))) / 2. - - with torch.no_grad(): - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): - # type: (Tensor, float, float, float, float) -> Tensor - return _no_grad_trunc_normal_(tensor, mean, std, a, b) - - -def apply_masks(x, masks): - """ - :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] - :param masks: list of tensors containing indices of patches [0,N) to keep - """ - all_x = [] - for m in masks: - mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) - all_x += [torch.gather(x, dim=1, index=mask_keep)] - return torch.cat(all_x, dim=0) - - -def repeat_interleave_batch(x, B, repeat): - N = len(x) // B - x = torch.cat([ - torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0) - for i in range(N) - ], dim=0) - return x diff --git a/build/lib/vjepa_encoder/__init__.py b/build/lib/vjepa_encoder/__init__.py index e69de29..e5f6838 100644 --- a/build/lib/vjepa_encoder/__init__.py +++ b/build/lib/vjepa_encoder/__init__.py @@ -0,0 +1,6 @@ +from vjepa_encoder.vision_encoder import JepaEncoder + +__all__ = [ + "JepaEncoder", + +] \ No newline at end of file diff --git a/build/lib/vjepa_encoder/vision_encoder.py b/build/lib/vjepa_encoder/vision_encoder.py index 1f473eb..67f0b7a 100644 --- a/build/lib/vjepa_encoder/vision_encoder.py +++ b/build/lib/vjepa_encoder/vision_encoder.py @@ -29,7 +29,7 @@ import torch.multiprocessing as mp import torch.nn.functional as F # from torch.nn.parallel import DistributedDataParallel -from jepa_src.utils.distributed import init_distributed, AllReduce +from jepa_src.utils.distributed import init_distributed from jepa_src.utils.logging import get_logger from vjepa_encoder.vjepa.utils import init_video_model @@ -44,8 +44,6 @@ torch.manual_seed(_GLOBAL_SEED) torch.backends.cudnn.benchmark = True -from jepa_src.models.vision_transformer import VIT_EMBED_DIMS as JEPA_DIM_SIZE - import logging from jepa_src.utils.logging import get_logger logger = get_logger(force=True) @@ -57,6 +55,17 @@ def __init__(self, args): self.args = args self.encoder, self.predictor = None, None + def save_checkpoint(self, path): + save_dict = { + 'encoder': self.encoder.state_dict(), + } + try: + torch.save(save_dict, path) + logger.info(f'Saved encoder state to {path}') + + except Exception as e: + logger.info(f'Encountered exception when saving checkpoint: {e}') + def preprocess_image(self, input_data: Any): """ Preprocess the input image data. @@ -192,7 +201,7 @@ def load_encoder_checkpoint( # -- loading encoder pretrained_dict = checkpoint['encoder'] msg = encoder.load_state_dict(pretrained_dict) - logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') + logger.info(f'loaded pretrained encoder from {r_path} with msg: {msg}') except Exception as e: logger.info(f'Encountered exception when loading checkpoint {e}') diff --git a/demo_jepa_encoder.py b/demo_jepa_encoder.py index 9a8842f..c432cbf 100644 --- a/demo_jepa_encoder.py +++ b/demo_jepa_encoder.py @@ -19,4 +19,6 @@ embedding = encoder.embed_image(x) print(embedding) -print(embedding.shape) \ No newline at end of file +print(embedding.shape) + +encoder.save_checkpoint("./test_jepa_model.tar") \ No newline at end of file diff --git a/jepa_encoder.egg-info/PKG-INFO b/jepa_encoder.egg-info/PKG-INFO deleted file mode 100644 index 6a3951f..0000000 --- a/jepa_encoder.egg-info/PKG-INFO +++ /dev/null @@ -1,17 +0,0 @@ -Metadata-Version: 2.1 -Name: jepa-encoder -Version: 0.0.1 -Summary: JEPA research code. -Requires-Python: >=3.9 -License-File: LICENSE -Requires-Dist: pyyaml -Requires-Dist: numpy -Requires-Dist: opencv-python -Requires-Dist: submitit -Requires-Dist: braceexpand -Requires-Dist: webdataset -Requires-Dist: timm -Requires-Dist: decord -Requires-Dist: pandas -Requires-Dist: einops -Requires-Dist: beartype diff --git a/jepa_encoder.egg-info/SOURCES.txt b/jepa_encoder.egg-info/SOURCES.txt deleted file mode 100644 index 00be8b0..0000000 --- a/jepa_encoder.egg-info/SOURCES.txt +++ /dev/null @@ -1,10 +0,0 @@ -LICENSE -README.md -setup.py -jepa_encoder.egg-info/PKG-INFO -jepa_encoder.egg-info/SOURCES.txt -jepa_encoder.egg-info/dependency_links.txt -jepa_encoder.egg-info/requires.txt -jepa_encoder.egg-info/top_level.txt -vjepa_encoder/__init__.py -vjepa_encoder/vision_encoder.py \ No newline at end of file diff --git a/jepa_encoder.egg-info/dependency_links.txt b/jepa_encoder.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789..0000000 --- a/jepa_encoder.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/jepa_encoder.egg-info/requires.txt b/jepa_encoder.egg-info/requires.txt deleted file mode 100644 index 386919b..0000000 --- a/jepa_encoder.egg-info/requires.txt +++ /dev/null @@ -1,11 +0,0 @@ -pyyaml -numpy -opencv-python -submitit -braceexpand -webdataset -timm -decord -pandas -einops -beartype diff --git a/jepa_encoder.egg-info/top_level.txt b/jepa_encoder.egg-info/top_level.txt deleted file mode 100644 index cca3137..0000000 --- a/jepa_encoder.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ -vjepa_encoder diff --git a/setup.py b/setup.py index 5865e1a..c31890f 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ import os from setuptools import setup -VERSION = "0.0.4" +VERSION = "0.1.1" from setuptools import setup, find_packages diff --git a/vjepa_encoder.egg-info/PKG-INFO b/vjepa_encoder.egg-info/PKG-INFO index ee525c7..e7b09bf 100644 --- a/vjepa_encoder.egg-info/PKG-INFO +++ b/vjepa_encoder.egg-info/PKG-INFO @@ -1,11 +1,8 @@ -Metadata-Version: 1.2 +Metadata-Version: 2.1 Name: vjepa-encoder -Version: 0.0.4 +Version: 0.1.1 Summary: JEPA research code. -Home-page: UNKNOWN Author: Jonathan Koch Author-email: johnnykoch02@gmail.com -License: UNKNOWN -Description: UNKNOWN -Platform: UNKNOWN Requires-Python: >=3.7 +License-File: LICENSE diff --git a/vjepa_encoder/__init__.py b/vjepa_encoder/__init__.py index e69de29..e5f6838 100644 --- a/vjepa_encoder/__init__.py +++ b/vjepa_encoder/__init__.py @@ -0,0 +1,6 @@ +from vjepa_encoder.vision_encoder import JepaEncoder + +__all__ = [ + "JepaEncoder", + +] \ No newline at end of file diff --git a/vjepa_encoder/vision_encoder.py b/vjepa_encoder/vision_encoder.py index 7d74393..67f0b7a 100644 --- a/vjepa_encoder/vision_encoder.py +++ b/vjepa_encoder/vision_encoder.py @@ -29,7 +29,7 @@ import torch.multiprocessing as mp import torch.nn.functional as F # from torch.nn.parallel import DistributedDataParallel -from jepa_src.utils.distributed import init_distributed, AllReduce +from jepa_src.utils.distributed import init_distributed from jepa_src.utils.logging import get_logger from vjepa_encoder.vjepa.utils import init_video_model @@ -55,6 +55,17 @@ def __init__(self, args): self.args = args self.encoder, self.predictor = None, None + def save_checkpoint(self, path): + save_dict = { + 'encoder': self.encoder.state_dict(), + } + try: + torch.save(save_dict, path) + logger.info(f'Saved encoder state to {path}') + + except Exception as e: + logger.info(f'Encountered exception when saving checkpoint: {e}') + def preprocess_image(self, input_data: Any): """ Preprocess the input image data. @@ -190,7 +201,7 @@ def load_encoder_checkpoint( # -- loading encoder pretrained_dict = checkpoint['encoder'] msg = encoder.load_state_dict(pretrained_dict) - logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') + logger.info(f'loaded pretrained encoder from {r_path} with msg: {msg}') except Exception as e: logger.info(f'Encountered exception when loading checkpoint {e}')