diff --git a/pkg/cloudprovider/providers/azure/azure_instance_metadata.go b/pkg/cloudprovider/providers/azure/azure_instance_metadata.go index a9d29ec5a51..ee60c619845 100644 --- a/pkg/cloudprovider/providers/azure/azure_instance_metadata.go +++ b/pkg/cloudprovider/providers/azure/azure_instance_metadata.go @@ -100,7 +100,7 @@ func (i *InstanceMetadata) queryMetadataBytes(path, format string) ([]byte, erro q := req.URL.Query() q.Add("format", format) - q.Add("api-version", "2017-04-02") + q.Add("api-version", "2017-12-01") req.URL.RawQuery = q.Encode() resp, err := client.Do(req) diff --git a/pkg/cloudprovider/providers/azure/azure_standard.go b/pkg/cloudprovider/providers/azure/azure_standard.go index 9591ca92a67..c7e3a91cf6a 100644 --- a/pkg/cloudprovider/providers/azure/azure_standard.go +++ b/pkg/cloudprovider/providers/azure/azure_standard.go @@ -408,14 +408,29 @@ func (as *availabilitySet) GetInstanceTypeByNodeName(name string) (string, error return string(machine.HardwareProfile.VMSize), nil } -// GetZoneByNodeName gets zone from instance view. +// GetZoneByNodeName gets availability zone for the specified node. If the node is not running +// with availability zone, then it returns fault domain. func (as *availabilitySet) GetZoneByNodeName(name string) (cloudprovider.Zone, error) { vm, err := as.getVirtualMachine(types.NodeName(name)) if err != nil { return cloudprovider.Zone{}, err } - failureDomain := strconv.Itoa(int(*vm.VirtualMachineProperties.InstanceView.PlatformFaultDomain)) + var failureDomain string + if vm.Zones != nil && len(*vm.Zones) > 0 { + // Get availability zone for the node. + zones := *vm.Zones + zoneID, err := strconv.Atoi(zones[0]) + if err != nil { + return cloudprovider.Zone{}, fmt.Errorf("failed to parse zone %q: %v", zones, err) + } + + failureDomain = as.makeZone(zoneID) + } else { + // Availability zone is not used for the node, falling back to fault domain. + failureDomain = strconv.Itoa(int(*vm.VirtualMachineProperties.InstanceView.PlatformFaultDomain)) + } + zone := cloudprovider.Zone{ FailureDomain: failureDomain, Region: *(vm.Location), diff --git a/pkg/cloudprovider/providers/azure/azure_test.go b/pkg/cloudprovider/providers/azure/azure_test.go index c8b0eb7be1b..979e6a2bdd6 100644 --- a/pkg/cloudprovider/providers/azure/azure_test.go +++ b/pkg/cloudprovider/providers/azure/azure_test.go @@ -22,6 +22,7 @@ import ( "encoding/json" "fmt" "math" + "net" "net/http" "net/http/httptest" "reflect" @@ -1667,35 +1668,82 @@ func validateEmptyConfig(t *testing.T, config string) { t.Errorf("got incorrect value for CloudProviderRateLimit") } } + func TestGetZone(t *testing.T) { - data := `{"ID":"_azdev","UD":"0","FD":"99"}` - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, data) - })) - defer ts.Close() - - cloud := &Cloud{} - cloud.Location = "eastus" - - zone, err := cloud.getZoneFromURL(ts.URL) - if err != nil { - t.Errorf("Unexpected error: %v", err) + cloud := &Cloud{ + Config: Config{ + Location: "eastus", + }, + metadata: &InstanceMetadata{}, } - if zone.FailureDomain != "99" { - t.Errorf("Unexpected value: %s, expected '99'", zone.FailureDomain) + testcases := []struct { + name string + zone string + faultDomain string + expected string + }{ + { + name: "GetZone should get real zone if only node's zone is set", + zone: "1", + expected: "eastus-1", + }, + { + name: "GetZone should get real zone if both node's zone and FD are set", + zone: "1", + faultDomain: "99", + expected: "eastus-1", + }, + { + name: "GetZone should get faultDomain if node's zone isn't set", + faultDomain: "99", + expected: "99", + }, } - if zone.Region != cloud.Location { - t.Errorf("Expected: %s, saw: %s", cloud.Location, zone.Region) + + for _, test := range testcases { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Errorf("Test [%s] unexpected error: %v", test.name, err) + } + + mux := http.NewServeMux() + mux.Handle("/v1/InstanceInfo/FD", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, test.faultDomain) + })) + mux.Handle("/instance/compute/zone", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, test.zone) + })) + go func() { + http.Serve(listener, mux) + }() + defer listener.Close() + + cloud.metadata.baseURL = "http://" + listener.Addr().String() + "/" + zone, err := cloud.GetZone(context.Background()) + if err != nil { + t.Errorf("Test [%s] unexpected error: %v", test.name, err) + } + if zone.FailureDomain != test.expected { + t.Errorf("Test [%s] unexpected zone: %s, expected %q", test.name, zone.FailureDomain, test.expected) + } + if zone.Region != cloud.Location { + t.Errorf("Test [%s] unexpected region: %s, expected: %s", test.name, zone.Region, cloud.Location) + } } } func TestFetchFaultDomain(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, `{"ID":"_azdev","UD":"0","FD":"99"}`) + fmt.Fprint(w, "99") })) defer ts.Close() - faultDomain, err := fetchFaultDomain(ts.URL) + cloud := &Cloud{} + cloud.metadata = &InstanceMetadata{ + baseURL: ts.URL + "/", + } + + faultDomain, err := cloud.fetchFaultDomain() if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -1707,23 +1755,6 @@ func TestFetchFaultDomain(t *testing.T) { } } -func TestDecodeInstanceInfo(t *testing.T) { - response := `{"ID":"_azdev","UD":"0","FD":"99"}` - - faultDomain, err := readFaultDomain(strings.NewReader(response)) - if err != nil { - t.Errorf("Unexpected error in ReadFaultDomain: %v", err) - } - - if faultDomain == nil { - t.Error("Fault domain was unexpectedly nil") - } - - if *faultDomain != "99" { - t.Error("got incorrect fault domain") - } -} - func TestGetNodeNameByProviderID(t *testing.T) { az := getTestCloud() providers := []struct { diff --git a/pkg/cloudprovider/providers/azure/azure_vmss.go b/pkg/cloudprovider/providers/azure/azure_vmss.go index 0cf6bf841a8..2461772b74e 100644 --- a/pkg/cloudprovider/providers/azure/azure_vmss.go +++ b/pkg/cloudprovider/providers/azure/azure_vmss.go @@ -211,7 +211,11 @@ func (ss *scaleSet) GetInstanceTypeByNodeName(name string) (string, error) { return "", nil } -// GetZoneByNodeName gets cloudprovider.Zone by node name. +// GetZoneByNodeName gets availability zone for the specified node. If the node is not running +// with availability zone, then it returns fault domain. +// TODO(feiskyer): Add availability zone support of VirtualMachineScaleSetVM +// after it is released in Azure Go SDK. +// Refer https://github.com/Azure/azure-sdk-for-go/pull/2224. func (ss *scaleSet) GetZoneByNodeName(name string) (cloudprovider.Zone, error) { managedByAS, err := ss.isNodeManagedByAvailabilitySet(name) if err != nil { diff --git a/pkg/cloudprovider/providers/azure/azure_zones.go b/pkg/cloudprovider/providers/azure/azure_zones.go index 1cf1bb929fb..29a15251965 100644 --- a/pkg/cloudprovider/providers/azure/azure_zones.go +++ b/pkg/cloudprovider/providers/azure/azure_zones.go @@ -18,39 +18,61 @@ package azure import ( "context" - "encoding/json" - "io" - "io/ioutil" - "net/http" + "fmt" + "strconv" + "strings" "sync" + "github.com/golang/glog" "k8s.io/apimachinery/pkg/types" "k8s.io/kubernetes/pkg/cloudprovider" ) -const instanceInfoURL = "http://169.254.169.254/metadata/v1/InstanceInfo" +const ( + faultDomainURI = "v1/InstanceInfo/FD" + zoneMetadataURI = "instance/compute/zone" +) var faultMutex = &sync.Mutex{} var faultDomain *string -type instanceInfo struct { - ID string `json:"ID"` - UpdateDomain string `json:"UD"` - FaultDomain string `json:"FD"` +// makeZone returns the zone value in format of -. +func (az *Cloud) makeZone(zoneID int) string { + return fmt.Sprintf("%s-%d", strings.ToLower(az.Location), zoneID) } -// GetZone returns the Zone containing the current failure zone and locality region that the program is running in +// GetZone returns the Zone containing the current availability zone and locality region that the program is running in. +// If the node is not running with availability zones, then it will fall back to fault domain. func (az *Cloud) GetZone(ctx context.Context) (cloudprovider.Zone, error) { - return az.getZoneFromURL(instanceInfoURL) + zone, err := az.metadata.Text(zoneMetadataURI) + if err != nil { + return cloudprovider.Zone{}, err + } + + if zone == "" { + glog.V(3).Infof("Availability zone is not enabled for the node, falling back to fault domain") + return az.getZoneFromFaultDomain() + } + + zoneID, err := strconv.Atoi(zone) + if err != nil { + return cloudprovider.Zone{}, fmt.Errorf("failed to parse zone ID %q: %v", zone, err) + } + + return cloudprovider.Zone{ + FailureDomain: az.makeZone(zoneID), + Region: az.Location, + }, nil } -// This is injectable for testing. -func (az *Cloud) getZoneFromURL(url string) (cloudprovider.Zone, error) { +// getZoneFromFaultDomain gets fault domain for the instance. +// Fault domain is the fallback when availability zone is not enabled for the node. +func (az *Cloud) getZoneFromFaultDomain() (cloudprovider.Zone, error) { faultMutex.Lock() defer faultMutex.Unlock() if faultDomain == nil { var err error - faultDomain, err = fetchFaultDomain(url) + faultDomain, err = az.fetchFaultDomain() if err != nil { return cloudprovider.Zone{}, err } @@ -81,24 +103,11 @@ func (az *Cloud) GetZoneByNodeName(ctx context.Context, nodeName types.NodeName) return az.vmSet.GetZoneByNodeName(string(nodeName)) } -func fetchFaultDomain(url string) (*string, error) { - resp, err := http.Get(url) +func (az *Cloud) fetchFaultDomain() (*string, error) { + faultDomain, err := az.metadata.Text(faultDomainURI) if err != nil { return nil, err } - defer resp.Body.Close() - return readFaultDomain(resp.Body) -} -func readFaultDomain(reader io.Reader) (*string, error) { - var instanceInfo instanceInfo - body, err := ioutil.ReadAll(reader) - if err != nil { - return nil, err - } - err = json.Unmarshal(body, &instanceInfo) - if err != nil { - return nil, err - } - return &instanceInfo.FaultDomain, nil + return &faultDomain, nil }