Skip to content
This repository was archived by the owner on Apr 17, 2024. It is now read-only.

Commit

Permalink
code refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Haodong Tang <[email protected]>
  • Loading branch information
Tang Haodong committed Oct 23, 2019
1 parent b982bd7 commit 511df29
Show file tree
Hide file tree
Showing 24 changed files with 629 additions and 1,273 deletions.

This file was deleted.

10 changes: 5 additions & 5 deletions core/src/main/java/org/apache/spark/storage/pmof/PmemBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class PmemBuffer {
private native long nativeGetPmemBufferDataAddr(long pmBuffer);
private native long nativeDeletePmemBuffer(long pmBuffer);

private boolean closed = false;
private boolean closed = false;
long pmBuffer;
PmemBuffer() {
pmBuffer = nativeNewPmemBuffer();
Expand Down Expand Up @@ -55,9 +55,9 @@ long getDirectAddr() {
}

synchronized void close() {
if (!closed) {
nativeDeletePmemBuffer(pmBuffer);
closed = true;
}
if (!closed) {
nativeDeletePmemBuffer(pmBuffer);
closed = true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,17 @@ import java.nio.ByteBuffer
import java.util.concurrent.ConcurrentHashMap

import com.intel.hpnl.core._
import org.apache.spark.SparkConf
import org.apache.spark.shuffle.pmof.PmofShuffleManager
import org.apache.spark.util.configuration.pmof.PmofConf

import scala.collection.mutable.ArrayBuffer

class ClientFactory(conf: SparkConf) {
final val SINGLE_BUFFER_SIZE: Int = PmofTransferService.CHUNKSIZE
final val BUFFER_NUM: Int = conf.getInt("spark.shuffle.pmof.client_buffer_nums", 16)
final val workers = conf.getInt("spark.shuffle.pmof.server_pool_size", 1)

final val eqService = new EqService(workers, BUFFER_NUM, false).init()
final val cqService = new CqService(eqService).init()

final val conArray: ArrayBuffer[Connection] = ArrayBuffer()
final val clientMap = new ConcurrentHashMap[InetSocketAddress, Client]()
final val conMap = new ConcurrentHashMap[Connection, Client]()
class ClientFactory(pmofConf: PmofConf) {
final val eqService = new EqService(pmofConf.clientWorkerNums, pmofConf.clientBufferNums, false).init()
private[this] final val cqService = new CqService(eqService).init()
private[this] final val clientMap = new ConcurrentHashMap[InetSocketAddress, Client]()
private[this] final val conMap = new ConcurrentHashMap[Connection, Client]()

def init(): Unit = {
eqService.initBufferPool(BUFFER_NUM, SINGLE_BUFFER_SIZE, BUFFER_NUM * 2)
eqService.initBufferPool(pmofConf.clientBufferNums, pmofConf.networkBufferSize, pmofConf.clientBufferNums * 2)
val clientRecvHandler = new ClientRecvHandler
val clientReadHandler = new ClientReadHandler
eqService.setRecvCallback(clientRecvHandler)
Expand Down Expand Up @@ -62,16 +54,16 @@ class ClientFactory(conf: SparkConf) {

class ClientRecvHandler() extends Handler {
override def handle(con: Connection, rdmaBufferId: Int, blockBufferSize: Int): Unit = {
val buffer: HpnlBuffer = con.getRecvBuffer(rdmaBufferId)
val rpcMessage: ByteBuffer = buffer.get(blockBufferSize)
val seq = buffer.getSeq
val msgType = buffer.getType
val hpnlBuffer: HpnlBuffer = con.getRecvBuffer(rdmaBufferId)
val byteBuffer: ByteBuffer = hpnlBuffer.get(blockBufferSize)
val seq = hpnlBuffer.getSeq
val msgType = hpnlBuffer.getType
val callback = conMap.get(con).outstandingReceiveFetches.get(seq)
if (msgType == 0.toByte) {
if (msgType == 0.toByte) { // get ACK from driver, which means the block info has been saved to driver memory
callback.onSuccess(null)
} else {
} else { // get block info from driver, and deserialize the info to scala object
val metadataResolver = conMap.get(con).shuffleManager.metadataResolver
val blockInfoArray = metadataResolver.deserializeShuffleBlockInfo(rpcMessage)
val blockInfoArray = metadataResolver.deserializeShuffleBlockInfo(byteBuffer)
callback.onSuccess(blockInfoArray)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager}
import org.apache.spark.shuffle.pmof.{MetadataResolver, PmofShuffleManager}
import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId}
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.util.configuration.pmof.PmofConf

import scala.collection.mutable

class PmofTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManager,
class PmofTransferService(val pmofConf: PmofConf, val shuffleManager: PmofShuffleManager,
val hostname: String, var port: Int) extends TransferService {
private[this] final val metadataResolver: MetadataResolver = this.shuffleManager.metadataResolver
final var server: Server = _
final private var clientFactory: ClientFactory = _
private var nextReqId: AtomicLong = _
final val metadataResolver: MetadataResolver = this.shuffleManager.metadataResolver
private[this] final var clientFactory: ClientFactory = _
private[this] var nextReqId: AtomicLong = _

override def fetchBlocks(host: String,
port: Int,
Expand All @@ -33,12 +33,12 @@ class PmofTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage
}

def fetchBlockInfo(blockIds: Array[BlockId], receivedCallback: ReceivedCallback): Unit = {
val shuffleBlockIds = blockIds.map(blockId=>blockId.asInstanceOf[ShuffleBlockId])
val shuffleBlockIds = blockIds.map(blockId => blockId.asInstanceOf[ShuffleBlockId])
metadataResolver.fetchBlockInfo(shuffleBlockIds, receivedCallback)
}

def syncBlocksInfo(host: String, port: Int, byteBuffer: ByteBuffer, msgType: Byte,
callback: ReceivedCallback): Unit = {
def pushBlockInfo(host: String, port: Int, byteBuffer: ByteBuffer, msgType: Byte,
callback: ReceivedCallback): Unit = {
clientFactory.createClient(shuffleManager, host, port).
send(byteBuffer, nextReqId.getAndIncrement(), msgType, callback, isDeferred = false)
}
Expand All @@ -59,8 +59,8 @@ class PmofTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage
}

def init(): Unit = {
this.server = new Server(conf, shuffleManager, hostname, port)
this.clientFactory = new ClientFactory(conf)
this.server = new Server(pmofConf, shuffleManager, hostname, port)
this.clientFactory = new ClientFactory(pmofConf)
this.server.init()
this.server.start()
this.clientFactory.init()
Expand All @@ -73,30 +73,24 @@ class PmofTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage
}

object PmofTransferService {
final val env: SparkEnv = SparkEnv.get
final val conf: SparkConf = env.conf
final val CHUNKSIZE: Int = conf.getInt("spark.shuffle.pmof.chunk_size", 4096*3)
final val driverHost: String = conf.get("spark.driver.rhost", defaultValue = "172.168.0.43")
final val driverPort: Int = conf.getInt("spark.driver.rport", defaultValue = 61000)
final val shuffleNodes: Array[Array[String]] =
conf.get("spark.shuffle.pmof.node", defaultValue = "").split(",").map(_.split("-"))
final val shuffleNodesMap: mutable.Map[String, String] = new mutable.HashMap[String, String]()
for (array <- shuffleNodes) {
shuffleNodesMap.put(array(0), array(1))
}
private val initialized = new AtomicBoolean(false)
private var transferService: PmofTransferService = _
def getTransferServiceInstance(blockManager: BlockManager, shuffleManager: PmofShuffleManager = null,
private[this] final val initialized = new AtomicBoolean(false)
private[this] var transferService: PmofTransferService = _

def getTransferServiceInstance(pmofConf: PmofConf, blockManager: BlockManager, shuffleManager: PmofShuffleManager = null,
isDriver: Boolean = false): PmofTransferService = {
if (!initialized.get()) {
PmofTransferService.this.synchronized {
if (initialized.get()) return transferService
if (isDriver) {
transferService =
new PmofTransferService(conf, shuffleManager, driverHost, driverPort)
new PmofTransferService(pmofConf, shuffleManager, pmofConf.driverHost, pmofConf.driverPort)
} else {
for (array <- pmofConf.shuffleNodes) {
shuffleNodesMap.put(array(0), array(1))
}
transferService =
new PmofTransferService(conf, shuffleManager, shuffleNodesMap(blockManager.shuffleServerId.host), 0)
new PmofTransferService(pmofConf, shuffleManager, shuffleNodesMap(blockManager.shuffleServerId.host), 0)
}
transferService.init()
initialized.set(true)
Expand Down
33 changes: 15 additions & 18 deletions core/src/main/scala/org/apache/spark/network/pmof/Server.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,22 @@ import java.nio.ByteBuffer
import java.util

import com.intel.hpnl.core._
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.pmof.PmofShuffleManager
import org.apache.spark.util.configuration.pmof.PmofConf

class Server(conf: SparkConf, val shuffleManager: PmofShuffleManager, address: String, var port: Int) {
class Server(pmofConf: PmofConf, val shuffleManager: PmofShuffleManager, address: String, var port: Int) {
if (port == 0) {
port = Utils.getPort
}
final val SINGLE_BUFFER_SIZE: Int = PmofTransferService.CHUNKSIZE
final val BUFFER_NUM: Int = conf.getInt("spark.shuffle.pmof.server_buffer_nums", 256)
final val workers = conf.getInt("spark.shuffle.pmof.server_pool_size", 1)

final val eqService = new EqService(workers, BUFFER_NUM, true).init()
final val cqService = new CqService(eqService).init()
private[this] final val eqService = new EqService(pmofConf.serverWorkerNums, pmofConf.serverBufferNums, true).init()
private[this] final val cqService = new CqService(eqService).init()

val conList = new util.ArrayList[Connection]()
private[this] final val conList = new util.ArrayList[Connection]()

def init(): Unit = {
eqService.initBufferPool(BUFFER_NUM, SINGLE_BUFFER_SIZE, BUFFER_NUM * 2)
eqService.initBufferPool(pmofConf.serverBufferNums, pmofConf.networkBufferSize, pmofConf.serverBufferNums * 2)
val recvHandler = new ServerRecvHandler(this)
val connectedHandler = new ServerConnectedHandler(this)
eqService.setConnectedCallback(connectedHandler)
Expand Down Expand Up @@ -62,17 +59,17 @@ class ServerRecvHandler(server: Server) extends Handler with Logging {
}

override def handle(con: Connection, bufferId: Int, blockBufferSize: Int): Unit = synchronized {
val buffer: HpnlBuffer = con.getRecvBuffer(bufferId)
val message: ByteBuffer = buffer.get(blockBufferSize)
val seq = buffer.getSeq
val msgType = buffer.getType
val hpnlBuffer: HpnlBuffer = con.getRecvBuffer(bufferId)
val byteBuffer: ByteBuffer = hpnlBuffer.get(blockBufferSize)
val seq = hpnlBuffer.getSeq
val msgType = hpnlBuffer.getType
val metadataResolver = server.shuffleManager.metadataResolver
if (msgType == 0.toByte) {
metadataResolver.addShuffleBlockInfo(message)
if (msgType == 0.toByte) { // get block info message from executor, then save the info to memory
metadataResolver.saveShuffleBlockInfo(byteBuffer)
sendMetadata(con, byteBufferTmp, 0.toByte, seq, isDeferred = false)
} else {
val bufferArray = metadataResolver.serializeShuffleBlockInfo(message)
for (buffer <- bufferArray) {
} else { // lookup block info from memory, then send the info to executor
val blockInfoArray = metadataResolver.serializeShuffleBlockInfo(byteBuffer)
for (buffer <- blockInfoArray) {
sendMetadata(con, buffer, 1.toByte, seq, isDeferred = false)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,28 @@
package org.apache.spark.shuffle

import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.internal.{Logging, config}
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.pmof.PmemExternalSorter
import org.apache.spark.util.configuration.pmof.PmofConf

/**
* Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by
* requesting them from other nodes' block stores.
*/
private[spark] class PmemShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
context: TaskContext,
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
* Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by
* requesting them from other nodes' block stores.
*/
private[spark] class BaseShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
context: TaskContext,
pmofConf: PmofConf,
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C] with Logging {

private val dep = handle.dependency
private[this] val dep = handle.dependency

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
Expand Down Expand Up @@ -97,10 +98,11 @@ private[spark] class PmemShuffleReader[K, C](
// Sort the output if there is a sort ordering defined.
val resultIter = dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
assert(pmofConf.enablePmem == true)
// Create an ExternalSorter to sort the data.
val sorter =
new PmemExternalSorter[K, C, C](context, handle, ordering = Some(keyOrd), serializer = dep.serializer)
logDebug("call PmemExternalSorter.insertAll for shuffle_0_" + handle.shuffleId + "_[" + startPartition + "," + endPartition + "]" )
new PmemExternalSorter[K, C, C](context, handle, pmofConf, ordering = Some(keyOrd), serializer = dep.serializer)
logDebug("call PmemExternalSorter.insertAll for shuffle_0_" + handle.shuffleId + "_[" + startPartition + "," + endPartition + "]")
sorter.insertAll(aggregatedIter)
// Use completion callback to stop sorter if task was finished/cancelled.
context.addTaskCompletionListener(_ => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,27 @@ import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, S
import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.util.configuration.pmof.PmofConf

private[spark] class BaseShuffleWriter[K, V, C](
shuffleBlockResolver: IndexShuffleBlockResolver,
metadataResolver: MetadataResolver,
handle: BaseShuffleHandle[K, V, C],
mapId: Int,
context: TaskContext,
enable_rdma: Boolean)
private[spark] class BaseShuffleWriter[K, V, C](shuffleBlockResolver: IndexShuffleBlockResolver,
metadataResolver: MetadataResolver,
handle: BaseShuffleHandle[K, V, C],
mapId: Int,
context: TaskContext,
pmofConf: PmofConf)
extends ShuffleWriter[K, V] with Logging {

private val dep = handle.dependency

private val blockManager = SparkEnv.get.blockManager

private val writeMetrics = context.taskMetrics().shuffleWriteMetrics
private var sorter: ExternalSorter[K, V, _] = _

// Are we in the process of stopping? Because map tasks can call stop() with success = true
// and then call stop() with success = false if they get an exception, we want to make sure
// we don't try deleting files, etc twice.
private var stopping = false

private var mapStatus: MapStatus = _

private val writeMetrics = context.taskMetrics().shuffleWriteMetrics

/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
Expand All @@ -76,11 +72,11 @@ private[spark] class BaseShuffleWriter[K, V, C](
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)

val shuffleServerId = blockManager.shuffleServerId
if (enable_rdma) {
metadataResolver.commitBlockInfo(dep.shuffleId, mapId, partitionLengths)
if (pmofConf.enableRdma) {
metadataResolver.pushFileBlockInfo(dep.shuffleId, mapId, partitionLengths)
val blockManagerId: BlockManagerId =
BlockManagerId(shuffleServerId.executorId, PmofTransferService.shuffleNodesMap(shuffleServerId.host),
PmofTransferService.getTransferServiceInstance(blockManager).port, shuffleServerId.topologyInfo)
BlockManagerId(shuffleServerId.executorId, PmofTransferService.shuffleNodesMap(shuffleServerId.host),
PmofTransferService.getTransferServiceInstance(pmofConf, blockManager).port, shuffleServerId.topologyInfo)
mapStatus = MapStatus(blockManagerId, partitionLengths)
} else {
mapStatus = MapStatus(shuffleServerId, partitionLengths)
Expand Down Expand Up @@ -115,16 +111,3 @@ private[spark] class BaseShuffleWriter[K, V, C](
}
}
}

private[spark] object BaseShuffleWriter {
def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
// We cannot bypass sorting if we need to do map-side aggregation.
if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
false
} else {
val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
dep.partitioner.numPartitions <= bypassMergeThreshold
}
}
}
Loading

0 comments on commit 511df29

Please sign in to comment.