Skip to content

Commit

Permalink
Merge pull request #60 from civitaspo/develop
Browse files Browse the repository at this point in the history
v0.2.2
  • Loading branch information
civitaspo authored Jul 19, 2019
2 parents 30be2c0 + 07889e4 commit 6ac81c0
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 37 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
0.2.2 (2019-07-19)
==================

* [Enhancement] Use scala-logging for logging instead of using slf4j directly
* [Enhancement] Use workgroup default output location for athena query result output location.
* [Change - `athena.ctas>`] Introduce `location` option and `output` option become deprecated.


0.2.1 (2019-07-16)
==================

Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ _export:
repositories:
- https://jitpack.io
dependencies:
- pro.civitaspo:digdag-operator-athena:0.2.1
- pro.civitaspo:digdag-operator-athena:0.2.2
athena:
auth_method: profile

Expand Down Expand Up @@ -200,7 +200,8 @@ Nothing
- **database**: The database name for query execution context. (string, optional)
- **table**: The table name for the new table (string, default: `digdag_athena_ctas_${session_uuid.replaceAll("-", "")}_${random}`)
- **workgroup**: The name of the workgroup in which the query is being started. (string, optional)
- **output**: Output location for data created by CTAS (string, default: `"s3://aws-athena-query-results-${AWS_ACCOUNT_ID}-<AWS_REGION>/Unsaved/${YEAR}/${MONTH}/${DAY}/${athena_query_id}/"`)
- **output**: [**Deprecated**] Use **location** option instead.
- **location**: Output location for data created by CTAS (string, default: `"s3://aws-athena-query-results-${AWS_ACCOUNT_ID}-<AWS_REGION>/Unsaved/${YEAR}/${MONTH}/${DAY}/${athena_query_id}/"`)
- **format**: The data format for the CTAS query results, such as `"orc"`, `"parquet"`, `"avro"`, `"json"`, or `"textfile"`. (string, default: `"parquet"`)
- **compression**: The compression type to use for `"orc"` or `"parquet"`. (string, default: `"snappy"`)
- **field_delimiter**: The field delimiter for files in CSV, TSV, and text files. This option is applied only when **format** is specific to text-based data storage formats. (string, optional)
Expand Down
5 changes: 4 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ plugins {
}

group = 'pro.civitaspo'
version = '0.2.1'
version = '0.2.2'

def digdagVersion = '0.9.37'
def awsSdkVersion = "1.11.587"
Expand Down Expand Up @@ -34,6 +34,9 @@ dependencies {
compile group: 'com.amazonaws', name: 'aws-java-sdk-sts', version: awsSdkVersion
// https://mvnrepository.com/artifact/com.amazonaws/aws-java-sdk-glue
compile group: 'com.amazonaws', name: 'aws-java-sdk-glue', version: awsSdkVersion
// https://mvnrepository.com/artifact/com.typesafe.scala-logging/scala-logging
compile group: 'com.typesafe.scala-logging', name: "scala-logging_$depScalaVersion", version: '3.9.2'

}

