模型参数

用户使用 Monolith 时需要关注的参数大致分为 2 类。

1. 训练任务参数

这部分参数一般伴随「训练任务」设置。

1.1 RunConfig

estimator.RunConfigmodel.pymain 函数中。

def main(_):
  est_config = RunConfig(dense_only_save_checkpoints_secs=600,
                         enable_fused_layout=True)
  model = Model()
  estimator = Estimator(model, est_config)
  estimator.train()

if __name__ == "__main__":
  tf.compat.v1.disable_eager_execution()
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  app.run(main)

这部分参数拟逐步收敛,后续不再向用户暴露,会在前端提交页面以用户交互的方式指定。

1.2 前端提交页面参数

  • 任务类型选择:决定任务类型是训练还是 eval

  • Checkpoint 加载:决定该任务从哪个 checkpoint load 参数以训练或者 eval

  • Dense 参数重新初始化:决定该任务训练时是否将 dense 参数重新初始化。一般模型迭代中,用户改了 dense 结构导致无法 load 之前的参数,但又想 load sparse 参数以加快训练收敛时会使用此功能。

2. 自定义模型代码内参数

这部分参数一般伴随「模型」设置,出现在自定义模型的 __init__ 方法中。

class MyModel(MonolithModel):
  def __init__(self, params=None):
    super(MyModel, self).__init__(params)
    # data pipline(自定义模型 input_fn 中使用)
    self.batch_size = 256    # 批大小,无默认值
    self.shuffle_size = 100  # shuffle buffer size

    # training(自定义模型 model_fn 中使用)
    self.default_occurrence_threshold = 2     # 特征过滤阈值
    self.sample_bias = True                   # 是否校正样本采样偏差

    # training(Monolith 框架内部生效)
    self.set_train_clip_norm(1000.0)          # Clip by global norm 阈值,默认 250.0
    self.set_train_dense_weight_decay(0.0001) # Dense 权值衰减,默认 0.001

可细分为两类

  • 用户自己定义,自己使用的参数(比如 self.shuffle_size = 100

  • 框架内部定义,用户设置的参数(比如 self.set_train_clip_norm(1000.0)