Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch RDD bindings to using inline-java under the hood. #103

Merged
merged 2 commits into from
Apr 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions Setup.hs
Original file line number Diff line number Diff line change
@@ -1,15 +1,4 @@
import Distribution.Simple
import System.Process
import System.Exit
import Language.Java.Inline.Cabal

main = defaultMainWithHooks simpleUserHooks { postBuild = buildJavaSource }

buildJavaSource _ _ _ _ = do
executeShellCommand "gradle build"
return ()

executeShellCommand cmd = system cmd >>= check
where
check ExitSuccess = return ()
check (ExitFailure n) =
error $ "Command " ++ cmd ++ " exited with failure code " ++ show n
main = defaultMainWithHooks (gradleHooks simpleUserHooks)
5 changes: 3 additions & 2 deletions sparkle.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ custom-setup
setup-depends:
base,
Cabal >= 1.24,
process
inline-java >= 0.6.3

library
include-dirs: cbits
Expand Down Expand Up @@ -67,8 +67,9 @@ library
bytestring >=0.10,
choice >= 0.1,
distributed-closure >=0.3,
inline-java >= 0.6.3,
jni >=0.3.0,
jvm >=0.2.0,
jvm >=0.2.1,
singletons >= 2.0,
streaming >= 0.1,
text >=1.2,
Expand Down
41 changes: 19 additions & 22 deletions src/Control/Distributed/Spark/Context.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Control.Distributed.Spark.Context
Expand All @@ -33,21 +34,22 @@ import qualified Data.Text as Text
import Data.Text (Text)
import Control.Distributed.Spark.RDD
import Language.Java
import Language.Java.Inline

