Skip to content

Commit

Permalink
Merge pull request #437 from mipt-npm/commandertvis/double-specialized
Browse files Browse the repository at this point in the history
Provide specializations of `AsmBuilder` for `Double`, `Long`, `Int`
  • Loading branch information
altavir authored Nov 16, 2021
2 parents b1911eb + 7b50400 commit e25827e
Show file tree
Hide file tree
Showing 25 changed files with 961 additions and 425 deletions.
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,12 @@ One can still use generic algebras though.
> **Maturity**: DEVELOPMENT
<hr/>
* ### [kmath-multik](kmath-multik)
>
>
> **Maturity**: PROTOTYPE
<hr/>
* ### [kmath-nd4j](kmath-nd4j)
>
>
Expand All @@ -252,6 +258,12 @@ One can still use generic algebras though.
<hr/>

* ### [kmath-optimization](kmath-optimization)
>
>
> **Maturity**: EXPERIMENTAL
<hr/>
* ### [kmath-stat](kmath-stat)
>
>
Expand Down Expand Up @@ -319,8 +331,8 @@ repositories {
}

dependencies {
api("space.kscience:kmath-core:0.3.0-dev-14")
// api("space.kscience:kmath-core-jvm:0.3.0-dev-14") for jvm-specific version
api("space.kscience:kmath-core:0.3.0-dev-17")
// api("space.kscience:kmath-core-jvm:0.3.0-dev-17") for jvm-specific version
}
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import kotlinx.benchmark.Scope
import kotlinx.benchmark.State
import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
Expand All @@ -35,7 +36,14 @@ internal class ExpressionsInterpretersBenchmark {
* Benchmark case for [Expression] created with [compileToExpression].
*/
@Benchmark
fun asmExpression(blackhole: Blackhole) = invokeAndSum(asm, blackhole)
fun asmGenericExpression(blackhole: Blackhole) = invokeAndSum(asmGeneric, blackhole)


/**
* Benchmark case for [Expression] created with [compileToExpression].
*/
@Benchmark
fun asmPrimitiveExpression(blackhole: Blackhole) = invokeAndSum(asmPrimitive, blackhole)

/**
* Benchmark case for [Expression] implemented manually with `kotlin.math` functions.
Expand Down Expand Up @@ -87,7 +95,8 @@ internal class ExpressionsInterpretersBenchmark {
}

private val mst = node.toExpression(DoubleField)
private val asm = node.compileToExpression(DoubleField)
private val asmPrimitive = node.compileToExpression(DoubleField)
private val asmGeneric = node.compileToExpression(DoubleField as Algebra<Double>)

private val raw = Expression<Double> { args ->
val x = args[x]!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
package space.kscience.kmath.ast

import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.MstField
import space.kscience.kmath.expressions.MstExtendedField
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke

fun main() {
val expr = MstField {
x * 2.0 + number(2.0) / x - 16.0
val expr = MstExtendedField {
x * 2.0 + number(2.0) / x - number(16.0) + asinh(x) / sin(x)
}.compileToExpression(DoubleField)

val m = HashMap<Symbol, Double>()
Expand Down
6 changes: 3 additions & 3 deletions kmath-ast/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Performance and visualization extensions to MST API.

## Artifact:

The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-14`.
The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-17`.

**Gradle:**
```gradle
Expand All @@ -20,7 +20,7 @@ repositories {
}
dependencies {
implementation 'space.kscience:kmath-ast:0.3.0-dev-14'
implementation 'space.kscience:kmath-ast:0.3.0-dev-17'
}
```
**Gradle Kotlin DSL:**
Expand All @@ -31,7 +31,7 @@ repositories {
}

dependencies {
implementation("space.kscience:kmath-ast:0.3.0-dev-14")
implementation("space.kscience:kmath-ast:0.3.0-dev-17")
}
```

Expand Down
4 changes: 4 additions & 0 deletions kmath-ast/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ tasks.dokkaHtml {
dependsOn(tasks.build)
}

tasks.jvmTest {
jvmArgs = (jvmArgs ?: emptyList()) + listOf("-Dspace.kscience.kmath.ast.dump.generated.classes=1")
}

readme {
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,30 @@ internal class TestCompilerOperations {
assertEquals(1.0, expr(x to 0.0))
}

@Test
fun testTangent() = runCompilerTest {
val expr = MstExtendedField { tan(x) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}

@Test
fun testArcSine() = runCompilerTest {
val expr = MstExtendedField { asin(x) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}

@Test
fun testArcCosine() = runCompilerTest {
val expr = MstExtendedField { acos(x) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 1.0))
}

@Test
fun testAreaHyperbolicSine() = runCompilerTest {
val expr = MstExtendedField { asinh(x) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}

@Test
fun testSubtract() = runCompilerTest {
val expr = MstExtendedField { x - x }.compileToExpression(DoubleField)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ import space.kscience.kmath.internal.webassembly.Module as WasmModule
private val spreader = eval("(obj, args) => obj(...args)")

@Suppress("UnsafeCastFromDynamic")
internal sealed class WasmBuilder<T>(
val binaryenType: Type,
val algebra: Algebra<T>,
val target: MST,
) where T : Number {
val keys: MutableList<Symbol> = mutableListOf()
internal sealed class WasmBuilder<T : Number>(
protected val binaryenType: Type,
protected val algebra: Algebra<T>,
protected val target: MST,
) {
protected val keys: MutableList<Symbol> = mutableListOf()
lateinit var ctx: BinaryenModule

open fun visitSymbolic(mst: Symbol): ExpressionRef {
Expand All @@ -41,30 +41,36 @@ internal sealed class WasmBuilder<T>(

abstract fun visitNumeric(mst: Numeric): ExpressionRef

open fun visitUnary(mst: Unary): ExpressionRef =
protected open fun visitUnary(mst: Unary): ExpressionRef =
error("Unary operation ${mst.operation} not defined in $this")

open fun visitBinary(mst: Binary): ExpressionRef =
protected open fun visitBinary(mst: Binary): ExpressionRef =
error("Binary operation ${mst.operation} not defined in $this")

open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")

fun visit(mst: MST): ExpressionRef = when (mst) {
protected fun visit(mst: MST): ExpressionRef = when (mst) {
is Symbol -> visitSymbolic(mst)
is Numeric -> visitNumeric(mst)

is Unary -> when {
algebra is NumericAlgebra && mst.value is Numeric -> visitNumeric(
Numeric(algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as Numeric).value))))
Numeric(algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as Numeric).value)))
)

else -> visitUnary(mst)
}

is Binary -> when {
algebra is NumericAlgebra && mst.left is Numeric && mst.right is Numeric -> visitNumeric(Numeric(
algebra.binaryOperationFunction(mst.operation)
.invoke(algebra.number((mst.left as Numeric).value), algebra.number((mst.right as Numeric).value))
))
algebra is NumericAlgebra && mst.left is Numeric && mst.right is Numeric -> visitNumeric(
Numeric(
algebra.binaryOperationFunction(mst.operation)
.invoke(
algebra.number((mst.left as Numeric).value),
algebra.number((mst.right as Numeric).value)
)
)
)

else -> visitBinary(mst)
}
Expand Down
18 changes: 0 additions & 18 deletions kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,6 @@ import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.wasm.internal.DoubleWasmBuilder
import space.kscience.kmath.wasm.internal.IntWasmBuilder

/**
* Compiles an [MST] to WASM in the context of reals.
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun DoubleField.expression(mst: MST): Expression<Double> =
DoubleWasmBuilder(mst).instance

/**
* Compiles an [MST] to WASM in the context of integers.
*
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
public fun IntRing.expression(mst: MST): Expression<Int> =
IntWasmBuilder(mst).instance

/**
* Create a compiled expression with given [MST] and given [algebra].
*
Expand Down
90 changes: 82 additions & 8 deletions kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/

@file:Suppress("UNUSED_PARAMETER")

package space.kscience.kmath.asm

import space.kscience.kmath.asm.internal.AsmBuilder
import space.kscience.kmath.asm.internal.buildName
import space.kscience.kmath.asm.internal.*
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
import space.kscience.kmath.operations.bindSymbolOrNull
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*

/**
* Compiles given MST to an Expression using AST compiler.
Expand All @@ -26,7 +26,7 @@ import space.kscience.kmath.operations.bindSymbolOrNull
*/
@PublishedApi
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.variablesVisitor(node: MST): Unit = when (node) {
fun GenericAsmBuilder<T>.variablesVisitor(node: MST): Unit = when (node) {
is Symbol -> prepareVariable(node.identity)
is Unary -> variablesVisitor(node.value)

Expand All @@ -38,7 +38,7 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
else -> Unit
}

fun AsmBuilder<T>.expressionVisitor(node: MST): Unit = when (node) {
fun GenericAsmBuilder<T>.expressionVisitor(node: MST): Unit = when (node) {
is Symbol -> {
val symbol = algebra.bindSymbolOrNull(node)

Expand Down Expand Up @@ -87,7 +87,7 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
}
}

return AsmBuilder<T>(
return GenericAsmBuilder<T>(
type,
buildName(this),
{ variablesVisitor(this@compileWith) },
Expand All @@ -114,3 +114,77 @@ public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments:
*/
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
compileToExpression(algebra).invoke(*arguments)


/**
* Create a compiled expression with given [MST] and given [algebra].
*
* @author Iaroslav Postovalov
*/
public fun MST.compileToExpression(algebra: IntRing): Expression<Int> = IntAsmBuilder(this).instance

/**
* Compile given MST to expression and evaluate it against [arguments].
*
* @author Iaroslav Postovalov
*/
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
compileToExpression(algebra).invoke(arguments)

/**
* Compile given MST to expression and evaluate it against [arguments].
*
* @author Iaroslav Postovalov
*/
public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): Int =
compileToExpression(algebra)(*arguments)


/**
* Create a compiled expression with given [MST] and given [algebra].
*
* @author Iaroslav Postovalov
*/
public fun MST.compileToExpression(algebra: LongRing): Expression<Long> = LongAsmBuilder(this).instance


/**
* Compile given MST to expression and evaluate it against [arguments].
*
* @author Iaroslav Postovalov
*/
public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long =
compileToExpression(algebra).invoke(arguments)


/**
* Compile given MST to expression and evaluate it against [arguments].
*
* @author Iaroslav Postovalov
*/
public fun MST.compile(algebra: LongRing, vararg arguments: Pair<Symbol, Long>): Long =
compileToExpression(algebra)(*arguments)


/**
* Create a compiled expression with given [MST] and given [algebra].
*
* @author Iaroslav Postovalov
*/
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = DoubleAsmBuilder(this).instance

/**
* Compile given MST to expression and evaluate it against [arguments].
*
* @author Iaroslav Postovalov
*/
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
compileToExpression(algebra).invoke(arguments)

/**
* Compile given MST to expression and evaluate it against [arguments].
*
* @author Iaroslav Postovalov
*/
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
compileToExpression(algebra).invoke(*arguments)
Loading

0 comments on commit e25827e

Please sign in to comment.