Skip to content

Latest commit

 

History

History
 
 

spark

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 

spark-tensorflow-connector

This repo contains a library for loading and storing TensorFlow records with Apache Spark. The library implements data import from the standard TensorFlow record format ([TFRecords] (https://www.tensorflow.org/how_tos/reading_data/)) into Spark SQL DataFrames, and data export from DataFrames to TensorFlow records.

What's new

This is the initial release of the spark-tensorflow-connector repo.

Known issues

None.

Prerequisites

  1. Apache Spark 2.0 (or later)

  2. Apache Maven

Building the library

You can build library using both Maven and SBT build tools

Maven

Build the library using Maven(3.3) as shown below

mvn clean install

SBT

Build the library using SBT(0.13.13) as show below

sbt clean assembly

Using Spark Shell

Run this library in Spark using the --jars command line option in spark-shell or spark-submit. For example:

Maven Jars

$SPARK_HOME/bin/spark-shell --jars target/spark-tensorflow-connector-1.0-SNAPSHOT.jar,target/lib/tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar

SBT Jars

$SPARK_HOME/bin/spark-shell --jars target/scala-2.11/spark-tensorflow-connector-assembly-1.0.0.jar

The following code snippet demonstrates usage.

import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

val path = "test-output.tfr"
val testRows: Array[Row] = Array(
new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
val schema = StructType(List(StructField("id", IntegerType), 
                             StructField("IntegerTypelabel", IntegerType), 
                             StructField("LongTypelabel", LongType), 
                             StructField("FloatTypelabel", FloatType), 
                             StructField("DoubleTypelabel", DoubleType), 
                             StructField("vectorlabel", ArrayType(DoubleType, true)), 
                             StructField("name", StringType)))
                             
val rdd = spark.sparkContext.parallelize(testRows)

//Save DataFrame as TFRecords
val df: DataFrame = spark.createDataFrame(rdd, schema)
df.write.format("tfrecords").save(path)

//Read TFRecords into DataFrame.
//The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
val importedDf1: DataFrame = spark.read.format("tfrecords").load(path)
importedDf1.show()

//Read TFRecords into DataFrame using custom schema
val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
importedDf2.show()