Added Quality Scaler and Alert Box

This commit is contained in:
Lahfir 2023-07-20 10:55:40 +05:30
parent 86100e28da
commit 19c0948eb3
2 changed files with 242 additions and 68 deletions

View File

@ -2,11 +2,12 @@
import os
import sys
# single thread doubles cuda performance - needs to be set before torch import
if any(arg.startswith('--execution-provider') for arg in sys.argv):
os.environ['OMP_NUM_THREADS'] = '1'
if any(arg.startswith("--execution-provider") for arg in sys.argv):
os.environ["OMP_NUM_THREADS"] = "1"
# reduce tensorflow log level
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import warnings
from typing import List
import platform
@ -15,7 +16,8 @@ import shutil
import argparse
import torch
import onnxruntime
if not 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
if not "CUDAExecutionProvider" in onnxruntime.get_available_providers():
del torch
import tensorflow
@ -24,39 +26,136 @@ import roop.metadata
import roop.ui as ui
from roop.predictor import predict_image, predict_video
from roop.processors.frame.core import get_frame_processors_modules
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
from roop.utilities import (
has_image_extension,
is_image,
is_video,
detect_fps,
create_video,
extract_frames,
get_temp_frame_paths,
restore_audio,
create_temp,
move_temp,
clean_temp,
normalize_output_path,
)
warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
warnings.filterwarnings("ignore", category=FutureWarning, module="insightface")
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision")
def parse_args() -> None:
signal.signal(signal.SIGINT, lambda signal_number, frame: destroy())
program = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100))
program.add_argument('-s', '--source', help='select an source image', dest='source_path')
program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
program.add_argument('--keep-fps', help='keep target fps', dest='keep_fps', action='store_true')
program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true')
program.add_argument('--skip-audio', help='skip target audio', dest='skip_audio', action='store_true')
program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true')
program.add_argument('--reference-face-position', help='position of the reference face', dest='reference_face_position', type=int, default=0)
program.add_argument('--reference-frame-number', help='number of the reference frame', dest='reference_frame_number', type=int, default=0)
program.add_argument('--similar-face-distance', help='face distance used for recognition', dest='similar_face_distance', type=float, default=0.85)
program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9'])
program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]')
program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int)
program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
program = argparse.ArgumentParser(
formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100)
)
program.add_argument(
"-s", "--source", help="select an source image", dest="source_path"
)
program.add_argument(
"-t", "--target", help="select an target image or video", dest="target_path"
)
program.add_argument(
"-o", "--output", help="select output file or directory", dest="output_path"
)
program.add_argument(
"--frame-processor",
help="frame processors (choices: face_swapper, face_enhancer, ...)",
dest="frame_processor",
default=["face_swapper"],
nargs="+",
)
program.add_argument(
"--keep-fps", help="keep target fps", dest="keep_fps", action="store_true"
)
program.add_argument(
"--keep-frames",
help="keep temporary frames",
dest="keep_frames",
action="store_true",
)
program.add_argument(
"--skip-audio", help="skip target audio", dest="skip_audio", action="store_true"
)
program.add_argument(
"--many-faces",
help="process every face",
dest="many_faces",
action="store_true",
)
program.add_argument(
"--reference-face-position",
help="position of the reference face",
dest="reference_face_position",
type=int,
default=0,
)
program.add_argument(
"--reference-frame-number",
help="number of the reference frame",
dest="reference_frame_number",
type=int,
default=0,
)
program.add_argument(
"--similar-face-distance",
help="face distance used for recognition",
dest="similar_face_distance",
type=float,
default=0.85,
)
program.add_argument(
"--video-encoder",
help="adjust output video encoder",
dest="video_encoder",
default="libx264",
choices=["libx264", "libx265", "libvpx-vp9"],
)
program.add_argument(
"--video-quality",
help="adjust output video quality",
dest="video_quality",
type=int,
default=18,
choices=range(52),
metavar="[0-51]",
)
program.add_argument(
"--max-memory", help="maximum amount of RAM in GB", dest="max_memory", type=int
)
program.add_argument(
"--execution-provider",
help="available execution provider (choices: cpu, ...)",
dest="execution_provider",
default=["cpu"],
choices=suggest_execution_providers(),
nargs="+",
)
program.add_argument(
"--execution-threads",
help="number of execution threads",
dest="execution_threads",
type=int,
default=suggest_execution_threads(),
)
program.add_argument(
"-v",
"--version",
action="version",
version=f"{roop.metadata.name} {roop.metadata.version}",
)
args = program.parse_args()
roop.globals.source_path = args.source_path
roop.globals.target_path = args.target_path
roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path) # type: ignore
roop.globals.headless = roop.globals.source_path and roop.globals.target_path and roop.globals.output_path
roop.globals.headless = (
roop.globals.source_path
and roop.globals.target_path
and roop.globals.output_path
)
roop.globals.frame_processors = args.frame_processor
roop.globals.keep_fps = args.keep_fps
roop.globals.keep_frames = args.keep_frames
@ -68,17 +167,31 @@ def parse_args() -> None:
roop.globals.video_encoder = args.video_encoder
roop.globals.video_quality = args.video_quality
roop.globals.max_memory = args.max_memory
roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
roop.globals.execution_providers = decode_execution_providers(
args.execution_provider
)
roop.globals.execution_threads = args.execution_threads
def encode_execution_providers(execution_providers: List[str]) -> List[str]:
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
return [
execution_provider.replace("ExecutionProvider", "").lower()
for execution_provider in execution_providers
]
def decode_execution_providers(execution_providers: List[str]) -> List[str]:
return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
return [
provider
for provider, encoded_execution_provider in zip(
onnxruntime.get_available_providers(),
encode_execution_providers(onnxruntime.get_available_providers()),
)
if any(
execution_provider in encoded_execution_provider
for execution_provider in execution_providers
)
]
def suggest_execution_providers() -> List[str]:
@ -86,44 +199,55 @@ def suggest_execution_providers() -> List[str]:
def suggest_execution_threads() -> int:
if 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
return 8
return 1
def limit_resources() -> None:
# prevent tensorflow memory leak
gpus = tensorflow.config.experimental.list_physical_devices('GPU')
gpus = tensorflow.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
tensorflow.config.experimental.set_virtual_device_configuration(gpu, [
tensorflow.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)
])
tensorflow.config.experimental.set_virtual_device_configuration(
gpu,
[
tensorflow.config.experimental.VirtualDeviceConfiguration(
memory_limit=1024
)
],
)
# limit memory usage
if roop.globals.max_memory:
memory = roop.globals.max_memory * 1024 ** 3
if platform.system().lower() == 'darwin':
memory = roop.globals.max_memory * 1024 ** 6
if platform.system().lower() == 'windows':
memory = roop.globals.max_memory * 1024**3
if platform.system().lower() == "darwin":
memory = roop.globals.max_memory * 1024**6
if platform.system().lower() == "windows":
import ctypes
kernel32 = ctypes.windll.kernel32
kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
kernel32.SetProcessWorkingSetSize(
-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory)
)
else:
import resource
resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
def pre_check() -> bool:
if sys.version_info < (3, 9):
update_status('Python version is not supported - please upgrade to 3.9 or higher.')
update_status(
"Python version is not supported - please upgrade to 3.9 or higher."
)
return False
if not shutil.which('ffmpeg'):
update_status('ffmpeg is not installed.')
if not shutil.which("ffmpeg"):
update_status("ffmpeg is not installed.")
return False
return True
def update_status(message: str, scope: str = 'ROOP.CORE') -> None:
print(f'[{scope}] {message}')
def update_status(message: str, scope: str = "ROOP.CORE") -> None:
print(f"[{scope}] {message}")
if not roop.globals.headless:
ui.update_status(message)
@ -138,60 +262,66 @@ def start() -> None:
destroy()
shutil.copy2(roop.globals.target_path, roop.globals.output_path)
# process frame
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
update_status('Progressing...', frame_processor.NAME)
frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
for frame_processor in get_frame_processors_modules(
roop.globals.frame_processors
):
update_status("Progressing...", frame_processor.NAME)
frame_processor.process_image(
roop.globals.source_path,
roop.globals.output_path,
roop.globals.output_path,
)
frame_processor.post_process()
# validate image
if is_image(roop.globals.target_path):
update_status('Processing to image succeed!')
update_status("Processing to image succeed!")
else:
update_status('Processing to image failed!')
update_status("Processing to image failed!")
return
# process image to videos
if predict_video(roop.globals.target_path):
destroy()
update_status('Creating temp resources...')
update_status("Creating temp resources...")
create_temp(roop.globals.target_path)
# extract frames
if roop.globals.keep_fps:
fps = detect_fps(roop.globals.target_path)
update_status(f'Extracting frames with {fps} FPS...')
update_status(f"Extracting frames with {fps} FPS...")
extract_frames(roop.globals.target_path, fps)
else:
update_status('Extracting frames with 30 FPS...')
update_status("Extracting frames with 30 FPS...")
extract_frames(roop.globals.target_path)
# process frame
temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
update_status('Progressing...', frame_processor.NAME)
update_status("Progressing...", frame_processor.NAME)
frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
frame_processor.post_process()
# create video
if roop.globals.keep_fps:
fps = detect_fps(roop.globals.target_path)
update_status(f'Creating video with {fps} FPS...')
update_status(f"Creating video with {fps} FPS...")
create_video(roop.globals.target_path, fps)
else:
update_status('Creating video with 30 FPS...')
update_status("Creating video with 30 FPS...")
create_video(roop.globals.target_path)
# handle audio
if roop.globals.skip_audio:
move_temp(roop.globals.target_path, roop.globals.output_path)
update_status('Skipping audio...')
update_status("Skipping audio...")
else:
if roop.globals.keep_fps:
update_status('Restoring audio...')
update_status("Restoring audio...")
else:
update_status('Restoring audio might cause issues as fps are not kept...')
update_status("Restoring audio might cause issues as fps are not kept...")
restore_audio(roop.globals.target_path, roop.globals.output_path)
# clean temp
clean_temp(roop.globals.target_path)
# validate video
if is_video(roop.globals.target_path):
update_status('Processing to video succeed!')
update_status("Processing to video succeed!")
else:
update_status('Processing to video failed!')
update_status("Processing to video failed!")
def destroy() -> None:

