Skip to content

Commit

Permalink
feat: interface for protocols (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
jesusprubio authored Nov 12, 2024
1 parent 784898b commit 64d77d3
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 353 deletions.
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

## Dependencies

- [Task](https://taskfile.dev/installation/)
- [Task](https://taskfile.dev/installation)
- Linters:

```sh
task dep
```
Expand Down
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ This will display help for the tool.
up -h
```

### Library

Check [the examples](examples) to see how to use this project in your own code.

## License

This project is under the MIT License. See the [LICENSE](LICENSE) file for the full text.
Expand Down
3 changes: 1 addition & 2 deletions examples/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ func main() {
reportCh := make(chan *pkg.Report)
defer close(reportCh)
probe := pkg.Probe{
Protocols: []*pkg.Protocol{pkg.Protocols[1]},
Protocols: []pkg.Protocol{&pkg.TCP{Timeout: 2 * time.Second}},
Count: 3,
Timeout: 2 * time.Second,
Logger: logger,
ReportCh: reportCh,
}
Expand Down
23 changes: 6 additions & 17 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,6 @@ const (
`
)

// ProtocolByID returns the protocol implementation whose ID matches the given
// one.
func ProtocolByID(id string) *pkg.Protocol {
for _, p := range pkg.Protocols {
if p.ID == id {
return p
}
}
return nil
}

// Fatal logs the error to the standard output and exits with status 1.
func Fatal(err error) {
fmt.Fprintf(os.Stderr, "%s: %s\n", appName, err)
Expand All @@ -45,15 +34,15 @@ func Fatal(err error) {

// ReportToLine returns a human-readable representation of the report.
func ReportToLine(r *pkg.Report) string {
line := fmt.Sprintf("%-15s %-14s %s", bold(r.ProtocolID), r.Time, r.RHost)
suffix := r.Extra
prefix := green("✔")
symbol := green("✔")
suffix := r.RHost
if r.Error != nil {
prefix = red("✘")
symbol = red("✘")
suffix = r.Error.Error()
}
suffix = fmt.Sprintf("(%s)", suffix)
return fmt.Sprintf("%s %s %s", prefix, line, faint(suffix))
return fmt.Sprintf("%s %s", symbol, fmt.Sprintf(
"%-15s %-14s %-15s", bold(r.ProtocolID), r.Time, faint(suffix),
))
}

var (
Expand Down
22 changes: 3 additions & 19 deletions internal/internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,23 @@ import (
"github.com/jesusprubio/up/pkg"
)

func TestProtocolByID(t *testing.T) {
t.Run("returns the protocol if it exists", func(t *testing.T) {
got := ProtocolByID("http")
if got == nil {
t.Fatal("got nil, want a protocol")
}
})
t.Run("returns nil if the protocol doesn't exist", func(t *testing.T) {
got := ProtocolByID("unknown")
if got != nil {
t.Fatalf("got %q, want nil", got)
}
})
}

func TestReportToLine(t *testing.T) {
r := &pkg.Report{
ProtocolID: "test",
Time: 1 * time.Second,
RHost: "test",
Extra: "test",
Time: 1 * time.Second,
}
t.Run("return success line if no error happened", func(t *testing.T) {
got := ReportToLine(r)
want := "✔ test 1s test (test)"
want := "✔ test 1s test "
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
})
t.Run("return error line if an error happened", func(t *testing.T) {
r.Error = errors.New("probe error")
got := ReportToLine(r)
want := "✘ test 1s test (probe error)"
want := "✘ test 1s probe error "
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
Expand Down
8 changes: 4 additions & 4 deletions internal/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ import (
// Options are the flags supported by the command line application.
type Options struct {
// Input flags.
//Custom DNS Resolver
DNSResolver string
// Protocol to use.
// Protocol to use. Example: 'http'.
Protocol string
// Number of iterations. Zero means infinite.
Count uint
Expand All @@ -20,6 +18,8 @@ type Options struct {
Delay time.Duration
// Stop after the first successful request.
Stop bool
// Custom DNS resolver.
DNSResolver string
// Output flags.
// Output in JSON format.
JSONOutput bool
Expand All @@ -33,7 +33,6 @@ type Options struct {

// Parse fulfills the command line flags provided by the user.
func (opts *Options) Parse() {
flag.StringVar(&opts.DNSResolver, "r", "", "DNS resolution server")
flag.StringVar(&opts.Protocol, "p", "", "Test only one protocol")
flag.UintVar(&opts.Count, "c", 0, "Number of iterations")
flag.DurationVar(
Expand All @@ -45,6 +44,7 @@ func (opts *Options) Parse() {
flag.BoolVar(
&opts.Stop, "s", false, "Stop after the first successful request",
)
flag.StringVar(&opts.DNSResolver, "dr", "", "DNS resolution server")
flag.BoolVar(&opts.JSONOutput, "j", false, "Output in JSON format")
flag.BoolVar(&opts.NoColor, "nc", false, "Disable color output")
flag.BoolVar(&opts.Debug, "dbg", false, "Verbose output")
Expand Down
24 changes: 17 additions & 7 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,27 @@ func main() {
lvl.Set(slog.LevelDebug)
}
logger.Debug("Starting ...", "options", opts)
protocols := pkg.Protocols
dnsProtocol := &pkg.DNS{Timeout: opts.Timeout}
if opts.DNSResolver != "" {
dnsProtocol.Resolver = opts.DNSResolver
}
protocols := []pkg.Protocol{
&pkg.HTTP{Timeout: opts.Timeout},
&pkg.TCP{Timeout: opts.Timeout},
dnsProtocol,
}
if opts.Protocol != "" {
protocol := internal.ProtocolByID(opts.Protocol)
var protocol pkg.Protocol
for _, p := range protocols {
if p.String() == opts.Protocol {
protocol = p
break
}
}
if protocol == nil {
internal.Fatal(fmt.Errorf("unknown protocol: %s", opts.Protocol))
}
if opts.DNSResolver != "" {
protocol.WithDNSResolver(opts.DNSResolver)
}
protocols = []*pkg.Protocol{protocol}
protocols = []pkg.Protocol{protocol}
}
logger.Info("Starting ...", "protocols", protocols, "count", opts.Count)
if opts.Help {
Expand Down Expand Up @@ -70,7 +81,6 @@ func main() {
probe := pkg.Probe{
Protocols: protocols,
Count: opts.Count,
Timeout: opts.Timeout,
Delay: opts.Delay,
Logger: logger,
ReportCh: reportCh,
Expand Down
41 changes: 13 additions & 28 deletions pkg/probe.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,13 @@ import (
"time"
)

const tmplRequiredProp = "required property: %s"

// Probe is an experiment to measure the connectivity of a network using
// different protocols and public servers.
type Probe struct {
// Protocols to use.
Protocols []*Protocol
Protocols []Protocol
// Number of iterations. Zero means infinite.
Count uint
// Time to wait for a response.
Timeout time.Duration
// Delay between requests.
Delay time.Duration
// For debugging purposes.
Expand All @@ -29,27 +25,22 @@ type Probe struct {
// Ensures the probe setup is correct.
func (p Probe) validate() error {
if p.Protocols == nil {
return fmt.Errorf(tmplRequiredProp, "Protocols")
}
for _, proto := range p.Protocols {
err := proto.validate()
if err != nil {
return fmt.Errorf("invalid protocol: %w", err)
}
}
if p.Timeout == 0 {
return fmt.Errorf(tmplRequiredProp, "Timeout")
return newErrorReqProp("Protocols")
}
// 'Delay' could be zero.
if p.Logger == nil {
return fmt.Errorf(tmplRequiredProp, "Logger")
return newErrorReqProp("Logger")
}
if p.ReportCh == nil {
return fmt.Errorf(tmplRequiredProp, "ReportCh")
return newErrorReqProp("ReportCh")
}
return nil
}

func newErrorReqProp(prop string) error {
return fmt.Errorf("required property: %s", prop)
}

// Run the connection requests against the public servers.
//
// The context can be cancelled between different protocol attempts or count
Expand Down Expand Up @@ -77,26 +68,20 @@ func (p Probe) Run(ctx context.Context) error {
case <-ctx.Done():
p.Logger.Debug(
"Context cancelled between protocols",
"count", count, "protocol", proto.ID,
"count", count, "protocol", proto,
)
return nil
default:
start := time.Now()
rhost, err := proto.RHost()
if err != nil {
return fmt.Errorf("creating remote host: %w", err)
}
p.Logger.Debug(
"New protocol",
"count", count, "protocol", proto.ID, "rhost", rhost,
"New protocol", "count", count, "protocol", proto,
)
extra, err := proto.Probe(rhost, p.Timeout)
rhost, err := proto.Probe("")
report := Report{
ProtocolID: proto.ID,
RHost: rhost,
ProtocolID: proto.String(),
Time: time.Since(start),
Error: err,
Extra: extra,
RHost: rhost,
}
p.Logger.Debug(
"Sending report back",
Expand Down
Loading

0 comments on commit 64d77d3

Please sign in to comment.