diff --git a/gotenv.go b/gotenv.go index d961355..1191d35 100644 --- a/gotenv.go +++ b/gotenv.go @@ -55,12 +55,12 @@ func Must(fn func(filenames ...string) error, filenames ...string) { } // Apply is a function to load an io Reader then export the valid variables into environment variables if they do not exist. -func Apply(r Reader) error { +func Apply(r io.Reader) error { return parset(r, false) } // OverApply is a function to load an io Reader then export and override the valid variables into environment variables. -func OverApply(r Reader) error { +func OverApply(r io.Reader) error { return parset(r, true) } @@ -86,7 +86,7 @@ func loadenv(override bool, filenames ...string) error { } // parse and set :) -func parset(r Reader, override bool) error { +func parset(r io.Reader, override bool) error { env, err := strictParse(r, override) if err != nil { return err @@ -112,7 +112,7 @@ func setenv(key, val string, override bool) { // Parse is a function to parse line by line any io.Reader supplied and returns the valid Env key/value pair of valid variables. // It expands the value of a variable from the environment variable but does not set the value to the environment itself. // This function is skipping any invalid lines and only processing the valid one. -func Parse(r Reader) Env { +func Parse(r io.Reader) Env { env, _ := strictParse(r, false) return env } @@ -120,7 +120,7 @@ func Parse(r Reader) Env { // StrictParse is a function to parse line by line any io.Reader supplied and returns the valid Env key/value pair of valid variables. // It expands the value of a variable from the environment variable but does not set the value to the environment itself. // This function is returning an error if there are any invalid lines. -func StrictParse(r Reader) (Env, error) { +func StrictParse(r io.Reader) (Env, error) { return strictParse(r, false) } @@ -208,31 +208,32 @@ func splitLines(data []byte, atEOF bool) (advance int, token []byte, err error) return eol, data[:idx], nil } -type Reader interface { - io.Reader - io.ReaderAt -} - -func strictParse(r Reader, override bool) (Env, error) { +func strictParse(r io.Reader, override bool) (Env, error) { env := make(Env) - // We chooes a different scanner depending on file encoding. - var scanner *bufio.Scanner + buf := new(bytes.Buffer) + tee := io.TeeReader(r, buf) // There can be a maximum of 3 BOM bytes. bomByteBuffer := make([]byte, 3) - if _, err := r.ReadAt(bomByteBuffer, 0); err != nil { + _, err := tee.Read(bomByteBuffer) + if err != nil && err != io.EOF { return env, err } + z := io.MultiReader(buf, r) + + // We chooes a different scanner depending on file encoding. + var scanner *bufio.Scanner + if bytes.HasPrefix(bomByteBuffer, bomUTF8) { - scanner = bufio.NewScanner(transform.NewReader(r, unicode.UTF8BOM.NewDecoder())) + scanner = bufio.NewScanner(transform.NewReader(z, unicode.UTF8BOM.NewDecoder())) } else if bytes.HasPrefix(bomByteBuffer, bomUTF16LE) { - scanner = bufio.NewScanner(transform.NewReader(r, unicode.UTF16(unicode.LittleEndian, unicode.ExpectBOM).NewDecoder())) + scanner = bufio.NewScanner(transform.NewReader(z, unicode.UTF16(unicode.LittleEndian, unicode.ExpectBOM).NewDecoder())) } else if bytes.HasPrefix(bomByteBuffer, bomUTF16BE) { - scanner = bufio.NewScanner(transform.NewReader(r, unicode.UTF16(unicode.BigEndian, unicode.ExpectBOM).NewDecoder())) + scanner = bufio.NewScanner(transform.NewReader(z, unicode.UTF16(unicode.BigEndian, unicode.ExpectBOM).NewDecoder())) } else { - scanner = bufio.NewScanner(r) + scanner = bufio.NewScanner(z) } scanner.Split(splitLines) diff --git a/gotenv_test.go b/gotenv_test.go index 1a4d041..0025552 100644 --- a/gotenv_test.go +++ b/gotenv_test.go @@ -3,6 +3,7 @@ package gotenv_test import ( "bufio" "errors" + "io" "os" "strings" "testing" @@ -242,34 +243,26 @@ func TestStrictParse(t *testing.T) { } type failingReader struct { - gotenv.Reader + io.Reader } func (fr failingReader) Read(p []byte) (n int, err error) { return 0, errors.New("you shall not read") } -func (fr failingReader) ReadAt(p []byte, off int64) (n int, err error) { - return 0, errors.New("you shall not read") -} - func TestStrictParse_PassThroughErrors(t *testing.T) { _, err := gotenv.StrictParse(&failingReader{}) assert.Error(t, err) } type infiniteReader struct { - gotenv.Reader + io.Reader } func (er infiniteReader) Read(p []byte) (n int, err error) { return len(p), nil } -func (er infiniteReader) ReadAt(p []byte, off int64) (n int, err error) { - return len(p), nil -} - func TestStrictParse_NoTokenPassThroughErrors(t *testing.T) { _, err := gotenv.StrictParse(&infiniteReader{}) assert.Error(t, err)