# -*- coding: utf-8 -*-
# @Time : 2025/3/22 21:30
# @Author : sjh
# @Site :
# @File : infer2.py
# @Comment :
import argparse
import os
import torch
import torch.nn as nn
import cv2
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from infer import show_depth_point
from models import __models__
from utils import tensor2numpy, make_nograd_func # 假设这些函数在utils中定义
# 设置可见的GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# 解析命令行参数
def parse_args():
parser = argparse.ArgumentParser(description='Accurate and Efficient Stereo Matching via Attention Concatenation Volume (Fast-ACV)')
parser.add_argument('--model', default='Fast_ACVNet_plus', choices=__models__.keys(), help='select a model structure')
parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity')
parser.add_argument('--left_img', default='111/im0.png', help='path to the left image')
parser.add_argument('--right_img', default='111/im1.png', help='path to the right image')
parser.add_argument('--loadckpt', default='sceneflow.ckpt', help='load the weights from a specific checkpoint')
return parser.parse_args()
# 图像预处理函数
def get_transform():
"""
返回一个包含图像预处理的transform
:return: 图像预处理转换操作
"""
mean = [0.485, 0.456, 0.406] # 通常用于ImageNet的预训练模型
std = [0.229, 0.224, 0.225]
return transforms.Compose([
transforms.ToTensor(), # 将PIL图像转换为Tensor
transforms.Normalize(mean=mean, std=std), # 进行标准化
])
def preprocess_image(img_path, target_size=(640, 352)):
"""
读取图像并进行预处理,包括调整大小、转换为RGB和标准化。
:param img_path: 图像文件路径
:param target_size: 目标图像大小
:return: 预处理后的图像张量
"""
# 读取图像并转换为RGB格式
img = Image.open(img_path).convert('RGB')
img = img.resize(target_size) # 调整大小
transform = get_transform() # 获取预处理的transform
img = transform(img)
return img.unsqueeze(0).cuda() # 增加batch维度并移至GPU
# 加载模型和预训练权重
def load_model(model_name, maxdisp, checkpoint_path):
"""
加载指定的模型及其预训练权重
:param model_name: 模型名称
:param maxdisp: 最大视差值
:param checkpoint_path: 预训练权重文件路径
:return: 加载的模型
"""
model = __models__[model_name](maxdisp, False) # 实例化模型
model = nn.DataParallel(model).cuda() # 使用DataParallel并移至GPU
print(f"Loading model from {checkpoint_path}...")
state_dict = torch.load(checkpoint_path) # 加载预训练权重
model.load_state_dict(state_dict['model']) # 加载权重
return model
# 推理函数
def infer(left_img_path, right_img_path, model, save_dir='./output'):
"""
进行视差图的推理,保存结果
:param left_img_path: 左图像路径
:param right_img_path: 右图像路径
:param model: 已加载的模型
:param save_dir: 保存结果的目录
"""
# 预处理图像
left_img = preprocess_image(left_img_path)
right_img = preprocess_image(right_img_path)
print(left_img.shape)
# 进行推理
disp_est = test_sample(left_img, right_img, model)
# 将视差图转换为numpy数组
disp_est_np = tensor2numpy(disp_est).squeeze()
show_depth_point(disp_est_np, cv2.imread(left_img_path))
# 推理单个样本
@make_nograd_func
def test_sample(left_img, right_img, model):
"""
在模型上运行推理,返回最终的视差图。
:param left_img: 左图像张量
:param right_img: 右图像张量
:param model: 模型
:return: 视差图
"""
model.eval() # 设置模型为评估模式
disp_ests = model(left_img, right_img)
return disp_ests[-1] # 返回最后一个视差图
if __name__ == '__main__':
# 解析命令行参数
args = parse_args()
# 加载模型和预训练权重
model = load_model(args.model, args.maxdisp, args.loadckpt)
# 进行推理并保存结果
infer(args.left_img, args.right_img, model)