diff --git a/api/account_test.go b/api/account_test.go index 06ebc14..a5f8b80 100644 --- a/api/account_test.go +++ b/api/account_test.go @@ -144,7 +144,7 @@ func TestCreateAccountAPI(t *testing.T) { store := mockdb.NewMockStore(ctrl) tc.buildStubs(store) - server := NewServer(store) + server := newTestServer(t, store) recorder := httptest.NewRecorder() data, err := json.Marshal(tc.body) @@ -244,7 +244,7 @@ func TestListAccountAPI(t *testing.T) { store := mockdb.NewMockStore(ctrl) tc.buildStubs(store) - server := NewServer(store) + server := newTestServer(t, store) recorder := httptest.NewRecorder() url := fmt.Sprintf("/accounts?page_id=%d&page_size=%d", tc.pageID, tc.pageSize) @@ -330,7 +330,7 @@ func TestGetAccountAPI(t *testing.T) { store := mockdb.NewMockStore(ctrl) tc.buildStubs(store) - server := NewServer(store) + server := newTestServer(t, store) recorder := httptest.NewRecorder() url := fmt.Sprintf("/accounts/%d", tc.accountID) diff --git a/api/main_test.go b/api/main_test.go index 10b4f22..2e57174 100644 --- a/api/main_test.go +++ b/api/main_test.go @@ -3,10 +3,25 @@ package api import ( "os" "testing" + "time" "github.com/gin-gonic/gin" + db "github.com/sinachaichi/gault/db/sqlc" + "github.com/sinachaichi/gault/util" + "github.com/stretchr/testify/require" ) +func newTestServer(t *testing.T, store db.Store) *Server { + config := util.Config{ + TokenSymmetricKey: util.RandomString(32), + AccessTokenDuration: time.Minute, + } + + server, err := NewServer(config, store) + require.NoError(t, err) + + return server +} func TestMain(m *testing.M) { gin.SetMode(gin.TestMode) diff --git a/api/server.go b/api/server.go index bf7cbf1..5975b90 100644 --- a/api/server.go +++ b/api/server.go @@ -1,6 +1,11 @@ package api import ( + "fmt" + + "github.com/sinachaichi/gault/token" + "github.com/sinachaichi/gault/util" + "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" "github.com/go-playground/validator/v10" @@ -8,26 +13,46 @@ import ( ) type Server struct{ + config util.Config store db.Store + tokenMaker token.Maker router *gin.Engine } -func NewServer(store db.Store) *Server { - server := &Server{store: store} - router := gin.Default() +func (server *Server) setupRouter() { + router := gin.Default() + + router.POST("/users", server.createUser) + router.POST("/users/login", server.loginUser) + + router.POST("/accounts", server.createAccount) + router.GET("/accounts/:id", server.getAccount) + router.GET("/accounts", server.listAccount) + + router.POST("/transfers", server.createTransfer) + + server.router = router +} + + +func NewServer(config util.Config, store db.Store) (*Server, error) { + tokenMaker, err := token.NewJWTMaker(config.TokenSymmetricKey) + if err != nil { + return nil, fmt.Errorf("cannot create token maker: %w", err) + } + server := &Server{ + config: config, + store: store, + tokenMaker: tokenMaker, + } if v, ok := binding.Validator.Engine().(*validator.Validate); ok { v.RegisterValidation("currency", validCurrency) } - router.POST("/users", server.createUser) - router.POST("/accounts", server.createAccount) - router.GET("/accounts/:id", server.getAccount) - router.GET("/accounts", server.listAccount) - router.POST("/transfers", server.createTransfer) - server.router = router - return server + server.setupRouter() + return server, nil } func (server *Server) Start(address string) error { diff --git a/api/transfer_test.go b/api/transfer_test.go index f1019e5..37d0dc0 100644 --- a/api/transfer_test.go +++ b/api/transfer_test.go @@ -238,7 +238,7 @@ func TestCreateTransferAPI(t *testing.T) { store := mockdb.NewMockStore(ctrl) tc.buildStubs(store) - server := NewServer(store) + server := newTestServer(t, store) recorder := httptest.NewRecorder() data, err := json.Marshal(tc.body) diff --git a/api/user.go b/api/user.go index 1ed936d..f72d564 100644 --- a/api/user.go +++ b/api/user.go @@ -1,6 +1,7 @@ package api import ( + "database/sql" "net/http" "time" @@ -17,7 +18,7 @@ type createUserRequest struct { Email string `json:"email" binding:"required,email"` } -type createUserResponse struct { +type userResponse struct { Username string `json:"username"` FullName string `json:"full_name"` Email string `json:"email"` @@ -26,6 +27,30 @@ type createUserResponse struct { } +type loginUserRequest struct { + Username string `json:"username" binding:"required,alphanum"` + Password string `json:"password" binding:"required,min=6"` +} + + +type loginUserResponse struct { + AccessToken string `json:"access_token"` + User userResponse `json:"user"` +} + + + +func newUserResponse(user db.User) userResponse { + return userResponse{ + Username: user.Username, + FullName: user.FullName, + Email: user.Email, + PasswordChangedAt: user.PasswordChangedAt, + CreatedAt: user.CreatedAt, + } +} + + func (server *Server) createUser(ctx *gin.Context) { var req createUserRequest if err := ctx.ShouldBindJSON(&req); err != nil { @@ -60,12 +85,46 @@ func (server *Server) createUser(ctx *gin.Context) { return } - rsp := createUserResponse{ - Username: user.Username, - FullName: user.FullName, - Email: user.Email, - PasswordChangedAt: user.PasswordChangedAt, - CreatedAt: user.CreatedAt, + rsp := newUserResponse(user) + ctx.JSON(http.StatusOK, rsp) +} + + +func (server *Server) loginUser(ctx *gin.Context) { + var req loginUserRequest + if err := ctx.ShouldBindJSON(&req); err != nil { + ctx.JSON(http.StatusBadRequest, errorResponse(err)) + return + } + + user, err := server.store.GetUser(ctx, req.Username) + if err != nil { + if err == sql.ErrNoRows { + ctx.JSON(http.StatusNotFound, errorResponse(err)) + return + } + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + err = util.CheckPassword(req.Password, user.HashedPassword) + if err != nil { + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + accessToken, err := server.tokenMaker.CreateToken( + user.Username, + server.config.AccessTokenDuration, + ) + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + rsp := loginUserResponse{ + AccessToken: accessToken, + User: newUserResponse(user), } ctx.JSON(http.StatusOK, rsp) } \ No newline at end of file diff --git a/api/user_test.go b/api/user_test.go index d4a5f24..55acd89 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -207,7 +207,7 @@ func TestCreateUserAPI(t *testing.T) { store := mockdb.NewMockStore(ctrl) tc.buildStubs(store) - server := NewServer(store) + server := newTestServer(t, store) recorder := httptest.NewRecorder() // Marshal body data to JSON diff --git a/app.env b/app.env index 5451792..4a639d0 100644 --- a/app.env +++ b/app.env @@ -1,3 +1,5 @@ DB_DRIVER="postgres" DB_SOURCE="postgresql://root:secret@localhost:25432/bank_db?sslmode=disable" -SERVER_ADDRESS="0.0.0.0:8080" \ No newline at end of file +SERVER_ADDRESS="0.0.0.0:8080" +TOKEN_SYMMETRIC_KEY="12345678901234567890123456789012" +ACCESS_TOKEN_DURATION="15m" \ No newline at end of file diff --git a/go.mod b/go.mod index 5f6fa8b..ba12d7d 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,13 @@ require ( go.uber.org/mock v0.6.0 ) +require ( + github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect + github.com/aead/chacha20poly1305 v0.0.0-20170617001512-233f39982aeb // indirect + github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 // indirect + github.com/pkg/errors v0.8.0 // indirect +) + require ( github.com/bytedance/gopkg v0.1.4 // indirect github.com/bytedance/sonic v1.15.0 // indirect @@ -21,16 +28,19 @@ require ( github.com/gin-contrib/sse v1.1.1 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.30.2 // indirect + github.com/go-playground/validator/v10 v10.30.2 github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/goccy/go-json v0.10.6 // indirect github.com/goccy/go-yaml v1.19.2 // indirect + github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/google/uuid v1.6.0 github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.21 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/o1egl/paseto v1.0.0 github.com/pelletier/go-toml/v2 v2.3.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect @@ -46,10 +56,10 @@ require ( go.mongodb.org/mongo-driver/v2 v2.5.1 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/arch v0.26.0 // indirect - golang.org/x/crypto v0.50.0 // indirect + golang.org/x/crypto v0.51.0 golang.org/x/net v0.53.0 // indirect - golang.org/x/sys v0.43.0 // indirect - golang.org/x/text v0.36.0 // indirect + golang.org/x/sys v0.44.0 // indirect + golang.org/x/text v0.37.0 // indirect google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 12c85cf..702bd41 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,9 @@ +github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= +github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= +github.com/aead/chacha20poly1305 v0.0.0-20170617001512-233f39982aeb h1:6Z/wqhPFZ7y5ksCEV/V5MXOazLaeu/EW97CU5rz8NWk= +github.com/aead/chacha20poly1305 v0.0.0-20170617001512-233f39982aeb/go.mod h1:UzH9IX1MMqOcwhoNOIjmTQeAxrFgzs50j4golQtXXxU= +github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw= +github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635/go.mod h1:lmLxL+FV291OopO93Bwf9fQLQeLyt33VJRUg5VJ30us= github.com/bytedance/gopkg v0.1.4 h1:oZnQwnX82KAIWb7033bEwtxvTqXcYMxDBaQxo5JJHWM= github.com/bytedance/gopkg v0.1.4/go.mod h1:v1zWfPm21Fb+OsyXN2VAHdL6TBb2L88anLQgdyje6R4= github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= @@ -33,9 +39,13 @@ github.com/goccy/go-json v0.10.6 h1:p8HrPJzOakx/mn/bQtjgNjdTcN+/S6FcG2CTtQOrHVU= github.com/goccy/go-json v0.10.6/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= @@ -55,8 +65,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/o1egl/paseto v1.0.0 h1:bwpvPu2au176w4IBlhbyUv/S5VPptERIA99Oap5qUd0= +github.com/o1egl/paseto v1.0.0/go.mod h1:5HxsZPmw/3RI2pAwGo1HhOOwSdvBpcuVzO7uDkm+CLU= github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM= github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= @@ -81,6 +95,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -102,14 +117,16 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/arch v0.26.0 h1:jZ6dpec5haP/fUv1kLCbuJy6dnRrfX6iVK08lZBFpk4= golang.org/x/arch v0.26.0/go.mod h1:0X+GdSIP+kL5wPmpK7sdkEVTt2XoYP0cSjQSbZBwOi8= -golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= -golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/crypto v0.0.0-20181025213731-e84da0312774/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= -golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= -golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= -golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/main.go b/main.go index 31b7885..ce53d39 100644 --- a/main.go +++ b/main.go @@ -22,7 +22,10 @@ func main() { } store := db.NewStore(conn) - server := api.NewServer(store) + server, err := api.NewServer(config, store) + if err != nil { + log.Fatal("cannot create server:", err) + } err = server.Start(config.ServerAddress) if err != nil { diff --git a/token/jwt_maker.go b/token/jwt_maker.go new file mode 100644 index 0000000..a0fe117 --- /dev/null +++ b/token/jwt_maker.go @@ -0,0 +1,59 @@ +package token + +import ( + "errors" + "fmt" + "time" + + "github.com/golang-jwt/jwt" +) + +type JWTMaker struct { + secretKey string +} + + +const minSecretKeySize = 32 + +func NewJWTMaker(secretKey string) (Maker, error) { + if len(secretKey) < minSecretKeySize { + return nil, fmt.Errorf("invalid key size: must be at least %d characters", minSecretKeySize) + } + return &JWTMaker{secretKey}, nil +} + +func (maker *JWTMaker) CreateToken(username string, duration time.Duration) (string, error) { + payload, err := NewPayload(username, duration) + if err != nil { + return "", err + } + + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) + return jwtToken.SignedString([]byte(maker.secretKey)) +} + +func (maker *JWTMaker) VerifyToken(token string) (*Payload, error) { + keyFunc := func(token *jwt.Token) (interface{}, error) { + _, ok := token.Method.(*jwt.SigningMethodHMAC) + if !ok { + return nil, ErrInvalidToken + } + return []byte(maker.secretKey), nil + } + + jwtToken, err := jwt.ParseWithClaims(token, &Payload{}, keyFunc) + if err != nil { + verr, ok := err.(*jwt.ValidationError) + if ok && errors.Is(verr.Inner, ErrExpiredToken) { + return nil, ErrExpiredToken + } + return nil, ErrInvalidToken + } + payload, ok := jwtToken.Claims.(*Payload) + if !ok { + return nil, ErrInvalidToken + } + + return payload, nil +} + diff --git a/token/jwt_maker_test.go b/token/jwt_maker_test.go new file mode 100644 index 0000000..b74e96f --- /dev/null +++ b/token/jwt_maker_test.go @@ -0,0 +1,67 @@ +package token + +import ( + "testing" + "time" + + "github.com/golang-jwt/jwt" + "github.com/sinachaichi/gault/util" + "github.com/stretchr/testify/require" +) + +func TestJWTMaker(t *testing.T) { + maker, err := NewJWTMaker(util.RandomString(32)) + require.NoError(t, err) + + username := util.RandomOwner() + duration := time.Minute + + issuedAt := time.Now() + expiredAt := issuedAt.Add(duration) + + token, err := maker.CreateToken(username, duration) + require.NoError(t, err) + require.NotEmpty(t, token) + + payload, err := maker.VerifyToken(token) + require.NoError(t, err) + require.NotEmpty(t, token) + + require.NotZero(t, payload.ID) + require.Equal(t, username, payload.Username) + require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second) + require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second) +} + + +func TestExpiredJWTToken(t *testing.T) { + maker, err := NewJWTMaker(util.RandomString(32)) + require.NoError(t, err) + + token, err := maker.CreateToken(util.RandomOwner(), -time.Minute) + require.NoError(t, err) + require.NotEmpty(t, token) + + payload, err := maker.VerifyToken(token) + require.Error(t, err) + require.EqualError(t, err, ErrExpiredToken.Error()) + require.Nil(t, payload) +} + + +func TestInvalidJWTTokenAlgNone(t *testing.T) { + payload, err := NewPayload(util.RandomOwner(), time.Minute) + require.NoError(t, err) + + jwtToken := jwt.NewWithClaims(jwt.SigningMethodNone, payload) + token, err := jwtToken.SignedString(jwt.UnsafeAllowNoneSignatureType) + require.NoError(t, err) + + maker, err := NewJWTMaker(util.RandomString(32)) + require.NoError(t, err) + + payload, err = maker.VerifyToken(token) + require.Error(t, err) + require.EqualError(t, err, ErrInvalidToken.Error()) + require.Nil(t, payload) +} \ No newline at end of file diff --git a/token/maker.go b/token/maker.go new file mode 100644 index 0000000..2182faa --- /dev/null +++ b/token/maker.go @@ -0,0 +1,8 @@ +package token + +import "time" + +type Maker interface { + CreateToken(username string, duration time.Duration) (string, error) + VerifyToken(token string) (*Payload, error) +} \ No newline at end of file diff --git a/token/paseto_maker.go b/token/paseto_maker.go new file mode 100644 index 0000000..4f817e8 --- /dev/null +++ b/token/paseto_maker.go @@ -0,0 +1,56 @@ +package token + +import ( + "fmt" + "time" + + "github.com/aead/chacha20poly1305" + "github.com/o1egl/paseto" +) + + +type PasetoMaker struct { + paseto *paseto.V2 + symmetricKey []byte +} + + +func NewPasetoMaker(symmetricKey string) (Maker, error) { + if len(symmetricKey) != chacha20poly1305.KeySize { + return nil, fmt.Errorf("invalid key size: must be exactly %d characters", chacha20poly1305.KeySize) + } + + maker := &PasetoMaker{ + paseto: paseto.NewV2(), + symmetricKey: []byte(symmetricKey), + } + + return maker, nil +} + + +func (maker *PasetoMaker) CreateToken(username string, duration time.Duration) (string, error) { + payload, err := NewPayload(username, duration) + if err != nil { + return "", err + } + + return maker.paseto.Encrypt(maker.symmetricKey, payload, nil) +} + +func (maker *PasetoMaker) VerifyToken(token string) (*Payload, error) { + payload := &Payload{} + + err := maker.paseto.Decrypt(token, maker.symmetricKey, payload, nil) + if err != nil { + return nil, ErrInvalidToken + } + + err = payload.Valid() + if err != nil { + return nil, err + } + + return payload, nil +} + diff --git a/token/paseto_maker_test.go b/token/paseto_maker_test.go new file mode 100644 index 0000000..41ef528 --- /dev/null +++ b/token/paseto_maker_test.go @@ -0,0 +1,48 @@ +package token + +import ( + "testing" + "time" + + "github.com/sinachaichi/gault/util" + "github.com/stretchr/testify/require" +) + + +func TestPasetoMaker(t *testing.T) { + maker, err := NewPasetoMaker(util.RandomString(32)) + require.NoError(t, err) + + username := util.RandomOwner() + duration := time.Minute + + issuedAt := time.Now() + expiredAt := issuedAt.Add(duration) + + token, err := maker.CreateToken(username, duration) + require.NoError(t, err) + require.NotEmpty(t, token) + + payload, err := maker.VerifyToken(token) + require.NoError(t, err) + require.NotEmpty(t, token) + + require.NotZero(t, payload.ID) + require.Equal(t, username, payload.Username) + require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second) + require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second) +} + +func TestExpiredPasetoToken(t *testing.T) { + maker, err := NewPasetoMaker(util.RandomString(32)) + require.NoError(t, err) + + token, err := maker.CreateToken(util.RandomOwner(), -time.Minute) + require.NoError(t, err) + require.NotEmpty(t, token) + + payload, err := maker.VerifyToken(token) + require.Error(t, err) + require.EqualError(t, err, ErrExpiredToken.Error()) + require.Nil(t, payload) +} \ No newline at end of file diff --git a/token/payload.go b/token/payload.go new file mode 100644 index 0000000..48c9a5a --- /dev/null +++ b/token/payload.go @@ -0,0 +1,45 @@ +package token + +import ( + "errors" + "time" + + "github.com/google/uuid" +) + +type Payload struct { + ID uuid.UUID `json:"id"` + Username string `json:"username"` + IssuedAt time.Time `json:"issued_at"` + ExpiredAt time.Time `json:"expireed_at"` +} + +var ( + ErrInvalidToken = errors.New("token is invalid") + ErrExpiredToken = errors.New("token has expired") +) + + +func (payload *Payload) Valid() error { + if time.Now().After(payload.ExpiredAt) { + return ErrExpiredToken + } + return nil +} + + +func NewPayload(username string, duration time.Duration) (*Payload, error) { + tokenID, err := uuid.NewRandom() + if err != nil { + return nil, err + } + + paylod := &Payload{ + ID: tokenID, + Username: username, + IssuedAt: time.Now(), + ExpiredAt: time.Now().Add(duration), + } + + return paylod, nil +} diff --git a/util/config.go b/util/config.go index 5ccb4f1..9401264 100644 --- a/util/config.go +++ b/util/config.go @@ -1,11 +1,17 @@ package util -import "github.com/spf13/viper" +import ( + "time" + + "github.com/spf13/viper" +) type Config struct { DBDriver string `mapstructure:"DB_DRIVER"` DBSource string `mapstructure:"DB_SOURCE"` ServerAddress string `mapstructure:"SERVER_ADDRESS"` + TokenSymmetricKey string `mapstructure:"TOKEN_SYMMETRIC_KEY"` + AccessTokenDuration time.Duration `mapstructure:"ACCESS_TOKEN_DURATION"` }