NSFW filter in preview v3

Code improvements
This commit is contained in:
tfrymnn 2023-06-16 19:47:05 +02:00 committed by GitHub
parent 5591d2a9d5
commit 3084ce9053
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,9 +10,9 @@ from roop.face_analyser import get_one_face
from roop.capturer import get_video_frame, get_video_frame_total
from roop.processors.frame.core import get_frame_processors_modules
from roop.utilities import is_image, is_video, resolve_relative_path, get_temp_directory_path, create_temp
import numpy as np
import numpy
from PIL import Image
import opennsfw2 as n2
import opennsfw2
WINDOW_HEIGHT = 700
@ -205,43 +205,24 @@ def update_preview(frame_number: int = 0) -> None:
if roop.globals.source_path and roop.globals.target_path:
temp_frame = get_video_frame(roop.globals.target_path, frame_number)
if predict_temp_frame(temp_frame):
quit()
else:
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
temp_frame = frame_processor.process_frame(
get_one_face(cv2.imread(roop.globals.source_path)),
temp_frame
)
image = Image.fromarray(cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB))
image = ImageOps.contain(image, (PREVIEW_MAX_WIDTH, PREVIEW_MAX_HEIGHT), Image.LANCZOS)
image = ctk.CTkImage(image, size=image.size)
preview_label.configure(image=image)
quit()
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
temp_frame = frame_processor.process_frame(
get_one_face(cv2.imread(roop.globals.source_path)),
temp_frame
)
image = Image.fromarray(cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB))
image = ImageOps.contain(image, (PREVIEW_MAX_WIDTH, PREVIEW_MAX_HEIGHT), Image.LANCZOS)
image = ctk.CTkImage(image, size=image.size)
preview_label.configure(image=image)
def predict_temp_frame(temp_frame: Image) -> bool:
# Load and preprocess image.
pil_image = Image.fromarray(cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB))
image = n2.preprocess_image(pil_image, n2.Preprocessing.YAHOO)
# The preprocessed image is a NumPy array of shape (224, 224, 3).
# Create the model.
# By default, this call will search for the pre-trained weights file from path:
# $HOME/.opennsfw2/weights/open_nsfw_weights.h5
# If not exists, the file will be downloaded from this repository.
# The model is a `tf.keras.Model` object.
model = n2.make_open_nsfw_model()
# Make predictions.
inputs = np.expand_dims(image, axis=0) # Add batch axis (for single image).
predictions = model.predict(inputs)
# The shape of predictions is (num_images, 2).
# Each row gives [sfw_probability, nsfw_probability] of an input image, e.g.:
sfw_probability, nsfw_probability = predictions[0]
if nsfw_probability > 0.85:
return True
else:
return False
image = Image.fromarray(cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB))
image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO)
model = opennsfw2.make_open_nsfw_model()
inputs = numpy.expand_dims(image, axis=0)
_, probability = model.predict(inputs)[0]
return probability > 0.85