diff --git a/database/mock/fixtures/store.go b/database/mock/fixtures/store.go index 005553760b..410053c086 100644 --- a/database/mock/fixtures/store.go +++ b/database/mock/fixtures/store.go @@ -124,6 +124,9 @@ func WithSuccessfulUpsertPullRequest( mockStore.EXPECT(). UpsertPullRequest(gomock.Any(), gomock.Any()). Return(pullRequest, nil) + mockStore.EXPECT(). + CreateOrEnsureEntityByID(gomock.Any(), gomock.Any()). + Return(db.EntityInstance{}, nil) } } @@ -132,6 +135,9 @@ func WithSuccessfulDeletePullRequest() func(*mockdb.MockStore) { mockStore.EXPECT(). DeletePullRequest(gomock.Any(), gomock.Any()). Return(nil) + mockStore.EXPECT(). + DeleteEntityByName(gomock.Any(), gomock.Any()). + Return(nil) } } @@ -154,3 +160,20 @@ func WithSuccessfulUpsertArtifact( Return(artifact, nil) } } + +func WithTransaction() func(*mockdb.MockStore) { + return func(mockStore *mockdb.MockStore) { + mockStore.EXPECT(). + BeginTransaction(). + Return(nil, nil) + mockStore.EXPECT(). + GetQuerierWithTransaction(gomock.Any()). + Return(mockStore) + mockStore.EXPECT(). + Commit(gomock.Any()). + Return(nil) + mockStore.EXPECT(). + Rollback(gomock.Any()). + Return(nil) + } +} diff --git a/database/mock/store.go b/database/mock/store.go index aa3ce96b68..88e256ae97 100644 --- a/database/mock/store.go +++ b/database/mock/store.go @@ -206,6 +206,21 @@ func (mr *MockStoreMockRecorder) CreateInvitation(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateInvitation", reflect.TypeOf((*MockStore)(nil).CreateInvitation), arg0, arg1) } +// CreateOrEnsureEntityByID mocks base method. +func (m *MockStore) CreateOrEnsureEntityByID(arg0 context.Context, arg1 db.CreateOrEnsureEntityByIDParams) (db.EntityInstance, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateOrEnsureEntityByID", arg0, arg1) + ret0, _ := ret[0].(db.EntityInstance) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateOrEnsureEntityByID indicates an expected call of CreateOrEnsureEntityByID. +func (mr *MockStoreMockRecorder) CreateOrEnsureEntityByID(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrEnsureEntityByID", reflect.TypeOf((*MockStore)(nil).CreateOrEnsureEntityByID), arg0, arg1) +} + // CreateProfile mocks base method. func (m *MockStore) CreateProfile(arg0 context.Context, arg1 db.CreateProfileParams) (db.Profile, error) { m.ctrl.T.Helper() @@ -414,6 +429,20 @@ func (mr *MockStoreMockRecorder) DeleteEntity(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEntity", reflect.TypeOf((*MockStore)(nil).DeleteEntity), arg0, arg1) } +// DeleteEntityByName mocks base method. +func (m *MockStore) DeleteEntityByName(arg0 context.Context, arg1 db.DeleteEntityByNameParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteEntityByName", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteEntityByName indicates an expected call of DeleteEntityByName. +func (mr *MockStoreMockRecorder) DeleteEntityByName(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEntityByName", reflect.TypeOf((*MockStore)(nil).DeleteEntityByName), arg0, arg1) +} + // DeleteEvaluationHistoryByIDs mocks base method. func (m *MockStore) DeleteEvaluationHistoryByIDs(arg0 context.Context, arg1 []uuid.UUID) (int64, error) { m.ctrl.T.Helper() diff --git a/database/query/entities.sql b/database/query/entities.sql index 8fb793ea69..16d636cc6e 100644 --- a/database/query/entities.sql +++ b/database/query/entities.sql @@ -23,12 +23,33 @@ INSERT INTO entity_instances ( ) VALUES ($1, $2, $3, sqlc.arg(project_id), sqlc.arg(provider_id), sqlc.narg(originated_from)) RETURNING *; + +-- CreateOrEnsureEntityByID adds an entry to the entity_instances table if it does not exist, or returns the existing entry. + +-- name: CreateOrEnsureEntityByID :one +INSERT INTO entity_instances ( + id, + entity_type, + name, + project_id, + provider_id, + originated_from +) VALUES ($1, $2, $3, sqlc.arg(project_id), sqlc.arg(provider_id), sqlc.narg(originated_from)) +ON CONFLICT (id) DO NOTHING +RETURNING *; + -- DeleteEntity removes an entity from the entity_instances table for a project. -- name: DeleteEntity :exec DELETE FROM entity_instances WHERE id = $1 AND project_id = $2; +-- DeleteEntityByName removes an entity from the entity_instances table for a project. + +-- name: DeleteEntityByName :exec +DELETE FROM entity_instances +WHERE name = sqlc.arg(name) AND project_id = $1; + -- GetEntityByID retrieves an entity by its ID for a project or hierarchy of projects. -- name: GetEntityByID :one diff --git a/internal/controlplane/handlers_githubwebhooks.go b/internal/controlplane/handlers_githubwebhooks.go index ae4ca9d3af..b4c64bb486 100644 --- a/internal/controlplane/handlers_githubwebhooks.go +++ b/internal/controlplane/handlers_githubwebhooks.go @@ -1438,22 +1438,60 @@ func (s *Server) reconcilePrWithDb( // published, see here // https://pkg.go.dev/github.com/google/go-github/v62@v62.0.0/github#PullRequestEvent case webhookActionEventOpened, webhookActionEventReopened: - dbPr, err := s.store.UpsertPullRequest(ctx, db.UpsertPullRequestParams{ - RepositoryID: dbrepo.ID, - PrNumber: prEvalInfo.Number, + var err error + retPr, err = db.WithTransaction(s.store, func(t db.ExtendQuerier) (*db.PullRequest, error) { + dbPr, err := t.UpsertPullRequest(ctx, db.UpsertPullRequestParams{ + RepositoryID: dbrepo.ID, + PrNumber: prEvalInfo.Number, + }) + if err != nil { + return nil, err + } + + _, err = t.CreateOrEnsureEntityByID(ctx, db.CreateOrEnsureEntityByIDParams{ + ID: dbPr.ID, + EntityType: db.EntitiesPullRequest, + Name: pullRequestName(dbrepo.RepoOwner, dbrepo.RepoName, prEvalInfo.Number), + ProjectID: dbrepo.ProjectID, + ProviderID: dbrepo.ProviderID, + OriginatedFrom: uuid.NullUUID{ + UUID: dbrepo.ID, + Valid: true, + }, + }) + if err != nil { + return nil, err + } + + return &dbPr, nil }) if err != nil { return nil, fmt.Errorf( "cannot upsert PR %d in repo %s/%s", prEvalInfo.Number, dbrepo.RepoOwner, dbrepo.RepoName) } - retPr = &dbPr retErr = nil case webhookActionEventClosed: - err := s.store.DeletePullRequest(ctx, db.DeletePullRequestParams{ - RepositoryID: dbrepo.ID, - PrNumber: prEvalInfo.Number, + _, err := db.WithTransaction(s.store, func(t db.ExtendQuerier) (*db.PullRequest, error) { + err := t.DeletePullRequest(ctx, db.DeletePullRequestParams{ + RepositoryID: dbrepo.ID, + PrNumber: prEvalInfo.Number, + }) + if err != nil { + return nil, err + } + + err = t.DeleteEntityByName(ctx, db.DeleteEntityByNameParams{ + Name: pullRequestName(dbrepo.RepoOwner, dbrepo.RepoName, prEvalInfo.Number), + ProjectID: dbrepo.ProjectID, + }) + if err != nil { + return nil, err + } + + return nil, nil }) + if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("cannot delete PR record %d in repo %s/%s", prEvalInfo.Number, dbrepo.RepoOwner, dbrepo.RepoName) @@ -1485,3 +1523,7 @@ func updatePullRequestInfoFromProvider( prEvalInfo.RepoName = dbrepo.RepoName return nil } + +func pullRequestName(owner, name string, number int64) string { + return fmt.Sprintf("%s/%s/%d", owner, name, number) +} diff --git a/internal/controlplane/handlers_githubwebhooks_test.go b/internal/controlplane/handlers_githubwebhooks_test.go index a1f8a15be6..4f52dc92fd 100644 --- a/internal/controlplane/handlers_githubwebhooks_test.go +++ b/internal/controlplane/handlers_githubwebhooks_test.go @@ -3041,6 +3041,7 @@ func (s *UnitTestSuite) TestHandleGitHubWebHook() { df.WithSuccessfulUpsertPullRequest( db.PullRequest{}, ), + df.WithTransaction(), ), topic: events.TopicQueueEntityEvaluate, statusCode: http.StatusOK, @@ -3092,6 +3093,7 @@ func (s *UnitTestSuite) TestHandleGitHubWebHook() { }, ), df.WithSuccessfulDeletePullRequest(), + df.WithTransaction(), ), topic: events.TopicQueueEntityEvaluate, statusCode: http.StatusOK, diff --git a/internal/db/entities.sql.go b/internal/db/entities.sql.go index 132c684cda..91c92b1f3e 100644 --- a/internal/db/entities.sql.go +++ b/internal/db/entities.sql.go @@ -99,6 +99,52 @@ func (q *Queries) CreateEntityWithID(ctx context.Context, arg CreateEntityWithID return i, err } +const createOrEnsureEntityByID = `-- name: CreateOrEnsureEntityByID :one + +INSERT INTO entity_instances ( + id, + entity_type, + name, + project_id, + provider_id, + originated_from +) VALUES ($1, $2, $3, $4, $5, $6) +ON CONFLICT (id) DO NOTHING +RETURNING id, entity_type, name, project_id, provider_id, created_at, originated_from +` + +type CreateOrEnsureEntityByIDParams struct { + ID uuid.UUID `json:"id"` + EntityType Entities `json:"entity_type"` + Name string `json:"name"` + ProjectID uuid.UUID `json:"project_id"` + ProviderID uuid.UUID `json:"provider_id"` + OriginatedFrom uuid.NullUUID `json:"originated_from"` +} + +// CreateOrEnsureEntityByID adds an entry to the entity_instances table if it does not exist, or returns the existing entry. +func (q *Queries) CreateOrEnsureEntityByID(ctx context.Context, arg CreateOrEnsureEntityByIDParams) (EntityInstance, error) { + row := q.db.QueryRowContext(ctx, createOrEnsureEntityByID, + arg.ID, + arg.EntityType, + arg.Name, + arg.ProjectID, + arg.ProviderID, + arg.OriginatedFrom, + ) + var i EntityInstance + err := row.Scan( + &i.ID, + &i.EntityType, + &i.Name, + &i.ProjectID, + &i.ProviderID, + &i.CreatedAt, + &i.OriginatedFrom, + ) + return i, err +} + const deleteEntity = `-- name: DeleteEntity :exec DELETE FROM entity_instances @@ -116,6 +162,23 @@ func (q *Queries) DeleteEntity(ctx context.Context, arg DeleteEntityParams) erro return err } +const deleteEntityByName = `-- name: DeleteEntityByName :exec + +DELETE FROM entity_instances +WHERE name = $2 AND project_id = $1 +` + +type DeleteEntityByNameParams struct { + ProjectID uuid.UUID `json:"project_id"` + Name string `json:"name"` +} + +// DeleteEntityByName removes an entity from the entity_instances table for a project. +func (q *Queries) DeleteEntityByName(ctx context.Context, arg DeleteEntityByNameParams) error { + _, err := q.db.ExecContext(ctx, deleteEntityByName, arg.ProjectID, arg.Name) + return err +} + const getEntitiesByType = `-- name: GetEntitiesByType :many SELECT id, entity_type, name, project_id, provider_id, created_at, originated_from FROM entity_instances diff --git a/internal/db/querier.go b/internal/db/querier.go index 2cc3fec66a..332c6cd2ad 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -28,6 +28,8 @@ type Querier interface { // invitation. The project is the project to which the invitee will be invited. // The sponsor is the user who is inviting the invitee. CreateInvitation(ctx context.Context, arg CreateInvitationParams) (UserInvite, error) + // CreateOrEnsureEntityByID adds an entry to the entity_instances table if it does not exist, or returns the existing entry. + CreateOrEnsureEntityByID(ctx context.Context, arg CreateOrEnsureEntityByIDParams) (EntityInstance, error) CreateProfile(ctx context.Context, arg CreateProfileParams) (Profile, error) CreateProfileForEntity(ctx context.Context, arg CreateProfileForEntityParams) (EntityProfile, error) CreateProject(ctx context.Context, arg CreateProjectParams) (Project, error) @@ -44,6 +46,8 @@ type Querier interface { DeleteArtifact(ctx context.Context, id uuid.UUID) error // DeleteEntity removes an entity from the entity_instances table for a project. DeleteEntity(ctx context.Context, arg DeleteEntityParams) error + // DeleteEntityByName removes an entity from the entity_instances table for a project. + DeleteEntityByName(ctx context.Context, arg DeleteEntityByNameParams) error DeleteEvaluationHistoryByIDs(ctx context.Context, evaluationids []uuid.UUID) (int64, error) DeleteExpiredSessionStates(ctx context.Context) (int64, error) DeleteInstallationIDByAppID(ctx context.Context, appInstallationID int64) error