From 159f1c3958aebc50e5172705958d35f2de848053 Mon Sep 17 00:00:00 2001 From: AlanFokCo Date: Sat, 16 May 2026 23:23:04 +0800 Subject: [PATCH 1/2] feat: TradingView-style report optimization and Python 3.9 compatibility - Add RSI(14), MACD(12,26,9), Bollinger Bands(20,2) indicators to HTML report - Add indicator toggle panel (MA/BB/VOL/S-R show/hide) on K-line chart - Add crosshair-linked dynamic legend with real-time indicator values - Replace iframe-based web console with native Lightweight Charts (ReportViewer) - Replace SVG sparklines with LTVC mini area charts in MetricsComparison - Extract shared _compute_chart_data() helper for unified chart data production - Fix Python 3.9 compatibility: replace str|None/list[str]/dict[str] with Optional/List/Dict across all backend files - Fix RSI reference lines using OHLCV time axis instead of portfolio dates - Fix ReportLinkModal scrolling by changing overflow:hidden to overflow:auto Co-Authored-By: Claude Opus 4.7 --- eqlib/report.py | 556 ++++++++++++++---- .../backend/studio_api/backtest_executor.py | 16 +- .../backend/studio_api/completion_service.py | 5 +- .../backend/studio_api/config.py | 7 +- .../backend/studio_api/format_service.py | 5 +- .../backend/studio_api/isolated_runner.py | 9 +- .../backend/studio_api/lint_service.py | 11 +- .../backend/studio_api/models.py | 33 +- .../backend/studio_api/proc_registry.py | 8 +- .../backend/studio_api/routers/runs.py | 46 +- .../backend/studio_api/routers/strategies.py | 3 +- .../backend/studio_api/run_queue.py | 25 +- .../backend/studio_api/schemas.py | 102 ++-- .../backend/studio_api/security_scanner.py | 9 +- .../backend/studio_api/stream_hub.py | 26 +- .../frontend/package-lock.json | 16 + web_strategy_studio/frontend/package.json | 1 + .../src/components/MetricsComparison.tsx | 87 ++- .../src/components/ReportLinkModal.tsx | 31 +- .../frontend/src/components/ReportViewer.tsx | 333 +++++++++++ 20 files changed, 1037 insertions(+), 292 deletions(-) create mode 100644 web_strategy_studio/frontend/src/components/ReportViewer.tsx diff --git a/eqlib/report.py b/eqlib/report.py index 3ea6f8e..b5c9f75 100644 --- a/eqlib/report.py +++ b/eqlib/report.py @@ -441,24 +441,21 @@ def generate_chart(result, out_path): print(f"Chart saved: {out_path}") -def generate_html_report(result, out_path): - """Generate interactive HTML report with TradingView lightweight-charts. - - Features: K-line with volume, support/resistance, pivot levels, - strategy vs 沪深300 vs 上证指数 (cumulative return %), - drawdown, daily P&L, daily return, trade calendar. +def _compute_chart_data(result): + """Compute all chart data arrays from a backtest result dict. + + Returns a dict with: candlestick_data, volume_data, ma5/20/60_data, + support_data, resistance_data, rsi_data, macd_data, macd_signal_data, + macd_hist_data, bb_upper/middle/lower_data, markers, cum_return_data, + ret_hs300_data, ret_sse_data, drawdown_data, pnl_bar_data, + daily_returns_data, symbol. """ ctx = result["context"] trade_log = result["trade_log"] recorded = result["recorded_values"] - benchmark = result.get("benchmark", "000300.XSHG") - + initial = ctx.portfolio.starting_cash start = ctx.start_date end = ctx.end_date - initial = ctx.portfolio.starting_cash - final = ctx.portfolio.total_value - pnl = final - initial - pnl_pct = (pnl / initial * 100) if initial > 0 else 0.0 # Collect traded securities securities = set() @@ -468,76 +465,80 @@ def generate_html_report(result, out_path): securities.add(ctx.universe[0]) if not securities: securities.add("601390") - symbol = list(securities)[0] - # ============================================================ - # K-line + technical indicators (use preloaded OHLCV first) - # ============================================================ - candlestick_data = [] - ma5_data = [] - ma20_data = [] - ma60_data = [] - volume_data = [] - support_data = [] - resistance_data = [] + # K-line + technical indicators + candlestick_data, ma5_data, ma20_data, ma60_data = [], [], [], [] + volume_data, support_data, resistance_data = [], [], [] ohlcv_data = result.get("ohlcv_data", {}) - if symbol in ohlcv_data: - df = ohlcv_data[symbol] - else: - df = pd.DataFrame() + df = ohlcv_data.get(symbol, pd.DataFrame()) + if df.empty: try: + from eqlib.data import fetch_stock_data df = fetch_stock_data(symbol, start, end) except Exception: pass + if not df.empty: - # Trim OHLCV to backtest period so K-line aligns with return charts - start_ts = pd.Timestamp(start) - end_ts = pd.Timestamp(end) - df_sorted = df.sort_index() - df_sorted = df_sorted.loc[start_ts:end_ts] + start_ts, end_ts = pd.Timestamp(start), pd.Timestamp(end) + df_sorted = df.sort_index().loc[start_ts:end_ts] if df_sorted.empty: df_sorted = df.sort_index() - closes = df_sorted["close"] - highs = df_sorted["high"] - lows = df_sorted["low"] - ma5 = closes.rolling(5).mean() - ma20 = closes.rolling(20).mean() - ma60 = closes.rolling(60).mean() + closes, highs, lows = df_sorted["close"], df_sorted["high"], df_sorted["low"] + ma5, ma20, ma60 = closes.rolling(5).mean(), closes.rolling(20).mean(), closes.rolling(60).mean() support, resistance = _compute_support_resistance(closes, highs, lows, window=20) for (date, row), m5, m20, m60, sup, res in zip( df_sorted.iterrows(), ma5, ma20, ma60, support, resistance): d = _to_tv_date(date) - o = float(row.get("open", 0)) - h = float(row.get("high", 0)) - l = float(row.get("low", 0)) - c = float(row.get("close", 0)) - v = float(row.get("volume", 0)) - - candlestick_data.append({ - "time": d, "open": round(o, 3), "high": round(h, 3), - "low": round(l, 3), "close": round(c, 3), - }) - volume_data.append({ - "time": d, "value": round(v, 0), - "color": "#26a69a" if c >= o else "#ef5350", - }) - if not pd.isna(m5): - ma5_data.append({"time": d, "value": round(float(m5), 3)}) - if not pd.isna(m20): - ma20_data.append({"time": d, "value": round(float(m20), 3)}) - if not pd.isna(m60): - ma60_data.append({"time": d, "value": round(float(m60), 3)}) - if not pd.isna(sup): - support_data.append({"time": d, "value": round(float(sup), 3)}) - if not pd.isna(res): - resistance_data.append({"time": d, "value": round(float(res), 3)}) + o, h, l, c, v = float(row.get("open", 0)), float(row.get("high", 0)), float(row.get("low", 0)), float(row.get("close", 0)), float(row.get("volume", 0)) + candlestick_data.append({"time": d, "open": round(o, 3), "high": round(h, 3), "low": round(l, 3), "close": round(c, 3)}) + volume_data.append({"time": d, "value": round(v, 0), "color": "#26a69a" if c >= o else "#ef5350"}) + if not pd.isna(m5): ma5_data.append({"time": d, "value": round(float(m5), 3)}) + if not pd.isna(m20): ma20_data.append({"time": d, "value": round(float(m20), 3)}) + if not pd.isna(m60): ma60_data.append({"time": d, "value": round(float(m60), 3)}) + if not pd.isna(sup): support_data.append({"time": d, "value": round(float(sup), 3)}) + if not pd.isna(res): resistance_data.append({"time": d, "value": round(float(res), 3)}) + + # RSI(14), MACD(12,26,9), Bollinger Bands(20,2) + rsi_data, macd_data, macd_signal_data, macd_hist_data = [], [], [], [] + bb_upper_data, bb_middle_data, bb_lower_data = [], [], [] + + if not df.empty and len(df_sorted) >= 26: + closes = df_sorted["close"] + delta = closes.diff() + gain, loss = delta.clip(lower=0), (-delta.clip(upper=0)) + avg_gain = gain.ewm(alpha=1/14, min_periods=14, adjust=False).mean() + avg_loss = loss.ewm(alpha=1/14, min_periods=14, adjust=False).mean() + rs = avg_gain / avg_loss.replace(0, np.nan) + rsi_series = 100 - (100 / (1 + rs)) + + ema_fast = closes.ewm(span=12, adjust=False).mean() + ema_slow = closes.ewm(span=26, adjust=False).mean() + macd_line = ema_fast - ema_slow + signal_line = macd_line.ewm(span=9, adjust=False).mean() + histogram = macd_line - signal_line + + bb_middle_series = closes.rolling(20).mean() + bb_std = closes.rolling(20).std() + bb_upper_series, bb_lower_series = bb_middle_series + 2 * bb_std, bb_middle_series - 2 * bb_std + + for (date, row), rsi_v, macd_v, sig_v, hist_v, bb_u, bb_m, bb_l in zip( + df_sorted.iterrows(), rsi_series, macd_line, signal_line, + histogram, bb_upper_series, bb_middle_series, bb_lower_series): + d = _to_tv_date(date) + if not pd.isna(rsi_v): rsi_data.append({"time": d, "value": round(float(rsi_v), 3)}) + if not pd.isna(macd_v): macd_data.append({"time": d, "value": round(float(macd_v), 4)}) + if not pd.isna(sig_v): macd_signal_data.append({"time": d, "value": round(float(sig_v), 4)}) + if not pd.isna(hist_v): + hv = round(float(hist_v), 4) + macd_hist_data.append({"time": d, "value": hv, "color": "rgba(245,34,45,0.6)" if hv >= 0 else "rgba(82,196,26,0.6)"}) + if not pd.isna(bb_u): bb_upper_data.append({"time": d, "value": round(float(bb_u), 3)}) + if not pd.isna(bb_m): bb_middle_data.append({"time": d, "value": round(float(bb_m), 3)}) + if not pd.isna(bb_l): bb_lower_data.append({"time": d, "value": round(float(bb_l), 3)}) - # ============================================================ # Buy/sell markers - # ============================================================ markers = [] for t in trade_log: markers.append({ @@ -549,21 +550,11 @@ def generate_html_report(result, out_path): }) markers.sort(key=lambda x: x["time"]) - # ============================================================ - # Cumulative return series (strategy) - # ============================================================ + # Cumulative return, daily P&L, daily return cum_return_data = _build_return_series(recorded, initial) - - # ============================================================ - # Daily P&L and daily return - # ============================================================ pnl_bar_data, daily_returns_data = _build_daily_pnl(recorded, initial) - # ============================================================ - # Benchmark cumulative returns: 沪深300 + 上证指数(图表双线); - # bench_data 为 set_benchmark 配置的指数,用于 _calc_metrics 等 - # ============================================================ - bench_data = _fetch_benchmark_returns(benchmark, start, end, recorded) + # Benchmark cumulative returns ret_hs300 = result.get("chart_index_hs300") if isinstance(result.get("chart_index_hs300"), list) else None if not ret_hs300: ret_hs300 = _fetch_index_returns("sh000300", start, end, recorded) @@ -571,21 +562,123 @@ def generate_html_report(result, out_path): if not ret_sse: ret_sse = _fetch_index_returns("sh000001", start, end, recorded) - # ============================================================ # Drawdown series - # ============================================================ drawdown_data = [] if cum_return_data: peak = cum_return_data[0]["value"] for d in cum_return_data: - if d["value"] > peak: - peak = d["value"] - dd = round(d["value"] - peak, 3) - drawdown_data.append({"time": d["time"], "value": dd}) + if d["value"] > peak: peak = d["value"] + drawdown_data.append({"time": d["time"], "value": round(d["value"] - peak, 3)}) + + # Technical summary stats for HTML report + tech_stats = {} + if not df.empty: + df_s = df.sort_index() + c = float(df_s["close"].iloc[-1]) + ma5_last = float(df_s["close"].rolling(5).mean().iloc[-1]) + ma20_last = float(df_s["close"].rolling(20).mean().iloc[-1]) + ma60_ser = df_s["close"].rolling(60).mean().dropna() + ma60_v = round(float(ma60_ser.iloc[-1]), 3) if len(ma60_ser) > 0 else None + atr14 = _compute_atr(df_s["high"], df_s["low"], df_s["close"], 14) + vol20 = df_s["volume"].rolling(20).mean().iloc[-1] + vol_ratio = round(float(df_s["volume"].iloc[-1] / vol20), 2) if vol20 > 0 else None + + rsi_last = rsi_data[-1]["value"] if rsi_data else None + macd_last = macd_data[-1]["value"] if macd_data else None + macd_sig_last = macd_signal_data[-1]["value"] if macd_signal_data else None + macd_hist_last = macd_hist_data[-1]["value"] if macd_hist_data else None + bb_u_last = bb_upper_data[-1]["value"] if bb_upper_data else None + bb_m_last = bb_middle_data[-1]["value"] if bb_middle_data else None + bb_l_last = bb_lower_data[-1]["value"] if bb_lower_data else None + bb_width = round((bb_u_last - bb_l_last) / bb_m_last * 100, 3) if bb_u_last is not None and bb_m_last and bb_m_last != 0 else None + + tech_stats = { + "latest_price": round(c, 3), + "ma5": round(ma5_last, 3), + "ma20": round(ma20_last, 3), + "ma60": ma60_v, + "atr14": round(float(atr14), 3) if atr14 else None, + "vol_ratio": vol_ratio, + "period_high": round(float(df_s["high"].max()), 3), + "period_low": round(float(df_s["low"].min()), 3), + "rsi14": rsi_last, + "macd": macd_last, + "macd_signal": macd_sig_last, + "macd_hist": macd_hist_last, + "bb_upper": bb_u_last, + "bb_middle": bb_m_last, + "bb_lower": bb_l_last, + "bb_width": bb_width, + } + + return { + "symbol": symbol, + "candlestick_data": candlestick_data, + "volume_data": volume_data, + "ma5_data": ma5_data, "ma20_data": ma20_data, "ma60_data": ma60_data, + "support_data": support_data, "resistance_data": resistance_data, + "rsi_data": rsi_data, "macd_data": macd_data, + "macd_signal_data": macd_signal_data, "macd_hist_data": macd_hist_data, + "bb_upper_data": bb_upper_data, "bb_middle_data": bb_middle_data, + "bb_lower_data": bb_lower_data, + "markers": markers, + "cum_return_data": cum_return_data, + "ret_hs300_data": ret_hs300 if ret_hs300 else [], + "ret_sse_data": ret_sse if ret_sse else [], + "drawdown_data": drawdown_data, + "pnl_bar_data": pnl_bar_data, + "daily_returns_data": daily_returns_data, + "tech_stats": tech_stats, + } + + +def generate_html_report(result, out_path): + """Generate interactive HTML report with TradingView lightweight-charts. + + Features: K-line with volume, support/resistance, pivot levels, + strategy vs 沪深300 vs 上证指数 (cumulative return %), + drawdown, daily P&L, daily return, trade calendar. + """ + ctx = result["context"] + trade_log = result["trade_log"] + recorded = result["recorded_values"] + benchmark = result.get("benchmark", "000300.XSHG") + + start = ctx.start_date + end = ctx.end_date + initial = ctx.portfolio.starting_cash + final = ctx.portfolio.total_value + pnl = final - initial + pnl_pct = (pnl / initial * 100) if initial > 0 else 0.0 + + # Compute all chart data arrays (shared with generate_report_json) + chart = _compute_chart_data(result) + symbol = chart["symbol"] + candlestick_data = chart["candlestick_data"] + ma5_data, ma20_data, ma60_data = chart["ma5_data"], chart["ma20_data"], chart["ma60_data"] + volume_data = chart["volume_data"] + support_data, resistance_data = chart["support_data"], chart["resistance_data"] + rsi_data = chart["rsi_data"] + macd_data = chart["macd_data"] + macd_signal_data = chart["macd_signal_data"] + macd_hist_data = chart["macd_hist_data"] + bb_upper_data = chart["bb_upper_data"] + bb_middle_data = chart["bb_middle_data"] + bb_lower_data = chart["bb_lower_data"] + markers = chart["markers"] + cum_return_data = chart["cum_return_data"] + ret_hs300 = chart["ret_hs300_data"] + ret_sse = chart["ret_sse_data"] + drawdown_data = chart["drawdown_data"] + pnl_bar_data = chart["pnl_bar_data"] + daily_returns_data = chart["daily_returns_data"] dd_hs300 = _build_drawdown_from_cumulative_pct(ret_hs300) dd_sse = _build_drawdown_from_cumulative_pct(ret_sse) + # Benchmark data for metrics calculation + bench_data = _fetch_benchmark_returns(benchmark, start, end, recorded) + # ============================================================ # Performance metrics # ============================================================ @@ -675,30 +768,9 @@ def generate_html_report(result, out_path): sell_count = sum(1 for t in trade_log if t["type"] == "SELL") # ============================================================ - # Technical summary stats + # Technical summary stats (computed by _compute_chart_data) # ============================================================ - tech_stats = {} - if not df.empty: - df_s = df.sort_index() - c = df_s["close"].iloc[-1] - ma5_last = df_s["close"].rolling(5).mean().iloc[-1] - ma20_last = df_s["close"].rolling(20).mean().iloc[-1] - ma60_last = df_s["close"].rolling(60).mean().dropna() - ma60_v = round(float(ma60_last.iloc[-1]), 3) if len(ma60_last) > 0 else None - atr14 = _compute_atr(df_s["high"], df_s["low"], df_s["close"], 14) - vol20 = df_s["volume"].rolling(20).mean().iloc[-1] - vol_ratio = round(float(df_s["volume"].iloc[-1] / vol20), 2) if vol20 > 0 else None - - tech_stats = { - "latest_price": round(c, 3), - "ma5": round(float(ma5_last), 3), - "ma20": round(float(ma20_last), 3), - "ma60": ma60_v, - "atr14": round(float(atr14), 3) if atr14 else None, - "vol_ratio": vol_ratio, - "period_high": round(float(df_s["high"].max()), 3), - "period_low": round(float(df_s["low"].min()), 3), - } + tech_stats = chart.get("tech_stats", {}) # ============================================================ # Build HTML @@ -729,6 +801,13 @@ def generate_html_report(result, out_path): markers_json=json.dumps(markers), support_json=json.dumps(support_data), resistance_json=json.dumps(resistance_data), + rsi_json=json.dumps(rsi_data), + macd_json=json.dumps(macd_data), + macd_signal_json=json.dumps(macd_signal_data), + macd_hist_json=json.dumps(macd_hist_data), + bb_upper_json=json.dumps(bb_upper_data), + bb_middle_json=json.dumps(bb_middle_data), + bb_lower_json=json.dumps(bb_lower_data), cum_return_json=json.dumps(cum_return_data), ret_hs300_json=json.dumps(ret_hs300), ret_sse_json=json.dumps(ret_sse), @@ -1024,6 +1103,38 @@ def _calc_metrics(result, bench_data): #drawdown {{ width: 100%; height: 160px; }} #pnlbar {{ width: 100%; height: 160px; }} #dailyret {{ width: 100%; height: 160px; }} + #rsichart {{ width: 100%; height: 160px; }} + #macdchart {{ width: 100%; height: 160px; }} + /* Indicator toggle panel */ + .indicator-panel {{ + position: absolute; top: 8px; left: 8px; z-index: 10; + display: flex; gap: 4px; + }} + .ind-btn {{ + padding: 4px 10px; font-size: 11px; font-weight: 500; + border: 1px solid var(--border); border-radius: 3px; + background: rgba(255,255,255,.85); color: var(--text-dim); + cursor: pointer; transition: all .15s; backdrop-filter: blur(4px); + }} + .ind-btn.active {{ background: var(--primary); color: #fff; border-color: var(--primary); }} + .ind-btn:hover:not(.active) {{ background: #fafafa; }} + /* Crosshair legend */ + .chart-legend {{ + position: absolute; top: 8px; right: 70px; z-index: 10; + background: rgba(255,255,255,.92); border: 1px solid var(--border); + border-radius: 4px; padding: 6px 10px; font-size: 11px; + font-family: "SF Mono", "Menlo", monospace; line-height: 1.6; + color: var(--text-secondary); pointer-events: none; + backdrop-filter: blur(4px); min-width: 200px; + box-shadow: 0 2px 8px rgba(0,0,0,.08); + display: none; + }} + .chart-legend.visible {{ display: block; }} + .chart-legend .leg-date {{ font-weight: 600; color: var(--text); margin-bottom: 2px; }} + .chart-legend .leg-row {{ display: flex; justify-content: space-between; gap: 12px; }} + .chart-legend .leg-label {{ color: var(--text-dim); }} + .chart-legend .leg-val {{ font-variant-numeric: tabular-nums; }} + .leg-dot {{ display: inline-block; width: 8px; height: 8px; border-radius: 50%; margin-right: 4px; vertical-align: middle; }} /* Legend */ .legend {{ display: flex; gap: 16px; font-size: 12px; color: var(--text-secondary); align-items: center; }} .legend span {{ display: flex; align-items: center; gap: 4px; }} @@ -1315,11 +1426,11 @@ def _calc_metrics(result, bench_data): -
+

K 线图 · 技术指标

-
日 K 线含 MA5/MA20/MA60 均线、20日动态支撑/压力位、成交量柱,以及买卖信号标记。· 使用前复权价格(含分红调整)
+
日 K 线含 MA5/MA20/MA60 均线、布林带(20,2)、20日动态支撑/压力位、成交量柱,以及买卖信号标记。· 使用前复权价格(含分红调整)
MA5 @@ -1329,9 +1440,41 @@ def _calc_metrics(result, bench_data): 卖出
+
+ + + + +
+
+ +
+
+

RSI(14) 相对强弱指标

+
+ RSI(14) + 超卖区 <30 / 超买区 >70 +
+
+
+
+ + +
+
+

MACD(12,26,9) 指数平滑异同移动平均

+
+ MACD + Signal + 柱状图 +
+
+
+
+
@@ -1836,11 +1979,30 @@ def _calc_metrics(result, bench_data): }}); cSeries.setData({candlestick_json}); cSeries.setMarkers({markers_json}); - kChart.addLineSeries({{ color: '#f5222d', lineWidth: 1, priceLineVisible: false, lastValueVisible: false }}).setData({ma5_json}); - kChart.addLineSeries({{ color: '#1890ff', lineWidth: 1, priceLineVisible: false, lastValueVisible: false }}).setData({ma20_json}); - kChart.addLineSeries({{ color: '#722ed1', lineWidth: 1, priceLineVisible: false, lastValueVisible: false }}).setData({ma60_json}); - kChart.addLineSeries({{ color: 'rgba(82,196,26,0.55)', lineWidth: 1, lineStyle: 2, priceLineVisible: false, lastValueVisible: false }}).setData({support_json}); - kChart.addLineSeries({{ color: 'rgba(245,34,45,0.55)', lineWidth: 1, lineStyle: 2, priceLineVisible: false, lastValueVisible: false }}).setData({resistance_json}); + + // MA series (group: 'ma') + const ma5S = kChart.addLineSeries({{ color: '#f5222d', lineWidth: 1, priceLineVisible: false, lastValueVisible: false, crosshairMarkerVisible: false }}); + ma5S.setData({ma5_json}); + const ma20S = kChart.addLineSeries({{ color: '#1890ff', lineWidth: 1, priceLineVisible: false, lastValueVisible: false, crosshairMarkerVisible: false }}); + ma20S.setData({ma20_json}); + const ma60S = kChart.addLineSeries({{ color: '#722ed1', lineWidth: 1, priceLineVisible: false, lastValueVisible: false, crosshairMarkerVisible: false }}); + ma60S.setData({ma60_json}); + + // Bollinger Bands (group: 'bb') + const bbUpperS = kChart.addLineSeries({{ color: 'rgba(24,144,255,0.5)', lineWidth: 1, lineStyle: 2, priceLineVisible: false, lastValueVisible: false, crosshairMarkerVisible: false }}); + bbUpperS.setData({bb_upper_json}); + const bbMiddleS = kChart.addLineSeries({{ color: 'rgba(24,144,255,0.7)', lineWidth: 1, priceLineVisible: false, lastValueVisible: false, crosshairMarkerVisible: false }}); + bbMiddleS.setData({bb_middle_json}); + const bbLowerS = kChart.addLineSeries({{ color: 'rgba(24,144,255,0.5)', lineWidth: 1, lineStyle: 2, priceLineVisible: false, lastValueVisible: false, crosshairMarkerVisible: false }}); + bbLowerS.setData({bb_lower_json}); + + // Support/Resistance (group: 'sr') + const supS = kChart.addLineSeries({{ color: 'rgba(82,196,26,0.55)', lineWidth: 1, lineStyle: 2, priceLineVisible: false, lastValueVisible: false, crosshairMarkerVisible: false }}); + supS.setData({support_json}); + const resS = kChart.addLineSeries({{ color: 'rgba(245,34,45,0.55)', lineWidth: 1, lineStyle: 2, priceLineVisible: false, lastValueVisible: false, crosshairMarkerVisible: false }}); + resS.setData({resistance_json}); + + // Volume (group: 'vol') const volS = kChart.addHistogramSeries({{ priceFormat: {{ type: 'volume' }}, priceScaleId: 'vol', @@ -1910,6 +2072,61 @@ def _calc_metrics(result, bench_data): excessLine.applyOptions({{ visible: retVis.excess }}); }}; + /* RSI(14) chart */ + const rsiEl = document.getElementById('rsichart'); + const rsiChart = LightweightCharts.createChart(rsiEl, {{ + ...cmn, width: rsiEl.clientWidth, height: 160, + rightPriceScale: {{ scaleMargins: {{ top: 0.05, bottom: 0.05 }} }}, + }}); + const rsiLine = rsiChart.addLineSeries({{ + color: '#722ed1', lineWidth: 1.5, priceLineVisible: false, lastValueVisible: false, + }}); + rsiLine.setData({rsi_json}); + // Overbought line (70) + const rsiOB = rsiChart.addLineSeries({{ + color: 'rgba(245,34,45,0.4)', lineWidth: 1, lineStyle: 2, priceLineVisible: false, lastValueVisible: false, + }}); + rsiOB.setData({rsi_json}.length > 0 ? {rsi_json}.map(d => ({{ time: d.time, value: 70 }})) : []); + // Oversold line (30) + const rsiOS = rsiChart.addLineSeries({{ + color: 'rgba(82,196,26,0.4)', lineWidth: 1, lineStyle: 2, priceLineVisible: false, lastValueVisible: false, + }}); + rsiOS.setData({rsi_json}.length > 0 ? {rsi_json}.map(d => ({{ time: d.time, value: 30 }})) : []); + // Middle line (50) + const rsiMid = rsiChart.addLineSeries({{ + color: 'rgba(140,140,140,0.3)', lineWidth: 1, lineStyle: 2, priceLineVisible: false, lastValueVisible: false, + }}); + rsiMid.setData({rsi_json}.length > 0 ? {rsi_json}.map(d => ({{ time: d.time, value: 50 }})) : []); + rsiChart.timeScale().fitContent(); + + /* MACD(12,26,9) chart */ + const macdEl = document.getElementById('macdchart'); + const macdChart = LightweightCharts.createChart(macdEl, {{ + ...cmn, width: macdEl.clientWidth, height: 160, + rightPriceScale: {{ scaleMargins: {{ top: 0.05, bottom: 0.05 }} }}, + }}); + // Histogram + const macdHist = macdChart.addHistogramSeries({{ + priceFormat: {{ type: 'price' }}, + }}); + macdHist.setData({macd_hist_json}); + // MACD line + const macdLineS = macdChart.addLineSeries({{ + color: '#1890ff', lineWidth: 1.5, priceLineVisible: false, lastValueVisible: false, + }}); + macdLineS.setData({macd_json}); + // Signal line + const macdSigS = macdChart.addLineSeries({{ + color: '#fa8c16', lineWidth: 1, priceLineVisible: false, lastValueVisible: false, + }}); + macdSigS.setData({macd_signal_json}); + // Zero line + const macdZeroData = {macd_json}.length > 0 ? [{{ time: {macd_json}[0].time, value: 0 }}, {{ time: {macd_json}[{macd_json}.length - 1].time, value: 0 }}] : []; + macdChart.addLineSeries({{ + color: '#d9d9d9', lineWidth: 1, lineStyle: 2, priceLineVisible: false, lastValueVisible: false, + }}).setData(macdZeroData); + macdChart.timeScale().fitContent(); + /* Drawdown — strategy area + HS300 + SSE lines */ const ddEl = document.getElementById('drawdown'); const ddChart = LightweightCharts.createChart(ddEl, {{ @@ -1980,7 +2197,7 @@ def _calc_metrics(result, bench_data): drChart.timeScale().fitContent(); /* Sync all time scales */ - const allCharts = [kChart, rChart, ddChart, pChart, drChart]; + const allCharts = [kChart, rChart, ddChart, pChart, drChart, rsiChart, macdChart]; allCharts.forEach(src => {{ src.timeScale().subscribeVisibleLogicalRangeChange(range => {{ if (!range) return; @@ -1993,10 +2210,88 @@ def _calc_metrics(result, bench_data): window.addEventListener('resize', () => {{ clearTimeout(rTimer); rTimer = setTimeout(() => {{ - [[kChart, kEl], [rChart, rEl], [ddChart, ddEl], [pChart, pEl], [drChart, drEl]] + [[kChart, kEl], [rChart, rEl], [ddChart, ddEl], [pChart, pEl], [drChart, drEl], [rsiChart, rsiEl], [macdChart, macdEl]] .forEach(([c, el]) => c.applyOptions({{ width: el.clientWidth }})); }}, 150); }}); + + /* ================================================================= + INDICATOR TOGGLE + ================================================================= */ + const indGroups = {{ + ma: [ma5S, ma20S, ma60S], + bb: [bbUpperS, bbMiddleS, bbLowerS], + vol: [volS], + sr: [supS, resS], + }}; + window.toggleInd = function(group, btn) {{ + const show = !btn.classList.contains('active'); + btn.classList.toggle('active'); + (indGroups[group] || []).forEach(s => s.applyOptions({{ visible: show }})); + }}; + + /* ================================================================= + CROSSHAIR-LINKED DYNAMIC LEGEND + ================================================================= */ + const legendEl = document.getElementById('klineLegend'); + // Build lookup maps for indicator data + function buildMap(arr) {{ + const m = {{}}; + arr.forEach(d => {{ m[d.time] = d.value; }}); + return m; + }} + const ma5Map = buildMap({ma5_json}); + const ma20Map = buildMap({ma20_json}); + const ma60Map = buildMap({ma60_json}); + const rsiMap = buildMap({rsi_json}); + const macdMap = buildMap({macd_json}); + const sigMap = buildMap({macd_signal_json}); + const histMap = buildMap({macd_hist_json}); + const bbUpMap = buildMap({bb_upper_json}); + const bbMidMap = buildMap({bb_middle_json}); + const bbLoMap = buildMap({bb_lower_json}); + const volMap = buildMap({volume_json}); + + function fmt(v, d) {{ + if (v === undefined || v === null || isNaN(v)) return '—'; + return Number(v).toFixed(d || 2); + }} + function fmtVol(v) {{ + if (v === undefined || v === null || isNaN(v)) return '—'; + const n = Number(v); + if (n >= 1e8) return (n / 1e8).toFixed(2) + '亿'; + if (n >= 1e4) return (n / 1e4).toFixed(2) + '万'; + return n.toLocaleString(); + }} + + kChart.subscribeCrosshairMove(param => {{ + if (!param || !param.time) {{ + legendEl.classList.remove('visible'); + return; + }} + legendEl.classList.add('visible'); + const t = param.time; + const sd = param.seriesData.get(cSeries); + const o = sd ? fmt(sd.open, 3) : '—'; + const h = sd ? fmt(sd.high, 3) : '—'; + const l = sd ? fmt(sd.low, 3) : '—'; + const c = sd ? fmt(sd.close, 3) : '—'; + + let html = `
${{t}}
`; + html += `
O/H/L/C${{o}} / ${{h}} / ${{l}} / ${{c}}
`; + html += `
MA5${{fmt(ma5Map[t])}}
`; + html += `
MA20${{fmt(ma20Map[t])}}
`; + html += `
MA60${{fmt(ma60Map[t])}}
`; + html += `
RSI(14)${{fmt(rsiMap[t])}}
`; + html += `
MACD${{fmt(macdMap[t], 4)}}
`; + html += `
Signal${{fmt(sigMap[t], 4)}}
`; + html += `
MACD Hist${{fmt(histMap[t], 4)}}
`; + html += `
BB Upper${{fmt(bbUpMap[t], 3)}}
`; + html += `
BB Middle${{fmt(bbMidMap[t], 3)}}
`; + html += `
BB Lower${{fmt(bbLoMap[t], 3)}}
`; + html += `
VOL${{fmtVol(volMap[t])}}
`; + legendEl.innerHTML = html; + }}); }} catch(e) {{ chartError = true; console.error('Chart initialization error:', e); @@ -2025,6 +2320,14 @@ def _calc_metrics(result, bench_data): ['MA20', tech.ma20], ['MA60', tech.ma60], ['ATR(14)', tech.atr14], + ['RSI(14)', tech.rsi14], + ['MACD', tech.macd], + ['MACD Signal', tech.macd_signal], + ['MACD Hist', tech.macd_hist], + ['BB Upper', tech.bb_upper], + ['BB Middle',tech.bb_middle], + ['BB Lower', tech.bb_lower], + ['BB Width(%)', tech.bb_width], ['量比', tech.vol_ratio], ['期间最高', tech.period_high], ['期间最低', tech.period_low], @@ -2339,6 +2642,9 @@ def generate_report_json(result, out_path): "cumulative_return": round(r["total_value"] / initial - 1, 6) if initial > 0 else 0.0, }) + # Chart data arrays for native Lightweight Charts rendering + chart = _compute_chart_data(result) + report = { "metadata": { "generated_at": str(datetime.datetime.now().replace(microsecond=0)), @@ -2393,6 +2699,28 @@ def generate_report_json(result, out_path): if pos.amount > 0 }, "cumulative_returns": cumulative_returns, + # Chart data arrays for native Lightweight Charts rendering (ReportViewer) + "candlestick_data": chart["candlestick_data"], + "volume_data": chart["volume_data"], + "ma5_data": chart["ma5_data"], + "ma20_data": chart["ma20_data"], + "ma60_data": chart["ma60_data"], + "rsi_data": chart["rsi_data"], + "macd_data": chart["macd_data"], + "macd_signal_data": chart["macd_signal_data"], + "macd_hist_data": chart["macd_hist_data"], + "bb_upper_data": chart["bb_upper_data"], + "bb_middle_data": chart["bb_middle_data"], + "bb_lower_data": chart["bb_lower_data"], + "support_data": chart["support_data"], + "resistance_data": chart["resistance_data"], + "markers": chart["markers"], + "cum_return_data": chart["cum_return_data"], + "ret_hs300_data": chart["ret_hs300_data"], + "ret_sse_data": chart["ret_sse_data"], + "drawdown_data": chart["drawdown_data"], + "pnl_bar_data": chart["pnl_bar_data"], + "daily_returns_data": chart["daily_returns_data"], } # Add risk metrics diff --git a/web_strategy_studio/backend/studio_api/backtest_executor.py b/web_strategy_studio/backend/studio_api/backtest_executor.py index 778d331..33cf405 100644 --- a/web_strategy_studio/backend/studio_api/backtest_executor.py +++ b/web_strategy_studio/backend/studio_api/backtest_executor.py @@ -11,7 +11,7 @@ import tempfile from datetime import date, datetime from pathlib import Path -from typing import Any +from typing import Any, Dict, Optional import pandas as pd @@ -32,7 +32,7 @@ def _parse_iso(d: str) -> date: def _estimate_trading_fraction(done_days: int, start: date, end: date) -> float: """Rough progress from trading-day span when bar-level hooks are unavailable.""" # Use pandas bdate_range (Mon-Fri) as a proxy for trading days (~250/yr) - # instead of calendar days (~365/yr) to avoid the ~1.46× overestimate. + # instead of calendar days (~365/yr) to avoid the ~1.46x overestimate. total = max(len(pd.bdate_range(start=start, end=end)), 1) return min(0.95, 0.15 + 0.75 * (done_days / total)) @@ -40,9 +40,9 @@ def _estimate_trading_fraction(done_days: int, start: date, end: date) -> float: async def execute_backtest( run_id: str, source_code: str, - params: dict[str, Any], + params: Dict[str, Any], on_log: Any = None, -) -> dict[str, Any]: +) -> Dict[str, Any]: """Run isolated subprocess; stream logs; return artifact paths or error.""" work = Path(tempfile.mkdtemp(prefix=f"eqrun_{run_id}_")) artifact_sub = settings.artifact_dir / "reports" / run_id @@ -80,7 +80,7 @@ async def execute_backtest( } ) - filtered_env: dict[str, str] = {} + filtered_env: Dict[str, str] = {} for k, v in os.environ.items(): if k in _ALLOWED_ENV_KEYS or any(k.startswith(p) for p in _ALLOWED_ENV_PREFIXES): filtered_env[k] = v @@ -128,7 +128,7 @@ async def pump_stream(stream: asyncio.StreamReader, name: str) -> None: log_lines += 1 # S5: Parse structured progress lines emitted by the engine. - # Format: "📍 Backtest progress: N/M (pct%)" or "Backtest progress N/M" + # Format: "Backtest progress: N/M (pct%)" or "Backtest progress N/M" # The regex handles optional emoji prefix, colon, and trailing percentage. m = _PROGRESS_RE.search(line) if m: @@ -177,7 +177,7 @@ async def progress_tick() -> None: t_err = asyncio.create_task(pump_stream(proc.stderr, "stderr")) # type: ignore[arg-type] t_prog = asyncio.create_task(progress_tick()) - timeout_payload: dict[str, Any] | None = None + timeout_payload: Optional[Dict[str, Any]] = None try: await asyncio.wait_for(proc.wait(), timeout=settings.run_timeout_sec) except asyncio.TimeoutError: @@ -198,7 +198,7 @@ async def progress_tick() -> None: return timeout_payload result_path = work / "result.json" - payload: dict[str, Any] = {"ok": False, "error": "No result.json", "error_code": "NO_RESULT"} + payload: Dict[str, Any] = {"ok": False, "error": "No result.json", "error_code": "NO_RESULT"} if result_path.is_file(): try: payload = json.loads(result_path.read_text(encoding="utf-8")) diff --git a/web_strategy_studio/backend/studio_api/completion_service.py b/web_strategy_studio/backend/studio_api/completion_service.py index 7549de5..217c1e4 100644 --- a/web_strategy_studio/backend/studio_api/completion_service.py +++ b/web_strategy_studio/backend/studio_api/completion_service.py @@ -5,6 +5,7 @@ import json from functools import lru_cache from pathlib import Path +from typing import Any, Dict, List @lru_cache @@ -13,14 +14,14 @@ def _symbols_path() -> Path: @lru_cache -def _load_symbols() -> list[dict]: +def _load_symbols() -> List[Dict[str, Any]]: p = _symbols_path() if not p.is_file(): return [] return json.loads(p.read_text(encoding="utf-8")) -def suggest(source: str, cursor_line: int, cursor_col: int) -> list[dict]: +def suggest(source: str, cursor_line: int, cursor_col: int) -> List[Dict[str, Any]]: lines = source.splitlines() if cursor_line < 1 or cursor_line > len(lines): line = "" diff --git a/web_strategy_studio/backend/studio_api/config.py b/web_strategy_studio/backend/studio_api/config.py index 782acb3..5d5c73e 100644 --- a/web_strategy_studio/backend/studio_api/config.py +++ b/web_strategy_studio/backend/studio_api/config.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from pathlib import Path +from typing import List, Optional from pydantic_settings import BaseSettings, SettingsConfigDict @@ -16,7 +19,7 @@ class Settings(BaseSettings): ) database_url: str = "sqlite+aiosqlite:///./studio.sqlite3" - redis_url: str | None = None # reserved for future queue split + redis_url: Optional[str] = None # reserved for future queue split # S11: Always resolve artifact_dir to absolute path so subprocess CWD # (a temp directory) doesn't break file lookups in backtest_executor. artifact_dir: Path = _default_repo_root() / "artifacts" @@ -29,7 +32,7 @@ class Settings(BaseSettings): api_port: int = 8080 # S1: CORS — restrict to localhost by default; override via env for staging/production. # Do NOT use ["*"] together with allow_credentials=True (browser spec disallows it). - cors_allowed_origins: list[str] = [ + cors_allowed_origins: List[str] = [ "http://localhost:5173", "http://127.0.0.1:5173", "http://localhost:8080", diff --git a/web_strategy_studio/backend/studio_api/format_service.py b/web_strategy_studio/backend/studio_api/format_service.py index f634000..a3b691b 100644 --- a/web_strategy_studio/backend/studio_api/format_service.py +++ b/web_strategy_studio/backend/studio_api/format_service.py @@ -6,11 +6,12 @@ import sys import tempfile from pathlib import Path +from typing import Any, Dict, Optional -def format_python(source: str, timeout: float = 30.0) -> dict: +def format_python(source: str, timeout: float = 30.0) -> Dict[str, Any]: proc = None - tmp_path: str | None = None + tmp_path: Optional[str] = None try: with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False, encoding="utf-8") as f: f.write(source) diff --git a/web_strategy_studio/backend/studio_api/isolated_runner.py b/web_strategy_studio/backend/studio_api/isolated_runner.py index da28218..72b0356 100644 --- a/web_strategy_studio/backend/studio_api/isolated_runner.py +++ b/web_strategy_studio/backend/studio_api/isolated_runner.py @@ -12,6 +12,7 @@ import sys import traceback from pathlib import Path +from typing import Optional def main() -> int: @@ -125,10 +126,10 @@ def _write_result( work: Path, *, ok: bool, - html: str | None = None, - report_json: str | None = None, - error: str | None = None, - error_code: str | None = None, + html: Optional[str] = None, + report_json: Optional[str] = None, + error: Optional[str] = None, + error_code: Optional[str] = None, ) -> None: # Local import: avoid any accidental shadowing of the stdlib `json` module. import json as json_stdlib diff --git a/web_strategy_studio/backend/studio_api/lint_service.py b/web_strategy_studio/backend/studio_api/lint_service.py index ee5b677..d5c336a 100644 --- a/web_strategy_studio/backend/studio_api/lint_service.py +++ b/web_strategy_studio/backend/studio_api/lint_service.py @@ -8,6 +8,7 @@ import sys import tempfile from pathlib import Path +from typing import Any, Dict, List from studio_api.security_scanner import SecurityScanner, require_initialize_function @@ -15,8 +16,8 @@ PROFILE_STRICT = "strict" -def _syntax_errors(source: str) -> list[dict]: - out: list[dict] = [] +def _syntax_errors(source: str) -> List[Dict[str, Any]]: + out: List[Dict[str, Any]] = [] try: compile(source, "", "exec", ast.PyCF_ONLY_AST) except SyntaxError as e: @@ -31,7 +32,7 @@ def _syntax_errors(source: str) -> list[dict]: return out -def _ruff_issues(source: str, timeout: float = 15.0) -> list[dict]: +def _ruff_issues(source: str, timeout: float = 15.0) -> List[Dict[str, Any]]: with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False, encoding="utf-8") as f: f.write(source) tmp = f.name @@ -68,7 +69,7 @@ def _ruff_issues(source: str, timeout: float = 15.0) -> list[dict]: return issues -def lint_source(source: str, profile: str = PROFILE_FAST) -> dict: +def lint_source(source: str, profile: str = PROFILE_FAST) -> Dict[str, Any]: syntax_errors = _syntax_errors(source) scanner = SecurityScanner() sec = scanner.scan(source) @@ -76,7 +77,7 @@ def lint_source(source: str, profile: str = PROFILE_FAST) -> dict: security_notes = [{"code": n.code, "line": n.line, "message": n.message} for n in sec] - lint_issues: list[dict] = [] + lint_issues: List[Dict[str, Any]] = [] if not syntax_errors: lint_issues = _ruff_issues(source) diff --git a/web_strategy_studio/backend/studio_api/models.py b/web_strategy_studio/backend/studio_api/models.py index d65747b..aa82b80 100644 --- a/web_strategy_studio/backend/studio_api/models.py +++ b/web_strategy_studio/backend/studio_api/models.py @@ -1,4 +1,5 @@ from datetime import datetime, timezone +from typing import Dict, List, Optional from sqlalchemy import JSON, DateTime, Float, ForeignKey, Integer, String, Text, func from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -16,22 +17,22 @@ class Strategy(Base): __tablename__ = "strategies" id: Mapped[str] = mapped_column(String(64), primary_key=True) - owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True) + owner_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) name: Mapped[str] = mapped_column(Text) - description: Mapped[str | None] = mapped_column(Text, nullable=True) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) current_version: Mapped[int] = mapped_column(Integer, default=1) - default_params: Mapped[dict | None] = mapped_column(JSON, nullable=True) + default_params: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True) - versions: Mapped[list["StrategyVersion"]] = relationship( + versions: Mapped[List["StrategyVersion"]] = relationship( back_populates="strategy", cascade="all, delete-orphan", order_by="StrategyVersion.version", ) - runs: Mapped[list["Run"]] = relationship(back_populates="strategy") + runs: Mapped[List["Run"]] = relationship(back_populates="strategy") class StrategyVersion(Base): @@ -44,9 +45,9 @@ class StrategyVersion(Base): version: Mapped[int] = mapped_column(Integer) source_code: Mapped[str] = mapped_column(Text) # B4/B15: content hash for dedup; sha256 hex (64 chars) or NULL for legacy rows - content_hash: Mapped[str | None] = mapped_column(String(64), nullable=True) + content_hash: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # Named snapshot label (set by POST /snapshot) - label: Mapped[str | None] = mapped_column(String(256), nullable=True) + label: Mapped[Optional[str]] = mapped_column(String(256), nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) strategy: Mapped["Strategy"] = relationship(back_populates="versions") @@ -62,14 +63,14 @@ class Run(Base): strategy_version: Mapped[int] = mapped_column(Integer) status: Mapped[str] = mapped_column(String(32), default="queued") progress: Mapped[float] = mapped_column(Float, default=0.0) - stage: Mapped[str | None] = mapped_column(String(64), nullable=True) - params: Mapped[dict] = mapped_column(JSON, default=dict) - error_code: Mapped[str | None] = mapped_column(String(64), nullable=True) - error_message: Mapped[str | None] = mapped_column(Text, nullable=True) - html_path: Mapped[str | None] = mapped_column(Text, nullable=True) - json_path: Mapped[str | None] = mapped_column(Text, nullable=True) - started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) - finished_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) - worker_hostname: Mapped[str | None] = mapped_column(String(256), nullable=True) + stage: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + params: Mapped[Dict] = mapped_column(JSON, default=dict) + error_code: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + html_path: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + json_path: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + worker_hostname: Mapped[Optional[str]] = mapped_column(String(256), nullable=True) strategy: Mapped["Strategy"] = relationship(back_populates="runs") diff --git a/web_strategy_studio/backend/studio_api/proc_registry.py b/web_strategy_studio/backend/studio_api/proc_registry.py index 577ba03..3b68e57 100644 --- a/web_strategy_studio/backend/studio_api/proc_registry.py +++ b/web_strategy_studio/backend/studio_api/proc_registry.py @@ -3,12 +3,10 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING +from typing import Dict, Optional -if TYPE_CHECKING: - pass -_procs: dict[str, asyncio.subprocess.Process] = {} +_procs: Dict[str, asyncio.subprocess.Process] = {} def register(run_id: str, proc: asyncio.subprocess.Process) -> None: @@ -19,7 +17,7 @@ def unregister(run_id: str) -> None: _procs.pop(run_id, None) -def get_proc(run_id: str) -> asyncio.subprocess.Process | None: +def get_proc(run_id: str) -> Optional[asyncio.subprocess.Process]: """Public accessor for a live subprocess handle (B21).""" return _procs.get(run_id) diff --git a/web_strategy_studio/backend/studio_api/routers/runs.py b/web_strategy_studio/backend/studio_api/routers/runs.py index ea38524..c41ffa2 100644 --- a/web_strategy_studio/backend/studio_api/routers/runs.py +++ b/web_strategy_studio/backend/studio_api/routers/runs.py @@ -9,11 +9,11 @@ import time from datetime import datetime, timezone from pathlib import Path -from typing import Any +from typing import Any, Dict, List, Optional, Set import structlog from fastapi import APIRouter, Depends, Header, HTTPException, Request -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -205,7 +205,7 @@ async def _process_run_task(run_id: str) -> None: "message": "finished", }, ) - done_payload: dict[str, Any] = { + done_payload: Dict[str, Any] = { "status": "succeeded" if exec_result.get("ok") else "failed", "artifacts": arts, } @@ -220,7 +220,7 @@ async def create_run( request: Request, body: CreateRunBody, session: AsyncSession = Depends(get_session), - idempotency_key: str | None = Header(None, alias="Idempotency-Key"), + idempotency_key: Optional[str] = Header(None, alias="Idempotency-Key"), ): # B18: per-IP rate limit client_ip = (request.headers.get("X-Forwarded-For") or "").split(",")[0].strip() or ( @@ -374,7 +374,7 @@ async def delete_run(run_id: str, session: AsyncSession = Depends(get_session)): async def run_stream( run_id: str, session: AsyncSession = Depends(get_session), - last_event_id: str | None = Header(None, alias="Last-Event-ID"), + last_event_id: Optional[str] = Header(None, alias="Last-Event-ID"), ): """SSE endpoint with Last-Event-ID replay and immediate done for terminal runs (B6/B13).""" # Resolve last_event_id to an int (default -1 = send everything). @@ -460,7 +460,7 @@ async def get_queue(): @router.get("/runs", response_model=RunListResponse) async def list_runs( session: AsyncSession = Depends(get_session), - strategy_id: str | None = None, + strategy_id: Optional[str] = None, limit: int = 100, offset: int = 0, ): @@ -481,7 +481,7 @@ async def list_runs( total = (await session.execute(count_q)).scalar_one() or 0 rows = (await session.execute(q.limit(limit).offset(offset))).scalars().all() - items: list[RunListItem] = [] + items: List[RunListItem] = [] for run in rows: items.append( RunListItem( @@ -500,7 +500,7 @@ async def list_runs( return RunListResponse(runs=items, total=total) -def _read_metrics_from_json(run: Run) -> dict[str, Any]: +def _read_metrics_from_json(run: Run) -> Dict[str, Any]: """Try to read metrics from the stored report.json artifact.""" alt = settings.artifact_dir / "reports" / run.id / "report.json" if alt.is_file(): @@ -518,7 +518,7 @@ def _read_metrics_from_json(run: Run) -> dict[str, Any]: return {} -def _extract_equity_curve(raw: dict[str, Any]) -> list[EquityCurvePoint]: +def _extract_equity_curve(raw: Dict[str, Any]) -> List[EquityCurvePoint]: """Extract equity curve from report.json. eqlib uses ``cumulative_returns`` as a list of @@ -526,7 +526,7 @@ def _extract_equity_curve(raw: dict[str, Any]) -> list[EquityCurvePoint]: We expose it as ``{"date": str, "value": float}`` (portfolio value). """ points = raw.get("cumulative_returns", []) - result: list[EquityCurvePoint] = [] + result: List[EquityCurvePoint] = [] for p in points: date = p.get("date") value = p.get("total_value") @@ -554,9 +554,9 @@ def _extract_equity_curve(raw: dict[str, Any]) -> list[EquityCurvePoint]: ) -def _extract_metrics(raw: dict[str, Any]) -> dict[str, float | None]: +def _extract_metrics(raw: Dict[str, Any]) -> Dict[str, Optional[float]]: risk = raw.get("risk_metrics", raw) - metrics: dict[str, float | None] = {} + metrics: Dict[str, Optional[float]] = {} for key in _METRIC_KEYS: val = risk.get(key) if val is not None: @@ -585,11 +585,11 @@ async def get_run_metrics(run_id: str, session: AsyncSession = Depends(get_sessi @router.post("/runs/compare", response_model=CompareResponse) async def compare_runs( - body: dict[str, Any], + body: Dict[str, Any], session: AsyncSession = Depends(get_session), ): """Compare metrics + equity curves across multiple runs (B22).""" - run_ids: list[str] = body.get("run_ids", []) + run_ids: List[str] = body.get("run_ids", []) if not run_ids: raise HTTPException( status_code=400, @@ -601,8 +601,8 @@ async def compare_runs( rows = (await session.execute(stmt)).scalars().all() runs_by_id = {r.id: r for r in rows} - runs_items: list[CompareRunItem] = [] - all_metric_keys: set[str] = set() + runs_items: List[CompareRunItem] = [] + all_metric_keys: Set[str] = set() for rid in run_ids: run = runs_by_id.get(rid) @@ -624,3 +624,17 @@ async def compare_runs( ) common = sorted(all_metric_keys) return CompareResponse(runs=runs_items, common_keys=common) + + +@router.get("/runs/{run_id}/report/data") +async def get_run_report_data(run_id: str, session: AsyncSession = Depends(get_session)): + """Return the full report.json contents for native frontend rendering.""" + run = await session.get(Run, run_id) + if run is None: + raise HTTPException(status_code=404, detail=api_error("NOT_FOUND", "Run not found")) + if run.status != "succeeded": + raise HTTPException(status_code=400, detail=api_error("RUN_NOT_SUCCEEDED", "Run has not completed")) + raw = _read_metrics_from_json(run) + if not raw: + raise HTTPException(status_code=404, detail=api_error("REPORT_NOT_FOUND", "Report data not found")) + return JSONResponse(content=raw) diff --git a/web_strategy_studio/backend/studio_api/routers/strategies.py b/web_strategy_studio/backend/studio_api/routers/strategies.py index b73984f..80eb771 100644 --- a/web_strategy_studio/backend/studio_api/routers/strategies.py +++ b/web_strategy_studio/backend/studio_api/routers/strategies.py @@ -10,6 +10,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload +from typing import List from studio_api.config import settings from studio_api.db import get_session @@ -259,7 +260,7 @@ async def create_snapshot( ) -@router.get("/strategies/{strategy_id}/versions", response_model=list[StrategyVersionItem]) +@router.get("/strategies/{strategy_id}/versions", response_model=List[StrategyVersionItem]) async def list_strategy_versions( strategy_id: str, session: AsyncSession = Depends(get_session), diff --git a/web_strategy_studio/backend/studio_api/run_queue.py b/web_strategy_studio/backend/studio_api/run_queue.py index 9303884..0a30ef2 100644 --- a/web_strategy_studio/backend/studio_api/run_queue.py +++ b/web_strategy_studio/backend/studio_api/run_queue.py @@ -15,6 +15,7 @@ import asyncio from collections.abc import Awaitable, Callable from datetime import datetime, timezone +from typing import Dict, List, Optional, Set, Tuple import structlog @@ -23,22 +24,22 @@ log = structlog.get_logger(__name__) # Queue of (run_id, coroutine_factory) pairs. -# A coroutine_factory is a zero-argument async callable that runs the task. +# A coroutine_factory is a zero-arg async callable that runs the task. _TaskCoro = Callable[[], Awaitable[None]] # Module-level state — re-initialised by start_worker() on each lifespan start # so tests using different event loops work correctly. -_queue: asyncio.Queue | None = None -_worker_task: asyncio.Task | None = None +_queue: Optional[asyncio.Queue] = None +_worker_task: Optional[asyncio.Task] = None # Ordered list of run_ids currently sitting in the queue (not yet started). -_pending: list[str] = [] +_pending: List[str] = [] # Set of run_ids currently executing. -_active: set[str] = set() +_active: Set[str] = set() # Semaphore limiting simultaneous executions (created in _worker()). -_semaphore: asyncio.Semaphore | None = None +_semaphore: Optional[asyncio.Semaphore] = None -def queue_position(run_id: str) -> int | None: +def queue_position(run_id: str) -> Optional[int]: """Return 1-based queue position of run_id, or None if not in queue.""" try: return _pending.index(run_id) + 1 @@ -46,11 +47,11 @@ def queue_position(run_id: str) -> int | None: return None -def active_run_ids() -> list[str]: +def active_run_ids() -> List[str]: return list(_active) -def pending_run_ids() -> list[str]: +def pending_run_ids() -> List[str]: return list(_pending) @@ -124,7 +125,7 @@ async def mark_orphan_runs_failed() -> None: async with SessionLocal() as session: result = await session.execute(select(Run).where(Run.status.in_(["running", "queued"]))) - orphans: list[Run] = result.scalars().all() + orphans: List[Run] = result.scalars().all() if not orphans: return now = datetime.now(timezone.utc) @@ -156,9 +157,9 @@ class _RateLimiter: def __init__(self, limit: int, window_sec: int) -> None: self._limit = limit self._window = window_sec - self._hits: dict[str, list[float]] = {} + self._hits: Dict[str, List[float]] = {} - def is_allowed(self, key: str) -> tuple[bool, int]: + def is_allowed(self, key: str) -> Tuple[bool, int]: """Return (allowed, remaining_hits). Updates the sliding window.""" import time diff --git a/web_strategy_studio/backend/studio_api/schemas.py b/web_strategy_studio/backend/studio_api/schemas.py index 7fedb8f..38b01ea 100644 --- a/web_strategy_studio/backend/studio_api/schemas.py +++ b/web_strategy_studio/backend/studio_api/schemas.py @@ -1,29 +1,27 @@ """Pydantic schemas aligned with design spec §4.""" -from __future__ import annotations - from datetime import datetime -from typing import Any, Literal +from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, Field class ErrorDetail(BaseModel): - field: str | None = None + field: Optional[str] = None issue: str class ErrorEnvelope(BaseModel): code: str message: str - details: list[ErrorDetail] | None = None + details: Optional[List[ErrorDetail]] = None class ErrorResponse(BaseModel): error: ErrorEnvelope -def api_error(code: str, message: str, details: Any = None) -> dict: +def api_error(code: str, message: str, details: Any = None) -> Dict[str, Any]: """Build a standardised error envelope dict for HTTPException detail.""" return {"error": {"code": code, "message": message, "details": details}} @@ -34,16 +32,16 @@ class DefaultBacktestParams(BaseModel): starting_cash: float = 100_000 benchmark: str = "000300.XSHG" use_local: bool = False - report_dir: str | None = None - securities: list[str] | None = None + report_dir: Optional[str] = None + securities: Optional[List[str]] = None max_memory_mb: int = 1024 class CreateStrategyBody(BaseModel): name: str - description: str | None = None + description: Optional[str] = None source_code: str - default_params: DefaultBacktestParams | None = None + default_params: Optional[DefaultBacktestParams] = None class StrategyCreated(BaseModel): @@ -56,35 +54,35 @@ class StrategyCreated(BaseModel): class StrategyDetail(BaseModel): id: str name: str - description: str | None + description: Optional[str] source_code: str version: int - updated_at: datetime | None - default_params: dict | None = None + updated_at: Optional[datetime] + default_params: Optional[Dict[str, Any]] = None class StrategyVersionItem(BaseModel): """One entry in GET /strategies/{id}/versions.""" version: int - label: str | None = None - content_hash: str | None = None + label: Optional[str] = None + content_hash: Optional[str] = None created_at: datetime class PatchStrategyBody(BaseModel): - source_code: str | None = None - name: str | None = None - description: str | None = None + source_code: Optional[str] = None + name: Optional[str] = None + description: Optional[str] = None class SnapshotBody(BaseModel): - label: str | None = None + label: Optional[str] = None class StrategyTemplateResponse(BaseModel): source_code: str - hints: list[str] + hints: List[str] class LintBody(BaseModel): @@ -94,9 +92,9 @@ class LintBody(BaseModel): class LintResponse(BaseModel): ok: bool - syntax_errors: list[dict[str, Any]] - lint_issues: list[dict[str, Any]] - security_notes: list[dict[str, Any]] + syntax_errors: List[Dict[str, Any]] + lint_issues: List[Dict[str, Any]] + security_notes: List[Dict[str, Any]] class RunParams(BaseModel): @@ -105,14 +103,14 @@ class RunParams(BaseModel): starting_cash: float = 100_000 benchmark: str = "000300.XSHG" use_local: bool = False - report_dir: str | None = None - securities: list[str] | None = None + report_dir: Optional[str] = None + securities: Optional[List[str]] = None max_memory_mb: int = 1024 class CreateRunBody(BaseModel): strategy_id: str - source_code: str | None = None + source_code: Optional[str] = None params: RunParams @@ -121,24 +119,24 @@ class CreateRunResponse(BaseModel): status: Literal["queued", "running", "succeeded", "failed", "cancelled"] poll_url: str stream_url: str # B16: renamed from ws_url - queue_position: int | None = None # B18: 1-based position when queued + queue_position: Optional[int] = None # B18: 1-based position when queued class RunArtifacts(BaseModel): - html_report_url: str | None = None - json_report_url: str | None = None + html_report_url: Optional[str] = None + json_report_url: Optional[str] = None class RunStatusResponse(BaseModel): run_id: str status: str progress: float - stage: str | None = None - started_at: datetime | None = None - finished_at: datetime | None = None + stage: Optional[str] = None + started_at: Optional[datetime] = None + finished_at: Optional[datetime] = None artifacts: RunArtifacts - error: str | None = None - queue_position: int | None = None # B18 + error: Optional[str] = None + queue_position: Optional[int] = None # B18 class CompletionBody(BaseModel): @@ -156,7 +154,7 @@ class CompletionItem(BaseModel): class CompletionResponse(BaseModel): - suggestions: list[CompletionItem] + suggestions: List[CompletionItem] class FormatBody(BaseModel): @@ -173,33 +171,33 @@ class RunListItem(BaseModel): run_id: str strategy_id: str - strategy_name: str | None = None + strategy_name: Optional[str] = None status: str progress: float - stage: str | None = None - started_at: datetime | None = None - finished_at: datetime | None = None - error_message: str | None = None - queue_position: int | None = None # B18 + stage: Optional[str] = None + started_at: Optional[datetime] = None + finished_at: Optional[datetime] = None + error_message: Optional[str] = None + queue_position: Optional[int] = None # B18 class RunListResponse(BaseModel): - runs: list[RunListItem] + runs: List[RunListItem] total: int class MetricValue(BaseModel): name: str - value: float | None = None + value: Optional[float] = None label: str = "" class RunMetricsResponse(BaseModel): run_id: str status: str - metrics: dict[str, float | None] + metrics: Dict[str, Optional[float]] # Raw dict from report.json so the frontend can render any key - raw: dict[str, Any] = Field(default_factory=dict) + raw: Dict[str, Any] = Field(default_factory=dict) class EquityCurvePoint(BaseModel): @@ -209,17 +207,17 @@ class EquityCurvePoint(BaseModel): class CompareRunItem(BaseModel): run_id: str - strategy_name: str | None = None + strategy_name: Optional[str] = None status: str - started_at: datetime | None = None - metrics: dict[str, float | None] - equity_curve: list[EquityCurvePoint] = Field(default_factory=list) # B22 + started_at: Optional[datetime] = None + metrics: Dict[str, Optional[float]] + equity_curve: List[EquityCurvePoint] = Field(default_factory=list) # B22 class CompareResponse(BaseModel): - runs: list[CompareRunItem] + runs: List[CompareRunItem] # Column names shared across all runs - common_keys: list[str] + common_keys: List[str] # --------------------------------------------------------------------------- @@ -236,4 +234,4 @@ class QueueStatusResponse(BaseModel): queue_length: int active_count: int max_concurrent: int - queued_runs: list[QueueRunItem] + queued_runs: List[QueueRunItem] diff --git a/web_strategy_studio/backend/studio_api/security_scanner.py b/web_strategy_studio/backend/studio_api/security_scanner.py index e77c913..997320f 100644 --- a/web_strategy_studio/backend/studio_api/security_scanner.py +++ b/web_strategy_studio/backend/studio_api/security_scanner.py @@ -26,6 +26,7 @@ import ast from dataclasses import dataclass +from typing import List @dataclass @@ -91,8 +92,8 @@ class SecurityNote: class SecurityScanner: - def scan(self, source: str) -> list[SecurityNote]: - notes: list[SecurityNote] = [] + def scan(self, source: str) -> List[SecurityNote]: + notes: List[SecurityNote] = [] try: tree = ast.parse(source) except SyntaxError: @@ -149,8 +150,8 @@ def scan(self, source: str) -> list[SecurityNote]: return notes -def require_initialize_function(source: str) -> list[SecurityNote]: - notes: list[SecurityNote] = [] +def require_initialize_function(source: str) -> List[SecurityNote]: + notes: List[SecurityNote] = [] try: tree = ast.parse(source) except SyntaxError: diff --git a/web_strategy_studio/backend/studio_api/stream_hub.py b/web_strategy_studio/backend/studio_api/stream_hub.py index 4afcb1c..9f61a91 100644 --- a/web_strategy_studio/backend/studio_api/stream_hub.py +++ b/web_strategy_studio/backend/studio_api/stream_hub.py @@ -6,7 +6,7 @@ import json import time from collections import defaultdict, deque -from typing import Any +from typing import Any, Dict, List, Optional import structlog @@ -30,12 +30,12 @@ class _RunBuffer: __slots__ = ("events", "terminal", "_expires_at", "_seq") def __init__(self) -> None: - self.events: deque[dict[str, Any]] = deque(maxlen=_RING_SIZE) - self.terminal: dict[str, Any] | None = None - self._expires_at: float | None = None + self.events: deque[Dict[str, Any]] = deque(maxlen=_RING_SIZE) + self.terminal: Optional[Dict[str, Any]] = None + self._expires_at: Optional[float] = None self._seq: int = 0 - def push(self, event: str, data: dict[str, Any], ttl_sec: int) -> dict[str, Any]: + def push(self, event: str, data: Dict[str, Any], ttl_sec: int) -> Dict[str, Any]: self._seq += 1 entry = {"id": self._seq, "event": event, "data": data} self.events.append(entry) @@ -49,7 +49,7 @@ def is_expired(self, now: float) -> bool: return False return now > self._expires_at - def missed_since(self, last_event_id: int) -> list[dict[str, Any]]: + def missed_since(self, last_event_id: int) -> List[Dict[str, Any]]: """Return all buffered events with id > last_event_id.""" return [e for e in self.events if e["id"] > last_event_id] @@ -58,11 +58,11 @@ class StreamHub: """Fan-out hub with per-run ring buffers and Last-Event-ID replay.""" def __init__(self, max_queued: int = 2000, buffer_ttl_sec: int = 1800) -> None: - self._queues: dict[str, list[asyncio.Queue]] = defaultdict(list) - self._buffers: dict[str, _RunBuffer] = {} + self._queues: Dict[str, List[asyncio.Queue]] = defaultdict(list) + self._buffers: Dict[str, _RunBuffer] = {} self._max = max_queued self._ttl = buffer_ttl_sec - self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + self._locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) # ------------------------------------------------------------------ # Public interface @@ -83,7 +83,7 @@ def unsubscribe(self, run_id: str, q: asyncio.Queue) -> None: del self._queues[run_id] self._locks.pop(run_id, None) - def get_buffer(self, run_id: str) -> _RunBuffer | None: + def get_buffer(self, run_id: str) -> Optional[_RunBuffer]: """Return the ring buffer for `run_id` if it exists and hasn't expired.""" buf = self._buffers.get(run_id) if buf is None: @@ -93,13 +93,13 @@ def get_buffer(self, run_id: str) -> _RunBuffer | None: return None return buf - async def publish(self, run_id: str, event: str, data: dict[str, Any]) -> None: + async def publish(self, run_id: str, event: str, data: Dict[str, Any]) -> None: # Store in ring buffer first (so late subscribers can replay). buf = self._buffers.setdefault(run_id, _RunBuffer()) entry = buf.push(event, data, self._ttl) line = {"id": entry["id"], "event": event, "data": data} - dead: list[asyncio.Queue] = [] + dead: List[asyncio.Queue] = [] for q in list(self._queues.get(run_id, [])): try: q.put_nowait(line) @@ -118,7 +118,7 @@ async def publish(self, run_id: str, event: str, data: dict[str, Any]) -> None: self._queues.pop(run_id, None) self._locks.pop(run_id, None) - def format_sse(self, event_id: int, event: str, data: dict[str, Any]) -> str: + def format_sse(self, event_id: int, event: str, data: Dict[str, Any]) -> str: return ( f"id: {event_id}\n" f"event: {event}\n" diff --git a/web_strategy_studio/frontend/package-lock.json b/web_strategy_studio/frontend/package-lock.json index 961c421..8c1ddd3 100644 --- a/web_strategy_studio/frontend/package-lock.json +++ b/web_strategy_studio/frontend/package-lock.json @@ -12,6 +12,7 @@ "@tanstack/react-query": "^5.28.0", "clsx": "^2.1.1", "cmdk": "^1.1.1", + "lightweight-charts": "^4.1.1", "lucide-react": "^1.16.0", "monaco-editor": "^0.47.0", "react": "^18.3.1", @@ -3701,6 +3702,12 @@ "node": ">=12.0.0" } }, + "node_modules/fancy-canvas": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fancy-canvas/-/fancy-canvas-2.1.0.tgz", + "integrity": "sha512-nifxXJ95JNLFR2NgRV4/MxVP45G9909wJTEKz5fg/TZS20JJZA6hfgRVh/bC9bwl2zBtBNcYPjiBE4njQHVBwQ==", + "license": "MIT" + }, "node_modules/fast-deep-equal": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", @@ -4339,6 +4346,15 @@ "node": ">= 0.8.0" } }, + "node_modules/lightweight-charts": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/lightweight-charts/-/lightweight-charts-4.1.1.tgz", + "integrity": "sha512-HYjm66NAIOhoLDNaaQsiwOVWiFHL1yrygZeKd4PgdZESnWyp5dPoTe3pH3t2h4ix+Ix5TwLZaNbWroZqQuj6OA==", + "license": "Apache-2.0", + "dependencies": { + "fancy-canvas": "2.1.0" + } + }, "node_modules/lilconfig": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-3.1.3.tgz", diff --git a/web_strategy_studio/frontend/package.json b/web_strategy_studio/frontend/package.json index ec72337..45b4254 100644 --- a/web_strategy_studio/frontend/package.json +++ b/web_strategy_studio/frontend/package.json @@ -18,6 +18,7 @@ "@tanstack/react-query": "^5.28.0", "clsx": "^2.1.1", "cmdk": "^1.1.1", + "lightweight-charts": "^4.1.1", "lucide-react": "^1.16.0", "monaco-editor": "^0.47.0", "react": "^18.3.1", diff --git a/web_strategy_studio/frontend/src/components/MetricsComparison.tsx b/web_strategy_studio/frontend/src/components/MetricsComparison.tsx index 3807786..45bf66a 100644 --- a/web_strategy_studio/frontend/src/components/MetricsComparison.tsx +++ b/web_strategy_studio/frontend/src/components/MetricsComparison.tsx @@ -1,8 +1,11 @@ import { useQuery } from "@tanstack/react-query"; +import { useEffect, useRef } from "react"; import { useMemo } from "react"; import { compareRunMetrics, EquityCurvePoint } from "../api/client"; import { useEditorStore } from "../store/editorStore"; +import { createChart } from "lightweight-charts"; +import type { IChartApi } from "lightweight-charts"; const METRIC_LABELS: Record = { total_return: "总收益率", @@ -41,30 +44,82 @@ function isGood(metric: string, val: number | null | undefined): boolean { return false; } -/** Minimal sparkline for the equity curve (SVG). */ +/** Lightweight Charts mini area chart for equity curve (60x24). */ function EquitySpark({ points }: { points: EquityCurvePoint[] }) { + const containerRef = useRef(null); + const chartRef = useRef(null); + + useEffect(() => { + if (!containerRef.current) return; + // Cleanup previous chart + if (chartRef.current) { + chartRef.current.remove(); + chartRef.current = null; + } + if (!points.length) return; + + const chart = createChart(containerRef.current, { + width: 60, + height: 24, + layout: { + background: { type: "solid", color: "transparent" }, + textColor: "transparent", + fontSize: 0, + }, + grid: { vertLines: { visible: false }, horzLines: { visible: false } }, + timeScale: { visible: false }, + rightPriceScale: { visible: false }, + crosshair: { mode: 0 }, + } as any); + chartRef.current = chart; + + const area = chart.addAreaSeries({ + lineColor: "transparent", + topColor: "rgba(34,197,94,0.3)", + bottomColor: "rgba(34,197,94,0)", + lineWidth: 1, + priceLineVisible: false, + lastValueVisible: false, + crosshairMarkerVisible: false, + }); + + const values = points.map((p) => p.value); + const lastVal = values[values.length - 1]; + const firstVal = values[0]; + const pct = ((lastVal - firstVal) / Math.abs(firstVal || 1)) * 100; + const positive = pct >= 0; + + // Reconfigure colors based on direction + area.applyOptions({ + lineColor: positive ? "rgba(34,197,94,0.8)" : "rgba(239,68,68,0.8)", + topColor: positive ? "rgba(34,197,94,0.3)" : "rgba(239,68,68,0.3)", + bottomColor: positive ? "rgba(34,197,94,0)" : "rgba(239,68,68,0)", + }); + + area.setData( + points.map((p) => ({ + time: p.date as unknown as string, + value: p.value, + })) as any + ); + chart.timeScale().fitContent(); + + return () => { + chart.remove(); + chartRef.current = null; + }; + }, [points]); + if (!points.length) return ; + const values = points.map((p) => p.value); - const min = Math.min(...values); - const max = Math.max(...values); - const range = max - min || 1; - const w = 80; - const h = 24; - const path = points - .map((p, i) => { - const x = (i / Math.max(points.length - 1, 1)) * w; - const y = h - ((p.value - min) / range) * h; - return `${i === 0 ? "M" : "L"}${x.toFixed(1)},${y.toFixed(1)}`; - }) - .join(" "); const lastVal = values[values.length - 1]; const pct = ((lastVal - values[0]) / Math.abs(values[0] || 1)) * 100; const color = pct >= 0 ? "var(--success)" : "var(--error)"; + return (
- - - +
{pct >= 0 ? "+" : ""}{pct.toFixed(1)}% diff --git a/web_strategy_studio/frontend/src/components/ReportLinkModal.tsx b/web_strategy_studio/frontend/src/components/ReportLinkModal.tsx index ffbd2da..b46b3f6 100644 --- a/web_strategy_studio/frontend/src/components/ReportLinkModal.tsx +++ b/web_strategy_studio/frontend/src/components/ReportLinkModal.tsx @@ -2,6 +2,7 @@ import type { CSSProperties } from "react"; import { useMemo } from "react"; import { resolveArtifactUrl } from "../api/client"; +import ReportViewer from "./ReportViewer"; type Props = { open: boolean; @@ -11,10 +12,10 @@ type Props = { }; export function ReportLinkModal({ open, htmlUrl, runId, onClose }: Props) { - const iframeSrc = useMemo(() => { - const fromApi = resolveArtifactUrl(htmlUrl ?? undefined); - if (fromApi) return fromApi; - if (runId) return resolveArtifactUrl(`/static/reports/${runId}/report.html`); + const jsonUrl = useMemo(() => { + if (runId) return `/static/reports/${runId}/report.json`; + // Derive from htmlUrl if available + if (htmlUrl) return htmlUrl.replace(/\.html$/, ".json"); return undefined; }, [htmlUrl, runId]); @@ -51,11 +52,11 @@ export function ReportLinkModal({ open, htmlUrl, runId, onClose }: Props) {

回测报告

- {iframeSrc ? ( + {htmlUrl ? ( @@ -65,20 +66,10 @@ export function ReportLinkModal({ open, htmlUrl, runId, onClose }: Props) {
- {iframeSrc ? ( -