diff --git a/.travis.yml b/.travis.yml index aede3048a..30eb457d4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,10 @@ notifications: env: - TRAVIS_CI=1 +# set build size to avoid 137 out of memory error +vm: + size: large + # conda setup copied from the conda docs install: # We do this conditionally because it saves us some downloading if the diff --git a/ramutils/cli/expconf.py b/ramutils/cli/expconf.py index 324e2bfb0..039c11c77 100644 --- a/ramutils/cli/expconf.py +++ b/ramutils/cli/expconf.py @@ -81,8 +81,17 @@ def validate_stim_settings(args): if args.experiment != "AmplitudeDetermination" and not args.experiment.startswith('PS'): if args.target_amplitudes is None: raise RuntimeError("--target-amplitudes is required") + if not (args.min_amplitudes is None and args.max_amplitudes is None): + raise RuntimeError('--min-amplitudes and --max-amplitudes are only used with ' + 'AmplitudeDetermination and "PS" experiments. To specify ' + 'a config for ' + args.experiment + ', only specify ' + '--target-amplitudes') valid = len(args.anodes) == len(args.target_amplitudes) else: + if not args.target_amplitudes is None: + raise RuntimeError('Cannot specify --target-amplitudes (which is only used in ' + '"PS" and AtmplitudeDetermination experiments) with ' + '--min-amplitudes or --max-amplitudes!') valid = len(args.anodes) == len( args.min_amplitudes) == len(args.max_amplitudes) diff --git a/ramutils/constants.py b/ramutils/constants.py index c6b91995b..fc3e70eed 100644 --- a/ramutils/constants.py +++ b/ramutils/constants.py @@ -12,7 +12,9 @@ 'CatFR1', 'PAL1', 'DBOY1', - 'RepFR1' + 'RepFR1', + 'IFR1', + 'ICatFR1' ], 'ps': [ 'PS4_FR5', @@ -38,6 +40,12 @@ 'CatFR6', 'TICL_FR', 'TICL_CatFR', + 'IFR3', + 'ICatFR3', + 'IFR5', + 'ICatFR5', + 'IFR6', + 'ICatFR6', ], # Experiments that allow multiple stim locations @@ -48,6 +56,8 @@ 'PS4_CatFR5', 'FR6', 'CatFR6', + 'IFR6', + 'ICatFR6', ] } diff --git a/ramutils/tasks/events.py b/ramutils/tasks/events.py index 4cf68b8e8..5b64f78f7 100644 --- a/ramutils/tasks/events.py +++ b/ramutils/tasks/events.py @@ -11,6 +11,9 @@ from ramutils.tasks import task from ramutils.utils import extract_experiment_series +from ..utils import insert_column +import numpy as np + __all__ = [ 'get_word_event_mask', 'subset_events', @@ -67,8 +70,36 @@ def build_training_data(subject, experiment, paths, sessions=None, **kwargs): pre=kwargs['pre_event_buf'], post=kwargs['post_event_buf']) + ifr_events = load_events(subject, 'IFR1', sessions=sessions, + rootdir=paths.root) + cleaned_ifr_events = clean_events(ifr_events, + start_time=kwargs['baseline_removal_start_time'], + end_time=kwargs['retrieval_time'], + duration=kwargs['empty_epoch_duration'], + pre=kwargs['pre_event_buf'], + post=kwargs['post_event_buf']) + if 'iscorrect' not in cleaned_ifr_events.dtype.names: + iscorrect_index = fr_events.dtype.names.index('iscorrect') + cleaned_ifr_events = insert_column(cleaned_ifr_events, 'iscorrect', + np.full(cleaned_ifr_events.shape, -999), int, iscorrect_index) + + icatfr_events = load_events(subject, 'ICatFR1', + sessions=sessions, + rootdir=paths.root) + cleaned_icatfr_events = clean_events(icatfr_events, + start_time=kwargs['baseline_removal_start_time'], + end_time=kwargs['retrieval_time'], + duration=kwargs['empty_epoch_duration'], + pre=kwargs['pre_event_buf'], + post=kwargs['post_event_buf']) + + if 'iscorrect' not in cleaned_icatfr_events.dtype.names: + iscorrect_index = catfr_events.dtype.names.index('iscorrect') + cleaned_icatfr_events = insert_column(cleaned_icatfr_events, 'iscorrect', + np.full(cleaned_icatfr_events.shape, -999), int, iscorrect_index) + free_recall_events = concatenate_events_across_experiments( - [cleaned_fr_events, cleaned_catfr_events], cat=True) + [cleaned_fr_events, cleaned_catfr_events, cleaned_ifr_events, cleaned_icatfr_events], cat=True) elif "FR" in experiment and not kwargs['combine_events']: free_recall_events = load_events(subject, experiment, sessions=sessions, @@ -140,13 +171,38 @@ def build_test_data(subject, experiment, paths, joint_report, sessions=None, pre=kwargs['pre_event_buf'], post=kwargs['post_event_buf'], return_stim_events=True) + # Immediate Free Recall variants + ifr_events = load_events(subject, 'IFR' + series_num, + sessions=sessions, + rootdir=paths.root) + cleaned_ifr_events, ifr_stim_params = clean_events( + ifr_events, start_time=kwargs['baseline_removal_start_time'], + end_time=kwargs['retrieval_time'], + duration=kwargs['empty_epoch_duration'], + pre=kwargs['pre_event_buf'], post=kwargs['post_event_buf'], + return_stim_events=True) + + icatfr_events = load_events(subject, 'ICatFR' + series_num, + sessions=sessions, + rootdir=paths.root) + cleaned_icatfr_events, icatfr_stim_params = clean_events( + icatfr_events, start_time=kwargs['baseline_removal_start_time'], + end_time=kwargs['retrieval_time'], + duration=kwargs['empty_epoch_duration'], + pre=kwargs['pre_event_buf'], post=kwargs['post_event_buf'], + return_stim_events=True) + all_events = concatenate_events_across_experiments([fr_events, - catfr_events]) + catfr_events, + ifr_events, + icatfr_events]) task_events = concatenate_events_across_experiments( [cleaned_fr_events, cleaned_catfr_events], cat=True) stim_params = concatenate_events_across_experiments([fr_stim_params, - catfr_stim_params], + catfr_stim_params, + ifr_stim_params, + icatfr_stim_params], stim=True) elif not joint_report and 'FR' in experiment: diff --git a/ramutils/tasks/summary.py b/ramutils/tasks/summary.py index b1a2ad0f9..8d0537600 100644 --- a/ramutils/tasks/summary.py +++ b/ramutils/tasks/summary.py @@ -275,7 +275,7 @@ def summarize_stim_sessions(all_events, task_events, stim_params, pairs_data, # TODO: Add some sort of data quality check here potentially. Do the # observed stim items match what we expect from classifier output? - if experiment in ['FR3', 'FR5', 'catFR3', 'catFR5', 'FR6', 'catFR6']: + if experiment in ['FR3', 'FR5', 'catFR3', 'catFR5', 'FR6', 'catFR6', 'ICatFR5', 'FR5', 'ICatFR6', 'FR6']: stim_events = dataframe_to_recarray(stim_df, expected_dtypes) stim_session_summary = FRStimSessionSummary() stim_session_summary.populate( diff --git a/ramutils/utils.py b/ramutils/utils.py index 8d981d14a..1e84412ea 100644 --- a/ramutils/utils.py +++ b/ramutils/utils.py @@ -390,3 +390,32 @@ def encode_file(fd): """ fd.seek(0) return base64.b64encode(fd.read()).decode() + + +def insert_column(recarr, column_name, data, dtype, position): + """ + Insert a column into a recarray at a specific position. + + Parameters: + - recarr: the original recarray + - column_name: the name of the new column to be added + - data: the data for the new column + - dtype: the data type for the new column + - position: the position to insert the new column at + Returns: + - new_arr: a new recarray with the inserted column + """ + + if position > len(recarr.dtype.names): + raise ValueError("Position is out of range.") + + before = [(name, recarr.dtype.fields[name][0]) for name in recarr.dtype.names[:position]] + after = [(name, recarr.dtype.fields[name][0]) for name in recarr.dtype.names[position:]] + new_dtype = np.dtype(before + [(column_name, dtype)] + after) + + new_arr = np.rec.array(np.zeros(recarr.shape, dtype=new_dtype)) + + for name in recarr.dtype.names: + new_arr[name] = recarr[name] + new_arr[column_name] = data + return new_arr