From 460a20fa885148fd634ca558d2142b539804a18e Mon Sep 17 00:00:00 2001 From: henryruhs Date: Sun, 18 Jun 2023 09:19:39 +0200 Subject: [PATCH] Introduce predicter --- roop/core.py | 7 +++---- roop/predicter.py | 24 ++++++++++++++++++++++++ roop/ui.py | 3 +++ 3 files changed, 30 insertions(+), 4 deletions(-) create mode 100644 roop/predicter.py diff --git a/roop/core.py b/roop/core.py index aa73cb3..c8e9927 100755 --- a/roop/core.py +++ b/roop/core.py @@ -16,10 +16,10 @@ import argparse import torch import onnxruntime import tensorflow -from opennsfw2 import predict_video_frames, predict_image import roop.globals import roop.ui as ui +from roop.predicter import predict_image, predict_video from roop.processors.frame.core import get_frame_processors_modules from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path @@ -164,7 +164,7 @@ def start() -> None: return # process image to image if has_image_extension(roop.globals.target_path): - if predict_image(roop.globals.target_path) > 0.85: + if predict_image(roop.globals.target_path): destroy() # todo: this needs a temp path for images to work with multiple frame processors for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): @@ -177,8 +177,7 @@ def start() -> None: update_status('Processing to image failed!') return # process image to videos - seconds, probabilities = predict_video_frames(video_path=roop.globals.target_path, frame_interval=100) - if any(probability > 0.85 for probability in probabilities): + if predict_video(roop.globals.target_path): destroy() update_status('Creating temp resources...') create_temp(roop.globals.target_path) diff --git a/roop/predicter.py b/roop/predicter.py new file mode 100644 index 0000000..0411ce3 --- /dev/null +++ b/roop/predicter.py @@ -0,0 +1,24 @@ +import numpy +import opennsfw2 +from PIL import Image +from opennsfw2 import predict_video_frames + +MAX_PROBABILITY = 0.85 + + +def predict_frame(target_frame: Image) -> bool: + image = Image.fromarray(target_frame) + image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO) + model = opennsfw2.make_open_nsfw_model() + views = numpy.expand_dims(image, axis=0) + _, probability = model.predict(views)[0] + return probability > MAX_PROBABILITY + + +def predict_image(target_path: str) -> bool: + return predict_image(target_path) > MAX_PROBABILITY + + +def predict_video(target_path: str) -> bool: + _, probabilities = predict_video_frames(video_path=target_path, frame_interval=100) + return any(probability > MAX_PROBABILITY for probability in probabilities) diff --git a/roop/ui.py b/roop/ui.py index 95133d1..134fb4a 100644 --- a/roop/ui.py +++ b/roop/ui.py @@ -8,6 +8,7 @@ from PIL import Image, ImageOps import roop.globals from roop.face_analyser import get_one_face from roop.capturer import get_video_frame, get_video_frame_total +from roop.predicter import predict_frame from roop.processors.frame.core import get_frame_processors_modules from roop.utilities import is_image, is_video, resolve_relative_path @@ -200,6 +201,8 @@ def init_preview() -> None: def update_preview(frame_number: int = 0) -> None: if roop.globals.source_path and roop.globals.target_path: temp_frame = get_video_frame(roop.globals.target_path, frame_number) + if predict_frame(temp_frame): + quit() for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): temp_frame = frame_processor.process_frame( get_one_face(cv2.imread(roop.globals.source_path)),