Skip to content

Commit

Permalink
Switch RDD bindings to using inline-java under the hood.
Browse files Browse the repository at this point in the history
Using inline-java slows down the compilation of sparkle, but is safer
because we can thus get the benefit of *both* type checkers (Java and
Haskell). In fact the extra safety isn't just theoretical: this patch
also includes a fix to the binding for `treeAggregate`, which was
supplying arguments in the wrong order.

This is preliminary work ahead of implementing #57, which we can do
serenely from the moment that the type checkers have our back.

This patch only switches over RDD for now. The rest can come later.
  • Loading branch information
mboes committed Apr 9, 2017
1 parent 82c2dac commit 0ae7a03
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 63 deletions.
15 changes: 2 additions & 13 deletions Setup.hs
Original file line number Diff line number Diff line change
@@ -1,15 +1,4 @@
import Distribution.Simple
import System.Process
import System.Exit
import Language.Java.Inline.Cabal

main = defaultMainWithHooks simpleUserHooks { postBuild = buildJavaSource }

buildJavaSource _ _ _ _ = do
executeShellCommand "gradle build"
return ()

executeShellCommand cmd = system cmd >>= check
where
check ExitSuccess = return ()
check (ExitFailure n) =
error $ "Command " ++ cmd ++ " exited with failure code " ++ show n
main = defaultMainWithHooks (gradleHooks simpleUserHooks)
5 changes: 3 additions & 2 deletions sparkle.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ custom-setup
setup-depends:
base,
Cabal >= 1.24,
process
inline-java >= 0.6.3

library
include-dirs: cbits
Expand Down Expand Up @@ -67,8 +67,9 @@ library
bytestring >=0.10,
choice >= 0.1,
distributed-closure >=0.3,
inline-java >= 0.6.3,
jni >=0.3.0,
jvm >=0.2.0,
jvm >=0.2.1,
singletons >= 2.0,
streaming >= 0.1,
text >=1.2,
Expand Down
94 changes: 47 additions & 47 deletions src/Control/Distributed/Spark/RDD.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StaticPointers #-}

Expand Down Expand Up @@ -40,13 +41,14 @@ module Control.Distributed.Spark.RDD

