在搭建自己的Model的时候,我们有时候需要自己从稍微低层的部分进行搭建,而不是直接用Sequential搭建模型,或者是使用Model(inputs,outputs)的方式搭建,例如下面这个简单的例子:
# from tensorflow.python import keras
# from tensorflow.keras.layers import Dense,Input
# 直接使用keras或者是从tensorflow当中导入keras,两种方式二选一
import keras
from keras.layers import Dense,Input
class ActorCriticSharedModel(keras.Model):
'''
Comment model class
'''
def __init__(self, state_size, action_size):
super().__init__()
self.state_size = state_size
self.action_size = action_size
self.dense_1 = Dense(100, activation='relu')
self.policy_logits = Dense(self.action_size)
self.dense_2 = Dense(100, activation='relu')
self.value = Dense(1) # output of value by critic-net according action
def call(self, inputs, training=None, mask=None):
'''over write call() method in Model class.
must handle inputs'''
x = self.dense_1(inputs) # input is states,
# this layer(dense_1) is shared layer in action and critic network
logits = self.policy_logits(x) # softmax(logits) is probabilities of output actions
x = self.dense_2(inputs)
value = self.value(x) # value produced by critic
return logits, value # logits--without softmax activation
对__init__
以及call
的方法进行重载之后,就可以很简单的进行搭建模型,如下所示:
model=ActorCriticSharedModel(5,1)
搭建一个简单的5输入,1输出的网络,但是查看网络的可训练参数model.trainableweights
是会显示空list的。其中的原因是因为在搭建网络时,没有执行参数初始化的操作。此时只需要简单的使用随机数初始化一下就可以,例如:model(tf.convert_to_tensor(np.random.random((1, 5),dtype=tf.float32))
。此时再去查看可训练参数model.trainable_weights
就是显示正常。
换而言之,如果自己采用的是Sequential或者说是Model(inputs,outputs)的方式搭建的模型,那么模型就会自动的执行参数的初始化,直接查看model.trainable_weights
就可以。