Skip to content

geometric-intelligence/pirounet

Repository files navigation

PirouNet

Official PyTorch implementation of the paper “PirouNet: Creating Dance through Artist-Centric Deep Learning”

[Pre-print] [Paper], shared at [EAI ArtsIT 2022]. Best Paper Award.

[Summary], shared at [NeurIPS Workshop for Creativity and Design 2022]

Overview of PirouNet's LSTM+VAE architecture.

PirouNet is a semi-supervised conditional recurrent variational autoencoder. This code is responsible for training and evaluating the model. Labels must be created separately prior to training. We propose this dance labeling web application which can be customized to the user's labeling needs.

🌎 Bibtex

If this code is useful to your research, please cite:

@InProceedings{papillon2023pirounet,
author="Papillon, Mathilde
and Pettee, Mariel
and Miolane, Nina",
editor="Brooks, Anthony L.",
title="PirouNet: Creating Dance Through Artist-Centric Deep Learning",
booktitle="ArtsIT, Interactivity and Game Creation",
year="2023",
publisher="Springer Nature Switzerland",
address="Cham",
pages="447--465",
isbn="978-3-031-28993-4"
}

Conditionally created dance sequences: Animated dance sequences conditionally created by PirouNet.

Reconstructed dance sequence: PirouNet reconstructs input dance.

making_dance_with_intention_03.mp4

🏡 Installation

This codes runs on Python 3.8. We recommend using Anaconda for easy installation. To create the necessary conda environment, run:

cd pirounet
conda env create -f environment.yml
conda activate choreo

🚀 Training

To train a new model (see below for loading a saved model), follow the steps below.

1. Set up Wandb logging.

Wandb is a powerful tool for logging performance during training, as well as animation artifacts. To use it, simply create an account, then run:

wandb login

to sign into your account.

2. Specify hyperparameters in default_config.py.

For wandb: Specify your wandb account under “entity” and title of the project under “project_name”. “run_name” will title this specific run within the project.

If specified, “load_from_checkpoint” indicates the saved model to load. Leave as “None” for training a new model.

Other hyperparameters are organized by category: hardware (choice of CUDA device), training, input data, LSTM VAE architecture, and classifier architecture.

3. Train!

For a single run, use the command:

python main.py

For a hyperparameter sweep (multiple runs), we invite you to follow wandb’s Quickstart guide and run the resulting wandb sweep command.

📕 Load a saved model.

There are three basic types of models to load:

  • $\text{PirouNet}_\text{watch}.$ Copy contents of saved_models/pirounet_watch_config.py file into default_config.py.

  • $\text{PirouNet}_\text{dance}.$ Copy contents of saved_models/pirounet_dance_config.py file into default_config.py.

  • Your new model. In default_config.py, specify “load_from_checkpoint” as the name and epoch corresponding your new model:“checkpoint_{run_name}_epoch{epoch}”. Make sure the rest of the hyperparameters match those you used during training.

Once this is done, there are two options:

  1. Continue training using this saved model as a starting point. See “Training” section.
  2. Evaluate this saved model.

🕺 Evaluation

  1. Follow the “Load a saved model” instructions to configure default_config.py.
  2. Specify the parameters of the evaluation in eval_config.py. Note that “plot_recognition_accuracy” should only be set to True once a human labeler has blindly labeled PirouNet-generated dance sequences (using generate_for_blind_labeling and the web-labeling app), and exported the csv of labels to the pirounet/pirounet/data directory.
  3. Unzip the pre-saved classifier model in saved_models/classifier.
  4. Run the command:
python main_eval.py

This will produce a subfolder in pirounet/results containing all the qualitative and quantitative metrics included in our paper, as well as extra plots of the latent space and its entanglement. Among the qualitative generation metrics, two examples are provided below.

💃 Authors

Mathilde Papillon

Mariel Pettee

Nina Miolane