forked from yl4579/StyleTTS2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
119 lines (91 loc) · 3.17 KB
/
utils.py
File metadata and controls
119 lines (91 loc) · 3.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import logging
import os
import shutil
import yaml
import matplotlib.pyplot as plt
import numpy as np
import torch
from accelerate.logging import get_logger
from monotonic_align.core import maximum_path_c
from munch import Munch
def maximum_path(neg_cent, mask):
"""Cython optimized version.
neg_cent: [b, t_t, t_s]
mask: [b, t_t, t_s]
"""
device = neg_cent.device
dtype = neg_cent.dtype
neg_cent = np.ascontiguousarray(neg_cent.data.cpu().numpy().astype(np.float32))
path = np.ascontiguousarray(np.zeros(neg_cent.shape, dtype=np.int32))
t_t_max = np.ascontiguousarray(
mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
)
t_s_max = np.ascontiguousarray(
mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
)
maximum_path_c(path, neg_cent, t_t_max, t_s_max)
return torch.from_numpy(path).to(device=device, dtype=dtype)
def get_data_path_list(train_path=None, val_path=None):
if train_path is None:
train_path = "Data/train_list.txt"
if val_path is None:
val_path = "Data/val_list.txt"
with open(train_path, "r", encoding="utf-8", errors="ignore") as f:
train_list = f.readlines()
with open(val_path, "r", encoding="utf-8", errors="ignore") as f:
val_list = f.readlines()
return train_list, val_list
def length_to_mask(lengths):
mask = (
torch.arange(lengths.max())
.unsqueeze(0)
.expand(lengths.shape[0], -1)
.type_as(lengths)
)
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
return mask
# for norm consistency loss
def log_norm(x, mean=-4, std=4, dim=2):
"""
normalized log mel -> mel -> norm -> log(norm)
"""
x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
return x
def get_image(arrs):
plt.switch_backend("agg")
fig = plt.figure()
ax = plt.gca()
ax.imshow(arrs)
return fig
def recursive_munch(d):
if isinstance(d, dict):
return Munch((k, recursive_munch(v)) for k, v in d.items())
elif isinstance(d, list):
return [recursive_munch(v) for v in d]
else:
return d
def _setup_logging(log_dir, logger_name, log_level="DEBUG"):
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
# Create logger
logger = get_logger(logger_name, log_level)
logger.setLevel(logging.DEBUG)
# Create handlers for console and log file
console_handler = logging.StreamHandler()
file_handler = logging.FileHandler(os.path.join(log_dir, "train.log"))
# Create formatters and add it to handlers
formatter = logging.Formatter("%(levelname)s:%(asctime)s: %(message)s")
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
# Add handlers to the logger
logger.logger.addHandler(console_handler)
logger.logger.addHandler(file_handler)
return logger
def configure_environment(config_path):
with open(config_path, "r") as file:
config = yaml.safe_load(file)
# Setup logging
log_dir = config["log_dir"]
logger = _setup_logging(log_dir, __name__)
shutil.copy(config_path, os.path.join(log_dir, os.path.basename(config_path)))
return config, logger, log_dir