From 46a724a6552d14a08c3e99b91f2a1dc8088d2bad Mon Sep 17 00:00:00 2001 From: Johannes Rauh Date: Wed, 4 Dec 2024 15:03:17 +0100 Subject: [PATCH] Build SET clause from columns --- database/query_builder.go | 32 +++++++++++++++++++---------- database/update.go | 42 ++++++++++++++++++++++++++++----------- 2 files changed, 51 insertions(+), 23 deletions(-) diff --git a/database/query_builder.go b/database/query_builder.go index 3982f09..9a5d9c7 100644 --- a/database/query_builder.go +++ b/database/query_builder.go @@ -184,41 +184,51 @@ func (qb *queryBuilder) SelectStatement(stmt SelectStatement) string { } func (qb *queryBuilder) UpdateStatement(stmt UpdateStatement) (string, error) { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + table := stmt.Table() if table == "" { table = TableName(stmt.Entity()) } - set := stmt.Set() - if set == "" { - return "", errors.New("set cannot be empty") - } + where := stmt.Where() if where == "" { - return "", errors.New("cannot use UpdateStatement() without where statement - use UpdateAllStatement() instead") + return "", fmt.Errorf("%w: %s", ErrMissingStatementPart, "where statement - use UpdateAllStatement() instead") + } + + var set []string + + for _, col := range columns { + set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col)) } return fmt.Sprintf( - `UPDATE "%s" SET %s%s`, + `UPDATE "%s" SET %s WHERE %s`, table, - set, + strings.Join(set, ", "), where, ), nil } func (qb *queryBuilder) UpdateAllStatement(stmt UpdateStatement) (string, error) { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + table := stmt.Table() if table == "" { table = TableName(stmt.Entity()) } - set := stmt.Set() - if set == "" { - return "", errors.New("set cannot be empty") - } + where := stmt.Where() if where != "" { return "", errors.New("cannot use UpdateAllStatement() with where statement - use UpdateStatement() instead") } + var set []string + + for _, col := range columns { + set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col)) + } + return fmt.Sprintf( `UPDATE "%s" SET %s`, table, diff --git a/database/update.go b/database/update.go index d206a58..e1ee3a4 100644 --- a/database/update.go +++ b/database/update.go @@ -8,8 +8,12 @@ type UpdateStatement interface { // Overrides the table name provided by the entity. SetTable(table string) UpdateStatement - // SetSet sets the set clause for the UPDATE statement. - SetSet(set string) UpdateStatement + // SetColumns sets the columns to be updated. + SetColumns(columns ...string) UpdateStatement + + // SetExcludedColumns sets the columns to be excluded from the UPDATE statement. + // Excludes also columns set by SetColumns. + SetExcludedColumns(columns ...string) UpdateStatement // SetWhere sets the where clause for the UPDATE statement. SetWhere(where string) UpdateStatement @@ -20,8 +24,11 @@ type UpdateStatement interface { // Table returns the table name for the UPDATE statement. Table() string - // Set returns the set clause for the UPDATE statement. - Set() string + // Columns returns the columns to be updated. + Columns() []string + + // ExcludedColumns returns the columns to be excluded from the UPDATE statement. + ExcludedColumns() []string // Where returns the where clause for the UPDATE statement. Where() string @@ -39,10 +46,11 @@ func NewUpdateStatement(entity Entity) UpdateStatement { // updateStatement is the default implementation of the UpdateStatement interface. type updateStatement struct { - entity Entity - table string - set string - where string + entity Entity + table string + columns []string + excludedColumns []string + where string } func (u *updateStatement) SetTable(table string) UpdateStatement { @@ -51,8 +59,14 @@ func (u *updateStatement) SetTable(table string) UpdateStatement { return u } -func (u *updateStatement) SetSet(set string) UpdateStatement { - u.set = set +func (u *updateStatement) SetColumns(columns ...string) UpdateStatement { + u.columns = columns + + return u +} + +func (u *updateStatement) SetExcludedColumns(columns ...string) UpdateStatement { + u.excludedColumns = columns return u } @@ -71,8 +85,12 @@ func (u *updateStatement) Table() string { return u.table } -func (u *updateStatement) Set() string { - return u.set +func (u *updateStatement) Columns() []string { + return u.columns +} + +func (u *updateStatement) ExcludedColumns() []string { + return u.excludedColumns } func (u *updateStatement) Where() string {