diff --git a/cmd/XDC/config.go b/cmd/XDC/config.go index a47f19c7bf20..26908a0b2c25 100644 --- a/cmd/XDC/config.go +++ b/cmd/XDC/config.go @@ -120,8 +120,8 @@ func defaultNodeConfig() node.Config { cfg := node.DefaultConfig cfg.Name = clientIdentifier cfg.Version = params.VersionWithCommit(gitCommit) - cfg.HTTPModules = append(cfg.HTTPModules, "eth", "shh") - cfg.WSModules = append(cfg.WSModules, "eth", "shh") + cfg.HTTPModules = append(cfg.HTTPModules, "eth") + cfg.WSModules = append(cfg.WSModules, "eth") cfg.IPCPath = "XDC.ipc" return cfg } diff --git a/cmd/XDC/monitorcmd.go b/cmd/XDC/monitorcmd.go deleted file mode 100644 index 42decc574ff5..000000000000 --- a/cmd/XDC/monitorcmd.go +++ /dev/null @@ -1,351 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of go-ethereum. -// -// go-ethereum is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// go-ethereum is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with go-ethereum. If not, see . - -package main - -import ( - "fmt" - "math" - "reflect" - "runtime" - "sort" - "strings" - "time" - - "github.com/XinFinOrg/XDPoSChain/cmd/utils" - "github.com/XinFinOrg/XDPoSChain/node" - "github.com/XinFinOrg/XDPoSChain/rpc" - "github.com/gizak/termui" - "gopkg.in/urfave/cli.v1" -) - -var ( - monitorCommandAttachFlag = cli.StringFlag{ - Name: "attach", - Value: node.DefaultIPCEndpoint(clientIdentifier), - Usage: "API endpoint to attach to", - } - monitorCommandRowsFlag = cli.IntFlag{ - Name: "rows", - Value: 5, - Usage: "Maximum rows in the chart grid", - } - monitorCommandRefreshFlag = cli.IntFlag{ - Name: "refresh", - Value: 3, - Usage: "Refresh interval in seconds", - } - monitorCommand = cli.Command{ - Action: utils.MigrateFlags(monitor), // keep track of migration progress - Name: "monitor", - Usage: "Monitor and visualize node metrics", - ArgsUsage: " ", - Category: "MONITOR COMMANDS", - Description: ` -The XDC monitor is a tool to collect and visualize various internal metrics -gathered by the node, supporting different chart types as well as the capacity -to display multiple metrics simultaneously. -`, - Flags: []cli.Flag{ - monitorCommandAttachFlag, - monitorCommandRowsFlag, - monitorCommandRefreshFlag, - }, - } -) - -// monitor starts a terminal UI based monitoring tool for the requested metrics. -func monitor(ctx *cli.Context) error { - var ( - client *rpc.Client - err error - ) - // Attach to an Ethereum node over IPC or RPC - endpoint := ctx.String(monitorCommandAttachFlag.Name) - if client, err = dialRPC(endpoint); err != nil { - utils.Fatalf("Unable to attach to XDC node: %v", err) - } - defer client.Close() - - // Retrieve all the available metrics and resolve the user pattens - metrics, err := retrieveMetrics(client) - if err != nil { - utils.Fatalf("Failed to retrieve system metrics: %v", err) - } - monitored := resolveMetrics(metrics, ctx.Args()) - if len(monitored) == 0 { - list := expandMetrics(metrics, "") - sort.Strings(list) - - if len(list) > 0 { - utils.Fatalf("No metrics specified.\n\nAvailable:\n - %s", strings.Join(list, "\n - ")) - } else { - utils.Fatalf("No metrics collected by XDC (--%s).\n", utils.MetricsEnabledFlag.Name) - } - } - sort.Strings(monitored) - if cols := len(monitored) / ctx.Int(monitorCommandRowsFlag.Name); cols > 6 { - utils.Fatalf("Requested metrics (%d) spans more that 6 columns:\n - %s", len(monitored), strings.Join(monitored, "\n - ")) - } - // Create and configure the chart UI defaults - if err := termui.Init(); err != nil { - utils.Fatalf("Unable to initialize terminal UI: %v", err) - } - defer termui.Close() - - rows := len(monitored) - if max := ctx.Int(monitorCommandRowsFlag.Name); rows > max { - rows = max - } - cols := (len(monitored) + rows - 1) / rows - for i := 0; i < rows; i++ { - termui.Body.AddRows(termui.NewRow()) - } - // Create each individual data chart - footer := termui.NewPar("") - footer.Block.Border = true - footer.Height = 3 - - charts := make([]*termui.LineChart, len(monitored)) - units := make([]int, len(monitored)) - data := make([][]float64, len(monitored)) - for i := 0; i < len(monitored); i++ { - charts[i] = createChart((termui.TermHeight() - footer.Height) / rows) - row := termui.Body.Rows[i%rows] - row.Cols = append(row.Cols, termui.NewCol(12/cols, 0, charts[i])) - } - termui.Body.AddRows(termui.NewRow(termui.NewCol(12, 0, footer))) - - refreshCharts(client, monitored, data, units, charts, ctx, footer) - termui.Body.Align() - termui.Render(termui.Body) - - // Watch for various system events, and periodically refresh the charts - termui.Handle("/sys/kbd/C-c", func(termui.Event) { - termui.StopLoop() - }) - termui.Handle("/sys/wnd/resize", func(termui.Event) { - termui.Body.Width = termui.TermWidth() - for _, chart := range charts { - chart.Height = (termui.TermHeight() - footer.Height) / rows - } - termui.Body.Align() - termui.Render(termui.Body) - }) - go func() { - tick := time.NewTicker(time.Duration(ctx.Int(monitorCommandRefreshFlag.Name)) * time.Second) - for range tick.C { - if refreshCharts(client, monitored, data, units, charts, ctx, footer) { - termui.Body.Align() - } - termui.Render(termui.Body) - } - }() - termui.Loop() - return nil -} - -// retrieveMetrics contacts the attached XDC node and retrieves the entire set -// of collected system metrics. -func retrieveMetrics(client *rpc.Client) (map[string]interface{}, error) { - var metrics map[string]interface{} - err := client.Call(&metrics, "debug_metrics", true) - return metrics, err -} - -// resolveMetrics takes a list of input metric patterns, and resolves each to one -// or more canonical metric names. -func resolveMetrics(metrics map[string]interface{}, patterns []string) []string { - res := []string{} - for _, pattern := range patterns { - res = append(res, resolveMetric(metrics, pattern, "")...) - } - return res -} - -// resolveMetrics takes a single of input metric pattern, and resolves it to one -// or more canonical metric names. -func resolveMetric(metrics map[string]interface{}, pattern string, path string) []string { - results := []string{} - - // If a nested metric was requested, recurse optionally branching (via comma) - parts := strings.SplitN(pattern, "/", 2) - if len(parts) > 1 { - for _, variation := range strings.Split(parts[0], ",") { - if submetrics, ok := metrics[variation].(map[string]interface{}); !ok { - utils.Fatalf("Failed to retrieve system metrics: %s", path+variation) - return nil - } else { - results = append(results, resolveMetric(submetrics, parts[1], path+variation+"/")...) - } - } - return results - } - // Depending what the last link is, return or expand - for _, variation := range strings.Split(pattern, ",") { - switch metric := metrics[variation].(type) { - case float64: - // Final metric value found, return as singleton - results = append(results, path+variation) - - case map[string]interface{}: - results = append(results, expandMetrics(metric, path+variation+"/")...) - - default: - utils.Fatalf("Metric pattern resolved to unexpected type: %v", reflect.TypeOf(metric)) - return nil - } - } - return results -} - -// expandMetrics expands the entire tree of metrics into a flat list of paths. -func expandMetrics(metrics map[string]interface{}, path string) []string { - // Iterate over all fields and expand individually - list := []string{} - for name, metric := range metrics { - switch metric := metric.(type) { - case float64: - // Final metric value found, append to list - list = append(list, path+name) - - case map[string]interface{}: - // Tree of metrics found, expand recursively - list = append(list, expandMetrics(metric, path+name+"/")...) - - default: - utils.Fatalf("Metric pattern %s resolved to unexpected type: %v", path+name, reflect.TypeOf(metric)) - return nil - } - } - return list -} - -// fetchMetric iterates over the metrics map and retrieves a specific one. -func fetchMetric(metrics map[string]interface{}, metric string) float64 { - parts := strings.Split(metric, "/") - for _, part := range parts[:len(parts)-1] { - var found bool - metrics, found = metrics[part].(map[string]interface{}) - if !found { - return 0 - } - } - if v, ok := metrics[parts[len(parts)-1]].(float64); ok { - return v - } - return 0 -} - -// refreshCharts retrieves a next batch of metrics, and inserts all the new -// values into the active datasets and charts -func refreshCharts(client *rpc.Client, metrics []string, data [][]float64, units []int, charts []*termui.LineChart, ctx *cli.Context, footer *termui.Par) (realign bool) { - values, err := retrieveMetrics(client) - for i, metric := range metrics { - if len(data) < 512 { - data[i] = append([]float64{fetchMetric(values, metric)}, data[i]...) - } else { - data[i] = append([]float64{fetchMetric(values, metric)}, data[i][:len(data[i])-1]...) - } - if updateChart(metric, data[i], &units[i], charts[i], err) { - realign = true - } - } - updateFooter(ctx, err, footer) - return -} - -// updateChart inserts a dataset into a line chart, scaling appropriately as to -// not display weird labels, also updating the chart label accordingly. -func updateChart(metric string, data []float64, base *int, chart *termui.LineChart, err error) (realign bool) { - dataUnits := []string{"", "K", "M", "G", "T", "E"} - timeUnits := []string{"ns", "µs", "ms", "s", "ks", "ms"} - colors := []termui.Attribute{termui.ColorBlue, termui.ColorCyan, termui.ColorGreen, termui.ColorYellow, termui.ColorRed, termui.ColorRed} - - // Extract only part of the data that's actually visible - if chart.Width*2 < len(data) { - data = data[:chart.Width*2] - } - // Find the maximum value and scale under 1K - high := 0.0 - if len(data) > 0 { - high = data[0] - for _, value := range data[1:] { - high = math.Max(high, value) - } - } - unit, scale := 0, 1.0 - for high >= 1000 && unit+1 < len(dataUnits) { - high, unit, scale = high/1000, unit+1, scale*1000 - } - // If the unit changes, re-create the chart (hack to set max height...) - if unit != *base { - realign, *base, *chart = true, unit, *createChart(chart.Height) - } - // Update the chart's data points with the scaled values - if cap(chart.Data) < len(data) { - chart.Data = make([]float64, len(data)) - } - chart.Data = chart.Data[:len(data)] - for i, value := range data { - chart.Data[i] = value / scale - } - // Update the chart's label with the scale units - units := dataUnits - if strings.Contains(metric, "/Percentiles/") || strings.Contains(metric, "/pauses/") || strings.Contains(metric, "/time/") { - units = timeUnits - } - chart.BorderLabel = metric - if len(units[unit]) > 0 { - chart.BorderLabel += " [" + units[unit] + "]" - } - chart.LineColor = colors[unit] | termui.AttrBold - if err != nil { - chart.LineColor = termui.ColorRed | termui.AttrBold - } - return -} - -// createChart creates an empty line chart with the default configs. -func createChart(height int) *termui.LineChart { - chart := termui.NewLineChart() - if runtime.GOOS == "windows" { - chart.Mode = "dot" - } - chart.DataLabels = []string{""} - chart.Height = height - chart.AxesColor = termui.ColorWhite - chart.PaddingBottom = -2 - - chart.BorderLabelFg = chart.BorderFg | termui.AttrBold - chart.BorderFg = chart.BorderBg - - return chart -} - -// updateFooter updates the footer contents based on any encountered errors. -func updateFooter(ctx *cli.Context, err error, footer *termui.Par) { - // Generate the basic footer - refresh := time.Duration(ctx.Int(monitorCommandRefreshFlag.Name)) * time.Second - footer.Text = fmt.Sprintf("Press Ctrl+C to quit. Refresh interval: %v.", refresh) - footer.TextFgColor = termui.ThemeAttr("par.fg") | termui.AttrBold - - // Append any encountered errors - if err != nil { - footer.Text = fmt.Sprintf("Error: %v.", err) - footer.TextFgColor = termui.ColorRed | termui.AttrBold - } -} diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index dbeadc9c180d..a4e7174859c1 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -768,12 +768,15 @@ func setNAT(ctx *cli.Context, cfg *p2p.Config) { // splitAndTrim splits input separated by a comma // and trims excessive white space from the substrings. -func splitAndTrim(input string) []string { - result := strings.Split(input, ",") - for i, r := range result { - result[i] = strings.TrimSpace(r) +func splitAndTrim(input string) (ret []string) { + l := strings.Split(input, ",") + for _, r := range l { + r = strings.TrimSpace(r) + if len(r) > 0 { + ret = append(ret, r) + } } - return result + return ret } // setHTTP creates the HTTP RPC listener interface string from the set diff --git a/eth/ethconfig/config.go b/eth/ethconfig/config.go index 512e57387620..642b59c4a0c0 100644 --- a/eth/ethconfig/config.go +++ b/eth/ethconfig/config.go @@ -86,8 +86,15 @@ func init() { home = user.HomeDir } } - if runtime.GOOS == "windows" { - Defaults.Ethash.DatasetDir = filepath.Join(home, "AppData", "Ethash") + if runtime.GOOS == "darwin" { + Defaults.Ethash.DatasetDir = filepath.Join(home, "Library", "Ethash") + } else if runtime.GOOS == "windows" { + localappdata := os.Getenv("LOCALAPPDATA") + if localappdata != "" { + Defaults.Ethash.DatasetDir = filepath.Join(localappdata, "Ethash") + } else { + Defaults.Ethash.DatasetDir = filepath.Join(home, "AppData", "Local", "Ethash") + } } else { Defaults.Ethash.DatasetDir = filepath.Join(home, ".ethash") } diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go index 8aaa3e9d263f..f726b5d6205b 100644 --- a/internal/web3ext/web3ext.go +++ b/internal/web3ext/web3ext.go @@ -300,11 +300,6 @@ web3._extend({ name: 'chaindbCompact', call: 'debug_chaindbCompact', }), - new web3._extend.Method({ - name: 'metrics', - call: 'debug_metrics', - params: 1 - }), new web3._extend.Method({ name: 'verbosity', call: 'debug_verbosity', diff --git a/node/api.go b/node/api.go index a014931efc6a..362880ceaa7b 100644 --- a/node/api.go +++ b/node/api.go @@ -21,11 +21,9 @@ import ( "errors" "fmt" "strings" - "time" "github.com/XinFinOrg/XDPoSChain/common/hexutil" "github.com/XinFinOrg/XDPoSChain/crypto" - "github.com/XinFinOrg/XDPoSChain/metrics" "github.com/XinFinOrg/XDPoSChain/p2p" "github.com/XinFinOrg/XDPoSChain/p2p/discover" "github.com/XinFinOrg/XDPoSChain/rpc" @@ -189,7 +187,7 @@ func (api *PrivateAdminAPI) StartRPC(host *string, port *int, cors *string, apis } } - if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts); err != nil { + if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts, api.node.config.WSOrigins); err != nil { return false, err } return true, nil @@ -249,7 +247,7 @@ func (api *PrivateAdminAPI) StartWS(host *string, port *int, allowedOrigins *str return true, nil } -// StopRPC terminates an already running websocket RPC API endpoint. +// StopWS terminates an already running websocket RPC API endpoint. func (api *PrivateAdminAPI) StopWS() (bool, error) { api.node.lock.Lock() defer api.node.lock.Unlock() @@ -298,121 +296,6 @@ func (api *PublicAdminAPI) Datadir() string { return api.node.DataDir() } -// PublicDebugAPI is the collection of debugging related API methods exposed over -// both secure and unsecure RPC channels. -type PublicDebugAPI struct { - node *Node // Node interfaced by this API -} - -// NewPublicDebugAPI creates a new API definition for the public debug methods -// of the node itself. -func NewPublicDebugAPI(node *Node) *PublicDebugAPI { - return &PublicDebugAPI{node: node} -} - -// Metrics retrieves all the known system metric collected by the node. -func (api *PublicDebugAPI) Metrics(raw bool) (map[string]interface{}, error) { - // Create a rate formatter - units := []string{"", "K", "M", "G", "T", "E", "P"} - round := func(value float64, prec int) string { - unit := 0 - for value >= 1000 { - unit, value, prec = unit+1, value/1000, 2 - } - return fmt.Sprintf(fmt.Sprintf("%%.%df%s", prec, units[unit]), value) - } - format := func(total float64, rate float64) string { - return fmt.Sprintf("%s (%s/s)", round(total, 0), round(rate, 2)) - } - // Iterate over all the metrics, and just dump for now - counters := make(map[string]interface{}) - metrics.DefaultRegistry.Each(func(name string, metric interface{}) { - // Create or retrieve the counter hierarchy for this metric - root, parts := counters, strings.Split(name, "/") - for _, part := range parts[:len(parts)-1] { - if _, ok := root[part]; !ok { - root[part] = make(map[string]interface{}) - } - root = root[part].(map[string]interface{}) - } - name = parts[len(parts)-1] - - // Fill the counter with the metric details, formatting if requested - if raw { - switch metric := metric.(type) { - case metrics.Counter: - root[name] = map[string]interface{}{ - "Overall": float64(metric.Count()), - } - - case metrics.Meter: - root[name] = map[string]interface{}{ - "AvgRate01Min": metric.Rate1(), - "AvgRate05Min": metric.Rate5(), - "AvgRate15Min": metric.Rate15(), - "MeanRate": metric.RateMean(), - "Overall": float64(metric.Count()), - } - - case metrics.Timer: - root[name] = map[string]interface{}{ - "AvgRate01Min": metric.Rate1(), - "AvgRate05Min": metric.Rate5(), - "AvgRate15Min": metric.Rate15(), - "MeanRate": metric.RateMean(), - "Overall": float64(metric.Count()), - "Percentiles": map[string]interface{}{ - "5": metric.Percentile(0.05), - "20": metric.Percentile(0.2), - "50": metric.Percentile(0.5), - "80": metric.Percentile(0.8), - "95": metric.Percentile(0.95), - }, - } - - default: - root[name] = "Unknown metric type" - } - } else { - switch metric := metric.(type) { - case metrics.Counter: - root[name] = map[string]interface{}{ - "Overall": float64(metric.Count()), - } - - case metrics.Meter: - root[name] = map[string]interface{}{ - "Avg01Min": format(metric.Rate1()*60, metric.Rate1()), - "Avg05Min": format(metric.Rate5()*300, metric.Rate5()), - "Avg15Min": format(metric.Rate15()*900, metric.Rate15()), - "Overall": format(float64(metric.Count()), metric.RateMean()), - } - - case metrics.Timer: - root[name] = map[string]interface{}{ - "Avg01Min": format(metric.Rate1()*60, metric.Rate1()), - "Avg05Min": format(metric.Rate5()*300, metric.Rate5()), - "Avg15Min": format(metric.Rate15()*900, metric.Rate15()), - "Overall": format(float64(metric.Count()), metric.RateMean()), - "Maximum": time.Duration(metric.Max()).String(), - "Minimum": time.Duration(metric.Min()).String(), - "Percentiles": map[string]interface{}{ - "5": time.Duration(metric.Percentile(0.05)).String(), - "20": time.Duration(metric.Percentile(0.2)).String(), - "50": time.Duration(metric.Percentile(0.5)).String(), - "80": time.Duration(metric.Percentile(0.8)).String(), - "95": time.Duration(metric.Percentile(0.95)).String(), - }, - } - - default: - root[name] = "Unknown metric type" - } - } - }) - return counters, nil -} - // PublicWeb3API offers helper utils type PublicWeb3API struct { stack *Node diff --git a/node/config.go b/node/config.go index 1984120f9ff3..355f316a19c6 100644 --- a/node/config.go +++ b/node/config.go @@ -89,11 +89,11 @@ type Config struct { // a simple file name, it is placed inside the data directory (or on the root // pipe path on Windows), whereas if it's a resolvable path name (absolute or // relative), then that specific path is enforced. An empty path disables IPC. - IPCPath string `toml:",omitempty"` + IPCPath string // HTTPHost is the host interface on which to start the HTTP RPC server. If this // field is empty, no HTTP API endpoint will be started. - HTTPHost string `toml:",omitempty"` + HTTPHost string // HTTPPort is the TCP port number on which to start the HTTP RPC server. The // default zero value is/ valid and will pick a port number randomly (useful @@ -117,7 +117,7 @@ type Config struct { // HTTPModules is a list of API modules to expose via the HTTP RPC interface. // If the module list is empty, all RPC API endpoints designated public will be // exposed. - HTTPModules []string `toml:",omitempty"` + HTTPModules []string // HTTPTimeouts allows for customization of the timeout values used by the HTTP RPC // interface. @@ -125,7 +125,7 @@ type Config struct { // WSHost is the host interface on which to start the websocket RPC server. If // this field is empty, no websocket API endpoint will be started. - WSHost string `toml:",omitempty"` + WSHost string // WSPort is the TCP port number on which to start the websocket RPC server. The // default zero value is/ valid and will pick a port number randomly (useful for @@ -140,7 +140,7 @@ type Config struct { // WSModules is a list of API modules to expose via the websocket RPC interface. // If the module list is empty, all RPC API endpoints designated public will be // exposed. - WSModules []string `toml:",omitempty"` + WSModules []string // WSExposeAll exposes all API modules via the WebSocket RPC interface rather // than just the public ones. @@ -215,7 +215,7 @@ func DefaultHTTPEndpoint() string { return config.HTTPEndpoint() } -// WSEndpoint resolves an websocket endpoint based on the configured host interface +// WSEndpoint resolves a websocket endpoint based on the configured host interface // and port parameters. func (c *Config) WSEndpoint() string { if c.WSHost == "" { diff --git a/node/defaults.go b/node/defaults.go index 3f36a50d8f72..8baa8b6e3935 100644 --- a/node/defaults.go +++ b/node/defaults.go @@ -56,11 +56,20 @@ func DefaultDataDir() string { // Try to place the data folder in the user's home dir home := homeDir() if home != "" { - if runtime.GOOS == "darwin" { + switch runtime.GOOS { + case "darwin": return filepath.Join(home, "Library", "XDCchain") - } else if runtime.GOOS == "windows" { - return filepath.Join(home, "AppData", "Roaming", "XDCchain") - } else { + case "windows": + // We used to put everything in %HOME%\AppData\Roaming, but this caused + // problems with non-typical setups. If this fallback location exists and + // is non-empty, use it, otherwise DTRT and check %LOCALAPPDATA%. + fallback := filepath.Join(home, "AppData", "Roaming", "XDCchain") + appdata := windowsAppData() + if appdata == "" || isNonEmptyDir(fallback) { + return fallback + } + return filepath.Join(appdata, "XDCchain") + default: return filepath.Join(home, ".XDC") } } @@ -68,6 +77,27 @@ func DefaultDataDir() string { return "" } +func windowsAppData() string { + v := os.Getenv("LOCALAPPDATA") + if v == "" { + // Windows XP and below don't have LocalAppData. Crash here because + // we don't support Windows XP and undefining the variable will cause + // other issues. + panic("environment variable LocalAppData is undefined") + } + return v +} + +func isNonEmptyDir(dir string) bool { + f, err := os.Open(dir) + if err != nil { + return false + } + names, _ := f.Readdir(1) + f.Close() + return len(names) > 0 +} + func homeDir() string { if home := os.Getenv("HOME"); home != "" { return home diff --git a/node/doc.go b/node/doc.go index d9688e0a12cb..e3cc58e5f49c 100644 --- a/node/doc.go +++ b/node/doc.go @@ -59,7 +59,7 @@ using the same data directory will store this information in different subdirect the data directory. LevelDB databases are also stored within the instance subdirectory. If multiple node -instances use the same data directory, openening the databases with identical names will +instances use the same data directory, opening the databases with identical names will create one database for each instance. The account key store is shared among all node instances using the same data directory @@ -69,7 +69,7 @@ unless its location is changed through the KeyStoreDir configuration option. Data Directory Sharing Example In this example, two node instances named A and B are started with the same data -directory. Mode instance A opens the database "db", node instance B opens the databases +directory. Node instance A opens the database "db", node instance B opens the databases "db" and "db-2". The following files will be created in the data directory: data-directory/ @@ -84,7 +84,7 @@ directory. Mode instance A opens the database "db", node instance B opens the da static-nodes.json -- devp2p static node list of instance B db/ -- LevelDB content for "db" db-2/ -- LevelDB content for "db-2" - B.ipc -- JSON-RPC UNIX domain socket endpoint of instance A + B.ipc -- JSON-RPC UNIX domain socket endpoint of instance B keystore/ -- account key store, used by both instances */ package node diff --git a/node/endpoints.go b/node/endpoints.go new file mode 100644 index 000000000000..234b49e73d71 --- /dev/null +++ b/node/endpoints.go @@ -0,0 +1,98 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package node + +import ( + "net" + "net/http" + "time" + + "github.com/XinFinOrg/XDPoSChain/log" + "github.com/XinFinOrg/XDPoSChain/rpc" +) + +// StartHTTPEndpoint starts the HTTP RPC endpoint. +func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http.Handler) (*http.Server, net.Addr, error) { + // start the HTTP listener + var ( + listener net.Listener + err error + ) + if listener, err = net.Listen("tcp", endpoint); err != nil { + return nil, nil, err + } + // Bundle and start the HTTP server + httpSrv := &http.Server{ + Handler: handler, + ReadTimeout: timeouts.ReadTimeout, + WriteTimeout: timeouts.WriteTimeout, + IdleTimeout: timeouts.IdleTimeout, + } + log.Info("StartHTTPEndpoint", "ReadTimeout", timeouts.ReadTimeout, "WriteTimeout", timeouts.WriteTimeout, "IdleTimeout", timeouts.IdleTimeout) + go httpSrv.Serve(listener) + return httpSrv, listener.Addr(), err +} + +// startWSEndpoint starts a websocket endpoint. +func startWSEndpoint(endpoint string, handler http.Handler) (*http.Server, net.Addr, error) { + // start the HTTP listener + var ( + listener net.Listener + err error + ) + if listener, err = net.Listen("tcp", endpoint); err != nil { + return nil, nil, err + } + wsSrv := &http.Server{Handler: handler} + go wsSrv.Serve(listener) + return wsSrv, listener.Addr(), err +} + +// checkModuleAvailability checks that all names given in modules are actually +// available API services. It assumes that the MetadataApi module ("rpc") is always available; +// the registration of this "rpc" module happens in NewServer() and is thus common to all endpoints. +func checkModuleAvailability(modules []string, apis []rpc.API) (bad, available []string) { + availableSet := make(map[string]struct{}) + for _, api := range apis { + if _, ok := availableSet[api.Namespace]; !ok { + availableSet[api.Namespace] = struct{}{} + available = append(available, api.Namespace) + } + } + for _, name := range modules { + if _, ok := availableSet[name]; !ok && name != rpc.MetadataApi { + bad = append(bad, name) + } + } + return bad, available +} + +// CheckTimeouts ensures that timeout values are meaningful +func CheckTimeouts(timeouts *rpc.HTTPTimeouts) { + if timeouts.ReadTimeout < time.Second { + log.Warn("Sanitizing invalid HTTP read timeout", "provided", timeouts.ReadTimeout, "updated", rpc.DefaultHTTPTimeouts.ReadTimeout) + timeouts.ReadTimeout = rpc.DefaultHTTPTimeouts.ReadTimeout + } + if timeouts.WriteTimeout < time.Second { + log.Warn("Sanitizing invalid HTTP write timeout", "provided", timeouts.WriteTimeout, "updated", rpc.DefaultHTTPTimeouts.WriteTimeout) + timeouts.WriteTimeout = rpc.DefaultHTTPTimeouts.WriteTimeout + } + if timeouts.IdleTimeout < time.Second { + log.Warn("Sanitizing invalid HTTP idle timeout", "provided", timeouts.IdleTimeout, "updated", rpc.DefaultHTTPTimeouts.IdleTimeout) + timeouts.IdleTimeout = rpc.DefaultHTTPTimeouts.IdleTimeout + } +} diff --git a/node/node.go b/node/node.go index c5adec4451bd..dfa261640c18 100644 --- a/node/node.go +++ b/node/node.go @@ -17,9 +17,11 @@ package node import ( + "context" "errors" "fmt" "net" + "net/http" "os" "path/filepath" "reflect" @@ -60,14 +62,16 @@ type Node struct { ipcListener net.Listener // IPC RPC listener socket to serve API requests ipcHandler *rpc.Server // IPC RPC request handler to process the API requests - httpEndpoint string // HTTP endpoint (interface + port) to listen at (empty = HTTP disabled) - httpWhitelist []string // HTTP RPC modules to allow through this endpoint - httpListener net.Listener // HTTP RPC listener socket to server API requests - httpHandler *rpc.Server // HTTP RPC request handler to process the API requests + httpEndpoint string // HTTP endpoint (interface + port) to listen at (empty = HTTP disabled) + httpWhitelist []string // HTTP RPC modules to allow through this endpoint + httpListenerAddr net.Addr // Address of HTTP RPC listener socket serving API requests + httpServer *http.Server // HTTP RPC HTTP server + httpHandler *rpc.Server // HTTP RPC request handler to process the API requests - wsEndpoint string // Websocket endpoint (interface + port) to listen at (empty = websocket disabled) - wsListener net.Listener // Websocket RPC listener socket to server API requests - wsHandler *rpc.Server // Websocket RPC request handler to process the API requests + wsEndpoint string // WebSocket endpoint (interface + port) to listen at (empty = WebSocket disabled) + wsListenerAddr net.Addr // Address of WebSocket RPC listener socket serving API requests + wsHTTPServer *http.Server // WebSocket RPC HTTP server + wsHandler *rpc.Server // WebSocket RPC request handler to process the API requests stop chan struct{} // Channel to wait for termination notifications lock sync.RWMutex @@ -269,17 +273,21 @@ func (n *Node) startRPC(services map[reflect.Type]Service) error { n.stopInProc() return err } - if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts); err != nil { + if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts, n.config.WSOrigins); err != nil { n.stopIPC() n.stopInProc() return err } - if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil { - n.stopHTTP() - n.stopIPC() - n.stopInProc() - return err + // if endpoints are not the same, start separate servers + if n.httpEndpoint != n.wsEndpoint { + if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil { + n.stopHTTP() + n.stopIPC() + n.stopInProc() + return err + } } + // All API endpoints started successfully n.rpcAPIs = apis return nil @@ -348,33 +356,47 @@ func (n *Node) stopIPC() { } // startHTTP initializes and starts the HTTP RPC endpoint. -func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts) error { +func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts, wsOrigins []string) error { // Short circuit if the HTTP endpoint isn't being exposed if endpoint == "" { return nil } - listener, handler, err := rpc.StartHTTPEndpoint(endpoint, apis, modules, cors, vhosts, timeouts) + // register apis and create handler stack + srv := rpc.NewServer() + err := RegisterApisFromWhitelist(apis, modules, srv, false) if err != nil { return err } - n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%v/", listener.Addr()), + handler := NewHTTPHandlerStack(srv, cors, vhosts, &timeouts) + // wrap handler in WebSocket handler only if WebSocket port is the same as http rpc + if n.httpEndpoint == n.wsEndpoint { + handler = NewWebsocketUpgradeHandler(handler, srv.WebsocketHandler(wsOrigins)) + } + httpServer, addr, err := StartHTTPEndpoint(endpoint, timeouts, handler) + if err != nil { + return err + } + n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%v/", addr), "cors", strings.Join(cors, ","), "vhosts", strings.Join(vhosts, ",")) + if n.httpEndpoint == n.wsEndpoint { + n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", addr)) + } // All listeners booted successfully n.httpEndpoint = endpoint - n.httpListener = listener - n.httpHandler = handler + n.httpListenerAddr = addr + n.httpServer = httpServer + n.httpHandler = srv return nil } // stopHTTP terminates the HTTP RPC endpoint. func (n *Node) stopHTTP() { - if n.httpListener != nil { - url := fmt.Sprintf("http://%v/", n.httpListener.Addr()) - n.httpListener.Close() - n.httpListener = nil - n.log.Info("HTTP endpoint closed", "url", url) + if n.httpServer != nil { + // Don't bother imposing a timeout here. + n.httpServer.Shutdown(context.Background()) + n.log.Info("HTTP endpoint closed", "url", fmt.Sprintf("http://%v/", n.httpListenerAddr)) } if n.httpHandler != nil { n.httpHandler.Stop() @@ -382,32 +404,39 @@ func (n *Node) stopHTTP() { } } -// startWS initializes and starts the websocket RPC endpoint. +// startWS initializes and starts the WebSocket RPC endpoint. func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrigins []string, exposeAll bool) error { // Short circuit if the WS endpoint isn't being exposed if endpoint == "" { return nil } - listener, handler, err := rpc.StartWSEndpoint(endpoint, apis, modules, wsOrigins, exposeAll) + + srv := rpc.NewServer() + handler := srv.WebsocketHandler(wsOrigins) + err := RegisterApisFromWhitelist(apis, modules, srv, exposeAll) + if err != nil { + return err + } + httpServer, addr, err := startWSEndpoint(endpoint, handler) if err != nil { return err } - n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%s", listener.Addr())) + n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", addr)) // All listeners booted successfully n.wsEndpoint = endpoint - n.wsListener = listener - n.wsHandler = handler + n.wsListenerAddr = addr + n.wsHTTPServer = httpServer + n.wsHandler = srv return nil } -// stopWS terminates the websocket RPC endpoint. +// stopWS terminates the WebSocket RPC endpoint. func (n *Node) stopWS() { - if n.wsListener != nil { - n.wsListener.Close() - n.wsListener = nil - - n.log.Info("WebSocket endpoint closed", "url", fmt.Sprintf("ws://%s", n.wsEndpoint)) + if n.wsHTTPServer != nil { + // Don't bother imposing a timeout here. + n.wsHTTPServer.Shutdown(context.Background()) + n.log.Info("WebSocket endpoint closed", "url", fmt.Sprintf("ws://%v", n.wsListenerAddr)) } if n.wsHandler != nil { n.wsHandler.Stop() @@ -574,8 +603,8 @@ func (n *Node) HTTPEndpoint() string { n.lock.Lock() defer n.lock.Unlock() - if n.httpListener != nil { - return n.httpListener.Addr().String() + if n.httpListenerAddr != nil { + return n.httpListenerAddr.String() } return n.httpEndpoint } @@ -585,8 +614,8 @@ func (n *Node) WSEndpoint() string { n.lock.Lock() defer n.lock.Unlock() - if n.wsListener != nil { - return n.wsListener.Addr().String() + if n.wsListenerAddr != nil { + return n.wsListenerAddr.String() } return n.wsEndpoint } @@ -628,11 +657,6 @@ func (n *Node) apis() []rpc.API { Namespace: "debug", Version: "1.0", Service: debug.Handler, - }, { - Namespace: "debug", - Version: "1.0", - Service: NewPublicDebugAPI(n), - Public: true, }, { Namespace: "web3", Version: "1.0", @@ -641,3 +665,25 @@ func (n *Node) apis() []rpc.API { }, } } + +// RegisterApisFromWhitelist checks the given modules' availability, generates a whitelist based on the allowed modules, +// and then registers all of the APIs exposed by the services. +func RegisterApisFromWhitelist(apis []rpc.API, modules []string, srv *rpc.Server, exposeAll bool) error { + if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 { + log.Error("Unavailable modules in HTTP API list", "unavailable", bad, "available", available) + } + // Generate the whitelist based on the allowed modules + whitelist := make(map[string]bool) + for _, module := range modules { + whitelist[module] = true + } + // Register all the APIs exposed by the services + for _, api := range apis { + if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { + if err := srv.RegisterName(api.Namespace, api.Service); err != nil { + return err + } + } + } + return nil +} diff --git a/node/node_test.go b/node/node_test.go index 40f4e7118aaa..4b4eec464d7d 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -18,6 +18,7 @@ package node import ( "errors" + "net/http" "os" "reflect" "testing" @@ -26,6 +27,7 @@ import ( "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/p2p" "github.com/XinFinOrg/XDPoSChain/rpc" + "github.com/stretchr/testify/assert" ) var ( @@ -332,7 +334,7 @@ func TestServiceStartupAbortion(t *testing.T) { } // Tests that even if a registered service fails to shut down cleanly, it does -// not influece the rest of the shutdown invocations. +// not influence the rest of the shutdown invocations. func TestServiceTerminationGuarantee(t *testing.T) { stack, err := New(testNodeConfig()) if err != nil { @@ -506,8 +508,8 @@ func TestAPIGather(t *testing.T) { } // Register a batch of services with some configured APIs calls := make(chan string, 1) - makeAPI := func(result string) *OneMethodApi { - return &OneMethodApi{fun: func() { calls <- result }} + makeAPI := func(result string) *OneMethodAPI { + return &OneMethodAPI{fun: func() { calls <- result }} } services := map[string]struct { APIs []rpc.API @@ -572,3 +574,58 @@ func TestAPIGather(t *testing.T) { } } } + +func TestWebsocketHTTPOnSamePort_WebsocketRequest(t *testing.T) { + node := startHTTP(t) + defer node.stopHTTP() + + wsReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil) + if err != nil { + t.Error("could not issue new http request ", err) + } + wsReq.Header.Set("Connection", "upgrade") + wsReq.Header.Set("Upgrade", "websocket") + wsReq.Header.Set("Sec-WebSocket-Version", "13") + wsReq.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==") + + resp := doHTTPRequest(t, wsReq) + assert.Equal(t, "websocket", resp.Header.Get("Upgrade")) +} + +func TestWebsocketHTTPOnSamePort_HTTPRequest(t *testing.T) { + node := startHTTP(t) + defer node.stopHTTP() + + httpReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil) + if err != nil { + t.Error("could not issue new http request ", err) + } + httpReq.Header.Set("Accept-Encoding", "gzip") + + resp := doHTTPRequest(t, httpReq) + assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding")) +} + +func startHTTP(t *testing.T) *Node { + conf := &Config{HTTPPort: 7453, WSPort: 7453} + node, err := New(conf) + if err != nil { + t.Error("could not create a new node ", err) + } + + err = node.startHTTP("127.0.0.1:7453", []rpc.API{}, []string{}, []string{}, []string{}, rpc.HTTPTimeouts{}, []string{}) + if err != nil { + t.Error("could not start http service on node ", err) + } + + return node +} + +func doHTTPRequest(t *testing.T, req *http.Request) *http.Response { + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Error("could not issue a GET request to the given endpoint", err) + } + return resp +} diff --git a/node/rpcstack.go b/node/rpcstack.go new file mode 100644 index 000000000000..2de884f4c3b8 --- /dev/null +++ b/node/rpcstack.go @@ -0,0 +1,170 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package node + +import ( + "compress/gzip" + "io" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/XinFinOrg/XDPoSChain/log" + "github.com/XinFinOrg/XDPoSChain/rpc" + "github.com/rs/cors" +) + +// NewHTTPHandlerStack returns wrapped http-related handlers +func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string, timeouts *rpc.HTTPTimeouts) http.Handler { + // Wrap the CORS-handler within a host-handler + handler := newCorsHandler(srv, cors) + handler = newVHostHandler(vhosts, handler) + handler = newGzipHandler(handler) + + // make sure timeout values are meaningful + CheckTimeouts(timeouts) + + // PR #469: register timeout handler before WebSocket and HTTP + handler = http.TimeoutHandler(handler, timeouts.WriteTimeout, `{"error":"http server timeout"}`) + + // add 1 second to let TimeoutHandler works first + timeouts.WriteTimeout = timeouts.WriteTimeout + time.Second + return handler +} + +func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler { + // disable CORS support if user has not specified a custom CORS configuration + if len(allowedOrigins) == 0 { + return srv + } + c := cors.New(cors.Options{ + AllowedOrigins: allowedOrigins, + AllowedMethods: []string{http.MethodPost, http.MethodGet}, + MaxAge: 600, + AllowedHeaders: []string{"*"}, + }) + return c.Handler(srv) +} + +// virtualHostHandler is a handler which validates the Host-header of incoming requests. +// Using virtual hosts can help prevent DNS rebinding attacks, where a 'random' domain name points to +// the service ip address (but without CORS headers). By verifying the targeted virtual host, we can +// ensure that it's a destination that the node operator has defined. +type virtualHostHandler struct { + vhosts map[string]struct{} + next http.Handler +} + +func newVHostHandler(vhosts []string, next http.Handler) http.Handler { + vhostMap := make(map[string]struct{}) + for _, allowedHost := range vhosts { + vhostMap[strings.ToLower(allowedHost)] = struct{}{} + } + return &virtualHostHandler{vhostMap, next} +} + +// ServeHTTP serves JSON-RPC requests over HTTP, implements http.Handler +func (h *virtualHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // if r.Host is not set, we can continue serving since a browser would set the Host header + if r.Host == "" { + h.next.ServeHTTP(w, r) + return + } + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + // Either invalid (too many colons) or no port specified + host = r.Host + } + if ipAddr := net.ParseIP(host); ipAddr != nil { + // It's an IP address, we can serve that + h.next.ServeHTTP(w, r) + return + + } + // Not an IP address, but a hostname. Need to validate + if _, exist := h.vhosts["*"]; exist { + h.next.ServeHTTP(w, r) + return + } + if _, exist := h.vhosts[host]; exist { + h.next.ServeHTTP(w, r) + return + } + http.Error(w, "invalid host specified", http.StatusForbidden) +} + +var gzPool = sync.Pool{ + New: func() interface{} { + w := gzip.NewWriter(io.Discard) + return w + }, +} + +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter +} + +func (w *gzipResponseWriter) WriteHeader(status int) { + w.Header().Del("Content-Length") + w.ResponseWriter.WriteHeader(status) +} + +func (w *gzipResponseWriter) Write(b []byte) (int, error) { + return w.Writer.Write(b) +} + +func newGzipHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + next.ServeHTTP(w, r) + return + } + + w.Header().Set("Content-Encoding", "gzip") + + gz := gzPool.Get().(*gzip.Writer) + defer gzPool.Put(gz) + + gz.Reset(w) + defer gz.Close() + + next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r) + }) +} + +// NewWebsocketUpgradeHandler returns a websocket handler that serves an incoming request only if it contains an upgrade +// request to the websocket protocol. If not, serves the the request with the http handler. +func NewWebsocketUpgradeHandler(h http.Handler, ws http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isWebsocket(r) { + ws.ServeHTTP(w, r) + log.Debug("serving websocket request") + return + } + + h.ServeHTTP(w, r) + }) +} + +// isWebsocket checks the header of an http request for a websocket upgrade request. +func isWebsocket(r *http.Request) bool { + return strings.ToLower(r.Header.Get("Upgrade")) == "websocket" && + strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") +} diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go new file mode 100644 index 000000000000..db08bd40cb43 --- /dev/null +++ b/node/rpcstack_test.go @@ -0,0 +1,53 @@ +package node + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/XinFinOrg/XDPoSChain/rpc" + "github.com/stretchr/testify/assert" +) + +func TestNewWebsocketUpgradeHandler_websocket(t *testing.T) { + srv := rpc.NewServer() + + handler := NewWebsocketUpgradeHandler(nil, srv.WebsocketHandler([]string{})) + ts := httptest.NewServer(handler) + defer ts.Close() + + responses := make(chan *http.Response) + go func(responses chan *http.Response) { + client := &http.Client{} + + req, _ := http.NewRequest(http.MethodGet, ts.URL, nil) + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==") + + resp, err := client.Do(req) + if err != nil { + t.Error("could not issue a GET request to the test http server", err) + } + responses <- resp + }(responses) + + response := <-responses + assert.Equal(t, "websocket", response.Header.Get("Upgrade")) +} + +// TestIsWebsocket tests if an incoming websocket upgrade request is handled properly. +func TestIsWebsocket(t *testing.T) { + r, _ := http.NewRequest("GET", "/", nil) + + assert.False(t, isWebsocket(r)) + r.Header.Set("upgrade", "websocket") + assert.False(t, isWebsocket(r)) + r.Header.Set("connection", "upgrade") + assert.True(t, isWebsocket(r)) + r.Header.Set("connection", "upgrade,keep-alive") + assert.True(t, isWebsocket(r)) + r.Header.Set("connection", " UPGRADE,keep-alive") + assert.True(t, isWebsocket(r)) +} diff --git a/node/utils_test.go b/node/utils_test.go index 034b92d03aa1..d6aff608e05d 100644 --- a/node/utils_test.go +++ b/node/utils_test.go @@ -123,12 +123,12 @@ func InstrumentedServiceMakerC(base ServiceConstructor) ServiceConstructor { return InstrumentingWrapperMaker(base, reflect.TypeOf(InstrumentedServiceC{})) } -// OneMethodApi is a single-method API handler to be returned by test services. -type OneMethodApi struct { +// OneMethodAPI is a single-method API handler to be returned by test services. +type OneMethodAPI struct { fun func() } -func (api *OneMethodApi) TheOneMethod() { +func (api *OneMethodAPI) TheOneMethod() { if api.fun != nil { api.fun() } diff --git a/rpc/endpoints.go b/rpc/endpoints.go index 21802c9542db..0b543102fc0e 100644 --- a/rpc/endpoints.go +++ b/rpc/endpoints.go @@ -23,64 +23,6 @@ import ( "github.com/XinFinOrg/XDPoSChain/log" ) -// StartHTTPEndpoint starts the HTTP RPC endpoint, configured with cors/vhosts/modules -func StartHTTPEndpoint(endpoint string, apis []API, modules []string, cors []string, vhosts []string, timeouts HTTPTimeouts) (net.Listener, *Server, error) { - // Generate the whitelist based on the allowed modules - whitelist := make(map[string]bool) - for _, module := range modules { - whitelist[module] = true - } - // Register all the APIs exposed by the services - handler := NewServer() - for _, api := range apis { - if whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { - if err := handler.RegisterName(api.Namespace, api.Service); err != nil { - return nil, nil, err - } - log.Debug("HTTP registered", "namespace", api.Namespace) - } - } - // All APIs registered, start the HTTP listener - var ( - listener net.Listener - err error - ) - if listener, err = net.Listen("tcp", endpoint); err != nil { - return nil, nil, err - } - go NewHTTPServer(cors, vhosts, timeouts, handler).Serve(listener) - return listener, handler, err -} - -// StartWSEndpoint starts a websocket endpoint -func StartWSEndpoint(endpoint string, apis []API, modules []string, wsOrigins []string, exposeAll bool) (net.Listener, *Server, error) { - // Generate the whitelist based on the allowed modules - whitelist := make(map[string]bool) - for _, module := range modules { - whitelist[module] = true - } - // Register all the APIs exposed by the services - handler := NewServer() - for _, api := range apis { - if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { - if err := handler.RegisterName(api.Namespace, api.Service); err != nil { - return nil, nil, err - } - log.Debug("WebSocket registered", "service", api.Service, "namespace", api.Namespace) - } - } - // All APIs registered, start the HTTP listener - var ( - listener net.Listener - err error - ) - if listener, err = net.Listen("tcp", endpoint); err != nil { - return nil, nil, err - } - go NewWSServer(wsOrigins, handler).Serve(listener) - return listener, handler, err -} - // StartIPCEndpoint starts an IPC endpoint. func StartIPCEndpoint(ipcEndpoint string, apis []API) (net.Listener, *Server, error) { // Register all the APIs exposed by the services. @@ -94,7 +36,7 @@ func StartIPCEndpoint(ipcEndpoint string, apis []API) (net.Listener, *Server, er log.Info("IPC registration failed", "namespace", api.Namespace, "error", err) return nil, nil, err } - log.Debug("IPC registered", "service", api.Service, "namespace", api.Namespace) + log.Debug("IPC registered", "namespace", api.Namespace) if _, ok := regMap[api.Namespace]; !ok { registered = append(registered, api.Namespace) regMap[api.Namespace] = struct{}{} diff --git a/rpc/http.go b/rpc/http.go index 2416d66753b4..ee23c79a1a95 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -24,15 +24,10 @@ import ( "fmt" "io" "mime" - "net" "net/http" "net/url" - "strings" "sync" "time" - - "github.com/XinFinOrg/XDPoSChain/log" - "github.com/rs/cors" ) const ( @@ -237,41 +232,6 @@ func (t *httpServerConn) RemoteAddr() string { // SetWriteDeadline does nothing and always returns nil. func (t *httpServerConn) SetWriteDeadline(time.Time) error { return nil } -// NewHTTPServer creates a new HTTP RPC server around an API provider. -// -// Deprecated: Server implements http.Handler -func NewHTTPServer(cors []string, vhosts []string, timeouts HTTPTimeouts, srv *Server) *http.Server { - // Wrap the CORS-handler within a host-handler - handler := newCorsHandler(srv, cors) - handler = newVHostHandler(vhosts, handler) - - // Make sure timeout values are meaningful - if timeouts.ReadTimeout < time.Second { - log.Warn("Sanitizing invalid HTTP read timeout", "provided", timeouts.ReadTimeout, "updated", DefaultHTTPTimeouts.ReadTimeout) - timeouts.ReadTimeout = DefaultHTTPTimeouts.ReadTimeout - } - if timeouts.WriteTimeout < time.Second { - log.Warn("Sanitizing invalid HTTP write timeout", "provided", timeouts.WriteTimeout, "updated", DefaultHTTPTimeouts.WriteTimeout) - timeouts.WriteTimeout = DefaultHTTPTimeouts.WriteTimeout - } - if timeouts.IdleTimeout < time.Second { - log.Warn("Sanitizing invalid HTTP idle timeout", "provided", timeouts.IdleTimeout, "updated", DefaultHTTPTimeouts.IdleTimeout) - timeouts.IdleTimeout = DefaultHTTPTimeouts.IdleTimeout - } - - // PR #469: return http code 503 and error message to client when timeout - handler = http.TimeoutHandler(handler, timeouts.WriteTimeout, `{"error":"http server timeout"}`) - log.Info("NewHTTPServer", "writeTimeout", timeouts.WriteTimeout) - - // Bundle and start the HTTP server - return &http.Server{ - Handler: handler, - ReadTimeout: timeouts.ReadTimeout, - WriteTimeout: timeouts.WriteTimeout + time.Second, - IdleTimeout: timeouts.IdleTimeout, - } -} - // ServeHTTP serves JSON-RPC requests over HTTP. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Permit dumb empty requests for remote health-checks (AWS) @@ -328,64 +288,3 @@ func validateRequest(r *http.Request) (int, error) { err := fmt.Errorf("invalid content type, only %s is supported", contentType) return http.StatusUnsupportedMediaType, err } - -func newCorsHandler(srv *Server, allowedOrigins []string) http.Handler { - // disable CORS support if user has not specified a custom CORS configuration - if len(allowedOrigins) == 0 { - return srv - } - c := cors.New(cors.Options{ - AllowedOrigins: allowedOrigins, - AllowedMethods: []string{http.MethodPost, http.MethodGet}, - MaxAge: 600, - AllowedHeaders: []string{"*"}, - }) - return c.Handler(srv) -} - -// virtualHostHandler is a handler which validates the Host-header of incoming requests. -// The virtualHostHandler can prevent DNS rebinding attacks, which do not utilize CORS-headers, -// since they do in-domain requests against the RPC api. Instead, we can see on the Host-header -// which domain was used, and validate that against a whitelist. -type virtualHostHandler struct { - vhosts map[string]struct{} - next http.Handler -} - -func newVHostHandler(vhosts []string, next http.Handler) http.Handler { - vhostMap := make(map[string]struct{}) - for _, allowedHost := range vhosts { - vhostMap[strings.ToLower(allowedHost)] = struct{}{} - } - return &virtualHostHandler{vhostMap, next} -} - -// ServeHTTP serves JSON-RPC requests over HTTP, implements http.Handler -func (h *virtualHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // if r.Host is not set, we can continue serving since a browser would set the Host header - if r.Host == "" { - h.next.ServeHTTP(w, r) - return - } - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - // Either invalid (too many colons) or no port specified - host = r.Host - } - if ipAddr := net.ParseIP(host); ipAddr != nil { - // It's an IP address, we can serve that - h.next.ServeHTTP(w, r) - return - - } - // Not an ip address, but a hostname. Need to validate - if _, exist := h.vhosts["*"]; exist { - h.next.ServeHTTP(w, r) - return - } - if _, exist := h.vhosts[host]; exist { - h.next.ServeHTTP(w, r) - return - } - http.Error(w, "invalid host specified", http.StatusForbidden) -} diff --git a/rpc/websocket.go b/rpc/websocket.go index fdc1d05e9fd1..7c7cae455473 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -64,13 +64,6 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { }) } -// NewWSServer creates a new websocket RPC server around an API provider. -// -// Deprecated: use Server.WebsocketHandler -func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { - return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} -} - // wsHandshakeValidator returns a handler that verifies the origin during the // websocket upgrade process. When a '*' is specified as an allowed origins all // connections are accepted.