From 74bae5fdc2d79fd90a4c6093a06de79d786caeed Mon Sep 17 00:00:00 2001 From: Bhunter <180028024+bhunter234@users.noreply.github.com> Date: Sun, 5 Jan 2025 18:57:18 +0100 Subject: [PATCH] refactor: add UUID validation and improve file handling in copy and move operations --- .../migrations/20250105180250_index.sql | 5 + pkg/services/file.go | 96 ++++++++++--------- pkg/services/file_query_builder.go | 7 +- 3 files changed, 62 insertions(+), 46 deletions(-) create mode 100644 internal/database/migrations/20250105180250_index.sql diff --git a/internal/database/migrations/20250105180250_index.sql b/internal/database/migrations/20250105180250_index.sql new file mode 100644 index 00000000..56a09662 --- /dev/null +++ b/internal/database/migrations/20250105180250_index.sql @@ -0,0 +1,5 @@ +-- +goose Up +-- +goose StatementBegin +DROP INDEX IF EXISTS teldrive.idx_files_unique_file; + +-- +goose StatementEnd diff --git a/pkg/services/file.go b/pkg/services/file.go index f71b2e4a..f1b11b87 100644 --- a/pkg/services/file.go +++ b/pkg/services/file.go @@ -13,8 +13,10 @@ import ( "strconv" "strings" + "github.com/google/uuid" "github.com/gotd/td/telegram" "github.com/gotd/td/tg" + "github.com/jackc/pgx/v5/pgtype" "github.com/tgdrive/teldrive/internal/api" "github.com/tgdrive/teldrive/internal/auth" "github.com/tgdrive/teldrive/internal/category" @@ -68,6 +70,10 @@ func randInt64() (int64, error) { b := &buffer{Buf: buf[:]} return b.long() } +func isUUID(str string) bool { + _, err := uuid.Parse(str) + return err == nil +} type fullFileDB struct { models.File @@ -182,15 +188,18 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap return nil, &apiError{err: err} } - var destRes []models.File - - if err := a.db.Raw("select * from teldrive.create_directories(?, ?)", userId, req.Destination). - Scan(&destRes).Error; err != nil { - return nil, &apiError{err: err} + var parentId string + if !isUUID(req.Destination) { + var destRes []models.File + if err := a.db.Raw("select * from teldrive.create_directories(?, ?)", userId, req.Destination). + Scan(&destRes).Error; err != nil { + return nil, &apiError{err: err} + } + parentId = destRes[0].Id + } else { + parentId = req.Destination } - dest := destRes[0] - dbFile := models.File{} dbFile.Name = req.NewName.Or(file.Name) @@ -201,7 +210,7 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap dbFile.UserID = userId dbFile.Status = "active" dbFile.ParentID = sql.NullString{ - String: dest.Id, + String: parentId, Valid: true, } dbFile.ChannelID = &channelId @@ -227,7 +236,6 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi ) if fileIn.Path.Value != "" { - path = strings.TrimSpace(fileIn.Path.Value) path = strings.ReplaceAll(path, "//", "/") if path != "/" { path = strings.TrimSuffix(path, "/") @@ -323,18 +331,10 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error { userId, _ := auth.GetUser(ctx) - if req.Source.Value == "" && len(req.Ids) == 0 { - return &apiError{err: errors.New("source or ids is required"), code: 409} - } - if req.Source.Value != "" && len(req.Ids) == 0 { - if err := a.db.Exec("call teldrive.delete_folder_recursive($1 , $2)", req.Source.Value, userId).Error; err != nil { - return &apiError{err: err} - } - } else if req.Source.Value == "" && len(req.Ids) > 0 { - if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil { - return &apiError{err: err} - } + if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil { + return &apiError{err: err} } + return nil } @@ -417,20 +417,23 @@ func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error { func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error { userId, _ := auth.GetUser(ctx) - if req.Source.Value == "" && len(req.Ids) == 0 { - return &apiError{err: errors.New("source or ids is required"), code: 409} + items := pgtype.Array[string]{ + Elements: req.Ids, + Valid: true, + Dims: []pgtype.ArrayDimension{{Length: int32(len(req.Ids)), LowerBound: 1}}, } - if req.Source.Value != "" && len(req.Ids) > 0 { - if err := a.db.Exec("select * from teldrive.move_items($1 , $2 , $3)", req.Ids, req.Destination, userId).Error; err != nil { + if !isUUID(req.Destination) { + r, err := a.getFileFromPath(req.Destination, userId) + if err != nil { return &apiError{err: err} } + req.Destination = r.Id } - if req.Source.Value == "" && len(req.Ids) == 0 { - if err := a.db.Exec("select * from teldrive.move_directory(? , ? , ?)", req.Source.Value, - req.Destination, userId).Error; err != nil { - return &apiError{err: err} - } + if err := a.db.Model(&models.File{}).Where("id = any(?)", items).Where("user_id = ?", userId). + Update("parent_id", req.Destination).Error; err != nil { + return &apiError{err: err} } + return nil } @@ -469,10 +472,6 @@ func (a *apiService) FilesStream(ctx context.Context, params api.FilesStreamPara func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, params api.FilesUpdateParams) (*api.File, error) { - var ( - files []models.File - chain *gorm.DB - ) updateDb := models.File{} if req.Name.Value != "" { updateDb.Name = req.Name.Value @@ -495,18 +494,17 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param updateDb.UpdatedAt = req.UpdatedAt.Value } - chain = a.db.Model(&files).Clauses(clause.Returning{}).Where("id = ?", params.ID).Updates(updateDb) - - if chain.Error != nil { - return nil, &apiError{err: chain.Error} - } - if chain.RowsAffected == 0 { - return nil, &apiError{err: errors.New("file not found"), code: 404} + if err := a.db.Model(&models.File{}).Where("id = ?", params.ID).Updates(updateDb).Error; err != nil { + return nil, &apiError{err: err} } a.cache.Delete(fmt.Sprintf("files:%s", params.ID)) - return mapper.ToFileOut(files[0], false), nil + file := models.File{} + if err := a.db.Where("id = ?", params.ID).First(&file).Error; err != nil { + return nil, &apiError{err: err} + } + return mapper.ToFileOut(file, false), nil } func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpdate, params api.FilesUpdatePartsParams) error { @@ -515,10 +513,8 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd var file models.File updatePayload := models.File{ - UpdatedAt: req.UpdatedAt, - Size: utils.Ptr(req.Size), + Size: utils.Ptr(req.Size), } - if req.ChannelId.Value == 0 { channelId, err := getDefaultChannel(a.db, a.cache, userId) if err != nil { @@ -539,11 +535,21 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd } updatePayload.Parts = datatypes.NewJSONSlice(parts) } + if req.Name.Value != "" { + updatePayload.Name = req.Name.Value + } + if req.ParentId.Value != "" { + updatePayload.ParentID = sql.NullString{ + String: req.ParentId.Value, + Valid: true, + } + } err := a.db.Transaction(func(tx *gorm.DB) error { if err := tx.Where("id = ?", params.ID).First(&file).Error; err != nil { return err } - if err := tx.Model(models.File{}).Where("id = ?", params.ID).Updates(updatePayload).Error; err != nil { + if err := tx.Model(models.File{}).Where("id = ?", params.ID).Updates(updatePayload). + Update("updated_at", req.UpdatedAt).Error; err != nil { return err } if req.UploadId.Value != "" { diff --git a/pkg/services/file_query_builder.go b/pkg/services/file_query_builder.go index d2320d2a..19d0f2bc 100644 --- a/pkg/services/file_query_builder.go +++ b/pkg/services/file_query_builder.go @@ -97,7 +97,12 @@ func (afb *fileQueryBuilder) applyFileSpecificFilters(query *gorm.DB, filesQuery } if filesQuery.ParentId.Value != "" { - query = query.Where("parent_id = ?", filesQuery.ParentId.Value) + if filesQuery.ParentId.Value == "nil" { + query = query.Where("parent_id is NULL") + } else { + query = query.Where("parent_id = ?", filesQuery.ParentId.Value) + } + } if filesQuery.ParentId.Value == "" && filesQuery.Path.Value != "" && filesQuery.Query.Value == "" {