diff --git a/src/story_protocol_python_sdk/__init__.py b/src/story_protocol_python_sdk/__init__.py index 4953f30..5092a83 100644 --- a/src/story_protocol_python_sdk/__init__.py +++ b/src/story_protocol_python_sdk/__init__.py @@ -29,8 +29,8 @@ RegistrationWithRoyaltyVaultAndLicenseTermsResponse, RegistrationWithRoyaltyVaultResponse, ) -from .types.resource.License import LicenseTermsInput -from .types.resource.Royalty import RoyaltyShareInput +from .types.resource.License import LicenseTermsInput, LicenseTermsOverride +from .types.resource.Royalty import NativeRoyaltyPolicy, RoyaltyShareInput from .utils.constants import ( DEFAULT_FUNCTION_SELECTOR, MAX_ROYALTY_TOKEN, @@ -44,6 +44,7 @@ from .utils.derivative_data import DerivativeDataInput from .utils.ip_metadata import IPMetadataInput from .utils.licensing_config_data import LicensingConfig +from .utils.pil_flavor import PILFlavor, PILFlavorError __all__ = [ "StoryClient", @@ -72,7 +73,9 @@ "LicensingConfig", "RegisterPILTermsAndAttachResponse", "RoyaltyShareInput", + "NativeRoyaltyPolicy", "LicenseTermsInput", + "LicenseTermsOverride", "MintNFT", "MintedNFT", "RegisterIpAssetResponse", @@ -86,4 +89,7 @@ "DEFAULT_FUNCTION_SELECTOR", "MAX_ROYALTY_TOKEN", "WIP_TOKEN_ADDRESS", + # utils + "PILFlavor", + "PILFlavorError", ] diff --git a/src/story_protocol_python_sdk/resources/Group.py b/src/story_protocol_python_sdk/resources/Group.py index 3166f05..1789f8b 100644 --- a/src/story_protocol_python_sdk/resources/Group.py +++ b/src/story_protocol_python_sdk/resources/Group.py @@ -22,6 +22,9 @@ from story_protocol_python_sdk.abi.LicensingModule.LicensingModule_client import ( LicensingModuleClient, ) +from story_protocol_python_sdk.abi.ModuleRegistry.ModuleRegistry_client import ( + ModuleRegistryClient, +) from story_protocol_python_sdk.abi.PILicenseTemplate.PILicenseTemplate_client import ( PILicenseTemplateClient, ) @@ -32,7 +35,7 @@ CollectRoyaltiesResponse, ) from story_protocol_python_sdk.utils.constants import ZERO_ADDRESS, ZERO_HASH -from story_protocol_python_sdk.utils.license_terms import LicenseTerms +from story_protocol_python_sdk.utils.licensing_config_data import LicensingConfigData from story_protocol_python_sdk.utils.sign import Sign from story_protocol_python_sdk.utils.transaction_utils import build_and_send_transaction from story_protocol_python_sdk.utils.validation import get_revenue_share @@ -59,8 +62,7 @@ def __init__(self, web3: Web3, account, chain_id: int): self.licensing_module_client = LicensingModuleClient(web3) self.license_registry_client = LicenseRegistryClient(web3) self.pi_license_template_client = PILicenseTemplateClient(web3) - - self.license_terms_util = LicenseTerms(web3) + self.module_registry_client = ModuleRegistryClient(web3) self.sign_util = Sign(web3, self.chain_id, self.account) def register_group(self, group_pool: str, tx_options: dict | None = None) -> dict: @@ -707,8 +709,8 @@ def _get_license_data(self, license_data: list) -> list: processed_item = { "licenseTemplate": license_template, "licenseTermsId": item["license_terms_id"], - "licensingConfig": self.license_terms_util.validate_licensing_config( - item.get("licensing_config", {}) + "licensingConfig": LicensingConfigData.validate_license_config( + self.module_registry_client, item.get("licensing_config", {}) ), } diff --git a/src/story_protocol_python_sdk/resources/IPAsset.py b/src/story_protocol_python_sdk/resources/IPAsset.py index 79558ed..5c23f4c 100644 --- a/src/story_protocol_python_sdk/resources/IPAsset.py +++ b/src/story_protocol_python_sdk/resources/IPAsset.py @@ -1,6 +1,6 @@ """Module for handling IP Account operations and transactions.""" -from dataclasses import asdict, is_dataclass +from dataclasses import asdict, is_dataclass, replace from typing import cast from ens.ens import Address, HexStr @@ -37,6 +37,9 @@ from story_protocol_python_sdk.abi.LicensingModule.LicensingModule_client import ( LicensingModuleClient, ) +from story_protocol_python_sdk.abi.ModuleRegistry.ModuleRegistry_client import ( + ModuleRegistryClient, +) from story_protocol_python_sdk.abi.Multicall3.Multicall3_client import Multicall3Client from story_protocol_python_sdk.abi.PILicenseTemplate.PILicenseTemplate_client import ( PILicenseTemplateClient, @@ -68,6 +71,7 @@ RegistrationWithRoyaltyVaultAndLicenseTermsResponse, RegistrationWithRoyaltyVaultResponse, ) +from story_protocol_python_sdk.types.resource.License import LicenseTermsInput from story_protocol_python_sdk.types.resource.Royalty import RoyaltyShareInput from story_protocol_python_sdk.utils.constants import ( DEADLINE, @@ -86,11 +90,14 @@ get_ip_metadata_dict, is_initial_ip_metadata, ) -from story_protocol_python_sdk.utils.license_terms import LicenseTerms +from story_protocol_python_sdk.utils.licensing_config_data import LicensingConfigData +from story_protocol_python_sdk.utils.pil_flavor import PILFlavor from story_protocol_python_sdk.utils.royalty import get_royalty_shares from story_protocol_python_sdk.utils.sign import Sign from story_protocol_python_sdk.utils.transaction_utils import build_and_send_transaction +from story_protocol_python_sdk.utils.util import convert_dict_keys_to_camel_case from story_protocol_python_sdk.utils.validation import ( + get_revenue_share, validate_address, validate_max_rts, ) @@ -128,8 +135,8 @@ def __init__(self, web3: Web3, account, chain_id: int): ) self.royalty_module_client = RoyaltyModuleClient(web3) self.multicall3_client = Multicall3Client(web3) - self.license_terms_util = LicenseTerms(web3) self.sign_util = Sign(web3, self.chain_id, self.account) + self.module_registry_client = ModuleRegistryClient(web3) def mint( self, @@ -685,7 +692,6 @@ def register_ip_and_attach_pil_terms( f"The NFT with id {token_id} is already registered as IP." ) license_terms = self._validate_license_terms_data(license_terms_data) - calculated_deadline = self.sign_util.get_deadline(deadline=deadline) # Get permission signature for all required permissions @@ -1329,7 +1335,6 @@ def register_ip_and_attach_pil_terms_and_distribute_royalty_tokens( license_terms = self._validate_license_terms_data(license_terms_data) calculated_deadline = self.sign_util.get_deadline(deadline=deadline) royalty_shares_obj = get_royalty_shares(royalty_shares) - signature_response = self.sign_util.get_permission_signature( ip_id=ip_id, deadline=calculated_deadline, @@ -2181,11 +2186,34 @@ def _validate_license_terms_data( terms_dict = term["terms"] licensing_config_dict = term["licensing_config"] + license_terms = PILFlavor.validate_license_terms( + LicenseTermsInput(**terms_dict) + ) + license_terms = replace( + license_terms, + commercial_rev_share=get_revenue_share( + license_terms.commercial_rev_share + ), + ) + if license_terms.royalty_policy != ZERO_ADDRESS: + is_whitelisted = self.royalty_module_client.isWhitelistedRoyaltyPolicy( + license_terms.royalty_policy + ) + if not is_whitelisted: + raise ValueError("The royalty_policy is not whitelisted.") + + if license_terms.currency != ZERO_ADDRESS: + is_whitelisted = self.royalty_module_client.isWhitelistedRoyaltyToken( + license_terms.currency + ) + if not is_whitelisted: + raise ValueError("The currency is not whitelisted.") + validated_license_terms_data.append( { - "terms": self.license_terms_util.validate_license_terms(terms_dict), - "licensingConfig": self.license_terms_util.validate_licensing_config( - licensing_config_dict + "terms": convert_dict_keys_to_camel_case(asdict(license_terms)), + "licensingConfig": LicensingConfigData.validate_license_config( + self.module_registry_client, licensing_config_dict ), } ) diff --git a/src/story_protocol_python_sdk/resources/License.py b/src/story_protocol_python_sdk/resources/License.py index 0abe68c..063782f 100644 --- a/src/story_protocol_python_sdk/resources/License.py +++ b/src/story_protocol_python_sdk/resources/License.py @@ -1,4 +1,7 @@ +from dataclasses import asdict, replace + from ens.ens import Address, HexStr +from typing_extensions import deprecated from web3 import Web3 from story_protocol_python_sdk.abi.IPAssetRegistry.IPAssetRegistry_client import ( @@ -16,14 +19,19 @@ from story_protocol_python_sdk.abi.PILicenseTemplate.PILicenseTemplate_client import ( PILicenseTemplateClient, ) +from story_protocol_python_sdk.abi.RoyaltyModule.RoyaltyModule_client import ( + RoyaltyModuleClient, +) from story_protocol_python_sdk.types.common import RevShareType +from story_protocol_python_sdk.types.resource.License import LicenseTermsInput from story_protocol_python_sdk.utils.constants import ZERO_ADDRESS -from story_protocol_python_sdk.utils.license_terms import LicenseTerms from story_protocol_python_sdk.utils.licensing_config_data import ( LicensingConfig, LicensingConfigData, ) +from story_protocol_python_sdk.utils.pil_flavor import PILFlavor from story_protocol_python_sdk.utils.transaction_utils import build_and_send_transaction +from story_protocol_python_sdk.utils.util import convert_dict_keys_to_camel_case from story_protocol_python_sdk.utils.validation import ( get_revenue_share, validate_address, @@ -49,8 +57,7 @@ def __init__(self, web3: Web3, account, chain_id: int): self.licensing_module_client = LicensingModuleClient(web3) self.ip_asset_registry_client = IPAssetRegistryClient(web3) self.module_registry_client = ModuleRegistryClient(web3) - - self.license_terms_util = LicenseTerms(web3) + self.royalty_module_client = RoyaltyModuleClient(web3) def _get_license_terms_id(self, license_terms: dict) -> int: """ @@ -106,48 +113,35 @@ def register_pil_terms( :return dict: A dictionary with the transaction hash and license terms ID. """ try: - license_terms = self.license_terms_util.validate_license_terms( - { - "transferable": transferable, - "royalty_policy": royalty_policy, - "default_minting_fee": default_minting_fee, - "expiration": expiration, - "commercial_use": commercial_use, - "commercial_attribution": commercial_attribution, - "commercializer_checker": commercializer_checker, - "commercializer_checker_data": commercializer_checker_data, - "commercial_rev_share": commercial_rev_share, - "commercial_rev_ceiling": commercial_rev_ceiling, - "derivatives_allowed": derivatives_allowed, - "derivatives_attribution": derivatives_attribution, - "derivatives_approval": derivatives_approval, - "derivatives_reciprocal": derivatives_reciprocal, - "derivative_rev_ceiling": derivative_rev_ceiling, - "currency": currency, - "uri": uri, - } - ) - - license_terms_id = self._get_license_terms_id(license_terms) - if (license_terms_id is not None) and (license_terms_id != 0): - return {"license_terms_id": license_terms_id} - - response = build_and_send_transaction( - self.web3, - self.account, - self.license_template_client.build_registerLicenseTerms_transaction, - license_terms, + return self._register_license_terms_helper( + license_terms=LicenseTermsInput( + transferable=transferable, + royalty_policy=royalty_policy, + default_minting_fee=default_minting_fee, + expiration=expiration, + commercial_use=commercial_use, + commercial_attribution=commercial_attribution, + commercializer_checker=commercializer_checker, + commercializer_checker_data=commercializer_checker_data, + commercial_rev_share=commercial_rev_share, + commercial_rev_ceiling=commercial_rev_ceiling, + derivatives_allowed=derivatives_allowed, + derivatives_attribution=derivatives_attribution, + derivatives_approval=derivatives_approval, + derivatives_reciprocal=derivatives_reciprocal, + derivative_rev_ceiling=derivative_rev_ceiling, + currency=currency, + uri=uri, + ), tx_options=tx_options, ) - - target_logs = self._parse_tx_license_terms_registered_event( - response["tx_receipt"] - ) - return {"tx_hash": response["tx_hash"], "license_terms_id": target_logs} - except Exception as e: raise e + @deprecated( + "Use register_pil_terms(**asdict(PILFlavor.non_commercial_social_remixing())) instead. " + "In the next major version, register_pil_terms will accept LicenseTermsInput directly, " + ) def register_non_com_social_remixing_pil( self, tx_options: dict | None = None ) -> dict: @@ -158,30 +152,16 @@ def register_non_com_social_remixing_pil( :return dict: A dictionary with the transaction hash and the license terms ID. """ try: - license_terms = self.license_terms_util.get_license_term_by_type( - self.license_terms_util.PIL_TYPE["NON_COMMERCIAL_REMIX"] - ) - - license_terms_id = self._get_license_terms_id(license_terms) - if (license_terms_id is not None) and (license_terms_id != 0): - return {"license_terms_id": license_terms_id} - - response = build_and_send_transaction( - self.web3, - self.account, - self.license_template_client.build_registerLicenseTerms_transaction, - license_terms, - tx_options=tx_options, - ) - - target_logs = self._parse_tx_license_terms_registered_event( - response["tx_receipt"] - ) - return {"tx_hash": response["tx_hash"], "license_terms_id": target_logs} - + license_terms = PILFlavor.non_commercial_social_remixing() + response = self._register_license_terms_helper(license_terms, tx_options) + return response except Exception as e: raise e + @deprecated( + "Use register_pil_terms(**asdict(PILFlavor.commercial_use(default_minting_fee, currency, royalty_policy))) instead. " + "In the next major version, register_pil_terms will accept LicenseTermsInput directly, " + ) def register_commercial_use_pil( self, default_minting_fee: int, @@ -199,38 +179,21 @@ def register_commercial_use_pil( :return dict: A dictionary with the transaction hash and the license terms ID. """ try: - complete_license_terms = self.license_terms_util.get_license_term_by_type( - self.license_terms_util.PIL_TYPE["COMMERCIAL_USE"], - { - "defaultMintingFee": default_minting_fee, - "currency": currency, - "royaltyPolicyAddress": royalty_policy, - }, + license_terms = PILFlavor.commercial_use( + default_minting_fee=default_minting_fee, + currency=currency, + royalty_policy=royalty_policy, ) - - license_terms_id = self._get_license_terms_id(complete_license_terms) - if (license_terms_id is not None) and (license_terms_id != 0): - return {"license_terms_id": license_terms_id} - - response = build_and_send_transaction( - self.web3, - self.account, - self.license_template_client.build_registerLicenseTerms_transaction, - complete_license_terms, - tx_options=tx_options, - ) - tx_hash = response["tx_hash"] - if not response["tx_receipt"]["logs"]: - return {"tx_hash": tx_hash} - - target_logs = self._parse_tx_license_terms_registered_event( - response["tx_receipt"] - ) - return {"tx_hash": tx_hash, "license_terms_id": target_logs} + response = self._register_license_terms_helper(license_terms, tx_options) + return response except Exception as e: raise e + @deprecated( + "Use register_pil_terms(**asdict(PILFlavor.commercial_remix(default_minting_fee, currency, commercial_rev_share, royalty_policy))) instead. " + "In the next major version, register_pil_terms will accept LicenseTermsInput directly, " + ) def register_commercial_remix_pil( self, default_minting_fee: int, @@ -250,39 +213,66 @@ def register_commercial_remix_pil( :return dict: A dictionary with the transaction hash and the license terms ID. """ try: - complete_license_terms = self.license_terms_util.get_license_term_by_type( - self.license_terms_util.PIL_TYPE["COMMERCIAL_REMIX"], - { - "defaultMintingFee": default_minting_fee, - "currency": currency, - "commercialRevShare": commercial_rev_share, - "royaltyPolicyAddress": royalty_policy, - }, + license_terms = PILFlavor.commercial_remix( + default_minting_fee=default_minting_fee, + currency=currency, + commercial_rev_share=commercial_rev_share, + royalty_policy=royalty_policy, ) + response = self._register_license_terms_helper(license_terms, tx_options) + return response - license_terms_id = self._get_license_terms_id(complete_license_terms) - if license_terms_id and license_terms_id != 0: - return {"license_terms_id": license_terms_id} - - response = build_and_send_transaction( - self.web3, - self.account, - self.license_template_client.build_registerLicenseTerms_transaction, - complete_license_terms, - tx_options=tx_options, - ) + except Exception as e: + raise e - tx_hash = response["tx_hash"] - if not response["tx_receipt"]["logs"]: - return {"tx_hash": tx_hash} + def _register_license_terms_helper( + self, license_terms: LicenseTermsInput, tx_options: dict | None = None + ): + """ + Validate the license terms. - target_logs = self._parse_tx_license_terms_registered_event( - response["tx_receipt"] + :param license_terms `LicenseTermsInput`: The license terms. + :param tx_options dict: [Optional] The transaction options. + :return dict: A dictionary with the transaction hash and the license terms ID. + """ + validated_license_terms = PILFlavor.validate_license_terms(license_terms) + validated_license_terms = replace( + validated_license_terms, + commercial_rev_share=get_revenue_share( + validated_license_terms.commercial_rev_share + ), + ) + if validated_license_terms.royalty_policy != ZERO_ADDRESS: + is_whitelisted = self.royalty_module_client.isWhitelistedRoyaltyPolicy( + validated_license_terms.royalty_policy ) - return {"tx_hash": tx_hash, "license_terms_id": target_logs} + if not is_whitelisted: + raise ValueError("The royalty_policy is not whitelisted.") - except Exception as e: - raise e + if validated_license_terms.currency != ZERO_ADDRESS: + is_whitelisted = self.royalty_module_client.isWhitelistedRoyaltyToken( + validated_license_terms.currency + ) + if not is_whitelisted: + raise ValueError("The currency is not whitelisted.") + camel_case_license_terms = convert_dict_keys_to_camel_case( + asdict(validated_license_terms) + ) + license_terms_id = self._get_license_terms_id(camel_case_license_terms) + if (license_terms_id is not None) and (license_terms_id != 0): + return {"license_terms_id": license_terms_id} + + response = build_and_send_transaction( + self.web3, + self.account, + self.license_template_client.build_registerLicenseTerms_transaction, + camel_case_license_terms, + tx_options=tx_options, + ) + target_logs = self._parse_tx_license_terms_registered_event( + response["tx_receipt"] + ) + return {"tx_hash": response["tx_hash"], "license_terms_id": target_logs} def _parse_tx_license_terms_registered_event(self, tx_receipt: dict) -> int | None: """ diff --git a/src/story_protocol_python_sdk/types/resource/License.py b/src/story_protocol_python_sdk/types/resource/License.py index 8de050f..187b8c1 100644 --- a/src/story_protocol_python_sdk/types/resource/License.py +++ b/src/story_protocol_python_sdk/types/resource/License.py @@ -1,10 +1,56 @@ from dataclasses import dataclass +from typing import Optional from ens.ens import Address, HexStr from story_protocol_python_sdk.types.resource.Royalty import RoyaltyPolicyInput +@dataclass +class LicenseTermsOverride: + """ + Optional override parameters for license terms. + All fields are optional and default to None. + + Attributes: + transferable: Whether the license is transferable. + royalty_policy: The type of royalty policy to be used. + default_minting_fee: The fee to be paid when minting a license. + expiration: The expiration period of the license. + commercial_use: Whether commercial use is allowed. + commercial_attribution: Whether commercial attribution is required. + commercializer_checker: The address of the commercializer checker contract. + commercializer_checker_data: The data to be passed to the commercializer checker contract. + commercial_rev_share: Percentage of revenue that must be shared with the licensor. Must be between 0 and 100 (where 100% represents 100_000_000). + commercial_rev_ceiling: The maximum revenue that can be collected from commercial use. + derivatives_allowed: Whether derivatives are allowed. + derivatives_attribution: Whether attribution is required for derivatives. + derivatives_approval: Whether approval is required for derivatives. + derivatives_reciprocal: Whether derivatives must have the same license terms. + derivative_rev_ceiling: The maximum revenue that can be collected from derivatives. + currency: The ERC20 token to be used to pay the minting fee. + uri: The URI of the license terms. + """ + + transferable: Optional[bool] = None + royalty_policy: Optional[RoyaltyPolicyInput] = None + default_minting_fee: Optional[int] = None + expiration: Optional[int] = None + commercial_use: Optional[bool] = None + commercial_attribution: Optional[bool] = None + commercializer_checker: Optional[Address] = None + commercializer_checker_data: Optional[Address | HexStr] = None + commercial_rev_share: Optional[int] = None + commercial_rev_ceiling: Optional[int] = None + derivatives_allowed: Optional[bool] = None + derivatives_attribution: Optional[bool] = None + derivatives_approval: Optional[bool] = None + derivatives_reciprocal: Optional[bool] = None + derivative_rev_ceiling: Optional[int] = None + currency: Optional[Address] = None + uri: Optional[str] = None + + @dataclass class LicenseTermsInput: """ diff --git a/src/story_protocol_python_sdk/utils/constants.py b/src/story_protocol_python_sdk/utils/constants.py index 1035c44..7151c27 100644 --- a/src/story_protocol_python_sdk/utils/constants.py +++ b/src/story_protocol_python_sdk/utils/constants.py @@ -4,7 +4,7 @@ DEFAULT_FUNCTION_SELECTOR = "0x00000000" MAX_ROYALTY_TOKEN = 100_000_000 ROYALTY_POLICY_LAP_ADDRESS = "0xBe54FB168b3c982b7AaE60dB6CF75Bd8447b390E" -ROYALTY_POLICY_LRP_ADDRESS = "0x9156E603C949481883B1D3355C6F1132D191FC41" +ROYALTY_POLICY_LRP_ADDRESS = "0x9156e603C949481883B1d3355c6f1132D191fC41" WIP_TOKEN_ADDRESS = "0x1514000000000000000000000000000000000000" # Default deadline for signature in seconds DEADLINE = 1000 diff --git a/src/story_protocol_python_sdk/utils/derivative_data.py b/src/story_protocol_python_sdk/utils/derivative_data.py index 7927e78..00ba2f0 100644 --- a/src/story_protocol_python_sdk/utils/derivative_data.py +++ b/src/story_protocol_python_sdk/utils/derivative_data.py @@ -15,7 +15,10 @@ ) from story_protocol_python_sdk.types.common import RevShareType from story_protocol_python_sdk.utils.constants import MAX_ROYALTY_TOKEN, ZERO_ADDRESS -from story_protocol_python_sdk.utils.validation import get_revenue_share +from story_protocol_python_sdk.utils.validation import ( + get_revenue_share, + validate_address, +) @dataclass @@ -110,13 +113,12 @@ def validate_parent_ip_ids_and_license_terms_ids(self): raise ValueError( "The number of parent IP IDs must match the number of license terms IDs." ) - total_royalty_percent = 0 for parent_ip_id, license_terms_id in zip( self.parent_ip_ids, self.license_terms_ids ): - if not Web3.is_checksum_address(parent_ip_id): - raise ValueError("The parent IP ID must be a valid address.") + validate_address(parent_ip_id) + if not self.ip_asset_registry_client.isRegistered(parent_ip_id): raise ValueError(f"The parent IP ID {parent_ip_id} must be registered.") if not self.license_registry_client.hasIpAttachedLicenseTerms( diff --git a/src/story_protocol_python_sdk/utils/license_terms.py b/src/story_protocol_python_sdk/utils/license_terms.py deleted file mode 100644 index 7b190b3..0000000 --- a/src/story_protocol_python_sdk/utils/license_terms.py +++ /dev/null @@ -1,288 +0,0 @@ -# src/story_protocol_python_sdk/utils/license_terms.py - -from ens.async_ens import HexStr -from web3 import Web3 - -from story_protocol_python_sdk.abi.RoyaltyModule.RoyaltyModule_client import ( - RoyaltyModuleClient, -) -from story_protocol_python_sdk.types.common import RevShareType -from story_protocol_python_sdk.utils.constants import ( - ROYALTY_POLICY_LAP_ADDRESS, - ZERO_ADDRESS, -) -from story_protocol_python_sdk.utils.validation import get_revenue_share - - -class LicenseTerms: - def __init__(self, web3: Web3): - self.web3 = web3 - self.royalty_module_client = RoyaltyModuleClient(web3) - - PIL_TYPE = { - "NON_COMMERCIAL_REMIX": "non_commercial_remix", - "COMMERCIAL_USE": "commercial_use", - "COMMERCIAL_REMIX": "commercial_remix", - } - - def get_license_term_by_type(self, type, term=None): - license_terms = { - "transferable": True, - "royaltyPolicy": "0x0000000000000000000000000000000000000000", - "defaultMintingFee": 0, - "expiration": 0, - "commercialUse": False, - "commercialAttribution": False, - "commercializerChecker": "0x0000000000000000000000000000000000000000", - "commercializerCheckerData": "0x0000000000000000000000000000000000000000", - "commercialRevShare": 0, - "commercialRevCeiling": 0, - "derivativesAllowed": True, - "derivativesAttribution": True, - "derivativesApproval": False, - "derivativesReciprocal": True, - "derivativeRevCeiling": 0, - "currency": "0x0000000000000000000000000000000000000000", - "uri": "", - } - - if type == self.PIL_TYPE["NON_COMMERCIAL_REMIX"]: - license_terms["commercializerCheckerData"] = "0x" - return license_terms - elif type == self.PIL_TYPE["COMMERCIAL_USE"]: - if not term or "defaultMintingFee" not in term or "currency" not in term: - raise ValueError( - "DefaultMintingFee, currency are required for commercial use PIL." - ) - - if term["royaltyPolicyAddress"] is None: - term["royaltyPolicyAddress"] = ROYALTY_POLICY_LAP_ADDRESS - - license_terms.update( - { - "defaultMintingFee": int(term["defaultMintingFee"]), - "currency": term["currency"], - "commercialUse": True, - "commercialAttribution": True, - "derivativesReciprocal": False, - "royaltyPolicy": term["royaltyPolicyAddress"], - } - ) - return license_terms - else: - if ( - not term - or "defaultMintingFee" not in term - or "currency" not in term - or "commercialRevShare" not in term - ): - raise ValueError( - "DefaultMintingFee, currency and commercialRevShare are required for commercial remix PIL." - ) - - if "royaltyPolicyAddress" not in term: - raise ValueError("royaltyPolicyAddress is required") - - if term["commercialRevShare"] < 0 or term["commercialRevShare"] > 100: - raise ValueError("CommercialRevShare should be between 0 and 100.") - - license_terms.update( - { - "defaultMintingFee": int(term["defaultMintingFee"]), - "currency": term["currency"], - "commercialUse": True, - "commercialAttribution": True, - "commercialRevShare": get_revenue_share(term["commercialRevShare"]), - "derivativesReciprocal": True, - "royaltyPolicy": term["royaltyPolicyAddress"], - } - ) - return license_terms - - def validate_license_terms(self, params): - royalty_policy = params.get("royalty_policy") - currency = params.get("currency") - if royalty_policy != ZERO_ADDRESS: - is_whitelisted = self.royalty_module_client.isWhitelistedRoyaltyPolicy( - royalty_policy - ) - if not is_whitelisted: - raise ValueError("The royalty policy is not whitelisted.") - - if currency != ZERO_ADDRESS: - is_whitelisted = self.royalty_module_client.isWhitelistedRoyaltyToken( - currency - ) - if not is_whitelisted: - raise ValueError("The currency token is not whitelisted.") - - if royalty_policy != ZERO_ADDRESS and currency == ZERO_ADDRESS: - raise ValueError("Royalty policy requires currency token.") - - commercial_rev_share = params.get("commercial_rev_share", 0) - if commercial_rev_share < 0 or commercial_rev_share > 100: - raise ValueError("commercial_rev_share should be between 0 and 100.") - - validated_params = { - "transferable": params.get("transferable"), - "royaltyPolicy": params.get("royalty_policy"), - "defaultMintingFee": int(params.get("default_minting_fee", 0)), - "expiration": int(params.get("expiration", 0)), - "commercialUse": params.get("commercial_use"), - "commercialAttribution": params.get("commercial_attribution"), - "commercializerChecker": params.get("commercializer_checker"), - "commercializerCheckerData": Web3.to_bytes( - hexstr=HexStr(params.get("commercializer_checker_data", ZERO_ADDRESS)) - ), - "commercialRevShare": get_revenue_share( - params.get("commercial_rev_share", 0) - ), - "commercialRevCeiling": int(params.get("commercial_rev_ceiling", 0)), - "derivativesAllowed": params.get("derivatives_allowed"), - "derivativesAttribution": params.get("derivatives_attribution"), - "derivativesApproval": params.get("derivatives_approval"), - "derivativesReciprocal": params.get("derivatives_reciprocal"), - "derivativeRevCeiling": int(params.get("derivative_rev_ceiling", 0)), - "currency": params.get("currency"), - "uri": params.get("uri"), - } - - self.verify_commercial_use(validated_params) - self.verify_derivatives(validated_params) - return validated_params - - def validate_licensing_config(self, params): - if not isinstance(params, dict): - raise TypeError("Licensing config parameters must be a dictionary") - - required_params = { - "is_set": bool, - "minting_fee": int, - "hook_data": str, - "licensing_hook": str, - "commercial_rev_share": int, - "disabled": bool, - "expect_minimum_group_reward_share": int, - "expect_group_reward_pool": str, - } - - for param, expected_type in required_params.items(): - if param in params: - if not isinstance(params[param], expected_type): - raise TypeError(f"{param} must be of type {expected_type.__name__}") - - default_params = { - "isSet": False, - "mintingFee": 0, - "hookData": ZERO_ADDRESS, - "licensingHook": ZERO_ADDRESS, - "commercialRevShare": 0, - "disabled": False, - "expectMinimumGroupRewardShare": 0, - "expectGroupRewardPool": ZERO_ADDRESS, - } - - if not params.get("is_set", False): - return default_params - - if params.get("minting_fee", 0) < 0: - raise ValueError("Minting fee cannot be negative") - - if ( - params.get("commercial_rev_share", 0) < 0 - or params.get("commercial_rev_share", 0) > 100 - ): - raise ValueError("Commercial revenue share must be between 0 and 100") - if ( - params.get("expect_minimum_group_reward_share", 0) < 0 - or params.get("expect_minimum_group_reward_share", 0) > 100 - ): - raise ValueError( - "Expect minimum group reward share must be between 0 and 100" - ) - validated_params = { - "isSet": params.get("is_set", False), - "mintingFee": params.get("minting_fee", 0), - "hookData": Web3.to_bytes(hexstr=HexStr(params["hook_data"])), - "licensingHook": params.get("licensing_hook", ZERO_ADDRESS), - "commercialRevShare": get_revenue_share(params["commercial_rev_share"]), - "disabled": params.get("disabled", False), - "expectMinimumGroupRewardShare": get_revenue_share( - params["expect_minimum_group_reward_share"], - RevShareType.EXPECT_MINIMUM_GROUP_REWARD_SHARE, - ), - "expectGroupRewardPool": params.get( - "expect_group_reward_pool", ZERO_ADDRESS - ), - } - - return validated_params - - def verify_commercial_use(self, terms): - if not terms.get("commercialUse", False): - if terms.get("commercialAttribution", False): - raise ValueError( - "Cannot add commercial attribution when commercial use is disabled." - ) - if terms.get("commercializerChecker") != ZERO_ADDRESS: - raise ValueError( - "Cannot add commercializerChecker when commercial use is disabled." - ) - if terms.get("commercialRevShare", 0) > 0: - raise ValueError( - "Cannot add commercial revenue share when commercial use is disabled." - ) - if terms.get("commercialRevCeiling", 0) > 0: - raise ValueError( - "Cannot add commercial revenue ceiling when commercial use is disabled." - ) - if terms.get("derivativeRevCeiling", 0) > 0: - raise ValueError( - "Cannot add derivative revenue ceiling when commercial use is disabled." - ) - if terms.get("royaltyPolicy") != ZERO_ADDRESS: - raise ValueError( - "Cannot add commercial royalty policy when commercial use is disabled." - ) - else: - if terms.get("royaltyPolicy") == ZERO_ADDRESS: - raise ValueError( - "Royalty policy is required when commercial use is enabled." - ) - - def verify_derivatives(self, terms): - if not terms.get("derivativesAllowed", False): - if terms.get("derivativesAttribution", False): - raise ValueError( - "Cannot add derivative attribution when derivative use is disabled." - ) - if terms.get("derivativesApproval", False): - raise ValueError( - "Cannot add derivative approval when derivative use is disabled." - ) - if terms.get("derivativesReciprocal", False): - raise ValueError( - "Cannot add derivative reciprocal when derivative use is disabled." - ) - if terms.get("derivativeRevCeiling", 0) > 0: - raise ValueError( - "Cannot add derivative revenue ceiling when derivative use is disabled." - ) - - def get_revenue_share(self, rev_share: int | str) -> int: - """ - Convert revenue share percentage to token amount. - - :param rev_share int|str: Revenue share percentage between 0-100 - :return int: Revenue share token amount - """ - try: - rev_share_number = float(rev_share) - except ValueError: - raise ValueError("CommercialRevShare must be a valid number.") - - if rev_share_number < 0 or rev_share_number > 100: - raise ValueError("CommercialRevShare should be between 0 and 100.") - - MAX_ROYALTY_TOKEN = 100000000 - return int((rev_share_number / 100) * MAX_ROYALTY_TOKEN) diff --git a/src/story_protocol_python_sdk/utils/licensing_config_data.py b/src/story_protocol_python_sdk/utils/licensing_config_data.py index abe3a8a..c1b0e89 100644 --- a/src/story_protocol_python_sdk/utils/licensing_config_data.py +++ b/src/story_protocol_python_sdk/utils/licensing_config_data.py @@ -2,6 +2,7 @@ from typing import TypedDict from ens.ens import Address, HexStr +from web3 import Web3 from story_protocol_python_sdk.abi.ModuleRegistry.ModuleRegistry_client import ( ModuleRegistryClient, @@ -37,7 +38,7 @@ class LicensingConfig(TypedDict): is_set: bool minting_fee: int licensing_hook: Address - hook_data: HexStr + hook_data: str commercial_rev_share: int disabled: bool expect_minimum_group_reward_share: int @@ -131,7 +132,11 @@ def validate_license_config( isSet=licensing_config["is_set"], mintingFee=licensing_config["minting_fee"], licensingHook=validate_address(licensing_config["licensing_hook"]), - hookData=licensing_config["hook_data"], + hookData=( + Web3.to_bytes(text=licensing_config["hook_data"]) + if licensing_config["hook_data"] != ZERO_HASH + else ZERO_HASH + ), commercialRevShare=get_revenue_share( licensing_config["commercial_rev_share"] ), diff --git a/src/story_protocol_python_sdk/utils/pil_flavor.py b/src/story_protocol_python_sdk/utils/pil_flavor.py new file mode 100644 index 0000000..5543e9c --- /dev/null +++ b/src/story_protocol_python_sdk/utils/pil_flavor.py @@ -0,0 +1,312 @@ +from dataclasses import asdict, replace +from typing import Optional + +from ens.ens import Address + +from story_protocol_python_sdk.types.resource.License import ( + LicenseTermsInput, + LicenseTermsOverride, +) +from story_protocol_python_sdk.types.resource.Royalty import RoyaltyPolicyInput +from story_protocol_python_sdk.utils.constants import ZERO_ADDRESS +from story_protocol_python_sdk.utils.royalty import royalty_policy_input_to_address +from story_protocol_python_sdk.utils.validation import validate_address + + +def _apply_override( + base: LicenseTermsInput, override: Optional[LicenseTermsOverride] +) -> LicenseTermsInput: + """Apply override values to base license terms, ignoring None values.""" + if not override: + return base + # Filter out None values from override + overrides = {k: v for k, v in asdict(override).items() if v is not None} + return replace(base, **overrides) + + +# PIL URIs for off-chain terms +PIL_URIS = { + "NCSR": "https://github.com/piplabs/pil-document/blob/998c13e6ee1d04eb817aefd1fe16dfe8be3cd7a2/off-chain-terms/NCSR.json", + "COMMERCIAL_USE": "https://github.com/piplabs/pil-document/blob/9a1f803fcf8101a8a78f1dcc929e6014e144ab56/off-chain-terms/CommercialUse.json", + "COMMERCIAL_REMIX": "https://github.com/piplabs/pil-document/blob/ad67bb632a310d2557f8abcccd428e4c9c798db1/off-chain-terms/CommercialRemix.json", + "CC_BY": "https://github.com/piplabs/pil-document/blob/998c13e6ee1d04eb817aefd1fe16dfe8be3cd7a2/off-chain-terms/CC-BY.json", +} + +# Common default values for license terms +COMMON_DEFAULTS: LicenseTermsInput = LicenseTermsInput( + transferable=True, + royalty_policy=ZERO_ADDRESS, + default_minting_fee=0, + expiration=0, + commercial_use=False, + commercial_attribution=False, + commercializer_checker=ZERO_ADDRESS, + commercializer_checker_data=ZERO_ADDRESS, + commercial_rev_share=0, + commercial_rev_ceiling=0, + derivatives_allowed=False, + derivatives_attribution=False, + derivatives_approval=False, + derivatives_reciprocal=False, + derivative_rev_ceiling=0, + currency=ZERO_ADDRESS, + uri="", +) + + +class PILFlavorError(Exception): + """Exception for PIL flavor validation errors.""" + + pass + + +class PILFlavor: + """ + Pre-configured Programmable IP License (PIL) flavors for ease of use. + + The PIL is highly configurable, but these pre-configured license terms (flavors) + are the most popular options that cover common use cases. + + See: https://docs.story.foundation/concepts/programmable-ip-license/pil-flavors + + Example: + # Create a commercial use license + commercial_license = PILFlavor.commercial_use( + default_minting_fee=1000000000000000000, # 1 IP minting fee + currency="0x1234...", # currency token + royalty_policy="LAP" # royalty policy + ) + + # Create a non-commercial social remixing license + remix_license = PILFlavor.non_commercial_social_remixing() + """ + + _non_commercial_social_remixing_pil = replace( + COMMON_DEFAULTS, + commercial_use=False, + commercial_attribution=False, + derivatives_allowed=True, + derivatives_attribution=True, + derivatives_approval=False, + derivatives_reciprocal=True, + uri=PIL_URIS["NCSR"], + ) + + _commercial_use = replace( + COMMON_DEFAULTS, + commercial_use=True, + commercial_attribution=True, + derivatives_allowed=False, + derivatives_attribution=False, + derivatives_approval=False, + derivatives_reciprocal=False, + uri=PIL_URIS["COMMERCIAL_USE"], + ) + + _commercial_remix = replace( + COMMON_DEFAULTS, + commercial_use=True, + commercial_attribution=True, + derivatives_allowed=True, + derivatives_attribution=True, + derivatives_approval=False, + derivatives_reciprocal=True, + uri=PIL_URIS["COMMERCIAL_REMIX"], + ) + + _creative_commons_attribution = replace( + COMMON_DEFAULTS, + commercial_use=True, + commercial_attribution=True, + derivatives_allowed=True, + derivatives_attribution=True, + derivatives_approval=False, + derivatives_reciprocal=True, + uri=PIL_URIS["CC_BY"], + ) + + @staticmethod + def non_commercial_social_remixing( + override: Optional[LicenseTermsOverride] = None, + ) -> LicenseTermsInput: + """ + Gets the values to create a Non-Commercial Social Remixing license terms flavor. + + See: https://docs.story.foundation/concepts/programmable-ip-license/pil-flavors#non-commercial-social-remixing + + :param `override` `Optional[LicenseTermsOverride]`: Optional overrides for the default license terms. + :return: `LicenseTermsInput`: The license terms. + """ + terms = _apply_override(PILFlavor._non_commercial_social_remixing_pil, override) + return PILFlavor.validate_license_terms(terms) + + @staticmethod + def commercial_use( + default_minting_fee: int, + currency: Address, + royalty_policy: Optional[RoyaltyPolicyInput] = None, + override: Optional[LicenseTermsOverride] = None, + ) -> LicenseTermsInput: + """ + Gets the values to create a Commercial Use license terms flavor. + + See: https://docs.story.foundation/concepts/programmable-ip-license/pil-flavors#commercial-use + + :param `default_minting_fee` int: The fee to be paid when minting a license. + :param `currency` Address: The ERC20 token to be used to pay the minting fee. + :param `royalty_policy` `Optional[RoyaltyPolicyInput]`: The type of royalty policy to be used.(default: LAP) + :param `override` `Optional[LicenseTermsOverride]`: Optional overrides for the default license terms. + :return: `LicenseTermsInput`: The license terms. + """ + base = replace( + PILFlavor._commercial_use, + default_minting_fee=default_minting_fee, + currency=currency, + royalty_policy=royalty_policy, + ) + terms = _apply_override(base, override) + return PILFlavor.validate_license_terms(terms) + + @staticmethod + def commercial_remix( + default_minting_fee: int, + currency: Address, + commercial_rev_share: int, + royalty_policy: Optional[RoyaltyPolicyInput] = None, + override: Optional[LicenseTermsOverride] = None, + ) -> LicenseTermsInput: + """ + Gets the values to create a Commercial Remixing license terms flavor. + + See: https://docs.story.foundation/concepts/programmable-ip-license/pil-flavors#commercial-remix + + :param `default_minting_fee` int: The fee to be paid when minting a license. + :param `currency` Address: The ERC20 token to be used to pay the minting fee. + :param `commercial_rev_share` int: Percentage of revenue that must be shared with the licensor. Must be between 0 and 100. + :param `royalty_policy` `Optional[RoyaltyPolicyInput]`: The type of royalty policy to be used.(default: LAP) + :param `override` `Optional[LicenseTermsOverride]`: Optional overrides for the default license terms. + :return: `LicenseTermsInput`: The license terms. + """ + base = replace( + PILFlavor._commercial_remix, + default_minting_fee=default_minting_fee, + currency=currency, + commercial_rev_share=commercial_rev_share, + royalty_policy=royalty_policy, + ) + terms = _apply_override(base, override) + return PILFlavor.validate_license_terms(terms) + + @staticmethod + def creative_commons_attribution( + currency: Address, + royalty_policy: Optional[RoyaltyPolicyInput] = None, + override: Optional[LicenseTermsOverride] = None, + ) -> LicenseTermsInput: + """ + Gets the values to create a Creative Commons Attribution (CC-BY) license terms flavor. + + See: https://docs.story.foundation/concepts/programmable-ip-license/pil-flavors#creative-commons-attribution + + :param `currency` Address: The ERC20 token to be used to pay the minting fee. + :param `royalty_policy` `Optional[RoyaltyPolicyInput]`: The type of royalty policy to be used.(default: LAP) + :param `override` `Optional[LicenseTermsOverride]`: Optional overrides for the default license terms. + :return: `LicenseTermsInput`: The license terms. + """ + base = replace( + PILFlavor._creative_commons_attribution, + currency=currency, + royalty_policy=royalty_policy, + ) + terms = _apply_override(base, override) + return PILFlavor.validate_license_terms(terms) + + @staticmethod + def validate_license_terms(params: LicenseTermsInput) -> LicenseTermsInput: + """ + Validates and normalizes license terms. + + :param params `LicenseTermsInput`: The license terms parameters to validate. + :return: `LicenseTermsInput`: The validated and normalized license terms. + :raises PILFlavorError: If validation fails. + """ + # Normalize royalty_policy to address + royalty_policy = royalty_policy_input_to_address(params.royalty_policy) + currency = validate_address(params.currency) + + normalized = replace( + params, + royalty_policy=royalty_policy, + ) + + # Validate royalty policy and currency relationship + if royalty_policy != ZERO_ADDRESS and currency == ZERO_ADDRESS: + raise PILFlavorError( + "royalty_policy is not zero address and currency cannot be zero address." + ) + + # Validate default_minting_fee + if normalized.default_minting_fee < 0: + raise PILFlavorError( + "default_minting_fee should be greater than or equal to 0." + ) + + if normalized.default_minting_fee > 0 and royalty_policy == ZERO_ADDRESS: + raise PILFlavorError( + "royalty_policy is required when default_minting_fee is greater than 0." + ) + + # Validate commercial use and derivatives + PILFlavor._verify_commercial_use(normalized) + PILFlavor._verify_derivatives(normalized) + + if normalized.commercial_rev_share > 100 or normalized.commercial_rev_share < 0: + raise PILFlavorError("commercial_rev_share must be between 0 and 100.") + + return normalized + + @staticmethod + def _verify_commercial_use(terms: LicenseTermsInput) -> None: + """Verify commercial use related fields.""" + royalty_policy = royalty_policy_input_to_address(terms.royalty_policy) + + if not terms.commercial_use: + commercial_fields = [ + ("commercial_attribution", terms.commercial_attribution), + ( + "commercializer_checker", + terms.commercializer_checker != ZERO_ADDRESS, + ), + ("commercial_rev_share", terms.commercial_rev_share > 0), + ("commercial_rev_ceiling", terms.commercial_rev_ceiling > 0), + ("derivative_rev_ceiling", terms.derivative_rev_ceiling > 0), + ("royalty_policy", royalty_policy != ZERO_ADDRESS), + ] + + for field, value in commercial_fields: + if value: + raise PILFlavorError( + f"cannot add {field} when commercial_use is False." + ) + else: + if royalty_policy == ZERO_ADDRESS: + raise PILFlavorError( + "royalty_policy is required when commercial_use is True." + ) + + @staticmethod + def _verify_derivatives(terms: LicenseTermsInput) -> None: + """Verify derivatives related fields.""" + if not terms.derivatives_allowed: + derivative_fields = [ + ("derivatives_attribution", terms.derivatives_attribution), + ("derivatives_approval", terms.derivatives_approval), + ("derivatives_reciprocal", terms.derivatives_reciprocal), + ("derivative_rev_ceiling", terms.derivative_rev_ceiling > 0), + ] + + for field, value in derivative_fields: + if value: + raise PILFlavorError( + f"cannot add {field} when derivatives_allowed is False." + ) diff --git a/src/story_protocol_python_sdk/utils/royalty.py b/src/story_protocol_python_sdk/utils/royalty.py index 367fca1..acea2a3 100644 --- a/src/story_protocol_python_sdk/utils/royalty.py +++ b/src/story_protocol_python_sdk/utils/royalty.py @@ -3,6 +3,7 @@ from typing import List from ens.ens import Address +from typing_extensions import cast from story_protocol_python_sdk.types.resource.Royalty import ( NativeRoyaltyPolicy, @@ -82,12 +83,9 @@ def royalty_policy_input_to_address( if input is None: return ROYALTY_POLICY_LAP_ADDRESS - if isinstance(input, str): - return validate_address(input) - if input == NativeRoyaltyPolicy.LAP: return ROYALTY_POLICY_LAP_ADDRESS elif input == NativeRoyaltyPolicy.LRP: return ROYALTY_POLICY_LRP_ADDRESS - - return ROYALTY_POLICY_LAP_ADDRESS + else: + return validate_address(cast(str, input)) diff --git a/src/story_protocol_python_sdk/utils/util.py b/src/story_protocol_python_sdk/utils/util.py new file mode 100644 index 0000000..85e8b35 --- /dev/null +++ b/src/story_protocol_python_sdk/utils/util.py @@ -0,0 +1,19 @@ +def snake_to_camel(snake_str: str) -> str: + """ + Convert a snake_case string to camelCase. + + :param snake_str str: The snake_case string to convert. + :return str: The camelCase string. + """ + components = snake_str.split("_") + return components[0] + "".join(word.capitalize() for word in components[1:]) + + +def convert_dict_keys_to_camel_case(snake_dict: dict) -> dict: + """ + Convert all keys in a dictionary from snake_case to camelCase. + + :param snake_dict dict: The dictionary with snake_case keys. + :return dict: A new dictionary with camelCase keys. + """ + return {snake_to_camel(key): value for key, value in snake_dict.items()} diff --git a/tests/integration/test_integration_ip_asset.py b/tests/integration/test_integration_ip_asset.py index 6ed8b1e..a8076cd 100644 --- a/tests/integration/test_integration_ip_asset.py +++ b/tests/integration/test_integration_ip_asset.py @@ -1,3 +1,5 @@ +from dataclasses import asdict + import pytest from story_protocol_python_sdk import ( @@ -8,9 +10,12 @@ IPMetadataInput, LicenseTermsDataInput, LicenseTermsInput, + LicenseTermsOverride, LicensingConfig, MintedNFT, MintNFT, + NativeRoyaltyPolicy, + PILFlavor, RoyaltyShareInput, StoryClient, ) @@ -114,8 +119,8 @@ def child_ip_id(self, story_client: StoryClient): @pytest.fixture(scope="module") def non_commercial_license(self, story_client: StoryClient): - license_register_response = ( - story_client.License.register_non_com_social_remixing_pil() + license_register_response = story_client.License.register_pil_terms( + **asdict(PILFlavor.non_commercial_social_remixing()) ) no_commercial_license_terms_id = license_register_response["license_terms_id"] return no_commercial_license_terms_id @@ -1234,24 +1239,9 @@ def test_register_ip_asset_minted_with_license_terms( ), license_terms_data=[ LicenseTermsDataInput( - terms=LicenseTermsInput( - transferable=True, - royalty_policy=ROYALTY_POLICY, - default_minting_fee=10000, - expiration=1000, - commercial_use=True, - commercial_attribution=False, - commercializer_checker=ZERO_ADDRESS, - commercializer_checker_data=ZERO_HASH, - commercial_rev_share=10, - commercial_rev_ceiling=0, - derivatives_allowed=True, - derivatives_attribution=True, - derivatives_approval=False, - derivatives_reciprocal=True, - derivative_rev_ceiling=0, + terms=PILFlavor.commercial_use( + default_minting_fee=10, currency=WIP_TOKEN_ADDRESS, - uri="test-minted-license-terms", ), licensing_config=LicensingConfig( is_set=True, @@ -1300,24 +1290,10 @@ def test_register_ip_asset_minted_with_license_terms_and_royalty_shares( ), license_terms_data=[ LicenseTermsDataInput( - terms=LicenseTermsInput( - transferable=True, - royalty_policy=ROYALTY_POLICY, - default_minting_fee=10000, - expiration=1000, - commercial_use=True, - commercial_attribution=False, - commercializer_checker=ZERO_ADDRESS, - commercializer_checker_data=ZERO_HASH, - commercial_rev_share=10, - commercial_rev_ceiling=0, - derivatives_allowed=True, - derivatives_attribution=True, - derivatives_approval=False, - derivatives_reciprocal=True, - derivative_rev_ceiling=0, + terms=PILFlavor.commercial_remix( + default_minting_fee=10, currency=WIP_TOKEN_ADDRESS, - uri="test-minted-license-terms-with-royalty", + commercial_rev_share=10, ), licensing_config=LicensingConfig( is_set=True, @@ -1382,24 +1358,8 @@ def test_register_ip_asset_mint_with_license_terms( ), license_terms_data=[ LicenseTermsDataInput( - terms=LicenseTermsInput( - transferable=True, - royalty_policy=ROYALTY_POLICY, - default_minting_fee=10000, - expiration=1000, - commercial_use=True, - commercial_attribution=False, - commercializer_checker=ZERO_ADDRESS, - commercializer_checker_data=ZERO_HASH, - commercial_rev_share=10, - commercial_rev_ceiling=0, - derivatives_allowed=True, - derivatives_attribution=True, - derivatives_approval=False, - derivatives_reciprocal=True, - derivative_rev_ceiling=0, + terms=PILFlavor.creative_commons_attribution( currency=WIP_TOKEN_ADDRESS, - uri="test-mint-license-terms", ), licensing_config=LicensingConfig( is_set=True, @@ -1444,24 +1404,13 @@ def test_register_ip_asset_mint_with_license_terms_and_royalty_shares( ), license_terms_data=[ LicenseTermsDataInput( - terms=LicenseTermsInput( - transferable=True, - royalty_policy=ROYALTY_POLICY, - default_minting_fee=10000, - expiration=1000, - commercial_use=True, - commercial_attribution=False, - commercializer_checker=ZERO_ADDRESS, - commercializer_checker_data=ZERO_HASH, - commercial_rev_share=10, - commercial_rev_ceiling=0, - derivatives_allowed=True, - derivatives_attribution=True, - derivatives_approval=False, - derivatives_reciprocal=True, - derivative_rev_ceiling=0, - currency=WIP_TOKEN_ADDRESS, - uri="test-mint-license-terms-with-royalty", + terms=PILFlavor.non_commercial_social_remixing( + override=LicenseTermsOverride( + commercial_use=True, + commercial_attribution=True, + royalty_policy=NativeRoyaltyPolicy.LRP, + currency=WIP_TOKEN_ADDRESS, + ), ), licensing_config=LicensingConfig( is_set=True, @@ -1505,24 +1454,10 @@ def parent_ip_with_commercial_license( ), license_terms_data=[ LicenseTermsDataInput( - terms=LicenseTermsInput( - transferable=True, - royalty_policy=ROYALTY_POLICY, + terms=PILFlavor.commercial_remix( default_minting_fee=0, - expiration=0, - commercial_use=True, - commercial_attribution=False, - commercializer_checker=ZERO_ADDRESS, - commercializer_checker_data=ZERO_HASH, commercial_rev_share=10, - commercial_rev_ceiling=0, - derivatives_allowed=True, - derivatives_attribution=True, - derivatives_approval=False, - derivatives_reciprocal=True, - derivative_rev_ceiling=0, currency=WIP_TOKEN_ADDRESS, - uri="test-parent-license-for-derivative", ), licensing_config=LicensingConfig( is_set=True, diff --git a/tests/integration/test_integration_license.py b/tests/integration/test_integration_license.py index 881ffe0..6398adf 100644 --- a/tests/integration/test_integration_license.py +++ b/tests/integration/test_integration_license.py @@ -336,7 +336,7 @@ def test_set_licensing_config( minting_fee=100, is_set=True, licensing_hook=ZERO_ADDRESS, - hook_data=b"", + hook_data="test", commercial_rev_share=100, disabled=False, expect_minimum_group_reward_share=10, @@ -361,7 +361,7 @@ def test_get_licensing_config( is_set=True, minting_fee=100, licensing_hook=ZERO_ADDRESS, - hook_data=b"", + hook_data=b"test", disabled=False, expect_minimum_group_reward_share=10 * 10**6, expect_group_reward_pool=ZERO_ADDRESS, diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index e5389e8..f1de4cc 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -48,16 +48,6 @@ def create_mock_contract(*args, **kwargs): return mock_web3 -@pytest.fixture(scope="package") -def mock_is_checksum_address(): - def _mock(is_checksum_address: bool = True): - return patch.object( - Web3, "is_checksum_address", return_value=is_checksum_address - ) - - return _mock - - @pytest.fixture(scope="package") def mock_signature_related_methods(): class SignatureMockContext: @@ -71,10 +61,6 @@ def __enter__(self): mock_contract.encode_abi = MagicMock(return_value=b"encoded_data") mock_client.contract = mock_contract - # Create all the patches - mock_web3_to_bytes = patch.object( - Web3, "to_bytes", return_value=b"mock_bytes" - ) mock_account_sign_message = patch.object( Account, "sign_message", @@ -95,13 +81,11 @@ def __init__(self, web3, contract_address=None): ) # Apply all patches at once - mock_web3_to_bytes.start() mock_account_sign_message.start() mock_ip_account_client.start() # Store patches for cleanup self.patches = [ - mock_web3_to_bytes, mock_account_sign_message, mock_ip_account_client, ] diff --git a/tests/unit/fixtures/data.py b/tests/unit/fixtures/data.py index bfca21d..136ded9 100644 --- a/tests/unit/fixtures/data.py +++ b/tests/unit/fixtures/data.py @@ -21,6 +21,7 @@ "expiration": 100, "commercial_use": True, "commercial_attribution": True, + "commercial_rev_ceiling": 0, "commercializer_checker": True, "commercializer_checker_data": ADDRESS, "derivatives_allowed": True, @@ -36,7 +37,7 @@ "is_set": True, "minting_fee": 10, "licensing_hook": ADDRESS, - "hook_data": ADDRESS, + "hook_data": "test", "commercial_rev_share": 10, "disabled": False, "expect_minimum_group_reward_share": 10, diff --git a/tests/unit/resources/test_ip_asset.py b/tests/unit/resources/test_ip_asset.py index 16ca846..0ed3785 100644 --- a/tests/unit/resources/test_ip_asset.py +++ b/tests/unit/resources/test_ip_asset.py @@ -6,8 +6,14 @@ from story_protocol_python_sdk import ( MAX_ROYALTY_TOKEN, + LicenseTermsDataInput, + LicenseTermsOverride, + LicensingConfig, MintedNFT, MintNFT, + NativeRoyaltyPolicy, + PILFlavor, + PILFlavorError, RoyaltyShareInput, ) from story_protocol_python_sdk.abi.IPAccountImpl.IPAccountImpl_client import ( @@ -193,6 +199,30 @@ def test_register_with_metadata( assert result["ip_id"] == IP_ID +@pytest.fixture(scope="class") +def mock_is_whitelisted_royalty_policy(ip_asset): + def _mock(is_whitelisted: bool = True): + return patch.object( + ip_asset.royalty_module_client, + "isWhitelistedRoyaltyPolicy", + return_value=is_whitelisted, + ) + + return _mock + + +@pytest.fixture(scope="class") +def mock_is_whitelisted_royalty_token(ip_asset): + def _mock(is_whitelisted: bool = True): + return patch.object( + ip_asset.royalty_module_client, + "isWhitelistedRoyaltyToken", + return_value=is_whitelisted, + ) + + return _mock + + class TestRegisterDerivativeIp: def test_ip_is_already_registered( self, ip_asset, mock_get_ip_id, mock_is_registered @@ -298,7 +328,7 @@ def test_royalty_policy_commercial_rev_share_is_less_than_0( ): with mock_get_ip_id(), mock_is_registered(): with pytest.raises( - ValueError, match="commercial_rev_share should be between 0 and 100." + PILFlavorError, match="commercial_rev_share must be between 0 and 100." ): ip_asset.register_ip_and_attach_pil_terms( nft_contract=ADDRESS, @@ -334,7 +364,6 @@ def test_transaction_to_be_called_with_correct_parameters( ip_asset.license_attachment_workflows_client, "build_registerIpAndAttachPILTerms_transaction", ) as mock_build_registerIpAndAttachPILTerms_transaction: - ip_asset.register_ip_and_attach_pil_terms( nft_contract=ADDRESS, token_id=3, @@ -359,7 +388,7 @@ def test_transaction_to_be_called_with_correct_parameters( "commercialUse": True, "commercialAttribution": True, "commercializerChecker": True, - "commercializerCheckerData": b"mock_bytes", + "commercializerCheckerData": "0x1234567890123456789012345678901234567890", "commercialRevShare": 19000000, "commercialRevCeiling": 0, "derivativesAllowed": True, @@ -373,7 +402,7 @@ def test_transaction_to_be_called_with_correct_parameters( "licensingConfig": { "isSet": True, "mintingFee": 10, - "hookData": b"mock_bytes", + "hookData": Web3.to_bytes(text="test"), "licensingHook": "0x1234567890123456789012345678901234567890", "commercialRevShare": 10000000, "disabled": False, @@ -1003,7 +1032,7 @@ def test_successful_registration( "commercialUse": True, "commercialAttribution": True, "commercializerChecker": True, - "commercializerCheckerData": b"mock_bytes", + "commercializerCheckerData": "0x1234567890123456789012345678901234567890", "commercialRevShare": 19000000, "commercialRevCeiling": 0, "derivativesAllowed": True, @@ -1017,7 +1046,7 @@ def test_successful_registration( "licensingConfig": { "isSet": True, "mintingFee": 10, - "hookData": b"mock_bytes", + "hookData": Web3.to_bytes(text="test"), "licensingHook": "0x1234567890123456789012345678901234567890", "commercialRevShare": 10000000, "disabled": False, @@ -2172,12 +2201,40 @@ def test_success_when_license_terms_data_and_royalty_shares_provided_for_minted_ assert result["royalty_vault"] == royalty_vault assert result["distribute_royalty_tokens_tx_hash"] == TX_HASH.hex() + def test_throw_error_when_royalty_policy_is_not_whitelisted( + self, + ip_asset: IPAsset, + mock_is_whitelisted_royalty_policy, + ): + with mock_is_whitelisted_royalty_policy(False): + with pytest.raises( + ValueError, match="The royalty_policy is not whitelisted." + ): + ip_asset.register_ip_asset( + nft=MintNFT(type="mint", spg_nft_contract=ADDRESS), + license_terms_data=LICENSE_TERMS_DATA, + ) + + def test_throw_error_when_currency_is_not_whitelisted( + self, + ip_asset: IPAsset, + mock_is_whitelisted_royalty_token, + ): + with mock_is_whitelisted_royalty_token(False): + with pytest.raises(ValueError, match="The currency is not whitelisted."): + ip_asset.register_ip_asset( + nft=MintNFT(type="mint", spg_nft_contract=ADDRESS), + license_terms_data=LICENSE_TERMS_DATA, + ) + def test_success_when_license_terms_data_and_royalty_shares_provided_for_mint_nft( self, ip_asset: IPAsset, mock_parse_ip_registered_event, mock_parse_tx_license_terms_attached_event, mock_get_royalty_vault_address_by_ip_id, + mock_is_whitelisted_royalty_policy, + mock_is_whitelisted_royalty_token, ): royalty_shares = [ RoyaltyShareInput(recipient=ACCOUNT_ADDRESS, percentage=50.0), @@ -2188,6 +2245,8 @@ def test_success_when_license_terms_data_and_royalty_shares_provided_for_mint_nf mock_parse_ip_registered_event(), mock_parse_tx_license_terms_attached_event(), mock_get_royalty_vault_address_by_ip_id(royalty_vault), + mock_is_whitelisted_royalty_policy(True), + mock_is_whitelisted_royalty_token(True), patch.object( ip_asset.royalty_token_distribution_workflows_client, "build_mintAndRegisterIpAndAttachPILTermsAndDistributeRoyaltyTokens_transaction", @@ -2196,7 +2255,30 @@ def test_success_when_license_terms_data_and_royalty_shares_provided_for_mint_nf ): result = ip_asset.register_ip_asset( nft=MintNFT(type="mint", spg_nft_contract=ADDRESS), - license_terms_data=LICENSE_TERMS_DATA, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.non_commercial_social_remixing( + override=LicenseTermsOverride( + derivatives_allowed=True, + derivatives_attribution=True, + derivatives_approval=False, + derivatives_reciprocal=True, + royalty_policy=ZERO_ADDRESS, + uri="https://example.com", + ), + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=10, + licensing_hook=ZERO_ADDRESS, + hook_data="test", + commercial_rev_share=10, + disabled=False, + expect_minimum_group_reward_share=1, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], royalty_shares=royalty_shares, ) assert ( @@ -2209,6 +2291,39 @@ def test_success_when_license_terms_data_and_royalty_shares_provided_for_mint_nf mock_build_register_transaction.call_args[0][2] == IPMetadata.from_input().get_validated_data() ) # ip_metadata + assert mock_build_register_transaction.call_args[0][3] == [ + { + "terms": { + "transferable": True, + "commercialAttribution": False, + "commercialRevCeiling": 0, + "commercialRevShare": 0, + "commercialUse": False, + "currency": ZERO_ADDRESS, + "derivativeRevCeiling": 0, + "derivativesAllowed": True, + "derivativesApproval": False, + "derivativesAttribution": True, + "derivativesReciprocal": True, + "expiration": 0, + "defaultMintingFee": 0, + "royaltyPolicy": ZERO_ADDRESS, + "commercializerChecker": ZERO_ADDRESS, + "commercializerCheckerData": ZERO_ADDRESS, + "uri": "https://example.com", + }, + "licensingConfig": { + "isSet": True, + "mintingFee": 10, + "hookData": Web3.to_bytes(text="test"), + "licensingHook": ZERO_ADDRESS, + "commercialRevShare": 10 * 10**6, + "disabled": False, + "expectMinimumGroupRewardShare": 1 * 10**6, + "expectGroupRewardPool": ZERO_ADDRESS, + }, + }, + ] # license_terms_data assert ( mock_build_register_transaction.call_args[0][4] == royalty_shares_obj["royalty_shares"] @@ -2250,7 +2365,33 @@ def test_success_when_license_terms_data_royalty_shares_and_all_optional_paramet allow_duplicates=False, recipient=ADDRESS, ), - license_terms_data=LICENSE_TERMS_DATA, + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.creative_commons_attribution( + currency=ADDRESS, + override=LicenseTermsOverride( + commercial_attribution=True, + derivatives_allowed=False, + derivatives_attribution=False, + derivatives_approval=False, + derivatives_reciprocal=False, + royalty_policy=ADDRESS, + commercial_rev_share=12, + commercializer_checker="0x", + ), + ), + licensing_config=LicensingConfig( + is_set=False, + minting_fee=10, + licensing_hook=ZERO_ADDRESS, + hook_data="test", + commercial_rev_share=10, + disabled=False, + expect_minimum_group_reward_share=11, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], royalty_shares=royalty_shares, ip_metadata=IP_METADATA, ) @@ -2261,6 +2402,40 @@ def test_success_when_license_terms_data_royalty_shares_and_all_optional_paramet mock_build_register_transaction.call_args[0][2] == IPMetadata.from_input(IP_METADATA).get_validated_data() ) # ip_metadata + + assert mock_build_register_transaction.call_args[0][3] == [ + { + "terms": { + "transferable": True, + "commercialAttribution": True, + "commercialRevCeiling": 0, + "commercialRevShare": 12 * 10**6, + "commercialUse": True, + "currency": ADDRESS, + "derivativeRevCeiling": 0, + "derivativesAllowed": False, + "derivativesApproval": False, + "derivativesAttribution": False, + "derivativesReciprocal": False, + "expiration": 0, + "defaultMintingFee": 0, + "royaltyPolicy": ADDRESS, + "commercializerChecker": "0x", + "commercializerCheckerData": ZERO_ADDRESS, + "uri": "https://github.com/piplabs/pil-document/blob/998c13e6ee1d04eb817aefd1fe16dfe8be3cd7a2/off-chain-terms/CC-BY.json", + }, + "licensingConfig": { + "isSet": False, + "mintingFee": 10, + "hookData": Web3.to_bytes(text="test"), + "licensingHook": ZERO_ADDRESS, + "commercialRevShare": 10 * 10**6, + "disabled": False, + "expectMinimumGroupRewardShare": 11 * 10**6, + "expectGroupRewardPool": ZERO_ADDRESS, + }, + }, + ] # license_terms_data assert ( mock_build_register_transaction.call_args[0][5] is False ) # allow_duplicates @@ -2302,6 +2477,87 @@ def test_success_when_license_terms_data_provided_for_mint_nft( assert result["token_id"] == 3 assert result["license_terms_ids"] is not None + def test_success_when_license_terms_data_is_commercial_remix_for_mint_nft( + self, + ip_asset: IPAsset, + mock_parse_ip_registered_event, + mock_parse_tx_license_terms_attached_event, + ): + with ( + mock_parse_ip_registered_event(), + mock_parse_tx_license_terms_attached_event(), + patch.object( + ip_asset.license_attachment_workflows_client, + "build_mintAndRegisterIpAndAttachPILTerms_transaction", + return_value={"tx_hash": TX_HASH.hex()}, + ) as mock_build_register_transaction, + ): + ip_asset.register_ip_asset( + nft=MintNFT(type="mint", spg_nft_contract=ADDRESS), + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.commercial_remix( + default_minting_fee=10, + currency=ADDRESS, + commercial_rev_share=10, + override=LicenseTermsOverride( + commercial_attribution=False, + derivatives_allowed=False, + derivatives_attribution=False, + derivatives_approval=False, + derivatives_reciprocal=False, + royalty_policy=ADDRESS, + commercial_rev_share=12, + ), + ), + licensing_config=LicensingConfig( + is_set=True, + minting_fee=10, + licensing_hook=ZERO_ADDRESS, + hook_data="test", + commercial_rev_share=10, + disabled=False, + expect_minimum_group_reward_share=11, + expect_group_reward_pool=ZERO_ADDRESS, + ), + ), + ], + ) + assert mock_build_register_transaction.call_args[0][3] == [ + { + "terms": { + "transferable": True, + "commercialAttribution": False, + "commercialRevCeiling": 0, + "commercialRevShare": 12 * 10**6, + "commercialUse": True, + "commercializerChecker": ZERO_ADDRESS, + "commercializerCheckerData": ZERO_ADDRESS, + "currency": ADDRESS, + "derivativeRevCeiling": 0, + "derivativesAllowed": False, + "derivativesApproval": False, + "derivativesAttribution": False, + "derivativesReciprocal": False, + "expiration": 0, + "defaultMintingFee": 10, + "royaltyPolicy": ADDRESS, + "uri": "https://github.com/piplabs/pil-document/blob/ad67bb632a310d2557f8abcccd428e4c9c798db1/off-chain-terms/CommercialRemix.json", + }, + "licensingConfig": { + "isSet": True, + "mintingFee": 10, + "hookData": Web3.to_bytes(text="test"), + "licensingHook": ZERO_ADDRESS, + "commercialRevShare": 10 * 10**6, + "disabled": False, + "expectMinimumGroupRewardShare": 11 * 10**6, + "expectGroupRewardPool": ZERO_ADDRESS, + }, + } + ] + # license_terms_data + def test_success_when_license_terms_data_and_all_optional_parameters_provided_for_mint_nft( self, ip_asset: IPAsset, @@ -2377,6 +2633,87 @@ def test_success_when_license_terms_data_provided_for_minted_nft( assert result["token_id"] == 3 assert result["license_terms_ids"] is not None + def test_success_when_license_terms_data_is_commercial_use_for_minted_nft( + self, + ip_asset: IPAsset, + mock_parse_ip_registered_event, + mock_parse_tx_license_terms_attached_event, + mock_get_ip_id, + mock_signature_related_methods, + mock_is_registered, + ): + with ( + mock_parse_ip_registered_event(), + mock_parse_tx_license_terms_attached_event(), + mock_get_ip_id(), + mock_signature_related_methods(), + mock_is_registered(is_registered=False), + patch.object( + ip_asset.license_attachment_workflows_client, + "build_registerIpAndAttachPILTerms_transaction", + return_value={"tx_hash": TX_HASH.hex()}, + ) as mock_build_register_transaction, + ): + ip_asset.register_ip_asset( + nft=MintedNFT(type="minted", nft_contract=ADDRESS, token_id=3), + license_terms_data=[ + LicenseTermsDataInput( + terms=PILFlavor.commercial_use( + default_minting_fee=10, + currency=ADDRESS, + override=LicenseTermsOverride( + commercial_rev_share=10, + royalty_policy=NativeRoyaltyPolicy.LAP, + ), + ), + licensing_config={ + "is_set": True, + "minting_fee": 10, + "licensing_hook": ADDRESS, + "hook_data": "11", + "commercial_rev_share": 10, + "disabled": False, + "expect_minimum_group_reward_share": 0, + "expect_group_reward_pool": ZERO_ADDRESS, + }, + ) + ], + ip_metadata=IP_METADATA, + ) + assert mock_build_register_transaction.call_args[0][3] == [ + { + "terms": { + "transferable": True, + "commercialAttribution": True, + "commercialRevCeiling": 0, + "commercialRevShare": 10 * 10**6, + "commercialUse": True, + "commercializerChecker": ZERO_ADDRESS, + "commercializerCheckerData": ZERO_ADDRESS, + "currency": ADDRESS, + "derivativeRevCeiling": 0, + "derivativesAllowed": False, + "derivativesApproval": False, + "derivativesAttribution": False, + "derivativesReciprocal": False, + "expiration": 0, + "defaultMintingFee": 10, + "royaltyPolicy": "0xBe54FB168b3c982b7AaE60dB6CF75Bd8447b390E", + "uri": "https://github.com/piplabs/pil-document/blob/9a1f803fcf8101a8a78f1dcc929e6014e144ab56/off-chain-terms/CommercialUse.json", + }, + "licensingConfig": { + "isSet": True, + "mintingFee": 10, + "hookData": Web3.to_bytes(text="11"), + "licensingHook": ADDRESS, + "commercialRevShare": 10 * 10**6, + "disabled": False, + "expectMinimumGroupRewardShare": 0, + "expectGroupRewardPool": ZERO_ADDRESS, + }, + } + ] + def test_success_when_ip_metadata_provided_for_minted_nft( self, ip_asset: IPAsset, diff --git a/tests/unit/resources/test_license.py b/tests/unit/resources/test_license.py index 23ce92e..73c47d2 100644 --- a/tests/unit/resources/test_license.py +++ b/tests/unit/resources/test_license.py @@ -1,3 +1,4 @@ +from dataclasses import asdict, replace from typing import Callable from unittest.mock import patch @@ -5,9 +6,15 @@ from _pytest.fixtures import fixture from web3 import Web3 -from story_protocol_python_sdk.resources.License import License -from story_protocol_python_sdk.utils.constants import ZERO_ADDRESS -from story_protocol_python_sdk.utils.licensing_config_data import LicensingConfig +from story_protocol_python_sdk import ( + WIP_TOKEN_ADDRESS, + ZERO_ADDRESS, + License, + LicensingConfig, + PILFlavor, + PILFlavorError, +) +from story_protocol_python_sdk.utils.util import convert_dict_keys_to_camel_case from tests.unit.fixtures.data import ADDRESS, CHAIN_ID, IP_ID, TX_HASH from tests.unit.resources.test_ip_account import ZERO_HASH @@ -24,33 +31,23 @@ def test_register_pil_terms_license_terms_id_registered(self, license: License): with patch.object( license.license_template_client, "getLicenseTermsId", return_value=1 ), patch.object( - license.license_terms_util.royalty_module_client, + license.royalty_module_client, "isWhitelistedRoyaltyPolicy", return_value=True, ), patch.object( - license.license_terms_util.royalty_module_client, + license.royalty_module_client, "isWhitelistedRoyaltyToken", return_value=True, ): response = license.register_pil_terms( - default_minting_fee=1513, - currency=ADDRESS, - royalty_policy=ADDRESS, - transferable=False, - expiration=0, - commercial_use=True, - commercial_attribution=False, - commercializer_checker=ZERO_ADDRESS, - commercializer_checker_data="0x", - commercial_rev_share=0, - commercial_rev_ceiling=0, - derivatives_allowed=False, - derivatives_attribution=False, - derivatives_approval=False, - derivatives_reciprocal=False, - derivative_rev_ceiling=0, - uri="", + **asdict( + PILFlavor.commercial_use( + default_minting_fee=1513, + currency=ADDRESS, + royalty_policy=ADDRESS, + ) + ) ) assert response["license_terms_id"] == 1 assert "tx_hash" not in response @@ -59,11 +56,11 @@ def test_register_pil_terms_success(self, license: License): with patch.object( license.license_template_client, "getLicenseTermsId", return_value=0 ), patch.object( - license.license_terms_util.royalty_module_client, + license.royalty_module_client, "isWhitelistedRoyaltyPolicy", return_value=True, ), patch.object( - license.license_terms_util.royalty_module_client, + license.royalty_module_client, "isWhitelistedRoyaltyToken", return_value=True, ), patch.object( @@ -75,7 +72,7 @@ def test_register_pil_terms_success(self, license: License): "gas": 2000000, "gasPrice": Web3.to_wei("100", "gwei"), }, - ): + ) as mock_build_registerLicenseTerms_transaction: response = license.register_pil_terms( transferable=False, @@ -96,7 +93,12 @@ def test_register_pil_terms_success(self, license: License): currency=ADDRESS, uri="", ) - + assert ( + mock_build_registerLicenseTerms_transaction.call_args[0][0][ + "commercialRevShare" + ] + == 90 * 10**6 + ) assert "tx_hash" in response assert response["tx_hash"] == TX_HASH.hex() assert isinstance(response["tx_hash"], str) @@ -107,17 +109,18 @@ def test_register_pil_terms_commercial_rev_share_error_more_than_100( with patch.object( license.license_template_client, "getLicenseTermsId", return_value=0 ), patch.object( - license.license_terms_util.royalty_module_client, + license.royalty_module_client, "isWhitelistedRoyaltyPolicy", return_value=True, ), patch.object( - license.license_terms_util.royalty_module_client, + license.royalty_module_client, "isWhitelistedRoyaltyToken", return_value=True, ): with pytest.raises( - ValueError, match="commercial_rev_share should be between 0 and 100." + PILFlavorError, + match="commercial_rev_share must be between 0 and 100.", ): license.register_pil_terms( transferable=False, @@ -145,17 +148,17 @@ def test_register_pil_terms_commercial_rev_share_error_less_than_0( with patch.object( license.license_template_client, "getLicenseTermsId", return_value=0 ), patch.object( - license.license_terms_util.royalty_module_client, + license.royalty_module_client, "isWhitelistedRoyaltyPolicy", return_value=True, ), patch.object( - license.license_terms_util.royalty_module_client, + license.royalty_module_client, "isWhitelistedRoyaltyToken", return_value=True, ): with pytest.raises( - ValueError, match="commercial_rev_share should be between 0 and 100." + PILFlavorError, match="commercial_rev_share must be between 0 and 100." ): license.register_pil_terms( transferable=False, @@ -177,6 +180,165 @@ def test_register_pil_terms_commercial_rev_share_error_less_than_0( uri="", ) + def test_register_non_commercial_social_remixing_pil_success( + self, license: License + ): + with patch.object( + license.license_template_client, "getLicenseTermsId", return_value=0 + ), patch.object( + license.royalty_module_client, + "isWhitelistedRoyaltyPolicy", + return_value=True, + ), patch.object( + license.royalty_module_client, + "isWhitelistedRoyaltyToken", + return_value=True, + ), patch.object( + license.license_template_client, + "build_registerLicenseTerms_transaction", + return_value={ + "from": ADDRESS, + "nonce": 1, + "gas": 2000000, + "gasPrice": Web3.to_wei("100", "gwei"), + }, + ) as mock_build_registerLicenseTerms_transaction: + + license.register_pil_terms( + **asdict(PILFlavor.non_commercial_social_remixing()) + ) + assert mock_build_registerLicenseTerms_transaction.call_args[0][ + 0 + ] == convert_dict_keys_to_camel_case( + asdict(PILFlavor.non_commercial_social_remixing()) + ) + + def test_register_commercial_remix_pil_success(self, license: License): + with patch.object( + license.license_template_client, "getLicenseTermsId", return_value=0 + ), patch.object( + license.royalty_module_client, + "isWhitelistedRoyaltyPolicy", + return_value=True, + ), patch.object( + license.royalty_module_client, + "isWhitelistedRoyaltyToken", + return_value=True, + ), patch.object( + license.license_template_client, + "build_registerLicenseTerms_transaction", + return_value={ + "from": ADDRESS, + "nonce": 1, + "gas": 2000000, + "gasPrice": Web3.to_wei("100", "gwei"), + }, + ) as mock_build_registerLicenseTerms_transaction: + + license.register_pil_terms( + **asdict( + PILFlavor.commercial_remix( + default_minting_fee=1513, + currency=ADDRESS, + commercial_rev_share=90, + ) + ) + ) + assert mock_build_registerLicenseTerms_transaction.call_args[0][ + 0 + ] == convert_dict_keys_to_camel_case( + asdict( + replace( + PILFlavor.commercial_remix( + default_minting_fee=1513, + currency=ADDRESS, + commercial_rev_share=90, + ), + commercial_rev_share=90 * 10**6, + ) + ) + ) + + def test_register_commercial_use_pil_success(self, license: License): + with patch.object( + license.license_template_client, "getLicenseTermsId", return_value=0 + ), patch.object( + license.royalty_module_client, + "isWhitelistedRoyaltyPolicy", + return_value=True, + ), patch.object( + license.royalty_module_client, + "isWhitelistedRoyaltyToken", + return_value=True, + ), patch.object( + license.license_template_client, + "build_registerLicenseTerms_transaction", + return_value={ + "from": ADDRESS, + "nonce": 1, + "gas": 2000000, + "gasPrice": Web3.to_wei("100", "gwei"), + }, + ) as mock_build_registerLicenseTerms_transaction: + + license.register_pil_terms( + **asdict( + PILFlavor.commercial_use( + default_minting_fee=1513, + currency=ADDRESS, + ) + ) + ) + assert mock_build_registerLicenseTerms_transaction.call_args[0][ + 0 + ] == convert_dict_keys_to_camel_case( + asdict( + PILFlavor.commercial_use( + default_minting_fee=1513, + currency=ADDRESS, + ) + ) + ) + + def test_register_creative_commons_attribution_pil_success(self, license: License): + with patch.object( + license.license_template_client, "getLicenseTermsId", return_value=0 + ), patch.object( + license.royalty_module_client, + "isWhitelistedRoyaltyPolicy", + return_value=True, + ), patch.object( + license.royalty_module_client, + "isWhitelistedRoyaltyToken", + return_value=True, + ), patch.object( + license.license_template_client, + "build_registerLicenseTerms_transaction", + return_value={ + "from": ADDRESS, + "nonce": 1, + "gas": 2000000, + "gasPrice": Web3.to_wei("100", "gwei"), + }, + ) as mock_build_registerLicenseTerms_transaction: + + license.register_pil_terms( + **asdict( + PILFlavor.creative_commons_attribution( + currency=ADDRESS, + ) + ) + ) + assert mock_build_registerLicenseTerms_transaction.call_args[0][ + 0 + ] == convert_dict_keys_to_camel_case( + asdict( + PILFlavor.creative_commons_attribution( + currency=ADDRESS, + ) + ) + ) + class TestNonComSocialRemixingPIL: """Tests for non-commercial social remixing PIL functionality.""" @@ -218,7 +380,7 @@ def test_register_non_com_social_remixing_pil_error(self, license: License): with patch.object( license.license_template_client, "getLicenseTermsId", return_value=0 ), patch.object( - license.license_terms_util.royalty_module_client, + license.royalty_module_client, "isWhitelistedRoyaltyPolicy", return_value=True, ), patch.object( @@ -241,7 +403,9 @@ def test_register_commercial_use_pil_license_terms_id_registered( license.license_template_client, "getLicenseTermsId", return_value=1 ): response = license.register_commercial_use_pil( - default_minting_fee=1, currency=ZERO_ADDRESS + default_minting_fee=1, + currency=WIP_TOKEN_ADDRESS, + royalty_policy=ADDRESS, ) assert response["license_terms_id"] == 1 assert "tx_hash" not in response @@ -250,7 +414,7 @@ def test_register_commercial_use_pil_success_without_logs(self, license: License with patch.object( license.license_template_client, "getLicenseTermsId", return_value=0 ), patch.object( - license.license_terms_util.royalty_module_client, + license.royalty_module_client, "isWhitelistedRoyaltyPolicy", return_value=True, ), patch.object( @@ -267,7 +431,7 @@ def test_register_commercial_use_pil_success_without_logs(self, license: License ): response = license.register_commercial_use_pil( - default_minting_fee=1, currency=ZERO_ADDRESS + default_minting_fee=1, currency=WIP_TOKEN_ADDRESS ) assert response is not None assert "tx_hash" in response @@ -278,7 +442,7 @@ def test_register_commercial_use_pil_error(self, license: License): with patch.object( license.license_template_client, "getLicenseTermsId", return_value=0 ), patch.object( - license.license_terms_util.royalty_module_client, + license.royalty_module_client, "isWhitelistedRoyaltyPolicy", return_value=True, ), patch.object( @@ -288,7 +452,9 @@ def test_register_commercial_use_pil_error(self, license: License): ): with pytest.raises(Exception, match="request fail."): license.register_commercial_use_pil( - default_minting_fee=1, currency=ZERO_ADDRESS + default_minting_fee=1, + currency=WIP_TOKEN_ADDRESS, + royalty_policy=ADDRESS, ) @@ -301,15 +467,15 @@ def test_register_commercial_remix_pil_license_terms_id_registered( with patch.object( license.license_template_client, "getLicenseTermsId", return_value=1 ), patch.object( - license.license_terms_util.royalty_module_client, + license.royalty_module_client, "isWhitelistedRoyaltyPolicy", return_value=True, ): response = license.register_commercial_remix_pil( default_minting_fee=1, commercial_rev_share=100, - currency=ZERO_ADDRESS, - royalty_policy=ZERO_ADDRESS, + currency=WIP_TOKEN_ADDRESS, + royalty_policy=ADDRESS, ) assert response["license_terms_id"] == 1 assert "tx_hash" not in response @@ -318,7 +484,7 @@ def test_register_commercial_remix_pil_success(self, license: License): with patch.object( license.license_template_client, "getLicenseTermsId", return_value=0 ), patch.object( - license.license_terms_util.royalty_module_client, + license.royalty_module_client, "isWhitelistedRoyaltyPolicy", return_value=True, ), patch.object( @@ -337,8 +503,8 @@ def test_register_commercial_remix_pil_success(self, license: License): response = license.register_commercial_remix_pil( default_minting_fee=1, commercial_rev_share=100, - currency=ZERO_ADDRESS, - royalty_policy=ZERO_ADDRESS, + currency=WIP_TOKEN_ADDRESS, + royalty_policy=ADDRESS, ) assert response is not None assert "tx_hash" in response @@ -624,7 +790,7 @@ def default_licensing_config() -> LicensingConfig: "is_set": True, "minting_fee": 1, "licensing_hook": ZERO_ADDRESS, - "hook_data": "0x", + "hook_data": ZERO_HASH, "commercial_rev_share": 0, "disabled": False, "expect_minimum_group_reward_share": 0, @@ -1014,7 +1180,7 @@ def test_set_licensing_config_success_with_custom_template( "isSet": True, "mintingFee": 1, "licensingHook": ZERO_ADDRESS, - "hookData": "0x", + "hookData": ZERO_HASH, "commercialRevShare": 0, "disabled": False, "expectMinimumGroupRewardShare": 0, diff --git a/tests/unit/utils/test_derivative_data.py b/tests/unit/utils/test_derivative_data.py index 1162df4..0551c51 100644 --- a/tests/unit/utils/test_derivative_data.py +++ b/tests/unit/utils/test_derivative_data.py @@ -114,30 +114,26 @@ def test_validate_parent_ip_ids_and_license_terms_ids_are_not_equal( license_template="0x1234567890123456789012345678901234567890", ) - def test_validate_parent_ip_ids_is_not_valid_address( - self, mock_web3, mock_is_checksum_address - ): - with mock_is_checksum_address(is_checksum_address=False): - with raises(ValueError, match="The parent IP ID must be a valid address."): - DerivativeData( - web3=mock_web3, - parent_ip_ids=["0x1234567890123456789012345678901234567890"], - license_terms_ids=[2], - max_minting_fee=10, - max_rts=10, - max_revenue_share=100, - license_template="0x1234567890123456789012345678901234567890", - ) + def test_validate_parent_ip_ids_is_not_valid_address(self, mock_web3): + with raises( + ValueError, match="Invalid address: 0x12345678901234567890123901234567890." + ): + DerivativeData( + web3=mock_web3, + parent_ip_ids=["0x12345678901234567890123901234567890"], + license_terms_ids=[2], + max_minting_fee=10, + max_rts=10, + max_revenue_share=100, + license_template="0x1234567890123456789012345678901234567890", + ) def test_validate_parent_ip_ids_is_not_registered( self, mock_web3, - mock_is_checksum_address, mock_ip_asset_registry_client, ): - with mock_is_checksum_address(), mock_ip_asset_registry_client( - is_registered=False - ): + with mock_ip_asset_registry_client(is_registered=False): with raises( ValueError, match=f"The parent IP ID {IP_ID} must be registered.", @@ -155,11 +151,10 @@ def test_validate_parent_ip_ids_is_not_registered( def test_validate_license_terms_not_attached( self, mock_web3, - mock_is_checksum_address, mock_ip_asset_registry_client, mock_license_registry_client, ): - with mock_is_checksum_address(), mock_ip_asset_registry_client( + with mock_ip_asset_registry_client( is_registered=True ), mock_license_registry_client(has_ip_attached_license_terms=False): with raises( @@ -179,11 +174,10 @@ def test_validate_license_terms_not_attached( def test_validate_royalty_percent_exceeds_max_revenue_share( self, mock_web3, - mock_is_checksum_address, mock_ip_asset_registry_client, mock_license_registry_client, ): - with mock_is_checksum_address(), mock_ip_asset_registry_client( + with mock_ip_asset_registry_client( is_registered=True ), mock_license_registry_client( has_ip_attached_license_terms=True, get_royalty_percent=1500000000000 @@ -205,11 +199,10 @@ def test_validate_royalty_percent_exceeds_max_revenue_share( def test_validate_royalty_percent_is_less_than_max_revenue_share( self, mock_web3, - mock_is_checksum_address, mock_ip_asset_registry_client, mock_license_registry_client, ): - with mock_is_checksum_address(), mock_ip_asset_registry_client(), mock_license_registry_client(): + with mock_ip_asset_registry_client(), mock_license_registry_client(): derivative_data = DerivativeData.from_input( web3=mock_web3, input_data=DerivativeDataInput( @@ -227,11 +220,10 @@ class TestValidateMaxMintingFee: def test_validate_max_minting_fee_is_less_than_0( self, mock_web3, - mock_is_checksum_address, mock_ip_asset_registry_client, mock_license_registry_client, ): - with mock_is_checksum_address(), mock_ip_asset_registry_client(), mock_license_registry_client(): + with mock_ip_asset_registry_client(), mock_license_registry_client(): with raises( ValueError, match="The max minting fee must be greater than 0." ): @@ -250,11 +242,10 @@ class TestValidateMaxRts: def test_validate_max_rts_is_less_than_0( self, mock_web3, - mock_is_checksum_address, mock_ip_asset_registry_client, mock_license_registry_client, ): - with mock_is_checksum_address(), mock_ip_asset_registry_client(), mock_license_registry_client(): + with mock_ip_asset_registry_client(), mock_license_registry_client(): with raises( ValueError, match="The maxRts must be greater than 0 and less than 100000000.", @@ -288,11 +279,10 @@ def test_validate_max_rts_is_greater_than_100_000_000( def test_validate_max_rts_default_value_is_max_rts( self, mock_web3, - mock_is_checksum_address, mock_ip_asset_registry_client, mock_license_registry_client, ): - with mock_is_checksum_address(), mock_ip_asset_registry_client(), mock_license_registry_client(): + with mock_ip_asset_registry_client(), mock_license_registry_client(): derivative_data = DerivativeData.from_input( web3=mock_web3, input_data=DerivativeDataInput( @@ -325,11 +315,10 @@ def test_validate_max_revenue_share_is_less_than_0( def test_validate_max_revenue_share_is_greater_than_100( self, mock_web3, - mock_is_checksum_address, mock_ip_asset_registry_client, mock_license_registry_client, ): - with mock_is_checksum_address(), mock_ip_asset_registry_client(), mock_license_registry_client(): + with mock_ip_asset_registry_client(), mock_license_registry_client(): with raises( ValueError, match="max_revenue_share must be between 0 and 100." ): @@ -347,11 +336,10 @@ def test_validate_max_revenue_share_is_greater_than_100( def test_validate_max_revenue_share_default_value_is_100( self, mock_web3, - mock_is_checksum_address, mock_ip_asset_registry_client, mock_license_registry_client, ): - with mock_is_checksum_address(), mock_ip_asset_registry_client(), mock_license_registry_client(): + with mock_ip_asset_registry_client(), mock_license_registry_client(): derivative_data = DerivativeData.from_input( web3=mock_web3, input_data=DerivativeDataInput( @@ -366,12 +354,11 @@ class TestValidateLicenseTemplate: def test_validate_license_template_default_value_is_pi_license_template( self, mock_web3, - mock_is_checksum_address, mock_pi_license_template_client, mock_ip_asset_registry_client, mock_license_registry_client, ): - with mock_is_checksum_address(), mock_pi_license_template_client(), mock_ip_asset_registry_client(), mock_license_registry_client(): + with mock_pi_license_template_client(), mock_ip_asset_registry_client(), mock_license_registry_client(): derivative_data = DerivativeData.from_input( web3=mock_web3, input_data=DerivativeDataInput( @@ -386,12 +373,11 @@ class TestGetValidatedData: def test_get_validated_data_with_default_values( self, mock_web3, - mock_is_checksum_address, mock_pi_license_template_client, mock_ip_asset_registry_client, mock_license_registry_client, ): - with mock_is_checksum_address(), mock_pi_license_template_client(), mock_ip_asset_registry_client(), mock_license_registry_client(): + with mock_pi_license_template_client(), mock_ip_asset_registry_client(), mock_license_registry_client(): derivative_data = DerivativeData.from_input( web3=mock_web3, input_data=DerivativeDataInput( @@ -415,9 +401,8 @@ def test_get_validated_data_with_custom_values( mock_ip_asset_registry_client, mock_license_registry_client, mock_pi_license_template_client, - mock_is_checksum_address, ): - with mock_is_checksum_address(), mock_pi_license_template_client(), mock_ip_asset_registry_client(), mock_license_registry_client(): + with mock_pi_license_template_client(), mock_ip_asset_registry_client(), mock_license_registry_client(): derivative_data = DerivativeData( web3=mock_web3, parent_ip_ids=[IP_ID], diff --git a/tests/unit/utils/test_licensing_config_data.py b/tests/unit/utils/test_licensing_config_data.py index 46ea42d..42918be 100644 --- a/tests/unit/utils/test_licensing_config_data.py +++ b/tests/unit/utils/test_licensing_config_data.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, Mock import pytest +from web3 import Web3 from story_protocol_python_sdk.utils.constants import ZERO_ADDRESS, ZERO_HASH from story_protocol_python_sdk.utils.licensing_config_data import ( @@ -44,7 +45,37 @@ def test_validate_license_config_valid_input(self, mock_module_registry_client): "is_set": True, "minting_fee": 100, "licensing_hook": ZERO_ADDRESS, - "hook_data": "0xabcdef", + "hook_data": ZERO_HASH, + "commercial_rev_share": 50, + "disabled": False, + "expect_minimum_group_reward_share": 25, + "expect_group_reward_pool": ZERO_ADDRESS, + } + + result = LicensingConfigData.validate_license_config( + mock_module_registry_client(), input_config + ) + + assert result == ValidatedLicensingConfig( + isSet=True, + mintingFee=100, + licensingHook=ZERO_ADDRESS, + hookData=ZERO_HASH, + commercialRevShare=50 * 10**6, + disabled=False, + expectMinimumGroupRewardShare=25 * 10**6, + expectGroupRewardPool=ZERO_ADDRESS, + ) + + def test_validate_license_config_valid_input_with_custom_hook_data( + self, mock_module_registry_client + ): + """Test validate_license_config with valid input.""" + input_config: LicensingConfig = { + "is_set": True, + "minting_fee": 100, + "licensing_hook": ZERO_ADDRESS, + "hook_data": "test", "commercial_rev_share": 50, "disabled": False, "expect_minimum_group_reward_share": 25, @@ -59,7 +90,7 @@ def test_validate_license_config_valid_input(self, mock_module_registry_client): isSet=True, mintingFee=100, licensingHook=ZERO_ADDRESS, - hookData="0xabcdef", + hookData=Web3.to_bytes(text="test"), commercialRevShare=50 * 10**6, disabled=False, expectMinimumGroupRewardShare=25 * 10**6, diff --git a/tests/unit/utils/test_pil_flavor.py b/tests/unit/utils/test_pil_flavor.py new file mode 100644 index 0000000..35d92b5 --- /dev/null +++ b/tests/unit/utils/test_pil_flavor.py @@ -0,0 +1,527 @@ +import pytest + +from story_protocol_python_sdk import ( + ROYALTY_POLICY_LAP_ADDRESS, + ROYALTY_POLICY_LRP_ADDRESS, + WIP_TOKEN_ADDRESS, + ZERO_ADDRESS, + LicenseTermsInput, + LicenseTermsOverride, + NativeRoyaltyPolicy, + PILFlavor, +) +from story_protocol_python_sdk.utils.pil_flavor import PILFlavorError +from tests.unit.fixtures.data import ADDRESS + + +class TestPILFlavor: + """Test PILFlavor class.""" + + class TestNonCommercialSocialRemixing: + """Test non commercial social remixing PIL flavor.""" + + def test_default_values(self): + """Test default values.""" + pil_flavor = PILFlavor.non_commercial_social_remixing() + assert pil_flavor == LicenseTermsInput( + transferable=True, + commercial_attribution=False, + commercial_rev_ceiling=0, + commercial_rev_share=0, + commercial_use=False, + commercializer_checker=ZERO_ADDRESS, + commercializer_checker_data=ZERO_ADDRESS, + currency=ZERO_ADDRESS, + derivative_rev_ceiling=0, + derivatives_allowed=True, + derivatives_approval=False, + derivatives_attribution=True, + derivatives_reciprocal=True, + expiration=0, + default_minting_fee=0, + royalty_policy=ZERO_ADDRESS, + uri="https://github.com/piplabs/pil-document/blob/998c13e6ee1d04eb817aefd1fe16dfe8be3cd7a2/off-chain-terms/NCSR.json", + ) + + def test_override_values(self): + """Test override values.""" + pil_flavor = PILFlavor.non_commercial_social_remixing( + override=LicenseTermsOverride( + commercial_use=True, + commercial_attribution=True, + royalty_policy=NativeRoyaltyPolicy.LAP, + currency=WIP_TOKEN_ADDRESS, + ), + ) + assert pil_flavor == LicenseTermsInput( + transferable=True, + commercial_attribution=True, + commercial_rev_ceiling=0, + commercial_rev_share=0, + commercial_use=True, + commercializer_checker=ZERO_ADDRESS, + commercializer_checker_data=ZERO_ADDRESS, + currency=WIP_TOKEN_ADDRESS, + derivative_rev_ceiling=0, + derivatives_allowed=True, + derivatives_approval=False, + derivatives_attribution=True, + derivatives_reciprocal=True, + expiration=0, + default_minting_fee=0, + royalty_policy=ROYALTY_POLICY_LAP_ADDRESS, + uri="https://github.com/piplabs/pil-document/blob/998c13e6ee1d04eb817aefd1fe16dfe8be3cd7a2/off-chain-terms/NCSR.json", + ) + + def test_throw_commercial_attribution_error_when_commercial_use_is_false(self): + """Test throw commercial attribution error when commercial use is false.""" + with pytest.raises( + PILFlavorError, + match="cannot add commercial_attribution when commercial_use is False.", + ): + PILFlavor.non_commercial_social_remixing( + override=LicenseTermsOverride(commercial_attribution=True), + ) + + def test_throw_commercializer_checker_error_when_commercial_use_is_false(self): + """Test throw commercializer checker error when commercial use is false.""" + with pytest.raises( + PILFlavorError, + match="cannot add commercializer_checker when commercial_use is False.", + ): + PILFlavor.non_commercial_social_remixing( + override=LicenseTermsOverride(commercializer_checker=ADDRESS), + ) + + def test_throw_commercial_rev_share_error_when_commercial_use_is_false(self): + """Test throw commercial rev share error when commercial use is false.""" + with pytest.raises( + PILFlavorError, + match="cannot add commercial_rev_share when commercial_use is False.", + ): + PILFlavor.non_commercial_social_remixing( + override=LicenseTermsOverride(commercial_rev_share=10), + ) + + def test_throw_commercial_rev_ceiling_error_when_commercial_use_is_false(self): + """Test throw commercial rev ceiling error when commercial use is false.""" + with pytest.raises( + PILFlavorError, + match="cannot add commercial_rev_ceiling when commercial_use is False.", + ): + PILFlavor.non_commercial_social_remixing( + override=LicenseTermsOverride(commercial_rev_ceiling=10000), + ) + + def test_throw_derivative_rev_ceiling_error_when_commercial_use_is_false(self): + """Test throw derivative rev ceiling error when commercial use is false.""" + with pytest.raises( + PILFlavorError, + match="cannot add derivative_rev_ceiling when commercial_use is False.", + ): + PILFlavor.non_commercial_social_remixing( + override=LicenseTermsOverride(derivative_rev_ceiling=10000), + ) + + def test_throw_royalty_policy_error_when_commercial_use_is_false(self): + """Test throw royalty policy error when commercial use is false.""" + with pytest.raises( + PILFlavorError, + match="cannot add royalty_policy when commercial_use is False.", + ): + PILFlavor.non_commercial_social_remixing( + override=LicenseTermsOverride( + royalty_policy=ADDRESS, currency=WIP_TOKEN_ADDRESS + ), + ) + + class TestCommercialUse: + """Test commercial use PIL flavor.""" + + def test_default_values(self): + """Test default values.""" + pil_flavor = PILFlavor.commercial_use( + default_minting_fee=10000, + currency=WIP_TOKEN_ADDRESS, + royalty_policy=ADDRESS, + ) + assert pil_flavor == LicenseTermsInput( + transferable=True, + commercial_attribution=True, + commercial_rev_ceiling=0, + commercial_rev_share=0, + commercial_use=True, + commercializer_checker=ZERO_ADDRESS, + commercializer_checker_data=ZERO_ADDRESS, + currency=WIP_TOKEN_ADDRESS, + derivative_rev_ceiling=0, + derivatives_allowed=False, + derivatives_approval=False, + derivatives_attribution=False, + derivatives_reciprocal=False, + expiration=0, + default_minting_fee=10000, + royalty_policy=ADDRESS, + uri="https://github.com/piplabs/pil-document/blob/9a1f803fcf8101a8a78f1dcc929e6014e144ab56/off-chain-terms/CommercialUse.json", + ) + + def test_without_royalty_policy(self): + """Test without royalty policy.""" + pil_flavor = PILFlavor.commercial_use( + default_minting_fee=10000, + currency=WIP_TOKEN_ADDRESS, + ) + assert pil_flavor == LicenseTermsInput( + transferable=True, + commercial_attribution=True, + commercial_rev_ceiling=0, + commercial_rev_share=0, + commercial_use=True, + commercializer_checker=ZERO_ADDRESS, + commercializer_checker_data=ZERO_ADDRESS, + currency=WIP_TOKEN_ADDRESS, + derivative_rev_ceiling=0, + derivatives_allowed=False, + derivatives_approval=False, + derivatives_attribution=False, + derivatives_reciprocal=False, + expiration=0, + default_minting_fee=10000, + royalty_policy=ROYALTY_POLICY_LAP_ADDRESS, + uri="https://github.com/piplabs/pil-document/blob/9a1f803fcf8101a8a78f1dcc929e6014e144ab56/off-chain-terms/CommercialUse.json", + ) + + def test_with_custom_values(self): + """Test with custom values.""" + pil_flavor = PILFlavor.commercial_use( + default_minting_fee=10000, + currency=WIP_TOKEN_ADDRESS, + royalty_policy=ADDRESS, + override=LicenseTermsOverride( + commercial_attribution=False, + derivatives_allowed=True, + derivatives_attribution=True, + derivatives_approval=False, + derivatives_reciprocal=True, + uri="https://example.com", + royalty_policy=NativeRoyaltyPolicy.LRP, + default_minting_fee=10, + commercial_rev_share=10, + ), + ) + assert pil_flavor == LicenseTermsInput( + transferable=True, + commercial_attribution=False, + commercial_rev_ceiling=0, + commercial_rev_share=10, + commercial_use=True, + commercializer_checker=ZERO_ADDRESS, + commercializer_checker_data=ZERO_ADDRESS, + currency=WIP_TOKEN_ADDRESS, + derivative_rev_ceiling=0, + derivatives_allowed=True, + derivatives_approval=False, + derivatives_attribution=True, + derivatives_reciprocal=True, + expiration=0, + default_minting_fee=10, + royalty_policy=ROYALTY_POLICY_LRP_ADDRESS, + uri="https://example.com", + ) + + def test_throw_error_when_royalty_policy_is_not_zero_address_and_currency_is_zero_address( + self, + ): + """Test throw error when royalty policy is not zero address and currency is zero address.""" + with pytest.raises( + PILFlavorError, + match="royalty_policy is not zero address and currency cannot be zero address.", + ): + PILFlavor.commercial_use( + default_minting_fee=10000, + currency=ZERO_ADDRESS, + royalty_policy=ADDRESS, + ) + + def test_throw_error_when_default_minting_fee_is_less_than_zero(self): + """Test throw error when default minting fee is less than zero.""" + with pytest.raises( + PILFlavorError, + match="default_minting_fee should be greater than or equal to 0.", + ): + PILFlavor.commercial_use( + default_minting_fee=-1, + currency=WIP_TOKEN_ADDRESS, + royalty_policy=ADDRESS, + ) + + def test_not_throw_error_when_default_minting_fee_is_zero_and_royalty_policy_is_not_zero_address( + self, + ): + """Test not throw error when default minting fee is zero and royalty policy is not zero address.""" + pil_flavor = PILFlavor.commercial_use( + default_minting_fee=0, + currency=WIP_TOKEN_ADDRESS, + royalty_policy=ADDRESS, + ) + assert pil_flavor.default_minting_fee == 0 + + def test_not_throw_error_when_default_minting_fee_is_100_(self): + """Test not throw error when default minting fee is 100""" + pil_flavor = PILFlavor.commercial_use( + default_minting_fee=100, + currency=WIP_TOKEN_ADDRESS, + royalty_policy=ADDRESS, + ) + assert pil_flavor.default_minting_fee == 100 + + def test_throw_error_when_default_minting_fee_is_greater_than_zero_and_royalty_policy_is_zero_address( + self, + ): + """Test throw error when default minting fee is greater than zero and royalty policy is zero address.""" + with pytest.raises( + PILFlavorError, + match="royalty_policy is required when default_minting_fee is greater than 0.", + ): + PILFlavor.commercial_use( + default_minting_fee=10000, + currency=WIP_TOKEN_ADDRESS, + royalty_policy=ZERO_ADDRESS, + ) + + def test_throw_error_when_commercial_rev_share_is_less_than_zero(self): + """Test throw error when commercial rev share is less than zero.""" + with pytest.raises( + PILFlavorError, match="commercial_rev_share must be between 0 and 100." + ): + PILFlavor.commercial_use( + default_minting_fee=10000, + currency=WIP_TOKEN_ADDRESS, + royalty_policy=ADDRESS, + override=LicenseTermsOverride(commercial_rev_share=-1), + ) + + def test_throw_error_when_commercial_rev_share_is_greater_than_100(self): + """Test throw error when commercial rev share is greater than 100.""" + with pytest.raises( + PILFlavorError, match="commercial_rev_share must be between 0 and 100." + ): + PILFlavor.commercial_use( + default_minting_fee=10000, + currency=WIP_TOKEN_ADDRESS, + royalty_policy=ADDRESS, + override=LicenseTermsOverride(commercial_rev_share=101), + ) + + def test_throw_error_when_commercial_is_true_and_royalty_policy_is_zero_address( + self, + ): + """Test throw error when commercial is true and royalty policy is zero address.""" + with pytest.raises( + PILFlavorError, + match="royalty_policy is required when commercial_use is True.", + ): + PILFlavor.commercial_use( + default_minting_fee=0, + currency=WIP_TOKEN_ADDRESS, + royalty_policy=ZERO_ADDRESS, + ) + + class TestCommercialRemix: + """Test commercial remix PIL flavor.""" + + def test_default_values(self): + """Test default values.""" + pil_flavor = PILFlavor.commercial_remix( + default_minting_fee=10000, + currency=WIP_TOKEN_ADDRESS, + commercial_rev_share=10, + ) + assert pil_flavor == LicenseTermsInput( + transferable=True, + commercial_attribution=True, + commercial_rev_ceiling=0, + commercial_rev_share=10, + commercial_use=True, + commercializer_checker=ZERO_ADDRESS, + commercializer_checker_data=ZERO_ADDRESS, + currency=WIP_TOKEN_ADDRESS, + derivative_rev_ceiling=0, + derivatives_allowed=True, + derivatives_approval=False, + derivatives_attribution=True, + derivatives_reciprocal=True, + expiration=0, + default_minting_fee=10000, + royalty_policy=ROYALTY_POLICY_LAP_ADDRESS, + uri="https://github.com/piplabs/pil-document/blob/ad67bb632a310d2557f8abcccd428e4c9c798db1/off-chain-terms/CommercialRemix.json", + ) + + def test_with_custom_values(self): + """Test with custom values.""" + pil_flavor = PILFlavor.commercial_remix( + default_minting_fee=10000, + currency=WIP_TOKEN_ADDRESS, + commercial_rev_share=100, + override=LicenseTermsOverride( + commercial_attribution=False, + derivatives_allowed=True, + derivatives_attribution=True, + derivatives_approval=False, + derivatives_reciprocal=True, + uri="https://example.com", + royalty_policy=NativeRoyaltyPolicy.LRP, + default_minting_fee=10, + commercial_rev_share=10, + ), + ) + assert pil_flavor == LicenseTermsInput( + transferable=True, + commercial_attribution=False, + commercial_rev_ceiling=0, + derivative_rev_ceiling=0, + derivatives_allowed=True, + derivatives_approval=False, + derivatives_attribution=True, + derivatives_reciprocal=True, + currency=WIP_TOKEN_ADDRESS, + commercial_rev_share=10, + commercial_use=True, + commercializer_checker=ZERO_ADDRESS, + commercializer_checker_data=ZERO_ADDRESS, + expiration=0, + default_minting_fee=10, + royalty_policy=ROYALTY_POLICY_LRP_ADDRESS, + uri="https://example.com", + ) + + class TestCreativeCommonsAttribution: + """Test creative commons attribution PIL flavor.""" + + def test_default_values(self): + """Test default values.""" + pil_flavor = PILFlavor.creative_commons_attribution( + currency=WIP_TOKEN_ADDRESS, + ) + assert pil_flavor == LicenseTermsInput( + transferable=True, + commercial_attribution=True, + commercial_rev_ceiling=0, + commercial_rev_share=0, + commercial_use=True, + commercializer_checker=ZERO_ADDRESS, + commercializer_checker_data=ZERO_ADDRESS, + currency=WIP_TOKEN_ADDRESS, + derivative_rev_ceiling=0, + derivatives_allowed=True, + derivatives_approval=False, + derivatives_attribution=True, + derivatives_reciprocal=True, + expiration=0, + default_minting_fee=0, + royalty_policy=ROYALTY_POLICY_LAP_ADDRESS, + uri="https://github.com/piplabs/pil-document/blob/998c13e6ee1d04eb817aefd1fe16dfe8be3cd7a2/off-chain-terms/CC-BY.json", + ) + + def test_with_custom_values(self): + """Test with custom values.""" + pil_flavor = PILFlavor.creative_commons_attribution( + currency=WIP_TOKEN_ADDRESS, + override=LicenseTermsOverride( + commercial_attribution=False, + derivatives_allowed=True, + derivatives_attribution=True, + derivatives_approval=False, + derivatives_reciprocal=True, + uri="https://example.com", + royalty_policy=ADDRESS, + ), + ) + assert pil_flavor == LicenseTermsInput( + transferable=True, + commercial_attribution=False, + commercial_rev_ceiling=0, + commercial_rev_share=0, + commercial_use=True, + commercializer_checker=ZERO_ADDRESS, + commercializer_checker_data=ZERO_ADDRESS, + currency=WIP_TOKEN_ADDRESS, + derivative_rev_ceiling=0, + derivatives_allowed=True, + derivatives_approval=False, + derivatives_attribution=True, + derivatives_reciprocal=True, + expiration=0, + default_minting_fee=0, + royalty_policy=ADDRESS, + uri="https://example.com", + ) + + def test_throw_derivatives_attribution_error_when_derivatives_allowed_is_false( + self, + ): + """Test throw derivatives attribution error when derivatives allowed is false.""" + with pytest.raises( + PILFlavorError, + match="cannot add derivatives_attribution when derivatives_allowed is False.", + ): + PILFlavor.creative_commons_attribution( + currency=WIP_TOKEN_ADDRESS, + override=LicenseTermsOverride( + derivatives_allowed=False, + ), + ) + + def test_throw_derivatives_approval_error_when_derivatives_allowed_is_false( + self, + ): + """Test throw derivatives approval error when derivatives allowed is false.""" + with pytest.raises( + PILFlavorError, + match="cannot add derivatives_approval when derivatives_allowed is False.", + ): + PILFlavor.creative_commons_attribution( + currency=WIP_TOKEN_ADDRESS, + override=LicenseTermsOverride( + derivatives_allowed=False, + derivatives_approval=True, + derivatives_attribution=False, + ), + ) + + def test_throw_derivatives_reciprocal_error_when_derivatives_allowed_is_false( + self, + ): + """Test throw derivatives reciprocal error when derivatives allowed is false.""" + with pytest.raises( + PILFlavorError, + match="cannot add derivatives_reciprocal when derivatives_allowed is False.", + ): + PILFlavor.creative_commons_attribution( + currency=WIP_TOKEN_ADDRESS, + override=LicenseTermsOverride( + derivatives_allowed=False, + derivatives_reciprocal=True, + derivatives_attribution=False, + derivatives_approval=False, + ), + ) + + def test_throw_derivative_rev_ceiling_error_when_derivatives_allowed_is_false( + self, + ): + """Test throw derivative rev ceiling error when derivatives allowed is false.""" + with pytest.raises( + PILFlavorError, + match="cannot add derivative_rev_ceiling when derivatives_allowed is False.", + ): + PILFlavor.creative_commons_attribution( + currency=WIP_TOKEN_ADDRESS, + override=LicenseTermsOverride( + derivatives_allowed=False, + derivative_rev_ceiling=10000, + derivatives_attribution=False, + derivatives_approval=False, + derivatives_reciprocal=False, + ), + ) diff --git a/tests/unit/utils/test_util.py b/tests/unit/utils/test_util.py new file mode 100644 index 0000000..94c491d --- /dev/null +++ b/tests/unit/utils/test_util.py @@ -0,0 +1,45 @@ +from story_protocol_python_sdk.utils.util import ( + convert_dict_keys_to_camel_case, + snake_to_camel, +) + + +class TestSnakeToCamel: + def test_single_word(self): + assert snake_to_camel("hello") == "hello" + + def test_two_words(self): + assert snake_to_camel("hello_world") == "helloWorld" + + def test_multiple_words(self): + assert snake_to_camel("this_is_a_test") == "thisIsATest" + + def test_empty_string(self): + assert snake_to_camel("") == "" + + def test_already_camel_case(self): + assert snake_to_camel("alreadyCamel") == "alreadyCamel" + + +class TestConvertDictKeysToCamelCase: + def test_single_key(self): + result = convert_dict_keys_to_camel_case({"hello_world": 1}) + assert result == {"helloWorld": 1} + + def test_multiple_keys(self): + result = convert_dict_keys_to_camel_case( + { + "first_key": 1, + "second_key": 2, + "third_key": 3, + } + ) + assert result == { + "firstKey": 1, + "secondKey": 2, + "thirdKey": 3, + } + + def test_empty_dict(self): + result = convert_dict_keys_to_camel_case({}) + assert result == {}