import Prelude hiding (filter, map, subtract, take)
import Control.Distributed.Closure
import Control.Distributed.Spark.Closure ()
import Control.Distributed.Spark.Closure (JFun1, JFun2)
import Data.Choice (Choice)
import qualified Data.Choice as Choice
import Data.Int
import qualified Data.Text as Text
import Data.Typeable (Typeable)
import Language.Java
import Language.Java.Inline
-- We don't need this instance. But import to bring it in scope transitively for users.
#if MIN_VERSION_base(4,9,1)
import Language.Java.Streaming ()
Expand All @@ -57,25 +59,25 @@ newtype RDD a = RDD (J ('Class "org.apache.spark.api.java.JavaRDD"))
instance Coercible (RDD a) ('Class "org.apache.spark.api.java.JavaRDD")

repartition :: Int32 -> RDD a -> IO (RDD a)
repartition nbPart rdd = call rdd "repartition" [JInt nbPart]
repartition n rdd = [java| $rdd.repartition($n) |]

filter
:: Reflect (Closure (a -> Bool)) ty
=> Closure (a -> Bool)
-> RDD a
-> IO (RDD a)
filter clos rdd = do
f <- reflect clos
call rdd "filter" [coerce f]
f <- unsafeUngeneric <$> reflect clos
[java| $rdd.filter($f) |]

map
:: Reflect (Closure (a -> b)) ty
:: Reflect (Closure (a -> b)) (JFun1 ty1 ty2)
=> Closure (a -> b)
-> RDD a
-> IO (RDD b)
map clos rdd = do
f <- reflect clos
call rdd "map" [coerce f]
f <- unsafeUngeneric <$> reflect clos
[java| $rdd.map($f) |]

mapPartitions
:: (Reflect (Closure (Int32 -> Stream (Of a) IO () -> Stream (Of b) IO ())) ty, Typeable a, Typeable b)
Expand All @@ -93,54 +95,54 @@ mapPartitionsWithIndex
-> RDD a
-> IO (RDD b)
mapPartitionsWithIndex preservePartitions clos rdd = do
f <- reflect clos
call rdd "mapPartitionsWithIndex" [coerce f, coerce (Choice.toBool preservePartitions)]
f <- unsafeUngeneric <$> reflect clos
[java| $rdd.mapPartitionsWithIndex($f, $preservePartitions) |]

fold
:: (Reflect (Closure (a -> a -> a)) ty1, Reflect a ty2, Reify a ty2)
:: (Reflect (Closure (a -> a -> a)) (JFun2 ty ty ty), Reflect a ty, Reify a ty)
=> Closure (a -> a -> a)
-> a
-> RDD a
-> IO a
fold clos zero rdd = do
f <- reflect clos
f <- unsafeUngeneric <$> reflect clos
jzero <- upcast <$> reflect zero
res :: JObject <- call rdd "fold" [coerce jzero, coerce f]
res :: JObject <- [java| $rdd.fold($jzero, $f) |]
reify (unsafeCast res)

reduce
:: (Reflect (Closure (a -> a -> a)) ty1, Reify a ty2, Reflect a ty2)
:: (Reflect (Closure (a -> a -> a)) (JFun2 ty ty ty), Reify a ty, Reflect a ty)
=> Closure (a -> a -> a)
-> RDD a
-> IO a
reduce clos rdd = do
f <- reflect clos
res :: JObject <- call rdd "reduce" [coerce f]
f <- unsafeUngeneric <$> reflect clos
res :: JObject <- [java| $rdd.reduce($f) |]
reify (unsafeCast res)

aggregate
:: ( Reflect (Closure (b -> a -> b)) ty1
, Reflect (Closure (b -> b -> b)) ty2
, Reify b ty3
, Reflect b ty3
:: ( Reflect (Closure (b -> a -> b)) (JFun2 ty2 ty1 ty2)
, Reflect (Closure (b -> b -> b)) (JFun2 ty2 ty2 ty2)
, Reify b ty2
, Reflect b ty2
)
=> Closure (b -> a -> b)
-> Closure (b -> b -> b)
-> b
-> RDD a
-> IO b
aggregate seqOp combOp zero rdd = do
jseqOp <- reflect seqOp
jcombOp <- reflect combOp
jseqOp <- unsafeUngeneric <$> reflect seqOp
jcombOp <- unsafeUngeneric <$> reflect combOp
jzero <- upcast <$> reflect zero
res :: JObject <- call rdd "aggregate" [coerce jzero, coerce jseqOp, coerce jcombOp]
res :: JObject <- [java| $rdd.aggregate($jzero, $jseqOp, $jcombOp) |]
reify (unsafeCast res)

treeAggregate
:: ( Reflect (Closure (b -> a -> b)) ty1
, Reflect (Closure (b -> b -> b)) ty2
, Reflect b ty3
, Reify b ty3
:: ( Reflect (Closure (b -> a -> b)) (JFun2 ty2 ty1 ty2)
, Reflect (Closure (b -> b -> b)) (JFun2 ty2 ty2 ty2)
, Reflect b ty2
, Reify b ty2
)
=> Closure (b -> a -> b)
-> Closure (b -> b -> b)
Expand All @@ -149,20 +151,17 @@ treeAggregate
-> RDD a
-> IO b
treeAggregate seqOp combOp zero depth rdd = do
jseqOp <- reflect seqOp
jcombOp <- reflect combOp
jseqOp <- unsafeUngeneric <$> reflect seqOp
jcombOp <- unsafeUngeneric <$> reflect combOp
jzero <- upcast <$> reflect zero
let jdepth = coerce depth
res :: JObject <-
call rdd "treeAggregate"
[ coerce jseqOp, coerce jcombOp, coerce jzero, jdepth ]
res :: JObject <- [java| $rdd.treeAggregate($jzero, $jseqOp, $jcombOp, $depth) |]
reify (unsafeCast res)

count :: RDD a -> IO Int64
count rdd = call rdd "count" []
count rdd = [java| $rdd.count() |]

subtract :: RDD a -> RDD a -> IO (RDD a)
subtract rdd rdds = call rdd "subtract" [coerce rdds]
subtract rdd1 rdd2 = [java| $rdd1.subtract($rdd2) |]

-- $reading_files
--
Expand All @@ -183,43 +182,44 @@ subtract rdd rdds = call rdd "subtract" [coerce rdds]
-- | See Note [Reading Files] ("Control.Distributed.Spark.RDD#reading_files").
collect :: Reify a ty => RDD a -> IO [a]
collect rdd = do
alst :: J ('Iface "java.util.List") <- call rdd "collect" []
arr :: JObjectArray <- call alst "toArray" []
res :: J ('Iface "java.util.List") <- [java| $rdd.collect() |]
arr :: JObjectArray <- [java| $res.toArray() |]
reify (unsafeCast arr)

-- | See Note [Reading Files] ("Control.Distributed.Spark.RDD#reading_files").
take :: Reify a ty => RDD a -> Int32 -> IO [a]
take rdd n = do
res :: J ('Class "java.util.List") <- call rdd "take" [JInt n]
arr :: JObjectArray <- call res "toArray" []
res :: J ('Class "java.util.List") <- [java| $rdd.take($n) |]
arr :: JObjectArray <- [java| $res.toArray() |]
reify (unsafeCast arr)

distinct :: RDD a -> IO (RDD a)
distinct r = call r "distinct" []
distinct rdd = [java| $rdd.distinct() |]

intersection :: RDD a -> RDD a -> IO (RDD a)
intersection r r' = call r "intersection" [coerce r']
intersection rdd1 rdd2 = [java| $rdd1.intersection($rdd2) |]

union :: RDD a -> RDD a -> IO (RDD a)
union r r' = call r "union" [coerce r']
union rdd1 rdd2 = [java| $rdd1.union($rdd2) |]

sample
:: RDD a
-> Choice "replacement" -- ^ sample with replacement
-> Choice "replacement" -- ^ Whether to sample with replacement
-> Double -- ^ fraction of elements to keep
-> IO (RDD a)
sample rdd replacement frac = do
call rdd "sample" [jvalue (Choice.toBool replacement), jvalue frac]
sample rdd replacement frac = [java| $rdd.sample($replacement, $frac) |]

first :: Reify a ty => RDD a -> IO a
first rdd = do
res :: JObject <- call rdd "first" []
res :: JObject <- [java| $rdd.first() |]
reify (unsafeCast res)

getNumPartitions :: RDD a -> IO Int32
getNumPartitions rdd = call rdd "getNumPartitions" []
getNumPartitions rdd = [java| $rdd.getNumPartitions() |]

saveAsTextFile :: RDD a -> FilePath -> IO ()
saveAsTextFile rdd fp = do
jfp <- reflect (Text.pack fp)
call rdd "saveAsTextFile" [coerce jfp]
-- XXX workaround for inline-java-0.6 not supporting void return types.
_ :: JObject <- [java| { $rdd.saveAsTextFile($jfp); return null; } |]
return ()
2 changes: 1 addition & 1 deletion stack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ packages:

extra-deps:
- jni-0.3.0
- jvm-0.2.0
- jvm-0.2.1
- jvm-streaming-0.2
- inline-java-0.6.3

Expand Down

0 comments on commit 0ae7a03

Please sign in to comment.