深度解析TensorFlow组件Estimator:构建自定义Estimator

Have you ever wondered what’s the magic behind the tutorials on Large-scale Linear Modelsand Wide & Deep Learning? I hope this post would at least point you to the right direction.

你是否思考过TensorFlow的 tutorial 和其背后的“魔力”?希望这篇文章至少能给你思考的正确方向。

TensorFlow的基本概念可以去查看TensorFlow官方文档。这里将帮你更好的理解TensorFlow Learn中 estimator 的工作原理,并指导你构建适合自己特定应用的estimator。

BaseEstimator和Estimator的理解

BaseEstimator是TensorFlow训练和评估模块的抽象和基类。它利用graph_actions.py的隐藏逻辑,提供像fit()partial_fit()evaluate()predict()的基本功能,处理不同类型的输入数据批量拉取(Note:未来learn.DataFrame 将替代DataFeeder)。它通过dtypes来检查输入数据的兼容,考虑输入数据是否稀疏需要使用 estimators.tensor_signature

BaseEstimator为monitors,checkpointing等初始化设置,并提供了构建和评估自定义模块的大部分逻辑。_get_train_ops()_get_eval_ops()_get_predict_ops()放在子类中实现,给Estimator自定义带来了更大的自由。BaseEstimator也是分布式的。

TensorFlow模块中Estimator的实现给我们重写BaseEstimator子类提供了很好的范本。
例如,Estimator中的 _get_train_ops() 载入 featurestargets 作为输入,返回训练Operation和损失Tensor的一个 tuple。如果你想完成自己的 estimator,并且用于非监督机器学习训练,这时你就可以自由决定targets是否可忽略。

类似地,子类中的 _get_eval_ops() 可自定义metric来评估每步的训练。在TensorFlow的high-level模块中可发现一打适用的metric。它们会返回Tensor对象的字典,表示指定metric的评价ops。

_get_predict_ops() 可实现自定义的prediction,例如 概率 v.s. 实际预测输出。它将返回一个Tensor或者Tensor对象的字典,表示预测ops。你可以很轻松的使用父类的predict() 函数实现像 transform() 的功能。

Estimator示例

逻辑回归(LogisticRegressor)

Estimator已经提供了自定义estimator大部分实现。例如,LogisticRegressor仅需实现自己的metric即可,比如AUC,accuracy,precision和recall。开发者使用LogisticRegressor子类即可实现二值分类问题。

随机森林(TensorForestEstimator)

TensorForestEstimator已经增加到TensorFlow Learn。contrib.tensor_forest 详细的实现了随机森林算法(Random Forests)评估器,并对外提供high-level API使得开发者构建随机森林评估器更简单。

例如,开发者只需传入params到构造器,params 使用 params.fill() 来填充,而不用传入所有的超参数,Tensor Forest自己的RandomForestGraphs使用这些参数来构建整幅图。

class TensorForestEstimator(estimator.BaseEstimator):
  """An estimator that can train and evaluate a random forest."""

  def __init__(self, params, device_assigner=None, model_dir=None,
               graph_builder_class=tensor_forest.RandomForestGraphs,
               master='', accuracy_metric=None,
               tf_random_seed=None, verbose=1,
               config=None):
    self.params = params.fill()

随机森林算法的接口实现有许多细节,_get_predict_ops()利用tensor_forest.RandomForestGraphs来构建随机森林图,调用graph_builder.inference_graph来获取预测ops。

def _get_predict_ops(self, features):
    graph_builder = self.graph_builder_class(
        self.params, device_assigner=self.device_assigner, training=False,
        **self.construction_args)
    features, spec = data_ops.ParseDataTensorOrDict(features)

    return graph_builder.inference_graph(features, data_spec=spec)

类似地,使用 graph_builder.training_loss 来实现_get_train_ops()。注意,TensorForestEstimator使用了tensor_forest.data.data_ops的模块功能,比如 ParseDataTensorOrDict和ParseLabelTensorOrDict解析输入特征和标签。

其它用例

K-means聚类的estimator刚加入项目,放在contrib.factorization.python.ops.kmeans。 更多的例子可以在learn.estimators中找到。

强烈推荐你领悟代码整体结构,开始实现自己的estimator之旅!

参考:http://terrytangyuan.github.io/2016/03/14/scikit-flow-intro

王海良@Chatopera 聊天机器人 机器学习 智能客服
Chatopera 联合创始人 & CEO,运营聊天机器人平台 https://bot.chatopera.com,让聊天机器人上线!2015年开始探索聊天机器人的商业应用,实现基于自然语言交互的流程引擎、语音识别、自然语言理解,2018年出版《智能问答与深度学习》一书。