diff --git a/api/account.go b/api/account.go index 78c91d7..66ab996 100644 --- a/api/account.go +++ b/api/account.go @@ -2,16 +2,17 @@ package api import ( "database/sql" + "errors" "net/http" "github.com/gin-gonic/gin" "github.com/lib/pq" db "github.com/sinachaichi/gault/db/sqlc" + "github.com/sinachaichi/gault/token" ) type createAccountRequest struct { - Owner string `json:"owner" binding:"required"` Currency string `json:"currency" binding:"required,currency"` } @@ -31,8 +32,9 @@ func (server *Server) createAccount(ctx *gin.Context){ return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) arg := db.CreateAccountParams{ - Owner: req.Owner, + Owner: authPayload.Username, Currency: req.Currency, Balance: 0, } @@ -72,6 +74,12 @@ func (server *Server) getAccount(ctx *gin.Context){ return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) + if account.Owner != authPayload.Username { + err := errors.New("account doesn't belong to the authenticated user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } ctx.JSON(http.StatusOK, account) } @@ -83,7 +91,9 @@ func (server *Server) listAccount(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) arg := db.ListAccountsParams{ + Owner: authPayload.Username, Limit: req.PageSize, Offset: (req.PageID - 1) * req.PageSize, } diff --git a/api/account_test.go b/api/account_test.go index a5f8b80..2cbf694 100644 --- a/api/account_test.go +++ b/api/account_test.go @@ -9,21 +9,22 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/gin-gonic/gin" - "github.com/lib/pq" mockdb "github.com/sinachaichi/gault/db/mock" db "github.com/sinachaichi/gault/db/sqlc" + "github.com/sinachaichi/gault/token" "github.com/sinachaichi/gault/util" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) -func randomAccount() db.Account { +func randomAccount(owner string) db.Account { return db.Account{ ID: util.RandomInt(1, 1000), - Owner: util.RandomOwner(), + Owner: owner, Balance: util.RandomMoney(), Currency: util.RandomCurrency(), } @@ -41,94 +42,89 @@ func requireBodyMatchAccount(t *testing.T, body *bytes.Buffer, account db.Accoun } func TestCreateAccountAPI(t *testing.T) { - account := randomAccount() + user, _ := randomUser(t) + account := randomAccount(user.Username) testCases := []struct { name string body gin.H + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) buildStubs func(store *mockdb.MockStore) - checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder) + checkResponse func(recoder *httptest.ResponseRecorder) }{ { name: "OK", body: gin.H{ - "owner": account.Owner, "currency": account.Currency, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { arg := db.CreateAccountParams{ Owner: account.Owner, Currency: account.Currency, Balance: 0, } + store.EXPECT(). CreateAccount(gomock.Any(), gomock.Eq(arg)). Times(1). Return(account, nil) }, - checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + checkResponse: func(recorder *httptest.ResponseRecorder) { require.Equal(t, http.StatusOK, recorder.Code) requireBodyMatchAccount(t, recorder.Body, account) }, }, { - name: "InternalError", + name: "NoAuthorization", body: gin.H{ - "owner": account.Owner, "currency": account.Currency, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). CreateAccount(gomock.Any(), gomock.Any()). - Times(1). - Return(db.Account{}, sql.ErrConnDone) + Times(0) }, - checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { - require.Equal(t, http.StatusInternalServerError, recorder.Code) + checkResponse: func(recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) }, }, { - name: "DuplicateOwnerCurrency", + name: "InternalError", body: gin.H{ - "owner": account.Owner, "currency": account.Currency, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). CreateAccount(gomock.Any(), gomock.Any()). Times(1). - Return(db.Account{}, &pq.Error{Code: "23505"}) + Return(db.Account{}, sql.ErrConnDone) }, - checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { - require.Equal(t, http.StatusForbidden, recorder.Code) + checkResponse: func(recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusInternalServerError, recorder.Code) }, }, { name: "InvalidCurrency", body: gin.H{ - "owner": account.Owner, - "currency": "XYZ", + "currency": "invalid", }, - buildStubs: func(store *mockdb.MockStore) { - store.EXPECT(). - CreateAccount(gomock.Any(), gomock.Any()). - Times(0) - }, - checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { - require.Equal(t, http.StatusBadRequest, recorder.Code) - }, - }, - { - name: "MissingOwner", - body: gin.H{ - "currency": account.Currency, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). CreateAccount(gomock.Any(), gomock.Any()). Times(0) }, - checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + checkResponse: func(recorder *httptest.ResponseRecorder) { require.Equal(t, http.StatusBadRequest, recorder.Code) }, }, @@ -147,30 +143,34 @@ func TestCreateAccountAPI(t *testing.T) { server := newTestServer(t, store) recorder := httptest.NewRecorder() + // Marshal body data to JSON data, err := json.Marshal(tc.body) require.NoError(t, err) - request, err := http.NewRequest(http.MethodPost, "/accounts", bytes.NewReader(data)) + url := "/accounts" + request, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data)) require.NoError(t, err) - request.Header.Set("Content-Type", "application/json") + tc.setupAuth(t, request, server.tokenMaker) server.router.ServeHTTP(recorder, request) - tc.checkResponse(t, recorder) + tc.checkResponse(recorder) }) } } func TestListAccountAPI(t *testing.T) { + user, _ := randomUser(t) n := 5 accounts := make([]db.Account, n) for i := 0; i < n; i++ { - accounts[i] = randomAccount() + accounts[i] = randomAccount(user.Username) } testCases := []struct { name string pageID int pageSize int + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) buildStubs func(store *mockdb.MockStore) checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder) }{ @@ -178,8 +178,12 @@ func TestListAccountAPI(t *testing.T) { name: "OK", pageID: 1, pageSize: 5, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { arg := db.ListAccountsParams{ + Owner: user.Username, Limit: 5, Offset: 0, } @@ -196,6 +200,9 @@ func TestListAccountAPI(t *testing.T) { name: "InternalError", pageID: 1, pageSize: 5, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). ListAccounts(gomock.Any(), gomock.Any()). @@ -210,6 +217,9 @@ func TestListAccountAPI(t *testing.T) { name: "InvalidPageID", pageID: 0, pageSize: 5, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). ListAccounts(gomock.Any(), gomock.Any()). @@ -223,6 +233,9 @@ func TestListAccountAPI(t *testing.T) { name: "InvalidPageSize", pageID: 1, pageSize: 2, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). ListAccounts(gomock.Any(), gomock.Any()). @@ -251,6 +264,7 @@ func TestListAccountAPI(t *testing.T) { request, err := http.NewRequest(http.MethodGet, url, nil) require.NoError(t, err) + tc.setupAuth(t, request, server.tokenMaker) server.router.ServeHTTP(recorder, request) tc.checkResponse(t, recorder) }) @@ -258,17 +272,22 @@ func TestListAccountAPI(t *testing.T) { } func TestGetAccountAPI(t *testing.T) { - account := randomAccount() + user, _ := randomUser(t) + account := randomAccount(user.Username) testCases := []struct { name string accountID int64 + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) buildStubs func(store *mockdb.MockStore) checkResponse func(t *testing.T, recoder *httptest.ResponseRecorder) }{ { name: "OK", accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). @@ -279,10 +298,42 @@ func TestGetAccountAPI(t *testing.T) { require.Equal(t, http.StatusOK, recorder.Code) requireBodyMatchAccount(t, recorder.Body, account) }, + }, + { + name: "UnauthorizedUser", + accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "unauthorized_user", time.Minute) + }, + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + GetAccount(gomock.Any(), gomock.Eq(account.ID)). + Times(1). + Return(account, nil) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "NoAuthorization", + accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){}, + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + GetAccount(gomock.Any(), gomock.Any()). + Times(0) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, }, { name: "NotFound", accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). @@ -296,6 +347,9 @@ func TestGetAccountAPI(t *testing.T) { { name: "InternalError", accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). @@ -309,6 +363,9 @@ func TestGetAccountAPI(t *testing.T) { { name: "InvalidID", accountID: 0, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Any()). @@ -336,7 +393,8 @@ func TestGetAccountAPI(t *testing.T) { url := fmt.Sprintf("/accounts/%d", tc.accountID) request, err := http.NewRequest(http.MethodGet, url, nil) require.NoError(t, err) - + + tc.setupAuth(t, request, server.tokenMaker) server.router.ServeHTTP(recorder, request) tc.checkResponse(t, recorder) }) diff --git a/api/middleware.go b/api/middleware.go new file mode 100644 index 0000000..6bd2a5b --- /dev/null +++ b/api/middleware.go @@ -0,0 +1,57 @@ +package api + +import ( + "errors" + "fmt" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/sinachaichi/gault/token" +) + + +const ( + authorizationHeaderKey = "authorization" + authorizationTypeBearer = "bearer" + authorizationPayloadKey = "authorization_payload" +) + + +func authMiddleWare(tokenMaker token.Maker) gin.HandlerFunc { + return func(ctx *gin.Context) { + authorizationHeaderKey := ctx.GetHeader(authorizationHeaderKey) + fmt.Println(authorizationHeaderKey) + fmt.Println("*************************************") + if len(authorizationHeaderKey) == 0 { + err := errors.New("authorization key is not provided") + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + fields := strings.Fields(authorizationHeaderKey) + if len(fields) < 2 { + err := errors.New("invalid authorization header format") + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + authorizationType := strings.ToLower(fields[0]) + if authorizationType != authorizationTypeBearer { + err := fmt.Errorf("unsupported authorization type %s", authorizationType) + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + accessToken := fields[1] + payload, err := tokenMaker.VerifyToken(accessToken) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + ctx.Set(authorizationPayloadKey, payload) + ctx.Next() + + } +} \ No newline at end of file diff --git a/api/middleware_test.go b/api/middleware_test.go new file mode 100644 index 0000000..3dc1297 --- /dev/null +++ b/api/middleware_test.go @@ -0,0 +1,101 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/sinachaichi/gault/token" + "github.com/stretchr/testify/require" +) + + +func addAuthorization( + t *testing.T, + request *http.Request, + tokenMaker token.Maker, + authorizationType string, + username string, + duration time.Duration, +) { + token, err := tokenMaker.CreateToken(username, duration) + require.NoError(t, err) + authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token) + request.Header.Set(authorizationHeaderKey, authorizationHeader) +} +func TestAuthMiddleWare(t *testing.T) { + testCases := []struct{ + name string + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) + checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder) + }{ + { + name: "Ok", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "username", time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder){ + require.Equal(t, http.StatusOK, recorder.Code) + }, + }, + { + name: "No Authorization", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){}, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder){ + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "Unsupported Authorization", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ + addAuthorization(t, request, tokenMaker, "unsupported type", "username", time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder){ + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "Invalid Authorization Format", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ + addAuthorization(t, request, tokenMaker, "", "username", time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder){ + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "Expired Authorization Token", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker){ + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "username", -time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder){ + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + } + + for i := range testCases { + tc := testCases[i] + t.Run(tc.name, func(t *testing.T){ + server := newTestServer(t, nil) + authPath := "/auth" + server.router.GET( + authPath, + authMiddleWare(server.tokenMaker), + func(ctx *gin.Context) { + ctx.JSON(http.StatusOK, gin.H{}) + }, + ) + recorder := httptest.NewRecorder() + request, err := http.NewRequest(http.MethodGet, authPath, nil) + require.NoError(t, err) + tc.setupAuth(t, request, server.tokenMaker) + server.router.ServeHTTP(recorder, request) + tc.checkResponse(t, recorder) + + }) + } +} \ No newline at end of file diff --git a/api/server.go b/api/server.go index 5975b90..d18e968 100644 --- a/api/server.go +++ b/api/server.go @@ -26,11 +26,13 @@ func (server *Server) setupRouter() { router.POST("/users", server.createUser) router.POST("/users/login", server.loginUser) - router.POST("/accounts", server.createAccount) - router.GET("/accounts/:id", server.getAccount) - router.GET("/accounts", server.listAccount) + authRoutes := router.Group("/").Use(authMiddleWare(server.tokenMaker)) - router.POST("/transfers", server.createTransfer) + authRoutes.POST("/accounts", server.createAccount) + authRoutes.GET("/accounts/:id", server.getAccount) + authRoutes.GET("/accounts", server.listAccount) + + authRoutes.POST("/transfers", server.createTransfer) server.router = router } diff --git a/api/transfer.go b/api/transfer.go index fcd38ed..14d9375 100644 --- a/api/transfer.go +++ b/api/transfer.go @@ -2,11 +2,13 @@ package api import ( "database/sql" + "errors" "fmt" "net/http" "github.com/gin-gonic/gin" db "github.com/sinachaichi/gault/db/sqlc" + "github.com/sinachaichi/gault/token" ) type transferRequest struct { @@ -17,25 +19,25 @@ type transferRequest struct { } -func (server *Server) validAccount(ctx *gin.Context, accountID int64, currency string) bool { +func (server *Server) validAccount(ctx *gin.Context, accountID int64, currency string) (db.Account, bool) { account, err := server.store.GetAccount(ctx, accountID) if err != nil { if err == sql.ErrNoRows { ctx.JSON(http.StatusNotFound, errorResponse(err)) - return false + return account, false } ctx.JSON(http.StatusInternalServerError, errorResponse(err)) - return false + return account, false } if account.Currency != currency { err := fmt.Errorf("account [%d] currency mismatch: %s vs %s", account.ID, account.Currency, currency) ctx.JSON(http.StatusBadRequest, errorResponse(err)) - return false + return account, false } - return true + return account, true } @@ -46,11 +48,22 @@ func (server *Server) createTransfer(ctx *gin.Context) { return } - if !server.validAccount(ctx, req.FromAccountID, req.Currency) { + fromAccount, valid := server.validAccount(ctx, req.FromAccountID, req.Currency) + + if !valid { return } - if !server.validAccount(ctx, req.ToAccountID, req.Currency) { + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) + if fromAccount.Owner != authPayload.Username { + err := errors.New("account doesn't blong to the authenticated user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + _, valid = server.validAccount(ctx, req.ToAccountID, req.Currency) + + if !valid { return } diff --git a/api/transfer_test.go b/api/transfer_test.go index 37d0dc0..12c9f7d 100644 --- a/api/transfer_test.go +++ b/api/transfer_test.go @@ -7,10 +7,12 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/gin-gonic/gin" mockdb "github.com/sinachaichi/gault/db/mock" db "github.com/sinachaichi/gault/db/sqlc" + "github.com/sinachaichi/gault/token" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) @@ -18,9 +20,14 @@ import ( func TestCreateTransferAPI(t *testing.T) { amount := int64(10) - account1 := randomAccount() - account2 := randomAccount() - account3 := randomAccount() + user1, _ := randomUser(t) + account1 := randomAccount(user1.Username) + + user2, _ := randomUser(t) + account2 := randomAccount(user2.Username) + + user3, _ := randomUser(t) + account3 := randomAccount(user3.Username) account1.Currency = "USD" account2.Currency = "USD" @@ -29,6 +36,7 @@ func TestCreateTransferAPI(t *testing.T) { testCases := []struct { name string body gin.H + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) buildStubs func(store *mockdb.MockStore) checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder) }{ @@ -40,6 +48,9 @@ func TestCreateTransferAPI(t *testing.T) { "amount": amount, "currency": "USD", }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account1.ID)). @@ -70,6 +81,9 @@ func TestCreateTransferAPI(t *testing.T) { "amount": amount, "currency": "USD", }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account1.ID)). @@ -90,6 +104,9 @@ func TestCreateTransferAPI(t *testing.T) { "amount": amount, "currency": "USD", }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account1.ID)). @@ -113,6 +130,9 @@ func TestCreateTransferAPI(t *testing.T) { "amount": amount, "currency": "USD", }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user3.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account3.ID)). @@ -133,6 +153,9 @@ func TestCreateTransferAPI(t *testing.T) { "amount": amount, "currency": "USD", }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account1.ID)). @@ -156,6 +179,9 @@ func TestCreateTransferAPI(t *testing.T) { "amount": amount, "currency": "XYZ", }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(0) store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(0) @@ -172,6 +198,9 @@ func TestCreateTransferAPI(t *testing.T) { "amount": -amount, "currency": "USD", }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(0) store.EXPECT().TransferTx(gomock.Any(), gomock.Any()).Times(0) @@ -188,6 +217,9 @@ func TestCreateTransferAPI(t *testing.T) { "amount": amount, "currency": "USD", }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account1.ID)). @@ -208,6 +240,9 @@ func TestCreateTransferAPI(t *testing.T) { "amount": amount, "currency": "USD", }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user1.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account1.ID)). @@ -248,6 +283,7 @@ func TestCreateTransferAPI(t *testing.T) { require.NoError(t, err) request.Header.Set("Content-Type", "application/json") + tc.setupAuth(t, request, server.tokenMaker) server.router.ServeHTTP(recorder, request) tc.checkResponse(t, recorder) }) diff --git a/db/query/account.sql b/db/query/account.sql index 9e9de38..9aff937 100644 --- a/db/query/account.sql +++ b/db/query/account.sql @@ -13,9 +13,10 @@ WHERE id = $1 LIMIT 1; -- name: ListAccounts :many SELECT * FROM accounts +WHERE owner = $1 ORDER BY id -LIMIT $1 -OFFSET $2; +LIMIT $2 +OFFSET $3; -- name: UpdateAccount :one UPDATE accounts diff --git a/db/sqlc/account.sql.go b/db/sqlc/account.sql.go index f73155c..e75f572 100644 --- a/db/sqlc/account.sql.go +++ b/db/sqlc/account.sql.go @@ -112,18 +112,20 @@ func (q *Queries) GetAccountForUpdate(ctx context.Context, id int64) (Account, e const listAccounts = `-- name: ListAccounts :many SELECT id, owner, balance, currency, created_at FROM accounts +WHERE owner = $1 ORDER BY id -LIMIT $1 -OFFSET $2 +LIMIT $2 +OFFSET $3 ` type ListAccountsParams struct { - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` + Owner string `json:"owner"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` } func (q *Queries) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]Account, error) { - rows, err := q.db.QueryContext(ctx, listAccounts, arg.Limit, arg.Offset) + rows, err := q.db.QueryContext(ctx, listAccounts, arg.Owner, arg.Limit, arg.Offset) if err != nil { return nil, err } diff --git a/db/sqlc/account_test.go b/db/sqlc/account_test.go index d5dc903..45f7479 100644 --- a/db/sqlc/account_test.go +++ b/db/sqlc/account_test.go @@ -102,20 +102,23 @@ func TestWithTx(t *testing.T) { } func TestListAccounts(t *testing.T) { + var lastAccount Account for i := 0; i < 10; i++ { - createRandomAccount(t) + lastAccount = createRandomAccount(t) } arg := ListAccountsParams{ + Owner: lastAccount.Owner, Limit: 5, - Offset: 5, + Offset: 0, } accounts, err := testQueries.ListAccounts(context.Background(), arg) require.NoError(t, err) - require.Len(t, accounts, 5) + require.NotEmpty(t, accounts) for _, account := range accounts { require.NotEmpty(t, account) + require.Equal(t, account.Owner, lastAccount.Owner) } } \ No newline at end of file