From b104741e689daea5180a083cb8c4f2f6fcf83e52 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Fri, 30 Jun 2023 09:04:33 +0200 Subject: [PATCH] Improve return typing --- roop/capturer.py | 6 ++++-- roop/face_analyser.py | 12 ++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/roop/capturer.py b/roop/capturer.py index fd49d46..515fc8e 100644 --- a/roop/capturer.py +++ b/roop/capturer.py @@ -1,8 +1,10 @@ -from typing import Any +from typing import Optional import cv2 +from roop.typing import Frame -def get_video_frame(video_path: str, frame_number: int = 0) -> Any: + +def get_video_frame(video_path: str, frame_number: int = 0) -> Optional[Frame]: capture = cv2.VideoCapture(video_path) frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT) capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1)) diff --git a/roop/face_analyser.py b/roop/face_analyser.py index 9c0afe4..287b3f3 100644 --- a/roop/face_analyser.py +++ b/roop/face_analyser.py @@ -1,9 +1,9 @@ import threading -from typing import Any +from typing import Any, Optional, List import insightface import roop.globals -from roop.typing import Frame +from roop.typing import Frame, Face FACE_ANALYSER = None THREAD_LOCK = threading.Lock() @@ -19,15 +19,15 @@ def get_face_analyser() -> Any: return FACE_ANALYSER -def get_one_face(frame: Frame) -> Any: - face = get_face_analyser().get(frame) +def get_one_face(frame: Frame) -> Optional[Face]: + faces = get_many_faces(frame) try: - return min(face, key=lambda x: x.bbox[0]) + return faces[0] except ValueError: return None -def get_many_faces(frame: Frame) -> Any: +def get_many_faces(frame: Frame) -> Optional[List[Face]]: try: return get_face_analyser().get(frame) except IndexError: