Add support for service address ranges to Azure load balancers.
This commit is contained in:
		@@ -474,27 +474,43 @@ func (az *Cloud) reconcileLoadBalancer(lb network.LoadBalancer, pip *network.Pub
 | 
				
			|||||||
func (az *Cloud) reconcileSecurityGroup(sg network.SecurityGroup, clusterName string, service *api.Service) (network.SecurityGroup, bool, error) {
 | 
					func (az *Cloud) reconcileSecurityGroup(sg network.SecurityGroup, clusterName string, service *api.Service) (network.SecurityGroup, bool, error) {
 | 
				
			||||||
	serviceName := getServiceName(service)
 | 
						serviceName := getServiceName(service)
 | 
				
			||||||
	wantLb := len(service.Spec.Ports) > 0
 | 
						wantLb := len(service.Spec.Ports) > 0
 | 
				
			||||||
	expectedSecurityRules := make([]network.SecurityRule, len(service.Spec.Ports))
 | 
					
 | 
				
			||||||
 | 
						sourceRanges, err := serviceapi.GetLoadBalancerSourceRanges(service)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return sg, false, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var sourceAddressPrefixes []string
 | 
				
			||||||
 | 
						if sourceRanges == nil || serviceapi.IsAllowAll(sourceRanges) {
 | 
				
			||||||
 | 
							sourceAddressPrefixes = []string{"Internet"}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							for _, ip := range sourceRanges {
 | 
				
			||||||
 | 
								sourceAddressPrefixes = append(sourceAddressPrefixes, ip.String())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						expectedSecurityRules := make([]network.SecurityRule, len(service.Spec.Ports)*len(sourceAddressPrefixes))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for i, port := range service.Spec.Ports {
 | 
						for i, port := range service.Spec.Ports {
 | 
				
			||||||
		securityRuleName := getRuleName(service, port)
 | 
							securityRuleName := getRuleName(service, port)
 | 
				
			||||||
		_, securityProto, _, err := getProtocolsFromKubernetesProtocol(port.Protocol)
 | 
							_, securityProto, _, err := getProtocolsFromKubernetesProtocol(port.Protocol)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return sg, false, err
 | 
								return sg, false, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							for j := range sourceAddressPrefixes {
 | 
				
			||||||
		expectedSecurityRules[i] = network.SecurityRule{
 | 
								ix := i*len(sourceAddressPrefixes) + j
 | 
				
			||||||
 | 
								expectedSecurityRules[ix] = network.SecurityRule{
 | 
				
			||||||
				Name: to.StringPtr(securityRuleName),
 | 
									Name: to.StringPtr(securityRuleName),
 | 
				
			||||||
				Properties: &network.SecurityRulePropertiesFormat{
 | 
									Properties: &network.SecurityRulePropertiesFormat{
 | 
				
			||||||
					Protocol:                 securityProto,
 | 
										Protocol:                 securityProto,
 | 
				
			||||||
					SourcePortRange:          to.StringPtr("*"),
 | 
										SourcePortRange:          to.StringPtr("*"),
 | 
				
			||||||
					DestinationPortRange:     to.StringPtr(strconv.Itoa(int(port.Port))),
 | 
										DestinationPortRange:     to.StringPtr(strconv.Itoa(int(port.Port))),
 | 
				
			||||||
				SourceAddressPrefix:      to.StringPtr("Internet"),
 | 
										SourceAddressPrefix:      to.StringPtr(sourceAddressPrefixes[j]),
 | 
				
			||||||
					DestinationAddressPrefix: to.StringPtr("*"),
 | 
										DestinationAddressPrefix: to.StringPtr("*"),
 | 
				
			||||||
					Access:    network.Allow,
 | 
										Access:    network.Allow,
 | 
				
			||||||
					Direction: network.Inbound,
 | 
										Direction: network.Inbound,
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// update security rules
 | 
						// update security rules
 | 
				
			||||||
	dirtySg := false
 | 
						dirtySg := false
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -187,6 +187,23 @@ func TestReconcileSecurityGroupRemoveServiceRemovesPort(t *testing.T) {
 | 
				
			|||||||
	validateSecurityGroup(t, sg, svcUpdated)
 | 
						validateSecurityGroup(t, sg, svcUpdated)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestReconcileSecurityWithSourceRanges(t *testing.T) {
 | 
				
			||||||
 | 
						az := getTestCloud()
 | 
				
			||||||
 | 
						svc := getTestService("servicea", 80, 443)
 | 
				
			||||||
 | 
						svc.Spec.LoadBalancerSourceRanges = []string{
 | 
				
			||||||
 | 
							"192.168.0.1/24",
 | 
				
			||||||
 | 
							"10.0.0.1/32",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sg := getTestSecurityGroup(svc)
 | 
				
			||||||
 | 
						sg, _, err := az.reconcileSecurityGroup(sg, testClusterName, &svc)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Errorf("Unexpected error: %q", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						validateSecurityGroup(t, sg, svc)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getTestCloud() *Cloud {
 | 
					func getTestCloud() *Cloud {
 | 
				
			||||||
	return &Cloud{
 | 
						return &Cloud{
 | 
				
			||||||
		Config: Config{
 | 
							Config: Config{
 | 
				
			||||||
@@ -269,20 +286,32 @@ func getTestLoadBalancer(services ...api.Service) network.LoadBalancer {
 | 
				
			|||||||
	return lb
 | 
						return lb
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getServiceSourceRanges(service *api.Service) []string {
 | 
				
			||||||
 | 
						if len(service.Spec.LoadBalancerSourceRanges) == 0 {
 | 
				
			||||||
 | 
							return []string{"Internet"}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return service.Spec.LoadBalancerSourceRanges
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getTestSecurityGroup(services ...api.Service) network.SecurityGroup {
 | 
					func getTestSecurityGroup(services ...api.Service) network.SecurityGroup {
 | 
				
			||||||
	rules := []network.SecurityRule{}
 | 
						rules := []network.SecurityRule{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, service := range services {
 | 
						for _, service := range services {
 | 
				
			||||||
		for _, port := range service.Spec.Ports {
 | 
							for _, port := range service.Spec.Ports {
 | 
				
			||||||
			ruleName := getRuleName(&service, port)
 | 
								ruleName := getRuleName(&service, port)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								sources := getServiceSourceRanges(&service)
 | 
				
			||||||
 | 
								for _, src := range sources {
 | 
				
			||||||
				rules = append(rules, network.SecurityRule{
 | 
									rules = append(rules, network.SecurityRule{
 | 
				
			||||||
					Name: to.StringPtr(ruleName),
 | 
										Name: to.StringPtr(ruleName),
 | 
				
			||||||
					Properties: &network.SecurityRulePropertiesFormat{
 | 
										Properties: &network.SecurityRulePropertiesFormat{
 | 
				
			||||||
 | 
											SourceAddressPrefix:  to.StringPtr(src),
 | 
				
			||||||
						DestinationPortRange: to.StringPtr(fmt.Sprintf("%d", port.Port)),
 | 
											DestinationPortRange: to.StringPtr(fmt.Sprintf("%d", port.Port)),
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	sg := network.SecurityGroup{
 | 
						sg := network.SecurityGroup{
 | 
				
			||||||
		Properties: &network.SecurityGroupPropertiesFormat{
 | 
							Properties: &network.SecurityGroupPropertiesFormat{
 | 
				
			||||||
@@ -344,7 +373,7 @@ func validateLoadBalancer(t *testing.T, loadBalancer network.LoadBalancer, servi
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	lenRules := len(*loadBalancer.Properties.LoadBalancingRules)
 | 
						lenRules := len(*loadBalancer.Properties.LoadBalancingRules)
 | 
				
			||||||
	if lenRules != expectedRuleCount {
 | 
						if lenRules != expectedRuleCount {
 | 
				
			||||||
		t.Errorf("Expected the loadbalancer to have %d rules. Found %d.", expectedRuleCount, lenRules)
 | 
							t.Errorf("Expected the loadbalancer to have %d rules. Found %d.\n%v", expectedRuleCount, lenRules, loadBalancer.Properties.LoadBalancingRules)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	lenProbes := len(*loadBalancer.Properties.Probes)
 | 
						lenProbes := len(*loadBalancer.Properties.Probes)
 | 
				
			||||||
	if lenProbes != expectedRuleCount {
 | 
						if lenProbes != expectedRuleCount {
 | 
				
			||||||
@@ -356,11 +385,15 @@ func validateSecurityGroup(t *testing.T, securityGroup network.SecurityGroup, se
 | 
				
			|||||||
	expectedRuleCount := 0
 | 
						expectedRuleCount := 0
 | 
				
			||||||
	for _, svc := range services {
 | 
						for _, svc := range services {
 | 
				
			||||||
		for _, wantedRule := range svc.Spec.Ports {
 | 
							for _, wantedRule := range svc.Spec.Ports {
 | 
				
			||||||
 | 
								sources := getServiceSourceRanges(&svc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for _, source := range sources {
 | 
				
			||||||
				expectedRuleCount++
 | 
									expectedRuleCount++
 | 
				
			||||||
				wantedRuleName := getRuleName(&svc, wantedRule)
 | 
									wantedRuleName := getRuleName(&svc, wantedRule)
 | 
				
			||||||
				foundRule := false
 | 
									foundRule := false
 | 
				
			||||||
				for _, actualRule := range *securityGroup.Properties.SecurityRules {
 | 
									for _, actualRule := range *securityGroup.Properties.SecurityRules {
 | 
				
			||||||
					if strings.EqualFold(*actualRule.Name, wantedRuleName) &&
 | 
										if strings.EqualFold(*actualRule.Name, wantedRuleName) &&
 | 
				
			||||||
 | 
											*actualRule.Properties.SourceAddressPrefix == source &&
 | 
				
			||||||
						*actualRule.Properties.DestinationPortRange == fmt.Sprintf("%d", wantedRule.Port) {
 | 
											*actualRule.Properties.DestinationPortRange == fmt.Sprintf("%d", wantedRule.Port) {
 | 
				
			||||||
						foundRule = true
 | 
											foundRule = true
 | 
				
			||||||
						break
 | 
											break
 | 
				
			||||||
@@ -371,10 +404,11 @@ func validateSecurityGroup(t *testing.T, securityGroup network.SecurityGroup, se
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	lenRules := len(*securityGroup.Properties.SecurityRules)
 | 
						lenRules := len(*securityGroup.Properties.SecurityRules)
 | 
				
			||||||
	if lenRules != expectedRuleCount {
 | 
						if lenRules != expectedRuleCount {
 | 
				
			||||||
		t.Errorf("Expected the loadbalancer to have %d rules. Found %d.", expectedRuleCount, lenRules)
 | 
							t.Errorf("Expected the loadbalancer to have %d rules. Found %d.\n", expectedRuleCount, lenRules)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user