最近由于需求,需要重载Keras的Model类,代码逻辑是好好的,但是最后运行的时候出现了NoImplementError这个错误,现实的是self.compute_output_shape没有在子类当中实现。代码如下:
from keras import Model
class ACModel(Model):
'''
Comment model class for actor and critic model
'''
def __init__(self,state_size,action_size):
super().__init__()
self.state_siz=state_size
self.action_size=action_size
self.dense_1=Dense(100,activation='relu')
self.policy_logits=Dense(self.action_size,activation='softmax')
# output of probabilities of actions in disceate space
self.dense_2=Dense(100,activation='relu')
self.value=Dense(1) # output of value by critic-net according action
def call(self, inputs, mask=None):
'''over write call() method in Model class.
must handle inputs'''
x=self.dense_1(inputs) # input is states,
# this layer(dense_1) is shared layer in action and critic network
logits=self.policy_logits(x) # output actions probabilities
x=self.dense_2(x)
value=self.value(x) # value produced by critic
return logits, value
model=ACModel(4,1)
res=model(tf.constant([[1,1,1,1]],dtype=tf.float32))
#最后会抛出NoImplementError错误
查看keras以及别人的实现,发现很多都是继承的tf.keras.Model这个类。按理来说是不应该的,Keras本身就是从tf当中高度集成的,但是实验之后发现真的有不一样。改变后的代码如下
import tensorflow as tf
class ACModel(tf.keras.Model):
'''
Comment model class for actor and critic model
'''
def __init__(self,state_size,action_size):
super().__init__()
self.state_siz=state_size
self.action_size=action_size
self.dense_1=Dense(100,activation='relu')
self.policy_logits=Dense(self.action_size,activation='softmax')
# output of probabilities of actions in disceate space
self.dense_2=Dense(100,activation='relu')
self.value=Dense(1) # output of value by critic-net according action
def call(self, inputs):
'''over write call() method in Model class.
must handle inputs'''
x=self.dense_1(inputs) # input is states,
# this layer(dense_1) is shared layer in action and critic network
logits=self.policy_logits(x) # output actions probabilities
x=self.dense_2(x)
value=self.value(x) # value produced by critic
return logits, value
model=ACModel(4,1)
res=model(tf.constant([[1,1,1,1]],dtype=tf.float32))
这下不会报错了,但其实仔细查看文档还是可以发现区别的。keras.Model当中的call方法的函数签名是call(self, inputs, mask=None),而tf.keras.Model当中call的签名是call(self, input)。这个原因可能是tensorflow和keras的版本没有统一的原因。
PS. tensorflow的版本是2.1.0,keras的版本是2.3.1。