Merge pull request #122498 from Gekko0114/close
Allow framework plugins to be closed
This commit is contained in:
		@@ -652,6 +652,9 @@ type Framework interface {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// SetPodNominator sets the PodNominator
 | 
						// SetPodNominator sets the PodNominator
 | 
				
			||||||
	SetPodNominator(nominator PodNominator)
 | 
						SetPodNominator(nominator PodNominator)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Close calls Close method of each plugin.
 | 
				
			||||||
 | 
						Close() error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Handle provides data and some tools that plugins can use. It is
 | 
					// Handle provides data and some tools that plugins can use. It is
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,7 +18,9 @@ package runtime
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"sort"
 | 
						"sort"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@@ -66,6 +68,9 @@ type frameworkImpl struct {
 | 
				
			|||||||
	postBindPlugins      []framework.PostBindPlugin
 | 
						postBindPlugins      []framework.PostBindPlugin
 | 
				
			||||||
	permitPlugins        []framework.PermitPlugin
 | 
						permitPlugins        []framework.PermitPlugin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// pluginsMap contains all plugins, by name.
 | 
				
			||||||
 | 
						pluginsMap map[string]framework.Plugin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	clientSet       clientset.Interface
 | 
						clientSet       clientset.Interface
 | 
				
			||||||
	kubeConfig      *restclient.Config
 | 
						kubeConfig      *restclient.Config
 | 
				
			||||||
	eventRecorder   events.EventRecorder
 | 
						eventRecorder   events.EventRecorder
 | 
				
			||||||
@@ -297,7 +302,7 @@ func NewFramework(ctx context.Context, r Registry, profile *config.KubeScheduler
 | 
				
			|||||||
		PluginConfig:             make([]config.PluginConfig, 0, len(pg)),
 | 
							PluginConfig:             make([]config.PluginConfig, 0, len(pg)),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pluginsMap := make(map[string]framework.Plugin)
 | 
						f.pluginsMap = make(map[string]framework.Plugin)
 | 
				
			||||||
	for name, factory := range r {
 | 
						for name, factory := range r {
 | 
				
			||||||
		// initialize only needed plugins.
 | 
							// initialize only needed plugins.
 | 
				
			||||||
		if !pg.Has(name) {
 | 
							if !pg.Has(name) {
 | 
				
			||||||
@@ -315,21 +320,21 @@ func NewFramework(ctx context.Context, r Registry, profile *config.KubeScheduler
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, fmt.Errorf("initializing plugin %q: %w", name, err)
 | 
								return nil, fmt.Errorf("initializing plugin %q: %w", name, err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		pluginsMap[name] = p
 | 
							f.pluginsMap[name] = p
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		f.fillEnqueueExtensions(p)
 | 
							f.fillEnqueueExtensions(p)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// initialize plugins per individual extension points
 | 
						// initialize plugins per individual extension points
 | 
				
			||||||
	for _, e := range f.getExtensionPoints(profile.Plugins) {
 | 
						for _, e := range f.getExtensionPoints(profile.Plugins) {
 | 
				
			||||||
		if err := updatePluginList(e.slicePtr, *e.plugins, pluginsMap); err != nil {
 | 
							if err := updatePluginList(e.slicePtr, *e.plugins, f.pluginsMap); err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// initialize multiPoint plugins to their expanded extension points
 | 
						// initialize multiPoint plugins to their expanded extension points
 | 
				
			||||||
	if len(profile.Plugins.MultiPoint.Enabled) > 0 {
 | 
						if len(profile.Plugins.MultiPoint.Enabled) > 0 {
 | 
				
			||||||
		if err := f.expandMultiPointPlugins(logger, profile, pluginsMap); err != nil {
 | 
							if err := f.expandMultiPointPlugins(logger, profile); err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -341,7 +346,7 @@ func NewFramework(ctx context.Context, r Registry, profile *config.KubeScheduler
 | 
				
			|||||||
		return nil, fmt.Errorf("at least one bind plugin is needed for profile with scheduler name %q", profile.SchedulerName)
 | 
							return nil, fmt.Errorf("at least one bind plugin is needed for profile with scheduler name %q", profile.SchedulerName)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := getScoreWeights(f, pluginsMap, append(profile.Plugins.Score.Enabled, profile.Plugins.MultiPoint.Enabled...)); err != nil {
 | 
						if err := getScoreWeights(f, append(profile.Plugins.Score.Enabled, profile.Plugins.MultiPoint.Enabled...)); err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -405,14 +410,29 @@ func (f *frameworkImpl) SetPodNominator(n framework.PodNominator) {
 | 
				
			|||||||
	f.PodNominator = n
 | 
						f.PodNominator = n
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Close closes each plugin, when they implement io.Closer interface.
 | 
				
			||||||
 | 
					func (f *frameworkImpl) Close() error {
 | 
				
			||||||
 | 
						var errs []error
 | 
				
			||||||
 | 
						for name, plugin := range f.pluginsMap {
 | 
				
			||||||
 | 
							if closer, ok := plugin.(io.Closer); ok {
 | 
				
			||||||
 | 
								err := closer.Close()
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									errs = append(errs, fmt.Errorf("%s failed to close: %w", name, err))
 | 
				
			||||||
 | 
									// We try to close all plugins even if we got errors from some.
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return errors.Join(errs...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// getScoreWeights makes sure that, between MultiPoint-Score plugin weights and individual Score
 | 
					// getScoreWeights makes sure that, between MultiPoint-Score plugin weights and individual Score
 | 
				
			||||||
// plugin weights there is not an overflow of MaxTotalScore.
 | 
					// plugin weights there is not an overflow of MaxTotalScore.
 | 
				
			||||||
func getScoreWeights(f *frameworkImpl, pluginsMap map[string]framework.Plugin, plugins []config.Plugin) error {
 | 
					func getScoreWeights(f *frameworkImpl, plugins []config.Plugin) error {
 | 
				
			||||||
	var totalPriority int64
 | 
						var totalPriority int64
 | 
				
			||||||
	scorePlugins := reflect.ValueOf(&f.scorePlugins).Elem()
 | 
						scorePlugins := reflect.ValueOf(&f.scorePlugins).Elem()
 | 
				
			||||||
	pluginType := scorePlugins.Type().Elem()
 | 
						pluginType := scorePlugins.Type().Elem()
 | 
				
			||||||
	for _, e := range plugins {
 | 
						for _, e := range plugins {
 | 
				
			||||||
		pg := pluginsMap[e.Name]
 | 
							pg := f.pluginsMap[e.Name]
 | 
				
			||||||
		if !reflect.TypeOf(pg).Implements(pluginType) {
 | 
							if !reflect.TypeOf(pg).Implements(pluginType) {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -469,7 +489,7 @@ func (os *orderedSet) delete(s string) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (f *frameworkImpl) expandMultiPointPlugins(logger klog.Logger, profile *config.KubeSchedulerProfile, pluginsMap map[string]framework.Plugin) error {
 | 
					func (f *frameworkImpl) expandMultiPointPlugins(logger klog.Logger, profile *config.KubeSchedulerProfile) error {
 | 
				
			||||||
	// initialize MultiPoint plugins
 | 
						// initialize MultiPoint plugins
 | 
				
			||||||
	for _, e := range f.getExtensionPoints(profile.Plugins) {
 | 
						for _, e := range f.getExtensionPoints(profile.Plugins) {
 | 
				
			||||||
		plugins := reflect.ValueOf(e.slicePtr).Elem()
 | 
							plugins := reflect.ValueOf(e.slicePtr).Elem()
 | 
				
			||||||
@@ -495,7 +515,7 @@ func (f *frameworkImpl) expandMultiPointPlugins(logger klog.Logger, profile *con
 | 
				
			|||||||
		multiPointEnabled := newOrderedSet()
 | 
							multiPointEnabled := newOrderedSet()
 | 
				
			||||||
		overridePlugins := newOrderedSet()
 | 
							overridePlugins := newOrderedSet()
 | 
				
			||||||
		for _, ep := range profile.Plugins.MultiPoint.Enabled {
 | 
							for _, ep := range profile.Plugins.MultiPoint.Enabled {
 | 
				
			||||||
			pg, ok := pluginsMap[ep.Name]
 | 
								pg, ok := f.pluginsMap[ep.Name]
 | 
				
			||||||
			if !ok {
 | 
								if !ok {
 | 
				
			||||||
				return fmt.Errorf("%s %q does not exist", pluginType.Name(), ep.Name)
 | 
									return fmt.Errorf("%s %q does not exist", pluginType.Name(), ep.Name)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -539,17 +559,17 @@ func (f *frameworkImpl) expandMultiPointPlugins(logger klog.Logger, profile *con
 | 
				
			|||||||
		// part 1
 | 
							// part 1
 | 
				
			||||||
		for _, name := range slice.CopyStrings(enabledSet.list) {
 | 
							for _, name := range slice.CopyStrings(enabledSet.list) {
 | 
				
			||||||
			if overridePlugins.has(name) {
 | 
								if overridePlugins.has(name) {
 | 
				
			||||||
				newPlugins = reflect.Append(newPlugins, reflect.ValueOf(pluginsMap[name]))
 | 
									newPlugins = reflect.Append(newPlugins, reflect.ValueOf(f.pluginsMap[name]))
 | 
				
			||||||
				enabledSet.delete(name)
 | 
									enabledSet.delete(name)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		// part 2
 | 
							// part 2
 | 
				
			||||||
		for _, name := range multiPointEnabled.list {
 | 
							for _, name := range multiPointEnabled.list {
 | 
				
			||||||
			newPlugins = reflect.Append(newPlugins, reflect.ValueOf(pluginsMap[name]))
 | 
								newPlugins = reflect.Append(newPlugins, reflect.ValueOf(f.pluginsMap[name]))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		// part 3
 | 
							// part 3
 | 
				
			||||||
		for _, name := range enabledSet.list {
 | 
							for _, name := range enabledSet.list {
 | 
				
			||||||
			newPlugins = reflect.Append(newPlugins, reflect.ValueOf(pluginsMap[name]))
 | 
								newPlugins = reflect.Append(newPlugins, reflect.ValueOf(f.pluginsMap[name]))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		plugins.Set(newPlugins)
 | 
							plugins.Set(newPlugins)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -54,6 +54,7 @@ const (
 | 
				
			|||||||
	testPlugin                        = "test-plugin"
 | 
						testPlugin                        = "test-plugin"
 | 
				
			||||||
	permitPlugin                      = "permit-plugin"
 | 
						permitPlugin                      = "permit-plugin"
 | 
				
			||||||
	bindPlugin                        = "bind-plugin"
 | 
						bindPlugin                        = "bind-plugin"
 | 
				
			||||||
 | 
						testCloseErrorPlugin              = "test-close-error-plugin"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	testProfileName              = "test-profile"
 | 
						testProfileName              = "test-profile"
 | 
				
			||||||
	testPercentageOfNodesToScore = 35
 | 
						testPercentageOfNodesToScore = 35
 | 
				
			||||||
@@ -238,6 +239,25 @@ func (pl *TestPlugin) Bind(ctx context.Context, state *framework.CycleState, p *
 | 
				
			|||||||
	return framework.NewStatus(framework.Code(pl.inj.BindStatus), injectReason)
 | 
						return framework.NewStatus(framework.Code(pl.inj.BindStatus), injectReason)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newTestCloseErrorPlugin(_ context.Context, injArgs runtime.Object, f framework.Handle) (framework.Plugin, error) {
 | 
				
			||||||
 | 
						return &TestCloseErrorPlugin{name: testCloseErrorPlugin}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TestCloseErrorPlugin implements for Close test.
 | 
				
			||||||
 | 
					type TestCloseErrorPlugin struct {
 | 
				
			||||||
 | 
						name string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (pl *TestCloseErrorPlugin) Name() string {
 | 
				
			||||||
 | 
						return pl.name
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var errClose = errors.New("close err")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (pl *TestCloseErrorPlugin) Close() error {
 | 
				
			||||||
 | 
						return errClose
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TestPreFilterPlugin only implements PreFilterPlugin interface.
 | 
					// TestPreFilterPlugin only implements PreFilterPlugin interface.
 | 
				
			||||||
type TestPreFilterPlugin struct {
 | 
					type TestPreFilterPlugin struct {
 | 
				
			||||||
	PreFilterCalled int
 | 
						PreFilterCalled int
 | 
				
			||||||
@@ -379,6 +399,7 @@ var registry = func() Registry {
 | 
				
			|||||||
	r.Register(testPlugin, newTestPlugin)
 | 
						r.Register(testPlugin, newTestPlugin)
 | 
				
			||||||
	r.Register(queueSortPlugin, newQueueSortPlugin)
 | 
						r.Register(queueSortPlugin, newQueueSortPlugin)
 | 
				
			||||||
	r.Register(bindPlugin, newBindPlugin)
 | 
						r.Register(bindPlugin, newBindPlugin)
 | 
				
			||||||
 | 
						r.Register(testCloseErrorPlugin, newTestCloseErrorPlugin)
 | 
				
			||||||
	return r
 | 
						return r
 | 
				
			||||||
}()
 | 
					}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -3261,6 +3282,53 @@ func TestListPlugins(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestClose(t *testing.T) {
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name    string
 | 
				
			||||||
 | 
							plugins *config.Plugins
 | 
				
			||||||
 | 
							wantErr error
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "close doesn't return error",
 | 
				
			||||||
 | 
								plugins: &config.Plugins{
 | 
				
			||||||
 | 
									MultiPoint: config.PluginSet{
 | 
				
			||||||
 | 
										Enabled: []config.Plugin{
 | 
				
			||||||
 | 
											{Name: testPlugin, Weight: 5},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "close returns error",
 | 
				
			||||||
 | 
								plugins: &config.Plugins{
 | 
				
			||||||
 | 
									MultiPoint: config.PluginSet{
 | 
				
			||||||
 | 
										Enabled: []config.Plugin{
 | 
				
			||||||
 | 
											{Name: testPlugin, Weight: 5},
 | 
				
			||||||
 | 
											{Name: testCloseErrorPlugin},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								wantErr: errClose,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, tc := range tests {
 | 
				
			||||||
 | 
							t.Run(tc.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								_, ctx := ktesting.NewTestContext(t)
 | 
				
			||||||
 | 
								ctx, cancel := context.WithCancel(ctx)
 | 
				
			||||||
 | 
								defer cancel()
 | 
				
			||||||
 | 
								fw, err := NewFramework(ctx, registry, &config.KubeSchedulerProfile{Plugins: tc.plugins})
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Fatalf("Unexpected error during calling NewFramework, got %v", err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								err = fw.Close()
 | 
				
			||||||
 | 
								if !errors.Is(err, tc.wantErr) {
 | 
				
			||||||
 | 
									t.Fatalf("Unexpected error from Close(), got: %v, want: %v", err, tc.wantErr)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func buildScoreConfigDefaultWeights(ps ...string) *config.Plugins {
 | 
					func buildScoreConfigDefaultWeights(ps ...string) *config.Plugins {
 | 
				
			||||||
	return buildScoreConfigWithWeights(defaultWeights, ps...)
 | 
						return buildScoreConfigWithWeights(defaultWeights, ps...)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -70,6 +70,18 @@ func (m Map) HandlesSchedulerName(name string) bool {
 | 
				
			|||||||
	return ok
 | 
						return ok
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Close closes all frameworks registered in this map.
 | 
				
			||||||
 | 
					func (m Map) Close() error {
 | 
				
			||||||
 | 
						var errs []error
 | 
				
			||||||
 | 
						for name, f := range m {
 | 
				
			||||||
 | 
							err := f.Close()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								errs = append(errs, fmt.Errorf("framework %s failed to close: %w", name, err))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return errors.Join(errs...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewRecorderFactory returns a RecorderFactory for the broadcaster.
 | 
					// NewRecorderFactory returns a RecorderFactory for the broadcaster.
 | 
				
			||||||
func NewRecorderFactory(b events.EventBroadcaster) RecorderFactory {
 | 
					func NewRecorderFactory(b events.EventBroadcaster) RecorderFactory {
 | 
				
			||||||
	return func(name string) events.EventRecorder {
 | 
						return func(name string) events.EventRecorder {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -416,6 +416,12 @@ func (sched *Scheduler) Run(ctx context.Context) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	<-ctx.Done()
 | 
						<-ctx.Done()
 | 
				
			||||||
	sched.SchedulingQueue.Close()
 | 
						sched.SchedulingQueue.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If the plugins satisfy the io.Closer interface, they are closed.
 | 
				
			||||||
 | 
						err := sched.Profiles.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logger.Error(err, "Failed to close plugins")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewInformerFactory creates a SharedInformerFactory and initializes a scheduler specific
 | 
					// NewInformerFactory creates a SharedInformerFactory and initializes a scheduler specific
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user