Deep-copy functions autogeneration.

This commit is contained in:
Wojciech Tyczynski
2015-05-13 15:52:53 +02:00
parent 11b058cc7d
commit b2280db724
14 changed files with 1049 additions and 13 deletions

View File

@@ -47,9 +47,6 @@ func generateConversions(t *testing.T, version string) bytes.Buffer {
if err := g.WriteConversionFunctions(functionsWriter); err != nil {
t.Fatalf("couldn't generate conversion functions: %v", err)
}
if err := functionsWriter.Flush(); err != nil {
t.Fatalf("error while flushing writer")
}
if err := g.RegisterConversionFunctions(functionsWriter); err != nil {
t.Fatalf("couldn't generate conversion function names: %v", err)
}
@@ -81,7 +78,7 @@ func readLinesUntil(t *testing.T, reader *bufio.Reader, stop string, buffer *byt
return nil
}
func bufferExistingConversions(t *testing.T, fileName string) bytes.Buffer {
func bufferExistingGeneratedCode(t *testing.T, fileName string) bytes.Buffer {
file, err := os.Open(fileName)
if err != nil {
t.Fatalf("couldn't open file %s", fileName)
@@ -139,7 +136,7 @@ func TestNoManualChangesToGenerateConversions(t *testing.T) {
for _, version := range versions {
fileName := fmt.Sprintf("../../pkg/api/%s/conversion_generated.go", version)
existingFunctions := bufferExistingConversions(t, fileName)
existingFunctions := bufferExistingGeneratedCode(t, fileName)
generatedFunctions := generateConversions(t, version)
functionsTxt := fmt.Sprintf("%s.functions.txt", version)

View File

@@ -0,0 +1,87 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package runtime_test
import (
"bufio"
"bytes"
"fmt"
"io/ioutil"
"os"
"testing"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
_ "github.com/GoogleCloudPlatform/kubernetes/pkg/api/v1"
_ "github.com/GoogleCloudPlatform/kubernetes/pkg/api/v1beta3"
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
"github.com/golang/glog"
)
func generateDeepCopies(t *testing.T, version string) bytes.Buffer {
g := runtime.NewDeepCopyGenerator(api.Scheme.Raw())
g.OverwritePackage(version, "")
testedVersion := version
if version == "api" {
testedVersion = api.Scheme.Raw().InternalVersion
}
for _, knownType := range api.Scheme.KnownTypes(testedVersion) {
if err := g.AddType(knownType); err != nil {
glog.Errorf("error while generating deep-copy functions for %v: %v", knownType, err)
}
}
var functions bytes.Buffer
functionsWriter := bufio.NewWriter(&functions)
if err := g.WriteImports(functionsWriter, version); err != nil {
t.Fatalf("couldn't generate deep-copy function imports: %v", err)
}
if err := g.WriteDeepCopyFunctions(functionsWriter); err != nil {
t.Fatalf("couldn't generate deep-copy functions: %v", err)
}
if err := g.RegisterDeepCopyFunctions(functionsWriter, version); err != nil {
t.Fatalf("couldn't generate deep-copy function names: %v", err)
}
if err := functionsWriter.Flush(); err != nil {
t.Fatalf("error while flushing writer")
}
return functions
}
func TestNoManualChangesToGenerateDeepCopies(t *testing.T) {
versions := []string{"api", "v1beta3", "v1"}
for _, version := range versions {
fileName := ""
if version == "api" {
fileName = "../../pkg/api/deep_copy_generated.go"
} else {
fileName = fmt.Sprintf("../../pkg/api/%s/deep_copy_generated.go", version)
}
existingFunctions := bufferExistingGeneratedCode(t, fileName)
generatedFunctions := generateDeepCopies(t, version)
functionsTxt := fmt.Sprintf("%s.deep_copy.txt", version)
ioutil.WriteFile(functionsTxt, generatedFunctions.Bytes(), os.FileMode(0644))
if ok := compareBuffers(t, functionsTxt, existingFunctions, generatedFunctions); ok {
os.Remove(functionsTxt)
}
}
}

View File

@@ -0,0 +1,481 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package runtime
import (
"fmt"
"io"
"reflect"
"sort"
"strings"
"github.com/GoogleCloudPlatform/kubernetes/pkg/conversion"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
)
type DeepCopyGenerator interface {
// Adds a type to a generator.
// If the type is non-struct, it will return an error, otherwise deep-copy
// functions for this type and all nested types will be generated.
AddType(inType reflect.Type) error
// Writes all imports that are necessary for deep-copy function and
// their registration.
WriteImports(w io.Writer, pkg string) error
// Writes deel-copy functions for all types added via AddType() method
// and their nested types.
WriteDeepCopyFunctions(w io.Writer) error
// Writes an init() function that registers all the generated deep-copy
// functions.
RegisterDeepCopyFunctions(w io.Writer, pkg string) error
// When generating code, all references to "pkg" package name will be
// replaced with "overwrite". It is used mainly to replace references
// to name of the package in which the code will be created with empty
// string.
OverwritePackage(pkg, overwrite string)
}
func NewDeepCopyGenerator(scheme *conversion.Scheme) DeepCopyGenerator {
return &deepCopyGenerator{
scheme: scheme,
copyables: make(map[reflect.Type]bool),
imports: util.StringSet{},
pkgOverwrites: make(map[string]string),
}
}
type deepCopyGenerator struct {
scheme *conversion.Scheme
copyables map[reflect.Type]bool
imports util.StringSet
pkgOverwrites map[string]string
}
func (g *deepCopyGenerator) addAllRecursiveTypes(inType reflect.Type) error {
if _, found := g.copyables[inType]; found {
return nil
}
switch inType.Kind() {
case reflect.Map:
if err := g.addAllRecursiveTypes(inType.Key()); err != nil {
return err
}
if err := g.addAllRecursiveTypes(inType.Elem()); err != nil {
return err
}
case reflect.Slice, reflect.Ptr:
if err := g.addAllRecursiveTypes(inType.Elem()); err != nil {
return err
}
case reflect.Interface:
g.imports.Insert(inType.PkgPath())
return nil
case reflect.Struct:
g.imports.Insert(inType.PkgPath())
if !strings.HasPrefix(inType.PkgPath(), "github.com/GoogleCloudPlatform/kubernetes") {
return nil
}
for i := 0; i < inType.NumField(); i++ {
inField := inType.Field(i)
if err := g.addAllRecursiveTypes(inField.Type); err != nil {
return err
}
}
g.copyables[inType] = true
default:
// Simple types should be copied automatically.
}
return nil
}
func (g *deepCopyGenerator) AddType(inType reflect.Type) error {
if inType.Kind() != reflect.Struct {
return fmt.Errorf("non-struct copies are not supported")
}
return g.addAllRecursiveTypes(inType)
}
func (g *deepCopyGenerator) WriteImports(w io.Writer, pkg string) error {
var packages []string
packages = append(packages, "github.com/GoogleCloudPlatform/kubernetes/pkg/api")
packages = append(packages, "github.com/GoogleCloudPlatform/kubernetes/pkg/conversion")
for key := range g.imports {
packages = append(packages, key)
}
sort.Strings(packages)
buffer := newBuffer()
indent := 0
buffer.addLine("import (\n", indent)
for _, importPkg := range packages {
if strings.HasSuffix(importPkg, pkg) {
continue
}
buffer.addLine(fmt.Sprintf("\"%s\"\n", importPkg), indent+1)
}
buffer.addLine(")\n", indent)
buffer.addLine("\n", indent)
if err := buffer.flushLines(w); err != nil {
return err
}
return nil
}
type byPkgAndName []reflect.Type
func (s byPkgAndName) Len() int {
return len(s)
}
func (s byPkgAndName) Less(i, j int) bool {
fullNameI := s[i].PkgPath() + "/" + s[i].Name()
fullNameJ := s[j].PkgPath() + "/" + s[j].Name()
return fullNameI < fullNameJ
}
func (s byPkgAndName) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
func (g *deepCopyGenerator) typeName(inType reflect.Type) string {
switch inType.Kind() {
case reflect.Map:
return fmt.Sprintf("map[%s]%s", g.typeName(inType.Key()), g.typeName(inType.Elem()))
case reflect.Slice:
return fmt.Sprintf("[]%s", g.typeName(inType.Elem()))
case reflect.Ptr:
return fmt.Sprintf("*%s", g.typeName(inType.Elem()))
default:
typeWithPkg := fmt.Sprintf("%s", inType)
slices := strings.Split(typeWithPkg, ".")
if len(slices) == 1 {
// Default package.
return slices[0]
}
if len(slices) == 2 {
pkg := slices[0]
if val, found := g.pkgOverwrites[pkg]; found {
pkg = val
}
if pkg != "" {
pkg = pkg + "."
}
return pkg + slices[1]
}
panic("Incorrect type name: " + typeWithPkg)
}
}
func (g *deepCopyGenerator) deepCopyFunctionName(inType reflect.Type) string {
funcNameFormat := "deepCopy_%s_%s"
inPkg := packageForName(inType)
funcName := fmt.Sprintf(funcNameFormat, inPkg, inType.Name())
return funcName
}
func (g *deepCopyGenerator) writeHeader(b *buffer, inType reflect.Type, indent int) {
format := "func %s(in %s, out *%s, c *conversion.Cloner) error {\n"
stmt := fmt.Sprintf(format, g.deepCopyFunctionName(inType), g.typeName(inType), g.typeName(inType))
b.addLine(stmt, indent)
}
func (g *deepCopyGenerator) writeFooter(b *buffer, indent int) {
b.addLine("return nil\n", indent+1)
b.addLine("}\n", indent)
}
func (g *deepCopyGenerator) WriteDeepCopyFunctions(w io.Writer) error {
var keys []reflect.Type
for key := range g.copyables {
keys = append(keys, key)
}
sort.Sort(byPkgAndName(keys))
buffer := newBuffer()
indent := 0
for _, inType := range keys {
if err := g.writeDeepCopyForType(buffer, inType, indent); err != nil {
return err
}
buffer.addLine("\n", 0)
}
if err := buffer.flushLines(w); err != nil {
return err
}
return nil
}
func (g *deepCopyGenerator) writeDeepCopyForMap(b *buffer, inField reflect.StructField, indent int) error {
ifFormat := "if in.%s != nil {\n"
ifStmt := fmt.Sprintf(ifFormat, inField.Name)
b.addLine(ifStmt, indent)
newFormat := "out.%s = make(%s)\n"
newStmt := fmt.Sprintf(newFormat, inField.Name, g.typeName(inField.Type))
b.addLine(newStmt, indent+1)
forFormat := "for key, val := range in.%s {\n"
forStmt := fmt.Sprintf(forFormat, inField.Name)
b.addLine(forStmt, indent+1)
switch inField.Type.Key().Kind() {
case reflect.Map, reflect.Ptr, reflect.Slice, reflect.Interface, reflect.Struct:
return fmt.Errorf("not supported")
default:
switch inField.Type.Elem().Kind() {
case reflect.Map, reflect.Ptr, reflect.Slice, reflect.Interface, reflect.Struct:
if _, found := g.copyables[inField.Type.Elem()]; found {
newFormat := "newVal := new(%s)\n"
newStmt := fmt.Sprintf(newFormat, g.typeName(inField.Type.Elem()))
b.addLine(newStmt, indent+2)
assignFormat := "if err := %s(val, newVal, c); err != nil {\n"
funcName := g.deepCopyFunctionName(inField.Type.Elem())
assignStmt := fmt.Sprintf(assignFormat, funcName)
b.addLine(assignStmt, indent+2)
b.addLine("return err\n", indent+3)
b.addLine("}\n", indent+2)
setFormat := "out.%s[key] = *newVal\n"
setStmt := fmt.Sprintf(setFormat, inField.Name)
b.addLine(setStmt, indent+2)
} else {
ifStmt := "if newVal, err := c.DeepCopy(val); err != nil {\n"
b.addLine(ifStmt, indent+2)
b.addLine("return err\n", indent+3)
b.addLine("} else {\n", indent+2)
assignFormat := "out.%s[key] = newVal.(%s)\n"
assignStmt := fmt.Sprintf(assignFormat, inField.Name, g.typeName(inField.Type.Elem()))
b.addLine(assignStmt, indent+3)
b.addLine("}\n", indent+2)
}
default:
assignFormat := "out.%s[key] = val\n"
assignStmt := fmt.Sprintf(assignFormat, inField.Name)
b.addLine(assignStmt, indent+2)
}
}
b.addLine("}\n", indent+1)
b.addLine("} else {\n", indent)
elseFormat := "out.%s = nil\n"
elseStmt := fmt.Sprintf(elseFormat, inField.Name)
b.addLine(elseStmt, indent+1)
b.addLine("}\n", indent)
return nil
}
func (g *deepCopyGenerator) writeDeepCopyForPtr(b *buffer, inField reflect.StructField, indent int) error {
ifFormat := "if in.%s != nil {\n"
ifStmt := fmt.Sprintf(ifFormat, inField.Name)
b.addLine(ifStmt, indent)
switch inField.Type.Elem().Kind() {
case reflect.Map, reflect.Ptr, reflect.Slice, reflect.Interface, reflect.Struct:
if _, found := g.copyables[inField.Type.Elem()]; found {
newFormat := "out.%s = new(%s)\n"
newStmt := fmt.Sprintf(newFormat, inField.Name, g.typeName(inField.Type.Elem()))
b.addLine(newStmt, indent+1)
assignFormat := "if err := %s(*in.%s, out.%s, c); err != nil {\n"
funcName := g.deepCopyFunctionName(inField.Type.Elem())
assignStmt := fmt.Sprintf(assignFormat, funcName, inField.Name, inField.Name)
b.addLine(assignStmt, indent+1)
b.addLine("return err\n", indent+2)
b.addLine("}\n", indent+1)
} else {
ifFormat := "if newVal, err := c.DeepCopy(in.%s); err != nil {\n"
ifStmt := fmt.Sprintf(ifFormat, inField.Name)
b.addLine(ifStmt, indent+1)
b.addLine("return err\n", indent+2)
b.addLine("} else {\n", indent+1)
assignFormat := "out.%s = newVal.(%s)\n"
assignStmt := fmt.Sprintf(assignFormat, inField.Name, g.typeName(inField.Type))
b.addLine(assignStmt, indent+2)
b.addLine("}\n", indent+1)
}
default:
newFormat := "out.%s = new(%s)\n"
newStmt := fmt.Sprintf(newFormat, inField.Name, g.typeName(inField.Type.Elem()))
b.addLine(newStmt, indent+1)
assignFormat := "*out.%s = *in.%s\n"
assignStmt := fmt.Sprintf(assignFormat, inField.Name, inField.Name)
b.addLine(assignStmt, indent+1)
}
b.addLine("} else {\n", indent)
elseFormat := "out.%s = nil\n"
elseStmt := fmt.Sprintf(elseFormat, inField.Name)
b.addLine(elseStmt, indent+1)
b.addLine("}\n", indent)
return nil
}
func (g *deepCopyGenerator) writeDeepCopyForSlice(b *buffer, inField reflect.StructField, indent int) error {
ifFormat := "if in.%s != nil {\n"
ifStmt := fmt.Sprintf(ifFormat, inField.Name)
b.addLine(ifStmt, indent)
newFormat := "out.%s = make(%s, len(in.%s))\n"
newStmt := fmt.Sprintf(newFormat, inField.Name, g.typeName(inField.Type), inField.Name)
b.addLine(newStmt, indent+1)
forFormat := "for i := range in.%s {\n"
forStmt := fmt.Sprintf(forFormat, inField.Name)
b.addLine(forStmt, indent+1)
switch inField.Type.Elem().Kind() {
case reflect.Map, reflect.Ptr, reflect.Slice, reflect.Interface, reflect.Struct:
if _, found := g.copyables[inField.Type.Elem()]; found {
assignFormat := "if err := %s(in.%s[i], &out.%s[i], c); err != nil {\n"
funcName := g.deepCopyFunctionName(inField.Type.Elem())
assignStmt := fmt.Sprintf(assignFormat, funcName, inField.Name, inField.Name)
b.addLine(assignStmt, indent+2)
b.addLine("return err\n", indent+3)
b.addLine("}\n", indent+2)
} else {
ifFormat := "if newVal, err := c.DeepCopy(in.%s[i]); err != nil {\n"
ifStmt := fmt.Sprintf(ifFormat, inField.Name)
b.addLine(ifStmt, indent+2)
b.addLine("return err\n", indent+3)
b.addLine("} else {\n", indent+2)
assignFormat := "out.%s[i] = newVal.(%s)\n"
assignStmt := fmt.Sprintf(assignFormat, inField.Name, g.typeName(inField.Type.Elem()))
b.addLine(assignStmt, indent+3)
b.addLine("}\n", indent+2)
}
default:
assignFormat := "out.%s[i] = in.%s[i]\n"
assignStmt := fmt.Sprintf(assignFormat, inField.Name, inField.Name)
b.addLine(assignStmt, indent+2)
}
b.addLine("}\n", indent+1)
b.addLine("} else {\n", indent)
elseFormat := "out.%s = nil\n"
elseStmt := fmt.Sprintf(elseFormat, inField.Name)
b.addLine(elseStmt, indent+1)
b.addLine("}\n", indent)
return nil
}
func (g *deepCopyGenerator) writeDeepCopyForStruct(b *buffer, inType reflect.Type, indent int) error {
for i := 0; i < inType.NumField(); i++ {
inField := inType.Field(i)
switch inField.Type.Kind() {
case reflect.Map:
if err := g.writeDeepCopyForMap(b, inField, indent); err != nil {
return err
}
case reflect.Ptr:
if err := g.writeDeepCopyForPtr(b, inField, indent); err != nil {
return err
}
case reflect.Slice:
if err := g.writeDeepCopyForSlice(b, inField, indent); err != nil {
return err
}
case reflect.Interface:
ifFormat := "if newVal, err := c.DeepCopy(in.%s); err != nil {\n"
ifStmt := fmt.Sprintf(ifFormat, inField.Name)
b.addLine(ifStmt, indent)
b.addLine("return err\n", indent+1)
b.addLine("} else {\n", indent)
copyFormat := "out.%s = newVal.(%s)\n"
copyStmt := fmt.Sprintf(copyFormat, inField.Name, g.typeName(inField.Type))
b.addLine(copyStmt, indent+1)
b.addLine("}\n", indent)
case reflect.Struct:
if _, found := g.copyables[inField.Type]; found {
ifFormat := "if err := %s(in.%s, &out.%s, c); err != nil {\n"
funcName := g.deepCopyFunctionName(inField.Type)
ifStmt := fmt.Sprintf(ifFormat, funcName, inField.Name, inField.Name)
b.addLine(ifStmt, indent)
b.addLine("return err\n", indent+1)
b.addLine("}\n", indent)
} else {
ifFormat := "if newVal, err := c.DeepCopy(in.%s); err != nil {\n"
ifStmt := fmt.Sprintf(ifFormat, inField.Name)
b.addLine(ifStmt, indent)
b.addLine("return err\n", indent+1)
b.addLine("} else {\n", indent)
assignFormat := "out.%s = newVal.(%s)\n"
assignStmt := fmt.Sprintf(assignFormat, inField.Name, g.typeName(inField.Type))
b.addLine(assignStmt, indent+1)
b.addLine("}\n", indent)
}
default:
// This should handle all simple types.
assignFormat := "out.%s = in.%s\n"
assignStmt := fmt.Sprintf(assignFormat, inField.Name, inField.Name)
b.addLine(assignStmt, indent)
}
}
return nil
}
func (g *deepCopyGenerator) writeDeepCopyForType(b *buffer, inType reflect.Type, indent int) error {
g.writeHeader(b, inType, indent)
switch inType.Kind() {
case reflect.Struct:
if err := g.writeDeepCopyForStruct(b, inType, indent+1); err != nil {
return err
}
default:
return fmt.Errorf("type not supported: %v", inType)
}
g.writeFooter(b, indent)
return nil
}
func (g *deepCopyGenerator) writeRegisterHeader(b *buffer, pkg string, indent int) {
b.addLine("func init() {\n", indent)
registerFormat := "err := %sScheme.AddGeneratedDeepCopyFuncs(\n"
if pkg == "api" {
b.addLine(fmt.Sprintf(registerFormat, ""), indent+1)
} else {
b.addLine(fmt.Sprintf(registerFormat, "api."), indent+1)
}
}
func (g *deepCopyGenerator) writeRegisterFooter(b *buffer, indent int) {
b.addLine(")\n", indent+1)
b.addLine("if err != nil {\n", indent+1)
b.addLine("// if one of the deep copy functions is malformed, detect it immediately.\n", indent+2)
b.addLine("panic(err)\n", indent+2)
b.addLine("}\n", indent+1)
b.addLine("}\n", indent)
b.addLine("\n", indent)
}
func (g *deepCopyGenerator) RegisterDeepCopyFunctions(w io.Writer, pkg string) error {
var keys []reflect.Type
for key := range g.copyables {
keys = append(keys, key)
}
sort.Sort(byPkgAndName(keys))
buffer := newBuffer()
indent := 0
g.writeRegisterHeader(buffer, pkg, indent)
for _, inType := range keys {
funcStmt := fmt.Sprintf("%s,\n", g.deepCopyFunctionName(inType))
buffer.addLine(funcStmt, indent+2)
}
g.writeRegisterFooter(buffer, indent)
if err := buffer.flushLines(w); err != nil {
return err
}
return nil
}
func (g *deepCopyGenerator) OverwritePackage(pkg, overwrite string) {
g.pkgOverwrites[pkg] = overwrite
}

View File

@@ -321,6 +321,19 @@ func (s *Scheme) AddGeneratedConversionFuncs(conversionFuncs ...interface{}) err
return s.raw.AddGeneratedConversionFuncs(conversionFuncs...)
}
// AddDeepCopyFuncs adds a function to the list of deep-copy functions.
// For the expected format of deep-copy function, see the comment for
// Copier.RegisterDeepCopyFunction.
func (s *Scheme) AddDeepCopyFuncs(deepCopyFuncs ...interface{}) error {
return s.raw.AddDeepCopyFuncs(deepCopyFuncs...)
}
// Similar to AddDeepCopyFuncs, but registers deep-copy functions that were
// automatically generated.
func (s *Scheme) AddGeneratedDeepCopyFuncs(deepCopyFuncs ...interface{}) error {
return s.raw.AddGeneratedDeepCopyFuncs(deepCopyFuncs...)
}
// AddFieldLabelConversionFunc adds a conversion function to convert field selectors
// of the given kind from the given version to internal version representation.
func (s *Scheme) AddFieldLabelConversionFunc(version, kind string, conversionFunc FieldLabelConversionFunc) error {
@@ -347,6 +360,11 @@ func (s *Scheme) AddDefaultingFuncs(defaultingFuncs ...interface{}) error {
return s.raw.AddDefaultingFuncs(defaultingFuncs...)
}
// Performs a deep copy of the given object.
func (s *Scheme) DeepCopy(src interface{}) (interface{}, error) {
return s.raw.DeepCopy(src)
}
// Convert will attempt to convert in into out. Both must be pointers.
// For easy testing of conversion functions. Returns an error if the conversion isn't
// possible.