Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 39 additions & 35 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,35 +233,46 @@ func (app *App) RegisterIntervals(intervals ...Interval) {
}
}

func (app *App) registerEntityListener(etl EntityListener) {
if etl.delay != 0 && etl.toState == "" {
slog.Error("EntityListener error: you have to use ToState() when using Duration()")
panic(ErrInvalidArgs)
}

for _, entity := range etl.entityIds {
app.entityListeners[entity] = append(app.entityListeners[entity], &etl)
}
}

func (app *App) RegisterEntityListeners(etls ...EntityListener) {
for _, etl := range etls {
etl := etl
if etl.delay != 0 && etl.toState == "" {
slog.Error("EntityListener error: you have to use ToState() when using Duration()")
panic(ErrInvalidArgs)
}
app.registerEntityListener(etl)
}
}

for _, entity := range etl.entityIds {
if elList, ok := app.entityListeners[entity]; ok {
app.entityListeners[entity] = append(elList, &etl)
} else {
app.entityListeners[entity] = []*EntityListener{&etl}
}
func (app *App) registerEventListener(evl EventListener) {
for _, eventType := range evl.eventTypes {
elList, ok := app.eventListeners[eventType]
if !ok {
// We're not listening to that event type yet. Ask HA to
// send them to us, and when they arrive, call any event
// listeners for that type (including any that are
// registered in the future).
eventType := eventType
app.conn.SubscribeToEventType(
eventType,
func(msg websocket.ChanMsg) {
go app.callEventListeners(eventType, msg)
},
)
}
app.eventListeners[eventType] = append(elList, &evl)
}
}

func (app *App) RegisterEventListeners(evls ...EventListener) {
for _, evl := range evls {
evl := evl
for _, eventType := range evl.eventTypes {
if elList, ok := app.eventListeners[eventType]; ok {
app.eventListeners[eventType] = append(elList, &evl)
} else {
websocket.SubscribeToEventType(eventType, app.conn)
app.eventListeners[eventType] = []*EventListener{&evl}
}
}
app.registerEventListener(evl)
}
}

Expand Down Expand Up @@ -316,7 +327,11 @@ func (app *App) Start() {
go app.runScheduledActions(app.ctx)

// subscribe to state_changed events
app.entitySubscription = websocket.SubscribeToStateChangedEvents(app.conn)
app.entitySubscription = app.conn.SubscribeToStateChangedEvents(
func(msg websocket.ChanMsg) {
go app.callEntityListeners(msg.Raw)
},
)

// entity listeners runOnStartup
for eid, etls := range app.entityListeners {
Expand All @@ -342,20 +357,9 @@ func (app *App) Start() {
}
}

// entity listeners and event listeners
elChan := make(chan websocket.ChanMsg)
go app.conn.ListenWebsocket(elChan)

for {
msg, ok := <-elChan
if !ok {
break
}
if app.entitySubscription.ID() == msg.Id {
go callEntityListeners(app, msg.Raw)
} else {
go callEventListeners(app, msg)
}
// Start listen on the connection for incoming messages:
if err := app.conn.Run(); err != nil {
slog.Error("Error reading from websocket", "err", err)
}
}

Expand Down
122 changes: 64 additions & 58 deletions entitylistener.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,18 @@ type stateChangedMsg struct {
ID int `json:"id"`
Type string `json:"type"`
Event struct {
Data struct {
EntityID string `json:"entity_id"`
NewState msgState `json:"new_state"`
OldState msgState `json:"old_state"`
} `json:"data"`
EventType string `json:"event_type"`
Origin string `json:"origin"`
Data stateData `json:"data"`
EventType string `json:"event_type"`
Origin string `json:"origin"`
} `json:"event"`
}

type stateData struct {
EntityID string `json:"entity_id"`
NewState msgState `json:"new_state"`
OldState msgState `json:"old_state"`
}

type msgState struct {
EntityID string `json:"entity_id"`
LastChanged time.Time `json:"last_changed"`
Expand Down Expand Up @@ -191,8 +193,52 @@ func (b elBuilder3) Build() EntityListener {
return b.entityListener
}

func (l *EntityListener) maybeCall(app *App, entityData EntityData, data stateData) {
// Check conditions
if c := checkWithinTimeRange(l.betweenStart, l.betweenEnd); c.fail {
return
}
if c := checkStatesMatch(l.fromState, data.OldState.State); c.fail {
return
}
if c := checkStatesMatch(l.toState, data.NewState.State); c.fail {
if l.delayTimer != nil {
l.delayTimer.Stop()
}
return
}
if c := checkThrottle(l.throttle, l.lastRan); c.fail {
return
}
if c := checkExceptionDates(l.exceptionDates); c.fail {
return
}
if c := checkExceptionRanges(l.exceptionRanges); c.fail {
return
}
if c := checkEnabledEntity(app.state, l.enabledEntities); c.fail {
return
}
if c := checkDisabledEntity(app.state, l.disabledEntities); c.fail {
return
}

if l.delay != 0 {
l := l
l.delayTimer = time.AfterFunc(l.delay, func() {
go l.callback(app.service, app.state, entityData)
l.lastRan = carbon.Now()
})
return
}

// run now if no delay set
go l.callback(app.service, app.state, entityData)
l.lastRan = carbon.Now()
}

/* Functions */
func callEntityListeners(app *App, msgBytes []byte) {
func (app *App) callEntityListeners(msgBytes []byte) {
msg := stateChangedMsg{}
_ = json.Unmarshal(msgBytes, &msg)
data := msg.Event.Data
Expand All @@ -211,56 +257,16 @@ func callEntityListeners(app *App, msgBytes []byte) {
return
}

for _, l := range listeners {
// Check conditions
if c := checkWithinTimeRange(l.betweenStart, l.betweenEnd); c.fail {
continue
}
if c := checkStatesMatch(l.fromState, data.OldState.State); c.fail {
continue
}
if c := checkStatesMatch(l.toState, data.NewState.State); c.fail {
if l.delayTimer != nil {
l.delayTimer.Stop()
}
continue
}
if c := checkThrottle(l.throttle, l.lastRan); c.fail {
continue
}
if c := checkExceptionDates(l.exceptionDates); c.fail {
continue
}
if c := checkExceptionRanges(l.exceptionRanges); c.fail {
continue
}
if c := checkEnabledEntity(app.state, l.enabledEntities); c.fail {
continue
}
if c := checkDisabledEntity(app.state, l.disabledEntities); c.fail {
continue
}

entityData := EntityData{
TriggerEntityId: eid,
FromState: data.OldState.State,
FromAttributes: data.OldState.Attributes,
ToState: data.NewState.State,
ToAttributes: data.NewState.Attributes,
LastChanged: data.OldState.LastChanged,
}

if l.delay != 0 {
l := l
l.delayTimer = time.AfterFunc(l.delay, func() {
go l.callback(app.service, app.state, entityData)
l.lastRan = carbon.Now()
})
continue
}
entityData := EntityData{
TriggerEntityId: eid,
FromState: data.OldState.State,
FromAttributes: data.OldState.Attributes,
ToState: data.NewState.State,
ToAttributes: data.NewState.Attributes,
LastChanged: data.OldState.LastChanged,
}

// run now if no delay set
go l.callback(app.service, app.state, entityData)
l.lastRan = carbon.Now()
for _, l := range listeners {
l.maybeCall(app, entityData, data)
}
}
66 changes: 31 additions & 35 deletions eventListener.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gomeassistant

import (
"encoding/json"
"fmt"
"time"

Expand Down Expand Up @@ -133,48 +132,45 @@ func (b eventListenerBuilder3) Build() EventListener {
return b.eventListener
}

type BaseEventMsg struct {
Event struct {
EventType string `json:"event_type"`
} `json:"event"`
func (l *EventListener) maybeCall(app *App, eventData EventData) {
// Check conditions
if c := checkWithinTimeRange(l.betweenStart, l.betweenEnd); c.fail {
return
}
if c := checkThrottle(l.throttle, l.lastRan); c.fail {
return
}
if c := checkExceptionDates(l.exceptionDates); c.fail {
return
}
if c := checkExceptionRanges(l.exceptionRanges); c.fail {
return
}
if c := checkEnabledEntity(app.state, l.enabledEntities); c.fail {
return
}
if c := checkDisabledEntity(app.state, l.disabledEntities); c.fail {
return
}

go l.callback(app.service, app.state, eventData)
l.lastRan = carbon.Now()
}

/* Functions */
func callEventListeners(app *App, msg websocket.ChanMsg) {
baseEventMsg := BaseEventMsg{}
_ = json.Unmarshal(msg.Raw, &baseEventMsg)
listeners, ok := app.eventListeners[baseEventMsg.Event.EventType]
func (app *App) callEventListeners(eventType string, msg websocket.ChanMsg) {
listeners, ok := app.eventListeners[eventType]
if !ok {
// no listeners registered for this event type
return
}

eventData := EventData{
Type: eventType,
RawEventJSON: msg.Raw,
}

for _, l := range listeners {
// Check conditions
if c := checkWithinTimeRange(l.betweenStart, l.betweenEnd); c.fail {
continue
}
if c := checkThrottle(l.throttle, l.lastRan); c.fail {
continue
}
if c := checkExceptionDates(l.exceptionDates); c.fail {
continue
}
if c := checkExceptionRanges(l.exceptionRanges); c.fail {
continue
}
if c := checkEnabledEntity(app.state, l.enabledEntities); c.fail {
continue
}
if c := checkDisabledEntity(app.state, l.disabledEntities); c.fail {
continue
}

eventData := EventData{
Type: baseEventMsg.Event.EventType,
RawEventJSON: msg.Raw,
}
go l.callback(app.service, app.state, eventData)
l.lastRan = carbon.Now()
l.maybeCall(app, eventData)
}
}
45 changes: 45 additions & 0 deletions internal/websocket/locked_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@ type LockedConn interface {
// `LockedConn` is still active.
NextMessageID() int64

// Subscribe allocates a new message ID and subscribes
// `subscriber` to it, in the sense that the subscriber will be
// called for any incoming messages that have that ID. This
// doesn't actually interact with the server. Typically the next
// step would be to send a message with its message ID set to
// `Subscription.ID()`.
//
// The returned `Subscription` must eventually be passed at least
// once to `Unsubscribe()`, though `Unsubscribe()` can be called
// against a different `LockedConn` than the one that generated
// it.
Subscribe(subscriber Subscriber) Subscription

// Unsubscribe terminates `subscription` at the websocket level;
// i.e., no more incoming messages will be forwarded to the
// corresponding `Subscriber`. Note that this does not interact
// with the server; it is the caller's responsibility to send it
// an "unsubscribe" command if necessary.
Unsubscribe(subscription Subscription)

// SendMessage sends the specified message over the websocket
// connection. `msg` must be JSON-serializable and have the
// correct format and a unique, monotonically-increasing ID, which
Expand All @@ -30,6 +50,31 @@ func (lc lockedConn) NextMessageID() int64 {
return lc.conn.lastMessageID
}

// Subscribe implements [LockedConn.Subscribe].
func (lc lockedConn) Subscribe(subscriber Subscriber) Subscription {
lc.conn.subscribersLock.Lock()
defer lc.conn.subscribersLock.Unlock()

id := lc.NextMessageID()
lc.conn.subscribers[id] = subscriber
return Subscription{
messageID: id,
}
}

// Unsubscribe implements [LockedConn.Unsubscribe].
func (lc lockedConn) Unsubscribe(subscription Subscription) {
if subscription.messageID == 0 {
return
}

lc.conn.subscribersLock.Lock()
defer lc.conn.subscribersLock.Unlock()

delete(lc.conn.subscribers, subscription.messageID)
subscription.messageID = 0
}

// SendMessage implements [LockedConn.SendMessage].
func (lc lockedConn) SendMessage(msg any) error {
if err := lc.conn.conn.WriteJSON(msg); err != nil {
Expand Down
Loading