diff --git a/api/factory.go b/api/factory.go index 70201733..3fa0837d 100644 --- a/api/factory.go +++ b/api/factory.go @@ -236,7 +236,7 @@ func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req // 批量文件任务 elems := make([]batchtfile.TaskElement, 0, len(files)) for _, file := range files { - elem, err := batchtfile.NewTaskElement(stor, req.Path, file) + elem, err := batchtfile.NewTaskElement(f.ctx, stor, req.Path, file) if err != nil { return nil, fmt.Errorf("failed to create task element: %w", err) } diff --git a/client/bot/handlers/utils/shortcut/tftask.go b/client/bot/handlers/utils/shortcut/tftask.go index df5efa00..cc107250 100644 --- a/client/bot/handlers/utils/shortcut/tftask.go +++ b/client/bot/handlers/utils/shortcut/tftask.go @@ -152,7 +152,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st } if !dirPath.NeedNewForAlbum() { storPath := path.Join(dirPath.String(), file.Name()) - elem, err := batchtfile.NewTaskElement(fileStor, storPath, file) + elem, err := batchtfile.NewTaskElement(ctx, fileStor, storPath, file) if err != nil { logger.Errorf("Failed to create task element: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ @@ -189,7 +189,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st albumStor := afiles[0].storage for _, af := range afiles { afstorPath := path.Join(dirPath, albumDir, af.file.Name()) - elem, err := batchtfile.NewTaskElement(albumStor, afstorPath, af.file) + elem, err := batchtfile.NewTaskElement(ctx, albumStor, afstorPath, af.file) if err != nil { logger.Errorf("Failed to create task element for album file: %s", err) ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{ diff --git a/cmd/upload/cmd.go b/cmd/upload/cmd.go index 7e1c7b51..2c7e7f94 100644 --- a/cmd/upload/cmd.go +++ b/cmd/upload/cmd.go @@ -113,7 +113,7 @@ func Upload(cmd *cobra.Command, args []string) error { reader = file } - if err := stor.Save(ctx, reader, uploadPath); err != nil { + if _, err := stor.Save(ctx, reader, uploadPath); err != nil { if progressUI != nil { progressUI.SetError(err) progressUI.Wait() diff --git a/common/utils/tgutil/chatinfo.go b/common/utils/tgutil/chatinfo.go new file mode 100644 index 00000000..27d411a1 --- /dev/null +++ b/common/utils/tgutil/chatinfo.go @@ -0,0 +1,131 @@ +package tgutil + +import ( + "github.com/celestix/gotgproto/ext" + "github.com/charmbracelet/log" + "github.com/gotd/td/tg" +) + +// ChatInfoFromExt extracts chat title and username for the given peer. +func ChatInfoFromExt(extCtx *ext.Context, peer tg.PeerClass) (title, username string) { + if extCtx == nil { + return + } + chatID := ChatIdFromPeer(peer) + if chatID == 0 { + return + } + + if extCtx.Entities != nil { + if t, u, found := lookupEntities(extCtx.Entities, chatID); found { + return t, u + } + } + + return fetchFromAPI(extCtx, peer, chatID) +} + +func formatUserName(firstName, lastName string) string { + if lastName != "" { + return firstName + " " + lastName + } + return firstName +} + +func lookupEntities(entities *tg.Entities, chatID int64) (title, username string, found bool) { + if ch, ok := entities.Channels[chatID]; ok { + return ch.Title, ch.Username, true + } + if ch, ok := entities.Chats[chatID]; ok { + return ch.Title, "", true // tg.Chat has no Username field + } + if u, ok := entities.Users[chatID]; ok { + return formatUserName(u.FirstName, u.LastName), u.Username, true + } + return "", "", false +} + +func fetchFromAPI(extCtx *ext.Context, peer tg.PeerClass, chatID int64) (title, username string) { + if extCtx.Raw == nil { + return + } + + var err error + switch peer.(type) { + case *tg.PeerChannel: + title, username, err = fetchChannel(extCtx, chatID) + case *tg.PeerChat: + title, username, err = fetchChat(extCtx, chatID) + case *tg.PeerUser: + title, username, err = fetchUser(extCtx, chatID) + } + if err != nil { + log.Debug("Failed to fetch chat info from API", "chatID", chatID, "error", err) + } + return +} + +func resolveInputPeer(extCtx *ext.Context, chatID int64) (tg.InputPeerClass, error) { + return extCtx.ResolveInputPeerById(chatID) +} + +func fetchChannel(extCtx *ext.Context, chatID int64) (string, string, error) { + inputPeer, err := resolveInputPeer(extCtx, chatID) + if err != nil { + return "", "", err + } + ch, ok := inputPeer.(*tg.InputPeerChannel) + if !ok { + return "", "", nil + } + + result, err := extCtx.Raw.ChannelsGetChannels(extCtx, []tg.InputChannelClass{ + &tg.InputChannel{ChannelID: chatID, AccessHash: ch.AccessHash}, + }) + if err != nil { + return "", "", err + } + for _, c := range result.GetChats() { + if channel, ok := c.(*tg.Channel); ok && channel.ID == chatID { + return channel.Title, channel.Username, nil + } + } + return "", "", nil +} + +func fetchChat(extCtx *ext.Context, chatID int64) (string, string, error) { + result, err := extCtx.Raw.MessagesGetFullChat(extCtx, chatID) + if err != nil { + return "", "", err + } + for _, c := range result.GetChats() { + if chat, ok := c.(*tg.Chat); ok && chat.ID == chatID { + return chat.Title, "", nil + } + } + return "", "", nil +} + +func fetchUser(extCtx *ext.Context, chatID int64) (string, string, error) { + inputPeer, err := resolveInputPeer(extCtx, chatID) + if err != nil { + return "", "", err + } + u, ok := inputPeer.(*tg.InputPeerUser) + if !ok { + return "", "", nil + } + + users, err := extCtx.Raw.UsersGetUsers(extCtx, []tg.InputUserClass{ + &tg.InputUser{UserID: chatID, AccessHash: u.AccessHash}, + }) + if err != nil { + return "", "", err + } + for _, user := range users { + if u, ok := user.(*tg.User); ok && u.ID == chatID { + return formatUserName(u.FirstName, u.LastName), u.Username, nil + } + } + return "", "", nil +} diff --git a/config.example.toml b/config.example.toml index 11dd746d..55b41230 100644 --- a/config.example.toml +++ b/config.example.toml @@ -4,6 +4,7 @@ workers = 4 # 同时下载文件数 retry = 3 # 下载失败重试次数 threads = 4 # 单个任务下载使用的最大线程数 stream = false # 使用流式传输模式, 建议仅在硬盘空间十分有限时使用. +save_metadata = false # 保存文件时同时生成 .meta.json 元数据侧车文件 [log] # 日志级别, 可选: debug, info, warn, error, fatal diff --git a/config/viper.go b/config/viper.go index 43c7ffc5..c2711bb0 100644 --- a/config/viper.go +++ b/config/viper.go @@ -22,6 +22,7 @@ type Config struct { NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"` Threads int `toml:"threads" mapstructure:"threads" json:"threads"` Stream bool `toml:"stream" mapstructure:"stream" json:"stream"` + SaveMetadata bool `toml:"save_metadata" mapstructure:"save_metadata" json:"save_metadata"` Proxy string `toml:"proxy" mapstructure:"proxy" json:"proxy"` Log logConfig `toml:"log" mapstructure:"log" json:"log"` Aria2 aria2Config `toml:"aria2" mapstructure:"aria2" json:"aria2"` @@ -105,6 +106,7 @@ func Init(ctx context.Context, configFile ...string) error { "workers": 3, "retry": 3, "threads": 4, + "save_metadata": false, "log.level": "debug", // 缓存配置 diff --git a/core/tasks/aria2dl/execute.go b/core/tasks/aria2dl/execute.go index 43e53de3..3b129dbc 100644 --- a/core/tasks/aria2dl/execute.go +++ b/core/tasks/aria2dl/execute.go @@ -205,7 +205,7 @@ func (t *Task) transferFile(ctx context.Context, filePath string) error { logger.Infof("Transferring file %s to %s:%s", fileName, t.Storage.Name(), destPath) - if err := t.Storage.Save(ctx, f, destPath); err != nil { + if _, err := t.Storage.Save(ctx, f, destPath); err != nil { return fmt.Errorf("failed to save file %s to storage: %w", fileName, err) } diff --git a/core/tasks/aria2dl/task_test.go b/core/tasks/aria2dl/task_test.go index 4e6ba656..aca1a709 100644 --- a/core/tasks/aria2dl/task_test.go +++ b/core/tasks/aria2dl/task_test.go @@ -29,9 +29,9 @@ func (m *mockStorage) Init(ctx context.Context, config storconfig.StorageConfig) return nil } -func (m *mockStorage) Save(ctx context.Context, reader io.Reader, path string) error { +func (m *mockStorage) Save(ctx context.Context, reader io.Reader, path string) (string, error) { m.savePath = path - return nil + return path, nil } func (m *mockStorage) Exists(ctx context.Context, path string) bool { diff --git a/core/tasks/batchtfile/execute.go b/core/tasks/batchtfile/execute.go index 403d2502..0016bf2b 100644 --- a/core/tasks/batchtfile/execute.go +++ b/core/tasks/batchtfile/execute.go @@ -57,9 +57,12 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error { if elem.stream { pr, pw := io.Pipe() defer pr.Close() + var actualPath string errg, uploadCtx := errgroup.WithContext(ctx) errg.Go(func() error { - return elem.Storage.Save(uploadCtx, pr, elem.Path) + var err error + actualPath, err = elem.Storage.Save(uploadCtx, pr, elem.Path) + return err }) wr := ioutil.NewProgressWriter(pw, func(n int) { t.downloaded.Add(int64(n)) @@ -79,6 +82,9 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error { return fmt.Errorf("failed to download file in stream mode: %w", err) } logger.Info("File downloaded successfully in stream mode") + if err := elem.saveMetadata(ctx, actualPath); err != nil { + logger.Warnf("failed to save metadata: %s", err) + } return nil } logger.Info("Starting file download") @@ -112,6 +118,7 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error { return fmt.Errorf("failed to get file stat: %w", err) } vctx := context.WithValue(ctx, ctxkey.ContentLength, fileStat.Size()) + var actualPath string err = retry.Retry(func() error { var file *os.File file, err = os.Open(elem.localPath) @@ -119,11 +126,17 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error { return fmt.Errorf("failed to open cache file: %w", err) } defer file.Close() - if err = elem.Storage.Save(vctx, file, elem.Path); err != nil { + if actualPath, err = elem.Storage.Save(vctx, file, elem.Path); err != nil { logger.Errorf("Failed to save file: %s, retrying...", err) return err } return nil }, retry.Context(vctx), retry.RetryTimes(uint(config.C().Retry))) - return err + if err != nil { + return err + } + if err := elem.saveMetadata(ctx, actualPath); err != nil { + logger.Warnf("failed to save metadata: %s", err) + } + return nil } diff --git a/core/tasks/batchtfile/task.go b/core/tasks/batchtfile/task.go index c00d510d..caf07f70 100644 --- a/core/tasks/batchtfile/task.go +++ b/core/tasks/batchtfile/task.go @@ -1,12 +1,15 @@ package batchtfile import ( + "bytes" "context" "fmt" "path/filepath" "sync" "sync/atomic" + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/pkg/metadata" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/core" "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" @@ -24,6 +27,15 @@ type TaskElement struct { File tfile.TGFile localPath string stream bool + metadata []byte +} + +func (e *TaskElement) saveMetadata(ctx context.Context, actualPath string) error { + if len(e.metadata) == 0 { + return nil + } + _, err := e.Storage.Save(ctx, bytes.NewReader(e.metadata), actualPath+metadata.MetaSuffix) + return err } type Task struct { @@ -49,11 +61,25 @@ func (t *Task) Type() tasktype.TaskType { } func NewTaskElement( + ctx context.Context, stor storage.Storage, path string, file tfile.TGFile, ) (*TaskElement, error) { id := xid.New().String() + + var meta []byte + if config.C().SaveMetadata { + if fmsg, ok := file.(tfile.TGFileMessage); ok { + m := metadata.BuildFromMessage(ctx, fmsg.Message(), file.Name(), file.Size()) + var err error + meta, err = m.ToJSON() + if err != nil { + log.FromContext(ctx).Warnf("failed to marshal metadata: %s", err) + } + } + } + _, ok := stor.(storage.StorageCannotStream) if !config.C().Stream || ok { cachePath, err := filepath.Abs(filepath.Join(config.C().Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name()))) @@ -66,14 +92,16 @@ func NewTaskElement( Path: path, File: file, localPath: cachePath, + metadata: meta, }, nil } return &TaskElement{ - ID: id, - Storage: stor, - Path: path, - File: file, - stream: true, + ID: id, + Storage: stor, + Path: path, + File: file, + stream: true, + metadata: meta, }, nil } diff --git a/core/tasks/directlinks/execute.go b/core/tasks/directlinks/execute.go index c5f10d59..e65a3707 100644 --- a/core/tasks/directlinks/execute.go +++ b/core/tasks/directlinks/execute.go @@ -130,7 +130,8 @@ func (t *Task) processLink(ctx context.Context, file *File) error { } ctx = context.WithValue(ctx, ctxkey.ContentLength, file.Size) if t.stream { - return t.Storage.Save(ctx, resp.Body, filepath.Join(t.StorPath, file.Name)) + _, err := t.Storage.Save(ctx, resp.Body, filepath.Join(t.StorPath, file.Name)) + return err } cacheFile, err := fsutil.CreateFile(filepath.Join(config.C().Temp.BasePath, fmt.Sprintf("direct_%s_%s", t.ID, file.Name))) @@ -166,7 +167,8 @@ func (t *Task) processLink(ctx context.Context, file *File) error { if err != nil { return fmt.Errorf("failed to seek cache file for resource %s: %w", file.URL, err) } - return t.Storage.Save(ctx, cacheFile, filepath.Join(t.StorPath, file.Name)) + _, err = t.Storage.Save(ctx, cacheFile, filepath.Join(t.StorPath, file.Name)) + return err }, retry.RetryTimes(uint(config.C().Retry)), retry.Context(ctx)) if ctx.Err() != nil { return ctx.Err() diff --git a/core/tasks/parsed/execute.go b/core/tasks/parsed/execute.go index f97e6d95..3be05f4f 100644 --- a/core/tasks/parsed/execute.go +++ b/core/tasks/parsed/execute.go @@ -94,7 +94,8 @@ func (t *Task) processResource(ctx context.Context, resource parser.Resource) er return resp.ContentLength }()) if t.stream { - return t.Stor.Save(ctx, resp.Body, path.Join(t.StorPath, resource.Filename)) + _, err := t.Stor.Save(ctx, resp.Body, path.Join(t.StorPath, resource.Filename)) + return err } cacheFile, err := fsutil.CreateFile(filepath.Join(config.C().Temp.BasePath, fmt.Sprintf("resource_%s_%s", t.ID, resource.Filename))) @@ -130,7 +131,8 @@ func (t *Task) processResource(ctx context.Context, resource parser.Resource) er if err != nil { return fmt.Errorf("failed to seek cache file for resource %s: %w", resource.URL, err) } - return t.Stor.Save(ctx, cacheFile, path.Join(t.StorPath, resource.Filename)) + _, err = t.Stor.Save(ctx, cacheFile, path.Join(t.StorPath, resource.Filename)) + return err }, retry.Context(ctx), retry.RetryTimes(uint(config.C().Retry))) if ctx.Err() != nil { return ctx.Err() diff --git a/core/tasks/telegraph/execute.go b/core/tasks/telegraph/execute.go index dfa327b0..5e35450c 100644 --- a/core/tasks/telegraph/execute.go +++ b/core/tasks/telegraph/execute.go @@ -75,12 +75,12 @@ func (t *Task) processPic(ctx context.Context, picUrl string, index int) error { if err != nil { return fmt.Errorf("failed to seek cache file for picture %s: %w", filename, err) } - err = t.Stor.Save(ctx, cacheFile, path.Join(t.StorPath, filename)) + _, err = t.Stor.Save(ctx, cacheFile, path.Join(t.StorPath, filename)) if err != nil { return fmt.Errorf("failed to save picture %s: %w", filename, err) } } else { - err = t.Stor.Save(ctx, body, path.Join(t.StorPath, filename)) + _, err = t.Stor.Save(ctx, body, path.Join(t.StorPath, filename)) } if err != nil { diff --git a/core/tasks/tfile/execute.go b/core/tasks/tfile/execute.go index 4efbfc09..a67a3e56 100644 --- a/core/tasks/tfile/execute.go +++ b/core/tasks/tfile/execute.go @@ -57,13 +57,15 @@ func (t *Task) Execute(ctx context.Context) error { return fmt.Errorf("failed to get file stat: %w", err) } vctx := context.WithValue(ctx, ctxkey.ContentLength, fileStat.Size()) + var actualPath string err = retry.Retry(func() error { file, err := os.Open(t.localPath) if err != nil { return fmt.Errorf("failed to open cache file: %w", err) } defer file.Close() - if err = t.Storage.Save(vctx, file, t.Path); err != nil { + actualPath, err = t.Storage.Save(vctx, file, t.Path) + if err != nil { return fmt.Errorf("failed to save file: %w", err) } return nil @@ -71,5 +73,8 @@ func (t *Task) Execute(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to save file after retries: %w", err) } + if err := t.saveMetadata(ctx, actualPath); err != nil { + logger.Warnf("failed to save metadata: %s", err) + } return nil } diff --git a/core/tasks/tfile/stream.go b/core/tasks/tfile/stream.go index 881c31e3..bbf95114 100644 --- a/core/tasks/tfile/stream.go +++ b/core/tasks/tfile/stream.go @@ -15,9 +15,12 @@ func executeStream(ctx context.Context, task *Task) error { pr, pw := io.Pipe() defer pr.Close() + var actualPath string errg, uploadCtx := errgroup.WithContext(ctx) errg.Go(func() error { - return task.Storage.Save(uploadCtx, pr, task.Path) + var err error + actualPath, err = task.Storage.Save(uploadCtx, pr, task.Path) + return err }) wr := newWriter(ctx, pw, task.Progress, task) errg.Go(func() error { @@ -40,5 +43,8 @@ func executeStream(ctx context.Context, task *Task) error { return err } logger.Info("File downloaded successfully in stream mode") + if err := task.saveMetadata(ctx, actualPath); err != nil { + logger.Warnf("failed to save metadata: %s", err) + } return nil } diff --git a/core/tasks/tfile/tftask.go b/core/tasks/tfile/tftask.go index b85585cb..0d0cc153 100644 --- a/core/tasks/tfile/tftask.go +++ b/core/tasks/tfile/tftask.go @@ -1,10 +1,13 @@ package tfile import ( + "bytes" "context" "fmt" "path/filepath" + "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/pkg/metadata" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/core" "github.com/krau/SaveAny-Bot/pkg/enums/tasktype" @@ -23,6 +26,7 @@ type Task struct { Progress ProgressTracker stream bool // true if the file should be downloaded in stream mode localPath string + metadata []byte // pre-built JSON metadata, nil if save_metadata is disabled } // Title implements core.Exectable. @@ -42,13 +46,24 @@ func NewTGFileTask( path string, progress ProgressTracker, ) (*Task, error) { + var meta []byte + if config.C().SaveMetadata { + if fmsg, ok := file.(tfile.TGFileMessage); ok { + m := metadata.BuildFromMessage(ctx, fmsg.Message(), file.Name(), file.Size()) + var err error + meta, err = m.ToJSON() + if err != nil { + log.FromContext(ctx).Warnf("failed to marshal metadata: %s", err) + } + } + } _, ok := stor.(storage.StorageCannotStream) if !config.C().Stream || ok { cachePath, err := filepath.Abs(filepath.Join(config.C().Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name()))) if err != nil { return nil, fmt.Errorf("failed to get absolute path for cache: %w", err) } - tfile := &Task{ + return &Task{ ID: id, Ctx: ctx, File: file, @@ -56,10 +71,10 @@ func NewTGFileTask( Path: path, Progress: progress, localPath: cachePath, - } - return tfile, nil + metadata: meta, + }, nil } - tfileTask := &Task{ + return &Task{ ID: id, Ctx: ctx, File: file, @@ -67,6 +82,14 @@ func NewTGFileTask( Path: path, Progress: progress, stream: true, + metadata: meta, + }, nil +} + +func (t *Task) saveMetadata(ctx context.Context, actualPath string) error { + if len(t.metadata) == 0 { + return nil } - return tfileTask, nil + _, err := t.Storage.Save(ctx, bytes.NewReader(t.metadata), actualPath+metadata.MetaSuffix) + return err } diff --git a/core/tasks/transfer/execute.go b/core/tasks/transfer/execute.go index dc57da67..983108ca 100644 --- a/core/tasks/transfer/execute.go +++ b/core/tasks/transfer/execute.go @@ -9,6 +9,7 @@ import ( "path/filepath" "github.com/charmbracelet/log" + "github.com/krau/SaveAny-Bot/pkg/metadata" "github.com/krau/SaveAny-Bot/config" "github.com/krau/SaveAny-Bot/pkg/enums/ctxkey" "github.com/krau/SaveAny-Bot/storage" @@ -91,8 +92,10 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error { // Inject file size into context ctx = context.WithValue(ctx, ctxkey.ContentLength, size) + var actualPath string if config.C().Stream { - if err := elem.TargetStorage.Save(ctx, reader, storagePath); err != nil { + actualPath, err = elem.TargetStorage.Save(ctx, reader, storagePath) + if err != nil { return fmt.Errorf("failed to upload file to storage: %w", err) } } else { @@ -109,7 +112,8 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error { } logger.Infof("Uploading file to storage (size: %d bytes)", size) - if err := elem.TargetStorage.Save(ctx, tempFile, storagePath); err != nil { + actualPath, err = elem.TargetStorage.Save(ctx, tempFile, storagePath) + if err != nil { return fmt.Errorf("failed to upload file to storage: %w", err) } } @@ -117,6 +121,11 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error { t.uploaded.Add(size) t.Progress.OnProgress(ctx, t) + // transfer companion metadata file if exists + if err := t.transferMetadata(ctx, elem, actualPath); err != nil { + logger.Warnf("failed to transfer metadata: %s", err) + } + logger.Info("File uploaded successfully") return nil } @@ -140,3 +149,19 @@ func (t *Task) downloadToTemp(reader io.Reader, filename string) (*os.File, erro return tempFile, nil } + +func (t *Task) transferMetadata(ctx context.Context, elem TaskElement, actualPath string) error { + readable, ok := elem.SourceStorage.(storage.StorageReadable) + if !ok { + return nil + } + metaSourcePath := elem.SourcePath + metadata.MetaSuffix + reader, _, err := readable.OpenFile(ctx, metaSourcePath) + if err != nil { + log.FromContext(ctx).Debugf("Failed to open metadata file %s, skipping: %s", metaSourcePath, err) + return nil + } + defer reader.Close() + _, err = elem.TargetStorage.Save(ctx, reader, actualPath+metadata.MetaSuffix) + return err +} diff --git a/core/tasks/ytdlp/execute.go b/core/tasks/ytdlp/execute.go index 20bb362d..da09c2ce 100644 --- a/core/tasks/ytdlp/execute.go +++ b/core/tasks/ytdlp/execute.go @@ -171,7 +171,7 @@ func (t *Task) transferFile(ctx context.Context, filePath string) error { logger.Infof("Transferring file %s to %s:%s", fileName, t.Storage.Name(), destPath) - if err := t.Storage.Save(ctx, f, destPath); err != nil { + if _, err := t.Storage.Save(ctx, f, destPath); err != nil { return fmt.Errorf("failed to save file %s to storage: %w", fileName, err) } diff --git a/core/tasks/ytdlp/task_test.go b/core/tasks/ytdlp/task_test.go index dab045e1..36a99a7a 100644 --- a/core/tasks/ytdlp/task_test.go +++ b/core/tasks/ytdlp/task_test.go @@ -16,7 +16,7 @@ func (m *MockStorage) Init(ctx context.Context, cfg storcfg.StorageConfig) error func (m *MockStorage) Type() storenum.StorageType { return "mock" } func (m *MockStorage) Name() string { return "test-storage" } func (m *MockStorage) JoinStoragePath(p string) string { return "test-path" } -func (m *MockStorage) Save(ctx context.Context, reader io.Reader, path string) error { return nil } +func (m *MockStorage) Save(ctx context.Context, reader io.Reader, path string) (string, error) { return "", nil } func (m *MockStorage) Exists(ctx context.Context, path string) bool { return false } func TestNewTask(t *testing.T) { diff --git a/docs/content/en/deployment/configuration/_index.md b/docs/content/en/deployment/configuration/_index.md index 177bd68a..ce7e8f7f 100644 --- a/docs/content/en/deployment/configuration/_index.md +++ b/docs/content/en/deployment/configuration/_index.md @@ -45,6 +45,7 @@ Stream mode is very useful for deployment environments with limited disk space, - `workers`: Number of tasks to process simultaneously, default is 3. - `threads`: Number of threads used when downloading files, default is 4. Only effective when Stream mode is not enabled. - `retry`: Number of retries when a task fails, default is 3. +- `save_metadata`: Whether to save a `.meta.json` sidecar file alongside each saved file, containing message info (sender, date, tags, forward origin, etc.), default is `false`. - `proxy`: Global proxy configuration. After setting this, all network connections inside the program will try to use this proxy. Optional. ```toml @@ -53,6 +54,7 @@ stream = false workers = 3 threads = 4 retry = 3 +save_metadata = false proxy = "socks5://127.0.0.1:7890" ``` diff --git a/docs/content/zh/deployment/configuration/_index.md b/docs/content/zh/deployment/configuration/_index.md index e4709611..4849d494 100644 --- a/docs/content/zh/deployment/configuration/_index.md +++ b/docs/content/zh/deployment/configuration/_index.md @@ -44,6 +44,7 @@ Stream 模式对于磁盘空间有限的部署环境十分有用, 但也有一 - `workers`: 同时处理任务数量, 默认为 3 - `threads`: 下载文件时使用的线程数, 默认为 4. 仅在未启用 Stream 模式时生效. - `retry`: 任务失败时的重试次数, 默认为 3. +- `save_metadata`: 保存文件时是否同时生成 `.meta.json` 元数据侧车文件, 包含消息信息 (发送者、日期、标签、转发来源等), 默认为 `false`. - `proxy`: 全局代理配置, 配置后程序内一切网络连接将会尝试使用该代理, 可选. ```toml @@ -51,6 +52,7 @@ stream = false workers = 3 threads = 4 retry = 3 +save_metadata = false proxy = "socks5://127.0.0.1:7890" ``` diff --git a/pkg/metadata/metadata.go b/pkg/metadata/metadata.go new file mode 100644 index 00000000..cbcc1352 --- /dev/null +++ b/pkg/metadata/metadata.go @@ -0,0 +1,183 @@ +package metadata + +import ( + "context" + "encoding/json" + "time" + + "github.com/gotd/td/tg" + "github.com/krau/SaveAny-Bot/common/utils/strutil" + "github.com/krau/SaveAny-Bot/common/utils/tgutil" +) + +type ForwardInfo struct { + Date string `json:"date,omitempty"` + ChatID int64 `json:"chat_id,omitempty"` + ChatTitle string `json:"chat_title,omitempty"` + ChatUsername string `json:"chat_username,omitempty"` + MessageID int `json:"message_id,omitempty"` + Author string `json:"author,omitempty"` +} + +type ReplyInfo struct { + MsgID int `json:"msg_id,omitempty"` + ChatID int64 `json:"chat_id,omitempty"` +} + +type FileMetadata struct { + MessageID int `json:"message_id"` + Date string `json:"date"` + EditDate string `json:"edit_date,omitempty"` + ChatID int64 `json:"chat_id"` + ChatTitle string `json:"chat_title,omitempty"` + ChatUsername string `json:"chat_username,omitempty"` + SenderID int64 `json:"sender_id,omitempty"` + Text string `json:"text,omitempty"` + MediaType string `json:"media_type"` + FileName string `json:"file_name,omitempty"` + FileSize int64 `json:"file_size,omitempty"` + MimeType string `json:"mime_type,omitempty"` + Width int `json:"width,omitempty"` + Height int `json:"height,omitempty"` + Duration float64 `json:"duration,omitempty"` + Title string `json:"title,omitempty"` + Performer string `json:"performer,omitempty"` + ForwardFrom *ForwardInfo `json:"forward_from,omitempty"` + ReplyTo *ReplyInfo `json:"reply_to,omitempty"` + Tags []string `json:"tags,omitempty"` + GroupID int64 `json:"group_id,omitempty"` + OriginalName string `json:"original_name,omitempty"` +} + +func BuildFromMessage(ctx context.Context, msg *tg.Message, fileName string, fileSize int64) FileMetadata { + m := FileMetadata{ + MessageID: msg.GetID(), + Date: func() string { + d := msg.GetDate() + if d == 0 { + return "" + } + return time.Unix(int64(d), 0).UTC().Format(time.RFC3339) + }(), + ChatID: tgutil.ChatIdFromPeer(msg.GetPeerID()), + Text: msg.GetMessage(), + FileSize: fileSize, + GroupID: func() int64 { id, _ := msg.GetGroupedID(); return id }(), + } + + m.ChatTitle, m.ChatUsername = tgutil.ChatInfoFromExt(tgutil.ExtFromContext(ctx), msg.GetPeerID()) + + // media type, mime type, original name, file attributes + if msg.Media != nil { + switch media := msg.Media.(type) { + case *tg.MessageMediaDocument: + m.MediaType = "document" + if doc, ok := media.Document.AsNotEmpty(); ok { + m.MimeType = doc.MimeType + for _, attr := range doc.Attributes { + switch a := attr.(type) { + case *tg.DocumentAttributeVideo: + m.Duration = a.GetDuration() + m.Width = a.GetW() + m.Height = a.GetH() + case *tg.DocumentAttributeAudio: + m.Duration = float64(a.GetDuration()) + if title, ok := a.GetTitle(); ok { + m.Title = title + } + if performer, ok := a.GetPerformer(); ok { + m.Performer = performer + } + case *tg.DocumentAttributeImageSize: + if m.Width == 0 { + m.Width = a.GetW() + m.Height = a.GetH() + } + } + } + } + case *tg.MessageMediaPhoto: + m.MediaType = "photo" + if photo, ok := media.Photo.AsNotEmpty(); ok { + for _, size := range photo.Sizes { + switch s := size.(type) { + case *tg.PhotoSize: + if s.W > m.Width || (s.W == m.Width && s.H > m.Height) { + m.Width, m.Height = s.W, s.H + } + case *tg.PhotoSizeProgressive: + if s.W > m.Width || (s.W == m.Width && s.H > m.Height) { + m.Width, m.Height = s.W, s.H + } + } + } + } + } + origName, _ := tgutil.GetMediaFileName(msg.Media) + m.OriginalName = origName + } + + // file name from the tfile layer (after applying user strategy) + m.FileName = fileName + + // tags from message text + if tags := strutil.ExtractTagsFromText(msg.GetMessage()); len(tags) > 0 { + m.Tags = tags + } + + // sender id + if from, ok := msg.GetFromID(); ok { + m.SenderID = tgutil.ChatIdFromPeer(from) + } + + // edit date + if d, ok := msg.GetEditDate(); ok && d != 0 { + m.EditDate = time.Unix(int64(d), 0).UTC().Format(time.RFC3339) + } + + // reply info + if reply, ok := msg.GetReplyTo(); ok { + if header, ok := reply.(*tg.MessageReplyHeader); ok { + msgID, _ := header.GetReplyToMsgID() + m.ReplyTo = &ReplyInfo{ + MsgID: msgID, + } + if peerID, ok := header.GetReplyToPeerID(); ok { + m.ReplyTo.ChatID = tgutil.ChatIdFromPeer(peerID) + } + } + } + + // forward info + if fwd, ok := msg.GetFwdFrom(); ok { + fwdDate := fwd.GetDate() + fi := &ForwardInfo{ + Date: func() string { + if fwdDate == 0 { + return "" + } + return time.Unix(int64(fwdDate), 0).UTC().Format(time.RFC3339) + }(), + } + if fromID, ok := fwd.GetFromID(); ok { + fi.ChatID = tgutil.ChatIdFromPeer(fromID) + fi.ChatTitle, fi.ChatUsername = tgutil.ChatInfoFromExt(tgutil.ExtFromContext(ctx), fromID) + } + if author, ok := fwd.GetPostAuthor(); ok { + fi.Author = author + } + if postID, ok := fwd.GetChannelPost(); ok { + fi.MessageID = postID + } + m.ForwardFrom = fi + } + + return m +} + +func (m FileMetadata) ToJSON() ([]byte, error) { + return json.MarshalIndent(m, "", " ") +} + +// MetaSuffix is the file extension appended to metadata sidecar files. +const MetaSuffix = ".meta.json" diff --git a/storage/alist/alist.go b/storage/alist/alist.go index f33a3e13..7194aee0 100644 --- a/storage/alist/alist.go +++ b/storage/alist/alist.go @@ -102,19 +102,20 @@ func (a *Alist) Name() string { return a.config.Name } -func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string) error { +func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string) (string, error) { a.logger.Infof("Saving file to %s", storagePath) - storagePath = a.JoinStoragePath(storagePath) - ext := path.Ext(storagePath) - base := strings.TrimSuffix(storagePath, ext) - candidate := storagePath + originalPath := storagePath + joinedPath := a.JoinStoragePath(storagePath) + ext := path.Ext(joinedPath) + base := strings.TrimSuffix(joinedPath, ext) + candidate := joinedPath for i := 1; a.Exists(ctx, candidate); i++ { candidate = fmt.Sprintf("%s_%d%s", base, i, ext) } req, err := http.NewRequestWithContext(ctx, http.MethodPut, a.baseURL+"/api/fs/put", reader) if err != nil { - return fmt.Errorf("failed to create request: %w", err) + return "", fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Authorization", a.token) req.Header.Set("File-Path", url.PathEscape(candidate)) @@ -128,29 +129,32 @@ func (a *Alist) Save(ctx context.Context, reader io.Reader, storagePath string) resp, err := a.client.Do(req) if err != nil { - return fmt.Errorf("failed to send request: %w", err) + return "", fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to save file to Alist: %s", resp.Status) + return "", fmt.Errorf("failed to save file to Alist: %s", resp.Status) } body, err := io.ReadAll(resp.Body) if err != nil { - return fmt.Errorf("failed to read response body: %w", err) + return "", fmt.Errorf("failed to read response body: %w", err) } var putResp putResponse if err := json.Unmarshal(body, &putResp); err != nil { - return fmt.Errorf("failed to unmarshal put response: %w", err) + return "", fmt.Errorf("failed to unmarshal put response: %w", err) } if putResp.Code != http.StatusOK { - return fmt.Errorf("failed to save file to Alist: %d, %s", putResp.Code, putResp.Message) + return "", fmt.Errorf("failed to save file to Alist: %d, %s", putResp.Code, putResp.Message) } - return nil + if candidate != joinedPath { + return path.Join(path.Dir(originalPath), path.Base(candidate)), nil + } + return originalPath, nil } func (a *Alist) JoinStoragePath(p string) string { diff --git a/storage/local/local.go b/storage/local/local.go index c6488008..88664a96 100644 --- a/storage/local/local.go +++ b/storage/local/local.go @@ -49,31 +49,37 @@ func (l *Local) JoinStoragePath(path string) string { return filepath.Join(l.config.BasePath, path) } -func (l *Local) Save(ctx context.Context, r io.Reader, storagePath string) error { +func (l *Local) Save(ctx context.Context, r io.Reader, storagePath string) (string, error) { l.logger.Infof("Saving file to %s", storagePath) - storagePath = l.JoinStoragePath(storagePath) + originalPath := storagePath + joinedPath := l.JoinStoragePath(storagePath) - ext := filepath.Ext(storagePath) - base := strings.TrimSuffix(storagePath, ext) - candidate := storagePath + ext := filepath.Ext(joinedPath) + base := strings.TrimSuffix(joinedPath, ext) + candidate := joinedPath for i := 1; l.Exists(ctx, candidate); i++ { candidate = fmt.Sprintf("%s_%d%s", base, i, ext) } absPath, err := filepath.Abs(candidate) if err != nil { - return err + return "", err } if err := fileutil.CreateDir(filepath.Dir(absPath)); err != nil { - return err + return "", err } file, err := os.Create(absPath) if err != nil { - return err + return "", err } defer file.Close() - _, err = io.Copy(file, r) - return err + if _, err := io.Copy(file, r); err != nil { + return "", err + } + if candidate != joinedPath { + return filepath.Join(filepath.Dir(originalPath), filepath.Base(candidate)), nil + } + return originalPath, nil } func (l *Local) Exists(ctx context.Context, storagePath string) bool { diff --git a/storage/minio/client.go b/storage/minio/client.go index c4ee16bb..074646af 100644 --- a/storage/minio/client.go +++ b/storage/minio/client.go @@ -75,12 +75,13 @@ func (m *Minio) JoinStoragePath(p string) string { return strings.TrimPrefix(path.Join(m.config.BasePath, p), "/") } -func (m *Minio) Save(ctx context.Context, r io.Reader, storagePath string) error { +func (m *Minio) Save(ctx context.Context, r io.Reader, storagePath string) (string, error) { m.logger.Infof("Saving file from reader to %s", storagePath) - storagePath = m.JoinStoragePath(storagePath) - ext := path.Ext(storagePath) - base := strings.TrimSuffix(storagePath, ext) - candidate := storagePath + originalPath := storagePath + joinedPath := m.JoinStoragePath(storagePath) + ext := path.Ext(joinedPath) + base := strings.TrimSuffix(joinedPath, ext) + candidate := joinedPath for i := 1; m.Exists(ctx, candidate); i++ { candidate = fmt.Sprintf("%s_%d%s", base, i, ext) if i > 10 { @@ -98,10 +99,13 @@ func (m *Minio) Save(ctx context.Context, r io.Reader, storagePath string) error } _, err := m.client.PutObject(ctx, m.config.BucketName, candidate, r, size, minio.PutObjectOptions{}) if err != nil { - return fmt.Errorf("failed to upload file to minio: %w", err) + return "", fmt.Errorf("failed to upload file to minio: %w", err) } - return nil + if candidate != joinedPath { + return path.Join(path.Dir(originalPath), path.Base(candidate)), nil + } + return originalPath, nil } func (m *Minio) Exists(ctx context.Context, storagePath string) bool { diff --git a/storage/minio/client_stub.go b/storage/minio/client_stub.go index 0c6601b7..a250e0a8 100644 --- a/storage/minio/client_stub.go +++ b/storage/minio/client_stub.go @@ -32,8 +32,8 @@ func (m *Minio) JoinStoragePath(p string) string { return strings.TrimPrefix(path.Join("", p), "/") } -func (m *Minio) Save(_ context.Context, _ io.Reader, _ string) error { - return fmt.Errorf("minio storage is not supported in this build") +func (m *Minio) Save(_ context.Context, _ io.Reader, _ string) (string, error) { + return "", fmt.Errorf("minio storage is not supported in this build") } func (m *Minio) Exists(_ context.Context, _ string) bool { diff --git a/storage/rclone/rclone.go b/storage/rclone/rclone.go index e705ea7d..d67dd756 100644 --- a/storage/rclone/rclone.go +++ b/storage/rclone/rclone.go @@ -101,7 +101,7 @@ func (r *Rclone) getRemotePath(storagePath string) string { return remote + fullPath } -func (r *Rclone) Save(ctx context.Context, reader io.Reader, storagePath string) error { +func (r *Rclone) Save(ctx context.Context, reader io.Reader, storagePath string) (string, error) { r.logger.Infof("Saving file to %s", storagePath) ext := path.Ext(storagePath) @@ -131,11 +131,11 @@ func (r *Rclone) Save(ctx context.Context, reader io.Reader, storagePath string) if err := cmd.Run(); err != nil { r.logger.Errorf("Failed to save file: %v, stderr: %s", err, stderr.String()) - return fmt.Errorf("%w: %s", ErrFailedToSaveFile, stderr.String()) + return "", fmt.Errorf("%w: %s", ErrFailedToSaveFile, stderr.String()) } r.logger.Infof("Successfully saved file to %s", candidate) - return nil + return candidate, nil } func (r *Rclone) Exists(ctx context.Context, storagePath string) bool { diff --git a/storage/s3/s3.go b/storage/s3/s3.go index 153cd569..17346e4a 100644 --- a/storage/s3/s3.go +++ b/storage/s3/s3.go @@ -63,12 +63,13 @@ func (m *S3) JoinStoragePath(p string) string { return strings.TrimPrefix(path.Join(m.config.BasePath, p), "/") } -func (m *S3) Save(ctx context.Context, r io.Reader, storagePath string) error { +func (m *S3) Save(ctx context.Context, r io.Reader, storagePath string) (string, error) { m.logger.Infof("Saving file from reader to %s", storagePath) - storagePath = m.JoinStoragePath(storagePath) - ext := path.Ext(storagePath) - base := strings.TrimSuffix(storagePath, ext) - candidate := storagePath + originalPath := storagePath + joinedPath := m.JoinStoragePath(storagePath) + ext := path.Ext(joinedPath) + base := strings.TrimSuffix(joinedPath, ext) + candidate := joinedPath // Unique filename for i := 1; m.Exists(ctx, candidate); i++ { @@ -90,10 +91,13 @@ func (m *S3) Save(ctx context.Context, r io.Reader, storagePath string) error { err := m.client.Put(ctx, candidate, r, size) if err != nil { - return fmt.Errorf("failed to upload file to S3: %w", err) + return "", fmt.Errorf("failed to upload file to S3: %w", err) } - return nil + if candidate != joinedPath { + return path.Join(path.Dir(originalPath), path.Base(candidate)), nil + } + return originalPath, nil } func (m *S3) Exists(ctx context.Context, storagePath string) bool { diff --git a/storage/s3/s3_test.go b/storage/s3/s3_test.go index 77749c8e..4123b88f 100644 --- a/storage/s3/s3_test.go +++ b/storage/s3/s3_test.go @@ -64,7 +64,7 @@ func TestS3(t *testing.T) { reader := bytes.NewReader(content) key := "foo/bar.txt" - if err := s.Save(ctx, reader, key); err != nil { + if _, err := s.Save(ctx, reader, key); err != nil { t.Fatalf("Save failed: %v", err) } @@ -76,8 +76,10 @@ func TestS3(t *testing.T) { t.Fatalf("Exists should return false for nonexistent key") } - if err := s.Save(ctx, bytes.NewReader(content), key); err != nil { + if actualPath, err := s.Save(ctx, bytes.NewReader(content), key); err != nil { t.Fatalf("Save with existing key failed: %v", err) + } else if actualPath != "foo/bar_1.txt" { + t.Fatalf("Expected renamed path foo/bar_1.txt, got %s", actualPath) } if !s.Exists(ctx, "foo/bar_1.txt") { @@ -86,7 +88,7 @@ func TestS3(t *testing.T) { var length int64 = int64(len(content)) ctx = context.WithValue(ctx, ctxkey.ContentLength, length) - if err := s.Save(ctx, bytes.NewReader(content), "size_test.txt"); err != nil { + if _, err := s.Save(ctx, bytes.NewReader(content), "size_test.txt"); err != nil { t.Fatalf("Save with content length failed: %v", err) } diff --git a/storage/storage.go b/storage/storage.go index 1d14cc18..d9e5892d 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -22,7 +22,7 @@ type Storage interface { Init(ctx context.Context, cfg storcfg.StorageConfig) error Type() storenum.StorageType Name() string - Save(ctx context.Context, reader io.Reader, storagePath string) error + Save(ctx context.Context, reader io.Reader, storagePath string) (string, error) Exists(ctx context.Context, storagePath string) bool } diff --git a/storage/telegram/telegram.go b/storage/telegram/telegram.go index 6e94c869..14bd07da 100644 --- a/storage/telegram/telegram.go +++ b/storage/telegram/telegram.go @@ -70,11 +70,12 @@ func (t *Telegram) Exists(ctx context.Context, storagePath string) bool { return false } -func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) error { +func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) (string, error) { storagePath = path.Clean(storagePath) + originalPath := storagePath tctx := tgutil.ExtFromContext(ctx) if tctx == nil { - return fmt.Errorf("failed to get telegram context") + return "", fmt.Errorf("failed to get telegram context") } size := func() int64 { if length := ctx.Value(ctxkey.ContentLength); length != nil { @@ -86,7 +87,7 @@ func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) er }() if t.config.SkipLarge && size > MaxUploadFileSize { log.FromContext(ctx).Warnf("Skipping file larger than Telegram limit (%d bytes): %d bytes", MaxUploadFileSize, size) - return nil + return originalPath, nil } rs, seekable := r.(io.ReadSeeker) splitSize := t.config.SplitSizeMB * 1024 * 1024 @@ -95,7 +96,7 @@ func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) er } if err := t.limiter.Wait(ctx); err != nil { - return fmt.Errorf("rate limit failed: %w", err) + return "", fmt.Errorf("rate limit failed: %w", err) } // 去除前导斜杠并分隔路径, 当 len(parts): @@ -122,21 +123,21 @@ func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) er WithThreads(dlutil.BestThreads(size, config.C().Threads)) peer := tryGetInputPeer(tctx, chatID) if peer == nil || peer.Zero() { - return fmt.Errorf("failed to get input peer for chat ID %d", chatID) + return "", fmt.Errorf("failed to get input peer for chat ID %d", chatID) } var mtype *mimetype.MIME if seekable { var err error mtype, err = mimetype.DetectReader(rs) if err != nil { - return fmt.Errorf("failed to detect mimetype: %w", err) + return "", fmt.Errorf("failed to detect mimetype: %w", err) } if filename == "" { filename = xid.New().String() + mtype.Extension() } if _, err := rs.Seek(0, io.SeekStart); err != nil { - return fmt.Errorf("failed to seek reader: %w", err) + return "", fmt.Errorf("failed to seek reader: %w", err) } } if size > splitSize { @@ -152,7 +153,7 @@ func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) er file, err = upler.Upload(ctx, uploader.NewUpload(filename, r, size)) } if err != nil { - return fmt.Errorf("failed to upload file to telegram: %w", err) + return "", fmt.Errorf("failed to upload file to telegram: %w", err) } caption := styling.Plain(filename) forceFile := t.config.ForceFile @@ -210,14 +211,17 @@ func (t *Telegram) Save(ctx context.Context, r io.Reader, storagePath string) er } sender := tctx.Sender _, err = sender.WithUploader(upler).To(peer).Media(ctx, media) - return err + if err != nil { + return "", err + } + return originalPath, nil } func (t *Telegram) CannotStream() string { return "Telegram storage must use a ReaderSeeker" } -func (t *Telegram) splitUpload(ctx *ext.Context, r io.Reader, filename string, upler *uploader.Uploader, peer tg.InputPeerClass, fileSize, splitSize int64) error { +func (t *Telegram) splitUpload(ctx *ext.Context, r io.Reader, filename string, upler *uploader.Uploader, peer tg.InputPeerClass, fileSize, splitSize int64) (string, error) { tempId := xid.New().String() outputBase := filepath.Join(config.C().Temp.BasePath, tempId, strings.Split(filename, ".")[0]) defer func() { @@ -227,11 +231,11 @@ func (t *Telegram) splitUpload(ctx *ext.Context, r io.Reader, filename string, u } }() if err := CreateSplitZip(ctx, r, fileSize, filename, outputBase, splitSize); err != nil { - return fmt.Errorf("failed to create split zip: %w", err) + return "", fmt.Errorf("failed to create split zip: %w", err) } matched, err := filepath.Glob(outputBase + ".z*") if err != nil { - return fmt.Errorf("failed to glob split files: %w", err) + return "", fmt.Errorf("failed to glob split files: %w", err) } inputFiles := make([]tg.InputFileClass, 0, len(matched)) for _, partPath := range matched { @@ -256,7 +260,7 @@ func (t *Telegram) splitUpload(ctx *ext.Context, r io.Reader, filename string, u return nil }() if err != nil { - return fmt.Errorf("failed to upload split part %s: %w", partPath, err) + return "", fmt.Errorf("failed to upload split part %s: %w", partPath, err) } } if len(inputFiles) == 1 { @@ -270,7 +274,10 @@ func (t *Telegram) splitUpload(ctx *ext.Context, r io.Reader, filename string, u WithUploader(upler). To(peer). Media(ctx, doc) - return err + if err != nil { + return "", err + } + return filename, nil } multiMedia := make([]message.MultiMediaOption, 0, len(inputFiles)) @@ -287,7 +294,10 @@ func (t *Telegram) splitUpload(ctx *ext.Context, r io.Reader, filename string, u _, err = sender.WithUploader(upler). To(peer). Album(ctx, multiMedia[0], multiMedia[1:]...) - return err + if err != nil { + return "", err + } + return filename, nil } // more than 10 parts, send in batches, each batch up to 10 parts @@ -298,9 +308,8 @@ func (t *Telegram) splitUpload(ctx *ext.Context, r io.Reader, filename string, u To(peer). Album(ctx, batch[0], batch[1:]...) if err != nil { - return fmt.Errorf("failed to send album batch: %w", err) + return "", fmt.Errorf("failed to send album batch: %w", err) } } - return nil - + return filename, nil } diff --git a/storage/webdav/webdav.go b/storage/webdav/webdav.go index db4216d9..64510b3f 100644 --- a/storage/webdav/webdav.go +++ b/storage/webdav/webdav.go @@ -51,12 +51,13 @@ func (w *Webdav) JoinStoragePath(p string) string { return path.Join(w.config.BasePath, p) } -func (w *Webdav) Save(ctx context.Context, r io.Reader, storagePath string) error { +func (w *Webdav) Save(ctx context.Context, r io.Reader, storagePath string) (string, error) { w.logger.Infof("Saving file to %s", storagePath) - storagePath = w.JoinStoragePath(storagePath) - ext := path.Ext(storagePath) - base := strings.TrimSuffix(storagePath, ext) - candidate := storagePath + originalPath := storagePath + joinedPath := w.JoinStoragePath(storagePath) + ext := path.Ext(joinedPath) + base := strings.TrimSuffix(joinedPath, ext) + candidate := joinedPath for i := 1; w.Exists(ctx, candidate); i++ { candidate = fmt.Sprintf("%s_%d%s", base, i, ext) if i > 1000 { @@ -68,13 +69,16 @@ func (w *Webdav) Save(ctx context.Context, r io.Reader, storagePath string) erro if err := w.client.MkDir(ctx, path.Dir(candidate)); err != nil { w.logger.Errorf("Failed to create directory %s: %v", path.Dir(candidate), err) - return ErrFailedToCreateDirectory + return "", ErrFailedToCreateDirectory } if err := w.client.WriteFile(ctx, candidate, r); err != nil { w.logger.Errorf("Failed to write file %s: %v", candidate, err) - return ErrFailedToWriteFile + return "", ErrFailedToWriteFile } - return nil + if candidate != joinedPath { + return path.Join(path.Dir(originalPath), path.Base(candidate)), nil + } + return originalPath, nil } func (w *Webdav) Exists(ctx context.Context, storagePath string) bool {