PV-RNN has the weirdest bugs
PV-RNN is an architecture that evolved from VRNN, addressing some of its limitations.
# Obtaining A1, A2 with BPTT
A_loss = self.min_loss + 1.
iter = 0
while loss > self.min_loss and iter < self.max_iter:
A_optim.zero_grad()
# self.zero_grad()
iter += 1
# predict the x sequence
x_hat = self.fast_generate_from_posterior(A1, A2)
# loss
A_loss = self.mse_loss(x, x_hat)
# backpropagation
A_loss.backward()
# update A1, A2
A_optim.step()
# Obtaining predictions and posterior parameters with A1, A2
self.zero_grad()
x_hat = self.fast_generate_from_posterior(A1, A2, store=True)
return x_hat, A1, A2
The latent state is a device to add stochasticity to the model so that the distribution associated with this stochasticity evolves through time. The prior network provides parameters to obtain the latent state $z_t^p$ given the current RNN hidden state $d_t$, and the value of the latent variables influence the next RNN state, from which the outputs are obtained. The network can thus produce outputs like a probabilistic finite-state machine. When you reach a point where a probabilistic state transition may happen this state is encoded in the RNN hidden state, from which the latent state is obtained. Transitions happening stochastically at this state cause prediction errors, which should cause the prior network to use more variance when producing the latent state.