Skip to content

Commit

Permalink
move plugins to hook_*
Browse files Browse the repository at this point in the history
  • Loading branch information
linyows committed Oct 8, 2023
1 parent ebfd3d7 commit 49e1345
Show file tree
Hide file tree
Showing 12 changed files with 366 additions and 363 deletions.
26 changes: 26 additions & 0 deletions hook.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package warp

import (
"time"
)

type Hook interface {
AfterInit()
AfterComm(*AfterCommData)
AfterConn(*AfterConnData)
}

type AfterCommData struct {
ConnID string
OccurredAt time.Time
Data
Direction
}

type AfterConnData struct {
ConnID string
OccurredAt time.Time
MailFrom []byte
MailTo []byte
Elapse
}
69 changes: 69 additions & 0 deletions hook_file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package warp

import (
"fmt"
"io"
"os"
"time"
)

const (
fileCommJson string = `{"type":"comm","occurred_at":"%s","connection_id":"%s","direction":"%s","data":"%s"}
`
fileConnJson string = `{"type":"conn","occurred_at":"%s","connection_id":"%s","from":"%s","to":"%s","elapse":"%s"}
`
)

type HookFile struct {
file io.Writer
}

func (h *HookFile) prefix() string {
return "file"
}

func (h *HookFile) writer() (io.Writer, error) {
if h.file != nil {
return h.file, nil
}

path := os.Getenv("FILE_PATH")
if len(path) == 0 {
return nil, fmt.Errorf("missing path for file, please set `FILE_PATH`")
}

var err error
h.file, err = os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return nil, fmt.Errorf("os.OpenFile error: %s\n", err)
}

return h.file, nil
}

func (h *HookFile) AfterInit() {
}

func (h *HookFile) AfterComm(d *AfterCommData) {
writer, err := h.writer()
if err != nil {
fmt.Printf("[%s] %s\n", h.prefix(), err)
return
}

if _, err := fmt.Fprintf(writer, fileCommJson, d.OccurredAt.Format(time.RFC3339), d.ConnID, d.Direction, d.Data); err != nil {
fmt.Printf("[%s] file append error: %s\n", h.prefix(), err)
}
}

func (h *HookFile) AfterConn(d *AfterConnData) {
writer, err := h.writer()
if err != nil {
fmt.Printf("[%s] %s\n", h.prefix(), err)
return
}

if _, err := fmt.Fprintf(writer, fileConnJson, d.OccurredAt.Format(time.RFC3339), d.ConnID, d.MailFrom, d.MailTo, d.Elapse); err != nil {
fmt.Printf("[%s] file append error: %s\n", h.prefix(), err)
}
}
45 changes: 23 additions & 22 deletions plugins/file/main_test.go → hook_file_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package warp

