diff --git a/Package.swift b/Package.swift
index 6ab4afaa..fa814d7f 100644
--- a/Package.swift
+++ b/Package.swift
@@ -6,9 +6,9 @@ import PackageDescription
let package = Package(
name: "stable-diffusion",
platforms: [
- .macOS(.v13),
- .iOS(.v16),
- ],
+ .macOS(.v11),
+ .iOS(.v14),
+ ],
products: [
.library(
name: "StableDiffusion",
diff --git a/README.md b/README.md
index 8e7e60e2..f108f360 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@ This repository comprises:
- `python_coreml_stable_diffusion`, a Python package for converting PyTorch models to Core ML format and performing image generation with Hugging Face [diffusers](https://github.com/huggingface/diffusers) in Python
- `StableDiffusion`, a Swift package that developers can add to their Xcode projects as a dependency to deploy image generation capabilities in their apps. The Swift package relies on the Core ML model files generated by `python_coreml_stable_diffusion`
-If you run into issues during installation or runtime, please refer to the [FAQ](#FAQ) section.
+If you run into issues during installation or runtime, please refer to the [FAQ](#faq) section. Please refer to the [System Requirements](#system-requirements) section before getting started.
## Example Results
@@ -25,6 +25,80 @@ M2 MacBook Air 8GB Latency (s) | 18 | 23 | 23 |
Please see [Important Notes on Performance Benchmarks](#important-notes-on-performance-benchmarks) section for details.
+## System Requirements
+
+The following is recommended to use all the functionality in this repository:
+
+Python | macOS | Xcode | iPadOS, iOS |
+:------:|:------:|:------:|:------:|
+3.8 | 13.1 | 14.2 | 16.2 |
+
+## Using Ready-made Core ML Models from Hugging Face Hub
+
+
+ Click to expand
+
+🤗 Hugging Face ran the [conversion procedure](#converting-models-to-coreml) on the following models and made the Core ML weights publicly available on the Hub. If you would like to convert a version of Stable Diffusion that is not already available on the Hub, please refer to the [Converting Models to Core ML](#converting-models-to-core-ml).
+
+* [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/apple/coreml-stable-diffusion-v1-4)
+* [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/apple/coreml-stable-diffusion-v1-5)
+* [`stabilityai/stable-diffusion-2-base`](https://huggingface.co/apple/coreml-stable-diffusion-2-base)
+
+If you want to use any of those models you may download the weights and proceed to [generate images with Python](#image-generation-with-python) or [Swift](#image-generation-with-swift).
+
+There are several variants in each model repository. You may clone the whole repos using `git` and `git lfs` to download all variants, or selectively download the ones you need.
+
+To clone the repos using `git`, please follow this process:
+
+**Step 1:** Install the `git lfs` extension for your system.
+
+`git lfs` stores large files outside the main git repo, and it downloads them from the appropriate server after you clone or checkout. It is available in most package managers, check [the installation page](https://git-lfs.com) for details.
+
+**Step 2:** Enable `git lfs` by running this command once:
+
+```bash
+git lfs install
+```
+
+**Step 3:** Use `git clone` to download a copy of the repo that includes all model variants. For Stable Diffusion version 1.4, you'd issue the following command in your terminal:
+
+```bash
+git clone https://huggingface.co/apple/coreml-stable-diffusion-v1-4
+```
+
+If you prefer to download specific variants instead of cloning the repos, you can use the `huggingface_hub` Python library. For example, to do generation in Python using the `ORIGINAL` attention implementation (read [this section](#converting-models-to-core-ml) for details), you could use the following helper code:
+
+```Python
+from huggingface_hub import snapshot_download
+from huggingface_hub.file_download import repo_folder_name
+from pathlib import Path
+import shutil
+
+repo_id = "apple/coreml-stable-diffusion-v1-4"
+variant = "original/packages"
+
+def download_model(repo_id, variant, output_dir):
+ destination = Path(output_dir) / (repo_id.split("/")[-1] + "_" + variant.replace("/", "_"))
+ if destination.exists():
+ raise Exception(f"Model already exists at {destination}")
+
+ # Download and copy without symlinks
+ downloaded = snapshot_download(repo_id, allow_patterns=f"{variant}/*", cache_dir=output_dir)
+ downloaded_bundle = Path(downloaded) / variant
+ shutil.copytree(downloaded_bundle, destination)
+
+ # Remove all downloaded files
+ cache_folder = Path(output_dir) / repo_folder_name(repo_id=repo_id, repo_type="model")
+ shutil.rmtree(cache_folder)
+ return destination
+
+model_path = download_model(repo_id, variant, output_dir="./models")
+print(f"Model downloaded at {model_path}")
+```
+
+`model_path` would be the path in your local filesystem where the checkpoint was saved. Please, refer to [this post](https://huggingface.co/blog/diffusers-coreml) for additional details.
+
+
## Converting Models to Core ML
@@ -50,7 +124,7 @@ pip install -e .
python -m python_coreml_stable_diffusion.torch2coreml --convert-unet --convert-text-encoder --convert-vae-decoder --convert-safety-checker -o
```
-**WARNING:** This command will download several GB worth of PyTorch checkpoints from Hugging Face.
+**WARNING:** This command will download several GB worth of PyTorch checkpoints from Hugging Face. Please ensure that you are on Wi-Fi and have enough disk space.
This generally takes 15-20 minutes on an M1 MacBook Pro. Upon successful execution, the 4 neural network models that comprise Stable Diffusion will have been converted from PyTorch to Core ML (`.mlpackage`) and saved into the specified ``. Some additional notable arguments:
@@ -59,9 +133,9 @@ This generally takes 15-20 minutes on an M1 MacBook Pro. Upon successful executi
- `--bundle-resources-for-swift-cli`: Compiles all 4 models and bundles them along with necessary resources for text tokenization into `/Resources` which should provided as input to the Swift package. This flag is not necessary for the diffusers-based Python pipeline.
-- `--chunk-unet`: Splits the Unet model in two approximately equal chunks (each with less than 1GB of weights) for mobile-friendly deployment. This is **required** for ANE deployment on iOS and iPadOS. This is not required for macOS. Swift CLI is able to consume both the chunked and regular versions of the Unet model but prioritizes the former. Note that chunked unet is not compatible with the Python pipeline because Python pipeline is intended for macOS only. Chunking is for on-device deployment with Swift only.
+- `--chunk-unet`: Splits the Unet model in two approximately equal chunks (each with less than 1GB of weights) for mobile-friendly deployment. This is **required** for Neural Engine deployment on iOS and iPadOS. This is not required for macOS. Swift CLI is able to consume both the chunked and regular versions of the Unet model but prioritizes the former. Note that chunked unet is not compatible with the Python pipeline because Python pipeline is intended for macOS only. Chunking is for on-device deployment with Swift only.
-- `--attention-implementation`: Defaults to `SPLIT_EINSUM` which is the implementation described in [Deploying Transformers on the Apple Neural Engine](https://machinelearning.apple.com/research/neural-engine-transformers). `--attention-implementation ORIGINAL` will switch to an alternative that should be used for non-ANE deployment. Please refer to the [Performance Benchmark](#performance-benchmark) section for further guidance.
+- `--attention-implementation`: Defaults to `SPLIT_EINSUM` which is the implementation described in [Deploying Transformers on the Apple Neural Engine](https://machinelearning.apple.com/research/neural-engine-transformers). `--attention-implementation ORIGINAL` will switch to an alternative that should be used for CPU or GPU deployment. Please refer to the [Performance Benchmark](#performance-benchmark) section for further guidance.
- `--check-output-correctness`: Compares original PyTorch model's outputs to final Core ML model's outputs. This flag increases RAM consumption significantly so it is recommended only for debugging purposes.
@@ -86,27 +160,30 @@ Please refer to the help menu for all available arguments: `python -m python_cor
-## Image Generation with Swift
+## Image Generation with Swift
Click to expand
### System Requirements
-Building the Swift projects require:
-- macOS 13 or newer
-- Xcode 14.1 or newer with command line tools installed. Please check [developer.apple.com](https://developer.apple.com/download/all/?q=xcode) for the latest version.
-- Core ML models and tokenization resources. Please see `--bundle-resources-for-swift-cli` from the [Converting Models to Core ML](#converting-models-to-coreml) section above
-
-If deploying this model to:
-- iPhone
- - iOS 16.2 or newer
- - iPhone 12 or newer
-- iPad
- - iPadOS 16.2 or newer
- - M1 or newer
-- Mac
- - macOS 13.1 or newer
- - M1 or newer
+
+**Building** (recommended):
+
+- Xcode 14.2
+- Command Line Tools for Xcode 14.2
+
+Check [developer.apple.com](https://developer.apple.com/download/all/?q=xcode) for the latest versions.
+
+**Running** (minimum):
+
+| Mac | iPad\* | iPhone\* |
+|:------:|:------:|:------:|
+| macOS 13.1 | iPadOS 16.2 | iOS 16.2 |
+| M1 | M1 | iPhone 12 Pro |
+
+You will also need the resources generated by the `--bundle-resources-for-swift-cli` option described in [Converting Models to Core ML](#converting-models-to-coreml)
+
+\* Please see [FAQ](#faq) [Q6](#q-mobile-app) regarding deploying on iPad and iPhone.
### Example CLI Usage
```shell
@@ -123,8 +200,10 @@ Please use the `--help` flag to learn about batched generation and more.
import StableDiffusion
...
let pipeline = try StableDiffusionPipeline(resourcesAt: resourceURL)
+pipeline.loadResources()
let image = try pipeline.generateImages(prompt: prompt, seed: seed).first
```
+On iOS, the `reduceMemory` option should be set to `true` when constructing `StableDiffusionPipeline`
### Swift Package Details
@@ -149,6 +228,15 @@ Note that the chunked version of Unet is checked for first. Only if it is not pr
+## Example Swift App
+
+
+ Click to expand
+
+🤗 Hugging Face created an [open-source demo app](https://github.com/huggingface/swift-coreml-diffusers) on top of this library. It's written in native Swift and Swift UI, and runs on macOS, iOS and iPadOS. You can use the code as a starting point for your app, or to see how to integrate this library in your own projects.
+
+
+
## Performance Benchmark
@@ -184,7 +272,7 @@ Please see [Important Notes on Performance Benchmarks](#important-notes-on-perfo
- The image generation procedure follows the standard configuration: 50 inference steps, 512x512 output image resolution, 77 text token sequence length, classifier-free guidance (batch size of 2 for unet).
- The actual prompt length does not impact performance because the Core ML model is converted with a static shape that computes the forward pass for all of the 77 elements (`tokenizer.model_max_length`) in the text token sequence regardless of the actual length of the input text.
- Pipelining across the 4 models is not optimized and these performance numbers are subject to variance under increased system load from other applications. Given these factors, we do not report sub-second variance in latency.
-- Weights and activations are in float16 precision for both the GPU and the ANE.
+- Weights and activations are in float16 precision for both the GPU and the Neural Engine.
- The Swift CLI program consumes a peak memory of approximately 2.6GB (without the safety checker), 2.1GB of which is model weights in float16 precision. We applied [8-bit weight quantization](https://coremltools.readme.io/docs/compressing-ml-program-weights#use-affine-quantization) to reduce peak memory consumption by approximately 1GB. However, we observed that it had an adverse effect on generated image quality and we rolled it back. We encourage developers to experiment with other advanced weight compression techniques such as [palettization](https://coremltools.readme.io/docs/compressing-ml-program-weights#use-a-lookup-table) and/or [pruning](https://coremltools.readme.io/docs/compressing-ml-program-weights#use-sparse-representation) which may yield better results.
- In the [benchmark table](performance-benchmark), we report the best performing `--compute-unit` and `--attention-implementation` values per device. The former does not modify the Core ML model and can be applied during runtime. The latter modifies the Core ML model. Note that the best performing compute unit is model version and hardware-specific.
@@ -208,7 +296,7 @@ Differences may be less or more pronounced for different inputs. Please see the
-## FAQ
+## FAQ
Click to expand
@@ -228,7 +316,7 @@ Differences may be less or more pronounced for different inputs. Please see the
- Q3: My Mac has 8GB RAM and I am converting models to Core ML using the example command. The process is geting killed because of memory issues. How do I fix this issue?
+ Q3: My Mac has 8GB RAM and I am converting models to Core ML using the example command. The process is getting killed because of memory issues. How do I fix this issue?
A3: In order to minimize the memory impact of the model conversion process, please execute the following command instead:
@@ -257,12 +345,22 @@ python -m python_coreml_stable_diffusion.torch2coreml --convert-unet --chunk-une
Q5: Every time I generate an image using the Python pipeline, loading all the Core ML models takes 2-3 minutes. Is this expected?
A5: Yes and using the Swift library reduces this to just a few seconds. The reason is that `coremltools` loads Core ML models (`.mlpackage`) and each model is compiled to be run on the requested compute unit during load time. Because of the size and number of operations of the unet model, it takes around 2-3 minutes to compile it for Neural Engine execution. Other models should take at most a few seconds. Note that `coremltools` does not cache the compiled model for later loads so each load takes equally long. In order to benefit from compilation caching, `StableDiffusion` Swift package by default relies on compiled Core ML models (`.mlmodelc`) which will be compiled down for the requested compute unit upon first load but then the cache will be reused on subsequent loads until it is purged due to lack of use.
+
+If you intend to use the Python pipeline in an application, we recommend initializing the pipeline once so that the load time is only incurred once. Afterwards, generating images using different prompts and random seeds will not incur the load time for the current session of your application.
+
+
- Q6: I want to deploy StableDiffusion
, the Swift package, in my mobile app. What should I be aware of?"
+ Q6: I want to deploy StableDiffusion
, the Swift package, in my mobile app. What should I be aware of?
- A6: [This section](#swift-requirements) describes the minimum SDK and OS versions as well as the device models supported by this package. In addition to these requirements, for best practice, we recommend testing the package on the device with the least amount of RAM available among your deployment targets. This is due to the fact that `StableDiffusion` consumes approximately 2.6GB of peak memory during runtime while using `.cpuAndNeuralEngine` (the Swift equivalent of `coremltools.ComputeUnit.CPU_AND_NE`). Other compute units may have a higher peak memory consumption so `.cpuAndNeuralEngine` is recommended for iOS and iPadOS deployment (Please refer to this [section](#swift-requirements) for minimum device model requirements). If your app crashes during image generation, please try adding the [Increased Memory Limit](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_increased-memory-limit) capability to your Xcode project which should significantly increase your app's memory limit.
+ A6: The [Image Generation with Swift](#image-gen-swift) section describes the minimum SDK and OS versions as well as the device models supported by this package. We recommend carefully testing the package on the device with the least amount of RAM available among your deployment targets.
+
+The image generation process in `StableDiffusion` can yield over 2 GB of peak memory during runtime depending on the compute units selected. On iPadOS, we recommend using `.cpuAndNeuralEngine` in your configuration and the `reduceMemory` option when constructing a `StableDiffusionPipeline` to minimize memory pressure.
+
+If your app crashes during image generation, consider adding the [Increased Memory Limit](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_increased-memory-limit) capability to inform the system that some of your app’s core features may perform better by exceeding the default app memory limit on supported devices.
+
+On iOS, depending on the iPhone model, Stable Diffusion model versions, selected compute units, system load and design of your app, this may still not be sufficient to keep your apps peak memory under the limit. Please remember, because the device shares memory between apps and iOS processes, one app using too much memory can compromise the user experience across the whole device.
@@ -291,7 +389,7 @@ python -m python_coreml_stable_diffusion.torch2coreml --convert-unet --chunk-une
4. Weights and Activations Data Type
- When quantizing models from float32 to lower-precision data types such as float16, the generated images are [known to vary slightly](https://lambdalabs.com/blog/inference-benchmark-stable-diffusion) in semantics even when using the same PyTorch model. Core ML models generated by coremltools have float16 weights and activations by default [unless explicitly overriden](https://github.com/apple/coremltools/blob/main/coremltools/converters/_converters_entry.py#L256). This is not expected to be a major source of difference.
+ When quantizing models from float32 to lower-precision data types such as float16, the generated images are [known to vary slightly](https://lambdalabs.com/blog/inference-benchmark-stable-diffusion) in semantics even when using the same PyTorch model. Core ML models generated by coremltools have float16 weights and activations by default [unless explicitly overridden](https://github.com/apple/coremltools/blob/main/coremltools/converters/_converters_entry.py#L256). This is not expected to be a major source of difference.
@@ -302,4 +400,26 @@ python -m python_coreml_stable_diffusion.torch2coreml --convert-unet --chunk-une
+
+ Q10: `Could not initialize NNPACK! Reason: Unsupported hardware`
+
+ A10: This warning is safe to ignore in the context of this repository.
+
+
+
+
+ Q11: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect
+
+ A11: This warning is safe to ignore in the context of this repository.
+
+
+
+ Q12: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
+
+ A12: If this warning is printed right after zsh: killed python -m python_coreml_stable_diffusion.torch2coreml ...
, then it is highly likely that your Mac has run out of memory while converting models to Core ML. Please see [Q3](#low-mem-conversion) from above for the solution.
+
+
+
+
+
diff --git a/python_coreml_stable_diffusion/pipeline.py b/python_coreml_stable_diffusion/pipeline.py
index 65d0deaa..b55e47cf 100644
--- a/python_coreml_stable_diffusion/pipeline.py
+++ b/python_coreml_stable_diffusion/pipeline.py
@@ -38,7 +38,7 @@
import time
import torch # Only used for `torch.from_tensor` in `pipe.scheduler.step()`
from transformers import CLIPFeatureExtractor, CLIPTokenizer
-from typing import Union, Optional
+from typing import List, Optional, Union
class CoreMLStableDiffusionPipeline(DiffusionPipeline):
diff --git a/python_coreml_stable_diffusion/torch2coreml.py b/python_coreml_stable_diffusion/torch2coreml.py
index f079d00f..6d6c2fad 100644
--- a/python_coreml_stable_diffusion/torch2coreml.py
+++ b/python_coreml_stable_diffusion/torch2coreml.py
@@ -576,7 +576,7 @@ def convert_unet(pipe, args):
# Set the output descriptions
coreml_unet.output_description["noise_pred"] = \
"Same shape and dtype as the `sample` input. " \
- "The predicted noise to faciliate the reverse diffusion (denoising) process"
+ "The predicted noise to facilitate the reverse diffusion (denoising) process"
_save_mlpackage(coreml_unet, out_path)
logger.info(f"Saved unet into {out_path}")
diff --git a/setup.py b/setup.py
index 88ecdea0..69c0e393 100644
--- a/setup.py
+++ b/setup.py
@@ -19,6 +19,7 @@
"torch",
"transformers",
"scipy",
+ "numpy<1.24",
],
packages=find_packages(),
classifiers=[
diff --git a/swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler.swift b/swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler.swift
new file mode 100644
index 00000000..704472a8
--- /dev/null
+++ b/swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler.swift
@@ -0,0 +1,182 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright (C) 2022 Apple Inc. and The HuggingFace Team. All Rights Reserved.
+
+import Accelerate
+import CoreML
+
+/// A scheduler used to compute a de-noised image
+///
+/// This implementation matches:
+/// [Hugging Face Diffusers DPMSolverMultistepScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py)
+///
+/// It uses the DPM-Solver++ algorithm: [code](https://github.com/LuChengTHU/dpm-solver) [paper](https://arxiv.org/abs/2211.01095).
+/// Limitations:
+/// - Only implemented for DPM-Solver++ algorithm (not DPM-Solver).
+/// - Second order only.
+/// - Assumes the model predicts epsilon.
+/// - No dynamic thresholding.
+/// - `midpoint` solver algorithm.
+@available(iOS 16.2, macOS 13.1, *)
+public final class DPMSolverMultistepScheduler: Scheduler {
+ public let trainStepCount: Int
+ public let inferenceStepCount: Int
+ public let betas: [Float]
+ public let alphas: [Float]
+ public let alphasCumProd: [Float]
+ public let timeSteps: [Int]
+
+ public let alpha_t: [Float]
+ public let sigma_t: [Float]
+ public let lambda_t: [Float]
+
+ public let solverOrder = 2
+ private(set) var lowerOrderStepped = 0
+
+ /// Whether to use lower-order solvers in the final steps. Only valid for less than 15 inference steps.
+ /// We empirically find this trick can stabilize the sampling of DPM-Solver, especially with 10 or fewer steps.
+ public let useLowerOrderFinal = true
+
+ // Stores solverOrder (2) items
+ private(set) var modelOutputs: [MLShapedArray] = []
+
+ /// Create a scheduler that uses a second order DPM-Solver++ algorithm.
+ ///
+ /// - Parameters:
+ /// - stepCount: Number of inference steps to schedule
+ /// - trainStepCount: Number of training diffusion steps
+ /// - betaSchedule: Method to schedule betas from betaStart to betaEnd
+ /// - betaStart: The starting value of beta for inference
+ /// - betaEnd: The end value for beta for inference
+ /// - Returns: A scheduler ready for its first step
+ public init(
+ stepCount: Int = 50,
+ trainStepCount: Int = 1000,
+ betaSchedule: BetaSchedule = .scaledLinear,
+ betaStart: Float = 0.00085,
+ betaEnd: Float = 0.012
+ ) {
+ self.trainStepCount = trainStepCount
+ self.inferenceStepCount = stepCount
+
+ switch betaSchedule {
+ case .linear:
+ self.betas = linspace(betaStart, betaEnd, trainStepCount)
+ case .scaledLinear:
+ self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 })
+ }
+
+ self.alphas = betas.map({ 1.0 - $0 })
+ var alphasCumProd = self.alphas
+ for i in 1.., timestep: Int, sample: MLShapedArray) -> MLShapedArray {
+ assert(modelOutput.scalars.count == sample.scalars.count)
+ let (alpha_t, sigma_t) = (self.alpha_t[timestep], self.sigma_t[timestep])
+
+ // This could be optimized with a Metal kernel if we find we need to
+ let x0_scalars = zip(modelOutput.scalars, sample.scalars).map { m, s in
+ (s - m * sigma_t) / alpha_t
+ }
+ return MLShapedArray(scalars: x0_scalars, shape: modelOutput.shape)
+ }
+
+ /// One step for the first-order DPM-Solver (equivalent to DDIM).
+ /// See https://arxiv.org/abs/2206.00927 for the detailed derivation.
+ /// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+ func firstOrderUpdate(
+ modelOutput: MLShapedArray,
+ timestep: Int,
+ prevTimestep: Int,
+ sample: MLShapedArray
+ ) -> MLShapedArray {
+ let (p_lambda_t, lambda_s) = (Double(lambda_t[prevTimestep]), Double(lambda_t[timestep]))
+ let p_alpha_t = Double(alpha_t[prevTimestep])
+ let (p_sigma_t, sigma_s) = (Double(sigma_t[prevTimestep]), Double(sigma_t[timestep]))
+ let h = p_lambda_t - lambda_s
+ // x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
+ let x_t = weightedSum(
+ [p_sigma_t / sigma_s, -p_alpha_t * (exp(-h) - 1)],
+ [sample, modelOutput]
+ )
+ return x_t
+ }
+
+ /// One step for the second-order multistep DPM-Solver++ algorithm, using the midpoint method.
+ /// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+ func secondOrderUpdate(
+ modelOutputs: [MLShapedArray],
+ timesteps: [Int],
+ prevTimestep t: Int,
+ sample: MLShapedArray
+ ) -> MLShapedArray {
+ let (s0, s1) = (timesteps[back: 1], timesteps[back: 2])
+ let (m0, m1) = (modelOutputs[back: 1], modelOutputs[back: 2])
+ let (p_lambda_t, lambda_s0, lambda_s1) = (Double(lambda_t[t]), Double(lambda_t[s0]), Double(lambda_t[s1]))
+ let p_alpha_t = Double(alpha_t[t])
+ let (p_sigma_t, sigma_s0) = (Double(sigma_t[t]), Double(sigma_t[s0]))
+ let (h, h_0) = (p_lambda_t - lambda_s0, lambda_s0 - lambda_s1)
+ let r0 = h_0 / h
+ let D0 = m0
+
+ // D1 = (1.0 / r0) * (m0 - m1)
+ let D1 = weightedSum(
+ [1/r0, -1/r0],
+ [m0, m1]
+ )
+
+ // See https://arxiv.org/abs/2211.01095 for detailed derivations
+ // x_t = (
+ // (sigma_t / sigma_s0) * sample
+ // - (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ // - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
+ // )
+ let x_t = weightedSum(
+ [p_sigma_t/sigma_s0, -p_alpha_t * (exp(-h) - 1), -0.5 * p_alpha_t * (exp(-h) - 1)],
+ [sample, D0, D1]
+ )
+ return x_t
+ }
+
+ public func step(output: MLShapedArray, timeStep t: Int, sample: MLShapedArray) -> MLShapedArray {
+ let stepIndex = timeSteps.firstIndex(of: t) ?? timeSteps.count - 1
+ let prevTimestep = stepIndex == timeSteps.count - 1 ? 0 : timeSteps[stepIndex + 1]
+
+ let lowerOrderFinal = useLowerOrderFinal && stepIndex == timeSteps.count - 1 && timeSteps.count < 15
+ let lowerOrderSecond = useLowerOrderFinal && stepIndex == timeSteps.count - 2 && timeSteps.count < 15
+ let lowerOrder = lowerOrderStepped < 1 || lowerOrderFinal || lowerOrderSecond
+
+ let modelOutput = convertModelOutput(modelOutput: output, timestep: t, sample: sample)
+ if modelOutputs.count == solverOrder { modelOutputs.removeFirst() }
+ modelOutputs.append(modelOutput)
+
+ let prevSample: MLShapedArray
+ if lowerOrder {
+ prevSample = firstOrderUpdate(modelOutput: modelOutput, timestep: t, prevTimestep: prevTimestep, sample: sample)
+ } else {
+ prevSample = secondOrderUpdate(
+ modelOutputs: modelOutputs,
+ timesteps: [timeSteps[stepIndex - 1], t],
+ prevTimestep: prevTimestep,
+ sample: sample
+ )
+ }
+ if lowerOrderStepped < solverOrder {
+ lowerOrderStepped += 1
+ }
+
+ return prevSample
+ }
+}
diff --git a/swift/StableDiffusion/pipeline/Decoder.swift b/swift/StableDiffusion/pipeline/Decoder.swift
index 2b55085d..04f04ba6 100644
--- a/swift/StableDiffusion/pipeline/Decoder.swift
+++ b/swift/StableDiffusion/pipeline/Decoder.swift
@@ -6,21 +6,31 @@ import CoreML
import Accelerate
/// A decoder model which produces RGB images from latent samples
-public struct Decoder {
+@available(iOS 16.2, macOS 13.1, *)
+public struct Decoder: ResourceManaging {
/// VAE decoder model
- var model: MLModel
+ var model: ManagedMLModel
/// Create decoder from Core ML model
///
- /// - Parameters
- /// - model: Core ML model for VAE decoder
- public init(model: MLModel) {
- self.model = model
+ /// - Parameters:
+ /// - url: Location of compiled VAE decoder Core ML model
+ /// - configuration: configuration to be used when the model is loaded
+ /// - Returns: A decoder that will lazily load its required resources when needed or requested
+ public init(modelAt url: URL, configuration: MLModelConfiguration) {
+ self.model = ManagedMLModel(modelAt: url, configuration: configuration)
}
- /// Prediction queue
- let queue = DispatchQueue(label: "decoder.predict")
+ /// Ensure the model has been loaded into memory
+ public func loadResources() throws {
+ try model.loadResources()
+ }
+
+ /// Unload the underlying model to free up memory
+ public func unloadResources() {
+ model.unloadResources()
+ }
/// Batch decode latent samples into images
///
@@ -42,7 +52,9 @@ public struct Decoder {
let batch = MLArrayBatchProvider(array: inputs)
// Batch predict with model
- let results = try queue.sync { try model.predictions(fromBatch: batch) }
+ let results = try model.perform { model in
+ try model.predictions(fromBatch: batch)
+ }
// Transform the outputs to CGImages
let images: [CGImage] = (0..
diff --git a/swift/StableDiffusion/pipeline/ManagedMLModel.swift b/swift/StableDiffusion/pipeline/ManagedMLModel.swift
new file mode 100644
index 00000000..5640a5f6
--- /dev/null
+++ b/swift/StableDiffusion/pipeline/ManagedMLModel.swift
@@ -0,0 +1,77 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright (C) 2022 Apple Inc. All Rights Reserved.
+
+import CoreML
+
+/// A class to manage and gate access to a Core ML model
+///
+/// It will automatically load a model into memory when needed or requested
+/// It allows one to request to unload the model from memory
+@available(iOS 16.2, macOS 13.1, *)
+public final class ManagedMLModel: ResourceManaging {
+
+ /// The location of the model
+ var modelURL: URL
+
+ /// The configuration to be used when the model is loaded
+ var configuration: MLModelConfiguration
+
+ /// The loaded model (when loaded)
+ var loadedModel: MLModel?
+
+ /// Queue to protect access to loaded model
+ var queue: DispatchQueue
+
+ /// Create a managed model given its location and desired loaded configuration
+ ///
+ /// - Parameters:
+ /// - url: The location of the model
+ /// - configuration: The configuration to be used when the model is loaded/used
+ /// - Returns: A managed model that has not been loaded
+ public init(modelAt url: URL, configuration: MLModelConfiguration) {
+ self.modelURL = url
+ self.configuration = configuration
+ self.loadedModel = nil
+ self.queue = DispatchQueue(label: "managed.\(url.lastPathComponent)")
+ }
+
+ /// Instantiation and load model into memory
+ public func loadResources() throws {
+ try queue.sync {
+ try loadModel()
+ }
+ }
+
+ /// Unload the model if it was loaded
+ public func unloadResources() {
+ queue.sync {
+ loadedModel = nil
+ }
+ }
+
+ /// Perform an operation with the managed model via a supplied closure.
+ /// The model will be loaded and supplied to the closure and should only be
+ /// used within the closure to ensure all resource management is synchronized
+ ///
+ /// - Parameters:
+ /// - body: Closure which performs and action on a loaded model
+ /// - Returns: The result of the closure
+ /// - Throws: An error if the model cannot be loaded or if the closure throws
+ public func perform(_ body: (MLModel) throws -> R) throws -> R {
+ return try queue.sync {
+ try autoreleasepool {
+ try loadModel()
+ return try body(loadedModel!)
+ }
+ }
+ }
+
+ private func loadModel() throws {
+ if loadedModel == nil {
+ loadedModel = try MLModel(contentsOf: modelURL,
+ configuration: configuration)
+ }
+ }
+
+
+}
diff --git a/swift/StableDiffusion/pipeline/Random.swift b/swift/StableDiffusion/pipeline/Random.swift
index 06846985..a1e8d355 100644
--- a/swift/StableDiffusion/pipeline/Random.swift
+++ b/swift/StableDiffusion/pipeline/Random.swift
@@ -9,6 +9,7 @@ import CoreML
/// This implementation matches:
/// [NumPy's older randomkit.c](https://github.com/numpy/numpy/blob/v1.0/numpy/random/mtrand/randomkit.c)
///
+@available(iOS 16.2, macOS 13.1, *)
struct NumPyRandomSource: RandomNumberGenerator {
struct State {
diff --git a/swift/StableDiffusion/pipeline/ResourceManaging.swift b/swift/StableDiffusion/pipeline/ResourceManaging.swift
new file mode 100644
index 00000000..3813487a
--- /dev/null
+++ b/swift/StableDiffusion/pipeline/ResourceManaging.swift
@@ -0,0 +1,20 @@
+// For licensing see accompanying LICENSE.md file.
+// Copyright (C) 2022 Apple Inc. All Rights Reserved.
+
+/// Protocol for managing internal resources
+public protocol ResourceManaging {
+
+ /// Request resources to be loaded and ready if possible
+ func loadResources() throws
+
+ /// Request resources are unloaded / remove from memory if possible
+ func unloadResources()
+}
+
+extension ResourceManaging {
+ /// Request resources are pre-warmed by loading and unloading
+ func prewarmResources() throws {
+ try loadResources()
+ unloadResources()
+ }
+}
diff --git a/swift/StableDiffusion/pipeline/SafetyChecker.swift b/swift/StableDiffusion/pipeline/SafetyChecker.swift
index e7b86418..fdc615e8 100644
--- a/swift/StableDiffusion/pipeline/SafetyChecker.swift
+++ b/swift/StableDiffusion/pipeline/SafetyChecker.swift
@@ -6,22 +6,31 @@ import CoreML
import Accelerate
/// Image safety checking model
-public struct SafetyChecker {
+@available(iOS 16.2, macOS 13.1, *)
+public struct SafetyChecker: ResourceManaging {
/// Safety checking Core ML model
- var model: MLModel
+ var model: ManagedMLModel
/// Creates safety checker
///
/// - Parameters:
- /// - model: Underlying model which performs the safety check
- /// - Returns: Safety checker ready from checks
- public init(model: MLModel) {
- self.model = model
+ /// - url: Location of compiled safety checking Core ML model
+ /// - configuration: configuration to be used when the model is loaded
+ /// - Returns: A safety cherker that will lazily load its required resources when needed or requested
+ public init(modelAt url: URL, configuration: MLModelConfiguration) {
+ self.model = ManagedMLModel(modelAt: url, configuration: configuration)
}
- /// Prediction queue
- let queue = DispatchQueue(label: "safetycheker.predict")
+ /// Ensure the model has been loaded into memory
+ public func loadResources() throws {
+ try model.loadResources()
+ }
+
+ /// Unload the underlying model to free up memory
+ public func unloadResources() {
+ model.unloadResources()
+ }
typealias PixelBufferPFx1 = vImage.PixelBuffer
typealias PixelBufferP8x1 = vImage.PixelBuffer
@@ -49,7 +58,9 @@ public struct SafetyChecker {
let adjustmentName = "adjustment"
let imagesNames = "images"
- let inputInfo = model.modelDescription.inputDescriptionsByName
+ let inputInfo = try model.perform { model in
+ model.modelDescription.inputDescriptionsByName
+ }
let inputShape = inputInfo[inputName]!.multiArrayConstraint!.shape
let width = inputShape[2].intValue
@@ -74,7 +85,9 @@ public struct SafetyChecker {
throw SafetyCheckError.modelInputFailure
}
- let result = try queue.sync { try model.prediction(from: input) }
+ let result = try model.perform { model in
+ try model.prediction(from: input)
+ }
let output = result.featureValue(for: "has_nsfw_concepts")
diff --git a/swift/StableDiffusion/pipeline/SampleTimer.swift b/swift/StableDiffusion/pipeline/SampleTimer.swift
index c427b0fc..e14e2917 100644
--- a/swift/StableDiffusion/pipeline/SampleTimer.swift
+++ b/swift/StableDiffusion/pipeline/SampleTimer.swift
@@ -18,6 +18,7 @@ import Foundation
/// print(String(format: "mean: %.2f, var: %.2f",
/// timer.mean, timer.variance))
/// ```
+@available(iOS 16.2, macOS 13.1, *)
public final class SampleTimer: Codable {
var startTime: CFAbsoluteTime?
var sum: Double = 0.0
diff --git a/swift/StableDiffusion/pipeline/Scheduler.swift b/swift/StableDiffusion/pipeline/Scheduler.swift
index 45966972..0bd92840 100644
--- a/swift/StableDiffusion/pipeline/Scheduler.swift
+++ b/swift/StableDiffusion/pipeline/Scheduler.swift
@@ -3,33 +3,98 @@
import CoreML
-/// A scheduler used to compute a de-noised image
-///
-/// This implementation matches:
-/// [Hugging Face Diffusers PNDMScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py)
-///
-/// It uses the pseudo linear multi-step (PLMS) method only, skipping pseudo Runge-Kutta (PRK) steps
-public final class Scheduler {
+@available(iOS 16.2, macOS 13.1, *)
+public protocol Scheduler {
/// Number of diffusion steps performed during training
- public let trainStepCount: Int
+ var trainStepCount: Int { get }
/// Number of inference steps to be performed
- public let inferenceStepCount: Int
+ var inferenceStepCount: Int { get }
/// Training diffusion time steps index by inference time step
- public let timeSteps: [Int]
+ var timeSteps: [Int] { get }
/// Schedule of betas which controls the amount of noise added at each timestep
- public let betas: [Float]
+ var betas: [Float] { get }
/// 1 - betas
- let alphas: [Float]
+ var alphas: [Float] { get }
/// Cached cumulative product of alphas
- let alphasCumProd: [Float]
+ var alphasCumProd: [Float] { get }
/// Standard deviation of the initial noise distribution
- public let initNoiseSigma: Float
+ var initNoiseSigma: Float { get }
+
+ /// Compute a de-noised image sample and step scheduler state
+ ///
+ /// - Parameters:
+ /// - output: The predicted residual noise output of learned diffusion model
+ /// - timeStep: The current time step in the diffusion chain
+ /// - sample: The current input sample to the diffusion model
+ /// - Returns: Predicted de-noised sample at the previous time step
+ /// - Postcondition: The scheduler state is updated.
+ /// The state holds the current sample and history of model output noise residuals
+ func step(
+ output: MLShapedArray,
+ timeStep t: Int,
+ sample s: MLShapedArray
+ ) -> MLShapedArray
+}
+
+@available(iOS 16.2, macOS 13.1, *)
+public extension Scheduler {
+ var initNoiseSigma: Float { 1 }
+}
+
+@available(iOS 16.2, macOS 13.1, *)
+public extension Scheduler {
+ /// Compute weighted sum of shaped arrays of equal shapes
+ ///
+ /// - Parameters:
+ /// - weights: The weights each array is multiplied by
+ /// - values: The arrays to be weighted and summed
+ /// - Returns: sum_i weights[i]*values[i]
+ func weightedSum(_ weights: [Double], _ values: [MLShapedArray]) -> MLShapedArray {
+ assert(weights.count > 1 && values.count == weights.count)
+ assert(values.allSatisfy({ $0.scalarCount == values.first!.scalarCount }))
+ var w = Float(weights.first!)
+ var scalars = values.first!.scalars.map({ $0 * w })
+ for next in 1 ..< values.count {
+ w = Float(weights[next])
+ let nextScalars = values[next].scalars
+ for i in 0 ..< scalars.count {
+ scalars[i] += w * nextScalars[i]
+ }
+ }
+ return MLShapedArray(scalars: scalars, shape: values.first!.shape)
+ }
+}
+
+/// How to map a beta range to a sequence of betas to step over
+@available(iOS 16.2, macOS 13.1, *)
+public enum BetaSchedule {
+ /// Linear stepping between start and end
+ case linear
+ /// Steps using linspace(sqrt(start),sqrt(end))^2
+ case scaledLinear
+}
+
+
+/// A scheduler used to compute a de-noised image
+///
+/// This implementation matches:
+/// [Hugging Face Diffusers PNDMScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py)
+///
+/// This scheduler uses the pseudo linear multi-step (PLMS) method only, skipping pseudo Runge-Kutta (PRK) steps
+@available(iOS 16.2, macOS 13.1, *)
+public final class PNDMScheduler: Scheduler {
+ public let trainStepCount: Int
+ public let inferenceStepCount: Int
+ public let betas: [Float]
+ public let alphas: [Float]
+ public let alphasCumProd: [Float]
+ public let timeSteps: [Int]
// Internal state
var counter: Int
@@ -61,15 +126,12 @@ public final class Scheduler {
case .scaledLinear:
self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 })
}
-
self.alphas = betas.map({ 1.0 - $0 })
- self.initNoiseSigma = 1.0
var alphasCumProd = self.alphas
for i in 1..]) -> MLShapedArray {
- assert(weights.count > 1 && values.count == weights.count)
- assert(values.allSatisfy({$0.scalarCount == values.first!.scalarCount}))
- var w = Float(weights.first!)
- var scalars = values.first!.scalars.map({ $0 * w })
- for next in 1 ..< values.count {
- w = Float(weights[next])
- let nextScalars = values[next].scalars
- for i in 0 ..< scalars.count {
- scalars[i] += w * nextScalars[i]
- }
- }
- return MLShapedArray(scalars: scalars, shape: values.first!.shape)
- }
-
/// Compute sample (denoised image) at previous step given a current time step
///
/// - Parameters:
@@ -224,16 +265,6 @@ public final class Scheduler {
}
}
-extension Scheduler {
- /// How to map a beta range to a sequence of betas to step over
- public enum BetaSchedule {
- /// Linear stepping between start and end
- case linear
- /// Steps using linspace(sqrt(start),sqrt(end))^2
- case scaledLinear
- }
-}
-
/// Evenly spaced floats between specified interval
///
/// - Parameters:
diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline+Resources.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline+Resources.swift
index 19c8bcf3..65c6e03c 100644
--- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline+Resources.swift
+++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline+Resources.swift
@@ -4,8 +4,32 @@
import Foundation
import CoreML
+@available(iOS 16.2, macOS 13.1, *)
public extension StableDiffusionPipeline {
+ struct ResourceURLs {
+
+ public let textEncoderURL: URL
+ public let unetURL: URL
+ public let unetChunk1URL: URL
+ public let unetChunk2URL: URL
+ public let decoderURL: URL
+ public let safetyCheckerURL: URL
+ public let vocabURL: URL
+ public let mergesURL: URL
+
+ public init(resourcesAt baseURL: URL) {
+ textEncoderURL = baseURL.appending(path: "TextEncoder.mlmodelc")
+ unetURL = baseURL.appending(path: "Unet.mlmodelc")
+ unetChunk1URL = baseURL.appending(path: "UnetChunk1.mlmodelc")
+ unetChunk2URL = baseURL.appending(path: "UnetChunk2.mlmodelc")
+ decoderURL = baseURL.appending(path: "VAEDecoder.mlmodelc")
+ safetyCheckerURL = baseURL.appending(path: "SafetyChecker.mlmodelc")
+ vocabURL = baseURL.appending(path: "vocab.json")
+ mergesURL = baseURL.appending(path: "merges.txt")
+ }
+ }
+
/// Create stable diffusion pipeline using model resources at a
/// specified URL
///
@@ -14,55 +38,48 @@ public extension StableDiffusionPipeline {
/// and tokenization resources
/// - configuration: The configuration to load model resources with
/// - disableSafety: Load time disable of safety to save memory
+ /// - reduceMemory: Setup pipeline in reduced memory mode
/// - Returns:
/// Pipeline ready for image generation if all necessary resources loaded
init(resourcesAt baseURL: URL,
configuration config: MLModelConfiguration = .init(),
- disableSafety: Bool = false) throws {
+ disableSafety: Bool = false,
+ reduceMemory: Bool = false) throws {
/// Expect URL of each resource
- let textEncoderURL = baseURL.appending(path: "TextEncoder.mlmodelc")
- let unetURL = baseURL.appending(path: "Unet.mlmodelc")
- let unetChunk1URL = baseURL.appending(path: "UnetChunk1.mlmodelc")
- let unetChunk2URL = baseURL.appending(path: "UnetChunk2.mlmodelc")
- let decoderURL = baseURL.appending(path: "VAEDecoder.mlmodelc")
- let safetyCheckerURL = baseURL.appending(path: "SafetyChecker.mlmodelc")
- let vocabURL = baseURL.appending(path: "vocab.json")
- let mergesURL = baseURL.appending(path: "merges.txt")
+ let urls = ResourceURLs(resourcesAt: baseURL)
// Text tokenizer and encoder
- let tokenizer = try BPETokenizer(mergesAt: mergesURL, vocabularyAt: vocabURL)
- let textEncoderModel = try MLModel(contentsOf: textEncoderURL, configuration: config)
- let textEncoder = TextEncoder(tokenizer: tokenizer, model:textEncoderModel )
+ let tokenizer = try BPETokenizer(mergesAt: urls.mergesURL, vocabularyAt: urls.vocabURL)
+ let textEncoder = TextEncoder(tokenizer: tokenizer,
+ modelAt: urls.textEncoderURL,
+ configuration: config)
// Unet model
let unet: Unet
- if FileManager.default.fileExists(atPath: unetChunk1URL.path) &&
- FileManager.default.fileExists(atPath: unetChunk2URL.path) {
- let chunk1 = try MLModel(contentsOf: unetChunk1URL, configuration: config)
- let chunk2 = try MLModel(contentsOf: unetChunk2URL, configuration: config)
- unet = Unet(chunks: [chunk1, chunk2])
+ if FileManager.default.fileExists(atPath: urls.unetChunk1URL.path) &&
+ FileManager.default.fileExists(atPath: urls.unetChunk2URL.path) {
+ unet = Unet(chunksAt: [urls.unetChunk1URL, urls.unetChunk2URL],
+ configuration: config)
} else {
- let unetModel = try MLModel(contentsOf: unetURL, configuration: config)
- unet = Unet(model: unetModel)
+ unet = Unet(modelAt: urls.unetURL, configuration: config)
}
// Image Decoder
- let decoderModel = try MLModel(contentsOf: decoderURL, configuration: config)
- let decoder = Decoder(model: decoderModel)
+ let decoder = Decoder(modelAt: urls.decoderURL, configuration: config)
// Optional safety checker
var safetyChecker: SafetyChecker? = nil
if !disableSafety &&
- FileManager.default.fileExists(atPath: safetyCheckerURL.path) {
- let checkerModel = try MLModel(contentsOf: safetyCheckerURL, configuration: config)
- safetyChecker = SafetyChecker(model: checkerModel)
+ FileManager.default.fileExists(atPath: urls.safetyCheckerURL.path) {
+ safetyChecker = SafetyChecker(modelAt: urls.safetyCheckerURL, configuration: config)
}
- // Construct pipelien
+ // Construct pipeline
self.init(textEncoder: textEncoder,
unet: unet,
decoder: decoder,
- safetyChecker: safetyChecker)
+ safetyChecker: safetyChecker,
+ reduceMemory: reduceMemory)
}
}
diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift
index bc888cc3..ea654723 100644
--- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift
+++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift
@@ -6,11 +6,20 @@ import CoreML
import Accelerate
import CoreGraphics
+/// Schedulers compatible with StableDiffusionPipeline
+public enum StableDiffusionScheduler {
+ /// Scheduler that uses a pseudo-linear multi-step (PLMS) method
+ case pndmScheduler
+ /// Scheduler that uses a second order DPM-Solver++ algorithm
+ case dpmSolverMultistepScheduler
+}
+
/// A pipeline used to generate image samples from text input using stable diffusion
///
/// This implementation matches:
/// [Hugging Face Diffusers Pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py)
-public struct StableDiffusionPipeline {
+@available(iOS 16.2, macOS 13.1, *)
+public struct StableDiffusionPipeline: ResourceManaging {
/// Model to generate embeddings for tokenized input text
var textEncoder: TextEncoder
@@ -24,14 +33,19 @@ public struct StableDiffusionPipeline {
/// Optional model for checking safety of generated image
var safetyChecker: SafetyChecker? = nil
- /// Controls the influence of the text prompt on sampling process (0=random images)
- var guidanceScale: Float = 7.5
-
/// Reports whether this pipeline can perform safety checks
public var canSafetyCheck: Bool {
safetyChecker != nil
}
+ /// Option to reduce memory during image generation
+ ///
+ /// If true, the pipeline will lazily load TextEncoder, Unet, Decoder, and SafetyChecker
+ /// when needed and aggressively unload their resources after
+ ///
+ /// This will increase latency in favor of reducing memory
+ var reduceMemory: Bool = false
+
/// Creates a pipeline using the specified models and tokenizer
///
/// - Parameters:
@@ -39,54 +53,100 @@ public struct StableDiffusionPipeline {
/// - unet: Model for noise prediction on latent samples
/// - decoder: Model for decoding latent sample to image
/// - safetyChecker: Optional model for checking safety of generated images
- /// - guidanceScale: Influence of the text prompt on generation process
+ /// - reduceMemory: Option to enable reduced memory mode
/// - Returns: Pipeline ready for image generation
public init(textEncoder: TextEncoder,
unet: Unet,
decoder: Decoder,
safetyChecker: SafetyChecker? = nil,
- guidanceScale: Float = 7.5) {
+ reduceMemory: Bool = false) {
self.textEncoder = textEncoder
self.unet = unet
self.decoder = decoder
self.safetyChecker = safetyChecker
- self.guidanceScale = guidanceScale
+ self.reduceMemory = reduceMemory
+ }
+
+ /// Load required resources for this pipeline
+ ///
+ /// If reducedMemory is true this will instead call prewarmResources instead
+ /// and let the pipeline lazily load resources as needed
+ public func loadResources() throws {
+ if reduceMemory {
+ try prewarmResources()
+ } else {
+ try textEncoder.loadResources()
+ try unet.loadResources()
+ try decoder.loadResources()
+ try safetyChecker?.loadResources()
+ }
+ }
+
+ /// Unload the underlying resources to free up memory
+ public func unloadResources() {
+ textEncoder.unloadResources()
+ unet.unloadResources()
+ decoder.unloadResources()
+ safetyChecker?.unloadResources()
+ }
+
+ // Prewarm resources one at a time
+ public func prewarmResources() throws {
+ try textEncoder.prewarmResources()
+ try unet.prewarmResources()
+ try decoder.prewarmResources()
+ try safetyChecker?.prewarmResources()
}
/// Text to image generation using stable diffusion
///
/// - Parameters:
/// - prompt: Text prompt to guide sampling
+ /// - negativePrompt: Negative text prompt to guide sampling
/// - stepCount: Number of inference steps to perform
/// - imageCount: Number of samples/images to generate for the input prompt
/// - seed: Random seed which
+ /// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images)
/// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
/// - progressHandler: Callback to perform after each step, stops on receiving false response
/// - Returns: An array of `imageCount` optional images.
/// The images will be nil if safety checks were performed and found the result to be un-safe
public func generateImages(
prompt: String,
+ negativePrompt: String = "",
imageCount: Int = 1,
stepCount: Int = 50,
- seed: Int = 0,
+ seed: UInt32 = 0,
+ guidanceScale: Float = 7.5,
disableSafety: Bool = false,
+ scheduler: StableDiffusionScheduler = .pndmScheduler,
progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] {
- // Encode the input prompt as well as a blank unconditioned input
+ // Encode the input prompt and negative prompt
let promptEmbedding = try textEncoder.encode(prompt)
- let blankEmbedding = try textEncoder.encode("")
+ let negativePromptEmbedding = try textEncoder.encode(negativePrompt)
+
+ if reduceMemory {
+ textEncoder.unloadResources()
+ }
// Convert to Unet hidden state representation
+ // Concatenate the prompt and negative prompt embeddings
let concatEmbedding = MLShapedArray(
- concatenating: [blankEmbedding, promptEmbedding],
+ concatenating: [negativePromptEmbedding, promptEmbedding],
alongAxis: 0
)
let hiddenStates = toHiddenStates(concatEmbedding)
/// Setup schedulers
- let scheduler = (0.. [MLShapedArray] {
+ func generateLatentSamples(_ count: Int, stdev: Float, seed: UInt32) -> [MLShapedArray] {
var sampleShape = unet.latentSampleShape
sampleShape[0] = 1
- var random = NumPyRandomSource(seed: UInt32(seed))
+ var random = NumPyRandomSource(seed: seed)
let samples = (0..(
converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev)))
@@ -168,11 +232,11 @@ public struct StableDiffusionPipeline {
return states
}
- func performGuidance(_ noise: [MLShapedArray]) -> [MLShapedArray] {
- noise.map { performGuidance($0) }
+ func performGuidance(_ noise: [MLShapedArray], _ guidanceScale: Float) -> [MLShapedArray] {
+ noise.map { performGuidance($0, guidanceScale) }
}
- func performGuidance(_ noise: MLShapedArray) -> MLShapedArray {
+ func performGuidance(_ noise: MLShapedArray, _ guidanceScale: Float) -> MLShapedArray {
let blankNoiseScalars = noise[0].scalars
let textNoiseScalars = noise[1].scalars
@@ -192,8 +256,10 @@ public struct StableDiffusionPipeline {
func decodeToImages(_ latents: [MLShapedArray],
disableSafety: Bool) throws -> [CGImage?] {
-
let images = try decoder.decode(latents)
+ if reduceMemory {
+ decoder.unloadResources()
+ }
// If safety is disabled return what was decoded
if disableSafety {
@@ -210,11 +276,16 @@ public struct StableDiffusionPipeline {
try safetyChecker.isSafe(image) ? image : nil
}
+ if reduceMemory {
+ safetyChecker.unloadResources()
+ }
+
return safeImages
}
}
+@available(iOS 16.2, macOS 13.1, *)
extension StableDiffusionPipeline {
/// Sampling progress details
public struct Progress {
diff --git a/swift/StableDiffusion/pipeline/TextEncoder.swift b/swift/StableDiffusion/pipeline/TextEncoder.swift
index 1271cb40..b9497e26 100644
--- a/swift/StableDiffusion/pipeline/TextEncoder.swift
+++ b/swift/StableDiffusion/pipeline/TextEncoder.swift
@@ -5,22 +5,37 @@ import Foundation
import CoreML
/// A model for encoding text
-public struct TextEncoder {
+@available(iOS 16.2, macOS 13.1, *)
+public struct TextEncoder: ResourceManaging {
/// Text tokenizer
var tokenizer: BPETokenizer
/// Embedding model
- var model: MLModel
+ var model: ManagedMLModel
/// Creates text encoder which embeds a tokenized string
///
/// - Parameters:
/// - tokenizer: Tokenizer for input text
- /// - model: Model for encoding tokenized text
- public init(tokenizer: BPETokenizer, model: MLModel) {
+ /// - url: Location of compiled text encoding Core ML model
+ /// - configuration: configuration to be used when the model is loaded
+ /// - Returns: A text encoder that will lazily load its required resources when needed or requested
+ public init(tokenizer: BPETokenizer,
+ modelAt url: URL,
+ configuration: MLModelConfiguration) {
self.tokenizer = tokenizer
- self.model = model
+ self.model = ManagedMLModel(modelAt: url, configuration: configuration)
+ }
+
+ /// Ensure the model has been loaded into memory
+ public func loadResources() throws {
+ try model.loadResources()
+ }
+
+ /// Unload the underlying model to free up memory
+ public func unloadResources() {
+ model.unloadResources()
}
/// Encode input text/string
@@ -60,13 +75,18 @@ public struct TextEncoder {
let inputFeatures = try! MLDictionaryFeatureProvider(
dictionary: [inputName: MLMultiArray(inputArray)])
- let result = try queue.sync { try model.prediction(from: inputFeatures) }
+ let result = try model.perform { model in
+ try model.prediction(from: inputFeatures)
+ }
+
let embeddingFeature = result.featureValue(for: "last_hidden_state")
return MLShapedArray(converting: embeddingFeature!.multiArrayValue!)
}
var inputDescription: MLFeatureDescription {
- model.modelDescription.inputDescriptionsByName.first!.value
+ try! model.perform { model in
+ model.modelDescription.inputDescriptionsByName.first!.value
+ }
}
var inputShape: [Int] {
diff --git a/swift/StableDiffusion/pipeline/Unet.swift b/swift/StableDiffusion/pipeline/Unet.swift
index b1611779..bf873a2d 100644
--- a/swift/StableDiffusion/pipeline/Unet.swift
+++ b/swift/StableDiffusion/pipeline/Unet.swift
@@ -5,33 +5,63 @@ import Foundation
import CoreML
/// U-Net noise prediction model for stable diffusion
-public struct Unet {
+@available(iOS 16.2, macOS 13.1, *)
+public struct Unet: ResourceManaging {
/// Model used to predict noise residuals given an input, diffusion time step, and conditional embedding
///
/// It can be in the form of a single model or multiple stages
- var models: [MLModel]
+ var models: [ManagedMLModel]
/// Creates a U-Net noise prediction model
///
/// - Parameters:
- /// - model: U-Net held in single Core ML model
- /// - Returns: Ready for prediction
- public init(model: MLModel) {
- self.models = [model]
+ /// - url: Location of single U-Net compiled Core ML model
+ /// - configuration: Configuration to be used when the model is loaded
+ /// - Returns: U-net model that will lazily load its required resources when needed or requested
+ public init(modelAt url: URL,
+ configuration: MLModelConfiguration) {
+ self.models = [ManagedMLModel(modelAt: url, configuration: configuration)]
}
/// Creates a U-Net noise prediction model
///
/// - Parameters:
- /// - chunks: U-Net held chunked into multiple Core ML models
- /// - Returns: Ready for prediction
- public init(chunks: [MLModel]) {
- self.models = chunks
+ /// - urls: Location of chunked U-Net via urls to each compiled chunk
+ /// - configuration: Configuration to be used when the model is loaded
+ /// - Returns: U-net model that will lazily load its required resources when needed or requested
+ public init(chunksAt urls: [URL],
+ configuration: MLModelConfiguration) {
+ self.models = urls.map { ManagedMLModel(modelAt: $0, configuration: configuration) }
+ }
+
+ /// Load resources.
+ public func loadResources() throws {
+ for model in models {
+ try model.loadResources()
+ }
+ }
+
+ /// Unload the underlying model to free up memory
+ public func unloadResources() {
+ for model in models {
+ model.unloadResources()
+ }
+ }
+
+ /// Pre-warm resources
+ public func prewarmResources() throws {
+ // Override default to pre-warm each model
+ for model in models {
+ try model.loadResources()
+ model.unloadResources()
+ }
}
var latentSampleDescription: MLFeatureDescription {
- models.first!.modelDescription.inputDescriptionsByName["sample"]!
+ try! models.first!.perform { model in
+ model.modelDescription.inputDescriptionsByName["sample"]!
+ }
}
/// The expected shape of the models latent sample input
@@ -91,13 +121,10 @@ public struct Unet {
return noise
}
- /// Prediction queue
- let queue = DispatchQueue(label: "unet.predict")
-
func predictions(from batch: MLBatchProvider) throws -> MLBatchProvider {
- var results = try queue.sync {
- try models.first!.predictions(fromBatch: batch)
+ var results = try models.first!.perform { model in
+ try model.predictions(fromBatch: batch)
}
if models.count == 1 {
@@ -117,8 +144,8 @@ public struct Unet {
let nextBatch = MLArrayBatchProvider(array: next)
// Predict
- results = try queue.sync {
- try stage.predictions(fromBatch: nextBatch)
+ results = try stage.perform { model in
+ try model.predictions(fromBatch: nextBatch)
}
}
diff --git a/swift/StableDiffusion/tokenizer/BPETokenizer+Reading.swift b/swift/StableDiffusion/tokenizer/BPETokenizer+Reading.swift
index 21c7ae51..cc8c91d9 100644
--- a/swift/StableDiffusion/tokenizer/BPETokenizer+Reading.swift
+++ b/swift/StableDiffusion/tokenizer/BPETokenizer+Reading.swift
@@ -3,6 +3,7 @@
import Foundation
+@available(iOS 16.2, macOS 13.1, *)
extension BPETokenizer {
enum FileReadError: Error {
case invalidMergeFileLine(Int)
diff --git a/swift/StableDiffusion/tokenizer/BPETokenizer.swift b/swift/StableDiffusion/tokenizer/BPETokenizer.swift
index 2789c979..799cbe56 100644
--- a/swift/StableDiffusion/tokenizer/BPETokenizer.swift
+++ b/swift/StableDiffusion/tokenizer/BPETokenizer.swift
@@ -4,6 +4,7 @@
import Foundation
/// A tokenizer based on byte pair encoding.
+@available(iOS 16.2, macOS 13.1, *)
public struct BPETokenizer {
/// A dictionary that maps pairs of tokens to the rank/order of the merge.
let merges: [TokenPair : Int]
@@ -166,6 +167,7 @@ public struct BPETokenizer {
}
}
+@available(iOS 16.2, macOS 13.1, *)
extension BPETokenizer {
/// A hashable tuple of strings
diff --git a/swift/StableDiffusionCLI/main.swift b/swift/StableDiffusionCLI/main.swift
index 6b09af62..7471316f 100644
--- a/swift/StableDiffusionCLI/main.swift
+++ b/swift/StableDiffusionCLI/main.swift
@@ -8,6 +8,7 @@ import Foundation
import StableDiffusion
import UniformTypeIdentifiers
+@available(iOS 16.2, macOS 13.1, *)
struct StableDiffusionSample: ParsableCommand {
static let configuration = CommandConfiguration(
@@ -18,6 +19,9 @@ struct StableDiffusionSample: ParsableCommand {
@Argument(help: "Input string prompt")
var prompt: String
+ @Option(help: "Input string negative prompt")
+ var negativePrompt: String
+
@Option(
help: ArgumentHelp(
"Path to stable diffusion resources.",
@@ -47,14 +51,23 @@ struct StableDiffusionSample: ParsableCommand {
var outputPath: String = "./"
@Option(help: "Random seed")
- var seed: Int = 93
+ var seed: UInt32 = 93
+
+ @Option(help: "Controls the influence of the text prompt on sampling process (0=random images)")
+ var guidanceScale: Float = 7.5
@Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}")
var computeUnits: ComputeUnits = .all
+ @Option(help: "Scheduler to use, one of {pndm, dpmpp}")
+ var scheduler: SchedulerOption = .pndm
+
@Flag(help: "Disable safety checking")
var disableSafety: Bool = false
+ @Flag(help: "Reduce memory usage")
+ var reduceMemory: Bool = false
+
mutating func run() throws {
guard FileManager.default.fileExists(atPath: resourcePath) else {
throw RunError.resources("Resource path does not exist \(resourcePath)")
@@ -68,7 +81,9 @@ struct StableDiffusionSample: ParsableCommand {
log("(Note: This can take a while the first time using these resources)\n")
let pipeline = try StableDiffusionPipeline(resourcesAt: resourceURL,
configuration: config,
- disableSafety: disableSafety)
+ disableSafety: disableSafety,
+ reduceMemory: reduceMemory)
+ try pipeline.loadResources()
log("Sampling ...\n")
let sampleTimer = SampleTimer()
@@ -76,9 +91,12 @@ struct StableDiffusionSample: ParsableCommand {
let images = try pipeline.generateImages(
prompt: prompt,
+ negativePrompt: negativePrompt,
imageCount: imageCount,
stepCount: stepCount,
- seed: seed
+ seed: seed,
+ guidanceScale: guidanceScale,
+ scheduler: scheduler.stableDiffusionScheduler
) { progress in
sampleTimer.stop()
handleProgress(progress,sampleTimer)
@@ -145,7 +163,8 @@ struct StableDiffusionSample: ParsableCommand {
}
func imageName(_ sample: Int, step: Int? = nil) -> String {
- var name = prompt.replacingOccurrences(of: " ", with: "_")
+ let fileCharLimit = 75
+ var name = prompt.prefix(fileCharLimit).replacingOccurrences(of: " ", with: "_")
if imageCount != 1 {
name += ".\(sample)"
}
@@ -171,6 +190,7 @@ enum RunError: Error {
case saving(String)
}
+@available(iOS 16.2, macOS 13.1, *)
enum ComputeUnits: String, ExpressibleByArgument, CaseIterable {
case all, cpuAndGPU, cpuOnly, cpuAndNeuralEngine
var asMLComputeUnits: MLComputeUnits {
@@ -183,4 +203,19 @@ enum ComputeUnits: String, ExpressibleByArgument, CaseIterable {
}
}
-StableDiffusionSample.main()
+@available(iOS 16.2, macOS 13.1, *)
+enum SchedulerOption: String, ExpressibleByArgument {
+ case pndm, dpmpp
+ var stableDiffusionScheduler: StableDiffusionScheduler {
+ switch self {
+ case .pndm: return .pndmScheduler
+ case .dpmpp: return .dpmSolverMultistepScheduler
+ }
+ }
+}
+
+if #available(iOS 16.2, macOS 13.1, *) {
+ StableDiffusionSample.main()
+} else {
+ print("Unsupported OS")
+}
diff --git a/swift/StableDiffusionTests/StableDiffusionTests.swift b/swift/StableDiffusionTests/StableDiffusionTests.swift
index c3b54cd3..15cf1a5a 100644
--- a/swift/StableDiffusionTests/StableDiffusionTests.swift
+++ b/swift/StableDiffusionTests/StableDiffusionTests.swift
@@ -5,6 +5,7 @@ import XCTest
import CoreML
@testable import StableDiffusion
+@available(iOS 16.2, macOS 13.1, *)
final class StableDiffusionTests: XCTestCase {
var vocabFileInBundleURL: URL {
diff --git a/tests/test_stable_diffusion.py b/tests/test_stable_diffusion.py
index b5e79dc4..a2e42a18 100644
--- a/tests/test_stable_diffusion.py
+++ b/tests/test_stable_diffusion.py
@@ -74,23 +74,23 @@ def test_torch_to_coreml_conversion(self):
with self.subTest(model="vae_decoder"):
logger.info("Converting vae_decoder")
torch2coreml.convert_vae_decoder(self.pytorch_pipe, self.cli_args)
- logger.info("Successfuly converted vae_decoder")
+ logger.info("Successfully converted vae_decoder")
with self.subTest(model="unet"):
logger.info("Converting unet")
torch2coreml.convert_unet(self.pytorch_pipe, self.cli_args)
- logger.info("Successfuly converted unet")
+ logger.info("Successfully converted unet")
with self.subTest(model="text_encoder"):
logger.info("Converting text_encoder")
torch2coreml.convert_text_encoder(self.pytorch_pipe, self.cli_args)
- logger.info("Successfuly converted text_encoder")
+ logger.info("Successfully converted text_encoder")
with self.subTest(model="safety_checker"):
logger.info("Converting safety_checker")
torch2coreml.convert_safety_checker(self.pytorch_pipe,
self.cli_args)
- logger.info("Successfuly converted safety_checker")
+ logger.info("Successfully converted safety_checker")
def test_end_to_end_image_generation_speed(self):
""" Tests: