Source code for monolith.native_training.estimator

from absl import logging, flags
from dataclasses import dataclass
from dataclasses_json import dataclass_json
import os
import copy
import numpy as np
import collections
import getpass
from typing import Dict, List
from kazoo.client import KazooClient
from typing import Optional, Union, get_type_hints

import tensorflow as tf

from monolith.native_training import env_utils
from monolith.agent_service.utils import AgentConfig
from monolith.agent_service.backends import ZKBackend
from monolith.native_training.zk_utils import default_zk_servers
from monolith.agent_service.replica_manager import ReplicaWatcher
from monolith.native_training.cpu_training import CpuTraining, create_exporter
from monolith.native_training.runner_utils import RunnerConfig, monolith_discovery
from monolith.native_training.cpu_training import local_train_internal
from monolith.native_training.cpu_training import distributed_train
from monolith.native_training.service_discovery import ServiceDiscoveryType
from monolith.native_training.monolith_export import monolith_export
from monolith.native_training.model_dump.dump_utils import DumpUtils
from monolith.core.hyperparams import InstantiableParams

from monolith.native_training.data.item_pool_hook import ItemPoolSaveRestoreHook
from monolith.native_training.utils import set_metric_prefix
from monolith.native_training.data.parsers import get_default_parser_ctx, ParserCtx
from monolith.native_training.model_export.export_context import \
  is_exporting, is_exporting_distributed, ExportMode
from monolith.native_training.zk_utils import MonolithKazooClient


