generated from fofr/cog-face-to-many
-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathsafety_checker.py
52 lines (43 loc) · 1.66 KB
/
safety_checker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import os
import subprocess
import numpy as np
import PIL
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from transformers import CLIPImageProcessor
FEATURE_EXTRACTOR = "./feature-extractor"
SAFETY_CACHE = "./safety-cache"
SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar"
class SafetyChecker:
def __init__(self):
if not os.path.exists(SAFETY_CACHE):
subprocess.check_call(
["pget", "-xf", SAFETY_URL, SAFETY_CACHE],
close_fds=False,
)
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_CACHE, torch_dtype=torch.float16
).to("cuda")
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)
def load_image(self, image_path):
return PIL.Image.open(image_path).convert("RGB")
def run(self, image_paths):
images = [self.load_image(image_path) for image_path in image_paths]
safety_checker_input = self.feature_extractor(images, return_tensors="pt").to(
"cuda"
)
np_images = [np.array(val) for val in images]
_, is_nsfw = self.safety_checker(
images=np_images,
clip_input=safety_checker_input.pixel_values.to(torch.float16),
)
for i, nsfw in enumerate(is_nsfw):
if nsfw:
print(f"NSFW content detected in image {i}")
if all(is_nsfw):
raise Exception(
"NSFW content detected in all outputs. Try running it again, or try a different prompt."
)
return is_nsfw