Estimator

class EstimatorSpec(label, pred, head_name=None, loss=None, optimizer=None, classification=True)

EstimatorSpec 是 model_fn 返回的数据结构。

Parameters:
  • label (Union[tf.Tensor, List[tf.Tensor]]) – 样本标签, multi-head 可以使用列表

  • pred (Union[tf.Tensor, List[tf.Tensor]]) – 预测结果, multi-head 可以使用列表

  • head_name (Union[str, List[str]]) – predict名称, multi-head 可以使用列表

  • loss (tf.Tensor) – 损失

  • optimizer (tf.Optimizer) – dense 部分的优化器

  • classification (Union[bool, List[bool]]) – 是否为分类模型, multi-head可使用列表

class RunConfig(is_local=False, num_ps=0, num_workers=1, chief_timeout_secs=1800, operation_timeout_in_ms=-1, session_creation_timeout_secs=7200, enable_fused_layout=False, enable_model_dump=False, partial_recovery=False, max_retry_times=6, retry_wait_in_secs=5, bzid=None, base_name=None, ps_replica_num=None, enable_parameter_sync=False, model_dir='', restore_dir=None, restore_ckpt=None, save_checkpoints_secs=None, save_checkpoints_steps=None, max_rpc_deadline_millis=30000, dense_only_save_checkpoints_secs=None, dense_only_save_checkpoints_steps=None, checkpoints_max_to_keep=10, warmup_file='./warmup_file', enable_local_profiling=False, use_native_multi_hash_table=None, clear_nn=False, continue_training=False, reload_alias_map=None, enable_alias_map_auto_gen=None, kafka_topics=None, kafka_group_id=None, kafka_servers=None, disable_native_metrics=True, save_summary_steps=100, log_step_count_steps=100, vector_search_item_feature_name=None, enable_compute_pushdown=False)

Estimator相关配置, 用户模型外参数统一入口

Parameters:
  • enable_fused_layout (bool) – 是否打开 layout 融合,打开具有加速效果

  • save_checkpoints_secs (int) – 每过多少秒存一次 checkpoint

  • dense_only_save_checkpoints_secs (int) – 每过多少秒存一次 dense 部分 checkpoint

  • save_summary_steps – 每隔多少 global_step 保存一次 summary

  • log_step_count_steps – 每隔多少 global_step 打印一次 loss

class Estimator(model, conf, warm_start_from=None)

Estimator 类似 tf.estimator ,另外, Estimator可以帮助初始化/save/restore变量, 执行hooks, 写summary等

Parameters:
  • model (Model) – MonolithModel 对象

  • conf (RunConfig) – 运行模型所要的配置

Warning

  • Estimator 以后将不再被用户侧感知