import (
"bytes"
Expand All @@ -7,20 +7,12 @@ import (
"strings"
"testing"
"time"

"github.com/linyows/warp"
)

func TestConst(t *testing.T) {
func TestHookFileConst(t *testing.T) {
var expect string
var got string

expect = "file-plugin"
got = prefix
if got != expect {
t.Errorf("expected %s, got %s", expect, got)
}

replace := func(str string) string {
return strings.ReplaceAll(
strings.ReplaceAll(str, "\n", ""),
Expand All @@ -36,7 +28,7 @@ func TestConst(t *testing.T) {
"data":"%s"
}
`)
got = commJson
got = fileCommJson
if got != expect {
t.Errorf("expected %s, got %s", expect, got)
}
Expand All @@ -51,13 +43,22 @@ func TestConst(t *testing.T) {
"elapse":"%s"
}
`)
got = connJson
got = fileConnJson
if got != expect {
t.Errorf("expected %s, got %s", expect, got)
}
}

func TestHookFilePrefix(t *testing.T) {
f := &HookFile{}
expect := "file"
got := f.prefix()
if got != expect {
t.Errorf("expected %s, got %s", expect, got)
}
}

func TestWriter(t *testing.T) {
func TestHookFileWriter(t *testing.T) {
var tests = []struct {
expectFileName string
expectError string
Expand All @@ -71,10 +72,10 @@ func TestWriter(t *testing.T) {
envVal: "",
},
{
expectFileName: "/tmp/warp-plugin-file",
expectFileName: "/tmp/warp-file",
expectError: "",
envName: "FILE_PATH",
envVal: "/tmp/warp-plugin-file",
envVal: "/tmp/warp-file",
},
}

Expand All @@ -84,7 +85,7 @@ func TestWriter(t *testing.T) {
defer os.Unsetenv(v.envName)
}

f := File{}
f := &HookFile{}
w, err := f.writer()

if w != nil || v.expectFileName != "" {
Expand All @@ -99,13 +100,13 @@ func TestWriter(t *testing.T) {
}
}

func TestAfterComm(t *testing.T) {
func TestHookFileAfterComm(t *testing.T) {
ti := time.Date(2023, time.August, 16, 14, 48, 0, 0, time.UTC)
buffer := new(bytes.Buffer)
f := File{
f := &HookFile{
file: buffer,
}
data := &warp.AfterCommData{
data := &AfterCommData{
ConnID: "abcdefg",
OccurredAt: ti,
Data: []byte("hello"),
Expand All @@ -120,13 +121,13 @@ func TestAfterComm(t *testing.T) {
}
}

func TestAfterConn(t *testing.T) {
func TestHookFileAfterConn(t *testing.T) {
ti := time.Date(2023, time.August, 16, 14, 48, 0, 0, time.UTC)
buffer := new(bytes.Buffer)
f := File{
f := &HookFile{
file: buffer,
}
data := &warp.AfterConnData{
data := &AfterConnData{
ConnID: "abcdefg",
OccurredAt: ti,
MailFrom: []byte("[email protected]"),
Expand Down
82 changes: 82 additions & 0 deletions hook_mysql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package warp

import (
"database/sql"
"fmt"
"os"
)

const (
mysqlCommQuery string = "insert into communications (id, connection_id, occurred_at, direction, data) values (?, ?, ?, ?, ?)"
mysqlConnQuery string = "insert into connections (id, occurred_at, mail_from, mail_to, elapse) values (?, ?, ?, ?, ?)"
)

type HookMysql struct {
pool *sql.DB // Database connection pool.
}

func (h *HookMysql) prefix() string {
return "mysql"
}

func (h *HookMysql) conn() (*sql.DB, error) {
if h.pool != nil {
return h.pool, nil
}

dsn := os.Getenv("DSN")
if len(dsn) == 0 {
return nil, fmt.Errorf("missing dsn for mysql, please set `DSN`")
}

var err error
h.pool, err = sql.Open("mysql", dsn)
if err != nil {
return nil, fmt.Errorf("sql.Open error: s%s\n", err)
}

return h.pool, nil
}

func (h *HookMysql) AfterInit() {
}

func (h *HookMysql) AfterComm(d *AfterCommData) {
conn, err := h.conn()
if err != nil {
fmt.Printf("[%s] %s\n", h.prefix(), err)
return
}

_, err = conn.Exec(
mysqlCommQuery,
GenID().String(),
d.ConnID,
d.OccurredAt.Format(TimeFormat),
d.Direction,
d.Data,
)
if err != nil {
fmt.Printf("[%s] db exec error: %s\n", h.prefix(), err)
}
}

func (h *HookMysql) AfterConn(d *AfterConnData) {
conn, err := h.conn()
if err != nil {
fmt.Printf("[%s] %s\n", h.prefix(), err)
return
}

_, err = conn.Exec(
mysqlConnQuery,
d.ConnID,
d.OccurredAt.Format(TimeFormat),
d.MailFrom,
d.MailTo,
d.Elapse,
)
if err != nil {
fmt.Printf("[%s] db exec error: %s\n", h.prefix(), err)
}
}
Loading

0 comments on commit 49e1345

Please sign in to comment.