Merge pull request #497 from s0md3v/refactor-enhancer-gfpgan

Give gfpgan a shot
This commit is contained in:
Henry Ruhs 2023-06-15 14:18:57 +02:00 committed by GitHub
commit 1d28572829
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 119 deletions

View File

@ -26,7 +26,8 @@ from roop.utilities import has_image_extension, is_image, is_video, detect_fps,
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
del torch
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
def parse_args() -> None:
@ -89,10 +90,6 @@ def parse_args() -> None:
print('\033[33mArgument --gpu-threads is deprecated. Use --execution-threads instead.\033[0m')
roop.globals.execution_threads = args.gpu_threads_deprecated
# limit face enhancer to cuda
if 'CUDAExecutionProvider' not in roop.globals.execution_providers and 'face_enhancer' in roop.globals.frame_processors:
roop.globals.frame_processors.remove('face_enhancer')
def encode_execution_providers(execution_providers: List[str]) -> List[str]:
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]

View File

@ -1,28 +1,23 @@
from typing import List
from typing import Any, List
import cv2
import torch
import threading
from torchvision.transforms.functional import normalize
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
from codeformer.basicsr.utils import img2tensor, tensor2img
import gfpgan
import roop.globals
import roop.processors.frame.core
from roop.core import update_status
from roop.face_analyser import get_one_face, get_many_faces
from roop.utilities import conditional_download, resolve_relative_path, is_image, is_video
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
del torch
CODE_FORMER = None
FACE_ENHANCER = None
THREAD_SEMAPHORE = threading.Semaphore()
THREAD_LOCK = threading.Lock()
NAME = 'ROOP.FACE-ENHANCER'
def pre_check() -> bool:
download_directory_path = resolve_relative_path('../models')
conditional_download(download_directory_path, ['https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'])
conditional_download(download_directory_path, ['https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'])
return True
@ -33,124 +28,61 @@ def pre_start() -> bool:
return True
def get_code_former():
global CODE_FORMER
def get_face_enhancer() -> None:
global FACE_ENHANCER
with THREAD_LOCK:
model_path = resolve_relative_path('../models/codeformer.pth')
if CODE_FORMER is None:
model = torch.load(model_path)['params_ema']
CODE_FORMER = ARCH_REGISTRY.get('CodeFormer')(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=['32', '64', '128', '256'],
).to('cuda')
CODE_FORMER.load_state_dict(model)
CODE_FORMER.eval()
return CODE_FORMER
def get_face_enhancer(FACE_ENHANCER):
if FACE_ENHANCER is None:
FACE_ENHANCER = FaceRestoreHelper(
upscale_factor = int(2),
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
use_parse=True,
device='cuda'
)
if FACE_ENHANCER is None:
model_path = resolve_relative_path('../models/GFPGANv1.3.pth')
# todo: set models path https://github.com/TencentARC/GFPGAN/issues/399
FACE_ENHANCER = gfpgan.GFPGANer(
model_path=model_path,
channel_multiplier=2,
upscale=2
)
return FACE_ENHANCER
def enhance_face_in_frame(cropped_faces):
try:
faces_enhanced = []
for _, cropped_face in enumerate(cropped_faces):
face_in_tensor = normalize_face(cropped_face)
face_enhanced = restore_face(face_in_tensor)
faces_enhanced.append(face_enhanced)
return faces_enhanced
except RuntimeError as error:
print(f'Failed inference for CodeFormer-code: {error}')
def process_frame(source_face: any, temp_frame: any) -> any:
try:
face_helper = get_face_enhancer(None)
face_helper.read_image(temp_frame)
# get face landmarks for each face
face_helper.get_face_landmarks_5(
only_center_face=False, resize=640, eye_dist_threshold=5
def enhance_face(source_face: Any, target_face: Any, temp_frame: Any) -> Any:
THREAD_SEMAPHORE.acquire()
if target_face:
_, _, temp_frame = get_face_enhancer().enhance(
temp_frame,
paste_back=True
)
# align and warp each face
face_helper.align_warp_face()
cropped_faces = face_helper.cropped_faces
faces_enhanced = enhance_face_in_frame(cropped_faces)
for face_enhanced in faces_enhanced:
face_helper.add_restored_face(face_enhanced)
face_helper.get_inverse_affine()
result = face_helper.paste_faces_to_input_image()
face_helper.clean_all()
return result
except RuntimeError as error:
print(f'Failed inference for CodeFormer-code-paste: {error}')
THREAD_SEMAPHORE.release()
return temp_frame
def normalize_face(face):
face_in_tensor = img2tensor(face / 255.0, bgr2rgb=True, float32=True)
normalize(face_in_tensor, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
return face_in_tensor.unsqueeze(0).to('cuda')
def process_frame(source_face: Any, temp_frame: Any) -> Any:
if roop.globals.many_faces:
many_faces = get_many_faces(temp_frame)
if many_faces:
for target_face in many_faces:
temp_frame = enhance_face(source_face, target_face, temp_frame)
else:
target_face = get_one_face(temp_frame)
if target_face:
temp_frame = enhance_face(source_face, target_face, temp_frame)
return temp_frame
def enhance_face_in_tensor(face_in_tensor, codeformer_fidelity = 0.6):
with torch.no_grad():
enhanced_face_in_tensor = get_code_former()(face_in_tensor, w=codeformer_fidelity, adain=True)[0]
return enhanced_face_in_tensor
def convert_tensor_to_image(enhanced_face_in_tensor):
restored_face = tensor2img(enhanced_face_in_tensor, rgb2bgr=True, min_max=(-1, 1))
return restored_face.astype('uint8')
def restore_face(face_in_tensor):
enhanced_face_in_tensor = enhance_face_in_tensor(face_in_tensor)
try:
restored_face = convert_tensor_to_image(enhanced_face_in_tensor)
del enhanced_face_in_tensor
except RuntimeError as error:
print(f'Failed inference for CodeFormer-tensor: {error}')
restored_face = convert_tensor_to_image(face_in_tensor)
return restored_face
return restored_face
def process_frames(source_path: str, frame_paths: list[str], progress=None) -> None:
for frame_path in frame_paths:
try:
frame = cv2.imread(frame_path)
result = process_frame(None, frame)
cv2.imwrite(frame_path, result)
except Exception as exception:
print(exception)
continue
def process_frames(source_path: str, temp_frame_paths: List[str], progress=None) -> None:
source_face = get_one_face(cv2.imread(source_path))
for temp_frame_path in temp_frame_paths:
temp_frame = cv2.imread(temp_frame_path)
result = process_frame(source_face, temp_frame)
cv2.imwrite(temp_frame_path, result)
if progress:
progress.update(1)
def process_image(source_path: str, image_path: str, output_file: str) -> None:
image = cv2.imread(image_path)
result = process_frame(None, image)
cv2.imwrite(output_file, result)
def process_image(source_path: str, target_path: str, output_path: str) -> None:
source_face = get_one_face(cv2.imread(source_path))
target_frame = cv2.imread(target_path)
result = process_frame(source_face, target_frame)
cv2.imwrite(output_path, result)
def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
# todo: remove threads limitation again
execution_threads = roop.globals.execution_threads
roop.globals.execution_threads = 1
roop.processors.frame.core.process_video(source_path, temp_frame_paths, process_frames)
roop.globals.execution_threads = execution_threads