diff --git a/api/http/api_devauth.go b/api/http/api_devauth.go index 1ea34592..24d8f746 100644 --- a/api/http/api_devauth.go +++ b/api/http/api_devauth.go @@ -52,6 +52,7 @@ const ( uriTenantDevicesCount = "/api/internal/v1/devauth/tenants/#tid/devices/count" // management API v2 + v2uriBulkDevices = "/api/management/v2/devauth/bulk/devices" v2uriDevices = "/api/management/v2/devauth/devices" v2uriDevicesCount = "/api/management/v2/devauth/devices/count" v2uriDevicesSearch = "/api/management/v2/devauth/devices/search" @@ -76,16 +77,22 @@ var ( type DevAuthApiHandlers struct { devAuth devauth.App db store.DataStore + limits Limits } type DevAuthApiStatus struct { Status string `json:"status"` } -func NewDevAuthApiHandlers(devAuth devauth.App, db store.DataStore) ApiHandler { +type Limits struct { + MaxPreAuthElements int +} + +func NewDevAuthApiHandlers(devAuth devauth.App, db store.DataStore, limits Limits) ApiHandler { return &DevAuthApiHandlers{ devAuth: devAuth, db: db, + limits: limits, } } @@ -112,6 +119,7 @@ func (d *DevAuthApiHandlers) GetApp() (rest.App, error) { rest.Get(v2uriDevices, d.GetDevicesV2Handler), rest.Post(v2uriDevicesSearch, d.SearchDevicesV2Handler), rest.Post(v2uriDevices, d.PostDevicesV2Handler), + rest.Post(v2uriBulkDevices, d.PostBulkDevicesV2Handler), rest.Get(v2uriDevice, d.GetDeviceV2Handler), rest.Delete(v2uriDevice, d.DeleteDeviceHandler), rest.Delete(v2uriDeviceAuthSet, d.DeleteDeviceAuthSetHandler), @@ -238,6 +246,47 @@ func (d *DevAuthApiHandlers) SubmitAuthRequestHandler(w rest.ResponseWriter, r * } } +func (d *DevAuthApiHandlers) PostBulkDevicesV2Handler(w rest.ResponseWriter, r *rest.Request) { + ctx := r.Context() + + l := log.FromContext(ctx) + + reqs, err := parsePreAuthReqs(r.Body) + if err != nil { + err = errors.Wrap(err, "failed to decode preauth request") + rest_utils.RestErrWithLog(w, r, l, err, http.StatusBadRequest) + return + } + + count := 0 + for _, req := range reqs { + if count > d.limits.MaxPreAuthElements { + break + } + reqDbModel, err := req.getDbModel() + if err != nil { + rest_utils.RestErrWithLogInternal(w, r, l, err) + return + } + + device, err := d.devAuth.PreauthorizeDevice(ctx, reqDbModel) + switch err { + case nil: + w.Header().Set("Location", "devices/"+reqDbModel.DeviceId) + case devauth.ErrDeviceExists: + l.Error(err) + w.WriteHeader(http.StatusConflict) + _ = w.WriteJson(device) + return + default: + rest_utils.RestErrWithLogInternal(w, r, l, err) + return + } + count++ + } + w.WriteHeader(http.StatusCreated) +} + func (d *DevAuthApiHandlers) PostDevicesV2Handler(w rest.ResponseWriter, r *rest.Request) { ctx := r.Context() diff --git a/api/http/api_devauth_test.go b/api/http/api_devauth_test.go index f1a54e32..42ef35b3 100644 --- a/api/http/api_devauth_test.go +++ b/api/http/api_devauth_test.go @@ -69,7 +69,8 @@ func runTestRequest(t *testing.T, handler http.Handler, req *http.Request, code } func makeMockApiHandler(t *testing.T, da devauth.App, db store.DataStore) http.Handler { - handlers := NewDevAuthApiHandlers(da, db) + defaultLimits := 128 + handlers := NewDevAuthApiHandlers(da, db, Limits{MaxPreAuthElements: defaultLimits}) assert.NotNil(t, handlers) app, err := handlers.GetApp() @@ -395,6 +396,176 @@ func (d *DevicePreauthReturnID) CheckHeaders(t *testing.T, recorded *test.Record assert.Contains(t, recorded.Recorder.HeaderMap["Location"][0], "devices/") } +func TestApiV2DevAuthBulkPreauthDevice(t *testing.T) { + t.Parallel() + + // enforce specific field naming in errors returned by API + updateRestErrorFieldName() + + pubkeyStr := mtest.LoadPubKeyStr("testdata/public.pem") + + type brokenPreAuthReq struct { + IdData string `json:"identity_data"` + PubKey string `json:"pubkey"` + } + + testCases := map[string]struct { + body interface{} + + devAuthErr error + outDev *model.Device + + callApp bool + + checker mt.ResponseChecker + }{ + "ok": { + body: []preAuthReq{ + { + IdData: map[string]interface{}{ + "sn": "0001", + }, + PubKey: pubkeyStr, + }, + }, + callApp: true, + checker: mt.NewJSONResponse( + http.StatusCreated, + nil, + nil), + }, + "ok - verify Location header": { + body: []preAuthReq{ + { + IdData: map[string]interface{}{ + "sn": "0001", + }, + PubKey: pubkeyStr, + }, + }, + callApp: true, + checker: NewJSONResponseIDChecker( + http.StatusCreated, + map[string]string{"Location": "devices/somegeneratedid"}, + nil), + }, + "invalid: id data is not json": { + body: []brokenPreAuthReq{ + { + IdData: `"sn":"0001"`, + PubKey: pubkeyStr, + }, + }, + checker: mt.NewJSONResponse( + http.StatusBadRequest, + nil, + restError("failed to decode preauth request: json: cannot unmarshal string into Go struct field preAuthReq.identity_data of type map[string]interface {}")), + }, + "invalid: no id data": { + body: []preAuthReq{ + { + PubKey: pubkeyStr, + }, + }, + checker: mt.NewJSONResponse( + http.StatusBadRequest, + nil, + restError("failed to decode preauth request: identity_data: cannot be blank.")), + }, + "invalid: no pubkey": { + body: []preAuthReq{ + { + IdData: map[string]interface{}{ + "sn": "0001", + }, + }, + }, + checker: mt.NewJSONResponse( + http.StatusBadRequest, + nil, + restError("failed to decode preauth request: pubkey: cannot be blank.")), + }, + "invalid: no body": { + checker: mt.NewJSONResponse( + http.StatusBadRequest, + nil, + restError("failed to decode preauth request: EOF")), + }, + "invalid public key": { + body: []preAuthReq{ + { + IdData: map[string]interface{}{ + "sn": "0001", + }, + PubKey: "invalid", + }, + }, + devAuthErr: devauth.ErrDeviceExists, + checker: mt.NewJSONResponse( + http.StatusBadRequest, + nil, + restError("failed to decode preauth request: cannot decode public key")), + }, + "devauth: device exists": { + body: []preAuthReq{ + { + IdData: map[string]interface{}{ + "sn": "0001", + }, + PubKey: pubkeyStr, + }, + }, + devAuthErr: devauth.ErrDeviceExists, + outDev: &model.Device{Id: "foo"}, + callApp: true, + checker: mt.NewJSONResponse( + http.StatusConflict, + nil, + model.Device{Id: "foo"}), + }, + "devauth: generic error": { + body: []preAuthReq{ + { + IdData: map[string]interface{}{ + "sn": "0001", + }, + PubKey: pubkeyStr, + }, + }, + callApp: true, + devAuthErr: errors.New("generic"), + checker: mt.NewJSONResponse( + http.StatusInternalServerError, + nil, + restError("internal error")), + }, + } + + for name, tc := range testCases { + t.Run(fmt.Sprintf("tc %s", name), func(t *testing.T) { + da := &mocks.App{} + if tc.callApp { + da.On("PreauthorizeDevice", + mtest.ContextMatcher(), + mock.AnythingOfType("*model.PreAuthReq")). + Return(tc.outDev, tc.devAuthErr) + } + + apih := makeMockApiHandler(t, da, nil) + + //make request + req := makeReq("POST", + "http://1.2.3.4/api/management/v2/devauth/bulk/devices", + "", + tc.body) + + recorded := test.RunRequest(t, apih, req) + mt.CheckResponse(t, tc.checker, recorded) + da.AssertExpectations(t) + }) + } +} + func TestApiV2DevAuthPreauthDevice(t *testing.T) { t.Parallel() diff --git a/api/http/model_pre_ath_req_test.go b/api/http/model_pre_ath_req_test.go new file mode 100644 index 00000000..76a6381c --- /dev/null +++ b/api/http/model_pre_ath_req_test.go @@ -0,0 +1,80 @@ +// Copyright 2018 Northern.tech AS +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package http + +import ( + "encoding/json" + "fmt" + mtest "github.com/mendersoftware/deviceauth/utils/testing" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParsePreAuthReqs(t *testing.T) { + t.Parallel() + + pubkeyStr := mtest.LoadPubKeyStr("testdata/public.pem") + + testCases := map[string]struct { + input interface{} + }{ + "ok": { + input: []preAuthReq{ + { + IdData: map[string]interface{}{ + "sn": "0001", + }, + PubKey: pubkeyStr, + }, + }, + }, + } + for name, tc := range testCases { + t.Run(fmt.Sprintf("tc %s", name), func(t *testing.T) { + data, _ := json.Marshal(tc.input) + req, err := parsePreAuthReqs(strings.NewReader(string(data))) + assert.NoError(t, err) + assert.Equal(t, tc.input, req) + }) + } +} + +func TestParsePreAuthReq(t *testing.T) { + t.Parallel() + + pubkeyStr := mtest.LoadPubKeyStr("testdata/public.pem") + + testCases := map[string]struct { + input interface{} + }{ + "ok": { + input: &preAuthReq{ + IdData: map[string]interface{}{ + "sn": "0001", + }, + PubKey: pubkeyStr, + }, + }, + } + for name, tc := range testCases { + t.Run(fmt.Sprintf("tc %s", name), func(t *testing.T) { + data, _ := json.Marshal(tc.input) + req, err := parsePreAuthReq(strings.NewReader(string(data))) + assert.NoError(t, err) + assert.Equal(t, tc.input, req) + }) + } +} diff --git a/api/http/model_pre_auth_req.go b/api/http/model_pre_auth_req.go index 337d9f78..289e050c 100644 --- a/api/http/model_pre_auth_req.go +++ b/api/http/model_pre_auth_req.go @@ -45,6 +45,24 @@ func parsePreAuthReq(source io.Reader) (*preAuthReq, error) { return &req, nil } +func parsePreAuthReqs(source io.Reader) ([]preAuthReq, error) { + jd := json.NewDecoder(source) + + var req []preAuthReq + + if err := jd.Decode(&req); err != nil { + return nil, err + } + + for _, r := range req { + if err := r.validate(); err != nil { + return nil, err + } + } + + return req, nil +} + func (r *preAuthReq) validate() error { err := validation.ValidateStruct(r, validation.Field(&r.IdData, validation.Required), diff --git a/config/config.go b/config/config.go index dc2e09ea..646f96a7 100644 --- a/config/config.go +++ b/config/config.go @@ -82,6 +82,9 @@ const ( SettingRedisLimitsExpSec = "redis_limits_expire_sec" SettingRedisLimitsExpSecDefault = "1800" + SettingMaxPreAuthElements = "max_pre_auth_requests" + SettingMaxPreAuthElementsDefault = "128" + // SettingHaveAddons is a feature toggle for using addon restrictions. // Has no effect if not running in multi-tenancy context. SettingHaveAddons = "have_addons" @@ -112,6 +115,7 @@ var ( {Key: SettingRedisTimeoutSec, Value: SettingRedisTimeoutSecDefault}, {Key: SettingRedisDb, Value: SettingRedisDbDefault}, {Key: SettingRedisLimitsExpSec, Value: SettingRedisLimitsExpSecDefault}, + {Key: SettingMaxPreAuthElements, Value: SettingMaxPreAuthElementsDefault}, {Key: SettingHaveAddons, Value: SettingHaveAddonsDefault}, } ) diff --git a/devauth/devauth.go b/devauth/devauth.go index a03321c9..1af99003 100644 --- a/devauth/devauth.go +++ b/devauth/devauth.go @@ -108,6 +108,7 @@ type App interface { RejectDeviceAuth(ctx context.Context, dev_id string, auth_id string) error ResetDeviceAuth(ctx context.Context, dev_id string, auth_id string) error PreauthorizeDevice(ctx context.Context, req *model.PreAuthReq) (*model.Device, error) + PreauthorizeDevices(ctx context.Context, req []model.PreAuthReq) error RevokeToken(ctx context.Context, tokenID string) error VerifyToken(ctx context.Context, token string) error @@ -238,6 +239,49 @@ func (d *DevAuth) setDeviceIdentity(ctx context.Context, dev *model.Device, tena return nil } +func (d *DevAuth) setDeviceIdentities(ctx context.Context, devices []*model.Device, tenantId string) error { + for _, dev := range devices { + attributes := make([]model.DeviceAttribute, len(dev.IdDataStruct)) + i := 0 + for name, value := range dev.IdDataStruct { + if name == "status" { + //we have to forbid the client to override attribute status in identity scope + //since it stands for status of a device (as in: accepted, rejected, preauthorized) + continue + } + attribute := model.DeviceAttribute{ + Name: name, + Description: nil, + Value: value, + Scope: "identity", + } + attributes[i] = attribute + i++ + } + attrJson, err := json.Marshal(attributes) + if err != nil { + return errors.New("internal error: cannot marshal attributes into json") + } + if err := d.cOrch.SubmitUpdateDeviceInventoryJob( + ctx, + orchestrator.UpdateDeviceInventoryReq{ + RequestId: requestid.FromContext(ctx), + TenantId: tenantId, + DeviceId: dev.Id, + Scope: "identity", + Attributes: string(attrJson), + }); err != nil { + return errors.Wrap(err, "failed to start device inventory update job") + } + if d.config.EnableReporting { + if err := d.cOrch.SubmitReindexReporting(ctx, string(dev.Id)); err != nil { + return errors.Wrap(err, "reindex reporting job error") + } + } + } + return nil +} + func (d *DevAuth) getDeviceFromAuthRequest( ctx context.Context, r *model.AuthReq, @@ -1150,6 +1194,92 @@ func (d *DevAuth) PreauthorizeDevice( } } +func (d *DevAuth) PreauthorizeDevices( + ctx context.Context, + req []model.PreAuthReq, +) error { + // try add device, if a device with the given id_data exists - + // the unique index on id_data will prevent it (conflict) + // this is the only safeguard against id data conflict - we won't try to handle it + // additionally on inserting the auth set (can't add an id data index on auth set - would + // prevent key rotation) + + // FIXME: tenant_token is "" on purpose, will be removed + + dev := make([]*model.Device, len(req)) + authset := make([]model.AuthSet, len(req)) + inventoryDevices := make([]model.DeviceInventoryUpdate, len(req)) + for i, r := range req { + device := model.NewDevice(r.DeviceId, r.IdData, r.PubKey) + device.Status = model.DevStatusPreauth + + idDataStruct, idDataSha256, err := parseIdData(r.IdData) + if err != nil { + return MakeErrDevAuthBadRequest(err) + } + + device.IdDataStruct = idDataStruct + device.IdDataSha256 = idDataSha256 + device.Status = model.DevStatusPreauth + dev[i] = device + + // record authentication request + authset[i] = model.AuthSet{ + Id: r.AuthSetId, + IdData: r.IdData, + IdDataStruct: idDataStruct, + IdDataSha256: idDataSha256, + PubKey: r.PubKey, + DeviceId: r.DeviceId, + Status: model.DevStatusPreauth, + Timestamp: uto.TimePtr(time.Now()), + } + inventoryDevices[i] = model.DeviceInventoryUpdate{ + Id: device.Id, + Revision: device.Revision, + } + } + + err := d.db.AddDevices(ctx, dev) + switch err { + case nil: + break + case store.ErrObjectExists: + return ErrDeviceExists + default: + return errors.Wrap(err, "failed to add device") + } + + tenantId := "" + idData := identity.FromContext(ctx) + if idData != nil { + tenantId = idData.Tenant + } + + wfReq := orchestrator.UpdateDeviceStatusReq{ + RequestId: requestid.FromContext(ctx), + Devices: inventoryDevices, + TenantId: tenantId, + Status: model.DevStatusPreauth, + } + if err = d.cOrch.SubmitUpdateDeviceStatusJob(ctx, wfReq); err != nil { + return errors.Wrap(err, "update device status job error") + } + + err = d.db.AddAuthSets(ctx, authset) + switch err { + case nil: + if err := d.setDeviceIdentities(ctx, dev, tenantId); err != nil { + return err + } + return nil + case store.ErrObjectExists: + return ErrDeviceExists + default: + return errors.Wrap(err, "failed to add auth set") + } +} + func (d *DevAuth) RevokeToken(ctx context.Context, tokenID string) error { l := log.FromContext(ctx) tokenOID := oid.FromString(tokenID) diff --git a/server.go b/server.go index 97161dab..f739da5e 100644 --- a/server.go +++ b/server.go @@ -134,7 +134,7 @@ func RunServer(c config.Reader) error { return errors.Wrap(err, "API setup failed") } - devauthapi := api_http.NewDevAuthApiHandlers(devauth, db) + devauthapi := api_http.NewDevAuthApiHandlers(devauth, db, api_http.Limits{MaxPreAuthElements: c.GetInt(dconfig.SettingMaxPreAuthElements)}) apph, err := devauthapi.GetApp() if err != nil { diff --git a/store/datastore.go b/store/datastore.go index 2d3a1806..ce2a5554 100644 --- a/store/datastore.go +++ b/store/datastore.go @@ -80,6 +80,8 @@ type DataStore interface { AddDevice(ctx context.Context, d model.Device) error + AddDevices(ctx context.Context, d []*model.Device) error + // updates a single device with deviceID, using data from `up` UpdateDevice(ctx context.Context, deviceID string, up model.DeviceUpdate) error @@ -88,6 +90,8 @@ type DataStore interface { AddAuthSet(ctx context.Context, set model.AuthSet) error + AddAuthSets(ctx context.Context, set []model.AuthSet) error + GetAuthSetByIdDataHashKey( ctx context.Context, idDataHash []byte, diff --git a/store/mongo/datastore_mongo.go b/store/mongo/datastore_mongo.go index 3446da02..11ae6ea7 100644 --- a/store/mongo/datastore_mongo.go +++ b/store/mongo/datastore_mongo.go @@ -313,6 +313,27 @@ func (db *DataStoreMongo) AddDevice(ctx context.Context, d model.Device) error { return nil } +func (db *DataStoreMongo) AddDevices(ctx context.Context, d []*model.Device) error { + devices := make(bson.A, len(d)) + for i := range d { + if d[i].Id == "" { + uid := oid.NewUUIDv4() + d[i].Id = uid.String() + } + devices[i] = d[i] + } + + c := db.client.Database(ctxstore.DbFromContext(ctx, DbName)).Collection(DbDevicesColl) + + if _, err := c.InsertMany(ctx, devices); err != nil { + if strings.Contains(err.Error(), "duplicate key error") { + return store.ErrObjectExists + } + return errors.Wrap(err, "failed to store device") + } + return nil +} + func (db *DataStoreMongo) UpdateDevice(ctx context.Context, deviceID string, updev model.DeviceUpdate) error { @@ -570,6 +591,28 @@ func (db *DataStoreMongo) AddAuthSet(ctx context.Context, set model.AuthSet) err return nil } +func (db *DataStoreMongo) AddAuthSets(ctx context.Context, set []model.AuthSet) error { + c := db.client.Database(ctxstore.DbFromContext(ctx, DbName)).Collection(DbAuthSetColl) + + sets := make(bson.A, len(set)) + for i := range set { + if set[i].Id == "" { + uid := oid.NewUUIDv4() + set[i].Id = uid.String() + } + sets[i] = set[i] + } + + if _, err := c.InsertMany(ctx, sets); err != nil { + if strings.Contains(err.Error(), "duplicate key error") { + return store.ErrObjectExists + } + return errors.Wrap(err, "failed to store device") + } + + return nil +} + func (db *DataStoreMongo) GetAuthSetByIdDataHashKey( ctx context.Context, idDataHash []byte,