Skip to content

Commit

Permalink
Scaladocs added (#129)
Browse files Browse the repository at this point in the history
missing scaladocs added
  • Loading branch information
salamonpavel authored Nov 23, 2023
1 parent cefe63c commit 597aa46
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 33 deletions.
7 changes: 7 additions & 0 deletions agent/src/main/scala/za/co/absa/atum/agent/AtumAgent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ class AtumAgent private[agent] () {
getExistingOrNewContext(atumPartitions, atumContext)
}

/**
* Provides an AtumContext given a `AtumPartitions` instance for sub partitions.
* Retrieves the data from AtumService API.
* @param subPartitions Sub partitions based on which an Atum Context will be created or obtained.
* @param parentAtumContext Parent AtumContext.
* @return Atum context object
*/
def getOrCreateAtumSubContext(subPartitions: AtumPartitions)(implicit parentAtumContext: AtumContext): AtumContext = {
val authorIfNew = AtumAgent.currentUser
val newPartitions: AtumPartitions = parentAtumContext.atumPartitions ++ subPartitions
Expand Down
65 changes: 65 additions & 0 deletions agent/src/main/scala/za/co/absa/atum/agent/AtumContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,18 @@ class AtumContext private[agent] (
private var additionalData: Map[String, Option[String]] = Map.empty
) {

/**
* Returns the current set of measures in the AtumContext.
*
* @return the current set of measures
*/
def currentMeasures: Set[Measure] = measures

/**
* Returns the sub-partition context in the AtumContext.
*
* @return the sub-partition context
*/
def subPartitionContext(subPartitions: AtumPartitions): AtumContext = {
agent.getOrCreateAtumSubContext(atumPartitions ++ subPartitions)(this)
}
Expand All @@ -54,6 +64,19 @@ class AtumContext private[agent] (
}
}

/**
* Creates a checkpoint in the AtumContext.
*
* This method is used to mark a specific point in the data processing pipeline where measurements of data
* completeness are taken.
* The checkpoint is identified by a name, which can be used later to retrieve the measurements taken at this point.
* After the checkpoint is created, the method returns the AtumContext for further operations.
*
* @param checkpointName the name of the checkpoint to be created. This name should be descriptive of the point in
* the data processing pipeline where the checkpoint is created.
* @return the AtumContext after the checkpoint has been created.
* This allows for method chaining in the data processing pipeline.
*/
def createCheckpoint(checkpointName: String, dataToMeasure: DataFrame): AtumContext = {
val startTime = ZonedDateTime.now()
val measurements = takeMeasurements(dataToMeasure)
Expand All @@ -74,6 +97,13 @@ class AtumContext private[agent] (
this
}

/**
* Creates a checkpoint with the specified name and provided measurements.
*
* @param checkpointName the name of the checkpoint to be created
* @param measurements the measurements to be included in the checkpoint
* @return the AtumContext after the checkpoint has been created
*/
def createCheckpointOnProvidedData(checkpointName: String, measurements: Seq[Measurement]): AtumContext = {
val dateTimeNow = ZonedDateTime.now()

Expand All @@ -91,24 +121,50 @@ class AtumContext private[agent] (
this
}

/**
* Adds additional data to the AtumContext.
*
* @param key the key of the additional data
* @param value the value of the additional data
*/
def addAdditionalData(key: String, value: String): Unit = {
additionalData += (key -> Some(value))
}

/**
* Returns the current additional data in the AtumContext.
*
* @return the current additional data
*/
def currentAdditionalData: Map[String, Option[String]] = {
this.additionalData
}

/**
* Adds a measure to the AtumContext.
*
* @param measure the measure to be added
*/
def addMeasure(newMeasure: Measure): AtumContext = {
measures = measures + newMeasure
this
}

/**
* Adds multiple measures to the AtumContext.
*
* @param measures set sequence of measures to be added
*/
def addMeasures(newMeasures: Set[Measure]): AtumContext = {
measures = measures ++ newMeasures
this
}

/**
* Removes a measure from the AtumContext.
*
* @param measureToRemove the measure to be removed
*/
def removeMeasure(measureToRemove: Measure): AtumContext = {
measures = measures - measureToRemove
this
Expand All @@ -125,8 +181,14 @@ class AtumContext private[agent] (
}

object AtumContext {
/**
* Type alias for Atum partitions.
*/
type AtumPartitions = ListMap[String, String]

/**
* Object contains helper methods to work with Atum partitions.
*/
object AtumPartitions {
def apply(elems: (String, String)): AtumPartitions = {
ListMap(elems)
Expand Down Expand Up @@ -154,6 +216,9 @@ object AtumContext {
)
}

/**
* Implicit class to add a method to DataFrame to create a checkpoint.
*/
implicit class DatasetWrapper(df: DataFrame) {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,23 @@ import org.apache.spark.sql.DataFrame
import za.co.absa.atum.agent.core.MeasurementProcessor.MeasurementFunction
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType

/**
* This trait provides a contract for different measurement processors
*/
trait MeasurementProcessor {

/**
* This method is used to compute measure on Spark `Dataframe`.
* @param df: Spark `Dataframe` to be measured.
* @return Result of measurement.
*/
def function: MeasurementFunction

}

/**
* This companion object provides a set of types for measurement processors
*/
object MeasurementProcessor {
/**
* The raw result of measurement is always gonna be string, because we want to avoid some floating point issues
Expand All @@ -34,6 +45,11 @@ object MeasurementProcessor {
*/
final case class ResultOfMeasurement(result: String, resultType: ResultValueType.ResultValueType)

/**
* This type alias describes a function that is used to compute measure on Spark `Dataframe`.
* @param df: Spark `Dataframe` to be measured.
* @return Result of measurement.
*/
type MeasurementFunction = DataFrame => ResultOfMeasurement

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,21 @@ package za.co.absa.atum.agent.dispatcher

import za.co.absa.atum.model.dto.{AtumContextDTO, CheckpointDTO, PartitioningSubmitDTO}

/**
* This trait provides a contract for different dispatchers
*/
trait Dispatcher {
/**
* This method is used to ensure the server knows the given partitioning.
* As a response the `AtumContext` is fetched from the server.
* @param partitioning: PartitioningSubmitDTO to be used to ensure server knows the given partitioning.
* @return AtumContextDTO.
*/
def createPartitioning(partitioning: PartitioningSubmitDTO): AtumContextDTO

/**
* This method is used to save checkpoint to server.
* @param checkpoint: CheckpointDTO to be saved.
*/
def saveCheckpoint(checkpoint: CheckpointDTO): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,35 @@

package za.co.absa.atum.agent.exception

/**
* This type represents a base class for exceptions thrown by the Atum Agent.
*
* @param message A message describing the exception.
*/
abstract class AtumAgentException(message: String) extends Exception(message)

/**
* This object contains possible exceptions thrown by the Atum Agent.
*/
object AtumAgentException {
/**
* This type represents an exception related to creation of provided measurement.
*
* @param message A message describing the exception.
*/
case class MeasurementProvidedException(message: String) extends AtumAgentException(message)

/**
* This type represents an exception thrown when a measure is not supported by the Atum Agent.
*
* @param message A message describing the exception.
*/
case class MeasureException(message: String) extends AtumAgentException(message)

/**
* This type represents an exception related to HTTP communication.
* @param statusCode A status code of the HTTP response.
* @param message A message describing the exception.
*/
case class HttpException(statusCode: Int, message: String) extends AtumAgentException(message)
}



65 changes: 44 additions & 21 deletions agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,23 @@ import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType
import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements

/**
* Type of different measures to be applied to the columns.
* This trait represents a measure that can be applied to a column.
*/
sealed trait Measure extends MeasurementProcessor with MeasureType {
val measuredColumn: String
}

/**
* This trait represents a measure type that can be applied to a column.
*/
trait MeasureType {
val measureName: String
val resultValueType: ResultValueType.ResultValueType
}

/**
* This object contains all the possible measures that can be applied to a column.
*/
object Measure {

private val valueColumnName: String = "value"
Expand All @@ -50,9 +56,9 @@ object Measure {
val supportedMeasureNames: Seq[String] = supportedMeasures.map(_.measureName)

case class RecordCount private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {

override def function: MeasurementFunction =
Expand All @@ -69,9 +75,9 @@ object Measure {
}

case class DistinctRecordCount private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {

override def function: MeasurementFunction =
Expand All @@ -90,9 +96,9 @@ object Measure {
}

case class SumOfValuesOfColumn private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {

override def function: MeasurementFunction = (ds: DataFrame) => {
Expand All @@ -111,9 +117,9 @@ object Measure {
}

case class AbsSumOfValuesOfColumn private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {

override def function: MeasurementFunction = (ds: DataFrame) => {
Expand All @@ -132,9 +138,9 @@ object Measure {
}

case class SumOfHashesOfColumn private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {

override def function: MeasurementFunction = (ds: DataFrame) => {
Expand All @@ -157,12 +163,21 @@ object Measure {
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.String
}

/**
* This method aggregates a column of a given data frame using a given aggregation expression.
* The result is converted to a string.
*
* @param df A data frame
* @param measureColumn A column to aggregate
* @param aggExpression An aggregation expression
* @return A string representation of the aggregated value
*/
private def aggregateColumn(
ds: DataFrame,
df: DataFrame,
measureColumn: String,
aggExpression: Column
): String = {
val dataType = ds.select(measureColumn).schema.fields(0).dataType
val dataType = df.select(measureColumn).schema.fields(0).dataType
val aggregatedValue = dataType match {
case _: LongType =>
// This is protection against long overflow, e.g. Long.MaxValue = 9223372036854775807:
Expand All @@ -171,14 +186,14 @@ object Measure {
// Converting to BigDecimal fixes the issue
// val ds2 = ds.select(col(measurement.measuredColumn).cast(DecimalType(38, 0)).as("value"))
// ds2.agg(sum(abs($"value"))).collect()(0)(0)
val ds2 = ds.select(
val ds2 = df.select(
col(measureColumn).cast(DecimalType(38, 0)).as(valueColumnName)
)
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
case _: StringType =>
// Support for string type aggregation
val ds2 = ds.select(
val ds2 = df.select(
col(measureColumn).cast(DecimalType(38, 18)).as(valueColumnName)
)
val collected = ds2.agg(aggExpression).collect()(0)(0)
Expand All @@ -188,14 +203,22 @@ object Measure {
value.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case _ =>
val ds2 = ds.select(col(measureColumn).as(valueColumnName))
val ds2 = df.select(col(measureColumn).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
}
// check if total is required to be presented as larger type - big decimal
workaroundBigDecimalIssues(aggregatedValue)
}

/**
* This method converts a given value to string.
* It is a workaround for different serializers generating different JSONs for BigDecimal.
* See https://stackoverflow.com/questions/61973058/json-serialization-of-bigdecimal-returns-scientific-notation
*
* @param value A value to convert
* @return A string representation of the value
*/
private def workaroundBigDecimalIssues(value: Any): String =
// If aggregated value is java.math.BigDecimal, convert it to scala.math.BigDecimal
value match {
Expand Down
Loading

0 comments on commit 597aa46

Please sign in to comment.