diff --git a/reasoner/common/src/main/scala/com/antgroup/openspg/reasoner/common/types/KgType.scala b/reasoner/common/src/main/scala/com/antgroup/openspg/reasoner/common/types/KgType.scala index 28a170c6c..ad2762c59 100644 --- a/reasoner/common/src/main/scala/com/antgroup/openspg/reasoner/common/types/KgType.scala +++ b/reasoner/common/src/main/scala/com/antgroup/openspg/reasoner/common/types/KgType.scala @@ -13,6 +13,8 @@ package com.antgroup.openspg.reasoner.common.types +import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException + trait KgType { def isNullable: Boolean = false } @@ -64,3 +66,16 @@ final case class KTAdvanced(label: String) extends KgType * @param elementType */ final case class KTMultiVersion(elementType: KgType) extends KgType + +object KgType { + + def getNumberSeq(kgType: KgType): Int = { + kgType match { + case KTInteger => 1 + case KTLong => 2 + case KTDouble => 3 + case _ => throw UnsupportedOperationException(s"cannot support number type $kgType") + } + } + +} diff --git a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParser.scala b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParser.scala index c87a20017..5f3705dd6 100644 --- a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParser.scala +++ b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParser.scala @@ -158,7 +158,6 @@ class OpenSPGDslParser extends ParserInterface { IRProperty(s.alias, propertyName) -> ProjectRule( IRProperty(s.alias, propertyName), - propertyType, Ref(ddlBlockWithNodes._3.target.alias))))) DDLBlock(Set.apply(ddlBlockOp), List.apply(prjBlk)) case AddPredicate(predicate) => @@ -399,7 +398,7 @@ class OpenSPGDslParser extends ParserInterface { ProjectBlock( List.apply(opBlock), ProjectFields(Map.apply(lValueName -> - ProjectRule(lValueName, exprParser.parseRetType(opChain.curExpr), opChain)))) + ProjectRule(lValueName, opChain)))) } case AggIfOpExpr(_, _) | AggOpExpr(_, _) => ProjectBlock( @@ -409,7 +408,6 @@ class OpenSPGDslParser extends ParserInterface { lValueName -> ProjectRule( lValueName, - exprParser.parseRetType(opChain.curExpr), opChain.curExpr)))) case _ => null } @@ -461,8 +459,8 @@ class OpenSPGDslParser extends ParserInterface { List.empty) case _ => rule match { - case ProjectRule(_, lvalueType, _) => - val projectRule = ProjectRule(lvalueFiled, lvalueType, expr) + case ProjectRule(_, _) => + val projectRule = ProjectRule(lvalueFiled, expr) ProjectBlock( List.apply(preBlock), ProjectFields(Map.apply(lvalueFiled -> projectRule))) @@ -727,7 +725,7 @@ class OpenSPGDslParser extends ParserInterface { exprParser.parseUnbrokenCharacterStringLiteral(ctx.unbroken_character_string_literal())) val defaultName = "const_output_" + patternParser.getDefaultAliasNum val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultName) - (ProjectRule(IRVariable(defaultName), KTString, expr), columnName, true) + (ProjectRule(IRVariable(defaultName), expr), columnName, true) } def parseGraphStructure( @@ -744,7 +742,7 @@ class OpenSPGDslParser extends ParserInterface { val defaultColumnName = parseExpr2ElementStr(expr) val columnName = parseAsAliasWithComment(ctx.as_alias_with_comment(), defaultColumnName) ( - ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr), + ProjectRule(IRVariable(defaultColumnName), expr), columnName, false) } @@ -861,7 +859,7 @@ class OpenSPGDslParser extends ParserInterface { val defaultColumnName = parseExpr2ElementStr(expr) val columnName = parseReturnAlias(ctx.return_item_alias(), defaultColumnName) ( - ProjectRule(IRVariable(defaultColumnName), exprParser.parseRetType(expr), expr), + ProjectRule(IRVariable(defaultColumnName), expr), columnName, false) } diff --git a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/expr/RuleExprParser.scala b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/expr/RuleExprParser.scala index 94a2ee600..6b8bb5027 100644 --- a/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/expr/RuleExprParser.scala +++ b/reasoner/kgdsl-parser/src/main/scala/com/antgroup/openspg/reasoner/parser/expr/RuleExprParser.scala @@ -849,10 +849,6 @@ class RuleExprParser extends Serializable { } } - def parseRetType(expr: Expr): KgType = { - KTObject - } - def parseRuleExpression(ctx: Rule_expressionContext): Rule = { ctx.getChild(0) match { case c: Logic_rule_expressionContext => parseLogicRuleExpression(c) @@ -878,10 +874,9 @@ class RuleExprParser extends Serializable { if (ctx.property_name() != null) { ProjectRule( IRProperty(ctx.identifier().getText, ctx.property_name().getText), - parseRetType(expr), expr) } else { - ProjectRule(IRVariable(ctx.identifier().getText), parseRetType(expr), expr) + ProjectRule(IRVariable(ctx.identifier().getText), expr) } } diff --git a/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParserTest.scala b/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParserTest.scala index 3a84b0c47..721f99ff9 100644 --- a/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParserTest.scala +++ b/reasoner/kgdsl-parser/src/test/scala/com/antgroup/openspg/reasoner/parser/OpenSPGDslParserTest.scala @@ -15,7 +15,6 @@ package com.antgroup.openspg.reasoner.parser import com.antgroup.openspg.reasoner.common.constants.Constants import com.antgroup.openspg.reasoner.common.exception.{KGDSLGrammarException, KGDSLInvalidTokenException, KGDSLOneTaskException} -import com.antgroup.openspg.reasoner.common.types.{KTInteger, KTString} import com.antgroup.openspg.reasoner.lube.block._ import com.antgroup.openspg.reasoner.lube.common.expr._ import com.antgroup.openspg.reasoner.lube.common.graph._ @@ -296,8 +295,7 @@ class OpenSPGDslParserTest extends AnyFunSpec { print(block.pretty) block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true) val proj = block.dependencies.head.asInstanceOf[ProjectBlock] - proj.projects.items.head._2 should equal( - ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o"))) + proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o"))) } it("addproperies with constraint") { @@ -314,8 +312,7 @@ class OpenSPGDslParserTest extends AnyFunSpec { print(block.pretty) block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true) val proj = block.dependencies.head.asInstanceOf[ProjectBlock] - proj.projects.items.head._2 should equal( - ProjectRule(IRProperty("s", "totalText"), KTString, Ref("o"))) + proj.projects.items.head._2 should equal(ProjectRule(IRProperty("s", "totalText"), Ref("o"))) } it("addproperies2") { @@ -334,7 +331,7 @@ class OpenSPGDslParserTest extends AnyFunSpec { block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true) val proj = block.dependencies.head.asInstanceOf[ProjectBlock] proj.projects.items.head._2 should equal( - ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o"))) + ProjectRule(IRProperty("s", "total_domain_num"), Ref("o"))) } it("addproperies") { val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) { @@ -352,7 +349,7 @@ class OpenSPGDslParserTest extends AnyFunSpec { block.dependencies.head.isInstanceOf[ProjectBlock] should equal(true) val proj = block.dependencies.head.asInstanceOf[ProjectBlock] proj.projects.items.head._2 should equal( - ProjectRule(IRProperty("s", "total_domain_num"), KTInteger, Ref("o"))) + ProjectRule(IRProperty("s", "total_domain_num"), Ref("o"))) } it("addNode") { val dsl = """Define (s:DomainFamily)-[p:total_domain_num]->(o:Int) { @@ -661,7 +658,7 @@ class OpenSPGDslParserTest extends AnyFunSpec { .asInstanceOf[AddPredicate] .predicate .fields - .keySet should contain ("same_domain_num") + .keySet should contain("same_domain_num") blocks(1).asInstanceOf[DDLBlock].ddlOp.head.isInstanceOf[AddProperty] should equal(true) blocks(1) @@ -1048,7 +1045,7 @@ class OpenSPGDslParserTest extends AnyFunSpec { val block = parser.parse(dsl) print(block.pretty) val text = """└─DDLBlock(ddlOp=Set(AddProperty((s:CustFundKG.Account),aggTransAmountNumByDay,KTBoolean))) - | └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),KTBoolean,Ref(refName=o))))) + | └─ProjectBlock(projects=ProjectFields(Map(IRProperty(s,aggTransAmountNumByDay) -> ProjectRule(IRProperty(s,aggTransAmountNumByDay),Ref(refName=o))))) | └─AggregationBlock(aggregations=Aggregations(Map(IRVariable(o) -> AggOpExpr(name=AggUdf(groupByAttrDoCount,List(VString(value=tranDate), VLong(value=50)))))), group=List(IRNode(s,Set()))) | └─FilterBlock(rules=LogicRule(R1,当月交易,BinaryOpExpr(name=BNotSmallerThan))) | └─MatchBlock(patterns=Map(unresolved_default_path -> GraphPath(unresolved_default_path,GraphPattern(s,Map(u -> (u:CustFundKG.Account), s -> (s:CustFundKG.Account)),Map(u -> Set((u)<-[t:accountFundContact]-(s)))),Map(u -> Set(), s -> Set(), t -> Set(transDate))),false))) @@ -1081,7 +1078,7 @@ class OpenSPGDslParserTest extends AnyFunSpec { val block = parser.parse(dsl) print(block.pretty) val text = """└─TableResultBlock(selectList=OrderedFields(List(IRProperty(s,id), IRVariable(o))), asList=List(s.id, b)) - * └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),KTObject,FunctionExpr(name=rule_value))))) + * └─ProjectBlock(projects=ProjectFields(Map(IRVariable(o) -> ProjectRule(IRVariable(o),FunctionExpr(name=rule_value))))) * └─FilterBlock(rules=LogicRule(R6,长得高,BinaryOpExpr(name=BOr))) * └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R5) -> LogicRule(R5,颜值高,BinaryOpExpr(name=BGreaterThan))))) * └─ProjectBlock(projects=ProjectFields(Map(IRVariable(R4) -> LogicRule(R4,女性,BinaryOpExpr(name=BEqual))))) diff --git a/reasoner/lube-api/pom.xml b/reasoner/lube-api/pom.xml index 87d0aa2e8..5377885b0 100644 --- a/reasoner/lube-api/pom.xml +++ b/reasoner/lube-api/pom.xml @@ -33,6 +33,10 @@ com.antgroup.openspg.reasoner reasoner-common + + com.antgroup.openspg.reasoner + reasoner-udf + org.scala-lang scala-library diff --git a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/catalog/Catalog.scala b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/catalog/Catalog.scala index 3d55ecaa0..171d46a48 100644 --- a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/catalog/Catalog.scala +++ b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/catalog/Catalog.scala @@ -13,10 +13,12 @@ package com.antgroup.openspg.reasoner.lube.catalog +import scala.collection.mutable + import com.antgroup.openspg.reasoner.common.exception.{ConnectionNotFoundException, GraphAlreadyExistsException, GraphNotFoundException} import com.antgroup.openspg.reasoner.lube.catalog.struct.Field import com.antgroup.openspg.reasoner.lube.common.graph.IRGraph -import scala.collection.mutable +import com.antgroup.openspg.reasoner.udf.{UdfMng, UdfMngFactory} /** @@ -27,6 +29,7 @@ import scala.collection.mutable */ abstract class Catalog() extends Serializable { protected val graphRepository = new mutable.HashMap[String, SemanticPropertyGraph]() + @transient private val udfRepo = UdfMngFactory.getUdfMng private val connections = new mutable.HashMap[String, mutable.HashSet[AbstractConnection]]() /** @@ -96,6 +99,8 @@ abstract class Catalog() extends Serializable { graphRepository.get(graphName).orNull } + def getUdfRepo: UdfMng = udfRepo + /** * Get schema from knowledge graph */ diff --git a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/Rule.scala b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/Rule.scala index eb48b7f50..99e475c2a 100644 --- a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/Rule.scala +++ b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/Rule.scala @@ -13,7 +13,6 @@ package com.antgroup.openspg.reasoner.lube.common.rule -import com.antgroup.openspg.reasoner.common.types.KgType import com.antgroup.openspg.reasoner.lube.common.expr.Expr import com.antgroup.openspg.reasoner.lube.common.graph.IRField @@ -39,13 +38,6 @@ trait Rule extends Cloneable{ */ def getExpr: Expr - - /** - * get lvalue type - * @return - */ - def getLvalueType: KgType - /** * get dependencies * @return diff --git a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/RuleImpl.scala b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/RuleImpl.scala index 322416a0a..07a0ccf06 100644 --- a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/RuleImpl.scala +++ b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/common/rule/RuleImpl.scala @@ -71,13 +71,6 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr) */ override def getExpr: Expr = expr - /** - * get lvalue type - * - * @return - */ - override def getLvalueType: KgType = KTBoolean - override def andRule(rule: Rule): Rule = { val andExpr = BinaryOpExpr(BAnd, getExpr, rule.getExpr) @@ -129,7 +122,7 @@ final case class LogicRule(ruleName: String, ruleExplain: String, expr: Expr) * @param lvalueType * @param expr */ -final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr) +final case class ProjectRule(output: IRField, expr: Expr) extends DependencyRule { /** @@ -158,12 +151,6 @@ final case class ProjectRule(output: IRField, lvalueType: KgType, expr: Expr) */ override def getExpr: Expr = expr - /** - * get lvalue type - * - * @return - */ - override def getLvalueType: KgType = lvalueType override def andRule(rule: Rule): Rule = { throw UnsupportedOperationException("ProjectRule cannot support andRule") diff --git a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/utils/RuleUtils.scala b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/utils/RuleUtils.scala index 84b9b0dbe..ecb6f726e 100644 --- a/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/utils/RuleUtils.scala +++ b/reasoner/lube-api/src/main/scala/com/antgroup/openspg/reasoner/lube/utils/RuleUtils.scala @@ -138,7 +138,7 @@ object RuleUtils { case logicRule: LogicRule => LogicRule(ruleNameStr, logicRule.ruleExplain, expr) case _ => - ProjectRule(IRVariable(ruleNameStr), rule.getLvalueType, expr) + ProjectRule(IRVariable(ruleNameStr), expr) } val oldDependencies = rule.getDependencies if (oldDependencies != null) { @@ -162,7 +162,7 @@ object RuleUtils { case logicRule: LogicRule => LogicRule(rule.getName, logicRule.ruleExplain, expr) case _ => - ProjectRule(rule.getOutput, rule.getLvalueType, expr) + ProjectRule(rule.getOutput, expr) } val oldDependencies = rule.getDependencies if (oldDependencies != null) { diff --git a/reasoner/lube-api/src/test/scala/com/antgroup/openspg/reasoner/parser/TransformerTest.scala b/reasoner/lube-api/src/test/scala/com/antgroup/openspg/reasoner/parser/TransformerTest.scala index 71308867a..408763376 100644 --- a/reasoner/lube-api/src/test/scala/com/antgroup/openspg/reasoner/parser/TransformerTest.scala +++ b/reasoner/lube-api/src/test/scala/com/antgroup/openspg/reasoner/parser/TransformerTest.scala @@ -50,7 +50,7 @@ class TransformerTest extends AnyFunSpec { false)))), ProjectFields( Map.apply(IRVariable("total_domain_num") -> - ProjectRule(IRVariable("total_domain_num"), KTInteger, Ref("o"))))) + ProjectRule(IRVariable("total_domain_num"), Ref("o"))))) val p = BlockUtils.transBlock2Graph(block) p.size should equal(1) p.head.graphPattern.nodes.size should equal(2) @@ -120,7 +120,6 @@ class TransformerTest extends AnyFunSpec { it("rename_rule") { val rule = ProjectRule( IRVariable("a"), - KTObject, BinaryOpExpr( BEqual, UnaryOpExpr(GetField("birthDate"), Ref("e")), @@ -153,12 +152,10 @@ class TransformerTest extends AnyFunSpec { it("variable_rule") { val rule = ProjectRule( IRVariable("a"), - KTObject, BinaryOpExpr(BEqual, UnaryOpExpr(GetField("birthDate"), Ref("e")), Ref("b"))) val rule2 = ProjectRule( IRVariable("b"), - KTObject, BinaryOpExpr( BEqual, UnaryOpExpr(GetField("attr1"), Ref("e")), @@ -174,7 +171,6 @@ class TransformerTest extends AnyFunSpec { def getDependenceRule(): Rule = { val r0 = ProjectRule( IRVariable("r0"), - KTLong, BinaryOpExpr(BAssign, Ref("r0"), VLong("123")) ) val r1 = LogicRule( @@ -230,7 +226,6 @@ class TransformerTest extends AnyFunSpec { val r0 = LogicRule("tmp", "", BinaryOpExpr(BGreaterThan, UnaryOpExpr(GetField("amount"), Ref("E1")), VLong("10"))) val r = ProjectRule(IRVariable("g"), - KTLong, OpChainExpr( GraphAggregatorExpr( "unresolved_default_path", diff --git a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtil.scala b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtil.scala index 041bf3d99..1ddc1da3a 100644 --- a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtil.scala +++ b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtil.scala @@ -13,9 +13,16 @@ package com.antgroup.openspg.reasoner.lube.logical +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException import com.antgroup.openspg.reasoner.common.trees.BottomUp +import com.antgroup.openspg.reasoner.common.types._ import com.antgroup.openspg.reasoner.lube.common.expr._ +import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable} import com.antgroup.openspg.reasoner.lube.common.rule.Rule +import com.antgroup.openspg.reasoner.udf.UdfMng object ExprUtil { @@ -46,15 +53,13 @@ object ExprUtil { } - def needResolved(rule: Expr): Boolean = { !getReferProperties(rule).filter(_._1 == null).isEmpty } def transExpr(rule: Expr, replaceVar: Map[String, PropertyVar]): Expr = { - def rewriter: PartialFunction[Expr, Expr] = { - case Ref(refName) => + def rewriter: PartialFunction[Expr, Expr] = { case Ref(refName) => if (replaceVar.contains(refName)) { val propertyVar = replaceVar(refName) UnaryOpExpr(GetField(propertyVar.field.name), Ref(propertyVar.name)) @@ -76,4 +81,104 @@ object ExprUtil { newRule } + def getTargetType(expr: Expr, referVars: Map[IRField, KgType], udfRepo: UdfMng): KgType = { + expr match { + case Ref(name) => + if (referVars.contains(IRVariable(name))) { + referVars(IRVariable(name)) + } else { + KTObject + } + case UnaryOpExpr(GetField(name), Ref(alis)) => referVars(IRProperty(alis, name)) + case BinaryOpExpr(name, l, r) => + name match { + case BAnd | BEqual | BNotEqual | BGreaterThan | BNotGreaterThan | BSmallerThan | + BNotSmallerThan | BOr | BIn | BLike | BRLike | BAssign => + KTBoolean + case BAdd | BSub | BMul | BDiv | BMod => + val left = getTargetType(l, referVars, udfRepo) + val right = getTargetType(r, referVars, udfRepo) + getUpperType(left, right) + case _ => throw UnsupportedOperationException(s"express cannot support ${name}") + } + case UnaryOpExpr(name, arg) => + name match { + case Not | Exists => KTBoolean + case Abs | Neg => getTargetType(arg, referVars, udfRepo) + case Floor | Ceil => KTDouble + case _ => throw UnsupportedOperationException(s"express cannot support ${name}") + } + case FunctionExpr(name, funcArgs) => + val types = funcArgs.map(getTargetType(_, referVars, udfRepo)) + name match { + case "rule_value" => types(1) + case "cast_type" | "Cast" => + funcArgs(1).asInstanceOf[VString].value match { + case "int" | "bigint" | "long" => KTLong + case "float" | "double" => KTDouble + case "varchar" | "string" => KTString + case _ => + throw UnsupportedOperationException(s"cannot support ${name} to ${funcArgs(1)}") + } + case _ => + val udf = udfRepo.getUdfMeta(name, types.asJava) + if (udf != null) { + udf.getResultType + } else { + throw UnsupportedOperationException(s"cannot find UDF: ${name}") + } + } + + case AggOpExpr(name, args) => + name match { + case Min | Max | Sum | Avg | First | Accumulate(_) => + getTargetType(args, referVars, udfRepo) + case StrJoin(_) => KTString + case Count => KTLong + case AggUdf(name, _) => + val types = getTargetType(args.head, referVars, udfRepo) + val udf = udfRepo.getUdafMeta(name, types) + if (udf != null) { + udf.getResultType + } else { + throw UnsupportedOperationException(s"cannot find UDAF ${name}") + } + case _ => throw UnsupportedOperationException(s"express cannot support ${name}") + } + case OpChainExpr(curExpr, _) => getTargetType(curExpr, referVars, udfRepo) + case ListOpExpr(name, _) => + name match { + case Reduce(_, _, _, initValue) => getTargetType(initValue, referVars, udfRepo) + case Constraint(_, _, _) => KTBoolean + case Get(_) | Slice(_, _) => KTObject + } + case AggIfOpExpr(op, _) => getTargetType(op, referVars, udfRepo) + case VNull | VString(_) => KTString + case VLong(_) => KTLong + case VDouble(_) => KTDouble + case VBoolean(_) => KTBoolean + case VList(_, listType) => KTList(listType) + case _ => throw UnsupportedOperationException(s"express cannot support ${expr.pretty}") + } + } + + def getTargetType(rule: Rule, referVars: Map[IRField, KgType], udfRepo: UdfMng): KgType = { + val newReferVars = new mutable.HashMap[IRField, KgType] + newReferVars.++=(referVars) + for (r <- rule.getDependencies) { + newReferVars.put(r.getOutput, getTargetType(r, referVars, udfRepo)) + } + getTargetType(rule.getExpr, newReferVars.toMap, udfRepo) + } + + private def getUpperType(left: KgType, right: KgType): KgType = { + val l = KgType.getNumberSeq(left) + val r = KgType.getNumberSeq(right) + if (l >= r) { + left + } else { + right + } + } + } diff --git a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/AggregationPlanner.scala b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/AggregationPlanner.scala index 9c1358e1e..ae7df8f90 100644 --- a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/AggregationPlanner.scala +++ b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/AggregationPlanner.scala @@ -16,7 +16,7 @@ package com.antgroup.openspg.reasoner.lube.logical.planning import scala.collection.mutable import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException -import com.antgroup.openspg.reasoner.common.types.KTObject +import com.antgroup.openspg.reasoner.common.types.{KgType, KTObject} import com.antgroup.openspg.reasoner.lube.block.Aggregations import com.antgroup.openspg.reasoner.lube.catalog.struct.Field import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator @@ -26,7 +26,8 @@ import com.antgroup.openspg.reasoner.lube.logical.operators.{Aggregate, LogicalL import com.antgroup.openspg.reasoner.lube.utils.ExprUtils import org.apache.commons.lang3.StringUtils -class AggregationPlanner(group: List[IRField], aggregations: Aggregations) { +class AggregationPlanner(group: List[IRField], aggregations: Aggregations)(implicit + context: LogicalPlannerContext) { def plan(dependency: LogicalOperator): LogicalOperator = { val groupVar: List[Var] = group.map(toVar(_, dependency.solved)) @@ -47,6 +48,21 @@ class AggregationPlanner(group: List[IRField], aggregations: Aggregations) { v } }) + + val referTypes = new mutable.HashMap[IRField, KgType]() + for (v <- ruleFields) { + resolved.getVar(v.name) match { + case p: PropertyVar => referTypes.put(v, p.field.kgType) + case node: NodeVar => + node.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType)) + case edge: EdgeVar => + edge.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType)) + case _ => throw UnsupportedOperationException(s"cannot support $v") + } + } + referTypes.++=(resolved.tmpFields.map(p => (p._1, p._2.field.kgType))) + val ruleRetType = ExprUtil.getTargetType(p._2, referTypes.toMap, context.catalog.getUdfRepo) + val renameVar = ruleFields .filter(_.isInstanceOf[IRVariable]) .map(v => (v, propertyVarToIr(resolved.tmpFields(v.asInstanceOf[IRVariable])))) @@ -56,21 +72,21 @@ class AggregationPlanner(group: List[IRField], aggregations: Aggregations) { val field = getAggregateTarget(referFields, resolved, dependency) field match { case IRNode(alias, _) => - val propertyVar = PropertyVar(alias, new Field(p._1.name, KTObject, true)) + val propertyVar = PropertyVar(alias, new Field(p._1.name, ruleRetType, true)) aggMap.put(propertyVar, newAggExpr) resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar)) case IREdge(alias, _) => if (resolved.getVar(alias).isInstanceOf[RepeatPathVar]) { aggMap.put(resolved.getVar(alias).asInstanceOf[RepeatPathVar].pathVar, newAggExpr) } else { - val propertyVar = PropertyVar(alias, new Field(p._1.name, KTObject, true)) + val propertyVar = PropertyVar(alias, new Field(p._1.name, ruleRetType, true)) resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar)) aggMap.put(propertyVar, newAggExpr) } case IRVariable(alias) => val tmpPropertyVar = resolved.tmpFields(IRVariable(alias)) val propertyVar = - PropertyVar(tmpPropertyVar.name, new Field(p._1.name, KTObject, true)) + PropertyVar(tmpPropertyVar.name, new Field(p._1.name, ruleRetType, true)) aggMap.put(propertyVar, newAggExpr) resolved = resolved.addField((p._1.asInstanceOf[IRVariable], propertyVar)) case _ => diff --git a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/ProjectPlanner.scala b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/ProjectPlanner.scala index 259789c70..6dd6042a0 100644 --- a/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/ProjectPlanner.scala +++ b/reasoner/lube-logical/src/main/scala/com/antgroup/openspg/reasoner/lube/logical/planning/ProjectPlanner.scala @@ -16,18 +16,19 @@ package com.antgroup.openspg.reasoner.lube.logical.planning import scala.collection.mutable import com.antgroup.openspg.reasoner.common.exception.UnsupportedOperationException -import com.antgroup.openspg.reasoner.common.types.KTObject +import com.antgroup.openspg.reasoner.common.types.KgType import com.antgroup.openspg.reasoner.lube.block.ProjectFields import com.antgroup.openspg.reasoner.lube.catalog.struct.Field import com.antgroup.openspg.reasoner.lube.common.expr.{Directly, Expr} import com.antgroup.openspg.reasoner.lube.common.graph._ -import com.antgroup.openspg.reasoner.lube.logical.{ExprUtil, PropertyVar, SolvedModel, Var} +import com.antgroup.openspg.reasoner.lube.common.rule.Rule +import com.antgroup.openspg.reasoner.lube.logical._ import com.antgroup.openspg.reasoner.lube.logical.operators.{LogicalOperator, Project, StackingLogicalOperator} import com.antgroup.openspg.reasoner.lube.utils.RuleUtils import com.antgroup.openspg.reasoner.lube.utils.transformer.impl.Rule2ExprTransformer import org.apache.commons.lang3.StringUtils -class ProjectPlanner(projects: ProjectFields) { +class ProjectPlanner(projects: ProjectFields)(implicit context: LogicalPlannerContext) { def plan(dependency: LogicalOperator): LogicalOperator = { val projectMap = new mutable.HashMap[Var, Expr]() @@ -49,7 +50,7 @@ class ProjectPlanner(projects: ProjectFields) { v } }) - val propertyVar = getTarget(rule._1, referVars, resolved, dependency) + val propertyVar = getTarget(rule._1, referVars, rule._2, resolved, dependency) val transformer = new Rule2ExprTransformer() val reference = ruleReferVars.filter(_.isInstanceOf[IRVariable]) val replaceVar = reference @@ -68,12 +69,27 @@ class ProjectPlanner(projects: ProjectFields) { private def getTarget( left: IRField, referVars: List[IRField], + rule: Rule, resolved: SolvedModel, dependency: LogicalOperator): PropertyVar = { + val referTypes = new mutable.HashMap[IRField, KgType]() + for (v <- referVars) { + resolved.getVar(v.name) match { + case p: PropertyVar => referTypes.put(v, p.field.kgType) + case node: NodeVar => + node.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType)) + case edge: EdgeVar => + edge.fields.foreach(f => referTypes.put(IRProperty(v.name, f.name), f.kgType)) + case _ => + } + } + referTypes.++=(resolved.tmpFields.map(p => (p._1, p._2.field.kgType))) + val ruleRetType = ExprUtil.getTargetType(rule, referTypes.toMap, context.catalog.getUdfRepo) + left match { case IRVariable(name) => if (referVars.size == 1) { - PropertyVar(referVars.head.name, new Field(name, KTObject, true)) + PropertyVar(referVars.head.name, new Field(name, ruleRetType, true)) } else { val aliasSet = new mutable.HashSet[String]() for (rVar <- referVars) { @@ -84,11 +100,11 @@ class ProjectPlanner(projects: ProjectFields) { } } val targetAlias = getTargetAlias(aliasSet.toSet, dependency) - PropertyVar(targetAlias, new Field(left.name, KTObject, true)) + PropertyVar(targetAlias, new Field(left.name, ruleRetType, true)) } - case IRProperty(name, field) => PropertyVar(name, new Field(field, KTObject, true)) - case IRNode(name, fields) => PropertyVar(name, new Field(fields.head, KTObject, true)) - case IREdge(name, fields) => PropertyVar(name, new Field(fields.head, KTObject, true)) + case IRProperty(name, field) => PropertyVar(name, new Field(field, ruleRetType, true)) + case IRNode(name, fields) => PropertyVar(name, new Field(fields.head, ruleRetType, true)) + case IREdge(name, fields) => PropertyVar(name, new Field(fields.head, ruleRetType, true)) case _ => throw UnsupportedOperationException(s"cannot support $left") } diff --git a/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtilTests.scala b/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtilTests.scala index 432097636..83ebd1b06 100644 --- a/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtilTests.scala +++ b/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/ExprUtilTests.scala @@ -13,11 +13,13 @@ package com.antgroup.openspg.reasoner.lube.logical -import com.antgroup.openspg.reasoner.common.types.KTInteger +import com.antgroup.openspg.reasoner.common.types.{KgType, KTBoolean, KTDouble, KTInteger, KTLong, KTString} import com.antgroup.openspg.reasoner.lube.catalog.struct.Field -import com.antgroup.openspg.reasoner.lube.common.expr.{GetField, Ref, UnaryOpExpr} -import com.antgroup.openspg.reasoner.lube.common.graph.IRVariable +import com.antgroup.openspg.reasoner.lube.common.expr.{Expr, Ref, UnaryOpExpr} +import com.antgroup.openspg.reasoner.lube.common.graph.{IRField, IRProperty, IRVariable} import com.antgroup.openspg.reasoner.lube.common.rule.{ProjectRule, Rule} +import com.antgroup.openspg.reasoner.parser.expr.RuleExprParser +import com.antgroup.openspg.reasoner.udf.UdfMngFactory import org.scalatest.funspec.AnyFunSpec import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal} @@ -26,16 +28,52 @@ class ExprUtilTests extends AnyFunSpec { val replaceMap = Map.apply( "a" -> PropertyVar("b", new Field("id", KTInteger, true)), "c" -> PropertyVar("d", new Field("id", KTInteger, true))) - val r1 = ProjectRule(IRVariable("test"), - KTInteger, Ref("a")) - val r2 = ProjectRule(IRVariable("test"), - KTInteger, Ref("c")) + val r1 = ProjectRule(IRVariable("test"), Ref("a")) + val r2 = ProjectRule(IRVariable("test"), Ref("c")) r1.addDependency(r2) val newRule: Rule = ExprUtil.transExpr(r1, replaceMap) print(newRule.getExpr.pretty) newRule.getExpr.isInstanceOf[UnaryOpExpr] should equal(true) newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.isInstanceOf[Ref] should equal(true) - newRule.getExpr.asInstanceOf[UnaryOpExpr] - .arg.asInstanceOf[Ref].refName should equal("b") + newRule.getExpr.asInstanceOf[UnaryOpExpr].arg.asInstanceOf[Ref].refName should equal("b") + } + + it("test expr output type") { + val parser = new RuleExprParser() + val udfRepo = UdfMngFactory.getUdfMng + val map = Map.apply(IRProperty("A", "age") -> KTInteger).asInstanceOf[Map[IRField, KgType]] + + var expr: Expr = parser.parse("A.age") + ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTInteger) + + expr = parser.parse("A.age + 1") + ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTLong) + + expr = parser.parse("A.age > 10") + ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTBoolean) + + expr = parser.parse("floor(A.age)") + ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTDouble) + + expr = parser.parse("abs(A.age)") + ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTInteger) + + expr = parser.parse("concat(A.age, \",\")") + ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTString) + + expr = parser.parse("cast_type(A.age, 'string')") + ExprUtil.getTargetType(expr, map, udfRepo) should equal(KTString) + } + + it("test rule output type") { + val parser = new RuleExprParser() + val udfRepo = UdfMngFactory.getUdfMng + val map = Map.apply(IRProperty("A", "age") -> KTInteger).asInstanceOf[Map[IRField, KgType]] + + val rule = ProjectRule(IRVariable("newAge"), parser.parse("age * 10")) + val r1 = ProjectRule(IRVariable("age"), parser.parse("A.age")) + rule.addDependency(r1) + + ExprUtil.getTargetType(rule, map, udfRepo) should equal(KTLong) } } diff --git a/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/LogicalPlannerTests.scala b/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/LogicalPlannerTests.scala index de18e507b..7c717637b 100644 --- a/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/LogicalPlannerTests.scala +++ b/reasoner/lube-logical/src/test/scala/com/antgroup/openspg/reasoner/lube/logical/LogicalPlannerTests.scala @@ -163,66 +163,6 @@ class LogicalPlannerTests extends AnyFunSpec { println(optimizedLogicalPlan.pretty) } - it("online") { - val dsl = """Define (s:User)-[p:redPacket]->(o:Int) { - | GraphStructure { - | (s) - | } - | Rule { - |LatestHighFrequencyMonthPayCount=s.ngfe_tag__pay_cnt_m - |Latest30DayPayCount=s.ngfe_tag__pay_cnt_d - |Latest7DayPayCount=s.ngfe_tag__pay_cnt_d - |LatestTTT = Latest7DayPayCount.accumulate(+) - |LatestHighFrequencyMonthAveragePayCount=get_first_notnull(maximum(LatestHighFrequencyMonthPayCount), 0.0) / 30.0 - |Latest7DayPayCountSum=Latest7DayPayCount - |Latest7DayPayCountAverage=Latest7DayPayCountSum / 7.0 - |HighReduceValue=(LatestHighFrequencyMonthAveragePayCount - Latest7DayPayCountAverage)/LatestHighFrequencyMonthAveragePayCount - |HighLost("高频降频100%"):HighReduceValue == 1 - |HighReduce80("高频降频80%"):HighReduceValue >= 0.8 and HighReduceValue < 1 - |HighReduce50("高频降频50%"):HighReduceValue >= 0.5 and HighReduceValue < 0.8 - |HighReduce30("高频降频30%"):HighReduceValue >= 0.3 and HighReduceValue < 0.5 - |HighReduce10("高频降频10%"):HighReduceValue >= 0.1 and HighReduceValue < 0.3 - |Latest3060DayPayCount=s.ngfe_tag__pay_cnt_d - |Latest30DayPayDayCount=size(Latest30DayPayCount) - |Latest3060DayPayDayCount=size(Latest3060DayPayCount) - |High1("高频用户1"):Latest3060DayPayDayCount < 13 and Latest30DayPayDayCount >= 13 - |High2("高频用户2"):Latest3060DayPayDayCount > 12 and Latest30DayPayDayCount >= 13 - |Middle1("中频用户1"):Latest3060DayPayDayCount == 0 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12 - |Middle2("中频用户2"):Latest3060DayPayDayCount >= 1 and Latest3060DayPayDayCount <= 3 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12 - |Middle3("中频用户3"):Latest3060DayPayDayCount >= 4 and Latest30DayPayDayCount >= 4 and Latest30DayPayDayCount <= 12 - |Low1("低频用户1"):Latest3060DayPayDayCount >= 1 and Latest3060DayPayDayCount <= 3 and Latest30DayPayDayCount >= 1 and Latest30DayPayDayCount <= 3 - |Low2("低频用户2"):(Latest3060DayPayDayCount > 3 or Latest3060DayPayDayCount == 0) and Latest30DayPayDayCount >= 1 and Latest30DayPayDayCount <= 3 - |Latest6090DayPayCount=s.ngfe_tag__pay_cnt_d - |Latest6090DayPayDayCount=size(Latest6090DayPayCount) - |Latest60DayPayCount=s.ngfe_tag__pay_cnt_d - |Latest60DayPayDayCount=size(Latest60DayPayCount) - |Sleep1("沉睡用户1"):Latest6090DayPayDayCount > 0 and Latest60DayPayDayCount == 0 - |Sleep2("沉睡用户2"):Latest3060DayPayDayCount > 0 and Latest30DayPayDayCount == 0 - |HistoricallyPay=s.ngfe_tag__pay_cnt_total - |HistoricallyPayCount=size(HistoricallyPay) - |New("新用户"):HistoricallyPayCount == 0 and Latest30DayPayDayCount == 0 - |Latest90DayPayCount=s.ngfe_tag__pay_cnt_d - |Latest90DayPayDayCount=size(Latest90DayPayCount) - |Lost("流失用户"):HistoricallyPayCount > 0 and Latest90DayPayDayCount == 0 - |o=get_first_notnull(rule_value(HighLost, "high_lost"), rule_value(HighReduce80, "high_reduce_80"),rule_value(HighReduce50, "high_reduce_50"), rule_value(HighReduce30, "high_reduce_30"), rule_value(HighReduce10, "high_reduce_10"), rule_value(High1, "high_1"), rule_value(High2, "high_2"), rule_value(Middle1, "middle_1"), rule_value(Middle2, "middle_2"), rule_value(Middle3, "middle_3"), rule_value(Low1, "low_1"), rule_value(Low2, "low_2"), rule_value(Sleep1, "sleep_1"), rule_value(Sleep2, "sleep_2"), rule_value(New, "new"), rule_value(Lost, "lost")) - | } - |}""".stripMargin - val parser = new OpenSPGDslParser() - val block = parser.parse(dsl) - println(block.pretty) - val schema: Map[String, Set[String]] = Map.apply( - "User" -> Set - .apply("ngfe_tag__pay_cnt_m", "ngfe_tag__pay_cnt_total", "ngfe_tag__pay_cnt_d")) - val catalog = new PropertyGraphCatalog(schema) - catalog.init() - implicit val context: LogicalPlannerContext = - LogicalPlannerContext(catalog, parser, Map.empty) - val logicalPlan = LogicalPlanner.plan(block) - println(logicalPlan.head.pretty) - val optimizedLogicalPlan = LogicalOptimizer.optimize(logicalPlan.head) - println(optimizedLogicalPlan.pretty) - } - it("test start flag") { val dsl = """ @@ -394,7 +334,7 @@ class LogicalPlannerTests extends AnyFunSpec { | (s)-[p2:followPM]->(o) |} |Rule { - | c = rule_value(p.avgProfit > 0, 1,0 ) + rule_value(p2.times>3, 1,0) + | c = rule_value(p.avgProfit > 0, 1,0 ) && rule_value(p2.times>3, 1,0) | |} |Action { diff --git a/reasoner/pom.xml b/reasoner/pom.xml index bcf6620c2..0f70033a3 100644 --- a/reasoner/pom.xml +++ b/reasoner/pom.xml @@ -41,7 +41,7 @@ 4.8 1.2.71_noneautotype - 27.0 + 28.0 2.0.0 2.7.2 3.1.0 diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerABMLocalTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerABMLocalTest.java index 6fcbd4822..8c5a03aa9 100644 --- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerABMLocalTest.java +++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerABMLocalTest.java @@ -137,7 +137,7 @@ public void test11() { + " \t(s)-[:p3]->(t:Attribute1.Name142)\n" + " }\n" + " Rule {\n" - + " \tv = t.stock/t.total\n" + + " \tv = cast_type(t.stock, 'double')/cast_type(t.total,'double')\n" + " R1(\"必须大于20%\"): v > 0.2\n" + " o = v\n" + " }\n" @@ -197,7 +197,7 @@ public void test9() { + " \t(s)-[:p3]->(t:Attribute1.Name142)\n" + " }\n" + " Rule {\n" - + " \tv = t.stock/t.total\n" + + " \tv = cast_type(t.stock, 'double')/cast_type(t.total,'double')\n" + " R1(\"必须大于20%\"): v > 0.2\n" + " o = v\n" + " }\n" diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAliasSetKFilmTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAliasSetKFilmTest.java index d0681d48e..443e66241 100644 --- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAliasSetKFilmTest.java +++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerAliasSetKFilmTest.java @@ -50,9 +50,12 @@ private void doTest1() { + "R1: A.id == $idSet1\n" + "R2: B.id in $idSet2\n" + "R3: C.id in $idSet2\n" - + "totalTrans1 = group(A,B,C).sum(p1.amount)\n" - + "totalTrans2 = group(A,B,C).sum(p2.amount)\n" - + "totalTrans3 = group(A,B,C).sum(p3.amount)\n" + + "p1_amt = cast_type(p1.amount,'long')\n" + + "p2_amt = cast_type(p2.amount,'long')\n" + + "p3_amt = cast_type(p3.amount,'long')\n" + + "totalTrans1 = group(A,B,C).sum(p1_amt)\n" + + "totalTrans2 = group(A,B,C).sum(p2_amt)\n" + + "totalTrans3 = group(A,B,C).sum(p3_amt)\n" + "totalTrans = totalTrans1 + totalTrans2 + totalTrans3\n" + "R2('取top2'): top(totalTrans, 2)" + "}\n" @@ -90,7 +93,7 @@ private void doTest2() { + "R1: A.id in $idSet1\n" + "R2: B.id in $idSet2\n" + "R3: C.id in $idSet2\n" - + "totalTrans = p1.amount + p2.amount + p3.amount\n" + + "totalTrans = cast_type(p1.amount,'long') + cast_type(p2.amount,'long') + cast_type(p3.amount,'long')\n" + "R2('取top2'): top(totalTrans, 3)" + "}\n" + "Action {\n" @@ -127,7 +130,7 @@ private void doTest3() { + "R1: A.id == $idSet1\n" + "R2: B.id == $idSet2\n" + "R3: C.id == $idSet3\n" - + "totalTrans = p1.amount + p2.amount + p3.amount\n" + + "totalTrans = cast_type(p1.amount,'long') + cast_type(p2.amount,'long') + cast_type(p3.amount,'long')\n" + "R2('取top2'): top(totalTrans, 3)" + "}\n" + "Action {\n" @@ -169,9 +172,12 @@ private void doTest4() { + "R1: A.id in $idSet1\n" + "R2: B.id in $idSet2\n" + "R3: C.id in $idSet2\n" - + "t1 = group(A,B,C).sum(p1.amount)\n" - + "t2 = group(A,B,C).sum(p2.amount)\n" - + "t3 = group(A,B,C).sum(p3.amount)\n" + + "p1_amt = cast_type(p1.amount,'long')\n" + + "p2_amt = cast_type(p2.amount,'long')\n" + + "p3_amt = cast_type(p3.amount,'long')\n" + + "t1 = group(A,B,C).sum(p1_amt)\n" + + "t2 = group(A,B,C).sum(p2_amt)\n" + + "t3 = group(A,B,C).sum(p3_amt)\n" + "totalSum = t1 + t2 + t3" + "}\n" + "Action {\n" diff --git a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerTopKFilmTest.java b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerTopKFilmTest.java index 6e675da2f..1572df587 100644 --- a/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerTopKFilmTest.java +++ b/reasoner/runner/local-runner/src/test/java/com/antgroup/openspg/reasoner/runner/local/main/KgReasonerTopKFilmTest.java @@ -403,7 +403,7 @@ private void doTest12() { + " o->star [starOfFilm] as sf2\n" + "}\n" + "Rule {\n" - + "total = sf.joinTs + sf2.joinTs\n" + + "total = cast_type(sf.joinTs, 'bigint') + cast_type(sf2.joinTs, 'bigint')\n" + "R2: top(total, 1)\n" + "}\n" + "Action {\n" diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/KgGraphAggregateImpl.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/KgGraphAggregateImpl.java index 689f2e199..ed2ec729c 100644 --- a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/KgGraphAggregateImpl.java +++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/KgGraphAggregateImpl.java @@ -25,13 +25,7 @@ import com.antgroup.openspg.reasoner.kggraph.KgGraph; import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphImpl; import com.antgroup.openspg.reasoner.kggraph.impl.KgGraphSplitStaticParameters; -import com.antgroup.openspg.reasoner.lube.common.expr.AggIfOpExpr; -import com.antgroup.openspg.reasoner.lube.common.expr.AggOpExpr; -import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator; -import com.antgroup.openspg.reasoner.lube.common.expr.Expr; -import com.antgroup.openspg.reasoner.lube.common.expr.GetField; -import com.antgroup.openspg.reasoner.lube.common.expr.Ref; -import com.antgroup.openspg.reasoner.lube.common.expr.UnaryOpExpr; +import com.antgroup.openspg.reasoner.lube.common.expr.*; import com.antgroup.openspg.reasoner.lube.common.pattern.Pattern; import com.antgroup.openspg.reasoner.lube.logical.EdgeVar; import com.antgroup.openspg.reasoner.lube.logical.NodeVar; @@ -41,6 +35,7 @@ import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggIfOpProcessBaseGroupProcess; import com.antgroup.openspg.reasoner.rdg.common.groupProcess.AggOpProcessBaseGroupProcess; import com.antgroup.openspg.reasoner.rdg.common.groupProcess.BaseGroupProcess; +import com.antgroup.openspg.reasoner.rdg.common.groupProcess.ParsedAggEle; import com.antgroup.openspg.reasoner.udf.model.BaseUdaf; import com.antgroup.openspg.reasoner.udf.model.UdafMeta; import com.antgroup.openspg.reasoner.udf.rule.RuleRunner; @@ -353,26 +348,14 @@ private Object doAggregation( if (null != udafInitParams) { udaf.initialize(udafInitParams); } - - String sourceAlias = null; - String sourcePropertyName = null; + ParsedAggEle parsedAggEle; Set aliasList = aggInfo.getExprUseAliasSet(); if (aliasList.size() <= 1) { - Expr sourceExpr = aggInfo.getAggEle(); - // aggregate by vertex subgraph - if (sourceExpr instanceof Ref) { - Ref sourceRef = (Ref) sourceExpr; - sourceAlias = sourceRef.refName(); - } else if (sourceExpr instanceof UnaryOpExpr) { - UnaryOpExpr expr = (UnaryOpExpr) sourceExpr; - GetField getField = (GetField) expr.name(); - sourceAlias = ((Ref) expr.arg()).refName(); - sourcePropertyName = getField.fieldName(); - } + parsedAggEle = aggInfo.getParsedAggEle(); if (!StringUtils.isEmpty(DEBUG_VERTEX_ALIAS)) { for (KgGraph valueFiltered : valueFilteredList) { if (valueFiltered.hasFocusVertexId(DEBUG_VERTEX_ALIAS, DEBUG_VERTEX_ID_SET)) { - StringBuffer sb = new StringBuffer(); + StringBuilder sb = new StringBuilder(); for (KgGraph valueFiltered2 : valueFilteredList) { sb.append(valueFiltered2).append("## "); } @@ -381,19 +364,41 @@ private Object doAggregation( } } } - String finalSourcePropertyName = sourcePropertyName; + String finalSourcePropertyName = parsedAggEle.getSourcePropertyName(); for (KgGraph valueFiltered : valueFilteredList) { - if (valueFiltered.getVertexAlias().contains(sourceAlias)) { - List> vertexList = valueFiltered.getVertex(sourceAlias); - if (sourcePropertyName == null) { + if (valueFiltered.getVertexAlias().contains(parsedAggEle.getSourceAlias())) { + List> vertexList = + valueFiltered.getVertex(parsedAggEle.getSourceAlias()); + if (CollectionUtils.isNotEmpty(parsedAggEle.getExprStrList())) { + vertexList.forEach( + vertex -> { + Map context = + RunnerUtil.vertexContext(vertex, parsedAggEle.getSourceAlias()); + Object value = + RuleRunner.getInstance() + .executeExpression(context, parsedAggEle.getExprStrList(), taskId); + udaf.update(value); + }); + } else if (finalSourcePropertyName == null) { vertexList.forEach(udaf::update); } else { vertexList.forEach( v -> updateUdafDataFromProperty(udaf, v.getValue(), finalSourcePropertyName)); } } else { - List> edgeList = valueFiltered.getEdge(sourceAlias); - if (sourcePropertyName == null) { + List> edgeList = + valueFiltered.getEdge(parsedAggEle.getSourceAlias()); + if (CollectionUtils.isNotEmpty(parsedAggEle.getExprStrList())) { + edgeList.forEach( + edge -> { + Map context = + RunnerUtil.edgeContext(edge, null, parsedAggEle.getSourceAlias()); + Object value = + RuleRunner.getInstance() + .executeExpression(context, parsedAggEle.getExprStrList(), taskId); + udaf.update(value); + }); + } else if (finalSourcePropertyName == null) { edgeList.forEach(udaf::update); } else { edgeList.forEach( diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggIfOpProcessBaseGroupProcess.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggIfOpProcessBaseGroupProcess.java index 8dc8e73fb..89fc848f0 100644 --- a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggIfOpProcessBaseGroupProcess.java +++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggIfOpProcessBaseGroupProcess.java @@ -23,6 +23,7 @@ import java.util.List; public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable { + private final ParsedAggEle parsedAggEle; /** * constructor @@ -33,6 +34,7 @@ public class AggIfOpProcessBaseGroupProcess extends BaseGroupProcess implements */ public AggIfOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) { super(taskId, var, aggregator); + parsedAggEle = parsedAggEle(); } /** @@ -58,4 +60,9 @@ public AggregatorOpSet getAggOpSet() { public Expr getAggEle() { return getAggIfOpExpr().aggOpExpr().aggEleExpr(); } + + @Override + public ParsedAggEle getParsedAggEle() { + return parsedAggEle; + } } diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggOpProcessBaseGroupProcess.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggOpProcessBaseGroupProcess.java index 91a21dd0a..aeab4b026 100644 --- a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggOpProcessBaseGroupProcess.java +++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/AggOpProcessBaseGroupProcess.java @@ -22,9 +22,11 @@ import java.util.List; public class AggOpProcessBaseGroupProcess extends BaseGroupProcess implements Serializable { + private final ParsedAggEle parsedAggEle; public AggOpProcessBaseGroupProcess(String taskId, Var var, Aggregator aggregator) { super(taskId, var, aggregator); + parsedAggEle = parsedAggEle(); } public AggOpExpr getAggOpExpr() { @@ -45,4 +47,9 @@ public AggregatorOpSet getAggOpSet() { public Expr getAggEle() { return getAggOpExpr().aggEleExpr(); } + + @Override + public ParsedAggEle getParsedAggEle() { + return this.parsedAggEle; + } } diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/BaseGroupProcess.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/BaseGroupProcess.java index 0dc5e2124..564afd51b 100644 --- a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/BaseGroupProcess.java +++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/BaseGroupProcess.java @@ -19,6 +19,9 @@ import com.antgroup.openspg.reasoner.lube.common.expr.Aggregator; import com.antgroup.openspg.reasoner.lube.common.expr.AggregatorOpSet; import com.antgroup.openspg.reasoner.lube.common.expr.Expr; +import com.antgroup.openspg.reasoner.lube.common.expr.GetField; +import com.antgroup.openspg.reasoner.lube.common.expr.Ref; +import com.antgroup.openspg.reasoner.lube.common.expr.UnaryOpExpr; import com.antgroup.openspg.reasoner.lube.logical.PropertyVar; import com.antgroup.openspg.reasoner.lube.logical.Var; import com.antgroup.openspg.reasoner.lube.utils.ExprUtils; @@ -139,6 +142,9 @@ protected UdafMeta parseUdafMeta() { */ public abstract Expr getAggEle(); + /** get parsed agg ele */ + public abstract ParsedAggEle getParsedAggEle(); + public Set parseExprUseAliasSet() { scala.collection.immutable.List aliasList = ExprUtils.getRefVariableByExpr(getAggEle()); return new HashSet<>(JavaConversions.seqAsJavaList(aliasList)); @@ -148,6 +154,26 @@ public List parseExprRuleList() { return WareHouseUtils.getRuleList(getAggEle()); } + protected ParsedAggEle parsedAggEle() { + String sourceAlias = null; + String sourcePropertyName = null; + List exprStrList = null; + Expr aggEle = getAggEle(); + if (aggEle instanceof Ref) { + Ref sourceRef = (Ref) aggEle; + sourceAlias = sourceRef.refName(); + } else if (aggEle instanceof UnaryOpExpr) { + UnaryOpExpr expr = (UnaryOpExpr) aggEle; + GetField getField = (GetField) expr.name(); + sourceAlias = ((Ref) expr.arg()).refName(); + sourcePropertyName = getField.fieldName(); + } else if (1 == this.exprUseAliasSet.size()) { + sourceAlias = this.exprUseAliasSet.iterator().next(); + exprStrList = this.exprRuleString; + } + return new ParsedAggEle(sourceAlias, sourcePropertyName, exprStrList); + } + /** * getter * diff --git a/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/ParsedAggEle.java b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/ParsedAggEle.java new file mode 100644 index 000000000..08026e6e1 --- /dev/null +++ b/reasoner/runner/runner-common/src/main/java/com/antgroup/openspg/reasoner/rdg/common/groupProcess/ParsedAggEle.java @@ -0,0 +1,40 @@ +/* + * Copyright 2023 OpenSPG Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. + */ + +package com.antgroup.openspg.reasoner.rdg.common.groupProcess; + +import java.util.List; + +public class ParsedAggEle { + private final String sourceAlias; + private final String sourcePropertyName; + private final List exprStrList; + + public ParsedAggEle(String sourceAlias, String sourcePropertyName, List exprStrList) { + this.sourceAlias = sourceAlias; + this.sourcePropertyName = sourcePropertyName; + this.exprStrList = exprStrList; + } + + public String getSourceAlias() { + return sourceAlias; + } + + public String getSourcePropertyName() { + return sourcePropertyName; + } + + public List getExprStrList() { + return exprStrList; + } +}