Skip to content

Commit

Permalink
Fix autoscaler creating too many instances, have new instances self-d…
Browse files Browse the repository at this point in the history
…elete if there is a problem
  • Loading branch information
EricTendian committed Dec 5, 2024
1 parent 5fba7c5 commit c3d43ef
Show file tree
Hide file tree
Showing 5 changed files with 909 additions and 858 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,6 @@ docker-compose.override.yml
config/*.json
!config/transcript_cleanup.json
scratch

# Made by the Vast.ai CLI
gpu_names_cache.json
34 changes: 29 additions & 5 deletions app/bin/autoscale-vast.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,36 @@ def _update_running_instances(self, instances: list[dict]):
)
]

def _update_pending_instances(self, instances: list[dict]):
pending_instances = {}
for instance in list(
filter(
lambda i: i["next_state"] != "running"
and "deletion_reason" not in i,
instances,
)
):
hostname = None
concurrency = 0
for env in instance["extra_env"]:
key = env[0]
value = env[1]
if key == "CELERY_HOSTNAME":
hostname = value
continue
if key == "CELERY_CONCURRENCY":
concurrency = int(value)
continue
if hostname and concurrency:
pending_instances[hostname] = concurrency

self.pending_instances = pending_instances

def get_worker_status(self) -> list[dict]:
workers = []
result = worker.celery.control.inspect(timeout=10).stats()
if result:
for name, stats in result.items():
# If this was one of our pending instances, remove it from the list
if name in self.pending_instances:
del self.pending_instances[name]
workers.append({"name": name, "stats": stats})
return workers

Expand Down Expand Up @@ -218,6 +240,7 @@ def get_current_instances(self) -> list[dict]:
)
)
self._update_running_instances(instances)
self._update_pending_instances(instances)
return instances

def create_instances(self, count: int) -> int:
Expand Down Expand Up @@ -294,8 +317,6 @@ def create_instances(self, count: int) -> int:
logging.info(
f"Started instance {instance_id}, a {instance['gpu_name']} for ${instance['dph_total'] if os.getenv('VAST_ONDEMAND') else body['price']}/hr"
)
# Add the instance to our list of pending instances so we can check when it comes online
self.pending_instances[self.envs["CELERY_HOSTNAME"]] = concurrency
# Update our other vars
self.running_instances.append(hostname)
instances_created += 1
Expand Down Expand Up @@ -387,6 +408,7 @@ def delete_instances(
)

self._update_running_instances(instances)
self._update_pending_instances(instances)

return len(deletable_instances)

Expand Down Expand Up @@ -443,6 +465,8 @@ def maybe_scale(self) -> int:
self.delete_instances(delete_exited=True, delete_errored=True)

current_instances, needed_instances = self.calculate_needed_instances()
current_instances += len(self.pending_instances)

target_instances = min(max(needed_instances, self.min), self.max)

if target_instances > current_instances:
Expand Down
27 changes: 19 additions & 8 deletions app/worker.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
#!/usr/bin/env python3

from hashlib import sha256
from multiprocessing.pool import AsyncResult
from typing import Optional
import json
import logging
import os
import signal
from hashlib import sha256
from multiprocessing.pool import AsyncResult
from typing import Optional

import requests
import sentry_sdk
from celery import Celery, signals, states
from celery.exceptions import Reject
from dotenv import load_dotenv
from sentry_sdk.integrations.celery import CeleryIntegration
import sentry_sdk

load_dotenv()

from app.utils.storage import fetch_audio
from app.whisper.base import TranscribeOptions, WhisperResult
from app.whisper.transcribe import transcribe
from app.geocoding.geocoding import lookup_geo
from app.models.metadata import Metadata
from app.notifications.notification import send_notifications
from app.search import search
from app.utils import api_client
from app.utils.exceptions import before_send
from app.search import search
from app.utils.storage import fetch_audio
from app.whisper.base import TranscribeOptions, WhisperResult
from app.whisper.exceptions import WhisperException
from app.whisper.task import API_IMPLEMENTATIONS, WhisperTask
from app.whisper.transcribe import transcribe

sentry_dsn = os.getenv("SENTRY_DSN")
if sentry_dsn:
Expand Down Expand Up @@ -96,6 +97,16 @@ def task_prerun(**kwargs): # type: ignore
logger.fatal(
"Exceeded job failure threshold, exiting...\n" + str(recent_job_results)
)
# If this is a vast.ai instance, delete itself since it must not be working properly
vast_api_key = os.getenv("CONTAINER_API_KEY")
vast_instance_id = os.getenv("CONTAINER_ID")
if vast_api_key and vast_instance_id:
logger.info("Deleting this vast.ai instance...")
requests.delete(
f"https://console.vast.ai/api/v0/instances/{vast_instance_id}/",
headers={"Authorization": f"Bearer {vast_api_key}"},
json={},
)
os.kill(
os.getppid(),
signal.SIGQUIT if hasattr(signal, "SIGQUIT") else signal.SIGTERM,
Expand Down
Loading

0 comments on commit c3d43ef

Please sign in to comment.