Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cmd/shfmt: implement --from-json and improve -tojson #900

Merged
merged 1 commit into from
Jul 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 225 additions & 34 deletions cmd/shfmt/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package main

import (
"encoding/json"
"go/ast"
"fmt"
"io"
"reflect"

Expand All @@ -14,68 +14,259 @@ import (

func writeJSON(w io.Writer, node syntax.Node, pretty bool) error {
val := reflect.ValueOf(node)
v, _ := encode(val)
encVal, _ := encode(val)
enc := json.NewEncoder(w)
if pretty {
enc.SetIndent("", "\t")
}
return enc.Encode(v)
return enc.Encode(encVal.Interface())
}

func encode(val reflect.Value) (interface{}, string) {
func encode(val reflect.Value) (reflect.Value, string) {
switch val.Kind() {
case reflect.Ptr:
elem := val.Elem()
if !elem.IsValid() {
return nil, ""
break
}
return encode(elem)
case reflect.Interface:
if val.IsNil() {
return nil, ""
break
}
v, tname := encode(val.Elem())
m := v.(map[string]interface{})
m["Type"] = tname
return m, ""
enc, tname := encode(val.Elem())
if tname != "" {
enc.Elem().Field(0).SetString(tname)
}
return enc, ""
case reflect.Struct:
m := make(map[string]interface{}, val.NumField()+1)
// Construct a new struct with an optional Type, Pos and End,
// and then all the visible fields which aren't positions.
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
ftyp := typ.Field(i)
if ftyp.Type.Name() == "Pos" {
continue
}
if !ast.IsExported(ftyp.Name) {
continue
fields := []reflect.StructField{typeField, posField, endField}
for _, field := range reflect.VisibleFields(typ) {
typ := anyType
if field.Type == posType {
typ = exportedPosType
}
fval := val.Field(i)
v, _ := encode(fval)
m[ftyp.Name] = v
fields = append(fields, reflect.StructField{
Name: field.Name,
Type: typ,
Tag: `json:",omitempty"`,
})
}
encTyp := reflect.StructOf(fields)
enc := reflect.New(encTyp).Elem()

// Pos methods are defined on struct pointer receivers.
for _, name := range [...]string{"Pos", "End"} {
for i, name := range [...]string{"Pos", "End"} {
if fn := val.Addr().MethodByName(name); fn.IsValid() {
m[name] = translatePos(fn.Call(nil)[0])
encodePos(enc.Field(1+i), fn.Call(nil)[0])
}
}
// Do the rest of the fields.
for i := 3; i < encTyp.NumField(); i++ {
ftyp := encTyp.Field(i)
fval := val.FieldByName(ftyp.Name)
if ftyp.Type == exportedPosType {
encodePos(enc.Field(i), fval)
} else {
encElem, _ := encode(fval)
if encElem.IsValid() {
enc.Field(i).Set(encElem)
}
}
}
return m, typ.Name()

// Addr helps prevent an allocation as we use interface{} fields.
return enc.Addr(), typ.Name()
case reflect.Slice:
l := make([]interface{}, val.Len())
for i := 0; i < val.Len(); i++ {
n := val.Len()
if n == 0 {
break
}
enc := reflect.MakeSlice(anySliceType, n, n)
for i := 0; i < n; i++ {
elem := val.Index(i)
l[i], _ = encode(elem)
encElem, _ := encode(elem)
enc.Index(i).Set(encElem)
}
return enc, ""
case reflect.Bool:
if val.Bool() {
return val, ""
}
case reflect.String:
if val.String() != "" {
return val, ""
}
case reflect.Uint32:
if val.Uint() != 0 {
return val, ""
}
return l, ""
default:
return val.Interface(), ""
panic(val.Kind().String())
}
return noValue, ""
}

var (
noValue reflect.Value

anyType = reflect.TypeOf((*interface{})(nil)).Elem() // interface{}
anySliceType = reflect.SliceOf(anyType) // []interface{}
posType = reflect.TypeOf((*syntax.Pos)(nil)).Elem() // syntax.Pos
exportedPosType = reflect.TypeOf((*exportedPos)(nil)) // *exportedPos

// TODO(v4): derived fields like Type, Pos, and End should have clearly
// different names to prevent confusion. For example: _type, _pos, _end.
typeField = reflect.StructField{
Name: "Type",
Type: reflect.TypeOf((*string)(nil)).Elem(),
Tag: `json:",omitempty"`,
}
posField = reflect.StructField{
Name: "Pos",
Type: exportedPosType,
Tag: `json:",omitempty"`,
}
endField = reflect.StructField{
Name: "End",
Type: exportedPosType,
Tag: `json:",omitempty"`,
}
)

type exportedPos struct {
Offset, Line, Col uint
}

func encodePos(encPtr, val reflect.Value) {
if !val.MethodByName("IsValid").Call(nil)[0].Bool() {
return
}
enc := reflect.New(exportedPosType.Elem())
encPtr.Set(enc)
enc = enc.Elem()

enc.Field(0).Set(val.MethodByName("Offset").Call(nil)[0])
enc.Field(1).Set(val.MethodByName("Line").Call(nil)[0])
enc.Field(2).Set(val.MethodByName("Col").Call(nil)[0])
}

func decodePos(val reflect.Value, enc map[string]interface{}) {
offset := uint(enc["Offset"].(float64))
line := uint(enc["Line"].(float64))
column := uint(enc["Col"].(float64))
val.Set(reflect.ValueOf(syntax.NewPos(offset, line, column)))
}

func translatePos(val reflect.Value) map[string]interface{} {
return map[string]interface{}{
"Offset": val.MethodByName("Offset").Call(nil)[0].Uint(),
"Line": val.MethodByName("Line").Call(nil)[0].Uint(),
"Col": val.MethodByName("Col").Call(nil)[0].Uint(),
func readJSON(r io.Reader) (syntax.Node, error) {
var enc interface{}
if err := json.NewDecoder(r).Decode(&enc); err != nil {
return nil, err
}
node := &syntax.File{}
if err := decode(reflect.ValueOf(node), enc); err != nil {
return nil, err
}
return node, nil
}

var nodeByName = map[string]reflect.Type{
"Word": reflect.TypeOf((*syntax.Word)(nil)).Elem(),

"Lit": reflect.TypeOf((*syntax.Lit)(nil)).Elem(),
"SglQuoted": reflect.TypeOf((*syntax.SglQuoted)(nil)).Elem(),
"DblQuoted": reflect.TypeOf((*syntax.DblQuoted)(nil)).Elem(),
"ParamExp": reflect.TypeOf((*syntax.ParamExp)(nil)).Elem(),
"CmdSubst": reflect.TypeOf((*syntax.CmdSubst)(nil)).Elem(),
"CallExpr": reflect.TypeOf((*syntax.CallExpr)(nil)).Elem(),
"ArithmExp": reflect.TypeOf((*syntax.ArithmExp)(nil)).Elem(),
"ProcSubst": reflect.TypeOf((*syntax.ProcSubst)(nil)).Elem(),
"ExtGlob": reflect.TypeOf((*syntax.ExtGlob)(nil)).Elem(),
"BraceExp": reflect.TypeOf((*syntax.BraceExp)(nil)).Elem(),

"ArithmCmd": reflect.TypeOf((*syntax.ArithmCmd)(nil)).Elem(),
"BinaryCmd": reflect.TypeOf((*syntax.BinaryCmd)(nil)).Elem(),
"IfClause": reflect.TypeOf((*syntax.IfClause)(nil)).Elem(),
"ForClause": reflect.TypeOf((*syntax.ForClause)(nil)).Elem(),
"WhileClause": reflect.TypeOf((*syntax.WhileClause)(nil)).Elem(),
"CaseClause": reflect.TypeOf((*syntax.CaseClause)(nil)).Elem(),
"Block": reflect.TypeOf((*syntax.Block)(nil)).Elem(),
"Subshell": reflect.TypeOf((*syntax.Subshell)(nil)).Elem(),
"FuncDecl": reflect.TypeOf((*syntax.FuncDecl)(nil)).Elem(),
"TestClause": reflect.TypeOf((*syntax.TestClause)(nil)).Elem(),
"DeclClause": reflect.TypeOf((*syntax.DeclClause)(nil)).Elem(),
"LetClause": reflect.TypeOf((*syntax.LetClause)(nil)).Elem(),
"TimeClause": reflect.TypeOf((*syntax.TimeClause)(nil)).Elem(),
"CoprocClause": reflect.TypeOf((*syntax.CoprocClause)(nil)).Elem(),
"TestDecl": reflect.TypeOf((*syntax.TestDecl)(nil)).Elem(),

"UnaryArithm": reflect.TypeOf((*syntax.UnaryArithm)(nil)).Elem(),
"BinaryArithm": reflect.TypeOf((*syntax.BinaryArithm)(nil)).Elem(),
"ParenArithm": reflect.TypeOf((*syntax.ParenArithm)(nil)).Elem(),

"UnaryTest": reflect.TypeOf((*syntax.UnaryTest)(nil)).Elem(),
"BinaryTest": reflect.TypeOf((*syntax.BinaryTest)(nil)).Elem(),
"ParenTest": reflect.TypeOf((*syntax.ParenTest)(nil)).Elem(),

"WordIter": reflect.TypeOf((*syntax.WordIter)(nil)).Elem(),
"CStyleLoop": reflect.TypeOf((*syntax.CStyleLoop)(nil)).Elem(),
}

func decode(val reflect.Value, enc interface{}) error {
switch enc := enc.(type) {
case map[string]interface{}:
if val.Kind() == reflect.Ptr && val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
if typeName, _ := enc["Type"].(string); typeName != "" {
typ := nodeByName[typeName]
if typ == nil {
return fmt.Errorf("unknown type: %q", typeName)
}
val.Set(reflect.New(typ))
}
for val.Kind() == reflect.Ptr || val.Kind() == reflect.Interface {
val = val.Elem()
}
for name, fv := range enc {
fval := val.FieldByName(name)
switch name {
case "Type", "Pos", "End":
// Type is already used above. Pos and End came from method calls.
continue
}
if !fval.IsValid() {
return fmt.Errorf("unknown field for %s: %q", val.Type(), name)
}
if fval.Type() == posType {
// TODO: don't panic on bad input
decodePos(fval, fv.(map[string]interface{}))
continue
}
if err := decode(fval, fv); err != nil {
return err
}
}
case []interface{}:
for _, encElem := range enc {
elem := reflect.New(val.Type().Elem()).Elem()
if err := decode(elem, encElem); err != nil {
return err
}
val.Set(reflect.Append(val, elem))
}
case float64:
// Tokens and thus operators are uint32, but encoding/json defaults to float64.
// TODO: reject invalid operators.
u := uint64(enc)
val.SetUint(u)
default:
if enc != nil {
val.Set(reflect.ValueOf(enc))
}
}
return nil
}
70 changes: 70 additions & 0 deletions cmd/shfmt/json_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) 2017, Daniel Martí <[email protected]>
// See LICENSE for licensing information

package main

import (
"bytes"
"os"
"strings"
"testing"

qt "github.com/frankban/quicktest"

"mvdan.cc/sh/v3/syntax"
)

func TestRoundtripJSON(t *testing.T) {
t.Parallel()

// Read testdata files.
inputShell, err := os.ReadFile("testdata/json.sh")
qt.Assert(t, err, qt.IsNil)
inputJSON, err := os.ReadFile("testdata/json.json")
if !*update { // allow it to not exist
qt.Assert(t, err, qt.IsNil)
}
sb := new(strings.Builder)

// Parse the shell source and check that it is well formatted.
parser := syntax.NewParser(syntax.KeepComments(true))
node, err := parser.Parse(bytes.NewReader(inputShell), "")
qt.Assert(t, err, qt.IsNil)

printer := syntax.NewPrinter()
sb.Reset()
err = printer.Print(sb, node)
qt.Assert(t, err, qt.IsNil)
qt.Assert(t, sb.String(), qt.Equals, string(inputShell))

// Validate writing the pretty JSON.
sb.Reset()
err = writeJSON(sb, node, true)
qt.Assert(t, err, qt.IsNil)
got := sb.String()
if *update {
err := os.WriteFile("testdata/json.json", []byte(got), 0o666)
qt.Assert(t, err, qt.IsNil)
} else {
qt.Assert(t, got, qt.Equals, string(inputJSON))
}

// Ensure we don't use the originally parsed node again.
node = nil

// Validate reading the pretty JSON and check that it formats the same.
node2, err := readJSON(bytes.NewReader(inputJSON))
qt.Assert(t, err, qt.IsNil)

sb.Reset()
err = printer.Print(sb, node2)
qt.Assert(t, err, qt.IsNil)
qt.Assert(t, sb.String(), qt.Equals, string(inputShell))

// Validate that emitting the JSON again produces the same result.
sb.Reset()
err = writeJSON(sb, node2, true)
qt.Assert(t, err, qt.IsNil)
got = sb.String()
qt.Assert(t, got, qt.Equals, string(inputJSON))
}
Loading