diff --git a/sig/flags.go b/sig/flags.go new file mode 100644 index 000000000..6c01244a3 --- /dev/null +++ b/sig/flags.go @@ -0,0 +1,130 @@ +package sig + +import ( + "context" + "sync" +) + +//TODO: Implemet Flags. Be mindful of memory management: +// - Wait/WaitContext should resuse channels whenever possible +// - Wait/WaitContext should discard channels that are no longer needed + +// Flags provides a thread-safe observable flag set +type Flags struct { + flagStates map[string]struct{} + flagChannels map[string]chan struct{} + mu sync.Mutex +} + +func NewFlags() *Flags { + return &Flags{ + flagStates: make(map[string]struct{}), + flagChannels: make(map[string]chan struct{}), + } +} + +// Set sets the provided flags +func (flags *Flags) Set(flag ...string) { + flags.mu.Lock() + defer flags.mu.Unlock() + + for _, f := range flag { + if _, ok := flags.flagStates[f]; ok { + continue + } + + flags.flagStates[f] = struct{}{} + if ch, ok := flags.flagChannels[f]; ok { + close(ch) + delete(flags.flagChannels, f) + } + } +} + +// Clear clears the provided flags +func (flags *Flags) Clear(flag ...string) { + flags.mu.Lock() + defer flags.mu.Unlock() + + for _, f := range flag { + if _, ok := flags.flagStates[f]; !ok { + continue + } + + delete(flags.flagStates, f) + if ch, ok := flags.flagChannels[f]; ok { + close(ch) + delete(flags.flagChannels, f) + } + } +} + +// IsSet returns true if the flag is up, false otherwise +func (flags *Flags) IsSet(flag string) bool { + flags.mu.Lock() + defer flags.mu.Unlock() + + _, ok := flags.flagStates[flag] + return ok +} + +// Flags returns a list of all set flags +func (flags *Flags) Flags() []string { + flags.mu.Lock() + defer flags.mu.Unlock() + + var setFlags []string + for flag := range flags.flagStates { + setFlags = append(setFlags, flag) + } + return setFlags +} + +// Wait returns a channel that will be closed as soon as the flag is in the specified state. +// If the flag is already in the specified state, Wait immediately returns a closed channel. +func (flags *Flags) Wait(flag string, state bool) <-chan struct{} { + flags.mu.Lock() + defer flags.mu.Unlock() + + if _, ok := flags.flagStates[flag]; ok == state { + ch := make(chan struct{}) + close(ch) + return ch + } + + ch, ok := flags.flagChannels[flag] + if !ok { + ch = make(chan struct{}) + flags.flagChannels[flag] = ch + } + return ch +} + +// WaitContext waits until one of the following occurs: +// 1. Context is canceled - WaitContext returns ctx.Err() +// 2. Flag is in the specified state - WaitContext returns nil +// If the flag is already in the specified state when the function is called, it returns nil immediately. +func (flags *Flags) WaitContext(ctx context.Context, flag string, state bool) error { + flags.mu.Lock() + defer flags.mu.Unlock() + + ch, ok := flags.flagChannels[flag] + if !ok { + ch = make(chan struct{}) + flags.flagChannels[flag] = ch + } + + if curState, ok := flags.flagStates[flag]; ok { + if (curState == struct{}{} && !state) || (curState != struct{}{} && state) { + return nil + } + } + + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } + +} diff --git a/sig/flags_test.go b/sig/flags_test.go new file mode 100644 index 000000000..d70fc3517 --- /dev/null +++ b/sig/flags_test.go @@ -0,0 +1,118 @@ +package sig + +import ( + "context" + "errors" + "testing" + "time" +) + +var _ FlagsSpec = &Flags{} + +type FlagsSpec interface { + Set(flag ...string) + Clear(flag ...string) + IsSet(flag string) bool + Flags() []string + Wait(flag string, state bool) <-chan struct{} + WaitContext(ctx context.Context, flag string, state bool) error +} + +const TestFlag1 = "test1" +const TestFlag2 = "test2" + +func TestWait(t *testing.T) { + var flags = NewFlags() + var tick = make(chan struct{}) + + go func() { + select { + case <-flags.Wait(TestFlag1, false): + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout reached") + } + + tick <- struct{}{} + + select { + case <-flags.Wait(TestFlag1, true): + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout reached") + } + + select { + case <-flags.Wait(TestFlag2, true): + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout reached") + } + + tick <- struct{}{} + + select { + case <-flags.Wait(TestFlag2, false): + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout reached") + } + + tick <- struct{}{} + }() + + <-tick + time.After(10 * time.Millisecond) + + flags.Set(TestFlag1, TestFlag2) + + <-tick + time.After(10 * time.Millisecond) + + flags.Clear(TestFlag2) + + <-tick +} + +func TestWaitContext(t *testing.T) { + var flags = NewFlags() + var tick = make(chan struct{}) + + go func() { + var err error + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err = flags.WaitContext(ctx, TestFlag1, true) + if err != nil { + t.Fatal("unexpected err:", err) + } + + tick <- struct{}{} + + err = flags.WaitContext(ctx, TestFlag1, false) + if err != nil { + t.Fatal("unexpected err:", err) + } + + err = flags.WaitContext(ctx, TestFlag2, false) + if err != nil { + t.Fatal("unexpected err:", err) + } + + ctx, cancel = context.WithCancel(context.Background()) + cancel() + + err = flags.WaitContext(ctx, TestFlag2, true) + if !errors.Is(err, context.Canceled) { + t.Fatal("unexpected err:", err) + } + + tick <- struct{}{} + }() + + flags.IsSet(TestFlag1) + + <-tick + + flags.Clear(TestFlag1) + + <-tick +}