Skip to content

Commit

Permalink
Move guidanceScale as generation parameter (apple#46)
Browse files Browse the repository at this point in the history
* Move guidanceScale as generation parameter

* Added guidanceScale in CLI

* Reverted identation change
  • Loading branch information
Wanaldino authored Dec 24, 2022
1 parent 877ccb9 commit e07c4d0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
16 changes: 6 additions & 10 deletions swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// Optional model for checking safety of generated image
var safetyChecker: SafetyChecker? = nil

/// Controls the influence of the text prompt on sampling process (0=random images)
var guidanceScale: Float = 7.5

/// Reports whether this pipeline can perform safety checks
public var canSafetyCheck: Bool {
safetyChecker != nil
Expand All @@ -56,20 +53,17 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// - unet: Model for noise prediction on latent samples
/// - decoder: Model for decoding latent sample to image
/// - safetyChecker: Optional model for checking safety of generated images
/// - guidanceScale: Influence of the text prompt on generation process
/// - reduceMemory: Option to enable reduced memory mode
/// - Returns: Pipeline ready for image generation
public init(textEncoder: TextEncoder,
unet: Unet,
decoder: Decoder,
safetyChecker: SafetyChecker? = nil,
guidanceScale: Float = 7.5,
reduceMemory: Bool = false) {
self.textEncoder = textEncoder
self.unet = unet
self.decoder = decoder
self.safetyChecker = safetyChecker
self.guidanceScale = guidanceScale
self.reduceMemory = reduceMemory
}

Expand Down Expand Up @@ -112,6 +106,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// - stepCount: Number of inference steps to perform
/// - imageCount: Number of samples/images to generate for the input prompt
/// - seed: Random seed which
/// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images)
/// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
/// - progressHandler: Callback to perform after each step, stops on receiving false response
/// - Returns: An array of `imageCount` optional images.
Expand All @@ -122,6 +117,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
imageCount: Int = 1,
stepCount: Int = 50,
seed: UInt32 = 0,
guidanceScale: Float = 7.5,
disableSafety: Bool = false,
scheduler: StableDiffusionScheduler = .pndmScheduler,
progressHandler: (Progress) -> Bool = { _ in true }
Expand Down Expand Up @@ -173,7 +169,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
hiddenStates: hiddenStates
)

noise = performGuidance(noise)
noise = performGuidance(noise, guidanceScale)

// Have the scheduler compute the previous (t-1) latent
// sample given the predicted noise and current sample
Expand Down Expand Up @@ -236,11 +232,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
return states
}

func performGuidance(_ noise: [MLShapedArray<Float32>]) -> [MLShapedArray<Float32>] {
noise.map { performGuidance($0) }
func performGuidance(_ noise: [MLShapedArray<Float32>], _ guidanceScale: Float) -> [MLShapedArray<Float32>] {
noise.map { performGuidance($0, guidanceScale) }
}

func performGuidance(_ noise: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
func performGuidance(_ noise: MLShapedArray<Float32>, _ guidanceScale: Float) -> MLShapedArray<Float32> {

let blankNoiseScalars = noise[0].scalars
let textNoiseScalars = noise[1].scalars
Expand Down
4 changes: 4 additions & 0 deletions swift/StableDiffusionCLI/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ struct StableDiffusionSample: ParsableCommand {
@Option(help: "Random seed")
var seed: UInt32 = 93

@Option(help: "Controls the influence of the text prompt on sampling process (0=random images)")
var guidanceScale: Float = 7.5

@Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}")
var computeUnits: ComputeUnits = .all

Expand Down Expand Up @@ -92,6 +95,7 @@ struct StableDiffusionSample: ParsableCommand {
imageCount: imageCount,
stepCount: stepCount,
seed: seed,
guidanceScale: guidanceScale,
scheduler: scheduler.stableDiffusionScheduler
) { progress in
sampleTimer.stop()
Expand Down

0 comments on commit e07c4d0

Please sign in to comment.