From d6d5bddd4f0f1764d47766351c7a82cf3baf02cf Mon Sep 17 00:00:00 2001 From: dan nelson Date: Fri, 17 Mar 2023 19:33:21 +0000 Subject: [PATCH] renaming and clarifying max_length parameter --- predict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/predict.py b/predict.py index 806b79ec..de407dcc 100644 --- a/predict.py +++ b/predict.py @@ -20,10 +20,10 @@ def predict( self, prompt: str = Input(description=f"Prompt to send to LLaMA."), n: int = Input(description="Number of output sequences to generate", default=1, ge=1, le=5), - max_length: int = Input( - description="Maximum number of tokens to generate. A word is generally 2-3 tokens", + total_tokens: int = Input( + description="Maximum number of tokens for input + generation. A word is generally 2-3 tokens", ge=1, - default=50 + default=2000 ), temperature: float = Input( description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic, 0.75 is a good starting value.", @@ -50,7 +50,7 @@ def predict( outputs = self.model.generate( input, num_return_sequences=n, - max_length=max_length, + max_length=total_tokens, do_sample=True, temperature=temperature, top_p=top_p,