[docs]@monolith_export class EstimatorSpec( collections.namedtuple( 'EstimatorSpec', ['label', 'pred', 'head_name', 'loss', 'optimizer', 'classification'])): """EstimatorSpec是model_fn返回的数据结构. Args: label (:obj:`Union[tf.Tensor, List[tf.Tensor]]`): 样本标签, multi-head可以使用列表 pred (:obj:`Union[tf.Tensor, List[tf.Tensor]]`): 预测结果, multi-head可以使用列表 head_name (:obj:`Union[str, List[str]]`): predict名称, multi-head可以使用列表 loss (:obj:`tf.Tensor`): 损失 optimizer (:obj:`tf.Optimizer`): dense部分的优化器 classification (:obj:`Union[bool, List[bool]]`): 是否为分类模型, multi-head可使用列表 """ def __new__(cls, label, pred, head_name=None, loss=None, optimizer=None, classification=True): return super(EstimatorSpec, cls).__new__(cls, label=label, pred=pred, head_name=head_name, loss=loss, optimizer=optimizer, classification=classification)
[docs] def _replace(self, **kwds): """Return a new EstimatorSpec replacing specified fields with new values.""" if 'mode' in kwds: if self.mode != kwds['mode']: raise ValueError('mode of EstimatorSpec cannot be changed.') new_fields = map(kwds.pop, self._fields, list(self)) return EstimatorSpec(*new_fields)
[docs]@monolith_export @dataclass_json @dataclass class RunConfig: """Estimator相关配置, 用户模型外参数统一入口 Args: chief_timeout_secs (:obj:`int`): chief 超时时长, 默认为 1800秒 operation_timeout_in_ms (:obj:`int`): 操作超时时长, 默认为 -1, 不会超时 session_creation_timeout_secs (:obj:`int`): session创建超时时长, 默认为7200秒 enable_fused_layout (:obj:`bool`): 是否打开layout融合, 加速计算, training & serving 阶段都有帮助,默认为 False partial_recovery (:obj:`bool`): 是否开启部分恢复 max_retry_times (:obj:`int`): 发生容错时, 最大重启次数, 默认为 6 (一般不需要修改) retry_wait_in_secs (:obj:`int`): 发生容错时, 重启时间间隔, 默认为 5 (一般不需要修改) enable_parameter_sync (:obj:`bool`): 是否开启参数同步, 默认为False model_dir (:obj:`str`): 模型dump目录 restore_dir (:obj:`str`): 模型加载目录, 当dump目录与加载目录不同时才需指定, 默认从model_dir中加载模型 restore_ckpt (:obj:`str`): 加载checkpoint版本, 默认加载最新版 save_checkpoints_secs (:obj:`int`): 每过多少秒存一个checkpoint save_checkpoints_steps (:obj:`int`): 每过多少step存一个checkpoint dense_only_save_checkpoints_secs (:obj:`int`): 每过多少秒存一个dense部分checkpoint dense_only_save_checkpoints_steps (:obj:`int`): 每过多少step存一个dense部分checkpoint checkpoints_max_to_keep (:obj:`int`): 最多保存多少个checkpoint use_native_multi_hash_table (:obj:`bool`): 请不要指定这个变量,将于2023-1-1移除 clear_nn (:obj:`bool`): 是否在reload模型时将dense部分随机初始化, 默认为false. 不会对 sparase 部分有影响 continue_training (:obj:`bool`): 是clear_nn为true时, global_step是否继续保持, 默认为false. """ # basic is_local: bool = False num_ps: int = 0 num_workers: int = 1 chief_timeout_secs: int = 1800 operation_timeout_in_ms: int = -1 session_creation_timeout_secs: int = 7200 enable_fused_layout: bool = False enable_model_dump: bool = False # failover partial_recovery: bool = False max_retry_times: int = 6 retry_wait_in_secs: int = 5 # for params sync bzid: str = None base_name: str = None ps_replica_num: int = None enable_parameter_sync: bool = False # checkpoint and export model_dir: str = "" restore_dir: str = None restore_ckpt: str = None save_checkpoints_secs: int = None save_checkpoints_steps: int = None max_rpc_deadline_millis: int = 30000 dense_only_save_checkpoints_secs: int = None dense_only_save_checkpoints_steps: int = None checkpoints_max_to_keep: int = 10 warmup_file: str = './warmup_file' enable_local_profiling: bool = False use_native_multi_hash_table: bool = None clear_nn: bool = False continue_training: bool = False reload_alias_map: Dict[str, int] = None enable_alias_map_auto_gen: bool = None kafka_topics: str = None kafka_group_id: str = None kafka_servers: str = None def to_runner_config(self) -> RunnerConfig: conf = RunnerConfig( restore_dir=self.restore_dir, restore_ckpt=self.restore_ckpt, model_dir=self.model_dir, enable_fused_layout=self.enable_fused_layout, enable_model_dump=self.enable_model_dump, warmup_file=self.warmup_file, enable_alias_map_auto_gen=self.enable_alias_map_auto_gen, kafka_topics=self.kafka_topics, kafka_group_id=self.kafka_group_id, kafka_servers=self.kafka_servers) for name, _ in get_type_hints(RunConfig).items(): value = getattr(self, name) if hasattr(conf, name) and value is not None: default = getattr(RunConfig, name) # must be and, because RunnerConfig value can be writen by command line # we cannot use default value in RunConfig to overwrite command line value if value != default and getattr(conf, name) != value: setattr(conf, name, value) # in case US tearm use CONSUL conf.discovery_type = ServiceDiscoveryType.ZK # set default value for embedding prefetch/postpush if conf.embedding_prefetch_capacity <= 0: conf.embedding_prefetch_capacity = 1 if not conf.enable_embedding_postpush: conf.enable_embedding_postpush = True # [todo] remove this when enable_realtime_training changed to enable_parameter_sync if self.enable_parameter_sync: if hasattr(conf, 'enable_realtime_training'): conf.enable_realtime_training = True elif hasattr(conf, 'enable_parameter_sync'): conf.enable_parameter_sync = True else: raise RuntimeError("enable_parameter_sync not set!") return conf def __post_init__(self): ser_data = self.to_json() DumpUtils().add_config(ser_data)
[docs]@monolith_export class Estimator(object): """利用Estimator可以实现local模式与分布式模式的统一, 另外, Estimator可以帮助初始化/save/restore变量, 执行hooks, 写summary等 Args: model (:obj:`Model`): NativeModel对象 conf (:obj:`RunConfig`): 运行模型所要的配置 """ def __init__(self, model, conf: Union[RunConfig, RunnerConfig], warm_start_from=None): self._runner_conf = conf.to_runner_config() if isinstance( conf, RunConfig) else conf self._model = model self._task = None self._warm_start_from = warm_start_from self._sync_backend = None self._kazoo_client = None if isinstance(conf, RunConfig): self._enable_loacl_profiling = conf.enable_local_profiling else: self._enable_loacl_profiling = False logging.info(self._runner_conf) if self._runner_conf.is_local: # local mode cannot asscess deep_insight self._model.metrics.enable_deep_insight = False else: self._model.metrics.enable_deep_insight = True if self._runner_conf.deep_insight_name: self._model.metrics.deep_insight_name = self._runner_conf.deep_insight_name if self._runner_conf.deep_insight_target: self._model.metrics.deep_insight_target = self._runner_conf.deep_insight_target if self._runner_conf.deep_insight_sample_ratio: self._model.metrics.deep_insight_sample_ratio = self._runner_conf.deep_insight_sample_ratio if self._runner_conf.enable_realtime_training and self._runner_conf.server_type == 'ps': assert self._runner_conf.bzid, "Business id cannot be none while realtime training." assert self._runner_conf.base_name, "Base name cannot be none while realtime training." zk_servers = self._runner_conf.zk_server or os.environ.get( 'zk_servers', default_zk_servers()) if self._runner_conf.unified_serving: self._sync_backend = ZKBackend(self._runner_conf.bzid, zk_servers=zk_servers) else: assert self._runner_conf.base_name, "Base name cannot be none while realtime training." self._kazoo_client = MonolithKazooClient(hosts=zk_servers) self._kazoo_client.start() agent_config = AgentConfig(bzid=self._runner_conf.bzid, base_name=self._runner_conf.base_name, deploy_type='ps', num_ps=self._runner_conf.num_ps, dc_aware=self._runner_conf.dc_aware) replica_watcher = ReplicaWatcher( self._kazoo_client, agent_config, zk_watch_address_family=self._runner_conf.zk_watch_address_family) self._sync_backend = replica_watcher.to_sync_wrapper() if self._runner_conf.params_override: logging.info("Override: {}".format(self._runner_conf.params_override)) params_override_dict = json.loads(self._runner_conf.params_override) if hasattr(model, 'p'): model.p.set(**params_override_dict) elif hasattr(model, 'params'): model.params.set(**params_override_dict) else: logging.warning('params_override error!') try: if not os.environ.get("HADOOP_HDFS_HOME"): env_utils.setup_hdfs_env() except Exception as e: logging.error('setup_hdfs_env fail {}!'.format(e)) os.environ["TF_GRPC_WORKER_CACHE_THREADS"] = str( self._runner_conf.tf_grpc_worker_cache_threads) os.environ["MONOLITH_GRPC_WORKER_SERVICE_HANDLER_MULTIPLIER"] = str( self._runner_conf.monolith_grpc_worker_service_handler_multiplier) # private attr self.__est: Optional[tf.estimator.Estimator] = None @property def _sess_config(self): return self._est._session_config @property def model_dir(self): return self._runner_conf.model_dir @property def config(self): return self._est._config @property def _est(self): if self.__est is None: model = copy.deepcopy(self._model) model.mode = tf.estimator.ModeKeys.PREDICT self._task = CpuTraining(config=self._runner_conf, task=model.instantiate()) # the default estimate for predict/export_saved_model/import_saved_model if 'TF_CONF' in os.environ: del os.environ['TF_CONF'] self.__est = tf.estimator.Estimator( self._task.create_model_fn(), model_dir=self._runner_conf.model_dir, config=tf.estimator.RunConfig(log_step_count_steps=1), warm_start_from=self._warm_start_from) return self.__est def _init_fountain_env(self): if self._model.train.use_fountain and bool( self._runner_conf.fountain_zk_host) and bool( self._runner_conf.fountain_model_name): logging.info("Override Fountain Params:{}; {}".format( self._runner_conf.fountain_model_name, self._runner_conf.fountain_zk_host)) self._model.train.fountain_zk_host = self._runner_conf.fountain_zk_host self._model.train.fountain_model_name = self._runner_conf.fountain_model_name def close(self): if self._sync_backend is not None: try: self._sync_backend.stop() except Exception as e: logging.error(e) if self._kazoo_client is not None: try: self._kazoo_client.stop() except Exception as e: logging.info(e) try: self._kazoo_client.close() except Exception as e: logging.info(e) def get_variable_value(self, name): return self._est.get_variable_value(name) def get_variable_names(self): return self._est.get_variable_names() def latest_checkpoint(self): return self._est.latest_checkpoint() def train(self, steps=None, max_steps=None): set_metric_prefix("monolith.training.{}".format( self._runner_conf.deep_insight_name)) model = copy.deepcopy(self._model) if not isinstance(model, InstantiableParams): model.file_name = self._model.file_name model.mode = tf.estimator.ModeKeys.TRAIN if steps is not None: model.train.steps = steps if max_steps is not None: model.train.max_steps = max_steps self._init_fountain_env() if self._runner_conf.is_local: if not self._runner_conf.model_dir: model_dir = "/tmp/{}/{}".format(getpass.getuser(), model.name) else: model_dir = self._runner_conf.model_dir DumpUtils().record_params(model) self.__est = local_train_internal(model, self._runner_conf, model_dir=model_dir, steps=steps, profiling=self._enable_loacl_profiling) DumpUtils().dump(f'{self._runner_conf.model_dir}/model_dump') else: DumpUtils().enable = False logging.info(f'{model.p}') with monolith_discovery(self._runner_conf) as discovery: if self._sync_backend is not None: self._sync_backend.start() self._sync_backend.subscribe_model(self._runner_conf.model_name or model.metrics.deep_insight_name) logging.info("Environment vars: %s", os.environ) logging.info("Flags: %s", flags.FLAGS.flag_values_dict()) self.__est = distributed_train(config=self._runner_conf, discovery=discovery, params=model, sync_backend=self._sync_backend) self.close() def evaluate(self, steps=None): model = copy.deepcopy(self._model) model.mode = tf.estimator.ModeKeys.EVAL if not isinstance(model, InstantiableParams): model.file_name = self._model.file_name self._runner_conf.mode = tf.estimator.ModeKeys.EVAL if steps is not None: model.train.steps = steps self._init_fountain_env() if self._runner_conf.is_local: DumpUtils().record_params(model) if not self._runner_conf.model_dir: model_dir = "/tmp/{}/{}".format(getpass.getuser(), model.name) else: model_dir = self._runner_conf.model_dir self.__est = local_train_internal(model, self._runner_conf, model_dir=model_dir, steps=steps, profiling=self._enable_loacl_profiling) DumpUtils().dump(f'{self._runner_conf.model_dir}/model_dump') else: DumpUtils().enable = False logging.info(f'{model.p}') logging.info("Environment vars: %s", os.environ) logging.info("Flags: %s", flags.FLAGS.flag_values_dict()) with monolith_discovery(self._runner_conf) as discovery: self.__est = distributed_train(self._runner_conf, discovery, model, sync_backend=self._sync_backend) self.close() def predict(self, predict_keys=None, hooks=None, checkpoint_path=None, yield_single_examples=True): est = self._est # create tf estimator input_fn = self._task.create_input_fn(tf.estimator.ModeKeys.PREDICT) est.predict(input_fn, predict_keys, hooks, checkpoint_path, yield_single_examples) self.close() def export_saved_model(self, batch_size=64, name=None, dense_only: bool = False, enable_fused_layout: bool = False): model = copy.deepcopy(self._model) runner_conf = copy.deepcopy(self._runner_conf) runner_conf.enable_fused_layout = enable_fused_layout model.name = name or "demo_export" model.train.per_replica_batch_size = batch_size model.mode = tf.estimator.ModeKeys.PREDICT model_dir = runner_conf.model_dir export_dir_base = os.path.join(model_dir, model.serving.export_dir_base) warmup_file = runner_conf.warmup_file task = CpuTraining(config=runner_conf, task=model.instantiate()) exporter = create_exporter(task, model_dir, warmup_file, export_dir_base, dense_only) serving_input_receiver_fn = task.create_serving_input_receiver_fn() with ParserCtx(enable_fused_layout=enable_fused_layout): exporter.export_saved_model(serving_input_receiver_fn)
[docs]@monolith_export def import_saved_model(saved_model_path: str, input_name: str = "instances", output_name: str = 'output', signature: str = None): """导出saved_model Args: saved_model_path (:obj:`str`): saved_model路径 """ class saved_model(object): def __init__(self, saved_model_path, signature, inputs, outputs): basename = os.path.basename(saved_model_path) if not basename.isnumeric(): versions = [] for subitem in tf.io.gfile.listdir(saved_model_path): if subitem.isnumeric(): versions.append(int(subitem)) if versions: versions.sort() saved_model_path = os.path.join(saved_model_path, str(versions[-1])) else: raise RuntimeError(f"no models in dir {saved_model_path}") self._saved_model_path = saved_model_path if signature: self._signature = signature else: self._signature = tf.compat.v1.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY if inputs: self._inputs = inputs if isinstance(inputs, (list, tuple)) else [inputs] else: self._inputs = None if outputs: self._outputs = outputs if isinstance(outputs, (list, tuple)) else [outputs] else: self._outputs = None def __enter__(self): class infer(object): def __init__(self, graph, sess, placeholders, output_dict, output_name_map): self._graph = graph self._sess = sess self._placeholders = placeholders self._output_dict = output_dict self._output_name_map = output_name_map def __call__(self, features: Dict[str, np.ndarray]) -> List[np.ndarray]: with self._graph.as_default(), self._sess.as_default(): if len(self._placeholders) == 1: placeholders = next(iter(self._placeholders.values())) result = sess.run(self._output_dict, feed_dict={placeholders: features}) else: result = sess.run(self._output_dict, feed_dict={self._placeholders[name]: feature for name, feature in features.items()}) return {self._output_name_map[key]: tensor for key , tensor in result.items()} tag = tf.compat.v1.saved_model.tag_constants.SERVING graph = tf.compat.v1.Graph() sess = tf.compat.v1.Session(graph=graph) with graph.as_default(), sess.as_default(): imported = tf.compat.v1.saved_model.load(sess, {tag}, self._saved_model_path) print(imported.signature_def, flush=True) signature_def = imported.signature_def[self._signature] placeholders: Dict[str, tf.compat.v1.placeholder] = {} for input_name in self._inputs: input_ph_name = signature_def.inputs[input_name].name input_ph = graph.get_tensor_by_name(input_ph_name) placeholders[input_name] = input_ph output_dict, output_name_map = {}, {} if self._outputs: for output_name in self._outputs: output_tensor_name = signature_def.outputs[output_name].name output_tensor = graph.get_tensor_by_name(output_tensor_name) if output_tensor_name.endswith(':0'): output_tensor_name = output_tensor_name[0:-2] output_dict[output_tensor_name] = output_tensor output_name_map[output_tensor_name] = output_name else: for output_name, tensor in signature_def.outputs.items(): output_tensor_name = tensor.name output_tensor = graph.get_tensor_by_name(output_tensor_name) if output_tensor_name.endswith(':0'): output_tensor_name = output_tensor_name[0:-2] output_dict[output_tensor_name] = output_tensor output_name_map[output_tensor_name] = output_name logging.info('import_saved_model finished') return infer(graph, sess, placeholders, output_dict, output_name_map) def __exit__(self, exc_type, exc_val, exc_tb): logging.info('exit import_saved_model') return saved_model(saved_model_path, signature, input_name, output_name)