Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve DB connection in pfdhcp #8419

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 30 additions & 26 deletions go/cmd/pfdhcp/api.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package main

import (
"database/sql"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"strconv"
Expand All @@ -21,6 +21,10 @@ import (
"github.com/inverse-inc/packetfence/go/pfconfigdriver"
)

type API struct {
DB *sql.DB
}

// Node struct
type Node struct {
Mac string `json:"mac"`
Expand Down Expand Up @@ -82,7 +86,7 @@ type OptionsFromFilter struct {
Type string `json:"type"`
}

func handleIP2Mac(res http.ResponseWriter, req *http.Request) {
func (a *API) handleIP2Mac(res http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req)

if index, expiresAt, found := GlobalIPCache.GetWithExpiration(vars["ip"]); found {
Expand All @@ -102,7 +106,7 @@ func handleIP2Mac(res http.ResponseWriter, req *http.Request) {
return
}

func handleMac2Ip(res http.ResponseWriter, req *http.Request) {
func (a *API) handleMac2Ip(res http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req)

if index, expiresAt, found := GlobalMacCache.GetWithExpiration(vars["mac"]); found {
Expand All @@ -122,7 +126,7 @@ func handleMac2Ip(res http.ResponseWriter, req *http.Request) {
return
}

func handleAllStats(res http.ResponseWriter, req *http.Request) {
func (a *API) handleAllStats(res http.ResponseWriter, req *http.Request) {
var result Items
var interfaces pfconfigdriver.ListenInts
pfconfigdriver.FetchDecodeSocket(ctx, &interfaces)
Expand All @@ -132,7 +136,7 @@ func handleAllStats(res http.ResponseWriter, req *http.Request) {
}
for _, i := range interfaces.Element {
if h, ok := intNametoInterface[i]; ok {
stat := h.handleAPIReq(APIReq{Req: "stats", NetInterface: i, NetWork: ""})
stat := h.handleAPIReq(APIReq{Req: "stats", NetInterface: i, NetWork: ""}, a.DB)
for _, s := range stat.([]Stats) {
result.Items = append(result.Items, s)
}
Expand All @@ -151,11 +155,11 @@ func handleAllStats(res http.ResponseWriter, req *http.Request) {
return
}

func handleStats(res http.ResponseWriter, req *http.Request) {
func (a *API) handleStats(res http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req)

if h, ok := intNametoInterface[vars["int"]]; ok {
stat := h.handleAPIReq(APIReq{Req: "stats", NetInterface: vars["int"], NetWork: vars["network"]})
stat := h.handleAPIReq(APIReq{Req: "stats", NetInterface: vars["int"], NetWork: vars["network"]}, a.DB)

outgoingJSON, err := json.Marshal(stat)

Expand All @@ -172,11 +176,11 @@ func handleStats(res http.ResponseWriter, req *http.Request) {
return
}

func handleDuplicates(res http.ResponseWriter, req *http.Request) {
func (a *API) handleDuplicates(res http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req)

if h, ok := intNametoInterface[vars["int"]]; ok {
stat := h.handleAPIReq(APIReq{Req: "duplicates", NetInterface: vars["int"], NetWork: vars["network"]})
stat := h.handleAPIReq(APIReq{Req: "duplicates", NetInterface: vars["int"], NetWork: vars["network"]}, a.DB)

outgoingJSON, err := json.Marshal(stat)

Expand All @@ -193,11 +197,11 @@ func handleDuplicates(res http.ResponseWriter, req *http.Request) {
return
}

func handleDebug(res http.ResponseWriter, req *http.Request) {
func (a *API) handleDebug(res http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req)

if h, ok := intNametoInterface[vars["int"]]; ok {
stat := h.handleAPIReq(APIReq{Req: "debug", NetInterface: vars["int"], Role: vars["role"]})
stat := h.handleAPIReq(APIReq{Req: "debug", NetInterface: vars["int"], Role: vars["role"]}, a.DB)

outgoingJSON, err := json.Marshal(stat)

Expand All @@ -213,7 +217,7 @@ func handleDebug(res http.ResponseWriter, req *http.Request) {
return
}

func handleReleaseIP(res http.ResponseWriter, req *http.Request) {
func (a *API) handleReleaseIP(res http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req)
_ = InterfaceScopeFromMac(vars["mac"])

Expand All @@ -226,11 +230,11 @@ func handleReleaseIP(res http.ResponseWriter, req *http.Request) {
}
}

func handleOverrideOptions(res http.ResponseWriter, req *http.Request) {
func (a *API) handleOverrideOptions(res http.ResponseWriter, req *http.Request) {

vars := mux.Vars(req)

body, err := ioutil.ReadAll(io.LimitReader(req.Body, 1048576))
body, err := io.ReadAll(io.LimitReader(req.Body, 1048576))
if err != nil {
panic(err)
}
Expand All @@ -239,7 +243,7 @@ func handleOverrideOptions(res http.ResponseWriter, req *http.Request) {
}

// Insert information in MySQL
_ = MysqlInsert(vars["mac"], sharedutils.ConvertToString(body))
_ = MysqlInsert(vars["mac"], sharedutils.ConvertToString(body), a.DB)

var result = &Info{Mac: vars["mac"], Status: "ACK"}

Expand All @@ -250,11 +254,11 @@ func handleOverrideOptions(res http.ResponseWriter, req *http.Request) {
}
}

func handleOverrideNetworkOptions(res http.ResponseWriter, req *http.Request) {
func (a *API) handleOverrideNetworkOptions(res http.ResponseWriter, req *http.Request) {

vars := mux.Vars(req)

body, err := ioutil.ReadAll(io.LimitReader(req.Body, 1048576))
body, err := io.ReadAll(io.LimitReader(req.Body, 1048576))
if err != nil {
panic(err)
}
Expand All @@ -263,7 +267,7 @@ func handleOverrideNetworkOptions(res http.ResponseWriter, req *http.Request) {
}

// Insert information in MySQL
_ = MysqlInsert(vars["network"], sharedutils.ConvertToString(body))
_ = MysqlInsert(vars["network"], sharedutils.ConvertToString(body), a.DB)

var result = &Info{Network: vars["network"], Status: "ACK"}

Expand All @@ -274,13 +278,13 @@ func handleOverrideNetworkOptions(res http.ResponseWriter, req *http.Request) {
}
}

func handleRemoveOptions(res http.ResponseWriter, req *http.Request) {
func (a *API) handleRemoveOptions(res http.ResponseWriter, req *http.Request) {

vars := mux.Vars(req)

var result = &Info{Mac: vars["mac"], Status: "ACK"}

err := MysqlDel(vars["mac"])
err := MysqlDel(vars["mac"], a.DB)
if !err {
result = &Info{Mac: vars["mac"], Status: "NAK"}
}
Expand All @@ -291,13 +295,13 @@ func handleRemoveOptions(res http.ResponseWriter, req *http.Request) {
}
}

func handleRemoveNetworkOptions(res http.ResponseWriter, req *http.Request) {
func (a *API) handleRemoveNetworkOptions(res http.ResponseWriter, req *http.Request) {

vars := mux.Vars(req)

var result = &Info{Network: vars["network"], Status: "ACK"}

err := MysqlDel(vars["network"])
err := MysqlDel(vars["network"], a.DB)
if !err {
result = &Info{Network: vars["network"], Status: "NAK"}
}
Expand All @@ -308,9 +312,9 @@ func handleRemoveNetworkOptions(res http.ResponseWriter, req *http.Request) {
}
}

func decodeOptions(b string) (map[dhcp.OptionCode][]byte, error) {
func decodeOptions(b string, db *sql.DB) (map[dhcp.OptionCode][]byte, error) {
var options []Options
_, value := MysqlGet(b)
_, value := MysqlGet(b, db)
decodedValue := sharedutils.ConvertToByte(value)
var dhcpOptions = make(map[dhcp.OptionCode][]byte)
if err := json.Unmarshal(decodedValue, &options); err != nil {
Expand Down Expand Up @@ -360,7 +364,7 @@ func extractMembers(v Network) ([]Node, []string, int) {
return Members, Macs, Count
}

func (h *Interface) handleAPIReq(Request APIReq) interface{} {
func (h *Interface) handleAPIReq(Request APIReq, db *sql.DB) interface{} {
var stats []Stats

if Request.Req == "duplicates" {
Expand Down Expand Up @@ -399,7 +403,7 @@ func (h *Interface) handleAPIReq(Request APIReq) interface{} {
}

// Add network options on the fly
x, err := decodeOptions(v.network.IP.String())
x, err := decodeOptions(v.network.IP.String(), db)
if err == nil {
for key, value := range x {
Options[key.String()] = Tlv.Tlvlist[int(key)].Transform.String(value)
Expand Down
11 changes: 6 additions & 5 deletions go/cmd/pfdhcp/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"database/sql"
"encoding/binary"
"math"
"net"
Expand Down Expand Up @@ -68,7 +69,7 @@ func newDHCPConfig() *Interfaces {
return &p
}

func (d *Interfaces) readConfig() {
func (d *Interfaces) readConfig(MyDB *sql.DB) {
interfaces := pfconfigdriver.GetType[pfconfigdriver.ListenInts](ctx)
DHCPinterfaces := pfconfigdriver.GetType[pfconfigdriver.DHCPInts](ctx)
portal := pfconfigdriver.GetType[pfconfigdriver.PfConfCaptivePortal](ctx)
Expand Down Expand Up @@ -246,7 +247,7 @@ func (d *Interfaces) readConfig() {
}
DHCPScope.dstIp = dstReplyIp
// Initialize dhcp pool
available, _ := pool.Create(ctx, backend, uint64(dhcp.IPRange(ip, ips)), DHCPNet.network.IP.String()+Role, algorithm, StatsdClient, MySQLdatabase)
available, _ := pool.Create(ctx, backend, uint64(dhcp.IPRange(ip, ips)), DHCPNet.network.IP.String()+Role, algorithm, StatsdClient, MyDB)

DHCPScope.available = available

Expand All @@ -267,7 +268,7 @@ func (d *Interfaces) readConfig() {
DHCPScope.xid = xid
wg.Add(1)
go func() {
initiaLease(DHCPScope, ConfNet)
initiaLease(DHCPScope, ConfNet, MyDB)
wg.Done()
}()
var options = make(map[dhcp.OptionCode][]byte)
Expand Down Expand Up @@ -326,7 +327,7 @@ func (d *Interfaces) readConfig() {
}
DHCPScope.dstIp = dstReplyIp
// Initialize dhcp pool
available, _ := pool.Create(ctx, backend, uint64(dhcp.IPRange(net.ParseIP(ConfNet.DhcpStart), net.ParseIP(ConfNet.DhcpEnd))), DHCPNet.network.IP.String(), algorithm, StatsdClient, MySQLdatabase)
available, _ := pool.Create(ctx, backend, uint64(dhcp.IPRange(net.ParseIP(ConfNet.DhcpStart), net.ParseIP(ConfNet.DhcpEnd))), DHCPNet.network.IP.String(), algorithm, StatsdClient, MyDB)

DHCPScope.available = available

Expand All @@ -347,7 +348,7 @@ func (d *Interfaces) readConfig() {
DHCPScope.xid = xid
wg.Add(1)
go func() {
initiaLease(DHCPScope, ConfNet)
initiaLease(DHCPScope, ConfNet, MyDB)
wg.Done()
}()

Expand Down
20 changes: 11 additions & 9 deletions go/cmd/pfdhcp/keysoption.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
package main

import (
"database/sql"

"github.com/inverse-inc/go-utils/log"
)

// MysqlInsert function
func MysqlInsert(key string, value string) bool {
if err := MySQLdatabase.PingContext(ctx); err != nil {
func MysqlInsert(key string, value string, db *sql.DB) bool {
if err := db.PingContext(ctx); err != nil {
log.LoggerWContext(ctx).Error("Unable to ping database, reconnect: " + err.Error())
}

_, err := MySQLdatabase.Exec(
_, err := db.Exec(
`
INSERT into key_value_storage values(?,?)
ON DUPLICATE KEY UPDATE value = VALUES(value)
Expand All @@ -28,11 +30,11 @@ ON DUPLICATE KEY UPDATE value = VALUES(value)
}

// MysqlGet function
func MysqlGet(key string) (string, string) {
if err := MySQLdatabase.PingContext(ctx); err != nil {
func MysqlGet(key string, db *sql.DB) (string, string) {
if err := db.PingContext(ctx); err != nil {
log.LoggerWContext(ctx).Error("Unable to ping database, reconnect: " + err.Error())
}
rows, err := MySQLdatabase.Query("select id, value from key_value_storage where id = ?", "/dhcpd/"+key)
rows, err := db.Query("select id, value from key_value_storage where id = ?", "/dhcpd/"+key)
defer rows.Close()
if err != nil {
log.LoggerWContext(ctx).Debug("Error while getting MySQL '" + key + "': " + err.Error())
Expand All @@ -52,11 +54,11 @@ func MysqlGet(key string) (string, string) {
}

// MysqlDel function
func MysqlDel(key string) bool {
if err := MySQLdatabase.PingContext(ctx); err != nil {
func MysqlDel(key string, db *sql.DB) bool {
if err := db.PingContext(ctx); err != nil {
log.LoggerWContext(ctx).Error("Unable to ping database, reconnect: " + err.Error())
}
rows, err := MySQLdatabase.Query("delete from key_value_storage where id = ?", "/dhcpd/"+key)
rows, err := db.Query("delete from key_value_storage where id = ?", "/dhcpd/"+key)
defer rows.Close()
if err != nil {
log.LoggerWContext(ctx).Error("Error while deleting MySQL key '" + key + "': " + err.Error())
Expand Down
Loading
Loading