Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(1, 360, 5)] 0
normalizer_1 (Normalizer) (1, 360, 5) 0
timexer (TimeXer) (1, 1) 6745
denormalizer (Denormalizer (1, 1) 0
)
...
_________________________________________________________________
934/934 [==============================] - 2s 1ms/step
Inference time: 1.866 seconds
Throughput: 16016.67 samples/second
val rmse : 8.589814186096191, test rmse : 13.827042579650879
val rmse : 9.852630615234375, test rmse : 15.736166954040527
val rmse : 17.27320098876953, test rmse : 14.202249526977539
avg test rmse: 14.588486353556315 [13.827043, 15.736167, 14.20225]
๋ ผ๋ฌธ: TimeXer: Empowering Transformers for Time Series Forecasting with Exogenous Variables (NeurIPS 2024)
์ ์/์์: Tsinghua Univ. BNRist
ํต์ฌ ์์ด๋์ด ํ ์ค ์์ฝ: ๋ด์(Endogenous) ์๊ณ์ด์ ํจ์น ๋จ์ ํ ํฐ์ผ๋ก, ์ธ์(Exogenous) ์๊ณ์ด์ ๋ณ์(variates) ๋จ์ ํ ํฐ์ผ๋ก ํํํ๊ณ , ๊ธ๋ก๋ฒ ํ ํฐ์ ๋ค๋ฆฌ๋ก ์ผ์ ํจ์น-์๊ธฐ์ดํ ์ ๊ณผ ๋ณ์-๊ต์ฐจ์ดํ ์ ์ ๋์์ ์ํํด ์ธ๋ถ ์์ธ์ ๊ฒฌ๊ณ ํ๊ฒ ํก์ํ๋ค.
- ์ ์ธ์ ๋ณ์๊ฐ ์ค์ํ๊ฐ?
- ๋ฌธ์ ์ ์
- ๋ชจ๋ธ ์ํคํ ์ฒ
- ํ์ต/์์ค ๋ฐ ๋ฉํฐ๋ณ์ ์์ธก์ผ๋ก์ ์ผ๋ฐํ
- ๊ฒฐ๊ณผ ์์ฝ
- ์ผ๋ฐ์ฑ/๊ฒฌ๊ณ ์ฑ/ํ์ฅ์ฑ
- ์ฌํ์ ์ํ ๊ธฐ๋ณธ ์ค์
- ๋ฐ์ดํฐ์ ๊ฐ์
- ์ ํ์ ๊ณผ ํ
- ์ธ์ฉ
์ค์ธ๊ณ ์๊ณ์ด์ ๊ฒฐ์ธก, ๋น๊ท ์ผ ์ํ๋ง, ์ฃผ๊ธฐ/๊ธธ์ด ๋ถ์ผ์น, ์๊ฐ ์ง์ฐ ํจ๊ณผ๊ฐ ํํ๋ค. ๊ธฐ์กด ์ ๊ทผ(๋ดยท์ธ์์ ๋์ผ ์์ ์ concat)์ผ๋ก๋ ์ ๋ ฌ/๋๊ธฐํ๊ฐ ์ด๋ ต๊ณ , ๋ถํ์ํ ์ํธ์์ฉ๊ณผ ๋ณต์ก๋๊ฐ ์ปค์ง๋ค. TimeXer๋ ์๋ฒ ๋ฉ ๋จ๊ณ์์ ์ญํ ์ ๋ถ๋ฆฌํด ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ์ฐํํ๋ค.
-
์
๋ ฅ: ๋ด์ ๋จ๋ณ๋
$x_{1:T}$ ์ ๋ค์์ ์ธ์ ๋ณ์ ์งํฉ$z^{(1)}{1:T{\mathrm{ex}}}, \dots, z^{(C)}{1:T{\mathrm{ex}}}$ (๋ดยท์ธ์์ look-back ๊ธธ์ด ๋ถ์ผ์น ํ์ฉ,$T \neq T_{\mathrm{ex}}$ ) -
๋ชฉํ: ํฅํ
$S$ ์คํ ์ ๋ด์ ์๊ณ์ด$\hat{x}{T+1:T+S} = F{\theta}!\big(x_{1:T}, z_{1:T_{\mathrm{ex}}}\big)$ ์์ธก
ํต์ฌ ์ค๊ณ:
- ๋ด์ ์๋ฒ ๋ฉ(Endogenous) โ ๋น์ค์ฒฉ ํจ์น๋ก ๋๋ ๋ค, ํจ์น ํ ํฐ๋ค(temporal patch tokens) + ํ์ตํ ๊ธ๋ก๋ฒ ํ ํฐ(series-level global token) ๊ตฌ์ฑ. ๊ธ๋ก๋ฒ ํ ํฐ์ด ํจ์นโ์ธ์ ์ ๋ณด ํต๋ก ์ญํ .
- ์ธ์ ์๋ฒ ๋ฉ(Exogenous) โ ๋ณ์(variates) ๋จ์ ์๊ณ์ด ์ ์ฒด๋ฅผ ํ๋์ ํ ํฐ์ผ๋ก ์๋ฒ ๋ฉ(variate token). ๊ฒฐ์ธก/๋ฏธ์ ๋ ฌ/์ฃผ๊ธฐยท๊ธธ์ด ์์ด์ฑ์ ์์ฐ ์ ์.
- ์ดํ
์
ํ๋ฆ
- ๋ด์ ์๊ธฐ์ดํ ์ (Self-Attn): [ํจ์น ํ ํฐ๋ค + ๊ธ๋ก๋ฒ ํ ํฐ]์ ๋ํด ํจ์น-ํจ์น ๋ฐ ํจ์น-๊ธ๋ก๋ฒ ๊ด๊ณ๋ฅผ ๋์์ ํ์ตํด ์๊ฐ ์์กด์ฑ์ ์ ํํ ์บก์ฒ.
- ์ธ์โ๋ด์ ๊ต์ฐจ์ดํ ์ (Cross-Attn): **๋ด์ ๊ธ๋ก๋ฒ ํ ํฐ(์ง์)**์ด **์ธ์ ๋ณ์ ํ ํฐ๋ค(ํค/๊ฐ)**์ ์ ํ์ ์ผ๋ก ํก์ โ ๋ณ์-์์ค ์๊ด ๋ฐ์.
์ง๊ด: ์ธ์์ โ๋ฌด์์ด ์ค์ํ ๋ณ์์ธ๊ฐโ๋ฅผ ๊ณ ๋ฅด๊ณ (๋ณ์-์์ค), ๋ด์์ โ์ธ์ ์ค์ํ๊ฐโ๋ฅผ ์ ๋ฐํ ๋ณธ๋ค(ํจ์น-์์ค). ๋ ์ถ์ ๊ธ๋ก๋ฒ ํ ํฐ์ผ๋ก ์ฎ์ด ๋ถํ์ํ ์ -๋ณ์๊ฐ ์ํธ์์ฉ ๋น์ฉ์ ์ค์ด๋ฉด์๋ ์ ๋ณด๋ ์ ํ์ ์ผ๋ก ์ ์ ๋๋ค.
- ์ถ๋ ฅ ์์ฑ: ๋ง์ง๋ง ๋ธ๋ก์์ ์ป์ ํจ์น ํํ๊ณผ ์ ์ญ(๊ธ๋ก๋ฒ) ํํ์ ํ๋๋ก ํฉ์น ๋ค, ์ด๋ฅผ **์ ํ ๋ณํ(์์ ์ฐ๊ฒฐ์ธต)**์ ํต๊ณผ์์ผ ๋ฏธ๋ ๊ฐ์ ์์ธกํ๋ค. ์ฆ, ์๊ฐ ๊ตฌ๊ฐ๋ณ ์ ๋ณด(ํจ์น)์ ์๊ณ์ด ์ ๋ฐ์ ์์ฝ ์ ๋ณด(์ ์ญ)๋ฅผ ๊ฒฐํฉํด ์ต์ข ์์ธก์ ๋ง๋ ๋ค.
- ์์ค: L2(์ ๊ณฑ ์ค์ฐจ)
- ๋ฉํฐ๋ณ์ ์์ธก: ๊ฐ ๋ณ์๋ฅผ โ๋ด์โ์ผ๋ก ๋๊ณ ๋๋จธ์ง ๋ณ์๋ ์ธ์์ผ๋ก ๋ณ๋ ฌ ์ฒ๋ฆฌ(์ฑ๋ ๋ ๋ฆฝ), Self/Cross-Attn ์ธต ๊ณต์ .
- ๋จ๊ธฐ ์ ๋ ฅ๊ฐ๊ฒฉ(EPF, ์
๋ ฅ 168โ์์ธก 24) 5๊ฐ ๋ง์ผ ๋ชจ๋์์ SOTA(MSE/MAE). ์:
PJM MSE 0.093 (iTransformer 0.097, Crossformer 0.101 ๋ฑ)
NP MSE 0.236 (Crossformer 0.240, RLinear 0.335 ๋ฑ)
โ ์ธ์ ๋ณ์์ ์ ํํ ํ์ฉ + ์๊ฐ ์์กด ํ์ต์ด ๊ฒฝ์ ๋ชจ๋ธ์ ์ผ๊ด๋๊ฒ ์ํ. - ์ฅ๊ธฐ ๋ฉํฐ๋ณ์(ETT/ECL/Weather/Traffic ๋ฑ, ํ๊ท ) ๋๋ค์ ๋ฐ์ดํฐ์ ์์ ์ผ๊ด๋ ์ฐ์ ์ฑ๋ฅ.
- ์ ์ ๋๋? ๊ธฐ์กด ๋ชจ๋ธ์
- Crossformer: ๋ชจ๋ ๋ณ์๋ฅผ ์ธ๋ฐ ํจ์น ์์ค์ผ๋ก ์ฎ์ด ๋ ธ์ด์ฆ/๋ณต์ก๋ ์ฆ๊ฐ
- iTransformer: ๋ณ์-์์ค๋ง ๋ณด๊ณ ์๊ฐ-์ธ๋ถ๋ ์ ํ ํฌ์์ ์์กด
โ TimeXer๋ ํจ์น(์๊ฐ)ร๋ณ์(์ธ์) ์ด์ ์ค๊ณ๋ก ์ฅ๋จ์ ์ ๋์์ ๋ณด์.
- Look-back ๋ถ์ผ์น(๋ด์/์ธ์ ๊ธธ์ด ๋ค๋ฆ)์๋ ์ฑ๋ฅ ์ด๋ ์ ์ง. ์ธ์ ๊ธธ์ด ํ์ฅ๋ณด๋ค ๋ด์ ๊ธธ์ด ํ์ฅ์ด ํนํ ์ ์ต.
- ๊ฒฐ์ธก/๋๋ค ์ธ์์๋ ๋ด์์ ์๊ฐ ํํ์ด ์์ธก์ ์ฃผ๋ํด ์ฑ๋ฅ ๊ฐ๊ฑด(์ธ์์ด ์์ ํ ๋ฌด์๋ฏธํด๋ ๊ธ๋ฝํ์ง ์์). ๋ฐ๋๋ก ๋ด์์ด ๋ฌด์๋ฏธํด์ง๋ฉด ๊ธ๊ฒฉํ ์ ํ.
- ํจ์จ์ฑ: ์ธ์ ๊ฐ ์ํธ์์ฉ์ ์ธต๋ง๋ค ํ์ด๋์ง ์๊ณ ๊ธ๋ก๋ฒ ํ ํฐ ๊ธฐ๋ฐ ๊ต์ฐจ์ดํ ์ ์ผ๋ก ์ฒ๋ฆฌ โ ๋ฉ๋ชจ๋ฆฌ ์ฐ์/ํ์ต์๋ ์ ๋ฆฌ.
- ํ๋ ์์ํฌ/ํ๋์จ์ด: PyTorch, ๋จ์ผ RTX 4090 24GB
- ์ต์ ํ: Adam, lr=1e-4, L2 Loss, Early Stopping, 10 epoch ๊ณ ์ ํ์ต
-
๋ชจ๋ธ ํฌ๊ธฐ: Block
$L \in {1,2,3}$ ,$d_{\text{model}}\in{128,256,512}$ - ํจ์น ๊ธธ์ด: ์ฅ๊ธฐ 16, ๋จ๊ธฐ 24(๋น์ค์ฒฉ) โ ์์ ํจ์น๋ ์๋ฏธ ์ ๋ณด ํฌ์ ๊ฐ๋ฅ(์ฑ๋ฅ ์ ํ)
# x: endogenous (T,), z_list: [z^(1)_(T_ex), ..., z^(C)_(T_ex)]
patch_tokens = PatchEmbed(split_nonoverlap(x, P)) # (N, D)
g_token = LearnableGlobalToken() # (1, D)
v_tokens = [VariateEmbed(z) for z in z_list] # (C, D)
# L layers
for _ in range(L):
# Self-Attn over [patch_tokens || g_token]
patch_tokens, g_token = SelfAttentionConcat(patch_tokens, g_token)
# Cross-Attn: g_token (Q) <-- v_tokens (K,V)
g_token = CrossAttention(g_token, v_tokens)
y_hat = LinearProjection(concat(patch_tokens, g_token)) # forecast
loss = mse(y_hat, y_true)