From 80677098aef651930dfe827775b6b58aefc6a8af Mon Sep 17 00:00:00 2001 From: Durand Fabrice Date: Wed, 13 Nov 2024 12:49:23 -0500 Subject: [PATCH] Updated db connection --- go/cmd/pfdhcp/api.go | 56 ++++++++++++++++++--------------- go/cmd/pfdhcp/config.go | 11 ++++--- go/cmd/pfdhcp/keysoption.go | 20 ++++++------ go/cmd/pfdhcp/main.go | 59 +++++++++++++++++------------------ go/cmd/pfdhcp/server.go | 3 +- go/cmd/pfdhcp/utils.go | 35 +++++++++++---------- go/cmd/pfdhcp/workers_pool.go | 4 ++- 7 files changed, 98 insertions(+), 90 deletions(-) diff --git a/go/cmd/pfdhcp/api.go b/go/cmd/pfdhcp/api.go index f640f46265a9..7a8058fbe05a 100644 --- a/go/cmd/pfdhcp/api.go +++ b/go/cmd/pfdhcp/api.go @@ -1,12 +1,12 @@ package main import ( + "database/sql" "encoding/binary" "encoding/json" "errors" "fmt" "io" - "io/ioutil" "net" "net/http" "strconv" @@ -21,6 +21,10 @@ import ( "github.com/inverse-inc/packetfence/go/pfconfigdriver" ) +type API struct { + DB *sql.DB +} + // Node struct type Node struct { Mac string `json:"mac"` @@ -82,7 +86,7 @@ type OptionsFromFilter struct { Type string `json:"type"` } -func handleIP2Mac(res http.ResponseWriter, req *http.Request) { +func (a *API) handleIP2Mac(res http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) if index, expiresAt, found := GlobalIPCache.GetWithExpiration(vars["ip"]); found { @@ -102,7 +106,7 @@ func handleIP2Mac(res http.ResponseWriter, req *http.Request) { return } -func handleMac2Ip(res http.ResponseWriter, req *http.Request) { +func (a *API) handleMac2Ip(res http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) if index, expiresAt, found := GlobalMacCache.GetWithExpiration(vars["mac"]); found { @@ -122,7 +126,7 @@ func handleMac2Ip(res http.ResponseWriter, req *http.Request) { return } -func handleAllStats(res http.ResponseWriter, req *http.Request) { +func (a *API) handleAllStats(res http.ResponseWriter, req *http.Request) { var result Items var interfaces pfconfigdriver.ListenInts pfconfigdriver.FetchDecodeSocket(ctx, &interfaces) @@ -132,7 +136,7 @@ func handleAllStats(res http.ResponseWriter, req *http.Request) { } for _, i := range interfaces.Element { if h, ok := intNametoInterface[i]; ok { - stat := h.handleAPIReq(APIReq{Req: "stats", NetInterface: i, NetWork: ""}) + stat := h.handleAPIReq(APIReq{Req: "stats", NetInterface: i, NetWork: ""}, a.DB) for _, s := range stat.([]Stats) { result.Items = append(result.Items, s) } @@ -151,11 +155,11 @@ func handleAllStats(res http.ResponseWriter, req *http.Request) { return } -func handleStats(res http.ResponseWriter, req *http.Request) { +func (a *API) handleStats(res http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) if h, ok := intNametoInterface[vars["int"]]; ok { - stat := h.handleAPIReq(APIReq{Req: "stats", NetInterface: vars["int"], NetWork: vars["network"]}) + stat := h.handleAPIReq(APIReq{Req: "stats", NetInterface: vars["int"], NetWork: vars["network"]}, a.DB) outgoingJSON, err := json.Marshal(stat) @@ -172,11 +176,11 @@ func handleStats(res http.ResponseWriter, req *http.Request) { return } -func handleDuplicates(res http.ResponseWriter, req *http.Request) { +func (a *API) handleDuplicates(res http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) if h, ok := intNametoInterface[vars["int"]]; ok { - stat := h.handleAPIReq(APIReq{Req: "duplicates", NetInterface: vars["int"], NetWork: vars["network"]}) + stat := h.handleAPIReq(APIReq{Req: "duplicates", NetInterface: vars["int"], NetWork: vars["network"]}, a.DB) outgoingJSON, err := json.Marshal(stat) @@ -193,11 +197,11 @@ func handleDuplicates(res http.ResponseWriter, req *http.Request) { return } -func handleDebug(res http.ResponseWriter, req *http.Request) { +func (a *API) handleDebug(res http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) if h, ok := intNametoInterface[vars["int"]]; ok { - stat := h.handleAPIReq(APIReq{Req: "debug", NetInterface: vars["int"], Role: vars["role"]}) + stat := h.handleAPIReq(APIReq{Req: "debug", NetInterface: vars["int"], Role: vars["role"]}, a.DB) outgoingJSON, err := json.Marshal(stat) @@ -213,7 +217,7 @@ func handleDebug(res http.ResponseWriter, req *http.Request) { return } -func handleReleaseIP(res http.ResponseWriter, req *http.Request) { +func (a *API) handleReleaseIP(res http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) _ = InterfaceScopeFromMac(vars["mac"]) @@ -226,11 +230,11 @@ func handleReleaseIP(res http.ResponseWriter, req *http.Request) { } } -func handleOverrideOptions(res http.ResponseWriter, req *http.Request) { +func (a *API) handleOverrideOptions(res http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) - body, err := ioutil.ReadAll(io.LimitReader(req.Body, 1048576)) + body, err := io.ReadAll(io.LimitReader(req.Body, 1048576)) if err != nil { panic(err) } @@ -239,7 +243,7 @@ func handleOverrideOptions(res http.ResponseWriter, req *http.Request) { } // Insert information in MySQL - _ = MysqlInsert(vars["mac"], sharedutils.ConvertToString(body)) + _ = MysqlInsert(vars["mac"], sharedutils.ConvertToString(body), a.DB) var result = &Info{Mac: vars["mac"], Status: "ACK"} @@ -250,11 +254,11 @@ func handleOverrideOptions(res http.ResponseWriter, req *http.Request) { } } -func handleOverrideNetworkOptions(res http.ResponseWriter, req *http.Request) { +func (a *API) handleOverrideNetworkOptions(res http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) - body, err := ioutil.ReadAll(io.LimitReader(req.Body, 1048576)) + body, err := io.ReadAll(io.LimitReader(req.Body, 1048576)) if err != nil { panic(err) } @@ -263,7 +267,7 @@ func handleOverrideNetworkOptions(res http.ResponseWriter, req *http.Request) { } // Insert information in MySQL - _ = MysqlInsert(vars["network"], sharedutils.ConvertToString(body)) + _ = MysqlInsert(vars["network"], sharedutils.ConvertToString(body), a.DB) var result = &Info{Network: vars["network"], Status: "ACK"} @@ -274,13 +278,13 @@ func handleOverrideNetworkOptions(res http.ResponseWriter, req *http.Request) { } } -func handleRemoveOptions(res http.ResponseWriter, req *http.Request) { +func (a *API) handleRemoveOptions(res http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) var result = &Info{Mac: vars["mac"], Status: "ACK"} - err := MysqlDel(vars["mac"]) + err := MysqlDel(vars["mac"], a.DB) if !err { result = &Info{Mac: vars["mac"], Status: "NAK"} } @@ -291,13 +295,13 @@ func handleRemoveOptions(res http.ResponseWriter, req *http.Request) { } } -func handleRemoveNetworkOptions(res http.ResponseWriter, req *http.Request) { +func (a *API) handleRemoveNetworkOptions(res http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) var result = &Info{Network: vars["network"], Status: "ACK"} - err := MysqlDel(vars["network"]) + err := MysqlDel(vars["network"], a.DB) if !err { result = &Info{Network: vars["network"], Status: "NAK"} } @@ -308,9 +312,9 @@ func handleRemoveNetworkOptions(res http.ResponseWriter, req *http.Request) { } } -func decodeOptions(b string) (map[dhcp.OptionCode][]byte, error) { +func decodeOptions(b string, db *sql.DB) (map[dhcp.OptionCode][]byte, error) { var options []Options - _, value := MysqlGet(b) + _, value := MysqlGet(b, db) decodedValue := sharedutils.ConvertToByte(value) var dhcpOptions = make(map[dhcp.OptionCode][]byte) if err := json.Unmarshal(decodedValue, &options); err != nil { @@ -360,7 +364,7 @@ func extractMembers(v Network) ([]Node, []string, int) { return Members, Macs, Count } -func (h *Interface) handleAPIReq(Request APIReq) interface{} { +func (h *Interface) handleAPIReq(Request APIReq, db *sql.DB) interface{} { var stats []Stats if Request.Req == "duplicates" { @@ -399,7 +403,7 @@ func (h *Interface) handleAPIReq(Request APIReq) interface{} { } // Add network options on the fly - x, err := decodeOptions(v.network.IP.String()) + x, err := decodeOptions(v.network.IP.String(), db) if err == nil { for key, value := range x { Options[key.String()] = Tlv.Tlvlist[int(key)].Transform.String(value) diff --git a/go/cmd/pfdhcp/config.go b/go/cmd/pfdhcp/config.go index 17fd4a038294..d4ced741d412 100644 --- a/go/cmd/pfdhcp/config.go +++ b/go/cmd/pfdhcp/config.go @@ -1,6 +1,7 @@ package main import ( + "database/sql" "encoding/binary" "math" "net" @@ -68,7 +69,7 @@ func newDHCPConfig() *Interfaces { return &p } -func (d *Interfaces) readConfig() { +func (d *Interfaces) readConfig(MyDB *sql.DB) { interfaces := pfconfigdriver.GetType[pfconfigdriver.ListenInts](ctx) DHCPinterfaces := pfconfigdriver.GetType[pfconfigdriver.DHCPInts](ctx) portal := pfconfigdriver.GetType[pfconfigdriver.PfConfCaptivePortal](ctx) @@ -246,7 +247,7 @@ func (d *Interfaces) readConfig() { } DHCPScope.dstIp = dstReplyIp // Initialize dhcp pool - available, _ := pool.Create(ctx, backend, uint64(dhcp.IPRange(ip, ips)), DHCPNet.network.IP.String()+Role, algorithm, StatsdClient, MySQLdatabase) + available, _ := pool.Create(ctx, backend, uint64(dhcp.IPRange(ip, ips)), DHCPNet.network.IP.String()+Role, algorithm, StatsdClient, MyDB) DHCPScope.available = available @@ -267,7 +268,7 @@ func (d *Interfaces) readConfig() { DHCPScope.xid = xid wg.Add(1) go func() { - initiaLease(DHCPScope, ConfNet) + initiaLease(DHCPScope, ConfNet, MyDB) wg.Done() }() var options = make(map[dhcp.OptionCode][]byte) @@ -326,7 +327,7 @@ func (d *Interfaces) readConfig() { } DHCPScope.dstIp = dstReplyIp // Initialize dhcp pool - available, _ := pool.Create(ctx, backend, uint64(dhcp.IPRange(net.ParseIP(ConfNet.DhcpStart), net.ParseIP(ConfNet.DhcpEnd))), DHCPNet.network.IP.String(), algorithm, StatsdClient, MySQLdatabase) + available, _ := pool.Create(ctx, backend, uint64(dhcp.IPRange(net.ParseIP(ConfNet.DhcpStart), net.ParseIP(ConfNet.DhcpEnd))), DHCPNet.network.IP.String(), algorithm, StatsdClient, MyDB) DHCPScope.available = available @@ -347,7 +348,7 @@ func (d *Interfaces) readConfig() { DHCPScope.xid = xid wg.Add(1) go func() { - initiaLease(DHCPScope, ConfNet) + initiaLease(DHCPScope, ConfNet, MyDB) wg.Done() }() diff --git a/go/cmd/pfdhcp/keysoption.go b/go/cmd/pfdhcp/keysoption.go index 4d2619413d77..9fabf1a54dbf 100644 --- a/go/cmd/pfdhcp/keysoption.go +++ b/go/cmd/pfdhcp/keysoption.go @@ -1,16 +1,18 @@ package main import ( + "database/sql" + "github.com/inverse-inc/go-utils/log" ) // MysqlInsert function -func MysqlInsert(key string, value string) bool { - if err := MySQLdatabase.PingContext(ctx); err != nil { +func MysqlInsert(key string, value string, db *sql.DB) bool { + if err := db.PingContext(ctx); err != nil { log.LoggerWContext(ctx).Error("Unable to ping database, reconnect: " + err.Error()) } - _, err := MySQLdatabase.Exec( + _, err := db.Exec( ` INSERT into key_value_storage values(?,?) ON DUPLICATE KEY UPDATE value = VALUES(value) @@ -28,11 +30,11 @@ ON DUPLICATE KEY UPDATE value = VALUES(value) } // MysqlGet function -func MysqlGet(key string) (string, string) { - if err := MySQLdatabase.PingContext(ctx); err != nil { +func MysqlGet(key string, db *sql.DB) (string, string) { + if err := db.PingContext(ctx); err != nil { log.LoggerWContext(ctx).Error("Unable to ping database, reconnect: " + err.Error()) } - rows, err := MySQLdatabase.Query("select id, value from key_value_storage where id = ?", "/dhcpd/"+key) + rows, err := db.Query("select id, value from key_value_storage where id = ?", "/dhcpd/"+key) defer rows.Close() if err != nil { log.LoggerWContext(ctx).Debug("Error while getting MySQL '" + key + "': " + err.Error()) @@ -52,11 +54,11 @@ func MysqlGet(key string) (string, string) { } // MysqlDel function -func MysqlDel(key string) bool { - if err := MySQLdatabase.PingContext(ctx); err != nil { +func MysqlDel(key string, db *sql.DB) bool { + if err := db.PingContext(ctx); err != nil { log.LoggerWContext(ctx).Error("Unable to ping database, reconnect: " + err.Error()) } - rows, err := MySQLdatabase.Query("delete from key_value_storage where id = ?", "/dhcpd/"+key) + rows, err := db.Query("delete from key_value_storage where id = ?", "/dhcpd/"+key) defer rows.Close() if err != nil { log.LoggerWContext(ctx).Error("Error while deleting MySQL key '" + key + "': " + err.Error()) diff --git a/go/cmd/pfdhcp/main.go b/go/cmd/pfdhcp/main.go index 171a8831ed56..3e1432f751a4 100644 --- a/go/cmd/pfdhcp/main.go +++ b/go/cmd/pfdhcp/main.go @@ -32,9 +32,6 @@ import ( // DHCPConfig global var var DHCPConfig *Interfaces -// MySQLdatabase global var -var MySQLdatabase *sql.DB - // GlobalIPCache global var var GlobalIPCache *cache.Cache @@ -97,12 +94,12 @@ func main() { // Read DB config configDatabase := pfconfigdriver.GetType[pfconfigdriver.PfConfDatabase](ctx) - connectDB(configDatabase) + MyDB := connectDB(configDatabase) // Keep the db alive - go func() { + go func(*sql.DB) { for { - err := MySQLdatabase.Ping() + err := MyDB.Ping() if err != nil { log.LoggerWContext(ctx).Error("Unable to ping DB: " + err.Error()) } else { @@ -110,12 +107,12 @@ func main() { } time.Sleep(5 * time.Second) } - }() + }(MyDB) VIP = make(map[string]bool) VIPIp = make(map[string]net.IP) - go func() { + go func(*sql.DB) { var DHCPinterfaces pfconfigdriver.DHCPInts pfconfigdriver.FetchDecodeSocket(ctx, &DHCPinterfaces) var interfaces pfconfigdriver.ListenInts @@ -132,11 +129,11 @@ func main() { } for { - DHCPConfig.detectVIP(sharedutils.RemoveDuplicates(append(interfaces.Element, intDhcp...))) + DHCPConfig.detectVIP(sharedutils.RemoveDuplicates(append(interfaces.Element, intDhcp...)), MyDB) time.Sleep(3 * time.Second) } - }() + }(MyDB) go func() { var err error @@ -161,7 +158,7 @@ func main() { // Read pfconfig DHCPConfig = newDHCPConfig() - DHCPConfig.readConfig() + DHCPConfig.readConfig(MyDB) webservices := pfconfigdriver.GetType[pfconfigdriver.PfConfWebservices](ctx) // Queue value @@ -202,21 +199,21 @@ func main() { v.run(ctx, jobs) }() } - + api := &API{DB: MyDB} // Api router := mux.NewRouter() - router.HandleFunc("/api/v1/dhcp/mac/{mac:(?:[0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}}", handleMac2Ip).Methods("GET") - router.HandleFunc("/api/v1/dhcp/mac/{mac:(?:[0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}}", handleReleaseIP).Methods("DELETE") - router.HandleFunc("/api/v1/dhcp/ip/{ip:(?:[0-9]{1,3}.){3}(?:[0-9]{1,3})}", handleIP2Mac).Methods("GET") - router.HandleFunc("/api/v1/dhcp/stats", handleAllStats).Methods("GET") - router.HandleFunc("/api/v1/dhcp/stats/{int:.*}/{network:(?:[0-9]{1,3}.){3}(?:[0-9]{1,3})}", handleStats).Methods("GET") - router.HandleFunc("/api/v1/dhcp/stats/{int:.*}", handleStats).Methods("GET") - router.HandleFunc("/api/v1/dhcp/debug/{int:.*}/{role:(?:[^/]*)}", handleDebug).Methods("GET") - router.HandleFunc("/api/v1/dhcp/detect_duplicates/{int:.*}", handleDuplicates).Methods("GET") - router.HandleFunc("/api/v1/dhcp/options/network/{network:(?:[0-9]{1,3}.){3}(?:[0-9]{1,3})}", handleOverrideNetworkOptions).Methods("POST") - router.HandleFunc("/api/v1/dhcp/options/network/{network:(?:[0-9]{1,3}.){3}(?:[0-9]{1,3})}", handleRemoveNetworkOptions).Methods("DELETE") - router.HandleFunc("/api/v1/dhcp/options/mac/{mac:(?:[0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}}", handleOverrideOptions).Methods("POST") - router.HandleFunc("/api/v1/dhcp/options/mac/{mac:(?:[0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}}", handleRemoveOptions).Methods("DELETE") + router.HandleFunc("/api/v1/dhcp/mac/{mac:(?:[0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}}", api.handleMac2Ip).Methods("GET") + router.HandleFunc("/api/v1/dhcp/mac/{mac:(?:[0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}}", api.handleReleaseIP).Methods("DELETE") + router.HandleFunc("/api/v1/dhcp/ip/{ip:(?:[0-9]{1,3}.){3}(?:[0-9]{1,3})}", api.handleIP2Mac).Methods("GET") + router.HandleFunc("/api/v1/dhcp/stats", api.handleAllStats).Methods("GET") + router.HandleFunc("/api/v1/dhcp/stats/{int:.*}/{network:(?:[0-9]{1,3}.){3}(?:[0-9]{1,3})}", api.handleStats).Methods("GET") + router.HandleFunc("/api/v1/dhcp/stats/{int:.*}", api.handleStats).Methods("GET") + router.HandleFunc("/api/v1/dhcp/debug/{int:.*}/{role:(?:[^/]*)}", api.handleDebug).Methods("GET") + router.HandleFunc("/api/v1/dhcp/detect_duplicates/{int:.*}", api.handleDuplicates).Methods("GET") + router.HandleFunc("/api/v1/dhcp/options/network/{network:(?:[0-9]{1,3}.){3}(?:[0-9]{1,3})}", api.handleOverrideNetworkOptions).Methods("POST") + router.HandleFunc("/api/v1/dhcp/options/network/{network:(?:[0-9]{1,3}.){3}(?:[0-9]{1,3})}", api.handleRemoveNetworkOptions).Methods("DELETE") + router.HandleFunc("/api/v1/dhcp/options/mac/{mac:(?:[0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}}", api.handleOverrideOptions).Methods("POST") + router.HandleFunc("/api/v1/dhcp/options/mac/{mac:(?:[0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}}", api.handleRemoveOptions).Methods("DELETE") http.Handle("/", httpauth.SimpleBasicAuth(webservices.User, webservices.Pass)(router)) srv := &http.Server{ @@ -277,7 +274,7 @@ func (I *Interface) runUnicast(ctx context.Context, jobs chan job) { } // ServeDHCP function is the main function that will deal with the dhcp packet -func (I *Interface) ServeDHCP(ctx context.Context, p dhcp.Packet, msgType dhcp.MessageType, srcIP net.Addr, srvIP net.IP) (answer Answer) { +func (I *Interface) ServeDHCP(ctx context.Context, p dhcp.Packet, msgType dhcp.MessageType, srcIP net.Addr, srvIP net.IP, db *sql.DB) (answer Answer) { var handler DHCPHandler var NetScope net.IPNet @@ -303,7 +300,7 @@ func (I *Interface) ServeDHCP(ctx context.Context, p dhcp.Packet, msgType dhcp.M if x, found := NodeCache.Get(answer.MAC.String()); found { node = x.(NodeInfo) } else { - node = NodeInformation(ctx, answer.MAC) + node = NodeInformation(ctx, answer.MAC, db) NodeCache.Set(answer.MAC.String(), node, 3*time.Second) } @@ -558,7 +555,7 @@ func (I *Interface) ServeDHCP(ctx context.Context, p dhcp.Packet, msgType dhcp.M leaseDuration := handler.leaseDuration // Add network options on the fly - x, err := decodeOptions(NetScope.IP.String()) + x, err := decodeOptions(NetScope.IP.String(), db) if err == nil { for key, value := range x { if key == dhcp.OptionIPAddressLeaseTime { @@ -584,7 +581,7 @@ func (I *Interface) ServeDHCP(ctx context.Context, p dhcp.Packet, msgType dhcp.M leaseDuration = 0 } // Add device (mac) options on the fly - x, err = decodeOptions(answer.MAC.String()) + x, err = decodeOptions(answer.MAC.String(), db) if err == nil { for key, value := range x { if key == dhcp.OptionIPAddressLeaseTime { @@ -696,9 +693,9 @@ func (I *Interface) ServeDHCP(ctx context.Context, p dhcp.Packet, msgType dhcp.M GlobalOptions = options leaseDuration := handler.leaseDuration // Add network options - AddDevicesOptions(NetScope.IP.String(), &leaseDuration, GlobalOptions) + AddDevicesOptions(NetScope.IP.String(), &leaseDuration, GlobalOptions, db) // Add device options - AddDevicesOptions(answer.MAC.String(), &leaseDuration, GlobalOptions) + AddDevicesOptions(answer.MAC.String(), &leaseDuration, GlobalOptions, db) info = GetFromGlobalFilterCache(msgType.String(), answer.MAC.String(), Options) // Add options on the fly from pffilter reject := AddPffilterDevicesOptions(info, GlobalOptions) @@ -723,7 +720,7 @@ func (I *Interface) ServeDHCP(ctx context.Context, p dhcp.Packet, msgType dhcp.M // Update Global Caches GlobalIPCache.Set(reqIP.String(), answer.MAC.String(), cacheDuration) GlobalMacCache.Set(answer.MAC.String(), reqIP.String(), cacheDuration) - err := MysqlUpdateIP4Log(answer.MAC.String(), reqIP.String(), cacheDuration) + err := MysqlUpdateIP4Log(answer.MAC.String(), reqIP.String(), cacheDuration, db) if err != nil { log.LoggerWContext(ctx).Info(err.Error()) } diff --git a/go/cmd/pfdhcp/server.go b/go/cmd/pfdhcp/server.go index abdfdb7da0bc..f2e9d949e864 100644 --- a/go/cmd/pfdhcp/server.go +++ b/go/cmd/pfdhcp/server.go @@ -2,6 +2,7 @@ package main import ( "context" + "database/sql" "net" dhcp "github.com/inverse-inc/dhcp4" @@ -19,7 +20,7 @@ type Answer struct { // Handler interface type Handler interface { - ServeDHCP(ctx context.Context, req dhcp.Packet, msgType dhcp.MessageType, srcIP net.Addr, srvIP net.IP) Answer + ServeDHCP(ctx context.Context, req dhcp.Packet, msgType dhcp.MessageType, srcIP net.Addr, srvIP net.IP, db *sql.DB) Answer } // ServeConn is the bare minimum connection functions required by Serve() diff --git a/go/cmd/pfdhcp/utils.go b/go/cmd/pfdhcp/utils.go index a21435099fc0..be87933e766f 100644 --- a/go/cmd/pfdhcp/utils.go +++ b/go/cmd/pfdhcp/utils.go @@ -2,6 +2,7 @@ package main import ( "context" + "database/sql" "encoding/binary" "fmt" "math/rand" @@ -32,14 +33,14 @@ type NodeInfo struct { } // connectDB connect to the database -func connectDB(configDatabase *pfconfigdriver.PfConfDatabase) { +func connectDB(configDatabase *pfconfigdriver.PfConfDatabase) *sql.DB { db, err := db.DbFromConfig(ctx) sharedutils.CheckError(err) - MySQLdatabase = db + return db } // initiaLease fetch the database to remove already assigned ip addresses -func initiaLease(dhcpHandler *DHCPHandler, ConfNet pfconfigdriver.RessourseNetworkConf) { +func initiaLease(dhcpHandler *DHCPHandler, ConfNet pfconfigdriver.RessourseNetworkConf, db *sql.DB) { // Need to calculate the end ip because of the ip per role feature now := time.Now() endip := binary.BigEndian.Uint32(dhcpHandler.start.To4()) + uint32(dhcpHandler.leaseRange) - uint32(1) @@ -47,7 +48,7 @@ func initiaLease(dhcpHandler *DHCPHandler, ConfNet pfconfigdriver.RessourseNetwo binary.BigEndian.PutUint32(a, endip) ipend := net.IPv4(a[0], a[1], a[2], a[3]) - rows, err := MySQLdatabase.Query("select ip,mac,end_time,start_time from ip4log i where inet_aton(ip) between inet_aton(?) and inet_aton(?) and (end_time = '"+ZeroDate+"' OR end_time > NOW()) and end_time in (select MAX(end_time) from ip4log where mac = i.mac) ORDER BY mac,end_time desc", dhcpHandler.start.String(), ipend.String()) + rows, err := db.Query("select ip,mac,end_time,start_time from ip4log i where inet_aton(ip) between inet_aton(?) and inet_aton(?) and (end_time = '"+ZeroDate+"' OR end_time > NOW()) and end_time in (select MAX(end_time) from ip4log where mac = i.mac) ORDER BY mac,end_time desc", dhcpHandler.start.String(), ipend.String()) if err != nil { log.LoggerWContext(ctx).Error(err.Error()) return @@ -121,7 +122,7 @@ func InterfaceScopeFromMac(MAC string) string { } // Detect the vip on each interfaces -func (d *Interfaces) detectVIP(interfaces []string) { +func (d *Interfaces) detectVIP(interfaces []string, db *sql.DB) { var keyConfCluster pfconfigdriver.NetInterface keyConfCluster.PfconfigNS = "config::Pf(CLUSTER," + pfconfigdriver.FindClusterName(ctx) + ")" @@ -151,7 +152,7 @@ func (d *Interfaces) detectVIP(interfaces []string) { if VIP[v] == false { log.LoggerWContext(ctx).Info(v + " got the VIP") if h, ok := intNametoInterface[v]; ok { - go h.handleAPIReq(APIReq{Req: "initialease", NetInterface: v, NetWork: ""}) + go h.handleAPIReq(APIReq{Req: "initialease", NetInterface: v, NetWork: ""}, db) } VIP[v] = true } @@ -164,9 +165,9 @@ func (d *Interfaces) detectVIP(interfaces []string) { } // NodeInformation return the node information -func NodeInformation(ctx context.Context, target net.HardwareAddr) (r NodeInfo) { +func NodeInformation(ctx context.Context, target net.HardwareAddr, db *sql.DB) (r NodeInfo) { - rows, err := MySQLdatabase.Query("SELECT mac, status, IF(ISNULL(nc.name), '', nc.name) as category FROM node LEFT JOIN node_category as nc on node.category_id = nc.category_id WHERE mac = ?", target.String()) + rows, err := db.Query("SELECT mac, status, IF(ISNULL(nc.name), '', nc.name) as category FROM node LEFT JOIN node_category as nc on node.category_id = nc.category_id WHERE mac = ?", target.String()) defer rows.Close() if err != nil { @@ -361,8 +362,8 @@ func AssignIP(dhcpHandler *DHCPHandler, ipRange string) (map[string]uint32, []ne } // AddDevicesOptions function add options on the fly -func AddDevicesOptions(object string, leaseDuration *time.Duration, GlobalOptions map[dhcp.OptionCode][]byte) { - x, err := decodeOptions(object) +func AddDevicesOptions(object string, leaseDuration *time.Duration, GlobalOptions map[dhcp.OptionCode][]byte, db *sql.DB) { + x, err := decodeOptions(object, db) if err == nil { for key, value := range x { if key == dhcp.OptionIPAddressLeaseTime { @@ -433,30 +434,30 @@ func IsIPv6(address net.IP) bool { } // MysqlUpdateIP4Log update the ip4log table -func MysqlUpdateIP4Log(mac string, ip string, duration time.Duration) error { - if err := MySQLdatabase.PingContext(ctx); err != nil { +func MysqlUpdateIP4Log(mac string, ip string, duration time.Duration, db *sql.DB) error { + if err := db.PingContext(ctx); err != nil { log.LoggerWContext(ctx).Error("Unable to ping database, reconnect: " + err.Error()) } - MAC2IP, err := MySQLdatabase.Prepare("SELECT ip FROM ip4log WHERE mac = ? AND (end_time = \"" + ZeroDate + "\" OR ( end_time + INTERVAL 30 SECOND ) > NOW()) ORDER BY start_time DESC LIMIT 1") + MAC2IP, err := db.Prepare("SELECT ip FROM ip4log WHERE mac = ? AND (end_time = \"" + ZeroDate + "\" OR ( end_time + INTERVAL 30 SECOND ) > NOW()) ORDER BY start_time DESC LIMIT 1") if err != nil { return err } defer MAC2IP.Close() - IP2MAC, err := MySQLdatabase.Prepare("SELECT mac FROM ip4log WHERE ip = ? AND (end_time = \"" + ZeroDate + "\" OR end_time > NOW()) ORDER BY start_time DESC") + IP2MAC, err := db.Prepare("SELECT mac FROM ip4log WHERE ip = ? AND (end_time = \"" + ZeroDate + "\" OR end_time > NOW()) ORDER BY start_time DESC") if err != nil { return err } defer IP2MAC.Close() - IPClose, err := MySQLdatabase.Prepare(" UPDATE ip4log SET end_time = NOW() WHERE ip = ?") + IPClose, err := db.Prepare(" UPDATE ip4log SET end_time = NOW() WHERE ip = ?") if err != nil { return err } - defer IP2MAC.Close() + defer IPClose.Close() - IPInsert, err := MySQLdatabase.Prepare("INSERT INTO ip4log (mac, ip, start_time, end_time) VALUES (?, ?, NOW(), DATE_ADD(NOW(), INTERVAL ? SECOND)) ON DUPLICATE KEY UPDATE mac=VALUES(mac), start_time=NOW(), end_time=VALUES(end_time)") + IPInsert, err := db.Prepare("INSERT INTO ip4log (mac, ip, start_time, end_time) VALUES (?, ?, NOW(), DATE_ADD(NOW(), INTERVAL ? SECOND)) ON DUPLICATE KEY UPDATE mac=VALUES(mac), start_time=NOW(), end_time=VALUES(end_time)") if err != nil { return err } diff --git a/go/cmd/pfdhcp/workers_pool.go b/go/cmd/pfdhcp/workers_pool.go index 11d5274f6b06..3e8075c54484 100644 --- a/go/cmd/pfdhcp/workers_pool.go +++ b/go/cmd/pfdhcp/workers_pool.go @@ -2,6 +2,7 @@ package main import ( "context" + "database/sql" _ "expvar" "net" "strconv" @@ -18,11 +19,12 @@ type job struct { clientAddr net.Addr //remote client ip srvAddr net.IP localCtx context.Context + db *sql.DB } func doWork(id int, element job) { var ans Answer - if ans = element.handler.ServeDHCP(element.localCtx, element.DHCPpacket, element.msgType, element.clientAddr, element.srvAddr); ans.D != nil { + if ans = element.handler.ServeDHCP(element.localCtx, element.DHCPpacket, element.msgType, element.clientAddr, element.srvAddr, element.db); ans.D != nil { ipStr, portStr, _ := net.SplitHostPort(element.clientAddr.String()) ctx = log.AddToLogContext(ctx, "mac", ans.MAC.String()) log.LoggerWContext(ctx).Debug("Giaddr " + element.DHCPpacket.GIAddr().String())