View File

@ -37,6 +37,8 @@ preview_slider = None
source_label = None
target_label = None
status_label = None
video_quality_display = None
video_quality_value = None
# todo: remove by native support -> https://github.com/TomSchimansky/CustomTkinter/issues/934
@ -143,11 +145,16 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
)
many_faces_switch.place(relx=0.6, rely=0.65)
# Add a slider for video quality
video_quality_label = ctk.CTkLabel(root, text="Quality", justify="center")
video_quality_label.place(relx=0.1, rely=0.5, relwidth=0.3, relheight=0.1)
video_quality_label.place(relx=0.1, rely=0.7)
video_quality_value = ctk.IntVar(value=10) # Set a default value for video quality
video_quality_display = ctk.CTkLabel(
root, textvariable=video_quality_value, justify="center"
)
video_quality_display.place(relx=0.2, rely=0.7)
video_quality_value = ctk.IntVar(value=25) # Set a default value for video quality
video_quality_slider = ctk.CTkSlider(
root,
from_=0,
@ -155,22 +162,23 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
variable=video_quality_value,
cursor="hand2",
)
video_quality_slider.place(relx=0.4, rely=0.5, relwidth=0.9, relheight=0.01)
video_quality_slider.place(relx=0.09, rely=0.75, relwidth=0.8, relheight=0.03)
video_quality_slider.bind("<<Value>>", update_video_quality_value)
start_button = ctk.CTkButton(
root, text="Start", cursor="hand2", command=lambda: select_output_path(start)
)
start_button.place(relx=0.15, rely=0.75, relwidth=0.2, relheight=0.05)
start_button.place(relx=0.15, rely=0.83, relwidth=0.2, relheight=0.05)
stop_button = ctk.CTkButton(
root, text="Destroy", cursor="hand2", command=lambda: destroy()
)
stop_button.place(relx=0.4, rely=0.75, relwidth=0.2, relheight=0.05)
stop_button.place(relx=0.4, rely=0.83, relwidth=0.2, relheight=0.05)
preview_button = ctk.CTkButton(
root, text="Preview", cursor="hand2", command=lambda: toggle_preview()
)
preview_button.place(relx=0.65, rely=0.75, relwidth=0.2, relheight=0.05)
preview_button.place(relx=0.65, rely=0.83, relwidth=0.2, relheight=0.05)
status_label = ctk.CTkLabel(root, text=None, justify="center")
status_label.place(relx=0.1, rely=0.9, relwidth=0.8)
@ -189,6 +197,34 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
return root
def update_video_quality_value() -> None:
video_quality = video_quality_value.get()
video_quality_display.configure(text=str(video_quality))
setattr(roop.globals, "video_quality", video_quality)
ROOT.update()
def show_alert_box(title: str, message: str) -> None:
alert_window = ctk.CTkToplevel()
alert_window.title(title)
alert_width = 300
alert_height = 100
alert_window.geometry(f"{alert_width}x{alert_height}")
root_x = (ROOT_WIDTH - alert_width) // 2
root_y = (ROOT_HEIGHT - alert_height) // 2
alert_window.geometry(f"+{root_x}+{root_y}")
alert_window.minsize(alert_width, alert_height)
alert_label = ctk.CTkLabel(alert_window, text=message)
alert_label.pack(expand=True)
ok_button = ctk.CTkButton(alert_window, text="OK", command=alert_window.destroy)
ok_button.pack(pady=10)
def create_preview(parent: ctk.CTkToplevel) -> ctk.CTkToplevel:
global preview_label, preview_slider
@ -282,6 +318,14 @@ def select_output_path(start: Callable[[], None]) -> None:
roop.globals.output_path = output_path
RECENT_DIRECTORY_OUTPUT = os.path.dirname(roop.globals.output_path)
start()
else:
# Check if source and target paths are provided
if not roop.globals.source_path:
show_alert_box("Missing Input", "Please select an input image.")
return
if not roop.globals.target_path:
show_alert_box("Missing Target", "Please select a target image or video.")
return
def render_image_preview(image_path: str, size: Tuple[int, int]) -> ctk.CTkImage: