Extract common methods to processors.frame.core

This commit is contained in:
henryruhs 2023-06-13 18:19:16 +02:00
parent e862700116
commit 11e641937c
3 changed files with 41 additions and 56 deletions

View File

@ -1,6 +1,11 @@
import sys import sys
import importlib import importlib
from typing import Any import threading
from typing import Any, List
from tqdm import tqdm
import roop
FRAME_PROCESSORS_MODULES = None FRAME_PROCESSORS_MODULES = None
@ -22,3 +27,31 @@ def get_frame_processors_modules(frame_processors):
FRAME_PROCESSORS_MODULES.append(frame_processor_module) FRAME_PROCESSORS_MODULES.append(frame_processor_module)
return FRAME_PROCESSORS_MODULES return FRAME_PROCESSORS_MODULES
def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_frames, progress) -> None:
threads = []
frames_per_thread = len(temp_frame_paths) // roop.globals.execution_threads
remaining_frames = len(temp_frame_paths) % roop.globals.execution_threads
start_index = 0
# create threads by frames
for _ in range(roop.globals.execution_threads):
end_index = start_index + frames_per_thread
if remaining_frames > 0:
end_index += 1
remaining_frames -= 1
thread_paths = temp_frame_paths[start_index:end_index]
thread = threading.Thread(target=process_frames, args=(source_path, thread_paths, progress))
threads.append(thread)
thread.start()
start_index = end_index
# join threads
for thread in threads:
thread.join()
def process_video(source_path: str, frame_paths: list[str], process_frames: Any) -> None:
progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
total = len(frame_paths)
with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
progress.set_postfix({'execution_providers': roop.globals.execution_providers, 'threads': roop.globals.execution_threads, 'memory': roop.globals.max_memory})
multi_process_frame(source_path, frame_paths, process_frames, progress)

View File

@ -1,13 +1,15 @@
from typing import List
import cv2 import cv2
import torch import torch
import threading import threading
from tqdm import tqdm
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
from codeformer.basicsr.utils.registry import ARCH_REGISTRY from codeformer.basicsr.utils.registry import ARCH_REGISTRY
from codeformer.basicsr.utils import img2tensor, tensor2img from codeformer.basicsr.utils import img2tensor, tensor2img
import roop.globals import roop.globals
import roop.processors.frame.core
from roop.utilities import conditional_download, resolve_relative_path from roop.utilities import conditional_download, resolve_relative_path
if 'ROCMExecutionProvider' in roop.globals.execution_providers: if 'ROCMExecutionProvider' in roop.globals.execution_providers:
@ -136,30 +138,5 @@ def process_frames(source_path: str, frame_paths: list[str], progress=None) -> N
progress.update(1) progress.update(1)
def multi_process_frame(source_img, frame_paths, progress) -> None: def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
threads = [] roop.processors.frame.core.process_video(source_path, temp_frame_paths, process_frames)
frames_per_thread = len(frame_paths) // roop.globals.execution_threads
remaining_frames = len(frame_paths) % roop.globals.execution_threads
start_index = 0
# create threads by frames
for _ in range(roop.globals.execution_threads):
end_index = start_index + frames_per_thread
if remaining_frames > 0:
end_index += 1
remaining_frames -= 1
thread_frame_paths = frame_paths[start_index:end_index]
thread = threading.Thread(target=process_frames, args=(source_img, thread_frame_paths, progress))
threads.append(thread)
thread.start()
start_index = end_index
# join threads
for thread in threads:
thread.join()
def process_video(source_path: str, frame_paths: list[str]) -> None:
progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
total = len(frame_paths)
with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
progress.set_postfix({'execution_providers': roop.globals.execution_providers, 'threads': roop.globals.execution_threads, 'memory': roop.globals.max_memory})
multi_process_frame(source_path, frame_paths, progress)

View File

@ -1,10 +1,10 @@
from typing import Any, List from typing import Any, List
from tqdm import tqdm
import cv2 import cv2
import insightface import insightface
import threading import threading
import roop.globals import roop.globals
import roop.processors.frame.core
from roop.face_analyser import get_one_face, get_many_faces from roop.face_analyser import get_one_face, get_many_faces
from roop.utilities import conditional_download, resolve_relative_path from roop.utilities import conditional_download, resolve_relative_path
@ -60,27 +60,6 @@ def process_frames(source_path: str, temp_frame_paths: List[str], progress=None)
progress.update(1) progress.update(1)
def multi_process_frame(source_path: str, temp_frame_paths: List[str], progress) -> None:
threads = []
frames_per_thread = len(temp_frame_paths) // roop.globals.execution_threads
remaining_frames = len(temp_frame_paths) % roop.globals.execution_threads
start_index = 0
# create threads by frames
for _ in range(roop.globals.execution_threads):
end_index = start_index + frames_per_thread
if remaining_frames > 0:
end_index += 1
remaining_frames -= 1
thread_paths = temp_frame_paths[start_index:end_index]
thread = threading.Thread(target=process_frames, args=(source_path, thread_paths, progress))
threads.append(thread)
thread.start()
start_index = end_index
# join threads
for thread in threads:
thread.join()
def process_image(source_path: str, target_path: str, output_path: str) -> None: def process_image(source_path: str, target_path: str, output_path: str) -> None:
source_face = get_one_face(cv2.imread(source_path)) source_face = get_one_face(cv2.imread(source_path))
target_frame = cv2.imread(target_path) target_frame = cv2.imread(target_path)
@ -89,8 +68,4 @@ def process_image(source_path: str, target_path: str, output_path: str) -> None:
def process_video(source_path: str, temp_frame_paths: List[str]) -> None: def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' roop.processors.frame.core.process_video(source_path, temp_frame_paths, process_frames)
total = len(temp_frame_paths)
with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
progress.set_postfix({'execution_providers': roop.globals.execution_providers, 'threads': roop.globals.execution_threads, 'memory': roop.globals.max_memory})
multi_process_frame(source_path, temp_frame_paths, progress)