diff --git a/compair/api/classlist.py b/compair/api/classlist.py index 14015f5a..55797f44 100644 --- a/compair/api/classlist.py +++ b/compair/api/classlist.py @@ -180,6 +180,9 @@ def import_users(import_type, course, users): # overwrite password if user has not logged in yet if u.last_online == None and not password in [None, '*']: set_user_passwords.append((u, password)) + if (import_type == ThirdPartyType.cas.value or import_type == ThirdPartyType.saml.value) \ + and u.global_unique_identifier is None: + u.global_unique_identifier = username else: u = User( username=None, @@ -191,6 +194,7 @@ def import_users(import_type, course, users): ) if import_type == ThirdPartyType.cas.value or import_type == ThirdPartyType.saml.value: # CAS/SAML login + u.global_unique_identifier = username u.third_party_auths.append(ThirdPartyUser( unique_identifier=username, third_party_type=ThirdPartyType(import_type) diff --git a/compair/models/lti_models/lti_user.py b/compair/models/lti_models/lti_user.py index 558b4851..4259d27c 100644 --- a/compair/models/lti_models/lti_user.py +++ b/compair/models/lti_models/lti_user.py @@ -49,6 +49,13 @@ def generate_or_link_user_account(self): .filter_by(global_unique_identifier=self.global_unique_identifier) \ .one_or_none() + if not self.compair_user and self.student_number: + self.compair_user = User.query \ + .filter_by(student_number=self.student_number) \ + .one_or_none() + if self.compair_user and self.compair_user.global_unique_identifier is None: + self.compair_user.global_unique_identifier = self.global_unique_identifier + if not self.compair_user: self.compair_user = User( username=None, diff --git a/compair/tests/test_models.py b/compair/tests/test_models.py index 6ea8a4f3..9e4d39f8 100644 --- a/compair/tests/test_models.py +++ b/compair/tests/test_models.py @@ -7,13 +7,16 @@ from compair import db from compair.models import User, Comparison, AnswerScore, \ - AnswerCriterionScore, LTIOutcome, SystemRole + AnswerCriterionScore, LTIOutcome, SystemRole, ThirdPartyUser +from compair.models.lti_models import LTIUser +from compair.models import ThirdPartyType from compair.models.comparison import update_answer_scores, \ update_answer_criteria_scores from compair.tests.test_compair import ComPAIRTestCase from compair.algorithms import ComparisonPair, ComparisonWinner from compair.algorithms.score import calculate_score from data.fixtures.test_data import TestFixture, LTITestData +from data.factories import LTIConsumerFactory, UserFactory class TestUsersModel(ComPAIRTestCase): user = User() @@ -115,6 +118,86 @@ def test_update_answer_criteria_scores(self): scores = update_answer_criteria_scores([score], 1, criterion_comparison_results) self.assertEqual(len(scores), 4) +class TestLTIUserGenerateOrLinkAccount(ComPAIRTestCase): + def setUp(self): + super(TestLTIUserGenerateOrLinkAccount, self).setUp() + self.lti_consumer = LTIConsumerFactory( + global_unique_identifier_param='custom_puid', + student_number_param='custom_student_number' + ) + db.session.commit() + + def test_links_existing_saml_user_by_student_number_when_global_unique_identifier_missing(self): + # user created via SAML - has student number but no global_unique_identifier + existing_user = UserFactory( + system_role=SystemRole.student, + student_number='12345678', + global_unique_identifier=None, + username=None, + password=None + ) + ThirdPartyUser( + unique_identifier='saml_identifier', + third_party_type=ThirdPartyType.saml, + user=existing_user + ) + db.session.commit() + + user_count_before = User.query.count() + + lti_user = LTIUser( + lti_consumer=self.lti_consumer, + user_id='canvas_user_123', + system_role=SystemRole.student, + global_unique_identifier='puid_abc', + student_number='12345678' + ) + db.session.add(lti_user) + lti_user.generate_or_link_user_account() + + # no new user should be created + self.assertEqual(User.query.count(), user_count_before) + # lti_user should be linked to the existing user + self.assertEqual(lti_user.compair_user_id, existing_user.id) + # global_unique_identifier should be backfilled on the existing user + self.assertEqual(existing_user.global_unique_identifier, 'puid_abc') + + def test_does_not_overwrite_existing_global_unique_identifier_when_linking_by_student_number(self): + # user already has a global_unique_identifier set from a prior SAML/CAS login + existing_user = UserFactory( + system_role=SystemRole.student, + student_number='12345678', + global_unique_identifier='existing_puid', + username=None, + password=None + ) + ThirdPartyUser( + unique_identifier='saml_identifier', + third_party_type=ThirdPartyType.saml, + user=existing_user + ) + db.session.commit() + + user_count_before = User.query.count() + + lti_user = LTIUser( + lti_consumer=self.lti_consumer, + user_id='canvas_user_456', + system_role=SystemRole.student, + global_unique_identifier='different_puid', + student_number='12345678' + ) + db.session.add(lti_user) + lti_user.generate_or_link_user_account() + + # no new user should be created + self.assertEqual(User.query.count(), user_count_before) + # lti_user should be linked to the existing user + self.assertEqual(lti_user.compair_user_id, existing_user.id) + # existing global_unique_identifier must not be overwritten + self.assertEqual(existing_user.global_unique_identifier, 'existing_puid') + + class TestLTIOutcome(ComPAIRTestCase): def setUp(self):