-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathutils.py
More file actions
152 lines (121 loc) · 4.4 KB
/
utils.py
File metadata and controls
152 lines (121 loc) · 4.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Shared utilities for adversarial attack and text generation models."""
import os
import json
import yaml
import hashlib
import base64
from typing import Dict, Any, List, Union
from omegaconf import OmegaConf
import wandb
from config_schema import MainConfig
def load_api_keys() -> Dict[str, str]:
"""Load API keys from the api_keys file.
Returns:
Dict[str, str]: Dictionary containing API keys for different models
Raises:
FileNotFoundError: If no api_keys file is found
"""
for ext in ['yaml', 'yml', 'json']:
file_path = f'api_keys.{ext}'
if os.path.exists(file_path):
with open(file_path, 'r') as f:
if ext in ['yaml', 'yml']:
return yaml.safe_load(f)
else:
return json.load(f)
raise FileNotFoundError(
"API keys file not found. Please create api_keys.yaml, api_keys.yml, or api_keys.json "
"in the root directory with your API keys."
)
def get_api_key(model_name: str) -> str:
"""Get API key for specified model.
Args:
model_name: Name of the model to get API key for
Returns:
str: API key for the specified model
Raises:
KeyError: If API key for model is not found
"""
api_keys = load_api_keys()
if model_name not in api_keys:
raise KeyError(
f"API key for {model_name} not found in api_keys file. "
f"Available models: {list(api_keys.keys())}"
)
return api_keys[model_name]
def hash_training_config(cfg: MainConfig) -> str:
"""Create a deterministic hash of training-relevant config parameters.
Args:
cfg: Configuration object containing model settings
Returns:
str: MD5 hash of the config parameters
"""
# Convert backbone list to plain Python list
if isinstance(cfg.model.backbone, (list, tuple)):
backbone = list(cfg.model.backbone)
else:
backbone = OmegaConf.to_container(cfg.model.backbone)
# Create config dict with converted values
train_config = {
"data": {
"batch_size": int(cfg.data.batch_size),
"num_samples": int(cfg.data.num_samples),
"cle_data_path": str(cfg.data.cle_data_path),
"tgt_data_path": str(cfg.data.tgt_data_path),
},
"optim": {
"alpha": float(cfg.optim.alpha),
"epsilon": int(cfg.optim.epsilon),
"steps": int(cfg.optim.steps),
},
"model": {
"input_res": int(cfg.model.input_res),
"use_source_crop": bool(cfg.model.use_source_crop),
"use_target_crop": bool(cfg.model.use_target_crop),
"crop_scale": tuple(float(x) for x in cfg.model.crop_scale),
"ensemble": bool(cfg.model.ensemble),
"backbone": backbone,
},
"attack": cfg.attack,
}
# Convert to JSON string with sorted keys
json_str = json.dumps(train_config, sort_keys=True)
return hashlib.md5(json_str.encode()).hexdigest()
def setup_wandb(cfg: MainConfig, tags=None) -> None:
"""Initialize Weights & Biases logging.
Args:
cfg: Configuration object containing wandb settings
"""
config_dict = OmegaConf.to_container(cfg, resolve=True)
wandb.init(
project=cfg.wandb.project,
config=config_dict,
tags=tags,
)
def encode_image(image_path: str) -> str:
"""Encode image file to base64 string.
Args:
image_path: Path to image file
Returns:
str: Base64 encoded image string
"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def ensure_dir(path: str) -> None:
"""Ensure directory exists, create if it doesn't.
Args:
path: Directory path to ensure exists
"""
os.makedirs(path, exist_ok=True)
def get_output_paths(cfg: MainConfig, config_hash: str) -> Dict[str, str]:
"""Get dictionary of output paths based on config.
Args:
cfg: Configuration object
config_hash: Hash of training config
Returns:
Dict[str, str]: Dictionary containing output paths
"""
return {
'output_dir': os.path.join(cfg.data.output, "img", config_hash),
'desc_output_dir': os.path.join(cfg.data.output, "description", config_hash)
}