Skip to content

Commit

Permalink
initialize
Browse files Browse the repository at this point in the history
  • Loading branch information
lxa9867 committed May 31, 2023
0 parents commit d91dd9e
Show file tree
Hide file tree
Showing 72 changed files with 10,748 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions .idea/deployment.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions .idea/remote-mappings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/rrvos-main.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Robust Referring VOS
![avatar](illustration.jpg)
Referring Video Object Segmentation (R-VOS) is a challenging task that aims to segment an object in a video based on a linguistic expression. Most existing R-VOS methods have a critical assumption: the object referred to must appear in the video. This assumption, which we refer to as "semantic consensus", is often violated in real-world scenarios, where the expression may be queried against false videos. In this work, we highlight the need for a robust R-VOS model that can handle semantic mismatches. Accordingly, we propose an extended task called Robust R-VOS, which accepts unpaired video-text inputs. We tackle this problem by jointly modeling the primary R-VOS problem and its dual (text reconstruction). A structural text-to-text cycle constraint is introduced to discriminate semantic consensus between video-text pairs and impose it in positive pairs, thereby achieving multi-modal alignment from both positive and negative pairs. Our structural constraint effectively addresses the challenge posed by linguistic diversity, overcoming the limitations of previous methods that relied on the point-wise constraint. A new evaluation dataset, R2-YTVOS is constructed to measure the model robustness. Our model achieves state-of-the-art performance on R-VOS benchmarks, Ref-DAVIS17 and Ref-Youtube-VOS, and also our R2-YTVOS dataset.

# Install

```
conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 -c pytorch
pip install -r requirements.txt
pip install 'git+https://github.com/facebookresearch/fvcore'
pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
cd models/ops
python setup.py build install
cd ../..
```

