Skip to content

Commit

Permalink
Fix inference bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Apr 19, 2024
1 parent b5cc561 commit f82a06e
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 55 deletions.
26 changes: 19 additions & 7 deletions surya/model/ordering/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,9 @@ def __init__(self, config):
self.h_embed = nn.Embedding(config.max_height, config.d_model)
self.cx_embed = nn.Embedding(config.max_width, config.d_model)
self.cy_embed = nn.Embedding(config.max_height, config.d_model)
self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.d_model)

self.layernorm = nn.LayerNorm(config.d_model, eps=1e-5)

def forward(self, boxes: torch.LongTensor):
def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor, past_key_values_length: int):
x1, y1, x2, y2 = boxes.unbind(dim=-1)
# Shape is (batch_size, num_boxes/seq len, d_model)
w = x2 - x1
Expand All @@ -261,7 +260,15 @@ def forward(self, boxes: torch.LongTensor):

coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)
embedded = self.layernorm(embedded)

# Add in positional embeddings for the boxes
if past_key_values_length == 0:
for j in range(embedded.shape[0]):
box_start = input_box_counts[j, 0]
box_end = input_box_counts[j, 1] - 1 # Skip the sep token
box_count = box_end - box_start
embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count]

return embedded


Expand Down Expand Up @@ -332,7 +339,7 @@ def forward(
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_boxes) * self.embed_scale
inputs_embeds = self.embed_tokens(input_boxes, input_boxes_counts, past_key_values_length) * self.embed_scale

if self._use_flash_attention_2:
# 2d mask is passed through the layers
Expand All @@ -344,8 +351,14 @@ def forward(
)

if past_key_values_length == 0:
box_ends = input_boxes_counts[:, 1]
box_starts = input_boxes_counts[:, 0]
input_shape_arranged = torch.arange(input_shape[1], device=attention_mask.device)[None, :]
# Enable all boxes to attend to each other (before the sep token)
boxes_mask = torch.arange(input_shape[1], device=attention_mask.device)[None, :] < input_boxes_counts[:, None]
# Ensure that the boxes are not attending to the padding tokens
boxes_end_mask = input_shape_arranged < box_ends[:, None]
boxes_start_mask = input_shape_arranged >= box_starts[:, None]
boxes_mask = boxes_end_mask & boxes_start_mask
boxes_mask = boxes_mask.unsqueeze(1).unsqueeze(1) # Enable proper broadcasting
attention_mask = attention_mask.masked_fill(boxes_mask, 0)

Expand Down Expand Up @@ -482,7 +495,6 @@ def __init__(self, config):
self.model = MBartOrderDecoderWrapper(config)

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.output_scale = config.output_scale

# Initialize weights and apply final processing
self.post_init()
Expand Down
28 changes: 17 additions & 11 deletions surya/model/ordering/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
def load_processor(checkpoint=settings.ORDER_MODEL_CHECKPOINT):
processor = OrderImageProcessor.from_pretrained(checkpoint)
processor.size = settings.ORDER_IMAGE_SIZE
box_size = 1000
processor.token_sep_id = 257 + box_size
processor.token_pad_id = 258 + box_size
processor.max_boxes = settings.ORDER_MAX_BOXES
box_size = 1024
max_tokens = 256
processor.token_sep_id = max_tokens + box_size + 1
processor.token_pad_id = max_tokens + box_size + 2
processor.max_boxes = settings.ORDER_MAX_BOXES - 1
processor.box_size = {"height": box_size, "width": box_size}
return processor

Expand Down Expand Up @@ -60,19 +61,24 @@ def process_inner(self, images: List[List]):
return np_images

def process_boxes(self, boxes):
max_boxes = max(len(b) for b in boxes) + 1
padded_boxes = []
box_masks = []
box_counts = []
for b in boxes:
# Left pad for generation
padded_b = deepcopy(b)
padded_b.append([self.token_sep_id] * 4) # Sep token to indicate start of label predictions
box_mask = [0] * (max_boxes - len(b)) + [1] * len(b)
padded_box = [[self.token_pad_id] * 4] * (max_boxes - len(b)) + b
padded_boxes.append(padded_box)
padded_boxes.append(padded_b)

max_boxes = max(len(b) for b in padded_boxes)
for i in range(len(padded_boxes)):
pad_len = max_boxes - len(padded_boxes[i])
box_len = len(padded_boxes[i])
box_mask = [0] * pad_len + [1] * box_len
padded_box = [[self.token_pad_id] * 4] * pad_len + padded_boxes[i]
padded_boxes[i] = padded_box
box_masks.append(box_mask)
box_counts.append(len(b))
box_counts.append([pad_len, max_boxes])

return padded_boxes, box_masks, box_counts

