Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(reasoner): support value type inference in udf #132

Merged
merged 9 commits into from
Mar 5, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

package com.antgroup.openspg.reasoner.common.types

import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException

trait KgType {
def isNullable: Boolean = false
}
Expand Down Expand Up @@ -64,3 +66,16 @@
* @param elementType
*/
final case class KTMultiVersion(elementType: KgType) extends KgType

object KgType {

def getNumberSeq(kgType: KgType): Int = {
kgType match {
case KTInteger => 1
case KTLong => 2
case KTDouble => 3
case _ => throw UnsupportedOperationException(s"cannot support number type $kgType")

Check warning on line 77 in reasoner/common/src/main/scala/com/antgroup/openspg/reasoner/common/types/KgType.scala

View check run for this annotation

Codecov / codecov/patch

reasoner/common/src/main/scala/com/antgroup/openspg/reasoner/common/types/KgType.scala#L77

Added line #L77 was not covered by tests
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ class OpenSPGDslParser extends ParserInterface {
IRProperty(s.alias, propertyName) ->
ProjectRule(
IRProperty(s.alias, propertyName),
propertyType,
Ref(ddlBlockWithNodes._3.target.alias)))))
DDLBlock(Set.apply(ddlBlockOp), List.apply(prjBlk))
case AddPredicate(predicate) =>
Expand Down Expand Up @@ -399,7 +398,7 @@ class OpenSPGDslParser extends ParserInterface {
ProjectBlock(
List.apply(opBlock),
ProjectFields(Map.apply(lValueName ->
ProjectRule(lValueName, exprParser.parseRetType(opChain.curExpr), opChain))))
ProjectRule(lValueName, opChain))))
}
case AggIfOpExpr(_, _) | AggOpExpr(_, _) =>
ProjectBlock(
Expand All @@ -409,7 +408,6 @@ class OpenSPGDslParser extends ParserInterface {
lValueName ->
ProjectRule(
lValueName,
exprParser.parseRetType(opChain.curExpr),
opChain.curExpr))))
case _ => null
}
Expand Down Expand Up @@ -461,8 +459,8 @@ class OpenSPGDslParser extends ParserInterface {
List.empty)
case _ =>
rule match {
case ProjectRule(_, lvalueType, _) =>
val projectRule = ProjectRule(lvalueFiled, lvalueType, expr)
case ProjectRule(_, _) =>
val projectRule = ProjectRule(lvalueFiled, expr)
ProjectBlock(
List.apply(preBlock),
ProjectFields(Map.apply(lvalueFiled -> projectRule)))
Expand Down Expand Up @@ -727,7 +725,7 @@ class OpenSPGDslParser extends ParserInterface {
exprParser.parseUnbrokenCharacterStringLiteral(ctx.unbroken_character_string_literal()))
val defaultName = "const_output_" + patternParser.getDefaultAliasNum
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultName)
(ProjectRule(IRVariable(defaultName), KTString, expr), columnName, true)
(ProjectRule(IRVariable(defaultName), expr), columnName, true)
}

def parseGraphStructure(
Expand All @@ -744,7 +742,7 @@ class OpenSPGDslParser extends ParserInterface {
val defaultColumnName = parseExpr2ElementStr(expr)
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultColumnName)
(
ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr),
ProjectRule(IRVariable(defaultColumnName), expr),
columnName,
false)
}
Expand Down Expand Up @@ -861,7 +859,7 @@ class OpenSPGDslParser extends ParserInterface {
val defaultColumnName = parseExpr2ElementStr(expr)
val columnName = parseReturnAlias(ctx.return_item_alias(), defaultColumnName)
(
ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr),
ProjectRule(IRVariable(defaultColumnName), expr),
columnName,
false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -849,10 +849,6 @@ class RuleExprParser extends Serializable {
}
}

def parseRetType(expr: Expr): KgType = {
KTObject
}

