Skip to content

Commit

Permalink
Build SET clause from columns
Browse files Browse the repository at this point in the history
  • Loading branch information
jrauh01 committed Dec 4, 2024
1 parent 148319b commit 46a724a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 23 deletions.
32 changes: 21 additions & 11 deletions database/query_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,41 +184,51 @@ func (qb *queryBuilder) SelectStatement(stmt SelectStatement) string {
}

func (qb *queryBuilder) UpdateStatement(stmt UpdateStatement) (string, error) {
columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns())

table := stmt.Table()
if table == "" {
table = TableName(stmt.Entity())
}
set := stmt.Set()
if set == "" {
return "", errors.New("set cannot be empty")
}

where := stmt.Where()
if where == "" {
return "", errors.New("cannot use UpdateStatement() without where statement - use UpdateAllStatement() instead")
return "", fmt.Errorf("%w: %s", ErrMissingStatementPart, "where statement - use UpdateAllStatement() instead")
}

var set []string

for _, col := range columns {
set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
}

return fmt.Sprintf(
`UPDATE "%s" SET %s%s`,
`UPDATE "%s" SET %s WHERE %s`,
table,
set,
strings.Join(set, ", "),
where,
), nil
}

func (qb *queryBuilder) UpdateAllStatement(stmt UpdateStatement) (string, error) {
columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns())

table := stmt.Table()
if table == "" {
table = TableName(stmt.Entity())
}
set := stmt.Set()
if set == "" {
return "", errors.New("set cannot be empty")
}

where := stmt.Where()
if where != "" {
return "", errors.New("cannot use UpdateAllStatement() with where statement - use UpdateStatement() instead")
}

var set []string

for _, col := range columns {
set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
}

return fmt.Sprintf(
`UPDATE "%s" SET %s`,
table,
Expand Down
42 changes: 30 additions & 12 deletions database/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ type UpdateStatement interface {
// Overrides the table name provided by the entity.
SetTable(table string) UpdateStatement

// SetSet sets the set clause for the UPDATE statement.
SetSet(set string) UpdateStatement
// SetColumns sets the columns to be updated.
SetColumns(columns ...string) UpdateStatement

// SetExcludedColumns sets the columns to be excluded from the UPDATE statement.
// Excludes also columns set by SetColumns.
SetExcludedColumns(columns ...string) UpdateStatement

// SetWhere sets the where clause for the UPDATE statement.
SetWhere(where string) UpdateStatement
Expand All @@ -20,8 +24,11 @@ type UpdateStatement interface {
// Table returns the table name for the UPDATE statement.
Table() string

// Set returns the set clause for the UPDATE statement.
Set() string
// Columns returns the columns to be updated.
Columns() []string

// ExcludedColumns returns the columns to be excluded from the UPDATE statement.
ExcludedColumns() []string

// Where returns the where clause for the UPDATE statement.
Where() string
Expand All @@ -39,10 +46,11 @@ func NewUpdateStatement(entity Entity) UpdateStatement {

// updateStatement is the default implementation of the UpdateStatement interface.
type updateStatement struct {
entity Entity
table string
set string
where string
entity Entity
table string
columns []string
excludedColumns []string
where string
}

func (u *updateStatement) SetTable(table string) UpdateStatement {
Expand All @@ -51,8 +59,14 @@ func (u *updateStatement) SetTable(table string) UpdateStatement {
return u
}

func (u *updateStatement) SetSet(set string) UpdateStatement {
u.set = set
func (u *updateStatement) SetColumns(columns ...string) UpdateStatement {
u.columns = columns

return u
}

func (u *updateStatement) SetExcludedColumns(columns ...string) UpdateStatement {
u.excludedColumns = columns

return u
}
Expand All @@ -71,8 +85,12 @@ func (u *updateStatement) Table() string {
return u.table
}

func (u *updateStatement) Set() string {
return u.set
func (u *updateStatement) Columns() []string {
return u.columns
}

func (u *updateStatement) ExcludedColumns() []string {
return u.excludedColumns
}

func (u *updateStatement) Where() string {
Expand Down

0 comments on commit 46a724a

Please sign in to comment.