from absl import logging, flags
from enum import Enum
import os, sys, six
import types
from datetime import datetime, timedelta
from typing import Dict, List, Iterable, Callable, Optional, Union
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.data.experimental.ops import matching_files
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.data.util import convert
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.framework import load_library
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
import tensorflow.python.data.experimental.service as dsvc
from monolith.native_training.hooks import ckpt_hooks
from monolith.utils import get_libops_path
from monolith.native_training.monolith_export import monolith_export
from monolith.native_training.data.feature_utils import create_item_pool, string_to_variant, \
has_variant, kafka_resource_init, kafka_read_next
from monolith.native_training.data.feature_list import FeatureList
from monolith.native_training import native_task_context
from monolith.native_training.distribute import distributed_dataset
from kafka import KafkaConsumer
from threading import Thread, RLock
from queue import Queue
pb_datasource_ops = load_library.load_op_library(
get_libops_path('monolith/native_training/data/pb_data_ops.so'))
FLAGS = flags.FLAGS
POOL_KEY = "TF_ITEMPOOL"
class FeaturePruningType(object):
AS_IS = 0
PRUNING_FEATURE = 1
PRUNING_RAW_FEATURE = 2
[docs]@monolith_export
class PbType(Enum):
INSTANCE = 1
EXAMPLEBATCH = 2
EXAMPLE = 3
PLAINTEXT = 4
def to_name(self):
return self.name.lower()
def _get_params(name, default=None):
try:
if name == 'data_type':
attr_val = getattr(FLAGS, name)
if attr_val:
attr_val = attr_val.upper()
if attr_val == 'EXAMPL_EBATCH':
return PbType.EXAMPLEBATCH
else:
return PbType[attr_val]
else:
return default
else:
return getattr(FLAGS, name)
except:
return default
class DatasetMetaclass(type):
def __call__(cls, *args, **kwargs):
if kwargs.get('topics_or_files', None):
value = kwargs['topics_or_files']
if isinstance(value, str):
kwargs['file_name'] = kwargs.get('file_name') or value
else:
kwargs['patterns'] = kwargs.get('patterns') or value
kwargs['topics'] = kwargs.get('topics') or value
if kwargs.get('buffer_size_or_group_id', None):
value = kwargs['buffer_size_or_group_id']
if isinstance(value, int):
kwargs['buffer_size'] = kwargs.get('buffer_size') or value
else:
kwargs['group_id'] = kwargs.get('group_id') or value
if kwargs.get('input_pb_type_or_servers', None):
value = kwargs['input_pb_type_or_servers']
if isinstance(value, (str, list)):
kwargs['servers'] = kwargs.get('servers') or value
else:
kwargs['input_pb_type'] = kwargs.get('input_pb_type') or value
try:
# the first param is str, batch to streaming, use kafka params for cmd
args = [kwargs.pop('topics', FLAGS.kafka_topics.split(',')),
kwargs.pop('group_id', FLAGS.kafka_group_id),
kwargs.pop('servers', FLAGS.kafka_servers)]
assert all(x is not None for x in args)
logging.info('use KafkaDataset!')
return KafkaDataset(*args, **kwargs)
except:
logging.info("it's not streaming training")
if args is None or len(args) == 0:
if 'patterns' in kwargs and 'group_id' not in kwargs and 'servers' not in kwargs:
logging.info('use DistributedFilePBDataset!')
return DistributedFilePBDataset(**kwargs)
elif 'topics' in kwargs and 'group_id' in kwargs and 'servers' in kwargs:
logging.info('use KafkaDataset!')
return KafkaDataset(**kwargs)
elif 'file_name' in kwargs or len(kwargs) == 0:
return FilePBDataset(*args, **kwargs)
else:
return super(DatasetMetaclass, cls).__call__(*args, **kwargs)
elif isinstance(args[0], str):
logging.info('use FilePBDataset!')
return FilePBDataset(*args, **kwargs)
elif isinstance(args[0], (list, tuple)):
if len(args) > 1:
if isinstance(args[1], str):
logging.info('use KafkaDataset!')
return KafkaDataset(*args, **kwargs)
else:
logging.info('use DistributedFilePBDataset!')
return DistributedFilePBDataset(*args, **kwargs)
else:
if 'group_id' in kwargs or 'servers' in kwargs:
logging.info('use KafkaDataset!')
return KafkaDataset(*args, **kwargs)
else:
logging.info('use DistributedFilePBDataset!')
return DistributedFilePBDataset(*args, **kwargs)
else:
return super(DatasetMetaclass, cls).__call__(*args, **kwargs)
class PBDataset(metaclass=DatasetMetaclass):
def __init__(self,
topics_or_files: Union[str, List[str]] = '',
buffer_size_or_group_id: Union[int, str] = None,
input_pb_type_or_servers: Union[PbType, str] = None,
output_pb_type: PbType = None,
feature_pruning_type: int = FeaturePruningType.PRUNING_RAW_FEATURE,
disable_iterator_save_restore: bool = True,
*,
has_header=True, variant_type: str = None,
stream_timeout=-1, message_poll_timeout=10000, poll_batch_size: int = 1024,
filter_empty: bool = False, configuration=None, container: str = '', shared_name: str = '',
use_data_service: bool = False, cycle_length=None, block_length=None,
num_parallel_calls=None, deterministic=None,
**kwargs):
pass
@classmethod
def gen_patterns(cls, input_path: str = None, start_date:int = None, start_hour: int = None,
end_date: int = None, end_hour: int = None, is_hourly: bool = False, wildcard: str = '*') -> List[str]:
input_path = input_path or _get_params('input_path', None)
if not input_path:
return []
start_date = start_date or _get_params('start_date', None)
if not start_date:
return []
end_date = end_date or _get_params('end_date', None)
if not end_date:
end_date = datetime.today().strftime('%Y%m%d')
is_hourly = is_hourly if is_hourly is not None else _get_params('is_hourly', False)
start_hour = start_hour or _get_params('start_hour', 0) or 0
end_hour = end_hour or _get_params('end_hour', 0) or 0
wildcard = wildcard or _get_params('wildcard', '*')
start = datetime.strptime(f'{start_date}:{start_hour:02d}', '%Y%m%d:%H')
if is_hourly:
end = datetime.strptime(f'{end_date}:{end_hour:02d}', '%Y%m%d:%H')
else:
end = datetime.strptime(f'{end_date}:00', '%Y%m%d:%H')
delta = timedelta(hours=1) if is_hourly else timedelta(days=1)
cur = start
patterns = []
while cur < end:
if is_hourly:
pat = f"{cur.strftime('%Y%m%d/%H')}{wildcard}"
else:
pat = os.path.join(cur.strftime('%Y%m%d'), wildcard)
patterns.append(os.path.join(input_path, pat))
cur = cur + delta
return patterns
class DynamicMatchingFilesDataset(dataset_ops.DatasetSource):
"""A `Dataset` that list the files according to the input patterns."""
def __init__(self, patterns: List[str]):
assert patterns is not None and len(patterns) > 0
self._patterns = ops.convert_to_tensor(
patterns, dtype=dtypes.string, name="patterns")
variant_tensor = pb_datasource_ops.dynamic_matching_files_dataset(self._patterns)
super(DynamicMatchingFilesDataset, self).__init__(variant_tensor)
@property
def element_spec(self):
return tensor_spec.TensorSpec([], dtypes.string)
[docs]@monolith_export
class FilePBDataset(dataset_ops.DatasetSource):
"""从标准输入/pb文件中读取序列化数据, 并将其反序列化存于TF的Variant类型中. 这样做的好处是可以直接对PB对象进行过滤与修改,
不用等到parse以后. Monolith提供了一系列工具操作Variant变量, 如filter_by_fids, filter_by_value, negative_sample等
另外, InstanceReweightDataset/NegativeGenDataset 这些DataSet也可以直接作用于Variant
Args:
file_name (:obj:`str`): 文件名, 如果为空, 则从stdin读取数据
buffer_size (:obj:`int`): 读取文件时缓存大小, 默认100MB
input_pb_type (:obj:`str`): 输入pb类型, 可以是example/example_batch/instance
output_pb_type (:obj:`str`): 输入pb类型, 可以是example/instance/plaintext
Raises:
TypeError: 如果有任何参数与类型不匹配, 则抛TypeError
ValueError: 如果有任何值与期望不匹配, 则抛ValueError
"""
def __init__(
self,
file_name: str = "",
buffer_size: int = None,
input_pb_type: PbType = None,
output_pb_type: PbType = None,
feature_pruning_type: int = FeaturePruningType.PRUNING_RAW_FEATURE,
disable_iterator_save_restore: bool = True,
use_snappy: bool = None,
**kwargs):
input_pb_type = input_pb_type or _get_params('data_type', PbType.INSTANCE)
output_pb_type = output_pb_type or (PbType.INSTANCE if input_pb_type
== PbType.INSTANCE else PbType.EXAMPLE)
feature_name_list = []
feature_id_list = []
if input_pb_type in [PbType.EXAMPLEBATCH, PbType.EXAMPLE]:
try:
feature_list = FeatureList.parse()
for feature in feature_list:
name, slot = feature.feature_name, feature.slot
assert None not in [name, slot]
feature_name_list.append(name)
feature_id_list.append(slot)
except Exception as e:
logging.warning('Failed to parse feature_list.conf, %s', e)
self._file_name = file_name
self._buffer_size = buffer_size
self._input_pb_type = input_pb_type
self._output_pb_type = output_pb_type
self._out_type = tf.string if output_pb_type == PbType.PLAINTEXT else tf.variant
self._has_sort_id = kwargs.get('has_sort_id', _get_params('sort_id', True))
self._kafka_dump = kwargs.get('kafka_dump',
_get_params('kafka_dump', False))
logging.info('input_pb_type: %s, kafka_dump: %s, output_pb_type: %s',
self._input_pb_type, self._kafka_dump, self._output_pb_type)
self._kafka_dump_prefix = kwargs.get(
'kafka_dump_prefix', _get_params('kafka_dump_prefix', False))
self._lagrangex_header = kwargs.get('lagrangex_header',
_get_params('lagrangex_header', False))
if disable_iterator_save_restore and isinstance(file_name, str):
# This is the special case that dataset uses stdin as the input.
# In this case, we should diable the ckpt save/restore.
if context.default_execution_mode == context.GRAPH_MODE:
ckpt_hooks.disable_iterator_save_restore()
default_buffer_size = 128 * 1024 * 1024 if input_pb_type == PbType.EXAMPLEBATCH else 64 * 1024 * 1024
if use_snappy is None:
if isinstance(file_name, str):
use_snappy = file_name.endswith('.snappy')
assert use_snappy is not None
variant_tensor = pb_datasource_ops.pb_dataset(
file_name=file_name,
use_snappy=use_snappy,
buffer_size=buffer_size or default_buffer_size,
input_pb_type=input_pb_type.to_name(),
output_pb_type=output_pb_type.to_name(),
has_sort_id=self._has_sort_id,
kafka_dump=self._kafka_dump,
kafka_dump_prefix=self._kafka_dump_prefix,
lagrangex_header=self._lagrangex_header,
feature_pruning_type=feature_pruning_type,
feature_name_list=feature_name_list,
feature_id_list=feature_id_list,
out_type=self._out_type,
)
logging.info("Start init of the pb instance dataset base.")
super().__init__(variant_tensor)
@property
def element_spec(self):
return tensor_spec.TensorSpec([], self._out_type)
class DistributedFilePBDataset(dataset_ops.DatasetSource):
def __init__(self,
patterns: Union[str, List[str]],
use_snappy=False,
buffer_size: int = None,
input_pb_type: PbType = None,
output_pb_type: PbType = None,
feature_pruning_type: int = FeaturePruningType.PRUNING_RAW_FEATURE,
exclude_fn: Callable[[tf.Tensor], bool] = None,
use_data_service: bool = False,
cycle_length=None,
block_length=None,
num_parallel_calls=None,
deterministic=None,
**kwargs):
if not patterns:
patterns = [""]
elif isinstance(patterns, str):
patterns = [patterns]
else:
logging.info(f'patterns: {patterns}')
patterns.sort()
enable_dynamic_sharding = kwargs.get('enable_dynamic_sharding', _get_params('enable_dynamic_sharding', False))
logging.info(f"enable_dynamic_sharding: {enable_dynamic_sharding}")
map_func = lambda file_name: FilePBDataset(
file_name=file_name, use_snappy=use_snappy, buffer_size=buffer_size,
input_pb_type=input_pb_type, output_pb_type=output_pb_type, feature_pruning_type=feature_pruning_type,
disable_iterator_save_restore=not enable_dynamic_sharding, **kwargs)
if use_data_service:
files_list = DynamicMatchingFilesDataset(patterns)
if exclude_fn is not None:
files_list = files_list.filter(predicate=exclude_fn)
dataset = files_list.interleave(map_func, cycle_length=cycle_length,
block_length=block_length,
num_parallel_calls=num_parallel_calls,
deterministic=deterministic)
elif enable_dynamic_sharding:
files_list = distributed_dataset.create_dynamic_sharding_dataset(patterns)
if exclude_fn is not None:
files_list = files_list.filter(predicate=exclude_fn)
dataset = files_list.flat_map(map_func)
else:
files_list = matching_files.MatchingFilesDataset(patterns)
if exclude_fn is not None:
files_list = files_list.filter(predicate=exclude_fn)
ctx = native_task_context.get()
if ctx is not None:
if ctx.num_workers > 1:
files_list = files_list.shard(ctx.num_workers, ctx.worker_index)
else:
shard_num = kwargs.get('shard_num', 1)
shard_index = kwargs.get('shard_index', 0)
if shard_num > 1:
files_list = files_list.shard(shard_num, shard_index)
cycle_length = kwargs.get('cycle_length', _get_params('max_task_num_per_worker', 4))
num_parallel_calls = kwargs.get('num_parallel_calls', _get_params('max_task_num_per_worker', 4))
block_length = kwargs.get('block_length', _get_params('block_length', 1))
dataset = files_list.interleave(map_func=map_func,
cycle_length=cycle_length,
block_length=block_length,
num_parallel_calls=num_parallel_calls,
deterministic=False)
self._dataset = dataset
super(DistributedFilePBDataset, self).__init__(variant_tensor=self._dataset._variant_tensor)
@property
def element_spec(self):
return self._dataset.element_spec
[docs]@monolith_export
class InstanceReweightDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""样本重加权, 并根据action给样本打标签, 使用方式为 dataset.instance_reweight
一个样本可能有多个action, 按`action_priority`, 找到最高优的action. 再用action找到对应的 `action:weight:label`,
让样本重复weight次(也有可能是0次, 即删除样本), 然后给样本打上label指定的标签
Args:
input_dataset (:obj:`dataset`): 输入数据集
action_priority (:obj:`str`): action用int表示, 以逗号分隔的int数组, 排在前面的优先级高
reweight (:obj:`str`): 基本单元是`action:weight:label`, 可以用逗号分隔多个基本单元
1) action: 动作, 用int表示, 与业务相关, 如download, install, click, exposure等
2) weight: 权重, 用int表示, 表示样本重复的次数
3) label: 标签, 一般用1/-1表示.
variant_type (:obj:`str`): 输入数据是variant类型的, 支持两种格式, instance/example
Raises:
TypeError: 如果有任何参数与类型不匹配, 则抛TypeError
ValueError: 如果有任何值与期望不匹配, 则抛ValueError
"""
def __init__(self,
input_dataset,
action_priority: str = None,
reweight: str = None,
variant_type: str = 'example'):
self._label_priority = action_priority
self._reweight = reweight
self._variant_type = variant_type
actions, weights, labels = [], [], []
for item in reweight.strip().split(','):
(action, weight, label) = item.strip().split(':')
actions.append(int(action))
weights.append(int(weight))
labels.append(int(label))
priorities = [int(p) for p in action_priority.strip().split(',')]
variant_tensor = pb_datasource_ops.instance_reweight_dataset(
input=input_dataset._variant_tensor,
method=0,
actions=actions,
weights=weights,
labels=labels,
priorities=priorities,
variant_type=variant_type)
logging.info("Start init of the pb instance dataset base.")
super(InstanceReweightDataset, self).__init__(input_dataset, variant_tensor)
@property
def element_spec(self):
return tensor_spec.TensorSpec([], dtypes.variant)
[docs]@monolith_export
class NegativeGenDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""负例生成. 有时, 样本中只有正例, 没有负例, 需要随机生成负例
推荐系统中的样本通常是由user侧, item侧两部分组成. 这里的做法是:
- 先收集每个样本的item侧信息, 生成一个item池子
- item池子并不是平铺的, 而是按某个特征(channel_slot)分类组织的. 如果在同一个channel随机取item得到的是hard负例, 在其它channel中抽样得到的是easy负例
- 并不是一开始就生成负例, 而是要等item池子积累到一定大小才开始生成负例
Args:
input_dataset (:obj:`dataset`): 输入数据集
neg_num (:obj:`int`): 为一个正例生成`neg_num`个负例
channel_feature (:obj:`string`): 用于当item分类的字段
per_channel (:obj:`bool`): 是否分类
start_num (:obj:`int`): 在item池子中积累多少个后才开始采样
max_iten_num (:obj:`int`): 每一个channel最多收集多注个item
item_features: (:obj:`List[str]`): item侧的特征名列表
positive_label: 正例的label, 仅为正例生成负例
negative_label: 生成的负例的被打上的label
Raises:
TypeError: 如果有任何参数与类型不匹配, 则抛TypeError
ValueError: 如果有任何值与期望不匹配, 则抛ValueError
"""
def __init__(self,
input_dataset,
neg_num: int,
per_channel: bool = False,
channel_feature: Union[int, str] = '',
item_features: Union[List[int], List[str]] = [],
start_num: int = 500,
max_item_num: int = 100000,
positive_label: int = 1,
negative_label: int = -1,
negative_action: int = -99999,
positive_actions: List[int] = [],
label_index: int = 0,
action_priority: str = '',
index_feature: Union[int, str] = '',
throw_origin: bool = False,
throw_origin_neg: bool = False,
cache_only_pos: bool = True,
real_neg_instance_weight: float = 1.0,
sampled_neg_instance_weight: float = -1.0,
unbias_sampled_neg: bool = True,
origin_neg_in_pool_proba: float = 1.0,
neg_sample_declay_factor: float = 1.0,
variant_type: str = 'example'):
pool = create_item_pool(start_num=start_num,
max_item_num_per_channel=max_item_num)
tf.compat.v1.add_to_collection(POOL_KEY, pool)
channel_feature = str(channel_feature)
item_features = [str(item) for item in item_features]
action_priority_items = action_priority.strip().split(',')
assert len(action_priority_items) == len(set(action_priority_items))
index_feature = str(index_feature)
assert variant_type in {'instance', 'example'}
assert label_index >= 0
variant_tensor = pb_datasource_ops.instance_negative_gen_dataset(
input=input_dataset._variant_tensor,
pool=pool,
neg_num=neg_num,
per_channel=per_channel,
channel_feature=channel_feature,
item_features=item_features,
label_index=label_index,
positive_label=positive_label,
negative_label=negative_label,
negative_action=negative_action,
action_priority=action_priority,
positive_actions=positive_actions,
index_feature=index_feature,
throw_origin=throw_origin,
throw_origin_neg=throw_origin_neg,
cache_only_pos=cache_only_pos,
real_neg_instance_weight=real_neg_instance_weight,
sampled_neg_instance_weight=sampled_neg_instance_weight,
unbias_sampled_neg=unbias_sampled_neg,
origin_neg_in_pool_proba=origin_neg_in_pool_proba,
neg_sample_declay_factor=neg_sample_declay_factor,
variant_type=variant_type)
super(NegativeGenDataset, self).__init__(input_dataset, variant_tensor)
@property
def element_spec(self):
return tensor_spec.TensorSpec([], dtypes.variant)
def instance_reweight(self,
action_priority: str,
reweight: str,
variant_type: str = 'example'):
return InstanceReweightDataset(self,
action_priority,
reweight,
variant_type=variant_type)
@monolith_export
class SplitFlowDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self,
input_dataset,
data_flow: List[str],
index: int,
max_queue_size: int = 1024,
variant_type: str = 'example'):
variant_tensor = pb_datasource_ops.split_flow_dataset(input_dataset._variant_tensor,
data_flow=data_flow,
index=index,
max_queue_size=max_queue_size,
variant_type=variant_type)
super(SplitFlowDataset, self).__init__(input_dataset, variant_tensor)
@property
def element_spec(self):
return tensor_spec.TensorSpec([], dtypes.variant)
@monolith_export
class MergeFlowDataset(dataset_ops.DatasetV2):
def __init__(self,
input_dataset,
dataset_to_merge,
max_queue_size: int = 1024,
variant_type: str = 'example'):
self._input_dataset = input_dataset
self._dataset_to_merge = dataset_to_merge
output_types = dataset_ops.get_legacy_output_types(input_dataset)
for ds in dataset_to_merge:
ds_types = dataset_ops.get_legacy_output_types(ds)
if output_types != ds_types:
raise TypeError(
"Datasets to merge have different types %s and %s" %
(output_types, ds_types))
input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset)
flat_sequence = None
input_shapes_flatten = nest.flatten(input_shapes)
for ds in dataset_to_merge:
ds_shapes_flatten = nest.flatten(dataset_ops.get_legacy_output_shapes(ds))
if flat_sequence is None:
flat_sequence = [ts1.most_specific_compatible_shape(ts2)
for (ts1, ts2) in zip(input_shapes_flatten, ds_shapes_flatten)]
else:
tmp = [ts1.most_specific_compatible_shape(ts2)
for (ts1, ts2) in zip(input_shapes_flatten,ds_shapes_flatten)]
assert all(ts1 == ts2 for (ts1, ts2) in zip(flat_sequence, tmp))
output_shapes = nest.pack_sequence_as(input_shapes, flat_sequence)
output_classes = dataset_ops.get_legacy_output_classes(input_dataset)
for ds in dataset_to_merge:
ds_classes = dataset_ops.get_legacy_output_classes(ds)
if output_classes != ds_classes:
raise TypeError(
"Datasets to merge have different classes %s and %s" %
(output_classes, ds_classes))
self._structure = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
self._input_datasets = [input_dataset] + dataset_to_merge
input_dataset_variant = [ds._variant_tensor for ds in self._input_datasets]
data_flow = ['input_ds'] + ['ds_to_merge_{}'.format(i+1) for i in range(len(self._dataset_to_merge))]
variant_tensor = pb_datasource_ops.merge_flow_dataset(input_dataset_variant,
data_flow=data_flow,
max_queue_size=max_queue_size,
variant_type=variant_type)
super(MergeFlowDataset, self).__init__(variant_tensor)
def _inputs(self):
return self._input_datasets
@property
def element_spec(self):
return self._structure
def negative_gen(self,
neg_num: int,
per_channel: bool = False,
channel_feature: Union[int, str] = '',
item_features: Union[List[int], List[str]] = [],
start_num: int = 500,
max_item_num: int = 100000,
positive_label: int = 1,
negative_label: int = -1,
negative_action: int = -99999,
positive_actions: List[int] = [],
label_index: int = 0,
action_priority: str = '',
index_feature: Union[int, str] = '',
throw_origin: bool = False,
throw_origin_neg: bool = False,
cache_only_pos: bool = False,
real_neg_instance_weight: float = 1.0,
sampled_neg_instance_weight: float = -1.0,
unbias_sampled_neg: bool = True,
origin_neg_in_pool_proba: float = 1.0,
neg_sample_declay_factor: float = 1.0,
variant_type: str = 'example'):
return NegativeGenDataset(
self,
neg_num=neg_num,
per_channel=per_channel,
channel_feature=channel_feature,
item_features=item_features,
start_num=start_num,
max_item_num=max_item_num,
label_index=label_index,
positive_label=positive_label,
negative_label=negative_label,
negative_action=negative_action,
action_priority=action_priority,
positive_actions=positive_actions,
index_feature=index_feature,
throw_origin=throw_origin,
throw_origin_neg=throw_origin_neg,
cache_only_pos=cache_only_pos,
real_neg_instance_weight=real_neg_instance_weight,
sampled_neg_instance_weight=sampled_neg_instance_weight,
unbias_sampled_neg=unbias_sampled_neg,
origin_neg_in_pool_proba=origin_neg_in_pool_proba,
neg_sample_declay_factor=neg_sample_declay_factor,
variant_type=variant_type)
def split_flow(self,
data_flow: List[str],
index: int,
max_queue_size: int = 1024,
variant_type: str = 'example'):
return SplitFlowDataset(self,
data_flow=data_flow, index=index,
max_queue_size=max_queue_size, variant_type=variant_type)
def merge_flow(self,
dataset_to_merge,
max_queue_size: int = 1024,
variant_type: str = 'example'):
return MergeFlowDataset(self, dataset_to_merge,
max_queue_size=max_queue_size, variant_type=variant_type)
class KafkaGen(object):
def __init__(self, topics: List[str], group_id: str, servers: Union[str, List[str]],
stream_timeout: int = -1, message_poll_timeout: int = 10000, poll_batch_size: int = 1024):
if stream_timeout == -1:
stream_timeout = sys.maxsize
elif stream_timeout >= 0:
stream_timeout = max(stream_timeout, message_poll_timeout)
else:
raise ValueError('stream_timeout must bigger then -1')
if isinstance(topics, str):
topics = [topics]
self.topics, self.group_id, self.servers = topics, group_id, servers
self._lock = RLock()
self._stop_iteration = False # lock
self._consumer: KafkaConsumer = None # lock
self._queue = Queue(maxsize=1024)
self.message_poll_timeout = message_poll_timeout
self.poll_batch_size = poll_batch_size
self._max_stream_timeout_polls = int(stream_timeout / message_poll_timeout)
self._stream_timeout_polls = -1
@property
def consumer(self):
with self._lock:
if self._consumer is None:
self._consumer = KafkaConsumer(*self.topics, group_id=self.group_id,
bootstrap_servers=self.servers)
thread = Thread(target=self._poll)
thread.start()
return self._consumer
def __iter__(self):
return self
def __next__(self):
assert self.consumer is not None
while True:
data = self._queue.get(timeout=self.message_poll_timeout)
if data:
return data
with self._lock:
if self._stop_iteration:
raise StopIteration
def __call__(self):
return self
def _poll(self):
while self._stream_timeout_polls < self._max_stream_timeout_polls:
try:
msg = self._consumer.poll(timeout_ms=self.message_poll_timeout,
max_records=self.poll_batch_size, update_offsets=True)
if msg:
poll_values = []
for part, values in msg.items():
part_vals = [value.value for value in values if value.value]
if part_vals:
poll_values.extend(part_vals)
if poll_values:
self._stream_timeout_polls = 0
self._queue.put(poll_values)
else:
self._stream_timeout_polls += 1
continue
else:
self._stream_timeout_polls += 1
except Exception as e:
logging.error(f'poll error: {e}')
break
with self._lock:
self._consumer.close()
self._stop_iteration = True
class PyKafkaDataset(dataset_ops.DatasetSource):
def __init__(self, topics, group_id, servers, *, has_header=True, variant_type: str = None,
stream_timeout=-1, message_poll_timeout=10000, poll_batch_size: int = 1024,
filter_empty: bool = False, **kwargs):
variant_type = variant_type or _get_params('data_type', PbType.INSTANCE).to_name()
self._has_sort_id = kwargs.get('has_sort_id', _get_params('sort_id', False))
self._kafka_dump = kwargs.get('kafka_dump',
_get_params('kafka_dump', False))
logging.info(f'pb_type: {variant_type}, kafka_dump: {self._kafka_dump}')
self._kafka_dump_prefix = kwargs.get(
'kafka_dump_prefix', _get_params('kafka_dump_prefix', False))
self._lagrangex_header = kwargs.get('lagrangex_header',
_get_params('lagrangex_header', False))
if context.default_execution_mode == context.GRAPH_MODE:
ckpt_hooks.disable_iterator_save_restore()
kafka_gen = KafkaGen(topics, group_id, servers, stream_timeout,message_poll_timeout, poll_batch_size)
dataset = tf.data.Dataset.from_generator(generator=kafka_gen, output_types=tf.string, output_shapes=None)
dataset = dataset.map(lambda v: string_to_variant(v,
variant_type=variant_type.lower(),
has_header=has_header,
lagrangex_header=self._lagrangex_header,
has_sort_id=self._has_sort_id,
kafka_dump=self._kafka_dump,
kafka_dump_prefix=self._kafka_dump_prefix),
num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE).unbatch()
if filter_empty:
dataset = dataset.filter(predicate=lambda x: has_variant(input=x, variant_type=variant_type.lower()))
self._dataset = dataset
super().__init__(self._dataset._variant_tensor)
@property
def element_spec(self):
return self._dataset.element_spec
class KafkaDataset(dataset_ops.DatasetSource):
def __init__(self, topics: List[str], group_id: str, servers: str, *, has_header=True, variant_type: str = None,
stream_timeout=-1, message_poll_timeout=10000, poll_batch_size: int = 1024,
filter_empty: bool = False, configuration=None, container: str = '', shared_name: str = '',
**kwargs):
variant_type = variant_type or _get_params('data_type', PbType.INSTANCE).to_name()
self._has_sort_id = kwargs.get('has_sort_id', _get_params('sort_id', False))
self._kafka_dump = kwargs.get('kafka_dump',
_get_params('kafka_dump', False))
logging.info(f'pb_type: {variant_type}, kafka_dump: {self._kafka_dump}')
self._kafka_dump_prefix = kwargs.get(
'kafka_dump_prefix', _get_params('kafka_dump_prefix', False))
self._lagrangex_header = kwargs.get('lagrangex_header',
_get_params('lagrangex_header', False))
if context.default_execution_mode == context.GRAPH_MODE:
ckpt_hooks.disable_iterator_save_restore()
self._chnids = kwargs.get('chnids', _get_params('chnids', None))
self._datasources = kwargs.get('datasources', _get_params('datasources', None))
self._default_datasource = kwargs.get('default_datasource', _get_params('default_datasource', ''))
with tf.name_scope("MonolithKafkaDataset"):
if stream_timeout == -1:
stream_timeout = sys.maxsize
elif stream_timeout >= 0:
stream_timeout = max(stream_timeout, message_poll_timeout)
else:
raise ValueError(
f"Invalid stream_timeout value: {stream_timeout} ,set it to -1 to block indefinitely.")
metadata = list(configuration or [])
if group_id is not None:
metadata.append(f"group.id={group_id}")
if servers is not None:
metadata.append(f"bootstrap.servers={servers}")
if poll_batch_size is not None:
assert isinstance(poll_batch_size, int) and poll_batch_size > 0
metadata.append(f"batch.num.messages={poll_batch_size}")
resource = kafka_resource_init(topics=topics, metadata=metadata,
container=container, shared_name=shared_name)
self._resource = resource
dataset = tf.data.experimental.Counter()
dataset = dataset.map(
lambda i: kafka_read_next(
input=self._resource,
index=i,
message_poll_timeout=message_poll_timeout,
stream_timeout=stream_timeout,
)
)
dataset = dataset.apply(
tf.data.experimental.take_while(
lambda v: tf.greater(v.continue_fetch, 0)
)
)
dataset = dataset.map(lambda v: string_to_variant(v.message,
variant_type=variant_type.lower(),
has_header=has_header,
lagrangex_header=self._lagrangex_header,
has_sort_id=self._has_sort_id,
kafka_dump=self._kafka_dump,
kafka_dump_prefix=self._kafka_dump_prefix,
chnids=self._chnids,
datasources=self._datasources,
default_datasource=self._default_datasource),
num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE).unbatch()
if filter_empty:
dataset = dataset.filter(predicate=lambda x: has_variant(input=x, variant_type=variant_type.lower()))
self._dataset = dataset
super().__init__(self._dataset._variant_tensor)
@property
def element_spec(self):
return self._dataset.element_spec
def distribute(self, target, *, job_name: str,
num_worker: int, worker_idx: int,
queue_device: str = "/device:CPU:0",
max_outstanding_requests: int = None):
try:
if FLAGS.kafka_topics is not None and FLAGS.kafka_group_id is not None:
return self
except Exception as e:
pass
element_spec = self.element_spec
with tf.compat.v1.device(queue_device):
queue = tf.compat.v1.FIFOQueue(capacity=num_worker-1, dtypes=[tf.int64], shared_name=f'{job_name}_queue')
if worker_idx == 0:
# data service try to register dataset, if the dataset has been registed, return dataset_id drectily
# that means get or register dataset. for data parallel, the data pipeline assure to be identity
# here we ues queue to ensure the same data pipeline for a job
dataset_id = dsvc.register_dataset(target, self)
enqueue_op = queue.enqueue_many(vals=[dataset_id] * (num_worker-1))
with tf.compat.v1.control_dependencies(control_inputs[enqueue_op]):
# to share pipeline, job_name must be specified
return dsvc.from_dataset_id(processing_mode="distributed_epoch",
service=target, dataset_id=dataset_id, job_name=job_name,
element_spec=element_spec, max_outstanding_requests=max_outstanding_requests)
else:
dataset_id = queue.dequeue()
return dsvc.from_dataset_id(processing_mode="distributed_epoch",
service=target, dataset_id=dataset_id, job_name=job_name,
element_spec=element_spec, max_outstanding_requests=max_outstanding_requests)
Dataset.instance_reweight = instance_reweight
Dataset.negative_gen = negative_gen
Dataset.split_flow = split_flow
Dataset.merge_flow = merge_flow
Dataset.distribute = distribute