Skip to content

Commit

Permalink
Merge pull request #1024 from slaclab/ESROGUE-683
Browse files Browse the repository at this point in the history
Improve variable update performance
  • Loading branch information
slacrherbst authored Sep 27, 2024
2 parents eb0d921 + 6598340 commit b47e9cd
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 19 deletions.
51 changes: 35 additions & 16 deletions python/pyrogue/_Root.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ def decrement(self):
self._check()

def _check(self):
if len(self._list) != 0 and (self._count == 0 or (self._period != 0 and (time.time() - self._last) > self._period)):
#print(f"Update fired {time.time()}")
self._last = time.time()
self._q.put(self._list)
self._list = {}
if self._count == 0 or (self._period != 0 and (time.time() - self._last) > self._period):
if len(self._list) != 0:
#print(f"Update fired {time.time()}")
self._last = time.time()
self._q.put(self._list)
self._list = {}

def update(self,var):
"""
Expand Down Expand Up @@ -482,19 +483,19 @@ def updateGroup(self, period=0):
tid = threading.get_ident()

# At with call
with self._updateLock:
if tid not in self._updateTrack:
self._updateTrack[tid] = UpdateTracker(self._updateQueue)

try:
self._updateTrack[tid].increment(period)
except Exception:
with self._updateLock:
self._updateTrack[tid] = UpdateTracker(self._updateQueue)
self._updateTrack[tid].increment(period)

try:
yield
finally:

# After with is done
with self._updateLock:
self._updateTrack[tid].decrement()
self._updateTrack[tid].decrement()

@contextmanager
def pollBlock(self):
Expand Down Expand Up @@ -1000,10 +1001,19 @@ def _queueUpdates(self,var):
"""
tid = threading.get_ident()

with self._updateLock:
if tid not in self._updateTrack:
self._updateTrack[tid] = UpdateTracker(self._updateQueue)
try:
self._updateTrack[tid].update(var)
except Exception:
with self._updateLock:
self._updateTrack[tid] = UpdateTracker(self._updateQueue)
self._updateTrack[tid].update(var)

# Recursively add listeners to update list
def _recurseAddListeners(self, nvars, var):
for vl in var._listeners:
nvars[vl.path] = vl

self._recurseAddListeners(nvars, vl)

# Worker thread
def _updateWorker(self):
Expand All @@ -1022,10 +1032,19 @@ def _updateWorker(self):
# Process list
elif len(uvars) > 0:
self._log.debug(F'Process update group. Length={len(uvars)}. Entry={list(uvars.keys())[0]}')

# Copy list and add listeners
nvars = uvars.copy()
for p,v in uvars.items():
self._recurseAddListeners(nvars, v)

# Process the new list
for p,v in nvars.items():

# Process updates
val = v._doUpdate()

# Call listener functions,
# Call root listener functions,
with self._varListenLock:
for func,doneFunc,incGroups,excGroups in self._varListeners:
if v.filterByGroup(incGroups, excGroups):
Expand All @@ -1040,7 +1059,7 @@ def _updateWorker(self):
else:
pr.logException(self._log,e)

# Finalize listeners
# Finalize root listeners
with self._varListenLock:
for func,doneFunc,incGroups,excGroups in self._varListeners:
if doneFunc is not None:
Expand Down
3 changes: 0 additions & 3 deletions python/pyrogue/_Variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,9 +860,6 @@ def _queueUpdate(self):
""" """
self._root._queueUpdates(self)

for var in self._listeners:
var._queueUpdate()

def _doUpdate(self):
""" """
val = VariableValue(self)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import time
import hwcounter

#import cProfile, pstats, io
#from pstats import SortKey

#rogue.Logging.setLevel(rogue.Logging.Debug)
#import logging
#logger = logging.getLogger('pyrogue')
Expand Down Expand Up @@ -105,6 +108,9 @@ def __init__(self):

def test_rate():

#pr = cProfile.Profile()
#pr.enable()

with DummyTree() as root:
count = 100000
resultRate = {}
Expand Down Expand Up @@ -178,5 +184,14 @@ def test_rate():
if passed is False:
raise AssertionError('Rate check failed')


#pr.disable()

#s = io.StringIO()
#sortby = SortKey.CUMULATIVE
#ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
#ps.print_stats()
#print(s.getvalue())

if __name__ == "__main__":
test_rate()

0 comments on commit b47e9cd

Please sign in to comment.