Skip to content

Commit 38a596f

Browse files
committed
Draft implementation
1 parent a53cd8f commit 38a596f

2 files changed

Lines changed: 65 additions & 3 deletions

File tree

ipykernel/heartbeat.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,29 @@
2727
class Heartbeat(Thread):
2828
"""A simple ping-pong style heartbeat that runs in a thread."""
2929

30-
def __init__(self, context, addr=None):
31-
"""Initialize the heartbeat thread."""
30+
def __init__(self, context, addr=None, curve_publickey=None, curve_secretkey=None):
31+
"""Initialize the heartbeat thread.
32+
33+
Parameters
34+
----------
35+
context : zmq.Context
36+
addr : tuple, optional
37+
(transport, ip, port)
38+
curve_publickey : bytes, optional
39+
Z85-encoded CurveZMQ public key. When provided together with
40+
*curve_secretkey*, the heartbeat socket will operate as a
41+
CurveZMQ server so that only authenticated clients can connect.
42+
curve_secretkey : bytes, optional
43+
Z85-encoded CurveZMQ secret key (paired with *curve_publickey*).
44+
"""
3245
if addr is None:
3346
addr = ("tcp", localhost(), 0)
3447
Thread.__init__(self, name="Heartbeat")
3548
self.context = context
3649
self.transport, self.ip, self.port = addr
3750
self.original_port = self.port
51+
self.curve_publickey = curve_publickey
52+
self.curve_secretkey = curve_secretkey
3853
if self.original_port == 0:
3954
self.pick_port()
4055
self.addr = (self.ip, self.port)
@@ -94,6 +109,10 @@ def run(self):
94109
self.name = "Heartbeat"
95110
self.socket = self.context.socket(zmq.ROUTER)
96111
self.socket.linger = 1000
112+
if self.curve_secretkey is not None:
113+
self.socket.curve_secretkey = self.curve_secretkey
114+
self.socket.curve_publickey = self.curve_publickey
115+
self.socket.curve_server = True
97116
try:
98117
self._bind_socket()
99118
except Exception:

ipykernel/kernelapp.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,19 @@ def abs_connection_file(self):
188188
""",
189189
).tag(config=True)
190190

191+
enable_curve = Bool(
192+
bool(int(os.environ.get("JUPYTER_ENABLE_CURVE", "0"))),
193+
help="Enable CurveZMQ transport encryption and authentication. "
194+
"When True, a keypair is generated at startup and stored in the "
195+
"connection file so that clients can authenticate and encrypt "
196+
"all ZMQ channels.",
197+
).tag(config=True)
198+
199+
# Internal CurveZMQ keypair (Z85-encoded bytes); populated in init_sockets
200+
# when enable_curve is True.
201+
_curve_publickey: bytes | None = None
202+
_curve_secretkey: bytes | None = None
203+
191204
# polling
192205
parent_handle = Integer(
193206
int(os.environ.get("JPY_PARENT_PID") or 0),
@@ -211,6 +224,17 @@ def excepthook(self, etype, evalue, tb):
211224
# write uncaught traceback to 'real' stderr, not zmq-forwarder
212225
traceback.print_exception(etype, evalue, tb, file=sys.__stderr__)
213226

227+
def _apply_curve_server_options(self, socket: zmq.sugar.socket.Socket) -> None:
228+
"""Set CurveZMQ server-side options on *socket* before it is bound.
229+
230+
This is a no-op when enable_curve is False or keys have not been
231+
generated yet, so it is safe to call unconditionally.
232+
"""
233+
if self.enable_curve and self._curve_secretkey is not None:
234+
socket.curve_secretkey = self._curve_secretkey
235+
socket.curve_publickey = self._curve_publickey
236+
socket.curve_server = True
237+
214238
def init_poller(self):
215239
"""Initialize the poller."""
216240
if sys.platform == "win32":
@@ -274,6 +298,12 @@ def write_connection_file(self, **kwargs: Any) -> None:
274298
iopub_port=self.iopub_port,
275299
control_port=self.control_port,
276300
)
301+
if self.enable_curve and self._curve_publickey is not None:
302+
# Store Z85-encoded keys as ASCII strings alongside the HMAC key.
303+
# Clients that understand CurveZMQ will use these to configure
304+
# their sockets; legacy clients ignore the unknown fields.
305+
connection_info["curve_publickey"] = self._curve_publickey.decode("ascii")
306+
connection_info["curve_secretkey"] = self._curve_secretkey.decode("ascii") # type: ignore[union-attr]
277307
if Path(cf).exists():
278308
# If the file exists, merge our info into it. For example, if the
279309
# original file had port number 0, we update with the actual port
@@ -328,13 +358,19 @@ def init_sockets(self):
328358
self.context = context = zmq.Context()
329359
atexit.register(self.close)
330360

361+
if self.enable_curve:
362+
self._curve_publickey, self._curve_secretkey = zmq.curve_keypair()
363+
self.log.debug("CurveZMQ enabled; generated server keypair")
364+
331365
self.shell_socket = context.socket(zmq.ROUTER)
332366
self.shell_socket.linger = 1000
367+
self._apply_curve_server_options(self.shell_socket)
333368
self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
334369
self.log.debug("shell ROUTER Channel on port: %i", self.shell_port)
335370

336371
self.stdin_socket = context.socket(zmq.ROUTER)
337372
self.stdin_socket.linger = 1000
373+
self._apply_curve_server_options(self.stdin_socket)
338374
self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
339375
self.log.debug("stdin ROUTER Channel on port: %i", self.stdin_port)
340376

@@ -351,6 +387,7 @@ def init_control(self, context):
351387
"""Initialize the control channel."""
352388
self.control_socket = context.socket(zmq.ROUTER)
353389
self.control_socket.linger = 1000
390+
self._apply_curve_server_options(self.control_socket)
354391
self.control_port = self._bind_socket(self.control_socket, self.control_port)
355392
self.log.debug("control ROUTER Channel on port: %i", self.control_port)
356393

@@ -379,6 +416,7 @@ def init_iopub(self, context):
379416
"""Initialize the iopub channel."""
380417
self.iopub_socket = context.socket(zmq.XPUB)
381418
self.iopub_socket.linger = 1000
419+
self._apply_curve_server_options(self.iopub_socket)
382420
self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
383421
self.log.debug("iopub PUB Channel on port: %i", self.iopub_port)
384422
self.configure_tornado_logger()
@@ -392,7 +430,12 @@ def init_heartbeat(self):
392430
# heartbeat doesn't share context, because it mustn't be blocked
393431
# by the GIL, which is accessed by libzmq when freeing zero-copy messages
394432
hb_ctx = zmq.Context()
395-
self.heartbeat = Heartbeat(hb_ctx, (self.transport, self.ip, self.hb_port))
433+
self.heartbeat = Heartbeat(
434+
hb_ctx,
435+
(self.transport, self.ip, self.hb_port),
436+
curve_publickey=self._curve_publickey if self.enable_curve else None,
437+
curve_secretkey=self._curve_secretkey if self.enable_curve else None,
438+
)
396439
self.hb_port = self.heartbeat.port
397440
self.log.debug("Heartbeat REP Channel on port: %i", self.hb_port)
398441
self.heartbeat.start()

0 commit comments

Comments
 (0)