diff --git a/pkg/backend/access.go b/pkg/backend/access.go new file mode 100644 index 000000000..719ef06ab --- /dev/null +++ b/pkg/backend/access.go @@ -0,0 +1,95 @@ +package backend + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/access" + "github.com/charmbracelet/soft-serve/pkg/proto" + "github.com/charmbracelet/soft-serve/pkg/sshutils" + "golang.org/x/crypto/ssh" +) + +// AccessLevel returns the access level of a user for a repository. +// +// It implements backend.Backend. +func (d *Backend) AccessLevel(ctx context.Context, repo string, username string) access.AccessLevel { + user, _ := d.User(ctx, username) + return d.AccessLevelForUser(ctx, repo, user) +} + +// AccessLevelByPublicKey returns the access level of a user's public key for a repository. +// +// It implements backend.Backend. +func (d *Backend) AccessLevelByPublicKey(ctx context.Context, repo string, pk ssh.PublicKey) access.AccessLevel { + for _, k := range d.cfg.AdminKeys() { + if sshutils.KeysEqual(pk, k) { + return access.AdminAccess + } + } + + user, _ := d.UserByPublicKey(ctx, pk) + if user != nil { + return d.AccessLevel(ctx, repo, user.Username()) + } + + return d.AccessLevel(ctx, repo, "") +} + +// AccessLevelForUser returns the access level of a user for a repository. +// TODO: user repository ownership +func (d *Backend) AccessLevelForUser(ctx context.Context, repo string, user proto.User) access.AccessLevel { + var username string + anon := d.AnonAccess(ctx) + if user != nil { + username = user.Username() + } + + // If the user is an admin, they have admin access. + if user != nil && user.IsAdmin() { + return access.AdminAccess + } + + // If the repository exists, check if the user is a collaborator. + r := proto.RepositoryFromContext(ctx) + if r == nil { + r, _ = d.Repository(ctx, repo) + } + + if r != nil { + if user != nil { + // If the user is the owner, they have admin access. + if r.UserID() == user.ID() { + return access.AdminAccess + } + } + + // If the user is a collaborator, they have return their access level. + collabAccess, isCollab, _ := d.IsCollaborator(ctx, repo, username) + if isCollab { + if anon > collabAccess { + return anon + } + return collabAccess + } + + // If the repository is private, the user has no access. + if r.IsPrivate() { + return access.NoAccess + } + + // Otherwise, the user has read-only access. + return access.ReadOnlyAccess + } + + if user != nil { + // If the repository doesn't exist, the user has read/write access. + if anon > access.ReadWriteAccess { + return anon + } + + return access.ReadWriteAccess + } + + // If the user doesn't exist, give them the anonymous access level. + return anon +} diff --git a/pkg/backend/collab.go b/pkg/backend/collab.go index c9635ae37..46ad438f0 100644 --- a/pkg/backend/collab.go +++ b/pkg/backend/collab.go @@ -18,7 +18,7 @@ import ( // It implements backend.Backend. func (d *Backend) AddCollaborator(ctx context.Context, repo string, username string, level access.AccessLevel) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -50,19 +50,33 @@ func (d *Backend) AddCollaborator(ctx context.Context, repo string, username str func (d *Backend) Collaborators(ctx context.Context, repo string) ([]string, error) { repo = utils.SanitizeRepo(repo) var users []models.User + var usernames []string if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { var err error users, err = d.store.ListCollabsByRepoAsUsers(ctx, tx, repo) - return err + if err != nil { + return err + } + + ids := make([]int64, len(users)) + for i, u := range users { + ids[i] = u.ID + } + + handles, err := d.store.ListHandlesForIDs(ctx, tx, ids) + if err != nil { + return err + } + + for _, h := range handles { + usernames = append(usernames, h.Handle) + } + + return nil }); err != nil { return nil, db.WrapError(err) } - var usernames []string - for _, u := range users { - usernames = append(usernames, u.Username) - } - return usernames, nil } diff --git a/pkg/backend/org.go b/pkg/backend/org.go new file mode 100644 index 000000000..2fd7ee10d --- /dev/null +++ b/pkg/backend/org.go @@ -0,0 +1,66 @@ +package backend + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/db/models" + "github.com/charmbracelet/soft-serve/pkg/proto" +) + +// CreateOrg creates a new organization. +func (d *Backend) CreateOrg(ctx context.Context, owner proto.User, name, email string) (proto.Org, error) { + o, err := d.store.CreateOrg(ctx, d.db, owner.ID(), name, email) + if err != nil { + return org{}, err + } + return org{o}, err +} + +// ListOrgs lists all organizations for a user. +func (d *Backend) ListOrgs(ctx context.Context, user proto.User) ([]proto.Org, error) { + orgs, err := d.store.ListOrgs(ctx, d.db, user.ID()) + var r []proto.Org + for _, o := range orgs { + r = append(r, org{o}) + } + return r, err +} + +// FindOrganization finds an organization belonging to a user by name. +func (d *Backend) FindOrganization(ctx context.Context, user proto.User, name string) (proto.Org, error) { + o, err := d.store.FindOrgByHandle(ctx, d.db, user.ID(), name) + return org{o}, err +} + +// DeleteOrganization deletes an organization for a user. +func (d *Backend) DeleteOrganization(ctx context.Context, user proto.User, name string) error { + o, err := d.store.FindOrgByHandle(ctx, d.db, user.ID(), name) + if err != nil { + return err + } + return d.store.DeleteOrgByID(ctx, d.db, user.ID(), o.ID) +} + +type org struct { + o models.Organization +} + +var _ proto.Org = org{} + +// DisplayName implements proto.Org. +func (o org) DisplayName() string { + if o.o.Name.Valid { + return o.o.Name.String + } + return "" +} + +// ID implements proto.Org. +func (o org) ID() int64 { + return o.o.ID +} + +// Name implements proto.Org. +func (o org) Name() string { + return o.o.Handle.Handle +} diff --git a/pkg/backend/team.go b/pkg/backend/team.go new file mode 100644 index 000000000..188e2ace1 --- /dev/null +++ b/pkg/backend/team.go @@ -0,0 +1,72 @@ +package backend + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/db/models" + "github.com/charmbracelet/soft-serve/pkg/proto" +) + +// CreateTeam creates a new team for an organization. +func (d *Backend) CreateTeam(ctx context.Context, org proto.Org, owner proto.User, name string) (proto.Team, error) { + m, err := d.store.CreateTeam(ctx, d.db, owner.ID(), org.ID(), name) + if err != nil { + return team{}, err + } + return team{m}, err +} + +// ListTeams lists all teams for a user. +func (d *Backend) ListTeams(ctx context.Context, user proto.User) ([]proto.Team, error) { + teams, err := d.store.ListTeams(ctx, d.db, user.ID()) + var r []proto.Team + for _, m := range teams { + r = append(r, team{m}) + } + return r, err +} + +// GetTeam gets a team by organization id and team name. +func (d *Backend) GetTeam(ctx context.Context, user proto.User, org proto.Org, name string) (proto.Team, error) { + m, err := d.store.FindTeamByOrgName(ctx, d.db, user.ID(), org.ID(), name) + if err != nil { + return team{}, err + } + return team{m}, err +} + +// FindTeam finds a team by name. +func (d *Backend) FindTeam(ctx context.Context, user proto.User, name string) ([]proto.Team, error) { + m, err := d.store.FindTeamByName(ctx, d.db, user.ID(), name) + var r []proto.Team + for _, m := range m { + r = append(r, team{m}) + } + return r, err +} + +// DeleteTeam deletes a team. +func (d *Backend) DeleteTeam(ctx context.Context, _ proto.User, team proto.Team) error { + return d.store.DeleteTeamByID(ctx, d.db, team.ID()) +} + +type team struct { + t models.Team +} + +var _ proto.Team = team{} + +// ID implements proto.Team. +func (t team) ID() int64 { + return t.t.ID +} + +// Name implements proto.Team. +func (t team) Name() string { + return t.t.Name +} + +// Org implements proto.Team. +func (t team) Org() int64 { + return t.t.OrganizationID +} diff --git a/pkg/backend/user.go b/pkg/backend/user.go index 75423048d..1ef1dd405 100644 --- a/pkg/backend/user.go +++ b/pkg/backend/user.go @@ -6,7 +6,6 @@ import ( "strings" "time" - "github.com/charmbracelet/soft-serve/pkg/access" "github.com/charmbracelet/soft-serve/pkg/db" "github.com/charmbracelet/soft-serve/pkg/db/models" "github.com/charmbracelet/soft-serve/pkg/proto" @@ -15,102 +14,19 @@ import ( "golang.org/x/crypto/ssh" ) -// AccessLevel returns the access level of a user for a repository. -// -// It implements backend.Backend. -func (d *Backend) AccessLevel(ctx context.Context, repo string, username string) access.AccessLevel { - user, _ := d.User(ctx, username) - return d.AccessLevelForUser(ctx, repo, user) -} - -// AccessLevelByPublicKey returns the access level of a user's public key for a repository. -// -// It implements backend.Backend. -func (d *Backend) AccessLevelByPublicKey(ctx context.Context, repo string, pk ssh.PublicKey) access.AccessLevel { - for _, k := range d.cfg.AdminKeys() { - if sshutils.KeysEqual(pk, k) { - return access.AdminAccess - } - } - - user, _ := d.UserByPublicKey(ctx, pk) - if user != nil { - return d.AccessLevel(ctx, repo, user.Username()) - } - - return d.AccessLevel(ctx, repo, "") -} - -// AccessLevelForUser returns the access level of a user for a repository. -// TODO: user repository ownership -func (d *Backend) AccessLevelForUser(ctx context.Context, repo string, user proto.User) access.AccessLevel { - var username string - anon := d.AnonAccess(ctx) - if user != nil { - username = user.Username() - } - - // If the user is an admin, they have admin access. - if user != nil && user.IsAdmin() { - return access.AdminAccess - } - - // If the repository exists, check if the user is a collaborator. - r := proto.RepositoryFromContext(ctx) - if r == nil { - r, _ = d.Repository(ctx, repo) - } - - if r != nil { - if user != nil { - // If the user is the owner, they have admin access. - if r.UserID() == user.ID() { - return access.AdminAccess - } - } - - // If the user is a collaborator, they have return their access level. - collabAccess, isCollab, _ := d.IsCollaborator(ctx, repo, username) - if isCollab { - if anon > collabAccess { - return anon - } - return collabAccess - } - - // If the repository is private, the user has no access. - if r.IsPrivate() { - return access.NoAccess - } - - // Otherwise, the user has read-only access. - return access.ReadOnlyAccess - } - - if user != nil { - // If the repository doesn't exist, the user has read/write access. - if anon > access.ReadWriteAccess { - return anon - } - - return access.ReadWriteAccess - } - - // If the user doesn't exist, give them the anonymous access level. - return anon -} - // User finds a user by username. // // It implements backend.Backend. func (d *Backend) User(ctx context.Context, username string) (proto.User, error) { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return nil, err } var m models.User var pks []ssh.PublicKey + var hl models.Handle + var ems []proto.UserEmail if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { var err error m, err = d.store.FindUserByUsername(ctx, tx, username) @@ -119,6 +35,20 @@ func (d *Backend) User(ctx context.Context, username string) (proto.User, error) } pks, err = d.store.ListPublicKeysByUserID(ctx, tx, m.ID) + if err != nil { + return err + } + + emails, err := d.store.ListUserEmails(ctx, tx, m.ID) + if err != nil { + return err + } + + for _, e := range emails { + ems = append(ems, &userEmail{e}) + } + + hl, err = d.store.GetHandleByUserID(ctx, tx, m.ID) return err }); err != nil { err = db.WrapError(err) @@ -132,6 +62,8 @@ func (d *Backend) User(ctx context.Context, username string) (proto.User, error) return &user{ user: m, publicKeys: pks, + handle: hl, + emails: ems, }, nil } @@ -139,6 +71,8 @@ func (d *Backend) User(ctx context.Context, username string) (proto.User, error) func (d *Backend) UserByID(ctx context.Context, id int64) (proto.User, error) { var m models.User var pks []ssh.PublicKey + var hl models.Handle + var ems []proto.UserEmail if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { var err error m, err = d.store.GetUserByID(ctx, tx, id) @@ -147,6 +81,20 @@ func (d *Backend) UserByID(ctx context.Context, id int64) (proto.User, error) { } pks, err = d.store.ListPublicKeysByUserID(ctx, tx, m.ID) + if err != nil { + return err + } + + emails, err := d.store.ListUserEmails(ctx, tx, m.ID) + if err != nil { + return err + } + + for _, e := range emails { + ems = append(ems, &userEmail{e}) + } + + hl, err = d.store.GetHandleByUserID(ctx, tx, m.ID) return err }); err != nil { err = db.WrapError(err) @@ -160,6 +108,8 @@ func (d *Backend) UserByID(ctx context.Context, id int64) (proto.User, error) { return &user{ user: m, publicKeys: pks, + handle: hl, + emails: ems, }, nil } @@ -169,6 +119,8 @@ func (d *Backend) UserByID(ctx context.Context, id int64) (proto.User, error) { func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (proto.User, error) { var m models.User var pks []ssh.PublicKey + var hl models.Handle + var ems []proto.UserEmail if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { var err error m, err = d.store.FindUserByPublicKey(ctx, tx, pk) @@ -177,6 +129,20 @@ func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (proto. } pks, err = d.store.ListPublicKeysByUserID(ctx, tx, m.ID) + if err != nil { + return err + } + + emails, err := d.store.ListUserEmails(ctx, tx, m.ID) + if err != nil { + return err + } + + for _, e := range emails { + ems = append(ems, &userEmail{e}) + } + + hl, err = d.store.GetHandleByUserID(ctx, tx, m.ID) return err }); err != nil { err = db.WrapError(err) @@ -190,6 +156,8 @@ func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (proto. return &user{ user: m, publicKeys: pks, + handle: hl, + emails: ems, }, nil } @@ -198,6 +166,8 @@ func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (proto. func (d *Backend) UserByAccessToken(ctx context.Context, token string) (proto.User, error) { var m models.User var pks []ssh.PublicKey + var hl models.Handle + var ems []proto.UserEmail token = HashToken(token) if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { @@ -216,6 +186,20 @@ func (d *Backend) UserByAccessToken(ctx context.Context, token string) (proto.Us } pks, err = d.store.ListPublicKeysByUserID(ctx, tx, m.ID) + if err != nil { + return err + } + + emails, err := d.store.ListUserEmails(ctx, tx, m.ID) + if err != nil { + return err + } + + for _, e := range emails { + ems = append(ems, &userEmail{e}) + } + + hl, err = d.store.GetHandleByUserID(ctx, tx, m.ID) return err }); err != nil { err = db.WrapError(err) @@ -229,6 +213,8 @@ func (d *Backend) UserByAccessToken(ctx context.Context, token string) (proto.Us return &user{ user: m, publicKeys: pks, + handle: hl, + emails: ems, }, nil } @@ -243,8 +229,18 @@ func (d *Backend) Users(ctx context.Context) ([]string, error) { return err } - for _, m := range ms { - users = append(users, m.Username) + ids := make([]int64, len(ms)) + for i, m := range ms { + ids[i] = m.ID + } + + handles, err := d.store.ListHandlesForIDs(ctx, tx, ids) + if err != nil { + return err + } + + for _, h := range handles { + users = append(users, h.Handle) } return nil @@ -260,7 +256,7 @@ func (d *Backend) Users(ctx context.Context) ([]string, error) { // It implements backend.Backend. func (d *Backend) AddPublicKey(ctx context.Context, username string, pk ssh.PublicKey) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -275,13 +271,8 @@ func (d *Backend) AddPublicKey(ctx context.Context, username string, pk ssh.Publ // // It implements backend.Backend. func (d *Backend) CreateUser(ctx context.Context, username string, opts proto.UserOptions) (proto.User, error) { - username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { - return nil, err - } - if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { - return d.store.CreateUser(ctx, tx, username, opts.Admin, opts.PublicKeys) + return d.store.CreateUser(ctx, tx, username, opts.Admin, opts.PublicKeys, opts.Emails) }); err != nil { return nil, db.WrapError(err) } @@ -294,7 +285,7 @@ func (d *Backend) CreateUser(ctx context.Context, username string, opts proto.Us // It implements backend.Backend. func (d *Backend) DeleteUser(ctx context.Context, username string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -321,7 +312,7 @@ func (d *Backend) RemovePublicKey(ctx context.Context, username string, pk ssh.P // ListPublicKeys lists the public keys of a user. func (d *Backend) ListPublicKeys(ctx context.Context, username string) ([]ssh.PublicKey, error) { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return nil, err } @@ -342,7 +333,7 @@ func (d *Backend) ListPublicKeys(ctx context.Context, username string) ([]ssh.Pu // It implements backend.Backend. func (d *Backend) SetUsername(ctx context.Context, username string, newUsername string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -358,7 +349,7 @@ func (d *Backend) SetUsername(ctx context.Context, username string, newUsername // It implements backend.Backend. func (d *Backend) SetAdmin(ctx context.Context, username string, admin bool) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -372,7 +363,7 @@ func (d *Backend) SetAdmin(ctx context.Context, username string, admin bool) err // SetPassword sets the password of a user. func (d *Backend) SetPassword(ctx context.Context, username string, rawPassword string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -388,9 +379,60 @@ func (d *Backend) SetPassword(ctx context.Context, username string, rawPassword ) } +// AddUserEmail adds an email to a user. +func (d *Backend) AddUserEmail(ctx context.Context, user proto.User, email string) error { + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.AddUserEmail(ctx, tx, user.ID(), email, false) + }), + ) +} + +// ListUserEmails lists the emails of a user. +func (d *Backend) ListUserEmails(ctx context.Context, user proto.User) ([]proto.UserEmail, error) { + var ems []proto.UserEmail + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + emails, err := d.store.ListUserEmails(ctx, tx, user.ID()) + if err != nil { + return err + } + + for _, e := range emails { + ems = append(ems, &userEmail{e}) + } + + return nil + }); err != nil { + return nil, db.WrapError(err) + } + + return ems, nil +} + +// RemoveUserEmail deletes an email for a user. +// The deleted email must not be the primary email. +func (d *Backend) RemoveUserEmail(ctx context.Context, user proto.User, email string) error { + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.RemoveUserEmail(ctx, tx, user.ID(), email) + }), + ) +} + +// SetUserPrimaryEmail sets the primary email of a user. +func (d *Backend) SetUserPrimaryEmail(ctx context.Context, user proto.User, email string) error { + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.SetUserPrimaryEmail(ctx, tx, user.ID(), email) + }), + ) +} + type user struct { user models.User publicKeys []ssh.PublicKey + handle models.Handle + emails []proto.UserEmail } var _ proto.User = (*user)(nil) @@ -407,7 +449,7 @@ func (u *user) PublicKeys() []ssh.PublicKey { // Username implements proto.User func (u *user) Username() string { - return u.user.Username + return u.handle.Handle } // ID implements proto.User. @@ -423,3 +465,29 @@ func (u *user) Password() string { return "" } + +// Emails implements proto.User. +func (u *user) Emails() []proto.UserEmail { + return u.emails +} + +type userEmail struct { + email models.UserEmail +} + +var _ proto.UserEmail = (*userEmail)(nil) + +// Email implements proto.UserEmail. +func (e *userEmail) Email() string { + return e.email.Email +} + +// ID implements proto.UserEmail. +func (e *userEmail) ID() int64 { + return e.email.ID +} + +// IsPrimary implements proto.UserEmail. +func (e *userEmail) IsPrimary() bool { + return e.email.IsPrimary +} diff --git a/pkg/db/handler.go b/pkg/db/handler.go index 981cadf21..ecad917ef 100644 --- a/pkg/db/handler.go +++ b/pkg/db/handler.go @@ -9,6 +9,7 @@ import ( // Handler is a database handler. type Handler interface { + DriverName() string Rebind(string) string Select(interface{}, string, ...interface{}) error diff --git a/pkg/db/migrate/0001_create_tables.go b/pkg/db/migrate/0001_create_tables.go index 8491328c0..590450291 100644 --- a/pkg/db/migrate/0001_create_tables.go +++ b/pkg/db/migrate/0001_create_tables.go @@ -20,71 +20,71 @@ const ( var createTables = Migration{ Version: createTablesVersion, Name: createTablesName, - Migrate: func(ctx context.Context, tx *db.Tx) error { + Migrate: func(ctx context.Context, h db.Handler) error { cfg := config.FromContext(ctx) insert := "INSERT " // Alter old tables (if exist) // This is to support prior versions of Soft Serve v0.6 - switch tx.DriverName() { + switch h.DriverName() { case "sqlite3", "sqlite": insert += "OR IGNORE " - hasUserTable := hasTable(tx, "user") + hasUserTable := hasTable(h, "user") if hasUserTable { - if _, err := tx.ExecContext(ctx, "ALTER TABLE user RENAME TO user_old"); err != nil { + if _, err := h.ExecContext(ctx, "ALTER TABLE user RENAME TO user_old"); err != nil { return err } } - if hasTable(tx, "public_key") { - if _, err := tx.ExecContext(ctx, "ALTER TABLE public_key RENAME TO public_key_old"); err != nil { + if hasTable(h, "public_key") { + if _, err := h.ExecContext(ctx, "ALTER TABLE public_key RENAME TO public_key_old"); err != nil { return err } } - if hasTable(tx, "collab") { - if _, err := tx.ExecContext(ctx, "ALTER TABLE collab RENAME TO collab_old"); err != nil { + if hasTable(h, "collab") { + if _, err := h.ExecContext(ctx, "ALTER TABLE collab RENAME TO collab_old"); err != nil { return err } } - if hasTable(tx, "repo") { - if _, err := tx.ExecContext(ctx, "ALTER TABLE repo RENAME TO repo_old"); err != nil { + if hasTable(h, "repo") { + if _, err := h.ExecContext(ctx, "ALTER TABLE repo RENAME TO repo_old"); err != nil { return err } } } - if err := migrateUp(ctx, tx, createTablesVersion, createTablesName); err != nil { + if err := migrateUp(ctx, h, createTablesVersion, createTablesName); err != nil { return err } - switch tx.DriverName() { + switch h.DriverName() { case "sqlite3", "sqlite": - if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil { + if _, err := h.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil { return err } - if hasTable(tx, "user_old") { + if hasTable(h, "user_old") { sqlm := ` INSERT INTO users (id, username, admin, updated_at) SELECT id, username, admin, updated_at FROM user_old; ` - if _, err := tx.ExecContext(ctx, sqlm); err != nil { + if _, err := h.ExecContext(ctx, sqlm); err != nil { return err } } - if hasTable(tx, "public_key_old") { + if hasTable(h, "public_key_old") { // Check duplicate keys pks := []struct { ID string `db:"id"` PublicKey string `db:"public_key"` }{} - if err := tx.SelectContext(ctx, &pks, "SELECT id, public_key FROM public_key_old"); err != nil { + if err := h.SelectContext(ctx, &pks, "SELECT id, public_key FROM public_key_old"); err != nil { return err } @@ -100,53 +100,53 @@ var createTables = Migration{ INSERT INTO public_keys (id, user_id, public_key, created_at, updated_at) SELECT id, user_id, public_key, created_at, updated_at FROM public_key_old; ` - if _, err := tx.ExecContext(ctx, sqlm); err != nil { + if _, err := h.ExecContext(ctx, sqlm); err != nil { return err } } - if hasTable(tx, "repo_old") { + if hasTable(h, "repo_old") { sqlm := ` INSERT INTO repos (id, name, project_name, description, private,mirror, hidden, created_at, updated_at, user_id) SELECT id, name, project_name, description, private, mirror, hidden, created_at, updated_at, ( SELECT id FROM users WHERE admin = true ORDER BY id LIMIT 1 ) FROM repo_old; ` - if _, err := tx.ExecContext(ctx, sqlm); err != nil { + if _, err := h.ExecContext(ctx, sqlm); err != nil { return err } } - if hasTable(tx, "collab_old") { + if hasTable(h, "collab_old") { sqlm := ` INSERT INTO collabs (id, user_id, repo_id, access_level, created_at, updated_at) SELECT id, user_id, repo_id, ` + strconv.Itoa(int(access.ReadWriteAccess)) + `, created_at, updated_at FROM collab_old; ` - if _, err := tx.ExecContext(ctx, sqlm); err != nil { + if _, err := h.ExecContext(ctx, sqlm); err != nil { return err } } - if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { + if _, err := h.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { return err } } // Insert default user - insertUser := tx.Rebind(insert + "INTO users (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)") - if _, err := tx.ExecContext(ctx, insertUser, "admin", true); err != nil { + insertUser := h.Rebind(insert + "INTO users (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)") + if _, err := h.ExecContext(ctx, insertUser, "admin", true); err != nil { return err } for _, k := range cfg.AdminKeys() { query := insert + "INTO public_keys (user_id, public_key, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)" - if tx.DriverName() == "postgres" { + if h.DriverName() == "postgres" { query += " ON CONFLICT DO NOTHING" } - query = tx.Rebind(query) + query = h.Rebind(query) ak := sshutils.MarshalAuthorizedKey(k) - if _, err := tx.ExecContext(ctx, query, 1, ak); err != nil { + if _, err := h.ExecContext(ctx, query, 1, ak); err != nil { if errors.Is(db.WrapError(err), db.ErrDuplicateKey) { continue } @@ -156,7 +156,7 @@ var createTables = Migration{ // Insert default settings insertSettings := insert + "INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)" - insertSettings = tx.Rebind(insertSettings) + insertSettings = h.Rebind(insertSettings) settings := []struct { Key string Value string @@ -167,14 +167,14 @@ var createTables = Migration{ } for _, s := range settings { - if _, err := tx.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil { + if _, err := h.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil { return fmt.Errorf("inserting default settings %q: %w", s.Key, err) } } return nil }, - Rollback: func(ctx context.Context, tx *db.Tx) error { - return migrateDown(ctx, tx, createTablesVersion, createTablesName) + Rollback: func(ctx context.Context, h db.Handler) error { + return migrateDown(ctx, h, createTablesVersion, createTablesName) }, } diff --git a/pkg/db/migrate/0002_webhooks.go b/pkg/db/migrate/0002_webhooks.go index a525f6edc..257be3640 100644 --- a/pkg/db/migrate/0002_webhooks.go +++ b/pkg/db/migrate/0002_webhooks.go @@ -14,10 +14,10 @@ const ( var webhooks = Migration{ Name: webhooksName, Version: webhooksVersion, - Migrate: func(ctx context.Context, tx *db.Tx) error { - return migrateUp(ctx, tx, webhooksVersion, webhooksName) + Migrate: func(ctx context.Context, h db.Handler) error { + return migrateUp(ctx, h, webhooksVersion, webhooksName) }, - Rollback: func(ctx context.Context, tx *db.Tx) error { - return migrateDown(ctx, tx, webhooksVersion, webhooksName) + Rollback: func(ctx context.Context, h db.Handler) error { + return migrateDown(ctx, h, webhooksVersion, webhooksName) }, } diff --git a/pkg/db/migrate/0003_migrate_lfs_objects.go b/pkg/db/migrate/0003_migrate_lfs_objects.go index 70aaa79e8..bb0dde129 100644 --- a/pkg/db/migrate/0003_migrate_lfs_objects.go +++ b/pkg/db/migrate/0003_migrate_lfs_objects.go @@ -23,17 +23,17 @@ const ( var migrateLfsObjects = Migration{ Name: migrateLfsObjectsName, Version: migrateLfsObjectsVersion, - Migrate: func(ctx context.Context, tx *db.Tx) error { + Migrate: func(ctx context.Context, h db.Handler) error { cfg := config.FromContext(ctx) logger := log.FromContext(ctx).WithPrefix("migrate_lfs_objects") var repoIds []int64 - if err := tx.Select(&repoIds, "SELECT id FROM repos"); err != nil { + if err := h.Select(&repoIds, "SELECT id FROM repos"); err != nil { return err } for _, r := range repoIds { var objs []models.LFSObject - if err := tx.Select(&objs, "SELECT * FROM lfs_objects WHERE repo_id = ?", r); err != nil { + if err := h.Select(&objs, "SELECT * FROM lfs_objects WHERE repo_id = ?", r); err != nil { return err } objsp := filepath.Join(cfg.DataPath, "lfs", strconv.FormatInt(r, 10), "objects") @@ -50,7 +50,7 @@ var migrateLfsObjects = Migration{ } return nil }, - Rollback: func(ctx context.Context, tx *db.Tx) error { + Rollback: func(ctx context.Context, h db.Handler) error { return nil }, } diff --git a/pkg/db/migrate/0004_create_orgs_teams.go b/pkg/db/migrate/0004_create_orgs_teams.go new file mode 100644 index 000000000..738e2b561 --- /dev/null +++ b/pkg/db/migrate/0004_create_orgs_teams.go @@ -0,0 +1,46 @@ +package migrate + +import ( + "context" + "strings" + + "github.com/charmbracelet/soft-serve/pkg/db" +) + +const ( + createOrgsTeamsName = "create_orgs_teams" + createOrgsTeamsVersion = 4 +) + +var createOrgsTeams = Migration{ + Name: createOrgsTeamsName, + Version: createOrgsTeamsVersion, + PreMigrate: func(ctx context.Context, h db.Handler) error { + if strings.HasPrefix(h.DriverName(), "sqlite") { + if _, err := h.ExecContext(ctx, "PRAGMA foreign_keys = OFF;"); err != nil { + return err + } + if _, err := h.ExecContext(ctx, "PRAGMA legacy_alter_table = ON;"); err != nil { + return err + } + } + return nil + }, + PostMigrate: func(ctx context.Context, h db.Handler) error { + if strings.HasPrefix(h.DriverName(), "sqlite") { + if _, err := h.ExecContext(ctx, "PRAGMA foreign_keys = ON;"); err != nil { + return err + } + if _, err := h.ExecContext(ctx, "PRAGMA legacy_alter_table = OFF;"); err != nil { + return err + } + } + return nil + }, + Migrate: func(ctx context.Context, h db.Handler) error { + return migrateUp(ctx, h, createOrgsTeamsVersion, createOrgsTeamsName) + }, + Rollback: func(ctx context.Context, h db.Handler) error { + return migrateDown(ctx, h, createOrgsTeamsVersion, createOrgsTeamsName) + }, +} diff --git a/pkg/db/migrate/0004_create_orgs_teams_postgres.down.sql b/pkg/db/migrate/0004_create_orgs_teams_postgres.down.sql new file mode 100644 index 000000000..e69de29bb diff --git a/pkg/db/migrate/0004_create_orgs_teams_postgres.up.sql b/pkg/db/migrate/0004_create_orgs_teams_postgres.up.sql new file mode 100644 index 000000000..ea52bd366 --- /dev/null +++ b/pkg/db/migrate/0004_create_orgs_teams_postgres.up.sql @@ -0,0 +1,135 @@ +CREATE TABLE IF NOT EXISTS handles ( + id SERIAL PRIMARY KEY, + handle TEXT NOT NULL UNIQUE, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL +); + +CREATE TABLE IF NOT EXISTS organizations ( + id SERIAL PRIMARY KEY, + name TEXT, + contact_email TEXT NOT NULL, + handle_id INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + CONSTRAINT handle_id_fk + FOREIGN KEY(handle_id) REFERENCES handles(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS organization_members ( + id SERIAL PRIMARY KEY, + org_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (org_id, user_id), + CONSTRAINT org_id_fk + FOREIGN KEY(org_id) REFERENCES organizations(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS teams ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + org_id INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (name, org_id), + CONSTRAINT org_id_fk + FOREIGN KEY(org_id) REFERENCES organizations(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS team_members ( + id SERIAL PRIMARY KEY, + team_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (team_id, user_id), + CONSTRAINT team_id_fk + FOREIGN KEY(team_id) REFERENCES teams(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS user_emails ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL, + email TEXT NOT NULL UNIQUE, + is_primary BOOLEAN NOT NULL DEFAULT false, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +-- Create unique index for primary email +CREATE UNIQUE INDEX user_emails_user_id_is_primary_idx ON user_emails (user_id) WHERE is_primary; + +-- Add name to users table +ALTER TABLE users ADD COLUMN name TEXT; + +-- Add handle_id to users table +ALTER TABLE users ADD COLUMN handle_id INTEGER; +ALTER TABLE users ADD CONSTRAINT handle_id_fk + FOREIGN KEY(handle_id) REFERENCES handles(id) + ON DELETE CASCADE + ON UPDATE CASCADE; + +-- Migrate user username to handles +INSERT INTO handles (handle, updated_at) SELECT username, updated_at FROM users; + +-- Update handle_id for users +UPDATE users SET handle_id = handles.id FROM handles WHERE handles.handle = users.username; + +-- Make handle_id not null and unique +ALTER TABLE users ALTER COLUMN handle_id SET NOT NULL; +ALTER TABLE users ADD CONSTRAINT handle_id_unique UNIQUE (handle_id); + +-- Drop username from users +ALTER TABLE users DROP COLUMN username; + +-- Add org_id to repos table +ALTER TABLE repos ADD COLUMN org_id INTEGER; +ALTER TABLE repos ADD CONSTRAINT org_id_fk + FOREIGN KEY(org_id) REFERENCES organizations(id) + ON DELETE CASCADE + ON UPDATE CASCADE; + +-- Alter user_id nullness in repos table +ALTER TABLE repos ALTER COLUMN user_id DROP NOT NULL; + +-- Check that both user_id and org_id can't be null +ALTER TABLE repos ADD CONSTRAINT user_id_org_id_not_null CHECK ((user_id IS NULL) <> (org_id IS NULL)); + +-- Add team_id to collabs table +ALTER TABLE collabs ADD COLUMN team_id INTEGER; +ALTER TABLE collabs ADD CONSTRAINT team_id_fk + FOREIGN KEY(team_id) REFERENCES teams(id) + ON DELETE CASCADE + ON UPDATE CASCADE; + +-- Alter user_id nullness in collabs table +ALTER TABLE collabs ALTER COLUMN user_id DROP NOT NULL; + +-- Check that both user_id and team_id can't be null +ALTER TABLE collabs ADD CONSTRAINT user_id_team_id_not_null CHECK ((user_id IS NULL) <> (team_id IS NULL)); + +-- Alter unique constraint on collabs table +ALTER TABLE collabs DROP CONSTRAINT collabs_user_id_repo_id_key; +ALTER TABLE collabs ADD CONSTRAINT collabs_user_id_repo_id_team_id_key UNIQUE (user_id, repo_id, team_id); diff --git a/pkg/db/migrate/0004_create_orgs_teams_sqlite.down.sql b/pkg/db/migrate/0004_create_orgs_teams_sqlite.down.sql new file mode 100644 index 000000000..e69de29bb diff --git a/pkg/db/migrate/0004_create_orgs_teams_sqlite.up.sql b/pkg/db/migrate/0004_create_orgs_teams_sqlite.up.sql new file mode 100644 index 000000000..17e4a88ff --- /dev/null +++ b/pkg/db/migrate/0004_create_orgs_teams_sqlite.up.sql @@ -0,0 +1,181 @@ +CREATE TABLE IF NOT EXISTS handles ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + handle TEXT NOT NULL UNIQUE, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL +); + +CREATE TABLE IF NOT EXISTS organizations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT, + contact_email TEXT NOT NULL, + handle_id INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + CONSTRAINT handle_id_fk + FOREIGN KEY(handle_id) REFERENCES handles(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS organization_members ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + org_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + access_level INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (org_id, user_id), + CONSTRAINT org_id_fk + FOREIGN KEY(org_id) REFERENCES organizations(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS teams ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + org_id INTEGER NOT NULL, + access_level INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (name, org_id), + CONSTRAINT org_id_fk + FOREIGN KEY(org_id) REFERENCES organizations(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS team_members ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + team_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (team_id, user_id), + CONSTRAINT team_id_fk + FOREIGN KEY(team_id) REFERENCES teams(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS user_emails ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + email TEXT NOT NULL UNIQUE, + is_primary BOOLEAN NOT NULL DEFAULT false, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +-- Create unique index for primary email +CREATE UNIQUE INDEX user_emails_user_id_is_primary_idx ON user_emails (user_id) WHERE is_primary; + +ALTER TABLE users RENAME TO _users_old; + +CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT, + handle_id INTEGER NOT NULL UNIQUE, + admin BOOLEAN NOT NULL, + password TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + CONSTRAINT handle_id_fk + FOREIGN KEY(handle_id) REFERENCES handles(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +-- Migrate user username to handles +INSERT INTO handles (handle, updated_at) SELECT username, updated_at FROM _users_old; + +-- Migrate users +INSERT INTO users (id, handle_id, admin, password, created_at, updated_at) SELECT id, ( + SELECT id FROM handles WHERE handle = _users_old.username +), admin, password, created_at, updated_at FROM _users_old; + +-- Drop old table +DROP TABLE _users_old; + +ALTER TABLE repos RENAME TO _repos_old; + +CREATE TABLE IF NOT EXISTS repos ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + project_name TEXT NOT NULL, + description TEXT NOT NULL, + private BOOLEAN NOT NULL, + mirror BOOLEAN NOT NULL, + hidden BOOLEAN NOT NULL, + user_id INTEGER, + org_id INTEGER, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT org_id_fk + FOREIGN KEY(org_id) REFERENCES organizations(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT user_id_org_id_not_null + CHECK (user_id IS NULL <> org_id IS NULL) +); + +-- Migrate repos +INSERT INTO repos (id, name, project_name, description, private, mirror, hidden, user_id, created_at, updated_at) +SELECT id, name, project_name, description, private, mirror, hidden, user_id, created_at, updated_at +FROM _repos_old; + +-- Drop old table +DROP TABLE _repos_old; + +-- Alter collabs table +ALTER TABLE collabs RENAME TO _collabs_old; + +CREATE TABLE IF NOT EXISTS collabs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER, + team_id INTEGER, + repo_id INTEGER NOT NULL, + access_level INTEGER NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + UNIQUE (user_id, team_id, repo_id), + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT team_id_fk + FOREIGN KEY(team_id) REFERENCES teams(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT repo_id_fk + FOREIGN KEY(repo_id) REFERENCES repos(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT user_id_team_id_not_null + CHECK (user_id IS NULL <> team_id IS NULL) +); + +-- Migrate collabs +INSERT INTO collabs (id, user_id, team_id, repo_id, access_level, created_at, updated_at) +SELECT id, user_id, NULL, repo_id, access_level, created_at, updated_at +FROM _collabs_old; + +-- Drop old table +DROP TABLE _collabs_old; diff --git a/pkg/db/migrate/migrate.go b/pkg/db/migrate/migrate.go index 18bc1780d..276aac132 100644 --- a/pkg/db/migrate/migrate.go +++ b/pkg/db/migrate/migrate.go @@ -11,15 +11,17 @@ import ( ) // MigrateFunc is a function that executes a migration. -type MigrateFunc func(ctx context.Context, tx *db.Tx) error // nolint:revive +type MigrateFunc func(ctx context.Context, h db.Handler) error // nolint:revive // Migration is a struct that contains the name of the migration and the // function to execute it. type Migration struct { - Version int64 - Name string - Migrate MigrateFunc - Rollback MigrateFunc + Version int64 + Name string + PreMigrate MigrateFunc + Migrate MigrateFunc + PostMigrate MigrateFunc + Rollback MigrateFunc } // Migrations is a database model to store migrations. @@ -62,37 +64,50 @@ func (Migrations) schema(driverName string) string { // Migrate runs the migrations. func Migrate(ctx context.Context, dbx *db.DB) error { logger := log.FromContext(ctx).WithPrefix("migrate") - return dbx.TransactionContext(ctx, func(tx *db.Tx) error { - if !hasTable(tx, "migrations") { - if _, err := tx.Exec(Migrations{}.schema(tx.DriverName())); err != nil { - return err - } + + if !hasTable(dbx, "migrations") { + if _, err := dbx.Exec(Migrations{}.schema(dbx.DriverName())); err != nil { + return err } + } - var migrs Migrations - if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil { - if !errors.Is(err, sql.ErrNoRows) { - return err - } + var migrs Migrations + if err := dbx.Get(&migrs, dbx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return err } + } - for _, m := range migrations { - if m.Version <= migrs.Version { - continue - } + for _, m := range migrations { + if m.Version <= migrs.Version { + continue + } - logger.Infof("running migration %d. %s", m.Version, m.Name) - if err := m.Migrate(ctx, tx); err != nil { + logger.Infof("running migration %d. %s", m.Version, m.Name) + if m.PreMigrate != nil { + if err := m.PreMigrate(ctx, dbx); err != nil { return err } + } + + if err := dbx.TransactionContext(ctx, func(tx *db.Tx) error { + return m.Migrate(ctx, tx) + }); err != nil { + return err + } - if _, err := tx.Exec(tx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil { + if m.PostMigrate != nil { + if err := m.PostMigrate(ctx, dbx); err != nil { return err } } - return nil - }) + if _, err := dbx.Exec(dbx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil { + return err + } + } + + return nil } // Rollback rolls back a migration. @@ -124,7 +139,7 @@ func Rollback(ctx context.Context, dbx *db.DB) error { }) } -func hasTable(tx *db.Tx, tableName string) bool { +func hasTable(tx db.Handler, tableName string) bool { var query string switch tx.DriverName() { case "sqlite3", "sqlite": diff --git a/pkg/db/migrate/migrations.go b/pkg/db/migrate/migrations.go index e2598b414..ff0ea08e8 100644 --- a/pkg/db/migrate/migrations.go +++ b/pkg/db/migrate/migrations.go @@ -18,15 +18,16 @@ var migrations = []Migration{ createTables, webhooks, migrateLfsObjects, + createOrgsTeams, } -func execMigration(ctx context.Context, tx *db.Tx, version int, name string, down bool) error { +func execMigration(ctx context.Context, h db.Handler, version int, name string, down bool) error { direction := "up" if down { direction = "down" } - driverName := tx.DriverName() + driverName := h.DriverName() if driverName == "sqlite3" { driverName = "sqlite" } @@ -37,19 +38,19 @@ func execMigration(ctx context.Context, tx *db.Tx, version int, name string, dow return err } - if _, err := tx.ExecContext(ctx, string(sqlstr)); err != nil { + if _, err := h.ExecContext(ctx, string(sqlstr)); err != nil { return err } return nil } -func migrateUp(ctx context.Context, tx *db.Tx, version int, name string) error { - return execMigration(ctx, tx, version, name, false) +func migrateUp(ctx context.Context, h db.Handler, version int, name string) error { + return execMigration(ctx, h, version, name, false) } -func migrateDown(ctx context.Context, tx *db.Tx, version int, name string) error { - return execMigration(ctx, tx, version, name, true) +func migrateDown(ctx context.Context, h db.Handler, version int, name string) error { + return execMigration(ctx, h, version, name, true) } var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") diff --git a/pkg/db/models/collab.go b/pkg/db/models/collab.go index 5141e5986..149143152 100644 --- a/pkg/db/models/collab.go +++ b/pkg/db/models/collab.go @@ -1,6 +1,7 @@ package models import ( + "database/sql" "time" "github.com/charmbracelet/soft-serve/pkg/access" @@ -10,7 +11,8 @@ import ( type Collab struct { ID int64 `db:"id"` RepoID int64 `db:"repo_id"` - UserID int64 `db:"user_id"` + UserID sql.NullInt64 `db:"user_id"` + TeamID sql.NullInt64 `db:"team_id"` AccessLevel access.AccessLevel `db:"access_level"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` diff --git a/pkg/db/models/handle.go b/pkg/db/models/handle.go new file mode 100644 index 000000000..99b532dc7 --- /dev/null +++ b/pkg/db/models/handle.go @@ -0,0 +1,11 @@ +package models + +import "time" + +// Handle represents a name handle. +type Handle struct { + ID int64 `db:"id"` + Handle string `db:"handle"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} diff --git a/pkg/db/models/org.go b/pkg/db/models/org.go new file mode 100644 index 000000000..596aa13cc --- /dev/null +++ b/pkg/db/models/org.go @@ -0,0 +1,25 @@ +package models + +import ( + "database/sql" + "time" +) + +// Organization represents an organization in the system. +type Organization struct { + ID int64 `db:"id"` + Name sql.NullString `db:"name"` + ContactEmail string `db:"contact_email"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` + Handle Handle `db:"handle"` +} + +// OrganizationMember represents a member of an organization. +type OrganizationMember struct { + ID int64 `db:"id"` + OrganizationID int64 `db:"org_id"` + UserID int64 `db:"user_id"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} diff --git a/pkg/db/models/repo.go b/pkg/db/models/repo.go index 88300bd3e..e58803931 100644 --- a/pkg/db/models/repo.go +++ b/pkg/db/models/repo.go @@ -15,6 +15,7 @@ type Repo struct { Mirror bool `db:"mirror"` Hidden bool `db:"hidden"` UserID sql.NullInt64 `db:"user_id"` + OrgID sql.NullInt64 `db:"org_id"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` } diff --git a/pkg/db/models/team.go b/pkg/db/models/team.go new file mode 100644 index 000000000..de04b7a8d --- /dev/null +++ b/pkg/db/models/team.go @@ -0,0 +1,23 @@ +package models + +import ( + "time" +) + +// Team represents a team in an organization. +type Team struct { + ID int64 `db:"id"` + Name string `db:"name"` + OrganizationID int64 `db:"org_id"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} + +// TeamMember represents a member of a team. +type TeamMember struct { + ID int64 `db:"id"` + TeamID int64 `db:"team_id"` + UserID int64 `db:"user_id"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} diff --git a/pkg/db/models/user.go b/pkg/db/models/user.go index 5ca0d3d9f..3d107a2eb 100644 --- a/pkg/db/models/user.go +++ b/pkg/db/models/user.go @@ -8,9 +8,20 @@ import ( // User represents a user. type User struct { ID int64 `db:"id"` - Username string `db:"username"` + Name sql.NullString `db:"name"` Admin bool `db:"admin"` Password sql.NullString `db:"password"` + HandleID int64 `db:"handle_id"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` } + +// UserEmail represents a user's email address. +type UserEmail struct { + ID int64 `db:"id"` + UserID int64 `db:"user_id"` + Email string `db:"email"` + IsPrimary bool `db:"is_primary"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} diff --git a/pkg/git/lfs.go b/pkg/git/lfs.go index 4b0065f3e..67adc5c3b 100644 --- a/pkg/git/lfs.go +++ b/pkg/git/lfs.go @@ -255,6 +255,11 @@ func (l *lfsLockBackend) Create(path string, refname string) (transfer.Lock, err } lock.owner, err = l.store.GetUserByID(l.ctx, tx, lock.lock.UserID) + if err != nil { + return db.WrapError(err) + } + + lock.handle, err = l.store.GetHandleByUserID(l.ctx, tx, lock.owner.ID) return db.WrapError(err) }); err != nil { // Return conflict (409) if the lock already exists. @@ -286,6 +291,11 @@ func (l *lfsLockBackend) FromID(id string) (transfer.Lock, error) { } lock.owner, err = l.store.GetUserByID(l.ctx, tx, lock.lock.UserID) + if err != nil { + return db.WrapError(err) + } + + lock.handle, err = l.store.GetHandleByUserID(l.ctx, tx, lock.owner.ID) return db.WrapError(err) }); err != nil { if errors.Is(err, db.ErrRecordNotFound) { @@ -312,6 +322,11 @@ func (l *lfsLockBackend) FromPath(path string) (transfer.Lock, error) { } lock.owner, err = l.store.GetUserByID(l.ctx, tx, lock.lock.UserID) + if err != nil { + return db.WrapError(err) + } + + lock.handle, err = l.store.GetHandleByUserID(l.ctx, tx, lock.owner.ID) return db.WrapError(err) }); err != nil { if errors.Is(err, db.ErrRecordNotFound) { @@ -410,6 +425,7 @@ func (l *lfsLockBackend) Unlock(lock transfer.Lock) error { type LFSLock struct { lock models.LFSLock owner models.User + handle models.Handle backend *lfsLockBackend } @@ -459,7 +475,7 @@ func (l *LFSLock) ID() string { // OwnerName implements transfer.Lock. func (l *LFSLock) OwnerName() string { - return l.owner.Username + return l.handle.Handle } // Path implements transfer.Lock. diff --git a/pkg/proto/org.go b/pkg/proto/org.go new file mode 100644 index 000000000..2809fd8af --- /dev/null +++ b/pkg/proto/org.go @@ -0,0 +1,11 @@ +package proto + +// Org is an interface representing a organization. +type Org interface { + // ID returns the user's ID. + ID() int64 + // Name returns the org's name. + Name() string + // DisplayName + DisplayName() string +} diff --git a/pkg/proto/team.go b/pkg/proto/team.go new file mode 100644 index 000000000..516da581f --- /dev/null +++ b/pkg/proto/team.go @@ -0,0 +1,11 @@ +package proto + +// Team is an interface representing a team. +type Team interface { + // ID returns the user's ID. + ID() int64 + // Name returns the org's name. + Name() string + // Parent organization's ID. + Org() int64 +} diff --git a/pkg/proto/user.go b/pkg/proto/user.go index 7b334122d..c6c65b1bf 100644 --- a/pkg/proto/user.go +++ b/pkg/proto/user.go @@ -14,6 +14,8 @@ type User interface { PublicKeys() []ssh.PublicKey // Password returns the user's password hash. Password() string + // Emails returns the user's emails. + Emails() []UserEmail } // UserOptions are options for creating a user. @@ -22,4 +24,19 @@ type UserOptions struct { Admin bool // PublicKeys are the user's public keys. PublicKeys []ssh.PublicKey + // Emails are the user's emails. + // The first email in the slice will be set as the user's primary email. + Emails []string +} + +// UserEmail represents a user's email address. +type UserEmail interface { + // ID returns the email's ID. + ID() int64 + + // Email returns the email address. + Email() string + + // IsPrimary returns whether the email is the user's primary email. + IsPrimary() bool } diff --git a/pkg/ssh/cmd/org.go b/pkg/ssh/cmd/org.go new file mode 100644 index 000000000..1d1a5fe15 --- /dev/null +++ b/pkg/ssh/cmd/org.go @@ -0,0 +1,90 @@ +package cmd + +import ( + "github.com/charmbracelet/soft-serve/pkg/backend" + "github.com/charmbracelet/soft-serve/pkg/proto" + "github.com/spf13/cobra" +) + +// OrgCommand returns a command for managing organizations. +func OrgCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "org", + Aliases: []string{"orgs", "organization", "organizations"}, + Short: "Manage organizations", + } + + cmd.AddCommand(&cobra.Command{ + Use: "create NAME EMAIL", + Short: "Create a new organization", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + owner := proto.UserFromContext(ctx) + if owner == nil { + return proto.ErrUnauthorized + } + _, err := be.CreateOrg(ctx, owner, args[0], args[1]) + return err + }, + }) + cmd.AddCommand(&cobra.Command{ + Use: "list", + Short: "List organizations", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + orgs, err := be.ListOrgs(ctx, user) + if err != nil { + return err + } + for _, o := range orgs { + cmd.Println(o.Name()) + } + return nil + }, + }) + + cmd.AddCommand(&cobra.Command{ + Use: "delete NAME", + Short: "Delete organization", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + return be.DeleteOrganization(ctx, user, args[0]) + }, + }) + + cmd.AddCommand(&cobra.Command{ + Use: "get NAME", + Short: "Show organization", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + org, err := be.FindOrganization(ctx, user, args[0]) + if err != nil { + return err + } + cmd.Println(org.Name()) + return nil + }, + }) + + return cmd +} diff --git a/pkg/ssh/cmd/team.go b/pkg/ssh/cmd/team.go new file mode 100644 index 000000000..c0d1c2919 --- /dev/null +++ b/pkg/ssh/cmd/team.go @@ -0,0 +1,121 @@ +package cmd + +import ( + "github.com/charmbracelet/soft-serve/pkg/backend" + "github.com/charmbracelet/soft-serve/pkg/proto" + "github.com/spf13/cobra" +) + +// TeamCommand returns a command for managing teams. +func TeamCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "team", + Aliases: []string{"teams"}, + Short: "Manage teams", + } + + cmd.AddCommand(&cobra.Command{ + Use: "create ORG NAME", + Short: "Create a new team", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + + org, err := be.FindOrganization(ctx, user, args[0]) + if err != nil { + return err + } + + team, err := be.CreateTeam(ctx, org, user, args[1]) + if err != nil { + return err + } + + cmd.Println("Created", team.Name()) + + return err + }, + }) + + cmd.AddCommand(&cobra.Command{ + Use: "list", + Short: "List teams", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + teams, err := be.ListTeams(ctx, user) + if err != nil { + return err + } + for _, o := range teams { + cmd.Println(o.Name()) + } + return nil + }, + }) + + cmd.AddCommand(&cobra.Command{ + Use: "delete ORG NAME", + Short: "Delete team", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + + org, err := be.FindOrganization(ctx, user, args[0]) + if err != nil { + return err + } + + team, err := be.GetTeam(ctx, user, org, args[1]) + if err != nil { + return err + } + + return be.DeleteTeam(ctx, user, team) + }, + }) + + cmd.AddCommand(&cobra.Command{ + Use: "get ORG NAME", + Short: "Show team", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + + org, err := be.FindOrganization(ctx, user, args[0]) + if err != nil { + return err + } + + team, err := be.GetTeam(ctx, user, org, args[1]) + if err != nil { + return err + } + + cmd.Println(org.Name(), "/", team.Name()) + return nil + }, + }) + + return cmd +} diff --git a/pkg/ssh/cmd/user.go b/pkg/ssh/cmd/user.go index 21981f9cb..9944aa243 100644 --- a/pkg/ssh/cmd/user.go +++ b/pkg/ssh/cmd/user.go @@ -22,9 +22,9 @@ func UserCommand() *cobra.Command { var admin bool var key string userCreateCommand := &cobra.Command{ - Use: "create USERNAME", + Use: "create USERNAME [EMAIL]", Short: "Create a new user", - Args: cobra.ExactArgs(1), + Args: cobra.MinimumNArgs(1), PersistentPreRunE: checkIfAdmin, RunE: func(cmd *cobra.Command, args []string) error { var pubkeys []ssh.PublicKey @@ -45,6 +45,10 @@ func UserCommand() *cobra.Command { PublicKeys: pubkeys, } + if len(args) > 1 { + opts.Emails = append(opts.Emails, args[1]) + } + _, err := be.CreateUser(ctx, username, opts) return err }, @@ -166,6 +170,14 @@ func UserCommand() *cobra.Command { cmd.Printf(" %s\n", sshutils.MarshalAuthorizedKey(pk)) } + emails := user.Emails() + if len(emails) > 0 { + cmd.Printf("Emails:\n") + for _, e := range emails { + cmd.Printf(" %s (primary: %v)\n", e.Email(), e.IsPrimary()) + } + } + return nil }, } @@ -185,6 +197,63 @@ func UserCommand() *cobra.Command { }, } + userAddEmailCommand := &cobra.Command{ + Use: "add-email USERNAME EMAIL", + Short: "Add an email to a user", + Args: cobra.ExactArgs(2), + PersistentPreRunE: checkIfAdmin, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + username := args[0] + email := args[1] + u, err := be.User(ctx, username) + if err != nil { + return err + } + + return be.AddUserEmail(ctx, u, email) + }, + } + + userRemoveEmailCommand := &cobra.Command{ + Use: "remove-email USERNAME EMAIL", + Short: "Remove an email from a user", + Args: cobra.ExactArgs(2), + PersistentPreRunE: checkIfAdmin, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + username := args[0] + email := args[1] + u, err := be.User(ctx, username) + if err != nil { + return err + } + + return be.RemoveUserEmail(ctx, u, email) + }, + } + + userSetPrimaryEmailCommand := &cobra.Command{ + Use: "set-primary-email USERNAME EMAIL", + Short: "Set a user's primary email", + Args: cobra.ExactArgs(2), + PersistentPreRunE: checkIfAdmin, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + username := args[0] + email := args[1] + u, err := be.User(ctx, username) + if err != nil { + return err + } + + return be.SetUserPrimaryEmail(ctx, u, email) + }, + } + cmd.AddCommand( userCreateCommand, userAddPubkeyCommand, @@ -194,6 +263,9 @@ func UserCommand() *cobra.Command { userRemovePubkeyCommand, userSetAdminCommand, userSetUsernameCommand, + userAddEmailCommand, + userRemoveEmailCommand, + userSetPrimaryEmailCommand, ) return cmd diff --git a/pkg/ssh/middleware.go b/pkg/ssh/middleware.go index 285a980a7..446d9d5b1 100644 --- a/pkg/ssh/middleware.go +++ b/pkg/ssh/middleware.go @@ -104,6 +104,7 @@ func CommandMiddleware(sh ssh.Handler) ssh.Handler { cmd.RepoCommand(), cmd.SettingsCommand(), cmd.UserCommand(), + cmd.OrgCommand(), cmd.InfoCommand(), cmd.PubkeyCommand(), cmd.SetUsernameCommand(), diff --git a/pkg/store/database/collab.go b/pkg/store/database/collab.go index 9068edef2..ede99515c 100644 --- a/pkg/store/database/collab.go +++ b/pkg/store/database/collab.go @@ -18,7 +18,7 @@ var _ store.CollaboratorStore = (*collabStore)(nil) // AddCollabByUsernameAndRepo implements store.CollaboratorStore. func (*collabStore) AddCollabByUsernameAndRepo(ctx context.Context, tx db.Handler, username string, repo string, level access.AccessLevel) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -28,7 +28,9 @@ func (*collabStore) AddCollabByUsernameAndRepo(ctx context.Context, tx db.Handle VALUES ( ?, ( - SELECT id FROM users WHERE username = ? + SELECT id FROM users WHERE handle_id = ( + SELECT id FROM handles WHERE handle = ? + ) ), ( SELECT id FROM repos WHERE name = ? @@ -44,7 +46,7 @@ func (*collabStore) GetCollabByUsernameAndRepo(ctx context.Context, tx db.Handle var m models.Collab username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return models.Collab{}, err } @@ -56,9 +58,10 @@ func (*collabStore) GetCollabByUsernameAndRepo(ctx context.Context, tx db.Handle FROM collabs INNER JOIN users ON users.id = collabs.user_id + INNER JOIN handles ON handles.id = users.handle_id INNER JOIN repos ON repos.id = collabs.repo_id WHERE - users.username = ? AND repos.name = ? + handles.handle = ? AND repos.name = ? `), username, repo) return m, err @@ -106,7 +109,7 @@ func (*collabStore) ListCollabsByRepoAsUsers(ctx context.Context, tx db.Handler, // RemoveCollabByUsernameAndRepo implements store.CollaboratorStore. func (*collabStore) RemoveCollabByUsernameAndRepo(ctx context.Context, tx db.Handler, username string, repo string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -116,7 +119,9 @@ func (*collabStore) RemoveCollabByUsernameAndRepo(ctx context.Context, tx db.Han collabs WHERE user_id = ( - SELECT id FROM users WHERE username = ? + SELECT id FROM users WHERE handle_id = ( + SELECT id FROM handles WHERE handle = ? + ) ) AND repo_id = ( SELECT id FROM repos WHERE name = ? ) diff --git a/pkg/store/database/database.go b/pkg/store/database/database.go index 02c4bd19c..de03192a7 100644 --- a/pkg/store/database/database.go +++ b/pkg/store/database/database.go @@ -18,16 +18,20 @@ type datastore struct { *settingsStore *repoStore *userStore + *orgStore + *teamStore *collabStore *lfsStore *accessTokenStore *webhookStore + *handleStore } // New returns a new store.Store database. func New(ctx context.Context, db *db.DB) store.Store { cfg := config.FromContext(ctx) logger := log.FromContext(ctx).WithPrefix("store") + handles := &handleStore{} s := &datastore{ ctx: ctx, @@ -37,10 +41,14 @@ func New(ctx context.Context, db *db.DB) store.Store { settingsStore: &settingsStore{}, repoStore: &repoStore{}, - userStore: &userStore{}, + userStore: &userStore{handles}, + orgStore: &orgStore{handles}, + teamStore: &teamStore{}, collabStore: &collabStore{}, lfsStore: &lfsStore{}, accessTokenStore: &accessTokenStore{}, + webhookStore: &webhookStore{}, + handleStore: handles, } return s diff --git a/pkg/store/database/handle.go b/pkg/store/database/handle.go new file mode 100644 index 000000000..a2b166e4e --- /dev/null +++ b/pkg/store/database/handle.go @@ -0,0 +1,88 @@ +package database + +import ( + "context" + "strings" + + "github.com/charmbracelet/soft-serve/pkg/db" + "github.com/charmbracelet/soft-serve/pkg/db/models" + "github.com/charmbracelet/soft-serve/pkg/store" + "github.com/charmbracelet/soft-serve/pkg/utils" + "github.com/jmoiron/sqlx" +) + +type handleStore struct{} + +var _ store.HandleStore = &handleStore{} + +// CreateHandle implements store.HandleStore. +func (*handleStore) CreateHandle(ctx context.Context, h db.Handler, handle string) (int64, error) { + handle = strings.ToLower(handle) + if err := utils.ValidateHandle(handle); err != nil { + return 0, err + } + + var id int64 + query := h.Rebind("INSERT INTO handles (handle, updated_at) VALUES (?, CURRENT_TIMESTAMP) RETURNING id;") + err := h.GetContext(ctx, &id, query, handle) + return id, db.WrapError(err) +} + +// DeleteHandle implements store.HandleStore. +func (*handleStore) DeleteHandle(ctx context.Context, h db.Handler, id int64) error { + query := h.Rebind("DELETE FROM handles WHERE id = ?;") + _, err := h.ExecContext(ctx, query, id) + return db.WrapError(err) +} + +// GetHandleByHandle implements store.HandleStore. +func (*handleStore) GetHandleByHandle(ctx context.Context, h db.Handler, handle string) (models.Handle, error) { + var hl models.Handle + query := h.Rebind("SELECT * FROM handles WHERE handle = ?;") + err := h.GetContext(ctx, &hl, query, handle) + return hl, db.WrapError(err) +} + +// GetHandleByID implements store.HandleStore. +func (*handleStore) GetHandleByID(ctx context.Context, h db.Handler, id int64) (models.Handle, error) { + var hl models.Handle + query := h.Rebind("SELECT * FROM handles WHERE id = ?;") + err := h.GetContext(ctx, &hl, query, id) + return hl, db.WrapError(err) +} + +// GetHandleByUserID implements store.HandleStore. +func (*handleStore) GetHandleByUserID(ctx context.Context, h db.Handler, userID int64) (models.Handle, error) { + var hl models.Handle + query := h.Rebind("SELECT * FROM handles WHERE id = (SELECT handle_id FROM users WHERE id = ?);") + err := h.GetContext(ctx, &hl, query, userID) + return hl, db.WrapError(err) +} + +// UpdateHandle implements store.HandleStore. +func (*handleStore) UpdateHandle(ctx context.Context, h db.Handler, id int64, handle string) error { + handle = strings.ToLower(handle) + if err := utils.ValidateHandle(handle); err != nil { + return err + } + query := h.Rebind("UPDATE handles SET handle = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?;") + _, err := h.ExecContext(ctx, query, handle, id) + return db.WrapError(err) +} + +// ListHandlesForIDs implements store.HandleStore. +func (*handleStore) ListHandlesForIDs(ctx context.Context, h db.Handler, ids []int64) ([]models.Handle, error) { + var hls []models.Handle + if len(ids) == 0 { + return hls, nil + } + + query, args, err := sqlx.In("SELECT * FROM handles WHERE id IN (?)", ids) + if err != nil { + return nil, db.WrapError(err) + } + + query = h.Rebind(query) + err = h.SelectContext(ctx, &hls, query, args...) + return hls, db.WrapError(err) +} diff --git a/pkg/store/database/org.go b/pkg/store/database/org.go new file mode 100644 index 000000000..60571dc9d --- /dev/null +++ b/pkg/store/database/org.go @@ -0,0 +1,174 @@ +package database + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/access" + "github.com/charmbracelet/soft-serve/pkg/db" + "github.com/charmbracelet/soft-serve/pkg/db/models" + "github.com/charmbracelet/soft-serve/pkg/store" + "github.com/charmbracelet/soft-serve/pkg/utils" +) + +var _ store.OrgStore = (*orgStore)(nil) + +type orgStore struct{ *handleStore } + +// UpdateOrgContactEmail implements store.OrgStore. +func (*orgStore) UpdateOrgContactEmail(ctx context.Context, h db.Handler, org int64, email string) error { + if err := utils.ValidateEmail(email); err != nil { + return err + } + + query := h.Rebind(` + UPDATE organizations + SET + contact_email = ? + WHERE + id = ? + `) + + _, err := h.ExecContext(ctx, query, email, org) + return err +} + +// ListOrgs implements store.OrgStore. +func (*orgStore) ListOrgs(ctx context.Context, h db.Handler, uid int64) ([]models.Organization, error) { + var m []models.Organization + query := h.Rebind(` + SELECT + o.*, + h AS handle + FROM + organizations o + JOIN handles h ON h.id = o.handle_id + JOIN organization_members om ON om.org_id = o.id + WHERE + o.user_id = ? + `) + err := h.SelectContext(ctx, &m, query, uid) + return m, err +} + +// Delete implements store.OrgStore. +func (s *orgStore) DeleteOrgByID(ctx context.Context, h db.Handler, user, id int64) error { + _, err := s.getOrgByIDWithAccess(ctx, h, user, id, access.AdminAccess) + if err != nil { + return err + } + query := h.Rebind(`DELETE FROM organizations WHERE id = ?;`) + _, err = h.ExecContext(ctx, query, id) + return err +} + +// Create implements store.OrgStore. +func (s *orgStore) CreateOrg(ctx context.Context, h db.Handler, user int64, name, email string) (models.Organization, error) { + if err := utils.ValidateEmail(email); err != nil { + return models.Organization{}, err + } + + handle, err := s.CreateHandle(ctx, h, name) + if err != nil { + return models.Organization{}, err + } + + query := h.Rebind(` + INSERT INTO + organizations (handle_id, contact_email, updated_at) + VALUES + (?, ?, CURRENT_TIMESTAMP) RETURNING id; + `) + + var id int64 + if err := h.GetContext(ctx, &id, query, handle, email); err != nil { + return models.Organization{}, err + } + if err := s.AddUserToOrg(ctx, h, id, user, access.AdminAccess); err != nil { + return models.Organization{}, err + } + + return s.GetOrgByID(ctx, h, user, id) +} + +func (*orgStore) UpdateUserAccessInOrg(ctx context.Context, h db.Handler, org, user int64, lvl access.AccessLevel) error { + query := h.Rebind(` + UPDATE organization_members + WHERE + organization_id = ? + AND user_id = ? + SET + access_level = ? + `) + _, err := h.ExecContext(ctx, query, org, user, lvl) + return err +} + +func (*orgStore) RemoveUserFromOrg(ctx context.Context, h db.Handler, org, user int64) error { + query := h.Rebind(` + DELETE FROM organization_members + WHERE + organization_id = ? + AND user_id = ? + `) + _, err := h.ExecContext(ctx, query, org, user) + return err +} + +func (*orgStore) AddUserToOrg(ctx context.Context, h db.Handler, org, user int64, lvl access.AccessLevel) error { + query := h.Rebind(` + INSERT INTO + organization_members ( + organization_id, + user_id, + access_level, + updated_at + ) + VALUES + (?, ?, ?, CURRENT_TIMESTAMP); + `) + _, err := h.ExecContext(ctx, query, org, user, lvl) + return err +} + +// FindByName implements store.OrgStore. +func (*orgStore) FindOrgByHandle(ctx context.Context, h db.Handler, user int64, name string) (models.Organization, error) { + var m models.Organization + query := h.Rebind(` + SELECT + o.*, + h AS handle + FROM + organizations o + JOIN handles h ON h.id = o.handle_id + JOIN organization_members om ON om.organization_id = o.id + WHERE + om.user_id = ? + AND h.handle = ?; + `) + err := h.GetContext(ctx, &m, query, user, name) + return m, err +} + +// GetByID implements store.OrgStore. +func (s *orgStore) GetOrgByID(ctx context.Context, h db.Handler, user, id int64) (models.Organization, error) { + return s.getOrgByIDWithAccess(ctx, h, user, id, access.ReadOnlyAccess) +} + +func (*orgStore) getOrgByIDWithAccess(ctx context.Context, h db.Handler, user, id int64, level access.AccessLevel) (models.Organization, error) { + var m models.Organization + query := h.Rebind(` + SELECT + o.*, + h AS handle + FROM + organizations o + JOIN handles h ON h.id = o.handle_id + JOIN organization_members om ON om.organization_id = o.id + WHERE + om.user_id = ? + AND id = ? + AND om.access_level >= ?; + `) + err := h.GetContext(ctx, &m, query, user, id, level) + return m, err +} diff --git a/pkg/store/database/team.go b/pkg/store/database/team.go new file mode 100644 index 000000000..bf103bf51 --- /dev/null +++ b/pkg/store/database/team.go @@ -0,0 +1,155 @@ +package database + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/access" + "github.com/charmbracelet/soft-serve/pkg/db" + "github.com/charmbracelet/soft-serve/pkg/db/models" + "github.com/charmbracelet/soft-serve/pkg/store" +) + +// TODO: should we return all the org's teams if the user is an org admin? +// if so, need to join organization_members too on the selects below. + +var _ store.TeamStore = (*teamStore)(nil) + +type teamStore struct{} + +// RemoveUserFromTeam implements store.TeamStore. +func (*teamStore) RemoveUserFromTeam(ctx context.Context, h db.Handler, team int64, user int64) error { + // TODO: caller perms + query := h.Rebind(` + DELETE FROM team_members + WHERE + team_id = ? + AND user_id = ? + `) + _, err := h.ExecContext(ctx, query, team, user) + return err +} + +// UpdateUserAccessInTeam implements store.TeamStore. +func (*teamStore) UpdateUserAccessInTeam(ctx context.Context, h db.Handler, team int64, user int64, lvl access.AccessLevel) error { + // TODO: caller perms + query := h.Rebind(` + UPDATE team_members + WHERE + team_id = ? + AND user_id = ? + SET + access_level = ? + `) + _, err := h.ExecContext(ctx, query, team, user, lvl) + return err +} + +// AddUserToTeam implements store.TeamStore. +func (*teamStore) AddUserToTeam(ctx context.Context, h db.Handler, team int64, user int64, lvl access.AccessLevel) error { + // TODO: caller perms + query := h.Rebind(` + INSERT INTO + team_members (team_id, user_id, access_level, updated_at) + VALUES + (?, ?, ?, CURRENT_TIMESTAMP); + `) + _, err := h.ExecContext(ctx, query, team, user, lvl) + return err +} + +// CreateTeam implements store.TeamStore. +func (s *teamStore) CreateTeam(ctx context.Context, h db.Handler, user, org int64, name string) (models.Team, error) { + // TODO: caller perms + // TODO: what the access_level column does on team? + query := h.Rebind(` + INSERT INTO + teams (organization_id, name) + VALUES + (?, ?) RETURNING * + `) + var team models.Team + if err := h.GetContext(ctx, &team, query, org, name); err != nil { + return models.Team{}, err + } + return team, s.AddUserToTeam(ctx, h, team.ID, user, access.AdminAccess) +} + +// DeleteTeamByID implements store.TeamStore. +func (*teamStore) DeleteTeamByID(ctx context.Context, h db.Handler, id int64) error { + // TODO: caller perms + query := h.Rebind(` + DELETE FROM teams + WHERE + id = ? + `) + _, err := h.ExecContext(ctx, query, id) + return err +} + +// FindTeamByName implements store.TeamStore. +func (*teamStore) FindTeamByName(ctx context.Context, h db.Handler, uid int64, name string) ([]models.Team, error) { + query := h.Rebind(` + SELECT + t.* + FROM + teams t + JOIN team_members tm ON tm.team_id = t.id + WHERE + tm.user_id = ? + AND t.name = ? + `) + var teams []models.Team + err := h.SelectContext(ctx, &teams, query, uid, name) + return teams, err +} + +// FindTeamByOrgName implements store.TeamStore. +func (*teamStore) FindTeamByOrgName(ctx context.Context, h db.Handler, user int64, org int64, name string) (models.Team, error) { + query := h.Rebind(` + SELECT + t.* + FROM + teams t + JOIN team_members tm ON tm.team_id = t.id + WHERE + tm.user_id = ? + AND t.organization_id = ? + AND t.name = ? + `) + var team models.Team + err := h.GetContext(ctx, &team, query, user, org, name) + return team, err +} + +// GetTeamByID implements store.TeamStore. +func (*teamStore) GetTeamByID(ctx context.Context, h db.Handler, uid, id int64) (models.Team, error) { + query := h.Rebind(` + SELECT + t.* + FROM + teams t + JOIN team_members tm ON tm.team_id = t.id + WHERE + tm.user_id = ? + AND t.id = ? + `) + var team models.Team + err := h.GetContext(ctx, &team, query, uid, id) + return team, err +} + +// ListTeams implements store.TeamStore. +func (*teamStore) ListTeams(ctx context.Context, h db.Handler, uid int64) ([]models.Team, error) { + query := h.Rebind(` + SELECT + t.* + FROM + teams t + JOIN team_members tm ON tm.team_id = t.id + WHERE + tm.user_id = ? + `) + var teams []models.Team + err := h.SelectContext(ctx, &teams, query, uid) + return teams, err +} diff --git a/pkg/store/database/user.go b/pkg/store/database/user.go index 86b161f47..f3ed59f39 100644 --- a/pkg/store/database/user.go +++ b/pkg/store/database/user.go @@ -2,6 +2,7 @@ package database import ( "context" + "fmt" "strings" "github.com/charmbracelet/soft-serve/pkg/db" @@ -12,19 +13,21 @@ import ( "golang.org/x/crypto/ssh" ) -type userStore struct{} +type userStore struct{ *handleStore } var _ store.UserStore = (*userStore)(nil) // AddPublicKeyByUsername implements store.UserStore. func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } var userID int64 - if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT id FROM users WHERE username = ?`), username); err != nil { + if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT users.id FROM users + INNER JOIN handles ON handles.id = users.handle_id + WHERE handles.handle = ?;`), username); err != nil { return err } @@ -37,23 +40,31 @@ func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, use } // CreateUser implements store.UserStore. -func (*userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error { - username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { +func (s *userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey, emails []string) error { + handleID, err := s.CreateHandle(ctx, tx, username) + if err != nil { return err } - query := tx.Rebind(`INSERT INTO users (username, admin, updated_at) - VALUES (?, ?, CURRENT_TIMESTAMP) RETURNING id;`) + query := tx.Rebind(` + INSERT INTO + users (handle_id, admin, updated_at) + VALUES + (?, ?, CURRENT_TIMESTAMP) RETURNING id; + `) var userID int64 - if err := tx.GetContext(ctx, &userID, query, username, isAdmin); err != nil { + if err := tx.GetContext(ctx, &userID, query, handleID, isAdmin); err != nil { return err } for _, pk := range pks { - query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at) - VALUES (?, ?, CURRENT_TIMESTAMP);`) + query := tx.Rebind(` + INSERT INTO + public_keys (user_id, public_key, updated_at) + VALUES + (?, ?, CURRENT_TIMESTAMP); + `) ak := sshutils.MarshalAuthorizedKey(pk) _, err := tx.ExecContext(ctx, query, userID, ak) if err != nil { @@ -61,17 +72,23 @@ func (*userStore) CreateUser(ctx context.Context, tx db.Handler, username string } } + for i, e := range emails { + if err := s.AddUserEmail(ctx, tx, userID, e, i == 0); err != nil { + return err + } + } + return nil } // DeleteUserByUsername implements store.UserStore. func (*userStore) DeleteUserByUsername(ctx context.Context, tx db.Handler, username string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } - query := tx.Rebind(`DELETE FROM users WHERE username = ?;`) + query := tx.Rebind(`DELETE FROM users WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`) _, err := tx.ExecContext(ctx, query, username) return err } @@ -98,12 +115,12 @@ func (*userStore) FindUserByPublicKey(ctx context.Context, tx db.Handler, pk ssh // FindUserByUsername implements store.UserStore. func (*userStore) FindUserByUsername(ctx context.Context, tx db.Handler, username string) (models.User, error) { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return models.User{}, err } var m models.User - query := tx.Rebind(`SELECT * FROM users WHERE username = ?;`) + query := tx.Rebind(`SELECT * FROM users WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`) err := tx.GetContext(ctx, &m, query, username) return m, err } @@ -153,14 +170,14 @@ func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx db.Handler, id // ListPublicKeysByUsername implements store.UserStore. func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, username string) ([]ssh.PublicKey, error) { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return nil, err } var aks []string query := tx.Rebind(`SELECT public_key FROM public_keys INNER JOIN users ON users.id = public_keys.user_id - WHERE users.username = ? + WHERE users.handle_id = (SELECT id FROM handles WHERE handle = ?) ORDER BY public_keys.id ASC;`) err := tx.SelectContext(ctx, &aks, query, username) if err != nil { @@ -182,12 +199,14 @@ func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, u // RemovePublicKeyByUsername implements store.UserStore. func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } query := tx.Rebind(`DELETE FROM public_keys - WHERE user_id = (SELECT id FROM users WHERE username = ?) + WHERE user_id = (SELECT id FROM users WHERE handle_id = ( + SELECT id FROM handles WHERE handle = ? + )) AND public_key = ?;`) _, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk)) return err @@ -196,11 +215,11 @@ func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, // SetAdminByUsername implements store.UserStore. func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, username string, isAdmin bool) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } - query := tx.Rebind(`UPDATE users SET admin = ? WHERE username = ?;`) + query := tx.Rebind(`UPDATE users SET admin = ? WHERE handle_id = (SELECT id FROM handles WHERE handle = ?)`) _, err := tx.ExecContext(ctx, query, isAdmin, username) return err } @@ -208,16 +227,16 @@ func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, usernam // SetUsernameByUsername implements store.UserStore. func (*userStore) SetUsernameByUsername(ctx context.Context, tx db.Handler, username string, newUsername string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } newUsername = strings.ToLower(newUsername) - if err := utils.ValidateUsername(newUsername); err != nil { + if err := utils.ValidateHandle(newUsername); err != nil { return err } - query := tx.Rebind(`UPDATE users SET username = ? WHERE username = ?;`) + query := tx.Rebind(`UPDATE handles SET handle = ? WHERE handle = ?;`) _, err := tx.ExecContext(ctx, query, newUsername, username) return err } @@ -232,11 +251,68 @@ func (*userStore) SetUserPassword(ctx context.Context, tx db.Handler, userID int // SetUserPasswordByUsername implements store.UserStore. func (*userStore) SetUserPasswordByUsername(ctx context.Context, tx db.Handler, username string, password string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } - query := tx.Rebind(`UPDATE users SET password = ? WHERE username = ?;`) + query := tx.Rebind(`UPDATE users SET password = ? WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`) _, err := tx.ExecContext(ctx, query, password, username) return err } + +// AddUserEmail implements store.UserStore. +func (*userStore) AddUserEmail(ctx context.Context, tx db.Handler, userID int64, email string, isPrimary bool) error { + if err := utils.ValidateEmail(email); err != nil { + return err + } + query := tx.Rebind(`INSERT INTO user_emails (user_id, email, is_primary, updated_at) + VALUES (?, ?, ?, CURRENT_TIMESTAMP);`) + _, err := tx.ExecContext(ctx, query, userID, email, isPrimary) + return err +} + +// ListUserEmails implements store.UserStore. +func (*userStore) ListUserEmails(ctx context.Context, tx db.Handler, userID int64) ([]models.UserEmail, error) { + var ms []models.UserEmail + query := tx.Rebind(`SELECT * FROM user_emails WHERE user_id = ?;`) + err := tx.SelectContext(ctx, &ms, query, userID) + return ms, err +} + +// RemoveUserEmail implements store.UserStore. +func (*userStore) RemoveUserEmail(ctx context.Context, tx db.Handler, userID int64, email string) error { + var e models.UserEmail + query := tx.Rebind(`DELETE FROM user_emails WHERE user_id = ? AND email = ? RETURNING *;`) + if err := tx.GetContext(ctx, &e, query, userID, email); err != nil { + return err + } + + if e.IsPrimary { + return fmt.Errorf("cannot remove primary email") + } else if e.ID == 0 { + return db.ErrRecordNotFound + } + + return nil +} + +// SetUserPrimaryEmail implements store.UserStore. +func (*userStore) SetUserPrimaryEmail(ctx context.Context, tx db.Handler, userID int64, email string) error { + query := tx.Rebind(`UPDATE user_emails SET is_primary = FALSE WHERE user_id = ?;`) + _, err := tx.ExecContext(ctx, query, userID) + if err != nil { + return err + } + + var emailID int64 + query = tx.Rebind(`UPDATE user_emails SET is_primary = TRUE WHERE user_id = ? AND email = ? RETURNING id;`) + if err := tx.GetContext(ctx, &emailID, query, userID, email); err != nil { + return err + } + + if emailID == 0 { + return db.ErrRecordNotFound + } + + return nil +} diff --git a/pkg/store/handle.go b/pkg/store/handle.go new file mode 100644 index 000000000..570a56022 --- /dev/null +++ b/pkg/store/handle.go @@ -0,0 +1,19 @@ +package store + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/db" + "github.com/charmbracelet/soft-serve/pkg/db/models" +) + +// HandleStore is a store for username handles. +type HandleStore interface { + GetHandleByID(ctx context.Context, h db.Handler, id int64) (models.Handle, error) + GetHandleByHandle(ctx context.Context, h db.Handler, handle string) (models.Handle, error) + GetHandleByUserID(ctx context.Context, h db.Handler, userID int64) (models.Handle, error) + ListHandlesForIDs(ctx context.Context, h db.Handler, ids []int64) ([]models.Handle, error) + UpdateHandle(ctx context.Context, h db.Handler, id int64, handle string) error + CreateHandle(ctx context.Context, h db.Handler, handle string) (int64, error) + DeleteHandle(ctx context.Context, h db.Handler, id int64) error +} diff --git a/pkg/store/org.go b/pkg/store/org.go new file mode 100644 index 000000000..4657fb6f1 --- /dev/null +++ b/pkg/store/org.go @@ -0,0 +1,24 @@ +package store + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/access" + "github.com/charmbracelet/soft-serve/pkg/db" + "github.com/charmbracelet/soft-serve/pkg/db/models" +) + +// OrgStore is a store for organizations. +type OrgStore interface { + CreateOrg(ctx context.Context, h db.Handler, user int64, name, email string) (models.Organization, error) + ListOrgs(ctx context.Context, h db.Handler, user int64) ([]models.Organization, error) + GetOrgByID(ctx context.Context, h db.Handler, user, id int64) (models.Organization, error) + FindOrgByHandle(ctx context.Context, h db.Handler, user int64, name string) (models.Organization, error) + DeleteOrgByID(ctx context.Context, h db.Handler, user, id int64) error + AddUserToOrg(ctx context.Context, h db.Handler, org, user int64, lvl access.AccessLevel) error + RemoveUserFromOrg(ctx context.Context, h db.Handler, org, user int64) error + UpdateUserAccessInOrg(ctx context.Context, h db.Handler, org, user int64, lvl access.AccessLevel) error + UpdateOrgContactEmail(ctx context.Context, h db.Handler, org int64, email string) error + // TODO: rename org? + // XXX: what else? +} diff --git a/pkg/store/store.go b/pkg/store/store.go index 41490cbf8..a002e5c38 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -4,9 +4,12 @@ package store type Store interface { RepositoryStore UserStore + OrgStore + TeamStore CollaboratorStore SettingStore LFSStore AccessTokenStore WebhookStore + HandleStore } diff --git a/pkg/store/team.go b/pkg/store/team.go new file mode 100644 index 000000000..026f87091 --- /dev/null +++ b/pkg/store/team.go @@ -0,0 +1,22 @@ +package store + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/access" + "github.com/charmbracelet/soft-serve/pkg/db" + "github.com/charmbracelet/soft-serve/pkg/db/models" +) + +// TeamStore is a store for teams. +type TeamStore interface { + CreateTeam(ctx context.Context, h db.Handler, user, org int64, name string) (models.Team, error) + ListTeams(ctx context.Context, h db.Handler, user int64) ([]models.Team, error) + GetTeamByID(ctx context.Context, h db.Handler, user, id int64) (models.Team, error) + FindTeamByOrgName(ctx context.Context, h db.Handler, user, org int64, name string) (models.Team, error) + FindTeamByName(ctx context.Context, h db.Handler, user int64, name string) ([]models.Team, error) + DeleteTeamByID(ctx context.Context, h db.Handler, id int64) error + AddUserToTeam(ctx context.Context, h db.Handler, team, user int64, lvl access.AccessLevel) error + RemoveUserFromTeam(ctx context.Context, h db.Handler, team, user int64) error + UpdateUserAccessInTeam(ctx context.Context, h db.Handler, team, user int64, lvl access.AccessLevel) error +} diff --git a/pkg/store/user.go b/pkg/store/user.go index 260a7279c..6e6f1c2a5 100644 --- a/pkg/store/user.go +++ b/pkg/store/user.go @@ -15,7 +15,7 @@ type UserStore interface { FindUserByPublicKey(ctx context.Context, h db.Handler, pk ssh.PublicKey) (models.User, error) FindUserByAccessToken(ctx context.Context, h db.Handler, token string) (models.User, error) GetAllUsers(ctx context.Context, h db.Handler) ([]models.User, error) - CreateUser(ctx context.Context, h db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error + CreateUser(ctx context.Context, h db.Handler, username string, isAdmin bool, pks []ssh.PublicKey, emails []string) error DeleteUserByUsername(ctx context.Context, h db.Handler, username string) error SetUsernameByUsername(ctx context.Context, h db.Handler, username string, newUsername string) error SetAdminByUsername(ctx context.Context, h db.Handler, username string, isAdmin bool) error @@ -25,4 +25,9 @@ type UserStore interface { ListPublicKeysByUsername(ctx context.Context, h db.Handler, username string) ([]ssh.PublicKey, error) SetUserPassword(ctx context.Context, h db.Handler, userID int64, password string) error SetUserPasswordByUsername(ctx context.Context, h db.Handler, username string, password string) error + + AddUserEmail(ctx context.Context, h db.Handler, userID int64, email string, isPrimary bool) error + ListUserEmails(ctx context.Context, h db.Handler, userID int64) ([]models.UserEmail, error) + RemoveUserEmail(ctx context.Context, h db.Handler, userID int64, email string) error + SetUserPrimaryEmail(ctx context.Context, h db.Handler, userID int64, email string) error } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index a98cb139c..0b6660e88 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -1,12 +1,20 @@ package utils import ( + "errors" "fmt" + "net/mail" "path" "strings" "unicode" ) +var ( + + // ErrInvalidEmail indicates that an email address is invalid. + ErrInvalidEmail = errors.New("invalid email address") +) + // SanitizeRepo returns a sanitized version of the given repository name. func SanitizeRepo(repo string) string { repo = strings.TrimPrefix(repo, "/") @@ -17,19 +25,19 @@ func SanitizeRepo(repo string) string { return repo } -// ValidateUsername returns an error if any of the given usernames are invalid. -func ValidateUsername(username string) error { - if username == "" { - return fmt.Errorf("username cannot be empty") +// ValidateHandle returns an error if any of the given usernames are invalid. +func ValidateHandle(handle string) error { + if handle == "" { + return fmt.Errorf("cannot be empty") } - if !unicode.IsLetter(rune(username[0])) { - return fmt.Errorf("username must start with a letter") + if !unicode.IsLetter(rune(handle[0])) { + return fmt.Errorf("must start with a letter") } - for _, r := range username { + for _, r := range handle { if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '-' { - return fmt.Errorf("username can only contain letters, numbers, and hyphens") + return fmt.Errorf("can only contain letters, numbers, and hyphens") } } @@ -50,3 +58,17 @@ func ValidateRepo(repo string) error { return nil } + +// ValidateEmail returns an error if the given email address is invalid. +func ValidateEmail(email string) error { + if strings.ContainsAny(email, " <>") { + return ErrInvalidEmail + } + + _, err := mail.ParseAddress(email) + if err != nil { + return fmt.Errorf("%w: %s", ErrInvalidEmail, err) + } + + return nil +} diff --git a/pkg/web/git_lfs.go b/pkg/web/git_lfs.go index dccd87f95..1fb1136b2 100644 --- a/pkg/web/git_lfs.go +++ b/pkg/web/git_lfs.go @@ -520,7 +520,15 @@ func serviceLfsLocksCreate(w http.ResponseWriter, r *http.Request) { }) return } - lockOwner.Name = owner.Username + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + logger.Error("error getting lock owner handle", "err", err) + renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ + Message: "internal server error", + }) + return + } + lockOwner.Name = handle.Handle } errResp.Lock.Owner = lockOwner } @@ -632,6 +640,15 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { return } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + logger.Error("error getting lock owner handle", "err", err) + renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ + Message: "internal server error", + }) + return + } + renderJSON(w, http.StatusOK, lfs.LockListResponse{ Locks: []lfs.Lock{ { @@ -639,7 +656,7 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { Path: lock.Path, LockedAt: lock.CreatedAt, Owner: lfs.Owner{ - Name: owner.Username, + Name: handle.Handle, }, }, }, @@ -670,6 +687,15 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { return } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + logger.Error("error getting lock owner handle", "err", err) + renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ + Message: "internal server error", + }) + return + } + renderJSON(w, http.StatusOK, lfs.LockListResponse{ Locks: []lfs.Lock{ { @@ -677,7 +703,7 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { Path: lock.Path, LockedAt: lock.CreatedAt, Owner: lfs.Owner{ - Name: owner.Username, + Name: handle.Handle, }, }, }, @@ -695,11 +721,11 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { } lockList := make([]lfs.Lock, len(locks)) - users := map[int64]models.User{} + users := map[int64]userModel{} for i, lock := range locks { owner, ok := users[lock.UserID] if !ok { - owner, err = datastore.GetUserByID(ctx, dbx, lock.UserID) + user, err := datastore.GetUserByID(ctx, dbx, lock.UserID) if err != nil { logger.Error("error getting lock owner", "err", err) renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ @@ -707,7 +733,15 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { }) return } - users[lock.UserID] = owner + handle, err := datastore.GetHandleByUserID(ctx, dbx, user.ID) + if err != nil { + logger.Error("error getting lock owner handle", "err", err) + renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ + Message: "internal server error", + }) + return + } + users[lock.UserID] = userModel{User: user, Handle: handle} } lockList[i] = lfs.Lock{ @@ -715,7 +749,7 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { Path: lock.Path, LockedAt: lock.CreatedAt, Owner: lfs.Owner{ - Name: owner.Username, + Name: owner.Handle.Handle, }, } } @@ -730,6 +764,11 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { renderJSON(w, http.StatusOK, resp) } +type userModel struct { + models.User + models.Handle +} + // POST: /.git/info/lfs/objects/locks/verify func serviceLfsLocksVerify(w http.ResponseWriter, r *http.Request) { if !isLfs(r) { @@ -786,11 +825,11 @@ func serviceLfsLocksVerify(w http.ResponseWriter, r *http.Request) { return } - users := map[int64]models.User{} + users := map[int64]userModel{} for _, lock := range locks { owner, ok := users[lock.UserID] if !ok { - owner, err = datastore.GetUserByID(ctx, dbx, lock.UserID) + user, err := datastore.GetUserByID(ctx, dbx, lock.UserID) if err != nil { logger.Error("error getting lock owner", "err", err) renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ @@ -798,7 +837,15 @@ func serviceLfsLocksVerify(w http.ResponseWriter, r *http.Request) { }) return } - users[lock.UserID] = owner + handle, err := datastore.GetHandleByUserID(ctx, dbx, user.ID) + if err != nil { + logger.Error("error getting lock owner handle", "err", err) + renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ + Message: "internal server error", + }) + return + } + users[lock.UserID] = userModel{User: user, Handle: handle} } l := lfs.Lock{ @@ -806,7 +853,7 @@ func serviceLfsLocksVerify(w http.ResponseWriter, r *http.Request) { Path: lock.Path, LockedAt: lock.CreatedAt, Owner: lfs.Owner{ - Name: owner.Username, + Name: owner.Handle.Handle, }, } @@ -893,13 +940,22 @@ func serviceLfsLocksDelete(w http.ResponseWriter, r *http.Request) { return } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + logger.Error("error getting lock owner handle", "err", err) + renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ + Message: "internal server error", + }) + return + } + // Delete another user's lock l := lfs.Lock{ ID: strconv.FormatInt(lock.ID, 10), Path: lock.Path, LockedAt: lock.CreatedAt, Owner: lfs.Owner{ - Name: owner.Username, + Name: handle.Handle, }, } if req.Force { diff --git a/pkg/webhook/branch_tag.go b/pkg/webhook/branch_tag.go index 89771e9fe..6630f00ff 100644 --- a/pkg/webhook/branch_tag.go +++ b/pkg/webhook/branch_tag.go @@ -75,8 +75,13 @@ func NewBranchTagEvent(ctx context.Context, user proto.User, repo proto.Reposito return BranchTagEvent{}, db.WrapError(err) } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + return BranchTagEvent{}, db.WrapError(err) + } + payload.Repository.Owner.ID = owner.ID - payload.Repository.Owner.Username = owner.Username + payload.Repository.Owner.Username = handle.Handle payload.Repository.DefaultBranch, err = proto.RepositoryDefaultBranch(repo) if err != nil { return BranchTagEvent{}, err diff --git a/pkg/webhook/collaborator.go b/pkg/webhook/collaborator.go index e7b737b16..32f7e2f75 100644 --- a/pkg/webhook/collaborator.go +++ b/pkg/webhook/collaborator.go @@ -2,6 +2,7 @@ package webhook import ( "context" + "fmt" "github.com/charmbracelet/soft-serve/pkg/access" "github.com/charmbracelet/soft-serve/pkg/db" @@ -63,8 +64,13 @@ func NewCollaboratorEvent(ctx context.Context, user proto.User, repo proto.Repos return CollaboratorEvent{}, db.WrapError(err) } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + return CollaboratorEvent{}, db.WrapError(err) + } + payload.Repository.Owner.ID = owner.ID - payload.Repository.Owner.Username = owner.Username + payload.Repository.Owner.Username = handle.Handle payload.Repository.DefaultBranch, err = proto.RepositoryDefaultBranch(repo) if err != nil { return CollaboratorEvent{}, err @@ -76,7 +82,11 @@ func NewCollaboratorEvent(ctx context.Context, user proto.User, repo proto.Repos } payload.AccessLevel = collab.AccessLevel - payload.Collaborator.ID = collab.UserID + if !collab.UserID.Valid { + return CollaboratorEvent{}, fmt.Errorf("collaborator user ID is invalid") + } + + payload.Collaborator.ID = collab.UserID.Int64 payload.Collaborator.Username = collabUsername return payload, nil diff --git a/pkg/webhook/push.go b/pkg/webhook/push.go index 6ef6062cd..f225a00d1 100644 --- a/pkg/webhook/push.go +++ b/pkg/webhook/push.go @@ -66,8 +66,13 @@ func NewPushEvent(ctx context.Context, user proto.User, repo proto.Repository, r return PushEvent{}, db.WrapError(err) } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + return PushEvent{}, db.WrapError(err) + } + payload.Repository.Owner.ID = owner.ID - payload.Repository.Owner.Username = owner.Username + payload.Repository.Owner.Username = handle.Handle // Find commits. r, err := repo.Open() diff --git a/pkg/webhook/repository.go b/pkg/webhook/repository.go index 3cad5c39f..3c0879944 100644 --- a/pkg/webhook/repository.go +++ b/pkg/webhook/repository.go @@ -74,8 +74,13 @@ func NewRepositoryEvent(ctx context.Context, user proto.User, repo proto.Reposit return RepositoryEvent{}, db.WrapError(err) } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + return RepositoryEvent{}, db.WrapError(err) + } + payload.Repository.Owner.ID = owner.ID - payload.Repository.Owner.Username = owner.Username + payload.Repository.Owner.Username = handle.Handle payload.Repository.DefaultBranch, _ = proto.RepositoryDefaultBranch(repo) return payload, nil diff --git a/testscript/script_test.go b/testscript/script_test.go index 32ebe77c1..cec209d7f 100644 --- a/testscript/script_test.go +++ b/testscript/script_test.go @@ -76,6 +76,7 @@ func TestScript(t *testing.T) { key, admin1 := mkkey("admin1") _, admin2 := mkkey("admin2") _, user1 := mkkey("user1") + _, user2 := mkkey("user2") testscript.Run(t, testscript.Params{ Dir: "./testdata/", @@ -117,6 +118,7 @@ func TestScript(t *testing.T) { e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey()) e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey()) e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey()) + e.Setenv("USER2_AUTHORIZED_KEY", user2.AuthorizedKey()) e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts")) e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config")) diff --git a/testscript/testdata/help.txtar b/testscript/testdata/help.txtar index d6756ca02..bdc9ce046 100644 --- a/testscript/testdata/help.txtar +++ b/testscript/testdata/help.txtar @@ -23,6 +23,7 @@ Available Commands: help Help about any command info Show your info jwt Generate a JSON Web Token + org Manage organizations pubkey Manage your public keys repo Manage repositories set-username Set your username diff --git a/testscript/testdata/user_management.txtar b/testscript/testdata/user_management.txtar index f65397090..15f1863c9 100644 --- a/testscript/testdata/user_management.txtar +++ b/testscript/testdata/user_management.txtar @@ -1,7 +1,7 @@ # vi: set ft=conf # convert crlf to lf on windows -[windows] dos2unix info.txt admin_key_list1.txt admin_key_list2.txt list1.txt list2.txt foo_info1.txt foo_info2.txt foo_info3.txt foo_info4.txt foo_info5.txt +[windows] dos2unix info.txt admin_key_list1.txt admin_key_list2.txt list1.txt list2.txt foo_info1.txt foo_info2.txt foo_info3.txt foo_info4.txt foo_info5.txt bar_info.txt # start soft serve exec soft serve & @@ -68,6 +68,38 @@ soft user delete foo2 soft user list cmpenv stdout list1.txt +# create a new user with an invalid email +! soft user create bar --key "$USER2_AUTHORIZED_KEY" "foobar" +stderr 'invalid email address.*' + +# create a new user with a valid email +soft user create bar --key "$USER2_AUTHORIZED_KEY" "foo@bar.baz" +! stdout . +# add email to existing user +soft user add-email bar "foobar@fubar.baz" +! stdout . +# add existing email +! soft user add-email bar "foobar@fubar.baz" +stderr 'duplicate key.*' + +# get new user info +soft user info bar +cmpenv stdout bar_info.txt + +# remove primary email from user +! soft user remove-email bar "foo@bar.baz" +stderr 'cannot remove primary email.*' + +# set primary email that doesn't exist +! soft user set-primary-email bar "foobar@foofoo.foo" +stderr 'no rows in result set.*' +# set primary email +soft user set-primary-email bar "foobar@fubar.baz" +! stdout . +# remove other email +soft user remove-email bar "foo@bar.baz" +! stdout . + # stop the server [windows] stopserver [windows] ! stderr . @@ -112,3 +144,11 @@ $ADMIN1_AUTHORIZED_KEY $ADMIN2_AUTHORIZED_KEY -- admin_key_list2.txt -- $ADMIN1_AUTHORIZED_KEY +-- bar_info.txt -- +Username: bar +Admin: false +Public keys: + $USER2_AUTHORIZED_KEY +Emails: + foo@bar.baz (primary: true) + foobar@fubar.baz (primary: false)