@@ -188,6 +188,19 @@ def abs_connection_file(self):
188188 """ ,
189189 ).tag (config = True )
190190
191+ enable_curve = Bool (
192+ bool (int (os .environ .get ("JUPYTER_ENABLE_CURVE" , "0" ))),
193+ help = "Enable CurveZMQ transport encryption and authentication. "
194+ "When True, a keypair is generated at startup and stored in the "
195+ "connection file so that clients can authenticate and encrypt "
196+ "all ZMQ channels." ,
197+ ).tag (config = True )
198+
199+ # Internal CurveZMQ keypair (Z85-encoded bytes); populated in init_sockets
200+ # when enable_curve is True.
201+ _curve_publickey : bytes | None = None
202+ _curve_secretkey : bytes | None = None
203+
191204 # polling
192205 parent_handle = Integer (
193206 int (os .environ .get ("JPY_PARENT_PID" ) or 0 ),
@@ -211,6 +224,17 @@ def excepthook(self, etype, evalue, tb):
211224 # write uncaught traceback to 'real' stderr, not zmq-forwarder
212225 traceback .print_exception (etype , evalue , tb , file = sys .__stderr__ )
213226
227+ def _apply_curve_server_options (self , socket : zmq .sugar .socket .Socket ) -> None :
228+ """Set CurveZMQ server-side options on *socket* before it is bound.
229+
230+ This is a no-op when enable_curve is False or keys have not been
231+ generated yet, so it is safe to call unconditionally.
232+ """
233+ if self .enable_curve and self ._curve_secretkey is not None :
234+ socket .curve_secretkey = self ._curve_secretkey
235+ socket .curve_publickey = self ._curve_publickey
236+ socket .curve_server = True
237+
214238 def init_poller (self ):
215239 """Initialize the poller."""
216240 if sys .platform == "win32" :
@@ -274,6 +298,12 @@ def write_connection_file(self, **kwargs: Any) -> None:
274298 iopub_port = self .iopub_port ,
275299 control_port = self .control_port ,
276300 )
301+ if self .enable_curve and self ._curve_publickey is not None :
302+ # Store Z85-encoded keys as ASCII strings alongside the HMAC key.
303+ # Clients that understand CurveZMQ will use these to configure
304+ # their sockets; legacy clients ignore the unknown fields.
305+ connection_info ["curve_publickey" ] = self ._curve_publickey .decode ("ascii" )
306+ connection_info ["curve_secretkey" ] = self ._curve_secretkey .decode ("ascii" ) # type: ignore[union-attr]
277307 if Path (cf ).exists ():
278308 # If the file exists, merge our info into it. For example, if the
279309 # original file had port number 0, we update with the actual port
@@ -328,13 +358,19 @@ def init_sockets(self):
328358 self .context = context = zmq .Context ()
329359 atexit .register (self .close )
330360
361+ if self .enable_curve :
362+ self ._curve_publickey , self ._curve_secretkey = zmq .curve_keypair ()
363+ self .log .debug ("CurveZMQ enabled; generated server keypair" )
364+
331365 self .shell_socket = context .socket (zmq .ROUTER )
332366 self .shell_socket .linger = 1000
367+ self ._apply_curve_server_options (self .shell_socket )
333368 self .shell_port = self ._bind_socket (self .shell_socket , self .shell_port )
334369 self .log .debug ("shell ROUTER Channel on port: %i" , self .shell_port )
335370
336371 self .stdin_socket = context .socket (zmq .ROUTER )
337372 self .stdin_socket .linger = 1000
373+ self ._apply_curve_server_options (self .stdin_socket )
338374 self .stdin_port = self ._bind_socket (self .stdin_socket , self .stdin_port )
339375 self .log .debug ("stdin ROUTER Channel on port: %i" , self .stdin_port )
340376
@@ -351,6 +387,7 @@ def init_control(self, context):
351387 """Initialize the control channel."""
352388 self .control_socket = context .socket (zmq .ROUTER )
353389 self .control_socket .linger = 1000
390+ self ._apply_curve_server_options (self .control_socket )
354391 self .control_port = self ._bind_socket (self .control_socket , self .control_port )
355392 self .log .debug ("control ROUTER Channel on port: %i" , self .control_port )
356393
@@ -379,6 +416,7 @@ def init_iopub(self, context):
379416 """Initialize the iopub channel."""
380417 self .iopub_socket = context .socket (zmq .XPUB )
381418 self .iopub_socket .linger = 1000
419+ self ._apply_curve_server_options (self .iopub_socket )
382420 self .iopub_port = self ._bind_socket (self .iopub_socket , self .iopub_port )
383421 self .log .debug ("iopub PUB Channel on port: %i" , self .iopub_port )
384422 self .configure_tornado_logger ()
@@ -392,7 +430,12 @@ def init_heartbeat(self):
392430 # heartbeat doesn't share context, because it mustn't be blocked
393431 # by the GIL, which is accessed by libzmq when freeing zero-copy messages
394432 hb_ctx = zmq .Context ()
395- self .heartbeat = Heartbeat (hb_ctx , (self .transport , self .ip , self .hb_port ))
433+ self .heartbeat = Heartbeat (
434+ hb_ctx ,
435+ (self .transport , self .ip , self .hb_port ),
436+ curve_publickey = self ._curve_publickey if self .enable_curve else None ,
437+ curve_secretkey = self ._curve_secretkey if self .enable_curve else None ,
438+ )
396439 self .hb_port = self .heartbeat .port
397440 self .log .debug ("Heartbeat REP Channel on port: %i" , self .hb_port )
398441 self .heartbeat .start ()
0 commit comments