diff --git a/CybersecurityBenchmarks/benchmark/llm.py b/CybersecurityBenchmarks/benchmark/llm.py index bcffed42..4cb3cb63 100644 --- a/CybersecurityBenchmarks/benchmark/llm.py +++ b/CybersecurityBenchmarks/benchmark/llm.py @@ -24,6 +24,7 @@ NUM_LLM_RETRIES = 100 +MAX_RETRY_TIMEOUT = 600 MAX_TOKENS = 1000 @@ -135,6 +136,7 @@ def _query_with_retries( func: Callable[..., str], *args: Union[str, bool, Optional[str]], retries: int = NUM_LLM_RETRIES, + max_retry_timeout: int = MAX_RETRY_TIMEOUT, backoff_factor: float = 0.5, ) -> str: last_exception = None @@ -145,8 +147,8 @@ def _query_with_retries( last_exception = exception sleep_time = ( backoff_factor * (2**retry) - if backoff_factor * (2**retry) <= 600 - else 600 + if backoff_factor * (2**retry) <= max_retry_timeout + else max_retry_timeout ) time.sleep(sleep_time) LOG.error( @@ -161,6 +163,7 @@ def _chat_with_retries( func: Callable[..., str], *args: Union[List[str], bool, Optional[str]], retries: int = NUM_LLM_RETRIES, + max_retry_timeout: int = MAX_RETRY_TIMEOUT, backoff_factor: float = 0.5, ) -> str: last_exception = None @@ -171,8 +174,8 @@ def _chat_with_retries( last_exception = exception sleep_time = ( backoff_factor * (2**retry) - if backoff_factor * (2**retry) <= 600 - else 600 + if backoff_factor * (2**retry) <= max_retry_timeout + else max_retry_timeout ) time.sleep(sleep_time) LOG.error( @@ -332,6 +335,7 @@ def query_multimodal_with_retries( image_path: Optional[str] = None, audio_path: Optional[str] = None, num_retries: int = NUM_LLM_RETRIES, + max_retry_timeout: int = MAX_RETRY_TIMEOUT, max_tokens: int = DEFAULT_MAX_TOKENS, temperature: float = DEFAULT_TEMPERATURE, top_p: float = DEFAULT_TOP_P, @@ -393,8 +397,8 @@ def query_multimodal_with_retries( last_exception = exception sleep_time = ( backoff_factor * (2**retry) - if backoff_factor * (2**retry) <= 600 - else 600 + if backoff_factor * (2**retry) <= max_retry_timeout + else max_retry_timeout ) time.sleep(sleep_time) LOG.error(