diff --git a/std/object/client_consume_seg.go b/std/object/client_consume_seg.go index c3ae9f03..801b33db 100644 --- a/std/object/client_consume_seg.go +++ b/std/object/client_consume_seg.go @@ -1,6 +1,7 @@ package object import ( + "container/list" "fmt" "slices" "sync" @@ -8,6 +9,9 @@ import ( enc "github.com/named-data/ndnd/std/encoding" "github.com/named-data/ndnd/std/log" "github.com/named-data/ndnd/std/ndn" + spec "github.com/named-data/ndnd/std/ndn/spec_2022" + cong "github.com/named-data/ndnd/std/object/congestion" + "github.com/named-data/ndnd/std/utils" ) // round-robin based segment fetcher @@ -25,7 +29,21 @@ type rrSegFetcher struct { // number of outstanding interests outstanding int // window size - window int + window cong.CongestionWindow + // retransmission queue + retxQueue *list.List + // remaining segments to be transmitted by state + txCounter map[*ConsumeState]int + // maximum number of retries + maxRetries int +} + +// retxEntry represents an entry in the retransmission queue +// it contains the consumer state, segment number and the number of retries left for the segment +type retxEntry struct { + state *ConsumeState + seg uint64 + retries int } func newRrSegFetcher(client *Client) rrSegFetcher { @@ -33,8 +51,11 @@ func newRrSegFetcher(client *Client) rrSegFetcher { mutex: sync.RWMutex{}, client: client, streams: make([]*ConsumeState, 0), - window: 100, + window: cong.NewFixedCongestionWindow(100), outstanding: 0, + retxQueue: list.New(), + txCounter: make(map[*ConsumeState]int), + maxRetries: 3, } } @@ -47,7 +68,7 @@ func (s *rrSegFetcher) String() string { func (s *rrSegFetcher) IsCongested() bool { s.mutex.RLock() defer s.mutex.RUnlock() - return s.outstanding >= s.window + return s.outstanding >= s.window.Size() } // add a stream to the fetch queue @@ -71,10 +92,6 @@ func (s *rrSegFetcher) findWork() *ConsumeState { s.mutex.Lock() defer s.mutex.Unlock() - if s.outstanding >= s.window { - return nil - } - // round-robin selection of the next stream to fetch next := func() *ConsumeState { if len(s.streams) == 0 { @@ -132,34 +149,103 @@ func (s *rrSegFetcher) findWork() *ConsumeState { func (s *rrSegFetcher) check() { for { - state := s.findWork() - if state == nil { + log.Debug(nil, "Checking for work") + + // check if the window is full + if s.IsCongested() { + log.Debug(nil, "Window full", "size", s.window.Size()) + return // no need to generate new interests + } + + var ( + state *ConsumeState + seg uint64 + retries int = s.maxRetries // TODO: make it configurable + ) + + // if there are retransmissions, handle them first + if s.retxQueue.Len() > 0 { + log.Debug(nil, "Retransmitting") + + var retx *retxEntry + + s.mutex.Lock() + front := s.retxQueue.Front() + if front != nil { + retx = s.retxQueue.Remove(front).(*retxEntry) + s.mutex.Unlock() + } else { + s.mutex.Unlock() + continue + } + + state = retx.state + seg = retx.seg + retries = retx.retries + + } else { // if no retransmissions, find a stream to work on + state = s.findWork() + if state == nil { + return + } + + // update window parameters + s.mutex.Lock() + seg = uint64(state.wnd[2]) + state.wnd[2]++ + s.mutex.Unlock() + } + + // build interest + name := state.fetchName.Append(enc.NewSegmentComponent(seg)) + config := &ndn.InterestConfig{ + MustBeFresh: false, + Nonce: utils.ConvertNonce(s.client.engine.Timer().Nonce()), // new nonce for each call + } + var appParam enc.Wire = nil + var signer ndn.Signer = nil + + log.Debug(nil, "Building interest", "name", name, "config", config) + interest, err := s.client.Engine().Spec().MakeInterest(name, config, appParam, signer) + if err != nil { + s.handleResult(ndn.ExpressCallbackArgs{ + Result: ndn.InterestResultError, + Error: err, + }, state, seg, retries) return } - // update window parameters - seg := uint64(state.wnd[2]) + // build express callback function + callback := func(args ndn.ExpressCallbackArgs) { + s.handleResult(args, state, seg, retries) + } + + // express interest + log.Debug(nil, "Expressing interest", "name", interest.FinalName) + err = s.client.Engine().Express(interest, callback) + if err != nil { + s.handleResult(ndn.ExpressCallbackArgs{ + Result: ndn.InterestResultError, + Error: err, + }, state, seg, retries) + return + } + + // increment outstanding interest count + s.mutex.Lock() s.outstanding++ - state.wnd[2]++ - - // queue outgoing interest for the next segment - s.client.ExpressR(ndn.ExpressRArgs{ - Name: state.fetchName.Append(enc.NewSegmentComponent(seg)), - Config: &ndn.InterestConfig{ - MustBeFresh: false, - }, - Retries: 3, - Callback: func(args ndn.ExpressCallbackArgs) { - s.handleData(args, state) - }, - }) + s.mutex.Unlock() } } -// handleData is called when a data packet is received. +// handleResult is called when the result for an interest is ready. // It is necessary that this function be called only from one goroutine - the engine. -// The notable exception here is when there is a timeout, which has a separate goroutine. -func (s *rrSegFetcher) handleData(args ndn.ExpressCallbackArgs, state *ConsumeState) { +func (s *rrSegFetcher) handleResult(args ndn.ExpressCallbackArgs, state *ConsumeState, seg uint64, retries int) { + // get the name of the interest + var interestName enc.Name = state.fetchName.Append(enc.NewSegmentComponent(seg)) + log.Debug(nil, "Parsing interest result", "name", interestName) + + // decrement outstanding interest count s.mutex.Lock() s.outstanding-- s.mutex.Unlock() @@ -168,23 +254,50 @@ func (s *rrSegFetcher) handleData(args ndn.ExpressCallbackArgs, state *ConsumeSt return } - if args.Result == ndn.InterestResultError { - state.finalizeError(fmt.Errorf("%w: fetch seg failed: %w", ndn.ErrNetwork, args.Error)) - return - } + // handle the result + switch args.Result { + case ndn.InterestResultTimeout: + log.Debug(nil, "Interest timeout", "name", interestName) + + s.window.HandleSignal(cong.SigLoss) + s.enqueueForRetransmission(state, seg, retries - 1) + + case ndn.InterestResultNack: + log.Debug(nil, "Interest nack'd", "name", interestName) + + switch args.NackReason { + case spec.NackReasonDuplicate: + // ignore Nack for duplicates + case spec.NackReasonCongestion: + // congestion signal + s.window.HandleSignal(cong.SigCongest) + s.enqueueForRetransmission(state, seg, retries - 1) + default: + // treat as irrecoverable error for now + state.finalizeError(fmt.Errorf("%w: fetch seg failed with result: %s", ndn.ErrNetwork, args.Result)) + } + + case ndn.InterestResultData: // data is successfully retrieved + s.handleData(args, state) + s.window.HandleSignal(cong.SigData) - if args.Result != ndn.InterestResultData { + default: // treat as irrecoverable error for now state.finalizeError(fmt.Errorf("%w: fetch seg failed with result: %s", ndn.ErrNetwork, args.Result)) - return } + s.check() // check for more work +} + +// handleData is called when the interest result is processed and the data is ready to be validated. +// It is necessary that this function be called only from one goroutine - the engine. +// The notable exception here is when there is a timeout, which has a separate goroutine. +func (s *rrSegFetcher) handleData(args ndn.ExpressCallbackArgs, state *ConsumeState) { s.client.Validate(args.Data, args.SigCovered, func(valid bool, err error) { if !valid { state.finalizeError(fmt.Errorf("%w: validate seg failed: %w", ndn.ErrSecurity, err)) } else { s.handleValidatedData(args, state) } - s.check() }) } @@ -203,6 +316,7 @@ func (s *rrSegFetcher) handleValidatedData(args ndn.ExpressCallbackArgs, state * } state.segCnt = int(fbId.NumberVal()) + 1 + s.txCounter[state] = state.segCnt // number of segments to be transmitted for this state if state.segCnt > maxObjectSeg || state.segCnt <= 0 { state.finalizeError(fmt.Errorf("%w: invalid FinalBlockId=%d", ndn.ErrProtocol, state.segCnt)) return @@ -235,17 +349,23 @@ func (s *rrSegFetcher) handleValidatedData(args ndn.ExpressCallbackArgs, state * panic("[BUG] consume: nil data segment") } + // decrease transmission counter + s.mutex.Lock() + s.txCounter[state]-- + s.mutex.Unlock() + // if this is the first outstanding segment, move windows if state.wnd[1] == segNum { for state.wnd[1] < state.segCnt && state.content[state.wnd[1]] != nil { state.wnd[1]++ } - if state.wnd[1] == state.segCnt { + if state.wnd[1] == state.segCnt && s.txCounter[state] == 0 { log.Debug(s, "Stream completed successfully", "name", state.fetchName) s.mutex.Lock() s.remove(state) + delete(s.txCounter, state) s.mutex.Unlock() if !state.complete.Swap(true) { @@ -265,3 +385,17 @@ func (s *rrSegFetcher) handleValidatedData(args ndn.ExpressCallbackArgs, state * // s.outstanding) // } } + +// enqueueForRetransmission enqueues a segment for retransmission +// it registers retries and treats exhausted retries as irrecoverable errors +func (s *rrSegFetcher) enqueueForRetransmission(state *ConsumeState, seg uint64, retries int) { + if (retries == 0) { // retransmission exhausted + state.finalizeError(fmt.Errorf("%w: retries exhausted, segment number=%d", ndn.ErrNetwork, seg)) + return + } + + s.mutex.Lock() + defer s.mutex.Unlock() + + s.retxQueue.PushBack(&retxEntry{state, seg, retries}) +} \ No newline at end of file diff --git a/std/object/congestion/congestion_window.go b/std/object/congestion/congestion_window.go new file mode 100644 index 00000000..0af5a6ea --- /dev/null +++ b/std/object/congestion/congestion_window.go @@ -0,0 +1,30 @@ +package congestion + +import "time" + +// CongestionSignal represents signals to adjust the congestion window. +type CongestionSignal int + +const ( + SigData = iota // data is fetched + SigLoss // data loss detected + SigCongest // congestion detected (e.g. NACK with a reason of congestion) +) + +// Congestion window change event +type WindowEvent struct { + age time.Time // time of the event + cwnd int // new window size +} + +// CongestionWindow provides an interface for congestion control that manages a window +type CongestionWindow interface { + String() string + + EventChannel() <-chan WindowEvent // where window events are emitted + HandleSignal(signal CongestionSignal) // signal handler + + Size() int + IncreaseWindow() + DecreaseWindow() +} \ No newline at end of file diff --git a/std/object/congestion/congestion_window_aimd.go b/std/object/congestion/congestion_window_aimd.go new file mode 100644 index 00000000..90a7339f --- /dev/null +++ b/std/object/congestion/congestion_window_aimd.go @@ -0,0 +1,97 @@ +package congestion + +import ( + "math" + "time" + + "github.com/named-data/ndnd/std/log" +) + +// AIMDCongestionControl is an implementation of CongestionWindow using Additive Increase Multiplicative Decrease algorithm +type AIMDCongestionWindow struct { + window float64 // window size - float64 to allow percentage growth in congestion avoidance phase + eventCh chan WindowEvent // channel for emitting window change event + + initCwnd float64 // initial window size + ssthresh float64 // slow start threshold + minSsthresh float64 // minimum slow start threshold + aiStep float64 // additive increase step + mdCoef float64 // multiplicative decrease coefficient + resetCwnd bool // whether to reset cwnd after decrease +} + +// TODO: should we bundle the parameters into an AIMDOption struct? + +func NewAIMDCongestionWindow(cwnd int) *AIMDCongestionWindow { + return &AIMDCongestionWindow{ + window: float64(cwnd), + eventCh: make(chan WindowEvent), + + initCwnd: float64(cwnd), + ssthresh: math.MaxFloat64, + minSsthresh: 2.0, + aiStep: 1.0, + mdCoef: 0.5, + resetCwnd: false, // defaults + } +} + +// log identifier +func (cw *AIMDCongestionWindow) String() string { + return "aimd-congestion-window" +} + +func (cw *AIMDCongestionWindow) Size() int { + return int(math.Floor(cw.window)) +} + +func (cw *AIMDCongestionWindow) IncreaseWindow() { + if cw.window < cw.ssthresh { + cw.window += cw.aiStep // additive increase + } else { + cw.window += cw.aiStep / cw.window // congestion avoidance + + // note: the congestion avoidance formula differs from RFC 5681 Section 3.1 + // recommendations and is borrowed from ndn-tools/catchunks, check + // https://github.com/named-data/ndn-tools/blob/130975c4be69d126fede77d47a50580d5e8b25b0/tools/chunks/catchunks/pipeline-interests-aimd.cpp#L45 + } + + cw.EmitWindowEvent(time.Now(), cw.Size()) // window change signal +} + +func (cw *AIMDCongestionWindow) DecreaseWindow() { + cw.ssthresh = math.Max(cw.window * cw.mdCoef, cw.minSsthresh) + + if cw.resetCwnd { + cw.window = cw.initCwnd + } else { + cw.window = cw.ssthresh + } + + cw.EmitWindowEvent(time.Now(), cw.Size()) // window change signal +} + +func (cw *AIMDCongestionWindow) EventChannel() <-chan WindowEvent { + return cw.eventCh +} + +func (cw *AIMDCongestionWindow) HandleSignal(signal CongestionSignal) { + switch signal { + case SigData: + cw.IncreaseWindow() + case SigLoss, SigCongest: + cw.DecreaseWindow() + default: + // no-op + } +} + +func (cw *AIMDCongestionWindow) EmitWindowEvent(age time.Time, cwnd int) { + // non-blocking send to the channel + select { + case cw.eventCh <- WindowEvent{age: age, cwnd: cwnd}: + default: + // if the channel is full, we log the change event + log.Debug(cw, "Window size changes", "window", cw.window) + } +} \ No newline at end of file diff --git a/std/object/congestion/congestion_window_fixed.go b/std/object/congestion/congestion_window_fixed.go new file mode 100644 index 00000000..ea3fd934 --- /dev/null +++ b/std/object/congestion/congestion_window_fixed.go @@ -0,0 +1,39 @@ +package congestion + +// FixedCongestionControl is an implementation of CongestionWindow using Additive Increase Multiplicative Decrease algorithm +type FixedCongestionWindow struct { + window int // window size + eventCh chan WindowEvent // channel for emitting window change event +} + +func NewFixedCongestionWindow(cwnd int) *FixedCongestionWindow { + return &FixedCongestionWindow{ + window: cwnd, + eventCh: make(chan WindowEvent), + } +} + +// log identifier +func (cw *FixedCongestionWindow) String() string { + return "fixed-congestion-window" +} + +func (cw *FixedCongestionWindow) Size() int { + return cw.window +} + +func (cw *FixedCongestionWindow) IncreaseWindow() { + // intentionally left blank: window size is fixed +} + +func (cw *FixedCongestionWindow) DecreaseWindow() { + // intentionally left blank: window size is fixed +} + +func (cw *FixedCongestionWindow) EventChannel() <-chan WindowEvent { + return cw.eventCh +} + +func (cw *FixedCongestionWindow) HandleSignal(signal CongestionSignal) { + // intentionally left blank: fixed CW doesn't respond to signals +} \ No newline at end of file