mirror of
https://github.com/s0md3v/roop.git
synced 2025-12-06 18:08:29 +00:00
added optional target face
This commit is contained in:
parent
3d02b26766
commit
3fec88237a
16
roop/core.py
16
roop/core.py
@ -16,6 +16,7 @@ import argparse
|
||||
import torch
|
||||
import onnxruntime
|
||||
import tensorflow
|
||||
import cv2
|
||||
|
||||
import roop.globals
|
||||
import roop.metadata
|
||||
@ -23,6 +24,7 @@ import roop.ui as ui
|
||||
from roop.predicter import predict_image, predict_video
|
||||
from roop.processors.frame.core import get_frame_processors_modules
|
||||
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
|
||||
from roop.face_analyser import get_one_face
|
||||
|
||||
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
|
||||
del torch
|
||||
@ -136,10 +138,24 @@ def update_status(message: str, scope: str = 'ROOP.CORE') -> None:
|
||||
ui.update_status(message)
|
||||
|
||||
|
||||
def process_target_face() -> None:
|
||||
if roop.globals.target_face_path is None or roop.globals.many_faces:
|
||||
roop.globals.target_face = None
|
||||
else:
|
||||
update_status('Target face selected. Processing...')
|
||||
roop.globals.target_face = get_one_face(cv2.imread(roop.globals.target_face_path))
|
||||
if roop.globals.target_face is None:
|
||||
update_status('Failed to find face. Falling back to no target face')
|
||||
else:
|
||||
update_status('Successfully found target face')
|
||||
|
||||
|
||||
def start() -> None:
|
||||
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
|
||||
if not frame_processor.pre_start():
|
||||
return
|
||||
# process target face
|
||||
process_target_face()
|
||||
# process image to image
|
||||
if has_image_extension(roop.globals.target_path):
|
||||
if predict_image(roop.globals.target_path):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import threading
|
||||
from typing import Any
|
||||
import insightface
|
||||
import numpy
|
||||
|
||||
import roop.globals
|
||||
from roop.typing import Frame
|
||||
@ -25,6 +26,15 @@ def get_one_face(frame: Frame) -> Any:
|
||||
return min(face, key=lambda x: x.bbox[0])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def get_target_face(frame: Frame) -> Any:
|
||||
face = get_face_analyser().get(frame)
|
||||
try:
|
||||
target_embedding = roop.globals.target_face.embedding
|
||||
return max(face, key=lambda x: numpy.dot(target_embedding, x.embedding) / (numpy.linalg.norm(target_embedding) * numpy.linalg.norm(x.embedding)))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def get_many_faces(frame: Frame) -> Any:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from typing import List
|
||||
|
||||
source_path = None
|
||||
target_face_path = None
|
||||
target_face = None
|
||||
target_path = None
|
||||
output_path = None
|
||||
frame_processors: List[str] = []
|
||||
|
||||
@ -6,7 +6,7 @@ import threading
|
||||
import roop.globals
|
||||
import roop.processors.frame.core
|
||||
from roop.core import update_status
|
||||
from roop.face_analyser import get_one_face, get_many_faces
|
||||
from roop.face_analyser import get_one_face, get_target_face, get_many_faces
|
||||
from roop.typing import Face, Frame
|
||||
from roop.utilities import conditional_download, resolve_relative_path, is_image, is_video
|
||||
|
||||
@ -61,7 +61,10 @@ def process_frame(source_face: Face, temp_frame: Frame) -> Frame:
|
||||
for target_face in many_faces:
|
||||
temp_frame = swap_face(source_face, target_face, temp_frame)
|
||||
else:
|
||||
target_face = get_one_face(temp_frame)
|
||||
if roop.globals.target_face is None:
|
||||
target_face = get_one_face(temp_frame)
|
||||
else:
|
||||
target_face = get_target_face(temp_frame)
|
||||
if target_face:
|
||||
temp_frame = swap_face(source_face, target_face, temp_frame)
|
||||
return temp_frame
|
||||
|
||||
32
roop/ui.py
32
roop/ui.py
@ -28,6 +28,7 @@ RECENT_DIRECTORY_OUTPUT = None
|
||||
preview_label = None
|
||||
preview_slider = None
|
||||
source_label = None
|
||||
target_face_label = None
|
||||
target_label = None
|
||||
status_label = None
|
||||
|
||||
@ -42,7 +43,7 @@ def init(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk:
|
||||
|
||||
|
||||
def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk:
|
||||
global source_label, target_label, status_label
|
||||
global source_label, target_face_label, target_label, status_label
|
||||
|
||||
ctk.deactivate_automatic_dpi_awareness()
|
||||
ctk.set_appearance_mode('system')
|
||||
@ -55,16 +56,22 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
|
||||
root.protocol('WM_DELETE_WINDOW', lambda: destroy())
|
||||
|
||||
source_label = ctk.CTkLabel(root, text=None)
|
||||
source_label.place(relx=0.1, rely=0.1, relwidth=0.3, relheight=0.25)
|
||||
source_label.place(relx=0.1, rely=0.1, relwidth=0.23, relheight=0.25)
|
||||
|
||||
target_face_label = ctk.CTkLabel(root, text=None)
|
||||
target_face_label.place(relx=0.38, rely=0.1, relwidth=0.23, relheight=0.25)
|
||||
|
||||
target_label = ctk.CTkLabel(root, text=None)
|
||||
target_label.place(relx=0.6, rely=0.1, relwidth=0.3, relheight=0.25)
|
||||
target_label.place(relx=0.66, rely=0.1, relwidth=0.23, relheight=0.25)
|
||||
|
||||
source_button = ctk.CTkButton(root, text='Select a face', cursor='hand2', command=lambda: select_source_path())
|
||||
source_button.place(relx=0.1, rely=0.4, relwidth=0.3, relheight=0.1)
|
||||
source_button.place(relx=0.1, rely=0.4, relwidth=0.23, relheight=0.1)
|
||||
|
||||
target_face_button = ctk.CTkButton(root, text='Select a target face', cursor='hand2', command=lambda: select_target_face_path())
|
||||
target_face_button.place(relx=0.38, rely=0.4, relwidth=0.23, relheight=0.1)
|
||||
|
||||
target_button = ctk.CTkButton(root, text='Select a target', cursor='hand2', command=lambda: select_target_path())
|
||||
target_button.place(relx=0.6, rely=0.4, relwidth=0.3, relheight=0.1)
|
||||
target_button.place(relx=0.66, rely=0.4, relwidth=0.23, relheight=0.1)
|
||||
|
||||
keep_fps_value = ctk.BooleanVar(value=roop.globals.keep_fps)
|
||||
keep_fps_checkbox = ctk.CTkSwitch(root, text='Keep fps', variable=keep_fps_value, cursor='hand2', command=lambda: setattr(roop.globals, 'keep_fps', not roop.globals.keep_fps))
|
||||
@ -140,6 +147,21 @@ def select_source_path() -> None:
|
||||
source_label.configure(image=None)
|
||||
|
||||
|
||||
def select_target_face_path() -> None:
|
||||
global RECENT_DIRECTORY_SOURCE
|
||||
|
||||
PREVIEW.withdraw()
|
||||
target_face_path = ctk.filedialog.askopenfilename(title='select a target image', initialdir=RECENT_DIRECTORY_SOURCE)
|
||||
if is_image(target_face_path):
|
||||
roop.globals.target_face_path = target_face_path
|
||||
RECENT_DIRECTORY_SOURCE = os.path.dirname(target_face_path)
|
||||
image = render_image_preview(target_face_path, (200, 200))
|
||||
target_face_label.configure(image=image)
|
||||
else:
|
||||
roop.globals.target_face_path = None
|
||||
target_face_label.configure(image=None)
|
||||
|
||||
|
||||
def select_target_path() -> None:
|
||||
global RECENT_DIRECTORY_TARGET
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user