Shortcuts

Source code for mmseg.apis.test

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import warnings

import mmcv
import numpy as np
import torch
from mmcv.engine import collect_results_cpu, collect_results_gpu
from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info


def np2tmp(array, temp_file_name=None, tmpdir=None):
    """Save ndarray to local numpy file.

    Args:
        array (ndarray): Ndarray to save.
        temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
            function will generate a file name with tempfile.NamedTemporaryFile
            to save ndarray. Default: None.
        tmpdir (str): Temporary directory to save Ndarray files. Default: None.
    Returns:
        str: The numpy file name.
    """

    if temp_file_name is None:
        temp_file_name = tempfile.NamedTemporaryFile(
            suffix='.npy', delete=False, dir=tmpdir).name
    np.save(temp_file_name, array)
    return temp_file_name


[docs]def single_gpu_test(model, data_loader, show=False, out_dir=None, efficient_test=False, opacity=0.5, pre_eval=False, format_only=False, format_args={}): """Test with single GPU by progressive mode. Args: model (nn.Module): Model to be tested. data_loader (utils.data.Dataloader): Pytorch data loader. show (bool): Whether show results during inference. Default: False. out_dir (str, optional): If specified, the results will be dumped into the directory to save output results. efficient_test (bool): Whether save the results as local numpy files to save CPU memory during evaluation. Mutually exclusive with pre_eval and format_results. Default: False. opacity(float): Opacity of painted segmentation map. Default 0.5. Must be in (0, 1] range. pre_eval (bool): Use dataset.pre_eval() function to generate pre_results for metric evaluation. Mutually exclusive with efficient_test and format_results. Default: False. format_only (bool): Only format result for results commit. Mutually exclusive with pre_eval and efficient_test. Default: False. format_args (dict): The args for format_results. Default: {}. Returns: list: list of evaluation pre-results or list of save file names. """ if efficient_test: warnings.warn( 'DeprecationWarning: ``efficient_test`` will be deprecated, the ' 'evaluation is CPU memory friendly with pre_eval=True') mmcv.mkdir_or_exist('.efficient_test') # when none of them is set true, return segmentation results as # a list of np.array. assert [efficient_test, pre_eval, format_only].count(True) <= 1, \ '``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \ 'exclusive, only one of them could be true .' model.eval() results = [] dataset = data_loader.dataset prog_bar = mmcv.ProgressBar(len(dataset)) # The pipeline about how the data_loader retrieval samples from dataset: # sampler -> batch_sampler -> indices # The indices are passed to dataset_fetcher to get data from dataset. # data_fetcher -> collate_fn(dataset[index]) -> data_sample # we use batch_sampler to get correct data idx loader_indices = data_loader.batch_sampler for batch_indices, data in zip(loader_indices, data_loader): with torch.no_grad(): result = model(return_loss=False, **data) if show or out_dir: img_tensor = data['img'][0] img_metas = data['img_metas'][0].data[0] imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) assert len(imgs) == len(img_metas) for img, img_meta in zip(imgs, img_metas): h, w, _ = img_meta['img_shape'] img_show = img[:h, :w, :] ori_h, ori_w = img_meta['ori_shape'][:-1] img_show = mmcv.imresize(img_show, (ori_w, ori_h)) if out_dir: out_file = osp.join(out_dir, img_meta['ori_filename']) else: out_file = None model.module.show_result( img_show, result, palette=dataset.PALETTE, show=show, out_file=out_file, opacity=opacity) if efficient_test: result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] if format_only: result = dataset.format_results( result, indices=batch_indices, **format_args) if pre_eval: # TODO: adapt samples_per_gpu > 1. # only samples_per_gpu=1 valid now result = dataset.pre_eval(result, indices=batch_indices) results.extend(result) else: results.extend(result) batch_size = len(result) for _ in range(batch_size): prog_bar.update() return results
[docs]def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False, efficient_test=False, pre_eval=False, format_only=False, format_args={}): """Test model with multiple gpus by progressive mode. This method tests model with multiple gpus and collects the results under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' it encodes results to gpu tensors and use gpu communication for results collection. On cpu mode it saves the results on different gpus to 'tmpdir' and collects them by the rank 0 worker. Args: model (nn.Module): Model to be tested. data_loader (utils.data.Dataloader): Pytorch data loader. tmpdir (str): Path of directory to save the temporary results from different gpus under cpu mode. The same path is used for efficient test. Default: None. gpu_collect (bool): Option to use either gpu or cpu to collect results. Default: False. efficient_test (bool): Whether save the results as local numpy files to save CPU memory during evaluation. Mutually exclusive with pre_eval and format_results. Default: False. pre_eval (bool): Use dataset.pre_eval() function to generate pre_results for metric evaluation. Mutually exclusive with efficient_test and format_results. Default: False. format_only (bool): Only format result for results commit. Mutually exclusive with pre_eval and efficient_test. Default: False. format_args (dict): The args for format_results. Default: {}. Returns: list: list of evaluation pre-results or list of save file names. """ if efficient_test: warnings.warn( 'DeprecationWarning: ``efficient_test`` will be deprecated, the ' 'evaluation is CPU memory friendly with pre_eval=True') mmcv.mkdir_or_exist('.efficient_test') # when none of them is set true, return segmentation results as # a list of np.array. assert [efficient_test, pre_eval, format_only].count(True) <= 1, \ '``efficient_test``, ``pre_eval`` and ``format_only`` are mutually ' \ 'exclusive, only one of them could be true .' model.eval() results = [] dataset = data_loader.dataset # The pipeline about how the data_loader retrieval samples from dataset: # sampler -> batch_sampler -> indices # The indices are passed to dataset_fetcher to get data from dataset. # data_fetcher -> collate_fn(dataset[index]) -> data_sample # we use batch_sampler to get correct data idx # batch_sampler based on DistributedSampler, the indices only point to data # samples of related machine. loader_indices = data_loader.batch_sampler rank, world_size = get_dist_info() if rank == 0: prog_bar = mmcv.ProgressBar(len(dataset)) for batch_indices, data in zip(loader_indices, data_loader): with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) if efficient_test: result = [np2tmp(_, tmpdir='.efficient_test') for _ in result] if format_only: result = dataset.format_results( result, indices=batch_indices, **format_args) if pre_eval: # TODO: adapt samples_per_gpu > 1. # only samples_per_gpu=1 valid now result = dataset.pre_eval(result, indices=batch_indices) results.extend(result) if rank == 0: batch_size = len(result) * world_size for _ in range(batch_size): prog_bar.update() # collect results from all ranks if gpu_collect: results = collect_results_gpu(results, len(dataset)) else: results = collect_results_cpu(results, len(dataset), tmpdir) return results
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.