Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions internal/database/products.sql.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 18 additions & 10 deletions internal/database/queries/products.sql
Original file line number Diff line number Diff line change
Expand Up @@ -192,19 +192,27 @@ WHERE external_id IN (SELECT pr.external_id


-- name: GetRecentPullRequests :many
SELECT external_id
FROM pull_requests
WHERE created_at >= unixepoch() - 300
AND state = 'OPEN'
ORDER BY created_at DESC;
SELECT pr.external_id, pr.repository_name, po.organisation_id
FROM pull_requests pr
JOIN repositories r ON r.name = pr.repository_name
JOIN products p ON JSON_VALID(p.tags)
AND EXISTS (SELECT 1 FROM JSON_EACH(p.tags) WHERE JSON_EACH.value = r.topic)
JOIN product_organisations po ON po.product_id = p.id
WHERE pr.created_at >= unixepoch() - 300
AND pr.state = 'OPEN'
ORDER BY pr.created_at DESC;


-- name: GetRecentSecurity :many
SELECT external_id
FROM securities
WHERE created_at >= unixepoch() - 300
and state = 'OPEN'
ORDER BY created_at DESC;
SELECT sec.external_id, sec.repository_name, po.organisation_id
FROM securities sec
JOIN repositories r ON r.name = sec.repository_name
JOIN products p ON JSON_VALID(p.tags)
AND EXISTS (SELECT 1 FROM JSON_EACH(p.tags) WHERE JSON_EACH.value = r.topic)
JOIN product_organisations po ON po.product_id = p.id
WHERE sec.created_at >= unixepoch() - 300
and state = 'OPEN'
ORDER BY sec.created_at DESC;

-- name: CreateSecurity :one
INSERT INTO securities (external_id,
Expand Down
4 changes: 2 additions & 2 deletions internal/products/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ type PullRequestStore interface {
CreatePullRequest(ctx context.Context, arg database.CreatePullRequestParams) (database.PullRequest, error)
UpdatePullRequest(ctx context.Context, arg database.UpdatePullRequestParams) (database.PullRequest, error)
GetPullRequestByExternalID(ctx context.Context, externalID string) (database.PullRequest, error)
GetRecentPullRequests(ctx context.Context) ([]string, error)
GetRecentPullRequests(ctx context.Context) ([]database.GetRecentPullRequestsRow, error)
DeletePullRequestsByProductID(ctx context.Context, id int64) error
}

type SecurityStore interface {
GetSecurityByProductIDAndState(ctx context.Context, arg database.GetSecurityByProductIDAndStateParams) ([]database.GetSecurityByProductIDAndStateRow, error)
GetSecurityByOrganisationAndState(ctx context.Context, arg database.GetSecurityByOrganisationAndStateParams) ([]database.GetSecurityByOrganisationAndStateRow, error)
CreateSecurity(ctx context.Context, arg database.CreateSecurityParams) (database.Security, error)
GetRecentSecurity(ctx context.Context) ([]string, error)
GetRecentSecurity(ctx context.Context) ([]database.GetRecentSecurityRow, error)
UpdateSecurity(ctx context.Context, arg database.UpdateSecurityParams) (database.Security, error)
GetSecurityByExternalID(ctx context.Context, externalID string) (database.Security, error)
DeleteSecurityByProductID(ctx context.Context, id int64) error
Expand Down
12 changes: 6 additions & 6 deletions internal/products/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,16 +372,16 @@ func (s *Service) UpsertPullRequest(ctx context.Context, params CreatePRParams)
})
}

