Skip to content

Commit

Permalink
Pad to batch size with xla
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Feb 4, 2025
1 parent a2c0a22 commit 9dbdf6f
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,27 @@ def TORCH_DEVICE_MODEL(self) -> str:

@computed_field
def DETECTOR_STATIC_CACHE(self) -> bool:
return self.COMPILE_ALL or self.COMPILE_DETECTOR
return self.COMPILE_ALL or self.COMPILE_DETECTOR or self.TORCH_DEVICE_MODEL == "xla" # We need to static cache and pad to batch size for XLA, since it will recompile otherwise

@computed_field
def RECOGNITION_STATIC_CACHE(self) -> bool:
return self.COMPILE_ALL or self.COMPILE_RECOGNITION
return self.COMPILE_ALL or self.COMPILE_RECOGNITION or self.TORCH_DEVICE_MODEL == "xla"

@computed_field
def LAYOUT_STATIC_CACHE(self) -> bool:
return self.COMPILE_ALL or self.COMPILE_LAYOUT
return self.COMPILE_ALL or self.COMPILE_LAYOUT or self.TORCH_DEVICE_MODEL == "xla"

@computed_field
def TABLE_REC_STATIC_CACHE(self) -> bool:
return self.COMPILE_ALL or self.COMPILE_TABLE_REC
return self.COMPILE_ALL or self.COMPILE_TABLE_REC or self.TORCH_DEVICE_MODEL == "xla"

@computed_field
def OCR_ERROR_STATIC_CACHE(self) -> bool:
return self.COMPILE_ALL or self.COMPILE_OCR_ERROR
return self.COMPILE_ALL or self.COMPILE_OCR_ERROR or self.TORCH_DEVICE_MODEL == "xla"

@computed_field
def TEXIFY_STATIC_CACHE(self) -> bool:
return self.COMPILE_ALL or self.COMPILE_TEXIFY
return self.COMPILE_ALL or self.COMPILE_TEXIFY or self.TORCH_DEVICE_MODEL == "xla"

@computed_field
def MODEL_DTYPE(self) -> torch.dtype:
Expand Down

0 comments on commit 9dbdf6f

Please sign in to comment.