diff --git a/shortcuts/im/helpers_network_test.go b/shortcuts/im/helpers_network_test.go index a068c5b74..b06c43cd0 100644 --- a/shortcuts/im/helpers_network_test.go +++ b/shortcuts/im/helpers_network_test.go @@ -573,6 +573,113 @@ func TestDownloadIMResourceToPathInvalidContentRange(t *testing.T) { } } +func TestDownloadIMResourceToPathInitialRangeMustMatchRequest(t *testing.T) { + runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { + switch { + case strings.Contains(req.URL.Path, "tenant_access_token"): + return shortcutJSONResponse(200, map[string]interface{}{ + "code": 0, + "tenant_access_token": "tenant-token", + "expire": 7200, + }), nil + case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_initial/resources/file_initial"): + if got := req.Header.Get("Range"); got != "bytes=0-131071" { + return nil, fmt.Errorf("Range = %q, want bytes=0-131071", got) + } + return shortcutRawResponse(http.StatusPartialContent, []byte("bad"), http.Header{ + "Content-Type": []string{"application/octet-stream"}, + "Content-Range": []string{"bytes 0-2/131082"}, + }), nil + default: + return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) + } + })) + + cmdutil.TestChdir(t, t.TempDir()) + _, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_initial", "file_initial", "file", "out.bin", true) + if err == nil || !strings.Contains(err.Error(), "unexpected initial Content-Range") { + t.Fatalf("downloadIMResourceToPath() error = %v, want unexpected initial Content-Range", err) + } +} + +func TestDownloadIMResourceRangeChunksValidateContentRange(t *testing.T) { + tests := []struct { + name string + secondCR string + wantErrSub string + }{ + {name: "matching range succeeds", secondCR: "bytes 131072-131081/131082"}, + {name: "wrong start fails", secondCR: "bytes 0-9/131082", wantErrSub: "unexpected Content-Range"}, + {name: "wrong end fails", secondCR: "bytes 131072-131080/131082", wantErrSub: "unexpected Content-Range"}, + {name: "wrong total fails", secondCR: "bytes 131072-131081/999999", wantErrSub: "unexpected Content-Range"}, + {name: "malformed range fails", secondCR: "bytes 131072-131081/*", wantErrSub: "invalid Content-Range header on range response"}, + {name: "missing range fails", secondCR: "", wantErrSub: "invalid Content-Range header on range response"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var requestCount int + runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { + switch { + case strings.Contains(req.URL.Path, "tenant_access_token"): + return shortcutJSONResponse(200, map[string]interface{}{ + "code": 0, + "tenant_access_token": "tenant-token", + "expire": 7200, + }), nil + case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_123/resources/file_123"): + requestCount++ + switch requestCount { + case 1: + if got := req.Header.Get("Range"); got != "bytes=0-131071" { + return nil, fmt.Errorf("first Range = %q, want bytes=0-131071", got) + } + return imResourceRangeResponse(http.StatusPartialContent, "bytes 0-131071/131082", strings.Repeat("a", int(probeChunkSize))), nil + case 2: + if got := req.Header.Get("Range"); got != "bytes=131072-131081" { + return nil, fmt.Errorf("second Range = %q, want bytes=131072-131081", got) + } + return imResourceRangeResponse(http.StatusPartialContent, tt.secondCR, "bbbbbbbbbb"), nil + default: + return nil, fmt.Errorf("unexpected resource request %d", requestCount) + } + default: + return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) + } + })) + + cmdutil.TestChdir(t, t.TempDir()) + _, size, err := downloadIMResourceToPath(context.Background(), runtime, "om_123", "file_123", "file", "out.bin", true) + if tt.wantErrSub != "" { + if err == nil || !strings.Contains(err.Error(), tt.wantErrSub) { + t.Fatalf("downloadIMResourceToPath() error = %v, want substring %q", err, tt.wantErrSub) + } + return + } + if err != nil { + t.Fatalf("downloadIMResourceToPath() unexpected error = %v", err) + } + if size != 131082 { + t.Fatalf("downloadIMResourceToPath() size = %d, want 131082", size) + } + }) + } +} + +func imResourceRangeResponse(status int, contentRange, body string) *http.Response { + resp := &http.Response{ + StatusCode: status, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + ContentLength: int64(len(body)), + } + if contentRange != "" { + resp.Header.Set("Content-Range", contentRange) + } + resp.Header.Set("Content-Type", "application/octet-stream") + return resp +} + func TestDownloadIMResourceToPathRangeChunkFailureCleansOutput(t *testing.T) { payload := bytes.Repeat([]byte("range-download-"), int((probeChunkSize+1024)/15)+1) payload = payload[:probeChunkSize+1024] @@ -605,6 +712,38 @@ func TestDownloadIMResourceToPathRangeChunkFailureCleansOutput(t *testing.T) { } } +func TestDownloadIMResourceToPathRangeChunkLengthMismatch(t *testing.T) { + totalSize := probeChunkSize + 20 + runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { + switch { + case strings.Contains(req.URL.Path, "/open-apis/im/v1/messages/om_len/resources/file_len"): + switch req.Header.Get("Range") { + case fmt.Sprintf("bytes=0-%d", probeChunkSize-1): + return shortcutRawResponse(http.StatusPartialContent, []byte(strings.Repeat("a", int(probeChunkSize)+1)), http.Header{ + "Content-Type": []string{"application/octet-stream"}, + "Content-Range": []string{fmt.Sprintf("bytes 0-%d/%d", probeChunkSize-1, totalSize)}, + }), nil + case fmt.Sprintf("bytes=%d-%d", probeChunkSize, totalSize-1): + return shortcutRawResponse(http.StatusPartialContent, []byte(strings.Repeat("b", 19)), http.Header{ + "Content-Type": []string{"application/octet-stream"}, + "Content-Range": []string{fmt.Sprintf("bytes %d-%d/%d", probeChunkSize, totalSize-1, totalSize)}, + }), nil + default: + return nil, fmt.Errorf("unexpected range: %s", req.Header.Get("Range")) + } + default: + return nil, fmt.Errorf("unexpected request: %s", req.URL.String()) + } + })) + + cmdutil.TestChdir(t, t.TempDir()) + + _, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_len", "file_len", "file", "out.bin", true) + if err == nil || !strings.Contains(err.Error(), "chunk size mismatch") { + t.Fatalf("downloadIMResourceToPath() error = %v, want chunk size mismatch", err) + } +} + func TestDownloadIMResourceToPathRangeOverflowCleansOutput(t *testing.T) { payload := []byte("overflow-payload") runtime := newBotShortcutRuntime(t, shortcutRoundTripFunc(func(req *http.Request) (*http.Response, error) { @@ -659,7 +798,7 @@ func TestDownloadIMResourceToPathRangeShortChunkSizeMismatch(t *testing.T) { cmdutil.TestChdir(t, t.TempDir()) _, _, err := downloadIMResourceToPath(context.Background(), runtime, "om_short", "file_short", "file", "out.bin", true) - if err == nil || !strings.Contains(err.Error(), "file size mismatch") { + if err == nil || !strings.Contains(err.Error(), "chunk size mismatch") { t.Fatalf("downloadIMResourceToPath() error = %v", err) } } diff --git a/shortcuts/im/helpers_test.go b/shortcuts/im/helpers_test.go index 114ae524f..5d349138f 100644 --- a/shortcuts/im/helpers_test.go +++ b/shortcuts/im/helpers_test.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "net/http" + "path/filepath" "reflect" "strings" "testing" @@ -563,6 +564,7 @@ func TestNormalizeDownloadOutputPath(t *testing.T) { {name: "empty key", fileKey: " ", wantErr: "file-key cannot be empty"}, {name: "separator in key", fileKey: "dir/file", wantErr: "file-key cannot contain path separators"}, {name: "absolute path", fileKey: "file_123", outputPath: "/tmp/out.bin", wantErr: "absolute paths are not allowed"}, + {name: "windows rooted path", fileKey: "file_123", outputPath: `\tmp\out.bin`, wantErr: "absolute paths are not allowed"}, {name: "parent escape", fileKey: "file_123", outputPath: "../out.bin", wantErr: "path cannot escape the current working directory"}, {name: "empty path after clean", fileKey: "file_123", outputPath: " . ", wantErr: "path cannot be empty"}, } @@ -599,39 +601,43 @@ func TestDownloadIMResourceToPathHTTPClientError(t *testing.T) { } } -func TestParseTotalSize(t *testing.T) { +func TestParseContentRange(t *testing.T) { tests := []struct { name string contentRange string - want int64 + want contentRange wantErr string }{ - {name: "normal", contentRange: "bytes 0-131071/104857600", want: 104857600}, - {name: "single probe chunk", contentRange: "bytes 0-131071/131072", want: 131072}, - {name: "single small chunk", contentRange: "bytes 0-15/16", want: 16}, + {name: "normal", contentRange: "bytes 0-131071/104857600", want: contentRange{start: 0, end: 131071, total: 104857600}}, + {name: "single small chunk", contentRange: "bytes 0-15/16", want: contentRange{start: 0, end: 15, total: 16}}, {name: "empty", contentRange: "", wantErr: "content-range is empty"}, {name: "invalid prefix", contentRange: "items 0-15/16", wantErr: `unsupported content-range: "items 0-15/16"`}, - {name: "missing total", contentRange: "bytes 0-15/", wantErr: `unsupported content-range: "bytes 0-15/"`}, - {name: "wildcard", contentRange: "bytes */16", wantErr: `unsupported content-range: "bytes */16"`}, + {name: "missing slash", contentRange: "bytes 0-15", wantErr: `unsupported content-range: "bytes 0-15"`}, + {name: "missing range end", contentRange: "bytes 0-/16", wantErr: `unsupported content-range: "bytes 0-/16"`}, + {name: "wildcard range", contentRange: "bytes */16", wantErr: `unsupported content-range: "bytes */16"`}, {name: "unknown total size", contentRange: "bytes 0-99/*", wantErr: `unknown total size in content-range: "bytes 0-99/*"`}, + {name: "invalid start", contentRange: "bytes nope-15/16", wantErr: "parse range start:"}, + {name: "invalid end", contentRange: "bytes 0-nope/16", wantErr: "parse range end:"}, {name: "invalid total", contentRange: "bytes 0-15/not-a-number", wantErr: "parse total size:"}, + {name: "start after end", contentRange: "bytes 16-15/32", wantErr: "invalid content range: start 16 is after end 15"}, {name: "zero total size", contentRange: "bytes 0-0/0", wantErr: "invalid total size: 0"}, + {name: "end reaches total", contentRange: "bytes 0-16/16", wantErr: "invalid content range: end 16 is outside total 16"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := parseTotalSize(tt.contentRange) + got, err := parseContentRange(tt.contentRange) if tt.wantErr != "" { if err == nil || !strings.Contains(err.Error(), tt.wantErr) { - t.Fatalf("parseTotalSize() error = %v, want substring %q", err, tt.wantErr) + t.Fatalf("parseContentRange() error = %v, want substring %q", err, tt.wantErr) } return } if err != nil { - t.Fatalf("parseTotalSize() unexpected error = %v", err) + t.Fatalf("parseContentRange() unexpected error = %v", err) } if got != tt.want { - t.Fatalf("parseTotalSize() = %d, want %d", got, tt.want) + t.Fatalf("parseContentRange() = %+v, want %+v", got, tt.want) } }) } @@ -682,7 +688,7 @@ func TestResolveIMResourceDownloadPath(t *testing.T) { {name: "default path, CD RFC5987", safePath: "file_xxx", contentDisposition: `attachment; filename*=UTF-8''%E5%AD%A3%E5%BA%A6%E6%8A%A5%E5%91%8A.xlsx`, want: "季度报告.xlsx"}, {name: "default path, no CD, MIME ext", safePath: "file_xxx", contentType: "application/pdf", want: "file_xxx.pdf"}, {name: "default path, no CD, unknown MIME", safePath: "file_xxx", contentType: "application/x-unknown", want: "file_xxx"}, - {name: "default path, CD with dir component", safePath: "downloads/file_xxx", contentDisposition: `attachment; filename="report.xlsx"`, want: "downloads/report.xlsx"}, + {name: "default path, CD with dir component", safePath: "downloads/file_xxx", contentDisposition: `attachment; filename="report.xlsx"`, want: filepath.Join("downloads", "report.xlsx")}, // User --output without extension: use CD filename's extension {name: "user path no ext, CD with ext", safePath: "myfile", contentDisposition: `attachment; filename="server.pdf"`, userSpecifiedOutput: true, want: "myfile.pdf"}, {name: "user path no ext, CD no ext, MIME ext", safePath: "myfile", contentDisposition: `attachment; filename="noext"`, contentType: "image/png", userSpecifiedOutput: true, want: "myfile.png"}, diff --git a/shortcuts/im/im_messages_resources_download.go b/shortcuts/im/im_messages_resources_download.go index abb3c3a54..b0fc09dea 100644 --- a/shortcuts/im/im_messages_resources_download.go +++ b/shortcuts/im/im_messages_resources_download.go @@ -95,13 +95,14 @@ func normalizeDownloadOutputPath(fileKey, outputPath string) (string, error) { if outputPath == "" { return fileKey, nil } - outputPath = filepath.Clean(strings.TrimSpace(outputPath)) + outputPath = strings.TrimSpace(outputPath) + if filepath.IsAbs(outputPath) || strings.HasPrefix(outputPath, "/") || strings.HasPrefix(outputPath, "\\") { + return "", fmt.Errorf("absolute paths are not allowed") + } + outputPath = filepath.Clean(outputPath) if outputPath == "." { return "", fmt.Errorf("path cannot be empty") } - if filepath.IsAbs(outputPath) { - return "", fmt.Errorf("absolute paths are not allowed") - } if outputPath == ".." || strings.HasPrefix(outputPath, ".."+string(filepath.Separator)) { return "", fmt.Errorf("path cannot escape the current working directory") } @@ -156,6 +157,8 @@ type rangeChunkReader struct { totalSize int64 delivered int64 current io.ReadCloser + chunkWant int64 + chunkRead int64 nextOffset int64 } @@ -174,6 +177,7 @@ func newRangeChunkReader( fileType: fileType, totalSize: totalSize, current: probeBody, + chunkWant: min(probeChunkSize, totalSize), nextOffset: probeChunkSize, } } @@ -183,6 +187,7 @@ func (r *rangeChunkReader) Read(p []byte) (int, error) { if r.current != nil { n, err := r.current.Read(p) r.delivered += int64(n) + r.chunkRead += int64(n) if r.delivered > r.totalSize { if err == io.EOF { @@ -194,11 +199,24 @@ func (r *rangeChunkReader) Read(p []byte) (int, error) { } return 0, output.ErrNetwork("chunk overflow: delivered %d, expected %d", r.delivered, r.totalSize) } + if r.chunkRead > r.chunkWant { + _ = r.current.Close() + r.current = nil + return 0, output.ErrNetwork("chunk size mismatch: expected %d bytes for current range, got more than %d", r.chunkWant, r.chunkWant) + } switch err { case nil: return n, nil case io.EOF: + if r.chunkRead != r.chunkWant { + closeErr := r.current.Close() + r.current = nil + if closeErr != nil { + return n, closeErr + } + return 0, output.ErrNetwork("chunk size mismatch: expected %d bytes for current range, got %d", r.chunkWant, r.chunkRead) + } closeErr := r.current.Close() r.current = nil if closeErr != nil { @@ -211,8 +229,12 @@ func (r *rangeChunkReader) Read(p []byte) (int, error) { return 0, io.EOF } if n > 0 { + r.chunkRead = 0 + r.chunkWant = 0 return n, nil } + r.chunkRead = 0 + r.chunkWant = 0 default: return n, err } @@ -240,8 +262,18 @@ func (r *rangeChunkReader) Read(p []byte) (int, error) { resp.Body.Close() return 0, output.ErrNetwork("unexpected status code: %d", resp.StatusCode) } + if err := validateContentRange(resp.Header.Get("Content-Range"), contentRange{ + start: r.nextOffset, + end: end, + total: r.totalSize, + }); err != nil { + resp.Body.Close() + return 0, output.ErrNetwork("invalid Content-Range header on range response: %s", err) + } r.current = resp.Body + r.chunkWant = end - r.nextOffset + 1 + r.chunkRead = 0 r.nextOffset = end + 1 } } @@ -286,13 +318,26 @@ func downloadIMResourceToPath(ctx context.Context, runtime *common.RuntimeContex ) switch downloadResp.StatusCode { case http.StatusPartialContent: - totalSize, err := parseTotalSize(downloadResp.Header.Get("Content-Range")) + firstRange, err := parseContentRange(downloadResp.Header.Get("Content-Range")) if err != nil { downloadResp.Body.Close() return "", 0, output.ErrNetwork("invalid Content-Range header on range response: %s", err) } - body = newRangeChunkReader(ctx, runtime, messageID, fileKey, fileType, downloadResp.Body, totalSize) - sizeBytes = totalSize + if firstRange.start != 0 { + downloadResp.Body.Close() + return "", 0, output.ErrNetwork("unexpected initial Content-Range: got %s, want start 0", firstRange) + } + wantFirstRange := contentRange{ + start: 0, + end: min(probeChunkSize-1, firstRange.total-1), + total: firstRange.total, + } + if firstRange != wantFirstRange { + downloadResp.Body.Close() + return "", 0, output.ErrNetwork("unexpected initial Content-Range: got %s, want %s", firstRange, wantFirstRange) + } + body = newRangeChunkReader(ctx, runtime, messageID, fileKey, fileType, downloadResp.Body, firstRange.total) + sizeBytes = firstRange.total case http.StatusOK: body = downloadResp.Body @@ -436,32 +481,72 @@ func downloadResponseError(resp *http.Response) error { return output.ErrNetwork("download failed: HTTP %d", resp.StatusCode) } -func parseTotalSize(contentRange string) (int64, error) { - contentRange = strings.TrimSpace(contentRange) - if contentRange == "" { - return 0, fmt.Errorf("content-range is empty") +type contentRange struct { + start int64 + end int64 + total int64 +} + +func (cr contentRange) String() string { + return fmt.Sprintf("bytes %d-%d/%d", cr.start, cr.end, cr.total) +} + +func validateContentRange(header string, want contentRange) error { + got, err := parseContentRange(header) + if err != nil { + return err + } + if got != want { + return fmt.Errorf("unexpected Content-Range: got %s, want %s", got, want) + } + return nil +} + +func parseContentRange(header string) (contentRange, error) { + header = strings.TrimSpace(header) + if header == "" { + return contentRange{}, fmt.Errorf("content-range is empty") } - if !strings.HasPrefix(contentRange, "bytes ") { - return 0, fmt.Errorf("unsupported content-range: %q", contentRange) + if !strings.HasPrefix(header, "bytes ") { + return contentRange{}, fmt.Errorf("unsupported content-range: %q", header) } - parts := strings.SplitN(strings.TrimPrefix(contentRange, "bytes "), "/", 2) - if len(parts) != 2 || parts[1] == "" { - return 0, fmt.Errorf("unsupported content-range: %q", contentRange) + parts := strings.SplitN(strings.TrimPrefix(header, "bytes "), "/", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return contentRange{}, fmt.Errorf("unsupported content-range: %q", header) } if parts[0] == "*" { - return 0, fmt.Errorf("unsupported content-range: %q", contentRange) + return contentRange{}, fmt.Errorf("unsupported content-range: %q", header) } if parts[1] == "*" { - return 0, fmt.Errorf("unknown total size in content-range: %q", contentRange) + return contentRange{}, fmt.Errorf("unknown total size in content-range: %q", header) + } + + bounds := strings.SplitN(parts[0], "-", 2) + if len(bounds) != 2 || bounds[0] == "" || bounds[1] == "" { + return contentRange{}, fmt.Errorf("unsupported content-range: %q", header) } - totalSize, err := strconv.ParseInt(parts[1], 10, 64) + start, err := strconv.ParseInt(bounds[0], 10, 64) + if err != nil { + return contentRange{}, fmt.Errorf("parse range start: %w", err) + } + end, err := strconv.ParseInt(bounds[1], 10, 64) if err != nil { - return 0, fmt.Errorf("parse total size: %w", err) + return contentRange{}, fmt.Errorf("parse range end: %w", err) + } + total, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return contentRange{}, fmt.Errorf("parse total size: %w", err) + } + if total <= 0 { + return contentRange{}, fmt.Errorf("invalid total size: %d", total) + } + if start > end { + return contentRange{}, fmt.Errorf("invalid content range: start %d is after end %d", start, end) } - if totalSize <= 0 { - return 0, fmt.Errorf("invalid total size: %d", totalSize) + if end >= total { + return contentRange{}, fmt.Errorf("invalid content range: end %d is outside total %d", end, total) } - return totalSize, nil + return contentRange{start: start, end: end, total: total}, nil }