estimator¶
estimator¶
- class EstimatorSpec(label, pred, head_name=None, loss=None, optimizer=None, classification=True)[source]¶
EstimatorSpec是model_fn返回的数据结构。
- Parameters:
label (
Union[tf.Tensor,List[tf.Tensor]]
) – 样本标签,multi-head可以使用列表pred (
Union[tf.Tensor,List[tf.Tensor]]
) – 预测结果,multi-head可以使用列表head_name (
Union[str,List[str]]
) – predict名称,multi-head可以使用列表loss (
tf.Tensor
) – 损失optimizer (
tf.Optimizer
) – dense部分的优化器classification (
Union[bool,List[bool]]
) – 是否为分类模型,multi-head可使用列表
- class RunConfig(is_local=False, num_ps=0, num_workers=1, chief_timeout_secs=1800, operation_timeout_in_ms=-1, session_creation_timeout_secs=7200, enable_fused_layout=False, enable_model_dump=False, partial_recovery=False, max_retry_times=6, retry_wait_in_secs=5, bzid=None, base_name=None, ps_replica_num=None, enable_parameter_sync=False, model_dir='', restore_dir=None, restore_ckpt=None, save_checkpoints_secs=None, save_checkpoints_steps=None, max_rpc_deadline_millis=30000, dense_only_save_checkpoints_secs=None, dense_only_save_checkpoints_steps=None, checkpoints_max_to_keep=10, warmup_file='./warmup_file', enable_local_profiling=False, use_native_multi_hash_table=None, clear_nn=False, continue_training=False, reload_alias_map=None, enable_alias_map_auto_gen=None, kafka_topics=None, kafka_group_id=None, kafka_servers=None)[source]¶
Estimator相关配置,用户模型外参数统一入口
- Parameters:
chief_timeout_secs (
int
) – chief 超时时长,默认为 1800秒operation_timeout_in_ms (
int
) – 操作超时时长,默认为 -1,不会超时session_creation_timeout_secs (
int
) – session创建超时时长,默认为7200秒enable_fused_layout (
bool
) – 是否打开layout融合,加速计算,training & serving 阶段都有帮助,默认为 Falsepartial_recovery (
bool
) – 是否开启部分恢复max_retry_times (
int
) – 发生容错时,最大重启次数,默认为 6 (一般不需要修改)retry_wait_in_secs (
int
) – 发生容错时,重启时间间隔,默认为 5 (一般不需要修改)enable_parameter_sync (
bool
) – 是否开启参数同步,默认为Falsemodel_dir (
str
) – 模型dump目录restore_dir (
str
) – 模型加载目录,当dump目录与加载目录不同时才需指定,默认从model_dir中加载模型restore_ckpt (
str
) – 加载checkpoint版本,默认加载最新版save_checkpoints_secs (
int
) – 每过多少秒存一个checkpointsave_checkpoints_steps (
int
) – 每过多少step存一个checkpointdense_only_save_checkpoints_secs (
int
) – 每过多少秒存一个dense部分checkpointdense_only_save_checkpoints_steps (
int
) – 每过多少step存一个dense部分checkpointcheckpoints_max_to_keep (
int
) – 最多保存多少个checkpointuse_native_multi_hash_table (
bool
) – 请不要指定这个变量,将于2023-1-1移除clear_nn (
bool
) – 是否在reload模型时将dense部分随机初始化,默认为false。不会对 sparase 部分有影响continue_training (
bool
) – 是clear_nn为true时,global_step是否继续保持,默认为false。