Branch-Train-MiX: Mixing Expert LLMs into a Mixture-of-Experts LLM
We investigate efficient methods for training Large Language Models (LLMs) to possess capabilities in multiple specialized domains, such as coding, math reasoning and world knowledge. Our method, named Branch-Train-MiX (BTX), starts from a seed model, which is branched to train experts in embarrassingly parallel fashion with high throughput and reduced communication cost. After individual experts are asynchronously trained, BTX brings together their feedforward parameters as experts in Mixture-of-Expert (MoE) layers and averages the remaining parameters, followed by an MoE-finetuning stage to learn token-level routing. BTX generalizes two special cases, the Branch-Train-Merge method, which does not have the MoE finetuning stage to learn routing, and sparse upcycling, which omits the stage of training experts asynchronously. Compared to alternative approaches, BTX achieves the best accuracy-efficiency tradeoff.
PDF AbstractCode
Datasets
Task | Dataset | Model | Metric Name | Metric Value | Global Rank | Benchmark |
---|---|---|---|---|---|---|
Arithmetic Reasoning | GSM8K | Branch-Train-MiX 4x7B (sampling top-2 experts) | Accuracy | 37.1 | # 137 | |
Math Word Problem Solving | MATH | Branch-Train-MiX 4x7B (sampling top-2 experts) | Accuracy | 17.8 | # 106 | |
Code Generation | MBPP | Branch-Train-Merge 4x7B (top-2) | Accuracy | 42.6 | # 77 | |
Code Generation | MBPP | Branch-Train-MiX 4x7B (sampling top-2 experts) | Accuracy | 39.4 | # 81 | |
Multi-task Language Understanding | MMLU | Branch-Train-MiX 4x7B (sampling top-1 experts) | Average (%) | 53.2 | # 32 | |
Question Answering | TriviaQA | Branch-Train-MiX 4x7B (sampling top-2 experts) | EM | 57.1 | # 40 | |
Common Sense Reasoning | WinoGrande | Branch-Train-MiX 4x7B (sampling top-1 expert) | Accuracy | 70.6 | # 38 |