forked from apple/ml-stable-diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge commit 'e07c4d00c387840f70fa3701fb3a51c2a32f37b8'
* commit 'e07c4d00c387840f70fa3701fb3a51c2a32f37b8': Move guidanceScale as generation parameter (apple#46) Add brief instructions to download weights from the Hub (apple#10) Adds Negative Prompts (apple#61) Changed seed type into UInt32 (apple#47) fixes apple#77 Update README.md (apple#66) Add Filename Character Limit (apple#19) Implement DPM-Solver++ scheduler (apple#59) Fix typos: Successfully facilitate getting pipeline overridden (apple#30) Undefined name: from typing import List (apple#31) Add Availability Annotations (apple#18) README improvements and reduceMemory option in Swift
- Loading branch information
Showing
22 changed files
with
812 additions
and
178 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
"torch", | ||
"transformers", | ||
"scipy", | ||
"numpy<1.24", | ||
], | ||
packages=find_packages(), | ||
classifiers=[ | ||
|
182 changes: 182 additions & 0 deletions
182
swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
// For licensing see accompanying LICENSE.md file. | ||
// Copyright (C) 2022 Apple Inc. and The HuggingFace Team. All Rights Reserved. | ||
|
||
import Accelerate | ||
import CoreML | ||
|
||
/// A scheduler used to compute a de-noised image | ||
/// | ||
/// This implementation matches: | ||
/// [Hugging Face Diffusers DPMSolverMultistepScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py) | ||
/// | ||
/// It uses the DPM-Solver++ algorithm: [code](https://github.com/LuChengTHU/dpm-solver) [paper](https://arxiv.org/abs/2211.01095). | ||
/// Limitations: | ||
/// - Only implemented for DPM-Solver++ algorithm (not DPM-Solver). | ||
/// - Second order only. | ||
/// - Assumes the model predicts epsilon. | ||
/// - No dynamic thresholding. | ||
/// - `midpoint` solver algorithm. | ||
@available(iOS 16.2, macOS 13.1, *) | ||
public final class DPMSolverMultistepScheduler: Scheduler { | ||
public let trainStepCount: Int | ||
public let inferenceStepCount: Int | ||
public let betas: [Float] | ||
public let alphas: [Float] | ||
public let alphasCumProd: [Float] | ||
public let timeSteps: [Int] | ||
|
||
public let alpha_t: [Float] | ||
public let sigma_t: [Float] | ||
public let lambda_t: [Float] | ||
|
||
public let solverOrder = 2 | ||
private(set) var lowerOrderStepped = 0 | ||
|
||
/// Whether to use lower-order solvers in the final steps. Only valid for less than 15 inference steps. | ||
/// We empirically find this trick can stabilize the sampling of DPM-Solver, especially with 10 or fewer steps. | ||
public let useLowerOrderFinal = true | ||
|
||
// Stores solverOrder (2) items | ||
private(set) var modelOutputs: [MLShapedArray<Float32>] = [] | ||
|
||
/// Create a scheduler that uses a second order DPM-Solver++ algorithm. | ||
/// | ||
/// - Parameters: | ||
/// - stepCount: Number of inference steps to schedule | ||
/// - trainStepCount: Number of training diffusion steps | ||
/// - betaSchedule: Method to schedule betas from betaStart to betaEnd | ||
/// - betaStart: The starting value of beta for inference | ||
/// - betaEnd: The end value for beta for inference | ||
/// - Returns: A scheduler ready for its first step | ||
public init( | ||
stepCount: Int = 50, | ||
trainStepCount: Int = 1000, | ||
betaSchedule: BetaSchedule = .scaledLinear, | ||
betaStart: Float = 0.00085, | ||
betaEnd: Float = 0.012 | ||
) { | ||
self.trainStepCount = trainStepCount | ||
self.inferenceStepCount = stepCount | ||
|
||
switch betaSchedule { | ||
case .linear: | ||
self.betas = linspace(betaStart, betaEnd, trainStepCount) | ||
case .scaledLinear: | ||
self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 }) | ||
} | ||
|
||
self.alphas = betas.map({ 1.0 - $0 }) | ||
var alphasCumProd = self.alphas | ||
for i in 1..<alphasCumProd.count { | ||
alphasCumProd[i] *= alphasCumProd[i - 1] | ||
} | ||
self.alphasCumProd = alphasCumProd | ||
|
||
// Currently we only support VP-type noise shedule | ||
self.alpha_t = vForce.sqrt(self.alphasCumProd) | ||
self.sigma_t = vForce.sqrt(vDSP.subtract([Float](repeating: 1, count: self.alphasCumProd.count), self.alphasCumProd)) | ||
self.lambda_t = zip(self.alpha_t, self.sigma_t).map { α, σ in log(α) - log(σ) } | ||
|
||
self.timeSteps = linspace(0, Float(self.trainStepCount-1), stepCount).reversed().map { Int(round($0)) } | ||
} | ||
|
||
/// Convert the model output to the corresponding type the algorithm needs. | ||
/// This implementation is for second-order DPM-Solver++ assuming epsilon prediction. | ||
func convertModelOutput(modelOutput: MLShapedArray<Float32>, timestep: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> { | ||
assert(modelOutput.scalars.count == sample.scalars.count) | ||
let (alpha_t, sigma_t) = (self.alpha_t[timestep], self.sigma_t[timestep]) | ||
|
||
// This could be optimized with a Metal kernel if we find we need to | ||
let x0_scalars = zip(modelOutput.scalars, sample.scalars).map { m, s in | ||
(s - m * sigma_t) / alpha_t | ||
} | ||
return MLShapedArray(scalars: x0_scalars, shape: modelOutput.shape) | ||
} | ||
|
||
/// One step for the first-order DPM-Solver (equivalent to DDIM). | ||
/// See https://arxiv.org/abs/2206.00927 for the detailed derivation. | ||
/// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | ||
func firstOrderUpdate( | ||
modelOutput: MLShapedArray<Float32>, | ||
timestep: Int, | ||
prevTimestep: Int, | ||
sample: MLShapedArray<Float32> | ||
) -> MLShapedArray<Float32> { | ||
let (p_lambda_t, lambda_s) = (Double(lambda_t[prevTimestep]), Double(lambda_t[timestep])) | ||
let p_alpha_t = Double(alpha_t[prevTimestep]) | ||
let (p_sigma_t, sigma_s) = (Double(sigma_t[prevTimestep]), Double(sigma_t[timestep])) | ||
let h = p_lambda_t - lambda_s | ||
// x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output | ||
let x_t = weightedSum( | ||
[p_sigma_t / sigma_s, -p_alpha_t * (exp(-h) - 1)], | ||
[sample, modelOutput] | ||
) | ||
return x_t | ||
} | ||
|
||
/// One step for the second-order multistep DPM-Solver++ algorithm, using the midpoint method. | ||
/// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | ||
func secondOrderUpdate( | ||
modelOutputs: [MLShapedArray<Float32>], | ||
timesteps: [Int], | ||
prevTimestep t: Int, | ||
sample: MLShapedArray<Float32> | ||
) -> MLShapedArray<Float32> { | ||
let (s0, s1) = (timesteps[back: 1], timesteps[back: 2]) | ||
let (m0, m1) = (modelOutputs[back: 1], modelOutputs[back: 2]) | ||
let (p_lambda_t, lambda_s0, lambda_s1) = (Double(lambda_t[t]), Double(lambda_t[s0]), Double(lambda_t[s1])) | ||
let p_alpha_t = Double(alpha_t[t]) | ||
let (p_sigma_t, sigma_s0) = (Double(sigma_t[t]), Double(sigma_t[s0])) | ||
let (h, h_0) = (p_lambda_t - lambda_s0, lambda_s0 - lambda_s1) | ||
let r0 = h_0 / h | ||
let D0 = m0 | ||
|
||
// D1 = (1.0 / r0) * (m0 - m1) | ||
let D1 = weightedSum( | ||
[1/r0, -1/r0], | ||
[m0, m1] | ||
) | ||
|
||
// See https://arxiv.org/abs/2211.01095 for detailed derivations | ||
// x_t = ( | ||
// (sigma_t / sigma_s0) * sample | ||
// - (alpha_t * (torch.exp(-h) - 1.0)) * D0 | ||
// - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 | ||
// ) | ||
let x_t = weightedSum( | ||
[p_sigma_t/sigma_s0, -p_alpha_t * (exp(-h) - 1), -0.5 * p_alpha_t * (exp(-h) - 1)], | ||
[sample, D0, D1] | ||
) | ||
return x_t | ||
} | ||
|
||
public func step(output: MLShapedArray<Float32>, timeStep t: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> { | ||
let stepIndex = timeSteps.firstIndex(of: t) ?? timeSteps.count - 1 | ||
let prevTimestep = stepIndex == timeSteps.count - 1 ? 0 : timeSteps[stepIndex + 1] | ||
|
||
let lowerOrderFinal = useLowerOrderFinal && stepIndex == timeSteps.count - 1 && timeSteps.count < 15 | ||
let lowerOrderSecond = useLowerOrderFinal && stepIndex == timeSteps.count - 2 && timeSteps.count < 15 | ||
let lowerOrder = lowerOrderStepped < 1 || lowerOrderFinal || lowerOrderSecond | ||
|
||
let modelOutput = convertModelOutput(modelOutput: output, timestep: t, sample: sample) | ||
if modelOutputs.count == solverOrder { modelOutputs.removeFirst() } | ||
modelOutputs.append(modelOutput) | ||
|
||
let prevSample: MLShapedArray<Float32> | ||
if lowerOrder { | ||
prevSample = firstOrderUpdate(modelOutput: modelOutput, timestep: t, prevTimestep: prevTimestep, sample: sample) | ||
} else { | ||
prevSample = secondOrderUpdate( | ||
modelOutputs: modelOutputs, | ||
timesteps: [timeSteps[stepIndex - 1], t], | ||
prevTimestep: prevTimestep, | ||
sample: sample | ||
) | ||
} | ||
if lowerOrderStepped < solverOrder { | ||
lowerOrderStepped += 1 | ||
} | ||
|
||
return prevSample | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// For licensing see accompanying LICENSE.md file. | ||
// Copyright (C) 2022 Apple Inc. All Rights Reserved. | ||
|
||
import CoreML | ||
|
||
/// A class to manage and gate access to a Core ML model | ||
/// | ||
/// It will automatically load a model into memory when needed or requested | ||
/// It allows one to request to unload the model from memory | ||
@available(iOS 16.2, macOS 13.1, *) | ||
public final class ManagedMLModel: ResourceManaging { | ||
|
||
/// The location of the model | ||
var modelURL: URL | ||
|
||
/// The configuration to be used when the model is loaded | ||
var configuration: MLModelConfiguration | ||
|
||
/// The loaded model (when loaded) | ||
var loadedModel: MLModel? | ||
|
||
/// Queue to protect access to loaded model | ||
var queue: DispatchQueue | ||
|
||
/// Create a managed model given its location and desired loaded configuration | ||
/// | ||
/// - Parameters: | ||
/// - url: The location of the model | ||
/// - configuration: The configuration to be used when the model is loaded/used | ||
/// - Returns: A managed model that has not been loaded | ||
public init(modelAt url: URL, configuration: MLModelConfiguration) { | ||
self.modelURL = url | ||
self.configuration = configuration | ||
self.loadedModel = nil | ||
self.queue = DispatchQueue(label: "managed.\(url.lastPathComponent)") | ||
} | ||
|
||
/// Instantiation and load model into memory | ||
public func loadResources() throws { | ||
try queue.sync { | ||
try loadModel() | ||
} | ||
} | ||
|
||
/// Unload the model if it was loaded | ||
public func unloadResources() { | ||
queue.sync { | ||
loadedModel = nil | ||
} | ||
} | ||
|
||
/// Perform an operation with the managed model via a supplied closure. | ||
/// The model will be loaded and supplied to the closure and should only be | ||
/// used within the closure to ensure all resource management is synchronized | ||
/// | ||
/// - Parameters: | ||
/// - body: Closure which performs and action on a loaded model | ||
/// - Returns: The result of the closure | ||
/// - Throws: An error if the model cannot be loaded or if the closure throws | ||
public func perform<R>(_ body: (MLModel) throws -> R) throws -> R { | ||
return try queue.sync { | ||
try autoreleasepool { | ||
try loadModel() | ||
return try body(loadedModel!) | ||
} | ||
} | ||
} | ||
|
||
private func loadModel() throws { | ||
if loadedModel == nil { | ||
loadedModel = try MLModel(contentsOf: modelURL, | ||
configuration: configuration) | ||
} | ||
} | ||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// For licensing see accompanying LICENSE.md file. | ||
// Copyright (C) 2022 Apple Inc. All Rights Reserved. | ||
|
||
/// Protocol for managing internal resources | ||
public protocol ResourceManaging { | ||
|
||
/// Request resources to be loaded and ready if possible | ||
func loadResources() throws | ||
|
||
/// Request resources are unloaded / remove from memory if possible | ||
func unloadResources() | ||
} | ||
|
||
extension ResourceManaging { | ||
/// Request resources are pre-warmed by loading and unloading | ||
func prewarmResources() throws { | ||
try loadResources() | ||
unloadResources() | ||
} | ||
} |
Oops, something went wrong.