Skip to content

Commit

Permalink
make file server with options (#18)
Browse files Browse the repository at this point in the history
* change FileServer to options-initialized, add optional file listing support

* make fs change back compat

* clarify dir listing in fs

* join src lines for better readability

* rena FSOpt for consistency

* mark FileServer and FileServerSPA as deprecated
  • Loading branch information
umputun authored Aug 22, 2021
1 parent bb853a4 commit 51f9235
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 64 deletions.
125 changes: 99 additions & 26 deletions file_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,67 +10,140 @@ import (
"strings"
)

// FileServer returns http.FileServer handler to serve static files from a http.FileSystem,
// prevents directory listing.
// FS provides http.FileServer handler to serve static files from a http.FileSystem,
// prevents directory listing by default and supports spa-friendly mode (off by default) returning /index.html on 404.
// - public defines base path of the url, i.e. for http://example.com/static/* it should be /static
// - local for the local path to the root of the served directory
// - notFound is the reader for the custom 404 html, can be nil for default
func FileServer(public, local string, notFound io.Reader) (http.Handler, error) {
type FS struct {
public, root string
notFound io.Reader
isSpa bool
enableListing bool
handler http.HandlerFunc
}

// NewFileServer creates file server with optional spa mode and optional direcroty listing (disabled by default)
func NewFileServer(public, local string, options ...FsOpt) (*FS, error) {
res := FS{
public: public,
notFound: nil,
isSpa: false,
enableListing: false,
}

root, err := filepath.Abs(local)
if err != nil {
return nil, fmt.Errorf("can't get absolute path for %s: %w", local, err)
}
res.root = root

if _, err = os.Stat(root); os.IsNotExist(err) {
return nil, fmt.Errorf("local path %s doesn't exist: %w", root, err)
}

fs := http.StripPrefix(public, http.FileServer(noDirListingFS{http.Dir(root), false}))
return custom404Handler(fs, notFound)
for _, opt := range options {
err = opt(&res)
if err != nil {
return nil, err
}
}

cfs := customFS{
fs: http.Dir(root),
spa: res.isSpa,
listing: res.enableListing,
}
f := http.StripPrefix(public, http.FileServer(cfs))
res.handler = func(w http.ResponseWriter, r *http.Request) { f.ServeHTTP(w, r) }

if !res.enableListing {
h, err := custom404Handler(f, res.notFound)
if err != nil {
return nil, err
}
res.handler = func(w http.ResponseWriter, r *http.Request) { h.ServeHTTP(w, r) }
}

return &res, nil
}

// FileServer is a shortcut for making FS with listing disabled and the custom noFound reader (can be nil).
// Deprecated: the method is for back-compatibility only and user should use the universal NewFileServer instead
func FileServer(public, local string, notFound io.Reader) (http.Handler, error) {
return NewFileServer(public, local, FsOptCustom404(notFound))
}

// FileServerSPA returns FileServer as above, but instead of no-found returns /local/index.html
// FileServerSPA is a shortcut for making FS with SPA-friendly handling of 404, listing disabled and the custom noFound reader (can be nil).
// Deprecated: the method is for back-compatibility only and user should use the universal NewFileServer instead
func FileServerSPA(public, local string, notFound io.Reader) (http.Handler, error) {
return NewFileServer(public, local, FsOptCustom404(notFound), FsOptSPA)
}

root, err := filepath.Abs(local)
if err != nil {
return nil, fmt.Errorf("can't get absolute path for %s: %w", local, err)
}
if _, err = os.Stat(root); os.IsNotExist(err) {
return nil, fmt.Errorf("local path %s doesn't exist: %w", root, err)
}
// ServeHTTP makes FileServer compatible with http.Handler interface
func (fs *FS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fs.handler(w, r)
}

fs := http.StripPrefix(public, http.FileServer(noDirListingFS{http.Dir(root), true}))
return custom404Handler(fs, notFound)
// FsOpt defines functional option type
type FsOpt func(fs *FS) error

// FsOptSPA turns on SPA mode returning "/index.html" on not-found
func FsOptSPA(fs *FS) error {
fs.isSpa = true
return nil
}

type noDirListingFS struct {
fs http.FileSystem
spa bool
// FsOptListing turns on directory listing
func FsOptListing(fs *FS) error {
fs.enableListing = true
return nil
}

// FsOptCustom404 sets custom 404 reader
func FsOptCustom404(fr io.Reader) FsOpt {
return func(fs *FS) error {
fs.notFound = fr
return nil
}
}

// customFS wraps http.FileSystem with spa and no-listing optional support
type customFS struct {
fs http.FileSystem
spa bool
listing bool
}

// Open file on FS, for directory enforce index.html and fail on a missing index
func (fs noDirListingFS) Open(name string) (http.File, error) {
func (cfs customFS) Open(name string) (http.File, error) {

f, err := fs.fs.Open(name)
f, err := cfs.fs.Open(name)
if err != nil {
if fs.spa {
return fs.fs.Open("/index.html")
if cfs.spa {
return cfs.fs.Open("/index.html")
}
return nil, err
}

s, err := f.Stat()
finfo, err := f.Stat()
if err != nil {
return nil, err
}

if s.IsDir() {
if finfo.IsDir() {
index := strings.TrimSuffix(name, "/") + "/index.html"
if _, err := fs.fs.Open(index); err != nil {
return nil, err
if _, err := cfs.fs.Open(index); err == nil { // index.html will be served if found
return f, nil
}
// no index.html in directory
if !cfs.listing { // listing disabled
if _, err := cfs.fs.Open(index); err != nil {
return nil, err
}
}
}

return f, nil
}

Expand Down
134 changes: 96 additions & 38 deletions file_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@ import (
"github.com/go-pkgz/rest/logger"
)

func TestFileServer(t *testing.T) {
fh, err := FileServer("/static", "./testdata/root", nil)
func TestFileServerDefault(t *testing.T) {
fh1, err := NewFileServer("/static", "./testdata/root")
require.NoError(t, err)
ts := httptest.NewServer(logger.Logger(fh))
defer ts.Close()

fh2, err := FileServer("/static", "./testdata/root", nil)
require.NoError(t, err)

ts1 := httptest.NewServer(logger.Logger(fh1))
defer ts1.Close()
ts2 := httptest.NewServer(logger.Logger(fh2))
defer ts2.Close()

client := http.Client{Timeout: 599 * time.Second}

tbl := []struct {
Expand Down Expand Up @@ -48,28 +55,73 @@ func TestFileServer(t *testing.T) {
for i, tt := range tbl {
tt := tt
t.Run(strconv.Itoa(i), func(t *testing.T) {
req, err := http.NewRequest("GET", ts.URL+tt.req, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
t.Logf("headers: %v", resp.Header)
assert.Equal(t, tt.status, resp.StatusCode)
if resp.StatusCode == http.StatusNotFound {
msg, e := ioutil.ReadAll(resp.Body)
require.NoError(t, e)
assert.Equal(t, "404 page not found\n", string(msg))
return
for _, ts := range []*httptest.Server{ts1, ts2} {
req, err := http.NewRequest("GET", ts.URL+tt.req, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
t.Logf("headers: %v", resp.Header)
assert.Equal(t, tt.status, resp.StatusCode)
if resp.StatusCode == http.StatusNotFound {
msg, e := ioutil.ReadAll(resp.Body)
require.NoError(t, e)
assert.Equal(t, "404 page not found\n", string(msg))
return
}
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, tt.body, string(body))
}
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, tt.body, string(body))

})
}
}

func TestFileServerWithListing(t *testing.T) {
fh, err := NewFileServer("/static", "./testdata/root", FsOptListing)
require.NoError(t, err)
ts := httptest.NewServer(logger.Logger(fh))
defer ts.Close()
client := http.Client{Timeout: 599 * time.Second}

{
req, err := http.NewRequest("GET", ts.URL+"/static/1", nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
msg, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
exp := `<pre>
<a href="f1.html">f1.html</a>
<a href="f2.html">f2.html</a>
</pre>
`
assert.Equal(t, exp, string(msg))
}

{
req, err := http.NewRequest("GET", ts.URL+"/static/xyz.js", nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
msg, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "testdata/xyz.js", string(msg))
}

{
req, err := http.NewRequest("GET", ts.URL+"/static/no-such-thing.html", nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
}
}

func TestFileServer_Custom404(t *testing.T) {
fh, err := FileServer("/static", "./testdata/root", bytes.NewBufferString("custom 404"))
nf := FsOptCustom404(bytes.NewBufferString("custom 404"))
fh, err := NewFileServer("/static", "./testdata/root", nf)
require.NoError(t, err)
ts := httptest.NewServer(logger.Logger(fh))
defer ts.Close()
Expand Down Expand Up @@ -121,10 +173,15 @@ func TestFileServer_Custom404(t *testing.T) {
}

func TestFileServerSPA(t *testing.T) {
fh, err := FileServerSPA("/static", "./testdata/root", nil)
fh1, err := NewFileServer("/static", "./testdata/root", FsOptSPA)
require.NoError(t, err)
ts := httptest.NewServer(logger.Logger(fh))
defer ts.Close()
fh2, err := FileServerSPA("/static", "./testdata/root", nil)
require.NoError(t, err)

ts1 := httptest.NewServer(logger.Logger(fh1))
defer ts1.Close()
ts2 := httptest.NewServer(logger.Logger(fh2))
defer ts2.Close()
client := http.Client{Timeout: 599 * time.Second}

tbl := []struct {
Expand Down Expand Up @@ -155,22 +212,23 @@ func TestFileServerSPA(t *testing.T) {
for i, tt := range tbl {
tt := tt
t.Run(strconv.Itoa(i), func(t *testing.T) {
req, err := http.NewRequest("GET", ts.URL+tt.req, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
t.Logf("headers: %v", resp.Header)
assert.Equal(t, tt.status, resp.StatusCode)
if resp.StatusCode == http.StatusNotFound {
msg, e := ioutil.ReadAll(resp.Body)
require.NoError(t, e)
assert.Equal(t, "404 page not found\n", string(msg))
return
for _, ts := range []*httptest.Server{ts1, ts2} {
req, err := http.NewRequest("GET", ts.URL+tt.req, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
t.Logf("headers: %v", resp.Header)
assert.Equal(t, tt.status, resp.StatusCode)
if resp.StatusCode == http.StatusNotFound {
msg, e := ioutil.ReadAll(resp.Body)
require.NoError(t, e)
assert.Equal(t, "404 page not found\n", string(msg))
return
}
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, tt.body, string(body))
}
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, tt.body, string(body))

})
}
}

0 comments on commit 51f9235

Please sign in to comment.