Skip to content

Commit

Permalink
Refactoring Measurement and a lot of stuff from Agent side
Browse files Browse the repository at this point in the history
  • Loading branch information
lsulak committed Nov 1, 2023
1 parent 0942d07 commit 7b013f9
Show file tree
Hide file tree
Showing 11 changed files with 229 additions and 151 deletions.
16 changes: 6 additions & 10 deletions agent/src/main/scala/za/co/absa/atum/agent/AtumContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package za.co.absa.atum.agent

import org.apache.spark.sql.DataFrame
import za.co.absa.atum.agent.AtumContext.AtumPartitions
import za.co.absa.atum.agent.model.Measurement.MeasurementByAtum
import za.co.absa.atum.agent.model._
import za.co.absa.atum.model.dto._

Expand All @@ -28,9 +27,6 @@ import scala.collection.immutable.ListMap

/**
* This class provides the methods to measure Spark `Dataframe`. Also allows to add and remove measures.
* @param atumPartitions
* @param agent
* @param measures
*/
class AtumContext private[agent] (
val atumPartitions: AtumPartitions,
Expand All @@ -45,16 +41,16 @@ class AtumContext private[agent] (
agent.getOrCreateAtumSubContext(atumPartitions ++ subPartitions)(this)
}

private def takeMeasurements(df: DataFrame): Set[MeasurementByAtum] = {
private def takeMeasurements(df: DataFrame): Set[MeasurementDTO] = {
measures.map { m =>
val measurementResult = m.function(df)
MeasurementByAtum(m, measurementResult.result, measurementResult.resultType)
val measureResult = m.function(df)
MeasurementBuilder.buildMeasurementDTO(m, measureResult)
}
}

def createCheckpoint(checkpointName: String, author: String, dataToMeasure: DataFrame): AtumContext = {
val startTime = OffsetDateTime.now()
val measurements = takeMeasurements(dataToMeasure)
val measurementDTOs = takeMeasurements(dataToMeasure)
val endTime = OffsetDateTime.now()

val checkpointDTO = CheckpointDTO(
Expand All @@ -65,7 +61,7 @@ class AtumContext private[agent] (
partitioning = AtumPartitions.toSeqPartitionDTO(this.atumPartitions),
processStartTime = startTime,
processEndTime = Some(endTime),
measurements = measurements.map(MeasurementBuilder.buildMeasurementDTO).toSeq
measurements = measurementDTOs.toSeq
)

agent.saveCheckpoint(checkpointDTO)
Expand All @@ -82,7 +78,7 @@ class AtumContext private[agent] (
partitioning = AtumPartitions.toSeqPartitionDTO(this.atumPartitions),
processStartTime = offsetDateTimeNow,
processEndTime = Some(offsetDateTimeNow),
measurements = measurements.map(MeasurementBuilder.buildMeasurementDTO)
measurements = MeasurementBuilder.buildMeasurementDTO(measurements)
)

agent.saveCheckpoint(checkpointDTO)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,12 @@ package za.co.absa.atum.agent.core

import org.apache.spark.sql.DataFrame
import za.co.absa.atum.agent.core.MeasurementProcessor.MeasurementFunction
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType
import za.co.absa.atum.agent.model.MeasureResult

trait MeasurementProcessor {

def function: MeasurementFunction

}

object MeasurementProcessor {
/**
* The raw result of measurement is always gonna be string, because we want to avoid some floating point issues
* (overflows, consistent representation of numbers - whether they are coming from Java or Scala world, and more),
* but the actual type is stored alongside the computation because we don't want to lost this information.
*/
final case class ResultOfMeasurement(result: String, resultType: ResultValueType.ResultValueType)

type MeasurementFunction = DataFrame => ResultOfMeasurement

type MeasurementFunction = DataFrame => MeasureResult
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ package za.co.absa.atum.agent.exception

sealed abstract class AtumAgentException extends Exception

case class MeasurementProvidedException(msg: String) extends AtumAgentException
case class MeasurementException(msg: String) extends AtumAgentException
case class MeasureException(msg: String) extends AtumAgentException
21 changes: 10 additions & 11 deletions agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DecimalType, LongType, StringType}
import org.apache.spark.sql.{Column, DataFrame}
import za.co.absa.atum.agent.core.MeasurementProcessor
import za.co.absa.atum.agent.core.MeasurementProcessor.{MeasurementFunction, ResultOfMeasurement}
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType
import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements

Expand Down Expand Up @@ -55,10 +54,10 @@ object Measure {
resultValueType: ResultValueType.ResultValueType
) extends Measure {

override def function: MeasurementFunction =
override def function: MeasurementProcessor.MeasurementFunction =
(ds: DataFrame) => {
val resultValue = ds.select(col(controlCol)).count().toString
ResultOfMeasurement(resultValue, resultValueType)
MeasureResult(resultValue, resultValueType)
}
}
object RecordCount extends MeasureType {
Expand All @@ -74,10 +73,10 @@ object Measure {
resultValueType: ResultValueType.ResultValueType
) extends Measure {

override def function: MeasurementFunction =
override def function: MeasurementProcessor.MeasurementFunction =
(ds: DataFrame) => {
val resultValue = ds.select(col(controlCol)).distinct().count().toString
ResultOfMeasurement(resultValue, resultValueType)
MeasureResult(resultValue, resultValueType)
}
}
object DistinctRecordCount extends MeasureType {
Expand All @@ -95,10 +94,10 @@ object Measure {
resultValueType: ResultValueType.ResultValueType
) extends Measure {

override def function: MeasurementFunction = (ds: DataFrame) => {
override def function: MeasurementProcessor.MeasurementFunction = (ds: DataFrame) => {
val aggCol = sum(col(valueColumnName))
val resultValue = aggregateColumn(ds, controlCol, aggCol)
ResultOfMeasurement(resultValue, resultValueType)
MeasureResult(resultValue, resultValueType)
}
}
object SumOfValuesOfColumn extends MeasureType {
Expand All @@ -116,10 +115,10 @@ object Measure {
resultValueType: ResultValueType.ResultValueType
) extends Measure {

override def function: MeasurementFunction = (ds: DataFrame) => {
override def function: MeasurementProcessor.MeasurementFunction = (ds: DataFrame) => {
val aggCol = sum(abs(col(valueColumnName)))
val resultValue = aggregateColumn(ds, controlCol, aggCol)
ResultOfMeasurement(resultValue, resultValueType)
MeasureResult(resultValue, resultValueType)
}
}
object AbsSumOfValuesOfColumn extends MeasureType {
Expand All @@ -137,15 +136,15 @@ object Measure {
resultValueType: ResultValueType.ResultValueType
) extends Measure {

override def function: MeasurementFunction = (ds: DataFrame) => {
override def function: MeasurementProcessor.MeasurementFunction = (ds: DataFrame) => {

val aggregatedColumnName = ds.schema.getClosestUniqueName("sum_of_hashes")
val value = ds
.withColumn(aggregatedColumnName, crc32(col(controlCol).cast("String")))
.agg(sum(col(aggregatedColumnName)))
.collect()(0)(0)
val resultValue = if (value == null) "" else value.toString
ResultOfMeasurement(resultValue, ResultValueType.String)
MeasureResult(resultValue, ResultValueType.String)
}
}
object SumOfHashesOfColumn extends MeasureType {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2021 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.atum.agent.model

import za.co.absa.atum.agent.exception.MeasurementException
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType

trait MeasureResult {
val resultValue: Any
val resultType: ResultValueType.ResultValueType
}

object MeasureResult {
private final case class MeasureResultWithType[T](resultValue: T, resultType: ResultValueType.ResultValueType)
extends MeasureResult

/**
* When the Atum Agent itself performs the measurements, using Spark, then in some cases some adjustments are
* needed - thus we are converting the results to strings always - but we need to keep the information about
* the actual type as well.
*
* These adjustments are needed to be performed - to avoid some floating point issues
* (overflows, consistent representation of numbers - whether they are coming from Java or Scala world, and more).
*/
def apply(resultValue: String, resultType: ResultValueType.ResultValueType): MeasureResult = {
MeasureResultWithType[String](resultValue, resultType)
}

/**
* When the application/user of Atum Agent provides actual results by himself, the type is precise and we don't need
* to do any adjustments.
*/
def apply(resultValue: Any): MeasureResult = {
resultValue match {

case l: Long =>
MeasureResultWithType[Long](l, ResultValueType.Long)
case d: Double =>
MeasureResultWithType[Double](d, ResultValueType.Double)
case bd: BigDecimal =>
MeasureResultWithType[BigDecimal](bd, ResultValueType.BigDecimal)
case s: String =>
MeasureResultWithType[String](s, ResultValueType.String)

case unsupportedType =>
val className = unsupportedType.getClass.getSimpleName
throw MeasurementException(
s"Unsupported type of measurement: $className for provided result: $resultValue")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,65 +16,4 @@

package za.co.absa.atum.agent.model

import za.co.absa.atum.agent.exception.MeasurementProvidedException
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType

trait Measurement {
val measure: Measure
val resultValue: Any
val resultType: ResultValueType.ResultValueType
}

object Measurement {

/**
* When the application/user of Atum Agent provides actual results by himself, the type is precise and we don't need
* to do any adjustments.
*/
case class MeasurementProvided[T](measure: Measure, resultValue: T, resultType: ResultValueType.ResultValueType)
extends Measurement

object MeasurementProvided {

private def handleSpecificType[T](
measure: Measure, resultValue: T, requiredType: ResultValueType.ResultValueType
): MeasurementProvided[T] = {

val actualType = measure.resultValueType
if (actualType != requiredType)
throw MeasurementProvidedException(
s"Type of a given provided measurement result and type that a given measure supports are not compatible! " +
s"Got $actualType but should be $requiredType"
)
MeasurementProvided[T](measure, resultValue, requiredType)
}

def apply[T](measure: Measure, resultValue: T): Measurement = {
resultValue match {
case l: Long =>
handleSpecificType[Long](measure, l, ResultValueType.Long)
case d: Double =>
handleSpecificType[Double](measure, d, ResultValueType.Double)
case bd: BigDecimal =>
handleSpecificType[BigDecimal](measure, bd, ResultValueType.BigDecimal)
case s: String =>
handleSpecificType[String](measure, s, ResultValueType.String)

case unsupportedType =>
val className = unsupportedType.getClass.getSimpleName
throw MeasurementProvidedException(
s"Unsupported type of measurement for measure ${measure.measureName}: $className " +
s"for provided result: $resultValue"
)
}
}
}

/**
* When the Atum Agent itself performs the measurements, using Spark, then in some cases some adjustments are
* needed - thus we are converting the results to strings always - but we need to keep the information about
* the actual type as well.
*/
case class MeasurementByAtum(measure: Measure, resultValue: String, resultType: ResultValueType.ResultValueType)
extends Measurement
}
final case class Measurement(measure: Measure, result: MeasureResult)
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,52 @@

package za.co.absa.atum.agent.model

import za.co.absa.atum.agent.exception.MeasurementException
import za.co.absa.atum.model.dto.{MeasureDTO, MeasureResultDTO, MeasurementDTO}
import za.co.absa.atum.model.dto.MeasureResultDTO.TypedValue

private [agent] object MeasurementBuilder {

private [agent] def buildMeasurementDTO(measurement: Measurement): MeasurementDTO = {
val measureName = measurement.measure.measureName
val controlCols = Seq(measurement.measure.controlCol)
val measureDTO = MeasureDTO(measureName, controlCols)
private def validateMeasurement(measure: Measure, result: MeasureResult): Unit = {
val actualType = result.resultType
val requiredType = measure.resultValueType

if (actualType != requiredType)
throw MeasurementException(
s"Type of a given provided measurement result and type that a given measure supports are not compatible! " +
s"Got $actualType but should be $requiredType"
)
}

private def validateMeasuresUniqueness(measures: Seq[Measure]): Unit = {
val originalMeasureCnt = measures.size
val uniqueMeasuresCnt = measures.map(m => Tuple2(m.measureName, m.controlCol)).distinct.size

val areMeasuresUnique = originalMeasureCnt == uniqueMeasuresCnt

require(areMeasuresUnique, s"Measures must be unique, i.e. they cannot repeat! Got: $measures")
}

val measureResultDTO = MeasureResultDTO(TypedValue(measurement.resultValue.toString, measurement.resultType))
private[agent] def buildMeasurementDTO(measurements: Seq[Measurement]): Seq[MeasurementDTO] = {
val allMeasures = measurements.map(_.measure)
validateMeasuresUniqueness(allMeasures)

measurements.map(m => buildMeasurementDTO(m.measure, m.result))
}

private[agent] def buildMeasurementDTO(measurement: Measurement): MeasurementDTO = {
buildMeasurementDTO(measurement.measure, measurement.result)
}

private[agent] def buildMeasurementDTO(measure: Measure, measureResult: MeasureResult): MeasurementDTO = {
val measureName = measure.measureName
val controlCols = Seq(measure.controlCol)

validateMeasurement(measure, measureResult)

val measureDTO = MeasureDTO(measureName, controlCols)
val measureResultDTO = MeasureResultDTO(
MeasureResultDTO.TypedValue(measureResult.resultValue.toString, measureResult.resultType)
)
MeasurementDTO(measureDTO, measureResultDTO)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import za.co.absa.atum.agent.AtumContext.AtumPartitions
import za.co.absa.atum.agent.model.Measure.{RecordCount, SumOfValuesOfColumn}
import za.co.absa.atum.agent.model.MeasurementBuilder
import za.co.absa.atum.agent.model.Measurement.MeasurementProvided
import za.co.absa.atum.agent.model.{Measurement, MeasurementBuilder, MeasureResult}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import za.co.absa.atum.model.dto.CheckpointDTO
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType
Expand Down Expand Up @@ -107,8 +106,8 @@ class AtumContextTest extends AnyFlatSpec with Matchers {
val atumContext: AtumContext = new AtumContext(atumPartitions, mockAgent)

val measurements = Seq(
MeasurementProvided(RecordCount("col"), 1L),
MeasurementProvided(SumOfValuesOfColumn("col"), BigDecimal(1))
Measurement(RecordCount("col"), MeasureResult(1L)),
Measurement(SumOfValuesOfColumn("col"), MeasureResult(BigDecimal(1)))
)

atumContext.createCheckpointOnProvidedData(
Expand Down Expand Up @@ -192,5 +191,4 @@ class AtumContextTest extends AnyFlatSpec with Matchers {

assert(atumContext.currentAdditionalData == expectedAdditionalData)
}

}
Loading

0 comments on commit 7b013f9

Please sign in to comment.