Learning Associative Inference Using Fast Weight Memory

Humans can quickly associate stimuli to solve problems in novel contexts. Our novel neural network model learns state representations of facts that can be composed to perform such associative inference. To this end, we augment the LSTM model with an associative memory, dubbed Fast Weight Memory (FWM). Through differentiable operations at every step of a given input sequence, the LSTM updates and maintains compositional associations stored in the rapidly changing FWM weights. Our model is trained end-to-end by gradient descent and yields excellent performance on compositional language reasoning problems, meta-reinforcement-learning for POMDPs, and small-scale word-level language modelling.

PDF Abstract ICLR 2021 PDF ICLR 2021 Abstract

Datasets


Results from the Paper


Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Question Answering catbAbI LM-mode Metalearned Neural Memory (plastic) Accuracy (mean) 69.3% # 4
Question Answering catbAbI LM-mode AWD-Transformer XL Accuracy (mean) 90.23% # 2
Question Answering catbAbI LM-mode Fast Weight Memory Accuracy (mean) 93.04% # 1
Question Answering catbAbI LM-mode AWD-LSTM Accuracy (mean) 80.15% # 3
Question Answering catbAbI QA-mode AWD-LSTM 1:1 Accuracy 80.88% # 4
Question Answering catbAbI QA-mode Fast Weight Memory 1:1 Accuracy 96.75% # 1
Question Answering catbAbI QA-mode Metalearned Neural Memory (plastic) 1:1 Accuracy 88.97% # 2
Question Answering catbAbI QA-mode AWD-Transformer XL 1:1 Accuracy 87.66% # 3
Language Modelling Penn Treebank (Word Level) AWD-FWM Schlag et al. (2020) Validation perplexity 56.76 # 18
Test perplexity 54.48 # 21
Params 24M # 7
Language Modelling WikiText-2 AWD-FWM Schlag et al. (2020) Validation perplexity 54.48 # 13
Test perplexity 61.65 # 27
Number of params 37M # 9

Methods