Introduce predicter

This commit is contained in:
henryruhs 2023-06-18 09:19:39 +02:00
parent b7762123fe
commit 460a20fa88
3 changed files with 30 additions and 4 deletions

View File

@ -16,10 +16,10 @@ import argparse
import torch
import onnxruntime
import tensorflow
from opennsfw2 import predict_video_frames, predict_image
import roop.globals
import roop.ui as ui
from roop.predicter import predict_image, predict_video
from roop.processors.frame.core import get_frame_processors_modules
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
@ -164,7 +164,7 @@ def start() -> None:
return
# process image to image
if has_image_extension(roop.globals.target_path):
if predict_image(roop.globals.target_path) > 0.85:
if predict_image(roop.globals.target_path):
destroy()
# todo: this needs a temp path for images to work with multiple frame processors
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
@ -177,8 +177,7 @@ def start() -> None:
update_status('Processing to image failed!')
return
# process image to videos
seconds, probabilities = predict_video_frames(video_path=roop.globals.target_path, frame_interval=100)
if any(probability > 0.85 for probability in probabilities):
if predict_video(roop.globals.target_path):
destroy()
update_status('Creating temp resources...')
create_temp(roop.globals.target_path)

24
roop/predicter.py Normal file
View File

@ -0,0 +1,24 @@
import numpy
import opennsfw2
from PIL import Image
from opennsfw2 import predict_video_frames
MAX_PROBABILITY = 0.85
def predict_frame(target_frame: Image) -> bool:
image = Image.fromarray(target_frame)
image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO)
model = opennsfw2.make_open_nsfw_model()
views = numpy.expand_dims(image, axis=0)
_, probability = model.predict(views)[0]
return probability > MAX_PROBABILITY
def predict_image(target_path: str) -> bool:
return predict_image(target_path) > MAX_PROBABILITY
def predict_video(target_path: str) -> bool:
_, probabilities = predict_video_frames(video_path=target_path, frame_interval=100)
return any(probability > MAX_PROBABILITY for probability in probabilities)

View File

@ -8,6 +8,7 @@ from PIL import Image, ImageOps
import roop.globals
from roop.face_analyser import get_one_face
from roop.capturer import get_video_frame, get_video_frame_total
from roop.predicter import predict_frame
from roop.processors.frame.core import get_frame_processors_modules
from roop.utilities import is_image, is_video, resolve_relative_path
@ -200,6 +201,8 @@ def init_preview() -> None:
def update_preview(frame_number: int = 0) -> None:
if roop.globals.source_path and roop.globals.target_path:
temp_frame = get_video_frame(roop.globals.target_path, frame_number)
if predict_frame(temp_frame):
quit()
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
temp_frame = frame_processor.process_frame(
get_one_face(cv2.imread(roop.globals.source_path)),