Skip to content

SerenaTradingResearch/deep-trader

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Intro

  • Exploring ideas in deep learning for trading
  • Example training result:

  • Labeled data:

Usage

pip install deep-trader
import h5py
import numpy as np
import talib as ta
import torch as tc
import torch.nn as nn

from deep_trader.core import DeepTraderBase
from deep_trader.plot import plot1
from deep_trader.utils import D_TYPE, mlp, postfix, tensor


class DeepTrader(DeepTraderBase):
    periods = [2**n for n in range(4, 10)]

    lev = 2
    TO = 60
    TP = 0.02
    SL = -0.5
    fee = 1e-3

    epochs = 1000
    f_plot = 100
    f_test = 10
    plot_id = "./data/DeepTrader"

    @property
    def obs_skip(s):
        return max(s.periods)

    def make_obs(s, price: np.ndarray):
        obs = []
        for T in s.periods:
            obs.append(price / ta.SMA(price, T) - 1)
        return np.array(obs).T

    def loss(s, d: D_TYPE):
        logits = s.net(d["obs"])
        pos = tc.tanh(logits[:, 0])
        profit = s.pos_abs_times(pos, d["profit_SL"])
        pr_mean, pr_std = profit.mean(), profit.std()
        # sharpe = pr_mean / pr_std
        return {
            "pos.hist": pos,
            "profit.hist": profit,
            "pr_mean": pr_mean,
            "pr_std": pr_std,
            "loss": -pr_mean,
        }


def main():
    tc.manual_seed(0)
    path = "./data/futures_data_2025-07-01_2025-08-01.h5"
    s = DeepTrader()
    s.set_h5_file(h5py.File(path))

    d1 = s.get_period("2025-07-01", "2025-07-15")
    # plot_crypto_data(d1, "./data/data")
    d1 = s.label_data(d1)
    d2 = s.label_data(s.get_period("2025-07-15", "2025-08-01"))
    s.normalize_obs(d1, d2)
    plot1(postfix(d1, ".hist"), "./data/labeled_data")

    device = "cuda"
    d1, d2 = tensor([d1, d2], device)
    s.net = mlp([6, 32, 32, 1], nn.ReLU).to(device)
    s.opt = tc.optim.Adam(s.net.parameters())
    s.train(d1, d2)


if __name__ == "__main__":
    main()

About

Deep Learning for trading

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors