Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func (f *TaskFactory) createTGFilesTask(taskID string, createdAt time.Time, req
// 批量文件任务
elems := make([]batchtfile.TaskElement, 0, len(files))
for _, file := range files {
elem, err := batchtfile.NewTaskElement(stor, req.Path, file)
elem, err := batchtfile.NewTaskElement(f.ctx, stor, req.Path, file)
if err != nil {
return nil, fmt.Errorf("failed to create task element: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions client/bot/handlers/utils/shortcut/tftask.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
}
if !dirPath.NeedNewForAlbum() {
storPath := path.Join(dirPath.String(), file.Name())
elem, err := batchtfile.NewTaskElement(fileStor, storPath, file)
elem, err := batchtfile.NewTaskElement(ctx, fileStor, storPath, file)
if err != nil {
logger.Errorf("Failed to create task element: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
Expand Down Expand Up @@ -189,7 +189,7 @@ func CreateAndAddBatchTGFileTaskWithEdit(ctx *ext.Context, userID int64, stor st
albumStor := afiles[0].storage
for _, af := range afiles {
afstorPath := path.Join(dirPath, albumDir, af.file.Name())
elem, err := batchtfile.NewTaskElement(albumStor, afstorPath, af.file)
elem, err := batchtfile.NewTaskElement(ctx, albumStor, afstorPath, af.file)
if err != nil {
logger.Errorf("Failed to create task element for album file: %s", err)
ctx.EditMessage(userID, &tg.MessagesEditMessageRequest{
Expand Down
2 changes: 1 addition & 1 deletion cmd/upload/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func Upload(cmd *cobra.Command, args []string) error {
reader = file
}

if err := stor.Save(ctx, reader, uploadPath); err != nil {
if _, err := stor.Save(ctx, reader, uploadPath); err != nil {
if progressUI != nil {
progressUI.SetError(err)
progressUI.Wait()
Expand Down
131 changes: 131 additions & 0 deletions common/utils/tgutil/chatinfo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package tgutil

import (
"github.com/celestix/gotgproto/ext"
"github.com/charmbracelet/log"
"github.com/gotd/td/tg"
)

// ChatInfoFromExt extracts chat title and username for the given peer.
func ChatInfoFromExt(extCtx *ext.Context, peer tg.PeerClass) (title, username string) {
if extCtx == nil {
return
}
chatID := ChatIdFromPeer(peer)
if chatID == 0 {
return
}

if extCtx.Entities != nil {
if t, u, found := lookupEntities(extCtx.Entities, chatID); found {
return t, u
}
}

return fetchFromAPI(extCtx, peer, chatID)
}
Comment on lines +19 to +26

func formatUserName(firstName, lastName string) string {
if lastName != "" {
return firstName + " " + lastName
}
return firstName
}

func lookupEntities(entities *tg.Entities, chatID int64) (title, username string, found bool) {
if ch, ok := entities.Channels[chatID]; ok {
return ch.Title, ch.Username, true
}
if ch, ok := entities.Chats[chatID]; ok {
return ch.Title, "", true // tg.Chat has no Username field
}
if u, ok := entities.Users[chatID]; ok {
return formatUserName(u.FirstName, u.LastName), u.Username, true
}
return "", "", false
}

func fetchFromAPI(extCtx *ext.Context, peer tg.PeerClass, chatID int64) (title, username string) {
if extCtx.Raw == nil {
return
}

var err error
switch peer.(type) {
case *tg.PeerChannel:
title, username, err = fetchChannel(extCtx, chatID)
case *tg.PeerChat:
title, username, err = fetchChat(extCtx, chatID)
case *tg.PeerUser:
title, username, err = fetchUser(extCtx, chatID)
}
if err != nil {
log.Debug("Failed to fetch chat info from API", "chatID", chatID, "error", err)
}
return
}

func resolveInputPeer(extCtx *ext.Context, chatID int64) (tg.InputPeerClass, error) {
return extCtx.ResolveInputPeerById(chatID)
}

func fetchChannel(extCtx *ext.Context, chatID int64) (string, string, error) {
inputPeer, err := resolveInputPeer(extCtx, chatID)
if err != nil {
return "", "", err
}
ch, ok := inputPeer.(*tg.InputPeerChannel)
if !ok {
return "", "", nil
}

result, err := extCtx.Raw.ChannelsGetChannels(extCtx, []tg.InputChannelClass{
&tg.InputChannel{ChannelID: chatID, AccessHash: ch.AccessHash},
})
if err != nil {
return "", "", err
}
for _, c := range result.GetChats() {
if channel, ok := c.(*tg.Channel); ok && channel.ID == chatID {
return channel.Title, channel.Username, nil
}
}
return "", "", nil
}

func fetchChat(extCtx *ext.Context, chatID int64) (string, string, error) {
result, err := extCtx.Raw.MessagesGetFullChat(extCtx, chatID)
if err != nil {
return "", "", err
}
for _, c := range result.GetChats() {
if chat, ok := c.(*tg.Chat); ok && chat.ID == chatID {
return chat.Title, "", nil
}
}
return "", "", nil
}

func fetchUser(extCtx *ext.Context, chatID int64) (string, string, error) {
inputPeer, err := resolveInputPeer(extCtx, chatID)
if err != nil {
return "", "", err
}
u, ok := inputPeer.(*tg.InputPeerUser)
if !ok {
return "", "", nil
}

users, err := extCtx.Raw.UsersGetUsers(extCtx, []tg.InputUserClass{
&tg.InputUser{UserID: chatID, AccessHash: u.AccessHash},
})
if err != nil {
return "", "", err
}
for _, user := range users {
if u, ok := user.(*tg.User); ok && u.ID == chatID {
return formatUserName(u.FirstName, u.LastName), u.Username, nil
}
}
return "", "", nil
}
Comment on lines +72 to +131
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is recommended to add an appropriate caching mechanism to the fetchXXX methods, as they are invoked within synchronous interaction flows.

1 change: 1 addition & 0 deletions config.example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ workers = 4 # 同时下载文件数
retry = 3 # 下载失败重试次数
threads = 4 # 单个任务下载使用的最大线程数
stream = false # 使用流式传输模式, 建议仅在硬盘空间十分有限时使用.
save_metadata = false # 保存文件时同时生成 .meta.json 元数据侧车文件

[log]
# 日志级别, 可选: debug, info, warn, error, fatal
Expand Down
2 changes: 2 additions & 0 deletions config/viper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type Config struct {
NoCleanCache bool `toml:"no_clean_cache" mapstructure:"no_clean_cache" json:"no_clean_cache"`
Threads int `toml:"threads" mapstructure:"threads" json:"threads"`
Stream bool `toml:"stream" mapstructure:"stream" json:"stream"`
SaveMetadata bool `toml:"save_metadata" mapstructure:"save_metadata" json:"save_metadata"`
Proxy string `toml:"proxy" mapstructure:"proxy" json:"proxy"`
Log logConfig `toml:"log" mapstructure:"log" json:"log"`
Aria2 aria2Config `toml:"aria2" mapstructure:"aria2" json:"aria2"`
Expand Down Expand Up @@ -105,6 +106,7 @@ func Init(ctx context.Context, configFile ...string) error {
"workers": 3,
"retry": 3,
"threads": 4,
"save_metadata": false,
"log.level": "debug",

// 缓存配置
Expand Down
2 changes: 1 addition & 1 deletion core/tasks/aria2dl/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func (t *Task) transferFile(ctx context.Context, filePath string) error {

logger.Infof("Transferring file %s to %s:%s", fileName, t.Storage.Name(), destPath)

if err := t.Storage.Save(ctx, f, destPath); err != nil {
if _, err := t.Storage.Save(ctx, f, destPath); err != nil {
return fmt.Errorf("failed to save file %s to storage: %w", fileName, err)
}

Expand Down
4 changes: 2 additions & 2 deletions core/tasks/aria2dl/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ func (m *mockStorage) Init(ctx context.Context, config storconfig.StorageConfig)
return nil
}

func (m *mockStorage) Save(ctx context.Context, reader io.Reader, path string) error {
func (m *mockStorage) Save(ctx context.Context, reader io.Reader, path string) (string, error) {
m.savePath = path
return nil
return path, nil
}

func (m *mockStorage) Exists(ctx context.Context, path string) bool {
Expand Down
19 changes: 16 additions & 3 deletions core/tasks/batchtfile/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
if elem.stream {
pr, pw := io.Pipe()
defer pr.Close()
var actualPath string
errg, uploadCtx := errgroup.WithContext(ctx)
errg.Go(func() error {
return elem.Storage.Save(uploadCtx, pr, elem.Path)
var err error
actualPath, err = elem.Storage.Save(uploadCtx, pr, elem.Path)
return err
})
wr := ioutil.NewProgressWriter(pw, func(n int) {
t.downloaded.Add(int64(n))
Expand All @@ -79,6 +82,9 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
return fmt.Errorf("failed to download file in stream mode: %w", err)
}
logger.Info("File downloaded successfully in stream mode")
if err := elem.saveMetadata(ctx, actualPath); err != nil {
logger.Warnf("failed to save metadata: %s", err)
}
return nil
}
logger.Info("Starting file download")
Expand Down Expand Up @@ -112,18 +118,25 @@ func (t *Task) processElement(ctx context.Context, elem TaskElement) error {
return fmt.Errorf("failed to get file stat: %w", err)
}
vctx := context.WithValue(ctx, ctxkey.ContentLength, fileStat.Size())
var actualPath string
err = retry.Retry(func() error {
var file *os.File
file, err = os.Open(elem.localPath)
if err != nil {
return fmt.Errorf("failed to open cache file: %w", err)
}
defer file.Close()
if err = elem.Storage.Save(vctx, file, elem.Path); err != nil {
if actualPath, err = elem.Storage.Save(vctx, file, elem.Path); err != nil {
logger.Errorf("Failed to save file: %s, retrying...", err)
return err
}
return nil
}, retry.Context(vctx), retry.RetryTimes(uint(config.C().Retry)))
return err
if err != nil {
return err
}
if err := elem.saveMetadata(ctx, actualPath); err != nil {
logger.Warnf("failed to save metadata: %s", err)
}
return nil
}
38 changes: 33 additions & 5 deletions core/tasks/batchtfile/task.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package batchtfile

import (
"bytes"
"context"
"fmt"
"path/filepath"
"sync"
"sync/atomic"

"github.com/charmbracelet/log"
"github.com/krau/SaveAny-Bot/pkg/metadata"
"github.com/krau/SaveAny-Bot/config"
"github.com/krau/SaveAny-Bot/core"
"github.com/krau/SaveAny-Bot/pkg/enums/tasktype"
Expand All @@ -24,6 +27,15 @@ type TaskElement struct {
File tfile.TGFile
localPath string
stream bool
metadata []byte
}

func (e *TaskElement) saveMetadata(ctx context.Context, actualPath string) error {
if len(e.metadata) == 0 {
return nil
}
_, err := e.Storage.Save(ctx, bytes.NewReader(e.metadata), actualPath+metadata.MetaSuffix)
return err
}

type Task struct {
Expand All @@ -49,11 +61,25 @@ func (t *Task) Type() tasktype.TaskType {
}

func NewTaskElement(
ctx context.Context,
stor storage.Storage,
path string,
file tfile.TGFile,
) (*TaskElement, error) {
id := xid.New().String()

var meta []byte
if config.C().SaveMetadata {
if fmsg, ok := file.(tfile.TGFileMessage); ok {
m := metadata.BuildFromMessage(ctx, fmsg.Message(), file.Name(), file.Size())
var err error
meta, err = m.ToJSON()
if err != nil {
log.FromContext(ctx).Warnf("failed to marshal metadata: %s", err)
}
}
}

_, ok := stor.(storage.StorageCannotStream)
if !config.C().Stream || ok {
cachePath, err := filepath.Abs(filepath.Join(config.C().Temp.BasePath, fmt.Sprintf("%s_%s", id, file.Name())))
Expand All @@ -66,14 +92,16 @@ func NewTaskElement(
Path: path,
File: file,
localPath: cachePath,
metadata: meta,
}, nil
}
return &TaskElement{
ID: id,
Storage: stor,
Path: path,
File: file,
stream: true,
ID: id,
Storage: stor,
Path: path,
File: file,
stream: true,
metadata: meta,
}, nil
}

Expand Down
6 changes: 4 additions & 2 deletions core/tasks/directlinks/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ func (t *Task) processLink(ctx context.Context, file *File) error {
}
ctx = context.WithValue(ctx, ctxkey.ContentLength, file.Size)
if t.stream {
return t.Storage.Save(ctx, resp.Body, filepath.Join(t.StorPath, file.Name))
_, err := t.Storage.Save(ctx, resp.Body, filepath.Join(t.StorPath, file.Name))
return err
}
cacheFile, err := fsutil.CreateFile(filepath.Join(config.C().Temp.BasePath,
fmt.Sprintf("direct_%s_%s", t.ID, file.Name)))
Expand Down Expand Up @@ -166,7 +167,8 @@ func (t *Task) processLink(ctx context.Context, file *File) error {
if err != nil {
return fmt.Errorf("failed to seek cache file for resource %s: %w", file.URL, err)
}
return t.Storage.Save(ctx, cacheFile, filepath.Join(t.StorPath, file.Name))
_, err = t.Storage.Save(ctx, cacheFile, filepath.Join(t.StorPath, file.Name))
return err
}, retry.RetryTimes(uint(config.C().Retry)), retry.Context(ctx))
if ctx.Err() != nil {
return ctx.Err()
Expand Down
6 changes: 4 additions & 2 deletions core/tasks/parsed/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ func (t *Task) processResource(ctx context.Context, resource parser.Resource) er
return resp.ContentLength
}())
if t.stream {
return t.Stor.Save(ctx, resp.Body, path.Join(t.StorPath, resource.Filename))
_, err := t.Stor.Save(ctx, resp.Body, path.Join(t.StorPath, resource.Filename))
return err
}
cacheFile, err := fsutil.CreateFile(filepath.Join(config.C().Temp.BasePath,
fmt.Sprintf("resource_%s_%s", t.ID, resource.Filename)))
Expand Down Expand Up @@ -130,7 +131,8 @@ func (t *Task) processResource(ctx context.Context, resource parser.Resource) er
if err != nil {
return fmt.Errorf("failed to seek cache file for resource %s: %w", resource.URL, err)
}
return t.Stor.Save(ctx, cacheFile, path.Join(t.StorPath, resource.Filename))
_, err = t.Stor.Save(ctx, cacheFile, path.Join(t.StorPath, resource.Filename))
return err
}, retry.Context(ctx), retry.RetryTimes(uint(config.C().Retry)))
if ctx.Err() != nil {
return ctx.Err()
Expand Down
4 changes: 2 additions & 2 deletions core/tasks/telegraph/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ func (t *Task) processPic(ctx context.Context, picUrl string, index int) error {
if err != nil {
return fmt.Errorf("failed to seek cache file for picture %s: %w", filename, err)
}
err = t.Stor.Save(ctx, cacheFile, path.Join(t.StorPath, filename))
_, err = t.Stor.Save(ctx, cacheFile, path.Join(t.StorPath, filename))
if err != nil {
return fmt.Errorf("failed to save picture %s: %w", filename, err)
}
} else {
err = t.Stor.Save(ctx, body, path.Join(t.StorPath, filename))
_, err = t.Stor.Save(ctx, body, path.Join(t.StorPath, filename))
}

if err != nil {
Expand Down
Loading