Estimator¶
- class EstimatorSpec(label, pred, head_name=None, loss=None, optimizer=None, classification=True)¶
EstimatorSpec 是
model_fn
返回的数据结构。- Parameters:
label (
Union[tf.Tensor, List[tf.Tensor]]
) – 样本标签, multi-head 可以使用列表pred (
Union[tf.Tensor, List[tf.Tensor]]
) – 预测结果, multi-head 可以使用列表head_name (
Union[str, List[str]]
) – predict名称, multi-head 可以使用列表loss (
tf.Tensor
) – 损失optimizer (
tf.Optimizer
) – dense 部分的优化器classification (
Union[bool, List[bool]]
) – 是否为分类模型, multi-head可使用列表
- class RunConfig(is_local=False, num_ps=0, num_workers=1, chief_timeout_secs=1800, operation_timeout_in_ms=-1, session_creation_timeout_secs=7200, enable_fused_layout=False, enable_model_dump=False, partial_recovery=False, max_retry_times=6, retry_wait_in_secs=5, bzid=None, base_name=None, ps_replica_num=None, enable_parameter_sync=False, model_dir='', restore_dir=None, restore_ckpt=None, save_checkpoints_secs=None, save_checkpoints_steps=None, max_rpc_deadline_millis=30000, dense_only_save_checkpoints_secs=None, dense_only_save_checkpoints_steps=None, checkpoints_max_to_keep=10, warmup_file='./warmup_file', enable_local_profiling=False, use_native_multi_hash_table=None, clear_nn=False, continue_training=False, reload_alias_map=None, enable_alias_map_auto_gen=None, kafka_topics=None, kafka_group_id=None, kafka_servers=None, disable_native_metrics=True, save_summary_steps=100, log_step_count_steps=100, vector_search_item_feature_name=None, enable_compute_pushdown=False)¶
Estimator相关配置, 用户模型外参数统一入口
- Parameters:
enable_fused_layout (
bool
) – 是否打开 layout 融合,打开具有加速效果save_checkpoints_secs (
int
) – 每过多少秒存一次 checkpointdense_only_save_checkpoints_secs (
int
) – 每过多少秒存一次 dense 部分 checkpointsave_summary_steps – 每隔多少 global_step 保存一次 summary
log_step_count_steps – 每隔多少 global_step 打印一次 loss