newtype SparkConf = SparkConf (J ('Class "org.apache.spark.SparkConf"))
instance Coercible SparkConf ('Class "org.apache.spark.SparkConf")

newSparkConf :: Text -> IO SparkConf
newSparkConf appname = do
jname <- reflect appname
cnf :: SparkConf <- new []
call cnf "setAppName" [coerce jname]
conf :: SparkConf <- new []
[java| $conf.setAppName($jname) |]

confSet :: SparkConf -> Text -> Text -> IO ()
confSet conf key value = do
jkey <- reflect key
jval <- reflect value
_ :: SparkConf <- call conf "set" [coerce jkey, coerce jval]
_ :: SparkConf <- [java| $conf.set($jkey, $jval) |]
return ()

newtype SparkContext = SparkContext (J ('Class "org.apache.spark.api.java.JavaSparkContext"))
Expand All @@ -57,57 +59,52 @@ newSparkContext :: SparkConf -> IO SparkContext
newSparkContext conf = new [coerce conf]

getOrCreateSparkContext :: SparkConf -> IO SparkContext
getOrCreateSparkContext cnf = do
getOrCreateSparkContext conf = do
scalaCtx :: J ('Class "org.apache.spark.SparkContext") <-
callStatic (sing :: Sing "org.apache.spark.SparkContext") "getOrCreate" [coerce cnf]

callStatic (sing :: Sing "org.apache.spark.api.java.JavaSparkContext") "fromSparkContext" [coerce scalaCtx]
[java| org.apache.spark.SparkContext.getOrCreate($conf) |]
[java| org.apache.spark.api.java.JavaSparkContext.fromSparkContext($scalaCtx) |]

-- | Adds the given file to the pool of files to be downloaded
-- on every worker node. Use 'getFile' on those nodes to
-- get the (local) file path of that file in order to read it.
addFile :: SparkContext -> FilePath -> IO ()
addFile sc fp = do
jfp <- reflect (Text.pack fp)
call sc "addFile" [coerce jfp]
-- XXX workaround for inline-java-0.6 not supporting void return types.
_ :: JObject <- [java| { $sc.addFile($jfp); return null; } |]
return ()

-- | Returns the local filepath of the given filename that
-- was "registered" using 'addFile'.
getFile :: FilePath -> IO FilePath
getFile filename = do
jfilename <- reflect (Text.pack filename)
fmap Text.unpack . reify =<< callStatic (sing :: Sing "org.apache.spark.SparkFiles") "get" [coerce jfilename]
fmap Text.unpack . reify =<<
[java| org.apache.spark.SparkFiles.get($jfilename) |]

master :: SparkContext -> IO Text
master sc = do
res <- call sc "master" []
reify res
master sc = reify =<< [java| $sc.master() |]

-- | See Note [Reading Files] ("Control.Distributed.Spark.RDD#reading_files").
textFile :: SparkContext -> FilePath -> IO (RDD Text)
textFile sc path = do
jpath <- reflect (Text.pack path)
call sc "textFile" [coerce jpath]
[java| $sc.textFile($jpath) |]

-- | The record length must be provided in bytes.
--
-- See Note [Reading Files] ("Control.Distributed.Spark.RDD#reading_files").
binaryRecords :: SparkContext -> FilePath -> Int32 -> IO (RDD ByteString)
binaryRecords sc fp recordLength = do
jpath <- reflect (Text.pack fp)
call sc "binaryRecords" [coerce jpath, coerce recordLength]
[java| $sc.binaryRecords($jpath, $recordLength) |]

parallelize
:: Reflect a ty
=> SparkContext
-> [a]
-> IO (RDD a)
parallelize sc xs = do
jxs :: J ('Iface "java.util.List") <- arrayToList =<< reflect xs
call sc "parallelize" [coerce jxs]
where
arrayToList jxs =
callStatic
(sing :: Sing "java.util.Arrays")
"asList"
[coerce (unsafeCast jxs :: JObjectArray)]
jxs :: J ('Array ('Class "java.lang.Object")) <- unsafeCast <$> reflect xs
jlist :: J ('Iface "java.util.List") <- [java| java.util.Arrays.asList($jxs) |]
[java| $sc.parallelize($jlist) |]
94 changes: 47 additions & 47 deletions src/Control/Distributed/Spark/RDD.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StaticPointers #-}

Expand Down Expand Up @@ -40,13 +41,14 @@ module Control.Distributed.Spark.RDD

import Prelude hiding (filter, map, subtract, take)
import Control.Distributed.Closure
import Control.Distributed.Spark.Closure ()
import Control.Distributed.Spark.Closure (JFun1, JFun2)
import Data.Choice (Choice)
import qualified Data.Choice as Choice
import Data.Int
import qualified Data.Text as Text
import Data.Typeable (Typeable)
import Language.Java
import Language.Java.Inline
-- We don't need this instance. But import to bring it in scope transitively for users.
#if MIN_VERSION_base(4,9,1)
import Language.Java.Streaming ()
Expand All @@ -57,25 +59,25 @@ newtype RDD a = RDD (J ('Class "org.apache.spark.api.java.JavaRDD"))
instance Coercible (RDD a) ('Class "org.apache.spark.api.java.JavaRDD")

repartition :: Int32 -> RDD a -> IO (RDD a)
repartition nbPart rdd = call rdd "repartition" [JInt nbPart]
repartition n rdd = [java| $rdd.repartition($n) |]

filter
:: Reflect (Closure (a -> Bool)) ty
=> Closure (a -> Bool)
-> RDD a
-> IO (RDD a)
filter clos rdd = do
f <- reflect clos
call rdd "filter" [coerce f]
f <- unsafeUngeneric <$> reflect clos
[java| $rdd.filter($f) |]

map
:: Reflect (Closure (a -> b)) ty
:: Reflect (Closure (a -> b)) (JFun1 ty1 ty2)
=> Closure (a -> b)
-> RDD a
-> IO (RDD b)
map clos rdd = do
f <- reflect clos
call rdd "map" [coerce f]
f <- unsafeUngeneric <$> reflect clos
[java| $rdd.map($f) |]

mapPartitions
:: (Reflect (Closure (Int32 -> Stream (Of a) IO () -> Stream (Of b) IO ())) ty, Typeable a, Typeable b)
Expand All @@ -93,54 +95,54 @@ mapPartitionsWithIndex
-> RDD a
-> IO (RDD b)
mapPartitionsWithIndex preservePartitions clos rdd = do
f <- reflect clos
call rdd "mapPartitionsWithIndex" [coerce f, coerce (Choice.toBool preservePartitions)]
f <- unsafeUngeneric <$> reflect clos
[java| $rdd.mapPartitionsWithIndex($f, $preservePartitions) |]

fold
:: (Reflect (Closure (a -> a -> a)) ty1, Reflect a ty2, Reify a ty2)
:: (Reflect (Closure (a -> a -> a)) (JFun2 ty ty ty), Reflect a ty, Reify a ty)
=> Closure (a -> a -> a)
-> a
-> RDD a
-> IO a
fold clos zero rdd = do
f <- reflect clos
f <- unsafeUngeneric <$> reflect clos
jzero <- upcast <$> reflect zero
res :: JObject <- call rdd "fold" [coerce jzero, coerce f]
res :: JObject <- [java| $rdd.fold($jzero, $f) |]
reify (unsafeCast res)

reduce
:: (Reflect (Closure (a -> a -> a)) ty1, Reify a ty2, Reflect a ty2)
:: (Reflect (Closure (a -> a -> a)) (JFun2 ty ty ty), Reify a ty, Reflect a ty)
=> Closure (a -> a -> a)
-> RDD a
-> IO a
reduce clos rdd = do
f <- reflect clos
res :: JObject <- call rdd "reduce" [coerce f]
f <- unsafeUngeneric <$> reflect clos
res :: JObject <- [java| $rdd.reduce($f) |]
reify (unsafeCast res)

aggregate
:: ( Reflect (Closure (b -> a -> b)) ty1
, Reflect (Closure (b -> b -> b)) ty2
, Reify b ty3
, Reflect b ty3
:: ( Reflect (Closure (b -> a -> b)) (JFun2 ty2 ty1 ty2)
, Reflect (Closure (b -> b -> b)) (JFun2 ty2 ty2 ty2)
, Reify b ty2
, Reflect b ty2
)
=> Closure (b -> a -> b)
-> Closure (b -> b -> b)
-> b
-> RDD a
-> IO b
aggregate seqOp combOp zero rdd = do
jseqOp <- reflect seqOp
jcombOp <- reflect combOp
jseqOp <- unsafeUngeneric <$> reflect seqOp
jcombOp <- unsafeUngeneric <$> reflect combOp
jzero <- upcast <$> reflect zero
res :: JObject <- call rdd "aggregate" [coerce jzero, coerce jseqOp, coerce jcombOp]
res :: JObject <- [java| $rdd.aggregate($jzero, $jseqOp, $jcombOp) |]
reify (unsafeCast res)

treeAggregate
:: ( Reflect (Closure (b -> a -> b)) ty1
, Reflect (Closure (b -> b -> b)) ty2
, Reflect b ty3
, Reify b ty3
:: ( Reflect (Closure (b -> a -> b)) (JFun2 ty2 ty1 ty2)
, Reflect (Closure (b -> b -> b)) (JFun2 ty2 ty2 ty2)
, Reflect b ty2
, Reify b ty2
)
=> Closure (b -> a -> b)
-> Closure (b -> b -> b)
Expand All @@ -149,20 +151,17 @@ treeAggregate
-> RDD a
-> IO b
treeAggregate seqOp combOp zero depth rdd = do
jseqOp <- reflect seqOp
jcombOp <- reflect combOp
jseqOp <- unsafeUngeneric <$> reflect seqOp
jcombOp <- unsafeUngeneric <$> reflect combOp
jzero <- upcast <$> reflect zero
let jdepth = coerce depth
res :: JObject <-
call rdd "treeAggregate"
[ coerce jseqOp, coerce jcombOp, coerce jzero, jdepth ]
res :: JObject <- [java| $rdd.treeAggregate($jzero, $jseqOp, $jcombOp, $depth) |]
reify (unsafeCast res)

count :: RDD a -> IO Int64
count rdd = call rdd "count" []
count rdd = [java| $rdd.count() |]

subtract :: RDD a -> RDD a -> IO (RDD a)
subtract rdd rdds = call rdd "subtract" [coerce rdds]
subtract rdd1 rdd2 = [java| $rdd1.subtract($rdd2) |]

-- $reading_files
--
Expand All @@ -183,43 +182,44 @@ subtract rdd rdds = call rdd "subtract" [coerce rdds]
-- | See Note [Reading Files] ("Control.Distributed.Spark.RDD#reading_files").
collect :: Reify a ty => RDD a -> IO [a]
collect rdd = do
alst :: J ('Iface "java.util.List") <- call rdd "collect" []
arr :: JObjectArray <- call alst "toArray" []
res :: J ('Iface "java.util.List") <- [java| $rdd.collect() |]
arr :: JObjectArray <- [java| $res.toArray() |]
reify (unsafeCast arr)

-- | See Note [Reading Files] ("Control.Distributed.Spark.RDD#reading_files").
take :: Reify a ty => RDD a -> Int32 -> IO [a]
take rdd n = do
res :: J ('Class "java.util.List") <- call rdd "take" [JInt n]
arr :: JObjectArray <- call res "toArray" []
res :: J ('Class "java.util.List") <- [java| $rdd.take($n) |]
arr :: JObjectArray <- [java| $res.toArray() |]
reify (unsafeCast arr)

distinct :: RDD a -> IO (RDD a)
distinct r = call r "distinct" []
distinct rdd = [java| $rdd.distinct() |]

intersection :: RDD a -> RDD a -> IO (RDD a)
intersection r r' = call r "intersection" [coerce r']
intersection rdd1 rdd2 = [java| $rdd1.intersection($rdd2) |]

union :: RDD a -> RDD a -> IO (RDD a)
union r r' = call r "union" [coerce r']
union rdd1 rdd2 = [java| $rdd1.union($rdd2) |]

sample
:: RDD a
-> Choice "replacement" -- ^ sample with replacement
-> Choice "replacement" -- ^ Whether to sample with replacement
-> Double -- ^ fraction of elements to keep
-> IO (RDD a)
sample rdd replacement frac = do
call rdd "sample" [jvalue (Choice.toBool replacement), jvalue frac]
sample rdd replacement frac = [java| $rdd.sample($replacement, $frac) |]

first :: Reify a ty => RDD a -> IO a
first rdd = do
res :: JObject <- call rdd "first" []
res :: JObject <- [java| $rdd.first() |]
reify (unsafeCast res)

getNumPartitions :: RDD a -> IO Int32
getNumPartitions rdd = call rdd "getNumPartitions" []
getNumPartitions rdd = [java| $rdd.getNumPartitions() |]

saveAsTextFile :: RDD a -> FilePath -> IO ()
saveAsTextFile rdd fp = do
jfp <- reflect (Text.pack fp)
call rdd "saveAsTextFile" [coerce jfp]
-- XXX workaround for inline-java-0.6 not supporting void return types.
_ :: JObject <- [java| { $rdd.saveAsTextFile($jfp); return null; } |]
return ()
2 changes: 1 addition & 1 deletion stack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ packages:

extra-deps:
- jni-0.3.0
- jvm-0.2.0
- jvm-0.2.1
- jvm-streaming-0.2
- inline-java-0.6.3

Expand Down