# Weight
Please download and put the [checkpoint.pth](https://drive.google.com/file/d/1gknDDMxWKqZ7yPuTh1fome1-Ba4f_G9K/view?usp=share_link) in the main folder.


# Run demo:
Inference on images in the demo/demo_examples.
```
python demo.py --with_box_refine --binary --freeze_text_encoder --output_dir=output/demo --resume=checkpoint.pth --backbone resnet50 --ngpu 1 --use_cycle --mix_query --neg_cls --is_eval --use_cls --demo_exp 'a big track on the road' --demo_path 'demo/demo_examples'
```


240 changes: 240 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
'''
Inference code for ReferFormer, on Ref-Youtube-VOS
Modified from DETR (https://github.com/facebookresearch/detr)
'''
import argparse
import json
import random
import sys
import time
from pathlib import Path

import numpy as np
import torch

import util.misc as utils
from models import build_model
import torchvision.transforms as T
import matplotlib.pyplot as plt
import os
import cv2
from PIL import Image, ImageDraw
import math
import torch.nn.functional as F
import json

import opts
from tqdm import tqdm

import multiprocessing as mp
import threading
import glob


from tools.colormap import colormap


# colormap
color_list = colormap()
color_list = color_list.astype('uint8').tolist()

def main(args):
args.masks = True
args.batch_size == 1
print("Inference only supports for batch size = 1")

global transform
transform = T.Compose([
T.Resize(args.inf_res),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# save path
output_dir = args.output_dir
save_path_prefix = os.path.join(output_dir)
if not os.path.exists(save_path_prefix):
os.makedirs(save_path_prefix)

global result_dict
result_dict = mp.Manager().dict()
frames = sorted(glob.glob(args.demo_path+'/*'))
sub_processor(0, args, args.demo_exp, frames, save_path_prefix)

result_dict = dict(result_dict)
num_all_frames_gpus = 0
for pid, num_all_frames in result_dict.items():
num_all_frames_gpus += num_all_frames

def sub_processor(pid, args, exp, frames, save_path_prefix):
torch.cuda.set_device(pid)

# model
model, criterion, _ = build_model(args)
device = args.device
model.to(device)

model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

if pid == 0:
print('number of params:', n_parameters)

if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
if len(missing_keys) > 0:
print('Missing Keys: {}'.format(missing_keys))
if len(unexpected_keys) > 0:
print('Unexpected Keys: {}'.format(unexpected_keys))
else:
raise ValueError('Please specify the checkpoint for inference.')


# start inference
num_all_frames = 0
model.eval()

sentence_features = []
pseudo_sentence_features = []
video_name = 'demo'
# exp = meta[i]["exp"]
# # exp = 'a dog is with its puppies on the cloth'
# # TODO: temp
# frames = meta[i]["frames"]
# frames = [f'/home/mcg/ReferFormer/demo/frames_{fid}.jpg' for fid in range(1,2)]

video_len = len(frames)
# store images
imgs = []
for t in range(video_len):
frame = frames[t]
img_path = os.path.join(frame)
img = Image.open(img_path).convert('RGB')
origin_w, origin_h = img.size
imgs.append(transform(img)) # list[img]

imgs = torch.stack(imgs, dim=0).to(args.device) # [video_len, 3, h, w]
img_h, img_w = imgs.shape[-2:]
size = torch.as_tensor([int(img_h), int(img_w)]).to(args.device)
target = {"size": size}

with torch.no_grad():
outputs = model([imgs], [exp], [target])

pred_logits = outputs["pred_logits"][0]
pred_boxes = outputs["pred_boxes"][0]
pred_masks = outputs["pred_masks"][0]
pred_ref_points = outputs["reference_points"][0]
text_sentence_features = outputs['sentence_feature']
if args.use_cycle:
pseudo_text_sentence_features = outputs['pseudo_sentence_feature']
# anchor = outputs['negative_anchor']
sentence_features.append(text_sentence_features)
pseudo_sentence_features.append(pseudo_text_sentence_features)
# print(F.pairwise_distance(text_sentence_features, pseudo_text_sentence_features.squeeze(0), p=2))
# print(anchor)
# according to pred_logits, select the query index
pred_scores = pred_logits.sigmoid() # [t, q, k]
pred_score = pred_scores
pred_scores = pred_scores.mean(0) # [q, k]
max_scores, _ = pred_scores.max(-1) # [q,]
# print(max_scores)
_, max_ind = max_scores.max(-1) # [1,]
max_inds = max_ind.repeat(video_len)
pred_masks = pred_masks[range(video_len), max_inds, ...] # [t, h, w]
pred_masks = pred_masks.unsqueeze(0)
pred_masks = F.interpolate(pred_masks, size=(origin_h, origin_w), mode='bilinear', align_corners=False)
if args.save_prob:
pred_masks = pred_masks.sigmoid().squeeze(0).detach().cpu().numpy()
else:
pred_masks = (pred_masks.sigmoid() > args.threshold).squeeze(0).detach().cpu().numpy()
if args.use_score:
pred_score = pred_score[range(video_len), max_inds, 0].unsqueeze(-1).unsqueeze(-1)
pred_masks *= (pred_score > 0.3).cpu().numpy() * pred_masks

# store the video results
all_pred_logits = pred_logits[range(video_len), max_inds].sigmoid().cpu().numpy()
all_pred_boxes = pred_boxes[range(video_len), max_inds]
all_pred_ref_points = pred_ref_points[range(video_len), max_inds]
all_pred_masks = pred_masks

save_path = os.path.join(save_path_prefix)
if not os.path.exists(save_path):
os.makedirs(save_path)
for j in range(video_len):
frame_name = frames[j]
confidence = all_pred_logits[j]
mask = all_pred_masks[j].astype(np.float32)
save_file = os.path.join(save_path, f"{j}" + ".png")
# print(save_file)
if 'pair_logits' in outputs.keys() and args.use_cls:
if outputs['pair_logits'].cpu().numpy() >= 0.5:
print('This is a negative pair, disalignment degree:', outputs['pair_logits'].cpu().numpy().item())
else:
print('This is a positive pair, disalignment degree:', outputs['pair_logits'].cpu().numpy().item())
mask *= 0 if outputs['pair_logits'].cpu().numpy() >= 0.5 else 1
mask = Image.fromarray(mask * 255).convert('L')
mask.save(save_file)
print(f'Results saved to {save_path}')
result_dict[str(pid)] = num_all_frames


# visuaize functions
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b.cpu() * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b


# Visualization functions
def draw_reference_points(draw, reference_points, img_size, color):
W, H = img_size
for i, ref_point in enumerate(reference_points):
init_x, init_y = ref_point
x, y = W * init_x, H * init_y
cur_color = color
draw.line((x-10, y, x+10, y), tuple(cur_color), width=4)
draw.line((x, y-10, x, y+10), tuple(cur_color), width=4)

def draw_sample_points(draw, sample_points, img_size, color_list):
alpha = 255
for i, samples in enumerate(sample_points):
for sample in samples:
x, y = sample
cur_color = color_list[i % len(color_list)][::-1]
cur_color += [alpha]
draw.ellipse((x-2, y-2, x+2, y+2),
fill=tuple(cur_color), outline=tuple(cur_color), width=1)

def vis_add_mask(img, mask, color):
origin_img = np.asarray(img.convert('RGB')).copy()
color = np.array(color)

mask = mask.reshape(mask.shape[0], mask.shape[1]).astype('uint8') # np
mask = mask > 0.5

origin_img[mask] = origin_img[mask] * 0.5 + color * 0.5
origin_img = Image.fromarray(origin_img)
return origin_img



if __name__ == '__main__':
parser = argparse.ArgumentParser('ReferFormer inference script', parents=[opts.get_args_parser()])
args = parser.parse_args()
main(args)
Binary file added demo/demo_examples/00000.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00005.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00010.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00015.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00020.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00025.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00030.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00035.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00040.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00045.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00050.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00055.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00060.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00065.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00070.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00075.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00080.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00085.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/demo_examples/00090.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit d91dd9e

Please sign in to comment.