Skip to content

Commit

Permalink
Enforce time limit in runner executor
Browse files Browse the repository at this point in the history
  • Loading branch information
chrhansk committed Jun 7, 2024
1 parent 735efd1 commit 4043696
Showing 1 changed file with 53 additions and 27 deletions.
80 changes: 53 additions & 27 deletions pygradflow/runners/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, wait
from multiprocessing import cpu_count
from multiprocessing import Process, Queue, cpu_count

import numpy as np

Expand All @@ -21,8 +21,6 @@
def solve_instance(instance, params, log_filename, verbose):
logger.handlers.clear()

np.seterr(divide="raise", over="raise", invalid="raise")

handler = logging.FileHandler(log_filename)
handler.setFormatter(formatter)
logger.addHandler(handler)
Expand All @@ -39,31 +37,40 @@ def solve_instance(instance, params, log_filename, verbose):
warn_logger.addHandler(handler)
warn_logger.setLevel(logging.WARN)

return instance.solve(params)
try:
np.seterr(divide="raise", over="raise", invalid="raise")
result = instance.solve(params)
return (instance, result)
except Exception as exc:
logger.error("Error solving %s", instance.name, exc_info=exc)
return (instance, "error")


def try_solve_instance(instance, params, log_filename, verbose):
try:
# No time limit
if params.time_limit == np.inf:
return (instance, solve_instance(instance, params, log_filename, verbose))
def solve_instance_queue(queue, instance, params, log_filename, verbose):
instance_result = solve_instance(instance, params, log_filename, verbose)
queue.put(instance_result)

with ProcessPoolExecutor(1) as pool:
future = pool.submit(
solve_instance, instance, params, log_filename, verbose
)
done, _ = wait([future], timeout=(params.time_limit + 10))

if len(done) == 0:
logger.error("Reached timeout, aborting")
return (instance, "timeout")
def solve_instance_time_limit(instance, params, log_filename, verbose):
if params.time_limit == np.inf:
return solve_instance(instance, params, log_filename, verbose)

result = next(iter(done)).result()
return (instance, result)
queue = Queue()
process = Process(
target=solve_instance_queue,
args=(queue, instance, params, log_filename, verbose),
)

except Exception as exc:
logger.error("Error solving %s", instance.name, exc_info=exc)
return (instance, "error")
process.start()
process.join(timeout=(params.time_limit + 10.0))

if process.is_alive():
logger.error("Reached timeout, terminating process for %s", instance.name)
process.terminate()
logger.error("Terminated process for %s", instance.name)
return (instance, "timeout")
else:
return queue.get()


class Runner(ABC):
Expand Down Expand Up @@ -92,7 +99,7 @@ def solve_instances_sequential(self, instances, args, params):
verbose = args.verbose
for instance in instances:
log_filename = self.log_filename(args, instance)
yield try_solve_instance(instance, params, log_filename, verbose)
yield solve_instance_time_limit(instance, params, log_filename, verbose)

def solve_instances_parallel(self, instances, args, params):
verbose = args.verbose
Expand All @@ -113,16 +120,35 @@ def solve_instances_parallel(self, instances, args, params):
solve_args = zip(instances, all_params, all_log_filenames, all_verbose)

with ProcessPoolExecutor(num_procs, max_tasks_per_child=1) as pool:
futures = [
pool.submit(try_solve_instance, *solve_arg) for solve_arg in solve_args
]

while len(futures) != 0:
futures = []

for solve_arg in solve_args:
future = pool.submit(solve_instance_time_limit, *solve_arg)
future.instance = solve_arg[0]
futures.append(future)

run_logger.info("Started %d instances", len(futures))

while True:
done, futures = wait(futures, return_when=FIRST_COMPLETED)

for item in done:
yield item.result()

if len(futures) == 0:
run_logger.info("Solved all instances, terminating")
break

remaining_future = next(iter(futures))

run_logger.info(
"Finished %d instances, waiting for %d (including %s)",
len(done),
len(futures),
remaining_future.instance.name,
)

def solve_instances(self, instances, args):
# yields sequence of tuples of (instance, result) for each instance
run_logger.info("Solving %d instances", len(instances))
Expand Down

0 comments on commit 4043696

Please sign in to comment.