diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index 7154c4449..f9df2cf69 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional, Tuple import torch @@ -44,7 +45,8 @@ def __init__( self.batch_size = batch_size # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) - self.block_size = 64 + default_block_size = 16 if device.type == "cpu" else 64 + self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size))) self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( batch_size, -1