Skip to content

Commit

Permalink
feat: add proto implement
Browse files Browse the repository at this point in the history
  • Loading branch information
thinkgos committed May 8, 2024
1 parent ccadb12 commit 397f861
Show file tree
Hide file tree
Showing 28 changed files with 861 additions and 414 deletions.
10 changes: 3 additions & 7 deletions cmd/ormat/command/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package command

import (
"context"
"fmt"
"log/slog"
"os"
"strings"

"ariga.io/atlas/sql/schema"
"github.com/spf13/cobra"
Expand All @@ -30,15 +28,13 @@ func newBuildCmd() *buildCmd {

getSchema := func() ens.Schemaer {
innerParseFromFile := func(filename string) (ens.Schemaer, error) {
var d driver.Driver

content, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
d, ok := driver.LoadDriver(root.Schema)
if !ok {
return nil, fmt.Errorf("unsupported schema, only support [%v]", strings.Join(driver.DriverNames(), ", "))
d, err := driver.LoadDriver(root.Schema)
if err != nil {
return nil, err
}
return d.InspectSchema(context.Background(), &driver.InspectOption{
URL: "",
Expand Down
7 changes: 3 additions & 4 deletions cmd/ormat/command/helper.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package command

import (
"fmt"
"net/url"
"os"
"path"
Expand All @@ -16,9 +15,9 @@ func LoadDriver(URL string) (driver.Driver, error) {
if err != nil {
return nil, err
}
d, ok := driver.LoadDriver(u.Scheme)
if !ok {
return nil, fmt.Errorf("unsupported schema, only support [%v]", strings.Join(driver.DriverNames(), ", "))
d, err := driver.LoadDriver(u.Scheme)
if err != nil {
return nil, err
}
return d, nil
}
Expand Down
3 changes: 0 additions & 3 deletions cmd/ormat/command/helper_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,4 @@ func InitFlagSetForConfig(s *pflag.FlagSet, cc *Config) {
s.StringVar(&cc.Package, "package", "", "package name")
s.StringToStringVar(&cc.Options, "options", nil, "options key value")
s.BoolVarP(&cc.DisableDocComment, "disableDocComment", "d", false, "禁用文档注释")

s.BoolVar(&cc.EnableGogo, "enableGogo", false, "使能用 gogo proto (仅输出 proto 有效)")
s.BoolVar(&cc.EnableSea, "enableSea", false, "使能用 seaql (仅输出 proto 有效)")
}
143 changes: 143 additions & 0 deletions cmd/ormat/command/proto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package command

import (
"context"
"errors"
"fmt"
"log/slog"
"os"

"ariga.io/atlas/sql/schema"
"github.com/spf13/cobra"
"github.com/things-go/ens/driver"
"github.com/things-go/ens/proto"
)

type protoOpt struct {
// sql file
InputFile []string
Schema string
// database url
Url string
Tables []string
Exclude []string

// output directory
OutputDir string

// codegen
PackageName string // required, proto 包名
Options map[string]string // required, proto option
DisableDocComment bool // 禁用doc注释
DisableBool bool // 禁用bool,使用int32
DisableTimestamp bool // 禁用google.protobuf.Timestamp,使用int64
}

type protoCmd struct {
cmd *cobra.Command
protoOpt
}

func newProtoCmd() *protoCmd {
root := &protoCmd{}

protoSchema := func() (*proto.Schema, error) {
if root.Url != "" {
d, err := LoadDriver(root.Url)
if err != nil {
return nil, err
}
return d.InspectProto(context.Background(), &driver.InspectOption{
URL: root.Url,
InspectOptions: schema.InspectOptions{
Mode: schema.InspectTables,
Tables: root.Tables,
Exclude: root.Exclude,
},
})
}
if len(root.InputFile) > 0 {
d, err := driver.LoadDriver(root.Schema)
if err != nil {
return nil, err
}
schemas := &proto.Schema{
Name: "",
Messages: make([]*proto.Message, 0, 128),
}
for _, filename := range root.InputFile {
tmpSchema, err := func() (*proto.Schema, error) {
content, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
return d.InspectProto(context.Background(), &driver.InspectOption{
URL: "",
Data: string(content),
InspectOptions: schema.InspectOptions{},
})
}()
if err != nil {
slog.Warn("🧐 parse failed !!!", slog.String("file", filename), slog.Any("error", err))
continue
}
schemas.Messages = append(schemas.Messages, tmpSchema.Messages...)
}
return schemas, nil
}
return nil, errors.New("at least one of [url input] is required")
}

cmd := &cobra.Command{
Use: "proto",
Short: "Generate proto from database",
Example: "ormat proto",
RunE: func(*cobra.Command, []string) error {
sc, err := protoSchema()
if err != nil {
return err
}
for _, msg := range sc.Messages {
codegen := &proto.CodeGen{
Messages: []*proto.Message{msg},
ByName: "ormat",
Version: version,
PackageName: root.PackageName,
Options: root.Options,
DisableDocComment: root.DisableDocComment,
DisableBool: root.DisableBool,
DisableTimestamp: root.DisableTimestamp,
}
data := codegen.Gen().Bytes()
filename := joinFilename(root.OutputDir, msg.TableName, ".proto")
err := WriteFile(filename, data)
if err != nil {
return fmt.Errorf("%v: %w", msg.TableName, err)
}
slog.Info("👉 " + filename)
}
return nil
},
}

cmd.Flags().StringSliceVarP(&root.InputFile, "input", "i", nil, "input file")
cmd.Flags().StringVarP(&root.Schema, "schema", "s", "file+mysql", "parser file driver, [file+mysql,file+tidb](仅input时有效)")

// database url
cmd.Flags().StringVarP(&root.Url, "url", "u", "", "mysql://root:[email protected]:3306/test")
cmd.Flags().StringSliceVarP(&root.Tables, "table", "t", nil, "only out custom table(仅url时有效)")
cmd.Flags().StringSliceVarP(&root.Exclude, "exclude", "e", nil, "exclude table pattern(仅url时有效)")

cmd.Flags().StringVarP(&root.OutputDir, "out", "o", "./mapper", "out directory")

cmd.Flags().StringVar(&root.PackageName, "package", "mapper", "proto package name")
cmd.Flags().StringToStringVar(&root.Options, "options", nil, "proto options key/value")
cmd.Flags().BoolVar(&root.DisableDocComment, "disableDocComment", false, "禁用文档注释")
cmd.Flags().BoolVar(&root.DisableBool, "disableBool", false, "禁用bool,使用int32")
cmd.Flags().BoolVar(&root.DisableTimestamp, "disableTimestamp", false, "禁用google.protobuf.Timestamp,使用int64")

cmd.MarkFlagsOneRequired("url", "input")

root.cmd = cmd
return root
}
2 changes: 1 addition & 1 deletion cmd/ormat/command/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewRootCmd() *RootCmd {
newSqlCmd().cmd,
newBuildCmd().cmd,
newGenCmd().cmd,
newUpgradeCmd().cmd,
newProtoCmd().cmd,
)
root.cmd = cmd
return root
Expand Down
Loading

0 comments on commit 397f861

Please sign in to comment.