mirror of
https://github.com/s0md3v/roop.git
synced 2025-12-16 03:57:19 +00:00
Fix suggested threads and memory
This commit is contained in:
parent
56dbf66a34
commit
663439fcee
15
roop/core.py
15
roop/core.py
@ -15,6 +15,8 @@ import shutil
|
||||
import argparse
|
||||
import torch
|
||||
import onnxruntime
|
||||
if not 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
|
||||
del torch
|
||||
import tensorflow
|
||||
|
||||
import roop.globals
|
||||
@ -24,9 +26,6 @@ from roop.predictor import predict_image, predict_video
|
||||
from roop.processors.frame.core import get_frame_processors_modules
|
||||
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
|
||||
|
||||
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
|
||||
del torch
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
|
||||
|
||||
@ -85,7 +84,7 @@ def decode_execution_providers(execution_providers: List[str]) -> List[str]:
|
||||
def suggest_max_memory() -> int:
|
||||
if platform.system().lower() == 'darwin':
|
||||
return 4
|
||||
return 16
|
||||
return 8
|
||||
|
||||
|
||||
def suggest_execution_providers() -> List[str]:
|
||||
@ -93,11 +92,9 @@ def suggest_execution_providers() -> List[str]:
|
||||
|
||||
|
||||
def suggest_execution_threads() -> int:
|
||||
if 'DmlExecutionProvider' in roop.globals.execution_providers:
|
||||
return 1
|
||||
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
|
||||
return 1
|
||||
return 8
|
||||
if 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
|
||||
return 8
|
||||
return 1
|
||||
|
||||
|
||||
def limit_resources() -> None:
|
||||
|
||||
@ -16,7 +16,7 @@ def get_face_analyser() -> Any:
|
||||
with THREAD_LOCK:
|
||||
if FACE_ANALYSER is None:
|
||||
FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.execution_providers)
|
||||
FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640))
|
||||
FACE_ANALYSER.prepare(ctx_id=0)
|
||||
return FACE_ANALYSER
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user