diff --git a/src/encoding/asn1/marshal.go b/src/encoding/asn1/marshal.go index c9ae2ca33e5244..0d34d5aa1e8152 100644 --- a/src/encoding/asn1/marshal.go +++ b/src/encoding/asn1/marshal.go @@ -5,10 +5,12 @@ package asn1 import ( + "bytes" "errors" "fmt" "math/big" "reflect" + "sort" "time" "unicode/utf8" ) @@ -78,6 +80,48 @@ func (m multiEncoder) Encode(dst []byte) { } } +type setEncoder []encoder + +func (s setEncoder) Len() int { + var size int + for _, e := range s { + size += e.Len() + } + return size +} + +func (s setEncoder) Encode(dst []byte) { + // Per X690 Section 11.6: The encodings of the component values of a + // set-of value shall appear in ascending order, the encodings being + // compared as octet strings with the shorter components being padded + // at their trailing end with 0-octets. + // + // First we encode each element to its TLV encoding and then use + // octetSort to get the ordering expected by X690 DER rules before + // writing the sorted encodings out to dst. + l := make([][]byte, len(s)) + for i, e := range s { + l[i] = make([]byte, e.Len()) + e.Encode(l[i]) + } + + sort.Slice(l, func(i, j int) bool { + // Since we are using bytes.Compare to compare TLV encodings we + // don't need to right pad s[i] and s[j] to the same length as + // suggested in X690. If len(s[i]) < len(s[j]) the length octet of + // s[i], which is the first determining byte, will inherently be + // smaller than the length octet of s[j]. This lets us skip the + // padding step. + return bytes.Compare(l[i], l[j]) < 0 + }) + + var off int + for _, b := range l { + copy(dst[off:], b) + off += len(b) + } +} + type taggedEncoder struct { // scratch contains temporary space for encoding the tag and length of // an element in order to avoid extra allocations. @@ -511,6 +555,9 @@ func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error } } + if params.set { + return setEncoder(m), nil + } return multiEncoder(m), nil } case reflect.String: @@ -618,6 +665,15 @@ func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) { tag = TagSet } + // makeField can be called for a slice that should be treated as a SET + // but doesn't have params.set set, for instance when using a slice + // with the SET type name suffix. In this case getUniversalType returns + // TagSet, but makeBody doesn't know about that so will treat the slice + // as a sequence. To work around this we set params.set. + if tag == TagSet && !params.set { + params.set = true + } + t := new(taggedEncoder) t.body, err = makeBody(v, params) diff --git a/src/encoding/asn1/marshal_test.go b/src/encoding/asn1/marshal_test.go index a77826a7b034fd..529052285f5950 100644 --- a/src/encoding/asn1/marshal_test.go +++ b/src/encoding/asn1/marshal_test.go @@ -319,3 +319,60 @@ func BenchmarkMarshal(b *testing.B) { } } } + +func TestSetEncoder(t *testing.T) { + testStruct := struct { + Strings []string `asn1:"set"` + }{ + Strings: []string{"a", "aa", "b", "bb", "c", "cc"}, + } + + // Expected ordering of the SET should be: + // a, b, c, aa, bb, cc + + output, err := Marshal(testStruct) + if err != nil { + t.Errorf("%v", err) + } + + expectedOrder := []string{"a", "b", "c", "aa", "bb", "cc"} + var resultStruct struct { + Strings []string `asn1:"set"` + } + rest, err := Unmarshal(output, &resultStruct) + if err != nil { + t.Errorf("%v", err) + } + if len(rest) != 0 { + t.Error("Unmarshal returned extra garbage") + } + if !reflect.DeepEqual(expectedOrder, resultStruct.Strings) { + t.Errorf("Unexpected SET content. got: %s, want: %s", resultStruct.Strings, expectedOrder) + } +} + +func TestSetEncoderSETSliceSuffix(t *testing.T) { + type testSetSET []string + testSet := testSetSET{"a", "aa", "b", "bb", "c", "cc"} + + // Expected ordering of the SET should be: + // a, b, c, aa, bb, cc + + output, err := Marshal(testSet) + if err != nil { + t.Errorf("%v", err) + } + + expectedOrder := testSetSET{"a", "b", "c", "aa", "bb", "cc"} + var resultSet testSetSET + rest, err := Unmarshal(output, &resultSet) + if err != nil { + t.Errorf("%v", err) + } + if len(rest) != 0 { + t.Error("Unmarshal returned extra garbage") + } + if !reflect.DeepEqual(expectedOrder, resultSet) { + t.Errorf("Unexpected SET content. got: %s, want: %s", resultSet, expectedOrder) + } +}