Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 196 additions & 49 deletions core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from core.settings import data_root, settings
from apps.base.models import FileCodes, UploadChunk
from core.utils import get_file_url, sanitize_filename
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse, StreamingResponse


class FileStorageInterface:
Expand Down Expand Up @@ -144,10 +144,20 @@ async def get_file_response(self, file_code: FileCodes):
filename = f"{file_code.prefix}{file_code.suffix}"
encoded_filename = quote(filename, safe='')
content_disposition = f"attachment; filename*=UTF-8''{encoded_filename}"

# 尝试获取文件系统大小,如果成功则设置 Content-Length
headers = {"Content-Disposition": content_disposition}
try:
content_length = file_path.stat().st_size
headers["Content-Length"] = str(content_length)
except Exception:
# 如果获取文件大小失败,则不提供 Content-Length
pass

return FileResponse(
file_path,
media_type="application/octet-stream",
headers={"Content-Disposition": content_disposition},
headers=headers,
filename=filename # 保留原始文件名以备某些场景使用
)

Expand Down Expand Up @@ -296,12 +306,29 @@ async def delete_file(self, file_code: FileCodes):
async def get_file_response(self, file_code: FileCodes):
try:
filename = file_code.prefix + file_code.suffix
content_length = None # 初始化为 None,表示未知大小

async with self.session.client(
"s3",
endpoint_url=self.endpoint_url,
region_name=self.region_name,
config=Config(signature_version=self.signature_version),
) as s3:
# 尝试获取文件大小(HEAD请求)
try:
head_response = await s3.head_object(
Bucket=self.bucket_name,
Key=await file_code.get_file_path()
)
# 从HEAD响应中获取Content-Length
if 'ContentLength' in head_response:
content_length = head_response['ContentLength']
elif 'Content-Length' in head_response['ResponseMetadata']['HTTPHeaders']:
content_length = int(head_response['ResponseMetadata']['HTTPHeaders']['Content-Length'])
except Exception:
# 如果HEAD请求失败,则不提供 Content-Length
pass

