Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make GRPC ClientConn and Server interfaces in generated code #675

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions grpc/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package grpc

import (
"context"
grpc "google.golang.org/grpc"
)

type Server interface {
RegisterService(sd *grpc.ServiceDesc, ss interface{})
}

type ClientConn interface {
Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error
NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error)
}
65 changes: 39 additions & 26 deletions plugin/marshalto/marshalto.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,33 +607,46 @@ func (p *marshalto) generateField(proto3 bool, numGen NumGen, file *generator.Fi
p.encodeKey(fieldNumber, wireType)
}
case descriptor.FieldDescriptorProto_TYPE_STRING:
if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`i -= len(`, val, `)`)
p.P(`copy(dAtA[i:], `, val, `)`)
p.callVarint(`len(`, val, `)`)
p.encodeKey(fieldNumber, wireType)
p.Out()
p.P(`}`)
} else if proto3 {
p.P(`if len(m.`, fieldname, `) > 0 {`)
p.In()
p.P(`i -= len(m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], m.`, fieldname, `)`)
p.callVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
p.Out()
p.P(`}`)
} else if !nullable {
p.P(`i -= len(m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], m.`, fieldname, `)`)
p.callVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
if !gogoproto.IsCustomType(field) {
if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`i -= len(`, val, `)`)
p.P(`copy(dAtA[i:], `, val, `)`)
p.callVarint(`len(`, val, `)`)
p.encodeKey(fieldNumber, wireType)
p.Out()
p.P(`}`)
} else if proto3 {
p.P(`if len(m.`, fieldname, `) > 0 {`)
p.In()
p.P(`i -= len(m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], m.`, fieldname, `)`)
p.callVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
p.Out()
p.P(`}`)
} else if !nullable {
p.P(`i -= len(m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], m.`, fieldname, `)`)
p.callVarint(`len(m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
} else {
p.P(`i -= len(*m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], *m.`, fieldname, `)`)
p.callVarint(`len(*m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
}
} else {
p.P(`i -= len(*m.`, fieldname, `)`)
p.P(`copy(dAtA[i:], *m.`, fieldname, `)`)
p.callVarint(`len(*m.`, fieldname, `)`)
p.encodeKey(fieldNumber, wireType)
if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.forward(val, true, protoSizer)
p.encodeKey(fieldNumber, wireType)
p.Out()
p.P(`}`)
} else {
p.forward(`m.`+fieldname, true, protoSizer)
p.encodeKey(fieldNumber, wireType)
}
}
case descriptor.FieldDescriptorProto_TYPE_GROUP:
panic(fmt.Errorf("marshaler does not support group %v", fieldname))
Expand Down
52 changes: 33 additions & 19 deletions plugin/size/size.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,26 +334,40 @@ func (p *size) generateField(proto3 bool, file *generator.FileDescriptor, messag
p.P(`n+=`, strconv.Itoa(key+1))
}
case descriptor.FieldDescriptorProto_TYPE_STRING:
if repeated {
p.P(`for _, s := range m.`, fieldname, ` { `)
p.In()
p.P(`l = len(s)`)
p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
p.Out()
p.P(`}`)
} else if proto3 {
p.P(`l=len(m.`, fieldname, `)`)
p.P(`if l > 0 {`)
p.In()
p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
p.Out()
p.P(`}`)
} else if nullable {
p.P(`l=len(*m.`, fieldname, `)`)
p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
if !gogoproto.IsCustomType(field) {
if repeated {
p.P(`for _, s := range m.`, fieldname, ` { `)
p.In()
p.P(`l = len(s)`)
p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
p.Out()
p.P(`}`)
} else if proto3 {
p.P(`l=len(m.`, fieldname, `)`)
p.P(`if l > 0 {`)
p.In()
p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
p.Out()
p.P(`}`)
} else if nullable {
p.P(`l=len(*m.`, fieldname, `)`)
p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
} else {
p.P(`l=len(m.`, fieldname, `)`)
p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
}
} else {
p.P(`l=len(m.`, fieldname, `)`)
p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
if repeated {
p.P(`for _, e := range m.`, fieldname, ` { `)
p.In()
p.P(`l=e.`, sizeName, `()`)
p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
p.Out()
p.P(`}`)
} else {
p.P(`l=m.`, fieldname, `.`, sizeName, `()`)
p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
}
}
case descriptor.FieldDescriptorProto_TYPE_GROUP:
panic(fmt.Errorf("size does not support group %v", fieldname))
Expand Down
70 changes: 59 additions & 11 deletions plugin/unmarshal/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,18 @@ func (p *unmarshal) declareMapField(varName string, nullable bool, customType bo
case descriptor.FieldDescriptorProto_TYPE_BOOL:
p.P(`var `, varName, ` bool`)
case descriptor.FieldDescriptorProto_TYPE_STRING:
cast, _ := p.GoType(nil, field)
cast = strings.Replace(cast, "*", "", 1)
p.P(`var `, varName, ` `, cast)
if customType {
_, ctyp, err := generator.GetCustomType(field)
if err != nil {
panic(err)
}
p.P(`var `, varName, `1 `, ctyp)
p.P(`var `, varName, ` = &`, varName, `1`)
} else {
cast, _ := p.GoType(nil, field)
cast = strings.Replace(cast, "*", "", 1)
p.P(`var `, varName, ` `, cast)
}
case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
if gogoproto.IsStdTime(field) {
p.P(varName, ` := new(time.Time)`)
Expand Down Expand Up @@ -652,15 +661,54 @@ func (p *unmarshal) field(file *generator.FileDescriptor, msg *generator.Descrip
p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`)
p.Out()
p.P(`}`)
if oneof {
p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{`, typ, `(dAtA[iNdEx:postIndex])}`)
} else if repeated {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, typ, `(dAtA[iNdEx:postIndex]))`)
} else if proto3 || !nullable {
p.P(`m.`, fieldname, ` = `, typ, `(dAtA[iNdEx:postIndex])`)
if !gogoproto.IsCustomType(field) {
if oneof {
p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{`, typ, `(dAtA[iNdEx:postIndex])}`)
} else if repeated {
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, `, typ, `(dAtA[iNdEx:postIndex]))`)
} else if proto3 || !nullable {
p.P(`m.`, fieldname, ` = `, typ, `(dAtA[iNdEx:postIndex])`)
} else {
p.P(`s := `, typ, `(dAtA[iNdEx:postIndex])`)
p.P(`m.`, fieldname, ` = &s`)
}
} else {
p.P(`s := `, typ, `(dAtA[iNdEx:postIndex])`)
p.P(`m.`, fieldname, ` = &s`)
_, ctyp, err := generator.GetCustomType(field)
if err != nil {
panic(err)
}
if oneof {
p.P(`var vv `, ctyp)
p.P(`v := &vv`)
p.P(`if err := v.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {`)
p.In()
p.P(`return err`)
p.Out()
p.P(`}`)
p.P(`m.`, fieldname, ` = &`, p.OneOfTypeName(msg, field), `{*v}`)
} else if repeated {
p.P(`var v `, ctyp)
p.P(`m.`, fieldname, ` = append(m.`, fieldname, `, v)`)
p.P(`if err := m.`, fieldname, `[len(m.`, fieldname, `)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil {`)
p.In()
p.P(`return err`)
p.Out()
p.P(`}`)
} else if nullable {
p.P(`var v `, ctyp)
p.P(`m.`, fieldname, ` = &v`)
p.P(`if err := m.`, fieldname, `.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {`)
p.In()
p.P(`return err`)
p.Out()
p.P(`}`)
} else {
p.P(`if err := m.`, fieldname, `.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {`)
p.In()
p.P(`return err`)
p.Out()
p.P(`}`)
}
}
p.P(`iNdEx = postIndex`)
case descriptor.FieldDescriptorProto_TYPE_GROUP:
Expand Down
21 changes: 12 additions & 9 deletions protoc-gen-gogo/grpc/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ const generatedCodeVersion = 4
// Paths for packages used by code generated in this file,
// relative to the import_prefix of the generator.Generator.
const (
contextPkgPath = "context"
grpcPkgPath = "google.golang.org/grpc"
codePkgPath = "google.golang.org/grpc/codes"
statusPkgPath = "google.golang.org/grpc/status"
contextPkgPath = "context"
grpcPkgPath = "google.golang.org/grpc"
codePkgPath = "google.golang.org/grpc/codes"
statusPkgPath = "google.golang.org/grpc/status"
gogoGrpcPkgPath = "github.com/gogo/protobuf/grpc"
)

