Skip to content

Commit

Permalink
Fix envvar plumbing when setting args to false (#739)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jan 13, 2025
1 parent 0f91a7e commit 095e892
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
6 changes: 4 additions & 2 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,12 +818,14 @@ fn shard_manager(

// Prefix caching
if let Some(prefix_caching) = prefix_caching {
envs.push(("PREFIX_CACHING".into(), prefix_caching.to_string().into()));
let prefix_caching = if prefix_caching { "1" } else { "0" };
envs.push(("PREFIX_CACHING".into(), prefix_caching.into()));
}

// Chunked prefill
if let Some(chunked_prefill) = chunked_prefill {
envs.push(("CHUNKED_PREFILL".into(), chunked_prefill.to_string().into()));
let chunked_prefill = if chunked_prefill { "1" } else { "0" };
envs.push(("CHUNKED_PREFILL".into(), chunked_prefill.into()));
}

// Compile max batch size and rank
Expand Down
4 changes: 2 additions & 2 deletions server/lorax_server/utils/punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
try:
import punica_kernels as _kernels

HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
HAS_SGMV = not bool(int(os.environ.get("DISABLE_SGMV", "0")))
except ImportError:
warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
_kernels = None
HAS_SGMV = False


LORAX_PUNICA_TRITON_DISABLED = bool(os.environ.get("LORAX_PUNICA_TRITON_DISABLED", ""))
LORAX_PUNICA_TRITON_DISABLED = bool(int(os.environ.get("LORAX_PUNICA_TRITON_DISABLED", "0")))
if LORAX_PUNICA_TRITON_DISABLED:
logger.info("LORAX_PUNICA_TRITON_DISABLED is set, disabling Punica Trion kernels.")

Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/utils/sources/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[


def get_hub_api(token: Optional[str] = None) -> HfApi:
if token == "" and bool(os.environ.get("LORAX_USE_GLOBAL_HF_TOKEN", 0)):
if token == "" and bool(int(os.environ.get("LORAX_USE_GLOBAL_HF_TOKEN", "0"))):
# User initialized LoRAX to fallback to global HF token if request token is empty
token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
return HfApi(token=token)
6 changes: 3 additions & 3 deletions server/lorax_server/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@


LORAX_PROFILER_DIR = os.environ.get("LORAX_PROFILER_DIR", None)
PREFIX_CACHING = bool(os.environ.get("PREFIX_CACHING", ""))
CHUNKED_PREFILL = bool(os.environ.get("CHUNKED_PREFILL", ""))
PREFIX_CACHING = bool(int(os.environ.get("PREFIX_CACHING", "0")))
CHUNKED_PREFILL = bool(int(os.environ.get("CHUNKED_PREFILL", "0")))
LORAX_SPECULATION_MAX_BATCH_SIZE = int(os.environ.get("LORAX_SPECULATION_MAX_BATCH_SIZE", 32))

# Always use flashinfer when prefix caching is enabled
FLASH_INFER = bool(os.environ.get("FLASH_INFER", "")) or PREFIX_CACHING
FLASH_INFER = bool(int(os.environ.get("FLASH_INFER", "0"))) or PREFIX_CACHING
if FLASH_INFER:
logger.info("Backend = flashinfer")
else:
Expand Down

0 comments on commit 095e892

Please sign in to comment.