Skip to content

Commit

Permalink
LUP cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
altavir committed Feb 18, 2024
1 parent fbee95a commit 79642a8
Showing 1 changed file with 21 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
*/

@file:Suppress("UnusedReceiverParameter")
@file:OptIn(PerformancePitfall::class)

package space.kscience.kmath.linear

import space.kscience.attributes.Attributes
import space.kscience.attributes.PolymorphicAttribute
import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.*
Expand Down Expand Up @@ -78,22 +80,22 @@ internal fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.abs(value: T): T =
/**
* Create a lup decomposition of generic matrix.
*/
public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
public fun <T : Comparable<T>> Field<T>.lup(
matrix: Matrix<T>,
checkSingular: (T) -> Boolean,
): GenericLupDecomposition<T> = elementAlgebra {
): GenericLupDecomposition<T> {
require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
val m = matrix.colNum
val pivot = IntArray(matrix.rowNum)

val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum))

val lu: MutableStructure2D<T> = MutableBufferND(
val lu = MutableBufferND(
strides,
bufferAlgebra.buffer(strides.linearSize) { offset ->
matrix[strides.index(offset)]
}
).as2D()
)


// Initialize the permutation array and parity
Expand All @@ -108,7 +110,9 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
// upper
for (row in 0 until col) {
var sum = lu[row, col]
for (i in 0 until row) sum -= lu[row, i] * lu[i, col]
for (i in 0 until row){
sum -= lu[row, i] * lu[i, col]
}
lu[row, col] = sum
}

Expand All @@ -118,7 +122,9 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(

for (row in col until m) {
var sum = lu[row, col]
for (i in 0 until col) sum -= lu[row, i] * lu[i, col]
for (i in 0 until col){
sum -= lu[row, i] * lu[i, col]
}
lu[row, col] = sum

// maintain the best permutation choice
Expand Down Expand Up @@ -151,31 +157,29 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
}


return GenericLupDecomposition(elementAlgebra, lu, pivot.asBuffer(), even)
return GenericLupDecomposition(this, lu.as2D(), pivot.asBuffer(), even)

}


public fun LinearSpace<Double, Float64Field>.lup(
public fun Field<Float64>.lup(
matrix: Matrix<Double>,
singularityThreshold: Double = 1e-11,
): GenericLupDecomposition<Double> = lup(matrix) { it < singularityThreshold }

internal fun <T> LinearSpace<T, Field<T>>.solve(
private fun <T> Field<T>.solve(
lup: LupDecomposition<T>,
matrix: Matrix<T>,
): Matrix<T> = elementAlgebra {
): Matrix<T> {
require(matrix.rowNum == lup.l.rowNum) { "Matrix dimension mismatch. Expected ${lup.l.rowNum}, but got ${matrix.colNum}" }

// with(BufferAccessor2D(matrix.rowNum, matrix.colNum, elementAlgebra.bufferFactory)) {

val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum))

// Apply permutations to b
val bp: MutableStructure2D<T> = MutableBufferND(
val bp = MutableBufferND(
strides,
bufferAlgebra.buffer(strides.linearSize) { offset -> zero }
).as2D()
bufferAlgebra.buffer(strides.linearSize) { _ -> zero }
)


for (row in 0 until matrix.rowNum) {
Expand Down Expand Up @@ -211,8 +215,7 @@ internal fun <T> LinearSpace<T, Field<T>>.solve(
}
}

return buildMatrix(matrix.rowNum, matrix.colNum) { i, j -> bp[i, j] }

return bp.as2D()
}


Expand All @@ -223,7 +226,7 @@ internal fun <T> LinearSpace<T, Field<T>>.solve(
public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lupSolver(
singularityCheck: (T) -> Boolean,
): LinearSolver<T> = object : LinearSolver<T> {
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> = elementAlgebra{
// Use existing decomposition if it is provided by matrix or linear space itself
val decomposition = a.getOrComputeAttribute(LUP) ?: lup(a, singularityCheck)
return solve(decomposition, b)
Expand Down

0 comments on commit 79642a8

Please sign in to comment.