Tensorflow 加载同一模型的两个不同模型参数

Tensorflow模型的保存

以下代码只是一个思想,并不保证能够运行

with tf.variable_scope('model1'):
	cnn1 = CNN()
	predict1 = cnn1(input)
with tf.variable_scope('model2'):
	cnn2 = CNN()
	predict2 = cnn2(input)

predict = (predict1 + predict2) / 2

t_var = tf.global_variable()
var_model1 = [var for var in t_var if var.name.startswith('model1')]
var_model2 = [var for var in t_var if var.name.startswith('model2')]

sess = tf.train.Session()
saver_model1 = tf.train.Saver(var_list=var_model1)
saver_model2 = tf.train.Saver(var_list=var_model2)

saver_model1.save(sess, 'checkpoint/model1')
saver_model2.save(sess, 'checkpoint/model2')

题外话,DeepSpeech模型参数的加载以及重新保存见Black Box attack

Tensorflow模型加载

with tf.variable_scope('model1'):
	cnn1 = CNN()
	predict1 = cnn1(input)
with tf.variable_scope('model2'):
	cnn2 = CNN()
	predict2 = cnn2(input)

predict = (predict1 + predict2) / 2

t_var = tf.global_variable()
var_model1 = [var for var in t_var if var.name.startswith('model1')]
var_model2 = [var for var in t_var if var.name.startswith('model2')]

sess = tf.train.Session()

saver_model1 = tf.train.Saver(var_list=var_model1)
saver_model2 = tf.train.Saver(var_list=var_model2)

saver_model1.restore(sess, 'checkpoint/model1')
saver_model2.restore(sess, 'checkpoint/model2')

Table of Contents