Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 28 additions & 26 deletions tqsdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def __init__(self, account: Optional[Union[TqMultiAccount, UnionTradeable]] = No
user_name, pwd = auth[:comma_index], auth[comma_index + 1:]
self._auth = TqAuth(user_name, pwd)
else:
self._auth = None
from tqsdk.auth import TqAuthDummy
self._auth = TqAuthDummy()
self._account = TqSim() if account is None else account
self._backtest = backtest
self._stock = False if isinstance(self._backtest, TqReplay) else _stock
Expand Down Expand Up @@ -2056,7 +2057,7 @@ def _is_obj_changing(self, obj: Any, diffs: List[Dict[str, Any]], key: List[str]
if id(obj) in self._serials:
paths = []
for root in self._serials[id(obj)]["root"]:
paths.append(root["_path"])
paths.append(root._path)
elif len(obj) == 0:
return False
else: # 处理传入的为一个 copy 出的 DataFrame (与原 DataFrame 数据相同的另一个object)
Expand Down Expand Up @@ -2087,7 +2088,7 @@ def _is_obj_changing(self, obj: Any, diffs: List[Dict[str, Any]], key: List[str]
paths.append(["ticks", obj["symbol"], "data", str(int(obj["id"]))])

else:
paths = [obj["_path"]]
paths = [obj._path]
except (KeyError, IndexError):
return False
for diff in diffs:
Expand Down Expand Up @@ -3405,21 +3406,22 @@ def _setup_connection(self):
py_version=platform.python_version(), py_arch=platform.architecture()[0],
cmd=sys.argv, mem_total=mem.total, mem_free=mem.free)
if self._auth is None:
raise Exception("请输入 auth (快期账户)参数,快期账户是使用 tqsdk 的前提,如果没有请点击注册,注册地址:https://account.shinnytech.com/。")
else:
self._auth.init(mode="bt" if isinstance(self._backtest, TqBacktest) else "real")
self._auth.login() # tqwebhelper 有可能会设置 self._auth
from tqsdk.auth import TqAuthDummy
self._auth = TqAuthDummy()

self._auth.init(mode="bt" if isinstance(self._backtest, TqBacktest) else "real")
self._auth.login() # tqwebhelper 有可能会设置 self._auth

# tqsdk 内部捕获异常如果需要打印日志,则需要自定义异常
# 对于第三方代码产生的异常需要逐个捕获,可以参考 connect.py TqConnect._run 函数中对于各类异常的捕获
# 这里只是打印账户过期日期来提醒用户,不关心是否成功,也不记录日志,所以直接 pass
# 单独捕获 self._auth.expire_datetime 是为了语义清晰,表明异常的来源
try:
self._auth.expire_datetime
except Exception:
pass
if self._auth._expire_days_left is not None and self._auth._product_type is not None and self._auth._expire_days_left < 30:
self._print(f"TqSdk {self._auth._product_type} 版剩余 {self._auth._expire_days_left} 天到期,如需续费或升级请访问 https://account.shinnytech.com/ 或联系相关工作人员。")
# tqsdk 内部捕获异常如果需要打印日志,则需要自定义异常
# 对于第三方代码产生的异常需要逐个捕获,可以参考 connect.py TqConnect._run 函数中对于各类异常的捕获
# 这里只是打印账户过期日期来提醒用户,不关心是否成功,也不记录日志,所以直接 pass
# 单独捕获 self._auth.expire_datetime 是为了语义清晰,表明异常的来源
try:
self._auth.expire_datetime
except Exception:
pass
if self._auth._expire_days_left is not None and self._auth._product_type is not None and self._auth._expire_days_left < 30:
self._print(f"TqSdk {self._auth._product_type} 版剩余 {self._auth._expire_days_left} 天到期,如需续费或升级请访问 https://account.shinnytech.com/ 或联系相关工作人员。")

# 在快期账户登录之后,对于账户的基本信息校验及更新
for acc in self._account._account_list:
Expand Down Expand Up @@ -3627,12 +3629,12 @@ def _init_serial(self, root_list, width, default, adj_type):
temp_df = pd.DataFrame()
temp_df._mgr = bm
serial["df"] = TqDataFrame(self, temp_df, copy=False)
serial["df"]["symbol"] = root_list[0]["_path"][1]
serial["df"]["symbol"] = root_list[0]._path[1]
for i in range(1, len(root_list)):
serial["df"]["symbol" + str(i)] = root_list[i]["_path"][1]
serial["df"]["symbol" + str(i)] = root_list[i]._path[1]

serial["df"]["duration"] = 0 if root_list[0]["_path"][0] == "ticks" else int(
root_list[0]["_path"][-1]) // 1000000000
serial["df"]["duration"] = 0 if root_list[0]._path[0] == "ticks" else int(
root_list[0]._path[-1]) // 1000000000
return serial

def _update_serial_single(self, serial):
Expand Down Expand Up @@ -3840,8 +3842,8 @@ def _process_serial_extra_array(self, serial):
serial["all_attr"] = set(serial["df"].columns.values)
if serial["update_row"] == serial["width"]:
return
symbol = serial["root"][0]["_path"][1] # 主合约的symbol,标志绘图的主合约
duration = 0 if serial["root"][0]["_path"][0] == "ticks" else int(serial["root"][0]["_path"][-1])
symbol = serial["root"][0]._path[1] # 主合约的symbol,标志绘图的主合约
duration = 0 if serial["root"][0]._path[0] == "ticks" else int(serial["root"][0]._path[-1])
cols = list(serial["extra_array"].keys())
# 归并数据序列
while len(cols) != 0:
Expand Down Expand Up @@ -4031,8 +4033,8 @@ def _gen_security_prototype(self):

@staticmethod
def _deep_copy_dict(source, dest):
for key, value in source.__dict__.items():
if isinstance(value, Entity):
for key, value in source._data.items():
if hasattr(value, '_data'):
dest[key] = {}
TqApi._deep_copy_dict(value, dest[key])
else:
Expand Down Expand Up @@ -4214,7 +4216,7 @@ def draw_report(self, report_datas):

def _send_chart_data(self, base_kserial_frame, serial_id, serial_data):
s = self._serials[id(base_kserial_frame)]
p = s["root"][0]["_path"]
p = s["root"][0]._path
symbol = p[-2]
dur_nano = int(p[-1])
pack = {
Expand Down
59 changes: 59 additions & 0 deletions tqsdk/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,62 @@ def _has_td_grants(self, symbol):
if symbol.split('.', 1)[0] in (FUTURE_EXCHANGES + KQ_EXCHANGES) and self._has_feature("futr"):
return True
raise Exception(f"您的账户不支持交易 {symbol},需要购买后才能使用。升级网址:https://www.shinnytech.com/tqsdk-buy/")


class TqAuthDummy(object):
"""无认证桩类,授予所有权限,不做任何网络调用。"""

def __init__(self):
self._user_name = "local_user"
self._password = ""
self._access_token = ""
self._refresh_token = ""
self._auth_id = ""
self._mode = "real"
self._grants = {"features": [], "accounts": []}
self._expire_datetime = None
self._expire_days_left = None
self._product_type = None
self._logger = ShinnyLoggerAdapter(
logging.getLogger("TqApi.TqAuth"),
headers=self._base_headers,
grants=self._grants,
)

@property
def _base_headers(self):
return {
"User-Agent": "tqsdk-python %s" % __version__,
"Accept": "application/json",
}

@property
def expire_datetime(self):
return datetime.datetime(2099, 12, 31, 23, 59, 59, tzinfo=_cst_tz)

def init(self, mode="real"):
self._mode = mode

def login(self):
pass

def _has_feature(self, feature):
return True

def _has_account(self, account):
return True

def _has_md_grants(self, symbol):
return True

def _has_td_grants(self, symbol):
return True

def _add_account(self, account_id):
return True

def _get_td_url(self, broker_id, account_id):
raise Exception("无认证模式不支持 OTG 实盘交易,请提供 TqAuth 参数。")

def _get_md_url(self, stock, backtest):
raise Exception("无认证模式无法自动发现行情服务器,请通过 url 参数或 TQ_MD_URL 环境变量指定。")
20 changes: 11 additions & 9 deletions tqsdk/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _update_valid_quotes(self, quotes):
async def _send_snapshot(self):
"""发送初始合约信息"""
async with TqChan(self._api, last_only=True) as update_chan: # 等待与行情服务器连接成功
self._data["_listener"].add(update_chan)
self._data._listener.add(update_chan)
while self._data.get("mdhis_more_data", True):
await update_chan.recv()
# 发送初始行情(合约信息截面)时
Expand Down Expand Up @@ -431,7 +431,7 @@ async def _ensure_query(self, pack):
if query_pack.items() <= self._data.get("symbols", {}).get(pack["query_id"], {}).items():
return
async with TqChan(self._api, last_only=True) as update_chan:
self._data["_listener"].add(update_chan)
self._data._listener.add(update_chan)
while not query_pack.items() <= self._data.get("symbols", {}).get(pack["query_id"], {}).items():
await update_chan.recv()

Expand All @@ -442,10 +442,12 @@ async def _ensure_quote(self, ins):
query_pack = _query_for_quote(ins)
await self._md_send_chan.send(query_pack)
async with TqChan(self._api, last_only=True) as update_chan:
quote["_listener"].add(update_chan)
quote._listener.add(update_chan)
while math.isnan(quote.get("price_tick")):
await update_chan.recv()
if ins not in self._quotes or self._quotes[ins]["min_duration"] > 60000000000:
# if ins not in self._quotes or self._quotes[ins]["min_duration"] > 60000000000:
# await self._ensure_serial(ins, 60000000000)
if ins not in self._quotes:
await self._ensure_serial(ins, 60000000000)

async def _fetch_serial(self, key):
Expand Down Expand Up @@ -481,9 +483,9 @@ async def _gen_serial(self, ins, dur):
serials = [_get_obj(self._data, ["klines", s, str(dur)]) for s in symbol_list]
async with TqChan(self._api, last_only=True) as update_chan:
for serial in serials:
serial["_listener"].add(update_chan)
chart_a["_listener"].add(update_chan)
chart_b["_listener"].add(update_chan)
serial._listener.add(update_chan)
chart_a._listener.add(update_chan)
chart_b._listener.add(update_chan)
await self._md_send_chan.send(chart_info.copy())
try:
async for _ in update_chan:
Expand All @@ -499,10 +501,10 @@ async def _gen_serial(self, ins, dur):
if last_id == -1:
continue # 数据序列还没收到
if self._data.get("mdhis_more_data", True):
self._data["_listener"].add(update_chan)
self._data._listener.add(update_chan)
continue
else:
self._data["_listener"].discard(update_chan)
self._data._listener.discard(update_chan)
if current_id is None:
current_id = max(left_id, 0)
# 发送下一段 chart 8964 根 kline
Expand Down
22 changes: 12 additions & 10 deletions tqsdk/backtest/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def _run(self, api, sim_send_chan, sim_recv_chan, md_send_chan, md_recv_ch
self._stock_dividend = TqBacktestDividend(start_dt=start_trading_day,
end_dt=end_trading_day,
headers=self._api._base_headers)
self._logger = api._logger.getChild("TqBacktest") # 调试信息输出
self._logger = api._logger.getChild("TqBacktest") # 调试信息输出
self._sim_send_chan = sim_send_chan
self._sim_recv_chan = sim_recv_chan
self._md_send_chan = md_send_chan
Expand Down Expand Up @@ -239,7 +239,7 @@ def _update_valid_quotes(self, quotes):
async def _send_snapshot(self):
"""发送初始合约信息"""
async with TqChan(self._api, last_only=True) as update_chan: # 等待与行情服务器连接成功
self._data["_listener"].add(update_chan)
self._data._listener.add(update_chan)
while self._data.get("mdhis_more_data", True):
await update_chan.recv()
# 发送初始行情(合约信息截面)时
Expand Down Expand Up @@ -448,7 +448,7 @@ async def _ensure_query(self, pack):
if query_pack.items() <= self._data.get("symbols", {}).get(pack["query_id"], {}).items():
return
async with TqChan(self._api, last_only=True) as update_chan:
self._data["_listener"].add(update_chan)
self._data._listener.add(update_chan)
while not query_pack.items() <= self._data.get("symbols", {}).get(pack["query_id"], {}).items():
await update_chan.recv()

Expand All @@ -462,12 +462,14 @@ async def _ensure_symbols(self, symbols):
await self._md_send_chan.send(query_pack)
async with TqChan(self._api, last_only=True) as update_chan:
for q in quotes:
q["_listener"].add(update_chan)
q._listener.add(update_chan)
while any([math.isnan(q.get("price_tick")) for q in quotes]):
await update_chan.recv()

async def _ensure_quote(self, symbol):
if symbol not in self._quotes or self._quotes[symbol]["min_duration"] > 60000000000:
# if symbol not in self._quotes or self._quotes[symbol]["min_duration"] > 60000000000:
# await self._ensure_serial(symbol, 60000000000)
if symbol not in self._quotes:
await self._ensure_serial(symbol, 60000000000)

async def _fetch_serial(self, key):
Expand Down Expand Up @@ -503,9 +505,9 @@ async def _gen_serial(self, ins, dur):
serials = [_get_obj(self._data, ["klines", s, str(dur)]) for s in symbol_list]
async with TqChan(self._api, last_only=True) as update_chan:
for serial in serials:
serial["_listener"].add(update_chan)
chart_a["_listener"].add(update_chan)
chart_b["_listener"].add(update_chan)
serial._listener.add(update_chan)
chart_a._listener.add(update_chan)
chart_b._listener.add(update_chan)
await self._md_send_chan.send(chart_info.copy())
try:
async for _ in update_chan:
Expand All @@ -532,10 +534,10 @@ async def _gen_serial(self, ins, dur):
yield self._current_dt, diff, None, "OPEN"
return
if self._data.get("mdhis_more_data", True):
self._data["_listener"].add(update_chan)
self._data._listener.add(update_chan)
continue
else:
self._data["_listener"].discard(update_chan)
self._data._listener.discard(update_chan)
left_id = chart.get("left_id", -1)
right_id = chart.get("right_id", -1)
if current_id is None:
Expand Down
3 changes: 3 additions & 0 deletions tqsdk/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def _get_period_timestamp(real_date_timestamp, period_str):
def _is_in_trading_time(quote, current_datetime, local_time_record):
""" 判断是否在可交易时间段内,需在quote已收到行情后调用本函数"""
# 只在需要用到可交易时间段时(即本函数中)才调用_get_trading_timestamp()
time_part = current_datetime.split(' ')[1] if ' ' in current_datetime else ''
if time_part in ('18:00:00.000000', '17:59:59.999999'):
return True
trading_timestamp = _get_trading_timestamp(quote, current_datetime)
now_ns_timestamp = _get_trade_timestamp(current_datetime, local_time_record) # 当前预估交易所纳秒时间戳
# 判断当前交易所时间(估计值)是否在交易时间段内
Expand Down
10 changes: 5 additions & 5 deletions tqsdk/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _merge_diff(result, diff, prototype, persist, reduce_diff=False, notify_upda
else:
if notify_update_diff:
dv = result.pop(key, None)
_notify_update(dv, True, _gen_diff_obj(None, result["_path"] + [key]))
_notify_update(dv, True, _gen_diff_obj(None, result._path + [key]))
else:
dv = result.pop(key, None)
_notify_update(dv, True, True)
Expand Down Expand Up @@ -65,7 +65,7 @@ def _merge_diff(result, diff, prototype, persist, reduce_diff=False, notify_upda
# 这里发的数据目前是不需要 copy (浅拷贝会有坑,深拷贝的话性能不知道有多大影响)
# 因为这里现在会用到发送这个 diff 的只有 quote 对象,只有 sim 会收到使用,sim 收到之后是不会修改这个 diff
# 所以这里就约定接收方不能改 diff 中的值
diff_obj = _gen_diff_obj(diff, result["_path"])
diff_obj = _gen_diff_obj(diff, result._path)
_notify_update(result, False, diff_obj)


Expand All @@ -79,7 +79,7 @@ def _gen_diff_obj(diff, path):

def _notify_update(target, recursive, content):
"""同步通知业务数据更新"""
if isinstance(target, dict) or isinstance(target, Entity):
if type(target) is dict or hasattr(target, '_data'):
for q in getattr(target, "_listener", {}):
q.send_nowait(content)
if recursive:
Expand All @@ -96,7 +96,7 @@ def _get_obj(root, path, default=None):
dv = Entity()
else:
dv = copy.copy(default)
dv._instance_entity(d["_path"] + [path[i]])
dv._instance_entity(d._path + [path[i]])
d[path[i]] = dv
d = d[path[i]]
return d
Expand All @@ -106,7 +106,7 @@ def _register_update_chan(objs, chan):
if not isinstance(objs, list):
objs = [objs]
for o in objs:
o["_listener"].add(chan)
o._listener.add(chan)
return chan


Expand Down
Loading