mirror of
https://github.com/s0md3v/roop.git
synced 2025-12-06 18:08:29 +00:00
Added Quality Scaler and Alert Box
This commit is contained in:
parent
86100e28da
commit
19c0948eb3
252
roop/core.py
252
roop/core.py
@ -2,11 +2,12 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# single thread doubles cuda performance - needs to be set before torch import
|
||||
if any(arg.startswith('--execution-provider') for arg in sys.argv):
|
||||
os.environ['OMP_NUM_THREADS'] = '1'
|
||||
if any(arg.startswith("--execution-provider") for arg in sys.argv):
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
# reduce tensorflow log level
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
import warnings
|
||||
from typing import List
|
||||
import platform
|
||||
@ -15,7 +16,8 @@ import shutil
|
||||
import argparse
|
||||
import torch
|
||||
import onnxruntime
|
||||
if not 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
|
||||
|
||||
if not "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
del torch
|
||||
import tensorflow
|
||||
|
||||
@ -24,39 +26,136 @@ import roop.metadata
|
||||
import roop.ui as ui
|
||||
from roop.predictor import predict_image, predict_video
|
||||
from roop.processors.frame.core import get_frame_processors_modules
|
||||
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
|
||||
from roop.utilities import (
|
||||
has_image_extension,
|
||||
is_image,
|
||||
is_video,
|
||||
detect_fps,
|
||||
create_video,
|
||||
extract_frames,
|
||||
get_temp_frame_paths,
|
||||
restore_audio,
|
||||
create_temp,
|
||||
move_temp,
|
||||
clean_temp,
|
||||
normalize_output_path,
|
||||
)
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
|
||||
warnings.filterwarnings("ignore", category=FutureWarning, module="insightface")
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision")
|
||||
|
||||
|
||||
def parse_args() -> None:
|
||||
signal.signal(signal.SIGINT, lambda signal_number, frame: destroy())
|
||||
program = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100))
|
||||
program.add_argument('-s', '--source', help='select an source image', dest='source_path')
|
||||
program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
|
||||
program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
|
||||
program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
|
||||
program.add_argument('--keep-fps', help='keep target fps', dest='keep_fps', action='store_true')
|
||||
program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true')
|
||||
program.add_argument('--skip-audio', help='skip target audio', dest='skip_audio', action='store_true')
|
||||
program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true')
|
||||
program.add_argument('--reference-face-position', help='position of the reference face', dest='reference_face_position', type=int, default=0)
|
||||
program.add_argument('--reference-frame-number', help='number of the reference frame', dest='reference_frame_number', type=int, default=0)
|
||||
program.add_argument('--similar-face-distance', help='face distance used for recognition', dest='similar_face_distance', type=float, default=0.85)
|
||||
program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9'])
|
||||
program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]')
|
||||
program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int)
|
||||
program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
|
||||
program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
|
||||
program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
|
||||
program = argparse.ArgumentParser(
|
||||
formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100)
|
||||
)
|
||||
program.add_argument(
|
||||
"-s", "--source", help="select an source image", dest="source_path"
|
||||
)
|
||||
program.add_argument(
|
||||
"-t", "--target", help="select an target image or video", dest="target_path"
|
||||
)
|
||||
program.add_argument(
|
||||
"-o", "--output", help="select output file or directory", dest="output_path"
|
||||
)
|
||||
program.add_argument(
|
||||
"--frame-processor",
|
||||
help="frame processors (choices: face_swapper, face_enhancer, ...)",
|
||||
dest="frame_processor",
|
||||
default=["face_swapper"],
|
||||
nargs="+",
|
||||
)
|
||||
program.add_argument(
|
||||
"--keep-fps", help="keep target fps", dest="keep_fps", action="store_true"
|
||||
)
|
||||
program.add_argument(
|
||||
"--keep-frames",
|
||||
help="keep temporary frames",
|
||||
dest="keep_frames",
|
||||
action="store_true",
|
||||
)
|
||||
program.add_argument(
|
||||
"--skip-audio", help="skip target audio", dest="skip_audio", action="store_true"
|
||||
)
|
||||
program.add_argument(
|
||||
"--many-faces",
|
||||
help="process every face",
|
||||
dest="many_faces",
|
||||
action="store_true",
|
||||
)
|
||||
program.add_argument(
|
||||
"--reference-face-position",
|
||||
help="position of the reference face",
|
||||
dest="reference_face_position",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
program.add_argument(
|
||||
"--reference-frame-number",
|
||||
help="number of the reference frame",
|
||||
dest="reference_frame_number",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
program.add_argument(
|
||||
"--similar-face-distance",
|
||||
help="face distance used for recognition",
|
||||
dest="similar_face_distance",
|
||||
type=float,
|
||||
default=0.85,
|
||||
)
|
||||
program.add_argument(
|
||||
"--video-encoder",
|
||||
help="adjust output video encoder",
|
||||
dest="video_encoder",
|
||||
default="libx264",
|
||||
choices=["libx264", "libx265", "libvpx-vp9"],
|
||||
)
|
||||
program.add_argument(
|
||||
"--video-quality",
|
||||
help="adjust output video quality",
|
||||
dest="video_quality",
|
||||
type=int,
|
||||
default=18,
|
||||
choices=range(52),
|
||||
metavar="[0-51]",
|
||||
)
|
||||
program.add_argument(
|
||||
"--max-memory", help="maximum amount of RAM in GB", dest="max_memory", type=int
|
||||
)
|
||||
program.add_argument(
|
||||
"--execution-provider",
|
||||
help="available execution provider (choices: cpu, ...)",
|
||||
dest="execution_provider",
|
||||
default=["cpu"],
|
||||
choices=suggest_execution_providers(),
|
||||
nargs="+",
|
||||
)
|
||||
program.add_argument(
|
||||
"--execution-threads",
|
||||
help="number of execution threads",
|
||||
dest="execution_threads",
|
||||
type=int,
|
||||
default=suggest_execution_threads(),
|
||||
)
|
||||
program.add_argument(
|
||||
"-v",
|
||||
"--version",
|
||||
action="version",
|
||||
version=f"{roop.metadata.name} {roop.metadata.version}",
|
||||
)
|
||||
|
||||
args = program.parse_args()
|
||||
|
||||
roop.globals.source_path = args.source_path
|
||||
roop.globals.target_path = args.target_path
|
||||
roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path) # type: ignore
|
||||
roop.globals.headless = roop.globals.source_path and roop.globals.target_path and roop.globals.output_path
|
||||
roop.globals.headless = (
|
||||
roop.globals.source_path
|
||||
and roop.globals.target_path
|
||||
and roop.globals.output_path
|
||||
)
|
||||
roop.globals.frame_processors = args.frame_processor
|
||||
roop.globals.keep_fps = args.keep_fps
|
||||
roop.globals.keep_frames = args.keep_frames
|
||||
@ -68,17 +167,31 @@ def parse_args() -> None:
|
||||
roop.globals.video_encoder = args.video_encoder
|
||||
roop.globals.video_quality = args.video_quality
|
||||
roop.globals.max_memory = args.max_memory
|
||||
roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
|
||||
roop.globals.execution_providers = decode_execution_providers(
|
||||
args.execution_provider
|
||||
)
|
||||
roop.globals.execution_threads = args.execution_threads
|
||||
|
||||
|
||||
def encode_execution_providers(execution_providers: List[str]) -> List[str]:
|
||||
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
|
||||
return [
|
||||
execution_provider.replace("ExecutionProvider", "").lower()
|
||||
for execution_provider in execution_providers
|
||||
]
|
||||
|
||||
|
||||
def decode_execution_providers(execution_providers: List[str]) -> List[str]:
|
||||
return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
|
||||
if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
|
||||
return [
|
||||
provider
|
||||
for provider, encoded_execution_provider in zip(
|
||||
onnxruntime.get_available_providers(),
|
||||
encode_execution_providers(onnxruntime.get_available_providers()),
|
||||
)
|
||||
if any(
|
||||
execution_provider in encoded_execution_provider
|
||||
for execution_provider in execution_providers
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def suggest_execution_providers() -> List[str]:
|
||||
@ -86,44 +199,55 @@ def suggest_execution_providers() -> List[str]:
|
||||
|
||||
|
||||
def suggest_execution_threads() -> int:
|
||||
if 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
|
||||
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
return 8
|
||||
return 1
|
||||
|
||||
|
||||
def limit_resources() -> None:
|
||||
# prevent tensorflow memory leak
|
||||
gpus = tensorflow.config.experimental.list_physical_devices('GPU')
|
||||
gpus = tensorflow.config.experimental.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
tensorflow.config.experimental.set_virtual_device_configuration(gpu, [
|
||||
tensorflow.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)
|
||||
])
|
||||
tensorflow.config.experimental.set_virtual_device_configuration(
|
||||
gpu,
|
||||
[
|
||||
tensorflow.config.experimental.VirtualDeviceConfiguration(
|
||||
memory_limit=1024
|
||||
)
|
||||
],
|
||||
)
|
||||
# limit memory usage
|
||||
if roop.globals.max_memory:
|
||||
memory = roop.globals.max_memory * 1024 ** 3
|
||||
if platform.system().lower() == 'darwin':
|
||||
memory = roop.globals.max_memory * 1024 ** 6
|
||||
if platform.system().lower() == 'windows':
|
||||
memory = roop.globals.max_memory * 1024**3
|
||||
if platform.system().lower() == "darwin":
|
||||
memory = roop.globals.max_memory * 1024**6
|
||||
if platform.system().lower() == "windows":
|
||||
import ctypes
|
||||
|
||||
kernel32 = ctypes.windll.kernel32
|
||||
kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
|
||||
kernel32.SetProcessWorkingSetSize(
|
||||
-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory)
|
||||
)
|
||||
else:
|
||||
import resource
|
||||
|
||||
resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
|
||||
|
||||
|
||||
def pre_check() -> bool:
|
||||
if sys.version_info < (3, 9):
|
||||
update_status('Python version is not supported - please upgrade to 3.9 or higher.')
|
||||
update_status(
|
||||
"Python version is not supported - please upgrade to 3.9 or higher."
|
||||
)
|
||||
return False
|
||||
if not shutil.which('ffmpeg'):
|
||||
update_status('ffmpeg is not installed.')
|
||||
if not shutil.which("ffmpeg"):
|
||||
update_status("ffmpeg is not installed.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def update_status(message: str, scope: str = 'ROOP.CORE') -> None:
|
||||
print(f'[{scope}] {message}')
|
||||
def update_status(message: str, scope: str = "ROOP.CORE") -> None:
|
||||
print(f"[{scope}] {message}")
|
||||
if not roop.globals.headless:
|
||||
ui.update_status(message)
|
||||
|
||||
@ -138,60 +262,66 @@ def start() -> None:
|
||||
destroy()
|
||||
shutil.copy2(roop.globals.target_path, roop.globals.output_path)
|
||||
# process frame
|
||||
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
|
||||
update_status('Progressing...', frame_processor.NAME)
|
||||
frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
|
||||
for frame_processor in get_frame_processors_modules(
|
||||
roop.globals.frame_processors
|
||||
):
|
||||
update_status("Progressing...", frame_processor.NAME)
|
||||
frame_processor.process_image(
|
||||
roop.globals.source_path,
|
||||
roop.globals.output_path,
|
||||
roop.globals.output_path,
|
||||
)
|
||||
frame_processor.post_process()
|
||||
# validate image
|
||||
if is_image(roop.globals.target_path):
|
||||
update_status('Processing to image succeed!')
|
||||
update_status("Processing to image succeed!")
|
||||
else:
|
||||
update_status('Processing to image failed!')
|
||||
update_status("Processing to image failed!")
|
||||
return
|
||||
# process image to videos
|
||||
if predict_video(roop.globals.target_path):
|
||||
destroy()
|
||||
update_status('Creating temp resources...')
|
||||
update_status("Creating temp resources...")
|
||||
create_temp(roop.globals.target_path)
|
||||
# extract frames
|
||||
if roop.globals.keep_fps:
|
||||
fps = detect_fps(roop.globals.target_path)
|
||||
update_status(f'Extracting frames with {fps} FPS...')
|
||||
update_status(f"Extracting frames with {fps} FPS...")
|
||||
extract_frames(roop.globals.target_path, fps)
|
||||
else:
|
||||
update_status('Extracting frames with 30 FPS...')
|
||||
update_status("Extracting frames with 30 FPS...")
|
||||
extract_frames(roop.globals.target_path)
|
||||
# process frame
|
||||
temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
|
||||
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
|
||||
update_status('Progressing...', frame_processor.NAME)
|
||||
update_status("Progressing...", frame_processor.NAME)
|
||||
frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
|
||||
frame_processor.post_process()
|
||||
# create video
|
||||
if roop.globals.keep_fps:
|
||||
fps = detect_fps(roop.globals.target_path)
|
||||
update_status(f'Creating video with {fps} FPS...')
|
||||
update_status(f"Creating video with {fps} FPS...")
|
||||
create_video(roop.globals.target_path, fps)
|
||||
else:
|
||||
update_status('Creating video with 30 FPS...')
|
||||
update_status("Creating video with 30 FPS...")
|
||||
create_video(roop.globals.target_path)
|
||||
# handle audio
|
||||
if roop.globals.skip_audio:
|
||||
move_temp(roop.globals.target_path, roop.globals.output_path)
|
||||
update_status('Skipping audio...')
|
||||
update_status("Skipping audio...")
|
||||
else:
|
||||
if roop.globals.keep_fps:
|
||||
update_status('Restoring audio...')
|
||||
update_status("Restoring audio...")
|
||||
else:
|
||||
update_status('Restoring audio might cause issues as fps are not kept...')
|
||||
update_status("Restoring audio might cause issues as fps are not kept...")
|
||||
restore_audio(roop.globals.target_path, roop.globals.output_path)
|
||||
# clean temp
|
||||
clean_temp(roop.globals.target_path)
|
||||
# validate video
|
||||
if is_video(roop.globals.target_path):
|
||||
update_status('Processing to video succeed!')
|
||||
update_status("Processing to video succeed!")
|
||||
else:
|
||||
update_status('Processing to video failed!')
|
||||
update_status("Processing to video failed!")
|
||||
|
||||
|
||||
def destroy() -> None:
|
||||
|
||||
58
roop/ui.py
58
roop/ui.py
@ -37,6 +37,8 @@ preview_slider = None
|
||||
source_label = None
|
||||
target_label = None
|
||||
status_label = None
|
||||
video_quality_display = None
|
||||
video_quality_value = None
|
||||
|
||||
|
||||
# todo: remove by native support -> https://github.com/TomSchimansky/CustomTkinter/issues/934
|
||||
@ -143,11 +145,16 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
|
||||
)
|
||||
many_faces_switch.place(relx=0.6, rely=0.65)
|
||||
|
||||
# Add a slider for video quality
|
||||
video_quality_label = ctk.CTkLabel(root, text="Quality", justify="center")
|
||||
video_quality_label.place(relx=0.1, rely=0.5, relwidth=0.3, relheight=0.1)
|
||||
video_quality_label.place(relx=0.1, rely=0.7)
|
||||
|
||||
video_quality_value = ctk.IntVar(value=10) # Set a default value for video quality
|
||||
|
||||
video_quality_display = ctk.CTkLabel(
|
||||
root, textvariable=video_quality_value, justify="center"
|
||||
)
|
||||
video_quality_display.place(relx=0.2, rely=0.7)
|
||||
|
||||
video_quality_value = ctk.IntVar(value=25) # Set a default value for video quality
|
||||
video_quality_slider = ctk.CTkSlider(
|
||||
root,
|
||||
from_=0,
|
||||
@ -155,22 +162,23 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
|
||||
variable=video_quality_value,
|
||||
cursor="hand2",
|
||||
)
|
||||
video_quality_slider.place(relx=0.4, rely=0.5, relwidth=0.9, relheight=0.01)
|
||||
video_quality_slider.place(relx=0.09, rely=0.75, relwidth=0.8, relheight=0.03)
|
||||
video_quality_slider.bind("<<Value>>", update_video_quality_value)
|
||||
|
||||
start_button = ctk.CTkButton(
|
||||
root, text="Start", cursor="hand2", command=lambda: select_output_path(start)
|
||||
)
|
||||
start_button.place(relx=0.15, rely=0.75, relwidth=0.2, relheight=0.05)
|
||||
start_button.place(relx=0.15, rely=0.83, relwidth=0.2, relheight=0.05)
|
||||
|
||||
stop_button = ctk.CTkButton(
|
||||
root, text="Destroy", cursor="hand2", command=lambda: destroy()
|
||||
)
|
||||
stop_button.place(relx=0.4, rely=0.75, relwidth=0.2, relheight=0.05)
|
||||
stop_button.place(relx=0.4, rely=0.83, relwidth=0.2, relheight=0.05)
|
||||
|
||||
preview_button = ctk.CTkButton(
|
||||
root, text="Preview", cursor="hand2", command=lambda: toggle_preview()
|
||||
)
|
||||
preview_button.place(relx=0.65, rely=0.75, relwidth=0.2, relheight=0.05)
|
||||
preview_button.place(relx=0.65, rely=0.83, relwidth=0.2, relheight=0.05)
|
||||
|
||||
status_label = ctk.CTkLabel(root, text=None, justify="center")
|
||||
status_label.place(relx=0.1, rely=0.9, relwidth=0.8)
|
||||
@ -189,6 +197,34 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
|
||||
return root
|
||||
|
||||
|
||||
def update_video_quality_value() -> None:
|
||||
video_quality = video_quality_value.get()
|
||||
video_quality_display.configure(text=str(video_quality))
|
||||
setattr(roop.globals, "video_quality", video_quality)
|
||||
ROOT.update()
|
||||
|
||||
|
||||
def show_alert_box(title: str, message: str) -> None:
|
||||
alert_window = ctk.CTkToplevel()
|
||||
alert_window.title(title)
|
||||
|
||||
alert_width = 300
|
||||
alert_height = 100
|
||||
alert_window.geometry(f"{alert_width}x{alert_height}")
|
||||
|
||||
root_x = (ROOT_WIDTH - alert_width) // 2
|
||||
root_y = (ROOT_HEIGHT - alert_height) // 2
|
||||
alert_window.geometry(f"+{root_x}+{root_y}")
|
||||
|
||||
alert_window.minsize(alert_width, alert_height)
|
||||
|
||||
alert_label = ctk.CTkLabel(alert_window, text=message)
|
||||
alert_label.pack(expand=True)
|
||||
|
||||
ok_button = ctk.CTkButton(alert_window, text="OK", command=alert_window.destroy)
|
||||
ok_button.pack(pady=10)
|
||||
|
||||
|
||||
def create_preview(parent: ctk.CTkToplevel) -> ctk.CTkToplevel:
|
||||
global preview_label, preview_slider
|
||||
|
||||
@ -282,6 +318,14 @@ def select_output_path(start: Callable[[], None]) -> None:
|
||||
roop.globals.output_path = output_path
|
||||
RECENT_DIRECTORY_OUTPUT = os.path.dirname(roop.globals.output_path)
|
||||
start()
|
||||
else:
|
||||
# Check if source and target paths are provided
|
||||
if not roop.globals.source_path:
|
||||
show_alert_box("Missing Input", "Please select an input image.")
|
||||
return
|
||||
if not roop.globals.target_path:
|
||||
show_alert_box("Missing Target", "Please select a target image or video.")
|
||||
return
|
||||
|
||||
|
||||
def render_image_preview(image_path: str, size: Tuple[int, int]) -> ctk.CTkImage:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user