From caca3898f40a382f2b0475ec3432722c92cba312 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Mon, 11 Dec 2023 10:48:15 -0500 Subject: [PATCH 1/2] feat: initial orgs & teams support Squashed commit of the following: commit 26d3df790c79a0803033ba71b7c394a27c401f1e Author: Ayman Bagabas Date: Thu Nov 30 10:33:32 2023 -0800 feat: add team collaborators (#19) Repository collaborators can be either individual users, or organization teams. When it's an org team, the repo must belong to the same organization that the team belongs to. This validation happens in the app logic. commit 8a598a1898228432114d81d158eb25578b112353 Merge: 9399cd6a60f9 5ae66bc932bd Author: Carlos Alexandro Becker Date: Tue Nov 14 09:37:42 2023 -0300 Merge pull request #15 from charmbracelet/orgs feat: initial organization support commit 5ae66bc932bde406b994877c7c32def5115d9afc Author: Carlos Alexandro Becker Date: Tue Nov 14 10:06:20 2023 +0000 fix: code review Signed-off-by: Carlos Alexandro Becker commit 4d16aecedec8b1a46ee1ea73c2d1cc4b51f0ebd1 Author: Carlos Alexandro Becker Date: Mon Nov 13 18:55:04 2023 +0000 wip team commit 3374c8cd24f3c03188cf1f9e27b9242fd9df00a1 Author: Carlos Alexandro Becker Date: Mon Nov 13 17:15:54 2023 +0000 wip team commit 890975daab842cab7207d68535c7ed5e2dec98c7 Author: Carlos Alexandro Becker Date: Mon Nov 13 15:59:22 2023 +0000 wip teams commit d67773949c1ca93ca392d156241f9f166e471723 Author: Carlos Alexandro Becker Date: Mon Nov 13 15:44:33 2023 +0000 fix: team signatures commit b3881958409bd98bc8861f4048826e9efaca7b6d Author: Carlos Alexandro Becker Date: Mon Nov 13 15:44:26 2023 +0000 fix: rename method commit 509585f7c6444e0a0106d1100204bee0ee65e70f Author: Carlos Alexandro Becker Date: Mon Nov 13 12:32:03 2023 +0000 wip commit ab0124e25c12cf2d6c8d75f793ac80a236bbbe82 Author: Carlos Alexandro Becker Date: Thu Nov 9 20:47:15 2023 +0000 fix: cr commit 3e6de9b54dbb4d0cefe3891f424976c760c39ee3 Author: Carlos Alexandro Becker Date: Thu Nov 9 20:44:46 2023 +0000 fixes commit 452815ea7231b3d362b95ea7ab2149492862974a Author: Carlos Alexandro Becker Date: Thu Nov 9 20:37:44 2023 +0000 wip commit aa25e4fc40104c9c390b2d69eb4fc33128938074 Author: Carlos Alexandro Becker Date: Thu Nov 9 20:35:39 2023 +0000 wip commit 28241cc64e57b2b3a30ed6897b1156a88965e99b Merge: 12a70b3e5e0c 9399cd6a60f9 Author: Carlos Alexandro Becker Date: Thu Nov 9 20:33:30 2023 +0000 Merge remote-tracking branch 'origin/orgs-teams' into orgs commit 9399cd6a60f9d3a6e6bf95aa4bb454161944c540 Author: Ayman Bagabas Date: Thu Nov 9 12:22:02 2023 -0500 feat: add email user relations and models commit a0715c42d628528292501591c6b5764a5ee80ff8 Author: Ayman Bagabas Date: Thu Nov 9 10:08:23 2023 -0500 fix: carlos comments commit 12a70b3e5e0cc625dd48511d035e632b551fd05b Author: Carlos Alexandro Becker Date: Thu Nov 9 13:40:37 2023 +0000 fix: admin commit 637c8bccd6877e3447a9e1f6f13e1a1105a3616f Author: Carlos Alexandro Becker Date: Thu Nov 9 12:57:01 2023 +0000 wip commit 4ec7653f4ba6fed4faee21f8416d2dbbe46ef3d3 Author: Carlos Alexandro Becker Date: Thu Nov 9 12:32:46 2023 +0000 fix: merge issues commit d8b8e22f98f5700d84d599e9a0483ab1d0572497 Merge: c2bf2721d2d0 777e451128b1 Author: Carlos Alexandro Becker Date: Thu Nov 9 12:27:26 2023 +0000 Merge remote-tracking branch 'origin/orgs-teams' into orgs commit c2bf2721d2d0b17e4cfa2cc1a45d9a21cf4197fd Author: Carlos Alexandro Becker Date: Thu Nov 9 02:25:19 2023 +0000 wip commit 92b5f57ec09a013b051dd6828dd1beb4ed95d641 Author: Carlos Alexandro Becker Date: Thu Nov 9 02:10:37 2023 +0000 wip commit 83f6cf906a5779168706277ec37e3bb84578b55b Author: Carlos Alexandro Becker Date: Wed Nov 8 19:42:02 2023 +0000 wip commit 777e451128b141304341e6497d6a96a6f4185f3a Author: Ayman Bagabas Date: Wed Nov 8 12:56:25 2023 -0500 fix: lint commit 50f2b054550489bf8df81a96f553764096734c89 Author: Ayman Bagabas Date: Wed Nov 8 12:53:24 2023 -0500 feat: add models and missing columns commit 84cb5889d269b135ed3e9ad1dbb660a6e562041f Author: Ayman Bagabas Date: Wed Nov 8 12:39:44 2023 -0500 fix(backend): update backend to use handles table commit af16adab439e5905704be986e32be48ef85d2846 Author: Carlos Alexandro Becker Date: Wed Nov 8 17:34:18 2023 +0000 wip: adding orgs commit a222f24860f25c25f92aacbe1d96294cb7d8627c Author: Ayman Bagabas Date: Wed Nov 8 07:33:06 2023 -0800 Add organizations and teams migration (#9) * feat(db): pre/post migration * feat(db): add create orgs/teams migration commit f7f521e9bff1efedd30be27acfde31137070bac5 Author: Ayman Bagabas Date: Tue Nov 7 16:44:09 2023 -0500 wip --- pkg/backend/access.go | 95 ++++++++++ pkg/backend/collab.go | 28 ++- pkg/backend/org.go | 66 +++++++ pkg/backend/team.go | 72 +++++++ pkg/backend/user.go | 150 +++++---------- pkg/db/handler.go | 1 + pkg/db/migrate/0001_create_tables.go | 64 +++---- pkg/db/migrate/0002_webhooks.go | 8 +- pkg/db/migrate/0003_migrate_lfs_objects.go | 8 +- pkg/db/migrate/0004_create_orgs_teams.go | 46 +++++ .../0004_create_orgs_teams_postgres.down.sql | 0 .../0004_create_orgs_teams_postgres.up.sql | 132 +++++++++++++ .../0004_create_orgs_teams_sqlite.down.sql | 0 .../0004_create_orgs_teams_sqlite.up.sql | 178 ++++++++++++++++++ pkg/db/migrate/migrate.go | 65 ++++--- pkg/db/migrate/migrations.go | 15 +- pkg/db/models/collab.go | 4 +- pkg/db/models/handle.go | 11 ++ pkg/db/models/org.go | 25 +++ pkg/db/models/repo.go | 1 + pkg/db/models/team.go | 23 +++ pkg/db/models/user.go | 13 +- pkg/git/lfs.go | 18 +- pkg/proto/org.go | 11 ++ pkg/proto/team.go | 11 ++ pkg/ssh/cmd/org.go | 90 +++++++++ pkg/ssh/cmd/team.go | 121 ++++++++++++ pkg/ssh/middleware.go | 1 + pkg/store/database/collab.go | 17 +- pkg/store/database/database.go | 10 +- pkg/store/database/handle.go | 88 +++++++++ pkg/store/database/org.go | 165 ++++++++++++++++ pkg/store/database/team.go | 155 +++++++++++++++ pkg/store/database/user.go | 94 ++++++--- pkg/store/handle.go | 19 ++ pkg/store/org.go | 24 +++ pkg/store/store.go | 3 + pkg/store/team.go | 22 +++ pkg/store/user.go | 5 + pkg/utils/utils.go | 16 +- pkg/web/git_lfs.go | 80 ++++++-- pkg/webhook/branch_tag.go | 7 +- pkg/webhook/collaborator.go | 14 +- pkg/webhook/push.go | 7 +- pkg/webhook/repository.go | 7 +- testscript/testdata/help.txtar | 1 + 46 files changed, 1750 insertions(+), 241 deletions(-) create mode 100644 pkg/backend/access.go create mode 100644 pkg/backend/org.go create mode 100644 pkg/backend/team.go create mode 100644 pkg/db/migrate/0004_create_orgs_teams.go create mode 100644 pkg/db/migrate/0004_create_orgs_teams_postgres.down.sql create mode 100644 pkg/db/migrate/0004_create_orgs_teams_postgres.up.sql create mode 100644 pkg/db/migrate/0004_create_orgs_teams_sqlite.down.sql create mode 100644 pkg/db/migrate/0004_create_orgs_teams_sqlite.up.sql create mode 100644 pkg/db/models/handle.go create mode 100644 pkg/db/models/org.go create mode 100644 pkg/db/models/team.go create mode 100644 pkg/proto/org.go create mode 100644 pkg/proto/team.go create mode 100644 pkg/ssh/cmd/org.go create mode 100644 pkg/ssh/cmd/team.go create mode 100644 pkg/store/database/handle.go create mode 100644 pkg/store/database/org.go create mode 100644 pkg/store/database/team.go create mode 100644 pkg/store/handle.go create mode 100644 pkg/store/org.go create mode 100644 pkg/store/team.go diff --git a/pkg/backend/access.go b/pkg/backend/access.go new file mode 100644 index 000000000..719ef06ab --- /dev/null +++ b/pkg/backend/access.go @@ -0,0 +1,95 @@ +package backend + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/access" + "github.com/charmbracelet/soft-serve/pkg/proto" + "github.com/charmbracelet/soft-serve/pkg/sshutils" + "golang.org/x/crypto/ssh" +) + +// AccessLevel returns the access level of a user for a repository. +// +// It implements backend.Backend. +func (d *Backend) AccessLevel(ctx context.Context, repo string, username string) access.AccessLevel { + user, _ := d.User(ctx, username) + return d.AccessLevelForUser(ctx, repo, user) +} + +// AccessLevelByPublicKey returns the access level of a user's public key for a repository. +// +// It implements backend.Backend. +func (d *Backend) AccessLevelByPublicKey(ctx context.Context, repo string, pk ssh.PublicKey) access.AccessLevel { + for _, k := range d.cfg.AdminKeys() { + if sshutils.KeysEqual(pk, k) { + return access.AdminAccess + } + } + + user, _ := d.UserByPublicKey(ctx, pk) + if user != nil { + return d.AccessLevel(ctx, repo, user.Username()) + } + + return d.AccessLevel(ctx, repo, "") +} + +// AccessLevelForUser returns the access level of a user for a repository. +// TODO: user repository ownership +func (d *Backend) AccessLevelForUser(ctx context.Context, repo string, user proto.User) access.AccessLevel { + var username string + anon := d.AnonAccess(ctx) + if user != nil { + username = user.Username() + } + + // If the user is an admin, they have admin access. + if user != nil && user.IsAdmin() { + return access.AdminAccess + } + + // If the repository exists, check if the user is a collaborator. + r := proto.RepositoryFromContext(ctx) + if r == nil { + r, _ = d.Repository(ctx, repo) + } + + if r != nil { + if user != nil { + // If the user is the owner, they have admin access. + if r.UserID() == user.ID() { + return access.AdminAccess + } + } + + // If the user is a collaborator, they have return their access level. + collabAccess, isCollab, _ := d.IsCollaborator(ctx, repo, username) + if isCollab { + if anon > collabAccess { + return anon + } + return collabAccess + } + + // If the repository is private, the user has no access. + if r.IsPrivate() { + return access.NoAccess + } + + // Otherwise, the user has read-only access. + return access.ReadOnlyAccess + } + + if user != nil { + // If the repository doesn't exist, the user has read/write access. + if anon > access.ReadWriteAccess { + return anon + } + + return access.ReadWriteAccess + } + + // If the user doesn't exist, give them the anonymous access level. + return anon +} diff --git a/pkg/backend/collab.go b/pkg/backend/collab.go index c9635ae37..46ad438f0 100644 --- a/pkg/backend/collab.go +++ b/pkg/backend/collab.go @@ -18,7 +18,7 @@ import ( // It implements backend.Backend. func (d *Backend) AddCollaborator(ctx context.Context, repo string, username string, level access.AccessLevel) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -50,19 +50,33 @@ func (d *Backend) AddCollaborator(ctx context.Context, repo string, username str func (d *Backend) Collaborators(ctx context.Context, repo string) ([]string, error) { repo = utils.SanitizeRepo(repo) var users []models.User + var usernames []string if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { var err error users, err = d.store.ListCollabsByRepoAsUsers(ctx, tx, repo) - return err + if err != nil { + return err + } + + ids := make([]int64, len(users)) + for i, u := range users { + ids[i] = u.ID + } + + handles, err := d.store.ListHandlesForIDs(ctx, tx, ids) + if err != nil { + return err + } + + for _, h := range handles { + usernames = append(usernames, h.Handle) + } + + return nil }); err != nil { return nil, db.WrapError(err) } - var usernames []string - for _, u := range users { - usernames = append(usernames, u.Username) - } - return usernames, nil } diff --git a/pkg/backend/org.go b/pkg/backend/org.go new file mode 100644 index 000000000..2fd7ee10d --- /dev/null +++ b/pkg/backend/org.go @@ -0,0 +1,66 @@ +package backend + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/db/models" + "github.com/charmbracelet/soft-serve/pkg/proto" +) + +// CreateOrg creates a new organization. +func (d *Backend) CreateOrg(ctx context.Context, owner proto.User, name, email string) (proto.Org, error) { + o, err := d.store.CreateOrg(ctx, d.db, owner.ID(), name, email) + if err != nil { + return org{}, err + } + return org{o}, err +} + +// ListOrgs lists all organizations for a user. +func (d *Backend) ListOrgs(ctx context.Context, user proto.User) ([]proto.Org, error) { + orgs, err := d.store.ListOrgs(ctx, d.db, user.ID()) + var r []proto.Org + for _, o := range orgs { + r = append(r, org{o}) + } + return r, err +} + +// FindOrganization finds an organization belonging to a user by name. +func (d *Backend) FindOrganization(ctx context.Context, user proto.User, name string) (proto.Org, error) { + o, err := d.store.FindOrgByHandle(ctx, d.db, user.ID(), name) + return org{o}, err +} + +// DeleteOrganization deletes an organization for a user. +func (d *Backend) DeleteOrganization(ctx context.Context, user proto.User, name string) error { + o, err := d.store.FindOrgByHandle(ctx, d.db, user.ID(), name) + if err != nil { + return err + } + return d.store.DeleteOrgByID(ctx, d.db, user.ID(), o.ID) +} + +type org struct { + o models.Organization +} + +var _ proto.Org = org{} + +// DisplayName implements proto.Org. +func (o org) DisplayName() string { + if o.o.Name.Valid { + return o.o.Name.String + } + return "" +} + +// ID implements proto.Org. +func (o org) ID() int64 { + return o.o.ID +} + +// Name implements proto.Org. +func (o org) Name() string { + return o.o.Handle.Handle +} diff --git a/pkg/backend/team.go b/pkg/backend/team.go new file mode 100644 index 000000000..188e2ace1 --- /dev/null +++ b/pkg/backend/team.go @@ -0,0 +1,72 @@ +package backend + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/db/models" + "github.com/charmbracelet/soft-serve/pkg/proto" +) + +// CreateTeam creates a new team for an organization. +func (d *Backend) CreateTeam(ctx context.Context, org proto.Org, owner proto.User, name string) (proto.Team, error) { + m, err := d.store.CreateTeam(ctx, d.db, owner.ID(), org.ID(), name) + if err != nil { + return team{}, err + } + return team{m}, err +} + +// ListTeams lists all teams for a user. +func (d *Backend) ListTeams(ctx context.Context, user proto.User) ([]proto.Team, error) { + teams, err := d.store.ListTeams(ctx, d.db, user.ID()) + var r []proto.Team + for _, m := range teams { + r = append(r, team{m}) + } + return r, err +} + +// GetTeam gets a team by organization id and team name. +func (d *Backend) GetTeam(ctx context.Context, user proto.User, org proto.Org, name string) (proto.Team, error) { + m, err := d.store.FindTeamByOrgName(ctx, d.db, user.ID(), org.ID(), name) + if err != nil { + return team{}, err + } + return team{m}, err +} + +// FindTeam finds a team by name. +func (d *Backend) FindTeam(ctx context.Context, user proto.User, name string) ([]proto.Team, error) { + m, err := d.store.FindTeamByName(ctx, d.db, user.ID(), name) + var r []proto.Team + for _, m := range m { + r = append(r, team{m}) + } + return r, err +} + +// DeleteTeam deletes a team. +func (d *Backend) DeleteTeam(ctx context.Context, _ proto.User, team proto.Team) error { + return d.store.DeleteTeamByID(ctx, d.db, team.ID()) +} + +type team struct { + t models.Team +} + +var _ proto.Team = team{} + +// ID implements proto.Team. +func (t team) ID() int64 { + return t.t.ID +} + +// Name implements proto.Team. +func (t team) Name() string { + return t.t.Name +} + +// Org implements proto.Team. +func (t team) Org() int64 { + return t.t.OrganizationID +} diff --git a/pkg/backend/user.go b/pkg/backend/user.go index 75423048d..bc651d2fe 100644 --- a/pkg/backend/user.go +++ b/pkg/backend/user.go @@ -6,7 +6,6 @@ import ( "strings" "time" - "github.com/charmbracelet/soft-serve/pkg/access" "github.com/charmbracelet/soft-serve/pkg/db" "github.com/charmbracelet/soft-serve/pkg/db/models" "github.com/charmbracelet/soft-serve/pkg/proto" @@ -15,102 +14,18 @@ import ( "golang.org/x/crypto/ssh" ) -// AccessLevel returns the access level of a user for a repository. -// -// It implements backend.Backend. -func (d *Backend) AccessLevel(ctx context.Context, repo string, username string) access.AccessLevel { - user, _ := d.User(ctx, username) - return d.AccessLevelForUser(ctx, repo, user) -} - -// AccessLevelByPublicKey returns the access level of a user's public key for a repository. -// -// It implements backend.Backend. -func (d *Backend) AccessLevelByPublicKey(ctx context.Context, repo string, pk ssh.PublicKey) access.AccessLevel { - for _, k := range d.cfg.AdminKeys() { - if sshutils.KeysEqual(pk, k) { - return access.AdminAccess - } - } - - user, _ := d.UserByPublicKey(ctx, pk) - if user != nil { - return d.AccessLevel(ctx, repo, user.Username()) - } - - return d.AccessLevel(ctx, repo, "") -} - -// AccessLevelForUser returns the access level of a user for a repository. -// TODO: user repository ownership -func (d *Backend) AccessLevelForUser(ctx context.Context, repo string, user proto.User) access.AccessLevel { - var username string - anon := d.AnonAccess(ctx) - if user != nil { - username = user.Username() - } - - // If the user is an admin, they have admin access. - if user != nil && user.IsAdmin() { - return access.AdminAccess - } - - // If the repository exists, check if the user is a collaborator. - r := proto.RepositoryFromContext(ctx) - if r == nil { - r, _ = d.Repository(ctx, repo) - } - - if r != nil { - if user != nil { - // If the user is the owner, they have admin access. - if r.UserID() == user.ID() { - return access.AdminAccess - } - } - - // If the user is a collaborator, they have return their access level. - collabAccess, isCollab, _ := d.IsCollaborator(ctx, repo, username) - if isCollab { - if anon > collabAccess { - return anon - } - return collabAccess - } - - // If the repository is private, the user has no access. - if r.IsPrivate() { - return access.NoAccess - } - - // Otherwise, the user has read-only access. - return access.ReadOnlyAccess - } - - if user != nil { - // If the repository doesn't exist, the user has read/write access. - if anon > access.ReadWriteAccess { - return anon - } - - return access.ReadWriteAccess - } - - // If the user doesn't exist, give them the anonymous access level. - return anon -} - // User finds a user by username. // // It implements backend.Backend. func (d *Backend) User(ctx context.Context, username string) (proto.User, error) { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return nil, err } var m models.User var pks []ssh.PublicKey + var hl models.Handle if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { var err error m, err = d.store.FindUserByUsername(ctx, tx, username) @@ -119,6 +34,11 @@ func (d *Backend) User(ctx context.Context, username string) (proto.User, error) } pks, err = d.store.ListPublicKeysByUserID(ctx, tx, m.ID) + if err != nil { + return err + } + + hl, err = d.store.GetHandleByUserID(ctx, tx, m.ID) return err }); err != nil { err = db.WrapError(err) @@ -132,6 +52,7 @@ func (d *Backend) User(ctx context.Context, username string) (proto.User, error) return &user{ user: m, publicKeys: pks, + handle: hl, }, nil } @@ -139,6 +60,7 @@ func (d *Backend) User(ctx context.Context, username string) (proto.User, error) func (d *Backend) UserByID(ctx context.Context, id int64) (proto.User, error) { var m models.User var pks []ssh.PublicKey + var hl models.Handle if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { var err error m, err = d.store.GetUserByID(ctx, tx, id) @@ -147,6 +69,11 @@ func (d *Backend) UserByID(ctx context.Context, id int64) (proto.User, error) { } pks, err = d.store.ListPublicKeysByUserID(ctx, tx, m.ID) + if err != nil { + return err + } + + hl, err = d.store.GetHandleByUserID(ctx, tx, m.ID) return err }); err != nil { err = db.WrapError(err) @@ -160,6 +87,7 @@ func (d *Backend) UserByID(ctx context.Context, id int64) (proto.User, error) { return &user{ user: m, publicKeys: pks, + handle: hl, }, nil } @@ -169,6 +97,7 @@ func (d *Backend) UserByID(ctx context.Context, id int64) (proto.User, error) { func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (proto.User, error) { var m models.User var pks []ssh.PublicKey + var hl models.Handle if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { var err error m, err = d.store.FindUserByPublicKey(ctx, tx, pk) @@ -177,6 +106,11 @@ func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (proto. } pks, err = d.store.ListPublicKeysByUserID(ctx, tx, m.ID) + if err != nil { + return err + } + + hl, err = d.store.GetHandleByUserID(ctx, tx, m.ID) return err }); err != nil { err = db.WrapError(err) @@ -190,6 +124,7 @@ func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (proto. return &user{ user: m, publicKeys: pks, + handle: hl, }, nil } @@ -198,6 +133,7 @@ func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (proto. func (d *Backend) UserByAccessToken(ctx context.Context, token string) (proto.User, error) { var m models.User var pks []ssh.PublicKey + var hl models.Handle token = HashToken(token) if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { @@ -216,6 +152,11 @@ func (d *Backend) UserByAccessToken(ctx context.Context, token string) (proto.Us } pks, err = d.store.ListPublicKeysByUserID(ctx, tx, m.ID) + if err != nil { + return err + } + + hl, err = d.store.GetHandleByUserID(ctx, tx, m.ID) return err }); err != nil { err = db.WrapError(err) @@ -229,6 +170,7 @@ func (d *Backend) UserByAccessToken(ctx context.Context, token string) (proto.Us return &user{ user: m, publicKeys: pks, + handle: hl, }, nil } @@ -243,8 +185,18 @@ func (d *Backend) Users(ctx context.Context) ([]string, error) { return err } - for _, m := range ms { - users = append(users, m.Username) + ids := make([]int64, len(ms)) + for i, m := range ms { + ids[i] = m.ID + } + + handles, err := d.store.ListHandlesForIDs(ctx, tx, ids) + if err != nil { + return err + } + + for _, h := range handles { + users = append(users, h.Handle) } return nil @@ -260,7 +212,7 @@ func (d *Backend) Users(ctx context.Context) ([]string, error) { // It implements backend.Backend. func (d *Backend) AddPublicKey(ctx context.Context, username string, pk ssh.PublicKey) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -275,11 +227,6 @@ func (d *Backend) AddPublicKey(ctx context.Context, username string, pk ssh.Publ // // It implements backend.Backend. func (d *Backend) CreateUser(ctx context.Context, username string, opts proto.UserOptions) (proto.User, error) { - username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { - return nil, err - } - if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { return d.store.CreateUser(ctx, tx, username, opts.Admin, opts.PublicKeys) }); err != nil { @@ -294,7 +241,7 @@ func (d *Backend) CreateUser(ctx context.Context, username string, opts proto.Us // It implements backend.Backend. func (d *Backend) DeleteUser(ctx context.Context, username string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -321,7 +268,7 @@ func (d *Backend) RemovePublicKey(ctx context.Context, username string, pk ssh.P // ListPublicKeys lists the public keys of a user. func (d *Backend) ListPublicKeys(ctx context.Context, username string) ([]ssh.PublicKey, error) { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return nil, err } @@ -342,7 +289,7 @@ func (d *Backend) ListPublicKeys(ctx context.Context, username string) ([]ssh.Pu // It implements backend.Backend. func (d *Backend) SetUsername(ctx context.Context, username string, newUsername string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -358,7 +305,7 @@ func (d *Backend) SetUsername(ctx context.Context, username string, newUsername // It implements backend.Backend. func (d *Backend) SetAdmin(ctx context.Context, username string, admin bool) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -372,7 +319,7 @@ func (d *Backend) SetAdmin(ctx context.Context, username string, admin bool) err // SetPassword sets the password of a user. func (d *Backend) SetPassword(ctx context.Context, username string, rawPassword string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -391,6 +338,7 @@ func (d *Backend) SetPassword(ctx context.Context, username string, rawPassword type user struct { user models.User publicKeys []ssh.PublicKey + handle models.Handle } var _ proto.User = (*user)(nil) @@ -407,7 +355,7 @@ func (u *user) PublicKeys() []ssh.PublicKey { // Username implements proto.User func (u *user) Username() string { - return u.user.Username + return u.handle.Handle } // ID implements proto.User. diff --git a/pkg/db/handler.go b/pkg/db/handler.go index 981cadf21..ecad917ef 100644 --- a/pkg/db/handler.go +++ b/pkg/db/handler.go @@ -9,6 +9,7 @@ import ( // Handler is a database handler. type Handler interface { + DriverName() string Rebind(string) string Select(interface{}, string, ...interface{}) error diff --git a/pkg/db/migrate/0001_create_tables.go b/pkg/db/migrate/0001_create_tables.go index 8491328c0..590450291 100644 --- a/pkg/db/migrate/0001_create_tables.go +++ b/pkg/db/migrate/0001_create_tables.go @@ -20,71 +20,71 @@ const ( var createTables = Migration{ Version: createTablesVersion, Name: createTablesName, - Migrate: func(ctx context.Context, tx *db.Tx) error { + Migrate: func(ctx context.Context, h db.Handler) error { cfg := config.FromContext(ctx) insert := "INSERT " // Alter old tables (if exist) // This is to support prior versions of Soft Serve v0.6 - switch tx.DriverName() { + switch h.DriverName() { case "sqlite3", "sqlite": insert += "OR IGNORE " - hasUserTable := hasTable(tx, "user") + hasUserTable := hasTable(h, "user") if hasUserTable { - if _, err := tx.ExecContext(ctx, "ALTER TABLE user RENAME TO user_old"); err != nil { + if _, err := h.ExecContext(ctx, "ALTER TABLE user RENAME TO user_old"); err != nil { return err } } - if hasTable(tx, "public_key") { - if _, err := tx.ExecContext(ctx, "ALTER TABLE public_key RENAME TO public_key_old"); err != nil { + if hasTable(h, "public_key") { + if _, err := h.ExecContext(ctx, "ALTER TABLE public_key RENAME TO public_key_old"); err != nil { return err } } - if hasTable(tx, "collab") { - if _, err := tx.ExecContext(ctx, "ALTER TABLE collab RENAME TO collab_old"); err != nil { + if hasTable(h, "collab") { + if _, err := h.ExecContext(ctx, "ALTER TABLE collab RENAME TO collab_old"); err != nil { return err } } - if hasTable(tx, "repo") { - if _, err := tx.ExecContext(ctx, "ALTER TABLE repo RENAME TO repo_old"); err != nil { + if hasTable(h, "repo") { + if _, err := h.ExecContext(ctx, "ALTER TABLE repo RENAME TO repo_old"); err != nil { return err } } } - if err := migrateUp(ctx, tx, createTablesVersion, createTablesName); err != nil { + if err := migrateUp(ctx, h, createTablesVersion, createTablesName); err != nil { return err } - switch tx.DriverName() { + switch h.DriverName() { case "sqlite3", "sqlite": - if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil { + if _, err := h.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil { return err } - if hasTable(tx, "user_old") { + if hasTable(h, "user_old") { sqlm := ` INSERT INTO users (id, username, admin, updated_at) SELECT id, username, admin, updated_at FROM user_old; ` - if _, err := tx.ExecContext(ctx, sqlm); err != nil { + if _, err := h.ExecContext(ctx, sqlm); err != nil { return err } } - if hasTable(tx, "public_key_old") { + if hasTable(h, "public_key_old") { // Check duplicate keys pks := []struct { ID string `db:"id"` PublicKey string `db:"public_key"` }{} - if err := tx.SelectContext(ctx, &pks, "SELECT id, public_key FROM public_key_old"); err != nil { + if err := h.SelectContext(ctx, &pks, "SELECT id, public_key FROM public_key_old"); err != nil { return err } @@ -100,53 +100,53 @@ var createTables = Migration{ INSERT INTO public_keys (id, user_id, public_key, created_at, updated_at) SELECT id, user_id, public_key, created_at, updated_at FROM public_key_old; ` - if _, err := tx.ExecContext(ctx, sqlm); err != nil { + if _, err := h.ExecContext(ctx, sqlm); err != nil { return err } } - if hasTable(tx, "repo_old") { + if hasTable(h, "repo_old") { sqlm := ` INSERT INTO repos (id, name, project_name, description, private,mirror, hidden, created_at, updated_at, user_id) SELECT id, name, project_name, description, private, mirror, hidden, created_at, updated_at, ( SELECT id FROM users WHERE admin = true ORDER BY id LIMIT 1 ) FROM repo_old; ` - if _, err := tx.ExecContext(ctx, sqlm); err != nil { + if _, err := h.ExecContext(ctx, sqlm); err != nil { return err } } - if hasTable(tx, "collab_old") { + if hasTable(h, "collab_old") { sqlm := ` INSERT INTO collabs (id, user_id, repo_id, access_level, created_at, updated_at) SELECT id, user_id, repo_id, ` + strconv.Itoa(int(access.ReadWriteAccess)) + `, created_at, updated_at FROM collab_old; ` - if _, err := tx.ExecContext(ctx, sqlm); err != nil { + if _, err := h.ExecContext(ctx, sqlm); err != nil { return err } } - if _, err := tx.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { + if _, err := h.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { return err } } // Insert default user - insertUser := tx.Rebind(insert + "INTO users (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)") - if _, err := tx.ExecContext(ctx, insertUser, "admin", true); err != nil { + insertUser := h.Rebind(insert + "INTO users (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)") + if _, err := h.ExecContext(ctx, insertUser, "admin", true); err != nil { return err } for _, k := range cfg.AdminKeys() { query := insert + "INTO public_keys (user_id, public_key, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)" - if tx.DriverName() == "postgres" { + if h.DriverName() == "postgres" { query += " ON CONFLICT DO NOTHING" } - query = tx.Rebind(query) + query = h.Rebind(query) ak := sshutils.MarshalAuthorizedKey(k) - if _, err := tx.ExecContext(ctx, query, 1, ak); err != nil { + if _, err := h.ExecContext(ctx, query, 1, ak); err != nil { if errors.Is(db.WrapError(err), db.ErrDuplicateKey) { continue } @@ -156,7 +156,7 @@ var createTables = Migration{ // Insert default settings insertSettings := insert + "INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)" - insertSettings = tx.Rebind(insertSettings) + insertSettings = h.Rebind(insertSettings) settings := []struct { Key string Value string @@ -167,14 +167,14 @@ var createTables = Migration{ } for _, s := range settings { - if _, err := tx.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil { + if _, err := h.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil { return fmt.Errorf("inserting default settings %q: %w", s.Key, err) } } return nil }, - Rollback: func(ctx context.Context, tx *db.Tx) error { - return migrateDown(ctx, tx, createTablesVersion, createTablesName) + Rollback: func(ctx context.Context, h db.Handler) error { + return migrateDown(ctx, h, createTablesVersion, createTablesName) }, } diff --git a/pkg/db/migrate/0002_webhooks.go b/pkg/db/migrate/0002_webhooks.go index a525f6edc..257be3640 100644 --- a/pkg/db/migrate/0002_webhooks.go +++ b/pkg/db/migrate/0002_webhooks.go @@ -14,10 +14,10 @@ const ( var webhooks = Migration{ Name: webhooksName, Version: webhooksVersion, - Migrate: func(ctx context.Context, tx *db.Tx) error { - return migrateUp(ctx, tx, webhooksVersion, webhooksName) + Migrate: func(ctx context.Context, h db.Handler) error { + return migrateUp(ctx, h, webhooksVersion, webhooksName) }, - Rollback: func(ctx context.Context, tx *db.Tx) error { - return migrateDown(ctx, tx, webhooksVersion, webhooksName) + Rollback: func(ctx context.Context, h db.Handler) error { + return migrateDown(ctx, h, webhooksVersion, webhooksName) }, } diff --git a/pkg/db/migrate/0003_migrate_lfs_objects.go b/pkg/db/migrate/0003_migrate_lfs_objects.go index 70aaa79e8..bb0dde129 100644 --- a/pkg/db/migrate/0003_migrate_lfs_objects.go +++ b/pkg/db/migrate/0003_migrate_lfs_objects.go @@ -23,17 +23,17 @@ const ( var migrateLfsObjects = Migration{ Name: migrateLfsObjectsName, Version: migrateLfsObjectsVersion, - Migrate: func(ctx context.Context, tx *db.Tx) error { + Migrate: func(ctx context.Context, h db.Handler) error { cfg := config.FromContext(ctx) logger := log.FromContext(ctx).WithPrefix("migrate_lfs_objects") var repoIds []int64 - if err := tx.Select(&repoIds, "SELECT id FROM repos"); err != nil { + if err := h.Select(&repoIds, "SELECT id FROM repos"); err != nil { return err } for _, r := range repoIds { var objs []models.LFSObject - if err := tx.Select(&objs, "SELECT * FROM lfs_objects WHERE repo_id = ?", r); err != nil { + if err := h.Select(&objs, "SELECT * FROM lfs_objects WHERE repo_id = ?", r); err != nil { return err } objsp := filepath.Join(cfg.DataPath, "lfs", strconv.FormatInt(r, 10), "objects") @@ -50,7 +50,7 @@ var migrateLfsObjects = Migration{ } return nil }, - Rollback: func(ctx context.Context, tx *db.Tx) error { + Rollback: func(ctx context.Context, h db.Handler) error { return nil }, } diff --git a/pkg/db/migrate/0004_create_orgs_teams.go b/pkg/db/migrate/0004_create_orgs_teams.go new file mode 100644 index 000000000..738e2b561 --- /dev/null +++ b/pkg/db/migrate/0004_create_orgs_teams.go @@ -0,0 +1,46 @@ +package migrate + +import ( + "context" + "strings" + + "github.com/charmbracelet/soft-serve/pkg/db" +) + +const ( + createOrgsTeamsName = "create_orgs_teams" + createOrgsTeamsVersion = 4 +) + +var createOrgsTeams = Migration{ + Name: createOrgsTeamsName, + Version: createOrgsTeamsVersion, + PreMigrate: func(ctx context.Context, h db.Handler) error { + if strings.HasPrefix(h.DriverName(), "sqlite") { + if _, err := h.ExecContext(ctx, "PRAGMA foreign_keys = OFF;"); err != nil { + return err + } + if _, err := h.ExecContext(ctx, "PRAGMA legacy_alter_table = ON;"); err != nil { + return err + } + } + return nil + }, + PostMigrate: func(ctx context.Context, h db.Handler) error { + if strings.HasPrefix(h.DriverName(), "sqlite") { + if _, err := h.ExecContext(ctx, "PRAGMA foreign_keys = ON;"); err != nil { + return err + } + if _, err := h.ExecContext(ctx, "PRAGMA legacy_alter_table = OFF;"); err != nil { + return err + } + } + return nil + }, + Migrate: func(ctx context.Context, h db.Handler) error { + return migrateUp(ctx, h, createOrgsTeamsVersion, createOrgsTeamsName) + }, + Rollback: func(ctx context.Context, h db.Handler) error { + return migrateDown(ctx, h, createOrgsTeamsVersion, createOrgsTeamsName) + }, +} diff --git a/pkg/db/migrate/0004_create_orgs_teams_postgres.down.sql b/pkg/db/migrate/0004_create_orgs_teams_postgres.down.sql new file mode 100644 index 000000000..e69de29bb diff --git a/pkg/db/migrate/0004_create_orgs_teams_postgres.up.sql b/pkg/db/migrate/0004_create_orgs_teams_postgres.up.sql new file mode 100644 index 000000000..74a10ffef --- /dev/null +++ b/pkg/db/migrate/0004_create_orgs_teams_postgres.up.sql @@ -0,0 +1,132 @@ +CREATE TABLE IF NOT EXISTS handles ( + id SERIAL PRIMARY KEY, + handle TEXT NOT NULL UNIQUE, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL +); + +CREATE TABLE IF NOT EXISTS organizations ( + id SERIAL PRIMARY KEY, + name TEXT, + contact_email TEXT NOT NULL, + handle_id INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + CONSTRAINT handle_id_fk + FOREIGN KEY(handle_id) REFERENCES handles(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS organization_members ( + id SERIAL PRIMARY KEY, + org_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (org_id, user_id), + CONSTRAINT org_id_fk + FOREIGN KEY(org_id) REFERENCES organizations(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS teams ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + org_id INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (name, org_id), + CONSTRAINT org_id_fk + FOREIGN KEY(org_id) REFERENCES organizations(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS team_members ( + id SERIAL PRIMARY KEY, + team_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (team_id, user_id), + CONSTRAINT team_id_fk + FOREIGN KEY(team_id) REFERENCES teams(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS user_emails ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL, + email TEXT NOT NULL UNIQUE, + is_primary BOOLEAN NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +-- Add name to users table +ALTER TABLE users ADD COLUMN name TEXT; + +-- Add handle_id to users table +ALTER TABLE users ADD COLUMN handle_id INTEGER; +ALTER TABLE users ADD CONSTRAINT handle_id_fk + FOREIGN KEY(handle_id) REFERENCES handles(id) + ON DELETE CASCADE + ON UPDATE CASCADE; + +-- Migrate user username to handles +INSERT INTO handles (handle, updated_at) SELECT username, updated_at FROM users; + +-- Update handle_id for users +UPDATE users SET handle_id = handles.id FROM handles WHERE handles.handle = users.username; + +-- Make handle_id not null and unique +ALTER TABLE users ALTER COLUMN handle_id SET NOT NULL; +ALTER TABLE users ADD CONSTRAINT handle_id_unique UNIQUE (handle_id); + +-- Drop username from users +ALTER TABLE users DROP COLUMN username; + +-- Add org_id to repos table +ALTER TABLE repos ADD COLUMN org_id INTEGER; +ALTER TABLE repos ADD CONSTRAINT org_id_fk + FOREIGN KEY(org_id) REFERENCES organizations(id) + ON DELETE CASCADE + ON UPDATE CASCADE; + +-- Alter user_id nullness in repos table +ALTER TABLE repos ALTER COLUMN user_id DROP NOT NULL; + +-- Check that both user_id and org_id can't be null +ALTER TABLE repos ADD CONSTRAINT user_id_org_id_not_null CHECK (user_id IS NULL <> org_id IS NULL); + +-- Add team_id to collabs table +ALTER TABLE collabs ADD COLUMN team_id INTEGER; +ALTER TABLE collabs ADD CONSTRAINT team_id_fk + FOREIGN KEY(team_id) REFERENCES teams(id) + ON DELETE CASCADE + ON UPDATE CASCADE; + +-- Alter user_id nullness in collabs table +ALTER TABLE collabs ALTER COLUMN user_id DROP NOT NULL; + +-- Check that both user_id and team_id can't be null +ALTER TABLE collabs ADD CONSTRAINT user_id_team_id_not_null CHECK (user_id IS NULL <> team_id IS NULL); + +-- Alter unique constraint on collabs table +ALTER TABLE collabs DROP CONSTRAINT collabs_user_id_repo_id_key; +ALTER TABLE collabs ADD CONSTRAINT collabs_user_id_repo_id_team_id_key UNIQUE (user_id, repo_id, team_id); diff --git a/pkg/db/migrate/0004_create_orgs_teams_sqlite.down.sql b/pkg/db/migrate/0004_create_orgs_teams_sqlite.down.sql new file mode 100644 index 000000000..e69de29bb diff --git a/pkg/db/migrate/0004_create_orgs_teams_sqlite.up.sql b/pkg/db/migrate/0004_create_orgs_teams_sqlite.up.sql new file mode 100644 index 000000000..d9a41eb0d --- /dev/null +++ b/pkg/db/migrate/0004_create_orgs_teams_sqlite.up.sql @@ -0,0 +1,178 @@ +CREATE TABLE IF NOT EXISTS handles ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + handle TEXT NOT NULL UNIQUE, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL +); + +CREATE TABLE IF NOT EXISTS organizations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT, + contact_email TEXT NOT NULL, + handle_id INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + CONSTRAINT handle_id_fk + FOREIGN KEY(handle_id) REFERENCES handles(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS organization_members ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + org_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + access_level INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (org_id, user_id), + CONSTRAINT org_id_fk + FOREIGN KEY(org_id) REFERENCES organizations(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS teams ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + org_id INTEGER NOT NULL, + access_level INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (name, org_id), + CONSTRAINT org_id_fk + FOREIGN KEY(org_id) REFERENCES organizations(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS team_members ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + team_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (team_id, user_id), + CONSTRAINT team_id_fk + FOREIGN KEY(team_id) REFERENCES teams(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS user_emails ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + email TEXT NOT NULL UNIQUE, + is_primary BOOLEAN NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +ALTER TABLE users RENAME TO _users_old; + +CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT, + handle_id INTEGER NOT NULL UNIQUE, + admin BOOLEAN NOT NULL, + password TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + CONSTRAINT handle_id_fk + FOREIGN KEY(handle_id) REFERENCES handles(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +-- Migrate user username to handles +INSERT INTO handles (handle, updated_at) SELECT username, updated_at FROM _users_old; + +-- Migrate users +INSERT INTO users (id, handle_id, admin, password, created_at, updated_at) SELECT id, ( + SELECT id FROM handles WHERE handle = _users_old.username +), admin, password, created_at, updated_at FROM _users_old; + +-- Drop old table +DROP TABLE _users_old; + +ALTER TABLE repos RENAME TO _repos_old; + +CREATE TABLE IF NOT EXISTS repos ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + project_name TEXT NOT NULL, + description TEXT NOT NULL, + private BOOLEAN NOT NULL, + mirror BOOLEAN NOT NULL, + hidden BOOLEAN NOT NULL, + user_id INTEGER, + org_id INTEGER, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT org_id_fk + FOREIGN KEY(org_id) REFERENCES organizations(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT user_id_org_id_not_null + CHECK (user_id IS NULL <> org_id IS NULL) +); + +-- Migrate repos +INSERT INTO repos (id, name, project_name, description, private, mirror, hidden, user_id, created_at, updated_at) +SELECT id, name, project_name, description, private, mirror, hidden, user_id, created_at, updated_at +FROM _repos_old; + +-- Drop old table +DROP TABLE _repos_old; + +-- Alter collabs table +ALTER TABLE collabs RENAME TO _collabs_old; + +CREATE TABLE IF NOT EXISTS collabs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER, + team_id INTEGER, + repo_id INTEGER NOT NULL, + access_level INTEGER NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + UNIQUE (user_id, team_id, repo_id), + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT team_id_fk + FOREIGN KEY(team_id) REFERENCES teams(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT repo_id_fk + FOREIGN KEY(repo_id) REFERENCES repos(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT user_id_team_id_not_null + CHECK (user_id IS NULL <> team_id IS NULL) +); + +-- Migrate collabs +INSERT INTO collabs (id, user_id, team_id, repo_id, access_level, created_at, updated_at) +SELECT id, user_id, NULL, repo_id, access_level, created_at, updated_at +FROM _collabs_old; + +-- Drop old table +DROP TABLE _collabs_old; diff --git a/pkg/db/migrate/migrate.go b/pkg/db/migrate/migrate.go index 18bc1780d..276aac132 100644 --- a/pkg/db/migrate/migrate.go +++ b/pkg/db/migrate/migrate.go @@ -11,15 +11,17 @@ import ( ) // MigrateFunc is a function that executes a migration. -type MigrateFunc func(ctx context.Context, tx *db.Tx) error // nolint:revive +type MigrateFunc func(ctx context.Context, h db.Handler) error // nolint:revive // Migration is a struct that contains the name of the migration and the // function to execute it. type Migration struct { - Version int64 - Name string - Migrate MigrateFunc - Rollback MigrateFunc + Version int64 + Name string + PreMigrate MigrateFunc + Migrate MigrateFunc + PostMigrate MigrateFunc + Rollback MigrateFunc } // Migrations is a database model to store migrations. @@ -62,37 +64,50 @@ func (Migrations) schema(driverName string) string { // Migrate runs the migrations. func Migrate(ctx context.Context, dbx *db.DB) error { logger := log.FromContext(ctx).WithPrefix("migrate") - return dbx.TransactionContext(ctx, func(tx *db.Tx) error { - if !hasTable(tx, "migrations") { - if _, err := tx.Exec(Migrations{}.schema(tx.DriverName())); err != nil { - return err - } + + if !hasTable(dbx, "migrations") { + if _, err := dbx.Exec(Migrations{}.schema(dbx.DriverName())); err != nil { + return err } + } - var migrs Migrations - if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil { - if !errors.Is(err, sql.ErrNoRows) { - return err - } + var migrs Migrations + if err := dbx.Get(&migrs, dbx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return err } + } - for _, m := range migrations { - if m.Version <= migrs.Version { - continue - } + for _, m := range migrations { + if m.Version <= migrs.Version { + continue + } - logger.Infof("running migration %d. %s", m.Version, m.Name) - if err := m.Migrate(ctx, tx); err != nil { + logger.Infof("running migration %d. %s", m.Version, m.Name) + if m.PreMigrate != nil { + if err := m.PreMigrate(ctx, dbx); err != nil { return err } + } + + if err := dbx.TransactionContext(ctx, func(tx *db.Tx) error { + return m.Migrate(ctx, tx) + }); err != nil { + return err + } - if _, err := tx.Exec(tx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil { + if m.PostMigrate != nil { + if err := m.PostMigrate(ctx, dbx); err != nil { return err } } - return nil - }) + if _, err := dbx.Exec(dbx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil { + return err + } + } + + return nil } // Rollback rolls back a migration. @@ -124,7 +139,7 @@ func Rollback(ctx context.Context, dbx *db.DB) error { }) } -func hasTable(tx *db.Tx, tableName string) bool { +func hasTable(tx db.Handler, tableName string) bool { var query string switch tx.DriverName() { case "sqlite3", "sqlite": diff --git a/pkg/db/migrate/migrations.go b/pkg/db/migrate/migrations.go index e2598b414..ff0ea08e8 100644 --- a/pkg/db/migrate/migrations.go +++ b/pkg/db/migrate/migrations.go @@ -18,15 +18,16 @@ var migrations = []Migration{ createTables, webhooks, migrateLfsObjects, + createOrgsTeams, } -func execMigration(ctx context.Context, tx *db.Tx, version int, name string, down bool) error { +func execMigration(ctx context.Context, h db.Handler, version int, name string, down bool) error { direction := "up" if down { direction = "down" } - driverName := tx.DriverName() + driverName := h.DriverName() if driverName == "sqlite3" { driverName = "sqlite" } @@ -37,19 +38,19 @@ func execMigration(ctx context.Context, tx *db.Tx, version int, name string, dow return err } - if _, err := tx.ExecContext(ctx, string(sqlstr)); err != nil { + if _, err := h.ExecContext(ctx, string(sqlstr)); err != nil { return err } return nil } -func migrateUp(ctx context.Context, tx *db.Tx, version int, name string) error { - return execMigration(ctx, tx, version, name, false) +func migrateUp(ctx context.Context, h db.Handler, version int, name string) error { + return execMigration(ctx, h, version, name, false) } -func migrateDown(ctx context.Context, tx *db.Tx, version int, name string) error { - return execMigration(ctx, tx, version, name, true) +func migrateDown(ctx context.Context, h db.Handler, version int, name string) error { + return execMigration(ctx, h, version, name, true) } var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") diff --git a/pkg/db/models/collab.go b/pkg/db/models/collab.go index 5141e5986..149143152 100644 --- a/pkg/db/models/collab.go +++ b/pkg/db/models/collab.go @@ -1,6 +1,7 @@ package models import ( + "database/sql" "time" "github.com/charmbracelet/soft-serve/pkg/access" @@ -10,7 +11,8 @@ import ( type Collab struct { ID int64 `db:"id"` RepoID int64 `db:"repo_id"` - UserID int64 `db:"user_id"` + UserID sql.NullInt64 `db:"user_id"` + TeamID sql.NullInt64 `db:"team_id"` AccessLevel access.AccessLevel `db:"access_level"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` diff --git a/pkg/db/models/handle.go b/pkg/db/models/handle.go new file mode 100644 index 000000000..99b532dc7 --- /dev/null +++ b/pkg/db/models/handle.go @@ -0,0 +1,11 @@ +package models + +import "time" + +// Handle represents a name handle. +type Handle struct { + ID int64 `db:"id"` + Handle string `db:"handle"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} diff --git a/pkg/db/models/org.go b/pkg/db/models/org.go new file mode 100644 index 000000000..596aa13cc --- /dev/null +++ b/pkg/db/models/org.go @@ -0,0 +1,25 @@ +package models + +import ( + "database/sql" + "time" +) + +// Organization represents an organization in the system. +type Organization struct { + ID int64 `db:"id"` + Name sql.NullString `db:"name"` + ContactEmail string `db:"contact_email"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` + Handle Handle `db:"handle"` +} + +// OrganizationMember represents a member of an organization. +type OrganizationMember struct { + ID int64 `db:"id"` + OrganizationID int64 `db:"org_id"` + UserID int64 `db:"user_id"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} diff --git a/pkg/db/models/repo.go b/pkg/db/models/repo.go index 88300bd3e..e58803931 100644 --- a/pkg/db/models/repo.go +++ b/pkg/db/models/repo.go @@ -15,6 +15,7 @@ type Repo struct { Mirror bool `db:"mirror"` Hidden bool `db:"hidden"` UserID sql.NullInt64 `db:"user_id"` + OrgID sql.NullInt64 `db:"org_id"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` } diff --git a/pkg/db/models/team.go b/pkg/db/models/team.go new file mode 100644 index 000000000..de04b7a8d --- /dev/null +++ b/pkg/db/models/team.go @@ -0,0 +1,23 @@ +package models + +import ( + "time" +) + +// Team represents a team in an organization. +type Team struct { + ID int64 `db:"id"` + Name string `db:"name"` + OrganizationID int64 `db:"org_id"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} + +// TeamMember represents a member of a team. +type TeamMember struct { + ID int64 `db:"id"` + TeamID int64 `db:"team_id"` + UserID int64 `db:"user_id"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} diff --git a/pkg/db/models/user.go b/pkg/db/models/user.go index 5ca0d3d9f..3d107a2eb 100644 --- a/pkg/db/models/user.go +++ b/pkg/db/models/user.go @@ -8,9 +8,20 @@ import ( // User represents a user. type User struct { ID int64 `db:"id"` - Username string `db:"username"` + Name sql.NullString `db:"name"` Admin bool `db:"admin"` Password sql.NullString `db:"password"` + HandleID int64 `db:"handle_id"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` } + +// UserEmail represents a user's email address. +type UserEmail struct { + ID int64 `db:"id"` + UserID int64 `db:"user_id"` + Email string `db:"email"` + IsPrimary bool `db:"is_primary"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} diff --git a/pkg/git/lfs.go b/pkg/git/lfs.go index 4b0065f3e..67adc5c3b 100644 --- a/pkg/git/lfs.go +++ b/pkg/git/lfs.go @@ -255,6 +255,11 @@ func (l *lfsLockBackend) Create(path string, refname string) (transfer.Lock, err } lock.owner, err = l.store.GetUserByID(l.ctx, tx, lock.lock.UserID) + if err != nil { + return db.WrapError(err) + } + + lock.handle, err = l.store.GetHandleByUserID(l.ctx, tx, lock.owner.ID) return db.WrapError(err) }); err != nil { // Return conflict (409) if the lock already exists. @@ -286,6 +291,11 @@ func (l *lfsLockBackend) FromID(id string) (transfer.Lock, error) { } lock.owner, err = l.store.GetUserByID(l.ctx, tx, lock.lock.UserID) + if err != nil { + return db.WrapError(err) + } + + lock.handle, err = l.store.GetHandleByUserID(l.ctx, tx, lock.owner.ID) return db.WrapError(err) }); err != nil { if errors.Is(err, db.ErrRecordNotFound) { @@ -312,6 +322,11 @@ func (l *lfsLockBackend) FromPath(path string) (transfer.Lock, error) { } lock.owner, err = l.store.GetUserByID(l.ctx, tx, lock.lock.UserID) + if err != nil { + return db.WrapError(err) + } + + lock.handle, err = l.store.GetHandleByUserID(l.ctx, tx, lock.owner.ID) return db.WrapError(err) }); err != nil { if errors.Is(err, db.ErrRecordNotFound) { @@ -410,6 +425,7 @@ func (l *lfsLockBackend) Unlock(lock transfer.Lock) error { type LFSLock struct { lock models.LFSLock owner models.User + handle models.Handle backend *lfsLockBackend } @@ -459,7 +475,7 @@ func (l *LFSLock) ID() string { // OwnerName implements transfer.Lock. func (l *LFSLock) OwnerName() string { - return l.owner.Username + return l.handle.Handle } // Path implements transfer.Lock. diff --git a/pkg/proto/org.go b/pkg/proto/org.go new file mode 100644 index 000000000..2809fd8af --- /dev/null +++ b/pkg/proto/org.go @@ -0,0 +1,11 @@ +package proto + +// Org is an interface representing a organization. +type Org interface { + // ID returns the user's ID. + ID() int64 + // Name returns the org's name. + Name() string + // DisplayName + DisplayName() string +} diff --git a/pkg/proto/team.go b/pkg/proto/team.go new file mode 100644 index 000000000..516da581f --- /dev/null +++ b/pkg/proto/team.go @@ -0,0 +1,11 @@ +package proto + +// Team is an interface representing a team. +type Team interface { + // ID returns the user's ID. + ID() int64 + // Name returns the org's name. + Name() string + // Parent organization's ID. + Org() int64 +} diff --git a/pkg/ssh/cmd/org.go b/pkg/ssh/cmd/org.go new file mode 100644 index 000000000..251d92c3f --- /dev/null +++ b/pkg/ssh/cmd/org.go @@ -0,0 +1,90 @@ +package cmd + +import ( + "github.com/charmbracelet/soft-serve/pkg/backend" + "github.com/charmbracelet/soft-serve/pkg/proto" + "github.com/spf13/cobra" +) + +// OrgCommand returns a command for managing organizations. +func OrgCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "org", + Aliases: []string{"orgs", "organization", "organizations"}, + Short: "Manage organizations", + } + + cmd.AddCommand(&cobra.Command{ + Use: "create NAME EMAIL", + Short: "Create a new organization", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + owner := proto.UserFromContext(ctx) + if owner == nil { + return proto.ErrUnauthorized + } + _, err := be.CreateOrg(ctx, owner, args[0], args[1]) + return err + }, + }) + cmd.AddCommand(&cobra.Command{ + Use: "list", + Short: "List organizations", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + orgs, err := be.ListOrgs(ctx, user) + if err != nil { + return err + } + for _, o := range orgs { + cmd.Println(o.Name()) + } + return nil + }, + }) + + cmd.AddCommand(&cobra.Command{ + Use: "delete NAME", + Short: "Delete organization", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + return be.DeleteOrganization(ctx, user, args[0]) + }, + }) + + cmd.AddCommand(&cobra.Command{ + Use: "get NAME", + Short: "Show organization", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + org, err := be.FindOrganization(ctx, user, args[0]) + if err != nil { + return err + } + cmd.Println(org.Name()) + return nil + }, + }) + + return cmd +} diff --git a/pkg/ssh/cmd/team.go b/pkg/ssh/cmd/team.go new file mode 100644 index 000000000..c0d1c2919 --- /dev/null +++ b/pkg/ssh/cmd/team.go @@ -0,0 +1,121 @@ +package cmd + +import ( + "github.com/charmbracelet/soft-serve/pkg/backend" + "github.com/charmbracelet/soft-serve/pkg/proto" + "github.com/spf13/cobra" +) + +// TeamCommand returns a command for managing teams. +func TeamCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "team", + Aliases: []string{"teams"}, + Short: "Manage teams", + } + + cmd.AddCommand(&cobra.Command{ + Use: "create ORG NAME", + Short: "Create a new team", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + + org, err := be.FindOrganization(ctx, user, args[0]) + if err != nil { + return err + } + + team, err := be.CreateTeam(ctx, org, user, args[1]) + if err != nil { + return err + } + + cmd.Println("Created", team.Name()) + + return err + }, + }) + + cmd.AddCommand(&cobra.Command{ + Use: "list", + Short: "List teams", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + teams, err := be.ListTeams(ctx, user) + if err != nil { + return err + } + for _, o := range teams { + cmd.Println(o.Name()) + } + return nil + }, + }) + + cmd.AddCommand(&cobra.Command{ + Use: "delete ORG NAME", + Short: "Delete team", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + + org, err := be.FindOrganization(ctx, user, args[0]) + if err != nil { + return err + } + + team, err := be.GetTeam(ctx, user, org, args[1]) + if err != nil { + return err + } + + return be.DeleteTeam(ctx, user, team) + }, + }) + + cmd.AddCommand(&cobra.Command{ + Use: "get ORG NAME", + Short: "Show team", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + user := proto.UserFromContext(ctx) + if user == nil { + return proto.ErrUnauthorized + } + + org, err := be.FindOrganization(ctx, user, args[0]) + if err != nil { + return err + } + + team, err := be.GetTeam(ctx, user, org, args[1]) + if err != nil { + return err + } + + cmd.Println(org.Name(), "/", team.Name()) + return nil + }, + }) + + return cmd +} diff --git a/pkg/ssh/middleware.go b/pkg/ssh/middleware.go index 285a980a7..446d9d5b1 100644 --- a/pkg/ssh/middleware.go +++ b/pkg/ssh/middleware.go @@ -104,6 +104,7 @@ func CommandMiddleware(sh ssh.Handler) ssh.Handler { cmd.RepoCommand(), cmd.SettingsCommand(), cmd.UserCommand(), + cmd.OrgCommand(), cmd.InfoCommand(), cmd.PubkeyCommand(), cmd.SetUsernameCommand(), diff --git a/pkg/store/database/collab.go b/pkg/store/database/collab.go index 9068edef2..ede99515c 100644 --- a/pkg/store/database/collab.go +++ b/pkg/store/database/collab.go @@ -18,7 +18,7 @@ var _ store.CollaboratorStore = (*collabStore)(nil) // AddCollabByUsernameAndRepo implements store.CollaboratorStore. func (*collabStore) AddCollabByUsernameAndRepo(ctx context.Context, tx db.Handler, username string, repo string, level access.AccessLevel) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -28,7 +28,9 @@ func (*collabStore) AddCollabByUsernameAndRepo(ctx context.Context, tx db.Handle VALUES ( ?, ( - SELECT id FROM users WHERE username = ? + SELECT id FROM users WHERE handle_id = ( + SELECT id FROM handles WHERE handle = ? + ) ), ( SELECT id FROM repos WHERE name = ? @@ -44,7 +46,7 @@ func (*collabStore) GetCollabByUsernameAndRepo(ctx context.Context, tx db.Handle var m models.Collab username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return models.Collab{}, err } @@ -56,9 +58,10 @@ func (*collabStore) GetCollabByUsernameAndRepo(ctx context.Context, tx db.Handle FROM collabs INNER JOIN users ON users.id = collabs.user_id + INNER JOIN handles ON handles.id = users.handle_id INNER JOIN repos ON repos.id = collabs.repo_id WHERE - users.username = ? AND repos.name = ? + handles.handle = ? AND repos.name = ? `), username, repo) return m, err @@ -106,7 +109,7 @@ func (*collabStore) ListCollabsByRepoAsUsers(ctx context.Context, tx db.Handler, // RemoveCollabByUsernameAndRepo implements store.CollaboratorStore. func (*collabStore) RemoveCollabByUsernameAndRepo(ctx context.Context, tx db.Handler, username string, repo string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } @@ -116,7 +119,9 @@ func (*collabStore) RemoveCollabByUsernameAndRepo(ctx context.Context, tx db.Han collabs WHERE user_id = ( - SELECT id FROM users WHERE username = ? + SELECT id FROM users WHERE handle_id = ( + SELECT id FROM handles WHERE handle = ? + ) ) AND repo_id = ( SELECT id FROM repos WHERE name = ? ) diff --git a/pkg/store/database/database.go b/pkg/store/database/database.go index 02c4bd19c..de03192a7 100644 --- a/pkg/store/database/database.go +++ b/pkg/store/database/database.go @@ -18,16 +18,20 @@ type datastore struct { *settingsStore *repoStore *userStore + *orgStore + *teamStore *collabStore *lfsStore *accessTokenStore *webhookStore + *handleStore } // New returns a new store.Store database. func New(ctx context.Context, db *db.DB) store.Store { cfg := config.FromContext(ctx) logger := log.FromContext(ctx).WithPrefix("store") + handles := &handleStore{} s := &datastore{ ctx: ctx, @@ -37,10 +41,14 @@ func New(ctx context.Context, db *db.DB) store.Store { settingsStore: &settingsStore{}, repoStore: &repoStore{}, - userStore: &userStore{}, + userStore: &userStore{handles}, + orgStore: &orgStore{handles}, + teamStore: &teamStore{}, collabStore: &collabStore{}, lfsStore: &lfsStore{}, accessTokenStore: &accessTokenStore{}, + webhookStore: &webhookStore{}, + handleStore: handles, } return s diff --git a/pkg/store/database/handle.go b/pkg/store/database/handle.go new file mode 100644 index 000000000..a2b166e4e --- /dev/null +++ b/pkg/store/database/handle.go @@ -0,0 +1,88 @@ +package database + +import ( + "context" + "strings" + + "github.com/charmbracelet/soft-serve/pkg/db" + "github.com/charmbracelet/soft-serve/pkg/db/models" + "github.com/charmbracelet/soft-serve/pkg/store" + "github.com/charmbracelet/soft-serve/pkg/utils" + "github.com/jmoiron/sqlx" +) + +type handleStore struct{} + +var _ store.HandleStore = &handleStore{} + +// CreateHandle implements store.HandleStore. +func (*handleStore) CreateHandle(ctx context.Context, h db.Handler, handle string) (int64, error) { + handle = strings.ToLower(handle) + if err := utils.ValidateHandle(handle); err != nil { + return 0, err + } + + var id int64 + query := h.Rebind("INSERT INTO handles (handle, updated_at) VALUES (?, CURRENT_TIMESTAMP) RETURNING id;") + err := h.GetContext(ctx, &id, query, handle) + return id, db.WrapError(err) +} + +// DeleteHandle implements store.HandleStore. +func (*handleStore) DeleteHandle(ctx context.Context, h db.Handler, id int64) error { + query := h.Rebind("DELETE FROM handles WHERE id = ?;") + _, err := h.ExecContext(ctx, query, id) + return db.WrapError(err) +} + +// GetHandleByHandle implements store.HandleStore. +func (*handleStore) GetHandleByHandle(ctx context.Context, h db.Handler, handle string) (models.Handle, error) { + var hl models.Handle + query := h.Rebind("SELECT * FROM handles WHERE handle = ?;") + err := h.GetContext(ctx, &hl, query, handle) + return hl, db.WrapError(err) +} + +// GetHandleByID implements store.HandleStore. +func (*handleStore) GetHandleByID(ctx context.Context, h db.Handler, id int64) (models.Handle, error) { + var hl models.Handle + query := h.Rebind("SELECT * FROM handles WHERE id = ?;") + err := h.GetContext(ctx, &hl, query, id) + return hl, db.WrapError(err) +} + +// GetHandleByUserID implements store.HandleStore. +func (*handleStore) GetHandleByUserID(ctx context.Context, h db.Handler, userID int64) (models.Handle, error) { + var hl models.Handle + query := h.Rebind("SELECT * FROM handles WHERE id = (SELECT handle_id FROM users WHERE id = ?);") + err := h.GetContext(ctx, &hl, query, userID) + return hl, db.WrapError(err) +} + +// UpdateHandle implements store.HandleStore. +func (*handleStore) UpdateHandle(ctx context.Context, h db.Handler, id int64, handle string) error { + handle = strings.ToLower(handle) + if err := utils.ValidateHandle(handle); err != nil { + return err + } + query := h.Rebind("UPDATE handles SET handle = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?;") + _, err := h.ExecContext(ctx, query, handle, id) + return db.WrapError(err) +} + +// ListHandlesForIDs implements store.HandleStore. +func (*handleStore) ListHandlesForIDs(ctx context.Context, h db.Handler, ids []int64) ([]models.Handle, error) { + var hls []models.Handle + if len(ids) == 0 { + return hls, nil + } + + query, args, err := sqlx.In("SELECT * FROM handles WHERE id IN (?)", ids) + if err != nil { + return nil, db.WrapError(err) + } + + query = h.Rebind(query) + err = h.SelectContext(ctx, &hls, query, args...) + return hls, db.WrapError(err) +} diff --git a/pkg/store/database/org.go b/pkg/store/database/org.go new file mode 100644 index 000000000..de6fd592a --- /dev/null +++ b/pkg/store/database/org.go @@ -0,0 +1,165 @@ +package database + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/access" + "github.com/charmbracelet/soft-serve/pkg/db" + "github.com/charmbracelet/soft-serve/pkg/db/models" + "github.com/charmbracelet/soft-serve/pkg/store" +) + +var _ store.OrgStore = (*orgStore)(nil) + +type orgStore struct{ *handleStore } + +// UpdateOrgContactEmail implements store.OrgStore. +func (*orgStore) UpdateOrgContactEmail(ctx context.Context, h db.Handler, org int64, email string) error { + query := h.Rebind(` + UPDATE organizations + SET + contact_email = ? + WHERE + id = ? + `) + + _, err := h.ExecContext(ctx, query, email, org) + return err +} + +// ListOrgs implements store.OrgStore. +func (*orgStore) ListOrgs(ctx context.Context, h db.Handler, uid int64) ([]models.Organization, error) { + var m []models.Organization + query := h.Rebind(` + SELECT + o.*, + h AS handle + FROM + organizations o + JOIN handles h ON h.id = o.handle_id + JOIN organization_members om ON om.org_id = o.id + WHERE + o.user_id = ? + `) + err := h.SelectContext(ctx, &m, query, uid) + return m, err +} + +// Delete implements store.OrgStore. +func (s *orgStore) DeleteOrgByID(ctx context.Context, h db.Handler, user, id int64) error { + _, err := s.getOrgByIDWithAccess(ctx, h, user, id, access.AdminAccess) + if err != nil { + return err + } + query := h.Rebind(`DELETE FROM organizations WHERE id = ?;`) + _, err = h.ExecContext(ctx, query, id) + return err +} + +// Create implements store.OrgStore. +func (s *orgStore) CreateOrg(ctx context.Context, h db.Handler, user int64, name, email string) (models.Organization, error) { + handle, err := s.CreateHandle(ctx, h, name) + if err != nil { + return models.Organization{}, err + } + + query := h.Rebind(` + INSERT INTO + organizations (handle_id, contact_email, updated_at) + VALUES + (?, ?, CURRENT_TIMESTAMP) RETURNING id; + `) + + var id int64 + if err := h.GetContext(ctx, &id, query, handle, email); err != nil { + return models.Organization{}, err + } + if err := s.AddUserToOrg(ctx, h, id, user, access.AdminAccess); err != nil { + return models.Organization{}, err + } + + return s.GetOrgByID(ctx, h, user, id) +} + +func (*orgStore) UpdateUserAccessInOrg(ctx context.Context, h db.Handler, org, user int64, lvl access.AccessLevel) error { + query := h.Rebind(` + UPDATE organization_members + WHERE + organization_id = ? + AND user_id = ? + SET + access_level = ? + `) + _, err := h.ExecContext(ctx, query, org, user, lvl) + return err +} + +func (*orgStore) RemoveUserFromOrg(ctx context.Context, h db.Handler, org, user int64) error { + query := h.Rebind(` + DELETE FROM organization_members + WHERE + organization_id = ? + AND user_id = ? + `) + _, err := h.ExecContext(ctx, query, org, user) + return err +} + +func (*orgStore) AddUserToOrg(ctx context.Context, h db.Handler, org, user int64, lvl access.AccessLevel) error { + query := h.Rebind(` + INSERT INTO + organization_members ( + organization_id, + user_id, + access_level, + updated_at + ) + VALUES + (?, ?, ?, CURRENT_TIMESTAMP); + `) + _, err := h.ExecContext(ctx, query, org, user, lvl) + return err +} + +// FindByName implements store.OrgStore. +func (*orgStore) FindOrgByHandle(ctx context.Context, h db.Handler, user int64, name string) (models.Organization, error) { + var m models.Organization + query := h.Rebind(` + SELECT + o.*, + h AS handle + FROM + organizations o + JOIN handles h ON h.id = o.handle_id + JOIN organization_members om ON om.organization_id = o.id + WHERE + om.user_id = ? + AND h.handle = ?; + `) + err := h.GetContext(ctx, &m, query, user, name) + return m, err +} + +// GetByID implements store.OrgStore. +func (s *orgStore) GetOrgByID(ctx context.Context, h db.Handler, user, id int64) (models.Organization, error) { + return s.getOrgByIDWithAccess(ctx, h, user, id, access.ReadOnlyAccess) +} + +func (*orgStore) getOrgByIDWithAccess(ctx context.Context, h db.Handler, user, id int64, level access.AccessLevel) (models.Organization, error) { + var m models.Organization + query := h.Rebind(` + SELECT + o.*, + h AS handle + FROM + organizations o + JOIN handles h ON h.id = o.handle_id + JOIN organization_members om ON om.organization_id = o.id + WHERE + om.user_id = ? + AND id = ? + AND om.access_level >= ?; + `) + err := h.GetContext(ctx, &m, query, user, id, level) + return m, err +} diff --git a/pkg/store/database/team.go b/pkg/store/database/team.go new file mode 100644 index 000000000..bf103bf51 --- /dev/null +++ b/pkg/store/database/team.go @@ -0,0 +1,155 @@ +package database + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/access" + "github.com/charmbracelet/soft-serve/pkg/db" + "github.com/charmbracelet/soft-serve/pkg/db/models" + "github.com/charmbracelet/soft-serve/pkg/store" +) + +// TODO: should we return all the org's teams if the user is an org admin? +// if so, need to join organization_members too on the selects below. + +var _ store.TeamStore = (*teamStore)(nil) + +type teamStore struct{} + +// RemoveUserFromTeam implements store.TeamStore. +func (*teamStore) RemoveUserFromTeam(ctx context.Context, h db.Handler, team int64, user int64) error { + // TODO: caller perms + query := h.Rebind(` + DELETE FROM team_members + WHERE + team_id = ? + AND user_id = ? + `) + _, err := h.ExecContext(ctx, query, team, user) + return err +} + +// UpdateUserAccessInTeam implements store.TeamStore. +func (*teamStore) UpdateUserAccessInTeam(ctx context.Context, h db.Handler, team int64, user int64, lvl access.AccessLevel) error { + // TODO: caller perms + query := h.Rebind(` + UPDATE team_members + WHERE + team_id = ? + AND user_id = ? + SET + access_level = ? + `) + _, err := h.ExecContext(ctx, query, team, user, lvl) + return err +} + +// AddUserToTeam implements store.TeamStore. +func (*teamStore) AddUserToTeam(ctx context.Context, h db.Handler, team int64, user int64, lvl access.AccessLevel) error { + // TODO: caller perms + query := h.Rebind(` + INSERT INTO + team_members (team_id, user_id, access_level, updated_at) + VALUES + (?, ?, ?, CURRENT_TIMESTAMP); + `) + _, err := h.ExecContext(ctx, query, team, user, lvl) + return err +} + +// CreateTeam implements store.TeamStore. +func (s *teamStore) CreateTeam(ctx context.Context, h db.Handler, user, org int64, name string) (models.Team, error) { + // TODO: caller perms + // TODO: what the access_level column does on team? + query := h.Rebind(` + INSERT INTO + teams (organization_id, name) + VALUES + (?, ?) RETURNING * + `) + var team models.Team + if err := h.GetContext(ctx, &team, query, org, name); err != nil { + return models.Team{}, err + } + return team, s.AddUserToTeam(ctx, h, team.ID, user, access.AdminAccess) +} + +// DeleteTeamByID implements store.TeamStore. +func (*teamStore) DeleteTeamByID(ctx context.Context, h db.Handler, id int64) error { + // TODO: caller perms + query := h.Rebind(` + DELETE FROM teams + WHERE + id = ? + `) + _, err := h.ExecContext(ctx, query, id) + return err +} + +// FindTeamByName implements store.TeamStore. +func (*teamStore) FindTeamByName(ctx context.Context, h db.Handler, uid int64, name string) ([]models.Team, error) { + query := h.Rebind(` + SELECT + t.* + FROM + teams t + JOIN team_members tm ON tm.team_id = t.id + WHERE + tm.user_id = ? + AND t.name = ? + `) + var teams []models.Team + err := h.SelectContext(ctx, &teams, query, uid, name) + return teams, err +} + +// FindTeamByOrgName implements store.TeamStore. +func (*teamStore) FindTeamByOrgName(ctx context.Context, h db.Handler, user int64, org int64, name string) (models.Team, error) { + query := h.Rebind(` + SELECT + t.* + FROM + teams t + JOIN team_members tm ON tm.team_id = t.id + WHERE + tm.user_id = ? + AND t.organization_id = ? + AND t.name = ? + `) + var team models.Team + err := h.GetContext(ctx, &team, query, user, org, name) + return team, err +} + +// GetTeamByID implements store.TeamStore. +func (*teamStore) GetTeamByID(ctx context.Context, h db.Handler, uid, id int64) (models.Team, error) { + query := h.Rebind(` + SELECT + t.* + FROM + teams t + JOIN team_members tm ON tm.team_id = t.id + WHERE + tm.user_id = ? + AND t.id = ? + `) + var team models.Team + err := h.GetContext(ctx, &team, query, uid, id) + return team, err +} + +// ListTeams implements store.TeamStore. +func (*teamStore) ListTeams(ctx context.Context, h db.Handler, uid int64) ([]models.Team, error) { + query := h.Rebind(` + SELECT + t.* + FROM + teams t + JOIN team_members tm ON tm.team_id = t.id + WHERE + tm.user_id = ? + `) + var teams []models.Team + err := h.SelectContext(ctx, &teams, query, uid) + return teams, err +} diff --git a/pkg/store/database/user.go b/pkg/store/database/user.go index 86b161f47..b815d546a 100644 --- a/pkg/store/database/user.go +++ b/pkg/store/database/user.go @@ -12,19 +12,21 @@ import ( "golang.org/x/crypto/ssh" ) -type userStore struct{} +type userStore struct{ *handleStore } var _ store.UserStore = (*userStore)(nil) // AddPublicKeyByUsername implements store.UserStore. func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } var userID int64 - if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT id FROM users WHERE username = ?`), username); err != nil { + if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT users.id FROM users + INNER JOIN handles ON handles.id = users.handle_id + WHERE handles.handle = ?;`), username); err != nil { return err } @@ -37,23 +39,31 @@ func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, use } // CreateUser implements store.UserStore. -func (*userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error { - username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { +func (s *userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error { + handleID, err := s.CreateHandle(ctx, tx, username) + if err != nil { return err } - query := tx.Rebind(`INSERT INTO users (username, admin, updated_at) - VALUES (?, ?, CURRENT_TIMESTAMP) RETURNING id;`) + query := tx.Rebind(` + INSERT INTO + users (handle_id, admin, updated_at) + VALUES + (?, ?, CURRENT_TIMESTAMP) RETURNING id; + `) var userID int64 - if err := tx.GetContext(ctx, &userID, query, username, isAdmin); err != nil { + if err := tx.GetContext(ctx, &userID, query, handleID, isAdmin); err != nil { return err } for _, pk := range pks { - query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at) - VALUES (?, ?, CURRENT_TIMESTAMP);`) + query := tx.Rebind(` + INSERT INTO + public_keys (user_id, public_key, updated_at) + VALUES + (?, ?, CURRENT_TIMESTAMP); + `) ak := sshutils.MarshalAuthorizedKey(pk) _, err := tx.ExecContext(ctx, query, userID, ak) if err != nil { @@ -67,11 +77,11 @@ func (*userStore) CreateUser(ctx context.Context, tx db.Handler, username string // DeleteUserByUsername implements store.UserStore. func (*userStore) DeleteUserByUsername(ctx context.Context, tx db.Handler, username string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } - query := tx.Rebind(`DELETE FROM users WHERE username = ?;`) + query := tx.Rebind(`DELETE FROM users WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`) _, err := tx.ExecContext(ctx, query, username) return err } @@ -98,12 +108,12 @@ func (*userStore) FindUserByPublicKey(ctx context.Context, tx db.Handler, pk ssh // FindUserByUsername implements store.UserStore. func (*userStore) FindUserByUsername(ctx context.Context, tx db.Handler, username string) (models.User, error) { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return models.User{}, err } var m models.User - query := tx.Rebind(`SELECT * FROM users WHERE username = ?;`) + query := tx.Rebind(`SELECT * FROM users WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`) err := tx.GetContext(ctx, &m, query, username) return m, err } @@ -153,14 +163,14 @@ func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx db.Handler, id // ListPublicKeysByUsername implements store.UserStore. func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, username string) ([]ssh.PublicKey, error) { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return nil, err } var aks []string query := tx.Rebind(`SELECT public_key FROM public_keys INNER JOIN users ON users.id = public_keys.user_id - WHERE users.username = ? + WHERE users.handle_id = (SELECT id FROM handles WHERE handle = ?) ORDER BY public_keys.id ASC;`) err := tx.SelectContext(ctx, &aks, query, username) if err != nil { @@ -182,12 +192,14 @@ func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, u // RemovePublicKeyByUsername implements store.UserStore. func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } query := tx.Rebind(`DELETE FROM public_keys - WHERE user_id = (SELECT id FROM users WHERE username = ?) + WHERE user_id = (SELECT id FROM users WHERE handle_id = ( + SELECT id FROM handles WHERE handle = ? + )) AND public_key = ?;`) _, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk)) return err @@ -196,11 +208,11 @@ func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, // SetAdminByUsername implements store.UserStore. func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, username string, isAdmin bool) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } - query := tx.Rebind(`UPDATE users SET admin = ? WHERE username = ?;`) + query := tx.Rebind(`UPDATE users SET admin = ? WHERE handle_id = (SELECT id FROM handles WHERE handle = ?)`) _, err := tx.ExecContext(ctx, query, isAdmin, username) return err } @@ -208,16 +220,16 @@ func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, usernam // SetUsernameByUsername implements store.UserStore. func (*userStore) SetUsernameByUsername(ctx context.Context, tx db.Handler, username string, newUsername string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } newUsername = strings.ToLower(newUsername) - if err := utils.ValidateUsername(newUsername); err != nil { + if err := utils.ValidateHandle(newUsername); err != nil { return err } - query := tx.Rebind(`UPDATE users SET username = ? WHERE username = ?;`) + query := tx.Rebind(`UPDATE handles SET handle = ? WHERE handle = ?;`) _, err := tx.ExecContext(ctx, query, newUsername, username) return err } @@ -232,11 +244,41 @@ func (*userStore) SetUserPassword(ctx context.Context, tx db.Handler, userID int // SetUserPasswordByUsername implements store.UserStore. func (*userStore) SetUserPasswordByUsername(ctx context.Context, tx db.Handler, username string, password string) error { username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { + if err := utils.ValidateHandle(username); err != nil { return err } - query := tx.Rebind(`UPDATE users SET password = ? WHERE username = ?;`) + query := tx.Rebind(`UPDATE users SET password = ? WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`) _, err := tx.ExecContext(ctx, query, password, username) return err } + +// AddUserEmail implements store.UserStore. +func (*userStore) AddUserEmail(ctx context.Context, tx db.Handler, userID int64, email string, isPrimary bool) error { + query := tx.Rebind(`INSERT INTO user_emails (user_id, email, is_primary, updated_at) + VALUES (?, ?, ?, CURRENT_TIMESTAMP);`) + _, err := tx.ExecContext(ctx, query, userID, email, isPrimary) + return err +} + +// ListUserEmails implements store.UserStore. +func (*userStore) ListUserEmails(ctx context.Context, tx db.Handler, userID int64) ([]models.UserEmail, error) { + var ms []models.UserEmail + query := tx.Rebind(`SELECT * FROM user_emails WHERE user_id = ?;`) + err := tx.SelectContext(ctx, &ms, query, userID) + return ms, err +} + +// UpdateUserEmail implements store.UserStore. +func (*userStore) UpdateUserEmail(ctx context.Context, tx db.Handler, userID int64, oldEmail string, newEmail string, isPrimary bool) error { + query := tx.Rebind(`UPDATE user_emails SET email = ?, is_primary = ?, updated_at = CURRENT_TIMESTAMP WHERE user_id = ? AND email = ?;`) + _, err := tx.ExecContext(ctx, query, newEmail, isPrimary, userID, oldEmail) + return err +} + +// DeleteUserEmail implements store.UserStore. +func (*userStore) DeleteUserEmail(ctx context.Context, tx db.Handler, userID int64, email string) error { + query := tx.Rebind(`DELETE FROM user_emails WHERE user_id = ? AND email = ?;`) + _, err := tx.ExecContext(ctx, query, userID, email) + return err +} diff --git a/pkg/store/handle.go b/pkg/store/handle.go new file mode 100644 index 000000000..570a56022 --- /dev/null +++ b/pkg/store/handle.go @@ -0,0 +1,19 @@ +package store + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/db" + "github.com/charmbracelet/soft-serve/pkg/db/models" +) + +// HandleStore is a store for username handles. +type HandleStore interface { + GetHandleByID(ctx context.Context, h db.Handler, id int64) (models.Handle, error) + GetHandleByHandle(ctx context.Context, h db.Handler, handle string) (models.Handle, error) + GetHandleByUserID(ctx context.Context, h db.Handler, userID int64) (models.Handle, error) + ListHandlesForIDs(ctx context.Context, h db.Handler, ids []int64) ([]models.Handle, error) + UpdateHandle(ctx context.Context, h db.Handler, id int64, handle string) error + CreateHandle(ctx context.Context, h db.Handler, handle string) (int64, error) + DeleteHandle(ctx context.Context, h db.Handler, id int64) error +} diff --git a/pkg/store/org.go b/pkg/store/org.go new file mode 100644 index 000000000..4657fb6f1 --- /dev/null +++ b/pkg/store/org.go @@ -0,0 +1,24 @@ +package store + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/access" + "github.com/charmbracelet/soft-serve/pkg/db" + "github.com/charmbracelet/soft-serve/pkg/db/models" +) + +// OrgStore is a store for organizations. +type OrgStore interface { + CreateOrg(ctx context.Context, h db.Handler, user int64, name, email string) (models.Organization, error) + ListOrgs(ctx context.Context, h db.Handler, user int64) ([]models.Organization, error) + GetOrgByID(ctx context.Context, h db.Handler, user, id int64) (models.Organization, error) + FindOrgByHandle(ctx context.Context, h db.Handler, user int64, name string) (models.Organization, error) + DeleteOrgByID(ctx context.Context, h db.Handler, user, id int64) error + AddUserToOrg(ctx context.Context, h db.Handler, org, user int64, lvl access.AccessLevel) error + RemoveUserFromOrg(ctx context.Context, h db.Handler, org, user int64) error + UpdateUserAccessInOrg(ctx context.Context, h db.Handler, org, user int64, lvl access.AccessLevel) error + UpdateOrgContactEmail(ctx context.Context, h db.Handler, org int64, email string) error + // TODO: rename org? + // XXX: what else? +} diff --git a/pkg/store/store.go b/pkg/store/store.go index 41490cbf8..a002e5c38 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -4,9 +4,12 @@ package store type Store interface { RepositoryStore UserStore + OrgStore + TeamStore CollaboratorStore SettingStore LFSStore AccessTokenStore WebhookStore + HandleStore } diff --git a/pkg/store/team.go b/pkg/store/team.go new file mode 100644 index 000000000..026f87091 --- /dev/null +++ b/pkg/store/team.go @@ -0,0 +1,22 @@ +package store + +import ( + "context" + + "github.com/charmbracelet/soft-serve/pkg/access" + "github.com/charmbracelet/soft-serve/pkg/db" + "github.com/charmbracelet/soft-serve/pkg/db/models" +) + +// TeamStore is a store for teams. +type TeamStore interface { + CreateTeam(ctx context.Context, h db.Handler, user, org int64, name string) (models.Team, error) + ListTeams(ctx context.Context, h db.Handler, user int64) ([]models.Team, error) + GetTeamByID(ctx context.Context, h db.Handler, user, id int64) (models.Team, error) + FindTeamByOrgName(ctx context.Context, h db.Handler, user, org int64, name string) (models.Team, error) + FindTeamByName(ctx context.Context, h db.Handler, user int64, name string) ([]models.Team, error) + DeleteTeamByID(ctx context.Context, h db.Handler, id int64) error + AddUserToTeam(ctx context.Context, h db.Handler, team, user int64, lvl access.AccessLevel) error + RemoveUserFromTeam(ctx context.Context, h db.Handler, team, user int64) error + UpdateUserAccessInTeam(ctx context.Context, h db.Handler, team, user int64, lvl access.AccessLevel) error +} diff --git a/pkg/store/user.go b/pkg/store/user.go index 260a7279c..ac8df8ef6 100644 --- a/pkg/store/user.go +++ b/pkg/store/user.go @@ -25,4 +25,9 @@ type UserStore interface { ListPublicKeysByUsername(ctx context.Context, h db.Handler, username string) ([]ssh.PublicKey, error) SetUserPassword(ctx context.Context, h db.Handler, userID int64, password string) error SetUserPasswordByUsername(ctx context.Context, h db.Handler, username string, password string) error + + AddUserEmail(ctx context.Context, h db.Handler, userID int64, email string, isPrimary bool) error + ListUserEmails(ctx context.Context, h db.Handler, userID int64) ([]models.UserEmail, error) + UpdateUserEmail(ctx context.Context, h db.Handler, userID int64, oldEmail string, newEmail string, isPrimary bool) error + DeleteUserEmail(ctx context.Context, h db.Handler, userID int64, email string) error } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index a98cb139c..5fb46f8c4 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -17,19 +17,19 @@ func SanitizeRepo(repo string) string { return repo } -// ValidateUsername returns an error if any of the given usernames are invalid. -func ValidateUsername(username string) error { - if username == "" { - return fmt.Errorf("username cannot be empty") +// ValidateHandle returns an error if any of the given usernames are invalid. +func ValidateHandle(handle string) error { + if handle == "" { + return fmt.Errorf("cannot be empty") } - if !unicode.IsLetter(rune(username[0])) { - return fmt.Errorf("username must start with a letter") + if !unicode.IsLetter(rune(handle[0])) { + return fmt.Errorf("must start with a letter") } - for _, r := range username { + for _, r := range handle { if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '-' { - return fmt.Errorf("username can only contain letters, numbers, and hyphens") + return fmt.Errorf("can only contain letters, numbers, and hyphens") } } diff --git a/pkg/web/git_lfs.go b/pkg/web/git_lfs.go index dccd87f95..1fb1136b2 100644 --- a/pkg/web/git_lfs.go +++ b/pkg/web/git_lfs.go @@ -520,7 +520,15 @@ func serviceLfsLocksCreate(w http.ResponseWriter, r *http.Request) { }) return } - lockOwner.Name = owner.Username + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + logger.Error("error getting lock owner handle", "err", err) + renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ + Message: "internal server error", + }) + return + } + lockOwner.Name = handle.Handle } errResp.Lock.Owner = lockOwner } @@ -632,6 +640,15 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { return } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + logger.Error("error getting lock owner handle", "err", err) + renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ + Message: "internal server error", + }) + return + } + renderJSON(w, http.StatusOK, lfs.LockListResponse{ Locks: []lfs.Lock{ { @@ -639,7 +656,7 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { Path: lock.Path, LockedAt: lock.CreatedAt, Owner: lfs.Owner{ - Name: owner.Username, + Name: handle.Handle, }, }, }, @@ -670,6 +687,15 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { return } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + logger.Error("error getting lock owner handle", "err", err) + renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ + Message: "internal server error", + }) + return + } + renderJSON(w, http.StatusOK, lfs.LockListResponse{ Locks: []lfs.Lock{ { @@ -677,7 +703,7 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { Path: lock.Path, LockedAt: lock.CreatedAt, Owner: lfs.Owner{ - Name: owner.Username, + Name: handle.Handle, }, }, }, @@ -695,11 +721,11 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { } lockList := make([]lfs.Lock, len(locks)) - users := map[int64]models.User{} + users := map[int64]userModel{} for i, lock := range locks { owner, ok := users[lock.UserID] if !ok { - owner, err = datastore.GetUserByID(ctx, dbx, lock.UserID) + user, err := datastore.GetUserByID(ctx, dbx, lock.UserID) if err != nil { logger.Error("error getting lock owner", "err", err) renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ @@ -707,7 +733,15 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { }) return } - users[lock.UserID] = owner + handle, err := datastore.GetHandleByUserID(ctx, dbx, user.ID) + if err != nil { + logger.Error("error getting lock owner handle", "err", err) + renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ + Message: "internal server error", + }) + return + } + users[lock.UserID] = userModel{User: user, Handle: handle} } lockList[i] = lfs.Lock{ @@ -715,7 +749,7 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { Path: lock.Path, LockedAt: lock.CreatedAt, Owner: lfs.Owner{ - Name: owner.Username, + Name: owner.Handle.Handle, }, } } @@ -730,6 +764,11 @@ func serviceLfsLocksGet(w http.ResponseWriter, r *http.Request) { renderJSON(w, http.StatusOK, resp) } +type userModel struct { + models.User + models.Handle +} + // POST: /.git/info/lfs/objects/locks/verify func serviceLfsLocksVerify(w http.ResponseWriter, r *http.Request) { if !isLfs(r) { @@ -786,11 +825,11 @@ func serviceLfsLocksVerify(w http.ResponseWriter, r *http.Request) { return } - users := map[int64]models.User{} + users := map[int64]userModel{} for _, lock := range locks { owner, ok := users[lock.UserID] if !ok { - owner, err = datastore.GetUserByID(ctx, dbx, lock.UserID) + user, err := datastore.GetUserByID(ctx, dbx, lock.UserID) if err != nil { logger.Error("error getting lock owner", "err", err) renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ @@ -798,7 +837,15 @@ func serviceLfsLocksVerify(w http.ResponseWriter, r *http.Request) { }) return } - users[lock.UserID] = owner + handle, err := datastore.GetHandleByUserID(ctx, dbx, user.ID) + if err != nil { + logger.Error("error getting lock owner handle", "err", err) + renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ + Message: "internal server error", + }) + return + } + users[lock.UserID] = userModel{User: user, Handle: handle} } l := lfs.Lock{ @@ -806,7 +853,7 @@ func serviceLfsLocksVerify(w http.ResponseWriter, r *http.Request) { Path: lock.Path, LockedAt: lock.CreatedAt, Owner: lfs.Owner{ - Name: owner.Username, + Name: owner.Handle.Handle, }, } @@ -893,13 +940,22 @@ func serviceLfsLocksDelete(w http.ResponseWriter, r *http.Request) { return } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + logger.Error("error getting lock owner handle", "err", err) + renderJSON(w, http.StatusInternalServerError, lfs.ErrorResponse{ + Message: "internal server error", + }) + return + } + // Delete another user's lock l := lfs.Lock{ ID: strconv.FormatInt(lock.ID, 10), Path: lock.Path, LockedAt: lock.CreatedAt, Owner: lfs.Owner{ - Name: owner.Username, + Name: handle.Handle, }, } if req.Force { diff --git a/pkg/webhook/branch_tag.go b/pkg/webhook/branch_tag.go index 89771e9fe..6630f00ff 100644 --- a/pkg/webhook/branch_tag.go +++ b/pkg/webhook/branch_tag.go @@ -75,8 +75,13 @@ func NewBranchTagEvent(ctx context.Context, user proto.User, repo proto.Reposito return BranchTagEvent{}, db.WrapError(err) } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + return BranchTagEvent{}, db.WrapError(err) + } + payload.Repository.Owner.ID = owner.ID - payload.Repository.Owner.Username = owner.Username + payload.Repository.Owner.Username = handle.Handle payload.Repository.DefaultBranch, err = proto.RepositoryDefaultBranch(repo) if err != nil { return BranchTagEvent{}, err diff --git a/pkg/webhook/collaborator.go b/pkg/webhook/collaborator.go index e7b737b16..32f7e2f75 100644 --- a/pkg/webhook/collaborator.go +++ b/pkg/webhook/collaborator.go @@ -2,6 +2,7 @@ package webhook import ( "context" + "fmt" "github.com/charmbracelet/soft-serve/pkg/access" "github.com/charmbracelet/soft-serve/pkg/db" @@ -63,8 +64,13 @@ func NewCollaboratorEvent(ctx context.Context, user proto.User, repo proto.Repos return CollaboratorEvent{}, db.WrapError(err) } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + return CollaboratorEvent{}, db.WrapError(err) + } + payload.Repository.Owner.ID = owner.ID - payload.Repository.Owner.Username = owner.Username + payload.Repository.Owner.Username = handle.Handle payload.Repository.DefaultBranch, err = proto.RepositoryDefaultBranch(repo) if err != nil { return CollaboratorEvent{}, err @@ -76,7 +82,11 @@ func NewCollaboratorEvent(ctx context.Context, user proto.User, repo proto.Repos } payload.AccessLevel = collab.AccessLevel - payload.Collaborator.ID = collab.UserID + if !collab.UserID.Valid { + return CollaboratorEvent{}, fmt.Errorf("collaborator user ID is invalid") + } + + payload.Collaborator.ID = collab.UserID.Int64 payload.Collaborator.Username = collabUsername return payload, nil diff --git a/pkg/webhook/push.go b/pkg/webhook/push.go index 6ef6062cd..f225a00d1 100644 --- a/pkg/webhook/push.go +++ b/pkg/webhook/push.go @@ -66,8 +66,13 @@ func NewPushEvent(ctx context.Context, user proto.User, repo proto.Repository, r return PushEvent{}, db.WrapError(err) } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + return PushEvent{}, db.WrapError(err) + } + payload.Repository.Owner.ID = owner.ID - payload.Repository.Owner.Username = owner.Username + payload.Repository.Owner.Username = handle.Handle // Find commits. r, err := repo.Open() diff --git a/pkg/webhook/repository.go b/pkg/webhook/repository.go index 3cad5c39f..3c0879944 100644 --- a/pkg/webhook/repository.go +++ b/pkg/webhook/repository.go @@ -74,8 +74,13 @@ func NewRepositoryEvent(ctx context.Context, user proto.User, repo proto.Reposit return RepositoryEvent{}, db.WrapError(err) } + handle, err := datastore.GetHandleByUserID(ctx, dbx, owner.ID) + if err != nil { + return RepositoryEvent{}, db.WrapError(err) + } + payload.Repository.Owner.ID = owner.ID - payload.Repository.Owner.Username = owner.Username + payload.Repository.Owner.Username = handle.Handle payload.Repository.DefaultBranch, _ = proto.RepositoryDefaultBranch(repo) return payload, nil diff --git a/testscript/testdata/help.txtar b/testscript/testdata/help.txtar index d6756ca02..bdc9ce046 100644 --- a/testscript/testdata/help.txtar +++ b/testscript/testdata/help.txtar @@ -23,6 +23,7 @@ Available Commands: help Help about any command info Show your info jwt Generate a JSON Web Token + org Manage organizations pubkey Manage your public keys repo Manage repositories set-username Set your username From 8fbd41a53198f872ac0d160eac15cee836ac030b Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Mon, 11 Dec 2023 13:47:28 -0500 Subject: [PATCH 2/2] feat: add user email support --- pkg/backend/user.go | 122 +++++++++++++++++- .../0004_create_orgs_teams_postgres.up.sql | 9 +- .../0004_create_orgs_teams_sqlite.up.sql | 5 +- pkg/proto/user.go | 17 +++ pkg/ssh/cmd/org.go | 2 +- pkg/ssh/cmd/user.go | 76 ++++++++++- pkg/store/database/org.go | 9 ++ pkg/store/database/user.go | 56 ++++++-- pkg/store/user.go | 6 +- pkg/utils/utils.go | 22 ++++ testscript/script_test.go | 2 + testscript/testdata/user_management.txtar | 42 +++++- 12 files changed, 345 insertions(+), 23 deletions(-) diff --git a/pkg/backend/user.go b/pkg/backend/user.go index bc651d2fe..1ef1dd405 100644 --- a/pkg/backend/user.go +++ b/pkg/backend/user.go @@ -26,6 +26,7 @@ func (d *Backend) User(ctx context.Context, username string) (proto.User, error) var m models.User var pks []ssh.PublicKey var hl models.Handle + var ems []proto.UserEmail if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { var err error m, err = d.store.FindUserByUsername(ctx, tx, username) @@ -38,6 +39,15 @@ func (d *Backend) User(ctx context.Context, username string) (proto.User, error) return err } + emails, err := d.store.ListUserEmails(ctx, tx, m.ID) + if err != nil { + return err + } + + for _, e := range emails { + ems = append(ems, &userEmail{e}) + } + hl, err = d.store.GetHandleByUserID(ctx, tx, m.ID) return err }); err != nil { @@ -53,6 +63,7 @@ func (d *Backend) User(ctx context.Context, username string) (proto.User, error) user: m, publicKeys: pks, handle: hl, + emails: ems, }, nil } @@ -61,6 +72,7 @@ func (d *Backend) UserByID(ctx context.Context, id int64) (proto.User, error) { var m models.User var pks []ssh.PublicKey var hl models.Handle + var ems []proto.UserEmail if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { var err error m, err = d.store.GetUserByID(ctx, tx, id) @@ -73,6 +85,15 @@ func (d *Backend) UserByID(ctx context.Context, id int64) (proto.User, error) { return err } + emails, err := d.store.ListUserEmails(ctx, tx, m.ID) + if err != nil { + return err + } + + for _, e := range emails { + ems = append(ems, &userEmail{e}) + } + hl, err = d.store.GetHandleByUserID(ctx, tx, m.ID) return err }); err != nil { @@ -88,6 +109,7 @@ func (d *Backend) UserByID(ctx context.Context, id int64) (proto.User, error) { user: m, publicKeys: pks, handle: hl, + emails: ems, }, nil } @@ -98,6 +120,7 @@ func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (proto. var m models.User var pks []ssh.PublicKey var hl models.Handle + var ems []proto.UserEmail if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { var err error m, err = d.store.FindUserByPublicKey(ctx, tx, pk) @@ -110,6 +133,15 @@ func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (proto. return err } + emails, err := d.store.ListUserEmails(ctx, tx, m.ID) + if err != nil { + return err + } + + for _, e := range emails { + ems = append(ems, &userEmail{e}) + } + hl, err = d.store.GetHandleByUserID(ctx, tx, m.ID) return err }); err != nil { @@ -125,6 +157,7 @@ func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (proto. user: m, publicKeys: pks, handle: hl, + emails: ems, }, nil } @@ -134,6 +167,7 @@ func (d *Backend) UserByAccessToken(ctx context.Context, token string) (proto.Us var m models.User var pks []ssh.PublicKey var hl models.Handle + var ems []proto.UserEmail token = HashToken(token) if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { @@ -156,6 +190,15 @@ func (d *Backend) UserByAccessToken(ctx context.Context, token string) (proto.Us return err } + emails, err := d.store.ListUserEmails(ctx, tx, m.ID) + if err != nil { + return err + } + + for _, e := range emails { + ems = append(ems, &userEmail{e}) + } + hl, err = d.store.GetHandleByUserID(ctx, tx, m.ID) return err }); err != nil { @@ -171,6 +214,7 @@ func (d *Backend) UserByAccessToken(ctx context.Context, token string) (proto.Us user: m, publicKeys: pks, handle: hl, + emails: ems, }, nil } @@ -228,7 +272,7 @@ func (d *Backend) AddPublicKey(ctx context.Context, username string, pk ssh.Publ // It implements backend.Backend. func (d *Backend) CreateUser(ctx context.Context, username string, opts proto.UserOptions) (proto.User, error) { if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { - return d.store.CreateUser(ctx, tx, username, opts.Admin, opts.PublicKeys) + return d.store.CreateUser(ctx, tx, username, opts.Admin, opts.PublicKeys, opts.Emails) }); err != nil { return nil, db.WrapError(err) } @@ -335,10 +379,60 @@ func (d *Backend) SetPassword(ctx context.Context, username string, rawPassword ) } +// AddUserEmail adds an email to a user. +func (d *Backend) AddUserEmail(ctx context.Context, user proto.User, email string) error { + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.AddUserEmail(ctx, tx, user.ID(), email, false) + }), + ) +} + +// ListUserEmails lists the emails of a user. +func (d *Backend) ListUserEmails(ctx context.Context, user proto.User) ([]proto.UserEmail, error) { + var ems []proto.UserEmail + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + emails, err := d.store.ListUserEmails(ctx, tx, user.ID()) + if err != nil { + return err + } + + for _, e := range emails { + ems = append(ems, &userEmail{e}) + } + + return nil + }); err != nil { + return nil, db.WrapError(err) + } + + return ems, nil +} + +// RemoveUserEmail deletes an email for a user. +// The deleted email must not be the primary email. +func (d *Backend) RemoveUserEmail(ctx context.Context, user proto.User, email string) error { + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.RemoveUserEmail(ctx, tx, user.ID(), email) + }), + ) +} + +// SetUserPrimaryEmail sets the primary email of a user. +func (d *Backend) SetUserPrimaryEmail(ctx context.Context, user proto.User, email string) error { + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.SetUserPrimaryEmail(ctx, tx, user.ID(), email) + }), + ) +} + type user struct { user models.User publicKeys []ssh.PublicKey handle models.Handle + emails []proto.UserEmail } var _ proto.User = (*user)(nil) @@ -371,3 +465,29 @@ func (u *user) Password() string { return "" } + +// Emails implements proto.User. +func (u *user) Emails() []proto.UserEmail { + return u.emails +} + +type userEmail struct { + email models.UserEmail +} + +var _ proto.UserEmail = (*userEmail)(nil) + +// Email implements proto.UserEmail. +func (e *userEmail) Email() string { + return e.email.Email +} + +// ID implements proto.UserEmail. +func (e *userEmail) ID() int64 { + return e.email.ID +} + +// IsPrimary implements proto.UserEmail. +func (e *userEmail) IsPrimary() bool { + return e.email.IsPrimary +} diff --git a/pkg/db/migrate/0004_create_orgs_teams_postgres.up.sql b/pkg/db/migrate/0004_create_orgs_teams_postgres.up.sql index 74a10ffef..ea52bd366 100644 --- a/pkg/db/migrate/0004_create_orgs_teams_postgres.up.sql +++ b/pkg/db/migrate/0004_create_orgs_teams_postgres.up.sql @@ -69,7 +69,7 @@ CREATE TABLE IF NOT EXISTS user_emails ( id SERIAL PRIMARY KEY, user_id INTEGER NOT NULL, email TEXT NOT NULL UNIQUE, - is_primary BOOLEAN NOT NULL, + is_primary BOOLEAN NOT NULL DEFAULT false, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL, CONSTRAINT user_id_fk @@ -78,6 +78,9 @@ CREATE TABLE IF NOT EXISTS user_emails ( ON UPDATE CASCADE ); +-- Create unique index for primary email +CREATE UNIQUE INDEX user_emails_user_id_is_primary_idx ON user_emails (user_id) WHERE is_primary; + -- Add name to users table ALTER TABLE users ADD COLUMN name TEXT; @@ -112,7 +115,7 @@ ALTER TABLE repos ADD CONSTRAINT org_id_fk ALTER TABLE repos ALTER COLUMN user_id DROP NOT NULL; -- Check that both user_id and org_id can't be null -ALTER TABLE repos ADD CONSTRAINT user_id_org_id_not_null CHECK (user_id IS NULL <> org_id IS NULL); +ALTER TABLE repos ADD CONSTRAINT user_id_org_id_not_null CHECK ((user_id IS NULL) <> (org_id IS NULL)); -- Add team_id to collabs table ALTER TABLE collabs ADD COLUMN team_id INTEGER; @@ -125,7 +128,7 @@ ALTER TABLE collabs ADD CONSTRAINT team_id_fk ALTER TABLE collabs ALTER COLUMN user_id DROP NOT NULL; -- Check that both user_id and team_id can't be null -ALTER TABLE collabs ADD CONSTRAINT user_id_team_id_not_null CHECK (user_id IS NULL <> team_id IS NULL); +ALTER TABLE collabs ADD CONSTRAINT user_id_team_id_not_null CHECK ((user_id IS NULL) <> (team_id IS NULL)); -- Alter unique constraint on collabs table ALTER TABLE collabs DROP CONSTRAINT collabs_user_id_repo_id_key; diff --git a/pkg/db/migrate/0004_create_orgs_teams_sqlite.up.sql b/pkg/db/migrate/0004_create_orgs_teams_sqlite.up.sql index d9a41eb0d..17e4a88ff 100644 --- a/pkg/db/migrate/0004_create_orgs_teams_sqlite.up.sql +++ b/pkg/db/migrate/0004_create_orgs_teams_sqlite.up.sql @@ -71,7 +71,7 @@ CREATE TABLE IF NOT EXISTS user_emails ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, email TEXT NOT NULL UNIQUE, - is_primary BOOLEAN NOT NULL, + is_primary BOOLEAN NOT NULL DEFAULT false, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL, CONSTRAINT user_id_fk @@ -80,6 +80,9 @@ CREATE TABLE IF NOT EXISTS user_emails ( ON UPDATE CASCADE ); +-- Create unique index for primary email +CREATE UNIQUE INDEX user_emails_user_id_is_primary_idx ON user_emails (user_id) WHERE is_primary; + ALTER TABLE users RENAME TO _users_old; CREATE TABLE IF NOT EXISTS users ( diff --git a/pkg/proto/user.go b/pkg/proto/user.go index 7b334122d..c6c65b1bf 100644 --- a/pkg/proto/user.go +++ b/pkg/proto/user.go @@ -14,6 +14,8 @@ type User interface { PublicKeys() []ssh.PublicKey // Password returns the user's password hash. Password() string + // Emails returns the user's emails. + Emails() []UserEmail } // UserOptions are options for creating a user. @@ -22,4 +24,19 @@ type UserOptions struct { Admin bool // PublicKeys are the user's public keys. PublicKeys []ssh.PublicKey + // Emails are the user's emails. + // The first email in the slice will be set as the user's primary email. + Emails []string +} + +// UserEmail represents a user's email address. +type UserEmail interface { + // ID returns the email's ID. + ID() int64 + + // Email returns the email address. + Email() string + + // IsPrimary returns whether the email is the user's primary email. + IsPrimary() bool } diff --git a/pkg/ssh/cmd/org.go b/pkg/ssh/cmd/org.go index 251d92c3f..1d1a5fe15 100644 --- a/pkg/ssh/cmd/org.go +++ b/pkg/ssh/cmd/org.go @@ -33,7 +33,7 @@ func OrgCommand() *cobra.Command { Use: "list", Short: "List organizations", Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() be := backend.FromContext(ctx) user := proto.UserFromContext(ctx) diff --git a/pkg/ssh/cmd/user.go b/pkg/ssh/cmd/user.go index 21981f9cb..9944aa243 100644 --- a/pkg/ssh/cmd/user.go +++ b/pkg/ssh/cmd/user.go @@ -22,9 +22,9 @@ func UserCommand() *cobra.Command { var admin bool var key string userCreateCommand := &cobra.Command{ - Use: "create USERNAME", + Use: "create USERNAME [EMAIL]", Short: "Create a new user", - Args: cobra.ExactArgs(1), + Args: cobra.MinimumNArgs(1), PersistentPreRunE: checkIfAdmin, RunE: func(cmd *cobra.Command, args []string) error { var pubkeys []ssh.PublicKey @@ -45,6 +45,10 @@ func UserCommand() *cobra.Command { PublicKeys: pubkeys, } + if len(args) > 1 { + opts.Emails = append(opts.Emails, args[1]) + } + _, err := be.CreateUser(ctx, username, opts) return err }, @@ -166,6 +170,14 @@ func UserCommand() *cobra.Command { cmd.Printf(" %s\n", sshutils.MarshalAuthorizedKey(pk)) } + emails := user.Emails() + if len(emails) > 0 { + cmd.Printf("Emails:\n") + for _, e := range emails { + cmd.Printf(" %s (primary: %v)\n", e.Email(), e.IsPrimary()) + } + } + return nil }, } @@ -185,6 +197,63 @@ func UserCommand() *cobra.Command { }, } + userAddEmailCommand := &cobra.Command{ + Use: "add-email USERNAME EMAIL", + Short: "Add an email to a user", + Args: cobra.ExactArgs(2), + PersistentPreRunE: checkIfAdmin, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + username := args[0] + email := args[1] + u, err := be.User(ctx, username) + if err != nil { + return err + } + + return be.AddUserEmail(ctx, u, email) + }, + } + + userRemoveEmailCommand := &cobra.Command{ + Use: "remove-email USERNAME EMAIL", + Short: "Remove an email from a user", + Args: cobra.ExactArgs(2), + PersistentPreRunE: checkIfAdmin, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + username := args[0] + email := args[1] + u, err := be.User(ctx, username) + if err != nil { + return err + } + + return be.RemoveUserEmail(ctx, u, email) + }, + } + + userSetPrimaryEmailCommand := &cobra.Command{ + Use: "set-primary-email USERNAME EMAIL", + Short: "Set a user's primary email", + Args: cobra.ExactArgs(2), + PersistentPreRunE: checkIfAdmin, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + username := args[0] + email := args[1] + u, err := be.User(ctx, username) + if err != nil { + return err + } + + return be.SetUserPrimaryEmail(ctx, u, email) + }, + } + cmd.AddCommand( userCreateCommand, userAddPubkeyCommand, @@ -194,6 +263,9 @@ func UserCommand() *cobra.Command { userRemovePubkeyCommand, userSetAdminCommand, userSetUsernameCommand, + userAddEmailCommand, + userRemoveEmailCommand, + userSetPrimaryEmailCommand, ) return cmd diff --git a/pkg/store/database/org.go b/pkg/store/database/org.go index de6fd592a..60571dc9d 100644 --- a/pkg/store/database/org.go +++ b/pkg/store/database/org.go @@ -7,6 +7,7 @@ import ( "github.com/charmbracelet/soft-serve/pkg/db" "github.com/charmbracelet/soft-serve/pkg/db/models" "github.com/charmbracelet/soft-serve/pkg/store" + "github.com/charmbracelet/soft-serve/pkg/utils" ) var _ store.OrgStore = (*orgStore)(nil) @@ -15,6 +16,10 @@ type orgStore struct{ *handleStore } // UpdateOrgContactEmail implements store.OrgStore. func (*orgStore) UpdateOrgContactEmail(ctx context.Context, h db.Handler, org int64, email string) error { + if err := utils.ValidateEmail(email); err != nil { + return err + } + query := h.Rebind(` UPDATE organizations SET @@ -58,6 +63,10 @@ func (s *orgStore) DeleteOrgByID(ctx context.Context, h db.Handler, user, id int // Create implements store.OrgStore. func (s *orgStore) CreateOrg(ctx context.Context, h db.Handler, user int64, name, email string) (models.Organization, error) { + if err := utils.ValidateEmail(email); err != nil { + return models.Organization{}, err + } + handle, err := s.CreateHandle(ctx, h, name) if err != nil { return models.Organization{}, err diff --git a/pkg/store/database/user.go b/pkg/store/database/user.go index b815d546a..f3ed59f39 100644 --- a/pkg/store/database/user.go +++ b/pkg/store/database/user.go @@ -2,6 +2,7 @@ package database import ( "context" + "fmt" "strings" "github.com/charmbracelet/soft-serve/pkg/db" @@ -39,7 +40,7 @@ func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, use } // CreateUser implements store.UserStore. -func (s *userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error { +func (s *userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey, emails []string) error { handleID, err := s.CreateHandle(ctx, tx, username) if err != nil { return err @@ -71,6 +72,12 @@ func (s *userStore) CreateUser(ctx context.Context, tx db.Handler, username stri } } + for i, e := range emails { + if err := s.AddUserEmail(ctx, tx, userID, e, i == 0); err != nil { + return err + } + } + return nil } @@ -255,6 +262,9 @@ func (*userStore) SetUserPasswordByUsername(ctx context.Context, tx db.Handler, // AddUserEmail implements store.UserStore. func (*userStore) AddUserEmail(ctx context.Context, tx db.Handler, userID int64, email string, isPrimary bool) error { + if err := utils.ValidateEmail(email); err != nil { + return err + } query := tx.Rebind(`INSERT INTO user_emails (user_id, email, is_primary, updated_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP);`) _, err := tx.ExecContext(ctx, query, userID, email, isPrimary) @@ -269,16 +279,40 @@ func (*userStore) ListUserEmails(ctx context.Context, tx db.Handler, userID int6 return ms, err } -// UpdateUserEmail implements store.UserStore. -func (*userStore) UpdateUserEmail(ctx context.Context, tx db.Handler, userID int64, oldEmail string, newEmail string, isPrimary bool) error { - query := tx.Rebind(`UPDATE user_emails SET email = ?, is_primary = ?, updated_at = CURRENT_TIMESTAMP WHERE user_id = ? AND email = ?;`) - _, err := tx.ExecContext(ctx, query, newEmail, isPrimary, userID, oldEmail) - return err +// RemoveUserEmail implements store.UserStore. +func (*userStore) RemoveUserEmail(ctx context.Context, tx db.Handler, userID int64, email string) error { + var e models.UserEmail + query := tx.Rebind(`DELETE FROM user_emails WHERE user_id = ? AND email = ? RETURNING *;`) + if err := tx.GetContext(ctx, &e, query, userID, email); err != nil { + return err + } + + if e.IsPrimary { + return fmt.Errorf("cannot remove primary email") + } else if e.ID == 0 { + return db.ErrRecordNotFound + } + + return nil } -// DeleteUserEmail implements store.UserStore. -func (*userStore) DeleteUserEmail(ctx context.Context, tx db.Handler, userID int64, email string) error { - query := tx.Rebind(`DELETE FROM user_emails WHERE user_id = ? AND email = ?;`) - _, err := tx.ExecContext(ctx, query, userID, email) - return err +// SetUserPrimaryEmail implements store.UserStore. +func (*userStore) SetUserPrimaryEmail(ctx context.Context, tx db.Handler, userID int64, email string) error { + query := tx.Rebind(`UPDATE user_emails SET is_primary = FALSE WHERE user_id = ?;`) + _, err := tx.ExecContext(ctx, query, userID) + if err != nil { + return err + } + + var emailID int64 + query = tx.Rebind(`UPDATE user_emails SET is_primary = TRUE WHERE user_id = ? AND email = ? RETURNING id;`) + if err := tx.GetContext(ctx, &emailID, query, userID, email); err != nil { + return err + } + + if emailID == 0 { + return db.ErrRecordNotFound + } + + return nil } diff --git a/pkg/store/user.go b/pkg/store/user.go index ac8df8ef6..6e6f1c2a5 100644 --- a/pkg/store/user.go +++ b/pkg/store/user.go @@ -15,7 +15,7 @@ type UserStore interface { FindUserByPublicKey(ctx context.Context, h db.Handler, pk ssh.PublicKey) (models.User, error) FindUserByAccessToken(ctx context.Context, h db.Handler, token string) (models.User, error) GetAllUsers(ctx context.Context, h db.Handler) ([]models.User, error) - CreateUser(ctx context.Context, h db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error + CreateUser(ctx context.Context, h db.Handler, username string, isAdmin bool, pks []ssh.PublicKey, emails []string) error DeleteUserByUsername(ctx context.Context, h db.Handler, username string) error SetUsernameByUsername(ctx context.Context, h db.Handler, username string, newUsername string) error SetAdminByUsername(ctx context.Context, h db.Handler, username string, isAdmin bool) error @@ -28,6 +28,6 @@ type UserStore interface { AddUserEmail(ctx context.Context, h db.Handler, userID int64, email string, isPrimary bool) error ListUserEmails(ctx context.Context, h db.Handler, userID int64) ([]models.UserEmail, error) - UpdateUserEmail(ctx context.Context, h db.Handler, userID int64, oldEmail string, newEmail string, isPrimary bool) error - DeleteUserEmail(ctx context.Context, h db.Handler, userID int64, email string) error + RemoveUserEmail(ctx context.Context, h db.Handler, userID int64, email string) error + SetUserPrimaryEmail(ctx context.Context, h db.Handler, userID int64, email string) error } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 5fb46f8c4..0b6660e88 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -1,12 +1,20 @@ package utils import ( + "errors" "fmt" + "net/mail" "path" "strings" "unicode" ) +var ( + + // ErrInvalidEmail indicates that an email address is invalid. + ErrInvalidEmail = errors.New("invalid email address") +) + // SanitizeRepo returns a sanitized version of the given repository name. func SanitizeRepo(repo string) string { repo = strings.TrimPrefix(repo, "/") @@ -50,3 +58,17 @@ func ValidateRepo(repo string) error { return nil } + +// ValidateEmail returns an error if the given email address is invalid. +func ValidateEmail(email string) error { + if strings.ContainsAny(email, " <>") { + return ErrInvalidEmail + } + + _, err := mail.ParseAddress(email) + if err != nil { + return fmt.Errorf("%w: %s", ErrInvalidEmail, err) + } + + return nil +} diff --git a/testscript/script_test.go b/testscript/script_test.go index 32ebe77c1..cec209d7f 100644 --- a/testscript/script_test.go +++ b/testscript/script_test.go @@ -76,6 +76,7 @@ func TestScript(t *testing.T) { key, admin1 := mkkey("admin1") _, admin2 := mkkey("admin2") _, user1 := mkkey("user1") + _, user2 := mkkey("user2") testscript.Run(t, testscript.Params{ Dir: "./testdata/", @@ -117,6 +118,7 @@ func TestScript(t *testing.T) { e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey()) e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey()) e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey()) + e.Setenv("USER2_AUTHORIZED_KEY", user2.AuthorizedKey()) e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts")) e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config")) diff --git a/testscript/testdata/user_management.txtar b/testscript/testdata/user_management.txtar index f65397090..15f1863c9 100644 --- a/testscript/testdata/user_management.txtar +++ b/testscript/testdata/user_management.txtar @@ -1,7 +1,7 @@ # vi: set ft=conf # convert crlf to lf on windows -[windows] dos2unix info.txt admin_key_list1.txt admin_key_list2.txt list1.txt list2.txt foo_info1.txt foo_info2.txt foo_info3.txt foo_info4.txt foo_info5.txt +[windows] dos2unix info.txt admin_key_list1.txt admin_key_list2.txt list1.txt list2.txt foo_info1.txt foo_info2.txt foo_info3.txt foo_info4.txt foo_info5.txt bar_info.txt # start soft serve exec soft serve & @@ -68,6 +68,38 @@ soft user delete foo2 soft user list cmpenv stdout list1.txt +# create a new user with an invalid email +! soft user create bar --key "$USER2_AUTHORIZED_KEY" "foobar" +stderr 'invalid email address.*' + +# create a new user with a valid email +soft user create bar --key "$USER2_AUTHORIZED_KEY" "foo@bar.baz" +! stdout . +# add email to existing user +soft user add-email bar "foobar@fubar.baz" +! stdout . +# add existing email +! soft user add-email bar "foobar@fubar.baz" +stderr 'duplicate key.*' + +# get new user info +soft user info bar +cmpenv stdout bar_info.txt + +# remove primary email from user +! soft user remove-email bar "foo@bar.baz" +stderr 'cannot remove primary email.*' + +# set primary email that doesn't exist +! soft user set-primary-email bar "foobar@foofoo.foo" +stderr 'no rows in result set.*' +# set primary email +soft user set-primary-email bar "foobar@fubar.baz" +! stdout . +# remove other email +soft user remove-email bar "foo@bar.baz" +! stdout . + # stop the server [windows] stopserver [windows] ! stderr . @@ -112,3 +144,11 @@ $ADMIN1_AUTHORIZED_KEY $ADMIN2_AUTHORIZED_KEY -- admin_key_list2.txt -- $ADMIN1_AUTHORIZED_KEY +-- bar_info.txt -- +Username: bar +Admin: false +Public keys: + $USER2_AUTHORIZED_KEY +Emails: + foo@bar.baz (primary: true) + foobar@fubar.baz (primary: false)