Lyft’s Reinforcement Learning Platform

While there are some fundamental differences between RL and supervised learning, we were able to extend our existing model training and serving systems to accommodate for the new technique. The big advantage of this approach is that it allows us to leverage lots of proven platform components.

Architecture

This blog post gives a more detailed overview of our model hosting solution, LyftLearn Serving. Here we want to focus on the modifications required to support RL models which include:

  • Providing the action space for every scoring request
  • Logging the action propensities
  • Emitting business event data from the application that allows for calculation of the reward, e.g. that a recommendation was clicked on. This step can be skipped if the reward data can be inferred from application metrics that are already logged. However, if those metrics depend on an ETL pipeline, the recency of training data will be limited by that schedule.

Architecture diagram showing how models are registered to the backend, synced to the serving instance where they accept requests from the client. The logged data is fed into a data warehouse from where it’s consumed by the policy update job which registers the updated model to the backend.

RL Platform System Architecture

There are two entry points for adding models to the system. The Experimentation Interface allows for kicking off an experiment with an untrained bandit model. The blank model only starts learning in production as it observes feedback for the actions it takes and is typically used for more efficient A/B tests. The Model Development way is more suitable for sophisticated models that are developed in source-controlled repositories and potentially pre-trained offline. This flow is very similar to the existing supervised learning model development and deployment.

The models are registered with the Model Database and loaded into the LyftLearn Serving Model Serving instances utilizing the existing model syncing capabilities.

In Policy Update, the events for the model scores and their responses on the client application side are pulled from the Data Warehouse, joined and a customer provided reward function is applied. This data is used to incrementally update the latest model instance.

Finally, the retrained model is written back to S3 and promoted in the Model Database. The Policy Update is orchestrated by a Model CI/CD workflow definition which schedules the training job and takes care of promoting the new model.

Library

We leverage open-source libraries like Vowpal Wabbit and RLlib for modeling. In addition, we created our own internal RL library with model definitions, data processing, and bandit algorithm implementations to integrate with our infrastructure and make it easy for model creators to get started.

Vowpal Wabbit

For modeling CBs, we’ve chosen the Vowpal Wabbit library for its maturity with a decade of research and development, currently maintained by Microsoft Research. While it is not the most user-friendly ML library with some odd text-based interfaces, it comes with a wealth of battle-tested features, such as multiple exploration algorithms, policy evaluation methods, and advanced capabilities like conditional bandits. The authors are also prolific researchers in the field. For a comparison of different bandit techniques, the Contextual Bandit Bake-off paper is a great starting point.

Lyft RL Library

In order to integrate the VW Contextual Bandits and other RL models into our ML ecosystem, we created a library with the following components.

Model class hierarchy showing how the RL base model extends the general model class and different RL implementations like Vowpal Wabbit extend the base RL model

Model Class Inheritance Tree

Core

The core layer adapts the RL specific components to the existing interfaces of supervised learning models. This includes the RL base model class definition, which extends the generic model class and overrides model loading, training and testing components to the RL patterns. Additionally, it contains the data-models for events and the model response as well as utilities for extracting data from logged events, transforming training data and processing rewards.

Library

The library layer adds implementations of the abstract core base classes for particular applications, such as Vowpal Wabbit or our own MAB algorithms. For VW, this includes using the library’s serialization schemes, emitting performance metrics and feature weights as well as translating feature dictionaries to VW’s text-based format.

Evaluation

Another important component is the evaluation tooling for model development. This includes Shapley Values-based feature importance analysis which is helpful for context feature selection as well as customizations for the Coba Off-Policy Evaluation framework discussed in a later section.

Serving

RL models use the same scoring API endpoints as traditional models. The difference is in the additional arguments passed in the request body for RL models. Each model supports a model handler to process input data before passing it to the actual model artifact. This mechanism is used to perform the necessary feature transformations for the VW models and to translate the output back into our expected format.

Diagram showing how the RL models share the same foundation as supervised learning ones except for an RL specific shim

Serving layers in LyftLearn Serving

Training

There are two phases for training a model: warm-starting before the actual launch and continuous updates during the lifetime of the model.

Warm Start

The bandit model can be pre-trained offline on log data of an existing policy. The logging policy does not have to be a bandit but can be some heuristic that the CB model is supposed to replace. Ideally, the logged actions include their propensities. However, learning works without it as well, as long as reward data can be associated with the selected actions.

Warm-starting avoids the costly exploration phase and kick-starts the model performance without preventing it from adapting to changes in the environment over time. While warm-starting helps with reducing regret, it is not necessary, and models can start out with exploring all actions evenly and then adjusting their exploration based on recurring training cycles.

Continuous Updates

For continuous training, we use the same Model CI/CD pipeline that is used for the automatic retraining of supervised learning models. The model update queries join all model scoring events since the last training cycle with the relevant reward data. The model scoring event includes the context features, selected action, and the probabilities of all actions as well as a UUID.

Reward data can either be emitted by the business application explicitly for the purpose of training the model or business metrics that are captured anyway can be used for calculating the reward. We just have to make sure to link the model action to a particular outcome, typically by joining on a session ID that’s used as the model scoring UUID. Ideally, the reward is not directly emitted by the business application but rather the metrics for calculating the reward are logged, e.g. an article being selected. This allows for different reward functions to be evaluated and tweaked over time, e.g. giving higher rewards for newer or longer articles. For applications that receive their rewards delayed, the data joining logic needs to be a bit more sophisticated.

Once all the necessary data is joined, we extract the necessary training fields into a data frame, perform some processing including data cleaning and normalization and then update the model. In order to evaluate the training progress we emit VW’s internal training loss as well as the changes in feature weights.

Upon completion of the training cycle, the new model artifact is registered as the latest version and loaded into the model serving instance with a zero-downtime hot-swap.

Accueil - Wiki
Copyright © 2011-2024 iteam. Current version is 2.137.1. UTC+08:00, 2024-11-15 12:53
浙ICP备14020137号-1 $Carte des visiteurs$