func init() {
Expand All @@ -77,8 +78,9 @@ func (g *grpc) Name() string {
// They may vary from the final path component of the import path
// if the name is used by other packages.
var (
contextPkg string
grpcPkg string
contextPkg string
grpcPkg string
gogoGrpcPkg string
)

// Init initializes the plugin.
Expand Down Expand Up @@ -109,6 +111,7 @@ func (g *grpc) Generate(file *generator.FileDescriptor) {

contextPkg = string(g.gen.AddImport(contextPkgPath))
grpcPkg = string(g.gen.AddImport(grpcPkgPath))
gogoGrpcPkg = string(g.gen.AddImport(gogoGrpcPkgPath))

g.P("// Reference imports to suppress errors if they are not otherwise used.")
g.P("var _ ", contextPkg, ".Context")
Expand Down Expand Up @@ -172,15 +175,15 @@ func (g *grpc) generateService(file *generator.FileDescriptor, service *pb.Servi

// Client structure.
g.P("type ", unexport(servName), "Client struct {")
g.P("cc *", grpcPkg, ".ClientConn")
g.P("cc ", gogoGrpcPkg, ".ClientConn")
g.P("}")
g.P()

// NewClient factory.
if deprecated {
g.P(deprecationComment)
}
g.P("func New", servName, "Client (cc *", grpcPkg, ".ClientConn) ", servName, "Client {")
g.P("func New", servName, "Client (cc ", gogoGrpcPkg, ".ClientConn) ", servName, "Client {")
g.P("return &", unexport(servName), "Client{cc}")
g.P("}")
g.P()
Expand Down Expand Up @@ -227,7 +230,7 @@ func (g *grpc) generateService(file *generator.FileDescriptor, service *pb.Servi
if deprecated {
g.P(deprecationComment)
}
g.P("func Register", servName, "Server(s *", grpcPkg, ".Server, srv ", serverType, ") {")
g.P("func Register", servName, "Server(s ", gogoGrpcPkg, ".Server, srv ", serverType, ") {")
g.P("s.RegisterService(&", serviceDescVar, `, srv)`)
g.P("}")
g.P()
Expand Down
7 changes: 4 additions & 3 deletions protoc-gen-gogo/testdata/deprecated/deprecated.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions protoc-gen-gogo/testdata/grpc/grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions protoc-gen-gogo/testdata/grpc/grpc_empty.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion test/castvalue/combos/unmarshaler/castvalue.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading