diff --git a/src/main/kotlin/plus/maa/backend/controller/UserController.kt b/src/main/kotlin/plus/maa/backend/controller/UserController.kt index 82c1005e..1571e4c4 100644 --- a/src/main/kotlin/plus/maa/backend/controller/UserController.kt +++ b/src/main/kotlin/plus/maa/backend/controller/UserController.kt @@ -1,4 +1,4 @@ -package plus.maa.backend.controller +package plus.maa.backend.controller import io.swagger.v3.oas.annotations.Operation import io.swagger.v3.oas.annotations.media.Content @@ -48,7 +48,7 @@ class UserController( private val helper: AuthenticationHelper, ) { /** - * 更新当前用户的密码(根据原密码) + * 更新当前用户的密码,根据原密码 * * @return http响应 */ @@ -77,7 +77,7 @@ class UserController( } /** - * 邮箱重设密码 + * 邮箱重置密码 * * @param passwordResetDTO 通过邮箱修改密码请求 * @return 成功响应 @@ -156,17 +156,44 @@ class UserController( fun login(@RequestBody user: @Valid LoginDTO): MaaResult = success("登陆成功", userService.login(user)) /** - * 查询用户信息 + * 获取当前登录用户信息 + */ + @GetMapping("/me") + @Operation(summary = "获取当前登录用户信息") + @ApiResponse(description = "当前用户详情信息") + @RequireJwt + fun getMe(): MaaResult { + return success(userService.getMe(helper.userId)) + } + + /** + * 查询用户信息(附带与当前用户的关系) */ @GetMapping("/info") @Operation(summary = "查询用户信息") @ApiResponse(responseCode = "200", description = "用户详情信息") @ApiResponse(responseCode = "404", content = [Content()]) fun getUserInfo(@RequestParam userId: String): MaaResult { - val userInfo = - userService.get(userId.toLongOrNull() ?: throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid user ID")) - ?: throw ResponseStatusException(HttpStatus.NOT_FOUND) - return success(userInfo) + val targetId = userId.toLongOrNull() + ?: throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid user ID") + val currentUserId = helper.obtainUserId()?.toLongOrNull() + return success(userService.getWithRelation(targetId, currentUserId)) + } + + /** + * 批量获取用户信息 + */ + @GetMapping("/batch") + @Operation(summary = "批量获取用户信息") + @ApiResponse(description = "用户信息列表") + fun getBatchUserInfo( + @RequestParam ids: List, + ): MaaResult> { + if (ids.size > 50) { + throw ResponseStatusException(HttpStatus.BAD_REQUEST, "单次查询用户量不能超过50") + } + val currentUserId = helper.obtainUserId()?.toLongOrNull() + return success(userService.getBatchUserInfos(ids, currentUserId)) } /** diff --git a/src/main/kotlin/plus/maa/backend/controller/response/user/MaaUserInfo.kt b/src/main/kotlin/plus/maa/backend/controller/response/user/MaaUserInfo.kt index dcfff513..8b1915ed 100644 --- a/src/main/kotlin/plus/maa/backend/controller/response/user/MaaUserInfo.kt +++ b/src/main/kotlin/plus/maa/backend/controller/response/user/MaaUserInfo.kt @@ -1,8 +1,10 @@ -package plus.maa.backend.controller.response.user +package plus.maa.backend.controller.response.user +import kotlinx.serialization.Contextual import kotlinx.serialization.Serializable import plus.maa.backend.repository.entity.MaaUser import plus.maa.backend.repository.entity.UserEntity +import java.time.Instant /** * 用户可对外公开的信息 @@ -16,6 +18,8 @@ data class MaaUserInfo( val activated: Boolean = false, val followingCount: Int = 0, val fansCount: Int = 0, + val relation: RelationType? = null, + @Contextual val followedAt: Instant? = null, ) { constructor(user: MaaUser) : this( id = user.userId!!, diff --git a/src/main/kotlin/plus/maa/backend/controller/response/user/RelationType.kt b/src/main/kotlin/plus/maa/backend/controller/response/user/RelationType.kt new file mode 100644 index 00000000..627ef7ed --- /dev/null +++ b/src/main/kotlin/plus/maa/backend/controller/response/user/RelationType.kt @@ -0,0 +1,8 @@ +package plus.maa.backend.controller.response.user + +import kotlinx.serialization.Serializable + +@Serializable +enum class RelationType { + SELF, NONE, FOLLOWING, FOLLOWED_BY, MUTUAL +} diff --git a/src/main/kotlin/plus/maa/backend/repository/ktorm/UserKtormRepository.kt b/src/main/kotlin/plus/maa/backend/repository/ktorm/UserKtormRepository.kt index 3e5c5e2c..a011502e 100644 --- a/src/main/kotlin/plus/maa/backend/repository/ktorm/UserKtormRepository.kt +++ b/src/main/kotlin/plus/maa/backend/repository/ktorm/UserKtormRepository.kt @@ -9,6 +9,7 @@ import org.ktorm.dsl.inList import org.ktorm.dsl.insert import org.ktorm.dsl.limit import org.ktorm.dsl.map +import org.ktorm.dsl.mapNotNull import org.ktorm.dsl.minus import org.ktorm.dsl.plus import org.ktorm.dsl.select @@ -164,6 +165,62 @@ class UserKtormRepository( } } + /** + * 查询 userId 关注了 targetIds 中的哪些用户,返回 followUserId -> updatedAt 的映射 + */ + fun getFollowUpdatedAtMap(userId: Long, targetIds: List): Map { + if (targetIds.isEmpty()) return emptyMap() + return database.from(UserFollows) + .select(UserFollows.followUserId, UserFollows.updatedAt) + .where { (UserFollows.userId eq userId) and (UserFollows.followUserId inList targetIds) } + .map { row -> row[UserFollows.followUserId]!! to row[UserFollows.updatedAt]!! } + .toMap() + } + + /** + * 查询 fanIds 中谁关注了 userId,返回 fanId -> updatedAt 的映射 + */ + fun getFansUpdatedAtMap(fanIds: List, userId: Long): Map { + if (fanIds.isEmpty()) return emptyMap() + return database.from(UserFollows) + .select(UserFollows.userId, UserFollows.updatedAt) + .where { (UserFollows.userId inList fanIds) and (UserFollows.followUserId eq userId) } + .map { row -> row[UserFollows.userId]!! to row[UserFollows.updatedAt]!! } + .toMap() + } + + /** + * 查询 userId 关注了 targetIds 中的哪些用户,返回被关注的 targetIds 子集 + */ + fun getFollowedTargetIds(userId: Long, targetIds: List): Set { + if (targetIds.isEmpty()) return emptySet() + return database.from(UserFollows) + .select(UserFollows.followUserId) + .where { (UserFollows.userId eq userId) and (UserFollows.followUserId inList targetIds) } + .mapNotNull { it[UserFollows.followUserId] } + .toSet() + } + + /** + * 查询 targetIds 中哪些用户关注了 userId,返回关注了 userId 的 targetIds 子集 + */ + fun getFollowerTargetIds(targetIds: List, userId: Long): Set { + if (targetIds.isEmpty()) return emptySet() + return database.from(UserFollows) + .select(UserFollows.userId) + .where { (UserFollows.userId inList targetIds) and (UserFollows.followUserId eq userId) } + .mapNotNull { it[UserFollows.userId] } + .toSet() + } + + fun isFollowing(userId: Long, followUserId: Long): Boolean { + return database.from(UserFollows).select(UserFollows.userId) + .where { (UserFollows.userId eq userId) and (UserFollows.followUserId eq followUserId) } + .limit(1) + .map { it[UserFollows.userId] } + .isNotEmpty() + } + override fun save(entity: UserEntity): UserEntity { return if (isNewEntity(entity)) { insertEntity(entity) diff --git a/src/main/kotlin/plus/maa/backend/service/UserService.kt b/src/main/kotlin/plus/maa/backend/service/UserService.kt index 52401b10..89447a40 100644 --- a/src/main/kotlin/plus/maa/backend/service/UserService.kt +++ b/src/main/kotlin/plus/maa/backend/service/UserService.kt @@ -12,8 +12,10 @@ import org.ktorm.entity.sortedBy import org.ktorm.entity.take import org.ktorm.entity.toList import org.springframework.dao.DuplicateKeyException +import org.springframework.http.HttpStatus import org.springframework.security.crypto.password.PasswordEncoder import org.springframework.stereotype.Service +import org.springframework.web.server.ResponseStatusException import plus.maa.backend.common.MaaStatusCode import plus.maa.backend.common.extensions.toMaaUser import plus.maa.backend.controller.request.user.LoginDTO @@ -24,6 +26,7 @@ import plus.maa.backend.controller.request.user.UserInfoUpdateDTO import plus.maa.backend.controller.response.MaaResultException import plus.maa.backend.controller.response.user.MaaLoginRsp import plus.maa.backend.controller.response.user.MaaUserInfo +import plus.maa.backend.controller.response.user.RelationType import plus.maa.backend.repository.entity.MaaUser import plus.maa.backend.repository.entity.UserEntity import plus.maa.backend.repository.entity.users @@ -117,7 +120,7 @@ class UserService( */ fun register(registerDTO: RegisterDTO): MaaUserInfo { val userName = registerDTO.userName.trim() - check(userName.length >= 4) { "用户名长度应在4-24位之间" } + check(userName.length in 4..24) { "用户名长度应在4-24位之间" } check(!userKtormRepository.existsByUserName(userName)) { "用户名已存在,请重新取个名字吧" } @@ -127,7 +130,7 @@ class UserService( val encoded = passwordEncoder.encode(registerDTO.password)!! - val user = MaaUser( + val maaUser = MaaUser( userName = userName, email = registerDTO.email, password = encoded, @@ -135,9 +138,9 @@ class UserService( pwdUpdateTime = Instant.now(), ) return try { - val userEntity = userKtormRepository.createFromMaaUser(user) - userKtormRepository.save(userEntity) - MaaUserInfo(userEntity).also { + val entity = userKtormRepository.createFromMaaUser(maaUser) + userKtormRepository.save(entity) + MaaUserInfo(entity).also { Cache.invalidateMaaUserById(it.id) } } catch (_: DuplicateKeyException) { @@ -154,11 +157,11 @@ class UserService( fun updateUserInfo(userId: Long, updateDTO: UserInfoUpdateDTO) { val userEntity = userKtormRepository.findById(userId) ?: return val newName = updateDTO.userName.trim() - check(newName.length >= 4) { "用户名长度应在4-24位之间" } if (newName == userEntity.userName) { // 暂时只支持修改用户名,如果有其他字段修改需要同步修改该逻辑 return } + check(newName.length in 4..24) { "用户名长度应在4-24位之间" } // 用户名需要trim check(!userKtormRepository.existsByUserName(newName)) { "用户名已存在,请重新取个名字吧" @@ -277,6 +280,72 @@ class UserService( .toList() } + /** + * 获取当前登录用户信息 + */ + fun getMe(userId: Long): MaaUserInfo { + val userEntity = userKtormRepository.findById(userId) + ?: throw ResponseStatusException(HttpStatus.NOT_FOUND) + return MaaUserInfo(userEntity) + } + + /** + * 获取用户信息并附带与当前用户的关系 + */ + fun getWithRelation(targetId: Long, currentUserId: Long?): MaaUserInfo { + val userEntity = userKtormRepository.findById(targetId) + ?: throw ResponseStatusException(HttpStatus.NOT_FOUND) + val base = MaaUserInfo(userEntity) + if (currentUserId == null) return base + val relation = resolveRelation(currentUserId, targetId) + return base.copy(relation = relation) + } + + /** + * 批量获取用户信息(可选附带关系) + */ + fun getBatchUserInfos(ids: List, currentUserId: Long?): List { + if (ids.isEmpty()) return emptyList() + val users = userKtormRepository.findAllById(ids) + // 保证结果顺序与输入 ids 一致 + val userMap = users.associateBy { it.userId } + if (currentUserId == null) { + return ids.mapNotNull { userMap[it] }.map { MaaUserInfo(it) } + } + // 当前用户关注了哪些目标 + val iFollowIds = userKtormRepository.getFollowedTargetIds(currentUserId, ids) + // 哪些目标关注了当前用户 + val theyFollowMeIds = userKtormRepository.getFollowerTargetIds(ids, currentUserId) + return ids.mapNotNull { userMap[it] }.map { user -> + val uid = user.userId + val iFollow = uid in iFollowIds + val theyFollow = uid in theyFollowMeIds + val relation = when { + uid == currentUserId -> RelationType.SELF + iFollow && theyFollow -> RelationType.MUTUAL + iFollow -> RelationType.FOLLOWING + theyFollow -> RelationType.FOLLOWED_BY + else -> RelationType.NONE + } + MaaUserInfo(user).copy(relation = relation) + } + } + + /** + * 解析当前用户与目标用户的关系 + */ + private fun resolveRelation(currentUserId: Long, targetUserId: Long): RelationType { + if (currentUserId == targetUserId) return RelationType.SELF + val iFollow = userKtormRepository.isFollowing(currentUserId, targetUserId) + val theyFollow = userKtormRepository.isFollowing(targetUserId, currentUserId) + return when { + iFollow && theyFollow -> RelationType.MUTUAL + iFollow -> RelationType.FOLLOWING + theyFollow -> RelationType.FOLLOWED_BY + else -> RelationType.NONE + } + } + @Suppress("unused") private fun isAllChinese(input: String): Boolean { return input.all { it in '\u4e00'..'\u9fa5' } diff --git a/src/main/kotlin/plus/maa/backend/service/follow/UserFollowService.kt b/src/main/kotlin/plus/maa/backend/service/follow/UserFollowService.kt index 667ecb0a..a0e5e85f 100644 --- a/src/main/kotlin/plus/maa/backend/service/follow/UserFollowService.kt +++ b/src/main/kotlin/plus/maa/backend/service/follow/UserFollowService.kt @@ -6,7 +6,9 @@ import org.springframework.stereotype.Service import plus.maa.backend.common.extensions.paginate import plus.maa.backend.common.extensions.toMaaUserInfo import plus.maa.backend.controller.response.user.MaaUserInfo +import plus.maa.backend.controller.response.user.RelationType import plus.maa.backend.repository.ktorm.UserKtormRepository +import java.time.ZoneId @Service class UserFollowService( @@ -15,7 +17,7 @@ class UserFollowService( fun follow(userId: Long, followUserId: Long) { check(userId != followUserId) { - "不能关注自己哦~" + "不能关注自己哦~" } val followUser = userKtormRepository.findById(followUserId) check(followUser != null && followUser.status > 0) { @@ -30,11 +32,35 @@ class UserFollowService( fun getFollowingList(userId: Long, pageable: Pageable): PageImpl { val res = userKtormRepository.follows(userId).paginate(pageable) - return PageImpl(res.map { it.toMaaUserInfo() }.toList(), pageable, res.totalElements) + val users = res.toList() + val targetIds = users.map { it.userId } + // 查询关注时间 + val updatedAtMap = userKtormRepository.getFollowUpdatedAtMap(userId, targetIds) + // 查询哪些目标也关注了我(用于判断 MUTUAL) + val mutualIds = userKtormRepository.getFollowerTargetIds(targetIds, userId) + val enriched = users.map { user -> + val info = user.toMaaUserInfo() + val relation = if (user.userId in mutualIds) RelationType.MUTUAL else RelationType.FOLLOWING + val followedAt = updatedAtMap[user.userId]?.atZone(ZoneId.systemDefault())?.toInstant() + info.copy(relation = relation, followedAt = followedAt) + } + return PageImpl(enriched, pageable, res.totalElements) } fun getFansList(userId: Long, pageable: Pageable): PageImpl { val res = userKtormRepository.fans(userId).paginate(pageable) - return PageImpl(res.map { it.toMaaUserInfo() }.toList(), pageable, res.totalElements) + val users = res.toList() + val fanIds = users.map { it.userId } + // 批量查询粉丝关注我的时间 + val fanUpdatedAtMap = userKtormRepository.getFansUpdatedAtMap(fanIds, userId) + // 查询我关注了哪些粉丝(用于判断 MUTUAL) + val iFollowBackIds = userKtormRepository.getFollowedTargetIds(userId, fanIds) + val enriched = users.map { user -> + val info = user.toMaaUserInfo() + val relation = if (user.userId in iFollowBackIds) RelationType.MUTUAL else RelationType.FOLLOWED_BY + val followedAt = fanUpdatedAtMap[user.userId]?.atZone(ZoneId.systemDefault())?.toInstant() + info.copy(relation = relation, followedAt = followedAt) + } + return PageImpl(enriched, pageable, res.totalElements) } }