estimator

estimator

class EstimatorSpec(label, pred, head_name=None, loss=None, optimizer=None, classification=True)[source]

EstimatorSpec是model_fn返回的数据结构。

Parameters:
  • label (Union[tf.Tensor,List[tf.Tensor]]) – 样本标签,multi-head可以使用列表

  • pred (Union[tf.Tensor,List[tf.Tensor]]) – 预测结果,multi-head可以使用列表

  • head_name (Union[str,List[str]]) – predict名称,multi-head可以使用列表

  • loss (tf.Tensor) – 损失

  • optimizer (tf.Optimizer) – dense部分的优化器

  • classification (Union[bool,List[bool]]) – 是否为分类模型,multi-head可使用列表

_replace(**kwds)[source]

Return a new EstimatorSpec replacing specified fields with new values.

class RunConfig(is_local=False, num_ps=0, num_workers=1, chief_timeout_secs=1800, operation_timeout_in_ms=-1, session_creation_timeout_secs=7200, enable_fused_layout=False, enable_model_dump=False, partial_recovery=False, max_retry_times=6, retry_wait_in_secs=5, bzid=None, base_name=None, ps_replica_num=None, enable_parameter_sync=False, model_dir='', restore_dir=None, restore_ckpt=None, save_checkpoints_secs=None, save_checkpoints_steps=None, max_rpc_deadline_millis=30000, dense_only_save_checkpoints_secs=None, dense_only_save_checkpoints_steps=None, checkpoints_max_to_keep=10, warmup_file='./warmup_file', enable_local_profiling=False, use_native_multi_hash_table=None, clear_nn=False, continue_training=False, reload_alias_map=None, enable_alias_map_auto_gen=None, kafka_topics=None, kafka_group_id=None, kafka_servers=None)[source]

Estimator相关配置,用户模型外参数统一入口

Parameters:
  • chief_timeout_secs (int) – chief 超时时长,默认为 1800秒

  • operation_timeout_in_ms (int) – 操作超时时长,默认为 -1,不会超时

  • session_creation_timeout_secs (int) – session创建超时时长,默认为7200秒

  • enable_fused_layout (bool) – 是否打开layout融合,加速计算,training & serving 阶段都有帮助,默认为 False

  • partial_recovery (bool) – 是否开启部分恢复

  • max_retry_times (int) – 发生容错时,最大重启次数,默认为 6 (一般不需要修改)

  • retry_wait_in_secs (int) – 发生容错时,重启时间间隔,默认为 5 (一般不需要修改)

  • enable_parameter_sync (bool) – 是否开启参数同步,默认为False

  • model_dir (str) – 模型dump目录

  • restore_dir (str) – 模型加载目录,当dump目录与加载目录不同时才需指定,默认从model_dir中加载模型

  • restore_ckpt (str) – 加载checkpoint版本,默认加载最新版

  • save_checkpoints_secs (int) – 每过多少秒存一个checkpoint

  • save_checkpoints_steps (int) – 每过多少step存一个checkpoint

  • dense_only_save_checkpoints_secs (int) – 每过多少秒存一个dense部分checkpoint

  • dense_only_save_checkpoints_steps (int) – 每过多少step存一个dense部分checkpoint

  • checkpoints_max_to_keep (int) – 最多保存多少个checkpoint

  • use_native_multi_hash_table (bool) – 请不要指定这个变量,将于2023-1-1移除

  • clear_nn (bool) – 是否在reload模型时将dense部分随机初始化,默认为false。不会对 sparase 部分有影响

  • continue_training (bool) – 是clear_nn为true时,global_step是否继续保持,默认为false。

class Estimator(model, conf, warm_start_from=None)[source]

利用Estimator可以实现local模式与分布式模式的统一,另外,Estimator可以帮助初始化/save/restore变量,执行hooks,写summary等

Parameters:
  • model (Model) – NativeModel对象

  • conf (RunConfig) – 运行模型所要的配置

import_saved_model(saved_model_path, input_name='instances', output_name='output', signature=None)[source]

导出saved_model

Parameters:

saved_model_path (str) – saved_model路径