diff --git a/core/storage.py b/core/storage.py index 0e820bbc..4c6b7334 100644 --- a/core/storage.py +++ b/core/storage.py @@ -23,7 +23,7 @@ from core.settings import data_root, settings from apps.base.models import FileCodes, UploadChunk from core.utils import get_file_url, sanitize_filename -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse, StreamingResponse class FileStorageInterface: @@ -144,10 +144,20 @@ async def get_file_response(self, file_code: FileCodes): filename = f"{file_code.prefix}{file_code.suffix}" encoded_filename = quote(filename, safe='') content_disposition = f"attachment; filename*=UTF-8''{encoded_filename}" + + # 尝试获取文件系统大小,如果成功则设置 Content-Length + headers = {"Content-Disposition": content_disposition} + try: + content_length = file_path.stat().st_size + headers["Content-Length"] = str(content_length) + except Exception: + # 如果获取文件大小失败,则不提供 Content-Length + pass + return FileResponse( file_path, media_type="application/octet-stream", - headers={"Content-Disposition": content_disposition}, + headers=headers, filename=filename # 保留原始文件名以备某些场景使用 ) @@ -296,12 +306,29 @@ async def delete_file(self, file_code: FileCodes): async def get_file_response(self, file_code: FileCodes): try: filename = file_code.prefix + file_code.suffix + content_length = None # 初始化为 None,表示未知大小 + async with self.session.client( "s3", endpoint_url=self.endpoint_url, region_name=self.region_name, config=Config(signature_version=self.signature_version), ) as s3: + # 尝试获取文件大小(HEAD请求) + try: + head_response = await s3.head_object( + Bucket=self.bucket_name, + Key=await file_code.get_file_path() + ) + # 从HEAD响应中获取Content-Length + if 'ContentLength' in head_response: + content_length = head_response['ContentLength'] + elif 'Content-Length' in head_response['ResponseMetadata']['HTTPHeaders']: + content_length = int(head_response['ResponseMetadata']['HTTPHeaders']['Content-Length']) + except Exception: + # 如果HEAD请求失败,则不提供 Content-Length + pass + link = await s3.generate_presigned_url( "get_object", Params={ @@ -310,20 +337,42 @@ async def get_file_response(self, file_code: FileCodes): }, ExpiresIn=3600, ) - tmp = io.BytesIO() - async with aiohttp.ClientSession() as session: - async with session.get(link) as resp: - tmp.write(await resp.read()) - tmp.seek(0) - content = tmp.read() - tmp.close() - return Response( - content, + + # 创建ClientSession并传递给生成器复用 + session = aiohttp.ClientSession() + + async def stream_generator(): + try: + async with session.get(link) as resp: + if resp.status != 200: + raise HTTPException( + status_code=resp.status, + detail=f"从S3获取文件失败: {resp.status}" + ) + # 设置块大小(例如64KB) + chunk_size = 65536 + while True: + chunk = await resp.content.read(chunk_size) + if not chunk: + break + yield chunk + finally: + await session.close() + + from fastapi.responses import StreamingResponse + encoded_filename = quote(filename, safe='') + headers = { + "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}" + } + if content_length is not None: + headers["Content-Length"] = str(content_length) + return StreamingResponse( + stream_generator(), media_type="application/octet-stream", - headers={ - "Content-Disposition": f'attachment; filename="{filename.encode("utf-8").decode("latin-1")}"' - }, + headers=headers ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=503, detail="服务代理下载异常,请稍后再试") @@ -602,20 +651,51 @@ async def get_file_response(self, file_code: FileCodes): link = await asyncio.to_thread( self._get_file_url, await file_code.get_file_path(), filename ) - tmp = io.BytesIO() - async with aiohttp.ClientSession() as session: - async with session.get(link) as resp: - tmp.write(await resp.read()) - tmp.seek(0) - content = tmp.read() - tmp.close() - return Response( - content, + + content_length = None # 初始化为 None,表示未知大小 + + # 创建ClientSession并复用 + session = aiohttp.ClientSession() + + # 尝试发送HEAD请求获取Content-Length + try: + async with session.head(link) as resp: + if resp.status == 200 and 'Content-Length' in resp.headers: + content_length = int(resp.headers['Content-Length']) + except Exception: + # 如果HEAD请求失败,则不提供 Content-Length + pass + + async def stream_generator(): + try: + async with session.get(link) as resp: + if resp.status != 200: + raise HTTPException( + status_code=resp.status, + detail=f"从OneDrive获取文件失败: {resp.status}" + ) + chunk_size = 65536 + while True: + chunk = await resp.content.read(chunk_size) + if not chunk: + break + yield chunk + finally: + await session.close() + + encoded_filename = quote(filename, safe='') + headers = { + "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}" + } + if content_length is not None: + headers["Content-Length"] = str(content_length) + return StreamingResponse( + stream_generator(), media_type="application/octet-stream", - headers={ - "Content-Disposition": f'attachment; filename="{filename.encode("utf-8").decode("latin-1")}"' - }, + headers=headers ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=503, detail="服务代理下载异常,请稍后再试") @@ -776,11 +856,54 @@ async def get_file_url(self, file_code: FileCodes): async def get_file_response(self, file_code: FileCodes): try: filename = file_code.prefix + file_code.suffix - content = await self.operator.read(await file_code.get_file_path()) + content_length = None # 初始化为 None,表示未知大小 + + # 尝试获取文件大小 + try: + stat_result = await self.operator.stat(await file_code.get_file_path()) + if hasattr(stat_result, 'content_length') and stat_result.content_length: + content_length = stat_result.content_length + elif hasattr(stat_result, 'size') and stat_result.size: + content_length = stat_result.size + except Exception: + # 如果获取大小失败,则不提供 Content-Length + pass + + # 尝试使用流式读取器 + try: + # OpenDAL 可能提供 reader 方法返回一个异步读取器 + reader = await self.operator.reader(await file_code.get_file_path()) + except AttributeError: + # 如果 reader 方法不存在,回退到全量读取(兼容旧版本) + content = await self.operator.read(await file_code.get_file_path()) + encoded_filename = quote(filename, safe='') + headers = { + "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}" + } + if content_length is not None: + headers["Content-Length"] = str(content_length) + return Response( + content, headers=headers, media_type="application/octet-stream" + ) + + async def stream_generator(): + chunk_size = 65536 + while True: + chunk = await reader.read(chunk_size) + if not chunk: + break + yield chunk + + encoded_filename = quote(filename, safe='') headers = { - "Content-Disposition": f'attachment; filename="{filename}"'} - return Response( - content, headers=headers, media_type="application/octet-stream" + "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}" + } + if content_length is not None: + headers["Content-Length"] = str(content_length) + return StreamingResponse( + stream_generator(), + media_type="application/octet-stream", + headers=headers ) except Exception as e: logger.info(e) @@ -969,26 +1092,50 @@ async def get_file_response(self, file_code: FileCodes): try: filename = file_code.prefix + file_code.suffix url = self._build_url(await file_code.get_file_path()) - async with aiohttp.ClientSession(headers={ + content_length = None # 初始化为 None,表示未知大小 + + # 创建ClientSession并复用(包含认证头) + session = aiohttp.ClientSession(headers={ "Authorization": f"Basic {base64.b64encode(f'{settings.webdav_username}:{settings.webdav_password}'.encode()).decode()}" - }) as session: - async with session.get(url) as resp: - if resp.status != 200: - raise HTTPException( - status_code=resp.status, - detail=f"文件获取失败{resp.status}: {await resp.text()}", - ) - # 读取内容到内存 - content = await resp.read() - return Response( - content=content, - media_type=resp.headers.get( - "Content-Type", "application/octet-stream" - ), - headers={ - "Content-Disposition": f'attachment; filename="{filename.encode("utf-8").decode()}"' - }, - ) + }) + + # 尝试发送HEAD请求获取Content-Length + try: + async with session.head(url) as resp: + if resp.status == 200 and 'Content-Length' in resp.headers: + content_length = int(resp.headers['Content-Length']) + except Exception: + # 如果HEAD请求失败,则不提供 Content-Length + pass + + async def stream_generator(): + try: + async with session.get(url) as resp: + if resp.status != 200: + raise HTTPException( + status_code=resp.status, + detail=f"文件获取失败{resp.status}: {await resp.text()}", + ) + chunk_size = 65536 + while True: + chunk = await resp.content.read(chunk_size) + if not chunk: + break + yield chunk + finally: + await session.close() + + encoded_filename = quote(filename, safe='') + headers = { + "Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}" + } + if content_length is not None: + headers["Content-Length"] = str(content_length) + return StreamingResponse( + stream_generator(), + media_type="application/octet-stream", + headers=headers + ) except aiohttp.ClientError as e: raise HTTPException( status_code=503, detail=f"WebDAV连接异常: {str(e)}")