def parseRuleExpression(ctx: Rule_expressionContext): Rule = {
ctx.getChild(0) match {
case c: Logic_rule_expressionContext => parseLogicRuleExpression(c)
Expand All @@ -878,10 +874,9 @@ class RuleExprParser extends Serializable {
if (ctx.property_name() != null) {
ProjectRule(
IRProperty(ctx.identifier().getText, ctx.property_name().getText),
parseRetType(expr),
expr)
} else {
ProjectRule(IRVariable(ctx.identifier().getText), parseRetType(expr), expr)
ProjectRule(IRVariable(ctx.identifier().getText), expr)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ package com.antgroup.openspg.reasoner.parser

import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.exception.{KGDSLGrammarException, KGDSLInvalidTokenException, KGDSLOneTaskException}
import com.antgroup.openspg.reasoner.common.types.{KTInteger, KTString}
import com.antgroup.openspg.reasoner.lube.block._
import com.antgroup.openspg.reasoner.lube.common.expr._
import com.antgroup.openspg.reasoner.lube.common.graph._
Expand Down Expand Up @@ -296,8 +295,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
print(block.pretty)
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal(
ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o")))
proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o")))
}

it("addproperies with constraint") {
Expand All @@ -314,8 +312,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
print(block.pretty)
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal(
ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o")))
proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o")))
}

it("addproperies2") {
Expand All @@ -334,7 +331,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal(
ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o")))
ProjectRule(IRProperty("s", "total_domain_num"), Ref("o")))
}
it("addproperies") {
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
Expand All @@ -352,7 +349,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal(
ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o")))
ProjectRule(IRProperty("s", "total_domain_num"), Ref("o")))
}
it("addNode") {
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
Expand Down Expand Up @@ -661,7 +658,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
.asInstanceOf[AddPredicate]
.predicate
.fields
.keySet should contain ("same_domain_num")
.keySet should contain("same_domain_num")

blocks(1).asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddProperty] should equal(true)
blocks(1)
Expand Down Expand Up @@ -1048,7 +1045,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
val block = parser.parse(dsl)
print(block.pretty)
val text = """└─DDLBlock(ddlOp=Set(AddProperty((s:CustFundKG.Account),aggTransAmountNumByDay,KTBoolean)))
| └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),KTBoolean,Ref(refName=o)))))
| └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),Ref(refName=o)))))
| └─AggregationBlock(aggregations=Aggregations(Map(IRVariable(o) -> AggOpExpr(name=AggUdf(groupByAttrDoCount,List(VString(value=tranDate), VLong(value=50)))))), group=List(IRNode(s,Set())))
| └─FilterBlock(rules=LogicRule(R1,当月交易,BinaryOpExpr(name=BNotSmallerThan)))
| └─MatchBlock(patterns=Map(unresolved_default_path -> GraphPath(unresolved_default_path,GraphPattern(s,Map(u -> (u:CustFundKG.Account), s -> (s:CustFundKG.Account)),Map(u -> Set((u)<-[t:accountFundContact]-(s)))),Map(u -> Set(), s -> Set(), t -> Set(transDate))),false)))
Expand Down Expand Up @@ -1081,7 +1078,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
val block = parser.parse(dsl)
print(block.pretty)
val text = """└─TableResultBlock(selectList=OrderedFields(List(IRProperty(s,id), IRVariable(o))), asList=List(s.id, b))
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),KTObject,FunctionExpr(name=rule_value)))))
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),FunctionExpr(name=rule_value)))))
* └─FilterBlock(rules=LogicRule(R6,长得高,BinaryOpExpr(name=BOr)))
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R5) -> LogicRule(R5,颜值高,BinaryOpExpr(name=BGreaterThan)))))
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R4) -> LogicRule(R4,女性,BinaryOpExpr(name=BEqual)))))
Expand Down
4 changes: 4 additions & 0 deletions reasoner/lube-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
<groupId>com.antgroup.openspg.reasoner</groupId>
<artifactId>reasoner-common</artifactId>
</dependency>
<dependency>
<groupId>com.antgroup.openspg.reasoner</groupId>
<artifactId>reasoner-udf</artifactId>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@

package com.antgroup.openspg.reasoner.lube.catalog

import scala.collection.mutable

import com.antgroup.openspg.reasoner.common.exception.{ConnectionNotFoundException, GraphAlreadyExistsException, GraphNotFoundException}
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.graph.IRGraph
import scala.collection.mutable
import com.antgroup.openspg.reasoner.udf.{UdfMng, UdfMngFactory}


