Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions internal/watchtower/notifications_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func TestService_Notifications(t *testing.T) {
err = s.DeleteOldNotifications()
odize.AssertNoError(t, err)
}).
Test("CreateUnreadPRNotification should create notifications for recent PRs", func(t *testing.T) {
Test("createUnreadPRNotification should create notifications for recent PRs", func(t *testing.T) {
// Setup: Org, Product, Repo, PR
org, err := s.CreateOrganisation("Test Org PR", "test-org-pr", "token", "desc")
odize.AssertNoError(t, err)
Expand Down Expand Up @@ -109,7 +109,7 @@ func TestService_Notifications(t *testing.T) {
odize.AssertNoError(t, err)

// Action
_, err = s.CreateUnreadPRNotification()
_, err = s.createUnreadPRNotification()
odize.AssertNoError(t, err)

// Verify
Expand All @@ -120,7 +120,7 @@ func TestService_Notifications(t *testing.T) {
odize.AssertTrue(t, strings.Contains(unreadNotification.Content, "repo-pr"))
odize.AssertTrue(t, strings.Contains(unreadNotification.Content, "New pull request"))
}).
Test("CreateUnreadSecurityNotification should create notifications for recent security alerts", func(t *testing.T) {
Test("createUnreadSecurityNotification should create notifications for recent security alerts", func(t *testing.T) {
// Setup: Org, Product, Repo, Security Alert
org, err := s.CreateOrganisation("Test Org Sec", "test-org-sec", "token", "desc")
odize.AssertNoError(t, err)
Expand Down Expand Up @@ -148,7 +148,7 @@ func TestService_Notifications(t *testing.T) {
odize.AssertNoError(t, err)

// Action
_, err = s.CreateUnreadSecurityNotification()
_, err = s.createUnreadSecurityNotification()
odize.AssertNoError(t, err)

// Verify
Expand Down
16 changes: 10 additions & 6 deletions internal/watchtower/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,22 @@ func NewService(ctx context.Context, db *database.Queries, txnDB *sql.DB) *Servi
}
}

// Startup initialises the service with the provided context and sets it for further use.
func (s *Service) Startup(ctx context.Context) {
s.ctx = ctx
}

// CreateUnreadNotification generates notifications for unread pull requests and security alerts and returns the total count.
func (s *Service) CreateUnreadNotification() (int, error) {
logger := logging.FromContext(s.ctx)

prCount, err := s.CreateUnreadPRNotification()
prCount, err := s.createUnreadPRNotification()
if err != nil {
logger.Error("Error creating unread pull request notification", "error", err)
return 0, err
}

secCount, err := s.CreateUnreadSecurityNotification()
secCount, err := s.createUnreadSecurityNotification()
if err != nil {
logger.Error("Error creating unread security notification", "error", err)
return 0, err
Expand All @@ -51,8 +53,8 @@ func (s *Service) CreateUnreadNotification() (int, error) {
return prCount + secCount, nil
}

// CreateUnreadPRNotification generates unread notifications for recent pull requests by fetching their IDs and creating notifications.
func (s *Service) CreateUnreadPRNotification() (int, error) {
// createUnreadPRNotification generates unread notifications for recent pull requests by fetching their IDs and creating notifications.
func (s *Service) createUnreadPRNotification() (int, error) {
logger := logging.FromContext(s.ctx)

prs, err := s.productSvc.GetRecentPullRequests(s.ctx)
Expand All @@ -64,10 +66,10 @@ func (s *Service) CreateUnreadPRNotification() (int, error) {
return s.createNotification("OPEN_PULL_REQUEST", "New pull request", prs)
}

// CreateUnreadSecurityNotification generates unread security notifications for recent security alerts.
// createUnreadSecurityNotification generates unread security notifications for recent security alerts.
// It retrieves recent security-related IDs and creates notifications for each using the notification service.
// Returns an error if fetching security IDs or creating notifications fails.
func (s *Service) CreateUnreadSecurityNotification() (int, error) {
func (s *Service) createUnreadSecurityNotification() (int, error) {
logger := logging.FromContext(s.ctx)
secList, err := s.productSvc.GetRecentSecurity(s.ctx)
if err != nil {
Expand All @@ -78,6 +80,8 @@ func (s *Service) CreateUnreadSecurityNotification() (int, error) {
return s.createNotification("OPEN_SECURITY_ALERT", "New security alert", secList)
}

// createNotification generates and dispatches notifications for a list of recently changed entities.
// Returns the count of notifications created or an error if the operation fails.
func (s *Service) createNotification(notificationType string, content string, recentlyChanged []products.RecentlyChangedEntity) (int, error) {

notificationsList := make([]notifications.CreateNotificationParams, len(recentlyChanged))
Expand Down
90 changes: 90 additions & 0 deletions internal/watchtower/sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2551,3 +2551,93 @@ func TestService_UpdateOrganisation(t *testing.T) {
Run()
odize.AssertNoError(t, err)
}

func TestService_CreateUnreadNotification(t *testing.T) {
group := odize.NewGroup(t, nil)

var s *Service

ctx := context.Background()

var orgID int64

group.BeforeAll(func() {
s = NewService(ctx, _testDB, _testTxnDB)
s.Startup(ctx)

org, err := s.CreateOrganisation("test_org_notifications", "test_org_namespace_notifications", "token", "test description")
if err != nil {
fmt.Print("create org error", err)
}
odize.AssertNoError(t, err)
orgID = org.ID

tags := []string{"web-app"}
_, err = s.CreateProduct("Web App Product", "A product for notifications", tags, orgID)
odize.AssertNoError(t, err)

err = s.productSvc.UpsertRepo(ctx, products.CreateRepoParams{
Name: "my-repo",
Url: "http://github.com/org/my-repo",
Topic: "web-app",
Owner: "org",
})
odize.AssertNoError(t, err)
})

group.BeforeEach(func() {
s = NewService(ctx, _testDB, _testTxnDB)
s.Startup(ctx)
})

err := group.
Test("should create unread notifications for PRs and Security alerts", func(t *testing.T) {

err := s.productSvc.UpsertPullRequest(ctx, products.CreatePRParams{
ExternalID: "pr-1",
Title: "Fix bug",
RepositoryName: "my-repo",
Url: "http://github.com/org/my-repo/pull/1",
State: "OPEN",
Author: "user",
CreatedAt: time.Now(),
})
odize.AssertNoError(t, err)

err = s.productSvc.UpsertSecurity(ctx, products.CreateSecurityParams{
ExternalID: "sec-1",
RepositoryName: "my-repo",
PackageName: "vulnerable-pkg",
State: "OPEN",
Severity: "HIGH",
PatchedVersion: "1.0.1",
CreatedAt: time.Now(),
})
odize.AssertNoError(t, err)

count, err := s.CreateUnreadNotification()
odize.AssertNoError(t, err)
odize.AssertEqual(t, count, 2)

unread, err := s.notificationSvc.GetUnreadNotifications(ctx)
odize.AssertNoError(t, err)

foundPR := false
foundSec := false
for _, n := range unread {
if n.ExternalID == "pr-1" {
foundPR = true
odize.AssertEqual(t, n.Type, "OPEN_PULL_REQUEST")
}
if n.ExternalID == "sec-1" {
foundSec = true
odize.AssertEqual(t, n.Type, "OPEN_SECURITY_ALERT")
}
}

odize.AssertTrue(t, foundPR)
odize.AssertTrue(t, foundSec)
}).
Run()
odize.AssertNoError(t, err)
}