Source code for monolith.native_training.layers.advanced_activations

import types

import tensorflow.keras.initializers as initializers
import tensorflow.keras.constraints as constraints
from tensorflow.python.keras.activations import exponential
from tensorflow.python.keras.activations import gelu
from tensorflow.python.keras.activations import hard_sigmoid
from tensorflow.python.keras.activations import linear
from tensorflow.python.keras.activations import selu
from tensorflow.python.keras.activations import sigmoid
from tensorflow.python.keras.activations import softplus
from tensorflow.python.keras.activations import softsign
from tensorflow.python.keras.activations import swish
from tensorflow.python.keras.activations import tanh
from tensorflow.python.keras.layers import Layer
from tensorflow.python.keras.layers.advanced_activations import ReLU
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
from tensorflow.python.keras.layers.advanced_activations import ELU
from tensorflow.python.keras.layers.advanced_activations import Softmax
from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU
from tensorflow.python.keras.layers.advanced_activations import PReLU

from monolith.native_training.utils import params as _params
from monolith.native_training.monolith_export import monolith_export

__all__ = [
    'ReLU', 'LeakyReLU', 'ELU', 'Softmax', 'ThresholdedReLU', 'PReLU',
    'Exponential', 'Gelu', 'HardSigmoid', 'Linear', 'Selu', 'Sigmoid',
    'Sigmoid2', 'Softplus', 'Softsign', 'Swish', 'Tanh'
]

Tanh = type('Tanh', (Layer,), {'call': lambda self, x: tanh(x)})
Sigmoid = type('Sigmoid', (Layer,), {'call': lambda self, x: sigmoid(x)})
Sigmoid2 = type('Sigmoid2', (Layer,), {'call': lambda self, x: sigmoid(x) * 2})
Linear = type('Linear', (Layer,), {'call': lambda self, x: linear(x)})
Gelu = type('Gelu', (Layer,), {'call': lambda self, x: gelu(x)})
Selu = type('Selu', (Layer,), {'call': lambda self, x: selu(x)})
Softsign = type('Softsign', (Layer,), {'call': lambda self, x: softsign(x)})
Softplus = type('Softplus', (Layer,), {'call': lambda self, x: softplus(x)})
Exponential = type('Exponential', (Layer,),
                   {'call': lambda self, x: exponential(x)})
HardSigmoid = type('HardSigmoid', (Layer,),
                   {'call': lambda self, x: hard_sigmoid(x)})
Swish = type('Swish', (Layer,), {'call': lambda self, x: swish(x)})

ReLU.params = types.MethodType(_params, ReLU)
PReLU.params = types.MethodType(_params, PReLU)
LeakyReLU.params = types.MethodType(_params, LeakyReLU)
ELU.params = types.MethodType(_params, ELU)
Softmax.params = types.MethodType(_params, Softmax)
ThresholdedReLU.params = types.MethodType(_params, ThresholdedReLU)
Tanh.params = types.MethodType(_params, Tanh)
Sigmoid.params = types.MethodType(_params, Sigmoid)
Sigmoid2.params = types.MethodType(_params, Sigmoid2)
Linear.params = types.MethodType(_params, Linear)
Gelu.params = types.MethodType(_params, Gelu)
Selu.params = types.MethodType(_params, Selu)
Softsign.params = types.MethodType(_params, Softsign)
Softplus.params = types.MethodType(_params, Softplus)
Exponential.params = types.MethodType(_params, Exponential)
HardSigmoid.params = types.MethodType(_params, HardSigmoid)
Swish.params = types.MethodType(_params, Swish)

__all_activations = {
    'exponential': Exponential,
    'gelu': Gelu,
    'hard_sigmoid': HardSigmoid,
    'hardsigmoid': HardSigmoid,
    'linear': Linear,
    'selu': Selu,
    'sigmoid': Sigmoid,
    'sigmoid2': Sigmoid2,
    'softplus': Softplus,
    'softsign': Softsign,
    'swish': Swish,
    'tanh': Tanh,
    'leakyrelu': LeakyReLU,
    'relu': ReLU,
    'elu': ELU,
    'softmax': Softmax,
    'thresholdedrelu': ThresholdedReLU,
    'prelu': PReLU
}
ALL_ACTIVATION_NAMES = set(__all_activations.keys())


[docs]@monolith_export def get(identifier): """获取函数, 可以用名字获取, 也可以用序列化的Json获取 Args: identifier (:obj:`Any`): 标识, 可以是name, 获序列化的Json, None等 Returns: 激活函数 """ if identifier is None: return None if isinstance(identifier, str): if identifier.lower() in __all_activations: return __all_activations[identifier.lower()]() else: evaled = eval(identifier) if isinstance(evaled, dict): return deserialize(evaled) raise TypeError( 'Could not interpret activation function identifier: {}'.format( identifier)) elif isinstance(identifier, dict): return deserialize(identifier) elif callable(identifier): if hasattr(identifier, 'params'): try: if issubclass(identifier, Layer): return identifier() else: return identifier except: return identifier elif isinstance(identifier, Layer): name = identifier.__class__.__name__.lower() return __all_activations[name]() else: try: name = identifier.__name__ return __all_activations[name]() except: return identifier else: raise TypeError( 'Could not interpret activation function identifier: {}'.format( identifier))
[docs]@monolith_export def serialize(activation): """序列化激活函数 Args: activation (:obj:`tf.activation`): keras激活函数 Returns: Dict/Json 获序列化的激活函数 """ if isinstance(activation, (Linear, Exponential, Selu, HardSigmoid, Gelu, Sigmoid, Softplus, Softsign, Swish, Tanh)): return repr({'name': activation.__class__.__name__}) elif isinstance(activation, (LeakyReLU, ELU)): return repr({ 'name': activation.__class__.__name__, 'alpha': float(activation.alpha) }) elif isinstance(activation, ReLU): return repr({ 'name': 'ReLU', 'max_value': activation.max_value, 'negative_slope': float(activation.negative_slope), 'threshold': float(activation.threshold) }) elif isinstance(activation, PReLU): return repr({ 'name': 'PReLU', 'alpha_initializer': initializers.serialize(activation.alpha_initializer), 'alpha_regularizer': initializers.serialize(activation.alpha_regularizer), 'alpha_constraint': constraints.serialize(activation.alpha_constraint), 'shared_axes': activation.shared_axes }) elif isinstance(activation, Softmax): return repr({'name': 'Softmax', 'axis': activation.axis}) elif isinstance(activation, ThresholdedReLU): return repr({'name': 'ThresholdedReLU', 'theta': float(activation.theta)}) else: return None
[docs]@monolith_export def deserialize(identifier): """反序列化激活函数 Args: identifier (:obj:`Any`): 标识, 可以是name, 获序列化的Json, None等 Returns: 激活函数 """ if identifier is None: return None else: if not isinstance(identifier, dict): identifier = eval(identifier) assert isinstance(identifier, dict) name = identifier['name'].lower() assert name in __all_activations identifier.pop('name') return __all_activations[name](**identifier)