Merge branch 'next' into issue_377

This commit is contained in:
Pozitronik 2023-06-19 19:47:01 +04:00
commit d6cd54b36c
9 changed files with 55 additions and 33 deletions

View File

@ -13,7 +13,9 @@ jobs:
with:
python-version: 3.9
- run: pip install flake8
- run: pip install mypy
- run: flake8 run.py roop
- run: mypy --config-file mypi.ini run.py roop
test:
runs-on: ubuntu-latest
steps:

7
mypi.ini Normal file
View File

@ -0,0 +1,7 @@
[mypy]
check_untyped_defs = True
disallow_any_generics = True
disallow_untyped_calls = True
disallow_untyped_defs = True
ignore_missing_imports = True
strict_optional = False

View File

@ -87,7 +87,7 @@ def parse_args() -> None:
roop.globals.execution_providers = decode_execution_providers(['cuda'])
if args.gpu_vendor_deprecated == 'amd':
print('\033[33mArgument --gpu-vendor amd is deprecated. Use --execution-provider cuda instead.\033[0m')
roop.globals.execution_threads = decode_execution_providers(['rocm'])
roop.globals.execution_providers = decode_execution_providers(['rocm'])
if args.gpu_threads_deprecated:
print('\033[33mArgument --gpu-threads is deprecated. Use --execution-threads instead.\033[0m')
roop.globals.execution_threads = args.gpu_threads_deprecated

View File

@ -1,7 +1,9 @@
from typing import List
source_path = None
target_path = None
output_path = None
frame_processors = []
frame_processors: List[str] = []
keep_fps = None
keep_audio = None
keep_frames = None
@ -9,7 +11,7 @@ many_faces = None
video_encoder = None
video_quality = None
max_memory = None
execution_providers = []
execution_providers: List[str] = []
execution_threads = None
headless = None
log_level = 'error'

View File

@ -1,13 +1,14 @@
import sys
import importlib
from concurrent.futures import ThreadPoolExecutor
from typing import Any, List
from types import ModuleType
from typing import Any, List, Callable
from tqdm import tqdm
import roop
from roop import state
FRAME_PROCESSORS_MODULES = None
FRAME_PROCESSORS_MODULES: List[ModuleType] = []
FRAME_PROCESSORS_INTERFACE = [
'pre_check',
'pre_start',
@ -28,17 +29,17 @@ def load_frame_processor_module(frame_processor: str) -> Any:
return frame_processor_module
def get_frame_processors_modules(frame_processors):
def get_frame_processors_modules(frame_processors: List[str]) -> List[ModuleType]:
global FRAME_PROCESSORS_MODULES
if FRAME_PROCESSORS_MODULES is None:
FRAME_PROCESSORS_MODULES = []
if not FRAME_PROCESSORS_MODULES:
for frame_processor in frame_processors:
frame_processor_module = load_frame_processor_module(frame_processor)
FRAME_PROCESSORS_MODULES.append(frame_processor_module)
return FRAME_PROCESSORS_MODULES
def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_frames, progress) -> None:
def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_frames: Callable[[str, List[str], Any], None], progress: Any = None) -> None:
with ThreadPoolExecutor(max_workers=roop.globals.execution_threads) as executor:
futures = []
for path in temp_frame_paths:
@ -48,7 +49,7 @@ def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_f
future.result()
def process_video(source_path: str, frame_paths: list[str], process_frames: Any) -> None:
def process_video(source_path: str, frame_paths: list[str], process_frames: Callable[[str, List[str], Any], None]) -> None:
progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
total = state.total_frames_count(roop.globals.target_path)
with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format, initial=state.processed_frames_count(roop.globals.target_path)) as progress:

View File

@ -19,7 +19,7 @@ NAME = 'ROOP.FACE-ENHANCER'
def pre_check() -> bool:
download_directory_path = resolve_relative_path('../models')
conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/GFPGANv1.3.pth'])
conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/GFPGANv1.4.pth'])
return True
@ -30,14 +30,14 @@ def pre_start() -> bool:
return True
def get_face_enhancer() -> None:
def get_face_enhancer() -> Any:
global FACE_ENHANCER
with THREAD_LOCK:
if FACE_ENHANCER is None:
model_path = resolve_relative_path('../models/GFPGANv1.3.pth')
model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
# todo: set models path https://github.com/TencentARC/GFPGAN/issues/399
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1)
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined]
return FACE_ENHANCER
@ -57,7 +57,7 @@ def process_frame(source_face: Any, temp_frame: Any) -> Any:
return temp_frame
def process_frames(source_path: str, temp_frame_paths: List[str], progress=None) -> None:
def process_frames(source_path: str, temp_frame_paths: List[str], progress: Any = None) -> None:
for temp_frame_path in temp_frame_paths:
temp_frame = cv2.imread(temp_frame_path)
result = process_frame(None, temp_frame)

View File

@ -35,7 +35,7 @@ def pre_start() -> bool:
return True
def get_face_swapper() -> None:
def get_face_swapper() -> Any:
global FACE_SWAPPER
with THREAD_LOCK:
@ -62,7 +62,7 @@ def process_frame(source_face: Any, temp_frame: Any) -> Any:
return temp_frame
def process_frames(source_path: str, temp_frame_paths: List[str], progress=None) -> None:
def process_frames(source_path: str, temp_frame_paths: List[str], progress: Any = None) -> None:
source_face = get_one_face(cv2.imread(source_path))
for temp_frame_path in temp_frame_paths:
temp_frame = cv2.imread(temp_frame_path)

View File

