Model-Based Transfer Learning for Contextual Reinforcement Learning

Massachusetts Institute of Technology
NeurIPS 2024
Teaser Image


Model-Based Transfer Learning (MBTL) for solving contextual MDPs. MBTL framework strategically selects optimal tasks to train by modeling the generalization gap that arises when transferring a model to different tasks. For each target task, we deploy the most effective trained model, indicated by solid arrows. We evaluate our framework on a suite of continuous and discrete tasks, including standard control benchmarks and urban traffic benchmarks.



"MIT researchers develop an efficient approach for training more reliable reinforcement learning models, focusing on complex tasks that involve variability." (MIT News)

Abstract

Deep reinforcement learning (RL) is a powerful approach to complex decision making. However, one issue that limits its practical application is its brittleness, sometimes failing to train in the presence of small changes in the environment. Motivated by the success of zero-shot transfer—where pre-trained models perform well on related tasks—we consider the problem of selecting a good set of training tasks to maximize generalization performance across a range of tasks. Given the high cost of training, it is critical to select training tasks strategically, but not well understood how to do so. We hence introduce Model-Based Transfer Learning (MBTL), which layers on top of existing RL methods to effectively solve contextual RL problems. MBTL models the generalization performance in two parts: 1) the performance set point, modeled using Gaussian processes, and 2) performance loss (generalization gap), modeled as a linear function of contextual similarity. MBTL combines these two pieces of information within a Bayesian optimization (BO) framework to strategically select training tasks. We show theoretically that the method exhibits sublinear regret in the number of training tasks and discuss conditions to further tighten regret bounds. We experimentally validate our methods using urban traffic and standard continuous control benchmarks. The experimental results suggest that MBTL can achieve up to 50x improved sample efficiency compared with canonical independent training and multi-task training. Further experiments demonstrate the efficacy of BO and the insensitivity to the underlying RL algorithm and hyperparameters. This work lays the foundations for investigating explicit modeling of generalization, thereby enabling principled yet effective methods for contextual RL.

MBTL in a nutshell

  • Pick training tasks that generalize: MBTL reasons about how performance will transfer instead of treating every task as equally useful.
  • Model-driven selection: it predicts both where a policy will learn well and how much it will lose when moved to new contexts.
  • Compatible by design: MBTL wraps around standard RL algorithms, improving sample efficiency without re-engineering the base learner.

How MBTL works

A two-part model captures both training performance and the generalization gap, guiding which task to train next.

Multi-model training illustration
Multi-model training pool. We train several policies across tasks and hyperparameter settings, building a library of candidate models to transfer from.
Zero-shot transfer illustration
Zero-shot transfer checks. Each trained model is evaluated on related tasks without additional training to see how well it carries over.
Generalization gap illustration
Generalization gap modeling. We explicitly track the drop between in-task training performance and transfer performance to quantify how sensitive each task is.
Generalization performance illustration
Predictive performance model. A Gaussian process estimates performance set points, while a linear term captures the gap; their combination drives the acquisition function for picking the next task.

MBTL decision loop

MBTL overview animation
Acquisition-guided training. MBTL estimates training performance with a Gaussian process, computes marginal generalization performance using the gap model, selects the task with the highest acquisition value, then updates the models using fresh zero-shot transfer results.
Animated walkthrough of MBTL iterations
Iterative rollout. The process repeats as models improve, steadily tightening uncertainty and focusing training on the most transferable tasks.

Results

MBTL improves sample efficiency and transfer reliability across contextual RL benchmarks.

Traffic CMDP results
Traffic CMDP benchmarks. MBTL achieves strong generalization with fewer training episodes compared to independent and multi-task baselines, improving stability across diverse traffic contexts.
Control CMDP results
Control CMDP benchmarks. On standard control tasks, MBTL converges faster and maintains higher transfer performance, demonstrating that the approach is not tied to a specific domain.

BibTeX

@inproceedings{cho2024model,
        title={Model-Based Transfer Learning for Contextual Reinforcement Learning},
        author={Cho, Jung-Hoon and Jayawardana, Vindula and Li, Sirui and Wu, Cathy},
        booktitle={Thirty-Eighth Conference on Neural Information Processing Systems},
        year={2024}
      }