Skip to content

Commit

Permalink
support caffe with gpus
Browse files Browse the repository at this point in the history
  • Loading branch information
robertnishihara committed Mar 5, 2016
1 parent dba1f31 commit d2d4573
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 7 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ classpathTypes += "maven-plugin"

// resolvers += "Local Maven Repository" at "file://"+Path.userHome.absolutePath+"/.m2/repository"

resolvers += "javacpp" at "http://www.eecs.berkeley.edu/~rkn/snapshot-2016-03-01/"
resolvers += "javacpp" at "http://www.eecs.berkeley.edu/~rkn/snapshot-2016-03-05/"

libraryDependencies += "org.bytedeco" % "javacpp" % "1.2-SPARKNET"

Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/apps/CifarApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ object CifarApp {
solverParam.clear_net()
solverParam.set_allocated_net_param(netParam)

// Caffe.set_mode(Caffe.GPU)
Caffe.set_mode(Caffe.GPU)
val solver = new CaffeSolver(solverParam, schema, new DefaultPreprocessor(schema))
workerStore.put("netParam", netParam) // prevent netParam from being garbage collected
workerStore.put("solverParam", solverParam) // prevent solverParam from being garbage collected
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/apps/ImageNetApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ object ImageNetApp {
ReadSolverParamsFromTextFileOrDie(sparkNetHome + "/models/bvlc_reference_caffenet/solver.prototxt", solverParam)
solverParam.clear_net()
solverParam.set_allocated_net_param(netParam)
// Caffe.set_mode(Caffe.GPU)
Caffe.set_mode(Caffe.GPU)
val solver = new CaffeSolver(solverParam, schema, new ImageNetPreprocessor(schema, meanImage, fullHeight, fullWidth, croppedHeight, croppedWidth))
workerStore.put("netParam", netParam) // prevent netParam from being garbage collected
workerStore.put("solverParam", solverParam) // prevent solverParam from being garbage collected
Expand Down
4 changes: 0 additions & 4 deletions src/main/scala/libs/CaffeNet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ class CaffeNet(netParam: NetParameter, schema: StructType, preprocessor: Preproc
private val layerNames = List.range(0, numLayers).map(i => caffeNet.layers.get(i).layer_param.name.getString)
private val numLayerBlobs = List.range(0, numLayers).map(i => caffeNet.layers.get(i).blobs().size.toInt)

// Caffe.set_mode(Caffe.GPU)

for (i <- 0 to inputSize - 1) {
val name = netParam.input(i).getString
transformations(i) = preprocessor.convert(name, JavaCPPUtils.getInputShape(netParam, i).drop(1)) // drop first index to ignore batchSize
Expand Down Expand Up @@ -86,7 +84,6 @@ class CaffeNet(netParam: NetParameter, schema: StructType, preprocessor: Preproc
}

def forward(rowIt: Iterator[Row], dataBlobNames: List[String] = List[String]()): Map[String, NDArray] = {
// Caffe.set_mode(Caffe.GPU)
transformInto(rowIt, inputs)
val tops = caffeNet.Forward(inputs)
val outputs = Map[String, NDArray]()
Expand All @@ -109,7 +106,6 @@ class CaffeNet(netParam: NetParameter, schema: StructType, preprocessor: Preproc
}

def forwardBackward(rowIt: Iterator[Row]) = {
// Caffe.set_mode(Caffe.GPU)
print("entering forwardBackward\n")
val t1 = System.currentTimeMillis()
transformInto(rowIt, inputs)
Expand Down

0 comments on commit d2d4573

Please sign in to comment.