Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions django/core/management/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,11 @@ def run_from_argv(self, argv):
else:
self.stderr.write("%s: %s" % (e.__class__.__name__, e))
sys.exit(e.returncode)
except KeyboardInterrupt:
if options.traceback:
raise
self.stderr.write("\nOperation cancelled.")
sys.exit(1)
finally:
try:
connections.close_all()
Expand Down
18 changes: 17 additions & 1 deletion django/db/backends/base/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import signal
import subprocess


Expand Down Expand Up @@ -28,4 +29,19 @@ def runshell(self, parameters):
self.connection.settings_dict, parameters
)
env = {**os.environ, **env} if env else None
subprocess.run(args, env=env, check=True)
sigint_handler = None
if hasattr(signal, "SIGINT"):
try:
sigint_handler = signal.getsignal(signal.SIGINT)
except ValueError:
pass
try:
if sigint_handler is not None:
signal.signal(signal.SIGINT, signal.SIG_IGN)
subprocess.run(args, env=env, check=True)
finally:
if sigint_handler is not None:
try:
signal.signal(signal.SIGINT, sigint_handler)
except ValueError:
pass
12 changes: 0 additions & 12 deletions django/db/backends/mysql/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import signal

from django.db.backends.base.client import BaseDatabaseClient


Expand Down Expand Up @@ -60,13 +58,3 @@ def settings_to_cmd_args_env(cls, settings_dict, parameters):
args += [database]
args.extend(parameters)
return args, env

def runshell(self, parameters):
sigint_handler = signal.getsignal(signal.SIGINT)
try:
# Allow SIGINT to pass to mysql to abort queries.
signal.signal(signal.SIGINT, signal.SIG_IGN)
super().runshell(parameters)
finally:
# Restore the original SIGINT handler.
signal.signal(signal.SIGINT, sigint_handler)
12 changes: 0 additions & 12 deletions django/db/backends/postgresql/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import signal

from django.db.backends.base.client import BaseDatabaseClient


Expand Down Expand Up @@ -52,13 +50,3 @@ def settings_to_cmd_args_env(cls, settings_dict, parameters):
if passfile:
env["PGPASSFILE"] = str(passfile)
return args, (env or None)

def runshell(self, parameters):
sigint_handler = signal.getsignal(signal.SIGINT)
try:
# Allow SIGINT to pass to psql to abort queries.
signal.signal(signal.SIGINT, signal.SIG_IGN)
super().runshell(parameters)
finally:
# Restore the original SIGINT handler.
signal.signal(signal.SIGINT, sigint_handler)
25 changes: 25 additions & 0 deletions tests/admin_scripts/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,6 +2232,31 @@ def raise_command_error(*args, **kwargs):
with self.assertRaises(CommandError):
command.run_from_argv(["", "", "--traceback"])

def test_run_from_argv_keyboard_interrupt(self):
"""
Test run_from_argv handles KeyboardInterrupt cleanly by printing an
error message and exiting with 1.
"""
err = StringIO()
command = BaseCommand(stderr=err)

def raise_keyboard_interrupt(*args, **kwargs):
raise KeyboardInterrupt()

command.execute = raise_keyboard_interrupt

# If --traceback is not present, should print "Operation cancelled."
# and exit with SystemExit(1)
err.truncate(0)
with self.assertRaises(SystemExit) as cm:
command.run_from_argv(["", ""])
self.assertEqual(cm.exception.code, 1)
self.assertIn("Operation cancelled.", err.getvalue())

# If --traceback is present, should propagate KeyboardInterrupt
with self.assertRaises(KeyboardInterrupt):
command.run_from_argv(["", "", "--traceback"])

def test_run_from_argv_non_ascii_error(self):
"""
Non-ASCII message of CommandError does not raise any
Expand Down
22 changes: 22 additions & 0 deletions tests/dbshell/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,25 @@ def test_command_missing(self):
with self.assertRaisesMessage(CommandError, msg):
with mock.patch("subprocess.run", side_effect=FileNotFoundError):
call_command("dbshell")

@mock.patch("django.db.backends.base.client.subprocess.run")
def test_sigint_ignored_during_runshell(self, mock_run):
import signal

from django.db.backends.base.client import BaseDatabaseClient

original_handler = signal.getsignal(signal.SIGINT)

def mock_run_side_effect(*args, **kwargs):
self.assertEqual(signal.getsignal(signal.SIGINT), signal.SIG_IGN)

mock_run.side_effect = mock_run_side_effect

client = BaseDatabaseClient(connection)
# Mock settings_to_cmd_args_env to return dummy args
with mock.patch.object(
client, "settings_to_cmd_args_env", return_value=(["mock_db_client"], None)
):
client.runshell([])

self.assertEqual(signal.getsignal(signal.SIGINT), original_handler)