@ -12,16 +12,26 @@ from roop.predicter import predict_frame
from roop.processors.frame.core import get_frame_processors_modules
from roop.utilities import is_image, is_video, resolve_relative_path
WINDOW_HEIGHT = 700
WINDOW_WIDTH = 600
ROOT = None
ROOT_HEIGHT = 700
ROOT_WIDTH = 600
PREVIEW = None
PREVIEW_MAX_HEIGHT = 700
PREVIEW_MAX_WIDTH = 1200
RECENT_DIRECTORY_SOURCE = None
RECENT_DIRECTORY_TARGET = None
RECENT_DIRECTORY_OUTPUT = None
preview_label = None
preview_slider = None
source_label = None
target_label = None
status_label = None
def init(start: Callable, destroy: Callable) -> ctk.CTk:
def init(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk:
global ROOT, PREVIEW
ROOT = create_root(start, destroy)
@ -30,14 +40,14 @@ def init(start: Callable, destroy: Callable) -> ctk.CTk:
return ROOT
def create_root(start: Callable, destroy: Callable) -> ctk.CTk:
def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk:
global source_label, target_label, status_label
ctk.deactivate_automatic_dpi_awareness()
ctk.set_appearance_mode('system')
ctk.set_default_color_theme(resolve_relative_path('ui.json'))
root = ctk.CTk()
root.minsize(WINDOW_WIDTH, WINDOW_HEIGHT)
root.minsize(ROOT_WIDTH, ROOT_HEIGHT)
root.title('roop')
root.configure()
root.protocol('WM_DELETE_WINDOW', lambda: destroy())
@ -85,7 +95,7 @@ def create_root(start: Callable, destroy: Callable) -> ctk.CTk:
return root
def create_preview(parent) -> ctk.CTkToplevel:
def create_preview(parent: ctk.CTkToplevel) -> ctk.CTkToplevel:
global preview_label, preview_slider
preview = ctk.CTkToplevel(parent)
@ -143,7 +153,7 @@ def select_target_path() -> None:
target_label.configure(image=None)
def select_output_path(start):
def select_output_path(start: Callable[[], None]) -> None:
global RECENT_DIRECTORY_OUTPUT
if is_image(roop.globals.target_path):
@ -158,14 +168,14 @@ def select_output_path(start):
start()
def render_image_preview(image_path: str, size: Tuple[int, int] = None) -> ctk.CTkImage:
def render_image_preview(image_path: str, size: Tuple[int, int]) -> ctk.CTkImage:
image = Image.open(image_path)
if size:
image = ImageOps.fit(image, size, Image.LANCZOS)
return ctk.CTkImage(image, size=image.size)
def render_video_preview(video_path: str, size: Tuple[int, int] = None, frame_number: int = 0) -> ctk.CTkImage:
def render_video_preview(video_path: str, size: Tuple[int, int], frame_number: int = 0) -> ctk.CTkImage:
capture = cv2.VideoCapture(video_path)
if frame_number:
capture.set(cv2.CAP_PROP_POS_FRAMES, frame_number)

View File

@ -7,7 +7,7 @@ import ssl
import subprocess
import urllib
from pathlib import Path
from typing import List
from typing import List, Any
from tqdm import tqdm
import roop.globals
@ -45,7 +45,7 @@ def detect_fps(target_path: str) -> float:
def extract_frames(target_path: str) -> None:
temp_directory_path = get_temp_directory_path(target_path)
run_ffmpeg(['-i', target_path, os.path.join(temp_directory_path, '%04d.png')])
run_ffmpeg(['-i', target_path, '-pix_fmt', 'rgb24', '-sws_flags', '+accurate_rnd+full_chroma_int', '-colorspace', '1', '-color_primaries', '1', '-color_trc', '1', os.path.join(temp_directory_path, '%04d.png')])
def create_video(target_path: str, fps: float = 30.0) -> None:
@ -77,7 +77,7 @@ def get_temp_output_path(target_path: str) -> str:
return os.path.join(temp_directory_path, TEMP_FILE)
def normalize_output_path(source_path: str, target_path: str, output_path: str) -> str:
def normalize_output_path(source_path: str, target_path: str, output_path: str) -> Any:
if source_path and target_path:
source_name, _ = os.path.splitext(os.path.basename(source_path))
target_name, target_extension = os.path.splitext(os.path.basename(target_path))
@ -115,14 +115,14 @@ def has_image_extension(image_path: str) -> bool:
def is_image(image_path: str) -> bool:
if image_path and os.path.isfile(image_path):
mimetype, _ = mimetypes.guess_type(image_path)
return mimetype and mimetype.startswith('image/')
return bool(mimetype and mimetype.startswith('image/'))
return False
def is_video(video_path: str) -> bool:
if video_path and os.path.isfile(video_path):
mimetype, _ = mimetypes.guess_type(video_path)
return mimetype and mimetype.startswith('video/')
return bool(mimetype and mimetype.startswith('video/'))
return False
@ -132,10 +132,10 @@ def conditional_download(download_directory_path: str, urls: List[str]) -> None:
for url in urls:
download_file_path = os.path.join(download_directory_path, os.path.basename(url))
if not os.path.exists(download_file_path):
request = urllib.request.urlopen(url)
request = urllib.request.urlopen(url) # type: ignore[attr-defined]
total = int(request.headers.get('Content-Length', 0))
with tqdm(total=total, desc='Downloading', unit='B', unit_scale=True, unit_divisor=1024) as progress:
urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size))
urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) # type: ignore[attr-defined]
def resolve_relative_path(path: str) -> str: