What needs to be done?
Currently, when a modelopt state file is loaded using modelopt.torch.opt.conversion.load_modelopt_state(), there is no validation performed to ensure that the loaded file actually contains a valid modelopt state object (or dictionary) with the expected structure.
Proposed Solution
Add validation logic inside load_modelopt_state (located in modelopt/torch/opt/conversion.py) to verify the schema of the loaded dictionary before returning it.
Specifically:
- Ensure the loaded object is a dictionary.
- Verify that it contains the expected keys (such as
"modelopt_state" or other expected schemas).
- Raise a clear
ValueError or TypeError if the file doesn't match the expected structure, rather than allowing downstream errors to occur.
Context
This addresses the existing TODO in modelopt/torch/opt/conversion.py:
# TODO: Add some validation to ensure the file is a valid modelopt state file.
modelopt_state = torch.load(modelopt_state_path, **kwargs)
return modelopt_state
What needs to be done?
Currently, when a
modeloptstate file is loaded usingmodelopt.torch.opt.conversion.load_modelopt_state(), there is no validation performed to ensure that the loaded file actually contains a validmodeloptstate object (or dictionary) with the expected structure.Proposed Solution
Add validation logic inside
load_modelopt_state(located inmodelopt/torch/opt/conversion.py) to verify the schema of the loaded dictionary before returning it.Specifically:
"modelopt_state"or other expected schemas).ValueErrororTypeErrorif the file doesn't match the expected structure, rather than allowing downstream errors to occur.Context
This addresses the existing
TODOinmodelopt/torch/opt/conversion.py: