Skip to content

Commit

Permalink
add mark step for recognition model [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Jan 31, 2025
1 parent 4bed591 commit c5650a7
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions surya/recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tqdm import tqdm
import torch.nn.functional as F

from surya.common.util import mark_step
from surya.common.predictor import BasePredictor
from surya.detection import DetectionPredictor
from surya.input.processing import convert_if_not_rgb, slice_polys_from_image, slice_bboxes_from_image
Expand Down Expand Up @@ -249,6 +250,7 @@ def batch_recognition(

with settings.INFERENCE_MODE():
encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state
mark_step()

text_encoder_input_ids = torch.arange(
self.model.text_encoder.config.query_token_count,
Expand All @@ -264,6 +266,7 @@ def batch_recognition(
encoder_attention_mask=None,
use_cache=False
).hidden_states
mark_step()
del encoder_hidden_states

if settings.RECOGNITION_STATIC_CACHE:
Expand All @@ -281,6 +284,7 @@ def batch_recognition(
use_cache=True,
prefill=is_prefill
)
mark_step()

decoder_position_ids = decoder_position_ids[-1:] + 1
logits = return_dict["logits"][:current_batch_size] # Ignore batch padding
Expand All @@ -296,27 +300,35 @@ def batch_recognition(
scores = scores.masked_fill(all_done, 0)
sequence_scores = torch.cat([sequence_scores, scores], dim=1)

mark_step()
if all_done.all():
break

batch_decoder_input = preds.unsqueeze(1)

for j, (pred, status) in enumerate(zip(preds, all_done)):
mark_step()
if not status:
batch_predictions[j].append(int(pred))

token_count += inference_token_count

mark_step()
inference_token_count = batch_decoder_input.shape[-1]

mark_step()
max_position_id = torch.max(decoder_position_ids).item()
decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64,
device=self.model.device).cumsum(0) - 1 + max_position_id

if settings.RECOGNITION_STATIC_CACHE:
batch_decoder_input = self.pad_to_batch_size(batch_decoder_input, batch_size)

mark_step()
sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1)
detected_text = self.processor.tokenizer.batch_decode(batch_predictions)

mark_step()
# Convert sequence_scores to list for the current batch
batch_confidences = sequence_scores.tolist()

Expand Down

0 comments on commit c5650a7

Please sign in to comment.