Skip to content

Commit d14a738

Browse files
authored
feat: Reject unknown fields in YAML config, allow fluent output without secret (#58)
* accessing internal yaml v3 methods, see golang/go#67361 * added tests for client worker's connection opening since the YAML change, all practically optional fields in config can be turned optional by removing the value checks.
1 parent f132d4f commit d14a738

File tree

9 files changed

+214
-19
lines changed

9 files changed

+214
-19
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
.DS_Store
12
*.test
23
/.vscode
34
/BUILD/*

base/bconfig/configholder.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/relex/gotils/logger"
88
"github.com/relex/slog-agent/util"
9+
"github.com/relex/slog-agent/util/yamlinternal"
910
"gopkg.in/yaml.v3"
1011
)
1112

@@ -45,7 +46,7 @@ func (holder *ConfigHolder[C]) UnmarshalYAML(value *yaml.Node) error {
4546
}
4647
holder.Value = createFunc()
4748

48-
if err := value.Decode(holder.Value); err != nil {
49+
if err := yamlinternal.NodeDecodeKnownFields(value, holder.Value); err != nil {
4950
return util.NewYamlError(value, err.Error())
5051
}
5152
holder.Location = util.GetYamlLocation(value)

output/fluentdforward/clientworker.go

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,25 @@ func openForwardConnection(parentLogger logger.Logger, config UpstreamConfig) (b
5050
}
5151
connLogger.Info("connected to ", sock.RemoteAddr())
5252

53-
success, reason, herr := forwardprotocol.DoClientHandshake(sock, config.Secret, defs.ForwarderHandshakeTimeout)
54-
if herr != nil {
55-
if err := sock.Close(); err != nil && !util.IsNetworkClosed(err) {
56-
connLogger.Warn("error closing connection: ", err)
53+
if len(config.Secret) > 0 {
54+
success, reason, protocolErr := forwardprotocol.DoClientHandshake(sock, config.Secret, defs.ForwarderHandshakeTimeout)
55+
56+
var authErr error
57+
switch {
58+
case protocolErr != nil:
59+
authErr = fmt.Errorf("failed to handshake due to unexpected error: %w", protocolErr)
60+
case !success:
61+
authErr = fmt.Errorf("login rejected: %s", reason)
62+
default:
63+
authErr = nil
5764
}
58-
return nil, fmt.Errorf("failed to handshake due to error: %w", herr)
59-
}
60-
if !success {
61-
if err := sock.Close(); err != nil && !util.IsNetworkClosed(err) {
62-
connLogger.Warn("error closing connection: ", err)
65+
66+
if authErr != nil {
67+
if err := sock.Close(); err != nil && !util.IsNetworkClosed(err) {
68+
connLogger.Warn("error closing connection: ", err)
69+
}
70+
return nil, authErr
6371
}
64-
return nil, fmt.Errorf("failed to handshake due to authentication: %s", reason)
6572
}
6673

6774
return &forwardConnection{
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package fluentdforward
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"github.com/relex/fluentlib/server"
8+
"github.com/relex/fluentlib/server/receivers"
9+
"github.com/relex/gotils/logger"
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestOpenClientConnection(t *testing.T) {
14+
recv := receivers.NewMessageWriter(os.Stdout)
15+
16+
t.Run("connect fails when protocols are different", func(t *testing.T) {
17+
srvCfg := server.Config{
18+
Address: "localhost:0",
19+
TLS: false,
20+
}
21+
srv, srvAddr := server.LaunchServer(logger.WithField("test", t.Name()), srvCfg, recv)
22+
defer srv.Shutdown()
23+
24+
_, err := openForwardConnection(logger.Root(), UpstreamConfig{
25+
Address: srvAddr.String(),
26+
TLS: true, // attempt to request TLS handshake
27+
})
28+
assert.ErrorContains(t, err, "failed to connect: tls:")
29+
})
30+
31+
t.Run("login fails when secrets are different", func(t *testing.T) {
32+
srvCfg := server.Config{
33+
Address: "localhost:0",
34+
Secret: "real pass",
35+
TLS: false,
36+
}
37+
srv, srvAddr := server.LaunchServer(logger.WithField("test", t.Name()), srvCfg, recv)
38+
defer srv.Shutdown()
39+
40+
_, err := openForwardConnection(logger.Root(), UpstreamConfig{
41+
Address: srvAddr.String(),
42+
TLS: false,
43+
Secret: "wrong pass",
44+
})
45+
assert.ErrorContains(t, err, "login rejected:")
46+
})
47+
}

output/fluentdforward/config.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,6 @@ func (cfg *Config) VerifyConfig(schema base.LogSchema) error {
129129
return fmt.Errorf(".upstream.address is invalid: %w", err)
130130
}
131131

132-
if cfg.Upstream.TLS && len(cfg.Upstream.Secret) == 0 {
133-
return fmt.Errorf(".upstream.secret is unspecified when tls=true")
134-
}
135-
136132
if cfg.Upstream.MaxDuration == 0 {
137133
return fmt.Errorf(".upstream.maxDuration is unspecified")
138134
}

run/loader_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ outputBufferPairs:
7777
upstream:
7878
address: %s
7979
tls: false
80+
secret: Hi
8081
maxDuration: 500ms
8182
`
8283

@@ -88,6 +89,11 @@ var sampleConf = assembleConfig(
8889
sampleOutputConf,
8990
)
9091

92+
var serverConfig = server.Config{
93+
Address: "localhost:0",
94+
Secret: "Hi",
95+
}
96+
9197
func TestLoader(t *testing.T) {
9298
logRecv, outBatchCh := receivers.NewMessageCollector(5 * time.Second)
9399

@@ -147,9 +153,7 @@ func runTestEnv(t *testing.T, logReceiver receivers.Receiver, confYML string,
147153
assert.NoError(t, confFileErr)
148154
defer os.Remove(confFile.Name())
149155

150-
srvConf := server.Config{}
151-
srvConf.Address = "localhost:0"
152-
srv, srvAddr := server.LaunchServer(logger.WithField("test", t.Name()), srvConf, logReceiver)
156+
srv, srvAddr := server.LaunchServer(logger.WithField("test", t.Name()), serverConfig, logReceiver)
153157

154158
_, writeErr := confFile.WriteString(fmt.Sprintf(confYML, bufDir, srvAddr.String()))
155159
assert.NoError(t, writeErr)

util/yaml.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func UnmarshalYamlFile(path string, output interface{}) error {
5656
// UnmarshalYamlReader loads and unmarshals YAML from IO reader to interface or pointer to struct
5757
func UnmarshalYamlReader(reader io.Reader, output interface{}) error {
5858
decoder := yaml.NewDecoder(reader)
59-
decoder.KnownFields(true)
59+
decoder.KnownFields(true) // only works outside of custom unmarshalers
6060
return decoder.Decode(output)
6161
}
6262

util/yamlinternal/decode.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2024 RELEX Oy
2+
// Copyright (c) 2011-2019 Canonical Ltd
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
package yamlinternal
17+
18+
import (
19+
"fmt"
20+
"reflect"
21+
22+
_ "unsafe"
23+
24+
"gopkg.in/yaml.v3"
25+
)
26+
27+
// decoder is a copy from gopkg.in/yaml.v3/decode.go
28+
//
29+
// FIXME: Can we get the actual type definition of the returned valued from newDecoder()? Attempts resulted in "<invalid reflect.Value>"
30+
//
31+
//lint:ignore U1000 all fields until the last used field must be kept for proper offsets
32+
type decoder struct {
33+
doc *yaml.Node
34+
aliases map[*yaml.Node]bool
35+
terrors []string
36+
37+
stringMapType reflect.Type
38+
generalMapType reflect.Type
39+
40+
knownFields bool
41+
uniqueKeys bool
42+
decodeCount int
43+
aliasCount int
44+
aliasDepth int
45+
46+
mergedFields map[interface{}]bool
47+
}
48+
49+
//go:linkname handleErr gopkg.in/yaml%2ev3.handleErr
50+
func handleErr(err *error)
51+
52+
//go:linkname newDecoder gopkg.in/yaml%2ev3.newDecoder
53+
func newDecoder() *decoder
54+
55+
//go:linkname unmarshal gopkg.in/yaml%2ev3.(*decoder).unmarshal
56+
func unmarshal(d *decoder, n *yaml.Node, out reflect.Value) (good bool)
57+
58+
// NodeDecodeKnownFields is yaml.Node.Decode with KnownFields=true, to disallow unknown fields in YAML source.
59+
func NodeDecodeKnownFields(node *yaml.Node, v interface{}) (err error) {
60+
d := newDecoder()
61+
d.knownFields = true
62+
defer handleErr(&err)
63+
out := reflect.ValueOf(v)
64+
if out.Kind() == reflect.Ptr && !out.IsNil() {
65+
out = out.Elem()
66+
}
67+
good := unmarshal(d, node, out)
68+
69+
if len(d.terrors) > 0 {
70+
return &yaml.TypeError{Errors: d.terrors}
71+
}
72+
73+
if !good {
74+
return fmt.Errorf("not good")
75+
}
76+
return nil
77+
}

util/yamlinternal/decode_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package yamlinternal
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"gopkg.in/yaml.v3"
8+
)
9+
10+
func TestNodeDecodeKnownFields(t *testing.T) {
11+
var correctNode yaml.Node
12+
correctYaml := `
13+
value: ok
14+
15+
list:
16+
- "Hi here"
17+
- "Hey correct"
18+
`
19+
20+
var incorrectNode yaml.Node
21+
incorrectYaml := `
22+
value: ok
23+
24+
list:
25+
- "Hi here"
26+
- "Hey incorrect"
27+
28+
unknown: 9
29+
`
30+
assert.NoError(t, yaml.Unmarshal([]byte(correctYaml), &correctNode))
31+
assert.NoError(t, yaml.Unmarshal([]byte(incorrectYaml), &incorrectNode))
32+
33+
type ResultType struct {
34+
Value string `yaml:"value"`
35+
List []string `yaml:"list"`
36+
}
37+
38+
t.Run("decode with builtin method to accept unknown fields", func(t *testing.T) {
39+
var result ResultType
40+
assert.NoError(t, incorrectNode.Decode(&result))
41+
assert.Equal(t, "ok", result.Value)
42+
if assert.Equal(t, 2, len(result.List)) {
43+
assert.Equal(t, "Hi here", result.List[0])
44+
assert.Equal(t, "Hey incorrect", result.List[1])
45+
}
46+
})
47+
48+
t.Run("decode with new method to reject unknown fields", func(t *testing.T) {
49+
var result ResultType
50+
assert.ErrorContains(t, NodeDecodeKnownFields(&incorrectNode, &result), "line 8: field unknown not found in type yamlinternal.ResultType")
51+
})
52+
53+
t.Run("decode with new method to accept known fields", func(t *testing.T) {
54+
var result ResultType
55+
assert.NoError(t, NodeDecodeKnownFields(&correctNode, &result))
56+
assert.Equal(t, "ok", result.Value)
57+
if assert.Equal(t, 2, len(result.List)) {
58+
assert.Equal(t, "Hi here", result.List[0])
59+
assert.Equal(t, "Hey correct", result.List[1])
60+
}
61+
})
62+
}

0 commit comments

Comments
 (0)