diff --git a/decode.go b/decode.go index f4f9130..35b7d35 100644 --- a/decode.go +++ b/decode.go @@ -95,6 +95,27 @@ func (d *decoder) value(val reflect.Value) error { if elemField.CanAddr() { err = d.value(elemField.Addr()) } + } else if typField.Type.Kind() == reflect.Ptr { + recQuery := make(url.Values) + pref := qstring + "." + for k, v := range d.data { + if strings.HasPrefix(k, pref) { + key := strings.Replace(k, pref, "", 1) + recQuery[key] = v + } + } + if len(recQuery) > 0 { + if elemField.IsNil() { + elemField.Set(reflect.New(typField.Type.Elem())) + elemField = elemField.Elem() + } + temp := d.data + d.data = recQuery + if elemField.CanAddr() { + err = d.value(elemField.Addr()) + } + d.data = temp + } } if err != nil { return err @@ -152,6 +173,12 @@ func (d *decoder) coerce(query string, target reflect.Kind, field reflect.Value) if err == nil { field.Set(reflect.ValueOf(t)) } + case ComparativeString: + s := ComparativeString{} + err = s.Parse(query) + if err == nil { + field.Set(reflect.ValueOf(s)) + } default: d.value(field) } diff --git a/decode_test.go b/decode_test.go index 1c8b04e..1be8fd2 100644 --- a/decode_test.go +++ b/decode_test.go @@ -212,3 +212,36 @@ func TestUnmarshaller(t *testing.T) { } } } + +func TestUnmarshalEmbeddedStruct(t *testing.T) { + testIO := []struct { + inp url.Values + err interface{} + expected *RecursiveStruct + }{ + { + url.Values{"object.value": []string{"embedded-example"}, "value": []string{"example"},}, + nil, + &RecursiveStruct{ + Object: &RecursiveStruct{ + Value: "embedded-example", + }, + Value: "example", + }, + }, + } + s := &RecursiveStruct{} + for _, test := range testIO { + err := Unmarshal(test.inp, s) + if err != test.err { + t.Errorf("Expected Unmarshaller to return %s, but got %s instead", test.err, err) + } + if !(test.expected.Value == s.Value) { + t.Errorf("Expected Unmarshaller to return %s, but got %s instead", test.expected.Value, s.Value) + } + if !(test.expected.Object.Value == s.Object.Value) { + t.Errorf("Expected Unmarshaller to return %s, but got %s instead", test.expected.Object.Value, s.Object.Value) + } + } + +} diff --git a/doc_test.go b/doc_test.go index 7acbcac..1442c76 100644 --- a/doc_test.go +++ b/doc_test.go @@ -1,12 +1,10 @@ -package qstring_test +package qstring import ( "fmt" "net/url" "os" "time" - - "github.com/dyninc/qstring" ) func ExampleUnmarshal() { @@ -19,7 +17,7 @@ func ExampleUnmarshal() { query := &Query{} qValues, _ := url.ParseQuery("names=foo&names=bar&limit=50&page=1") - err := qstring.Unmarshal(qValues, query) + err := Unmarshal(qValues, query) if err != nil { panic("Unable to Parse Query String") } @@ -41,7 +39,7 @@ func ExampleMarshalString() { Limit: 50, Page: 1, } - q, _ := qstring.MarshalString(query) + q, _ := MarshalString(query) os.Stdout.Write([]byte(q)) // Output: limit=50&names=foo&names=bar&page=1 } @@ -62,7 +60,7 @@ func ExampleUnmarshal_complex() { } query := &Query{} qValues, _ := url.ParseQuery("names=foo&names=bar&limit=50&page=1&ids=1&ids=2&created=2006-01-02T15:04:05Z") - err := qstring.Unmarshal(qValues, query) + err := Unmarshal(qValues, query) if err != nil { panic("Unable to Parse Query String") } @@ -70,13 +68,13 @@ func ExampleUnmarshal_complex() { func ExampleComparativeTime() { type DateQuery struct { - Created qstring.ComparativeTime - Modified qstring.ComparativeTime + Created ComparativeTime + Modified ComparativeTime } var query DateQuery qValues, _ := url.ParseQuery("created=>=2006-01-02T15:04:05Z&modified=<=2016-01-01T15:04Z") - err := qstring.Unmarshal(qValues, &query) + err := Unmarshal(qValues, &query) if err != nil { panic("Unable to Parse Query String") } diff --git a/encode.go b/encode.go index 9bc437a..300614b 100644 --- a/encode.go +++ b/encode.go @@ -1,6 +1,8 @@ package qstring import ( + "encoding" + "fmt" "net/url" "reflect" "strconv" @@ -71,6 +73,9 @@ func (e *encoder) marshal() (url.Values, error) { } } +var textMarshallerElem = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() +var stringerElem = reflect.TypeOf(new(fmt.Stringer)).Elem() + func (e *encoder) value(val reflect.Value) (url.Values, error) { elem := val.Elem() typ := elem.Type() @@ -93,16 +98,21 @@ func (e *encoder) value(val reflect.Value) (url.Values, error) { continue } + // verify if the element type implements compatible interfaces + if val, ok := compatibleInterfaceValue(elemField); ok { + output.Set(qstring, val) + continue + } // only do work if the current fields query string parameter was provided switch k := typField.Type.Kind(); k { - default: - output.Set(qstring, marshalValue(elemField, k)) case reflect.Slice: output[qstring] = marshalSlice(elemField) case reflect.Ptr: marshalStruct(output, qstring, reflect.Indirect(elemField), k) case reflect.Struct: marshalStruct(output, qstring, elemField, k) + default: + output.Set(qstring, marshalValue(elemField, k)) } } return output, err @@ -116,7 +126,21 @@ func marshalSlice(field reflect.Value) []string { return out } +func compatibleInterfaceValue(field reflect.Value) (string, bool) { + if field.Type().Implements(textMarshallerElem) { + byt, _ := field.Interface().(encoding.TextMarshaler).MarshalText() + return string(byt), true + } + if field.Type().Implements(stringerElem) { + return field.Interface().(fmt.Stringer).String(), true + } + return "", false +} + func marshalValue(field reflect.Value, source reflect.Kind) string { + if val, ok := compatibleInterfaceValue(field); ok { + return val + } switch source { case reflect.String: return field.String() @@ -130,6 +154,9 @@ func marshalValue(field reflect.Value, source reflect.Kind) string { return strconv.FormatFloat(field.Float(), 'G', -1, 64) case reflect.Struct: switch field.Interface().(type) { + case encoding.TextMarshaler: + byt, _ := field.Interface().(encoding.TextMarshaler).MarshalText() + return string(byt) case time.Time: return field.Interface().(time.Time).Format(time.RFC3339) case ComparativeTime: @@ -154,6 +181,9 @@ func marshalStruct(output url.Values, qstring string, field reflect.Value, sourc return err } for key, list := range vals { + if qstring != "" { + key = qstring + "." + key + } output[key] = list } } diff --git a/encode_test.go b/encode_test.go index cd12cd5..3796578 100644 --- a/encode_test.go +++ b/encode_test.go @@ -294,3 +294,74 @@ func TestMarshaller(t *testing.T) { } } } + +type MyText []byte +type MyOtherText []byte + +type MyStruct struct { + Text MyText + Other MyOtherText +} + +func (m MyText) MarshalText() ([]byte, error) { + return []byte(m), nil +} +func (m MyOtherText) String() string { + return string(m) +} + +func TestMarshalTextMarshalType(t *testing.T) { + el := MyStruct{Text: MyText("example string"), Other: MyOtherText("second example")} + + result, err := MarshalString(&el) + if err != nil { + t.Fatalf("Unable to marshal type %T: %s", el, err.Error()) + } + + var unescaped string + unescaped, err = url.QueryUnescape(result) + if err != nil { + t.Fatalf("Unable to unescape query string %q: %q", result, err.Error()) + } + + // ensure fields we expect to be present are + expected := []string{"text=example string", "other=second example"} + for _, q := range expected { + if !strings.Contains(unescaped, q) { + t.Errorf("Expected query string %s to contain %s", unescaped, q) + } + } +} + +type RecursiveStruct struct { + Object *RecursiveStruct `qstring:"object,omitempty"` + Value string `qstring:"value"` +} + +func TestMarshalEmbeddedStruct(t *testing.T) { + rec := RecursiveStruct{ + Value: "example", + Object: &RecursiveStruct{ + Value: "embedded-example", + }, + } + + result, err := MarshalString(&rec) + if err != nil { + t.Fatalf("Unable to marshal type %T: %s", rec, err.Error()) + } + + var unescaped string + unescaped, err = url.QueryUnescape(result) + if err != nil { + t.Fatalf("Unable to unescape query string %q: %q", result, err.Error()) + } + + // ensure fields we expect to be present are + expected := []string{"value=example", "object.value=embedded-example"} + for _, q := range expected { + if !strings.Contains(unescaped, q) { + t.Errorf("Expected query string %s to contain %s", unescaped, q) + } + } +} diff --git a/fields.go b/fields.go index c50a2bd..f593310 100644 --- a/fields.go +++ b/fields.go @@ -7,29 +7,57 @@ import ( "time" ) +const ( + operatorEquals = "=" + operatorGreater = ">" + operatorGreaterEq = ">=" + operatorLesser = "<" + operatorLesserEq = "<=" + operatorLike = "~" + operatorDifferent = "!" +) + // parseOperator parses a leading logical operator out of the provided string func parseOperator(s string) string { + if len(s) == 0 { + return "" + } switch s[0] { - case 60: // "<" + case operatorLesser[0]: // "<" + if 1 == len(s) { + return operatorLesser + } switch s[1] { - case 61: // "=" - return "<=" + case operatorEquals[0]: // "=" + return operatorLesserEq default: - return "<" + return operatorLesser + } + case operatorGreater[0]: // ">" + if 1 == len(s) { + return operatorGreater } - case 62: // ">" switch s[1] { - case 61: // "=" - return ">=" + case operatorEquals[0]: // "=" + return operatorGreaterEq default: - return ">" + return operatorGreater } + case operatorLike[0]: // "~" + return operatorLike + case operatorDifferent[0]: // "!" + return operatorDifferent default: // no operator found, default to "=" - return "=" + return operatorEquals } } +type ComparativeString struct { + Operator string + Str string +} + // ComparativeTime is a field that can be used for specifying a query parameter // which includes a conditional operator and a timestamp type ComparativeTime struct { @@ -70,3 +98,37 @@ func (c *ComparativeTime) Parse(query string) error { func (c ComparativeTime) String() string { return fmt.Sprintf("%s%s", c.Operator, c.Time.Format(time.RFC3339)) } + +// Parse is used to parse a query string into a ComparativeString instance +func (c *ComparativeString) Parse(query string) error { + c.Operator = parseOperator(query) + + if len(c.Operator) > 0 && c.Operator != operatorDifferent && c.Operator != operatorLike && c.Operator != operatorEquals { + return errors.New(fmt.Sprintf("qstring: Invalid operator for %T", c)) + } + if c.Operator == operatorEquals { + c.Operator = "" + } + + // if no operator was provided and we defaulted to an equality operator + if !strings.HasPrefix(query, c.Operator) { + query = fmt.Sprintf("=%s", query) + } + + var err error + c.Str = query[len(c.Operator):] + if err != nil { + return err + } + + return nil +} + +// String returns this ComparativeString instance in the form of the query +// parameter that it came in on +func (c ComparativeString) String() string { + if c.Operator == operatorEquals { + c.Operator = "" + } + return fmt.Sprintf("%s%s", c.Operator, c.Str) +} diff --git a/fields_test.go b/fields_test.go index 76bf82c..a80c697 100644 --- a/fields_test.go +++ b/fields_test.go @@ -1,6 +1,7 @@ package qstring import ( + "fmt" "net/url" "strings" "testing" @@ -108,3 +109,88 @@ func TestComparativeTimeMarshalString(t *testing.T) { } } } + +func TestComparativeStringUnmarshal(t *testing.T) { + type Query struct { + Equals ComparativeString + Similar ComparativeString + Different ComparativeString + } + + val1 := "stringValue1" + equalsVal := fmt.Sprintf("%s", val1) + val2 := "stringValue2" + similarVal := fmt.Sprintf("~%s", val2) + val3 := "stringValue3" + diffVal := fmt.Sprintf("!%s", val3) + + query := url.Values{ + "equals": []string{equalsVal}, + "different": []string{diffVal}, + "similar": []string{similarVal}, + } + + params := &Query{} + err := Unmarshal(query, params) + if err != nil { + t.Fatal(err.Error()) + } + + equals := params.Equals.String() + if equals != equalsVal { + t.Errorf("Expected equals val of %s, got %s instead.", equalsVal, equals) + } + similar := params.Similar.String() + if similar != similarVal { + t.Errorf("Expected similar val of %s, got %s instead.", similarVal, similar) + } + diff := params.Different.String() + if diff != diffVal { + t.Errorf("Expected different val of %s, got %s instead.", diffVal, diff) + } +} + +func TestComparativeStringMarshalString(t *testing.T) { + type Query struct { + Equals ComparativeString + Similar ComparativeString + Different ComparativeString + } + + val1 := "stringValue1" + equalsVal := fmt.Sprintf("%s", val1) + equals := &ComparativeString{} + equals.Parse(equalsVal) + + val2 := "stringValue2" + similarVal := fmt.Sprintf("~%s", val2) + similar := &ComparativeString{} + similar.Parse(similarVal) + + val3 := "stringValue3" + diffVal := fmt.Sprintf("!%s", val3) + different := &ComparativeString{} + different.Parse(diffVal) + + q := &Query{*equals, *similar, *different} + result, err := MarshalString(q) + if err != nil { + t.Fatalf("Unable to marshal comparative timestamp: %s", err.Error()) + } + + var unescaped string + unescaped, err = url.QueryUnescape(result) + if err != nil { + t.Fatalf("Unable to unescape query string %q: %q", result, err.Error()) + } + expected := []string{ + fmt.Sprintf("different=!%s", val3), + fmt.Sprintf("equals=%s", val1), + fmt.Sprintf("similar=~%s", val2), + } + for _, ts := range expected { + if !strings.Contains(unescaped, ts) { + t.Errorf("Expected query string %s to contain %s", unescaped, ts) + } + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..5b4fa48 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/mariusor/qstring + +go 1.13