From 27eb7c9ea0e419d817178012ae6ee67cc6097294 Mon Sep 17 00:00:00 2001 From: Scott Fleener Date: Sat, 7 Jan 2023 20:13:56 -0500 Subject: [PATCH] Update unit tests to include request count tracking --- provider/pihole/pihole_test.go | 50 +++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/provider/pihole/pihole_test.go b/provider/pihole/pihole_test.go index b936b68fb3..99e1fb92ee 100644 --- a/provider/pihole/pihole_test.go +++ b/provider/pihole/pihole_test.go @@ -26,6 +26,7 @@ import ( type testPiholeClient struct { endpoints []*endpoint.Endpoint + requests *requestTracker } func (t *testPiholeClient) listRecords(ctx context.Context, rtype string) ([]*endpoint.Endpoint, error) { @@ -40,6 +41,7 @@ func (t *testPiholeClient) listRecords(ctx context.Context, rtype string) ([]*en func (t *testPiholeClient) createRecord(ctx context.Context, ep *endpoint.Endpoint) error { t.endpoints = append(t.endpoints, ep) + t.requests.createRequests += 1 return nil } @@ -51,9 +53,20 @@ func (t *testPiholeClient) deleteRecord(ctx context.Context, ep *endpoint.Endpoi } } t.endpoints = newEPs + t.requests.deleteRequests += 1 return nil } +type requestTracker struct { + createRequests int + deleteRequests int +} + +func (r *requestTracker) clear() { + r.createRequests = 0 + r.deleteRequests = 0 +} + func TestNewPiholeProvider(t *testing.T) { // Test invalid configuration _, err := NewPiholeProvider(PiholeConfig{}) @@ -68,8 +81,9 @@ func TestNewPiholeProvider(t *testing.T) { } func TestProvider(t *testing.T) { + requests := requestTracker{} p := &PiholeProvider{ - api: &testPiholeClient{}, + api: &testPiholeClient{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests}, } records, err := p.Records(context.Background()) @@ -113,6 +127,12 @@ func TestProvider(t *testing.T) { if len(newRecords) != 3 { t.Fatal("Expected list of 3 records, got:", records) } + if requests.createRequests != 3 { + t.Fatal("Expected 3 create requests, got:", requests.createRequests) + } + if requests.deleteRequests != 0 { + t.Fatal("Expected no delete requests, got:", requests.deleteRequests) + } for idx, record := range records { if newRecords[idx].DNSName != record.DNSName { @@ -123,6 +143,8 @@ func TestProvider(t *testing.T) { } } + requests.clear() + // Test delete a record records = []*endpoint.Endpoint{ @@ -148,6 +170,12 @@ func TestProvider(t *testing.T) { }); err != nil { t.Fatal(err) } + if requests.createRequests != 0 { + t.Fatal("Expected no create requests, got:", requests.createRequests) + } + if requests.deleteRequests != 1 { + t.Fatal("Expected 1 delete request, got:", requests.deleteRequests) + } // Test records are updated newRecords, err = p.Records(context.Background()) @@ -167,6 +195,8 @@ func TestProvider(t *testing.T) { } } + requests.clear() + // Test update a record records = []*endpoint.Endpoint{ @@ -183,6 +213,11 @@ func TestProvider(t *testing.T) { } if err := p.ApplyChanges(context.Background(), &plan.Changes{ UpdateOld: []*endpoint.Endpoint{ + { + DNSName: "test1.example.com", + Targets: []string{"192.168.1.1"}, + RecordType: endpoint.RecordTypeA, + }, { DNSName: "test2.example.com", Targets: []string{"192.168.1.2"}, @@ -190,6 +225,11 @@ func TestProvider(t *testing.T) { }, }, UpdateNew: []*endpoint.Endpoint{ + { + DNSName: "test1.example.com", + Targets: []string{"192.168.1.1"}, + RecordType: endpoint.RecordTypeA, + }, { DNSName: "test2.example.com", Targets: []string{"10.0.0.1"}, @@ -208,6 +248,12 @@ func TestProvider(t *testing.T) { if len(newRecords) != 2 { t.Fatal("Expected list of 2 records, got:", records) } + if requests.createRequests != 1 { + t.Fatal("Expected 1 create request, got:", requests.createRequests) + } + if requests.deleteRequests != 1 { + t.Fatal("Expected 1 delete request, got:", requests.deleteRequests) + } for idx, record := range records { if newRecords[idx].DNSName != record.DNSName { @@ -217,4 +263,6 @@ func TestProvider(t *testing.T) { t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets) } } + + requests.clear() }