I am trying to convert a saved model in tensorflow 1 to tensorflow 2. I am migrating the code to tensorflow 2, as higlighted in tensorflow docs. However, I would like to simply update my model_weights.ckpt
to tensorflow 2. Some weights (Linear
, Embdedding
) have a similar shape to tensorflow 2 syntax, but I am struggling to transform the weights from my GRUCell
.
GRUCell
weights from compat.v1.nn.rnn_cell.GRUCell
to keras.layers.GRUCell
?The GRUCell
has four weights:
gru_cell/gates/kernel:0
of shape (S + H, 2 x H)
,gru_cell/gates/bias:0
of shape (2 x H, )
,gru_cell/candidate/kernel:0
of shape (S + H, H)
,gru_cell/candidate/bias:0
of shape (H, )
I would like to have weights with a similar shape to tensoflow 2 API (or PyTorch API), i.e. a GRUCell
with the following weights:
gru_cell/kernel:0
of shape (S, 3 x H)
gru_cell/recurrent_kernel:0
of shape (H, 3 x H)
gru_cell/bias:0
of shape (2, 3 x H)
GRUCell
with tensorflow 1 APIimport tensorflow as tf
SEQ_LENGTH = 4
HIDDEN_SIZE = 512
BATCH_SIZE = 1
inputs = tf.random.normal([BATCH_SIZE, SEQ_LENGTH])
# GRU cell
gru = tf.compat.v1.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
# Hidden state
state = gru.zero_state(BATCH_SIZE, tf.float32)
# Forward
output, state = gru(inputs, state)
for weight in gru.weights:
print(weight.name, weight.shape)
Output:
gru_cell/gates/kernel:0 (516, 1024)
gru_cell/gates/bias:0 (1024,)
gru_cell/candidate/kernel:0 (516, 512)
gru_cell/candidate/bias:0 (512,)
GRUCell
with tensorflow 2 APIimport tensorflow as tf
SEQ_LENGTH = 4
HIDDEN_SIZE = 512
BATCH_SIZE = 1
inputs = tf.random.normal([BATCH_SIZE , SEQ_LENGTH])
# GRU cell
gru = tf.keras.layers.GRUCell(HIDDEN_SIZE)
# Hidden state
state = tf.zeros((BATCH_SIZE, HIDDEN_SIZE), dtype=tf.float32)
# Forward
output, state = gru(inputs, state)
# Display the weigths
for weight in gru.weights:
print(weight.name, weight.shape)
Output:
gru_cell/kernel:0 (4, 1536)
gru_cell/recurrent_kernel:0 (512, 1536)
gru_cell/bias:0 (2, 1536)
_convert_rnn_weights
tensorflow function to convert the desired weights. It works but only for CuDNN
weights, so I can't use it in my case.For the benefit of community providing solution here though it is presented in Github.
In short, the weights between compat.v1.nn.rnn_cell.GRUCell
and keras.layers.GRUCell
are not compatible between each other. We don't have a function to convert between them, and if you really want to do it, you will need to do it manually.
Math wise, if you have the numpy value of the v1 weights, the formula are:
B = batch_size
H = state_size