diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 81cbd78..0b8f4ce 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,9 @@ jobs: with: python-version: 3.9 - run: pip install flake8 + - run: pip install mypy - run: flake8 run.py roop + - run: mypy --config-file mypi.ini run.py roop test: runs-on: ubuntu-latest steps: diff --git a/mypi.ini b/mypi.ini new file mode 100644 index 0000000..64218bc --- /dev/null +++ b/mypi.ini @@ -0,0 +1,7 @@ +[mypy] +check_untyped_defs = True +disallow_any_generics = True +disallow_untyped_calls = True +disallow_untyped_defs = True +ignore_missing_imports = True +strict_optional = False diff --git a/roop/core.py b/roop/core.py index f59a0e3..0bdbc81 100755 --- a/roop/core.py +++ b/roop/core.py @@ -87,7 +87,7 @@ def parse_args() -> None: roop.globals.execution_providers = decode_execution_providers(['cuda']) if args.gpu_vendor_deprecated == 'amd': print('\033[33mArgument --gpu-vendor amd is deprecated. Use --execution-provider cuda instead.\033[0m') - roop.globals.execution_threads = decode_execution_providers(['rocm']) + roop.globals.execution_providers = decode_execution_providers(['rocm']) if args.gpu_threads_deprecated: print('\033[33mArgument --gpu-threads is deprecated. Use --execution-threads instead.\033[0m') roop.globals.execution_threads = args.gpu_threads_deprecated diff --git a/roop/globals.py b/roop/globals.py index c66fef2..77fd391 100644 --- a/roop/globals.py +++ b/roop/globals.py @@ -1,7 +1,9 @@ +from typing import List + source_path = None target_path = None output_path = None -frame_processors = [] +frame_processors: List[str] = [] keep_fps = None keep_audio = None keep_frames = None @@ -9,7 +11,7 @@ many_faces = None video_encoder = None video_quality = None max_memory = None -execution_providers = [] +execution_providers: List[str] = [] execution_threads = None headless = None log_level = 'error' diff --git a/roop/processors/frame/core.py b/roop/processors/frame/core.py index a64293e..6118e4e 100644 --- a/roop/processors/frame/core.py +++ b/roop/processors/frame/core.py @@ -1,13 +1,14 @@ import sys import importlib from concurrent.futures import ThreadPoolExecutor -from typing import Any, List +from types import ModuleType +from typing import Any, List, Callable from tqdm import tqdm import roop from roop import state -FRAME_PROCESSORS_MODULES = None +FRAME_PROCESSORS_MODULES: List[ModuleType] = [] FRAME_PROCESSORS_INTERFACE = [ 'pre_check', 'pre_start', @@ -28,17 +29,17 @@ def load_frame_processor_module(frame_processor: str) -> Any: return frame_processor_module -def get_frame_processors_modules(frame_processors): +def get_frame_processors_modules(frame_processors: List[str]) -> List[ModuleType]: global FRAME_PROCESSORS_MODULES - if FRAME_PROCESSORS_MODULES is None: - FRAME_PROCESSORS_MODULES = [] + + if not FRAME_PROCESSORS_MODULES: for frame_processor in frame_processors: frame_processor_module = load_frame_processor_module(frame_processor) FRAME_PROCESSORS_MODULES.append(frame_processor_module) return FRAME_PROCESSORS_MODULES -def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_frames, progress) -> None: +def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_frames: Callable[[str, List[str], Any], None], progress: Any = None) -> None: with ThreadPoolExecutor(max_workers=roop.globals.execution_threads) as executor: futures = [] for path in temp_frame_paths: @@ -48,7 +49,7 @@ def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_f future.result() -def process_video(source_path: str, frame_paths: list[str], process_frames: Any) -> None: +def process_video(source_path: str, frame_paths: list[str], process_frames: Callable[[str, List[str], Any], None]) -> None: progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' total = state.total_frames_count(roop.globals.target_path) with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format, initial=state.processed_frames_count(roop.globals.target_path)) as progress: diff --git a/roop/processors/frame/face_enhancer.py b/roop/processors/frame/face_enhancer.py index c0eb6cf..1dc522c 100644 --- a/roop/processors/frame/face_enhancer.py +++ b/roop/processors/frame/face_enhancer.py @@ -19,7 +19,7 @@ NAME = 'ROOP.FACE-ENHANCER' def pre_check() -> bool: download_directory_path = resolve_relative_path('../models') - conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/GFPGANv1.3.pth']) + conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/GFPGANv1.4.pth']) return True @@ -30,14 +30,14 @@ def pre_start() -> bool: return True -def get_face_enhancer() -> None: +def get_face_enhancer() -> Any: global FACE_ENHANCER with THREAD_LOCK: if FACE_ENHANCER is None: - model_path = resolve_relative_path('../models/GFPGANv1.3.pth') + model_path = resolve_relative_path('../models/GFPGANv1.4.pth') # todo: set models path https://github.com/TencentARC/GFPGAN/issues/399 - FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) + FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined] return FACE_ENHANCER @@ -57,7 +57,7 @@ def process_frame(source_face: Any, temp_frame: Any) -> Any: return temp_frame -def process_frames(source_path: str, temp_frame_paths: List[str], progress=None) -> None: +def process_frames(source_path: str, temp_frame_paths: List[str], progress: Any = None) -> None: for temp_frame_path in temp_frame_paths: temp_frame = cv2.imread(temp_frame_path) result = process_frame(None, temp_frame) diff --git a/roop/processors/frame/face_swapper.py b/roop/processors/frame/face_swapper.py index 580a600..7a67060 100644 --- a/roop/processors/frame/face_swapper.py +++ b/roop/processors/frame/face_swapper.py @@ -35,7 +35,7 @@ def pre_start() -> bool: return True -def get_face_swapper() -> None: +def get_face_swapper() -> Any: global FACE_SWAPPER with THREAD_LOCK: @@ -62,7 +62,7 @@ def process_frame(source_face: Any, temp_frame: Any) -> Any: return temp_frame -def process_frames(source_path: str, temp_frame_paths: List[str], progress=None) -> None: +def process_frames(source_path: str, temp_frame_paths: List[str], progress: Any = None) -> None: source_face = get_one_face(cv2.imread(source_path)) for temp_frame_path in temp_frame_paths: temp_frame = cv2.imread(temp_frame_path) diff --git a/roop/ui.py b/roop/ui.py index 134fb4a..d672234 100644 --- a/roop/ui.py +++ b/roop/ui.py @@ -12,16 +12,26 @@ from roop.predicter import predict_frame from roop.processors.frame.core import get_frame_processors_modules from roop.utilities import is_image, is_video, resolve_relative_path -WINDOW_HEIGHT = 700 -WINDOW_WIDTH = 600 +ROOT = None +ROOT_HEIGHT = 700 +ROOT_WIDTH = 600 + +PREVIEW = None PREVIEW_MAX_HEIGHT = 700 PREVIEW_MAX_WIDTH = 1200 + RECENT_DIRECTORY_SOURCE = None RECENT_DIRECTORY_TARGET = None RECENT_DIRECTORY_OUTPUT = None +preview_label = None +preview_slider = None +source_label = None +target_label = None +status_label = None -def init(start: Callable, destroy: Callable) -> ctk.CTk: + +def init(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk: global ROOT, PREVIEW ROOT = create_root(start, destroy) @@ -30,14 +40,14 @@ def init(start: Callable, destroy: Callable) -> ctk.CTk: return ROOT -def create_root(start: Callable, destroy: Callable) -> ctk.CTk: +def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk: global source_label, target_label, status_label ctk.deactivate_automatic_dpi_awareness() ctk.set_appearance_mode('system') ctk.set_default_color_theme(resolve_relative_path('ui.json')) root = ctk.CTk() - root.minsize(WINDOW_WIDTH, WINDOW_HEIGHT) + root.minsize(ROOT_WIDTH, ROOT_HEIGHT) root.title('roop') root.configure() root.protocol('WM_DELETE_WINDOW', lambda: destroy()) @@ -85,7 +95,7 @@ def create_root(start: Callable, destroy: Callable) -> ctk.CTk: return root -def create_preview(parent) -> ctk.CTkToplevel: +def create_preview(parent: ctk.CTkToplevel) -> ctk.CTkToplevel: global preview_label, preview_slider preview = ctk.CTkToplevel(parent) @@ -143,7 +153,7 @@ def select_target_path() -> None: target_label.configure(image=None) -def select_output_path(start): +def select_output_path(start: Callable[[], None]) -> None: global RECENT_DIRECTORY_OUTPUT if is_image(roop.globals.target_path): @@ -158,14 +168,14 @@ def select_output_path(start): start() -def render_image_preview(image_path: str, size: Tuple[int, int] = None) -> ctk.CTkImage: +def render_image_preview(image_path: str, size: Tuple[int, int]) -> ctk.CTkImage: image = Image.open(image_path) if size: image = ImageOps.fit(image, size, Image.LANCZOS) return ctk.CTkImage(image, size=image.size) -def render_video_preview(video_path: str, size: Tuple[int, int] = None, frame_number: int = 0) -> ctk.CTkImage: +def render_video_preview(video_path: str, size: Tuple[int, int], frame_number: int = 0) -> ctk.CTkImage: capture = cv2.VideoCapture(video_path) if frame_number: capture.set(cv2.CAP_PROP_POS_FRAMES, frame_number) diff --git a/roop/utilities.py b/roop/utilities.py index bed206c..aa70e7c 100644 --- a/roop/utilities.py +++ b/roop/utilities.py @@ -7,7 +7,7 @@ import ssl import subprocess import urllib from pathlib import Path -from typing import List +from typing import List, Any from tqdm import tqdm import roop.globals @@ -45,7 +45,7 @@ def detect_fps(target_path: str) -> float: def extract_frames(target_path: str) -> None: temp_directory_path = get_temp_directory_path(target_path) - run_ffmpeg(['-i', target_path, os.path.join(temp_directory_path, '%04d.png')]) + run_ffmpeg(['-i', target_path, '-pix_fmt', 'rgb24', '-sws_flags', '+accurate_rnd+full_chroma_int', '-colorspace', '1', '-color_primaries', '1', '-color_trc', '1', os.path.join(temp_directory_path, '%04d.png')]) def create_video(target_path: str, fps: float = 30.0) -> None: @@ -77,7 +77,7 @@ def get_temp_output_path(target_path: str) -> str: return os.path.join(temp_directory_path, TEMP_FILE) -def normalize_output_path(source_path: str, target_path: str, output_path: str) -> str: +def normalize_output_path(source_path: str, target_path: str, output_path: str) -> Any: if source_path and target_path: source_name, _ = os.path.splitext(os.path.basename(source_path)) target_name, target_extension = os.path.splitext(os.path.basename(target_path)) @@ -115,14 +115,14 @@ def has_image_extension(image_path: str) -> bool: def is_image(image_path: str) -> bool: if image_path and os.path.isfile(image_path): mimetype, _ = mimetypes.guess_type(image_path) - return mimetype and mimetype.startswith('image/') + return bool(mimetype and mimetype.startswith('image/')) return False def is_video(video_path: str) -> bool: if video_path and os.path.isfile(video_path): mimetype, _ = mimetypes.guess_type(video_path) - return mimetype and mimetype.startswith('video/') + return bool(mimetype and mimetype.startswith('video/')) return False @@ -132,10 +132,10 @@ def conditional_download(download_directory_path: str, urls: List[str]) -> None: for url in urls: download_file_path = os.path.join(download_directory_path, os.path.basename(url)) if not os.path.exists(download_file_path): - request = urllib.request.urlopen(url) + request = urllib.request.urlopen(url) # type: ignore[attr-defined] total = int(request.headers.get('Content-Length', 0)) with tqdm(total=total, desc='Downloading', unit='B', unit_scale=True, unit_divisor=1024) as progress: - urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) + urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) # type: ignore[attr-defined] def resolve_relative_path(path: str) -> str: