Skip to content

Commit

Permalink
feat: add rapier
Browse files Browse the repository at this point in the history
  • Loading branch information
thinkgos committed May 9, 2024
1 parent f52434c commit b54e83a
Show file tree
Hide file tree
Showing 17 changed files with 425 additions and 88 deletions.
6 changes: 3 additions & 3 deletions cmd/ormat/command/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func newProtoCmd() *protoCmd {
}
schemas := &proto.Schema{
Name: "",
Messages: make([]*proto.Message, 0, 128),
Entities: make([]*proto.Message, 0, 128),
}
for _, filename := range root.InputFile {
tmpSchema, err := func() (*proto.Schema, error) {
Expand All @@ -83,7 +83,7 @@ func newProtoCmd() *protoCmd {
slog.Warn("🧐 parse failed !!!", slog.String("file", filename), slog.Any("error", err))
continue
}
schemas.Messages = append(schemas.Messages, tmpSchema.Messages...)
schemas.Entities = append(schemas.Entities, tmpSchema.Entities...)
}
return schemas, nil
}
Expand All @@ -99,7 +99,7 @@ func newProtoCmd() *protoCmd {
if err != nil {
return err
}
for _, msg := range sc.Messages {
for _, msg := range sc.Entities {
codegen := &proto.CodeGen{
Messages: []*proto.Message{msg},
ByName: "ormat",
Expand Down
150 changes: 150 additions & 0 deletions cmd/ormat/command/rapier.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package command

import (
"context"
"errors"
"fmt"
"log/slog"
"os"

"ariga.io/atlas/sql/schema"
"github.com/spf13/cobra"
"github.com/things-go/ens/driver"
"github.com/things-go/ens/rapier"
)

type rapierOpt struct {
// sql file
InputFile []string
Schema string
// database url
Url string
Tables []string
Exclude []string

// output directory
OutputDir string

// codegen
PackageName string // required, proto 包名
ModelImportPath string // required, model导入路径
DisableDocComment bool // 禁用doc注释
EnableInt bool // 使能int8,uint8,int16,uint16,int32,uint32输出为int,uint
EnableIntegerInt bool // 使能int32,uint32输出为int,uint
EnableBoolInt bool // 使能bool输出int
}

type rapierCmd struct {
cmd *cobra.Command
rapierOpt
}

func newRapierCmd() *rapierCmd {
root := &rapierCmd{}

rapierSchema := func() (*rapier.Schema, error) {
if root.Url != "" {
d, err := LoadDriver(root.Url)
if err != nil {
return nil, err
}
return d.InspectRapier(context.Background(), &driver.InspectOption{
URL: root.Url,
InspectOptions: schema.InspectOptions{
Mode: schema.InspectTables,
Tables: root.Tables,
Exclude: root.Exclude,
},
})
}
if len(root.InputFile) > 0 {
d, err := driver.LoadDriver(root.Schema)
if err != nil {
return nil, err
}
schemas := &rapier.Schema{
Name: "",
Entities: make([]*rapier.Struct, 0, 128),
}
for _, filename := range root.InputFile {
tmpSchema, err := func() (*rapier.Schema, error) {
content, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
return d.InspectRapier(context.Background(), &driver.InspectOption{
URL: "",
Data: string(content),
InspectOptions: schema.InspectOptions{},
})
}()
if err != nil {
slog.Warn("🧐 parse failed !!!", slog.String("file", filename), slog.Any("error", err))
continue
}
schemas.Entities = append(schemas.Entities, tmpSchema.Entities...)
}
return schemas, nil
}
return nil, errors.New("at least one of [url input] is required")
}

cmd := &cobra.Command{
Use: "rapier",
Short: "Generate rapier from database/file",
Example: "ormat rapier",
RunE: func(*cobra.Command, []string) error {
sc, err := rapierSchema()
if err != nil {
return err
}
for _, msg := range sc.Entities {
codegen := &rapier.CodeGen{
Entities: []*rapier.Struct{msg},
ByName: "ormat",
Version: version,
PackageName: root.PackageName,
ModelImportPath: root.ModelImportPath,
DisableDocComment: root.DisableDocComment,
EnableInt: root.EnableInt,
EnableIntegerInt: root.EnableIntegerInt,
EnableBoolInt: root.EnableBoolInt,
}

data, err := codegen.Gen().FormatSource()
if err != nil {
return err
}
filename := joinFilename(root.OutputDir, msg.TableName, ".rapier.gen.go")
err = WriteFile(filename, data)
if err != nil {
return fmt.Errorf("%v: %w", msg.TableName, err)
}
slog.Info("👉 " + filename)
}
return nil
},
}

cmd.Flags().StringSliceVarP(&root.InputFile, "input", "i", nil, "input file")
cmd.Flags().StringVarP(&root.Schema, "schema", "s", "file+mysql", "parser file driver, [file+mysql,file+tidb](仅input时有效)")

// database url
cmd.Flags().StringVarP(&root.Url, "url", "u", "", "mysql://root:[email protected]:3306/test")
cmd.Flags().StringSliceVarP(&root.Tables, "table", "t", nil, "only out custom table(仅url时有效)")
cmd.Flags().StringSliceVarP(&root.Exclude, "exclude", "e", nil, "exclude table pattern(仅url时有效)")

cmd.Flags().StringVarP(&root.OutputDir, "out", "o", "./repository", "out directory")

cmd.Flags().StringVar(&root.PackageName, "package", "repository", "proto package name")
cmd.Flags().StringVar(&root.ModelImportPath, "modelImportPath", "", "model导入路径")
cmd.Flags().BoolVar(&root.DisableDocComment, "enableInt", false, "禁用文档注释")
cmd.Flags().BoolVar(&root.EnableInt, "disableBool", false, "使能int8,uint8,int16,uint16,int32,uint32输出为int,uint")
cmd.Flags().BoolVar(&root.EnableIntegerInt, "enableIntegerInt", false, "使能int32,uint32输出为int,uint")
cmd.Flags().BoolVar(&root.EnableBoolInt, "enableBoolInt", false, "使能bool输出int")

cmd.MarkFlagsOneRequired("url", "input")

root.cmd = cmd
return root
}
1 change: 1 addition & 0 deletions cmd/ormat/command/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func NewRootCmd() *RootCmd {
newBuildCmd().cmd,
newGenCmd().cmd,
newProtoCmd().cmd,
newRapierCmd().cmd,
)
root.cmd = cmd
return root
Expand Down
4 changes: 3 additions & 1 deletion cmd/ormat/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"github.com/things-go/ens/cmd/ormat/command"
)

var root = command.NewRootCmd()

func main() {
err := command.NewRootCmd().Execute()
err := root.Execute()
if err != nil {
os.Exit(1)
}
Expand Down
2 changes: 1 addition & 1 deletion codegen/rapier.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

func (g *CodeGen) GenRapier(modelImportPath string) *CodeGen {
pkgQualifierPrefix := ""
if p := ens.PkgName(modelImportPath); p != "" {
if p := utils.PkgName(modelImportPath); p != "" {
pkgQualifierPrefix = p + "."
}
if !g.disableDocComment {
Expand Down
2 changes: 2 additions & 0 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"ariga.io/atlas/sql/schema"
"github.com/things-go/ens"
"github.com/things-go/ens/proto"
"github.com/things-go/ens/rapier"
)

const (
Expand All @@ -22,6 +23,7 @@ var drivers sync.Map
type Driver interface {
InspectSchema(context.Context, *InspectOption) (*ens.MixinSchema, error)
InspectProto(context.Context, *InspectOption) (*proto.Schema, error)
InspectRapier(ctx context.Context, arg *InspectOption) (*rapier.Schema, error)
}

func RegisterDriver(name string, d Driver) {
Expand Down
29 changes: 28 additions & 1 deletion driver/mysql/def_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package mysql
import (
"ariga.io/atlas/sql/mysql"
"ariga.io/atlas/sql/schema"
"google.golang.org/protobuf/reflect/protoreflect"

"github.com/things-go/ens"
"github.com/things-go/ens/internal/sqlx"
"github.com/things-go/ens/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/things-go/ens/rapier"
"github.com/things-go/ens/utils"
)

func autoIncrement(attrs []schema.Attr) bool {
Expand Down Expand Up @@ -76,3 +79,27 @@ func IntoProto(tb *schema.Table) *proto.Message {
Fields: fields,
}
}

func IntoRapier(tb *schema.Table) *rapier.Struct {
// * columns
fields := make([]*rapier.StructField, 0, len(tb.Columns))
for _, col := range tb.Columns {
goType := intoGoType(col.Type.Raw)

t := goType.Type.IntoRapierType()

fields = append(fields, &rapier.StructField{
Type: t,
GoName: utils.CamelCase(col.Name),
Nullable: col.Type.Null,
ColumnName: col.Name,
Comment: sqlx.MustComment(col.Attrs),
})
}
return &rapier.Struct{
GoName: utils.CamelCase(tb.Name),
TableName: tb.Name,
Comment: sqlx.MustComment(tb.Attrs),
Fields: fields,
}
}
23 changes: 20 additions & 3 deletions driver/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/things-go/ens"
"github.com/things-go/ens/driver"
"github.com/things-go/ens/proto"
"github.com/things-go/ens/rapier"

_ "ariga.io/atlas/sql/mysql"
_ "github.com/go-sql-driver/mysql"
Expand Down Expand Up @@ -38,13 +39,29 @@ func (self *MySQL) InspectProto(ctx context.Context, arg *driver.InspectOption)
if err != nil {
return nil, err
}
messages := make([]*proto.Message, 0, len(schemaes.Tables))
entities := make([]*proto.Message, 0, len(schemaes.Tables))
for _, tb := range schemaes.Tables {
messages = append(messages, IntoProto(tb))
entities = append(entities, IntoProto(tb))
}
return &proto.Schema{
Name: schemaes.Name,
Messages: messages,
Entities: entities,
}, nil
}

// InspectRapier implements driver.Driver.
func (self *MySQL) InspectRapier(ctx context.Context, arg *driver.InspectOption) (*rapier.Schema, error) {
schemaes, err := self.inspectSchema(ctx, arg)
if err != nil {
return nil, err
}
entities := make([]*rapier.Struct, 0, len(schemaes.Tables))
for _, tb := range schemaes.Tables {
entities = append(entities, IntoRapier(tb))
}
return &rapier.Schema{
Name: schemaes.Name,
Entities: entities,
}, nil
}

Expand Down
17 changes: 15 additions & 2 deletions driver/mysql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/things-go/ens/driver"
"github.com/things-go/ens/internal/sqlx"
"github.com/things-go/ens/proto"
"github.com/things-go/ens/rapier"
"github.com/xwb1989/sqlparser"
)

Expand All @@ -32,15 +33,27 @@ func (self *SQL) InspectSchema(ctx context.Context, arg *driver.InspectOption) (
}, nil
}

// InspectSchema implements driver.Driver.
// InspectProto implements driver.Driver.
func (self *SQL) InspectProto(ctx context.Context, arg *driver.InspectOption) (*proto.Schema, error) {
table, err := self.inspectSchema(ctx, arg)
if err != nil {
return nil, err
}
return &proto.Schema{
Name: "",
Messages: []*proto.Message{IntoProto(table)},
Entities: []*proto.Message{IntoProto(table)},
}, nil
}

// InspectRapier implements driver.Driver.
func (self *SQL) InspectRapier(ctx context.Context, arg *driver.InspectOption) (*rapier.Schema, error) {
table, err := self.inspectSchema(ctx, arg)
if err != nil {
return nil, err
}
return &rapier.Schema{
Name: "",
Entities: []*rapier.Struct{IntoRapier(table)},
}, nil
}

Expand Down
Loading

0 comments on commit b54e83a

Please sign in to comment.