您当前的位置: 首页 >  tensorflow

Better Bench

暂无认证

  • 2浏览

    0关注

    695博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

【Tensorflow+keras】Keras 用Class类封装的模型如何调试call子函数的模型内部变量

Better Bench 发布时间:2021-06-08 11:04:49 ,浏览量:2

1 引言

keras搭建神经网络模型有三种方式,第一种是使用sequential,第二种函数API,第三种是Class。第二种在IDE直接家断点就可以调试。但是在Class封装的神经网络中,如下,添加断点后,运行是不会进入到调试的。

# 模型
class test_layer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(test_layer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.w = K.variable(1.)
        self._trainable_weights.append(self.w)
        super(test_layer, self).build(input_shape)

    def call(self, x, **kwargs):
        m = x * x            # 在这设置断点
        n = self.w * K.sqrt(x)
        return m + n
# 主函数
import tensorflow as tf
import keras
import keras.backend as K

input = keras.layers.Input((100,1))
y = test_layer()(input)

model = keras.Model(input,y)
model.predict(np.ones((100,1)))
2 实现

添加断点后,通过单独调用Class中的call类,并传入实参,就可以进入到call函数进行调试查看

# 主函数
import tensorflow as tf
import keras
import keras.backend as K

test_input = np.ones((100,1)
model = test_layer()
test = model.call(test_input)
关注
打赏
1665674626
查看更多评论
立即登录/注册

微信扫码登录

0.0409s