Add support for service address ranges to Azure load balancers.
This commit is contained in:
		@@ -474,25 +474,41 @@ func (az *Cloud) reconcileLoadBalancer(lb network.LoadBalancer, pip *network.Pub
 | 
			
		||||
func (az *Cloud) reconcileSecurityGroup(sg network.SecurityGroup, clusterName string, service *api.Service) (network.SecurityGroup, bool, error) {
 | 
			
		||||
	serviceName := getServiceName(service)
 | 
			
		||||
	wantLb := len(service.Spec.Ports) > 0
 | 
			
		||||
	expectedSecurityRules := make([]network.SecurityRule, len(service.Spec.Ports))
 | 
			
		||||
 | 
			
		||||
	sourceRanges, err := serviceapi.GetLoadBalancerSourceRanges(service)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return sg, false, err
 | 
			
		||||
	}
 | 
			
		||||
	var sourceAddressPrefixes []string
 | 
			
		||||
	if sourceRanges == nil || serviceapi.IsAllowAll(sourceRanges) {
 | 
			
		||||
		sourceAddressPrefixes = []string{"Internet"}
 | 
			
		||||
	} else {
 | 
			
		||||
		for _, ip := range sourceRanges {
 | 
			
		||||
			sourceAddressPrefixes = append(sourceAddressPrefixes, ip.String())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	expectedSecurityRules := make([]network.SecurityRule, len(service.Spec.Ports)*len(sourceAddressPrefixes))
 | 
			
		||||
 | 
			
		||||
	for i, port := range service.Spec.Ports {
 | 
			
		||||
		securityRuleName := getRuleName(service, port)
 | 
			
		||||
		_, securityProto, _, err := getProtocolsFromKubernetesProtocol(port.Protocol)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return sg, false, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		expectedSecurityRules[i] = network.SecurityRule{
 | 
			
		||||
			Name: to.StringPtr(securityRuleName),
 | 
			
		||||
			Properties: &network.SecurityRulePropertiesFormat{
 | 
			
		||||
				Protocol:                 securityProto,
 | 
			
		||||
				SourcePortRange:          to.StringPtr("*"),
 | 
			
		||||
				DestinationPortRange:     to.StringPtr(strconv.Itoa(int(port.Port))),
 | 
			
		||||
				SourceAddressPrefix:      to.StringPtr("Internet"),
 | 
			
		||||
				DestinationAddressPrefix: to.StringPtr("*"),
 | 
			
		||||
				Access:    network.Allow,
 | 
			
		||||
				Direction: network.Inbound,
 | 
			
		||||
			},
 | 
			
		||||
		for j := range sourceAddressPrefixes {
 | 
			
		||||
			ix := i*len(sourceAddressPrefixes) + j
 | 
			
		||||
			expectedSecurityRules[ix] = network.SecurityRule{
 | 
			
		||||
				Name: to.StringPtr(securityRuleName),
 | 
			
		||||
				Properties: &network.SecurityRulePropertiesFormat{
 | 
			
		||||
					Protocol:                 securityProto,
 | 
			
		||||
					SourcePortRange:          to.StringPtr("*"),
 | 
			
		||||
					DestinationPortRange:     to.StringPtr(strconv.Itoa(int(port.Port))),
 | 
			
		||||
					SourceAddressPrefix:      to.StringPtr(sourceAddressPrefixes[j]),
 | 
			
		||||
					DestinationAddressPrefix: to.StringPtr("*"),
 | 
			
		||||
					Access:    network.Allow,
 | 
			
		||||
					Direction: network.Inbound,
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -187,6 +187,23 @@ func TestReconcileSecurityGroupRemoveServiceRemovesPort(t *testing.T) {
 | 
			
		||||
	validateSecurityGroup(t, sg, svcUpdated)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestReconcileSecurityWithSourceRanges(t *testing.T) {
 | 
			
		||||
	az := getTestCloud()
 | 
			
		||||
	svc := getTestService("servicea", 80, 443)
 | 
			
		||||
	svc.Spec.LoadBalancerSourceRanges = []string{
 | 
			
		||||
		"192.168.0.1/24",
 | 
			
		||||
		"10.0.0.1/32",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sg := getTestSecurityGroup(svc)
 | 
			
		||||
	sg, _, err := az.reconcileSecurityGroup(sg, testClusterName, &svc)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Unexpected error: %q", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	validateSecurityGroup(t, sg, svc)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getTestCloud() *Cloud {
 | 
			
		||||
	return &Cloud{
 | 
			
		||||
		Config: Config{
 | 
			
		||||
@@ -269,18 +286,30 @@ func getTestLoadBalancer(services ...api.Service) network.LoadBalancer {
 | 
			
		||||
	return lb
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getServiceSourceRanges(service *api.Service) []string {
 | 
			
		||||
	if len(service.Spec.LoadBalancerSourceRanges) == 0 {
 | 
			
		||||
		return []string{"Internet"}
 | 
			
		||||
	}
 | 
			
		||||
	return service.Spec.LoadBalancerSourceRanges
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getTestSecurityGroup(services ...api.Service) network.SecurityGroup {
 | 
			
		||||
	rules := []network.SecurityRule{}
 | 
			
		||||
 | 
			
		||||
	for _, service := range services {
 | 
			
		||||
		for _, port := range service.Spec.Ports {
 | 
			
		||||
			ruleName := getRuleName(&service, port)
 | 
			
		||||
			rules = append(rules, network.SecurityRule{
 | 
			
		||||
				Name: to.StringPtr(ruleName),
 | 
			
		||||
				Properties: &network.SecurityRulePropertiesFormat{
 | 
			
		||||
					DestinationPortRange: to.StringPtr(fmt.Sprintf("%d", port.Port)),
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			sources := getServiceSourceRanges(&service)
 | 
			
		||||
			for _, src := range sources {
 | 
			
		||||
				rules = append(rules, network.SecurityRule{
 | 
			
		||||
					Name: to.StringPtr(ruleName),
 | 
			
		||||
					Properties: &network.SecurityRulePropertiesFormat{
 | 
			
		||||
						SourceAddressPrefix:  to.StringPtr(src),
 | 
			
		||||
						DestinationPortRange: to.StringPtr(fmt.Sprintf("%d", port.Port)),
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -344,7 +373,7 @@ func validateLoadBalancer(t *testing.T, loadBalancer network.LoadBalancer, servi
 | 
			
		||||
 | 
			
		||||
	lenRules := len(*loadBalancer.Properties.LoadBalancingRules)
 | 
			
		||||
	if lenRules != expectedRuleCount {
 | 
			
		||||
		t.Errorf("Expected the loadbalancer to have %d rules. Found %d.", expectedRuleCount, lenRules)
 | 
			
		||||
		t.Errorf("Expected the loadbalancer to have %d rules. Found %d.\n%v", expectedRuleCount, lenRules, loadBalancer.Properties.LoadBalancingRules)
 | 
			
		||||
	}
 | 
			
		||||
	lenProbes := len(*loadBalancer.Properties.Probes)
 | 
			
		||||
	if lenProbes != expectedRuleCount {
 | 
			
		||||
@@ -356,25 +385,30 @@ func validateSecurityGroup(t *testing.T, securityGroup network.SecurityGroup, se
 | 
			
		||||
	expectedRuleCount := 0
 | 
			
		||||
	for _, svc := range services {
 | 
			
		||||
		for _, wantedRule := range svc.Spec.Ports {
 | 
			
		||||
			expectedRuleCount++
 | 
			
		||||
			wantedRuleName := getRuleName(&svc, wantedRule)
 | 
			
		||||
			foundRule := false
 | 
			
		||||
			for _, actualRule := range *securityGroup.Properties.SecurityRules {
 | 
			
		||||
				if strings.EqualFold(*actualRule.Name, wantedRuleName) &&
 | 
			
		||||
					*actualRule.Properties.DestinationPortRange == fmt.Sprintf("%d", wantedRule.Port) {
 | 
			
		||||
					foundRule = true
 | 
			
		||||
					break
 | 
			
		||||
			sources := getServiceSourceRanges(&svc)
 | 
			
		||||
 | 
			
		||||
			for _, source := range sources {
 | 
			
		||||
				expectedRuleCount++
 | 
			
		||||
				wantedRuleName := getRuleName(&svc, wantedRule)
 | 
			
		||||
				foundRule := false
 | 
			
		||||
				for _, actualRule := range *securityGroup.Properties.SecurityRules {
 | 
			
		||||
					if strings.EqualFold(*actualRule.Name, wantedRuleName) &&
 | 
			
		||||
						*actualRule.Properties.SourceAddressPrefix == source &&
 | 
			
		||||
						*actualRule.Properties.DestinationPortRange == fmt.Sprintf("%d", wantedRule.Port) {
 | 
			
		||||
						foundRule = true
 | 
			
		||||
						break
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				if !foundRule {
 | 
			
		||||
					t.Errorf("Expected security group rule but didn't find it: %q", wantedRuleName)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if !foundRule {
 | 
			
		||||
				t.Errorf("Expected security group rule but didn't find it: %q", wantedRuleName)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	lenRules := len(*securityGroup.Properties.SecurityRules)
 | 
			
		||||
	if lenRules != expectedRuleCount {
 | 
			
		||||
		t.Errorf("Expected the loadbalancer to have %d rules. Found %d.", expectedRuleCount, lenRules)
 | 
			
		||||
		t.Errorf("Expected the loadbalancer to have %d rules. Found %d.\n", expectedRuleCount, lenRules)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user