/**
Expand All @@ -27,6 +29,8 @@ import scala.collection.mutable
*/
abstract class Catalog() extends Serializable {
protected val graphRepository = new mutable.HashMap[String, SemanticPropertyGraph]()
@transient
private val udfRepo = UdfMngFactory.getUdfMng
private val connections = new mutable.HashMap[String, mutable.HashSet[AbstractConnection]]()

/**
Expand Down Expand Up @@ -96,6 +100,8 @@ abstract class Catalog() extends Serializable {
graphRepository.get(graphName).orNull
}

def getUdfRepo: UdfMng = udfRepo

/**
* Get schema from knowledge graph
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

package com.antgroup.openspg.reasoner.lube.common.rule

import com.antgroup.openspg.reasoner.common.types.KgType
import com.antgroup.openspg.reasoner.lube.common.expr.Expr
import com.antgroup.openspg.reasoner.lube.common.graph.IRField

Expand All @@ -39,13 +38,6 @@ trait Rule extends Cloneable{
*/
def getExpr: Expr


/**
* get lvalue type
* @return
*/
def getLvalueType: KgType

/**
* get dependencies
* @return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,6 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr)
*/
override def getExpr: Expr = expr

/**
* get lvalue type
*
* @return
*/
override def getLvalueType: KgType = KTBoolean

override def andRule(rule: Rule): Rule = {
val andExpr = BinaryOpExpr(BAnd, getExpr, rule.getExpr)

Expand Down Expand Up @@ -129,7 +122,7 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr)
* @param lvalueType
* @param expr
*/
final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr)
final case class ProjectRule(output: IRField, expr: Expr)
extends DependencyRule {

/**
Expand Down Expand Up @@ -158,12 +151,6 @@ final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr)
*/
override def getExpr: Expr = expr

/**
* get lvalue type
*
* @return
*/
override def getLvalueType: KgType = lvalueType

override def andRule(rule: Rule): Rule = {
throw UnsupportedOperationException("ProjectRule cannot support andRule")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ object RuleUtils {
case logicRule: LogicRule =>
LogicRule(ruleNameStr, logicRule.ruleExplain, expr)
case _ =>
ProjectRule(IRVariable(ruleNameStr), rule.getLvalueType, expr)
ProjectRule(IRVariable(ruleNameStr), expr)
}
val oldDependencies = rule.getDependencies
if (oldDependencies != null) {
Expand All @@ -162,7 +162,7 @@ object RuleUtils {
case logicRule: LogicRule =>
LogicRule(rule.getName, logicRule.ruleExplain, expr)
case _ =>
ProjectRule(rule.getOutput, rule.getLvalueType, expr)
ProjectRule(rule.getOutput, expr)
}
val oldDependencies = rule.getDependencies
if (oldDependencies != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class TransformerTest extends AnyFunSpec {
false)))),
ProjectFields(
Map.apply(IRVariable("total_domain_num") ->
ProjectRule(IRVariable("total_domain_num"), KTInteger, Ref("o")))))
ProjectRule(IRVariable("total_domain_num"), Ref("o")))))
val p = BlockUtils.transBlock2Graph(block)
p.size should equal(1)
p.head.graphPattern.nodes.size should equal(2)
Expand Down Expand Up @@ -120,7 +120,6 @@ class TransformerTest extends AnyFunSpec {
it("rename_rule") {
val rule = ProjectRule(
IRVariable("a"),
KTObject,
BinaryOpExpr(
BEqual,
UnaryOpExpr(GetField("birthDate"), Ref("e")),
Expand Down Expand Up @@ -153,12 +152,10 @@ class TransformerTest extends AnyFunSpec {
it("variable_rule") {
val rule = ProjectRule(
IRVariable("a"),
KTObject,
BinaryOpExpr(BEqual, UnaryOpExpr(GetField("birthDate"), Ref("e")), Ref("b")))

val rule2 = ProjectRule(
IRVariable("b"),
KTObject,
BinaryOpExpr(
BEqual,
UnaryOpExpr(GetField("attr1"), Ref("e")),
Expand All @@ -174,7 +171,6 @@ class TransformerTest extends AnyFunSpec {
def getDependenceRule(): Rule = {
val r0 = ProjectRule(
IRVariable("r0"),
KTLong,
BinaryOpExpr(BAssign, Ref("r0"), VLong("123"))
)
val r1 = LogicRule(
Expand Down Expand Up @@ -230,7 +226,6 @@ class TransformerTest extends AnyFunSpec {
val r0 = LogicRule("tmp", "",
BinaryOpExpr(BGreaterThan, UnaryOpExpr(GetField("amount"), Ref("E1")), VLong("10")))
val r = ProjectRule(IRVariable("g"),
KTLong,
OpChainExpr(
GraphAggregatorExpr(
"unresolved_default_path",
Expand Down
Loading
Loading