Skip to content

Commit

Permalink
refactor: add UUID validation and improve file handling in copy and m…
Browse files Browse the repository at this point in the history
…ove operations
  • Loading branch information
bhunter234 committed Jan 5, 2025
1 parent 77e463b commit 74bae5f
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 46 deletions.
5 changes: 5 additions & 0 deletions internal/database/migrations/20250105180250_index.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- +goose Up
-- +goose StatementBegin
DROP INDEX IF EXISTS teldrive.idx_files_unique_file;

-- +goose StatementEnd
96 changes: 51 additions & 45 deletions pkg/services/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ import (
"strconv"
"strings"

"github.com/google/uuid"
"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
"github.com/jackc/pgx/v5/pgtype"
"github.com/tgdrive/teldrive/internal/api"
"github.com/tgdrive/teldrive/internal/auth"
"github.com/tgdrive/teldrive/internal/category"
Expand Down Expand Up @@ -68,6 +70,10 @@ func randInt64() (int64, error) {
b := &buffer{Buf: buf[:]}
return b.long()
}
func isUUID(str string) bool {
_, err := uuid.Parse(str)
return err == nil
}

type fullFileDB struct {
models.File
Expand Down Expand Up @@ -182,15 +188,18 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
return nil, &apiError{err: err}
}

var destRes []models.File

if err := a.db.Raw("select * from teldrive.create_directories(?, ?)", userId, req.Destination).
Scan(&destRes).Error; err != nil {
return nil, &apiError{err: err}
var parentId string
if !isUUID(req.Destination) {
var destRes []models.File
if err := a.db.Raw("select * from teldrive.create_directories(?, ?)", userId, req.Destination).
Scan(&destRes).Error; err != nil {
return nil, &apiError{err: err}
}
parentId = destRes[0].Id
} else {
parentId = req.Destination
}

dest := destRes[0]

dbFile := models.File{}

dbFile.Name = req.NewName.Or(file.Name)
Expand All @@ -201,7 +210,7 @@ func (a *apiService) FilesCopy(ctx context.Context, req *api.FileCopy, params ap
dbFile.UserID = userId
dbFile.Status = "active"
dbFile.ParentID = sql.NullString{
String: dest.Id,
String: parentId,
Valid: true,
}
dbFile.ChannelID = &channelId
Expand All @@ -227,7 +236,6 @@ func (a *apiService) FilesCreate(ctx context.Context, fileIn *api.File) (*api.Fi
)

if fileIn.Path.Value != "" {
path = strings.TrimSpace(fileIn.Path.Value)
path = strings.ReplaceAll(path, "//", "/")
if path != "/" {
path = strings.TrimSuffix(path, "/")
Expand Down Expand Up @@ -323,18 +331,10 @@ func (a *apiService) FilesCreateShare(ctx context.Context, req *api.FileShareCre

func (a *apiService) FilesDelete(ctx context.Context, req *api.FileDelete) error {
userId, _ := auth.GetUser(ctx)
if req.Source.Value == "" && len(req.Ids) == 0 {
return &apiError{err: errors.New("source or ids is required"), code: 409}
}
if req.Source.Value != "" && len(req.Ids) == 0 {
if err := a.db.Exec("call teldrive.delete_folder_recursive($1 , $2)", req.Source.Value, userId).Error; err != nil {
return &apiError{err: err}
}
} else if req.Source.Value == "" && len(req.Ids) > 0 {
if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil {
return &apiError{err: err}
}
if err := a.db.Exec("call teldrive.delete_files_bulk($1 , $2)", req.Ids, userId).Error; err != nil {
return &apiError{err: err}
}

return nil
}

Expand Down Expand Up @@ -417,20 +417,23 @@ func (a *apiService) FilesMkdir(ctx context.Context, req *api.FileMkDir) error {

func (a *apiService) FilesMove(ctx context.Context, req *api.FileMove) error {
userId, _ := auth.GetUser(ctx)
if req.Source.Value == "" && len(req.Ids) == 0 {
return &apiError{err: errors.New("source or ids is required"), code: 409}
items := pgtype.Array[string]{
Elements: req.Ids,
Valid: true,
Dims: []pgtype.ArrayDimension{{Length: int32(len(req.Ids)), LowerBound: 1}},
}
if req.Source.Value != "" && len(req.Ids) > 0 {
if err := a.db.Exec("select * from teldrive.move_items($1 , $2 , $3)", req.Ids, req.Destination, userId).Error; err != nil {
if !isUUID(req.Destination) {
r, err := a.getFileFromPath(req.Destination, userId)
if err != nil {
return &apiError{err: err}
}
req.Destination = r.Id
}
if req.Source.Value == "" && len(req.Ids) == 0 {
if err := a.db.Exec("select * from teldrive.move_directory(? , ? , ?)", req.Source.Value,
req.Destination, userId).Error; err != nil {
return &apiError{err: err}
}
if err := a.db.Model(&models.File{}).Where("id = any(?)", items).Where("user_id = ?", userId).
Update("parent_id", req.Destination).Error; err != nil {
return &apiError{err: err}
}

return nil

}
Expand Down Expand Up @@ -469,10 +472,6 @@ func (a *apiService) FilesStream(ctx context.Context, params api.FilesStreamPara

func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, params api.FilesUpdateParams) (*api.File, error) {

var (
files []models.File
chain *gorm.DB
)
updateDb := models.File{}
if req.Name.Value != "" {
updateDb.Name = req.Name.Value
Expand All @@ -495,18 +494,17 @@ func (a *apiService) FilesUpdate(ctx context.Context, req *api.FileUpdate, param
updateDb.UpdatedAt = req.UpdatedAt.Value
}

chain = a.db.Model(&files).Clauses(clause.Returning{}).Where("id = ?", params.ID).Updates(updateDb)

if chain.Error != nil {
return nil, &apiError{err: chain.Error}
}
if chain.RowsAffected == 0 {
return nil, &apiError{err: errors.New("file not found"), code: 404}
if err := a.db.Model(&models.File{}).Where("id = ?", params.ID).Updates(updateDb).Error; err != nil {
return nil, &apiError{err: err}
}

a.cache.Delete(fmt.Sprintf("files:%s", params.ID))

return mapper.ToFileOut(files[0], false), nil
file := models.File{}
if err := a.db.Where("id = ?", params.ID).First(&file).Error; err != nil {
return nil, &apiError{err: err}
}
return mapper.ToFileOut(file, false), nil
}

func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpdate, params api.FilesUpdatePartsParams) error {
Expand All @@ -515,10 +513,8 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
var file models.File

updatePayload := models.File{
UpdatedAt: req.UpdatedAt,
Size: utils.Ptr(req.Size),
Size: utils.Ptr(req.Size),
}

if req.ChannelId.Value == 0 {
channelId, err := getDefaultChannel(a.db, a.cache, userId)
if err != nil {
Expand All @@ -539,11 +535,21 @@ func (a *apiService) FilesUpdateParts(ctx context.Context, req *api.FilePartsUpd
}
updatePayload.Parts = datatypes.NewJSONSlice(parts)
}
if req.Name.Value != "" {
updatePayload.Name = req.Name.Value
}
if req.ParentId.Value != "" {
updatePayload.ParentID = sql.NullString{
String: req.ParentId.Value,
Valid: true,
}
}
err := a.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("id = ?", params.ID).First(&file).Error; err != nil {
return err
}
if err := tx.Model(models.File{}).Where("id = ?", params.ID).Updates(updatePayload).Error; err != nil {
if err := tx.Model(models.File{}).Where("id = ?", params.ID).Updates(updatePayload).
Update("updated_at", req.UpdatedAt).Error; err != nil {
return err
}
if req.UploadId.Value != "" {
Expand Down
7 changes: 6 additions & 1 deletion pkg/services/file_query_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,12 @@ func (afb *fileQueryBuilder) applyFileSpecificFilters(query *gorm.DB, filesQuery
}

if filesQuery.ParentId.Value != "" {
query = query.Where("parent_id = ?", filesQuery.ParentId.Value)
if filesQuery.ParentId.Value == "nil" {
query = query.Where("parent_id is NULL")
} else {
query = query.Where("parent_id = ?", filesQuery.ParentId.Value)
}

}

if filesQuery.ParentId.Value == "" && filesQuery.Path.Value != "" && filesQuery.Query.Value == "" {
Expand Down

0 comments on commit 74bae5f

Please sign in to comment.