func (s *Service) GetRecentPullRequests(ctx context.Context) ([]string, error) {
func (s *Service) GetRecentPullRequests(ctx context.Context) ([]RecentlyChangedEntity, error) {
logger := logging.FromContext(ctx).With("service", "products")
logger.Debug("Getting recent pull requests")
externalIDs, err := s.store.GetRecentPullRequests(ctx)
recentPRs, err := s.store.GetRecentPullRequests(ctx)
if err != nil {
logger.Error("Error fetching recent pull requests", "error", err)
return nil, err
}

return externalIDs, nil
return fromRecentlyChangedPRModels(recentPRs), nil
}

func (s *Service) BulkCreatePullRequest(ctx context.Context, paramsList []CreatePRParams) error {
Expand Down Expand Up @@ -435,17 +435,17 @@ func (s *Service) GetSecurityByOrg(ctx context.Context, orgID int64) ([]Security
return orgToSecurityDTOs(model), nil
}

func (s *Service) GetRecentSecurity(ctx context.Context) ([]string, error) {
func (s *Service) GetRecentSecurity(ctx context.Context) ([]RecentlyChangedEntity, error) {
logger := logging.FromContext(ctx).With("service", "products")
logger.Debug("Getting recent security")

externalIDs, err := s.store.GetRecentSecurity(ctx)
secList, err := s.store.GetRecentSecurity(ctx)
if err != nil {
logger.Error("Error fetching recent security", "error", err)
return nil, err
}

return externalIDs, nil
return fromRecentlyChangedSecurityModels(secList), nil
}

func (s *Service) BulkCreateSecurity(ctx context.Context, paramsList []CreateSecurityParams) error {
Expand Down
38 changes: 30 additions & 8 deletions internal/products/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"watchtower/internal/github"

"github.com/code-gorilla-au/odize"
"github.com/google/uuid"
)

func TestService(t *testing.T) {
Expand Down Expand Up @@ -281,25 +282,46 @@ func TestService(t *testing.T) {
}).
Test("GetRecentPullRequests should return external IDs of recent PRs", func(t *testing.T) {
params := CreatePRParams{
ExternalID: "recent-pr-1",
Title: "Recent PR",
RepositoryName: "repo1",
ExternalID: uuid.New().String(),
Title: uuid.New().String(),
RepositoryName: uuid.New().String(),
Url: "url1",
State: "OPEN",
Author: "author1",
CreatedAt: time.Now(),
}

err := s.CreatePullRequest(ctx, params)
err := s.CreateRepo(ctx, CreateRepoParams{Name: params.RepositoryName, Topic: "tag", Owner: "owner"})
odize.AssertNoError(t, err)

prodID, err := s.Create(ctx, CreateProductParams{
Name: params.RepositoryName,
Desc: "",
Tags: []string{"tag"},
})
odize.AssertNoError(t, err)

_ = _testDB.AddProductToOrganisation(ctx, database.AddProductToOrganisationParams{
ProductID: sql.NullInt64{
Int64: prodID.ID,
Valid: true,
},
OrganisationID: sql.NullInt64{
Int64: 0,
Valid: true,
},
})

err = s.CreatePullRequest(ctx, params)
odize.AssertNoError(t, err)

recent, err := s.GetRecentPullRequests(ctx)
odize.AssertNoError(t, err)
odize.AssertTrue(t, len(recent) > 0)

found := false
for _, id := range recent {
if id == params.ExternalID {
for _, entity := range recent {
if entity.ExternalID == params.ExternalID {
found = true
break
}
Expand Down Expand Up @@ -356,8 +378,8 @@ func TestService(t *testing.T) {
odize.AssertTrue(t, len(recent) > 0)

found := false
for _, id := range recent {
if id == params.ExternalID {
for _, entity := range recent {
if entity.ExternalID == params.ExternalID {
found = true
break
}
Expand Down
34 changes: 34 additions & 0 deletions internal/products/transforms.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,37 @@ func toSecParamsFromGithubVulnerabilities(secs github.RootNode[github.Vulnerabil

return result
}

func fromRecentlyChangedPRModel(model database.GetRecentPullRequestsRow) RecentlyChangedEntity {
return RecentlyChangedEntity{
ExternalID: model.ExternalID,
OrganisationID: model.OrganisationID.Int64,
RepositoryName: model.RepositoryName,
}
}

func fromRecentlyChangedPRModels(models []database.GetRecentPullRequestsRow) []RecentlyChangedEntity {
result := make([]RecentlyChangedEntity, 0, len(models))
for _, m := range models {
result = append(result, fromRecentlyChangedPRModel(m))
}

return result
}

func fromRecentlyChangedSecurityModel(model database.GetRecentSecurityRow) RecentlyChangedEntity {
return RecentlyChangedEntity{
ExternalID: model.ExternalID,
OrganisationID: model.OrganisationID.Int64,
RepositoryName: model.RepositoryName,
}
}

func fromRecentlyChangedSecurityModels(models []database.GetRecentSecurityRow) []RecentlyChangedEntity {
result := make([]RecentlyChangedEntity, 0, len(models))
for _, m := range models {
result = append(result, fromRecentlyChangedSecurityModel(m))
}

return result
}
6 changes: 6 additions & 0 deletions internal/products/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,9 @@ type UpdateSecurityParams struct {
PatchedVersion string
FixedAt *time.Time
}

type RecentlyChangedEntity struct {
ExternalID string
RepositoryName string
OrganisationID int64
}
44 changes: 18 additions & 26 deletions internal/watchtower/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package watchtower
import (
"context"
"database/sql"
"fmt"
"strings"
"watchtower/internal/database"
"watchtower/internal/notifications"
Expand Down Expand Up @@ -36,49 +37,40 @@ func (s *Service) Startup(ctx context.Context) {
func (s *Service) CreateUnreadPRNotification() error {
logger := logging.FromContext(s.ctx)

prIDs, err := s.productSvc.GetRecentPullRequests(s.ctx)
prs, err := s.productSvc.GetRecentPullRequests(s.ctx)
if err != nil {
logging.FromContext(s.ctx).Error("Error fetching recent pull requests", "error", err)
logger.Error("Error fetching recent pull requests", "error", err)
return err
}

logger.Debug("Creating unread notifications for pull requests", "count", len(prIDs))

for _, id := range prIDs {
if notifyErr := s.notificationSvc.CreateNotification(s.ctx, notifications.CreateNotificationParams{
OrgID: 0,
ExternalID: id,
NotificationType: "OPEN_PULL_REQUEST",
Content: "New pull request",
}); notifyErr != nil {
logger.Error("Error creating notification", "error", err)
return err
}
}

return nil
return s.createNotification("OPEN_PULL_REQUEST", "New pull request", prs)
}

// CreateUnreadSecurityNotification generates unread security notifications for recent security alerts.
// It retrieves recent security-related IDs and creates notifications for each using the notification service.
// Returns an error if fetching security IDs or creating notifications fails.
func (s *Service) CreateUnreadSecurityNotification() error {
logger := logging.FromContext(s.ctx)
externalIDs, err := s.productSvc.GetRecentSecurity(s.ctx)
secList, err := s.productSvc.GetRecentSecurity(s.ctx)
if err != nil {
logger.Error("Error fetching recent security", "error", err)
return err
}

logger.Debug("creating unread notifications for security alerts", "count", len(externalIDs))
return s.createNotification("OPEN_SECURITY_ALERT", "New security alert", secList)
}

for _, id := range externalIDs {
if notifyErr := s.notificationSvc.CreateNotification(s.ctx, notifications.CreateNotificationParams{
OrgID: 0,
ExternalID: id,
NotificationType: "SECURITY_ALERT",
Content: "New security alert",
}); notifyErr != nil {
func (s *Service) createNotification(notificationType string, content string, recentlyChanged []products.RecentlyChangedEntity) error {
logger := logging.FromContext(s.ctx)
logger.Debug("creating unread notifications", "count", len(recentlyChanged))

for _, entity := range recentlyChanged {
if err := s.notificationSvc.CreateNotification(s.ctx, notifications.CreateNotificationParams{
OrgID: entity.OrganisationID,
ExternalID: entity.ExternalID,
NotificationType: notificationType,
Content: fmt.Sprintf("%s: %s", entity.RepositoryName, content),
}); err != nil {
logger.Error("Error creating notification", "error", err)
return err
}
Expand Down