Skip to content

Commit

Permalink
Change EJML implementation to match CM and Ojalgo
Browse files Browse the repository at this point in the history
  • Loading branch information
altavir committed Jan 12, 2025
1 parent 1ec3d1a commit b2b64f3
Show file tree
Hide file tree
Showing 9 changed files with 232 additions and 219 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- attributes-kt moved to a separate project, and the version used is 0.3.0
- Kotlin 2.1. Now use cross-compilation to deploy macOS targets.
- Changed `origin` to `cmMatrix` in kmath-commons to avoid property name clash. Expose bidirectional conversion in `CMLinearSpace`
- (BREAKING CHANGE) Changed implementations in `kmath-ejml` to match CM and ojalgo style. Specifically, provide bidirectional conversion for library types.

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope
import kotlinx.benchmark.State
import space.kscience.kmath.commons.linear.CMLinearSpace
import space.kscience.kmath.commons.linear.CMLinearSpace.dot
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
import space.kscience.kmath.linear.Float64ParallelLinearSpace
import space.kscience.kmath.linear.invoke
Expand Down Expand Up @@ -59,13 +58,13 @@ internal class DotBenchmark {
}

@Benchmark
fun cmDot(blackhole: Blackhole) = CMLinearSpace {
blackhole.consume(cmMatrix1 dot cmMatrix2)
fun cmDot(blackhole: Blackhole): Unit = CMLinearSpace {
blackhole.consume(cmMatrix1.asMatrix() dot cmMatrix2.asMatrix())
}

@Benchmark
fun ejmlDot(blackhole: Blackhole) = EjmlLinearSpaceDDRM {
blackhole.consume(ejmlMatrix1 dot ejmlMatrix2)
fun ejmlDot(blackhole: Blackhole): Unit = EjmlLinearSpaceDDRM {
blackhole.consume(ejmlMatrix1.asMatrix() dot ejmlMatrix2.asMatrix())
}

@Benchmark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ import space.kscience.kmath.structures.Float64
* @param M the EJML matrix type.
* @author Iaroslav Postovalov
*/
public abstract class EjmlLinearSpace<T : Any, out A : Ring<T>, out M : org.ejml.data.Matrix> : LinearSpace<T, A> {
public interface EjmlLinearSpace<T : Any, out A : Ring<T>, M : org.ejml.data.Matrix> : LinearSpace<T, A> {
/**
* Converts this matrix to EJML one.
*/
public abstract fun Matrix<T>.toEjml(): EjmlMatrix<T, M>
public fun Matrix<T>.toEjml(): M

/**
* Converts this vector to EJML one.
*/
public abstract fun Point<T>.toEjml(): EjmlVector<T, M>
public fun Point<T>.toEjml(): M

public abstract override fun buildMatrix(
rows: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ import space.kscience.kmath.nd.Structure2D
*
* @param T the type of elements contained in the buffer.
* @param M the type of EJML matrix.
* @property origin The underlying EJML matrix.
* @property ejmlMatrix The underlying EJML matrix.
* @author Iaroslav Postovalov
*/
public abstract class EjmlMatrix<out T, out M : Matrix>(public open val origin: M) : Structure2D<T> {
override val rowNum: Int get() = origin.numRows
override val colNum: Int get() = origin.numCols
public abstract class EjmlMatrix<out T, out M : Matrix>(public open val ejmlMatrix: M) : Structure2D<T> {
override val rowNum: Int get() = ejmlMatrix.numRows
override val colNum: Int get() = ejmlMatrix.numCols
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ import space.kscience.kmath.linear.Point
*
* @param T the type of elements contained in the buffer.
* @param M the type of EJML matrix.
* @property origin The underlying matrix, must have only one row.
* @property ejmlVector The underlying matrix, must have only one row.
* @author Iaroslav Postovalov
*/
public abstract class EjmlVector<out T, out M : Matrix>(public open val origin: M) : Point<T> {
public abstract class EjmlVector<out T, out M : Matrix>(public open val ejmlVector: M) : Point<T> {
override val size: Int
get() = origin.numRows
get() = ejmlVector.numRows

override operator fun iterator(): Iterator<T> = object : Iterator<T> {
private var cursor: Int = 0
Expand All @@ -28,8 +28,8 @@ public abstract class EjmlVector<out T, out M : Matrix>(public open val origin:
return this@EjmlVector[cursor - 1]
}

override fun hasNext(): Boolean = cursor < origin.numCols * origin.numRows
override fun hasNext(): Boolean = cursor < ejmlVector.numCols * ejmlVector.numRows
}

override fun toString(): String = "EjmlVector(origin=$origin)"
override fun toString(): String = "EjmlVector(origin=$ejmlVector)"
}
Loading

0 comments on commit b2b64f3

Please sign in to comment.