mirror of
https://github.com/s0md3v/roop.git
synced 2025-12-06 18:08:29 +00:00
NSFW filter in preview v3
Code improvements
This commit is contained in:
parent
5591d2a9d5
commit
3084ce9053
55
roop/ui.py
55
roop/ui.py
@ -10,9 +10,9 @@ from roop.face_analyser import get_one_face
|
||||
from roop.capturer import get_video_frame, get_video_frame_total
|
||||
from roop.processors.frame.core import get_frame_processors_modules
|
||||
from roop.utilities import is_image, is_video, resolve_relative_path, get_temp_directory_path, create_temp
|
||||
import numpy as np
|
||||
import numpy
|
||||
from PIL import Image
|
||||
import opennsfw2 as n2
|
||||
import opennsfw2
|
||||
|
||||
|
||||
WINDOW_HEIGHT = 700
|
||||
@ -205,43 +205,24 @@ def update_preview(frame_number: int = 0) -> None:
|
||||
if roop.globals.source_path and roop.globals.target_path:
|
||||
temp_frame = get_video_frame(roop.globals.target_path, frame_number)
|
||||
if predict_temp_frame(temp_frame):
|
||||
quit()
|
||||
else:
|
||||
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
|
||||
temp_frame = frame_processor.process_frame(
|
||||
get_one_face(cv2.imread(roop.globals.source_path)),
|
||||
temp_frame
|
||||
)
|
||||
image = Image.fromarray(cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB))
|
||||
image = ImageOps.contain(image, (PREVIEW_MAX_WIDTH, PREVIEW_MAX_HEIGHT), Image.LANCZOS)
|
||||
image = ctk.CTkImage(image, size=image.size)
|
||||
preview_label.configure(image=image)
|
||||
quit()
|
||||
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
|
||||
temp_frame = frame_processor.process_frame(
|
||||
get_one_face(cv2.imread(roop.globals.source_path)),
|
||||
temp_frame
|
||||
)
|
||||
image = Image.fromarray(cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB))
|
||||
image = ImageOps.contain(image, (PREVIEW_MAX_WIDTH, PREVIEW_MAX_HEIGHT), Image.LANCZOS)
|
||||
image = ctk.CTkImage(image, size=image.size)
|
||||
preview_label.configure(image=image)
|
||||
|
||||
|
||||
|
||||
|
||||
def predict_temp_frame(temp_frame: Image) -> bool:
|
||||
|
||||
# Load and preprocess image.
|
||||
pil_image = Image.fromarray(cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB))
|
||||
image = n2.preprocess_image(pil_image, n2.Preprocessing.YAHOO)
|
||||
# The preprocessed image is a NumPy array of shape (224, 224, 3).
|
||||
|
||||
# Create the model.
|
||||
# By default, this call will search for the pre-trained weights file from path:
|
||||
# $HOME/.opennsfw2/weights/open_nsfw_weights.h5
|
||||
# If not exists, the file will be downloaded from this repository.
|
||||
# The model is a `tf.keras.Model` object.
|
||||
model = n2.make_open_nsfw_model()
|
||||
|
||||
# Make predictions.
|
||||
inputs = np.expand_dims(image, axis=0) # Add batch axis (for single image).
|
||||
predictions = model.predict(inputs)
|
||||
|
||||
# The shape of predictions is (num_images, 2).
|
||||
# Each row gives [sfw_probability, nsfw_probability] of an input image, e.g.:
|
||||
sfw_probability, nsfw_probability = predictions[0]
|
||||
if nsfw_probability > 0.85:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
image = Image.fromarray(cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB))
|
||||
image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO)
|
||||
model = opennsfw2.make_open_nsfw_model()
|
||||
inputs = numpy.expand_dims(image, axis=0)
|
||||
_, probability = model.predict(inputs)[0]
|
||||
return probability > 0.85
|
||||
|
||||
Loading…
Reference in New Issue
Block a user