From 556140e3e51d8690921862cfc6b81cdde142da23 Mon Sep 17 00:00:00 2001 From: Bartosz Cierocki Date: Sun, 3 Mar 2019 23:25:33 +0100 Subject: [PATCH] Infer type from default if provided but type omitted --- README.rst | 10 ++++++++++ smart_getenv.py | 13 +++++++------ tests.py | 19 +++++++++++++++++++ 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/README.rst b/README.rst index de9929d..52f028b 100644 --- a/README.rst +++ b/README.rst @@ -38,14 +38,24 @@ Get them: 'true' >>> getenv('BOOLEAN', type=bool) True + >>> getenv('BOOLEAN', default=False) + True >>> getenv('LIST', type=list) ['a', 'b', 'c'] + >>> getenv('LIST', default=[1, 2, 3]) + ['a', 'b', 'c'] >>> getenv('LIST', type=tuple) ('a', 'b', 'c') + >>> getenv('LIST', default=(1, 2, 3)) + ('a', 'b', 'c') >>> getenv('TRICKY_LIST', type=list, separator=':') ['d', 'e', 'f'] + >>> getenv('TRICKY_LIST', default=[1, 2, 3], separator=':') + ['d', 'e', 'f'] >>> getenv('DICT', type=dict) {'foo': 'bar'} + >>> getenv('DICT', default={'key': 'value'}) + {'foo': 'bar'} >>> getenv('LOST', default='default value anyone?') 'default value anyone?' diff --git a/smart_getenv.py b/smart_getenv.py index 3953e15..c8e2912 100644 --- a/smart_getenv.py +++ b/smart_getenv.py @@ -1,8 +1,7 @@ import os from ast import literal_eval - -__version__ = '1.1.0' +__version__ = '1.2.0' def getenv(name, **kwargs): @@ -10,9 +9,11 @@ def getenv(name, **kwargs): Retrieves environment variable by name and casts the value to desired type. If desired type is list or tuple - uses separator to split the value. """ - default_value = kwargs.pop('default', None) - desired_type = kwargs.pop('type', str) - list_separator = kwargs.pop('separator', ',') + default_value = kwargs.get('default', None) + list_separator = kwargs.get('separator', ',') + desired_type = kwargs.get('type') + if not desired_type and default_value is not None: + desired_type = type(default_value) value = os.getenv(name, None) @@ -28,7 +29,7 @@ def getenv(name, **kwargs): else: return bool(value) - if desired_type is list or desired_type is tuple: + if desired_type in (list, tuple): value = value.split(list_separator) return desired_type(value) diff --git a/tests.py b/tests.py index 2d26b56..36f7fbc 100644 --- a/tests.py +++ b/tests.py @@ -40,6 +40,7 @@ def test_getenv_type_str(self): """ os.environ[self.test_var_name] = 'abc' self.assertEqual(getenv(self.test_var_name, type=str), 'abc') + self.assertEqual(getenv(self.test_var_name, default='qwe'), 'abc') def test_getenv_type_int(self): """ @@ -58,6 +59,9 @@ def test_getenv_type_int(self): except ValueError: pass + os.environ[self.test_var_name] = '2' + self.assertEqual(getenv(self.test_var_name, default=1), 2) + def test_getenv_type_float(self): """ If environment variable exists and desired type is float: @@ -75,6 +79,9 @@ def test_getenv_type_float(self): except ValueError: pass + os.environ[self.test_var_name] = '123.4' + self.assertEqual(getenv(self.test_var_name, default=245.6), 123.4) + def test_getenv_type_bool(self): """ If environment variable exists and desired type is bool, ensure getenv returns bool. @@ -106,6 +113,9 @@ def test_getenv_type_bool(self): os.environ[self.test_var_name] = '' self.assertEqual(getenv(self.test_var_name, type=bool), False) + os.environ[self.test_var_name] = 'absolutely not a boolean' + self.assertEqual(getenv(self.test_var_name, default=False), True) + def test_getenv_type_list(self): """ If environment variable exists and desired type is list: @@ -125,6 +135,9 @@ def test_getenv_type_list(self): os.environ[self.test_var_name] = 'a:b:c' self.assertEqual(getenv(self.test_var_name, type=list, separator=':'), ['a', 'b', 'c']) + os.environ[self.test_var_name] = 'a,b,c' + self.assertEqual(getenv(self.test_var_name, default=[1, 2]), ['a', 'b', 'c']) + def test_getenv_type_tuple(self): """ If environment variable exists and desired type is tuple: @@ -144,6 +157,9 @@ def test_getenv_type_tuple(self): os.environ[self.test_var_name] = 'a:b:c' self.assertEqual(getenv(self.test_var_name, type=tuple, separator=':'), ('a', 'b', 'c')) + os.environ[self.test_var_name] = 'a,b,c' + self.assertEqual(getenv(self.test_var_name, default=(1, 2)), ('a', 'b', 'c')) + def test_getenv_type_dict(self): """ If environment variable exists and desired type is dict: @@ -169,6 +185,9 @@ def test_getenv_type_dict(self): except SyntaxError: pass + os.environ[self.test_var_name] = '{ "key": "value" }' + self.assertEqual(getenv(self.test_var_name, default={"key": "other value"}), {'key': 'value'}) + if __name__ == '__main__': unittest.main()