Expand All @@ -87,7 +93,7 @@ def resize_img_and_boxes(self, img, boxes):
width, height = orig_dim
box_width, box_height = self.box_size["width"], self.box_size["height"]
for box in boxes:
# Rescale to 0-1000
# Rescale to 0-1024
box[0] = box[0] / width * box_width
box[1] = box[1] / height * box_height
box[2] = box[2] / width * box_width
Expand Down Expand Up @@ -136,7 +142,7 @@ def preprocess(
new_images = []
new_boxes = []
for img, box in zip(images, boxes):
if len(box) > self.processor.max_boxes:
if len(box) > self.max_boxes:
raise ValueError(f"Too many boxes, max is {self.max_boxes}")
img, box = self.resize_img_and_boxes(img, box)
new_images.append(img)
Expand Down
79 changes: 44 additions & 35 deletions surya/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ def batch_ordering(images: List, bboxes: List[List[List[float]]], model, process
batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device)
batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device)
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)

with torch.inference_mode():
token_count = 0
past_key_values = None
encoder_outputs = None
batch_predictions = [[] for _ in range(len(batch_images))]
done = [False for _ in range(len(batch_images))]
while token_count < settings.ORDER_MAX_BOXES:
batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device)

token_count = 0
past_key_values = None
encoder_outputs = None
batch_predictions = [[] for _ in range(len(batch_images))]
done = [False for _ in range(len(batch_images))]
while token_count < settings.ORDER_MAX_BOXES:
with torch.inference_mode():
return_dict = model(
pixel_values=batch_pixel_values,
decoder_input_boxes=batch_bboxes,
Expand All @@ -75,38 +76,46 @@ def batch_ordering(images: List, bboxes: List[List[List[float]]], model, process
)
logits = return_dict["logits"].detach().cpu()

last_tokens = []
last_token_mask = []
for j, in range(logits.shape[0]):
new_logits = logits[j].clone()
new_logits[batch_predictions[j]] = -1e9 # Mask out already predicted tokens, we can only predict each token once
pred = int(torch.argmax(logits[j], dim=-1).item())

last_tokens.append([pred] * 4)
if pred == processor.token_pad_id:
last_token_mask.append([0])
done[j] = True
else:
last_token_mask.append([1])
batch_predictions[j].append(pred - processor.box_size["height"]) # Get rank prediction for given position

# Break when we finished generating all sequences
if all(done):
break

past_key_values = return_dict["past_key_values"]
encoder_outputs = (return_dict["encoder_last_hidden_state"],)

batch_bboxes = torch.tensor(last_tokens, dtype=torch.long).to(model.device)
batch_bbox_mask = torch.tensor(last_token_mask, dtype=torch.long).to(model.device)
token_count += 1
last_tokens = []
last_token_mask = []
for j in range(logits.shape[0]):
label_count = batch_bbox_counts[j, 1] - batch_bbox_counts[j, 0] - 1 # Subtract 1 for the sep token
new_logits = logits[j, -1].clone()
new_logits[batch_predictions[j]] = -1e9 # Mask out already predicted tokens, we can only predict each token once
new_logits[label_count:] = -1e9 # Mask out all logit positions above the number of bboxes
pred = int(torch.argmax(new_logits, dim=-1).item())

# Add one to avoid colliding with the 1000 height/width token for bboxes
last_tokens.append([[pred + processor.box_size["height"] + 1] * 4])
if len(batch_predictions[j]) == label_count - 1: # Minus one since we're appending the final label
last_token_mask.append([0])
batch_predictions[j].append(pred)
done[j] = True
else:
last_token_mask.append([1])
batch_predictions[j].append(pred) # Get rank prediction for given position

# Break when we finished generating all sequences
if all(done):
break

past_key_values = return_dict["past_key_values"]
encoder_outputs = (return_dict["encoder_last_hidden_state"],)

batch_bboxes = torch.tensor(last_tokens, dtype=torch.long).to(model.device)
token_bbox_mask = torch.tensor(last_token_mask, dtype=torch.long).to(model.device)
batch_bbox_mask = torch.cat([batch_bbox_mask, token_bbox_mask], dim=1)
token_count += 1

for j, row_pred in enumerate(batch_predictions):
row_bboxes = bboxes[i+j]
assert len(row_pred) == len(row_bboxes), "Mismatch between logits and bboxes."
assert len(row_pred) == len(row_bboxes), f"Mismatch between logits and bboxes. Logits: {len(row_pred)}, Bboxes: {len(row_bboxes)}"

orig_size = orig_sizes[j]
ranks = rank_elements(row_pred)
ranks = [0] * len(row_bboxes)

for box_idx in range(len(row_bboxes)):
ranks[row_pred[box_idx]] = box_idx

if labels is not None:
# This is to force headers/footers into the proper order
Expand Down
4 changes: 2 additions & 2 deletions surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def TORCH_DEVICE_DETECTION(self) -> str:
LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench"

# Ordering
ORDER_MODEL_CHECKPOINT: str = "vikp/order_hr"
ORDER_IMAGE_SIZE: Dict = {"height": 1280, "width": 1280}
ORDER_MODEL_CHECKPOINT: str = "vikp/surya_order"
ORDER_IMAGE_SIZE: Dict = {"height": 1024, "width": 1024}
ORDER_MAX_BOXES: int = 256
ORDER_BATCH_SIZE: Optional[int] = None # Defaults to 4 for CPU/MPS, 32 otherwise
ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench"
Expand Down

0 comments on commit f82a06e

Please sign in to comment.