diff --git a/nmostesting/IS10Utils.py b/nmostesting/IS10Utils.py index 8f170cc00..bae35b7a2 100644 --- a/nmostesting/IS10Utils.py +++ b/nmostesting/IS10Utils.py @@ -15,6 +15,7 @@ from Crypto.PublicKey import RSA from authlib.jose import jwt, JsonWebKey +import re import time import uuid @@ -22,6 +23,7 @@ from OpenSSL import crypto from cryptography.hazmat.primitives import serialization from cryptography import x509 +from flask import request from .TestHelper import get_default_ip, get_mocks_hostname from . import Config as CONFIG @@ -157,3 +159,36 @@ def is_any_contain(list, enum): if item in [e.name for e in enum]: return True return False + + @staticmethod + def check_authorization(auth, path, scope="x-nmos-registration", write=False): + def _check_path_match(path, path_wildcards): + path_match = False + for path_wildcard in path_wildcards: + pattern = path_wildcard.replace("*", ".*") + if re.search(pattern, path): + path_match = True + break + return path_match + + if CONFIG.ENABLE_AUTH: + try: + if "Authorization" not in request.headers: + return 400, "Authorization header not found" + if not request.headers["Authorization"].startswith("Bearer "): + return 400, "Bearer not found in Authorization header" + token = request.headers["Authorization"].split(" ")[1] + claims = jwt.decode(token, auth.generate_jwk()) + claims.validate() + if claims["iss"] != auth.make_issuer(): + return 401, f"Unexpected issuer, expected: {auth.make_issuer()}, actual: {claims['iss']}" + # TODO: Check 'aud' claim matches 'mocks.' + if not _check_path_match(path, claims[scope]["read"]): + return 403, f"Paths mismatch for {scope} read claims" + if write and not _check_path_match(path, claims[scope]["write"]): + return 403, f"Paths mismatch for {scope} write claims" + except KeyError as err: + return 400, f"KeyError: {err}" + except Exception as err: + return 400, f"Exception: {err}" + return True, "" diff --git a/nmostesting/IS12Utils.py b/nmostesting/IS12Utils.py index 7cf85bed3..b1e10cca5 100644 --- a/nmostesting/IS12Utils.py +++ b/nmostesting/IS12Utils.py @@ -171,6 +171,12 @@ def send_command(self, test, command_json): for tm in self.ncp_websocket.get_timestamped_messages(): parsed_message = json.loads(tm.message) + if parsed_message is None: + raise NMOSTestException(test.FAIL( + f"Null message received for command: {str(command_json)}", + f"https://specs.amwa.tv/is-12/branches/{self.apis[CONTROL_API_KEY]['spec_branch']}" + "/docs/Protocol_messaging.html#command-message-type")) + if self.message_type_to_schema_name(parsed_message.get("messageType")): self._validate_is12_schema( test, @@ -230,7 +236,7 @@ def get_notifications(self): # Get any timestamped messages that have arrived in the interim period for tm in self.ncp_websocket.get_timestamped_messages(): parsed_message = json.loads(tm.message) - if parsed_message["messageType"] == MessageTypes.Notification: + if parsed_message and parsed_message["messageType"] == MessageTypes.Notification: self.notifications += [IS12Notification(n, tm.received_time) for n in parsed_message["notifications"]] return self.notifications diff --git a/nmostesting/NMOSTesting.py b/nmostesting/NMOSTesting.py index b95f4a15c..77387470e 100644 --- a/nmostesting/NMOSTesting.py +++ b/nmostesting/NMOSTesting.py @@ -169,6 +169,7 @@ # Primary Authorization server if CONFIG.ENABLE_AUTH: auth_app = Flask(__name__) + CORS(auth_app) auth_app.debug = False auth_app.config['AUTH_INSTANCE'] = 0 auth_app.config['PORT'] = PRIMARY_AUTH.port diff --git a/nmostesting/mocks/Auth.py b/nmostesting/mocks/Auth.py index 4b43921fa..b4e70c69f 100644 --- a/nmostesting/mocks/Auth.py +++ b/nmostesting/mocks/Auth.py @@ -113,6 +113,7 @@ def __init__(self, port_increment, version="v1.0"): self.host = get_mocks_hostname() # authorization code of the authorization code flow self.code = None + self.scopes_cache = {} # remember client scopes def make_mdns_info(self, priority=0, api_ver=None, ip=None): """Get an mDNS ServiceInfo object in order to create an advertisement""" @@ -302,10 +303,6 @@ def auth_auth(): # Recommended parameters # state - ctype_valid, ctype_message = check_content_type(request.headers, "application/x-www-form-urlencoded") - if not ctype_valid: - raise AuthException("invalid_request", ctype_message) - # hmm, no client authorization done, just redirects a random authorization code back to the client # TODO: add web pages for client authorization for the future @@ -342,6 +339,8 @@ def auth_auth(): if not scope_found: error = "invalid_request" error_description = "scope: {} are not supported".format(scopes) + # cache the client scopes + auth.scopes_cache[request.args["client_id"]] = scopes vars = {} if error: @@ -370,7 +369,6 @@ def auth_auth(): def auth_token(): auth = AUTHS[flask.current_app.config["AUTH_INSTANCE"]] try: - auth_header_required = False scopes = [] ctype_valid, ctype_message = check_content_type(request.headers, "application/x-www-form-urlencoded") @@ -395,7 +393,13 @@ def auth_token(): refresh_token = query["refresh_token"][0] if "refresh_token" in query else None - scopes = query["scope"][0].split() if "scope" in query else SCOPE.split() if SCOPE else [] + # Scope query parameter is OPTIONAL + # see https://datatracker.ietf.org/doc/html/rfc6749#section-4.4.2 + # and https://datatracker.ietf.org/doc/html/rfc6749#section-6 + # Use scopes cached from when the token was created if not provided in query + cached_scopes = auth.scopes_cache[client_id] if client_id in auth.scopes_cache else [] + scopes = query["scope"][0].split() if "scope" in query else cached_scopes \ + if len(cached_scopes) else SCOPE.split() if SCOPE else [] if scopes: scope_found = IS10Utils.is_any_contain(scopes, SCOPES) if not scope_found: @@ -484,8 +488,6 @@ def auth_token(): else: raise AuthException("unsupported_grant_type", "missing client_assertion_type used for private_key_jwt client authentication") - else: - auth_header_required = True # for the Confidential client, client_id and client_secret are embedded in the Authorization header auth_header = request.headers.get("Authorization", None) @@ -504,8 +506,6 @@ def auth_token(): "missing client_id or client_secret from authorization header") else: raise AuthException("invalid_client", "invalid authorization header") - elif auth_header_required: - raise AuthException("invalid_client", "invalid authorization header", HTTPStatus.UNAUTHORIZED) # client_id MUST be provided by all types of client if not client_id: diff --git a/nmostesting/mocks/Node.py b/nmostesting/mocks/Node.py index 80245f1c9..42aa0fdf7 100644 --- a/nmostesting/mocks/Node.py +++ b/nmostesting/mocks/Node.py @@ -24,6 +24,8 @@ from .. import Config as CONFIG from ..TestHelper import get_default_ip, do_request from ..IS04Utils import IS04Utils +from ..IS10Utils import IS10Utils +from .Auth import PRIMARY_AUTH class Node(object): @@ -39,6 +41,7 @@ def reset(self): self.receivers = {} self.senders = {} self.patched_sdp = {} + self.auth_cache = {} def get_sender(self, media_type="video/raw", version="v1.3"): protocol = "http" @@ -360,11 +363,49 @@ def patch_staged(self, resource, resource_id, request_json): return response_data, response_code + def check_authorization(self, auth, path, scope, write=False): + if not CONFIG.ENABLE_AUTH: + return True, "" + + if "Authorization" in request.headers and request.headers["Authorization"].startswith("Bearer ") \ + and scope in self.auth_cache and \ + ((write and self.auth_cache[scope]["Write"]) or self.auth_cache[scope]["Read"]): + return True, "" + + authorized, error_message = IS10Utils.check_authorization(auth, + path, + scope=scope, + write=write) + if authorized: + if scope not in self.auth_cache: + self.auth_cache[scope] = {"Read": True, "Write": write} + else: + self.auth_cache[scope]["Read"] = True + self.auth_cache[scope]["Write"] = self.auth_cache[scope]["Write"] or write + return authorized, error_message + NODE = Node(1) NODE_API = Blueprint('node_api', __name__) +# Authorization decorator +def check_authorization(func): + def wrapper(*args, **kwargs): + write = (request.method == 'PATCH') + authorized, error_message = NODE.check_authorization(PRIMARY_AUTH, + request.path, + scope="x-nmos-connection", + write=write) + if authorized is not True: + abort(authorized, description=error_message) + + return func(*args, **kwargs) + # Rename wrapper to allow decoration of decorator + wrapper.__name__ = func.__name__ + return wrapper + + @NODE_API.route('/x-nmos', methods=['GET'], strict_slashes=False) def x_nmos_root(): base_data = ['connection/'] @@ -373,6 +414,7 @@ def x_nmos_root(): @NODE_API.route('/x-nmos/connection', methods=['GET'], strict_slashes=False) +@check_authorization def connection_root(): base_data = ['v1.0/', 'v1.1/'] @@ -380,6 +422,7 @@ def connection_root(): @NODE_API.route('/x-nmos/connection/', methods=['GET'], strict_slashes=False) +@check_authorization def version(version): base_data = ['bulk/', 'single/'] @@ -387,6 +430,7 @@ def version(version): @NODE_API.route('/x-nmos/connection//single', methods=['GET'], strict_slashes=False) +@check_authorization def single(version): base_data = ['senders/', 'receivers/'] @@ -394,6 +438,7 @@ def single(version): @NODE_API.route('/x-nmos/connection//single//', methods=["GET"], strict_slashes=False) +@check_authorization def resources(version, resource): if resource == 'senders': base_data = [r + '/' for r in [*NODE.senders]] @@ -404,6 +449,7 @@ def resources(version, resource): @NODE_API.route('/x-nmos/connection//single//', methods=["GET"], strict_slashes=False) +@check_authorization def connection(version, resource, resource_id): if resource != 'senders' and resource != 'receivers': abort(404) @@ -440,6 +486,7 @@ def _get_constraints(resource): @NODE_API.route('/x-nmos/connection//single///constraints', methods=["GET"], strict_slashes=False) +@check_authorization def constraints(version, resource, resource_id): base_data = [_get_constraints(resource)] @@ -472,6 +519,7 @@ def _check_constraint(constraint, transport_param): @NODE_API.route('/x-nmos/connection//single///staged', methods=["GET", "PATCH"], strict_slashes=False) +@check_authorization def staged(version, resource, resource_id): """ GET returns current staged data for given resource @@ -515,6 +563,7 @@ def staged(version, resource, resource_id): @NODE_API.route('/x-nmos/connection//single///active', methods=["GET"], strict_slashes=False) +@check_authorization def active(version, resource, resource_id): try: if resource == 'senders': @@ -529,6 +578,7 @@ def active(version, resource, resource_id): @NODE_API.route('/x-nmos/connection//single///transporttype', methods=["GET"], strict_slashes=False) +@check_authorization def transport_type(version, resource, resource_id): # TODO fetch from resource info base_data = "urn:x-nmos:transport:rtp" @@ -583,6 +633,7 @@ def node_sdp(media_type, media_subtype): @NODE_API.route('/x-nmos/connection//single///transportfile', methods=["GET"], strict_slashes=False) +@check_authorization def transport_file(version, resource, resource_id): # GET should either redirect to the location of the transport file or return it directly try: diff --git a/nmostesting/mocks/Registry.py b/nmostesting/mocks/Registry.py index af2e79225..401f5db6e 100644 --- a/nmostesting/mocks/Registry.py +++ b/nmostesting/mocks/Registry.py @@ -15,12 +15,13 @@ import time import flask import json -import re import uuid import functools from flask import request, jsonify, abort, Blueprint, Response from threading import Event, Lock + +from ..IS10Utils import IS10Utils from ..Config import PORT_BASE, ENABLE_AUTH, \ WEBSOCKET_PORT_BASE, ENABLE_HTTPS, SPECIFICATIONS from authlib.jose import jwt @@ -78,6 +79,7 @@ def reset(self): self.query_api_called = False self.paging_limit = 100 self.pagination_used = False + self.auth_cache = {} def add(self, headers, payload, version): self.last_time = time.time() @@ -162,38 +164,26 @@ def _get_client_id(self, headers): return None return None - def _check_path_match(self, path, path_wildcards): - path_match = False - for path_wildcard in path_wildcards: - pattern = path_wildcard.replace("*", ".*") - if re.search(pattern, path): - path_match = True - break - return path_match - - def check_authorized(self, headers, path, write=False): - if ENABLE_AUTH: - try: - if not request.headers["Authorization"].startswith("Bearer "): - return 400 - token = request.headers["Authorization"].split(" ")[1] - claims = jwt.decode(token, PRIMARY_AUTH.generate_jwk()) - claims.validate() - if claims["iss"] != PRIMARY_AUTH.make_issuer(): - return 401 - # TODO: Check 'aud' claim matches 'mocks.' - if not self._check_path_match(path, claims["x-nmos-registration"]["read"]): - return 403 - if write: - if not self._check_path_match(path, claims["x-nmos-registration"]["write"]): - return 403 - except KeyError: - # TODO: Add debug which can be returned in the error response JSON - return 400 - except Exception: - # TODO: Add debug which can be returned in the error response JSON - return 400 - return True + def check_authorization(self, auth, path, scope, write=False): + if not ENABLE_AUTH: + return True, "" + + if "Authorization" in request.headers and request.headers["Authorization"].startswith("Bearer ") \ + and scope in self.auth_cache and \ + ((write and self.auth_cache[scope]["Write"]) or self.auth_cache[scope]["Read"]): + return True, "" + + authorized, error_message = IS10Utils.check_authorization(auth, + path, + scope=scope, + write=write) + if authorized: + if scope not in self.auth_cache: + self.auth_cache[scope] = {"Read": True, "Write": write} + else: + self.auth_cache[scope]["Read"] = True + self.auth_cache[scope]["Write"] = self.auth_cache[scope]["Write"] or write + return authorized, error_message # Query API subscription support methods @@ -335,26 +325,35 @@ def _close_subscription_websockets(self): REGISTRY_API = Blueprint('registry_api', __name__) +# Authorization decorator +def check_enabled_and_authorization(func): + def wrapper(*args, **kwargs): + registry = REGISTRIES[flask.current_app.config["REGISTRY_INSTANCE"]] + if not registry.enabled: + abort(503) + authorized, error_message = registry.check_authorization(PRIMARY_AUTH, + request.path, + scope="x-nmos-registration") + if authorized is not True: + abort(authorized, description=error_message) + + return func(*args, **kwargs) + # Rename wrapper to allow decoration of decorator + wrapper.__name__ = func.__name__ + return wrapper + + @REGISTRY_API.route('/x-nmos', methods=["GET"], strict_slashes=False) +@check_enabled_and_authorization def x_nmos(): - registry = REGISTRIES[flask.current_app.config["REGISTRY_INSTANCE"]] - if not registry.enabled: - abort(503) - base_data = ['query/', 'registration/'] return Response(json.dumps(base_data), mimetype='application/json') @REGISTRY_API.route('/x-nmos/registration', methods=["GET"], strict_slashes=False) +@check_enabled_and_authorization def registration_root(): - registry = REGISTRIES[flask.current_app.config["REGISTRY_INSTANCE"]] - if not registry.enabled: - abort(503) - authorized = registry.check_authorized(request.headers, request.path) - if authorized is not True: - abort(authorized) - base_data = [version + '/' for version in SPECIFICATIONS["is-04"]["versions"]] return Response(json.dumps(base_data), mimetype='application/json') @@ -362,13 +361,8 @@ def registration_root(): # IS-04 resources @REGISTRY_API.route('/x-nmos/registration/', methods=["GET"], strict_slashes=False) +@check_enabled_and_authorization def base_resource(version): - registry = REGISTRIES[flask.current_app.config["REGISTRY_INSTANCE"]] - if not registry.enabled: - abort(503) - authorized = registry.check_authorized(request.headers, request.path) - if authorized is not True: - abort(authorized) base_data = ["resource/", "health/"] # Using json.dumps to support older Flask versions http://flask.pocoo.org/docs/1.0/security/#json-security @@ -376,13 +370,9 @@ def base_resource(version): @REGISTRY_API.route('/x-nmos/registration//resource', methods=["POST"]) +@check_enabled_and_authorization def post_resource(version): registry = REGISTRIES[flask.current_app.config["REGISTRY_INSTANCE"]] - if not registry.enabled: - abort(500) - authorized = registry.check_authorized(request.headers, request.path, True) - if authorized is not True: - abort(authorized) if not registry.test_first_reg: registered = False try: @@ -406,13 +396,9 @@ def post_resource(version): @REGISTRY_API.route('/x-nmos/registration//resource//', methods=["DELETE"]) +@check_enabled_and_authorization def delete_resource(version, resource_type, resource_id): registry = REGISTRIES[flask.current_app.config["REGISTRY_INSTANCE"]] - if not registry.enabled: - abort(500) - authorized = registry.check_authorized(request.headers, request.path, True) - if authorized is not True: - abort(authorized) resource_type = resource_type.rstrip("s") if not registry.test_first_reg: registered = False @@ -440,13 +426,9 @@ def delete_resource(version, resource_type, resource_id): @REGISTRY_API.route('/x-nmos/registration//health/nodes/', methods=["POST"]) +@check_enabled_and_authorization def heartbeat(version, node_id): registry = REGISTRIES[flask.current_app.config["REGISTRY_INSTANCE"]] - if not registry.enabled: - abort(500) - authorized = registry.check_authorized(request.headers, request.path, True) - if authorized is not True: - abort(authorized) if node_id in registry.get_resources()["node"]: # store raw request payload, in order to check for empty request bodies later try: @@ -459,28 +441,17 @@ def heartbeat(version, node_id): @REGISTRY_API.route('/x-nmos/query', methods=["GET"], strict_slashes=False) +@check_enabled_and_authorization def query_root(): - registry = REGISTRIES[flask.current_app.config["REGISTRY_INSTANCE"]] - if not registry.enabled: - abort(503) - authorized = registry.check_authorized(request.headers, request.path) - if authorized is not True: - abort(authorized) - base_data = [version + '/' for version in SPECIFICATIONS["is-04"]["versions"]] return Response(json.dumps(base_data), mimetype='application/json') @REGISTRY_API.route('/x-nmos/query/', methods=["GET"], strict_slashes=False) +@check_enabled_and_authorization def query(version): registry = REGISTRIES[flask.current_app.config["REGISTRY_INSTANCE"]] - if not registry.enabled: - abort(503) - authorized = registry.check_authorized(request.headers, request.path) - if authorized is not True: - abort(authorized) - registry.requested_query_api_version = version base_data = ['devices/', 'flows/', 'nodes/', 'receivers/', 'senders/', 'sources/', 'subscriptions/'] @@ -497,14 +468,9 @@ def compare_resources(resource1, resource2): @REGISTRY_API.route('/x-nmos/query//', methods=["GET"], strict_slashes=False) +@check_enabled_and_authorization def query_resource(version, resource): registry = REGISTRIES[flask.current_app.config["REGISTRY_INSTANCE"]] - if not registry.enabled: - abort(503) - authorized = registry.check_authorized(request.headers, request.path) - if authorized is not True: - abort(authorized) - registry.requested_query_api_version = version # NOTE: Advanced Query Syntax (RQL) is not currently supported @@ -622,14 +588,9 @@ def query_resource(version, resource): @REGISTRY_API.route('/x-nmos/query///', methods=['GET'], strict_slashes=False) +@check_enabled_and_authorization def get_resource(version, resource, resource_id): registry = REGISTRIES[flask.current_app.config["REGISTRY_INSTANCE"]] - if not registry.enabled: - abort(503) - authorized = registry.check_authorized(request.headers, request.path) - if authorized is not True: - abort(authorized) - registry.requested_query_api_version = version registry.query_api_called = True @@ -647,15 +608,9 @@ def get_resource(version, resource, resource_id): @REGISTRY_API.route('/x-nmos/query//subscriptions', methods=["POST"]) +@check_enabled_and_authorization def post_subscription(version): - registry = REGISTRIES[flask.current_app.config["REGISTRY_INSTANCE"]] - if not registry.enabled: - abort(503) - authorized = registry.check_authorized(request.headers, request.path) - if authorized is not True: - abort(authorized) - registry.requested_query_api_version = version subscription_request = request.json @@ -691,15 +646,9 @@ def post_subscription(version): @REGISTRY_API.route('/x-nmos/query//subscriptions/', methods=["DELETE"]) +@check_enabled_and_authorization def delete_subscription(version, subscription_id): - registry = REGISTRIES[flask.current_app.config["REGISTRY_INSTANCE"]] - if not registry.enabled: - abort(503) - authorized = registry.check_authorized(request.headers, request.path) - if authorized is not True: - abort(authorized) - registry.requested_query_api_version = version try: