Contrastive BERT is a reinforcement learning agent that combines a new contrastive loss and a hybrid LSTM-transformer architecture to tackle the challenge of improving data efficiency for RL. It uses bidirectional masked prediction in combination with a generalization of recent contrastive methods to learn better representations for transformers in RL, without the need of hand engineered data augmentations.
For the architecture, a residual network is used to encode observations into embeddings $Y_{t}$. $Y_{t}$ is fed through a causally masked GTrXL transformer, which computes the predicted masked inputs $X_{t}$ and passes those together with $Y_{t}$ to a learnt gate. The output of the gate is passed through a single LSTM layer to produce the values that we use for computing the RL loss. A contrastive loss is computed using predicted masked inputs $X_{t}$ and $Y_{t}$ as targets. For this, we do not use the causal mask of the Transformer.
Source: CoBERL: Contrastive BERT for Reinforcement LearningPaper | Code | Results | Date | Stars |
---|
Component | Type |
|
---|---|---|
GTrXL
|
RL Transformers | |
LSTM
|
Recurrent Neural Networks | |
ReLIC
|
Self-Supervised Learning | |
Residual Connection
|
Skip Connections |