import os
import cv2
import math
import yaml
import torch
import numpy as np
import argparse
import pathlib
import logging
from collections import OrderedDict
from PIL import Image
from pathlib import Path
from torchvision.utils import make_grid
from skimage.metrics import structural_similarity as ssim
from xfusion.inference.model.edvr_models import EDVRSTFTempRank
from xfusion.train.basicsr.data.xray_dataset import XrayVideoTestDatasetSTF
from xfusion.inference.dataset.dist_util import get_dist_info
from xfusion.train.basicsr.utils import get_root_logger
from xfusion.utils import yaml_load
from xfusion import config
from xfusion.train.basicsr.utils import tensor2img
[docs]
def inference_pipeline(args):
# To be changed:
# - all arguments are passed as args as set in the xfusion/config.py file
# - relative paths are removed
os.chdir(Path(__file__).parent)
b0 = int(args.b0)
topk_times = 5
lo_frame_sep = int(args.lo_frame_sep)
hi_frame_sep = int(args.hi_frame_sep)
img_class = str(args.img_class)
#opt_path = args.opt + '/' + rf'config_{img_class}.yml'
opt_path = args.opt
# builds and runs correctly up to here. Please adjust below
opt = yaml_load(opt_path)
opt['model_type'] = args.model_type
opt['path']['pretrain_network_g'] = args.model_file
test_set_name = opt['name']
gt_size = opt['datasets']['val']['gt_size']
dataroot = config.get_inf_data_dirs(img_class)
inf_home_dir = Path(dataroot).parent
out_dir = inf_home_dir / f'{test_set_name}_stf_lr_r_{lo_frame_sep}_hr_d_{hi_frame_sep*2}_b0_{b0}'
out_dir.mkdir(exist_ok=True,parents=True)
log_file = os.path.join(out_dir, f"inference.log")
logger = get_root_logger(logger_name=__name__, log_level=logging.INFO, log_file=log_file)
logger.info(f'inference under {args.infer_mode} mode')
logger.info(f'LR frame separation is {lo_frame_sep}')
logger.info(f'HR frame separation is {hi_frame_sep}')
logger.info(f"path to config file is: {opt_path}")
logger.info(f'inference data dir is: {dataroot}')
opt['datasets']['val']['dataroot_lq'] = [os.path.join(dataroot,'LR')]
opt['datasets']['val']['dataroot_gt'] = [os.path.join(dataroot,'HR')]
# default input file structure is: */dataset[n]/HR and */dataset[n]/LR
# the postfixes are appended to the parent directory
opt['datasets']['val']['dataroot_lq'] = [dr+f'_b0_{b0}' if Path(dr).name.lower() != 'lr' else str(Path(dr).parents[0])+f'_b0_{b0}'+f"/{Path(dr).name}" for dr in opt['datasets']['val']['dataroot_lq']]
opt['datasets']['val']['dataroot_gt'] = [dr+f'_b0_{0}' if Path(dr).name.lower() != 'hr' else str(Path(dr).parents[0])+f'_b0_{b0}'+f"/{Path(dr).name}" for dr in opt['datasets']['val']['dataroot_gt']]
logger.info(f"data paths are: {opt['datasets']['val']['dataroot_lq']} for LR and {opt['datasets']['val']['dataroot_gt']} for HR")
opt['dist'] = False
opt['manual_seed'] = 10
torch.manual_seed(opt['manual_seed'])
model_config = opt['network_g']
model_config['num_frame'] = topk_times
model_config['num_frame_hi'] = 0
model_config['center_frame_idx'] = 1
model = EDVRSTFTempRank(**model_config)
model.load_state_dict(torch.load(os.path.join(str(pathlib.Path.home()),opt['path']['pretrain_network_g']))['params'])
try:
model.cuda()
except:
logger.info("no gpu detected")
dataset_opt = opt['datasets']['val']
dataset_opt['scale'] = 4
dataset_opt['gt_size'] = gt_size[:2]
logger.info(dataset_opt)
test_set = XrayVideoTestDatasetSTF(dataset_opt)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=1, shuffle=False, num_workers=0)
dataset = test_loader.dataset
rank, world_size = get_dist_info()
psnrs, aads, atts, ssims, masks = [], [], [], [], []
for idx in range(rank+1, len(dataset), world_size):
results = []
val_data_ = dataset[idx]
val_data_['lq'].unsqueeze_(0)
if 'gt' in list(val_data_.keys()):
val_data_['gt'].unsqueeze_(0)
else:
val_data_['image'].unsqueeze_(0)
if 'gt' in list(val_data_.keys()):
val_data = {'lq':val_data_['lq'].cuda(), 'gt':val_data_['gt'].cuda(), 'hq':val_data_['hq'][None,:,:,:].cuda()}
else:
val_data = {'lq':val_data_['lq'].cuda(), 'image':val_data_['image'].cuda(), 'hq':val_data_['hq'][None,:,:,:].cuda()}
val_data['lq'] = torch.cat((dataset[max(0,int(idx-lo_frame_sep))]['lq'][model_config['center_frame_idx'],:,:,:].unsqueeze(0),\
val_data['lq'][0,model_config['center_frame_idx'],:,:,:].unsqueeze(0).cpu(),\
dataset[min(len(dataset)-1,int(idx+lo_frame_sep))]['lq'][model_config['center_frame_idx'],:,:,:].unsqueeze(0)),dim=0).unsqueeze(0).cuda()
if 'gt' in list(dataset[max(0,int(idx//(hi_frame_sep*2)*hi_frame_sep*2))].keys()):
gt_key = 'gt'
else:
gt_key = 'image'
if len(dataset[max(0,int(idx//(hi_frame_sep*2)*hi_frame_sep*2))][gt_key].size()) == 4:
val_data['hq'] = torch.cat((dataset[max(0,int(idx//(hi_frame_sep*2)*hi_frame_sep*2))][gt_key][0,:,:,:].unsqueeze(0), dataset[min(gt_size[2],int((idx//(hi_frame_sep*2)+1)*hi_frame_sep*2))][gt_key][0,:,:,:].unsqueeze(0))).unsqueeze(0).cuda()
elif len(dataset[max(0,int(idx//(hi_frame_sep*2)*hi_frame_sep*2))][gt_key].size()) == 3:
val_data['hq'] = torch.cat((dataset[max(0,int(idx//(hi_frame_sep*2)*hi_frame_sep*2))][gt_key].unsqueeze(0), dataset[min(gt_size[2],int((idx//(hi_frame_sep*2)+1)*hi_frame_sep*2))][gt_key].unsqueeze(0))).unsqueeze(0).cuda()
if (idx == max(0,int(idx//(hi_frame_sep*2)*hi_frame_sep*2))) or (idx == min(len(dataset)-1,int(idx+lo_frame_sep))):
masks.append(1)
else:
masks.append(0)
with torch.no_grad():
results = model(val_data)
result, corr = results['out'], results['corr_score']
result_img = (tensor2img(result).astype(float))
result_img[result_img>255] = 255
if 'gt' in list(val_data.keys()):
gt = val_data['gt'].detach().cpu()
else:
gt = val_data['image'].detach().cpu()
hi_img = tensor2img(gt).astype(float)
if len(hi_img.shape) == 3:
hi_img = hi_img[:,:,0]
elif len(hi_img.shape) == 2:
pass
diff = (result_img - hi_img)
mse = np.mean((diff)**2)
psnr = 10. * np.log10(255. * 255. / mse)
aad = abs(diff).mean()
_ssim = ssim(result_img,hi_img, data_range=255)
logger.info(f"psnr is {psnr} dB")
logger.info(f"aad is {aad}")
logger.info(f"ssim is {_ssim}")
psnrs.append(psnr)
aads.append(aad)
ssims.append(_ssim)
logger.info(f"attention high is {corr.detach().cpu().squeeze().numpy()[-1]}")
atts.append(corr.detach().cpu().squeeze().numpy())
Image.fromarray(result_img.astype(np.uint8)).save(f"{(out_dir / Path(val_data_['lq_path']).stem)}_{psnr}.png")
att_all = np.vstack(atts)
if att_all.shape[1] == 4:
result_dict = {'psnr':psnrs, 'aad': aads, 'ssim': ssims, 't-1 lo':att_all[:,0], 't lo': att_all[:,1], 't+1 lo': att_all[:,2], 't hi': att_all[:,3]}
elif att_all.shape[1] == 5:
result_dict = {'psnr':psnrs, 'aad': aads, 'ssim': ssims, 't-1 lo':att_all[:,0], 't lo': att_all[:,1], 't+1 lo': att_all[:,2], 't-1 hi': att_all[:,3], 't+1 hi': att_all[:,4]}
if masks:
result_dict['mask'] = masks
import pandas as pd
pd.DataFrame(result_dict).to_csv(out_dir / f'error_{test_set_name}_stf_lr_r_{lo_frame_sep}_hr_d_{2*hi_frame_sep}_b0_{b0}.csv')
logger.info('done')