Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions dryad/conf/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package conf
import (
"fmt"
"io"
"io/ioutil"

"github.com/BurntSushi/toml"

Expand Down Expand Up @@ -84,9 +83,9 @@ func (g *General) Marshal(w io.Writer) error {
return toml.NewEncoder(w).Encode(g)
}

// Unmarshal reads TOML representation from r and parses it into g.
// Unmarshal reads TOML representation from r and parses it into g. Function may panic (e.g. when reader is nil).
func (g *General) Unmarshal(r io.Reader) error {
b, err := ioutil.ReadAll(r)
b, err := io.ReadAll(r)
if err != nil {
return err
}
Expand Down
29 changes: 0 additions & 29 deletions dryad/conf/conf_suite_test.go

This file was deleted.

112 changes: 84 additions & 28 deletions dryad/conf/conf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,29 @@
* limitations under the License
*/

package conf_test
package conf

import (
"bytes"
"errors"
"io"
"strings"
"testing"

"github.com/SamsungSLAV/boruta"
. "github.com/SamsungSLAV/boruta/dryad/conf"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/assert"
)

var _ = Describe("Conf", func() {
marshaled := `listen_address = ":7175"
type brokenReader struct {
io.Reader
}

func (*brokenReader) Read(_ []byte) (n int, err error) {
return 0, errors.New("broken reader")
}

var (
marshaled = `listen_address = ":7175"
boruta_address = ""
ssh_address = ":22"
sdcard = "/dev/sdX"
Expand All @@ -40,7 +48,13 @@ stm_path = "/run/stm.socket"
name = "boruta-user"
groups = []
`
unmarshaled := &General{
empty = `listen_address = ""
boruta_address = ""
ssh_address = ""
sdcard = ""
stm_path = ""
`
unmarshaled = &General{
Address: ":7175",
SSHAdress: ":22",
Caps: boruta.Capabilities(map[string]string{}),
Expand All @@ -51,27 +65,69 @@ stm_path = "/run/stm.socket"
SDcard: "/dev/sdX",
STMsocket: "/run/stm.socket",
}
var g *General
)

BeforeEach(func() {
g = NewConf()
})
func TestNewConf(t *testing.T) {
assert.Equal(t, NewConf(), unmarshaled)
}

It("should initially have default configuration", func() {
Expect(g).To(Equal(unmarshaled))
})
func TestMarshal(t *testing.T) {
testCases := [...]struct {
name string
conf *General
str string
err error
}{
{name: "valid", conf: NewConf(), str: marshaled, err: nil},
{name: "empty", conf: new(General), str: empty, err: nil},
{name: "nil", conf: nil, str: "", err: nil},
}
assert := assert.New(t)

It("should encode default configuration", func() {
var w bytes.Buffer
g.Marshal(&w)
result := w.String()
Expect(result).ToNot(BeEmpty())
Expect(result).To(Equal(marshaled))
})
for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
var b bytes.Buffer
assert.ErrorIs(test.conf.Marshal(&b), test.err)
assert.Equal(b.String(), test.str)
})
}
}

It("should decode default configuration", func() {
g = new(General)
g.Unmarshal(strings.NewReader(marshaled))
Expect(g).To(Equal(unmarshaled))
})
})
func TestUnmarshal(t *testing.T) {
testCases := [...]struct {
name string
conf *General
read io.Reader
err error
panics bool
}{
{name: "valid", conf: NewConf(), read: strings.NewReader(marshaled), err: nil},
{name: "invalid", conf: new(General), read: strings.NewReader(`/4`), err: errors.New(`toml: line 1: expected '.' or '=', but got '/' instead`)},
{name: "empty", conf: new(General), read: strings.NewReader(empty), err: nil},
{name: "brokenReader", conf: new(General), read: new(brokenReader), err: errors.New("broken reader")},
{name: "nil", conf: new(General), read: nil, err: nil, panics: true},
}
assert := assert.New(t)

for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
var err error
g := new(General)
if test.panics {
assert.Panics(func() { err = g.Unmarshal(test.read) })
} else {
assert.NotPanics(func() { err = g.Unmarshal(test.read) })
}
if test.err != nil {
assert.ErrorContains(err, test.err.Error())
} else {
assert.NoError(err)
}
assert.Equal(g, test.conf)
})
}
}
105 changes: 71 additions & 34 deletions filter/filter_test.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2017-2018 Samsung Electronics Co., Ltd All Rights Reserved
* Copyright (c) 2017-2022 Samsung Electronics Co., Ltd All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@
package filter

import (
"fmt"
"testing"

"github.com/SamsungSLAV/boruta"
Expand Down Expand Up @@ -76,32 +77,36 @@ func TestNewRequest(t *testing.T) {
}

for _, tcase := range newRequestTests {
filter := NewRequests(tcase.ids, tcase.priorities, tcase.states)
assert.NotNil(filter, tcase.name)
tcase := tcase
t.Run(tcase.name, func(t *testing.T) {
t.Parallel()
filter := NewRequests(tcase.ids, tcase.priorities, tcase.states)
assert.NotNil(filter, tcase.name)

// Verify IDs.
assert.Len(filter.IDs, len(tcase.ids), tcase.name)
if len(tcase.ids) > 0 {
assert.Equal(tcase.ids, filter.IDs, tcase.name)
} else {
assert.Nil(filter.IDs)
}
// Verify IDs.
assert.Len(filter.IDs, len(tcase.ids), tcase.name)
if len(tcase.ids) > 0 {
assert.Equal(tcase.ids, filter.IDs, tcase.name)
} else {
assert.Nil(filter.IDs)
}

// Verify Priorities.
assert.Len(filter.Priorities, len(tcase.priorities), tcase.name)
if len(tcase.priorities) > 0 {
assert.Equal(tcase.priorities, filter.Priorities, tcase.name)
} else {
assert.Nil(filter.Priorities)
}
// Verify Priorities.
assert.Len(filter.Priorities, len(tcase.priorities), tcase.name)
if len(tcase.priorities) > 0 {
assert.Equal(tcase.priorities, filter.Priorities, tcase.name)
} else {
assert.Nil(filter.Priorities)
}

// Verify States.
assert.Len(filter.States, len(tcase.states), tcase.name)
if len(tcase.states) > 0 {
assert.Equal(tcase.expectedStates, filter.States, tcase.name)
} else {
assert.Nil(filter.States)
}
// Verify States.
assert.Len(filter.States, len(tcase.states), tcase.name)
if len(tcase.states) > 0 {
assert.Equal(tcase.expectedStates, filter.States, tcase.name)
} else {
assert.Nil(filter.States)
}
})
}
}

Expand Down Expand Up @@ -203,17 +208,32 @@ func TestRequestMatch(t *testing.T) {
},
}

var filter Requests
makeName := func(states []boruta.ReqState, priorities []boruta.Priority, ids []boruta.ReqID) string {
return fmt.Sprintf("States: %v, Priorities: %v, ISs: %v", states, priorities, ids)
}
makeFilter := func(states []boruta.ReqState, priorities []boruta.Priority, ids []boruta.ReqID) Requests {
return Requests{
States: states,
Priorities: priorities,
IDs: ids,
}
}

for _, stest := range statesTests {
filter.States = stest.states
stest := stest
for _, ptest := range priorityTests {
filter.Priorities = ptest.priorities
ptest := ptest
for _, idstest := range idsTests {
filter.IDs = idstest.ids
assert.Equal(stest.result && ptest.result && idstest.result, filter.Match(&req))
idtest := idstest
filter := makeFilter(stest.states, ptest.priorities, idtest.ids)
t.Run(makeName(filter.States, filter.Priorities, filter.IDs), func(t *testing.T) {
t.Parallel()
assert.Equal(stest.result && ptest.result && idtest.result, filter.Match(&req))
})
}
}
}
var filter Requests
assert.False(filter.Match(nil))
assert.False(filter.Match(5))
}
Expand Down Expand Up @@ -261,63 +281,80 @@ func TestWorkerMatch(t *testing.T) {
other := boruta.Group("other")

var tests = [...]struct {
name string
worker *boruta.WorkerInfo
filter *Workers
result bool
}{
{
name: "NilGroupsAndCaps",
worker: newWorker(groups(all), caps("armv7", "true")),
filter: NewWorkers(nil, nil),
result: true,
},
{
name: "EmptyGroupsAndNilCaps",
worker: newWorker(groups(all, some), caps("aarch64", "true")),
filter: NewWorkers(groups(empty), nil),
result: false,
},
{
name: "NilGroupsAndMatchingCaps",
worker: newWorker(groups(all, some), caps("aarch64", "true")),
filter: NewWorkers(nil, caps("aarch64", "true")),
result: true,
},
{
name: "NilGroupsAndDefaultCaps",
worker: newWorker(groups(all, some), caps("aarch64", "true")),
filter: NewWorkers(nil, make(boruta.Capabilities)),
result: true,
},
{
name: "DefaultGroupsAndNilCaps",
worker: newWorker(groups(all, some), caps("aarch64", "true")),
filter: NewWorkers(make(boruta.Groups, 0), nil),
result: true,
},
{
name: "MatchingOtherGroupsAndMatchingCaps",
worker: newWorker(groups(all, some), caps("aarch64", "true")),
filter: NewWorkers(groups(all, other), caps("aarch64", "true")),
result: true,
},
{
name: "NotMatchingGroupsAndMatchingCaps",
worker: newWorker(groups(all, some), caps("aarch64", "true")),
filter: NewWorkers(groups(other), caps("aarch64", "true")),
result: false,
},
{
name: "MatchingAllGroupsAndMatchingCaps",
worker: newWorker(groups(all, some), caps("aarch64", "true")),
filter: NewWorkers(groups(all, other), caps("aarch64", "false")),
result: false,
},
{
name: "MatchingAllGroupsAndNotMatchingCaps",
worker: newWorker(groups(all, some), caps("aarch64", "true")),
filter: NewWorkers(groups(all, other),
boruta.Capabilities{"foo": "bar"}),
result: false,
},
{
name: "DefaultFilterNil",
worker: nil,
filter: new(Workers),
result: false,
},
}

for _, tcase := range tests {
assert.Equal(tcase.result, tcase.filter.Match(tcase.worker))
tcase := tcase
t.Run(tcase.name, func(t *testing.T) {
t.Parallel()
assert.Equal(tcase.result, tcase.filter.Match(tcase.worker))
})
}

filter := new(Workers)
assert.False(filter.Match(nil))
assert.False(filter.Match(5))
assert.False(new(Workers).Match(5), "WrongType")
}
Loading