diff --git a/daemon/daemon.py b/daemon/daemon.py index 7aeece2..dad5e15 100644 --- a/daemon/daemon.py +++ b/daemon/daemon.py @@ -2,6 +2,7 @@ from urllib.parse import urlparse from data_dispatcher.db import DBProject, DBReplica, DBRSE, DBProximityMap from data_dispatcher.logs import Logged +from data_dispatcher.scitoken import scitoken from daemon_web_server import DaemonWebServer from tape_interfaces import get_interface @@ -572,12 +573,15 @@ def run(self): self.log("Message broker is not configured. Stopping the RucioListener thread.") return broker_addr = (self.MessageBrokerConfig["host"], self.MessageBrokerConfig["port"]) - cert_file = self.SSLConfig.get("cert") - key_file = self.SSLConfig.get("key") + cert_file = self.SSLConfig.get("cert", None) + key_file = self.SSLConfig.get("key", None) ca_file = self.SSLConfig.get("ca_bundle") vhost = self.MessageBrokerConfig.get("vhost", "/") subscribe = self.MessageBrokerConfig["subscribe"] - connection = stompy.connect(broker_addr, cert_file=cert_file, key_file=key_file, ca_file=ca_file, host=vhost) + if not cert_file and scitoken(): + connection = stompy.connect(broker_addr, ca_file=ca_file, host=vhost, Authorization=f"Bearer {scitoken()}") + else: + connection = stompy.connect(broker_addr, cert_file=cert_file, key_file=key_file, ca_file=ca_file, host=vhost) connection.subscribe(subscribe) for frame in connection: if frame.Command == "MESSAGE": diff --git a/daemon/dcache.py b/daemon/dcache.py index bf0aa10..3c5fb6b 100644 --- a/daemon/dcache.py +++ b/daemon/dcache.py @@ -2,6 +2,7 @@ from data_dispatcher.db import DBReplica from pythreader import Primitive, Scheduler, synchronized, PyThread from data_dispatcher.logs import Logged +from data_dispatcher.scitoken import scitoken class DCachePoller(PyThread, Logged): @@ -42,8 +43,10 @@ def run(self): remove_dids = [] for did, path in burst: url = self.BaseURL + path + "?locality=true" - cert = None if self.Cert is None else (self.Cert, self.Key) #self.debug("dCache poll URL:", url) + cert = None if self.Cert is None else (self.Cert, self.Key) + if cert == None and scitoken(): + headers["Authorization"] = f"Bearer {scitoken()}" response = requests.get(url, headers=headers, cert=cert, verify=False) #self.debug("response:", response.status_code, response.text) if response.status_code == 404: @@ -135,6 +138,8 @@ def send(self): sys.exit(1) self.debug("request data:", json.dumps(data, indent=" ")) + if not self.CertTuple and scitoken(): + headers["Authorization"] = f"Bearer {scitoken()}" r = requests.post(url, data = json.dumps(data), headers=headers, verify=False, cert = self.CertTuple) self.debug("response:", r) @@ -160,6 +165,8 @@ def send(self): def query(self): assert self.URL is not None headers = { "accept" : "application/json" } + if cert == None and scitoken(): + headers["Authorization"] = f"Bearer {scitoken()}" r = requests.get(self.URL, headers=headers, verify=False, cert = self.CertTuple) #self.debug("status(): response:", r) if r.status_code // 100 == 4: @@ -193,6 +200,8 @@ def staged_replicas(self): def delete(self): assert self.URL is not None headers = { "accept" : "application/json" } + if cert == None and scitoken(): + headers["Authorization"] = f"Bearer {scitoken()}" r = requests.delete(self.URL, headers=headers, verify=False, cert = self.CertTuple) #self.debug("status(): response:", r) if r.status_code // 100 == 4: diff --git a/data_dispatcher/scitoken.py b/data_dispatcher/scitoken.py new file mode 100644 index 0000000..bae2662 --- /dev/null +++ b/data_dispatcher/scitoken.py @@ -0,0 +1,30 @@ +import os +import time + +class TokenReader: + + def __init__(self): + self.last_token = None + self.last_fetch = 0 + + def get(self): + if not os.environ.get("BEARER_TOKEN_FILE", ""): + # pick default file as BEARER_TOKEN_FILE if not set and it exists + uid = os.getuid() + deftokenf = f"/var/run/user/{uid}/bt_u{uid}" + if os.access(deftokenf, os.R_OK): + os.environ["BEARER_TOKEN_FILE"] = deftokenf + else: + return "" + + if self.last_fetch < time.time() - 5: + # return cached if less than 5 seconds old + with open(os.environ["BEARER_TOKEN_FILE"], "r") as tin: + self.last_token = tin.read().strip() + self.last_fetch = time.time() + + return self.last_token + +scitoken_obj = TokenReader() + +scitoken = scitoken_obj.get