diff --git a/reasoner/common/src/main/scala/com/antgroup/openspg/reasoner/common/types/KgType.scala b/reasoner/common/src/main/scala/com/antgroup/openspg/reasoner/common/types/KgType.scala
index 28a170c6c..ad2762c59 100644
--- a/reasoner/common/src/main/scala/com/antgroup/openspg/reasoner/common/types/KgType.scala
+++ b/reasoner/common/src/main/scala/com/antgroup/openspg/reasoner/common/types/KgType.scala
@@ -13,6 +13,8 @@
package com.antgroup.openspg.reasoner.common.types
+import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
+
trait KgType {
def isNullable: Boolean = false
}
@@ -64,3 +66,16 @@ final case class KTAdvanced(label: String) extends KgType
* @param elementType
*/
final case class KTMultiVersion(elementType: KgType) extends KgType
+
+object KgType {
+
+ def getNumberSeq(kgType: KgType): Int = {
+ kgType match {
+ case KTInteger => 1
+ case KTLong => 2
+ case KTDouble => 3
+ case _ => throw UnsupportedOperationException(s"cannot support number type $kgType")
+ }
+ }
+
+}
diff --git a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParser.scala b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParser.scala
index c87a20017..5f3705dd6 100644
--- a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParser.scala
+++ b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParser.scala
@@ -158,7 +158,6 @@ class OpenSPGDslParser extends ParserInterface {
IRProperty(s.alias, propertyName) ->
ProjectRule(
IRProperty(s.alias, propertyName),
- propertyType,
Ref(ddlBlockWithNodes._3.target.alias)))))
DDLBlock(Set.apply(ddlBlockOp), List.apply(prjBlk))
case AddPredicate(predicate) =>
@@ -399,7 +398,7 @@ class OpenSPGDslParser extends ParserInterface {
ProjectBlock(
List.apply(opBlock),
ProjectFields(Map.apply(lValueName ->
- ProjectRule(lValueName, exprParser.parseRetType(opChain.curExpr), opChain))))
+ ProjectRule(lValueName, opChain))))
}
case AggIfOpExpr(_, _) | AggOpExpr(_, _) =>
ProjectBlock(
@@ -409,7 +408,6 @@ class OpenSPGDslParser extends ParserInterface {
lValueName ->
ProjectRule(
lValueName,
- exprParser.parseRetType(opChain.curExpr),
opChain.curExpr))))
case _ => null
}
@@ -461,8 +459,8 @@ class OpenSPGDslParser extends ParserInterface {
List.empty)
case _ =>
rule match {
- case ProjectRule(_, lvalueType, _) =>
- val projectRule = ProjectRule(lvalueFiled, lvalueType, expr)
+ case ProjectRule(_, _) =>
+ val projectRule = ProjectRule(lvalueFiled, expr)
ProjectBlock(
List.apply(preBlock),
ProjectFields(Map.apply(lvalueFiled -> projectRule)))
@@ -727,7 +725,7 @@ class OpenSPGDslParser extends ParserInterface {
exprParser.parseUnbrokenCharacterStringLiteral(ctx.unbroken_character_string_literal()))
val defaultName = "const_output_" + patternParser.getDefaultAliasNum
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultName)
- (ProjectRule(IRVariable(defaultName), KTString, expr), columnName, true)
+ (ProjectRule(IRVariable(defaultName), expr), columnName, true)
}
def parseGraphStructure(
@@ -744,7 +742,7 @@ class OpenSPGDslParser extends ParserInterface {
val defaultColumnName = parseExpr2ElementStr(expr)
val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultColumnName)
(
- ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr),
+ ProjectRule(IRVariable(defaultColumnName), expr),
columnName,
false)
}
@@ -861,7 +859,7 @@ class OpenSPGDslParser extends ParserInterface {
val defaultColumnName = parseExpr2ElementStr(expr)
val columnName = parseReturnAlias(ctx.return_item_alias(), defaultColumnName)
(
- ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr),
+ ProjectRule(IRVariable(defaultColumnName), expr),
columnName,
false)
}
diff --git a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/expr/RuleExprParser.scala b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/expr/RuleExprParser.scala
index 94a2ee600..6b8bb5027 100644
--- a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/expr/RuleExprParser.scala
+++ b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/expr/RuleExprParser.scala
@@ -849,10 +849,6 @@ class RuleExprParser extends Serializable {
}
}
- def parseRetType(expr: Expr): KgType = {
- KTObject
- }
-
def parseRuleExpression(ctx: Rule_expressionContext): Rule = {
ctx.getChild(0) match {
case c: Logic_rule_expressionContext => parseLogicRuleExpression(c)
@@ -878,10 +874,9 @@ class RuleExprParser extends Serializable {
if (ctx.property_name() != null) {
ProjectRule(
IRProperty(ctx.identifier().getText, ctx.property_name().getText),
- parseRetType(expr),
expr)
} else {
- ProjectRule(IRVariable(ctx.identifier().getText), parseRetType(expr), expr)
+ ProjectRule(IRVariable(ctx.identifier().getText), expr)
}
}
diff --git a/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParserTest.scala b/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParserTest.scala
index 3a84b0c47..721f99ff9 100644
--- a/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParserTest.scala
+++ b/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParserTest.scala
@@ -15,7 +15,6 @@ package com.antgroup.openspg.reasoner.parser
import com.antgroup.openspg.reasoner.common.constants.Constants
import com.antgroup.openspg.reasoner.common.exception.{KGDSLGrammarException, KGDSLInvalidTokenException, KGDSLOneTaskException}
-import com.antgroup.openspg.reasoner.common.types.{KTInteger, KTString}
import com.antgroup.openspg.reasoner.lube.block._
import com.antgroup.openspg.reasoner.lube.common.expr._
import com.antgroup.openspg.reasoner.lube.common.graph._
@@ -296,8 +295,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
print(block.pretty)
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
- proj.projects.items.head._2 should equal(
- ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o")))
+ proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o")))
}
it("addproperies with constraint") {
@@ -314,8 +312,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
print(block.pretty)
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
- proj.projects.items.head._2 should equal(
- ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o")))
+ proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o")))
}
it("addproperies2") {
@@ -334,7 +331,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal(
- ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o")))
+ ProjectRule(IRProperty("s", "total_domain_num"), Ref("o")))
}
it("addproperies") {
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
@@ -352,7 +349,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true)
val proj = block.dependencies.head.asInstanceOf[ProjectBlock]
proj.projects.items.head._2 should equal(
- ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o")))
+ ProjectRule(IRProperty("s", "total_domain_num"), Ref("o")))
}
it("addNode") {
val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) {
@@ -661,7 +658,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
.asInstanceOf[AddPredicate]
.predicate
.fields
- .keySet should contain ("same_domain_num")
+ .keySet should contain("same_domain_num")
blocks(1).asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddProperty] should equal(true)
blocks(1)
@@ -1048,7 +1045,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
val block = parser.parse(dsl)
print(block.pretty)
val text = """└─DDLBlock(ddlOp=Set(AddProperty((s:CustFundKG.Account),aggTransAmountNumByDay,KTBoolean)))
- | └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),KTBoolean,Ref(refName=o)))))
+ | └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),Ref(refName=o)))))
| └─AggregationBlock(aggregations=Aggregations(Map(IRVariable(o) -> AggOpExpr(name=AggUdf(groupByAttrDoCount,List(VString(value=tranDate), VLong(value=50)))))), group=List(IRNode(s,Set())))
| └─FilterBlock(rules=LogicRule(R1,当月交易,BinaryOpExpr(name=BNotSmallerThan)))
| └─MatchBlock(patterns=Map(unresolved_default_path -> GraphPath(unresolved_default_path,GraphPattern(s,Map(u -> (u:CustFundKG.Account), s -> (s:CustFundKG.Account)),Map(u -> Set((u)<-[t:accountFundContact]-(s)))),Map(u -> Set(), s -> Set(), t -> Set(transDate))),false)))
@@ -1081,7 +1078,7 @@ class OpenSPGDslParserTest extends AnyFunSpec {
val block = parser.parse(dsl)
print(block.pretty)
val text = """└─TableResultBlock(selectList=OrderedFields(List(IRProperty(s,id), IRVariable(o))), asList=List(s.id, b))
- * └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),KTObject,FunctionExpr(name=rule_value)))))
+ * └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),FunctionExpr(name=rule_value)))))
* └─FilterBlock(rules=LogicRule(R6,长得高,BinaryOpExpr(name=BOr)))
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R5) -> LogicRule(R5,颜值高,BinaryOpExpr(name=BGreaterThan)))))
* └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R4) -> LogicRule(R4,女性,BinaryOpExpr(name=BEqual)))))
diff --git a/reasoner/lube-api/pom.xml b/reasoner/lube-api/pom.xml
index 87d0aa2e8..5377885b0 100644
--- a/reasoner/lube-api/pom.xml
+++ b/reasoner/lube-api/pom.xml
@@ -33,6 +33,10 @@
com.antgroup.openspg.reasoner
reasoner-common
+
+ com.antgroup.openspg.reasoner
+ reasoner-udf
+
org.scala-lang
scala-library
diff --git a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/catalog/Catalog.scala b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/catalog/Catalog.scala
index 3d55ecaa0..171d46a48 100644
--- a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/catalog/Catalog.scala
+++ b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/catalog/Catalog.scala
@@ -13,10 +13,12 @@
package com.antgroup.openspg.reasoner.lube.catalog
+import scala.collection.mutable
+
import com.antgroup.openspg.reasoner.common.exception.{ConnectionNotFoundException, GraphAlreadyExistsException, GraphNotFoundException}
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.graph.IRGraph
-import scala.collection.mutable
+import com.antgroup.openspg.reasoner.udf.{UdfMng, UdfMngFactory}
/**
@@ -27,6 +29,7 @@ import scala.collection.mutable
*/
abstract class Catalog() extends Serializable {
protected val graphRepository = new mutable.HashMap[String, SemanticPropertyGraph]()
+ @transient private val udfRepo = UdfMngFactory.getUdfMng
private val connections = new mutable.HashMap[String, mutable.HashSet[AbstractConnection]]()
/**
@@ -96,6 +99,8 @@ abstract class Catalog() extends Serializable {
graphRepository.get(graphName).orNull
}
+ def getUdfRepo: UdfMng = udfRepo
+
/**
* Get schema from knowledge graph
*/
diff --git a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/Rule.scala b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/Rule.scala
index eb48b7f50..99e475c2a 100644
--- a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/Rule.scala
+++ b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/Rule.scala
@@ -13,7 +13,6 @@
package com.antgroup.openspg.reasoner.lube.common.rule
-import com.antgroup.openspg.reasoner.common.types.KgType
import com.antgroup.openspg.reasoner.lube.common.expr.Expr
import com.antgroup.openspg.reasoner.lube.common.graph.IRField
@@ -39,13 +38,6 @@ trait Rule extends Cloneable{
*/
def getExpr: Expr
-
- /**
- * get lvalue type
- * @return
- */
- def getLvalueType: KgType
-
/**
* get dependencies
* @return
diff --git a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/RuleImpl.scala b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/RuleImpl.scala
index 322416a0a..07a0ccf06 100644
--- a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/RuleImpl.scala
+++ b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/RuleImpl.scala
@@ -71,13 +71,6 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr)
*/
override def getExpr: Expr = expr
- /**
- * get lvalue type
- *
- * @return
- */
- override def getLvalueType: KgType = KTBoolean
-
override def andRule(rule: Rule): Rule = {
val andExpr = BinaryOpExpr(BAnd, getExpr, rule.getExpr)
@@ -129,7 +122,7 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr)
* @param lvalueType
* @param expr
*/
-final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr)
+final case class ProjectRule(output: IRField, expr: Expr)
extends DependencyRule {
/**
@@ -158,12 +151,6 @@ final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr)
*/
override def getExpr: Expr = expr
- /**
- * get lvalue type
- *
- * @return
- */
- override def getLvalueType: KgType = lvalueType
override def andRule(rule: Rule): Rule = {
throw UnsupportedOperationException("ProjectRule cannot support andRule")
diff --git a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/utils/RuleUtils.scala b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/utils/RuleUtils.scala
index 84b9b0dbe..ecb6f726e 100644
--- a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/utils/RuleUtils.scala
+++ b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/utils/RuleUtils.scala
@@ -138,7 +138,7 @@ object RuleUtils {
case logicRule: LogicRule =>
LogicRule(ruleNameStr, logicRule.ruleExplain, expr)
case _ =>
- ProjectRule(IRVariable(ruleNameStr), rule.getLvalueType, expr)
+ ProjectRule(IRVariable(ruleNameStr), expr)
}
val oldDependencies = rule.getDependencies
if (oldDependencies != null) {
@@ -162,7 +162,7 @@ object RuleUtils {
case logicRule: LogicRule =>
LogicRule(rule.getName, logicRule.ruleExplain, expr)
case _ =>
- ProjectRule(rule.getOutput, rule.getLvalueType, expr)
+ ProjectRule(rule.getOutput, expr)
}
val oldDependencies = rule.getDependencies
if (oldDependencies != null) {
diff --git a/reasoner/lube-api/src/test/scala/com/antgroup/openspg/reasoner/parser/TransformerTest.scala b/reasoner/lube-api/src/test/scala/com/antgroup/openspg/reasoner/parser/TransformerTest.scala
index 71308867a..408763376 100644
--- a/reasoner/lube-api/src/test/scala/com/antgroup/openspg/reasoner/parser/TransformerTest.scala
+++ b/reasoner/lube-api/src/test/scala/com/antgroup/openspg/reasoner/parser/TransformerTest.scala
@@ -50,7 +50,7 @@ class TransformerTest extends AnyFunSpec {
false)))),
ProjectFields(
Map.apply(IRVariable("total_domain_num") ->
- ProjectRule(IRVariable("total_domain_num"), KTInteger, Ref("o")))))
+ ProjectRule(IRVariable("total_domain_num"), Ref("o")))))
val p = BlockUtils.transBlock2Graph(block)
p.size should equal(1)
p.head.graphPattern.nodes.size should equal(2)
@@ -120,7 +120,6 @@ class TransformerTest extends AnyFunSpec {
it("rename_rule") {
val rule = ProjectRule(
IRVariable("a"),
- KTObject,
BinaryOpExpr(
BEqual,
UnaryOpExpr(GetField("birthDate"), Ref("e")),
@@ -153,12 +152,10 @@ class TransformerTest extends AnyFunSpec {
it("variable_rule") {
val rule = ProjectRule(
IRVariable("a"),
- KTObject,
BinaryOpExpr(BEqual, UnaryOpExpr(GetField("birthDate"), Ref("e")), Ref("b")))
val rule2 = ProjectRule(
IRVariable("b"),
- KTObject,
BinaryOpExpr(
BEqual,
UnaryOpExpr(GetField("attr1"), Ref("e")),
@@ -174,7 +171,6 @@ class TransformerTest extends AnyFunSpec {
def getDependenceRule(): Rule = {
val r0 = ProjectRule(
IRVariable("r0"),
- KTLong,
BinaryOpExpr(BAssign, Ref("r0"), VLong("123"))
)
val r1 = LogicRule(
@@ -230,7 +226,6 @@ class TransformerTest extends AnyFunSpec {
val r0 = LogicRule("tmp", "",
BinaryOpExpr(BGreaterThan, UnaryOpExpr(GetField("amount"), Ref("E1")), VLong("10")))
val r = ProjectRule(IRVariable("g"),
- KTLong,
OpChainExpr(
GraphAggregatorExpr(
"unresolved_default_path",
diff --git a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtil.scala b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtil.scala
index 041bf3d99..1ddc1da3a 100644
--- a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtil.scala
+++ b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtil.scala
@@ -13,9 +13,16 @@
package com.antgroup.openspg.reasoner.lube.logical
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
import com.antgroup.openspg.reasoner.common.trees.BottomUp
+import com.antgroup.openspg.reasoner.common.types._
import com.antgroup.openspg.reasoner.lube.common.expr._
+import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable}
import com.antgroup.openspg.reasoner.lube.common.rule.Rule
+import com.antgroup.openspg.reasoner.udf.UdfMng
object ExprUtil {
@@ -46,15 +53,13 @@ object ExprUtil {
}
-
def needResolved(rule: Expr): Boolean = {
!getReferProperties(rule).filter(_._1 == null).isEmpty
}
def transExpr(rule: Expr, replaceVar: Map[String, PropertyVar]): Expr = {
- def rewriter: PartialFunction[Expr, Expr] = {
- case Ref(refName) =>
+ def rewriter: PartialFunction[Expr, Expr] = { case Ref(refName) =>
if (replaceVar.contains(refName)) {
val propertyVar = replaceVar(refName)
UnaryOpExpr(GetField(propertyVar.field.name), Ref(propertyVar.name))
@@ -76,4 +81,104 @@ object ExprUtil {
newRule
}
+ def getTargetType(expr: Expr, referVars: Map[IRField, KgType], udfRepo: UdfMng): KgType = {
+ expr match {
+ case Ref(name) =>
+ if (referVars.contains(IRVariable(name))) {
+ referVars(IRVariable(name))
+ } else {
+ KTObject
+ }
+ case UnaryOpExpr(GetField(name), Ref(alis)) => referVars(IRProperty(alis, name))
+ case BinaryOpExpr(name, l, r) =>
+ name match {
+ case BAnd | BEqual | BNotEqual | BGreaterThan | BNotGreaterThan | BSmallerThan |
+ BNotSmallerThan | BOr | BIn | BLike | BRLike | BAssign =>
+ KTBoolean
+ case BAdd | BSub | BMul | BDiv | BMod =>
+ val left = getTargetType(l, referVars, udfRepo)
+ val right = getTargetType(r, referVars, udfRepo)
+ getUpperType(left, right)
+ case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
+ }
+ case UnaryOpExpr(name, arg) =>
+ name match {
+ case Not | Exists => KTBoolean
+ case Abs | Neg => getTargetType(arg, referVars, udfRepo)
+ case Floor | Ceil => KTDouble
+ case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
+ }
+ case FunctionExpr(name, funcArgs) =>
+ val types = funcArgs.map(getTargetType(_, referVars, udfRepo))
+ name match {
+ case "rule_value" => types(1)
+ case "cast_type" | "Cast" =>
+ funcArgs(1).asInstanceOf[VString].value match {
+ case "int" | "bigint" | "long" => KTLong
+ case "float" | "double" => KTDouble
+ case "varchar" | "string" => KTString
+ case _ =>
+ throw UnsupportedOperationException(s"cannot support ${name} to ${funcArgs(1)}")
+ }
+ case _ =>
+ val udf = udfRepo.getUdfMeta(name, types.asJava)
+ if (udf != null) {
+ udf.getResultType
+ } else {
+ throw UnsupportedOperationException(s"cannot find UDF: ${name}")
+ }
+ }
+
+ case AggOpExpr(name, args) =>
+ name match {
+ case Min | Max | Sum | Avg | First | Accumulate(_) =>
+ getTargetType(args, referVars, udfRepo)
+ case StrJoin(_) => KTString
+ case Count => KTLong
+ case AggUdf(name, _) =>
+ val types = getTargetType(args.head, referVars, udfRepo)
+ val udf = udfRepo.getUdafMeta(name, types)
+ if (udf != null) {
+ udf.getResultType
+ } else {
+ throw UnsupportedOperationException(s"cannot find UDAF ${name}")
+ }
+ case _ => throw UnsupportedOperationException(s"express cannot support ${name}")
+ }
+ case OpChainExpr(curExpr, _) => getTargetType(curExpr, referVars, udfRepo)
+ case ListOpExpr(name, _) =>
+ name match {
+ case Reduce(_, _, _, initValue) => getTargetType(initValue, referVars, udfRepo)
+ case Constraint(_, _, _) => KTBoolean
+ case Get(_) | Slice(_, _) => KTObject
+ }
+ case AggIfOpExpr(op, _) => getTargetType(op, referVars, udfRepo)
+ case VNull | VString(_) => KTString
+ case VLong(_) => KTLong
+ case VDouble(_) => KTDouble
+ case VBoolean(_) => KTBoolean
+ case VList(_, listType) => KTList(listType)
+ case _ => throw UnsupportedOperationException(s"express cannot support ${expr.pretty}")
+ }
+ }
+
+ def getTargetType(rule: Rule, referVars: Map[IRField, KgType], udfRepo: UdfMng): KgType = {
+ val newReferVars = new mutable.HashMap[IRField, KgType]
+ newReferVars.++=(referVars)
+ for (r <- rule.getDependencies) {
+ newReferVars.put(r.getOutput, getTargetType(r, referVars, udfRepo))
+ }
+ getTargetType(rule.getExpr, newReferVars.toMap, udfRepo)
+ }
+
+ private def getUpperType(left: KgType, right: KgType): KgType = {
+ val l = KgType.getNumberSeq(left)
+ val r = KgType.getNumberSeq(right)
+ if (l >= r) {
+ left
+ } else {
+ right
+ }
+ }
+
}
diff --git a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/AggregationPlanner.scala b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/AggregationPlanner.scala
index 9c1358e1e..ae7df8f90 100644
--- a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/AggregationPlanner.scala
+++ b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/AggregationPlanner.scala
@@ -16,7 +16,7 @@ package com.antgroup.openspg.reasoner.lube.logical.planning
import scala.collection.mutable
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
-import com.antgroup.openspg.reasoner.common.types.KTObject
+import com.antgroup.openspg.reasoner.common.types.{KgType, KTObject}
import com.antgroup.openspg.reasoner.lube.block.Aggregations
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator
@@ -26,7 +26,8 @@ import com.antgroup.openspg.reasoner.lube.logical.operators.{Aggregate, LogicalL
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils
import org.apache.commons.lang3.StringUtils
-class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
+class AggregationPlanner(group: List[IRField], aggregations: Aggregations)(implicit
+ context: LogicalPlannerContext) {
def plan(dependency: LogicalOperator): LogicalOperator = {
val groupVar: List[Var] = group.map(toVar(_, dependency.solved))
@@ -47,6 +48,21 @@ class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
v
}
})
+
+ val referTypes = new mutable.HashMap[IRField, KgType]()
+ for (v <- ruleFields) {
+ resolved.getVar(v.name) match {
+ case p: PropertyVar => referTypes.put(v, p.field.kgType)
+ case node: NodeVar =>
+ node.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
+ case edge: EdgeVar =>
+ edge.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
+ case _ => throw UnsupportedOperationException(s"cannot support $v")
+ }
+ }
+ referTypes.++=(resolved.tmpFields.map(p => (p._1, p._2.field.kgType)))
+ val ruleRetType = ExprUtil.getTargetType(p._2, referTypes.toMap, context.catalog.getUdfRepo)
+
val renameVar = ruleFields
.filter(_.isInstanceOf[IRVariable])
.map(v => (v, propertyVarToIr(resolved.tmpFields(v.asInstanceOf[IRVariable]))))
@@ -56,21 +72,21 @@ class AggregationPlanner(group: List[IRField], aggregations: Aggregations) {
val field = getAggregateTarget(referFields, resolved, dependency)
field match {
case IRNode(alias, _) =>
- val propertyVar = PropertyVar(alias, new Field(p._1.name, KTObject, true))
+ val propertyVar = PropertyVar(alias, new Field(p._1.name, ruleRetType, true))
aggMap.put(propertyVar, newAggExpr)
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
case IREdge(alias, _) =>
if (resolved.getVar(alias).isInstanceOf[RepeatPathVar]) {
aggMap.put(resolved.getVar(alias).asInstanceOf[RepeatPathVar].pathVar, newAggExpr)
} else {
- val propertyVar = PropertyVar(alias, new Field(p._1.name, KTObject, true))
+ val propertyVar = PropertyVar(alias, new Field(p._1.name, ruleRetType, true))
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
aggMap.put(propertyVar, newAggExpr)
}
case IRVariable(alias) =>
val tmpPropertyVar = resolved.tmpFields(IRVariable(alias))
val propertyVar =
- PropertyVar(tmpPropertyVar.name, new Field(p._1.name, KTObject, true))
+ PropertyVar(tmpPropertyVar.name, new Field(p._1.name, ruleRetType, true))
aggMap.put(propertyVar, newAggExpr)
resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar))
case _ =>
diff --git a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/ProjectPlanner.scala b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/ProjectPlanner.scala
index 259789c70..6dd6042a0 100644
--- a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/ProjectPlanner.scala
+++ b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/ProjectPlanner.scala
@@ -16,18 +16,19 @@ package com.antgroup.openspg.reasoner.lube.logical.planning
import scala.collection.mutable
import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException
-import com.antgroup.openspg.reasoner.common.types.KTObject
+import com.antgroup.openspg.reasoner.common.types.KgType
import com.antgroup.openspg.reasoner.lube.block.ProjectFields
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
import com.antgroup.openspg.reasoner.lube.common.expr.{Directly, Expr}
import com.antgroup.openspg.reasoner.lube.common.graph._
-import com.antgroup.openspg.reasoner.lube.logical.{ExprUtil, PropertyVar, SolvedModel, Var}
+import com.antgroup.openspg.reasoner.lube.common.rule.Rule
+import com.antgroup.openspg.reasoner.lube.logical._
import com.antgroup.openspg.reasoner.lube.logical.operators.{LogicalOperator, Project, StackingLogicalOperator}
import com.antgroup.openspg.reasoner.lube.utils.RuleUtils
import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Rule2ExprTransformer
import org.apache.commons.lang3.StringUtils
-class ProjectPlanner(projects: ProjectFields) {
+class ProjectPlanner(projects: ProjectFields)(implicit context: LogicalPlannerContext) {
def plan(dependency: LogicalOperator): LogicalOperator = {
val projectMap = new mutable.HashMap[Var, Expr]()
@@ -49,7 +50,7 @@ class ProjectPlanner(projects: ProjectFields) {
v
}
})
- val propertyVar = getTarget(rule._1, referVars, resolved, dependency)
+ val propertyVar = getTarget(rule._1, referVars, rule._2, resolved, dependency)
val transformer = new Rule2ExprTransformer()
val reference = ruleReferVars.filter(_.isInstanceOf[IRVariable])
val replaceVar = reference
@@ -68,12 +69,27 @@ class ProjectPlanner(projects: ProjectFields) {
private def getTarget(
left: IRField,
referVars: List[IRField],
+ rule: Rule,
resolved: SolvedModel,
dependency: LogicalOperator): PropertyVar = {
+ val referTypes = new mutable.HashMap[IRField, KgType]()
+ for (v <- referVars) {
+ resolved.getVar(v.name) match {
+ case p: PropertyVar => referTypes.put(v, p.field.kgType)
+ case node: NodeVar =>
+ node.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
+ case edge: EdgeVar =>
+ edge.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType))
+ case _ =>
+ }
+ }
+ referTypes.++=(resolved.tmpFields.map(p => (p._1, p._2.field.kgType)))
+ val ruleRetType = ExprUtil.getTargetType(rule, referTypes.toMap, context.catalog.getUdfRepo)
+
left match {
case IRVariable(name) =>
if (referVars.size == 1) {
- PropertyVar(referVars.head.name, new Field(name, KTObject, true))
+ PropertyVar(referVars.head.name, new Field(name, ruleRetType, true))
} else {
val aliasSet = new mutable.HashSet[String]()
for (rVar <- referVars) {
@@ -84,11 +100,11 @@ class ProjectPlanner(projects: ProjectFields) {
}
}
val targetAlias = getTargetAlias(aliasSet.toSet, dependency)
- PropertyVar(targetAlias, new Field(left.name, KTObject, true))
+ PropertyVar(targetAlias, new Field(left.name, ruleRetType, true))
}
- case IRProperty(name, field) => PropertyVar(name, new Field(field, KTObject, true))
- case IRNode(name, fields) => PropertyVar(name, new Field(fields.head, KTObject, true))
- case IREdge(name, fields) => PropertyVar(name, new Field(fields.head, KTObject, true))
+ case IRProperty(name, field) => PropertyVar(name, new Field(field, ruleRetType, true))
+ case IRNode(name, fields) => PropertyVar(name, new Field(fields.head, ruleRetType, true))
+ case IREdge(name, fields) => PropertyVar(name, new Field(fields.head, ruleRetType, true))
case _ => throw UnsupportedOperationException(s"cannot support $left")
}
diff --git a/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtilTests.scala b/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtilTests.scala
index 432097636..83ebd1b06 100644
--- a/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtilTests.scala
+++ b/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtilTests.scala
@@ -13,11 +13,13 @@
package com.antgroup.openspg.reasoner.lube.logical
-import com.antgroup.openspg.reasoner.common.types.KTInteger
+import com.antgroup.openspg.reasoner.common.types.{KgType, KTBoolean, KTDouble, KTInteger, KTLong, KTString}
import com.antgroup.openspg.reasoner.lube.catalog.struct.Field
-import com.antgroup.openspg.reasoner.lube.common.expr.{GetField, Ref, UnaryOpExpr}
-import com.antgroup.openspg.reasoner.lube.common.graph.IRVariable
+import com.antgroup.openspg.reasoner.lube.common.expr.{Expr, Ref, UnaryOpExpr}
+import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable}
import com.antgroup.openspg.reasoner.lube.common.rule.{ProjectRule, Rule}
+import com.antgroup.openspg.reasoner.parser.expr.RuleExprParser
+import com.antgroup.openspg.reasoner.udf.UdfMngFactory
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal}
@@ -26,16 +28,52 @@ class ExprUtilTests extends AnyFunSpec {
val replaceMap = Map.apply(
"a" -> PropertyVar("b", new Field("id", KTInteger, true)),
"c" -> PropertyVar("d", new Field("id", KTInteger, true)))
- val r1 = ProjectRule(IRVariable("test"),
- KTInteger, Ref("a"))
- val r2 = ProjectRule(IRVariable("test"),
- KTInteger, Ref("c"))
+ val r1 = ProjectRule(IRVariable("test"), Ref("a"))
+ val r2 = ProjectRule(IRVariable("test"), Ref("c"))
r1.addDependency(r2)
val newRule: Rule = ExprUtil.transExpr(r1, replaceMap)
print(newRule.getExpr.pretty)
newRule.getExpr.isInstanceOf[UnaryOpExpr] should equal(true)
newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.isInstanceOf[Ref] should equal(true)
- newRule.getExpr.asInstanceOf[UnaryOpExpr]
- .arg.asInstanceOf[Ref].refName should equal("b")
+ newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.asInstanceOf[Ref].refName should equal("b")
+ }
+
+ it("test expr output type") {
+ val parser = new RuleExprParser()
+ val udfRepo = UdfMngFactory.getUdfMng
+ val map = Map.apply(IRProperty("A", "age") -> KTInteger).asInstanceOf[Map[IRField, KgType]]
+
+ var expr: Expr = parser.parse("A.age")
+ ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTInteger)
+
+ expr = parser.parse("A.age + 1")
+ ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTLong)
+
+ expr = parser.parse("A.age > 10")
+ ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTBoolean)
+
+ expr = parser.parse("floor(A.age)")
+ ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTDouble)
+
+ expr = parser.parse("abs(A.age)")
+ ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTInteger)
+
+ expr = parser.parse("concat(A.age, \",\")")
+ ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTString)
+
+ expr = parser.parse("cast_type(A.age, 'string')")
+ ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTString)
+ }
+
+ it("test rule output type") {
+ val parser = new RuleExprParser()
+ val udfRepo = UdfMngFactory.getUdfMng
+ val map = Map.apply(IRProperty("A", "age") -> KTInteger).asInstanceOf[Map[IRField, KgType]]
+
+ val rule = ProjectRule(IRVariable("newAge"), parser.parse("age * 10"))
+ val r1 = ProjectRule(IRVariable("age"), parser.parse("A.age"))
+ rule.addDependency(r1)
+
+ ExprUtil.getTargetType(rule, map, udfRepo) should equal(KTLong)
}
}
diff --git a/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/LogicalPlannerTests.scala b/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/LogicalPlannerTests.scala
index de18e507b..7c717637b 100644
--- a/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/LogicalPlannerTests.scala
+++ b/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/LogicalPlannerTests.scala
@@ -163,66 +163,6 @@ class LogicalPlannerTests extends AnyFunSpec {
println(optimizedLogicalPlan.pretty)
}
- it("online") {
- val dsl = """Define (s:User)-[p:redPacket]->(o:Int) {
- | GraphStructure {
- | (s)
- | }
- | Rule {
- |LatestHighFrequencyMonthPayCount=s.ngfe_tag__pay_cnt_m
- |Latest30DayPayCount=s.ngfe_tag__pay_cnt_d
- |Latest7DayPayCount=s.ngfe_tag__pay_cnt_d
- |LatestTTT = Latest7DayPayCount.accumulate(+)
- |LatestHighFrequencyMonthAveragePayCount=get_first_notnull(maximum(LatestHighFrequencyMonthPayCount), 0.0) / 30.0
- |Latest7DayPayCountSum=Latest7DayPayCount
- |Latest7DayPayCountAverage=Latest7DayPayCountSum / 7.0
- |HighReduceValue=(LatestHighFrequencyMonthAveragePayCount - Latest7DayPayCountAverage)/LatestHighFrequencyMonthAveragePayCount
- |HighLost("高频降频100%"):HighReduceValue == 1
- |HighReduce80("高频降频80%"):HighReduceValue >= 0.8 and HighReduceValue < 1
- |HighReduce50("高频降频50%"):HighReduceValue >= 0.5 and HighReduceValue < 0.8
- |HighReduce30("高频降频30%"):HighReduceValue >= 0.3 and HighReduceValue < 0.5
- |HighReduce10("高频降频10%"):HighReduceValue >= 0.1 and HighReduceValue < 0.3
- |Latest3060DayPayCount=s.ngfe_tag__pay_cnt_d
- |Latest30DayPayDayCount=size(Latest30DayPayCount)
- |Latest3060DayPayDayCount=size(Latest3060DayPayCount)
- |High1("高频用户1"):Latest3060DayPayDayCount < 13 and Latest30DayPayDayCount >= 13
- |High2("高频用户2"):Latest3060DayPayDayCount > 12 and Latest30DayPayDayCount >= 13
- |Middle1("中频用户1"):Latest3060DayPayDayCount == 0 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
- |Middle2("中频用户2"):Latest3060DayPayDayCount >= 1 and Latest3060DayPayDayCount <= 3 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
- |Middle3("中频用户3"):Latest3060DayPayDayCount >= 4 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12
- |Low1("低频用户1"):Latest3060DayPayDayCount >= 1 and Latest3060DayPayDayCount <= 3 and Latest30DayPayDayCount >= 1 and Latest30DayPayDayCount <= 3
- |Low2("低频用户2"):(Latest3060DayPayDayCount > 3 or Latest3060DayPayDayCount == 0) and Latest30DayPayDayCount >= 1 and Latest30DayPayDayCount <= 3
- |Latest6090DayPayCount=s.ngfe_tag__pay_cnt_d
- |Latest6090DayPayDayCount=size(Latest6090DayPayCount)
- |Latest60DayPayCount=s.ngfe_tag__pay_cnt_d
- |Latest60DayPayDayCount=size(Latest60DayPayCount)
- |Sleep1("沉睡用户1"):Latest6090DayPayDayCount > 0 and Latest60DayPayDayCount == 0
- |Sleep2("沉睡用户2"):Latest3060DayPayDayCount > 0 and Latest30DayPayDayCount == 0
- |HistoricallyPay=s.ngfe_tag__pay_cnt_total
- |HistoricallyPayCount=size(HistoricallyPay)
- |New("新用户"):HistoricallyPayCount == 0 and Latest30DayPayDayCount == 0
- |Latest90DayPayCount=s.ngfe_tag__pay_cnt_d
- |Latest90DayPayDayCount=size(Latest90DayPayCount)
- |Lost("流失用户"):HistoricallyPayCount > 0 and Latest90DayPayDayCount == 0
- |o=get_first_notnull(rule_value(HighLost, "high_lost"), rule_value(HighReduce80, "high_reduce_80"),rule_value(HighReduce50, "high_reduce_50"), rule_value(HighReduce30, "high_reduce_30"), rule_value(HighReduce10, "high_reduce_10"), rule_value(High1, "high_1"), rule_value(High2, "high_2"), rule_value(Middle1, "middle_1"), rule_value(Middle2, "middle_2"), rule_value(Middle3, "middle_3"), rule_value(Low1, "low_1"), rule_value(Low2, "low_2"), rule_value(Sleep1, "sleep_1"), rule_value(Sleep2, "sleep_2"), rule_value(New, "new"), rule_value(Lost, "lost"))
- | }
- |}""".stripMargin
- val parser = new OpenSPGDslParser()
- val block = parser.parse(dsl)
- println(block.pretty)
- val schema: Map[String, Set[String]] = Map.apply(
- "User" -> Set
- .apply("ngfe_tag__pay_cnt_m", "ngfe_tag__pay_cnt_total", "ngfe_tag__pay_cnt_d"))
- val catalog = new PropertyGraphCatalog(schema)
- catalog.init()
- implicit val context: LogicalPlannerContext =
- LogicalPlannerContext(catalog, parser, Map.empty)
- val logicalPlan = LogicalPlanner.plan(block)
- println(logicalPlan.head.pretty)
- val optimizedLogicalPlan = LogicalOptimizer.optimize(logicalPlan.head)
- println(optimizedLogicalPlan.pretty)
- }
-
it("test start flag") {
val dsl =
"""
@@ -394,7 +334,7 @@ class LogicalPlannerTests extends AnyFunSpec {
| (s)-[p2:followPM]->(o)
|}
|Rule {
- | c = rule_value(p.avgProfit > 0, 1,0 ) + rule_value(p2.times>3, 1,0)
+ | c = rule_value(p.avgProfit > 0, 1,0 ) && rule_value(p2.times>3, 1,0)
|
|}
|Action {
diff --git a/reasoner/pom.xml b/reasoner/pom.xml
index bcf6620c2..0f70033a3 100644
--- a/reasoner/pom.xml
+++ b/reasoner/pom.xml
@@ -41,7 +41,7 @@
4.8
1.2.71_noneautotype
- 27.0
+ 28.0
2.0.0
2.7.2
3.1.0
diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerABMLocalTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerABMLocalTest.java
index 6fcbd4822..8c5a03aa9 100644
--- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerABMLocalTest.java
+++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerABMLocalTest.java
@@ -137,7 +137,7 @@ public void test11() {
+ " \t(s)-[:p3]->(t:Attribute1.Name142)\n"
+ " }\n"
+ " Rule {\n"
- + " \tv = t.stock/t.total\n"
+ + " \tv = cast_type(t.stock, 'double')/cast_type(t.total,'double')\n"
+ " R1(\"必须大于20%\"): v > 0.2\n"
+ " o = v\n"
+ " }\n"
@@ -197,7 +197,7 @@ public void test9() {
+ " \t(s)-[:p3]->(t:Attribute1.Name142)\n"
+ " }\n"
+ " Rule {\n"
- + " \tv = t.stock/t.total\n"
+ + " \tv = cast_type(t.stock, 'double')/cast_type(t.total,'double')\n"
+ " R1(\"必须大于20%\"): v > 0.2\n"
+ " o = v\n"
+ " }\n"
diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAliasSetKFilmTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAliasSetKFilmTest.java
index d0681d48e..443e66241 100644
--- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAliasSetKFilmTest.java
+++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAliasSetKFilmTest.java
@@ -50,9 +50,12 @@ private void doTest1() {
+ "R1: A.id == $idSet1\n"
+ "R2: B.id in $idSet2\n"
+ "R3: C.id in $idSet2\n"
- + "totalTrans1 = group(A,B,C).sum(p1.amount)\n"
- + "totalTrans2 = group(A,B,C).sum(p2.amount)\n"
- + "totalTrans3 = group(A,B,C).sum(p3.amount)\n"
+ + "p1_amt = cast_type(p1.amount,'long')\n"
+ + "p2_amt = cast_type(p2.amount,'long')\n"
+ + "p3_amt = cast_type(p3.amount,'long')\n"
+ + "totalTrans1 = group(A,B,C).sum(p1_amt)\n"
+ + "totalTrans2 = group(A,B,C).sum(p2_amt)\n"
+ + "totalTrans3 = group(A,B,C).sum(p3_amt)\n"
+ "totalTrans = totalTrans1 + totalTrans2 + totalTrans3\n"
+ "R2('取top2'): top(totalTrans, 2)"
+ "}\n"
@@ -90,7 +93,7 @@ private void doTest2() {
+ "R1: A.id in $idSet1\n"
+ "R2: B.id in $idSet2\n"
+ "R3: C.id in $idSet2\n"
- + "totalTrans = p1.amount + p2.amount + p3.amount\n"
+ + "totalTrans = cast_type(p1.amount,'long') + cast_type(p2.amount,'long') + cast_type(p3.amount,'long')\n"
+ "R2('取top2'): top(totalTrans, 3)"
+ "}\n"
+ "Action {\n"
@@ -127,7 +130,7 @@ private void doTest3() {
+ "R1: A.id == $idSet1\n"
+ "R2: B.id == $idSet2\n"
+ "R3: C.id == $idSet3\n"
- + "totalTrans = p1.amount + p2.amount + p3.amount\n"
+ + "totalTrans = cast_type(p1.amount,'long') + cast_type(p2.amount,'long') + cast_type(p3.amount,'long')\n"
+ "R2('取top2'): top(totalTrans, 3)"
+ "}\n"
+ "Action {\n"
@@ -169,9 +172,12 @@ private void doTest4() {
+ "R1: A.id in $idSet1\n"
+ "R2: B.id in $idSet2\n"
+ "R3: C.id in $idSet2\n"
- + "t1 = group(A,B,C).sum(p1.amount)\n"
- + "t2 = group(A,B,C).sum(p2.amount)\n"
- + "t3 = group(A,B,C).sum(p3.amount)\n"
+ + "p1_amt = cast_type(p1.amount,'long')\n"
+ + "p2_amt = cast_type(p2.amount,'long')\n"
+ + "p3_amt = cast_type(p3.amount,'long')\n"
+ + "t1 = group(A,B,C).sum(p1_amt)\n"
+ + "t2 = group(A,B,C).sum(p2_amt)\n"
+ + "t3 = group(A,B,C).sum(p3_amt)\n"
+ "totalSum = t1 + t2 + t3"
+ "}\n"
+ "Action {\n"
diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerTopKFilmTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerTopKFilmTest.java
index 6e675da2f..1572df587 100644
--- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerTopKFilmTest.java
+++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerTopKFilmTest.java
@@ -403,7 +403,7 @@ private void doTest12() {
+ " o->star [starOfFilm] as sf2\n"
+ "}\n"
+ "Rule {\n"
- + "total = sf.joinTs + sf2.joinTs\n"
+ + "total = cast_type(sf.joinTs, 'bigint') + cast_type(sf2.joinTs, 'bigint')\n"
+ "R2: top(total, 1)\n"
+ "}\n"
+ "Action {\n"
diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/KgGraphAggregateImpl.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/KgGraphAggregateImpl.java
index 689f2e199..ed2ec729c 100644
--- a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/KgGraphAggregateImpl.java
+++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/KgGraphAggregateImpl.java
@@ -25,13 +25,7 @@
import com.antgroup.openspg.reasoner.kggraph.KgGraph;
import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphImpl;
import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphSplitStaticParameters;
-import com.antgroup.openspg.reasoner.lube.common.expr.AggIfOpExpr;
-import com.antgroup.openspg.reasoner.lube.common.expr.AggOpExpr;
-import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator;
-import com.antgroup.openspg.reasoner.lube.common.expr.Expr;
-import com.antgroup.openspg.reasoner.lube.common.expr.GetField;
-import com.antgroup.openspg.reasoner.lube.common.expr.Ref;
-import com.antgroup.openspg.reasoner.lube.common.expr.UnaryOpExpr;
+import com.antgroup.openspg.reasoner.lube.common.expr.*;
import com.antgroup.openspg.reasoner.lube.common.pattern.Pattern;
import com.antgroup.openspg.reasoner.lube.logical.EdgeVar;
import com.antgroup.openspg.reasoner.lube.logical.NodeVar;
@@ -41,6 +35,7 @@
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggIfOpProcessBaseGroupProcess;
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggOpProcessBaseGroupProcess;
import com.antgroup.openspg.reasoner.rdg.common.groupProcess.BaseGroupProcess;
+import com.antgroup.openspg.reasoner.rdg.common.groupProcess.ParsedAggEle;
import com.antgroup.openspg.reasoner.udf.model.BaseUdaf;
import com.antgroup.openspg.reasoner.udf.model.UdafMeta;
import com.antgroup.openspg.reasoner.udf.rule.RuleRunner;
@@ -353,26 +348,14 @@ private Object doAggregation(
if (null != udafInitParams) {
udaf.initialize(udafInitParams);
}
-
- String sourceAlias = null;
- String sourcePropertyName = null;
+ ParsedAggEle parsedAggEle;
Set aliasList = aggInfo.getExprUseAliasSet();
if (aliasList.size() <= 1) {
- Expr sourceExpr = aggInfo.getAggEle();
- // aggregate by vertex subgraph
- if (sourceExpr instanceof Ref) {
- Ref sourceRef = (Ref) sourceExpr;
- sourceAlias = sourceRef.refName();
- } else if (sourceExpr instanceof UnaryOpExpr) {
- UnaryOpExpr expr = (UnaryOpExpr) sourceExpr;
- GetField getField = (GetField) expr.name();
- sourceAlias = ((Ref) expr.arg()).refName();
- sourcePropertyName = getField.fieldName();
- }
+ parsedAggEle = aggInfo.getParsedAggEle();
if (!StringUtils.isEmpty(DEBUG_VERTEX_ALIAS)) {
for (KgGraph valueFiltered : valueFilteredList) {
if (valueFiltered.hasFocusVertexId(DEBUG_VERTEX_ALIAS, DEBUG_VERTEX_ID_SET)) {
- StringBuffer sb = new StringBuffer();
+ StringBuilder sb = new StringBuilder();
for (KgGraph valueFiltered2 : valueFilteredList) {
sb.append(valueFiltered2).append("## ");
}
@@ -381,19 +364,41 @@ private Object doAggregation(
}
}
}
- String finalSourcePropertyName = sourcePropertyName;
+ String finalSourcePropertyName = parsedAggEle.getSourcePropertyName();
for (KgGraph valueFiltered : valueFilteredList) {
- if (valueFiltered.getVertexAlias().contains(sourceAlias)) {
- List> vertexList = valueFiltered.getVertex(sourceAlias);
- if (sourcePropertyName == null) {
+ if (valueFiltered.getVertexAlias().contains(parsedAggEle.getSourceAlias())) {
+ List> vertexList =
+ valueFiltered.getVertex(parsedAggEle.getSourceAlias());
+ if (CollectionUtils.isNotEmpty(parsedAggEle.getExprStrList())) {
+ vertexList.forEach(
+ vertex -> {
+ Map context =
+ RunnerUtil.vertexContext(vertex, parsedAggEle.getSourceAlias());
+ Object value =
+ RuleRunner.getInstance()
+ .executeExpression(context, parsedAggEle.getExprStrList(), taskId);
+ udaf.update(value);
+ });
+ } else if (finalSourcePropertyName == null) {
vertexList.forEach(udaf::update);
} else {
vertexList.forEach(
v -> updateUdafDataFromProperty(udaf, v.getValue(), finalSourcePropertyName));
}
} else {
- List> edgeList = valueFiltered.getEdge(sourceAlias);
- if (sourcePropertyName == null) {
+ List> edgeList =
+ valueFiltered.getEdge(parsedAggEle.getSourceAlias());
+ if (CollectionUtils.isNotEmpty(parsedAggEle.getExprStrList())) {
+ edgeList.forEach(
+ edge -> {
+ Map context =
+ RunnerUtil.edgeContext(edge, null, parsedAggEle.getSourceAlias());
+ Object value =
+ RuleRunner.getInstance()
+ .executeExpression(context, parsedAggEle.getExprStrList(), taskId);
+ udaf.update(value);
+ });
+ } else if (finalSourcePropertyName == null) {
edgeList.forEach(udaf::update);
} else {
edgeList.forEach(
diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggIfOpProcessBaseGroupProcess.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggIfOpProcessBaseGroupProcess.java
index 8dc8e73fb..89fc848f0 100644
--- a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggIfOpProcessBaseGroupProcess.java
+++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggIfOpProcessBaseGroupProcess.java
@@ -23,6 +23,7 @@
import java.util.List;
public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable {
+ private final ParsedAggEle parsedAggEle;
/**
* constructor
@@ -33,6 +34,7 @@ public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements
*/
public AggIfOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) {
super(taskId, var, aggregator);
+ parsedAggEle = parsedAggEle();
}
/**
@@ -58,4 +60,9 @@ public AggregatorOpSet getAggOpSet() {
public Expr getAggEle() {
return getAggIfOpExpr().aggOpExpr().aggEleExpr();
}
+
+ @Override
+ public ParsedAggEle getParsedAggEle() {
+ return parsedAggEle;
+ }
}
diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggOpProcessBaseGroupProcess.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggOpProcessBaseGroupProcess.java
index 91a21dd0a..aeab4b026 100644
--- a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggOpProcessBaseGroupProcess.java
+++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggOpProcessBaseGroupProcess.java
@@ -22,9 +22,11 @@
import java.util.List;
public class AggOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable {
+ private final ParsedAggEle parsedAggEle;
public AggOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) {
super(taskId, var, aggregator);
+ parsedAggEle = parsedAggEle();
}
public AggOpExpr getAggOpExpr() {
@@ -45,4 +47,9 @@ public AggregatorOpSet getAggOpSet() {
public Expr getAggEle() {
return getAggOpExpr().aggEleExpr();
}
+
+ @Override
+ public ParsedAggEle getParsedAggEle() {
+ return this.parsedAggEle;
+ }
}
diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/BaseGroupProcess.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/BaseGroupProcess.java
index 0dc5e2124..564afd51b 100644
--- a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/BaseGroupProcess.java
+++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/BaseGroupProcess.java
@@ -19,6 +19,9 @@
import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator;
import com.antgroup.openspg.reasoner.lube.common.expr.AggregatorOpSet;
import com.antgroup.openspg.reasoner.lube.common.expr.Expr;
+import com.antgroup.openspg.reasoner.lube.common.expr.GetField;
+import com.antgroup.openspg.reasoner.lube.common.expr.Ref;
+import com.antgroup.openspg.reasoner.lube.common.expr.UnaryOpExpr;
import com.antgroup.openspg.reasoner.lube.logical.PropertyVar;
import com.antgroup.openspg.reasoner.lube.logical.Var;
import com.antgroup.openspg.reasoner.lube.utils.ExprUtils;
@@ -139,6 +142,9 @@ protected UdafMeta parseUdafMeta() {
*/
public abstract Expr getAggEle();
+ /** get parsed agg ele */
+ public abstract ParsedAggEle getParsedAggEle();
+
public Set parseExprUseAliasSet() {
scala.collection.immutable.List aliasList = ExprUtils.getRefVariableByExpr(getAggEle());
return new HashSet<>(JavaConversions.seqAsJavaList(aliasList));
@@ -148,6 +154,26 @@ public List parseExprRuleList() {
return WareHouseUtils.getRuleList(getAggEle());
}
+ protected ParsedAggEle parsedAggEle() {
+ String sourceAlias = null;
+ String sourcePropertyName = null;
+ List exprStrList = null;
+ Expr aggEle = getAggEle();
+ if (aggEle instanceof Ref) {
+ Ref sourceRef = (Ref) aggEle;
+ sourceAlias = sourceRef.refName();
+ } else if (aggEle instanceof UnaryOpExpr) {
+ UnaryOpExpr expr = (UnaryOpExpr) aggEle;
+ GetField getField = (GetField) expr.name();
+ sourceAlias = ((Ref) expr.arg()).refName();
+ sourcePropertyName = getField.fieldName();
+ } else if (1 == this.exprUseAliasSet.size()) {
+ sourceAlias = this.exprUseAliasSet.iterator().next();
+ exprStrList = this.exprRuleString;
+ }
+ return new ParsedAggEle(sourceAlias, sourcePropertyName, exprStrList);
+ }
+
/**
* getter
*
diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/ParsedAggEle.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/ParsedAggEle.java
new file mode 100644
index 000000000..08026e6e1
--- /dev/null
+++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/ParsedAggEle.java
@@ -0,0 +1,40 @@
+/*
+ * Copyright 2023 OpenSPG Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
+ * in compliance with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License
+ * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+ * or implied.
+ */
+
+package com.antgroup.openspg.reasoner.rdg.common.groupProcess;
+
+import java.util.List;
+
+public class ParsedAggEle {
+ private final String sourceAlias;
+ private final String sourcePropertyName;
+ private final List exprStrList;
+
+ public ParsedAggEle(String sourceAlias, String sourcePropertyName, List exprStrList) {
+ this.sourceAlias = sourceAlias;
+ this.sourcePropertyName = sourcePropertyName;
+ this.exprStrList = exprStrList;
+ }
+
+ public String getSourceAlias() {
+ return sourceAlias;
+ }
+
+ public String getSourcePropertyName() {
+ return sourcePropertyName;
+ }
+
+ public List getExprStrList() {
+ return exprStrList;
+ }
+}