Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce MonadError & refactor Doobie classes #113

Merged
merged 11 commits into from
Feb 21, 2024
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
with:
paths: >
${{ github.workspace }}/core/target/scala-${{ matrix.scalaShort }}/jacoco/report/jacoco.xml,
${{ github.workspace }}/slick/target/scala-${{ matrix.scalaShort }}/jacoco/report/jacoco.xml
${{ github.workspace }}/slick/target/scala-${{ matrix.scalaShort }}/jacoco/report/jacoco.xml,
${{ github.workspace }}/doobie/target/scala-${{ matrix.scalaShort }}/jacoco/report/jacoco.xml
# examples don't need code coverage - at least not now
token: ${{ secrets.GITHUB_TOKEN }}
Expand Down
39 changes: 23 additions & 16 deletions core/src/main/scala/za/co/absa/fadb/DBFunction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package za.co.absa.fadb

import cats.MonadError
import cats.implicits.toFlatMapOps
import za.co.absa.fadb.exceptions.StatusException
import za.co.absa.fadb.status.handling.StatusHandling

Expand All @@ -31,8 +33,8 @@ import scala.language.higherKinds
* @tparam E - The type of the [[DBEngine]] engine.
* @tparam F - The type of the context in which the database function is executed.
*/
abstract class DBFunction[I, R, E <: DBEngine[F], F[_]](functionNameOverride: Option[String] = None)(
implicit override val schema: DBSchema,
abstract class DBFunction[I, R, E <: DBEngine[F], F[_]](functionNameOverride: Option[String] = None)(implicit
override val schema: DBSchema,
val dBEngine: E
) extends DBFunctionFabric(functionNameOverride) {

Expand All @@ -47,29 +49,33 @@ abstract class DBFunction[I, R, E <: DBEngine[F], F[_]](functionNameOverride: Op
* @param values - The values to pass over to the database function.
* @return - A sequence of results from the database function.
*/
protected def multipleResults(values: I): F[Seq[R]] = dBEngine.fetchAll(query(values))
protected def multipleResults(values: I)(implicit me: MonadError[F, Throwable]): F[Seq[R]] =
query(values).flatMap(q => dBEngine.fetchAll(q))

/**
* Executes the database function and returns a single result.
* @param values - The values to pass over to the database function.
* @return - A single result from the database function.
*/
protected def singleResult(values: I): F[R] = dBEngine.fetchHead(query(values))
protected def singleResult(values: I)(implicit me: MonadError[F, Throwable]): F[R] =
query(values).flatMap(q => dBEngine.fetchHead(q))

/**
* Executes the database function and returns an optional result.
* @param values - The values to pass over to the database function.
* @return - An optional result from the database function.
*/
protected def optionalResult(values: I): F[Option[R]] = dBEngine.fetchHeadOption(query(values))
protected def optionalResult(values: I)(implicit me: MonadError[F, Throwable]): F[Option[R]] = {
query(values).flatMap(q => dBEngine.fetchHeadOption(q))
}

/**
* Function to create the DB function call specific to the provided [[DBEngine]].
* Expected to be implemented by the DBEngine specific mix-in.
* @param values - The values to pass over to the database function.
* @return - The SQL query in the format specific to the provided [[DBEngine]].
*/
protected def query(values: I): dBEngine.QueryType[R]
protected def query(values: I)(implicit me: MonadError[F, Throwable]): F[dBEngine.QueryType[R]]
}

/**
Expand All @@ -83,8 +89,8 @@ abstract class DBFunction[I, R, E <: DBEngine[F], F[_]](functionNameOverride: Op
* @tparam E - The type of the [[DBEngine]] engine.
* @tparam F - The type of the context in which the database function is executed.
*/
abstract class DBFunctionWithStatus[I, R, E <: DBEngine[F], F[_]](functionNameOverride: Option[String] = None)(
implicit override val schema: DBSchema,
abstract class DBFunctionWithStatus[I, R, E <: DBEngine[F], F[_]](functionNameOverride: Option[String] = None)(implicit
override val schema: DBSchema,
val dBEngine: E
) extends DBFunctionFabric(functionNameOverride)
with StatusHandling {
Expand All @@ -100,10 +106,12 @@ abstract class DBFunctionWithStatus[I, R, E <: DBEngine[F], F[_]](functionNameOv

/**
* Executes the database function and returns multiple results.
* @param values
* @param values The values to pass over to the database function.
* @return A sequence of results from the database function.
*/
def apply(values: I): F[Either[StatusException, R]] = dBEngine.runWithStatus(query(values))
def apply(values: I)(implicit me: MonadError[F, Throwable]): F[Either[StatusException, R]] = {
query(values).flatMap(q => dBEngine.runWithStatus(q))
}

/**
* The fields to select from the database function call
Expand All @@ -117,12 +125,11 @@ abstract class DBFunctionWithStatus[I, R, E <: DBEngine[F], F[_]](functionNameOv
}

/**
* Function to create the DB function call specific to the provided [[DBEngine]]. Expected to be implemented by the
* DBEngine specific mix-in.
* Function to create the DB function call specific to the provided [[DBEngine]].
* @param values the values to pass over to the database function
* @return the SQL query in the format specific to the provided [[DBEngine]]
*/
protected def query(values: I): dBEngine.QueryWithStatusType[R]
protected def query(values: I)(implicit me: MonadError[F, Throwable]): F[dBEngine.QueryWithStatusType[R]]

// To be provided by an implementation of QueryStatusHandling
override def checkStatus[A](statusWithData: FunctionStatusWithData[A]): Either[StatusException, A]
Expand Down Expand Up @@ -151,7 +158,7 @@ object DBFunction {
* @return - a sequence of values, each coming from a row returned from the DB function transformed to scala
* type `R`
*/
def apply(values: I): F[Seq[R]] = multipleResults(values)
def apply(values: I)(implicit me: MonadError[F, Throwable]): F[Seq[R]] = multipleResults(values)
}

/**
Expand All @@ -174,7 +181,7 @@ object DBFunction {
* @param values - the values to pass over to the database function
* @return - the value returned from the DB function transformed to scala type `R`
*/
def apply(values: I): F[R] = singleResult(values)
def apply(values: I)(implicit me: MonadError[F, Throwable]): F[R] = singleResult(values)
}

/**
Expand All @@ -197,6 +204,6 @@ object DBFunction {
* @param values - the values to pass over to the database function
* @return - the value returned from the DB function transformed to scala type `R` if a row is returned, otherwise `None`
*/
def apply(values: I): F[Option[R]] = optionalResult(values)
def apply(values: I)(implicit me: MonadError[F, Throwable]): F[Option[R]] = optionalResult(values)
}
}
8 changes: 6 additions & 2 deletions core/src/test/scala/za/co/absa/fadb/DBFunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package za.co.absa.fadb

import cats.MonadError
import cats.implicits._
import org.scalatest.funsuite.AnyFunSuite
import za.co.absa.fadb.DBFunction.DBSingleResultFunction
Expand Down Expand Up @@ -44,7 +45,9 @@ class DBFunctionSuite extends AnyFunSuite {
class MyFunction(implicit override val schema: DBSchema, val dbEngine: EngineThrow)
extends DBSingleResultFunction[Unit, Unit, EngineThrow, Future](None) {

override protected def query(values: Unit): dBEngine.QueryType[Unit] = neverHappens
override protected def query(values: Unit)
(implicit me: MonadError[Future, Throwable]
): Future[dBEngine.QueryType[Unit]] = neverHappens
}

val fnc1 = new MyFunction()(FooNamed, new EngineThrow)
Expand All @@ -58,7 +61,8 @@ class DBFunctionSuite extends AnyFunSuite {
class MyFunction(implicit override val schema: DBSchema, val dbEngine: EngineThrow)
extends DBSingleResultFunction[Unit, Unit, EngineThrow, Future](Some("bar")) {

override protected def query(values: Unit): dBEngine.QueryType[Unit] = neverHappens
override protected def query(values: Unit)(implicit me: MonadError[Future, Throwable]
): Future[dBEngine.QueryType[Unit]] = neverHappens
}

val fnc1 = new MyFunction()(FooNamed, new EngineThrow)
Expand Down
6 changes: 6 additions & 0 deletions doobie/src/it/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
### How to do testing

In order to execute tests in this module you need to:
- deploy all sql code from database folder into postgres instance of your choice
- make sure you have data in your tables as tests expect populated tables (unfortunately as this point this is not automated)
- set up connection to your database in DoobieTest trait
40 changes: 16 additions & 24 deletions doobie/src/it/scala/za/co/absa/fadb/doobie/DatesTimesTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package za.co.absa.fadb.doobie
import cats.effect.IO
import cats.effect.unsafe.implicits.global
import doobie.implicits.toSqlInterpolator
import doobie.util.Read
import doobie.util.fragment.Fragment
import org.scalatest.funsuite.AnyFunSuite
import za.co.absa.fadb.DBSchema
import za.co.absa.fadb.doobie.DoobieFunction.{DoobieSingleResultFunction, DoobieSingleResultFunctionWithStatus}
Expand All @@ -46,30 +44,24 @@ class DatesTimesTest extends AnyFunSuite with DoobieTest {
)

class GetAllDateTimeTypes(implicit schema: DBSchema, dbEngine: DoobieEngine[IO])
extends DoobieSingleResultFunction[Int, DatesTimes, IO] {

override def sql(values: Int)(implicit read: Read[DatesTimes]): Fragment =
sql"SELECT * FROM ${Fragment.const(functionName)}($values)"
}
extends DoobieSingleResultFunction[Int, DatesTimes, IO](values => Seq(fr"$values"))

class InsertDatesTimes(implicit schema: DBSchema, dbEngine: DoobieEngine[IO])
extends DoobieSingleResultFunctionWithStatus[DatesTimes, Int, IO] with StandardStatusHandling {

override def sql(values: DatesTimes)(implicit read: Read[StatusWithData[Int]]): Fragment =
sql"""
SELECT * FROM ${Fragment.const(functionName)}(
${values.offsetDateTime},
${values.instant},
${values.zonedDateTime},
${values.localDateTime},
${values.localDate},
${values.localTime},
${values.sqlDate},
${values.sqlTime},
${values.sqlTimestamp},
${values.utilDate}
)
"""
extends DoobieSingleResultFunctionWithStatus[DatesTimes, Int, IO] (
values => Seq(
fr"${values.offsetDateTime}",
fr"${values.instant}",
fr"${values.zonedDateTime}",
fr"${values.localDateTime}",
fr"${values.localDate}",
fr"${values.localTime}",
fr"${values.sqlDate}",
fr"${values.sqlTime}",
fr"${values.sqlTimestamp}",
fr"${values.utilDate}"
)
) with StandardStatusHandling {
override def fieldsToSelect: Seq[String] = super.fieldsToSelect ++ Seq("o_id")
lsulak marked this conversation as resolved.
Show resolved Hide resolved
}

private val getAllDateTimeTypes = new GetAllDateTimeTypes()(Runs, new DoobieEngine(transactor))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,39 @@

package za.co.absa.fadb.doobie

import cats.Semigroup
import cats.effect.IO
import cats.effect.unsafe.implicits.global
import cats.implicits.catsSyntaxSemigroup
import doobie.Fragment
import doobie.implicits.toSqlInterpolator
import doobie.util.Read
import doobie.util.fragment.Fragment
import org.scalatest.funsuite.AnyFunSuite
import za.co.absa.fadb.DBSchema
import za.co.absa.fadb.doobie.DoobieFunction.DoobieMultipleResultFunction

class DoobieMultipleResultFunctionTest extends AnyFunSuite with DoobieTest {

class GetActors(implicit schema: DBSchema, dbEngine: DoobieEngine[IO])
extends DoobieMultipleResultFunction[GetActorsQueryParameters, Actor, IO] {
implicit def toFragmentsFunctionSemigroup[T]: Semigroup[T => Seq[Fragment]] = {
(f1: T => Seq[Fragment], f2: T => Seq[Fragment]) => (params: T) => f1(params) ++ f2(params)
}

private val firstNameFragment: GetActorsQueryParameters => Seq[Fragment] = params => Seq(fr"${params.firstName}")
private val lastNameFragment: GetActorsQueryParameters => Seq[Fragment] = params => Seq(fr"${params.lastName}")

private val combinedQueryFragments: GetActorsQueryParameters => Seq[Fragment] =
params => firstNameFragment(params) ++ lastNameFragment(params)

override def sql(values: GetActorsQueryParameters)(implicit read: Read[Actor]): Fragment =
sql"SELECT actor_id, first_name, last_name FROM ${Fragment.const(functionName)}(${values.firstName}, ${values.lastName})"
// using Semigroup's combine method, |+| is syntactical sugar for combine method
private val combinedUsingSemigroup = firstNameFragment |+| lastNameFragment

// not combined, defined as one function
private val getActorsQueryFragments: GetActorsQueryParameters => Seq[Fragment] = {
values => Seq(fr"${values.firstName}", fr"${values.lastName}")
}

class GetActors(implicit schema: DBSchema, dbEngine: DoobieEngine[IO])
extends DoobieMultipleResultFunction[GetActorsQueryParameters, Actor, IO](combinedUsingSemigroup)

private val getActors = new GetActors()(Runs, new DoobieEngine(transactor))

test("Retrieving actor from database") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,14 @@ package za.co.absa.fadb.doobie
import cats.effect.IO
import cats.effect.unsafe.implicits.global
import doobie.implicits.toSqlInterpolator
import doobie.util.Read
import doobie.util.fragment.Fragment
import org.scalatest.funsuite.AnyFunSuite
import za.co.absa.fadb.DBSchema
import za.co.absa.fadb.doobie.DoobieFunction.DoobieOptionalResultFunction

class DoobieOptionalResultFunctionTest extends AnyFunSuite with DoobieTest {

class GetActorById(implicit schema: DBSchema, dbEngine: DoobieEngine[IO])
extends DoobieOptionalResultFunction[Int, Actor, IO] {

override def sql(values: Int)(implicit read: Read[Actor]): Fragment =
sql"SELECT actor_id, first_name, last_name FROM ${Fragment.const(functionName)}($values)"
}
extends DoobieOptionalResultFunction[Int, Actor, IO](id => Seq(fr"$id"))

private val createActor = new GetActorById()(Runs, new DoobieEngine(transactor))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package za.co.absa.fadb.doobie
import cats.effect.IO
import cats.effect.unsafe.implicits.global
import doobie.implicits.toSqlInterpolator
import doobie.util.Read
import doobie.util.fragment.Fragment
import org.scalatest.funsuite.AnyFunSuite
import za.co.absa.fadb.DBSchema
import za.co.absa.fadb.doobie.DoobieFunction.{DoobieSingleResultFunction, DoobieSingleResultFunctionWithStatus}
Expand Down Expand Up @@ -49,32 +47,25 @@ class DoobieOtherTypesTest extends AnyFunSuite with DoobieTest {
arrayCol: Array[Int]
)


class ReadOtherTypes(implicit schema: DBSchema, dbEngine: DoobieEngine[IO])
extends DoobieSingleResultFunction[Int, OtherTypesData, IO] {

override def sql(values: Int)(implicit read: Read[OtherTypesData]): Fragment =
sql"SELECT * FROM ${Fragment.const(functionName)}($values)"
}
extends DoobieSingleResultFunction[Int, OtherTypesData, IO] (values => Seq(fr"$values"))

class InsertOtherTypes(implicit schema: DBSchema, dbEngine: DoobieEngine[IO])
extends DoobieSingleResultFunctionWithStatus[OtherTypesData, Option[Int], IO] with StandardStatusHandling {

override def sql(values: OtherTypesData)(implicit read: Read[StatusWithData[Option[Int]]]): Fragment =
sql"""
SELECT * FROM ${Fragment.const(functionName)}(
${values.id},
${values.ltreeCol}::ltree,
${values.inetCol}::inet,
${values.macaddrCol}::macaddr,
${values.hstoreCol}::hstore,
${values.cidrCol}::cidr,
${values.jsonCol}::json,
${values.jsonbCol}::jsonb,
${values.uuidCol}::uuid,
${values.arrayCol}::integer[]
)
"""
extends DoobieSingleResultFunctionWithStatus[OtherTypesData, Option[Int], IO] (
values => Seq(
fr"${values.id}",
fr"${values.ltreeCol}::ltree",
fr"${values.inetCol}::inet",
fr"${values.macaddrCol}::macaddr",
fr"${values.hstoreCol}::hstore",
fr"${values.cidrCol}::cidr",
fr"${values.jsonCol}::json",
fr"${values.jsonbCol}::jsonb",
fr"${values.uuidCol}::uuid",
fr"${values.arrayCol}::integer[]"
)
) with StandardStatusHandling {
override def fieldsToSelect: Seq[String] = super.fieldsToSelect ++ Seq("o_id")
}

private val readOtherTypes = new ReadOtherTypes()(Runs, new DoobieEngine(transactor))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,26 @@ package za.co.absa.fadb.doobie
import cats.effect.IO
import cats.effect.unsafe.implicits.global
import doobie.implicits.toSqlInterpolator
import doobie.util.Read
import doobie.util.fragment.Fragment
import org.scalatest.funsuite.AnyFunSuite
import za.co.absa.fadb.DBSchema
import za.co.absa.fadb.doobie.DoobieFunction.DoobieSingleResultFunction

class DoobieSingleResultFunctionTest extends AnyFunSuite with DoobieTest {

class CreateActor(implicit schema: DBSchema, dbEngine: DoobieEngine[IO])
extends DoobieSingleResultFunction[CreateActorRequestBody, Int, IO] {

// do not remove the example below
// override def fieldsToSelect: Seq[String] = super.fieldsToSelect ++ Seq("o_actor_id")

override def sql(values: CreateActorRequestBody)(implicit read: Read[Int]): Fragment =
sql"SELECT o_actor_id FROM ${Fragment.const(functionName)}(${values.firstName}, ${values.lastName})"
// do not remove the example below, it has to be used with the override def fieldsToSelect
// sql"SELECT ${Fragment.const(selectEntry)} FROM ${Fragment.const(functionName)}(${values.firstName}, ${values.lastName}) ${Fragment.const(alias)}"
extends DoobieSingleResultFunction[CreateActorRequestBody, Int, IO] (
values => {
throw new Exception("boom")
Seq(fr"${values.firstName}", fr"${values.lastName}")
}
) {
override def fieldsToSelect: Seq[String] = super.fieldsToSelect ++ Seq("o_actor_id")
}

private val createActor = new CreateActor()(Runs, new DoobieEngine(transactor))

test("Inserting an actor into database") {
assert(createActor(CreateActorRequestBody("Pavel", "Marek")).unsafeRunSync().isInstanceOf[Int])
test("Inserting an actor into database & handling an error") {
val result = createActor(CreateActorRequestBody("Pavel", "Marek")).handleErrorWith(_ => IO(Int.MaxValue)).unsafeRunSync()
assert(result == Int.MaxValue)
}
}
Loading
Loading