Skip to content

Commit

Permalink
fix: add a simple dal command
Browse files Browse the repository at this point in the history
  • Loading branch information
thinkgos committed May 26, 2024
1 parent 326d052 commit 383e95d
Show file tree
Hide file tree
Showing 15 changed files with 708 additions and 17 deletions.
134 changes: 134 additions & 0 deletions cmd/ormat/command/dal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package command

import (
"bytes"
"cmp"
"errors"
"fmt"
"log/slog"
"path/filepath"
"strings"

"github.com/spf13/cobra"
"github.com/things-go/ens"
"github.com/things-go/ens/utils"
)

type dalOpt struct {
source
OutputDir string
PackageName string // 包名
ModelImportPath string // required, model导入路径
RepoImportPath string // required, repository导入路径
DalImportPath string // required, dal导入路径
CustomTemplate string // 自定义模板
ens.Option
}

type dalCmd struct {
cmd *cobra.Command
dalOpt
}

func newDakCmd() *dalCmd {
root := &dalCmd{}

cmd := &cobra.Command{
Use: "dal",
Short: "Generate dal from database",
Example: "ormat dal",
RunE: func(*cobra.Command, []string) error {
if root.CustomTemplate == "builtin-rapier" && root.RepoImportPath == "" {
return errors.New("使用builtin-rapier时repository导入路径, 不能为空")
}
schemaes, err := getSchema(&root.source)
if err != nil {
return err
}
daltpl, err := GetUsedTemplate(root.CustomTemplate)
if err != nil {
return err
}
packageName := cmp.Or(root.PackageName, utils.GetPkgName(root.OutputDir))
queryImportPath := strings.Join([]string{root.DalImportPath, "query"}, "/")

dal := Dal{
Package: packageName,
Imports: []string{root.ModelImportPath, queryImportPath, root.RepoImportPath},
ModelPrefix: utils.PkgName(root.ModelImportPath) + ".",
QueryPrefix: "query.",
RepoPrefix: utils.PkgName(root.RepoImportPath) + ".",
Entity: nil,
}
dalQuery := Dal{
Package: "query",
Imports: []string{},
ModelPrefix: utils.PkgName(root.ModelImportPath) + ".",
QueryPrefix: "",
RepoPrefix: "",
Entity: nil,
}
for _, entity := range schemaes.Entities {
dalFilename := joinFilename(root.OutputDir, entity.Name, ".go")
// _, err = os.Stat(dalFilename)
// if err == nil || os.IsExist(err) {
// slog.Warn("🐛 " + entity.Name + " already exists")
// continue
// }
dal.Entity = entity
buf := bytes.Buffer{}
err = daltpl.Execute(&buf, dal)
if err != nil {
return err
}

err = WriteFile(dalFilename, buf.Bytes())
if err != nil {
return fmt.Errorf("%v: %v", entity.Name, err)
}

buf.Reset()
dalQuery.Entity = entity
err = dalQueryTpl.Execute(&buf, dalQuery)
if err != nil {
return err
}
dalQueryFilename := joinFilename(filepath.Join(root.OutputDir, "query"), entity.Name, ".go")
err = WriteFile(dalQueryFilename, buf.Bytes())
if err != nil {
return err
}
slog.Info("👉 " + dalFilename)
slog.Info("👉 " + dalQueryFilename)
}

slog.Info("😄 generate success !!!")
return nil
},
}
// input file
cmd.Flags().StringSliceVarP(&root.InputFile, "input", "i", nil, "input file")
cmd.Flags().StringVarP(&root.Schema, "schema", "s", "file+mysql", "parser file driver, [file+mysql,file+tidb](仅input时有效)")
// database url
cmd.Flags().StringVarP(&root.URL, "url", "u", "", "mysql://root:[email protected]:3306/test")
cmd.Flags().StringSliceVarP(&root.Tables, "table", "t", nil, "only out custom table")
cmd.Flags().StringSliceVarP(&root.Exclude, "exclude", "e", nil, "exclude table pattern")

cmd.Flags().StringVarP(&root.OutputDir, "out", "o", "./dal", "out directory")
cmd.Flags().StringVar(&root.PackageName, "package", "", "package name")
cmd.Flags().StringVar(&root.CustomTemplate, "template", "builtin-rapier", "use custom template except [builtin-rapier, builtin-gorm]")
cmd.Flags().StringVar(&root.ModelImportPath, "modelImportPath", "", "model导入路径")
cmd.Flags().StringVar(&root.DalImportPath, "dalImportPath", "", "dal导入路径")
cmd.Flags().StringVar(&root.RepoImportPath, "repoImportPath", "", "repository导入路径")

cmd.Flags().BoolVar(&root.EnableInt, "enableInt", false, "使能int8,uint8,int16,uint16,int32,uint32输出为int,uint")
cmd.Flags().BoolVar(&root.EnableBoolInt, "enableBoolInt", false, "使能bool输出int")
cmd.Flags().BoolVar(&root.DisableNullToPoint, "disableNullToPoint", false, "禁用字段为null时输出指针类型,将输出为sql.Nullxx")
cmd.Flags().StringSliceVar(&root.EscapeName, "escapeName", nil, "escape name list")

cmd.MarkFlagsOneRequired("url", "input")
cmd.MarkFlagRequired("modelImportPath")
cmd.MarkFlagRequired("dalImportPath")
root.cmd = cmd
return root
}
6 changes: 3 additions & 3 deletions cmd/ormat/command/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,18 @@ func newModelCmd() *modelCmd {
cmd.Flags().StringVar(&root.PackageName, "package", "", "package name")
cmd.Flags().BoolVar(&root.DisableDocComment, "disableDocComment", false, "禁用文档注释")

cmd.Flags().StringToStringVar(&root.Tags, "tags", map[string]string{"json": utils.StyleSmallCamelCase}, "tags标签,类型支持[smallCamelCase,camelCase,snakeCase,kebab]")
cmd.Flags().StringToStringVar(&root.Tags, "tags", map[string]string{"json": utils.StyleSmallCamelCase}, "tags标签,类型支持[smallCamelCase,pascalCase,snakeCase,kebab]")
cmd.Flags().BoolVar(&root.EnableInt, "enableInt", false, "使能int8,uint8,int16,uint16,int32,uint32输出为int,uint")
cmd.Flags().BoolVar(&root.EnableBoolInt, "enableBoolInt", false, "使能bool输出int")
cmd.Flags().BoolVar(&root.DisableNullToPoint, "disableNullToPoint", false, "禁用字段为null时输出指针类型,将输出为sql.Nullxx")
cmd.Flags().BoolVar(&root.DisableCommentTag, "disableCommentTag", false, "禁用注释放入tag标签中")
cmd.Flags().BoolVar(&root.EnableForeignKey, "enableForeignKey", false, "使用外键")
cmd.Flags().StringSliceVar(&root.EscapeName, "escapeName", nil, "exclude table pattern")
cmd.Flags().StringSliceVar(&root.EscapeName, "escapeName", nil, "escape name list")

cmd.Flags().BoolVar(&root.Merge, "merge", false, "merge in a file or not")
cmd.Flags().StringVar(&root.MergeFilename, "filename", "", "merge filename")

cmd.MarkPersistentFlagRequired("url") // nolint
cmd.MarkFlagsOneRequired("url", "input")

root.cmd = cmd
return root
Expand Down
4 changes: 2 additions & 2 deletions cmd/ormat/command/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type protoOpt struct {
// codegen
PackageName string // required, proto 包名
Options map[string]string // required, proto option
Style string // 字段代码风格, snakeCase, smallCamelCase, camelCase
Style string // 字段代码风格, snakeCase, smallCamelCase, pascalCase
DisableDocComment bool // 禁用doc注释
DisableBool bool // 禁用bool,使用int32
DisableTimestamp bool // 禁用google.protobuf.Timestamp,使用int64
Expand Down Expand Up @@ -82,7 +82,7 @@ func newProtoCmd() *protoCmd {

cmd.Flags().StringVar(&root.PackageName, "package", "", "proto package name")
cmd.Flags().StringToStringVar(&root.Options, "options", nil, "proto options key/value")
cmd.Flags().StringVar(&root.Style, "style", "", "字段代码风格, [snakeCase,smallCamelCase,camelCase]")
cmd.Flags().StringVar(&root.Style, "style", "", "字段代码风格, [snakeCase,smallCamelCase,pascalCase]")
cmd.Flags().BoolVar(&root.DisableDocComment, "disableDocComment", false, "禁用文档注释")
cmd.Flags().BoolVar(&root.DisableBool, "disableBool", false, "禁用bool,使用int32")
cmd.Flags().BoolVar(&root.DisableTimestamp, "disableTimestamp", false, "禁用google.protobuf.Timestamp,使用int64")
Expand Down
1 change: 1 addition & 0 deletions cmd/ormat/command/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func NewRootCmd() *RootCmd {
newModelCmd().cmd,
newProtoCmd().cmd,
newRapierCmd().cmd,
newDakCmd().cmd,
)
root.cmd = cmd
return root
Expand Down
77 changes: 77 additions & 0 deletions cmd/ormat/command/template.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package command

import (
"embed"
"errors"
"text/template"

"github.com/things-go/ens"
"github.com/things-go/ens/utils"
)

//go:embed template/*.tpl
var Static embed.FS

var TemplateFuncs = template.FuncMap{
"add": func(a, b int) int { return a + b },
"snakecase": func(s string) string { return utils.SnakeCase(s) },
"kebabcase": func(s string) string { return utils.Kebab(s) },
"pascalcase": func(s string) string { return utils.PascalCase(s) },
"smallcamelcase": func(s string) string { return utils.SmallCamelCase(s) },
}
var (
tpl = template.Must(template.New("components").
Funcs(TemplateFuncs).
ParseFS(Static, "template/*.tpl"))
dalRapierTpl = tpl.Lookup("dal_rapier.tpl")
dalGormTpl = tpl.Lookup("dal_gorm.tpl")
dalQueryTpl = tpl.Lookup("dal_query.tpl")
)

type Dal struct {
Package string
Imports []string
ModelPrefix string
QueryPrefix string
RepoPrefix string
Entity *ens.EntityDescriptor
}

type DalQuery struct {
PackageName string
Imports []string
ModelQualifier string
Entity ens.EntityDescriptor
}

func GetUsedTemplate(t string) (*template.Template, error) {
switch t {
case "builtin-gorm":
return dalGormTpl, nil
case "builtin-rapier":
return dalRapierTpl, nil
default:
t, err := ParseTemplateFromFile(t)
if err != nil {
return nil, err
}
return t, nil
}
}

func ParseTemplateFromFile(filename string) (*template.Template, error) {
if filename == "" {
return nil, errors.New("required template filename")
}
tt, err := template.New("custom").
Funcs(TemplateFuncs).
ParseFiles(filename)
if err != nil {
return nil, err
}
ts := tt.Templates()
if len(ts) == 0 {
return nil, errors.New("not found any template")
}
return ts[0], nil
}
Loading

0 comments on commit 383e95d

Please sign in to comment.