DeepSpeech中transcribe文件分析

在DeepSpeech中的transcribe功能中有关logits的定义,主要分析transcribe.py文件

#!/usr/bin/env python

# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function

import os
import sys
import json
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import tensorflow.compat.v1.logging as tflogging
import tensorflow.compat.v1 as tfv1
tflogging.set_verbosity(tflogging.ERROR)
import logging
logging.getLogger('sox').setLevel(logging.ERROR)

from multiprocessing import Process, cpu_count
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from util.config import Config, initialize_globals
from util.audio import AudioFile
from util.feeding import split_audio_file
from util.flags import create_flags, FLAGS
from util.logging import log_error, log_info, log_progress, create_progressbar


def fail(message, code=1):
    log_error(message)
    sys.exit(code)


# 除了要输入音频路径外,还需要输入 目标文本信息
# 可以使用如下的.csv文件进行训练

def transcribe_file(audio_path, tlog_path):
    # 导入DeepSpeech 模型
	from DeepSpeech import create_model, try_loading  # pylint: disable=cyclic-import,import-outside-toplevel  
    # 初始化参数
	initialize_globals()
	# 用于CTC decoder的参数信息
    scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet)
    # 获得CPU的个数
	try:
        num_processes = cpu_count()
    except NotImplementedError:
        num_processes = 1
	# 使用AudioFile读取音频数据,输入音频文件的路径等
    with AudioFile(audio_path, as_path=True) as wav_path:
		# 将采样值数据转化成MFCC特征的直接可输入模型的信息
		# data_set = split_audio_file(wav_path,
        #                            batch_size=FLAGS.batch_size,  
        #                            aggressiveness=FLAGS.vad_aggressiveness,    # FLAGE中似乎没有该参数, 若报错可删除。因为在该函数中已经有默认值为3
        #                            outlier_duration_ms=FLAGS.outlier_duration_ms,   # FLAGE中没有该参数,默认值为10000
        #                            outlier_batch_size=FLAGS.outlier_batch_size)     # FLGAE中没有该参数,默认值为1
        # 构造该dataset的迭代器
		iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes,
                                                   output_classes=data_set.output_classes)
		# 使用get_next()获取batch_size大小的数据
        batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
        no_dropout = [None] * 6
		# ----------------------TODO--------------------
		# 得到模型输出的logits信息, 在tf_logit.py中,通过输入audio,以及batch_size,输出该logits 最后一层由([n_steps*batch_size, n_hidden_6])->([n_steps, batch_size, n_hidden_6]).
		# 若要使用集成模型,则需要对该部分进行改造
        logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)  ([n_steps*batch_size, n_hidden_6])->([n_steps, batch_size, n_hidden_6]).
        # ----------------------------------------------
		# 进行转置操作 Transpose to batch major for decoder。 其实在deepSpeech v0.4.1的版本中也是用了该操作(注:在使用logit进行ctcloss时,并没有使用转置操作,而是直接使用model输出的logits)
		transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))   # 该操作仅用于Decode
        # 若是restore模型,则使用该操作获取global_step; 若不是,则创建一个global_step
		tf.train.get_or_create_global_step()
        saver = tf.train.Saver()
        with tf.Session(config=Config.session_config) as session:
            loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', log_success=False)
            if not loaded:
                loaded = try_loading(session, saver, 'checkpoint', 'most recent', log_success=False)
            if not loaded:
                fail('Checkpoint directory ({}) does not contain a valid checkpoint state.'
                     .format(FLAGS.checkpoint_dir))
			# 数据迭代器初始化
            session.run(iterator.make_initializer(data_set))
            transcripts = []
            while True:
                try:
				# 得到解码文字的一些参数
                    starts, ends, batch_logits, batch_lengths = \
                        session.run([batch_time_start, batch_time_end, transposed, batch_x_len])
                except tf.errors.OutOfRangeError:
                    break
				# 解码
                decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
                                                        num_processes=num_processes,
                                                        scorer=scorer)
                decoded = list(d[0][1] for d in decoded)
                transcripts.extend(zip(starts, ends, decoded))
            transcripts.sort(key=lambda t: t[0])
            transcripts = [{'start': int(start),
                            'end': int(end),
                            'transcript': transcript} for start, end, transcript in transcripts]
            with open(tlog_path, 'w') as tlog_file:
                json.dump(transcripts, tlog_file, default=float)


