chore: structure unifies the way to handle top-level and sub structs

This commit is contained in:
wwqgtxx
2026-02-10 11:42:44 +08:00
parent c3399fd346
commit 9fda032a28
4 changed files with 81 additions and 65 deletions

View File

@@ -14,7 +14,7 @@ type RealityOptions struct {
PublicKey string `proxy:"public-key"`
ShortID string `proxy:"short-id"`
SupportX25519MLKEM768 bool `proxy:"support-x25519mlkem768"`
SupportX25519MLKEM768 bool `proxy:"support-x25519mlkem768,omitempty"`
}
func (o RealityOptions) Parse() (*tlsC.RealityConfig, error) {

View File

@@ -77,8 +77,8 @@ type WireGuardOption struct {
}
type WireGuardPeerOption struct {
Server string `proxy:"server"`
Port int `proxy:"port"`
Server string `proxy:"server,omitempty"`
Port int `proxy:"port,omitempty"`
PublicKey string `proxy:"public-key,omitempty"`
PreSharedKey string `proxy:"pre-shared-key,omitempty"`
Reserved []uint8 `proxy:"reserved,omitempty"`

View File

@@ -7,6 +7,7 @@ import (
"encoding/base64"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
)
@@ -38,58 +39,7 @@ func (d *Decoder) Decode(src map[string]any, dst any) error {
if reflect.TypeOf(dst).Kind() != reflect.Ptr {
return fmt.Errorf("decode must recive a ptr struct")
}
t := reflect.TypeOf(dst).Elem()
v := reflect.ValueOf(dst).Elem()
for idx := 0; idx < v.NumField(); idx++ {
field := t.Field(idx)
if field.Anonymous {
if err := d.decodeStruct(field.Name, src, v.Field(idx)); err != nil {
return err
}
continue
}
tag := field.Tag.Get(d.option.TagName)
key, omitKey, found := strings.Cut(tag, ",")
omitempty := found && omitKey == "omitempty"
// As a special case, if the field tag is "-", the field is always omitted.
// Note that a field with name "-" can still be generated using the tag "-,".
if key == "-" {
continue
}
value, ok := src[key]
if !ok {
if d.option.KeyReplacer != nil {
key = d.option.KeyReplacer.Replace(key)
}
for _strKey := range src {
strKey := _strKey
if d.option.KeyReplacer != nil {
strKey = d.option.KeyReplacer.Replace(strKey)
}
if strings.EqualFold(key, strKey) {
value = src[_strKey]
ok = true
break
}
}
}
if !ok || value == nil {
if omitempty {
continue
}
return fmt.Errorf("key '%s' missing", key)
}
err := d.decode(key, value, v.Field(idx))
if err != nil {
return err
}
}
return nil
return d.decode("", src, reflect.ValueOf(dst).Elem())
}
// isNil returns true if the input is nil or a typed nil pointer.
@@ -456,6 +406,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
dataValKeysUnused[dataValKey.Interface()] = struct{}{}
}
targetValKeysUnused := make(map[any]struct{})
errors := make([]string, 0)
// This slice will keep track of all the structs we'll be decoding.
@@ -479,10 +430,16 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
for i := 0; i < structType.NumField(); i++ {
fieldType := structType.Field(i)
fieldKind := fieldType.Type.Kind()
fieldVal := structVal.Field(i)
if fieldVal.Kind() == reflect.Ptr && fieldVal.Elem().Kind() == reflect.Struct {
// Handle embedded struct pointers as embedded structs.
fieldVal = fieldVal.Elem()
}
// If "squash" is specified in the tag, we squash the field down.
squash := false
squash := fieldVal.Kind() == reflect.Struct && fieldType.Anonymous
// We always parse the tags cause we're looking for other tags too
tagParts := strings.Split(fieldType.Tag.Get(d.option.TagName), ",")
for _, tag := range tagParts[1:] {
if tag == "squash" {
@@ -492,17 +449,17 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
}
if squash {
if fieldKind != reflect.Struct {
if fieldVal.Kind() != reflect.Struct {
errors = append(errors,
fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldKind).Error())
fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind()).Error())
} else {
structs = append(structs, structVal.FieldByName(fieldType.Name))
structs = append(structs, fieldVal)
}
continue
}
// Normal struct field, store it away
fields = append(fields, field{fieldType, structVal.Field(i)})
fields = append(fields, field{fieldType, fieldVal})
}
}
@@ -511,8 +468,8 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
field, fieldValue := f.field, f.val
fieldName := field.Name
tagValue := field.Tag.Get(d.option.TagName)
tagValue = strings.SplitN(tagValue, ",", 2)[0]
tagParts := strings.Split(field.Tag.Get(d.option.TagName), ",")
tagValue := tagParts[0]
if tagValue != "" {
fieldName = tagValue
}
@@ -521,6 +478,13 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
continue
}
omitempty := false
for _, tag := range tagParts[1:] {
if tag == "omitempty" {
omitempty = true
}
}
rawMapKey := reflect.ValueOf(fieldName)
rawMapVal := dataVal.MapIndex(rawMapKey)
if !rawMapVal.IsValid() {
@@ -548,7 +512,10 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
if !rawMapVal.IsValid() {
// There was no matching key in the map for the value in
// the struct. Just ignore.
// the struct. Remember it for potential errors and metadata.
if !omitempty {
targetValKeysUnused[fieldName] = struct{}{}
}
continue
}
}
@@ -570,7 +537,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
// If the name is empty string, then we're at the root, and we
// don't dot-join the fields.
if name != "" {
fieldName = fmt.Sprintf("%s.%s", name, fieldName)
fieldName = name + "." + fieldName
}
if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil {
@@ -578,6 +545,17 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
}
}
if len(targetValKeysUnused) > 0 {
keys := make([]string, 0, len(targetValKeysUnused))
for rawKey := range targetValKeysUnused {
keys = append(keys, rawKey.(string))
}
sort.Strings(keys)
err := fmt.Errorf("'%s' has unset fields: %s", name, strings.Join(keys, ", "))
errors = append(errors, err.Error())
}
if len(errors) > 0 {
return fmt.Errorf(strings.Join(errors, ","))
}

View File

@@ -139,6 +139,27 @@ func TestStructure_Nest(t *testing.T) {
assert.Equal(t, s.BazOptional, goal)
}
func TestStructure_DoubleNest(t *testing.T) {
rawMap := map[string]any{
"bar": map[string]any{
"foo": 1,
},
}
goal := BazOptional{
Foo: 1,
}
s := &struct {
Bar struct {
BazOptional
} `test:"bar"`
}{}
err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.Equal(t, s.Bar.BazOptional, goal)
}
func TestStructure_SliceNilValue(t *testing.T) {
rawMap := map[string]any{
"foo": 1,
@@ -228,6 +249,23 @@ func TestStructure_Pointer(t *testing.T) {
assert.Nil(t, s.Bar)
}
func TestStructure_PointerStruct(t *testing.T) {
rawMap := map[string]any{
"foo": "foo",
}
s := &struct {
Foo *string `test:"foo,omitempty"`
Bar *Baz `test:"bar,omitempty"`
}{}
err := decoder.Decode(rawMap, s)
assert.Nil(t, err)
assert.NotNil(t, s.Foo)
assert.Equal(t, "foo", *s.Foo)
assert.Nil(t, s.Bar)
}
type num struct {
a int
}