-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
25 lines (20 loc) · 777 Bytes
/
server.py
File metadata and controls
25 lines (20 loc) · 777 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import sys
try:
import numpy.core as _c; sys.modules['numpy._core'] = _c; sys.modules['numpy._core.multiarray'] = _c.multiarray
except: pass
from flask import Flask, request, jsonify
import numpy as np
from stable_baselines3 import PPO
MODEL = PPO.load('ppo_hft_scalper.zip') # Load model
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
w = np.array(request.json['window'], dtype=np.float32)
if w.shape != (50, 5):
return jsonify(error='50x5 required'), 400
a, _ = MODEL.predict(w, deterministic=True)
sig = {0: 'HOLD', 1: 'BUY', 2: 'SELL'}[int(a)]
return jsonify(action=int(a), signal=sig)
if __name__ == '__main__':
# FIXED HOST AND PORT
app.run(host='0.0.0.0', port=10000)