link = await s3.generate_presigned_url(
"get_object",
Params={
Expand All @@ -310,20 +337,42 @@ async def get_file_response(self, file_code: FileCodes):
},
ExpiresIn=3600,
)
tmp = io.BytesIO()
async with aiohttp.ClientSession() as session:
async with session.get(link) as resp:
tmp.write(await resp.read())
tmp.seek(0)
content = tmp.read()
tmp.close()
return Response(
content,

# 创建ClientSession并传递给生成器复用
session = aiohttp.ClientSession()

async def stream_generator():
try:
async with session.get(link) as resp:
if resp.status != 200:
raise HTTPException(
status_code=resp.status,
detail=f"从S3获取文件失败: {resp.status}"
)
# 设置块大小(例如64KB)
chunk_size = 65536
while True:
chunk = await resp.content.read(chunk_size)
if not chunk:
break
yield chunk
finally:
await session.close()

from fastapi.responses import StreamingResponse
encoded_filename = quote(filename, safe='')
headers = {
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
}
if content_length is not None:
headers["Content-Length"] = str(content_length)
return StreamingResponse(
stream_generator(),
media_type="application/octet-stream",
headers={
"Content-Disposition": f'attachment; filename="{filename.encode("utf-8").decode("latin-1")}"'
},
headers=headers
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=503, detail="服务代理下载异常,请稍后再试")

Expand Down Expand Up @@ -602,20 +651,51 @@ async def get_file_response(self, file_code: FileCodes):
link = await asyncio.to_thread(
self._get_file_url, await file_code.get_file_path(), filename
)
tmp = io.BytesIO()
async with aiohttp.ClientSession() as session:
async with session.get(link) as resp:
tmp.write(await resp.read())
tmp.seek(0)
content = tmp.read()
tmp.close()
return Response(
content,

content_length = None # 初始化为 None,表示未知大小

# 创建ClientSession并复用
session = aiohttp.ClientSession()

# 尝试发送HEAD请求获取Content-Length
try:
async with session.head(link) as resp:
if resp.status == 200 and 'Content-Length' in resp.headers:
content_length = int(resp.headers['Content-Length'])
except Exception:
# 如果HEAD请求失败,则不提供 Content-Length
pass

async def stream_generator():
try:
async with session.get(link) as resp:
if resp.status != 200:
raise HTTPException(
status_code=resp.status,
detail=f"从OneDrive获取文件失败: {resp.status}"
)
chunk_size = 65536
while True:
chunk = await resp.content.read(chunk_size)
if not chunk:
break
yield chunk
finally:
await session.close()

encoded_filename = quote(filename, safe='')
headers = {
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
}
if content_length is not None:
headers["Content-Length"] = str(content_length)
return StreamingResponse(
stream_generator(),
media_type="application/octet-stream",
headers={
"Content-Disposition": f'attachment; filename="{filename.encode("utf-8").decode("latin-1")}"'
},
headers=headers
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=503, detail="服务代理下载异常,请稍后再试")

Expand Down Expand Up @@ -776,11 +856,54 @@ async def get_file_url(self, file_code: FileCodes):
async def get_file_response(self, file_code: FileCodes):
try:
filename = file_code.prefix + file_code.suffix
content = await self.operator.read(await file_code.get_file_path())
content_length = None # 初始化为 None,表示未知大小

# 尝试获取文件大小
try:
stat_result = await self.operator.stat(await file_code.get_file_path())
if hasattr(stat_result, 'content_length') and stat_result.content_length:
content_length = stat_result.content_length
elif hasattr(stat_result, 'size') and stat_result.size:
content_length = stat_result.size
except Exception:
# 如果获取大小失败,则不提供 Content-Length
pass

# 尝试使用流式读取器
try:
# OpenDAL 可能提供 reader 方法返回一个异步读取器
reader = await self.operator.reader(await file_code.get_file_path())
except AttributeError:
# 如果 reader 方法不存在,回退到全量读取(兼容旧版本)
content = await self.operator.read(await file_code.get_file_path())
encoded_filename = quote(filename, safe='')
headers = {
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
}
if content_length is not None:
headers["Content-Length"] = str(content_length)
return Response(
content, headers=headers, media_type="application/octet-stream"
)

async def stream_generator():
chunk_size = 65536
while True:
chunk = await reader.read(chunk_size)
if not chunk:
break
yield chunk

encoded_filename = quote(filename, safe='')
headers = {
"Content-Disposition": f'attachment; filename="{filename}"'}
return Response(
content, headers=headers, media_type="application/octet-stream"
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
}
if content_length is not None:
headers["Content-Length"] = str(content_length)
return StreamingResponse(
stream_generator(),
media_type="application/octet-stream",
headers=headers
)
except Exception as e:
logger.info(e)
Expand Down Expand Up @@ -969,26 +1092,50 @@ async def get_file_response(self, file_code: FileCodes):
try:
filename = file_code.prefix + file_code.suffix
url = self._build_url(await file_code.get_file_path())
async with aiohttp.ClientSession(headers={
content_length = None # 初始化为 None,表示未知大小

# 创建ClientSession并复用(包含认证头)
session = aiohttp.ClientSession(headers={
"Authorization": f"Basic {base64.b64encode(f'{settings.webdav_username}:{settings.webdav_password}'.encode()).decode()}"
}) as session:
async with session.get(url) as resp:
if resp.status != 200:
raise HTTPException(
status_code=resp.status,
detail=f"文件获取失败{resp.status}: {await resp.text()}",
)
# 读取内容到内存
content = await resp.read()
return Response(
content=content,
media_type=resp.headers.get(
"Content-Type", "application/octet-stream"
),
headers={
"Content-Disposition": f'attachment; filename="{filename.encode("utf-8").decode()}"'
},
)
})

# 尝试发送HEAD请求获取Content-Length
try:
async with session.head(url) as resp:
if resp.status == 200 and 'Content-Length' in resp.headers:
content_length = int(resp.headers['Content-Length'])
except Exception:
# 如果HEAD请求失败,则不提供 Content-Length
pass

async def stream_generator():
try:
async with session.get(url) as resp:
if resp.status != 200:
raise HTTPException(
status_code=resp.status,
detail=f"文件获取失败{resp.status}: {await resp.text()}",
)
chunk_size = 65536
while True:
chunk = await resp.content.read(chunk_size)
if not chunk:
break
yield chunk
finally:
await session.close()

encoded_filename = quote(filename, safe='')
headers = {
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
}
if content_length is not None:
headers["Content-Length"] = str(content_length)
return StreamingResponse(
stream_generator(),
media_type="application/octet-stream",
headers=headers
)
except aiohttp.ClientError as e:
raise HTTPException(
status_code=503, detail=f"WebDAV连接异常: {str(e)}")
Expand Down