diff --git a/api/email.go b/api/email.go index 77f9d3a..e6a37f0 100644 --- a/api/email.go +++ b/api/email.go @@ -7,7 +7,6 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/piquel-fr/api/config" - "github.com/piquel-fr/api/database" "github.com/piquel-fr/api/database/repository" "github.com/piquel-fr/api/services/auth" "github.com/piquel-fr/api/services/email" @@ -314,7 +313,7 @@ func (h *EmailHandler) handleAddAccount(w http.ResponseWriter, r *http.Request) } params.OwnerId = user.ID - if _, err = database.Queries.AddEmailAccount(r.Context(), params); err != nil { + if _, err = h.emailService.AddAccount(r.Context(), params); err != nil { errors.HandleError(w, r, err) return } diff --git a/database/database.go b/database/database.go deleted file mode 100644 index de48acc..0000000 --- a/database/database.go +++ /dev/null @@ -1,31 +0,0 @@ -package database - -import ( - "context" - "log" - - "github.com/jackc/pgx/v5/pgxpool" - "github.com/piquel-fr/api/config" - "github.com/piquel-fr/api/database/repository" -) - -var Queries *repository.Queries -var Connection *pgxpool.Pool - -func InitDatabase() { - log.Printf("[Database] Attempting to connect to the database...\n") - - conn, err := pgxpool.New(context.Background(), config.Envs.DBURL) - if err != nil { - panic(err) - } - - Connection = conn - - if err = Connection.Ping(context.Background()); err != nil { - panic(err) - } - - Queries = repository.New(Connection) - log.Printf("[Database] Successfully connected to the database!\n") -} diff --git a/main.go b/main.go index f551813..7ac5892 100644 --- a/main.go +++ b/main.go @@ -7,9 +7,9 @@ import ( "github.com/piquel-fr/api/api" "github.com/piquel-fr/api/config" - "github.com/piquel-fr/api/database" "github.com/piquel-fr/api/services/auth" "github.com/piquel-fr/api/services/email" + "github.com/piquel-fr/api/services/storage" "github.com/piquel-fr/api/services/users" gh "github.com/piquel-fr/api/utils/github" "github.com/piquel-fr/api/utils/oauth" @@ -22,12 +22,12 @@ func main() { config.LoadConfig() gh.InitGithubClient() oauth.InitOAuth() - database.InitDatabase() - defer database.Connection.Close() - userService := users.NewRealUserService() - authService := auth.NewRealAuthService(userService) - emailService := email.NewRealEmailService() + storageService := storage.NewDatabaseStorageService() + defer storageService.Close() + userService := users.NewRealUserService(storageService) + authService := auth.NewRealAuthService(storageService, userService) + emailService := email.NewRealEmailService(storageService) config.UsernameBlacklist = userService.GetUsernameBlacklist() config.Policy = authService.GetPolicy() diff --git a/services/auth/auth.go b/services/auth/auth.go index 3894cd5..c27521d 100644 --- a/services/auth/auth.go +++ b/services/auth/auth.go @@ -12,8 +12,8 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/jackc/pgx/v5" "github.com/piquel-fr/api/config" - "github.com/piquel-fr/api/database" "github.com/piquel-fr/api/database/repository" + "github.com/piquel-fr/api/services/storage" "github.com/piquel-fr/api/services/users" "github.com/piquel-fr/api/utils" "github.com/piquel-fr/api/utils/errors" @@ -50,11 +50,12 @@ type AuthService interface { } type realAuthService struct { - userService users.UserService + userService users.UserService + storageService storage.StorageService } -func NewRealAuthService(userService users.UserService) AuthService { - return &realAuthService{userService} +func NewRealAuthService(storageService storage.StorageService, userService users.UserService) AuthService { + return &realAuthService{userService, storageService} } func (s *realAuthService) GetPolicy() *config.PolicyConfiguration { return &policy } @@ -87,7 +88,7 @@ func (s *realAuthService) FinishAuth(user *repository.User, r *http.Request, w h IpAdress: ipAddress, ExpiresAt: time.Now().Add(refreshExpiry), // one month } - if _, err := database.Queries.AddSession(r.Context(), sessionParams); err != nil { + if _, err := s.storageService.AddSession(r.Context(), sessionParams); err != nil { return err } @@ -101,7 +102,7 @@ func (s *realAuthService) Refresh(w http.ResponseWriter, r *http.Request) error cookies := utils.GetCookiesFromStr(r.Header.Get("Cookie")) hash := s.hashRefreshToken(cookies[refreshKey], ipAddress) - session, err := database.Queries.GetSessionFromHash(r.Context(), hash) + session, err := s.storageService.GetSessionFromHash(r.Context(), hash) if errors.Is(err, pgx.ErrNoRows) { return errors.ErrorNotAuthenticated } @@ -135,7 +136,7 @@ func (s *realAuthService) Refresh(w http.ResponseWriter, r *http.Request) error TokenHash: refreshHash, ExpiresAt: time.Now().Add(refreshExpiry), } - if err := database.Queries.UpdateSession(r.Context(), updateSessionParams); err != nil { + if err := s.storageService.UpdateSession(r.Context(), updateSessionParams); err != nil { return err } @@ -152,7 +153,7 @@ func (s *realAuthService) Logout(w http.ResponseWriter, r *http.Request) error { w.Header().Add("Set-Cookie", utils.GenerateClearCookie(accessKey, config.Envs.Domain, "/")) hash := s.hashRefreshToken(cookies[refreshKey], ipAddress) - return database.Queries.DeleteSessionByHash(r.Context(), hash) + return s.storageService.DeleteSessionByHash(r.Context(), hash) } func (s *realAuthService) generateAccessToken(user *repository.User, expiresAt time.Time) *jwt.Token { @@ -226,13 +227,13 @@ func (s *realAuthService) AuthMiddleware(next http.Handler) http.Handler { } func (s *realAuthService) GetUserSessions(ctx context.Context, userId int32) ([]*repository.UserSession, error) { - return database.Queries.GetUserSessions(ctx, userId) + return s.storageService.GetUserSessions(ctx, userId) } func (s *realAuthService) DeleteUserSession(ctx context.Context, userId, id int32) error { - return database.Queries.DeleteSessionById(ctx, userId, id) + return s.storageService.DeleteSessionById(ctx, userId, id) } func (s *realAuthService) DeleteUserSessions(ctx context.Context, userId int32) error { - return database.Queries.ClearUserSessions(ctx, userId) + return s.storageService.ClearUserSessions(ctx, userId) } diff --git a/services/email/accounts.go b/services/email/accounts.go index 4ab4583..aa23598 100644 --- a/services/email/accounts.go +++ b/services/email/accounts.go @@ -4,7 +4,6 @@ import ( "context" "github.com/emersion/go-imap/v2/imapclient" - "github.com/piquel-fr/api/database" "github.com/piquel-fr/api/database/repository" ) @@ -20,29 +19,29 @@ type AccountInfo struct { Shares []string `json:"shares"` } -func (r *realEmailService) GetAccountByEmail(ctx context.Context, email string) (*repository.MailAccount, error) { - return database.Queries.GetMailAccountByEmail(ctx, email) +func (s *realEmailService) GetAccountByEmail(ctx context.Context, email string) (*repository.MailAccount, error) { + return s.storageService.GetMailAccountByEmail(ctx, email) } -func (r *realEmailService) ListAccounts(ctx context.Context, userId int32) ([]*repository.MailAccount, error) { - return database.Queries.ListUserMailAccounts(ctx, userId) +func (s *realEmailService) ListAccounts(ctx context.Context, userId int32) ([]*repository.MailAccount, error) { + return s.storageService.ListUserMailAccounts(ctx, userId) } -func (r *realEmailService) CountAccounts(ctx context.Context, userId int32) (int64, error) { - return database.Queries.CountUserMailAccounts(ctx, userId) +func (s *realEmailService) CountAccounts(ctx context.Context, userId int32) (int64, error) { + return s.storageService.CountUserMailAccounts(ctx, userId) } -func (r *realEmailService) AddAccount(ctx context.Context, params repository.AddEmailAccountParams) (int32, error) { - return database.Queries.AddEmailAccount(ctx, params) +func (s *realEmailService) AddAccount(ctx context.Context, params repository.AddEmailAccountParams) (int32, error) { + return s.storageService.AddEmailAccount(ctx, params) } -func (r *realEmailService) RemoveAccount(ctx context.Context, accountId int32) error { +func (s *realEmailService) RemoveAccount(ctx context.Context, accountId int32) error { // TODO: remove the shares as well - return database.Queries.DeleteMailAccount(ctx, accountId) + return s.storageService.DeleteMailAccount(ctx, accountId) } -func (r *realEmailService) GetAccountInfo(ctx context.Context, account *repository.MailAccount) (AccountInfo, error) { - client, err := imapclient.DialTLS(r.imapAddr, nil) +func (s *realEmailService) GetAccountInfo(ctx context.Context, account *repository.MailAccount) (AccountInfo, error) { + client, err := imapclient.DialTLS(s.imapAddr, nil) if err != nil { return AccountInfo{}, err } @@ -73,13 +72,13 @@ func (r *realEmailService) GetAccountInfo(ctx context.Context, account *reposito } // get shares - shares, err := r.GetAccountShares(ctx, account.ID) + shares, err := s.GetAccountShares(ctx, account.ID) if err != nil { return AccountInfo{}, err } for _, share := range shares { - user, err := database.Queries.GetUserById(ctx, share) + user, err := s.storageService.GetUserById(ctx, share) if err != nil { return AccountInfo{}, err } @@ -89,14 +88,14 @@ func (r *realEmailService) GetAccountInfo(ctx context.Context, account *reposito return accountInfo, nil } -func (r *realEmailService) AddShare(ctx context.Context, params repository.AddShareParams) error { - return database.Queries.AddShare(ctx, params) +func (s *realEmailService) AddShare(ctx context.Context, params repository.AddShareParams) error { + return s.storageService.AddShare(ctx, params) } -func (r *realEmailService) RemoveShare(ctx context.Context, userId, accountId int32) error { - return database.Queries.DeleteShare(ctx, userId, accountId) +func (s *realEmailService) RemoveShare(ctx context.Context, userId, accountId int32) error { + return s.storageService.DeleteShare(ctx, userId, accountId) } -func (r *realEmailService) GetAccountShares(ctx context.Context, account int32) ([]int32, error) { - return database.Queries.ListAccountShares(ctx, account) +func (s *realEmailService) GetAccountShares(ctx context.Context, account int32) ([]int32, error) { + return s.storageService.ListAccountShares(ctx, account) } diff --git a/services/email/email.go b/services/email/email.go index c89226f..44674b0 100644 --- a/services/email/email.go +++ b/services/email/email.go @@ -6,6 +6,7 @@ import ( "github.com/piquel-fr/api/config" "github.com/piquel-fr/api/database/repository" + "github.com/piquel-fr/api/services/storage" ) type EmailService interface { @@ -24,12 +25,14 @@ type EmailService interface { } type realEmailService struct { - imapAddr string + imapAddr string + storageService storage.StorageService } -func NewRealEmailService() *realEmailService { +func NewRealEmailService(storageService storage.StorageService) *realEmailService { addr := fmt.Sprintf("%s:%s", config.Envs.ImapHost, config.Envs.ImapPort) return &realEmailService{ - imapAddr: addr, + imapAddr: addr, + storageService: storageService, } } diff --git a/services/storage/storage.go b/services/storage/storage.go new file mode 100644 index 0000000..90fd203 --- /dev/null +++ b/services/storage/storage.go @@ -0,0 +1,45 @@ +package storage + +import ( + "context" + "log" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/piquel-fr/api/config" + "github.com/piquel-fr/api/database/repository" +) + +type StorageService interface { + repository.Querier + Close() +} + +type databaseStorageService struct { + repository.Queries + connection *pgxpool.Pool +} + +func NewDatabaseStorageService() StorageService { + log.Printf("[Database] Attempting to connect to the database...\n") + + connection, err := pgxpool.New(context.Background(), config.Envs.DBURL) + if err != nil { + panic(err) + } + + if err = connection.Ping(context.Background()); err != nil { + panic(err) + } + + queries := repository.New(connection) + log.Printf("[Database] Successfully connected to the database!\n") + + return &databaseStorageService{ + connection: connection, + Queries: *queries, + } +} + +func (s *databaseStorageService) Close() { + s.connection.Close() +} diff --git a/services/users/users.go b/services/users/users.go index 950cbf1..55960aa 100644 --- a/services/users/users.go +++ b/services/users/users.go @@ -10,8 +10,8 @@ import ( "strings" "github.com/piquel-fr/api/config" - "github.com/piquel-fr/api/database" "github.com/piquel-fr/api/database/repository" + "github.com/piquel-fr/api/services/storage" "github.com/piquel-fr/api/utils/errors" ) @@ -34,22 +34,24 @@ type UserService interface { ListUsers(ctx context.Context, offset, limit int32) ([]*repository.User, error) } -type realUserService struct{} +type realUserService struct { + storageService storage.StorageService +} -func NewRealUserService() *realUserService { - return &realUserService{} +func NewRealUserService(storageService storage.StorageService) UserService { + return &realUserService{storageService} } func (s *realUserService) GetUserById(ctx context.Context, id int32) (*repository.User, error) { - return database.Queries.GetUserById(ctx, id) + return s.storageService.GetUserById(ctx, id) } func (s *realUserService) GetUserByUsername(ctx context.Context, username string) (*repository.User, error) { - return database.Queries.GetUserByUsername(ctx, username) + return s.storageService.GetUserByUsername(ctx, username) } func (s *realUserService) GetUserByEmail(ctx context.Context, email string) (*repository.User, error) { - return database.Queries.GetUserByEmail(ctx, email) + return s.storageService.GetUserByEmail(ctx, email) } func (s *realUserService) GetUserFromContext(ctx context.Context) (*repository.User, error) { @@ -67,7 +69,7 @@ func (s *realUserService) UpdateUser(ctx context.Context, params repository.Upda } params.Username = username - return database.Queries.UpdateUser(ctx, params) + return s.storageService.UpdateUser(ctx, params) } func (s *realUserService) UpdateUserAdmin(ctx context.Context, params repository.UpdateUserAdminParams) error { @@ -81,7 +83,7 @@ func (s *realUserService) UpdateUserAdmin(ctx context.Context, params repository return err } - return database.Queries.UpdateUserAdmin(ctx, params) + return s.storageService.UpdateUserAdmin(ctx, params) } func (s *realUserService) RegisterUser(ctx context.Context, username, email, name, image, role string) (*repository.User, error) { @@ -102,7 +104,7 @@ func (s *realUserService) RegisterUser(ctx context.Context, username, email, nam Role: role, } - return database.Queries.AddUser(ctx, params) + return s.storageService.AddUser(ctx, params) } func (s *realUserService) DeleteUser(ctx context.Context, user *repository.User) error { @@ -150,7 +152,7 @@ func (s *realUserService) formatAndValidateUsername(ctx context.Context, usernam } // already existing users - names, err := database.Queries.ListUserNames(ctx) + names, err := s.storageService.ListUserNames(ctx) if err != nil { random = true if !force { @@ -181,7 +183,7 @@ func (s *realUserService) ListUsers(ctx context.Context, offset, limit int32) ([ if limit > 200 { limit = 200 } - return database.Queries.ListUsers(ctx, offset, limit) + return s.storageService.ListUsers(ctx, offset, limit) } func (s *realUserService) GetUsernameBlacklist() []string {