mirror of
https://github.com/s0md3v/roop.git
synced 2025-12-06 18:08:29 +00:00
move the model_path to the method
This commit is contained in:
parent
eaf899043c
commit
36446a92ca
@ -16,23 +16,20 @@ if 'ROCMExecutionProvider' in roop.globals.providers:
|
||||
|
||||
CODE_FORMER = None
|
||||
THREAD_LOCK = threading.Lock()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
ckpt_path = os.path.join("models", "codeformer.pth")
|
||||
|
||||
|
||||
def pre_check() -> None:
|
||||
pretrain_model_url = [
|
||||
'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
]
|
||||
download_directory_path = resolve_relative_path('../models')
|
||||
conditional_download(download_directory_path, pretrain_model_url)
|
||||
download_directory_path = resolve_relative_path('../models/codeformer.pth')
|
||||
conditional_download(download_directory_path, ['https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'])
|
||||
|
||||
|
||||
def get_code_former():
|
||||
global CODE_FORMER
|
||||
with THREAD_LOCK:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model_path = os.path.join("models", "codeformer.pth")
|
||||
if CODE_FORMER is None:
|
||||
checkpoint = torch.load(ckpt_path)["params_ema"]
|
||||
model = torch.load(model_path)["params_ema"]
|
||||
CODE_FORMER = ARCH_REGISTRY.get("CodeFormer")(
|
||||
dim_embd=512,
|
||||
codebook_size=1024,
|
||||
@ -40,12 +37,13 @@ def get_code_former():
|
||||
n_layers=9,
|
||||
connect_list=["32", "64", "128", "256"],
|
||||
).to(device)
|
||||
CODE_FORMER.load_state_dict(checkpoint)
|
||||
CODE_FORMER.load_state_dict(model)
|
||||
CODE_FORMER.eval()
|
||||
return CODE_FORMER
|
||||
|
||||
|
||||
def get_face_enhancer(FACE_ENHANCER):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if FACE_ENHANCER is None:
|
||||
FACE_ENHANCER = FaceRestoreHelper(
|
||||
upscale_factor = int(2),
|
||||
@ -91,6 +89,7 @@ def process_faces(source_face: any, frame: any) -> any:
|
||||
|
||||
|
||||
def normalize_face(face):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
face_in_tensor = img2tensor(face / 255.0, bgr2rgb=True, float32=True)
|
||||
normalize(face_in_tensor, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
return face_in_tensor.unsqueeze(0).to(device)
|
||||
|
||||
@ -14,7 +14,7 @@ THREAD_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def pre_check() -> None:
|
||||
download_directory_path = resolve_relative_path('../models')
|
||||
download_directory_path = resolve_relative_path('../models/inswapper_128.onnx')
|
||||
conditional_download(download_directory_path, ['https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx'])
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user