diff --git a/README.md b/README.md index 820459e..e709e12 100644 --- a/README.md +++ b/README.md @@ -102,7 +102,7 @@ func myFunc(se *ga.Service, st *ga.State) { Entity Listeners are used to respond to entities changing state. The simplest entity listener looks like: ```go -etl := ga.NewEntityListener().EntityIds("binary_sensor.front_door").Call(myFunc).Build() +etl := ga.NewEntityListener().EntityIDs("binary_sensor.front_door").Call(myFunc).Build() ``` Entity listeners have other functions to change the behavior. @@ -139,7 +139,7 @@ func myFunc(se *ga.Service, st *ga.State, e ga.EntityData) { Event Listeners are used to respond to entities changing state. The simplest event listener looks like: ```go -evl := ga.NewEntityListener().EntityIds("binary_sensor.front_door").Call(myFunc).Build() +evl := ga.NewEntityListener().EntityIDs("binary_sensor.front_door").Call(myFunc).Build() ``` Event listeners have other functions to change the behavior. diff --git a/app.go b/app.go deleted file mode 100644 index 4cd6e79..0000000 --- a/app.go +++ /dev/null @@ -1,328 +0,0 @@ -package gomeassistant - -import ( - "context" - "errors" - "fmt" - "log/slog" - "time" - - "github.com/golang-module/carbon" - "github.com/gorilla/websocket" - sunriseLib "github.com/nathan-osman/go-sunrise" - "saml.dev/gome-assistant/internal" - "saml.dev/gome-assistant/internal/http" - pq "saml.dev/gome-assistant/internal/priorityqueue" - ws "saml.dev/gome-assistant/internal/websocket" -) - -// Returned by NewApp() if authentication fails -var ErrInvalidToken = ws.ErrInvalidToken - -var ErrInvalidArgs = errors.New("invalid arguments provided") - -type App struct { - ctx context.Context - ctxCancel context.CancelFunc - conn *websocket.Conn - - // Wraps the ws connection with added mutex locking - wsWriter *ws.WebsocketWriter - - httpClient *http.HttpClient - - service *Service - state *StateImpl - - schedules pq.PriorityQueue - intervals pq.PriorityQueue - entityListeners map[string][]*EntityListener - entityListenersId int64 - eventListeners map[string][]*EventListener -} - -/* -DurationString represents a duration, such as "2s" or "24h". -See https://pkg.go.dev/time#ParseDuration for all valid time units. -*/ -type DurationString string - -/* -TimeString is a 24-hr format time "HH:MM" such as "07:30". -*/ -type TimeString string - -type timeRange struct { - start time.Time - end time.Time -} - -type NewAppRequest struct { - // Required - // IpAddress of your Home Assistant instance i.e. "localhost" - // or "192.168.86.59" etc. - IpAddress string - - // Optional - // Port number Home Assistant is running on. Defaults to 8123. - Port string - - // Required - // Auth token generated in Home Assistant. Used - // to connect to the Websocket API. - HAAuthToken string - - // Required - // EntityId of the zone representing your home e.g. "zone.home". - // Used to pull latitude/longitude from Home Assistant - // to calculate sunset/sunrise times. - HomeZoneEntityId string - - // Optional - // Whether to use secure connections for http and websockets. - // Setting this to `true` will use `https://` instead of `https://` - // and `wss://` instead of `ws://`. - Secure bool -} - -/* -NewApp establishes the websocket connection and returns an object -you can use to register schedules and listeners. -*/ -func NewApp(request NewAppRequest) (*App, error) { - if request.IpAddress == "" || request.HAAuthToken == "" || request.HomeZoneEntityId == "" { - slog.Error("IpAddress, HAAuthToken, and HomeZoneEntityId are all required arguments in NewAppRequest") - return nil, ErrInvalidArgs - } - port := request.Port - if port == "" { - port = "8123" - } - - var ( - conn *websocket.Conn - ctx context.Context - ctxCancel context.CancelFunc - err error - ) - - if request.Secure { - conn, ctx, ctxCancel, err = ws.SetupSecureConnection(request.IpAddress, port, request.HAAuthToken) - } else { - conn, ctx, ctxCancel, err = ws.SetupConnection(request.IpAddress, port, request.HAAuthToken) - } - - if conn == nil { - return nil, err - } - - var httpClient *http.HttpClient - - if request.Secure { - httpClient = http.NewHttpsClient(request.IpAddress, port, request.HAAuthToken) - } else { - httpClient = http.NewHttpClient(request.IpAddress, port, request.HAAuthToken) - } - - wsWriter := &ws.WebsocketWriter{Conn: conn} - service := newService(wsWriter, ctx, httpClient) - state, err := newState(httpClient, request.HomeZoneEntityId) - if err != nil { - return nil, err - } - - return &App{ - conn: conn, - wsWriter: wsWriter, - ctx: ctx, - ctxCancel: ctxCancel, - httpClient: httpClient, - service: service, - state: state, - schedules: pq.New(), - intervals: pq.New(), - entityListeners: map[string][]*EntityListener{}, - eventListeners: map[string][]*EventListener{}, - }, nil -} - -func (a *App) Cleanup() { - if a.ctxCancel != nil { - a.ctxCancel() - } -} - -func (a *App) RegisterSchedules(schedules ...DailySchedule) { - for _, s := range schedules { - // realStartTime already set for sunset/sunrise - if s.isSunrise || s.isSunset { - s.nextRunTime = getNextSunRiseOrSet(a, s.isSunrise, s.sunOffset).Carbon2Time() - a.schedules.Insert(s, float64(s.nextRunTime.Unix())) - continue - } - - now := carbon.Now() - startTime := carbon.Now().SetTimeMilli(s.hour, s.minute, 0, 0) - - // advance first scheduled time by frequency until it is in the future - if startTime.Lt(now) { - startTime = startTime.AddDay() - } - - s.nextRunTime = startTime.Carbon2Time() - a.schedules.Insert(s, float64(startTime.Carbon2Time().Unix())) - } -} - -func (a *App) RegisterIntervals(intervals ...Interval) { - for _, i := range intervals { - if i.frequency == 0 { - slog.Error("A schedule must use either set frequency via Every()") - panic(ErrInvalidArgs) - } - - i.nextRunTime = internal.ParseTime(string(i.startTime)).Carbon2Time() - now := time.Now() - for i.nextRunTime.Before(now) { - i.nextRunTime = i.nextRunTime.Add(i.frequency) - } - a.intervals.Insert(i, float64(i.nextRunTime.Unix())) - } -} - -func (a *App) RegisterEntityListeners(etls ...EntityListener) { - for _, etl := range etls { - etl := etl - if etl.delay != 0 && etl.toState == "" { - slog.Error("EntityListener error: you have to use ToState() when using Duration()") - panic(ErrInvalidArgs) - } - - for _, entity := range etl.entityIds { - if elList, ok := a.entityListeners[entity]; ok { - a.entityListeners[entity] = append(elList, &etl) - } else { - a.entityListeners[entity] = []*EntityListener{&etl} - } - } - } -} - -func (a *App) RegisterEventListeners(evls ...EventListener) { - for _, evl := range evls { - evl := evl - for _, eventType := range evl.eventTypes { - if elList, ok := a.eventListeners[eventType]; ok { - a.eventListeners[eventType] = append(elList, &evl) - } else { - ws.SubscribeToEventType(eventType, a.wsWriter, a.ctx) - a.eventListeners[eventType] = []*EventListener{&evl} - } - } - } -} - -func getSunriseSunset(s *StateImpl, sunrise bool, dateToUse carbon.Carbon, offset ...DurationString) carbon.Carbon { - date := dateToUse.Carbon2Time() - rise, set := sunriseLib.SunriseSunset(s.latitude, s.longitude, date.Year(), date.Month(), date.Day()) - rise, set = rise.Local(), set.Local() - - val := set - printString := "Sunset" - if sunrise { - val = rise - printString = "Sunrise" - } - - setOrRiseToday := carbon.Parse(val.String()) - - var t time.Duration - var err error - if len(offset) == 1 { - t, err = time.ParseDuration(string(offset[0])) - if err != nil { - parsingErr := fmt.Errorf("could not parse offset passed to %s: \"%s\": %w", printString, offset[0], err) - slog.Error(parsingErr.Error()) - panic(parsingErr) - } - } - - // add offset if set, this code works for negative values too - if t.Microseconds() != 0 { - setOrRiseToday = setOrRiseToday.AddMinutes(int(t.Minutes())) - } - - return setOrRiseToday -} - -func getNextSunRiseOrSet(a *App, sunrise bool, offset ...DurationString) carbon.Carbon { - sunriseOrSunset := getSunriseSunset(a.state, sunrise, carbon.Now(), offset...) - if sunriseOrSunset.Lt(carbon.Now()) { - // if we're past today's sunset or sunrise (accounting for offset) then get tomorrows - // as that's the next time the schedule will run - sunriseOrSunset = getSunriseSunset(a.state, sunrise, carbon.Tomorrow(), offset...) - } - return sunriseOrSunset -} - -func (a *App) Start() { - slog.Info("Starting", "schedules", a.schedules.Len()) - slog.Info("Starting", "entity listeners", len(a.entityListeners)) - slog.Info("Starting", "event listeners", len(a.eventListeners)) - - go runSchedules(a) - go runIntervals(a) - - // subscribe to state_changed events - id := internal.GetId() - ws.SubscribeToStateChangedEvents(id, a.wsWriter, a.ctx) - a.entityListenersId = id - - // entity listeners runOnStartup - for eid, etls := range a.entityListeners { - for _, etl := range etls { - // ensure each ETL only runs once, even if - // it listens to multiple entities - if etl.runOnStartup && !etl.runOnStartupCompleted { - entityState, err := a.state.Get(eid) - if err != nil { - slog.Warn("Failed to get entity state \"", eid, "\" during startup, skipping RunOnStartup") - } - - etl.runOnStartupCompleted = true - go etl.callback(a.service, a.state, EntityData{ - TriggerEntityId: eid, - FromState: entityState.State, - FromAttributes: entityState.Attributes, - ToState: entityState.State, - ToAttributes: entityState.Attributes, - LastChanged: entityState.LastChanged, - }) - } - } - } - - // entity listeners and event listeners - elChan := make(chan ws.ChanMsg) - go ws.ListenWebsocket(a.conn, a.ctx, elChan) - - for { - msg, ok := <-elChan - if !ok { - break - } - if a.entityListenersId == msg.Id { - go callEntityListeners(a, msg.Raw) - } else { - go callEventListeners(a, msg) - } - } -} - -func (a *App) GetService() *Service { - return a.service -} - -func (a *App) GetState() State { - return a.state -} diff --git a/app/app.go b/app/app.go new file mode 100644 index 0000000..5bbb26a --- /dev/null +++ b/app/app.go @@ -0,0 +1,541 @@ +package app + +import ( + "context" + "errors" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/golang-module/carbon" + sunriseLib "github.com/nathan-osman/go-sunrise" + "golang.org/x/sync/errgroup" + + "saml.dev/gome-assistant/internal/http" + "saml.dev/gome-assistant/internal/priorityqueue" + "saml.dev/gome-assistant/websocket" +) + +// Returned by NewApp() if authentication fails +var ErrInvalidToken = websocket.ErrInvalidToken + +var ErrInvalidArgs = errors.New("invalid arguments provided") + +type App struct { + // Wraps the ws connection with added mutex locking + wsConn *websocket.Conn + + httpClient *http.HttpClient + + Service *Service + State State + + scheduledActions priorityqueue.PriorityQueue + entityListeners map[string][]*EntityListener + eventListeners map[string][]*EventListener + + // Ready is closed when the app is ready for use. + ready chan struct{} + + // If `App.Start()` has been called, `cancel()` cancels the + // context being used, which causes the app to shut down cleanly. + cancel context.CancelFunc + + closeOnce sync.Once +} + +// DurationString represents a duration, such as "2s" or "24h". See +// https://pkg.go.dev/time#ParseDuration for all valid time units. +type DurationString string + +// TimeString is a 24-hr format time "HH:MM" such as "07:30". +type TimeString string + +type timeRange struct { + start time.Time + end time.Time +} + +type NewAppConfig struct { + // RESTBaseURI is the base URI for REST requests; for example, + // * `http://homeassistant.local:8123/api` from outside of the + // HA appliance (without encryption) + // * `https://homeassistant.local:8123/api` from outside of the + // HA appliance (with encryption) + // * `http://supervisor/core/api` from an add-on running within + // the appliance and connecting via the proxy + RESTBaseURI string + + // WebsocketURI is the base URI for websocket connections; for + // example, + // * `ws://homeassistant.local:8123/api/websocket` from outside + // of the HA appliance (without encryption) + // * `wss://homeassistant.local:8123/api/websocket` from outside + // of the HA appliance (with encryption) + // * `ws://supervisor/core/api/websocket` from an add-on running + // within the appliance and connecting via the proxy + WebsocketURI string + + // Auth token generated in Home Assistant. Used to connect to the + // Websocket API. + HAAuthToken string + + // Required + // EntityID of the zone representing your home e.g. "zone.home". + // Used to pull latitude/longitude from Home Assistant + // to calculate sunset/sunrise times. + HomeZoneEntityID string +} + +// NewAppFromConfig establishes the websocket connection and returns +// an object you can use to register schedules and listeners, based on +// the URIs that it should connect to. `ctx` is used only to limit the +// time spent connecting; it cannot be used after that to cancel the +// app. +func NewAppFromConfig(ctx context.Context, config NewAppConfig) (*App, error) { + if config.RESTBaseURI == "" || config.WebsocketURI == "" || + config.HAAuthToken == "" || config.HomeZoneEntityID == "" { + slog.Error( + "RESTBaseURI, WebsocketURI, HAAuthToken, and HomeZoneEntityID " + + "are all required arguments in NewAppRequest", + ) + return nil, ErrInvalidArgs + } + + wsWriter, err := websocket.NewConnFromURI(ctx, config.WebsocketURI, config.HAAuthToken) + if err != nil { + return nil, err + } + + httpClient := http.ClientFromUri(config.RESTBaseURI, config.HAAuthToken) + + state, err := newState(httpClient, config.HomeZoneEntityID) + if err != nil { + return nil, err + } + app := App{ + wsConn: wsWriter, + httpClient: httpClient, + State: state, + scheduledActions: priorityqueue.New(), + entityListeners: map[string][]*EntityListener{}, + eventListeners: map[string][]*EventListener{}, + ready: make(chan struct{}), + cancel: func() {}, + } + app.Service = newService(&app, httpClient) + + return &app, nil +} + +type NewAppRequest struct { + // Required + // IpAddress of your Home Assistant instance i.e. "localhost" + // or "192.168.86.59" etc. + IpAddress string + + // Optional + // Port number Home Assistant is running on. Defaults to 8123. + Port string + + // Required + // Auth token generated in Home Assistant. Used + // to connect to the Websocket API. + HAAuthToken string + + // Required + // EntityID of the zone representing your home e.g. "zone.home". + // Used to pull latitude/longitude from Home Assistant + // to calculate sunset/sunrise times. + HomeZoneEntityID string + + // Optional + // Whether to use secure connections for http and websockets. + // Setting this to `true` will use `https://` instead of `https://` + // and `wss://` instead of `ws://`. + Secure bool +} + +// NewApp establishes the websocket connection and returns an object +// you can use to register schedules and listeners. `ctx` is used only +// to limit the time spent connecting; it cannot be used after that to +// cancel the app. If this function returns successfully, then +// `App.Close()` must eventually be called to release resources. +func NewApp(ctx context.Context, request NewAppRequest) (*App, error) { + if request.IpAddress == "" || request.HAAuthToken == "" || request.HomeZoneEntityID == "" { + slog.Error( + "IpAddress, HAAuthToken, and HomeZoneEntityID " + + "are all required arguments in NewAppRequest", + ) + return nil, ErrInvalidArgs + } + port := request.Port + if port == "" { + port = "8123" + } + + config := NewAppConfig{ + HAAuthToken: request.HAAuthToken, + HomeZoneEntityID: request.HomeZoneEntityID, + } + + if request.Secure { + config.WebsocketURI = fmt.Sprintf("wss://%s:%s/api/websocket", request.IpAddress, port) + config.RESTBaseURI = fmt.Sprintf("https://%s:%s/api", request.IpAddress, port) + } else { + config.WebsocketURI = fmt.Sprintf("ws://%s:%s/api/websocket", request.IpAddress, port) + config.RESTBaseURI = fmt.Sprintf("http://%s:%s/api", request.IpAddress, port) + } + + return NewAppFromConfig(ctx, config) +} + +// Ready returns a channel that is closed when the app is ready for use. +func (app *App) Ready() <-chan struct{} { + return app.ready +} + +type scheduledAction interface { + String() string + Hash() string + initializeNextRunTime(app *App) + shouldRun(app *App) bool + run(app *App) + updateNextRunTime(app *App) + getNextRunTime() time.Time +} + +func (app *App) RegisterScheduledAction(action scheduledAction) { + action.initializeNextRunTime(app) + app.scheduledActions.Insert(action, float64(action.getNextRunTime().Unix())) +} + +func (app *App) RegisterSchedules(schedules ...*DailySchedule) { + for _, s := range schedules { + app.RegisterScheduledAction(s) + } +} + +func (app *App) RegisterIntervals(intervals ...*Interval) { + for _, i := range intervals { + app.RegisterScheduledAction(i) + } +} + +func (app *App) RegisterEntityListener(etl EntityListener) { + if etl.delay != 0 && etl.toState == "" { + slog.Error("EntityListener error: you have to use ToState() when using Duration()") + panic(ErrInvalidArgs) + } + + for _, entity := range etl.entityIDs { + if elList, ok := app.entityListeners[entity]; ok { + app.entityListeners[entity] = append(elList, &etl) + } else { + app.entityListeners[entity] = []*EntityListener{&etl} + } + } +} + +func (app *App) RegisterEntityListeners(etls ...EntityListener) { + for _, etl := range etls { + app.RegisterEntityListener(etl) + } +} + +func (app *App) RegisterEventListener(evl EventListener) { + for _, eventType := range evl.eventTypes { + elList, ok := app.eventListeners[eventType] + if !ok { + // FIXME: keep track of subscriptions so that they can + // be unsubscribed from. + _, err := app.SubscribeEvents( + eventType, + func(msg websocket.Message) { + go app.callEventListeners(msg) + }, + ) + if err != nil { + // FIXME: better error handling + panic(err) + } + } + app.eventListeners[eventType] = append(elList, &evl) + } +} + +func (app *App) RegisterEventListeners(evls ...EventListener) { + for _, evl := range evls { + app.RegisterEventListener(evl) + } +} + +func getSunriseSunset( + s State, sunrise bool, dateToUse carbon.Carbon, offset ...DurationString, +) carbon.Carbon { + date := dateToUse.Carbon2Time() + rise, set := sunriseLib.SunriseSunset( + s.Latitude(), s.Longitude(), date.Year(), date.Month(), date.Day(), + ) + rise, set = rise.Local(), set.Local() + + val := set + printString := "Sunset" + if sunrise { + val = rise + printString = "Sunrise" + } + + setOrRiseToday := carbon.Parse(val.String()) + + var t time.Duration + var err error + if len(offset) == 1 { + t, err = time.ParseDuration(string(offset[0])) + if err != nil { + parsingErr := fmt.Errorf( + "could not parse offset passed to %s: \"%s\": %w", + printString, offset[0], err, + ) + slog.Error(parsingErr.Error()) + panic(parsingErr) + } + } + + // add offset if set, this code works for negative values too + if t.Microseconds() != 0 { + setOrRiseToday = setOrRiseToday.AddMinutes(int(t.Minutes())) + } + + return setOrRiseToday +} + +func getNextSunRiseOrSet(app *App, sunrise bool, offset ...DurationString) carbon.Carbon { + sunriseOrSunset := getSunriseSunset(app.State, sunrise, carbon.Now(), offset...) + if sunriseOrSunset.Lt(carbon.Now()) { + // if we're past today's sunset or sunrise (accounting for offset) then get tomorrows + // as that's the next time the schedule will run + sunriseOrSunset = getSunriseSunset(app.State, sunrise, carbon.Tomorrow(), offset...) + } + return sunriseOrSunset +} + +// Start the app. When `ctx` expires, the app closes the connection +// and returns. +func (app *App) Start(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + app.cancel = cancel + defer cancel() + + eg, ctx := errgroup.WithContext(ctx) + + slog.Info("Starting", "scheduled actions", app.scheduledActions.Len()) + slog.Info("Starting", "entity listeners", len(app.entityListeners)) + slog.Info("Starting", "event listeners", len(app.eventListeners)) + + // entity listeners and event listeners + eg.Go(func() error { + app.wsConn.Start() + cancel() + return nil + }) + + eg.Go(func() error { + app.runScheduledActions(ctx) + return nil + }) + + // subscribe to state_changed events + stateChangedSubscription, err := app.SubscribeStateChangedEvents( + func(msg websocket.Message) { + go app.callEntityListeners(msg) + }, + ) + if err != nil { + return fmt.Errorf("subscribing to 'state_changed' events: %w", err) + } + + defer app.UnsubscribeEvents(stateChangedSubscription) + + // entity listeners runOnStartup + for eid, etls := range app.entityListeners { + eid := eid + for _, etl := range etls { + etl := etl + // ensure each ETL only runs once, even if + // it listens to multiple entities + if etl.runOnStartup && !etl.runOnStartupCompleted { + entityState, err := app.State.Get(eid) + if err != nil { + slog.Warn( + "Failed to get entity state \"", eid, + "\" during startup, skipping RunOnStartup", + ) + } + + etl.runOnStartupCompleted = true + eg.Go(func() error { + etl.callback(EntityData{ + TriggerEntityID: eid, + FromState: entityState.State, + FromAttributes: entityState.Attributes, + ToState: entityState.State, + ToAttributes: entityState.Attributes, + LastChanged: entityState.LastChanged, + }) + return nil + }) + } + } + } + + close(app.ready) + + eg.Go(func() error { + <-ctx.Done() + app.Close() + return nil + }) + + eg.Wait() + + return nil +} + +// Close closes the connection and releases any resources. It may be +// called more than once; only the first call does anything. +func (app *App) Close() { + app.closeOnce.Do(func() { + app.close() + }) +} + +// close closes the connection and releases resources. It must be +// called exactly once. +func (app *App) close() { + app.cancel() + app.wsConn.Close() +} + +func (app *App) GetService() *Service { + return app.Service +} + +func (app *App) GetState() State { + return app.State +} + +func (app *App) runScheduledActions(ctx context.Context) { + if app.scheduledActions.Len() == 0 { + return + } + + // Create a new, but stopped, timer: + timer := time.NewTimer(1 * time.Hour) + if !timer.Stop() { + <-timer.C + } + + for { + action := app.popScheduledAction() + if action.getNextRunTime().After(time.Now()) { + timer.Reset(time.Until(action.getNextRunTime())) + + select { + case <-timer.C: + case <-ctx.Done(): + return + } + } + + if action.shouldRun(app) { + go action.run(app) + } + + app.requeueScheduledAction(action) + } +} + +func (app *App) popScheduledAction() scheduledAction { + action, _ := app.scheduledActions.Pop() + return action.(scheduledAction) +} + +func (app *App) requeueScheduledAction(action scheduledAction) { + action.updateNextRunTime(app) + app.scheduledActions.Insert(action, float64(action.getNextRunTime().Unix())) +} + +type subscribeEventsRequest struct { + websocket.BaseMessage + EventType string `json:"event_type"` +} + +// SubscribeEvents subscribes to events of the given type, invoking +// `subscriber` when any such events are received. `eventType` can be +// `*` to listen to all event types. Calls to `subscriber` are +// synchronous with respect to any other received messages, but +// asynchronous with respect to writes. +func (app *App) SubscribeEvents( + eventType string, subscriber websocket.Subscriber, +) (websocket.Subscription, error) { + ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second) + defer cancel() + + // Make sure we're listening before events might start arriving: + e := subscribeEventsRequest{ + BaseMessage: websocket.BaseMessage{ + Type: "subscribe_events", + }, + EventType: eventType, + } + + response, subscription, err := app.Subscribe(ctx, &e, subscriber) + if err != nil { + return websocket.Subscription{}, err + } + + // FIXME: check response for success + _ = response + + return subscription, nil +} + +func (app *App) SubscribeStateChangedEvents( + subscriber websocket.Subscriber, +) (websocket.Subscription, error) { + return app.SubscribeEvents("state_changed", subscriber) +} + +type unsubscribeEventsRequest struct { + websocket.BaseMessage + Subscription int64 `json:"subscription"` +} + +// UnsubscribeEvents unsubscribes, at the server, from events that +// were subscribed to via the specified `subscription`. +func (app *App) UnsubscribeEvents(subscription websocket.Subscription) error { + ctx := context.TODO() + + req := unsubscribeEventsRequest{ + BaseMessage: websocket.BaseMessage{ + Type: "unsubscribe_events", + }, + Subscription: subscription.ID(), + } + + var result any + rs := newResultSubscriber(app, &result) + err := app.wsConn.Send(func(lc websocket.LockedConn) error { + lc.Unsubscribe(subscription) + // Subscribe, so that we receive the result of the unsubscribe + // command itself: + return rs.subscribe(lc, &req) + }) + if err != nil { + return fmt.Errorf("unsubscribing from ID %d: %w", subscription.ID(), err) + } + + return rs.wait(ctx) +} diff --git a/app/calls.go b/app/calls.go new file mode 100644 index 0000000..2508d80 --- /dev/null +++ b/app/calls.go @@ -0,0 +1,156 @@ +package app + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + + ga "saml.dev/gome-assistant" + "saml.dev/gome-assistant/websocket" +) + +// Call invokes an RPC and processes the result as follows: +// 1. Generate a message ID. +// 2. Subscribe to that ID. +// 3. Send `req` over the websocket +// 4. Waits for a single "result" message +// 5. Unsubscribe from ID +// 6. Unmarshal the result into `result`. +// +// `msg` must be serializable to JSON. It shouldn't have its ID filled +// in yet; that will be done within this method. `result` must be +// something that `json.Unmarshal()` can deserialize into; typically, +// it is a pointer. If the result indicates a failure +// (success==false), then return that as an error. +func (app *App) Call( + ctx context.Context, req websocket.Request, result any, +) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + rs := newResultSubscriber(app, result) + err := app.wsConn.Send(func(lc websocket.LockedConn) error { + return rs.subscribe(lc, req) + }) + if err != nil { + return err + } + return rs.wait(ctx) +} + +type CallServiceRequest struct { + websocket.BaseMessage + Domain string `json:"domain"` + Service string `json:"service"` + + // ServiceData must be serializable to a JSON object. + ServiceData any `json:"service_data,omitempty"` + + Target ga.Target `json:"target,omitempty"` +} + +// CallService invokes a service using a `call_service` message, then +// waits for the response. The response is evaluated; if it indicates +// an error, then this method returns that error. Otherwise, the +// "result" field is stored to `result`, which must be something that +// `json.Unmarshal()` can serialize into (typically a pointer). +func (app *App) CallService( + ctx context.Context, domain string, service string, serviceData any, target ga.Target, + result any, +) error { + req := CallServiceRequest{ + BaseMessage: websocket.BaseMessage{ + Type: "call_service", + }, + Domain: domain, + Service: service, + ServiceData: serviceData, + Target: target, + } + + if err := app.Call(ctx, &req, result); err != nil { + switch target { + case ga.Target{}: + return fmt.Errorf("calling '%s.%s': %w", domain, service, err) + default: + return fmt.Errorf("calling '%s.%s' for %s: %w", domain, service, target, err) + } + } + return nil +} + +// Subscribe subscribes to some events via `req`, waits for a single +// response, and then leaves `subscriber` subscribed to the events. If +// this method returns without an error, `subscriber` must eventually +// be unsubscribed. `ctx` covers the subscription and the wait for the +// first answer, but not the forwarding of subsequent events or +// unsubscribing. +// +// FIXME: should this subscriber and subscription be specialized to +// event messages? +// +// FIXME: should the result be examined? If the subscription request +// failed, then we could fail more generally instead of leaving the +// cleanup to the caller. +func (app *App) Subscribe( + ctx context.Context, req websocket.Request, subscriber websocket.Subscriber, +) (websocket.ResultMessage, websocket.Subscription, error) { + // The result of the attempt to subscribe (i.e., the first + // message) will be sent to this channel. + resultReceived := false + var resultMsg websocket.ResultMessage + var resultErr error + done := make(chan struct{}) + + var subscription websocket.Subscription + + // Receive a single "result" message, send it to `responseCh`, + // then unsubscribe: + dualSubscriber := func(msg websocket.Message) { + if msg.Type == "result" { + if resultReceived { + slog.Warn( + "Error: multiple responses received for one 'subscribe' request (ignored)", + ) + return + } + resultReceived = true + + defer close(done) + + resultErr = json.Unmarshal(msg.Raw, &resultMsg) + if resultErr != nil { + return + } + // FIXME: turn non-success responses into errors. + return + } + + // Forward other responses (i.e., the events themselves) to + // `subscriber`: + subscriber(msg) + } + + err := app.wsConn.Send(func(lc websocket.LockedConn) error { + subscription = lc.Subscribe(dualSubscriber) + req.SetID(subscription.ID()) + if err := lc.SendMessage(req); err != nil { + lc.Unsubscribe(subscription) + return fmt.Errorf("error writing to websocket: %w", err) + } + return nil + }) + + if err != nil { + return websocket.ResultMessage{}, websocket.Subscription{}, err + } + + select { + case <-done: + return resultMsg, subscription, nil + case <-ctx.Done(): + // FIXME: unsubscribe + return websocket.ResultMessage{}, websocket.Subscription{}, ctx.Err() + } +} diff --git a/checkers.go b/app/checkers.go similarity index 96% rename from checkers.go rename to app/checkers.go index 1936190..7b80a0a 100644 --- a/checkers.go +++ b/app/checkers.go @@ -1,4 +1,4 @@ -package gomeassistant +package app import ( "time" @@ -20,7 +20,8 @@ func checkWithinTimeRange(startTime, endTime string) conditionCheck { parsedEnd := internal.ParseTime(endTime) // check for midnight overlap - if parsedEnd.Lt(parsedStart) { // example turn on night lights when motion from 23:00 to 07:00 + if parsedEnd.Lt(parsedStart) { + // example turn on night lights when motion from 23:00 to 07:00 if parsedEnd.IsPast() { // such as at 15:00, 22:00 parsedEnd = parsedEnd.AddDay() } else { diff --git a/checkers_test.go b/app/checkers_test.go similarity index 96% rename from checkers_test.go rename to app/checkers_test.go index 22dd451..59c18ce 100644 --- a/checkers_test.go +++ b/app/checkers_test.go @@ -1,4 +1,4 @@ -package gomeassistant +package app import ( "errors" @@ -15,6 +15,14 @@ type MockState struct { GetError bool } +func (s MockState) Latitude() float64 { + return 0.0 +} + +func (s MockState) Longitude() float64 { + return 0.0 +} + func (s MockState) AfterSunrise(_ ...DurationString) bool { return true } diff --git a/entitylistener.go b/app/entitylistener.go similarity index 72% rename from entitylistener.go rename to app/entitylistener.go index 2dd3a3a..b6625e9 100644 --- a/entitylistener.go +++ b/app/entitylistener.go @@ -1,4 +1,4 @@ -package gomeassistant +package app import ( "encoding/json" @@ -7,10 +7,11 @@ import ( "github.com/golang-module/carbon" "saml.dev/gome-assistant/internal" + "saml.dev/gome-assistant/websocket" ) type EntityListener struct { - entityIds []string + entityIDs []string callback EntityListenerCallback fromState string toState string @@ -33,20 +34,19 @@ type EntityListener struct { disabledEntities []internal.EnabledDisabledInfo } -type EntityListenerCallback func(*Service, State, EntityData) +type EntityListenerCallback func(EntityData) type EntityData struct { - TriggerEntityId string + TriggerEntityID string FromState string FromAttributes map[string]any ToState string ToAttributes map[string]any - LastChanged time.Time + LastChanged websocket.TimeStamp } type stateChangedMsg struct { - ID int `json:"id"` - Type string `json:"type"` + websocket.BaseMessage Event struct { Data struct { EntityID string `json:"entity_id"` @@ -59,10 +59,10 @@ type stateChangedMsg struct { } type msgState struct { - EntityID string `json:"entity_id"` - LastChanged time.Time `json:"last_changed"` - State string `json:"state"` - Attributes map[string]any `json:"attributes"` + EntityID string `json:"entity_id"` + LastChanged websocket.TimeStamp `json:"last_changed"` + State string `json:"state"` + Attributes map[string]any `json:"attributes"` } /* Methods */ @@ -77,11 +77,11 @@ type elBuilder1 struct { entityListener EntityListener } -func (b elBuilder1) EntityIds(entityIds ...string) elBuilder2 { - if len(entityIds) == 0 { - panic("must pass at least one entityId to EntityIds()") +func (b elBuilder1) EntityIDs(entityIDs ...string) elBuilder2 { + if len(entityIDs) == 0 { + panic("must pass at least one entityID to EntityIDs()") } else { - b.entityListener.entityIds = entityIds + b.entityListener.entityIDs = entityIDs } return elBuilder2(b) } @@ -143,7 +143,9 @@ func (b elBuilder3) ExceptionDates(t time.Time, tl ...time.Time) elBuilder3 { } func (b elBuilder3) ExceptionRange(start, end time.Time) elBuilder3 { - b.entityListener.exceptionRanges = append(b.entityListener.exceptionRanges, timeRange{start, end}) + b.entityListener.exceptionRanges = append( + b.entityListener.exceptionRanges, timeRange{start, end}, + ) return b } @@ -152,16 +154,20 @@ func (b elBuilder3) RunOnStartup() elBuilder3 { return b } -/* -Enable this listener only when the current state of {entityId} matches {state}. -If there is a network error while retrieving state, the listener runs if {runOnNetworkError} is true. -*/ -func (b elBuilder3) EnabledWhen(entityId, state string, runOnNetworkError bool) elBuilder3 { - if entityId == "" { - panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) +// Enable this listener only when the current state of {entityID} +// matches {state}. If there is a network error while retrieving +// state, the listener runs if {runOnNetworkError} is true. +func (b elBuilder3) EnabledWhen(entityID, state string, runOnNetworkError bool) elBuilder3 { + if entityID == "" { + panic( + fmt.Sprintf( + "entityID is empty in EnabledWhen entityID='%s' state='%s'", + entityID, state, + ), + ) } i := internal.EnabledDisabledInfo{ - Entity: entityId, + Entity: entityID, State: state, RunOnError: runOnNetworkError, } @@ -169,16 +175,20 @@ func (b elBuilder3) EnabledWhen(entityId, state string, runOnNetworkError bool) return b } -/* -Disable this listener when the current state of {entityId} matches {state}. -If there is a network error while retrieving state, the listener runs if {runOnNetworkError} is true. -*/ -func (b elBuilder3) DisabledWhen(entityId, state string, runOnNetworkError bool) elBuilder3 { - if entityId == "" { - panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) +// Disable this listener when the current state of {entityID} matches +// {state}. If there is a network error while retrieving state, the +// listener runs if {runOnNetworkError} is true. +func (b elBuilder3) DisabledWhen(entityID, state string, runOnNetworkError bool) elBuilder3 { + if entityID == "" { + panic( + fmt.Sprintf( + "entityID is empty in EnabledWhen entityID='%s' state='%s'", + entityID, state, + ), + ) } i := internal.EnabledDisabledInfo{ - Entity: entityId, + Entity: entityID, State: state, RunOnError: runOnNetworkError, } @@ -191,7 +201,8 @@ func (b elBuilder3) Build() EntityListener { } /* Functions */ -func callEntityListeners(app *App, msgBytes []byte) { +func (app *App) callEntityListeners(chanMsg websocket.Message) { + msgBytes := chanMsg.Raw msg := stateChangedMsg{} json.Unmarshal(msgBytes, &msg) data := msg.Event.Data @@ -233,15 +244,15 @@ func callEntityListeners(app *App, msgBytes []byte) { if c := checkExceptionRanges(l.exceptionRanges); c.fail { continue } - if c := checkEnabledEntity(app.state, l.enabledEntities); c.fail { + if c := checkEnabledEntity(app.State, l.enabledEntities); c.fail { continue } - if c := checkDisabledEntity(app.state, l.disabledEntities); c.fail { + if c := checkDisabledEntity(app.State, l.disabledEntities); c.fail { continue } entityData := EntityData{ - TriggerEntityId: eid, + TriggerEntityID: eid, FromState: data.OldState.State, FromAttributes: data.OldState.Attributes, ToState: data.NewState.State, @@ -252,14 +263,14 @@ func callEntityListeners(app *App, msgBytes []byte) { if l.delay != 0 { l := l l.delayTimer = time.AfterFunc(l.delay, func() { - go l.callback(app.service, app.state, entityData) + go l.callback(entityData) l.lastRan = carbon.Now() }) continue } // run now if no delay set - go l.callback(app.service, app.state, entityData) + go l.callback(entityData) l.lastRan = carbon.Now() } } diff --git a/eventListener.go b/app/eventListener.go similarity index 60% rename from eventListener.go rename to app/eventListener.go index 37fe98a..37c3656 100644 --- a/eventListener.go +++ b/app/eventListener.go @@ -1,4 +1,4 @@ -package gomeassistant +package app import ( "encoding/json" @@ -6,8 +6,9 @@ import ( "time" "github.com/golang-module/carbon" + "saml.dev/gome-assistant/internal" - ws "saml.dev/gome-assistant/internal/websocket" + "saml.dev/gome-assistant/websocket" ) type EventListener struct { @@ -25,7 +26,7 @@ type EventListener struct { disabledEntities []internal.EnabledDisabledInfo } -type EventListenerCallback func(*Service, State, EventData) +type EventListenerCallback func(websocket.Event) type EventData struct { Type string @@ -84,26 +85,38 @@ func (b eventListenerBuilder3) Throttle(s DurationString) eventListenerBuilder3 return b } -func (b eventListenerBuilder3) ExceptionDates(t time.Time, tl ...time.Time) eventListenerBuilder3 { +func (b eventListenerBuilder3) ExceptionDates( + t time.Time, tl ...time.Time, +) eventListenerBuilder3 { b.eventListener.exceptionDates = append(tl, t) return b } func (b eventListenerBuilder3) ExceptionRange(start, end time.Time) eventListenerBuilder3 { - b.eventListener.exceptionRanges = append(b.eventListener.exceptionRanges, timeRange{start, end}) + b.eventListener.exceptionRanges = append( + b.eventListener.exceptionRanges, + timeRange{start, end}, + ) return b } -/* -Enable this listener only when the current state of {entityId} matches {state}. -If there is a network error while retrieving state, the listener runs if {runOnNetworkError} is true. -*/ -func (b eventListenerBuilder3) EnabledWhen(entityId, state string, runOnNetworkError bool) eventListenerBuilder3 { - if entityId == "" { - panic(fmt.Sprintf("entityId is empty in eventListener EnabledWhen entityId='%s' state='%s' runOnNetworkError='%t'", entityId, state, runOnNetworkError)) +// Enable this listener only when the current state of {entityID} +// matches {state}. If there is a network error while retrieving +// state, the listener runs if {runOnNetworkError} is true. +func (b eventListenerBuilder3) EnabledWhen( + entityID, state string, runOnNetworkError bool, +) eventListenerBuilder3 { + if entityID == "" { + panic( + fmt.Sprintf( + "entityID is empty in eventListener EnabledWhen "+ + "entityID='%s' state='%s' runOnNetworkError='%t'", + entityID, state, runOnNetworkError, + ), + ) } i := internal.EnabledDisabledInfo{ - Entity: entityId, + Entity: entityID, State: state, RunOnError: runOnNetworkError, } @@ -111,16 +124,23 @@ func (b eventListenerBuilder3) EnabledWhen(entityId, state string, runOnNetworkE return b } -/* -Disable this listener when the current state of {entityId} matches {state}. -If there is a network error while retrieving state, the listener runs if {runOnNetworkError} is true. -*/ -func (b eventListenerBuilder3) DisabledWhen(entityId, state string, runOnNetworkError bool) eventListenerBuilder3 { - if entityId == "" { - panic(fmt.Sprintf("entityId is empty in eventListener EnabledWhen entityId='%s' state='%s' runOnNetworkError='%t'", entityId, state, runOnNetworkError)) +// Disable this listener when the current state of {entityID} matches +// {state}. If there is a network error while retrieving state, the +// listener runs if {runOnNetworkError} is true. +func (b eventListenerBuilder3) DisabledWhen( + entityID, state string, runOnNetworkError bool, +) eventListenerBuilder3 { + if entityID == "" { + panic( + fmt.Sprintf( + "entityID is empty in eventListener EnabledWhen "+ + "entityID='%s' state='%s' runOnNetworkError='%t'", + entityID, state, runOnNetworkError, + ), + ) } i := internal.EnabledDisabledInfo{ - Entity: entityId, + Entity: entityID, State: state, RunOnError: runOnNetworkError, } @@ -132,17 +152,11 @@ func (b eventListenerBuilder3) Build() EventListener { return b.eventListener } -type BaseEventMsg struct { - Event struct { - EventType string `json:"event_type"` - } `json:"event"` -} - /* Functions */ -func callEventListeners(app *App, msg ws.ChanMsg) { - baseEventMsg := BaseEventMsg{} - json.Unmarshal(msg.Raw, &baseEventMsg) - listeners, ok := app.eventListeners[baseEventMsg.Event.EventType] +func (app *App) callEventListeners(msg websocket.Message) { + var eventMessage websocket.EventMessage + json.Unmarshal(msg.Raw, &eventMessage) + listeners, ok := app.eventListeners[eventMessage.Event.EventType] if !ok { // no listeners registered for this event type return @@ -162,18 +176,14 @@ func callEventListeners(app *App, msg ws.ChanMsg) { if c := checkExceptionRanges(l.exceptionRanges); c.fail { continue } - if c := checkEnabledEntity(app.state, l.enabledEntities); c.fail { + if c := checkEnabledEntity(app.State, l.enabledEntities); c.fail { continue } - if c := checkDisabledEntity(app.state, l.disabledEntities); c.fail { + if c := checkDisabledEntity(app.State, l.disabledEntities); c.fail { continue } - eventData := EventData{ - Type: baseEventMsg.Event.EventType, - RawEventJSON: msg.Raw, - } - go l.callback(app.service, app.state, eventData) + go l.callback(eventMessage.Event) l.lastRan = carbon.Now() } } diff --git a/app/eventTypes.go b/app/eventTypes.go new file mode 100644 index 0000000..517a504 --- /dev/null +++ b/app/eventTypes.go @@ -0,0 +1,18 @@ +package app + +type ZWaveJSEventData struct { + Domain string `json:"domain"` + NodeID int `json:"node_id"` + HomeID int64 `json:"home_id"` + Endpoint int `json:"endpoint"` + DeviceID string `json:"device_id"` + CommandClass int `json:"command_class"` + CommandClassName string `json:"command_class_name"` + Label string `json:"label"` + Property string `json:"property"` + PropertyName string `json:"property_name"` + PropertyKey string `json:"property_key"` + PropertyKeyName string `json:"property_key_name"` + Value string `json:"value"` + ValueRaw int `json:"value_raw"` +} diff --git a/interval.go b/app/interval.go similarity index 59% rename from interval.go rename to app/interval.go index 3d525d9..067bf54 100644 --- a/interval.go +++ b/app/interval.go @@ -1,13 +1,14 @@ -package gomeassistant +package app import ( "fmt" + "log/slog" "time" "saml.dev/gome-assistant/internal" ) -type IntervalCallback func(*Service, State) +type IntervalCallback func() type Interval struct { frequency time.Duration @@ -23,28 +24,30 @@ type Interval struct { disabledEntities []internal.EnabledDisabledInfo } -func (i Interval) Hash() string { - return fmt.Sprint(i.startTime, i.endTime, i.frequency, i.callback, i.exceptionDates, i.exceptionRanges) +func (i *Interval) Hash() string { + return fmt.Sprint( + i.startTime, i.endTime, i.frequency, i.callback, i.exceptionDates, i.exceptionRanges, + ) } // Call type intervalBuilder struct { - interval Interval + interval *Interval } // Every type intervalBuilderCall struct { - interval Interval + interval *Interval } // Offset, ExceptionDates, ExceptionRange type intervalBuilderEnd struct { - interval Interval + interval *Interval } func NewInterval() intervalBuilder { return intervalBuilder{ - Interval{ + &Interval{ frequency: 0, startTime: "00:00", endTime: "00:00", @@ -52,7 +55,7 @@ func NewInterval() intervalBuilder { } } -func (i Interval) String() string { +func (i *Interval) String() string { return fmt.Sprintf("Interval{ call %q every %s%s%s }", internal.GetFunctionName(i.callback), i.frequency, @@ -106,16 +109,22 @@ func (ib intervalBuilderEnd) ExceptionRange(start, end time.Time) intervalBuilde return ib } -/* -Enable this interval only when the current state of {entityId} matches {state}. -If there is a network error while retrieving state, the interval runs if {runOnNetworkError} is true. -*/ -func (ib intervalBuilderEnd) EnabledWhen(entityId, state string, runOnNetworkError bool) intervalBuilderEnd { - if entityId == "" { - panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) +// Enable this interval only when the current state of {entityID} +// matches {state}. If there is a network error while retrieving +// state, the interval runs if {runOnNetworkError} is true. +func (ib intervalBuilderEnd) EnabledWhen( + entityID, state string, runOnNetworkError bool, +) intervalBuilderEnd { + if entityID == "" { + panic( + fmt.Sprintf( + "entityID is empty in EnabledWhen entityID='%s' state='%s'", + entityID, state, + ), + ) } i := internal.EnabledDisabledInfo{ - Entity: entityId, + Entity: entityID, State: state, RunOnError: runOnNetworkError, } @@ -123,16 +132,22 @@ func (ib intervalBuilderEnd) EnabledWhen(entityId, state string, runOnNetworkErr return ib } -/* -Disable this interval when the current state of {entityId} matches {state}. -If there is a network error while retrieving state, the interval runs if {runOnNetworkError} is true. -*/ -func (ib intervalBuilderEnd) DisabledWhen(entityId, state string, runOnNetworkError bool) intervalBuilderEnd { - if entityId == "" { - panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) +// Disable this interval when the current state of {entityID} matches +// {state}. If there is a network error while retrieving state, the +// interval runs if {runOnNetworkError} is true. +func (ib intervalBuilderEnd) DisabledWhen( + entityID, state string, runOnNetworkError bool, +) intervalBuilderEnd { + if entityID == "" { + panic( + fmt.Sprintf( + "entityID is empty in EnabledWhen entityID='%s' state='%s'", + entityID, state, + ), + ) } i := internal.EnabledDisabledInfo{ - Entity: entityId, + Entity: entityID, State: state, RunOnError: runOnNetworkError, } @@ -140,62 +155,53 @@ func (ib intervalBuilderEnd) DisabledWhen(entityId, state string, runOnNetworkEr return ib } -func (sb intervalBuilderEnd) Build() Interval { +func (sb intervalBuilderEnd) Build() *Interval { return sb.interval } -// app.Start() functions -func runIntervals(a *App) { - if a.intervals.Len() == 0 { - return +func (i *Interval) initializeNextRunTime(app *App) { + if i.frequency == 0 { + slog.Error("A schedule must use either set frequency via Every()") + panic(ErrInvalidArgs) } - for { - i := popInterval(a) - - // run callback for all intervals before now in case they overlap - for i.nextRunTime.Before(time.Now()) { - i.maybeRunCallback(a) - requeueInterval(a, i) - - i = popInterval(a) - } - - time.Sleep(time.Until(i.nextRunTime)) - i.maybeRunCallback(a) - requeueInterval(a, i) + i.nextRunTime = internal.ParseTime(string(i.startTime)).Carbon2Time() + now := time.Now() + for i.nextRunTime.Before(now) { + i.nextRunTime = i.nextRunTime.Add(i.frequency) } } -func (i Interval) maybeRunCallback(a *App) { +func (i *Interval) getNextRunTime() time.Time { + return i.nextRunTime +} + +func (i Interval) shouldRun(app *App) bool { if c := checkStartEndTime(i.startTime /* isStart = */, true); c.fail { - return + return false } if c := checkStartEndTime(i.endTime /* isStart = */, false); c.fail { - return + return false } if c := checkExceptionDates(i.exceptionDates); c.fail { - return + return false } if c := checkExceptionRanges(i.exceptionRanges); c.fail { - return + return false } - if c := checkEnabledEntity(a.state, i.enabledEntities); c.fail { - return + if c := checkEnabledEntity(app.State, i.enabledEntities); c.fail { + return false } - if c := checkDisabledEntity(a.state, i.disabledEntities); c.fail { - return + if c := checkDisabledEntity(app.State, i.disabledEntities); c.fail { + return false } - go i.callback(a.service, a.state) + return true } -func popInterval(a *App) Interval { - i, _ := a.intervals.Pop() - return i.(Interval) +func (i *Interval) run(app *App) { + i.callback() } -func requeueInterval(a *App, i Interval) { +func (i *Interval) updateNextRunTime(app *App) { i.nextRunTime = i.nextRunTime.Add(i.frequency) - - a.intervals.Insert(i, float64(i.nextRunTime.Unix())) } diff --git a/app/result_subscriber.go b/app/result_subscriber.go new file mode 100644 index 0000000..e3513fd --- /dev/null +++ b/app/result_subscriber.go @@ -0,0 +1,82 @@ +package app + +import ( + "context" + "fmt" + "log/slog" + "sync" + + "saml.dev/gome-assistant/websocket" +) + +// resultSubscriber is a helper type for handling the result message +// sent by the server in response to some kind of request. It +// subscribes itself, sends a message to the server, captures the +// first `result` message, then unsubscribes itself. The +// `ResultMessage` or error can be read using `wait()`. +type resultSubscriber struct { + app *App + subscription websocket.Subscription + + once sync.Once + result any + err error + done chan struct{} +} + +// newResultSubscriber creates a new subscriber that writes its result +// into `result`, which must be something that `json.Unmarshal()` can +// marshal into (typically a pointer). +func newResultSubscriber(app *App, result any) *resultSubscriber { + return &resultSubscriber{ + app: app, + result: result, + done: make(chan struct{}), + } +} + +// subscribe prepares and sends `req` to `lc`, but first subscribes +// `rs.callback` to receive the result of the request. +func (rs *resultSubscriber) subscribe( + lc websocket.LockedConn, req websocket.Request, +) error { + rs.subscription = lc.Subscribe(rs.callback) + req.SetID(rs.subscription.ID()) + if err := lc.SendMessage(req); err != nil { + lc.Unsubscribe(rs.subscription) + return fmt.Errorf("error writing to websocket: %w", err) + } + return nil +} + +// callback receives a single "result" message, stores the result to +// `rs`, then unsubscribes. It implements `websocket.Subscriber`. +func (rs *resultSubscriber) callback(msg websocket.Message) { + defer rs.close() + rs.err = msg.GetResult(rs.result) +} + +// wait waits for the result message to be received by `callback()`, +// then returns it to the caller. +func (rs *resultSubscriber) wait(ctx context.Context) error { + select { + case <-rs.done: + return rs.err + case <-ctx.Done(): + rs.close() + return ctx.Err() + } +} + +func (rs *resultSubscriber) close() { + rs.once.Do(func() { + close(rs.done) + err := rs.app.wsConn.Send(func(lc websocket.LockedConn) error { + lc.Unsubscribe(rs.subscription) + return nil + }) + if err != nil { + slog.Warn("Error unsubscribing", "message_id", rs.subscription.ID()) + } + }) +} diff --git a/schedule.go b/app/schedule.go similarity index 59% rename from schedule.go rename to app/schedule.go index 77a26d5..2c58537 100644 --- a/schedule.go +++ b/app/schedule.go @@ -1,15 +1,14 @@ -package gomeassistant +package app import ( "fmt" - "log/slog" "time" "github.com/golang-module/carbon" "saml.dev/gome-assistant/internal" ) -type ScheduleCallback func(*Service, State) +type ScheduleCallback func() type DailySchedule struct { // 0-23 @@ -31,25 +30,25 @@ type DailySchedule struct { disabledEntities []internal.EnabledDisabledInfo } -func (s DailySchedule) Hash() string { +func (s *DailySchedule) Hash() string { return fmt.Sprint(s.hour, s.minute, s.callback) } type scheduleBuilder struct { - schedule DailySchedule + schedule *DailySchedule } type scheduleBuilderCall struct { - schedule DailySchedule + schedule *DailySchedule } type scheduleBuilderEnd struct { - schedule DailySchedule + schedule *DailySchedule } func NewDailySchedule() scheduleBuilder { return scheduleBuilder{ - DailySchedule{ + &DailySchedule{ hour: 0, minute: 0, sunOffset: "0s", @@ -57,7 +56,7 @@ func NewDailySchedule() scheduleBuilder { } } -func (s DailySchedule) String() string { +func (s *DailySchedule) String() string { return fmt.Sprintf("Schedule{ call %q daily at %s }", internal.GetFunctionName(s.callback), stringHourMinute(s.hour, s.minute), @@ -113,16 +112,22 @@ func (sb scheduleBuilderEnd) OnlyOnDates(t time.Time, tl ...time.Time) scheduleB return sb } -/* -Enable this schedule only when the current state of {entityId} matches {state}. -If there is a network error while retrieving state, the schedule runs if {runOnNetworkError} is true. -*/ -func (sb scheduleBuilderEnd) EnabledWhen(entityId, state string, runOnNetworkError bool) scheduleBuilderEnd { - if entityId == "" { - panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) +// Enable this schedule only when the current state of {entityID} +// matches {state}. If there is a network error while retrieving +// state, the schedule runs if {runOnNetworkError} is true. +func (sb scheduleBuilderEnd) EnabledWhen( + entityID, state string, runOnNetworkError bool, +) scheduleBuilderEnd { + if entityID == "" { + panic( + fmt.Sprintf( + "entityID is empty in EnabledWhen entityID='%s' state='%s'", + entityID, state, + ), + ) } i := internal.EnabledDisabledInfo{ - Entity: entityId, + Entity: entityID, State: state, RunOnError: runOnNetworkError, } @@ -130,16 +135,22 @@ func (sb scheduleBuilderEnd) EnabledWhen(entityId, state string, runOnNetworkErr return sb } -/* -Disable this schedule when the current state of {entityId} matches {state}. -If there is a network error while retrieving state, the schedule runs if {runOnNetworkError} is true. -*/ -func (sb scheduleBuilderEnd) DisabledWhen(entityId, state string, runOnNetworkError bool) scheduleBuilderEnd { - if entityId == "" { - panic(fmt.Sprintf("entityId is empty in EnabledWhen entityId='%s' state='%s'", entityId, state)) +// Disable this schedule when the current state of {entityID} matches +// {state}. If there is a network error while retrieving state, the +// schedule runs if {runOnNetworkError} is true. +func (sb scheduleBuilderEnd) DisabledWhen( + entityID, state string, runOnNetworkError bool, +) scheduleBuilderEnd { + if entityID == "" { + panic( + fmt.Sprintf( + "entityID is empty in EnabledWhen entityID='%s' state='%s'", + entityID, state, + ), + ) } i := internal.EnabledDisabledInfo{ - Entity: entityId, + Entity: entityID, State: state, RunOnError: runOnNetworkError, } @@ -147,69 +158,64 @@ func (sb scheduleBuilderEnd) DisabledWhen(entityId, state string, runOnNetworkEr return sb } -func (sb scheduleBuilderEnd) Build() DailySchedule { +func (sb scheduleBuilderEnd) Build() *DailySchedule { return sb.schedule } -// app.Start() functions -func runSchedules(a *App) { - if a.schedules.Len() == 0 { +func (s *DailySchedule) initializeNextRunTime(app *App) { + // realStartTime already set for sunset/sunrise + if s.isSunrise || s.isSunset { + s.nextRunTime = getNextSunRiseOrSet(app, s.isSunrise, s.sunOffset).Carbon2Time() return } - for { - sched := popSchedule(a) + now := carbon.Now() + startTime := carbon.Now().SetTimeMilli(s.hour, s.minute, 0, 0) - // run callback for all schedules before now in case they overlap - for sched.nextRunTime.Before(time.Now()) { - sched.maybeRunCallback(a) - requeueSchedule(a, sched) + // advance first scheduled time by frequency until it is in the future + if startTime.Lt(now) { + startTime = startTime.AddDay() + } - sched = popSchedule(a) - } + s.nextRunTime = startTime.Carbon2Time() +} - slog.Info("Next schedule", "start_time", sched.nextRunTime) - time.Sleep(time.Until(sched.nextRunTime)) - sched.maybeRunCallback(a) - requeueSchedule(a, sched) - } +func (s *DailySchedule) getNextRunTime() time.Time { + return s.nextRunTime } -func (s DailySchedule) maybeRunCallback(a *App) { +func (s *DailySchedule) shouldRun(app *App) bool { if c := checkExceptionDates(s.exceptionDates); c.fail { - return + return false } if c := checkAllowlistDates(s.allowlistDates); c.fail { - return + return false } - if c := checkEnabledEntity(a.state, s.enabledEntities); c.fail { - return + if c := checkEnabledEntity(app.State, s.enabledEntities); c.fail { + return false } - if c := checkDisabledEntity(a.state, s.disabledEntities); c.fail { - return + if c := checkDisabledEntity(app.State, s.disabledEntities); c.fail { + return false } - go s.callback(a.service, a.state) + return true } -func popSchedule(a *App) DailySchedule { - _sched, _ := a.schedules.Pop() - return _sched.(DailySchedule) +func (s *DailySchedule) run(app *App) { + s.callback() } -func requeueSchedule(a *App, s DailySchedule) { +func (s *DailySchedule) updateNextRunTime(app *App) { if s.isSunrise || s.isSunset { var nextSunTime carbon.Carbon // "0s" is default value if s.sunOffset != "0s" { - nextSunTime = getNextSunRiseOrSet(a, s.isSunrise, s.sunOffset) + nextSunTime = getNextSunRiseOrSet(app, s.isSunrise, s.sunOffset) } else { - nextSunTime = getNextSunRiseOrSet(a, s.isSunrise) + nextSunTime = getNextSunRiseOrSet(app, s.isSunrise) } s.nextRunTime = nextSunTime.Carbon2Time() } else { s.nextRunTime = carbon.Time2Carbon(s.nextRunTime).AddDay().Carbon2Time() } - - a.schedules.Insert(s, float64(s.nextRunTime.Unix())) } diff --git a/app/service.go b/app/service.go new file mode 100644 index 0000000..b41eb13 --- /dev/null +++ b/app/service.go @@ -0,0 +1,56 @@ +package app + +import ( + "saml.dev/gome-assistant/internal/http" + "saml.dev/gome-assistant/internal/services" +) + +type Service struct { + AlarmControlPanel *services.AlarmControlPanel + Climate *services.Climate + Cover *services.Cover + HomeAssistant *services.HomeAssistant + Light *services.Light + Lock *services.Lock + MediaPlayer *services.MediaPlayer + Switch *services.Switch + InputBoolean *services.InputBoolean + InputButton *services.InputButton + InputText *services.InputText + InputDatetime *services.InputDatetime + InputNumber *services.InputNumber + Event *services.Event + Notify *services.Notify + Number *services.Number + Scene *services.Scene + Script *services.Script + TTS *services.TTS + Vacuum *services.Vacuum + ZWaveJS *services.ZWaveJS +} + +func newService(app *App, httpClient *http.HttpClient) *Service { + return &Service{ + AlarmControlPanel: services.NewAlarmControlPanel(app), + Climate: services.NewClimate(app), + Cover: services.NewCover(app), + Light: services.NewLight(app), + HomeAssistant: services.NewHomeAssistant(app), + Lock: services.NewLock(app), + MediaPlayer: services.NewMediaPlayer(app), + Switch: services.NewSwitch(app), + InputBoolean: services.NewInputBoolean(app), + InputButton: services.NewInputButton(app), + InputText: services.NewInputText(app), + InputDatetime: services.NewInputDatetime(app), + InputNumber: services.NewInputNumber(app), + Event: services.NewEvent(app), + Notify: services.NewNotify(app), + Number: services.NewNumber(app), + Scene: services.NewScene(app), + Script: services.NewScript(app), + TTS: services.NewTTS(app), + Vacuum: services.NewVacuum(app), + ZWaveJS: services.NewZWaveJS(app), + } +} diff --git a/state.go b/app/state.go similarity index 60% rename from state.go rename to app/state.go index edc9c91..a834475 100644 --- a/state.go +++ b/app/state.go @@ -1,22 +1,24 @@ -package gomeassistant +package app import ( "encoding/json" "errors" "fmt" - "time" "github.com/golang-module/carbon" "saml.dev/gome-assistant/internal/http" + "saml.dev/gome-assistant/websocket" ) type State interface { + Latitude() float64 + Longitude() float64 AfterSunrise(...DurationString) bool BeforeSunrise(...DurationString) bool AfterSunset(...DurationString) bool BeforeSunset(...DurationString) bool - Get(entityId string) (EntityState, error) - Equals(entityId, state string) (bool, error) + Get(entityID string) (EntityState, error) + Equals(entityID, state string) (bool, error) } // State is used to retrieve state from Home Assistant. @@ -27,25 +29,32 @@ type StateImpl struct { } type EntityState struct { - EntityID string `json:"entity_id"` - State string `json:"state"` - Attributes map[string]any `json:"attributes"` - LastChanged time.Time `json:"last_changed"` + EntityID string `json:"entity_id"` + State string `json:"state"` + Attributes map[string]any `json:"attributes"` + LastChanged websocket.TimeStamp `json:"last_changed"` + + // The whole message, in JSON format: + Raw websocket.RawMessage `json:"-"` } -func newState(c *http.HttpClient, homeZoneEntityId string) (*StateImpl, error) { +func newState(c *http.HttpClient, homeZoneEntityID string) (*StateImpl, error) { state := &StateImpl{httpClient: c} - err := state.getLatLong(c, homeZoneEntityId) + err := state.getLatLong(c, homeZoneEntityID) if err != nil { return nil, err } return state, nil } -func (s *StateImpl) getLatLong(c *http.HttpClient, homeZoneEntityId string) error { - resp, err := s.Get(homeZoneEntityId) +func (s *StateImpl) getLatLong(c *http.HttpClient, homeZoneEntityID string) error { + resp, err := s.Get(homeZoneEntityID) if err != nil { - return fmt.Errorf("couldn't get latitude/longitude from home assistant entity '%s'. Did you type it correctly? It should be a zone like 'zone.home'", homeZoneEntityId) + return fmt.Errorf( + "couldn't get latitude/longitude from home assistant entity '%s'. "+ + "Did you type it correctly? It should be a zone like 'zone.home'", + homeZoneEntityID, + ) } if resp.Attributes["latitude"] != nil { @@ -63,18 +72,27 @@ func (s *StateImpl) getLatLong(c *http.HttpClient, homeZoneEntityId string) erro return nil } -func (s *StateImpl) Get(entityId string) (EntityState, error) { - resp, err := s.httpClient.GetState(entityId) +func (s *StateImpl) Latitude() float64 { + return s.latitude +} + +func (s *StateImpl) Longitude() float64 { + return s.longitude +} + +func (s *StateImpl) Get(entityID string) (EntityState, error) { + resp, err := s.httpClient.GetState(entityID) if err != nil { return EntityState{}, err } es := EntityState{} json.Unmarshal(resp, &es) + es.Raw = resp return es, nil } -func (s *StateImpl) Equals(entityId string, expectedState string) (bool, error) { - currentState, err := s.Get(entityId) +func (s *StateImpl) Equals(entityID string, expectedState string) (bool, error) { + currentState, err := s.Get(entityID) if err != nil { return false, err } diff --git a/eventTypes.go b/eventTypes.go deleted file mode 100644 index cefc6dc..0000000 --- a/eventTypes.go +++ /dev/null @@ -1,29 +0,0 @@ -package gomeassistant - -import "time" - -type EventZWaveJSValueNotification struct { - ID int `json:"id"` - Type string `json:"type"` - Event struct { - EventType string `json:"event_type"` - Data struct { - Domain string `json:"domain"` - NodeID int `json:"node_id"` - HomeID int64 `json:"home_id"` - Endpoint int `json:"endpoint"` - DeviceID string `json:"device_id"` - CommandClass int `json:"command_class"` - CommandClassName string `json:"command_class_name"` - Label string `json:"label"` - Property string `json:"property"` - PropertyName string `json:"property_name"` - PropertyKey string `json:"property_key"` - PropertyKeyName string `json:"property_key_name"` - Value string `json:"value"` - ValueRaw int `json:"value_raw"` - } `json:"data"` - Origin string `json:"origin"` - TimeFired time.Time `json:"time_fired"` - } `json:"event"` -} diff --git a/example/.gitignore b/example/.gitignore new file mode 100644 index 0000000..6f30a3a --- /dev/null +++ b/example/.gitignore @@ -0,0 +1 @@ +/example diff --git a/example/example.go b/example/example.go index 6d77d46..1e64512 100644 --- a/example/example.go +++ b/example/example.go @@ -1,46 +1,60 @@ -package example +package main import ( + "context" "encoding/json" "log/slog" "os" "time" ga "saml.dev/gome-assistant" + "saml.dev/gome-assistant/app" + gaapp "saml.dev/gome-assistant/app" + "saml.dev/gome-assistant/websocket" ) func main() { - app, err := ga.NewApp(ga.NewAppRequest{ - IpAddress: "192.168.86.67", // Replace with your Home Assistant IP Address - HAAuthToken: os.Getenv("HA_AUTH_TOKEN"), - HomeZoneEntityId: "zone.home", - }) + ctx := context.Background() + app, err := gaapp.NewApp( + ctx, + gaapp.NewAppRequest{ + IpAddress: "192.168.86.67", // Replace with your Home Assistant IP Address + HAAuthToken: os.Getenv("HA_AUTH_TOKEN"), + HomeZoneEntityID: "zone.home", + }, + ) if err != nil { slog.Error("Error connecting to HASS:", err) os.Exit(1) } - defer app.Cleanup() + defer app.Close() - pantryDoor := ga. + pantryDoor := gaapp. NewEntityListener(). - EntityIds("binary_sensor.pantry_door"). - Call(pantryLights). + EntityIDs("binary_sensor.pantry_door"). + Call(func(sensor gaapp.EntityData) { + pantryLights(app, sensor) + }). Build() - _11pmSched := ga. + _11pmSched := gaapp. NewDailySchedule(). - Call(lightsOut). + Call(func() { + lightsOut(app) + }). At("23:00"). Build() - _30minsBeforeSunrise := ga. + _30minsBeforeSunrise := gaapp. NewDailySchedule(). - Call(sunriseSched). + Call(func() { + sunriseSched(app) + }). Sunrise("-30m"). Build() - zwaveEventListener := ga. + zwaveEventListener := gaapp. NewEventListener(). EventTypes("zwave_js_value_notification"). Call(onEvent). @@ -50,33 +64,32 @@ func main() { app.RegisterSchedules(_11pmSched, _30minsBeforeSunrise) app.RegisterEventListeners(zwaveEventListener) - app.Start() + app.Start(ctx) } -func pantryLights(service *ga.Service, state ga.State, sensor ga.EntityData) { - l := "light.pantry" +func pantryLights(app *app.App, sensor gaapp.EntityData) { + l := ga.EntityTarget("light.pantry") if sensor.ToState == "on" { - service.HomeAssistant.TurnOn(l) + app.Service.HomeAssistant.TurnOn(l, nil) } else { - service.HomeAssistant.TurnOff(l) + app.Service.HomeAssistant.TurnOff(l) } } -func onEvent(service *ga.Service, state ga.State, data ga.EventData) { - // Since the structure of the event changes depending - // on the event type, you can Unmarshal the raw json - // into a Go type. If a type for your event doesn't - // exist, you can write it yourself! PR's welcome to - // the eventTypes.go file :) - ev := ga.EventZWaveJSValueNotification{} - json.Unmarshal(data.RawEventJSON, &ev) - slog.Info("On event invoked", "event", ev) +func onEvent(ev websocket.Event) { + // Since the structure of the event data changes depending on the + // event type, you can Unmarshal the data into a Go type. If a + // type for your event doesn't exist, you can write it yourself! + // PR's welcome to the eventTypes.go file :) + var data gaapp.ZWaveJSEventData + json.Unmarshal(ev.RawData, &data) + slog.Info("On event invoked", "data", data) } -func lightsOut(service *ga.Service, state ga.State) { +func lightsOut(app *app.App) { // always turn off outside lights - service.Light.TurnOff("light.outside_lights") - s, err := state.Get("binary_sensor.living_room_motion") + app.Service.Light.TurnOff(ga.EntityTarget("light.outside_lights")) + s, err := app.State.Get("binary_sensor.living_room_motion") if err != nil { slog.Warn("couldnt get living room motion state, doing nothing") return @@ -84,11 +97,11 @@ func lightsOut(service *ga.Service, state ga.State) { // if no motion detected in living room for 30mins if s.State == "off" && time.Since(s.LastChanged).Minutes() > 30 { - service.Light.TurnOff("light.main_lights") + app.Service.Light.TurnOff(ga.EntityTarget("light.main_lights")) } } -func sunriseSched(service *ga.Service, state ga.State) { - service.Light.TurnOn("light.living_room_lamps") - service.Light.TurnOff("light.christmas_lights") +func sunriseSched(app *app.App) { + app.Service.Light.TurnOn(ga.EntityTarget("light.living_room_lamps"), nil) + app.Service.Light.TurnOff(ga.EntityTarget("light.christmas_lights")) } diff --git a/example/example_live_test.go b/example/example_live_test.go index e636a22..ce285d4 100644 --- a/example/example_live_test.go +++ b/example/example_live_test.go @@ -1,6 +1,7 @@ -package example +package main import ( + "context" "log/slog" "os" "testing" @@ -10,13 +11,15 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "gopkg.in/yaml.v3" + ga "saml.dev/gome-assistant" + gaapp "saml.dev/gome-assistant/app" ) type ( MySuite struct { suite.Suite - app *ga.App + app *gaapp.App config *Config suiteCtx map[string]any } @@ -26,10 +29,10 @@ type ( HAAuthToken string `yaml:"token"` IpAddress string `yaml:"address"` Port string `yaml:"port"` - HomeZoneEntityId string `yaml:"zone"` + HomeZoneEntityID string `yaml:"zone"` } Entities struct { - LightEntityId string `yaml:"light_entity_id"` + LightEntityID string `yaml:"light_entity_id"` } } ) @@ -45,7 +48,7 @@ func setupLogging() { slog.SetDefault(slog.New(devslog.NewHandler(os.Stdout, opts))) } -func (s *MySuite) SetupSuite() { +func (s *MySuite) SetupSuite(ctx context.Context) { setupLogging() slog.Debug("Setting up test suite...") s.suiteCtx = make(map[string]any) @@ -61,53 +64,62 @@ func (s *MySuite) SetupSuite() { slog.Error("Error unmarshalling config file", err) } - s.app, err = ga.NewApp(ga.NewAppRequest{ - HAAuthToken: s.config.Hass.HAAuthToken, - IpAddress: s.config.Hass.IpAddress, - HomeZoneEntityId: s.config.Hass.HomeZoneEntityId, - }) + s.app, err = gaapp.NewApp( + ctx, + gaapp.NewAppRequest{ + HAAuthToken: s.config.Hass.HAAuthToken, + IpAddress: s.config.Hass.IpAddress, + HomeZoneEntityID: s.config.Hass.HomeZoneEntityID, + }, + ) if err != nil { slog.Error("Failed to createw new app", err) s.T().FailNow() } // Register all automations - entityId := s.config.Entities.LightEntityId - if entityId != "" { + entityID := s.config.Entities.LightEntityID + if entityID != "" { s.suiteCtx["entityCallbackInvoked"] = false - etl := ga.NewEntityListener().EntityIds(entityId).Call(s.entityCallback).Build() + etl := gaapp.NewEntityListener().EntityIDs(entityID).Call(s.entityCallback).Build() s.app.RegisterEntityListeners(etl) } s.suiteCtx["dailyScheduleCallbackInvoked"] = false runTime := time.Now().Add(1 * time.Minute).Format("15:04") - dailySchedule := ga.NewDailySchedule().Call(s.dailyScheduleCallback).At(runTime).Build() + dailySchedule := gaapp.NewDailySchedule().Call(s.dailyScheduleCallback).At(runTime).Build() s.app.RegisterSchedules(dailySchedule) // start GA app - go s.app.Start() + go s.app.Start(ctx) } func (s *MySuite) TearDownSuite() { if s.app != nil { - s.app.Cleanup() + s.app.Close() s.app = nil } } // Basic test of light toggle service and entity listener func (s *MySuite) TestLightService() { - entityId := s.config.Entities.LightEntityId - - if entityId != "" { - initState := getEntityState(s, entityId) - s.app.GetService().Light.Toggle(entityId) - - assert.EventuallyWithT(s.T(), func(c *assert.CollectT) { - newState := getEntityState(s, entityId) - assert.NotEqual(c, initState, newState) - assert.True(c, s.suiteCtx["entityCallbackInvoked"].(bool)) - }, 10*time.Second, 1*time.Second, "State of light entity did not change or callback was not invoked") + entityID := s.config.Entities.LightEntityID + + if entityID != "" { + target := ga.EntityTarget(entityID) + initState := getEntityState(s, entityID) + s.app.GetService().Light.Toggle(target, nil) + + assert.EventuallyWithT( + s.T(), + func(c *assert.CollectT) { + newState := getEntityState(s, entityID) + assert.NotEqual(c, initState, newState) + assert.True(c, s.suiteCtx["entityCallbackInvoked"].(bool)) + }, + 10*time.Second, 1*time.Second, + "State of light entity did not change or callback was not invoked", + ) } else { s.T().Skip("No light entity id provided") } @@ -121,19 +133,24 @@ func (s *MySuite) TestSchedule() { } // Capture event after light entity state has changed -func (s *MySuite) entityCallback(se *ga.Service, st ga.State, e ga.EntityData) { - slog.Info("Entity callback called.", "entity id", e.TriggerEntityId, "from state", e.FromState, "to state", e.ToState) +func (s *MySuite) entityCallback(e gaapp.EntityData) { + slog.Info( + "Entity callback called.", + "entity id", e.TriggerEntityID, + "from state", e.FromState, + "to state", e.ToState, + ) s.suiteCtx["entityCallbackInvoked"] = true } // Capture planned daily schedule -func (s *MySuite) dailyScheduleCallback(se *ga.Service, st ga.State) { +func (s *MySuite) dailyScheduleCallback() { slog.Info("Daily schedule callback called.") s.suiteCtx["dailyScheduleCallbackInvoked"] = true } -func getEntityState(s *MySuite, entityId string) string { - state, err := s.app.GetState().Get(entityId) +func getEntityState(s *MySuite, entityID string) string { + state, err := s.app.GetState().Get(entityID) if err != nil { slog.Error("Error getting entity state", err) s.T().FailNow() diff --git a/example/go.mod b/example/go.mod deleted file mode 100644 index 49184a0..0000000 --- a/example/go.mod +++ /dev/null @@ -1,24 +0,0 @@ -module example - -go 1.21 - -require ( - github.com/golang-cz/devslog v0.0.8 - github.com/stretchr/testify v1.8.4 - gopkg.in/yaml.v3 v3.0.1 - saml.dev/gome-assistant v0.2.0 -) - -require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/gobuffalo/envy v1.10.2 // indirect - github.com/gobuffalo/packd v1.0.2 // indirect - github.com/gobuffalo/packr v1.30.1 // indirect - github.com/golang-module/carbon v1.7.3 // indirect - github.com/gorilla/websocket v1.5.0 // indirect - github.com/joho/godotenv v1.5.1 // indirect - github.com/nathan-osman/go-sunrise v1.1.0 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.11.0 // indirect - golang.org/x/mod v0.9.0 // indirect -) diff --git a/example/go.sum b/example/go.sum deleted file mode 100644 index 232d6d1..0000000 --- a/example/go.sum +++ /dev/null @@ -1,100 +0,0 @@ -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= -github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= -github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/gobuffalo/envy v1.7.0/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= -github.com/gobuffalo/envy v1.10.2 h1:EIi03p9c3yeuRCFPOKcSfajzkLb3hrRjEpHGI8I2Wo4= -github.com/gobuffalo/envy v1.10.2/go.mod h1:qGAGwdvDsaEtPhfBzb3o0SfDea8ByGn9j8bKmVft9z8= -github.com/gobuffalo/logger v1.0.0/go.mod h1:2zbswyIUa45I+c+FLXuWl9zSWEiVuthsk8ze5s8JvPs= -github.com/gobuffalo/packd v0.3.0/go.mod h1:zC7QkmNkYVGKPw4tHpBQ+ml7W/3tIebgeo1b36chA3Q= -github.com/gobuffalo/packd v1.0.2 h1:Yg523YqnOxGIWCp69W12yYBKsoChwI7mtu6ceM9Bwfw= -github.com/gobuffalo/packd v1.0.2/go.mod h1:sUc61tDqGMXON80zpKGp92lDb86Km28jfvX7IAyxFT8= -github.com/gobuffalo/packr v1.30.1 h1:hu1fuVR3fXEZR7rXNW3h8rqSML8EVAf6KNm0NKO/wKg= -github.com/gobuffalo/packr v1.30.1/go.mod h1:ljMyFO2EcrnzsHsN99cvbq055Y9OhRrIaviy289eRuk= -github.com/gobuffalo/packr/v2 v2.5.1/go.mod h1:8f9c96ITobJlPzI44jj+4tHnEKNt0xXWSVlXRN9X1Iw= -github.com/golang-cz/devslog v0.0.8 h1:53ipA2rC5JzWBWr9qB8EfenvXppenNiF/8DwgtNT5Q4= -github.com/golang-cz/devslog v0.0.8/go.mod h1:bSe5bm0A7Nyfqtijf1OMNgVJHlWEuVSXnkuASiE1vV8= -github.com/golang-module/carbon v1.7.3 h1:p5mUZj7Tg62MblrkF7XEoxVPvhVs20N/kimqsZOQ+/U= -github.com/golang-module/carbon v1.7.3/go.mod h1:nUMnXq90Rv8a7h2+YOo2BGKS77Y0w/hMPm4/a8h19N8= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= -github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= -github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= -github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= -github.com/karrick/godirwalk v1.10.12/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/nathan-osman/go-sunrise v1.1.0 h1:ZqZmtmtzs8Os/DGQYi0YMHpuUqR/iRoJK+wDO0wTCw8= -github.com/nathan-osman/go-sunrise v1.1.0/go.mod h1:RcWqhT+5ShCZDev79GuWLayetpJp78RSjSWxiDowmlM= -github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= -github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= -github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= -github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= -github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= -github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= -github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= -golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= -golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190515120540-06a5c4944438/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190624180213-70d37148ca0c/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -saml.dev/gome-assistant v0.2.0 h1:Clo5DrziTdsYydVUTQfroeBVmToMnNHoObr+k6HhbMY= -saml.dev/gome-assistant v0.2.0/go.mod h1:jsZUtnxANCP0zB2B7iyy4j7sZohMGop8g+5EB2MER3o= diff --git a/go.mod b/go.mod index 4c9b4d1..92e2948 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module saml.dev/gome-assistant -go 1.21 +go 1.21.0 + +toolchain go1.21.6 require ( github.com/golang-module/carbon v1.7.1 @@ -13,10 +15,12 @@ require ( github.com/gobuffalo/envy v1.10.2 // indirect github.com/gobuffalo/packd v1.0.2 // indirect github.com/gobuffalo/packr v1.30.1 // indirect + github.com/golang-cz/devslog v0.0.8 // indirect github.com/joho/godotenv v1.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.9.0 // indirect github.com/stretchr/objx v0.5.0 // indirect github.com/stretchr/testify v1.8.4 // indirect + golang.org/x/sync v0.6.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index e5e5f2f..f34677d 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,8 @@ github.com/gobuffalo/packd v1.0.2/go.mod h1:sUc61tDqGMXON80zpKGp92lDb86Km28jfvX7 github.com/gobuffalo/packr v1.30.1 h1:hu1fuVR3fXEZR7rXNW3h8rqSML8EVAf6KNm0NKO/wKg= github.com/gobuffalo/packr v1.30.1/go.mod h1:ljMyFO2EcrnzsHsN99cvbq055Y9OhRrIaviy289eRuk= github.com/gobuffalo/packr/v2 v2.5.1/go.mod h1:8f9c96ITobJlPzI44jj+4tHnEKNt0xXWSVlXRN9X1Iw= +github.com/golang-cz/devslog v0.0.8 h1:53ipA2rC5JzWBWr9qB8EfenvXppenNiF/8DwgtNT5Q4= +github.com/golang-cz/devslog v0.0.8/go.mod h1:bSe5bm0A7Nyfqtijf1OMNgVJHlWEuVSXnkuASiE1vV8= github.com/golang-module/carbon v1.7.1 h1:EDPV0YjxeS2kE2cRedfGgDikU6l5D79HB/teHuZDLu8= github.com/golang-module/carbon v1.7.1/go.mod h1:M/TDTYPp3qWtW68u49dLDJOyGmls6L6BXdo/pyvkMaU= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= @@ -75,6 +77,8 @@ golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4/go.mod h1:yigFU9vqHzYiE8U golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/http/http.go b/internal/http/http.go index 2f89643..f541e4c 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -36,8 +36,8 @@ func ClientFromUri(uri, token string) *HttpClient { } } -func (c *HttpClient) GetState(entityId string) ([]byte, error) { - resp, err := get(c.url+"/states/"+entityId, c.token) +func (c *HttpClient) GetState(entityID string) ([]byte, error) { + resp, err := get(c.url+"/states/"+entityID, c.token) if err != nil { return nil, err } diff --git a/internal/internal.go b/internal/internal.go index f5efc14..c730a0a 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -16,18 +16,13 @@ type EnabledDisabledInfo struct { RunOnError bool } -var id int64 = 0 - -func GetId() int64 { - id += 1 - return id -} - // Parses a HH:MM string. func ParseTime(s string) carbon.Carbon { t, err := time.Parse("15:04", s) if err != nil { - parsingErr := fmt.Errorf("failed to parse time string \"%s\"; format must be HH:MM.: %w", s, err) + parsingErr := fmt.Errorf( + "failed to parse time string \"%s\"; format must be HH:MM.: %w", s, err, + ) slog.Error(parsingErr.Error()) panic(parsingErr) } @@ -37,7 +32,11 @@ func ParseTime(s string) carbon.Carbon { func ParseDuration(s string) time.Duration { d, err := time.ParseDuration(s) if err != nil { - parsingErr := fmt.Errorf("couldn't parse string duration: \"%s\" see https://pkg.go.dev/time#ParseDuration for valid time units: %w", s, err) + parsingErr := fmt.Errorf( + "couldn't parse string duration: \"%s\" see "+ + "https://pkg.go.dev/time#ParseDuration for valid time units: %w", + s, err, + ) slog.Error(parsingErr.Error()) panic(parsingErr) } diff --git a/internal/services/alarm_control_panel.go b/internal/services/alarm_control_panel.go index 5b0756e..4406ba3 100644 --- a/internal/services/alarm_control_panel.go +++ b/internal/services/alarm_control_panel.go @@ -3,112 +3,123 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type AlarmControlPanel struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service } /* Public API */ -// Send the alarm the command for arm away. -// Takes an entityId and an optional -// map that is translated into service_data. -func (acp AlarmControlPanel) ArmAway(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "alarm_control_panel" - req.Service = "alarm_arm_away" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func NewAlarmControlPanel(service Service) *AlarmControlPanel { + return &AlarmControlPanel{ + service: service, } +} - acp.conn.WriteMessage(req, acp.ctx) +// Send the alarm the command for arm away. +func (acp AlarmControlPanel) ArmAway(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := acp.service.CallService( + ctx, "alarm_control_panel", "alarm_arm_away", + serviceData, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Send the alarm the command for arm away. -// Takes an entityId and an optional +// Takes an entityID and an optional // map that is translated into service_data. -func (acp AlarmControlPanel) ArmWithCustomBypass(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "alarm_control_panel" - req.Service = "alarm_arm_custom_bypass" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (acp AlarmControlPanel) ArmWithCustomBypass( + target ga.Target, serviceData any, +) (any, error) { + ctx := context.TODO() + var result any + err := acp.service.CallService( + ctx, "alarm_control_panel", "alarm_arm_custom_bypass", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - acp.conn.WriteMessage(req, acp.ctx) + return result, nil } // Send the alarm the command for arm home. -// Takes an entityId and an optional +// Takes an entityID and an optional // map that is translated into service_data. -func (acp AlarmControlPanel) ArmHome(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "alarm_control_panel" - req.Service = "alarm_arm_home" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (acp AlarmControlPanel) ArmHome(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := acp.service.CallService( + ctx, "alarm_control_panel", "alarm_arm_home", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - acp.conn.WriteMessage(req, acp.ctx) + return result, nil } // Send the alarm the command for arm night. -// Takes an entityId and an optional -// map that is translated into service_data. -func (acp AlarmControlPanel) ArmNight(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "alarm_control_panel" - req.Service = "alarm_arm_night" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (acp AlarmControlPanel) ArmNight(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := acp.service.CallService( + ctx, "alarm_control_panel", "alarm_arm_night", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - acp.conn.WriteMessage(req, acp.ctx) + return result, nil } // Send the alarm the command for arm vacation. -// Takes an entityId and an optional -// map that is translated into service_data. -func (acp AlarmControlPanel) ArmVacation(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "alarm_control_panel" - req.Service = "alarm_arm_vacation" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (acp AlarmControlPanel) ArmVacation(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := acp.service.CallService( + ctx, "alarm_control_panel", "alarm_arm_vacation", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - acp.conn.WriteMessage(req, acp.ctx) + return result, nil } // Send the alarm the command for disarm. -// Takes an entityId and an optional -// map that is translated into service_data. -func (acp AlarmControlPanel) Disarm(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "alarm_control_panel" - req.Service = "alarm_disarm" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (acp AlarmControlPanel) Disarm(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := acp.service.CallService( + ctx, "alarm_control_panel", "alarm_disarm", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - acp.conn.WriteMessage(req, acp.ctx) + return result, nil } // Send the alarm the command for trigger. -// Takes an entityId and an optional -// map that is translated into service_data. -func (acp AlarmControlPanel) Trigger(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "alarm_control_panel" - req.Service = "alarm_trigger" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (acp AlarmControlPanel) Trigger(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := acp.service.CallService( + ctx, "alarm_control_panel", "alarm_trigger", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - acp.conn.WriteMessage(req, acp.ctx) + return result, nil } diff --git a/internal/services/climate.go b/internal/services/climate.go index 797215f..2cf37f3 100644 --- a/internal/services/climate.go +++ b/internal/services/climate.go @@ -3,33 +3,71 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" - "saml.dev/gome-assistant/types" + ga "saml.dev/gome-assistant" ) /* Structs */ type Climate struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service } -/* Public API */ +func NewClimate(service Service) *Climate { + return &Climate{ + service: service, + } +} -func (c Climate) SetFanMode(entityId string, fanMode string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "climate" - req.Service = "set_fan_mode" - req.ServiceData = map[string]any{"fan_mode": fanMode} +func (c Climate) SetFanMode(target ga.Target, fanMode string) (any, error) { + ctx := context.TODO() + var result any + err := c.service.CallService( + ctx, "climate", "set_fan_mode", + map[string]any{"fan_mode": fanMode}, + target, &result, + ) + if err != nil { + return nil, err + } + return result, nil +} - c.conn.WriteMessage(req, c.ctx) +type SetTemperatureRequest struct { + Temperature *float32 + TargetTempHigh *float32 + TargetTempLow *float32 + HvacMode string } -func (c Climate) SetTemperature(entityId string, serviceData types.SetTemperatureRequest) { - req := NewBaseServiceRequest(entityId) - req.Domain = "climate" - req.Service = "set_temperature" - req.ServiceData = serviceData.ToJSON() +func (r *SetTemperatureRequest) ToJSON() map[string]any { + m := map[string]any{} + if r.Temperature != nil { + m["temperature"] = *r.Temperature + } + if r.TargetTempHigh != nil { + m["target_temp_high"] = *r.TargetTempHigh + } + if r.TargetTempLow != nil { + m["target_temp_low"] = *r.TargetTempLow + } + if r.HvacMode != "" { + m["hvac_mode"] = r.HvacMode + } + return m +} - c.conn.WriteMessage(req, c.ctx) +func (c Climate) SetTemperature( + target ga.Target, setTemperatureRequest SetTemperatureRequest, +) (any, error) { + ctx := context.TODO() + var result any + err := c.service.CallService( + ctx, "climate", "set_temperature", + setTemperatureRequest.ToJSON(), + target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/cover.go b/internal/services/cover.go index 8fb6e75..9ac0569 100644 --- a/internal/services/cover.go +++ b/internal/services/cover.go @@ -3,112 +3,161 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type Cover struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service } -/* Public API */ - -// Close all or specified cover. Takes an entityId. -func (c Cover) Close(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "cover" - req.Service = "close_cover" - - c.conn.WriteMessage(req, c.ctx) +func NewCover(service Service) *Cover { + return &Cover{ + service: service, + } } -// Close all or specified cover tilt. Takes an entityId. -func (c Cover) CloseTilt(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "cover" - req.Service = "close_cover_tilt" +/* Public API */ - c.conn.WriteMessage(req, c.ctx) +// Close all or specified cover. Takes an entityID. +func (c Cover) Close(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := c.service.CallService( + ctx, "cover", "close_cover", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -// Open all or specified cover. Takes an entityId. -func (c Cover) Open(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "cover" - req.Service = "open_cover" - - c.conn.WriteMessage(req, c.ctx) +// Close all or specified cover tilt. Takes an entityID. +func (c Cover) CloseTilt(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := c.service.CallService( + ctx, "cover", "close_cover_tilt", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -// Open all or specified cover tilt. Takes an entityId. -func (c Cover) OpenTilt(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "cover" - req.Service = "open_cover_tilt" +// Open all or specified cover. Takes an entityID. +func (c Cover) Open(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := c.service.CallService( + ctx, "cover", "open_cover", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil +} - c.conn.WriteMessage(req, c.ctx) +// Open all or specified cover tilt. Takes an entityID. +func (c Cover) OpenTilt(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := c.service.CallService( + ctx, "cover", "open_cover_tilt", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -// Move to specific position all or specified cover. Takes an entityId and an optional +// Move to specific position all or specified cover. Takes an entityID and an optional // map that is translated into service_data. -func (c Cover) SetPosition(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "cover" - req.Service = "set_cover_position" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (c Cover) SetPosition(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := c.service.CallService( + ctx, "cover", "set_cover_position", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - c.conn.WriteMessage(req, c.ctx) + return result, nil } -// Move to specific position all or specified cover tilt. Takes an entityId and an optional +// Move to specific position all or specified cover tilt. Takes an entityID and an optional // map that is translated into service_data. -func (c Cover) SetTiltPosition(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "cover" - req.Service = "set_cover_tilt_position" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (c Cover) SetTiltPosition(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := c.service.CallService( + ctx, "cover", "set_cover_tilt_position", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - c.conn.WriteMessage(req, c.ctx) + return result, nil } -// Stop a cover entity. Takes an entityId. -func (c Cover) Stop(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "cover" - req.Service = "stop_cover" - - c.conn.WriteMessage(req, c.ctx) +// Stop a cover entity. Takes an entityID. +func (c Cover) Stop(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := c.service.CallService( + ctx, "cover", "stop_cover", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -// Stop a cover entity tilt. Takes an entityId. -func (c Cover) StopTilt(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "cover" - req.Service = "stop_cover_tilt" - - c.conn.WriteMessage(req, c.ctx) +// Stop a cover entity tilt. Takes an entityID. +func (c Cover) StopTilt(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := c.service.CallService( + ctx, "cover", "stop_cover_tilt", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -// Toggle a cover open/closed. Takes an entityId. -func (c Cover) Toggle(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "cover" - req.Service = "toggle" - - c.conn.WriteMessage(req, c.ctx) +// Toggle a cover open/closed. Takes an entityID. +func (c Cover) Toggle(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := c.service.CallService( + ctx, "cover", "toggle", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -// Toggle a cover tilt open/closed. Takes an entityId. -func (c Cover) ToggleTilt(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "cover" - req.Service = "toggle_cover_tilt" - - c.conn.WriteMessage(req, c.ctx) +// Toggle a cover tilt open/closed. Takes an entityID. +func (c Cover) ToggleTilt(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := c.service.CallService( + ctx, "cover", "toggle_cover_tilt", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/event.go b/internal/services/event.go index 9205db1..3c50d41 100644 --- a/internal/services/event.go +++ b/internal/services/event.go @@ -3,19 +3,22 @@ package services import ( "context" - "saml.dev/gome-assistant/internal" - ws "saml.dev/gome-assistant/internal/websocket" + "saml.dev/gome-assistant/websocket" ) type Event struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service +} + +func NewEvent(service Service) *Event { + return &Event{ + service: service, + } } // Fire an event type FireEventRequest struct { - Id int64 `json:"id"` - Type string `json:"type"` // always set to "fire_event" + websocket.BaseMessage EventType string `json:"event_type"` EventData map[string]any `json:"event_data,omitempty"` } @@ -24,16 +27,18 @@ type FireEventRequest struct { // Fire an event. Takes an event type and an optional map that is sent // as `event_data`. -func (e Event) Fire(eventType string, eventData ...map[string]any) { +func (e Event) Fire(eventType string, eventData map[string]any) error { + ctx := context.TODO() + req := FireEventRequest{ - Id: internal.GetId(), - Type: "fire_event", + BaseMessage: websocket.BaseMessage{ + Type: "fire_event", + }, } req.EventType = eventType - if len(eventData) != 0 { - req.EventData = eventData[0] - } + req.EventData = eventData - e.conn.WriteMessage(req, e.ctx) + var result any + return e.service.Call(ctx, &req, &result) } diff --git a/internal/services/homeassistant.go b/internal/services/homeassistant.go index 8400509..d683433 100644 --- a/internal/services/homeassistant.go +++ b/internal/services/homeassistant.go @@ -3,44 +3,58 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) type HomeAssistant struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service } -// TurnOn a Home Assistant entity. Takes an entityId and an optional -// map that is translated into service_data. -func (ha *HomeAssistant) TurnOn(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "homeassistant" - req.Service = "turn_on" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func NewHomeAssistant(service Service) *HomeAssistant { + return &HomeAssistant{ + service: service, } - - ha.conn.WriteMessage(req, ha.ctx) } -// Toggle a Home Assistant entity. Takes an entityId and an optional +// TurnOn a Home Assistant entity. Takes an entityID and an optional // map that is translated into service_data. -func (ha *HomeAssistant) Toggle(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "homeassistant" - req.Service = "toggle" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (ha *HomeAssistant) TurnOn(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := ha.service.CallService( + ctx, "homeassistant", "turn_on", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - ha.conn.WriteMessage(req, ha.ctx) + return result, nil } -func (ha *HomeAssistant) TurnOff(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "homeassistant" - req.Service = "turn_off" +// Toggle a Home Assistant entity. Takes an entityID and an optional +// map that is translated into service_data. +func (ha *HomeAssistant) Toggle(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := ha.service.CallService( + ctx, "homeassistant", "toggle", + serviceData, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil +} - ha.conn.WriteMessage(req, ha.ctx) +func (ha *HomeAssistant) TurnOff(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := ha.service.CallService( + ctx, "homeassistant", "turn_off", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/input_boolean.go b/internal/services/input_boolean.go index ac589f7..62884df 100644 --- a/internal/services/input_boolean.go +++ b/internal/services/input_boolean.go @@ -3,44 +3,70 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type InputBoolean struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service } -/* Public API */ +func NewInputBoolean(service Service) *InputBoolean { + return &InputBoolean{ + service: service, + } +} -func (ib InputBoolean) TurnOn(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "input_boolean" - req.Service = "turn_on" +/* Public API */ - ib.conn.WriteMessage(req, ib.ctx) +func (ib InputBoolean) TurnOn(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_boolean", "turn_on", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -func (ib InputBoolean) Toggle(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "input_boolean" - req.Service = "toggle" - - ib.conn.WriteMessage(req, ib.ctx) +func (ib InputBoolean) Toggle(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_boolean", "toggle", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -func (ib InputBoolean) TurnOff(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "input_boolean" - req.Service = "turn_off" - ib.conn.WriteMessage(req, ib.ctx) +func (ib InputBoolean) TurnOff(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_boolean", "turn_off", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -func (ib InputBoolean) Reload() { - req := NewBaseServiceRequest("") - req.Domain = "input_boolean" - req.Service = "reload" - ib.conn.WriteMessage(req, ib.ctx) +func (ib InputBoolean) Reload() (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_boolean", "reload", nil, ga.Target{}, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/input_button.go b/internal/services/input_button.go index e0ec541..45c93b1 100644 --- a/internal/services/input_button.go +++ b/internal/services/input_button.go @@ -3,29 +3,44 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type InputButton struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service } -/* Public API */ +func NewInputButton(service Service) *InputButton { + return &InputButton{ + service: service, + } +} -func (ib InputButton) Press(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "input_button" - req.Service = "press" +/* Public API */ - ib.conn.WriteMessage(req, ib.ctx) +func (ib InputButton) Press(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_button", "press", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -func (ib InputButton) Reload() { - req := NewBaseServiceRequest("") - req.Domain = "input_button" - req.Service = "reload" - ib.conn.WriteMessage(req, ib.ctx) +func (ib InputButton) Reload() (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_button", "reload", nil, ga.Target{}, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/input_datetime.go b/internal/services/input_datetime.go index 92c12d5..f206fdb 100644 --- a/internal/services/input_datetime.go +++ b/internal/services/input_datetime.go @@ -5,32 +5,47 @@ import ( "fmt" "time" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type InputDatetime struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service +} + +func NewInputDatetime(service Service) *InputDatetime { + return &InputDatetime{ + service: service, + } } /* Public API */ -func (ib InputDatetime) Set(entityId string, value time.Time) { - req := NewBaseServiceRequest(entityId) - req.Domain = "input_datetime" - req.Service = "set_datetime" - req.ServiceData = map[string]any{ - "timestamp": fmt.Sprint(value.Unix()), +func (ib InputDatetime) Set(target ga.Target, value time.Time) (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_datetime", "set_datetime", + map[string]any{ + "timestamp": fmt.Sprint(value.Unix()), + }, + target, &result, + ) + if err != nil { + return nil, err } - - ib.conn.WriteMessage(req, ib.ctx) + return result, nil } -func (ib InputDatetime) Reload() { - req := NewBaseServiceRequest("") - req.Domain = "input_datetime" - req.Service = "reload" - ib.conn.WriteMessage(req, ib.ctx) +func (ib InputDatetime) Reload() (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_datetime", "reload", nil, ga.Target{}, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/input_number.go b/internal/services/input_number.go index 59409f6..b803960 100644 --- a/internal/services/input_number.go +++ b/internal/services/input_number.go @@ -3,46 +3,71 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type InputNumber struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service } -/* Public API */ - -func (ib InputNumber) Set(entityId string, value float32) { - req := NewBaseServiceRequest(entityId) - req.Domain = "input_number" - req.Service = "set_value" - req.ServiceData = map[string]any{"value": value} - - ib.conn.WriteMessage(req, ib.ctx) +func NewInputNumber(service Service) *InputNumber { + return &InputNumber{ + service: service, + } } -func (ib InputNumber) Increment(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "input_number" - req.Service = "increment" +/* Public API */ - ib.conn.WriteMessage(req, ib.ctx) +func (ib InputNumber) Set(target ga.Target, value float32) (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_number", "set_value", + map[string]any{"value": value}, + target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -func (ib InputNumber) Decrement(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "input_number" - req.Service = "decrement" +func (ib InputNumber) Increment(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_number", "increment", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil +} - ib.conn.WriteMessage(req, ib.ctx) +func (ib InputNumber) Decrement(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_number", "decrement", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -func (ib InputNumber) Reload() { - req := NewBaseServiceRequest("") - req.Domain = "input_number" - req.Service = "reload" - ib.conn.WriteMessage(req, ib.ctx) +func (ib InputNumber) Reload() (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_number", "reload", nil, ga.Target{}, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/input_text.go b/internal/services/input_text.go index b7a0d1a..26e0f76 100644 --- a/internal/services/input_text.go +++ b/internal/services/input_text.go @@ -3,32 +3,47 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type InputText struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service +} + +func NewInputText(service Service) *InputText { + return &InputText{ + service: service, + } } /* Public API */ -func (ib InputText) Set(entityId string, value string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "input_text" - req.Service = "set_value" - req.ServiceData = map[string]any{ - "value": value, +func (ib InputText) Set(target ga.Target, value string) (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_text", "set_value", + map[string]any{ + "value": value, + }, + target, &result, + ) + if err != nil { + return nil, err } - - ib.conn.WriteMessage(req, ib.ctx) + return result, nil } -func (ib InputText) Reload() { - req := NewBaseServiceRequest("") - req.Domain = "input_text" - req.Service = "reload" - ib.conn.WriteMessage(req, ib.ctx) +func (ib InputText) Reload() (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "input_text", "reload", nil, ga.Target{}, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/light.go b/internal/services/light.go index c1a2179..ecd7576 100644 --- a/internal/services/light.go +++ b/internal/services/light.go @@ -3,47 +3,57 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type Light struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service +} + +func NewLight(service Service) *Light { + return &Light{ + service: service, + } } /* Public API */ -// TurnOn a light entity. Takes an entityId and an optional -// map that is translated into service_data. -func (l Light) TurnOn(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "light" - req.Service = "turn_on" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +// TurnOn a light entity. +func (l Light) TurnOn(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := l.service.CallService( + ctx, "light", "turn_on", serviceData, target, &result, + ) + if err != nil { + return nil, err } - - l.conn.WriteMessage(req, l.ctx) + return result, nil } -// Toggle a light entity. Takes an entityId and an optional -// map that is translated into service_data. -func (l Light) Toggle(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "light" - req.Service = "toggle" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +// Toggle a light entity. +func (l Light) Toggle(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := l.service.CallService( + ctx, "light", "toggle", serviceData, target, &result, + ) + if err != nil { + return nil, err } - - l.conn.WriteMessage(req, l.ctx) + return result, nil } -func (l Light) TurnOff(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "light" - req.Service = "turn_off" - l.conn.WriteMessage(req, l.ctx) +func (l Light) TurnOff(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := l.service.CallService( + ctx, "light", "turn_off", nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/lock.go b/internal/services/lock.go index e122b25..086f49b 100644 --- a/internal/services/lock.go +++ b/internal/services/lock.go @@ -3,40 +3,47 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type Lock struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service +} + +func NewLock(service Service) *Lock { + return &Lock{ + service: service, + } } /* Public API */ -// Lock a lock entity. Takes an entityId and an optional -// map that is translated into service_data. -func (l Lock) Lock(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "lock" - req.Service = "lock" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +// Lock a lock entity. +func (l Lock) Lock(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := l.service.CallService( + ctx, "lock", "lock", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - l.conn.WriteMessage(req, l.ctx) + return result, nil } -// Unlock a lock entity. Takes an entityId and an optional -// map that is translated into service_data. -func (l Lock) Unlock(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "lock" - req.Service = "unlock" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +// Unlock a lock entity. +func (l Lock) Unlock(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := l.service.CallService( + ctx, "lock", "unlock", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - l.conn.WriteMessage(req, l.ctx) + return result, nil } diff --git a/internal/services/media_player.go b/internal/services/media_player.go index 727d7a9..ed6eb53 100644 --- a/internal/services/media_player.go +++ b/internal/services/media_player.go @@ -3,270 +3,328 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type MediaPlayer struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service +} + +func NewMediaPlayer(service Service) *MediaPlayer { + return &MediaPlayer{ + service: service, + } } /* Public API */ // Send the media player the command to clear players playlist. -// Takes an entityId. -func (mp MediaPlayer) ClearPlaylist(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "clear_playlist" - - mp.conn.WriteMessage(req, mp.ctx) +func (mp MediaPlayer) ClearPlaylist(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "clear_playlist", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Group players together. Only works on platforms with support for player groups. -// Takes an entityId and an optional -// map that is translated into service_data. -func (mp MediaPlayer) Join(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "join" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (mp MediaPlayer) Join(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "join", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - mp.conn.WriteMessage(req, mp.ctx) + return result, nil } // Send the media player the command for next track. -// Takes an entityId. -func (mp MediaPlayer) Next(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "media_next_track" - - mp.conn.WriteMessage(req, mp.ctx) +func (mp MediaPlayer) Next(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "media_next_track", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Send the media player the command for pause. -// Takes an entityId. -func (mp MediaPlayer) Pause(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "media_pause" - - mp.conn.WriteMessage(req, mp.ctx) +func (mp MediaPlayer) Pause(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "media_pause", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Send the media player the command for play. -// Takes an entityId. -func (mp MediaPlayer) Play(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "media_play" - - mp.conn.WriteMessage(req, mp.ctx) +func (mp MediaPlayer) Play(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "media_play", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Toggle media player play/pause state. -// Takes an entityId. -func (mp MediaPlayer) PlayPause(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "media_play_pause" - - mp.conn.WriteMessage(req, mp.ctx) +func (mp MediaPlayer) PlayPause(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "media_play_pause", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Send the media player the command for previous track. -// Takes an entityId. -func (mp MediaPlayer) Previous(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "media_previous_track" - - mp.conn.WriteMessage(req, mp.ctx) +func (mp MediaPlayer) Previous(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "media_previous_track", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Send the media player the command to seek in current playing media. -// Takes an entityId and an optional -// map that is translated into service_data. -func (mp MediaPlayer) Seek(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "media_seek" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (mp MediaPlayer) Seek(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "media_seek", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - mp.conn.WriteMessage(req, mp.ctx) + return result, nil } // Send the media player the stop command. -// Takes an entityId. -func (mp MediaPlayer) Stop(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "media_stop" - - mp.conn.WriteMessage(req, mp.ctx) +func (mp MediaPlayer) Stop(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "media_stop", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Send the media player the command for playing media. -// Takes an entityId and an optional -// map that is translated into service_data. -func (mp MediaPlayer) PlayMedia(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "play_media" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (mp MediaPlayer) PlayMedia(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "play_media", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - mp.conn.WriteMessage(req, mp.ctx) + return result, nil } -// Set repeat mode. Takes an entityId and an optional -// map that is translated into service_data. -func (mp MediaPlayer) RepeatSet(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "repeat_set" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +// Set repeat mode. +func (mp MediaPlayer) RepeatSet(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "repeat_set", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - mp.conn.WriteMessage(req, mp.ctx) + return result, nil } // Send the media player the command to change sound mode. -// Takes an entityId and an optional -// map that is translated into service_data. -func (mp MediaPlayer) SelectSoundMode(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "select_sound_mode" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (mp MediaPlayer) SelectSoundMode(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "select_sound_mode", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - mp.conn.WriteMessage(req, mp.ctx) + return result, nil } // Send the media player the command to change input source. -// Takes an entityId and an optional -// map that is translated into service_data. -func (mp MediaPlayer) SelectSource(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "select_source" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (mp MediaPlayer) SelectSource(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "select_source", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - mp.conn.WriteMessage(req, mp.ctx) + return result, nil } // Set shuffling state. -// Takes an entityId and an optional -// map that is translated into service_data. -func (mp MediaPlayer) Shuffle(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "shuffle_set" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (mp MediaPlayer) Shuffle(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "shuffle_set", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - mp.conn.WriteMessage(req, mp.ctx) + return result, nil } // Toggles a media player power state. -// Takes an entityId. -func (mp MediaPlayer) Toggle(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "toggle" - - mp.conn.WriteMessage(req, mp.ctx) +func (mp MediaPlayer) Toggle(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "toggle", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Turn a media player power off. -// Takes an entityId. -func (mp MediaPlayer) TurnOff(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "turn_off" - - mp.conn.WriteMessage(req, mp.ctx) +func (mp MediaPlayer) TurnOff(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "turn_off", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Turn a media player power on. -// Takes an entityId. -func (mp MediaPlayer) TurnOn(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "turn_on" - - mp.conn.WriteMessage(req, mp.ctx) +func (mp MediaPlayer) TurnOn(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "turn_on", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Unjoin the player from a group. Only works on // platforms with support for player groups. -// Takes an entityId. -func (mp MediaPlayer) Unjoin(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "unjoin" - - mp.conn.WriteMessage(req, mp.ctx) +func (mp MediaPlayer) Unjoin(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "unjoin", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Turn a media player volume down. -// Takes an entityId. -func (mp MediaPlayer) VolumeDown(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "volume_down" - - mp.conn.WriteMessage(req, mp.ctx) +func (mp MediaPlayer) VolumeDown(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "volume_down", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Mute a media player's volume. -// Takes an entityId and an optional -// map that is translated into service_data. -func (mp MediaPlayer) VolumeMute(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "volume_mute" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (mp MediaPlayer) VolumeMute(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "volume_mute", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - mp.conn.WriteMessage(req, mp.ctx) + return result, nil } // Set a media player's volume level. -// Takes an entityId and an optional -// map that is translated into service_data. -func (mp MediaPlayer) VolumeSet(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "volume_set" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (mp MediaPlayer) VolumeSet(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "volume_set", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - mp.conn.WriteMessage(req, mp.ctx) + return result, nil } // Turn a media player volume up. -// Takes an entityId. -func (mp MediaPlayer) VolumeUp(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "media_player" - req.Service = "volume_up" - - mp.conn.WriteMessage(req, mp.ctx) +func (mp MediaPlayer) VolumeUp(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := mp.service.CallService( + ctx, "media_player", "volume_up", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/notify.go b/internal/services/notify.go index e76dd42..73510eb 100644 --- a/internal/services/notify.go +++ b/internal/services/notify.go @@ -3,28 +3,45 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" - "saml.dev/gome-assistant/types" + ga "saml.dev/gome-assistant" ) type Notify struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service } -// Send a notification. Takes a types.NotifyRequest. -func (ha *Notify) Notify(reqData types.NotifyRequest) { - req := NewBaseServiceRequest("") - req.Domain = "notify" - req.Service = reqData.ServiceName +func NewNotify(service Service) *Notify { + return &Notify{ + service: service, + } +} + +type NotifyRequest struct { + // Which notify service to call, such as mobile_app_sams_iphone + ServiceName string + Message string + Title string + Data map[string]any +} - serviceData := map[string]any{} - serviceData["message"] = reqData.Message - serviceData["title"] = reqData.Title +// Send a notification. +func (ha *Notify) Notify(reqData NotifyRequest) (any, error) { + ctx := context.TODO() + serviceData := map[string]any{ + "message": reqData.Message, + "title": reqData.Title, + } if reqData.Data != nil { serviceData["data"] = reqData.Data } - req.ServiceData = serviceData - ha.conn.WriteMessage(req, ha.ctx) + var result any + err := ha.service.CallService( + ctx, "notify", reqData.ServiceName, + serviceData, ga.Target{}, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/number.go b/internal/services/number.go index 243603e..f282620 100644 --- a/internal/services/number.go +++ b/internal/services/number.go @@ -3,23 +3,33 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type Number struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service } -/* Public API */ +func NewNumber(service Service) *Number { + return &Number{ + service: service, + } +} -func (ib Number) SetValue(entityId string, value float32) { - req := NewBaseServiceRequest(entityId) - req.Domain = "number" - req.Service = "set_value" - req.ServiceData = map[string]any{"value": value} +/* Public API */ - ib.conn.WriteMessage(req, ib.ctx) +func (ib Number) SetValue(target ga.Target, value float32) (any, error) { + ctx := context.TODO() + var result any + err := ib.service.CallService( + ctx, "number", "set_value", + map[string]any{"value": value}, + target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/scene.go b/internal/services/scene.go index e17ada9..d37a674 100644 --- a/internal/services/scene.go +++ b/internal/services/scene.go @@ -3,61 +3,74 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type Scene struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service +} + +func NewScene(service Service) *Scene { + return &Scene{ + service: service, + } } /* Public API */ // Apply a scene. Takes map that is translated into service_data. -func (s Scene) Apply(serviceData ...map[string]any) { - req := NewBaseServiceRequest("") - req.Domain = "scene" - req.Service = "apply" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (s Scene) Apply(serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := s.service.CallService( + ctx, "scene", "apply", + serviceData, ga.Target{}, &result, + ) + if err != nil { + return nil, err } - - s.conn.WriteMessage(req, s.ctx) + return result, nil } -// Create a scene entity. Takes an entityId and an optional -// map that is translated into service_data. -func (s Scene) Create(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "scene" - req.Service = "create" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +// Create a scene entity. +func (s Scene) Create(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := s.service.CallService( + ctx, "scene", "create", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - s.conn.WriteMessage(req, s.ctx) + return result, nil } // Reload the scenes. -func (s Scene) Reload() { - req := NewBaseServiceRequest("") - req.Domain = "scene" - req.Service = "reload" - - s.conn.WriteMessage(req, s.ctx) +func (s Scene) Reload() (any, error) { + ctx := context.TODO() + var result any + err := s.service.CallService( + ctx, "scene", "reload", nil, ga.Target{}, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -// TurnOn a scene entity. Takes an entityId and an optional -// map that is translated into service_data. -func (s Scene) TurnOn(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "scene" - req.Service = "turn_on" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +// TurnOn a scene entity. +func (s Scene) TurnOn(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := s.service.CallService( + ctx, "scene", "turn_on", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - s.conn.WriteMessage(req, s.ctx) + return result, nil } diff --git a/internal/services/script.go b/internal/services/script.go index b80dbbb..7a63149 100644 --- a/internal/services/script.go +++ b/internal/services/script.go @@ -3,50 +3,75 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type Script struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service +} + +func NewScript(service Service) *Script { + return &Script{ + service: service, + } } /* Public API */ // Reload a script that was created in the HA UI. -func (s Script) Reload(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "script" - req.Service = "reload" - - s.conn.WriteMessage(req, s.ctx) +func (s Script) Reload(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := s.service.CallService( + ctx, "script", "reload", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Toggle a script that was created in the HA UI. -func (s Script) Toggle(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "script" - req.Service = "toggle" - - s.conn.WriteMessage(req, s.ctx) +func (s Script) Toggle(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := s.service.CallService( + ctx, "script", "toggle", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Turn off a script that was created in the HA UI. -func (s Script) TurnOff() { - req := NewBaseServiceRequest("") - req.Domain = "script" - req.Service = "turn_off" - - s.conn.WriteMessage(req, s.ctx) +func (s Script) TurnOff() (any, error) { + ctx := context.TODO() + var result any + err := s.service.CallService( + ctx, "script", "turn_off", + nil, ga.Target{}, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Turn on a script that was created in the HA UI. -func (s Script) TurnOn(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "script" - req.Service = "turn_on" - - s.conn.WriteMessage(req, s.ctx) +func (s Script) TurnOn(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := s.service.CallService( + ctx, "script", "turn_on", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/services.go b/internal/services/services.go index 6dbb024..2052ea1 100644 --- a/internal/services/services.go +++ b/internal/services/services.go @@ -3,55 +3,17 @@ package services import ( "context" - "saml.dev/gome-assistant/internal" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" + "saml.dev/gome-assistant/websocket" ) -func BuildService[ - T AlarmControlPanel | - Climate | - Cover | - Light | - HomeAssistant | - Lock | - MediaPlayer | - Switch | - InputBoolean | - InputButton | - InputDatetime | - InputText | - InputNumber | - Event | - Notify | - Number | - Scene | - Script | - TTS | - Vacuum | - ZWaveJS, -](conn *ws.WebsocketWriter, ctx context.Context) *T { - return &T{conn: conn, ctx: ctx} -} - -type BaseServiceRequest struct { - Id int64 `json:"id"` - RequestType string `json:"type"` // hardcoded "call_service" - Domain string `json:"domain"` - Service string `json:"service"` - ServiceData map[string]any `json:"service_data,omitempty"` - Target struct { - EntityId string `json:"entity_id,omitempty"` - } `json:"target,omitempty"` -} +type Service interface { + Call( + ctx context.Context, req websocket.Request, result any, + ) error -func NewBaseServiceRequest(entityId string) BaseServiceRequest { - id := internal.GetId() - bsr := BaseServiceRequest{ - Id: id, - RequestType: "call_service", - } - if entityId != "" { - bsr.Target.EntityId = entityId - } - return bsr + CallService( + ctx context.Context, domain string, service string, serviceData any, target ga.Target, + result any, + ) error } diff --git a/internal/services/switch.go b/internal/services/switch.go index 0e7be52..65bc593 100644 --- a/internal/services/switch.go +++ b/internal/services/switch.go @@ -3,37 +3,58 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type Switch struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service } -/* Public API */ +func NewSwitch(service Service) *Switch { + return &Switch{ + service: service, + } +} -func (s Switch) TurnOn(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "switch" - req.Service = "turn_on" +/* Public API */ - s.conn.WriteMessage(req, s.ctx) +func (s Switch) TurnOn(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := s.service.CallService( + ctx, "switch", "turn_on", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -func (s Switch) Toggle(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "switch" - req.Service = "toggle" - - s.conn.WriteMessage(req, s.ctx) +func (s Switch) Toggle(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := s.service.CallService( + ctx, "switch", "toggle", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -func (s Switch) TurnOff(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "switch" - req.Service = "turn_off" - s.conn.WriteMessage(req, s.ctx) +func (s Switch) TurnOff(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := s.service.CallService( + ctx, "switch", "turn_off", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/tts.go b/internal/services/tts.go index 74b4963..a05a7c2 100644 --- a/internal/services/tts.go +++ b/internal/services/tts.go @@ -3,51 +3,61 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type TTS struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service +} + +func NewTTS(service Service) *TTS { + return &TTS{ + service: service, + } } /* Public API */ // Remove all text-to-speech cache files and RAM cache. -func (tts TTS) ClearCache() { - req := NewBaseServiceRequest("") - req.Domain = "tts" - req.Service = "clear_cache" - - tts.conn.WriteMessage(req, tts.ctx) +func (tts TTS) ClearCache() (any, error) { + ctx := context.TODO() + var result any + err := tts.service.CallService( + ctx, "tts", "clear_cache", nil, ga.Target{}, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Say something using text-to-speech on a media player with cloud. -// Takes an entityId and an optional -// map that is translated into service_data. -func (tts TTS) CloudSay(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "tts" - req.Service = "cloud_say" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +func (tts TTS) CloudSay(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := tts.service.CallService( + ctx, "tts", "cloud_say", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - tts.conn.WriteMessage(req, tts.ctx) + return result, nil } -// Say something using text-to-speech on a media player with google_translate. -// Takes an entityId and an optional -// map that is translated into service_data. -func (tts TTS) GoogleTranslateSay(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "tts" - req.Service = "google_translate_say" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +// Say something using text-to-speech on a media player with +// google_translate. +func (tts TTS) GoogleTranslateSay(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := tts.service.CallService( + ctx, "tts", "google_translate_say", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - tts.conn.WriteMessage(req, tts.ctx) + return result, nil } diff --git a/internal/services/vacuum.go b/internal/services/vacuum.go index fbc71b0..0dda3db 100644 --- a/internal/services/vacuum.go +++ b/internal/services/vacuum.go @@ -3,131 +3,173 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type Vacuum struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service +} + +func NewVacuum(service Service) *Vacuum { + return &Vacuum{ + service: service, + } } /* Public API */ // Tell the vacuum cleaner to do a spot clean-up. -// Takes an entityId. -func (v Vacuum) CleanSpot(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "vacuum" - req.Service = "clean_spot" - - v.conn.WriteMessage(req, v.ctx) +func (v Vacuum) CleanSpot(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := v.service.CallService( + ctx, "vacuum", "clean_spot", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Locate the vacuum cleaner robot. -// Takes an entityId. -func (v Vacuum) Locate(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "vacuum" - req.Service = "locate" - - v.conn.WriteMessage(req, v.ctx) +func (v Vacuum) Locate(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := v.service.CallService( + ctx, "vacuum", "locate", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Pause the cleaning task. -// Takes an entityId. -func (v Vacuum) Pause(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "vacuum" - req.Service = "pause" - - v.conn.WriteMessage(req, v.ctx) +func (v Vacuum) Pause(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := v.service.CallService( + ctx, "vacuum", "pause", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Tell the vacuum cleaner to return to its dock. -// Takes an entityId. -func (v Vacuum) ReturnToBase(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "vacuum" - req.Service = "return_to_base" - - v.conn.WriteMessage(req, v.ctx) +func (v Vacuum) ReturnToBase(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := v.service.CallService( + ctx, "vacuum", "return_to_base", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } -// Send a raw command to the vacuum cleaner. Takes an entityId and an optional -// map that is translated into service_data. -func (v Vacuum) SendCommand(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "vacuum" - req.Service = "send_command" - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +// Send a raw command to the vacuum cleaner. +func (v Vacuum) SendCommand(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := v.service.CallService( + ctx, "vacuum", "send_command", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - v.conn.WriteMessage(req, v.ctx) + return result, nil } -// Set the fan speed of the vacuum cleaner. Takes an entityId and an optional -// map that is translated into service_data. -func (v Vacuum) SetFanSpeed(entityId string, serviceData ...map[string]any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "vacuum" - req.Service = "set_fan_speed" - - if len(serviceData) != 0 { - req.ServiceData = serviceData[0] +// Set the fan speed of the vacuum cleaner. +func (v Vacuum) SetFanSpeed(target ga.Target, serviceData any) (any, error) { + ctx := context.TODO() + var result any + err := v.service.CallService( + ctx, "vacuum", "set_fan_speed", + serviceData, target, &result, + ) + if err != nil { + return nil, err } - - v.conn.WriteMessage(req, v.ctx) + return result, nil } // Start or resume the cleaning task. -// Takes an entityId. -func (v Vacuum) Start(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "vacuum" - req.Service = "start" - - v.conn.WriteMessage(req, v.ctx) +func (v Vacuum) Start(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := v.service.CallService( + ctx, "vacuum", "start", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Start, pause, or resume the cleaning task. -// Takes an entityId. -func (v Vacuum) StartPause(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "vacuum" - req.Service = "start_pause" - - v.conn.WriteMessage(req, v.ctx) +func (v Vacuum) StartPause(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := v.service.CallService( + ctx, "vacuum", "start_pause", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Stop the current cleaning task. -// Takes an entityId. -func (v Vacuum) Stop(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "vacuum" - req.Service = "stop" - - v.conn.WriteMessage(req, v.ctx) +func (v Vacuum) Stop(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := v.service.CallService( + ctx, "vacuum", "stop", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Stop the current cleaning task and return to home. -// Takes an entityId. -func (v Vacuum) TurnOff(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "vacuum" - req.Service = "turn_off" - - v.conn.WriteMessage(req, v.ctx) +func (v Vacuum) TurnOff(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := v.service.CallService( + ctx, "vacuum", "turn_off", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } // Start a new cleaning task. -// Takes an entityId. -func (v Vacuum) TurnOn(entityId string) { - req := NewBaseServiceRequest(entityId) - req.Domain = "vacuum" - req.Service = "turn_on" - - v.conn.WriteMessage(req, v.ctx) +func (v Vacuum) TurnOn(target ga.Target) (any, error) { + ctx := context.TODO() + var result any + err := v.service.CallService( + ctx, "vacuum", "turn_on", + nil, target, &result, + ) + if err != nil { + return nil, err + } + return result, nil } diff --git a/internal/services/zwavejs.go b/internal/services/zwavejs.go index f19fc6f..ec8d701 100644 --- a/internal/services/zwavejs.go +++ b/internal/services/zwavejs.go @@ -3,27 +3,39 @@ package services import ( "context" - ws "saml.dev/gome-assistant/internal/websocket" + ga "saml.dev/gome-assistant" ) /* Structs */ type ZWaveJS struct { - conn *ws.WebsocketWriter - ctx context.Context + service Service +} + +func NewZWaveJS(service Service) *ZWaveJS { + return &ZWaveJS{ + service: service, + } } /* Public API */ // ZWaveJS bulk_set_partial_config_parameters service. -func (zw ZWaveJS) BulkSetPartialConfigParam(entityId string, parameter int, value any) { - req := NewBaseServiceRequest(entityId) - req.Domain = "zwave_js" - req.Service = "bulk_set_partial_config_parameters" - req.ServiceData = map[string]any{ - "parameter": parameter, - "value": value, +func (zw ZWaveJS) BulkSetPartialConfigParam( + target ga.Target, parameter int, value any, +) (any, error) { + ctx := context.TODO() + var result any + err := zw.service.CallService( + ctx, "zwave_js", "bulk_set_partial_config_parameters", + map[string]any{ + "parameter": parameter, + "value": value, + }, + target, &result, + ) + if err != nil { + return nil, err } - - zw.conn.WriteMessage(req, zw.ctx) + return result, nil } diff --git a/internal/websocket/reader.go b/internal/websocket/reader.go deleted file mode 100644 index 6ac4640..0000000 --- a/internal/websocket/reader.go +++ /dev/null @@ -1,50 +0,0 @@ -package websocket - -import ( - "context" - "encoding/json" - "log/slog" - - "github.com/gorilla/websocket" -) - -type BaseMessage struct { - Type string `json:"type"` - Id int64 `json:"id"` - Success bool `json:"success"` -} - -type ChanMsg struct { - Id int64 - Type string - Success bool - Raw []byte -} - -func ListenWebsocket(conn *websocket.Conn, ctx context.Context, c chan ChanMsg) { - for { - bytes, err := ReadMessage(conn, ctx) - if err != nil { - slog.Error("Error reading from websocket:", err) - close(c) - break - } - - base := BaseMessage{ - // default to true for messages that don't include "success" at all - Success: true, - } - json.Unmarshal(bytes, &base) - if !base.Success { - slog.Warn("Received unsuccessful response", "response", string(bytes)) - } - chanMsg := ChanMsg{ - Type: base.Type, - Id: base.Id, - Success: base.Success, - Raw: bytes, - } - - c <- chanMsg - } -} diff --git a/internal/websocket/websocket.go b/internal/websocket/websocket.go deleted file mode 100644 index 2eec28b..0000000 --- a/internal/websocket/websocket.go +++ /dev/null @@ -1,160 +0,0 @@ -// Package websocket is used to interact with the Home Assistant -// websocket API. All HA interaction is done via websocket -// except for cases explicitly called out in http package -// documentation. -package websocket - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "sync" - "time" - - "github.com/gorilla/websocket" - i "saml.dev/gome-assistant/internal" -) - -var ErrInvalidToken = errors.New("invalid authentication token") - -type AuthMessage struct { - MsgType string `json:"type"` - AccessToken string `json:"access_token"` -} - -type WebsocketWriter struct { - Conn *websocket.Conn - mutex sync.Mutex -} - -func (w *WebsocketWriter) WriteMessage(msg interface{}, ctx context.Context) error { - w.mutex.Lock() - defer w.mutex.Unlock() - - err := w.Conn.WriteJSON(msg) - if err != nil { - return err - } - - return nil -} - -func ReadMessage(conn *websocket.Conn, ctx context.Context) ([]byte, error) { - _, msg, err := conn.ReadMessage() - if err != nil { - return []byte{}, err - } - return msg, nil -} - -func SetupConnection(ip, port, authToken string) (*websocket.Conn, context.Context, context.CancelFunc, error) { - uri := fmt.Sprintf("ws://%s:%s/api/websocket", ip, port) - return ConnectionFromUri(uri, authToken) -} - -func SetupSecureConnection(ip, port, authToken string) (*websocket.Conn, context.Context, context.CancelFunc, error) { - uri := fmt.Sprintf("wss://%s:%s/api/websocket", ip, port) - return ConnectionFromUri(uri, authToken) -} - -func ConnectionFromUri(uri, authToken string) (*websocket.Conn, context.Context, context.CancelFunc, error) { - ctx, ctxCancel := context.WithTimeout(context.Background(), time.Second*3) - - // Init websocket connection - dialer := websocket.DefaultDialer - conn, _, err := dialer.DialContext(ctx, uri, nil) - if err != nil { - ctxCancel() - slog.Error("Failed to connect to websocket. Check URI\n", "uri", uri) - return nil, nil, nil, err - } - - // Read auth_required message - _, err = ReadMessage(conn, ctx) - if err != nil { - ctxCancel() - slog.Error("Unknown error creating websocket client\n") - return nil, nil, nil, err - } - - // Send auth message - err = SendAuthMessage(conn, ctx, authToken) - if err != nil { - ctxCancel() - slog.Error("Unknown error creating websocket client\n") - return nil, nil, nil, err - } - - // Verify auth message was successful - err = VerifyAuthResponse(conn, ctx) - if err != nil { - ctxCancel() - slog.Error("Auth token is invalid. Please double check it or create a new token in your Home Assistant profile\n") - return nil, nil, nil, err - } - - return conn, ctx, ctxCancel, nil -} - -func SendAuthMessage(conn *websocket.Conn, ctx context.Context, token string) error { - err := conn.WriteJSON(AuthMessage{MsgType: "auth", AccessToken: token}) - if err != nil { - return err - } - return nil -} - -type authResponse struct { - MsgType string `json:"type"` - Message string `json:"message"` -} - -func VerifyAuthResponse(conn *websocket.Conn, ctx context.Context) error { - msg, err := ReadMessage(conn, ctx) - if err != nil { - return err - } - - var authResp authResponse - json.Unmarshal(msg, &authResp) - // log.Println(authResp.MsgType) - if authResp.MsgType != "auth_ok" { - return ErrInvalidToken - } - - return nil -} - -type SubEvent struct { - Id int64 `json:"id"` - Type string `json:"type"` - EventType string `json:"event_type"` -} - -func SubscribeToStateChangedEvents(id int64, conn *WebsocketWriter, ctx context.Context) { - SubscribeToEventType("state_changed", conn, ctx, id) -} - -func SubscribeToEventType(eventType string, conn *WebsocketWriter, ctx context.Context, id ...int64) { - var finalId int64 - if len(id) == 0 { - finalId = i.GetId() - } else { - finalId = id[0] - } - e := SubEvent{ - Id: finalId, - Type: "subscribe_events", - EventType: eventType, - } - err := conn.WriteMessage(e, ctx) - if err != nil { - wrappedErr := fmt.Errorf("error writing to websocket: %w", err) - slog.Error(wrappedErr.Error()) - panic(wrappedErr) - } - // m, _ := ReadMessage(conn, ctx) - // log.Default().Println(string(m)) -} diff --git a/service.go b/service.go deleted file mode 100644 index c861ec7..0000000 --- a/service.go +++ /dev/null @@ -1,59 +0,0 @@ -package gomeassistant - -import ( - "context" - - "saml.dev/gome-assistant/internal/http" - "saml.dev/gome-assistant/internal/services" - ws "saml.dev/gome-assistant/internal/websocket" -) - -type Service struct { - AlarmControlPanel *services.AlarmControlPanel - Climate *services.Climate - Cover *services.Cover - HomeAssistant *services.HomeAssistant - Light *services.Light - Lock *services.Lock - MediaPlayer *services.MediaPlayer - Switch *services.Switch - InputBoolean *services.InputBoolean - InputButton *services.InputButton - InputText *services.InputText - InputDatetime *services.InputDatetime - InputNumber *services.InputNumber - Event *services.Event - Notify *services.Notify - Number *services.Number - Scene *services.Scene - Script *services.Script - TTS *services.TTS - Vacuum *services.Vacuum - ZWaveJS *services.ZWaveJS -} - -func newService(conn *ws.WebsocketWriter, ctx context.Context, httpClient *http.HttpClient) *Service { - return &Service{ - AlarmControlPanel: services.BuildService[services.AlarmControlPanel](conn, ctx), - Climate: services.BuildService[services.Climate](conn, ctx), - Cover: services.BuildService[services.Cover](conn, ctx), - Light: services.BuildService[services.Light](conn, ctx), - HomeAssistant: services.BuildService[services.HomeAssistant](conn, ctx), - Lock: services.BuildService[services.Lock](conn, ctx), - MediaPlayer: services.BuildService[services.MediaPlayer](conn, ctx), - Switch: services.BuildService[services.Switch](conn, ctx), - InputBoolean: services.BuildService[services.InputBoolean](conn, ctx), - InputButton: services.BuildService[services.InputButton](conn, ctx), - InputText: services.BuildService[services.InputText](conn, ctx), - InputDatetime: services.BuildService[services.InputDatetime](conn, ctx), - InputNumber: services.BuildService[services.InputNumber](conn, ctx), - Event: services.BuildService[services.Event](conn, ctx), - Notify: services.BuildService[services.Notify](conn, ctx), - Number: services.BuildService[services.Number](conn, ctx), - Scene: services.BuildService[services.Scene](conn, ctx), - Script: services.BuildService[services.Script](conn, ctx), - TTS: services.BuildService[services.TTS](conn, ctx), - Vacuum: services.BuildService[services.Vacuum](conn, ctx), - ZWaveJS: services.BuildService[services.ZWaveJS](conn, ctx), - } -} diff --git a/target.go b/target.go new file mode 100644 index 0000000..243a49c --- /dev/null +++ b/target.go @@ -0,0 +1,32 @@ +package ga + +import "fmt" + +// Target represents the target of the service call, if applicable. +type Target struct { + EntityID string `json:"entity_id,omitempty"` + DeviceID string `json:"device_id,omitempty"` +} + +func EntityTarget(entityID string) Target { + return Target{ + EntityID: entityID, + } +} + +func DeviceTarget(deviceID string) Target { + return Target{ + DeviceID: deviceID, + } +} + +func (t Target) String() string { + switch { + case t.EntityID != "": + return fmt.Sprintf("entity %s", t.EntityID) + case t.DeviceID != "": + return fmt.Sprintf("device %s", t.DeviceID) + default: + return "unset target" + } +} diff --git a/types/requestTypes.go b/types/requestTypes.go deleted file mode 100644 index a2f736f..0000000 --- a/types/requestTypes.go +++ /dev/null @@ -1,33 +0,0 @@ -package types - -type NotifyRequest struct { - // Which notify service to call, such as mobile_app_sams_iphone - ServiceName string - Message string - Title string - Data map[string]any -} - -type SetTemperatureRequest struct { - Temperature float32 - TargetTempHigh float32 - TargetTempLow float32 - HvacMode string -} - -func (r *SetTemperatureRequest) ToJSON() map[string]any { - m := map[string]any{} - if r.Temperature != 0 { - m["temperature"] = r.Temperature - } - if r.TargetTempHigh != 0 { - m["target_temp_high"] = r.TargetTempHigh - } - if r.TargetTempLow != 0 { - m["target_temp_low"] = r.TargetTempLow - } - if r.HvacMode != "" { - m["hvac_mode"] = r.HvacMode - } - return m -} diff --git a/websocket/context.go b/websocket/context.go new file mode 100644 index 0000000..d0a9dea --- /dev/null +++ b/websocket/context.go @@ -0,0 +1,39 @@ +package websocket + +import ( + "bytes" + "encoding/json" + "fmt" + "log/slog" +) + +type Context struct { + ID *string `json:"id"` + UserID *string `json:"user_id"` + ParentID *string `json:"parent_id"` +} + +func (c *Context) UnmarshalJSON(b []byte) error { + if bytes.Equal(b, []byte("null")) { + return nil + } + if b[0] == '"' { + // The context is stored as a naked string. I think this can + // happen but I don't know what it's supposed to signify. + slog.Info("bare string as context; ignored", "input", string(b)) + return nil + } + + // Unmarshal into a type that is assignable to Context but without + // an `UnmarshalJSON()` method: + var context struct { + ID *string `json:"id"` + UserID *string `json:"user_id"` + ParentID *string `json:"parent_id"` + } + if err := json.Unmarshal(b, &context); err != nil { + return fmt.Errorf("unmarshaling context '%s': %w", string(b), err) + } + *c = context + return nil +} diff --git a/websocket/entity.go b/websocket/entity.go new file mode 100644 index 0000000..8b9449a --- /dev/null +++ b/websocket/entity.go @@ -0,0 +1,41 @@ +package websocket + +// "state_changed" events are compressed in a rather awkward way. +// These types help pick them apart. + +type Entity[AttributesT any] struct { + State EntityState `json:"state"` + Attributes AttributesT `json:"attributes"` + Context Context `json:"context"` + LastChanged TimeStamp `json:"last_changed"` +} + +type EntityItem[AttributesT any] struct { + EntityID string `json:"entity_id"` + Entity[AttributesT] +} + +// CompressedEntity is similar to `Entity` except that the JSON field +// names are abbreviated. +type CompressedEntity[AttributesT any] struct { + State EntityState `json:"s"` + Attributes AttributesT `json:"a"` + Context Context `json:"c"` + LastChanged TimeStamp `json:"lc"` +} + +// EntityState is the state of an entity ( // E.g., "on", "off", +// "unavailable"; there are probably more. +type EntityState string + +func (s EntityState) On() bool { + return s == "on" +} + +func (s EntityState) Off() bool { + return s == "off" +} + +func (s EntityState) Unavailable() bool { + return s == "unavailable" +} diff --git a/websocket/event_message.go b/websocket/event_message.go new file mode 100644 index 0000000..e913ea4 --- /dev/null +++ b/websocket/event_message.go @@ -0,0 +1,20 @@ +package websocket + +import "time" + +type BaseEvent struct { + EventType string `json:"event_type"` + Origin string `json:"origin"` + TimeFired time.Time `json:"time_fired"` + Context Context `json:"context"` +} + +type Event struct { + BaseEvent + RawData RawMessage `json:"data"` +} + +type EventMessage struct { + BaseMessage + Event Event `json:"event"` +} diff --git a/websocket/locked_conn.go b/websocket/locked_conn.go new file mode 100644 index 0000000..5043232 --- /dev/null +++ b/websocket/locked_conn.go @@ -0,0 +1,32 @@ +package websocket + +// LockedConn represents a `Conn` object that is currently locked. It +// allows user access to operations that usually require the lock. +type LockedConn interface { + // NextID returns the next unused id to be used in a websocket + // message. The IDs so generated must be used in order, while the + // `LockedConn` is still active. + NextID() int64 + + // Subscribe creates a new (unique) subscription ID and subscribes + // `subscriber` to it, in the sense that the subscriber will be + // called for any responses that have that ID. This doesn't + // actually interact with the server. The returned `Subscription` + // must eventually be passed at least once to `Unsubscribe()`, + // though `Unsubscribe()` can be called against a different + // `LockedConn` than the one that generated it. + Subscribe(subscriber Subscriber) Subscription + + // Unsubscribe terminates `subscription` at the websocket level; + // i.e., no more incoming messages will be forwarded to the + // corresponding `Subscriber`. Note that this does not interact + // with the server; it is the caller's responsibility to send it + // an "unsubscribe" command if necessary. + Unsubscribe(subscription Subscription) + + // SendMessage sends the specified message over the websocket + // connection. `msg` must be JSON-serializable and have the + // correct format and a unique, monotonically-increasing ID, which + // should be generated using `NextID()` and used in order. + SendMessage(msg any) error +} diff --git a/websocket/message.go b/websocket/message.go new file mode 100644 index 0000000..bd05767 --- /dev/null +++ b/websocket/message.go @@ -0,0 +1,31 @@ +package websocket + +// BaseMessage implements the required part of any websocket message. +// The idea is to embed this type in other message types. +type BaseMessage struct { + Type string `json:"type"` + ID int64 `json:"id"` +} + +type Request interface { + GetID() int64 + SetID(id int64) +} + +func (msg *BaseMessage) GetID() int64 { + return msg.ID +} + +func (msg *BaseMessage) SetID(id int64) { + msg.ID = id +} + +// Message holds a complete message, only partly parsed. The entire, +// original, unparsed message is available in the `Raw` field. +type Message struct { + BaseMessage + + // Raw contains the original, full, unparsed message (including + // fields `Type` and `ID`, which also appear in `BaseMessage`). + Raw RawMessage `json:"-"` +} diff --git a/websocket/raw_message.go b/websocket/raw_message.go new file mode 100644 index 0000000..f28ceec --- /dev/null +++ b/websocket/raw_message.go @@ -0,0 +1,29 @@ +package websocket + +import ( + "encoding/json" +) + +// RawMessage is like `json.RawMessage`, except that its `String()` +// method converts it directly to a string. +type RawMessage json.RawMessage + +func (m RawMessage) MarshalJSON() ([]byte, error) { + if m == nil { + return []byte("null"), nil + } + return m, nil +} + +// UnmarshalJSON delegates to `json.RawMessage`. (The method has a +// pointer receiver, so we have to implement it explicitly.) +func (m *RawMessage) UnmarshalJSON(data []byte) error { + return (*json.RawMessage)(m).UnmarshalJSON(data) +} + +func (rm RawMessage) String() string { + return string(rm) +} + +// RawObject is a minimally-parsed representation of a JSON object. +type RawObject map[string]RawMessage diff --git a/websocket/read.go b/websocket/read.go new file mode 100644 index 0000000..bef256d --- /dev/null +++ b/websocket/read.go @@ -0,0 +1,33 @@ +package websocket + +import ( + "encoding/json" + "log/slog" +) + +// Start reads JSON-formatted messages from `conn`, partly +// deserializes them, and processes them. If the message ID is +// currently subscribed to, invoke the subscriber for the message. If +// there is an error reading from `conn`, log it and return. +func (conn *Conn) Start() { + for { + b, err := conn.readMessage() + if err != nil { + slog.Error("Error reading from websocket:", err) + return + } + + var msg Message + if err := json.Unmarshal(b, &msg); err != nil { + slog.Error("Error parsing JSON message from websocket:", err) + return + } + // We've only deserialized part of the message, so store the + // raw bytes as well, so that the listeners can handle them. + msg.Raw = b + + if subscriber, ok := conn.getSubscriber(msg.ID); ok { + subscriber(msg) + } + } +} diff --git a/websocket/result_message.go b/websocket/result_message.go new file mode 100644 index 0000000..6219b9c --- /dev/null +++ b/websocket/result_message.go @@ -0,0 +1,67 @@ +package websocket + +import ( + "encoding/json" + "fmt" +) + +type BaseResultMessage struct { + BaseMessage + Success bool `json:"success"` + Error *ResultError `json:"error,omitempty"` +} + +type ResultError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +func (err *ResultError) Error() string { + switch { + case err.Code != "" && err.Message != "": + return fmt.Sprintf("%s: %s", err.Code, err.Message) + case err.Code == "" && err.Message != "": + return fmt.Sprintf("unknown_error: %s", err.Message) + case err.Code != "" && err.Message == "": + return fmt.Sprintf("%s", err.Code) + default: + // This seems not to be an error at all. + return fmt.Sprintf("INVALID (seems not to be an error)") + } +} + +type ResultMessage struct { + BaseResultMessage + + // Raw contains the "result" part of the message, unparsed. + Result RawMessage `json:"result"` +} + +// GetResult parses a result out of `msg` (which must have type +// "result"). If `msg` indicates that an error occurred, convert that +// to an error and return it. Parse the result into `result`, which +// must be unmarshalable as JSON. +func (msg Message) GetResult(result any) error { + if msg.Type != "result" { + return fmt.Errorf( + "response message was not of type 'result': %#v", msg, + ) + } + var resultMsg ResultMessage + if err := json.Unmarshal(msg.Raw, &resultMsg); err != nil { + return fmt.Errorf("unmarshaling result message: %w", err) + } + if !resultMsg.Success { + if resultMsg.Error == nil { + return fmt.Errorf( + "request did not succeed but no error was returned", + ) + } + return resultMsg.Error + } + + if err := json.Unmarshal(resultMsg.Result, result); err != nil { + return fmt.Errorf("unmarshalling result from %q: %w", resultMsg.Result, err) + } + return nil +} diff --git a/websocket/send.go b/websocket/send.go new file mode 100644 index 0000000..ce1e502 --- /dev/null +++ b/websocket/send.go @@ -0,0 +1,69 @@ +package websocket + +import ( + "fmt" +) + +// Messager is called by `Send()` while holding the `writeMutex`. It +// can send a message by allocating an ID using `lc.NextID()` then +// sending it using `lc.SendMessage()`. The `MessageWriter` should +// only be used while the callback is running. +type Messager func(lc LockedConn) error + +// Send is the primary way to write a message over the websocket +// interface. Since these messages require monotonically-increasing ID +// numbers, the work from allocating a new ID number through sending +// the message has to be done under the `writeMutex`. This is done by +// passing this function a `Messager`, which is invoked while holding +// the lock and passed the ID that it should use. +// +// Usage: +// +// msg := NewFooMessage{…} +// err := conn.Send(func(lc MessageWriter) error { +// id := lc.NextID() +// // …do anything else that needs to be done with `id`… +// msg.ID = id +// return lc.SendMessage(msg) +// }) +func (conn *Conn) Send(msgr Messager) error { + conn.writeMutex.Lock() + defer conn.writeMutex.Unlock() + + return msgr(lockedConn{conn: conn}) +} + +// lockedConn is a `LockedConn` view of a `Conn`, to be used +// only for a finite time when the connection is locked. +type lockedConn struct { + conn *Conn +} + +func (lc lockedConn) SendMessage(msg any) error { + if err := lc.conn.conn.WriteJSON(msg); err != nil { + return fmt.Errorf("sending websocket message to server: %w", err) + } + + return nil +} + +func (lc lockedConn) NextID() int64 { + lc.conn.lastID++ + return lc.conn.lastID +} + +func (lc lockedConn) Subscribe(subscriber Subscriber) Subscription { + id := lc.NextID() + lc.conn.subscribers[id] = subscriber + return Subscription{ + id: id, + } +} + +func (lc lockedConn) Unsubscribe(subscription Subscription) { + if subscription.id == 0 { + return + } + delete(lc.conn.subscribers, subscription.id) + subscription.id = 0 +} diff --git a/websocket/state_changed_message.go b/websocket/state_changed_message.go new file mode 100644 index 0000000..f6e9e0e --- /dev/null +++ b/websocket/state_changed_message.go @@ -0,0 +1,169 @@ +package websocket + +import ( + "encoding/json" + "fmt" +) + +// "state_changed" events are compressed in a rather awkward way. +// These types help pick them apart. + +// CompressedEntityChange keeps tracks of fields added and removed as +// part of a change. Fields that are mutated appear as "additions". +type CompressedEntityChange struct { + Additions CompressedEntity[RawObject] `json:"+,omitempty"` + Removals struct { + Attributes []string `json:"a"` + Context []string `json:"c"` + } `json:"-,omitempty"` +} + +type CompressedStateChangedMessage struct { + BaseMessage + Event struct { + Added map[string]CompressedEntity[RawObject] `json:"a,omitempty"` + Changed map[string]CompressedEntityChange `json:"c,omitempty"` + Removed []string `json:"r,omitempty"` + } `json:"event"` +} + +// ApplyChange applies the changes indicated in `msg` to the entity with the +// specified `entityID` and whose old state was `oldEntity`, returning the +// new entity. If the entity was removed altogether, return an empty +// entity. +// +// Because the entity being changed might not store its attributes as +// a generic `RawObject`, we have to do the conversion in an awkward +// way to avoiding needing specialized code for each `AttributeT`: + +// 1. Convert the old attributes from an `AttributeT` into a +// `RawObject`; +// 2. Apply the attribute changes to the `RawObject`; +// 3. Convert the updated `RawObject` back into an `AttributeT`. +func ApplyChange[AttributeT any]( + msg CompressedStateChangedMessage, + entityID string, oldEntity Entity[AttributeT], +) (Entity[AttributeT], error) { + for _, eid := range msg.Event.Removed { + if eid == entityID { + return Entity[AttributeT]{}, nil + } + } + + if entity, ok := msg.Event.Added[entityID]; ok { + // This entityID was added. The new state was right there in + // the message. + var newAttributes AttributeT + if err := convertTypes(&newAttributes, entity.Attributes); err != nil { + return Entity[AttributeT]{}, fmt.Errorf( + "converting the added attributes: %w", err, + ) + } + return Entity[AttributeT]{ + State: entity.State, + Attributes: newAttributes, + // FIXME: apparently, context can also be a single string. + Context: entity.Context, + LastChanged: entity.LastChanged, + }, nil + } + + change, ok := msg.Event.Changed[entityID] + if !ok { + // There were no changes. + return oldEntity, nil + } + + // The existing entry has had some fields changed. Apply them to + // `entity` to produce the new entity: + + newEntity := Entity[AttributeT]{ + State: oldEntity.State, + Context: mergeContexts( + oldEntity.Context, + change.Additions.Context, + change.Removals.Context, + ), + LastChanged: change.Additions.LastChanged, + } + + if change.Additions.State != "" { + newEntity.State = change.Additions.State + } + + var oldAttributes RawObject + if err := convertTypes(&oldAttributes, oldEntity.Attributes); err != nil { + return Entity[AttributeT]{}, fmt.Errorf("converting the old attributes: %w", err) + } + + attributes := mergeMaps( + oldAttributes, + change.Additions.Attributes, + change.Removals.Attributes, + ) + + if err := convertTypes(&newEntity.Attributes, attributes); err != nil { + return Entity[AttributeT]{}, fmt.Errorf("converting the new attributes: %w", err) + } + + return newEntity, nil +} + +func mergeMaps(old, additions RawObject, removals []string) RawObject { + new := make(RawObject, len(old)+len(additions)-len(removals)) + for k, v := range old { + new[k] = v + } + for k, v := range additions { + new[k] = v + } + for _, k := range removals { + delete(new, k) + } + return new +} + +func mergeContexts(context, additions Context, removals []string) Context { + // Adjust context for any additions: + if additions.ID != nil { + context.ID = additions.ID + } + if additions.UserID != nil { + context.UserID = additions.UserID + } + if additions.ParentID != nil { + context.ParentID = additions.ParentID + } + + // Adjust context for any removals: + for _, key := range removals { + switch key { + case "user_id": + context.UserID = nil + case "id": + context.ID = nil + case "parent_id": + context.ParentID = nil + } + } + + return context +} + +// Convert `src` to `dst` (which can be of two different types) by +// serializing to JSON then deserializing. `src` must be something +// that can be passed to `json.Marshal()`, and `dst` must be something +// that can be passed to `json.Unmarshal()` (i.e., typically a +// pointer). +func convertTypes(dst any, src any) error { + b, err := json.Marshal(src) + if err != nil { + return fmt.Errorf("serializing src: %w", err) + } + + if err := json.Unmarshal(b, dst); err != nil { + return fmt.Errorf("deserializing to dst: %w", err) + } + + return nil +} diff --git a/websocket/subscriptions.go b/websocket/subscriptions.go new file mode 100644 index 0000000..80d2f30 --- /dev/null +++ b/websocket/subscriptions.go @@ -0,0 +1,24 @@ +package websocket + +func (conn *Conn) getSubscriber(id int64) (Subscriber, bool) { + conn.subscribeMutex.RLock() + defer conn.subscribeMutex.RUnlock() + + subscriber, ok := conn.subscribers[id] + return subscriber, ok +} + +// Subscriber is called synchronously when a message with the +// subscribed `id` is received. +type Subscriber func(msg Message) + +// Subscription represents a websocket-level subscription to a +// particular message ID. Incoming messages with that ID will be +// forwarded to the corresponding `Subscriber`. +type Subscription struct { + id int64 +} + +func (subscription Subscription) ID() int64 { + return subscription.id +} diff --git a/websocket/time_stamp.go b/websocket/time_stamp.go new file mode 100644 index 0000000..e4f67e5 --- /dev/null +++ b/websocket/time_stamp.go @@ -0,0 +1,29 @@ +package websocket + +import ( + "encoding/json" + "fmt" + "math" + "time" +) + +type TimeStamp time.Time + +// UnmarshalJSON unmarshals a timestamp from JSON. HA sometimes +// formats timestamps as RFC 3339 strings, sometimes as fractional +// seconds since the epoch. Handle either one (without recording which +// one it was). +func (ts *TimeStamp) UnmarshalJSON(b []byte) error { + if err := (*time.Time)(ts).UnmarshalJSON(b); err == nil { + return nil + } + + var v float64 + if err := json.Unmarshal(b, &v); err == nil { + seconds := math.Floor(v) + *(*time.Time)(ts) = time.Unix(int64(seconds), int64((v-seconds)*1e+9)) + return nil + } + + return fmt.Errorf("unmarshaling timestamp: '%s'", string(b)) +} diff --git a/websocket/websocket.go b/websocket/websocket.go new file mode 100644 index 0000000..faf64ed --- /dev/null +++ b/websocket/websocket.go @@ -0,0 +1,126 @@ +// Package websocket is used to interact with the Home Assistant +// websocket API. All HA interaction is done via websocket +// except for cases explicitly called out in http package +// documentation. +package websocket + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "sync" + + "github.com/gorilla/websocket" +) + +var ErrInvalidToken = errors.New("invalid authentication token") + +type Conn struct { + writeMutex sync.Mutex + conn *websocket.Conn + + subscribeMutex sync.RWMutex + subscribers map[int64]Subscriber + + // lastID is the last message ID that has already been used. It + // must be accessed atomically. + lastID int64 +} + +func NewConnFromURI(ctx context.Context, uri string, authToken string) (*Conn, error) { + // Init websocket connection + dialer := websocket.DefaultDialer + wsConn, _, err := dialer.DialContext(ctx, uri, nil) + if err != nil { + slog.Error("Failed to connect to websocket. Check URI\n", "uri", uri) + return nil, err + } + + conn := &Conn{ + conn: wsConn, + subscribers: make(map[int64]Subscriber), + } + + // Read auth_required message + if _, err := conn.readMessage(); err != nil { + slog.Error("Unknown error creating websocket client\n") + return nil, err + } + + // Send auth message + err = conn.sendAuthMessage(authToken) + if err != nil { + slog.Error("Unknown error creating websocket client\n") + return nil, err + } + + // Verify auth message was successful + err = conn.verifyAuthResponse() + if err != nil { + slog.Error( + "Auth token is invalid. Please double check it " + + "or create a new token in your Home Assistant profile\n", + ) + return nil, err + } + + return conn, nil +} + +func NewConn(ctx context.Context, ip, port, authToken string) (*Conn, error) { + uri := fmt.Sprintf("ws://%s:%s/api/websocket", ip, port) + return NewConnFromURI(ctx, uri, authToken) +} + +func NewSecureConn(ctx context.Context, ip, port, authToken string) (*Conn, error) { + uri := fmt.Sprintf("wss://%s:%s/api/websocket", ip, port) + return NewConnFromURI(ctx, uri, authToken) +} + +func (conn *Conn) readMessage() ([]byte, error) { + _, msg, err := conn.conn.ReadMessage() + if err != nil { + return []byte{}, err + } + return msg, nil +} + +func (conn *Conn) Close() error { + return conn.conn.Close() +} + +type authRequest struct { + MsgType string `json:"type"` + AccessToken string `json:"access_token"` +} + +func (conn *Conn) sendAuthMessage(token string) error { + err := conn.conn.WriteJSON(authRequest{MsgType: "auth", AccessToken: token}) + if err != nil { + return err + } + return nil +} + +type authResponse struct { + MsgType string `json:"type"` + Message string `json:"message"` +} + +func (conn *Conn) verifyAuthResponse() error { + msg, err := conn.readMessage() + if err != nil { + return err + } + + var authResp authResponse + json.Unmarshal(msg, &authResp) + // log.Println(authResp.MsgType) + if authResp.MsgType != "auth_ok" { + return ErrInvalidToken + } + + return nil +}