Rewrite Docker hosts parser
Signed-off-by: Maksym Pavlenko <pavlenko.maksym@gmail.com>
This commit is contained in:
		| @@ -30,11 +30,12 @@ import ( | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/BurntSushi/toml" | ||||
| 	"github.com/pelletier/go-toml" | ||||
| 	"github.com/pkg/errors" | ||||
|  | ||||
| 	"github.com/containerd/containerd/errdefs" | ||||
| 	"github.com/containerd/containerd/log" | ||||
| 	"github.com/containerd/containerd/remotes/docker" | ||||
| 	"github.com/pkg/errors" | ||||
| ) | ||||
|  | ||||
| // UpdateClientFunc is a function that lets you to amend http Client behavior used by registry clients. | ||||
| @@ -264,7 +265,7 @@ func loadHostDir(ctx context.Context, hostsDir string) ([]hostConfig, error) { | ||||
| 		return loadCertFiles(ctx, hostsDir) | ||||
| 	} | ||||
|  | ||||
| 	hosts, err := parseHostsFile(ctx, hostsDir, b) | ||||
| 	hosts, err := parseHostsFile(hostsDir, b) | ||||
| 	if err != nil { | ||||
| 		log.G(ctx).WithError(err).Error("failed to decode hosts.toml") | ||||
| 		// Fallback to checking certificate files | ||||
| @@ -274,7 +275,9 @@ func loadHostDir(ctx context.Context, hostsDir string) ([]hostConfig, error) { | ||||
| 	return hosts, nil | ||||
| } | ||||
|  | ||||
| type hostFileConfig struct { | ||||
| // HostFileConfig describes a single host section within TOML file. | ||||
| // Note: This struct needs to be public in order to be properly deserialized by TOML library. | ||||
| type HostFileConfig struct { | ||||
| 	// Capabilities determine what operations a host is | ||||
| 	// capable of performing. Allowed values | ||||
| 	//  - pull | ||||
| @@ -283,14 +286,14 @@ type hostFileConfig struct { | ||||
| 	Capabilities []string `toml:"capabilities"` | ||||
|  | ||||
| 	// CACert can be a string or an array of strings | ||||
| 	CACert toml.Primitive `toml:"ca"` | ||||
| 	CACert interface{} `toml:"ca"` | ||||
|  | ||||
| 	// TODO: Make this an array (two key types, one for pairs (multiple files), one for single file?) | ||||
| 	Client toml.Primitive `toml:"client"` | ||||
| 	Client interface{} `toml:"client"` | ||||
|  | ||||
| 	SkipVerify *bool `toml:"skip_verify"` | ||||
|  | ||||
| 	Header map[string]toml.Primitive `toml:"header"` | ||||
| 	Header map[string]interface{} `toml:"header"` | ||||
|  | ||||
| 	// API (default: "docker") | ||||
| 	// API Version (default: "v2") | ||||
| @@ -300,187 +303,191 @@ type hostFileConfig struct { | ||||
| type configFile struct { | ||||
| 	// hostConfig holds defaults for all hosts as well as | ||||
| 	// for the default server | ||||
| 	hostFileConfig | ||||
| 	HostFileConfig | ||||
|  | ||||
| 	// Server specifies the default server. When `host` is | ||||
| 	// also specified, those hosts are tried first. | ||||
| 	Server string `toml:"server"` | ||||
|  | ||||
| 	// HostConfigs store the per-host configuration | ||||
| 	HostConfigs map[string]hostFileConfig `toml:"host"` | ||||
| 	HostConfigs map[string]HostFileConfig `toml:"host"` | ||||
| } | ||||
|  | ||||
| func parseHostsFile(ctx context.Context, baseDir string, b []byte) ([]hostConfig, error) { | ||||
| 	var c configFile | ||||
| 	md, err := toml.Decode(string(b), &c) | ||||
| func parseHostsFile(baseDir string, b []byte) ([]hostConfig, error) { | ||||
| 	tree, err := toml.LoadBytes(b) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "failed to parse TOML") | ||||
| 	} | ||||
|  | ||||
| 	var ( | ||||
| 		c     configFile | ||||
| 		hosts []hostConfig | ||||
| 	) | ||||
|  | ||||
| 	if err := tree.Unmarshal(&c); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var orderedHosts []string | ||||
| 	for _, key := range md.Keys() { | ||||
| 		if len(key) >= 2 { | ||||
| 			if key[0] == "host" && (len(orderedHosts) == 0 || orderedHosts[len(orderedHosts)-1] != key[1]) { | ||||
| 				orderedHosts = append(orderedHosts, key[1]) | ||||
| 			} | ||||
| 		} | ||||
| 	// Parse root host config | ||||
| 	parsed, err := parseHostConfig(c.Server, baseDir, c.HostFileConfig) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	hosts = append(hosts, parsed) | ||||
|  | ||||
| 	if c.HostConfigs == nil { | ||||
| 		c.HostConfigs = map[string]hostFileConfig{} | ||||
| 	} | ||||
| 	if c.Server != "" { | ||||
| 		c.HostConfigs[c.Server] = c.hostFileConfig | ||||
| 		orderedHosts = append(orderedHosts, c.Server) | ||||
| 	} else if len(orderedHosts) == 0 { | ||||
| 		c.HostConfigs[""] = c.hostFileConfig | ||||
| 		orderedHosts = append(orderedHosts, "") | ||||
| 	} | ||||
| 	hosts := make([]hostConfig, len(orderedHosts)) | ||||
| 	for i, server := range orderedHosts { | ||||
| 		hostConfig := c.HostConfigs[server] | ||||
|  | ||||
| 		if server != "" { | ||||
| 			if !strings.HasPrefix(server, "http") { | ||||
| 				server = "https://" + server | ||||
| 			} | ||||
| 			u, err := url.Parse(server) | ||||
| 			if err != nil { | ||||
| 				return nil, errors.Errorf("unable to parse server %v", server) | ||||
| 			} | ||||
| 			hosts[i].scheme = u.Scheme | ||||
| 			hosts[i].host = u.Host | ||||
|  | ||||
| 			// TODO: Handle path based on registry protocol | ||||
| 			// Define a registry protocol type | ||||
| 			//   OCI v1    - Always use given path as is | ||||
| 			//   Docker v2 - Always ensure ends with /v2/ | ||||
| 			if len(u.Path) > 0 { | ||||
| 				u.Path = path.Clean(u.Path) | ||||
| 				if !strings.HasSuffix(u.Path, "/v2") { | ||||
| 					u.Path = u.Path + "/v2" | ||||
| 				} | ||||
| 			} else { | ||||
| 				u.Path = "/v2" | ||||
| 			} | ||||
| 			hosts[i].path = u.Path | ||||
| 		} | ||||
| 		hosts[i].skipVerify = hostConfig.SkipVerify | ||||
|  | ||||
| 		if len(hostConfig.Capabilities) > 0 { | ||||
| 			for _, c := range hostConfig.Capabilities { | ||||
| 				switch strings.ToLower(c) { | ||||
| 				case "pull": | ||||
| 					hosts[i].capabilities |= docker.HostCapabilityPull | ||||
| 				case "resolve": | ||||
| 					hosts[i].capabilities |= docker.HostCapabilityResolve | ||||
| 				case "push": | ||||
| 					hosts[i].capabilities |= docker.HostCapabilityPush | ||||
| 				default: | ||||
| 					return nil, errors.Errorf("unknown capability %v", c) | ||||
| 				} | ||||
| 			} | ||||
| 		} else { | ||||
| 			hosts[i].capabilities = docker.HostCapabilityPull | docker.HostCapabilityResolve | docker.HostCapabilityPush | ||||
| 		} | ||||
|  | ||||
| 		baseKey := []string{} | ||||
| 		if server != "" && server != c.Server { | ||||
| 			baseKey = append(baseKey, "host", server) | ||||
| 		} | ||||
| 		caKey := append(baseKey, "ca") | ||||
| 		if md.IsDefined(caKey...) { | ||||
| 			switch t := md.Type(caKey...); t { | ||||
| 			case "String": | ||||
| 				var caCert string | ||||
| 				if err := md.PrimitiveDecode(hostConfig.CACert, &caCert); err != nil { | ||||
| 					return nil, errors.Wrap(err, "failed to decode \"ca\"") | ||||
| 				} | ||||
| 				hosts[i].caCerts = []string{makeAbsPath(caCert, baseDir)} | ||||
| 			case "Array": | ||||
| 				var caCerts []string | ||||
| 				if err := md.PrimitiveDecode(hostConfig.CACert, &caCerts); err != nil { | ||||
| 					return nil, errors.Wrap(err, "failed to decode \"ca\"") | ||||
| 				} | ||||
| 				for i, p := range caCerts { | ||||
| 					caCerts[i] = makeAbsPath(p, baseDir) | ||||
| 				} | ||||
|  | ||||
| 				hosts[i].caCerts = caCerts | ||||
| 			default: | ||||
| 				return nil, errors.Errorf("invalid type %v for \"ca\"", t) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		clientKey := append(baseKey, "client") | ||||
| 		if md.IsDefined(clientKey...) { | ||||
| 			switch t := md.Type(clientKey...); t { | ||||
| 			case "String": | ||||
| 				var clientCert string | ||||
| 				if err := md.PrimitiveDecode(hostConfig.Client, &clientCert); err != nil { | ||||
| 					return nil, errors.Wrap(err, "failed to decode \"ca\"") | ||||
| 				} | ||||
| 				hosts[i].clientPairs = [][2]string{{makeAbsPath(clientCert, baseDir), ""}} | ||||
| 			case "Array": | ||||
| 				var clientCerts []interface{} | ||||
| 				if err := md.PrimitiveDecode(hostConfig.Client, &clientCerts); err != nil { | ||||
| 					return nil, errors.Wrap(err, "failed to decode \"ca\"") | ||||
| 				} | ||||
| 				for _, pairs := range clientCerts { | ||||
| 					switch p := pairs.(type) { | ||||
| 					case string: | ||||
| 						hosts[i].clientPairs = append(hosts[i].clientPairs, [2]string{makeAbsPath(p, baseDir), ""}) | ||||
| 					case []interface{}: | ||||
| 						var pair [2]string | ||||
| 						if len(p) > 2 { | ||||
| 							return nil, errors.Errorf("invalid pair %v for \"client\"", p) | ||||
| 						} | ||||
| 						for pi, cp := range p { | ||||
| 							s, ok := cp.(string) | ||||
| 							if !ok { | ||||
| 								return nil, errors.Errorf("invalid type %T for \"client\"", cp) | ||||
| 							} | ||||
| 							pair[pi] = makeAbsPath(s, baseDir) | ||||
| 						} | ||||
| 						hosts[i].clientPairs = append(hosts[i].clientPairs, pair) | ||||
| 					default: | ||||
| 						return nil, errors.Errorf("invalid type %T for \"client\"", p) | ||||
| 					} | ||||
| 				} | ||||
| 			default: | ||||
| 				return nil, errors.Errorf("invalid type %v for \"client\"", t) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		headerKey := append(baseKey, "header") | ||||
| 		if md.IsDefined(headerKey...) { | ||||
| 			header := http.Header{} | ||||
| 			for key, prim := range hostConfig.Header { | ||||
| 				switch t := md.Type(append(headerKey, key)...); t { | ||||
| 				case "String": | ||||
| 					var value string | ||||
| 					if err := md.PrimitiveDecode(prim, &value); err != nil { | ||||
| 						return nil, errors.Wrapf(err, "failed to decode header %q", key) | ||||
| 					} | ||||
| 					header[key] = []string{value} | ||||
| 				case "Array": | ||||
| 					var value []string | ||||
| 					if err := md.PrimitiveDecode(prim, &value); err != nil { | ||||
| 						return nil, errors.Wrapf(err, "failed to decode header %q", key) | ||||
| 					} | ||||
|  | ||||
| 					header[key] = value | ||||
| 				default: | ||||
| 					return nil, errors.Errorf("invalid type %v for header %q", t, key) | ||||
| 				} | ||||
| 			} | ||||
| 			hosts[i].header = header | ||||
| 	// Parse hosts array | ||||
| 	for host, config := range c.HostConfigs { | ||||
| 		parsed, err := parseHostConfig(host, baseDir, config) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		hosts = append(hosts, parsed) | ||||
| 	} | ||||
|  | ||||
| 	return hosts, nil | ||||
| } | ||||
|  | ||||
| func parseHostConfig(server string, baseDir string, config HostFileConfig) (hostConfig, error) { | ||||
| 	var ( | ||||
| 		result = hostConfig{} | ||||
| 		err    error | ||||
| 	) | ||||
|  | ||||
| 	if server != "" { | ||||
| 		if !strings.HasPrefix(server, "http") { | ||||
| 			server = "https://" + server | ||||
| 		} | ||||
| 		u, err := url.Parse(server) | ||||
| 		if err != nil { | ||||
| 			return hostConfig{}, errors.Wrapf(err, "unable to parse server %v", server) | ||||
| 		} | ||||
| 		result.scheme = u.Scheme | ||||
| 		result.host = u.Host | ||||
| 		// TODO: Handle path based on registry protocol | ||||
| 		// Define a registry protocol type | ||||
| 		//   OCI v1    - Always use given path as is | ||||
| 		//   Docker v2 - Always ensure ends with /v2/ | ||||
| 		if len(u.Path) > 0 { | ||||
| 			u.Path = path.Clean(u.Path) | ||||
| 			if !strings.HasSuffix(u.Path, "/v2") { | ||||
| 				u.Path = u.Path + "/v2" | ||||
| 			} | ||||
| 		} else { | ||||
| 			u.Path = "/v2" | ||||
| 		} | ||||
| 		result.path = u.Path | ||||
| 	} | ||||
|  | ||||
| 	result.skipVerify = config.SkipVerify | ||||
|  | ||||
| 	if len(config.Capabilities) > 0 { | ||||
| 		for _, c := range config.Capabilities { | ||||
| 			switch strings.ToLower(c) { | ||||
| 			case "pull": | ||||
| 				result.capabilities |= docker.HostCapabilityPull | ||||
| 			case "resolve": | ||||
| 				result.capabilities |= docker.HostCapabilityResolve | ||||
| 			case "push": | ||||
| 				result.capabilities |= docker.HostCapabilityPush | ||||
| 			default: | ||||
| 				return hostConfig{}, errors.Errorf("unknown capability %v", c) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		result.capabilities = docker.HostCapabilityPull | docker.HostCapabilityResolve | docker.HostCapabilityPush | ||||
| 	} | ||||
|  | ||||
| 	if config.CACert != nil { | ||||
| 		switch cert := config.CACert.(type) { | ||||
| 		case string: | ||||
| 			result.caCerts = []string{makeAbsPath(cert, baseDir)} | ||||
| 		case []string: | ||||
| 			for _, p := range cert { | ||||
| 				result.caCerts = append(result.caCerts, makeAbsPath(p, baseDir)) | ||||
| 			} | ||||
| 		case []interface{}: | ||||
| 			result.caCerts, err = makeStringSlice(cert, func(p string) string { | ||||
| 				return makeAbsPath(p, baseDir) | ||||
| 			}) | ||||
| 			if err != nil { | ||||
| 				return hostConfig{}, err | ||||
| 			} | ||||
| 		default: | ||||
| 			return hostConfig{}, errors.Errorf("invalid type %v for \"ca\"", cert) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if config.Client != nil { | ||||
| 		switch client := config.Client.(type) { | ||||
| 		case string: | ||||
| 			result.clientPairs = [][2]string{{makeAbsPath(client, baseDir), ""}} | ||||
| 		case []interface{}: | ||||
| 			// []string or [][2]string | ||||
| 			for _, pairs := range client { | ||||
| 				switch p := pairs.(type) { | ||||
| 				case string: | ||||
| 					result.clientPairs = append(result.clientPairs, [2]string{makeAbsPath(p, baseDir), ""}) | ||||
| 				case []interface{}: | ||||
| 					slice, err := makeStringSlice(p, nil) | ||||
| 					if err != nil { | ||||
| 						return hostConfig{}, err | ||||
| 					} | ||||
| 					if len(slice) != 2 { | ||||
| 						return hostConfig{}, errors.Errorf("invalid pair %v for \"client\"", p) | ||||
| 					} | ||||
|  | ||||
| 					var pair [2]string | ||||
| 					copy(pair[:], slice) | ||||
| 					result.clientPairs = append(result.clientPairs, pair) | ||||
| 				default: | ||||
| 					return hostConfig{}, errors.Errorf("invalid type %T for \"client\"", p) | ||||
| 				} | ||||
| 			} | ||||
| 		default: | ||||
| 			return hostConfig{}, errors.Errorf("invalid type %v for \"client\"", client) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if config.Header != nil { | ||||
| 		header := http.Header{} | ||||
| 		for key, ty := range config.Header { | ||||
| 			switch value := ty.(type) { | ||||
| 			case string: | ||||
| 				header[key] = []string{value} | ||||
| 			case []interface{}: | ||||
| 				header[key], err = makeStringSlice(value, nil) | ||||
| 				if err != nil { | ||||
| 					return hostConfig{}, err | ||||
| 				} | ||||
| 			default: | ||||
| 				return hostConfig{}, errors.Errorf("invalid type %v for header %q", ty, key) | ||||
| 			} | ||||
| 		} | ||||
| 		result.header = header | ||||
| 	} | ||||
|  | ||||
| 	return result, nil | ||||
| } | ||||
|  | ||||
| // makeStringSlice is a helper func to convert from []interface{} to []string. | ||||
| // Additionally an optional cb func may be passed to perform string mapping. | ||||
| func makeStringSlice(slice []interface{}, cb func(string) string) ([]string, error) { | ||||
| 	out := make([]string, len(slice)) | ||||
| 	for i, value := range slice { | ||||
| 		str, ok := value.(string) | ||||
| 		if !ok { | ||||
| 			return nil, errors.Errorf("unable to cast %v to string", value) | ||||
| 		} | ||||
|  | ||||
| 		if cb != nil { | ||||
| 			out[i] = cb(str) | ||||
| 		} else { | ||||
| 			out[i] = str | ||||
| 		} | ||||
| 	} | ||||
| 	return out, nil | ||||
| } | ||||
|  | ||||
| func makeAbsPath(p string, base string) string { | ||||
| 	if filepath.IsAbs(p) { | ||||
| 		return p | ||||
|   | ||||
| @@ -26,6 +26,8 @@ import ( | ||||
| 	"path/filepath" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/assert" | ||||
|  | ||||
| 	"github.com/containerd/containerd/log/logtest" | ||||
| 	"github.com/containerd/containerd/remotes/docker" | ||||
| ) | ||||
| @@ -74,8 +76,6 @@ func TestDefaultHosts(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestParseHostFile(t *testing.T) { | ||||
| 	ctx := logtest.WithT(context.Background(), t) | ||||
|  | ||||
| 	const testtoml = ` | ||||
| server = "https://test-default.registry" | ||||
| ca = "/etc/path/default" | ||||
| @@ -170,7 +170,7 @@ ca = "/etc/path/default" | ||||
| 			header:       http.Header{"x-custom-1": {"custom header"}}, | ||||
| 		}, | ||||
| 	} | ||||
| 	hosts, err := parseHostsFile(ctx, "", []byte(testtoml)) | ||||
| 	hosts, err := parseHostsFile("", []byte(testtoml)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -181,15 +181,7 @@ ca = "/etc/path/default" | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	if len(hosts) != len(expected) { | ||||
| 		t.Fatalf("Unexpected number of hosts %d, expected %d", len(hosts), len(expected)) | ||||
| 	} | ||||
|  | ||||
| 	for i := range hosts { | ||||
| 		if !compareHostConfig(hosts[i], expected[i]) { | ||||
| 			t.Fatalf("Mismatch at host %d", i) | ||||
| 		} | ||||
| 	} | ||||
| 	assert.ElementsMatch(t, expected, hosts) | ||||
| } | ||||
|  | ||||
| func TestLoadCertFiles(t *testing.T) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Maksym Pavlenko
					Maksym Pavlenko