Skip to content

Commit

Permalink
fix dot bug introduced in the last refactor. Add test for parallel li…
Browse files Browse the repository at this point in the history
…near algebra.
  • Loading branch information
altavir committed Feb 18, 2024
1 parent 79642a8 commit 41a325d
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- New Attributes-kt module that could be used as stand-alone. It declares. type-safe attributes containers.
- Explicit `mutableStructureND` builders for mutable structures.
- `Buffer.asList()` zero-copy transformation.
- Wasm support.
- Parallel implementation of `LinearSpace` for Float64
- Parallel buffer factories

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public object Float64LinearSpace : LinearSpace<Double, Float64Field> {
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
val rows = this@dot.rows.map { it.linearize() }
val columns = other.columns.map { it.linearize() }
val indices = 0 until this.rowNum
val indices = 0 until this.colNum
return buildMatrix(rowNum, other.colNum) { i, j ->
val r = rows[i]
val c = columns[j]
Expand All @@ -70,7 +70,7 @@ public object Float64LinearSpace : LinearSpace<Double, Float64Field> {
override fun Matrix<Double>.dot(vector: Point<Double>): Float64Buffer {
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
val rows = this@dot.rows.map { it.linearize() }
val indices = 0 until this.rowNum
val indices = 0 until this.colNum
return Float64Buffer(rowNum) { i ->
val r = rows[i]
var res = 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ class DoubleLUSolverTest {
}

@Test
fun testDecomposition() = Double.algebra.linearSpace.run {
fun testDecomposition() = with(Double.algebra.linearSpace){
val matrix = matrix(2, 2)(
3.0, 1.0,
2.0, 3.0
)

val lup = lup(matrix)
val lup = elementAlgebra.lup(matrix)

//Check determinant
// assertEquals(7.0, lup.determinant)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package space.kscience.kmath.linear
import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.operations.algebra
import kotlin.test.Test
import kotlin.test.assertEquals
Expand Down Expand Up @@ -58,7 +59,7 @@ class MatrixTest {
}

@Test
fun test2DDot() = Double.algebra.linearSpace.run {
fun test2DDot() = Float64Field.linearSpace {
val firstMatrix = buildMatrix(2, 3) { i, j -> (i + j).toDouble() }
val secondMatrix = buildMatrix(3, 2) { i, j -> (i + j).toDouble() }

Expand All @@ -70,6 +71,5 @@ class MatrixTest {
assertEquals(8.0, result[0, 1])
assertEquals(8.0, result[1, 0])
assertEquals(14.0, result[1, 1])

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public object Float64ParallelLinearSpace : LinearSpace<Double, Float64Field> {
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
val rows = this@dot.rows.map { it.linearize() }
val columns = other.columns.map { it.linearize() }
val indices = 0 until this.rowNum
val indices = 0 until this.colNum
return buildMatrix(rowNum, other.colNum) { i, j ->
val r = rows[i]
val c = columns[j]
Expand All @@ -85,7 +85,7 @@ public object Float64ParallelLinearSpace : LinearSpace<Double, Float64Field> {
override fun Matrix<Double>.dot(vector: Point<Double>): Float64Buffer {
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
val rows = this@dot.rows.map { it.linearize() }
val indices = 0 until this.rowNum
val indices = 0 until this.colNum
return Float64Buffer(rowNum) { i ->
val r = rows[i]
var res = 0.0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/

package space.kscience.kmath.linear

import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Float64Field
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue

@UnstableKMathAPI
@OptIn(PerformancePitfall::class)
@Suppress("UNUSED_VARIABLE")
class ParallelMatrixTest {

@Test
fun testTranspose() = Float64Field.linearSpace.parallel{
val matrix = one(3, 3)
val transposed = matrix.transposed()
assertTrue { StructureND.contentEquals(matrix, transposed) }
}

@Test
fun testBuilder() = Float64Field.linearSpace.parallel{
val matrix = matrix(2, 3)(
1.0, 0.0, 0.0,
0.0, 1.0, 2.0
)

assertEquals(2.0, matrix[1, 2])
}

@Test
fun testMatrixExtension() = Float64Field.linearSpace.parallel{
val transitionMatrix: Matrix<Double> = VirtualMatrix(type,6, 6) { row, col ->
when {
col == 0 -> .50
row + 1 == col -> .50
row == 5 && col == 5 -> 1.0
else -> 0.0
}
}

infix fun Matrix<Double>.pow(power: Int): Matrix<Double> {
var res = this
repeat(power - 1) {
res = res dot this@pow
}
return res
}

val toTenthPower = transitionMatrix pow 10
}

@Test
fun test2DDot() = Float64Field.linearSpace.parallel {
val firstMatrix = buildMatrix(2, 3) { i, j -> (i + j).toDouble() }
val secondMatrix = buildMatrix(3, 2) { i, j -> (i + j).toDouble() }

// val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() }
// val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() }
val result = firstMatrix dot secondMatrix
assertEquals(2, result.rowNum)
assertEquals(2, result.colNum)
assertEquals(8.0, result[0, 1])
assertEquals(8.0, result[1, 0])
assertEquals(14.0, result[1, 1])
}
}

0 comments on commit 41a325d

Please sign in to comment.