def transcribe_many(path_pairs):
    pbar = create_progressbar(prefix='Transcribing files | ', max_value=len(path_pairs)).start()
    for i, (src_path, dst_path) in enumerate(path_pairs):
        p = Process(target=transcribe_file, args=(src_path, dst_path))
        p.start()
        p.join()
        log_progress('Transcribed file {} of {} from "{}" to "{}"'.format(i + 1, len(path_pairs), src_path, dst_path))
        pbar.update(i)
    pbar.finish()


def transcribe_one(src_path, dst_path):
    transcribe_file(src_path, dst_path)
    log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path))


def resolve(base_path, spec_path):
    if spec_path is None:
        return None
    if not os.path.isabs(spec_path):
        spec_path = os.path.join(base_path, spec_path)
    return spec_path


def main(_):
    if not FLAGS.src:
        fail('You have to specify which file or catalog to transcribe via the --src flag.')
    src_path = os.path.abspath(FLAGS.src)
    if not os.path.isfile(src_path):
        fail('Path in --src not existing')
    if src_path.endswith('.catalog'):
        if FLAGS.dst:
            fail('Parameter --dst not supported if --src points to a catalog')
        catalog_dir = os.path.dirname(src_path)
        with open(src_path, 'r') as catalog_file:
            catalog_entries = json.load(catalog_file)
        catalog_entries = [(resolve(catalog_dir, e['audio']), resolve(catalog_dir, e['tlog'])) for e in catalog_entries]
        if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
            fail('Missing source file(s) in catalog')
        if not FLAGS.force and any(map(lambda e: os.path.isfile(e[1]), catalog_entries)):
            fail('Destination file(s) from catalog already existing, use --force for overwriting')
        if any(map(lambda e: not os.path.isdir(os.path.dirname(e[1])), catalog_entries)):
            fail('Missing destination directory for at least one catalog entry')
        transcribe_many(catalog_entries)
    else:
        dst_path = os.path.abspath(FLAGS.dst) if FLAGS.dst else os.path.splitext(src_path)[0] + '.tlog'
        if os.path.isfile(dst_path):
            if FLAGS.force:
                transcribe_one(src_path, dst_path)
            else:
                fail('Destination file "{}" already existing - use --force for overwriting'.format(dst_path), code=0)
        elif os.path.isdir(os.path.dirname(dst_path)):
            transcribe_one(src_path, dst_path)
        else:
            fail('Missing destination directory')


if __name__ == '__main__':
    create_flags()
    tf.app.flags.DEFINE_string('src', '', 'source path to an audio file or directory to recursively scan '
                                          'for audio files. If --dst not set, transcription logs (.tlog) will be '
                                          'written in-place using the source filenames with '
                                          'suffix ".tlog" instead of ".wav".')
    tf.app.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). '
                                          'If --src is a directory, this one also has to be a directory '
                                          'and the required sub-dir tree of --src will get replicated.')
    tf.app.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing '
                                                'transcription logs (.tlog)')
    tf.app.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should '
                                                         'split audio')
    tf.app.flags.DEFINE_integer('batch_size', 40, 'Default batch size')
    tf.app.flags.DEFINE_float('outlier_duration_ms', 10000, 'Duration in ms after which samples are considered outliers')
    tf.app.flags.DEFINE_integer('outlier_batch_size', 1, 'Batch size for duration outliers (defaults to 1)')
    tf.app.run(main)

关于global_step的坑

global_step 往往会影响learning rate的计算,所以要正确确定该参数。 fine-tuning时,global_step肯定要从0开始记;继续训练时要从上次的断点开始计。 tf中restore model的方法:

作者:leaf
链接:https://www.zhihu.com/question/269968195/answer/351000240
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

# method 1
var_list=variables_to_restore = []
saver = tf_saver.Saver(var_list)
# 可指定var_list, 适用于fine-tuning
# 若不指定,则回复所有参数,适用于继续训练
saver.restore(session, model_path)
# 该方法不会对global_step 产生作用,即使是在继续训练时,也要额外操作
# 使用正则获得step,assign给gloal_step
global_step = tf.Variable(0, name="global_step", trainable=False,dtype=tf.int32)
# ckpt.model_checkpoint_path like 'model.ckpt-175409'
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
tf.assign(global_step, step+1)
# method 2, 使用slim
global_step = tf.train.get_or_create_global_step()
# or global_step = slim.create_global_step()
# 这两种定义方法,都会吧global_step 加入到GraphKeys.GLOBAL_STEP中,方便取出
slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master='',
            is_chief=True,
            init_fn=tf_utils.get_init_fn()
            )
# 从logdir中存在ckpt,默认恢复global_step
# 若指定init_fn=tf_utils.get_init_fn,可自定义需要恢复的参数,用来fine-tuning
Table of Contents