From d56b49c13d0a2822a7e6c8248af45e7d2e1312d4 Mon Sep 17 00:00:00 2001 From: Maksym Pavlenko Date: Sat, 27 Mar 2021 15:39:46 -0700 Subject: [PATCH] Rewrite Docker hosts parser Signed-off-by: Maksym Pavlenko --- remotes/docker/config/hosts.go | 343 ++++++++++++++-------------- remotes/docker/config/hosts_test.go | 16 +- 2 files changed, 179 insertions(+), 180 deletions(-) diff --git a/remotes/docker/config/hosts.go b/remotes/docker/config/hosts.go index a96e64ff6..e3be0f44a 100644 --- a/remotes/docker/config/hosts.go +++ b/remotes/docker/config/hosts.go @@ -30,11 +30,12 @@ import ( "strings" "time" - "github.com/BurntSushi/toml" + "github.com/pelletier/go-toml" + "github.com/pkg/errors" + "github.com/containerd/containerd/errdefs" "github.com/containerd/containerd/log" "github.com/containerd/containerd/remotes/docker" - "github.com/pkg/errors" ) // UpdateClientFunc is a function that lets you to amend http Client behavior used by registry clients. @@ -264,7 +265,7 @@ func loadHostDir(ctx context.Context, hostsDir string) ([]hostConfig, error) { return loadCertFiles(ctx, hostsDir) } - hosts, err := parseHostsFile(ctx, hostsDir, b) + hosts, err := parseHostsFile(hostsDir, b) if err != nil { log.G(ctx).WithError(err).Error("failed to decode hosts.toml") // Fallback to checking certificate files @@ -274,7 +275,9 @@ func loadHostDir(ctx context.Context, hostsDir string) ([]hostConfig, error) { return hosts, nil } -type hostFileConfig struct { +// HostFileConfig describes a single host section within TOML file. +// Note: This struct needs to be public in order to be properly deserialized by TOML library. +type HostFileConfig struct { // Capabilities determine what operations a host is // capable of performing. Allowed values // - pull @@ -283,14 +286,14 @@ type hostFileConfig struct { Capabilities []string `toml:"capabilities"` // CACert can be a string or an array of strings - CACert toml.Primitive `toml:"ca"` + CACert interface{} `toml:"ca"` // TODO: Make this an array (two key types, one for pairs (multiple files), one for single file?) - Client toml.Primitive `toml:"client"` + Client interface{} `toml:"client"` SkipVerify *bool `toml:"skip_verify"` - Header map[string]toml.Primitive `toml:"header"` + Header map[string]interface{} `toml:"header"` // API (default: "docker") // API Version (default: "v2") @@ -300,187 +303,191 @@ type hostFileConfig struct { type configFile struct { // hostConfig holds defaults for all hosts as well as // for the default server - hostFileConfig + HostFileConfig // Server specifies the default server. When `host` is // also specified, those hosts are tried first. Server string `toml:"server"` // HostConfigs store the per-host configuration - HostConfigs map[string]hostFileConfig `toml:"host"` + HostConfigs map[string]HostFileConfig `toml:"host"` } -func parseHostsFile(ctx context.Context, baseDir string, b []byte) ([]hostConfig, error) { - var c configFile - md, err := toml.Decode(string(b), &c) +func parseHostsFile(baseDir string, b []byte) ([]hostConfig, error) { + tree, err := toml.LoadBytes(b) if err != nil { + return nil, errors.Wrap(err, "failed to parse TOML") + } + + var ( + c configFile + hosts []hostConfig + ) + + if err := tree.Unmarshal(&c); err != nil { return nil, err } - var orderedHosts []string - for _, key := range md.Keys() { - if len(key) >= 2 { - if key[0] == "host" && (len(orderedHosts) == 0 || orderedHosts[len(orderedHosts)-1] != key[1]) { - orderedHosts = append(orderedHosts, key[1]) - } - } + // Parse root host config + parsed, err := parseHostConfig(c.Server, baseDir, c.HostFileConfig) + if err != nil { + return nil, err } + hosts = append(hosts, parsed) - if c.HostConfigs == nil { - c.HostConfigs = map[string]hostFileConfig{} - } - if c.Server != "" { - c.HostConfigs[c.Server] = c.hostFileConfig - orderedHosts = append(orderedHosts, c.Server) - } else if len(orderedHosts) == 0 { - c.HostConfigs[""] = c.hostFileConfig - orderedHosts = append(orderedHosts, "") - } - hosts := make([]hostConfig, len(orderedHosts)) - for i, server := range orderedHosts { - hostConfig := c.HostConfigs[server] - - if server != "" { - if !strings.HasPrefix(server, "http") { - server = "https://" + server - } - u, err := url.Parse(server) - if err != nil { - return nil, errors.Errorf("unable to parse server %v", server) - } - hosts[i].scheme = u.Scheme - hosts[i].host = u.Host - - // TODO: Handle path based on registry protocol - // Define a registry protocol type - // OCI v1 - Always use given path as is - // Docker v2 - Always ensure ends with /v2/ - if len(u.Path) > 0 { - u.Path = path.Clean(u.Path) - if !strings.HasSuffix(u.Path, "/v2") { - u.Path = u.Path + "/v2" - } - } else { - u.Path = "/v2" - } - hosts[i].path = u.Path - } - hosts[i].skipVerify = hostConfig.SkipVerify - - if len(hostConfig.Capabilities) > 0 { - for _, c := range hostConfig.Capabilities { - switch strings.ToLower(c) { - case "pull": - hosts[i].capabilities |= docker.HostCapabilityPull - case "resolve": - hosts[i].capabilities |= docker.HostCapabilityResolve - case "push": - hosts[i].capabilities |= docker.HostCapabilityPush - default: - return nil, errors.Errorf("unknown capability %v", c) - } - } - } else { - hosts[i].capabilities = docker.HostCapabilityPull | docker.HostCapabilityResolve | docker.HostCapabilityPush - } - - baseKey := []string{} - if server != "" && server != c.Server { - baseKey = append(baseKey, "host", server) - } - caKey := append(baseKey, "ca") - if md.IsDefined(caKey...) { - switch t := md.Type(caKey...); t { - case "String": - var caCert string - if err := md.PrimitiveDecode(hostConfig.CACert, &caCert); err != nil { - return nil, errors.Wrap(err, "failed to decode \"ca\"") - } - hosts[i].caCerts = []string{makeAbsPath(caCert, baseDir)} - case "Array": - var caCerts []string - if err := md.PrimitiveDecode(hostConfig.CACert, &caCerts); err != nil { - return nil, errors.Wrap(err, "failed to decode \"ca\"") - } - for i, p := range caCerts { - caCerts[i] = makeAbsPath(p, baseDir) - } - - hosts[i].caCerts = caCerts - default: - return nil, errors.Errorf("invalid type %v for \"ca\"", t) - } - } - - clientKey := append(baseKey, "client") - if md.IsDefined(clientKey...) { - switch t := md.Type(clientKey...); t { - case "String": - var clientCert string - if err := md.PrimitiveDecode(hostConfig.Client, &clientCert); err != nil { - return nil, errors.Wrap(err, "failed to decode \"ca\"") - } - hosts[i].clientPairs = [][2]string{{makeAbsPath(clientCert, baseDir), ""}} - case "Array": - var clientCerts []interface{} - if err := md.PrimitiveDecode(hostConfig.Client, &clientCerts); err != nil { - return nil, errors.Wrap(err, "failed to decode \"ca\"") - } - for _, pairs := range clientCerts { - switch p := pairs.(type) { - case string: - hosts[i].clientPairs = append(hosts[i].clientPairs, [2]string{makeAbsPath(p, baseDir), ""}) - case []interface{}: - var pair [2]string - if len(p) > 2 { - return nil, errors.Errorf("invalid pair %v for \"client\"", p) - } - for pi, cp := range p { - s, ok := cp.(string) - if !ok { - return nil, errors.Errorf("invalid type %T for \"client\"", cp) - } - pair[pi] = makeAbsPath(s, baseDir) - } - hosts[i].clientPairs = append(hosts[i].clientPairs, pair) - default: - return nil, errors.Errorf("invalid type %T for \"client\"", p) - } - } - default: - return nil, errors.Errorf("invalid type %v for \"client\"", t) - } - } - - headerKey := append(baseKey, "header") - if md.IsDefined(headerKey...) { - header := http.Header{} - for key, prim := range hostConfig.Header { - switch t := md.Type(append(headerKey, key)...); t { - case "String": - var value string - if err := md.PrimitiveDecode(prim, &value); err != nil { - return nil, errors.Wrapf(err, "failed to decode header %q", key) - } - header[key] = []string{value} - case "Array": - var value []string - if err := md.PrimitiveDecode(prim, &value); err != nil { - return nil, errors.Wrapf(err, "failed to decode header %q", key) - } - - header[key] = value - default: - return nil, errors.Errorf("invalid type %v for header %q", t, key) - } - } - hosts[i].header = header + // Parse hosts array + for host, config := range c.HostConfigs { + parsed, err := parseHostConfig(host, baseDir, config) + if err != nil { + return nil, err } + hosts = append(hosts, parsed) } return hosts, nil } +func parseHostConfig(server string, baseDir string, config HostFileConfig) (hostConfig, error) { + var ( + result = hostConfig{} + err error + ) + + if server != "" { + if !strings.HasPrefix(server, "http") { + server = "https://" + server + } + u, err := url.Parse(server) + if err != nil { + return hostConfig{}, errors.Wrapf(err, "unable to parse server %v", server) + } + result.scheme = u.Scheme + result.host = u.Host + // TODO: Handle path based on registry protocol + // Define a registry protocol type + // OCI v1 - Always use given path as is + // Docker v2 - Always ensure ends with /v2/ + if len(u.Path) > 0 { + u.Path = path.Clean(u.Path) + if !strings.HasSuffix(u.Path, "/v2") { + u.Path = u.Path + "/v2" + } + } else { + u.Path = "/v2" + } + result.path = u.Path + } + + result.skipVerify = config.SkipVerify + + if len(config.Capabilities) > 0 { + for _, c := range config.Capabilities { + switch strings.ToLower(c) { + case "pull": + result.capabilities |= docker.HostCapabilityPull + case "resolve": + result.capabilities |= docker.HostCapabilityResolve + case "push": + result.capabilities |= docker.HostCapabilityPush + default: + return hostConfig{}, errors.Errorf("unknown capability %v", c) + } + } + } else { + result.capabilities = docker.HostCapabilityPull | docker.HostCapabilityResolve | docker.HostCapabilityPush + } + + if config.CACert != nil { + switch cert := config.CACert.(type) { + case string: + result.caCerts = []string{makeAbsPath(cert, baseDir)} + case []string: + for _, p := range cert { + result.caCerts = append(result.caCerts, makeAbsPath(p, baseDir)) + } + case []interface{}: + result.caCerts, err = makeStringSlice(cert, func(p string) string { + return makeAbsPath(p, baseDir) + }) + if err != nil { + return hostConfig{}, err + } + default: + return hostConfig{}, errors.Errorf("invalid type %v for \"ca\"", cert) + } + } + + if config.Client != nil { + switch client := config.Client.(type) { + case string: + result.clientPairs = [][2]string{{makeAbsPath(client, baseDir), ""}} + case []interface{}: + // []string or [][2]string + for _, pairs := range client { + switch p := pairs.(type) { + case string: + result.clientPairs = append(result.clientPairs, [2]string{makeAbsPath(p, baseDir), ""}) + case []interface{}: + slice, err := makeStringSlice(p, nil) + if err != nil { + return hostConfig{}, err + } + if len(slice) != 2 { + return hostConfig{}, errors.Errorf("invalid pair %v for \"client\"", p) + } + + var pair [2]string + copy(pair[:], slice) + result.clientPairs = append(result.clientPairs, pair) + default: + return hostConfig{}, errors.Errorf("invalid type %T for \"client\"", p) + } + } + default: + return hostConfig{}, errors.Errorf("invalid type %v for \"client\"", client) + } + } + + if config.Header != nil { + header := http.Header{} + for key, ty := range config.Header { + switch value := ty.(type) { + case string: + header[key] = []string{value} + case []interface{}: + header[key], err = makeStringSlice(value, nil) + if err != nil { + return hostConfig{}, err + } + default: + return hostConfig{}, errors.Errorf("invalid type %v for header %q", ty, key) + } + } + result.header = header + } + + return result, nil +} + +// makeStringSlice is a helper func to convert from []interface{} to []string. +// Additionally an optional cb func may be passed to perform string mapping. +func makeStringSlice(slice []interface{}, cb func(string) string) ([]string, error) { + out := make([]string, len(slice)) + for i, value := range slice { + str, ok := value.(string) + if !ok { + return nil, errors.Errorf("unable to cast %v to string", value) + } + + if cb != nil { + out[i] = cb(str) + } else { + out[i] = str + } + } + return out, nil +} + func makeAbsPath(p string, base string) string { if filepath.IsAbs(p) { return p diff --git a/remotes/docker/config/hosts_test.go b/remotes/docker/config/hosts_test.go index 9a94859ea..df3f3328c 100644 --- a/remotes/docker/config/hosts_test.go +++ b/remotes/docker/config/hosts_test.go @@ -26,6 +26,8 @@ import ( "path/filepath" "testing" + "github.com/stretchr/testify/assert" + "github.com/containerd/containerd/log/logtest" "github.com/containerd/containerd/remotes/docker" ) @@ -74,8 +76,6 @@ func TestDefaultHosts(t *testing.T) { } func TestParseHostFile(t *testing.T) { - ctx := logtest.WithT(context.Background(), t) - const testtoml = ` server = "https://test-default.registry" ca = "/etc/path/default" @@ -170,7 +170,7 @@ ca = "/etc/path/default" header: http.Header{"x-custom-1": {"custom header"}}, }, } - hosts, err := parseHostsFile(ctx, "", []byte(testtoml)) + hosts, err := parseHostsFile("", []byte(testtoml)) if err != nil { t.Fatal(err) } @@ -181,15 +181,7 @@ ca = "/etc/path/default" } }() - if len(hosts) != len(expected) { - t.Fatalf("Unexpected number of hosts %d, expected %d", len(hosts), len(expected)) - } - - for i := range hosts { - if !compareHostConfig(hosts[i], expected[i]) { - t.Fatalf("Mismatch at host %d", i) - } - } + assert.ElementsMatch(t, expected, hosts) } func TestLoadCertFiles(t *testing.T) {