Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Replicate to use multiple outputs #12

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
checkpoints/waveunet/
.cog
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Wave-U-Net (Pytorch)
<a href="https://replicate.ai/f90/wave-u-net-pytorch"><img src="https://img.shields.io/static/v1?label=Replicate&message=Demo and Docker Image&color=darkgreen" height=20></a>
<a href="https://replicate.com/f90/wave-u-net-pytorch"><img src="https://replicate.com/f90/wave-u-net-pytorch/badge"></a>

Improved version of the [Wave-U-Net](https://arxiv.org/abs/1806.03185) for audio source separation, implemented in Pytorch.

Expand Down Expand Up @@ -97,6 +97,10 @@ We provide the default model in a pre-trained form as download so you can separa
Download our pretrained model [here](https://www.dropbox.com/s/r374hce896g4xlj/models.7z?dl=1).
Extract the archive into the ``checkpoints`` subfolder in this repository, so that you have one subfolder for each model (e.g. ``REPO/checkpoints/waveunet``)

If you have Docker installed, you can run this script to download the weights from [Replicate](https://replicate.com/f90/wave-u-net-pytorch):

$ script/download-weights

## Run pretrained model

To apply our pretrained model to any of your own songs, simply point to its audio file path using the ``input_path`` parameter:
Expand Down
4 changes: 2 additions & 2 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
build:
python_version: "3.6"
python_version: "3.7"
gpu: false
python_packages:
- future==0.18.2
Expand All @@ -17,4 +17,4 @@ build:
system_packages:
- libsndfile-dev
- ffmpeg
predict: "cog_predict.py:waveunetPredictor"
predict: "cog_predict.py:Predictor"
45 changes: 22 additions & 23 deletions cog_predict.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
import argparse
import os
import cog
import tempfile
import zipfile
from pathlib import Path
import argparse

from cog import BasePredictor, Input, Path, BaseModel

import data.utils
import model.utils as model_utils
from test import predict_song
from model.waveunet import Waveunet
from test import predict_song


class Output(BaseModel):
bass: Path
drums: Path
other: Path
vocals: Path


class waveunetPredictor(cog.Predictor):
class Predictor(BasePredictor):
def setup(self):
"""Init wave u net model"""
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -112,33 +120,24 @@ def setup(self):
)

if args.cuda:
self.model = model_utils.DataParallel(model)
self.model = model_utils.DataParallel(self.model)
print("move model to gpu")
self.model.cuda()

print("Loading model from checkpoint " + str(args.load_model))
state = model_utils.load_model(self.model, None, args.load_model, args.cuda)
print("Step", state["step"])

@cog.input("input", type=Path, help="audio mixture path")
def predict(self, input):
def predict(self, mix: Path = Input(description="audio mixture path")) -> Output:
"""Separate tracks from input mixture audio"""

out_path = Path(tempfile.mkdtemp())
zip_path = Path(tempfile.mkdtemp()) / "output.zip"
tmpdir = Path(tempfile.mkdtemp())

preds = predict_song(self.args, input, self.model)

out_names = []
preds = predict_song(self.args, mix, self.model)
output = {}
for inst in preds.keys():
temp_n = os.path.join(
str(out_path), os.path.basename(str(input)) + "_" + inst + ".wav"
)
data.utils.write_wav(temp_n, preds[inst], self.args.sr)
out_names.append(temp_n)

with zipfile.ZipFile(str(zip_path), "w") as zf:
for i in out_names:
zf.write(str(i))
path = tmpdir / (inst + ".wav")
data.utils.write_wav(path, preds[inst], self.args.sr)
output[inst] = path

return zip_path
return Output(**output)
3 changes: 3 additions & 0 deletions script/download-weights
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash
id=$(docker create r8.im/f90/wave-u-net-pytorch)
docker cp $id:/src/checkpoints ./