From f53f8258e9494ff764f3b7dcaa4bd58396a83c5b Mon Sep 17 00:00:00 2001 From: frag223 Date: Thu, 8 Jan 2026 22:05:00 +1100 Subject: [PATCH] improving notification message to include repository --- internal/database/products.sql.gen.go | 60 +++++++++++++++++--------- internal/database/queries/products.sql | 28 +++++++----- internal/products/interfaces.go | 4 +- internal/products/service.go | 12 +++--- internal/products/service_test.go | 38 ++++++++++++---- internal/products/transforms.go | 34 +++++++++++++++ internal/products/types.go | 6 +++ internal/watchtower/sync.go | 44 ++++++++----------- 8 files changed, 154 insertions(+), 72 deletions(-) diff --git a/internal/database/products.sql.gen.go b/internal/database/products.sql.gen.go index beed269..7312954 100644 --- a/internal/database/products.sql.gen.go +++ b/internal/database/products.sql.gen.go @@ -506,26 +506,36 @@ func (q *Queries) GetPullRequestsByOrganisationAndState(ctx context.Context, arg } const getRecentPullRequests = `-- name: GetRecentPullRequests :many -SELECT external_id -FROM pull_requests -WHERE created_at >= unixepoch() - 300 -AND state = 'OPEN' -ORDER BY created_at DESC +SELECT pr.external_id, pr.repository_name, po.organisation_id +FROM pull_requests pr + JOIN repositories r ON r.name = pr.repository_name + JOIN products p ON JSON_VALID(p.tags) + AND EXISTS (SELECT 1 FROM JSON_EACH(p.tags) WHERE JSON_EACH.value = r.topic) + JOIN product_organisations po ON po.product_id = p.id +WHERE pr.created_at >= unixepoch() - 300 + AND pr.state = 'OPEN' +ORDER BY pr.created_at DESC ` -func (q *Queries) GetRecentPullRequests(ctx context.Context) ([]string, error) { +type GetRecentPullRequestsRow struct { + ExternalID string + RepositoryName string + OrganisationID sql.NullInt64 +} + +func (q *Queries) GetRecentPullRequests(ctx context.Context) ([]GetRecentPullRequestsRow, error) { rows, err := q.db.QueryContext(ctx, getRecentPullRequests) if err != nil { return nil, err } defer rows.Close() - var items []string + var items []GetRecentPullRequestsRow for rows.Next() { - var external_id string - if err := rows.Scan(&external_id); err != nil { + var i GetRecentPullRequestsRow + if err := rows.Scan(&i.ExternalID, &i.RepositoryName, &i.OrganisationID); err != nil { return nil, err } - items = append(items, external_id) + items = append(items, i) } if err := rows.Close(); err != nil { return nil, err @@ -537,26 +547,36 @@ func (q *Queries) GetRecentPullRequests(ctx context.Context) ([]string, error) { } const getRecentSecurity = `-- name: GetRecentSecurity :many -SELECT external_id -FROM securities -WHERE created_at >= unixepoch() - 300 -and state = 'OPEN' -ORDER BY created_at DESC +SELECT sec.external_id, sec.repository_name, po.organisation_id +FROM securities sec + JOIN repositories r ON r.name = sec.repository_name + JOIN products p ON JSON_VALID(p.tags) + AND EXISTS (SELECT 1 FROM JSON_EACH(p.tags) WHERE JSON_EACH.value = r.topic) + JOIN product_organisations po ON po.product_id = p.id +WHERE sec.created_at >= unixepoch() - 300 + and state = 'OPEN' +ORDER BY sec.created_at DESC ` -func (q *Queries) GetRecentSecurity(ctx context.Context) ([]string, error) { +type GetRecentSecurityRow struct { + ExternalID string + RepositoryName string + OrganisationID sql.NullInt64 +} + +func (q *Queries) GetRecentSecurity(ctx context.Context) ([]GetRecentSecurityRow, error) { rows, err := q.db.QueryContext(ctx, getRecentSecurity) if err != nil { return nil, err } defer rows.Close() - var items []string + var items []GetRecentSecurityRow for rows.Next() { - var external_id string - if err := rows.Scan(&external_id); err != nil { + var i GetRecentSecurityRow + if err := rows.Scan(&i.ExternalID, &i.RepositoryName, &i.OrganisationID); err != nil { return nil, err } - items = append(items, external_id) + items = append(items, i) } if err := rows.Close(); err != nil { return nil, err diff --git a/internal/database/queries/products.sql b/internal/database/queries/products.sql index 01dc8e6..04a2b42 100644 --- a/internal/database/queries/products.sql +++ b/internal/database/queries/products.sql @@ -192,19 +192,27 @@ WHERE external_id IN (SELECT pr.external_id -- name: GetRecentPullRequests :many -SELECT external_id -FROM pull_requests -WHERE created_at >= unixepoch() - 300 -AND state = 'OPEN' -ORDER BY created_at DESC; +SELECT pr.external_id, pr.repository_name, po.organisation_id +FROM pull_requests pr + JOIN repositories r ON r.name = pr.repository_name + JOIN products p ON JSON_VALID(p.tags) + AND EXISTS (SELECT 1 FROM JSON_EACH(p.tags) WHERE JSON_EACH.value = r.topic) + JOIN product_organisations po ON po.product_id = p.id +WHERE pr.created_at >= unixepoch() - 300 + AND pr.state = 'OPEN' +ORDER BY pr.created_at DESC; -- name: GetRecentSecurity :many -SELECT external_id -FROM securities -WHERE created_at >= unixepoch() - 300 -and state = 'OPEN' -ORDER BY created_at DESC; +SELECT sec.external_id, sec.repository_name, po.organisation_id +FROM securities sec + JOIN repositories r ON r.name = sec.repository_name + JOIN products p ON JSON_VALID(p.tags) + AND EXISTS (SELECT 1 FROM JSON_EACH(p.tags) WHERE JSON_EACH.value = r.topic) + JOIN product_organisations po ON po.product_id = p.id +WHERE sec.created_at >= unixepoch() - 300 + and state = 'OPEN' +ORDER BY sec.created_at DESC; -- name: CreateSecurity :one INSERT INTO securities (external_id, diff --git a/internal/products/interfaces.go b/internal/products/interfaces.go index 55d4986..970dfb8 100644 --- a/internal/products/interfaces.go +++ b/internal/products/interfaces.go @@ -38,7 +38,7 @@ type PullRequestStore interface { CreatePullRequest(ctx context.Context, arg database.CreatePullRequestParams) (database.PullRequest, error) UpdatePullRequest(ctx context.Context, arg database.UpdatePullRequestParams) (database.PullRequest, error) GetPullRequestByExternalID(ctx context.Context, externalID string) (database.PullRequest, error) - GetRecentPullRequests(ctx context.Context) ([]string, error) + GetRecentPullRequests(ctx context.Context) ([]database.GetRecentPullRequestsRow, error) DeletePullRequestsByProductID(ctx context.Context, id int64) error } @@ -46,7 +46,7 @@ type SecurityStore interface { GetSecurityByProductIDAndState(ctx context.Context, arg database.GetSecurityByProductIDAndStateParams) ([]database.GetSecurityByProductIDAndStateRow, error) GetSecurityByOrganisationAndState(ctx context.Context, arg database.GetSecurityByOrganisationAndStateParams) ([]database.GetSecurityByOrganisationAndStateRow, error) CreateSecurity(ctx context.Context, arg database.CreateSecurityParams) (database.Security, error) - GetRecentSecurity(ctx context.Context) ([]string, error) + GetRecentSecurity(ctx context.Context) ([]database.GetRecentSecurityRow, error) UpdateSecurity(ctx context.Context, arg database.UpdateSecurityParams) (database.Security, error) GetSecurityByExternalID(ctx context.Context, externalID string) (database.Security, error) DeleteSecurityByProductID(ctx context.Context, id int64) error diff --git a/internal/products/service.go b/internal/products/service.go index 905e5d3..64568c1 100644 --- a/internal/products/service.go +++ b/internal/products/service.go @@ -372,16 +372,16 @@ func (s *Service) UpsertPullRequest(ctx context.Context, params CreatePRParams) }) } -func (s *Service) GetRecentPullRequests(ctx context.Context) ([]string, error) { +func (s *Service) GetRecentPullRequests(ctx context.Context) ([]RecentlyChangedEntity, error) { logger := logging.FromContext(ctx).With("service", "products") logger.Debug("Getting recent pull requests") - externalIDs, err := s.store.GetRecentPullRequests(ctx) + recentPRs, err := s.store.GetRecentPullRequests(ctx) if err != nil { logger.Error("Error fetching recent pull requests", "error", err) return nil, err } - return externalIDs, nil + return fromRecentlyChangedPRModels(recentPRs), nil } func (s *Service) BulkCreatePullRequest(ctx context.Context, paramsList []CreatePRParams) error { @@ -435,17 +435,17 @@ func (s *Service) GetSecurityByOrg(ctx context.Context, orgID int64) ([]Security return orgToSecurityDTOs(model), nil } -func (s *Service) GetRecentSecurity(ctx context.Context) ([]string, error) { +func (s *Service) GetRecentSecurity(ctx context.Context) ([]RecentlyChangedEntity, error) { logger := logging.FromContext(ctx).With("service", "products") logger.Debug("Getting recent security") - externalIDs, err := s.store.GetRecentSecurity(ctx) + secList, err := s.store.GetRecentSecurity(ctx) if err != nil { logger.Error("Error fetching recent security", "error", err) return nil, err } - return externalIDs, nil + return fromRecentlyChangedSecurityModels(secList), nil } func (s *Service) BulkCreateSecurity(ctx context.Context, paramsList []CreateSecurityParams) error { diff --git a/internal/products/service_test.go b/internal/products/service_test.go index 4820315..67954f0 100644 --- a/internal/products/service_test.go +++ b/internal/products/service_test.go @@ -10,6 +10,7 @@ import ( "watchtower/internal/github" "github.com/code-gorilla-au/odize" + "github.com/google/uuid" ) func TestService(t *testing.T) { @@ -281,16 +282,37 @@ func TestService(t *testing.T) { }). Test("GetRecentPullRequests should return external IDs of recent PRs", func(t *testing.T) { params := CreatePRParams{ - ExternalID: "recent-pr-1", - Title: "Recent PR", - RepositoryName: "repo1", + ExternalID: uuid.New().String(), + Title: uuid.New().String(), + RepositoryName: uuid.New().String(), Url: "url1", State: "OPEN", Author: "author1", CreatedAt: time.Now(), } - err := s.CreatePullRequest(ctx, params) + err := s.CreateRepo(ctx, CreateRepoParams{Name: params.RepositoryName, Topic: "tag", Owner: "owner"}) + odize.AssertNoError(t, err) + + prodID, err := s.Create(ctx, CreateProductParams{ + Name: params.RepositoryName, + Desc: "", + Tags: []string{"tag"}, + }) + odize.AssertNoError(t, err) + + _ = _testDB.AddProductToOrganisation(ctx, database.AddProductToOrganisationParams{ + ProductID: sql.NullInt64{ + Int64: prodID.ID, + Valid: true, + }, + OrganisationID: sql.NullInt64{ + Int64: 0, + Valid: true, + }, + }) + + err = s.CreatePullRequest(ctx, params) odize.AssertNoError(t, err) recent, err := s.GetRecentPullRequests(ctx) @@ -298,8 +320,8 @@ func TestService(t *testing.T) { odize.AssertTrue(t, len(recent) > 0) found := false - for _, id := range recent { - if id == params.ExternalID { + for _, entity := range recent { + if entity.ExternalID == params.ExternalID { found = true break } @@ -356,8 +378,8 @@ func TestService(t *testing.T) { odize.AssertTrue(t, len(recent) > 0) found := false - for _, id := range recent { - if id == params.ExternalID { + for _, entity := range recent { + if entity.ExternalID == params.ExternalID { found = true break } diff --git a/internal/products/transforms.go b/internal/products/transforms.go index 442f951..7d2cc18 100644 --- a/internal/products/transforms.go +++ b/internal/products/transforms.go @@ -207,3 +207,37 @@ func toSecParamsFromGithubVulnerabilities(secs github.RootNode[github.Vulnerabil return result } + +func fromRecentlyChangedPRModel(model database.GetRecentPullRequestsRow) RecentlyChangedEntity { + return RecentlyChangedEntity{ + ExternalID: model.ExternalID, + OrganisationID: model.OrganisationID.Int64, + RepositoryName: model.RepositoryName, + } +} + +func fromRecentlyChangedPRModels(models []database.GetRecentPullRequestsRow) []RecentlyChangedEntity { + result := make([]RecentlyChangedEntity, 0, len(models)) + for _, m := range models { + result = append(result, fromRecentlyChangedPRModel(m)) + } + + return result +} + +func fromRecentlyChangedSecurityModel(model database.GetRecentSecurityRow) RecentlyChangedEntity { + return RecentlyChangedEntity{ + ExternalID: model.ExternalID, + OrganisationID: model.OrganisationID.Int64, + RepositoryName: model.RepositoryName, + } +} + +func fromRecentlyChangedSecurityModels(models []database.GetRecentSecurityRow) []RecentlyChangedEntity { + result := make([]RecentlyChangedEntity, 0, len(models)) + for _, m := range models { + result = append(result, fromRecentlyChangedSecurityModel(m)) + } + + return result +} diff --git a/internal/products/types.go b/internal/products/types.go index 0aa9387..2dac93c 100644 --- a/internal/products/types.go +++ b/internal/products/types.go @@ -129,3 +129,9 @@ type UpdateSecurityParams struct { PatchedVersion string FixedAt *time.Time } + +type RecentlyChangedEntity struct { + ExternalID string + RepositoryName string + OrganisationID int64 +} diff --git a/internal/watchtower/sync.go b/internal/watchtower/sync.go index 888b933..bd3ae4c 100644 --- a/internal/watchtower/sync.go +++ b/internal/watchtower/sync.go @@ -3,6 +3,7 @@ package watchtower import ( "context" "database/sql" + "fmt" "strings" "watchtower/internal/database" "watchtower/internal/notifications" @@ -36,27 +37,13 @@ func (s *Service) Startup(ctx context.Context) { func (s *Service) CreateUnreadPRNotification() error { logger := logging.FromContext(s.ctx) - prIDs, err := s.productSvc.GetRecentPullRequests(s.ctx) + prs, err := s.productSvc.GetRecentPullRequests(s.ctx) if err != nil { - logging.FromContext(s.ctx).Error("Error fetching recent pull requests", "error", err) + logger.Error("Error fetching recent pull requests", "error", err) return err } - logger.Debug("Creating unread notifications for pull requests", "count", len(prIDs)) - - for _, id := range prIDs { - if notifyErr := s.notificationSvc.CreateNotification(s.ctx, notifications.CreateNotificationParams{ - OrgID: 0, - ExternalID: id, - NotificationType: "OPEN_PULL_REQUEST", - Content: "New pull request", - }); notifyErr != nil { - logger.Error("Error creating notification", "error", err) - return err - } - } - - return nil + return s.createNotification("OPEN_PULL_REQUEST", "New pull request", prs) } // CreateUnreadSecurityNotification generates unread security notifications for recent security alerts. @@ -64,21 +51,26 @@ func (s *Service) CreateUnreadPRNotification() error { // Returns an error if fetching security IDs or creating notifications fails. func (s *Service) CreateUnreadSecurityNotification() error { logger := logging.FromContext(s.ctx) - externalIDs, err := s.productSvc.GetRecentSecurity(s.ctx) + secList, err := s.productSvc.GetRecentSecurity(s.ctx) if err != nil { logger.Error("Error fetching recent security", "error", err) return err } - logger.Debug("creating unread notifications for security alerts", "count", len(externalIDs)) + return s.createNotification("OPEN_SECURITY_ALERT", "New security alert", secList) +} - for _, id := range externalIDs { - if notifyErr := s.notificationSvc.CreateNotification(s.ctx, notifications.CreateNotificationParams{ - OrgID: 0, - ExternalID: id, - NotificationType: "SECURITY_ALERT", - Content: "New security alert", - }); notifyErr != nil { +func (s *Service) createNotification(notificationType string, content string, recentlyChanged []products.RecentlyChangedEntity) error { + logger := logging.FromContext(s.ctx) + logger.Debug("creating unread notifications", "count", len(recentlyChanged)) + + for _, entity := range recentlyChanged { + if err := s.notificationSvc.CreateNotification(s.ctx, notifications.CreateNotificationParams{ + OrgID: entity.OrganisationID, + ExternalID: entity.ExternalID, + NotificationType: notificationType, + Content: fmt.Sprintf("%s: %s", entity.RepositoryName, content), + }); err != nil { logger.Error("Error creating notification", "error", err) return err }