mirror of
https://github.com/s0md3v/roop.git
synced 2025-12-06 18:08:29 +00:00
Introduce predicter
This commit is contained in:
parent
b7762123fe
commit
460a20fa88
@ -16,10 +16,10 @@ import argparse
|
||||
import torch
|
||||
import onnxruntime
|
||||
import tensorflow
|
||||
from opennsfw2 import predict_video_frames, predict_image
|
||||
|
||||
import roop.globals
|
||||
import roop.ui as ui
|
||||
from roop.predicter import predict_image, predict_video
|
||||
from roop.processors.frame.core import get_frame_processors_modules
|
||||
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
|
||||
|
||||
@ -164,7 +164,7 @@ def start() -> None:
|
||||
return
|
||||
# process image to image
|
||||
if has_image_extension(roop.globals.target_path):
|
||||
if predict_image(roop.globals.target_path) > 0.85:
|
||||
if predict_image(roop.globals.target_path):
|
||||
destroy()
|
||||
# todo: this needs a temp path for images to work with multiple frame processors
|
||||
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
|
||||
@ -177,8 +177,7 @@ def start() -> None:
|
||||
update_status('Processing to image failed!')
|
||||
return
|
||||
# process image to videos
|
||||
seconds, probabilities = predict_video_frames(video_path=roop.globals.target_path, frame_interval=100)
|
||||
if any(probability > 0.85 for probability in probabilities):
|
||||
if predict_video(roop.globals.target_path):
|
||||
destroy()
|
||||
update_status('Creating temp resources...')
|
||||
create_temp(roop.globals.target_path)
|
||||
|
||||
24
roop/predicter.py
Normal file
24
roop/predicter.py
Normal file
@ -0,0 +1,24 @@
|
||||
import numpy
|
||||
import opennsfw2
|
||||
from PIL import Image
|
||||
from opennsfw2 import predict_video_frames
|
||||
|
||||
MAX_PROBABILITY = 0.85
|
||||
|
||||
|
||||
def predict_frame(target_frame: Image) -> bool:
|
||||
image = Image.fromarray(target_frame)
|
||||
image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO)
|
||||
model = opennsfw2.make_open_nsfw_model()
|
||||
views = numpy.expand_dims(image, axis=0)
|
||||
_, probability = model.predict(views)[0]
|
||||
return probability > MAX_PROBABILITY
|
||||
|
||||
|
||||
def predict_image(target_path: str) -> bool:
|
||||
return predict_image(target_path) > MAX_PROBABILITY
|
||||
|
||||
|
||||
def predict_video(target_path: str) -> bool:
|
||||
_, probabilities = predict_video_frames(video_path=target_path, frame_interval=100)
|
||||
return any(probability > MAX_PROBABILITY for probability in probabilities)
|
||||
@ -8,6 +8,7 @@ from PIL import Image, ImageOps
|
||||
import roop.globals
|
||||
from roop.face_analyser import get_one_face
|
||||
from roop.capturer import get_video_frame, get_video_frame_total
|
||||
from roop.predicter import predict_frame
|
||||
from roop.processors.frame.core import get_frame_processors_modules
|
||||
from roop.utilities import is_image, is_video, resolve_relative_path
|
||||
|
||||
@ -200,6 +201,8 @@ def init_preview() -> None:
|
||||
def update_preview(frame_number: int = 0) -> None:
|
||||
if roop.globals.source_path and roop.globals.target_path:
|
||||
temp_frame = get_video_frame(roop.globals.target_path, frame_number)
|
||||
if predict_frame(temp_frame):
|
||||
quit()
|
||||
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
|
||||
temp_frame = frame_processor.process_frame(
|
||||
get_one_face(cv2.imread(roop.globals.source_path)),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user