#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import itertools
import json
from multiprocessing import cpu_count
import absl.app
import numpy as np
import progressbar
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip
from util.config import Config, initialize_globals
from util.evaluate_tools import calculate_report
from util.feeding import create_dataset
from util.flags import create_flags, FLAGS
from util.logging import log_error, log_progress, create_progressbar
def sparse_tensor_value_to_texts(value, alphabet):
r"""
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
representing its values, converting tokens to strings using ``alphabet``.
"""
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
def sparse_tuple_to_texts(sp_tuple, alphabet):
indices = sp_tuple[0]
values = sp_tuple[1]
results = [[] for _ in range(sp_tuple[2][0])]
for i, index in enumerate(indices):
results[index[0]].append(values[i])
# List of strings
return [alphabet.decode(res) for res in results]
def evaluate(test_csvs, create_model, try_loading):
# path to the language model binary file created with KenLM
if FLAGS.lm_binary_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
Config.alphabet)
else:
scorer = None
# 这一步似乎多此一举,因为参数test_csvs已经做了该操作,这个代码表示的应该是由多个文件构成的test_files,
test_csvs = FLAGS.test_files.split(',') # 具体使用时可不要这一行
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
# test_set = create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) # 单个csv.文件
# 构造dataset迭代器
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]))
# 初始化各个dataset
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
# test_init_ops = iterator.make_initializer(test_set) # 此处一样
# 得到数据
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
# One rate per layer
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x,
batch_size=FLAGS.test_batch_size,
seq_length=batch_x_len,
dropout=no_dropout)
# Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
loss = tfv1.nn.ctc_loss(labels=batch_y,
inputs=logits,
sequence_length=batch_x_len)
# 获取或创建global_step
tfv1.train.get_or_create_global_step()
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
# Create a saver using variables from the above newly created graph
saver = tfv1.train.Saver()
with tfv1.Session(config=Config.session_config) as session:
# Restore variables from training checkpoint
loaded = False
# 恢复模型参数
if not loaded and FLAGS.load in ['auto', 'best']:
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
if not loaded and FLAGS.load in ['auto', 'last']:
loaded = try_loading(session, saver, 'checkpoint', 'most recent')
if not loaded:
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
sys.exit(1)
def run_test(init_op, dataset):
wav_filenames = []
losses = []
predictions = []
ground_truths = []
bar = create_progressbar(prefix='Test epoch | ',
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
log_progress('Test epoch...')
step_count = 0
# Initialize iterator to the appropriate dataset
session.run(init_op)
# First pass, compute losses and transposed logits for decoding
while True:
try:
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer,
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
predictions.extend(d[0][1] for d in decoded)
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
losses.extend(batch_loss)
step_count += 1
bar.update(step_count)
bar.finish()
wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses)
mean_loss = np.mean(losses)
# Take only the first report_count items
report_samples = itertools.islice(samples, FLAGS.report_count)
print('Test on %s - WER: %f, CER: %f, loss: %f' %
(dataset, wer, cer, mean_loss))
print('-' * 80)
for sample in report_samples:
print('WER: %f, CER: %f, loss: %f' %
(sample.wer, sample.cer, sample.loss))
print(' - wav: file://%s' % sample.wav_filename)
print(' - src: "%s"' % sample.src)
print(' - res: "%s"' % sample.res)
print('-' * 80)
return samples
samples = []
for csv, init_op in zip(test_csvs, test_init_ops):
print('Testing model on {}'.format(csv))
samples.extend(run_test(init_op, dataset=csv))
return samples
def main(_):
# 初始化DeepSpeech的参数, 这与tf_logits中DeepSpeech.initialize_globals()相同
initialize_globals()
# test.csv文件,文件的内容/格式如下:
# wav_filename,wav_filesize,transcript
# data/00001.wav, 76448, open the door
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
sys.exit(1)
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
if __name__ == '__main__':
create_flags()
absl.app.run(main)