Efficient Exploration via State Marginal Matching

Exploration is critical to a reinforcement learning agent's performance in its given environment. Prior exploration methods are often based on using heuristic auxiliary predictions to guide policy behavior, lacking a mathematically-grounded objective with clear properties. In contrast, we recast exploration as a problem of State Marginal Matching (SMM), where we aim to learn a policy for which the state marginal distribution matches a given target state distribution. The target distribution is a uniform distribution in most cases, but can incorporate prior knowledge if available. In effect, SMM amortizes the cost of learning to explore in a given environment. The SMM objective can be viewed as a two-player, zero-sum game between a state density model and a parametric policy, an idea that we use to build an algorithm for optimizing the SMM objective. Using this formalism, we further demonstrate that prior work approximately maximizes the SMM objective, offering an explanation for the success of these methods. On both simulated and real-world tasks, we demonstrate that agents that directly optimize the SMM objective explore faster and adapt more quickly to new tasks as compared to prior exploration methods.

PDF Abstract

Datasets


Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Unsupervised Reinforcement Learning URLB (pixels, 10^5 frames) SMM Walker (mean normalized return) 6.07±6.14 # 10
Quadruped (mean normalized return) 22.52±6.44 # 6
Jaco (mean normalized return) 0.99±0.61 # 7
Unsupervised Reinforcement Learning URLB (pixels, 10^6 frames) SMM Walker (mean normalized return) 6.61±6.70 # 9
Quadruped (mean normalized return) 21.21±6.10 # 7
Jaco (mean normalized return) 0.99±0.61 # 8
Unsupervised Reinforcement Learning URLB (pixels, 2*10^6 frames) SMM Walker (mean normalized return) 6.61±6.70 # 10
Quadruped (mean normalized return) 21.21±6.10 # 8
Jaco (mean normalized return) 0.99±0.61 # 9
Unsupervised Reinforcement Learning URLB (pixels, 5*10^5 frames) SMM Walker (mean normalized return) 6.31±6.44 # 9
Quadruped (mean normalized return) 21.18±6.13 # 8
Jaco (mean normalized return) 0.99±0.61 # 7
Unsupervised Reinforcement Learning URLB (states, 10^5 frames) SMM Walker (mean normalized return) 57.84±26.88 # 9
Quadruped (mean normalized return) 35.53±10.16 # 2
Jaco (mean normalized return) 26.06±6.40 # 8
Unsupervised Reinforcement Learning URLB (states, 10^6 frames) SMM Walker (mean normalized return) 72.60±32.07 # 8
Quadruped (mean normalized return) 37.37±4.30 # 6
Jaco (mean normalized return) 29.96±1.37 # 8
Unsupervised Reinforcement Learning URLB (states, 2*10^6 frames) SMM Walker (mean normalized return) 77.13±29.55 # 2
Quadruped (mean normalized return) 29.95±7.59 # 7
Jaco (mean normalized return) 21.87±2.77 # 8
Unsupervised Reinforcement Learning URLB (states, 5*10^5 frames) SMM Walker (mean normalized return) 73.64±33.56 # 8
Quadruped (mean normalized return) 37.20±12.78 # 4
Jaco (mean normalized return) 31.95±2.95 # 8

Methods


No methods listed for this paper. Add relevant methods here