tf bert cover

Check out our repo for all the code referenced in this blog!

Recommender systems are used by many groups to maximize the presentation of products to users. There is a variety of implementations for building recommender systems, but at their core, these systems are designed to sort a universe of items by their relevance to a user based on user information, item information, or both.

One well known algorithm for solving the sorting problem is the Learn-to-Rank model, where the objective is to rank a list of examples by each item’s relevance to a particular user. This model is different from traditional machine learning models where the task is to generate a prediction (classification/regression) of a single instance. This model learns from potentially sparse user-item interaction examples to prioritize a list of items. In this case, we are trying to evaluate the ordering of a collection of items relative to each other. Therefore, the metrics used to measure the quality of this model are different than those used traditionally.

Historically, it has been a challenge to implement a production-worthy model from scratch. Luckily, the Tensorflow team rolled out a library designed to train and deploy large-scale ranking models called TF-Ranking. They include support for state-of-the-art components like ranking-specific loss functions including pointwise, pairwise, and listwise losses and ranking-specific metrics like MRR and NDCG. The TF-Ranking library leverages the tf.estimator API, an abstraction that simplifies the implementation of the entire lifecycle of this model.

The library is flexible enough to incorporate other model architectures to improve ranking model performance. The team recently open sourced the code implementation for this paper where researchers built a LTR model through fine tuning BERT representations of query-document pairs within TF-Ranking. We wanted to explore this new research by training an LTR model to rank movies using the MovieLens dataset with TF-Ranking.

Let’s begin by preparing the data for training.

The Data

As mentioned above, we use the MovieLens 100K dataset which we’ve enhanced to include title descriptions for all movies in the set. This dataset is hosted on google drive here. To set up our training set, we need to find use an abstraction for representing user (context) and item (example) information and how they are related. TF-Ranking works with tf.Example protos, specifically the ExampleListWithContext (ELWC) protobuffer. This format stores the context as an tf.Example proto and stores the items as a list of tf.Example protos. In this case, the context is our user information, ie. age, sex, and occupation.

We then concatenate the user’s movie history as a list of tf.Example protos for each movie and how the user rated it as a relevance score. For each movie, we’ll store the title description in BERT format, where we transform the text into tokens. These protos will store our example information. We’ve included a script to create tfrecords in this format in the repo for this blog.

examples {
  features {
    feature {
      key: "input_ids"
      value {
        int64_list {
          value: 101
          value: 324
          value: 543
          value: 654
          value: 767
          value: 234
        }
      }
    }
    feature {
      key: "input_mask"
      value {
        int64_list {
          value: 1
          value: 1
          value: 1
          value: 1
          value: 1
          value: 1
        }
      }
    }
    feature {
      key: "movie_title"
      value {
        bytes_list {
          value: "Star"
          value: "Trek"
        }
      }
    }
    feature {
      key: "relevance"
      value {
        int64_list {
          value: 10
        }
      }
    }
    feature {
      key: "segment_ids"
      value {
        int64_list {
          value: 0
          value: 0
          value: 0
          value: 0
          value: 0
          value: 0
        }
      }
    }
    feature {
      key: "title_description"
      value {
        bytes_list {
          value: "The"
          value: "brash"
          value: "James"
          value: "T"
          value: "Kirk"
        }
      }
    }
  }
}
context {
  features {
    feature {
      key: "agegroup"
      value {
        bytes_list {
          value: "adult"
        }
      }
    }
    feature {
      key: "occupation"
      value {
        bytes_list {
          value: "engineer"
        }
      }
    }
    feature {
      key: "sex"
      value {
        bytes_list {
          value: "female"
        }
      }
    }
    feature {
      key: "user_id"
      value {
        bytes_list {
          value: "123"
        }
      }
    }
  }
}

The Model

The extension module of the TF-Ranking library includes TFRBertRankingNetwork component to build most of the network architecture. We modified this component to include the context data from the ELWCs we created earlier into the scoring function.

class TFRBertRankingNetwork(tfrkeras_network.UnivariateRankingNetwork):
  """A TFRBertRankingNetwork scoring based univariate ranking network."""

  def __init__(self,
               context_feature_columns,
               example_feature_columns,
               bert_config_file,
               bert_max_seq_length,
               bert_output_dropout,
               name="tfrbert",
               **kwargs):
  # ...
  # ... initializing BERT related variables
  # ...
  def score(self, context_features=None, example_features=None, training=True):
    """Univariate scoring of context and one example to generate a score."""

    @tf.function
    def get_inputs():
        context_inputs = [
                tf.compat.v1.layers.flatten(context_features[name])
                for name in sorted(context_feature_columns())
                ]

        example_inputs = {
                "input_word_ids": tf.cast(example_features["input_ids"], tf.int32),
                "input_mask": tf.cast(example_features["input_mask"], tf.int32),
                "input_type_ids": tf.cast(example_features["segment_ids"], tf.int32)
                }

        # The `bert_encoder` returns a tuple of (sequence_output, cls_output).
        _, cls_output = self._bert_encoder(example_inputs, training=training)
        result = tf.concat(context_inputs + [cls_output], axis=1)
        return result

    result = get_inputs()

    output = self._dropout_layer(result, training=training)

    return self._score_layer(output)

With this class, we can initialize, configure, and compile a TFR-BERT network into an estimator for training.

# Initializing Network

network = TFRBertRankingNetwork(
  context_feature_columns=context_feature_columns(),
  example_feature_columns=example_feature_columns(),
  bert_config_file=hparams.get("bert_config_file"),
  bert_max_seq_length=hparams.get("bert_max_seq_length"),
  bert_output_dropout=hparams.get("dropout_rate"),
  name=_NETWORK_NAME)

# After initializing the config, loss, metrics, optimizer, and ranker
# the full estimator gets created like this:
tfr.keras.estimator.model_to_estimator(
  model=ranker,
  model_dir=hparams.get("model_dir"),
  config=config,
  warm_start_from=util.get_warm_start_settings(exclude=_NETWORK_NAME)) # warm start for BERT

Finally, we use the RankingPipeline class from the TFRanking library to pull all the training components together. This component includes the train_and_eval() method to handle training and evaluating the ranking model in one call.

bert_ranking_pipeline = tfr.ext.pipeline.RankingPipeline(
  context_feature_columns=context_feature_columns(),
  example_feature_columns=example_feature_columns(),
  hparams=hparams,
  estimator=get_estimator(hparams),
  label_feature_name="relevance",
  label_feature_type=tf.int64,
  size_feature_name=_SIZE)

bert_ranking_pipeline.train_and_eval(local_training=FLAGS.local_training)

The pipeline automatically configures tensorboard to monitor all metrics like the NDCG at various list lengths.

training

Predict

The RankingPipeline also saves the best loss models in the SavedModel format automatically while training. You can easily serve this model with a TFServing docker container to rank movies for a user.

$ docker run -t --rm -p 8501:8501 -v "./model/export/best_model_by_loss:/models/tfrbert" -e MODEL_NAME=tfrbert tensorflow/serving &

Running inference for this kind of model requires serializing the entire ELWC (example list with context proto) into a base64 string in order to make a request to the server correctly (See this issue). We’ve included a handy inference script here. The output looks like a list of relevance scores for each movie in the Example list of a sample.

So to recommend movies to our users, we can pass the list of movies a user hasn’t watched along with their user context data to rank these movies. We can then suggest the top N ranked videos for them to enjoy next.

The TF-Ranking library makes training and deploying ranking models simpler. It is actively expanding as more developers include other models for improving the ranking task.