diff options
-rw-r--r-- | cmap/map.go (renamed from map.go) | 2 | ||||
-rw-r--r-- | cmap/map_test.go (renamed from map_test.go) | 6 | ||||
-rw-r--r-- | scfg/reader.go | 154 | ||||
-rw-r--r-- | scfg/reader_test.go | 174 | ||||
-rw-r--r-- | scfg/scfg.go | 55 | ||||
-rw-r--r-- | scfg/struct.go | 79 | ||||
-rw-r--r-- | scfg/unmarshal.go | 349 | ||||
-rw-r--r-- | scfg/unmarshal_test.go | 251 | ||||
-rw-r--r-- | scfg/writer.go | 107 | ||||
-rw-r--r-- | scfg/writer_test.go | 158 |
10 files changed, 1331 insertions, 4 deletions
@@ -5,7 +5,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package lindenii_common +package cmap import ( "sync" diff --git a/map_test.go b/cmap/map_test.go index b131207..c407c18 100644 --- a/map_test.go +++ b/cmap/map_test.go @@ -4,7 +4,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package lindenii_common_test +package cmap_test import ( "math/rand" @@ -12,13 +12,13 @@ import ( "sync" "testing" - lindenii_common "go.lindenii.runxiyu.org/lindenii-common" + "go.lindenii.runxiyu.org/lindenii-common/cmap" ) func TestConcurrentRange(t *testing.T) { const mapSize = 1 << 10 - m := new(lindenii_common.Map[int64, int64]) + m := new(cmap.Map[int64, int64]) for n := int64(1); n <= mapSize; n++ { m.Store(n, int64(n)) } diff --git a/scfg/reader.go b/scfg/reader.go new file mode 100644 index 0000000..53a0cc4 --- /dev/null +++ b/scfg/reader.go @@ -0,0 +1,154 @@ +package scfg + +import ( + "bufio" + "fmt" + "io" + "os" + "strings" +) + +// This limits the max block nesting depth to prevent stack overflows. +const maxNestingDepth = 1000 + +// Load loads a configuration file. +func Load(path string) (Block, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + return Read(f) +} + +// Read parses a configuration file from an io.Reader. +func Read(r io.Reader) (Block, error) { + scanner := bufio.NewScanner(r) + + dec := decoder{scanner: scanner} + block, closingBrace, err := dec.readBlock() + if err != nil { + return nil, err + } else if closingBrace { + return nil, fmt.Errorf("line %v: unexpected '}'", dec.lineno) + } + + return block, scanner.Err() +} + +type decoder struct { + scanner *bufio.Scanner + lineno int + blockDepth int +} + +// readBlock reads a block. closingBrace is true if parsing stopped on '}' +// (otherwise, it stopped on Scanner.Scan). +func (dec *decoder) readBlock() (block Block, closingBrace bool, err error) { + dec.blockDepth++ + defer func() { + dec.blockDepth-- + }() + + if dec.blockDepth >= maxNestingDepth { + return nil, false, fmt.Errorf("exceeded max block depth") + } + + for dec.scanner.Scan() { + dec.lineno++ + + l := dec.scanner.Text() + words, err := splitWords(l) + if err != nil { + return nil, false, fmt.Errorf("line %v: %v", dec.lineno, err) + } else if len(words) == 0 { + continue + } + + if len(words) == 1 && l[len(l)-1] == '}' { + closingBrace = true + break + } + + var d *Directive + if words[len(words)-1] == "{" && l[len(l)-1] == '{' { + words = words[:len(words)-1] + + var name string + params := words + if len(words) > 0 { + name, params = words[0], words[1:] + } + + startLineno := dec.lineno + childBlock, childClosingBrace, err := dec.readBlock() + if err != nil { + return nil, false, err + } else if !childClosingBrace { + return nil, false, fmt.Errorf("line %v: unterminated block", startLineno) + } + + // Allows callers to tell apart "no block" and "empty block" + if childBlock == nil { + childBlock = Block{} + } + + d = &Directive{Name: name, Params: params, Children: childBlock, lineno: dec.lineno} + } else { + d = &Directive{Name: words[0], Params: words[1:], lineno: dec.lineno} + } + block = append(block, d) + } + + return block, closingBrace, nil +} + +func splitWords(l string) ([]string, error) { + var ( + words []string + sb strings.Builder + escape bool + quote rune + wantWSP bool + ) + for _, ch := range l { + switch { + case escape: + sb.WriteRune(ch) + escape = false + case wantWSP && (ch != ' ' && ch != '\t'): + return words, fmt.Errorf("atom not allowed after quoted string") + case ch == '\\': + escape = true + case quote != 0 && ch == quote: + quote = 0 + wantWSP = true + if sb.Len() == 0 { + words = append(words, "") + } + case quote == 0 && len(words) == 0 && sb.Len() == 0 && ch == '#': + return nil, nil + case quote == 0 && (ch == '\'' || ch == '"'): + if sb.Len() > 0 { + return words, fmt.Errorf("quoted string not allowed after atom") + } + quote = ch + case quote == 0 && (ch == ' ' || ch == '\t'): + if sb.Len() > 0 { + words = append(words, sb.String()) + } + sb.Reset() + wantWSP = false + default: + sb.WriteRune(ch) + } + } + if quote != 0 { + return words, fmt.Errorf("unterminated quoted string") + } + if sb.Len() > 0 { + words = append(words, sb.String()) + } + return words, nil +} diff --git a/scfg/reader_test.go b/scfg/reader_test.go new file mode 100644 index 0000000..5a4af28 --- /dev/null +++ b/scfg/reader_test.go @@ -0,0 +1,174 @@ +package scfg + +import ( + "fmt" + "reflect" + "strings" + "testing" +) + +var readTests = []struct { + name string + src string + want Block +}{ + { + name: "flat", + src: `dir1 param1 param2 "" param3 +dir2 +dir3 param1 + +# comment +dir4 "param 1" 'param 2'`, + want: Block{ + {Name: "dir1", Params: []string{"param1", "param2", "", "param3"}}, + {Name: "dir2", Params: []string{}}, + {Name: "dir3", Params: []string{"param1"}}, + {Name: "dir4", Params: []string{"param 1", "param 2"}}, + }, + }, + { + name: "simpleBlocks", + src: `block1 { + dir1 param1 param2 + dir2 param1 +} + +block2 { +} + +block3 { + # comment +} + +block4 param1 "param2" { + dir1 +}`, + want: Block{ + { + Name: "block1", + Params: []string{}, + Children: Block{ + {Name: "dir1", Params: []string{"param1", "param2"}}, + {Name: "dir2", Params: []string{"param1"}}, + }, + }, + {Name: "block2", Params: []string{}, Children: Block{}}, + {Name: "block3", Params: []string{}, Children: Block{}}, + { + Name: "block4", + Params: []string{"param1", "param2"}, + Children: Block{ + {Name: "dir1", Params: []string{}}, + }, + }, + }, + }, + { + name: "nested", + src: `block1 { + block2 { + dir1 param1 + } + + block3 { + } +} + +block4 { + block5 { + block6 param1 { + dir1 + } + } + + dir1 +}`, + want: Block{ + { + Name: "block1", + Params: []string{}, + Children: Block{ + { + Name: "block2", + Params: []string{}, + Children: Block{ + {Name: "dir1", Params: []string{"param1"}}, + }, + }, + { + Name: "block3", + Params: []string{}, + Children: Block{}, + }, + }, + }, + { + Name: "block4", + Params: []string{}, + Children: Block{ + { + Name: "block5", + Params: []string{}, + Children: Block{{ + Name: "block6", + Params: []string{"param1"}, + Children: Block{{ + Name: "dir1", + Params: []string{}, + }}, + }}, + }, + { + Name: "dir1", + Params: []string{}, + }, + }, + }, + }, + }, + { + name: "quotes", + src: `"a \b ' \" c" 'd \e \' " f' a\"b`, + want: Block{ + {Name: "a b ' \" c", Params: []string{"d e ' \" f", "a\"b"}}, + }, + }, + { + name: "quotes-2", + src: `dir arg1 "arg2" ` + `\"\"`, + want: Block{ + {Name: "dir", Params: []string{"arg1", "arg2", "\"\""}}, + }, + }, + { + name: "quotes-3", + src: `dir arg1 "\"\"\"\"" arg3`, + want: Block{ + {Name: "dir", Params: []string{"arg1", `"` + "\"\"" + `"`, "arg3"}}, + }, + }, +} + +func TestRead(t *testing.T) { + for _, test := range readTests { + t.Run(test.name, func(t *testing.T) { + r := strings.NewReader(test.src) + got, err := Read(r) + if err != nil { + t.Fatalf("Read() = %v", err) + } + stripLineno(got) + if !reflect.DeepEqual(got, test.want) { + t.Error(fmt.Sprintf("Read() = %#v but want %#v", got, test.want)) + } + }) + } +} + +func stripLineno(block Block) { + for _, dir := range block { + dir.lineno = 0 + stripLineno(dir.Children) + } +} diff --git a/scfg/scfg.go b/scfg/scfg.go new file mode 100644 index 0000000..6f20b49 --- /dev/null +++ b/scfg/scfg.go @@ -0,0 +1,55 @@ +// Package scfg parses and formats configuration files. +package scfg + +import ( + "fmt" +) + +// Block is a list of directives. +type Block []*Directive + +// GetAll returns a list of directives with the provided name. +func (blk Block) GetAll(name string) []*Directive { + l := make([]*Directive, 0, len(blk)) + for _, child := range blk { + if child.Name == name { + l = append(l, child) + } + } + return l +} + +// Get returns the first directive with the provided name. +func (blk Block) Get(name string) *Directive { + for _, child := range blk { + if child.Name == name { + return child + } + } + return nil +} + +// Directive is a configuration directive. +type Directive struct { + Name string + Params []string + + Children Block + + lineno int +} + +// ParseParams extracts parameters from the directive. It errors out if the +// user hasn't provided enough parameters. +func (d *Directive) ParseParams(params ...*string) error { + if len(d.Params) < len(params) { + return fmt.Errorf("directive %q: want %v params, got %v", d.Name, len(params), len(d.Params)) + } + for i, ptr := range params { + if ptr == nil { + continue + } + *ptr = d.Params[i] + } + return nil +} diff --git a/scfg/struct.go b/scfg/struct.go new file mode 100644 index 0000000..8b39e34 --- /dev/null +++ b/scfg/struct.go @@ -0,0 +1,79 @@ +package scfg + +import ( + "fmt" + "reflect" + "strings" + "sync" +) + +// structInfo contains scfg metadata for structs. +type structInfo struct { + param int // index of field storing parameters + children map[string]int // indices of fields storing child directives +} + +var ( + structCacheMutex sync.Mutex + structCache = make(map[reflect.Type]*structInfo) +) + +func getStructInfo(t reflect.Type) (*structInfo, error) { + structCacheMutex.Lock() + defer structCacheMutex.Unlock() + + if info := structCache[t]; info != nil { + return info, nil + } + + info := &structInfo{ + param: -1, + children: make(map[string]int), + } + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Anonymous { + return nil, fmt.Errorf("scfg: anonymous struct fields are not supported") + } else if !f.IsExported() { + continue + } + + tag := f.Tag.Get("scfg") + parts := strings.Split(tag, ",") + k, options := parts[0], parts[1:] + if k == "-" { + continue + } else if k == "" { + k = f.Name + } + + isParam := false + for _, opt := range options { + switch opt { + case "param": + isParam = true + default: + return nil, fmt.Errorf("scfg: invalid option %q in struct tag", opt) + } + } + + if isParam { + if info.param >= 0 { + return nil, fmt.Errorf("scfg: param option specified multiple times in struct tag in %v", t) + } + if parts[0] != "" { + return nil, fmt.Errorf("scfg: name must be empty when param option is specified in struct tag in %v", t) + } + info.param = i + } else { + if _, ok := info.children[k]; ok { + return nil, fmt.Errorf("scfg: key %q specified multiple times in struct tag in %v", k, t) + } + info.children[k] = i + } + } + + structCache[t] = info + return info, nil +} diff --git a/scfg/unmarshal.go b/scfg/unmarshal.go new file mode 100644 index 0000000..a99910f --- /dev/null +++ b/scfg/unmarshal.go @@ -0,0 +1,349 @@ +package scfg + +import ( + "encoding" + "fmt" + "io" + "reflect" + "strconv" +) + +// Decoder reads and decodes an scfg document from an input stream. +type Decoder struct { + r io.Reader +} + +// NewDecoder returns a new decoder which reads from r. +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{r} +} + +// Decode reads scfg document from the input and stores it in the value pointed +// to by v. +// +// If v is nil or not a pointer, Decode returns an error. +// +// Blocks can be unmarshaled to: +// +// - Maps. Each directive is unmarshaled into a map entry. The map key must +// be a string. +// - Structs. Each directive is unmarshaled into a struct field. +// +// Duplicate directives are not allowed, unless the struct field or map value +// is a slice of values representing a directive: structs or maps. +// +// Directives can be unmarshaled to: +// +// - Maps. The children block is unmarshaled into the map. Parameters are not +// allowed. +// - Structs. The children block is unmarshaled into the struct. Parameters +// are allowed if one of the struct fields contains the "param" option in +// its tag. +// - Slices. Parameters are unmarshaled into the slice. Children blocks are +// not allowed. +// - Arrays. Parameters are unmarshaled into the array. The number of +// parameters must match exactly the length of the array. Children blocks +// are not allowed. +// - Strings, booleans, integers, floating-point values, values implementing +// encoding.TextUnmarshaler. Only a single parameter is allowed and is +// unmarshaled into the value. Children blocks are not allowed. +// +// The decoding of each struct field can be customized by the format string +// stored under the "scfg" key in the struct field's tag. The tag contains the +// name of the field possibly followed by a comma-separated list of options. +// The name may be empty in order to specify options without overriding the +// default field name. As a special case, if the field name is "-", the field +// is ignored. The "param" option specifies that directive parameters are +// stored in this field (the name must be empty). +func (dec *Decoder) Decode(v interface{}) error { + block, err := Read(dec.r) + if err != nil { + return err + } + + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("scfg: invalid value for unmarshaling") + } + + return unmarshalBlock(block, rv) +} + +func unmarshalBlock(block Block, v reflect.Value) error { + v = unwrapPointers(v) + t := v.Type() + + dirsByName := make(map[string][]*Directive, len(block)) + for _, dir := range block { + dirsByName[dir.Name] = append(dirsByName[dir.Name], dir) + } + + switch v.Kind() { + case reflect.Map: + if t.Key().Kind() != reflect.String { + return fmt.Errorf("scfg: map key type must be string") + } + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } else if v.Len() > 0 { + clearMap(v) + } + + for name, dirs := range dirsByName { + mv := reflect.New(t.Elem()).Elem() + if err := unmarshalDirectiveList(dirs, mv); err != nil { + return err + } + v.SetMapIndex(reflect.ValueOf(name), mv) + } + case reflect.Struct: + si, err := getStructInfo(t) + if err != nil { + return err + } + + for name, dirs := range dirsByName { + fieldIndex, ok := si.children[name] + if !ok { + return newUnmarshalDirectiveError(dirs[0], "unknown directive") + } + fv := v.Field(fieldIndex) + if err := unmarshalDirectiveList(dirs, fv); err != nil { + return err + } + } + default: + return fmt.Errorf("scfg: unsupported type for unmarshaling blocks: %v", t) + } + + return nil +} + +func unmarshalDirectiveList(dirs []*Directive, v reflect.Value) error { + v = unwrapPointers(v) + t := v.Type() + + if v.Kind() != reflect.Slice || !isDirectiveType(t.Elem()) { + if len(dirs) > 1 { + return newUnmarshalDirectiveError(dirs[1], "directive must not be specified more than once") + } + return unmarshalDirective(dirs[0], v) + } + + sv := reflect.MakeSlice(t, len(dirs), len(dirs)) + for i, dir := range dirs { + if err := unmarshalDirective(dir, sv.Index(i)); err != nil { + return err + } + } + v.Set(sv) + return nil +} + +// isDirectiveType checks whether a type can only be unmarshaled as a +// directive, not as a parameter. Accepting too many types here would result in +// ambiguities, see: +// https://lists.sr.ht/~emersion/public-inbox/%3C20230629132458.152205-1-contact%40emersion.fr%3E#%3Ch4Y2peS_YBqY3ar4XlmPDPiNBFpYGns3EBYUx3_6zWEhV2o8_-fBQveRujGADWYhVVCucHBEryFGoPtpC3d3mQ-x10pWnFogfprbQTSvtxc=@emersion.fr%3E +func isDirectiveType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + textUnmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + if reflect.PtrTo(t).Implements(textUnmarshalerType) { + return false + } + + switch t.Kind() { + case reflect.Struct, reflect.Map: + return true + default: + return false + } +} + +func unmarshalDirective(dir *Directive, v reflect.Value) error { + v = unwrapPointers(v) + t := v.Type() + + if v.CanAddr() { + if _, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok { + if len(dir.Children) != 0 { + return newUnmarshalDirectiveError(dir, "directive requires zero children") + } + return unmarshalParamList(dir, v) + } + } + + switch v.Kind() { + case reflect.Map: + if len(dir.Params) > 0 { + return newUnmarshalDirectiveError(dir, "directive requires zero parameters") + } + if err := unmarshalBlock(dir.Children, v); err != nil { + return err + } + case reflect.Struct: + si, err := getStructInfo(t) + if err != nil { + return err + } + + if si.param >= 0 { + if err := unmarshalParamList(dir, v.Field(si.param)); err != nil { + return err + } + } else { + if len(dir.Params) > 0 { + return newUnmarshalDirectiveError(dir, "directive requires zero parameters") + } + } + + if err := unmarshalBlock(dir.Children, v); err != nil { + return err + } + default: + if len(dir.Children) != 0 { + return newUnmarshalDirectiveError(dir, "directive requires zero children") + } + if err := unmarshalParamList(dir, v); err != nil { + return err + } + } + return nil +} + +func unmarshalParamList(dir *Directive, v reflect.Value) error { + switch v.Kind() { + case reflect.Slice: + t := v.Type() + sv := reflect.MakeSlice(t, len(dir.Params), len(dir.Params)) + for i, param := range dir.Params { + if err := unmarshalParam(param, sv.Index(i)); err != nil { + return newUnmarshalParamError(dir, i, err) + } + } + v.Set(sv) + case reflect.Array: + if len(dir.Params) != v.Len() { + return newUnmarshalDirectiveError(dir, fmt.Sprintf("directive requires exactly %v parameters", v.Len())) + } + for i, param := range dir.Params { + if err := unmarshalParam(param, v.Index(i)); err != nil { + return newUnmarshalParamError(dir, i, err) + } + } + default: + if len(dir.Params) != 1 { + return newUnmarshalDirectiveError(dir, "directive requires exactly one parameter") + } + if err := unmarshalParam(dir.Params[0], v); err != nil { + return newUnmarshalParamError(dir, 0, err) + } + } + + return nil +} + +func unmarshalParam(param string, v reflect.Value) error { + v = unwrapPointers(v) + t := v.Type() + + // TODO: improve our logic following: + // https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/encoding/json/decode.go;drc=b9b8cecbfc72168ca03ad586cc2ed52b0e8db409;l=421 + if v.CanAddr() { + if v, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok { + return v.UnmarshalText([]byte(param)) + } + } + + switch v.Kind() { + case reflect.String: + v.Set(reflect.ValueOf(param)) + case reflect.Bool: + switch param { + case "true": + v.Set(reflect.ValueOf(true)) + case "false": + v.Set(reflect.ValueOf(false)) + default: + return fmt.Errorf("invalid bool parameter %q", param) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, err := strconv.ParseInt(param, 10, t.Bits()) + if err != nil { + return fmt.Errorf("invalid %v parameter: %v", t, err) + } + v.Set(reflect.ValueOf(i).Convert(t)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + u, err := strconv.ParseUint(param, 10, t.Bits()) + if err != nil { + return fmt.Errorf("invalid %v parameter: %v", t, err) + } + v.Set(reflect.ValueOf(u).Convert(t)) + case reflect.Float32, reflect.Float64: + f, err := strconv.ParseFloat(param, t.Bits()) + if err != nil { + return fmt.Errorf("invalid %v parameter: %v", t, err) + } + v.Set(reflect.ValueOf(f).Convert(t)) + default: + return fmt.Errorf("unsupported type for unmarshaling parameter: %v", t) + } + + return nil +} + +func unwrapPointers(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + return v +} + +func clearMap(v reflect.Value) { + for _, k := range v.MapKeys() { + v.SetMapIndex(k, reflect.Value{}) + } +} + +type unmarshalDirectiveError struct { + lineno int + name string + msg string +} + +func newUnmarshalDirectiveError(dir *Directive, msg string) *unmarshalDirectiveError { + return &unmarshalDirectiveError{ + name: dir.Name, + lineno: dir.lineno, + msg: msg, + } +} + +func (err *unmarshalDirectiveError) Error() string { + return fmt.Sprintf("line %v, directive %q: %v", err.lineno, err.name, err.msg) +} + +type unmarshalParamError struct { + lineno int + directive string + paramIndex int + err error +} + +func newUnmarshalParamError(dir *Directive, paramIndex int, err error) *unmarshalParamError { + return &unmarshalParamError{ + directive: dir.Name, + lineno: dir.lineno, + paramIndex: paramIndex, + err: err, + } +} + +func (err *unmarshalParamError) Error() string { + return fmt.Sprintf("line %v, directive %q, parameter %v: %v", err.lineno, err.directive, err.paramIndex+1, err.err) +} diff --git a/scfg/unmarshal_test.go b/scfg/unmarshal_test.go new file mode 100644 index 0000000..5b8df32 --- /dev/null +++ b/scfg/unmarshal_test.go @@ -0,0 +1,251 @@ +package scfg_test + +import ( + "fmt" + "log" + "reflect" + "strings" + "testing" + + "go.lindenii.runxiyu.org/lindenii-common/scfg" +) + +func ExampleDecoder() { + var data struct { + Foo int `scfg:"foo"` + Bar struct { + Param string `scfg:",param"` + Baz string `scfg:"baz"` + } `scfg:"bar"` + } + + raw := `foo 42 +bar asdf { + baz hello +} +` + + r := strings.NewReader(raw) + if err := scfg.NewDecoder(r).Decode(&data); err != nil { + log.Fatal(err) + } + + fmt.Printf("Foo = %v\n", data.Foo) + fmt.Printf("Bar.Param = %v\n", data.Bar.Param) + fmt.Printf("Bar.Baz = %v\n", data.Bar.Baz) + + // Output: + // Foo = 42 + // Bar.Param = asdf + // Bar.Baz = hello +} + +type nestedStructInner struct { + Bar string `scfg:"bar"` +} + +type structParams struct { + Params []string `scfg:",param"` + Bar string +} + +type textUnmarshaler struct { + text string +} + +func (tu *textUnmarshaler) UnmarshalText(text []byte) error { + tu.text = string(text) + return nil +} + +type textUnmarshalerParams struct { + Params []textUnmarshaler `scfg:",param"` +} + +var barStr = "bar" + +var unmarshalTests = []struct { + name string + raw string + want interface{} +}{ + { + name: "stringMap", + raw: `hello world +foo bar`, + want: map[string]string{ + "hello": "world", + "foo": "bar", + }, + }, + { + name: "simpleStruct", + raw: `MyString asdf +MyBool true +MyInt -42 +MyUint 42 +MyFloat 3.14`, + want: struct { + MyString string + MyBool bool + MyInt int + MyUint uint + MyFloat float32 + }{ + MyString: "asdf", + MyBool: true, + MyInt: -42, + MyUint: 42, + MyFloat: 3.14, + }, + }, + { + name: "simpleStructTag", + raw: `foo bar`, + want: struct { + Foo string `scfg:"foo"` + }{ + Foo: "bar", + }, + }, + { + name: "sliceParams", + raw: `Foo a s d f`, + want: struct { + Foo []string + }{ + Foo: []string{"a", "s", "d", "f"}, + }, + }, + { + name: "arrayParams", + raw: `Foo a s d f`, + want: struct { + Foo [4]string + }{ + Foo: [4]string{"a", "s", "d", "f"}, + }, + }, + { + name: "pointers", + raw: `Foo bar`, + want: struct { + Foo *string + }{ + Foo: &barStr, + }, + }, + { + name: "nestedMap", + raw: `foo { + bar baz +}`, + want: struct { + Foo map[string]string `scfg:"foo"` + }{ + Foo: map[string]string{"bar": "baz"}, + }, + }, + { + name: "nestedStruct", + raw: `foo { + bar baz +}`, + want: struct { + Foo nestedStructInner `scfg:"foo"` + }{ + Foo: nestedStructInner{ + Bar: "baz", + }, + }, + }, + { + name: "structParams", + raw: `Foo param1 param2 { + Bar baz +}`, + want: struct { + Foo structParams + }{ + Foo: structParams{ + Params: []string{"param1", "param2"}, + Bar: "baz", + }, + }, + }, + { + name: "textUnmarshaler", + raw: `Foo param1 +Bar param2 +Baz param3`, + want: struct { + Foo []textUnmarshaler + Bar *textUnmarshaler + Baz textUnmarshalerParams + }{ + Foo: []textUnmarshaler{{"param1"}}, + Bar: &textUnmarshaler{"param2"}, + Baz: textUnmarshalerParams{ + Params: []textUnmarshaler{{"param3"}}, + }, + }, + }, + { + name: "directiveStructSlice", + raw: `Foo param1 param2 { + Bar baz +} +Foo param3 param4`, + want: struct { + Foo []structParams + }{ + Foo: []structParams{ + { + Params: []string{"param1", "param2"}, + Bar: "baz", + }, + { + Params: []string{"param3", "param4"}, + }, + }, + }, + }, + { + name: "directiveMapSlice", + raw: `Foo { + key1 param1 +} +Foo { + key2 param2 +}`, + want: struct { + Foo []map[string]string + }{ + Foo: []map[string]string{ + {"key1": "param1"}, + {"key2": "param2"}, + }, + }, + }, +} + +func TestUnmarshal(t *testing.T) { + for _, tc := range unmarshalTests { + tc := tc // capture variable + t.Run(tc.name, func(t *testing.T) { + testUnmarshal(t, tc.raw, tc.want) + }) + } +} + +func testUnmarshal(t *testing.T, raw string, want interface{}) { + out := reflect.New(reflect.TypeOf(want)) + r := strings.NewReader(raw) + if err := scfg.NewDecoder(r).Decode(out.Interface()); err != nil { + t.Fatalf("Decode() = %v", err) + } + got := out.Elem().Interface() + if !reflect.DeepEqual(got, want) { + t.Errorf("Decode() = \n%#v\n but want \n%#v", got, want) + } +} diff --git a/scfg/writer.go b/scfg/writer.go new file mode 100644 index 0000000..97148a3 --- /dev/null +++ b/scfg/writer.go @@ -0,0 +1,107 @@ +package scfg + +import ( + "errors" + "io" + "strings" +) + +var ( + errDirEmptyName = errors.New("scfg: directive with empty name") +) + +// Write writes a parsed configuration to the provided io.Writer. +func Write(w io.Writer, blk Block) error { + enc := newEncoder(w) + err := enc.encodeBlock(blk) + return err +} + +// encoder write SCFG directives to an output stream. +type encoder struct { + w io.Writer + lvl int + err error +} + +// newEncoder returns a new encoder that writes to w. +func newEncoder(w io.Writer) *encoder { + return &encoder{w: w} +} + +func (enc *encoder) push() { + enc.lvl++ +} + +func (enc *encoder) pop() { + enc.lvl-- +} + +func (enc *encoder) writeIndent() { + for i := 0; i < enc.lvl; i++ { + enc.write([]byte("\t")) + } +} + +func (enc *encoder) write(p []byte) { + if enc.err != nil { + return + } + _, enc.err = enc.w.Write(p) +} + +func (enc *encoder) encodeBlock(blk Block) error { + for _, dir := range blk { + enc.encodeDir(*dir) + } + return enc.err +} + +func (enc *encoder) encodeDir(dir Directive) error { + if enc.err != nil { + return enc.err + } + + if dir.Name == "" { + enc.err = errDirEmptyName + return enc.err + } + + enc.writeIndent() + enc.write([]byte(maybeQuote(dir.Name))) + for _, p := range dir.Params { + enc.write([]byte(" ")) + enc.write([]byte(maybeQuote(p))) + } + + if len(dir.Children) > 0 { + enc.write([]byte(" {\n")) + enc.push() + enc.encodeBlock(dir.Children) + enc.pop() + + enc.writeIndent() + enc.write([]byte("}")) + } + enc.write([]byte("\n")) + + return enc.err +} + +const specialChars = "\"\\\r\n'{} \t" + +func maybeQuote(s string) string { + if s == "" || strings.ContainsAny(s, specialChars) { + var sb strings.Builder + sb.WriteByte('"') + for _, ch := range s { + if strings.ContainsRune(`"\`, ch) { + sb.WriteByte('\\') + } + sb.WriteRune(ch) + } + sb.WriteByte('"') + return sb.String() + } + return s +} diff --git a/scfg/writer_test.go b/scfg/writer_test.go new file mode 100644 index 0000000..b27d513 --- /dev/null +++ b/scfg/writer_test.go @@ -0,0 +1,158 @@ +package scfg + +import ( + "bytes" + "testing" +) + +func TestWrite(t *testing.T) { + for _, tc := range []struct { + src Block + want string + err error + }{ + { + src: Block{}, + want: "", + }, + { + src: Block{{ + Name: "dir", + Children: Block{{ + Name: "blk1", + Params: []string{"p1", `"p2"`}, + Children: Block{ + { + Name: "sub1", + Params: []string{"arg11", "arg12"}, + }, + { + Name: "sub2", + Params: []string{"arg21", "arg22"}, + }, + { + Name: "sub3", + Params: []string{"arg31", "arg32"}, + Children: Block{ + { + Name: "sub-sub1", + }, + { + Name: "sub-sub2", + Params: []string{"arg321", "arg322"}, + }, + }, + }, + }, + }}, + }}, + want: `dir { + blk1 p1 "\"p2\"" { + sub1 arg11 arg12 + sub2 arg21 arg22 + sub3 arg31 arg32 { + sub-sub1 + sub-sub2 arg321 arg322 + } + } +} +`, + }, + { + src: Block{{Name: "dir1"}}, + want: "dir1\n", + }, + { + src: Block{{Name: "dir\"1"}}, + want: "\"dir\\\"1\"\n", + }, + { + src: Block{{Name: "dir'1"}}, + want: "\"dir'1\"\n", + }, + { + src: Block{{Name: "dir:}"}}, + want: "\"dir:}\"\n", + }, + { + src: Block{{Name: "dir:{"}}, + want: "\"dir:{\"\n", + }, + { + src: Block{{Name: "dir\t1"}}, + want: `"dir` + "\t" + `1"` + "\n", + }, + { + src: Block{{Name: "dir 1"}}, + want: "\"dir 1\"\n", + }, + { + src: Block{{Name: "dir1", Params: []string{"arg1", "arg2", `"arg3"`}}}, + want: "dir1 arg1 arg2 " + `"\"arg3\""` + "\n", + }, + { + src: Block{{Name: "dir1", Params: []string{"arg1", "arg 2", "arg'3"}}}, + want: "dir1 arg1 \"arg 2\" \"arg'3\"\n", + }, + { + src: Block{{Name: "dir1", Params: []string{"arg1", "", "arg3"}}}, + want: "dir1 arg1 \"\" arg3\n", + }, + { + src: Block{{Name: "dir1", Params: []string{"arg1", `"` + "\"\"" + `"`, "arg3"}}}, + want: "dir1 arg1 " + `"\"\"\"\""` + " arg3\n", + }, + { + src: Block{{ + Name: "dir1", + Children: Block{ + {Name: "sub1"}, + {Name: "sub2", Params: []string{"arg1", "arg2"}}, + }, + }}, + want: `dir1 { + sub1 + sub2 arg1 arg2 +} +`, + }, + { + src: Block{{Name: ""}}, + err: errDirEmptyName, + }, + { + src: Block{{ + Name: "dir", + Children: Block{ + {Name: "sub1"}, + {Name: "", Children: Block{{Name: "sub21"}}}, + }, + }}, + err: errDirEmptyName, + }, + } { + t.Run("", func(t *testing.T) { + var buf bytes.Buffer + err := Write(&buf, tc.src) + switch { + case err != nil && tc.err != nil: + if got, want := err.Error(), tc.err.Error(); got != want { + t.Fatalf("invalid error:\ngot= %q\nwant=%q", got, want) + } + return + case err == nil && tc.err != nil: + t.Fatalf("got err=nil, want=%q", tc.err.Error()) + case err != nil && tc.err == nil: + t.Fatalf("could not marshal: %+v", err) + case err == nil && tc.err == nil: + // ok. + } + if got, want := buf.String(), tc.want; got != want { + t.Fatalf( + "invalid marshal representation:\ngot:\n%s\nwant:\n%s\n---", + got, want, + ) + } + }) + } +} |