Skip to content

Commit

Permalink
Deepspeed terminate (huggingface#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrs303 authored Jan 17, 2024
1 parent c459c86 commit 8523f7e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
5 changes: 2 additions & 3 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,7 @@ fn shard_manager(

// We received a shutdown signal
if shutdown.load(Ordering::SeqCst) {
p.kill().unwrap();
let _ = p.wait();
terminate("Shard", p, Duration::from_secs(30)).unwrap();
tracing::info!("Shard terminated");
return;
}
Expand Down Expand Up @@ -923,7 +922,7 @@ fn spawn_shards(
drop(shutdown_sender);

// Wait for shard to start
let mut shard_ready = 0;
let mut shard_ready = 0;
while running.load(Ordering::SeqCst) {
match status_receiver.try_recv() {
Ok(ShardStatus::Ready) => {
Expand Down
36 changes: 35 additions & 1 deletion server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import psutil
import signal
import sys
import typer

Expand Down Expand Up @@ -76,7 +78,39 @@ def serve(
sys.stdout.flush()
sys.stderr.flush()
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
proc.wait()
do_terminate = False
current_handler = signal.getsignal(signal.SIGTERM)
def terminate_handler(sig, frame):
nonlocal do_terminate
do_terminate = True
if callable(current_handler):
current_handler(sig, frame)

signal.signal(signal.SIGTERM, terminate_handler)

finished = False
while not finished:
try:
if do_terminate:
parent = psutil.Process(proc.pid)
all_procs = parent.children(recursive=True) + [parent]
for p in all_procs:
try:
p.terminate()
except psutil.NoSuchProcess:
pass
_, alive = psutil.wait_procs(all_procs, timeout=30)
for p in alive:
p.kill()

do_terminate = False

proc.wait(timeout=3)
except subprocess.TimeoutExpired:
pass
else:
finished = True

sys.stdout.flush()
sys.stderr.flush()
if proc.returncode != 0:
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def __init__(
}

world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK"), 0)
rank = int(os.getenv("RANK", "0"))
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true"

Expand Down

0 comments on commit 8523f7e

Please sign in to comment.