shadowJar {
Expand Down
6 changes: 3 additions & 3 deletions example/example.dig
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ _export:
- file://${repos}
# - https://jitpack.io
dependencies:
- pro.civitaspo:digdag-operator-athena:0.2.1
- pro.civitaspo:digdag-operator-athena:0.2.2
athena:
auth_method: profile
value: 5
Expand All @@ -22,7 +22,7 @@ _export:
athena.ctas>: template.sql
database: ${database}
table: hoge
output: ${output}
location: ${output}

+step5:
echo>: ${athena}
Expand All @@ -37,7 +37,7 @@ _export:
athena.ctas>: select 1 as a, 2 as b, 3 as c union all select 4 as a, 5 as b, 6 as c
database: ${database}
table: hoge
output: ${output}
location: ${output}
partitioned_by: [b, c]

+step8:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
package pro.civitaspo.digdag.plugin.athena


import com.typesafe.scalalogging.LazyLogging
import io.digdag.client.config.{Config, ConfigFactory}
import io.digdag.spi.{OperatorContext, SecretProvider, TemplateEngine}
import io.digdag.util.{BaseOperator, DurationParam}
import org.slf4j.{Logger, LoggerFactory}
import pro.civitaspo.digdag.plugin.athena.aws.{Aws, AwsConf}


abstract class AbstractAthenaOperator(operatorName: String,
context: OperatorContext,
systemConfig: Config,
templateEngine: TemplateEngine)
extends BaseOperator(context)
with LazyLogging
{

protected val logger: Logger = LoggerFactory.getLogger(operatorName)
if (!logger.isDebugEnabled) {
if (!logger.underlying.isDebugEnabled) {
// NOTE: suppress aws-java-sdk logs because of a bit noisy logging.
System.setProperty("org.apache.commons.logging.Log", "org.apache.commons.logging.impl.NoOpLog")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package pro.civitaspo.digdag.plugin.athena.aws


import com.typesafe.scalalogging.LazyLogging


abstract class AwsService(aws: Aws)
extends LazyLogging
{
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package pro.civitaspo.digdag.plugin.athena.aws.athena


import com.amazonaws.services.athena.{AmazonAthena, AmazonAthenaClientBuilder}
import com.amazonaws.services.athena.model.{GetQueryExecutionRequest, GetQueryResultsRequest, QueryExecution, QueryExecutionContext, QueryExecutionState, ResultConfiguration, ResultSet, StartQueryExecutionRequest}
import com.amazonaws.services.athena.model.{GetQueryExecutionRequest, GetQueryResultsRequest, GetWorkGroupRequest, QueryExecution, QueryExecutionContext, QueryExecutionState, ResultConfiguration, ResultSet, StartQueryExecutionRequest}
import io.digdag.util.DurationParam
import org.slf4j.Logger
import pro.civitaspo.digdag.plugin.athena.aws.{Aws, AwsService}

import scala.util.{Failure, Success, Try}
import scala.util.chaining._


case class Athena(aws: Aws)
extends AwsService(aws)
Expand Down Expand Up @@ -35,13 +37,34 @@ case class Athena(aws: Aws)
req.setQueryString(query)
database.foreach(db => req.setQueryExecutionContext(new QueryExecutionContext().withDatabase(db)))
req.setWorkGroup(workGroup.getOrElse(DEFAULT_WORKGROUP))
// TODO: overwrite by workgroup configurations if workgroup is not "primary".
req.setResultConfiguration(new ResultConfiguration().withOutputLocation(outputLocation.getOrElse(DEFAULT_OUTPUT_LOCATION)))
req.setResultConfiguration(new ResultConfiguration()
.withOutputLocation(resolveWorkGroupOutputLocation(workGroup.getOrElse(DEFAULT_WORKGROUP))))
requestToken.foreach(req.setClientRequestToken)

withAthena(_.startQueryExecution(req)).getQueryExecutionId
}

def resolveWorkGroupOutputLocation(workGroup: String): String =
{
workGroup match {
case DEFAULT_WORKGROUP => DEFAULT_OUTPUT_LOCATION
case wg =>
val t = Try {
withAthena(_.getWorkGroup(new GetWorkGroupRequest().withWorkGroup(wg)))
.getWorkGroup
.getConfiguration
.getResultConfiguration
.getOutputLocation
}
t match {
case Success(outputLocation) => outputLocation
case Failure(ex) => DEFAULT_OUTPUT_LOCATION.tap { default =>
logger.warn(s"Use $default as athena output location because the workgroup output location cannot be resolved due to '${ex.getMessage}'.", ex)
}
}
}
}

def getQueryExecution(executionId: String): QueryExecution =
{
withAthena(_.getQueryExecution(new GetQueryExecutionRequest().withQueryExecutionId(executionId))).getQueryExecution
Expand All @@ -50,14 +73,12 @@ case class Athena(aws: Aws)
def waitQueryExecution(executionId: String,
successStates: Seq[QueryExecutionState],
failureStates: Seq[QueryExecutionState],
timeout: DurationParam,
loggerOption: Option[Logger] = None): Unit =
timeout: DurationParam): Unit =
{
val waiter = AthenaQueryWaiter(athena = this,
successStats = successStates,
failureStats = failureStates,
timeout = timeout,
loggerOption = loggerOption)
timeout = timeout)
waiter.wait(executionId)
}

Expand All @@ -68,8 +89,7 @@ case class Athena(aws: Aws)
requestToken: Option[String] = None,
successStates: Seq[QueryExecutionState],
failureStates: Seq[QueryExecutionState],
timeout: DurationParam,
loggerOption: Option[Logger] = None): QueryExecution =
timeout: DurationParam): QueryExecution =
{
val executionId: String = startQueryExecution(query = query,
database = database,
Expand All @@ -80,8 +100,7 @@ case class Athena(aws: Aws)
waitQueryExecution(executionId = executionId,
successStates = successStates,
failureStates = failureStates,
timeout = timeout,
loggerOption = loggerOption)
timeout = timeout)

getQueryExecution(executionId = executionId)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,17 @@ import java.util.concurrent.{Executors, ExecutorService}
import com.amazonaws.services.athena.model.{GetQueryExecutionRequest, GetQueryExecutionResult, QueryExecutionState}
import com.amazonaws.waiters.{PollingStrategy, PollingStrategyContext, SdkFunction, Waiter, WaiterAcceptor, WaiterBuilder, WaiterParameters, WaiterState}
import com.amazonaws.waiters.PollingStrategy.{DelayStrategy, RetryStrategy}
import com.typesafe.scalalogging.LazyLogging
import io.digdag.util.DurationParam
import org.slf4j.{Logger, LoggerFactory}


case class AthenaQueryWaiter(athena: Athena,
successStats: Seq[QueryExecutionState],
failureStats: Seq[QueryExecutionState],
executorService: ExecutorService = Executors.newFixedThreadPool(50),
timeout: DurationParam,
loggerOption: Option[Logger] = None)
timeout: DurationParam)
extends LazyLogging
{
val logger: Logger = loggerOption.getOrElse(LoggerFactory.getLogger(classOf[AthenaQueryWaiter]))

def wait(executionId: String): Unit =
{
newWaiter().run(newWaiterParameters(executionId = executionId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import com.amazonaws.services.securitytoken.{AWSSecurityTokenService, AWSSecurit
import com.amazonaws.services.securitytoken.model.{AssumeRoleRequest, GetCallerIdentityRequest, PolicyDescriptorType}
import pro.civitaspo.digdag.plugin.athena.aws.{Aws, AwsService}

import scala.collection.JavaConverters._
import scala.jdk.CollectionConverters._


case class Sts(aws: Aws)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import pro.civitaspo.digdag.plugin.athena.AbstractAthenaOperator
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Random, Success, Try}


class AthenaCtasOperator(operatorName: String,
context: OperatorContext,
systemConfig: Config,
Expand Down Expand Up @@ -85,7 +86,20 @@ class AthenaCtasOperator(operatorName: String,
protected val database: Optional[String] = params.getOptional("database", classOf[String])
protected val table: String = params.get("table", classOf[String], defaultTableName)
protected val workGroup: Optional[String] = params.getOptional("workgroup", classOf[String])
@deprecated(message = "Use location option instead", since = "0.2.2")
protected val output: Optional[String] = params.getOptional("output", classOf[String])
protected val location: Optional[String] = {
val l = params.getOptional("location", classOf[String])
if (output.isPresent && l.isPresent) {
logger.warn(s"Use the value of location option: ${l.get()} although the value of output option (${output.get()}) is specified.")
l
}
else if (output.isPresent) {
logger.warn("output option is deprecated. Please use location option instead.")
output
}
else l
}
protected val format: String = params.get("format", classOf[String], "parquet")
protected val compression: String = params.get("compression", classOf[String], "snappy")
protected val fieldDelimiter: Optional[String] = params.getOptional("field_delimiter", classOf[String])
Expand Down Expand Up @@ -135,15 +149,15 @@ class AthenaCtasOperator(operatorName: String,
override def runTask(): TaskResult =
{
saveMode match {
case SaveMode.ErrorIfExists if output.isPresent && hasObjects(output.get) =>
throw new IllegalStateException(s"${output.get} already exists")
case SaveMode.Ignore if output.isPresent && hasObjects(output.get) =>
logger.info(s"${output.get} already exists, so ignore this session.")
case SaveMode.ErrorIfExists if location.isPresent && hasObjects(location.get) =>
throw new IllegalStateException(s"${location.get} already exists")
case SaveMode.Ignore if location.isPresent && hasObjects(location.get) =>
logger.info(s"${location.get} already exists, so ignore this session.")
return TaskResult.empty(request)
case SaveMode.Overwrite if output.isPresent =>
logger.info(s"Overwrite ${output.get}")
rmObjects(output.get)
case _ => // do nothing
case SaveMode.Overwrite if location.isPresent =>
logger.info(s"Overwrite ${location.get}")
rmObjects(location.get)
case _ => // do nothing
}

val subTask: Config = cf.create()
Expand All @@ -169,7 +183,7 @@ class AthenaCtasOperator(operatorName: String,
protected def generateCtasQuery(): String =
{
val propsBuilder = Map.newBuilder[String, String]
if (output.isPresent) propsBuilder += ("external_location" -> s"'${output.get}'")
if (location.isPresent) propsBuilder += ("external_location" -> s"'${location.get}'")
propsBuilder += ("format" -> s"'$format'")
format match {
case "parquet" => propsBuilder += ("parquet_compression" -> s"'$compression'")
Expand Down

0 comments on commit 6ac81c0

Please sign in to comment.