DeepSpeech中evaluate.py文件分析

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function

import itertools
import json

from multiprocessing import cpu_count

import absl.app
import numpy as np
import progressbar
import tensorflow as tf
import tensorflow.compat.v1 as tfv1

from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip

from util.config import Config, initialize_globals
from util.evaluate_tools import calculate_report
from util.feeding import create_dataset
from util.flags import create_flags, FLAGS
from util.logging import log_error, log_progress, create_progressbar


def sparse_tensor_value_to_texts(value, alphabet):
    r"""
    Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
    representing its values, converting tokens to strings using ``alphabet``.
    """
    return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)


def sparse_tuple_to_texts(sp_tuple, alphabet):
    indices = sp_tuple[0]
    values = sp_tuple[1]
    results = [[] for _ in range(sp_tuple[2][0])]
    for i, index in enumerate(indices):
        results[index[0]].append(values[i])
    # List of strings
    return [alphabet.decode(res) for res in results]


def evaluate(test_csvs, create_model, try_loading):
	# path to the language model binary file created with KenLM
    if FLAGS.lm_binary_path:
        scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
                        FLAGS.lm_binary_path, FLAGS.lm_trie_path,
                        Config.alphabet)
    else:
        scorer = None
	# 这一步似乎多此一举,因为参数test_csvs已经做了该操作,这个代码表示的应该是由多个文件构成的test_files,
    test_csvs = FLAGS.test_files.split(',')   # 具体使用时可不要这一行
	
    test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
    # test_set = create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False)  # 单个csv.文件
	# 构造dataset迭代器
	iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
                                                 tfv1.data.get_output_shapes(test_sets[0]),
                                                 output_classes=tfv1.data.get_output_classes(test_sets[0]))
    # 初始化各个dataset
	test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
	# test_init_ops = iterator.make_initializer(test_set)  # 此处一样
	
	# 得到数据
    batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()

    # One rate per layer
    no_dropout = [None] * 6
    logits, _ = create_model(batch_x=batch_x,
                             batch_size=FLAGS.test_batch_size,
                             seq_length=batch_x_len,
                             dropout=no_dropout)

    # Transpose to batch major and apply softmax for decoder
    transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))

    loss = tfv1.nn.ctc_loss(labels=batch_y,
                          inputs=logits,
                          sequence_length=batch_x_len)

	# 获取或创建global_step
    tfv1.train.get_or_create_global_step()

    # Get number of accessible CPU cores for this process
    try:
        num_processes = cpu_count()
    except NotImplementedError:
        num_processes = 1

    # Create a saver using variables from the above newly created graph
    saver = tfv1.train.Saver()

    with tfv1.Session(config=Config.session_config) as session:
        # Restore variables from training checkpoint
        loaded = False
		# 恢复模型参数
        if not loaded and FLAGS.load in ['auto', 'best']:
            loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
        if not loaded and FLAGS.load in ['auto', 'last']:
            loaded = try_loading(session, saver, 'checkpoint', 'most recent')
        if not loaded:
            print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
            sys.exit(1)

        def run_test(init_op, dataset):
            wav_filenames = []
            losses = []
            predictions = []
            ground_truths = []

            bar = create_progressbar(prefix='Test epoch | ',
                                     widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
            log_progress('Test epoch...')

            step_count = 0

            # Initialize iterator to the appropriate dataset
            session.run(init_op)

            # First pass, compute losses and transposed logits for decoding
            while True:
                try:
                    batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
                        session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
                except tf.errors.OutOfRangeError:
                    break

                decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
                                                        num_processes=num_processes, scorer=scorer,
                                                        cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
                predictions.extend(d[0][1] for d in decoded)
                ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
                wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
                losses.extend(batch_loss)

                step_count += 1
                bar.update(step_count)

            bar.finish()

            wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses)
            mean_loss = np.mean(losses)

            # Take only the first report_count items
            report_samples = itertools.islice(samples, FLAGS.report_count)

            print('Test on %s - WER: %f, CER: %f, loss: %f' %
                  (dataset, wer, cer, mean_loss))
            print('-' * 80)
            for sample in report_samples:
                print('WER: %f, CER: %f, loss: %f' %
                      (sample.wer, sample.cer, sample.loss))
                print(' - wav: file://%s' % sample.wav_filename)
                print(' - src: "%s"' % sample.src)
                print(' - res: "%s"' % sample.res)
                print('-' * 80)

            return samples

        samples = []
        for csv, init_op in zip(test_csvs, test_init_ops):
            print('Testing model on {}'.format(csv))
            samples.extend(run_test(init_op, dataset=csv))
        return samples


def main(_):
	# 初始化DeepSpeech的参数, 这与tf_logits中DeepSpeech.initialize_globals()相同
    initialize_globals()
	# test.csv文件,文件的内容/格式如下:
	# wav_filename,wav_filesize,transcript
	# data/00001.wav, 76448, open the door
    if not FLAGS.test_files:
        log_error('You need to specify what files to use for evaluation via '
                  'the --test_files flag.')
        sys.exit(1)

    from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
    samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)

    if FLAGS.test_output_file:
        # Save decoded tuples as JSON, converting NumPy floats to Python floats
        json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)


if __name__ == '__main__':
    create_flags()
    absl.app.run(main)
Table of Contents