"""AdamP for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend_config
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.util.tf_export import keras_export
[docs]@keras_export('keras.optimizers.AdamP')
class AdamP(optimizer_v2.OptimizerV2):
"""Construct a new AdamP optimizer.
Follows the work of Byeongho Heo et al. [https://arxiv.org/abs/2006.08217]
Parameters
----------
learning_rate: A `Tensor`, floating point value, or a schedule that is a
`tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
that takes no arguments and returns the actual value to use, The
learning rate. Defaults to 0.001.
beta_1: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use. The
exponential decay rate for the 1st moment estimates. Defaults to 0.9.
beta_2: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use, The
exponential decay rate for the 2nd moment estimates. Defaults to 0.999.
epsilon: A small constant for numerical stability. This epsilon is
"epsilon hat" in the Kingma and Ba paper (in the formula just before
Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to
1e-8.
weight_decay: A Tensor or a floating point value. The weight decay. Defaults to 0.
delta : threhold that determines whether a set of parameters is scale invariant or
not. Defaults to 0.1.
wd_ratio : relative weight decay applied on scale-invariant parameters compared to
that applied on scale-variant parameters. Defaults to 0.1.
name: Optional name for the operations created when applying gradients.
Defaults to `"AdamP"`.
**kwargs: Keyword arguments. Allowed to be one of
`"clipnorm"` or `"clipvalue"`.
`"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
gradients by value.
References
-----------
https://github.com/taki0112/AdamP-Tensorflow
"""
_HAS_AGGREGATE_GRAD = True
def __init__(self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-8,
weight_decay=0.0,
delta=0.1,
wd_ratio=0.1,
nesterov=False,
name='AdamP',
**kwargs):
super(AdamP, self).__init__(name, **kwargs)
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
self._set_hyper('beta_1', beta_1)
self._set_hyper('beta_2', beta_2)
self._set_hyper('delta', delta)
self._set_hyper('wd_ratio', wd_ratio)
self.epsilon = epsilon or backend_config.epsilon()
self.weight_decay = weight_decay
self.nesterov = nesterov
def _create_slots(self, var_list):
# Create slots for the first and second moments.
# Separate for-loops to respect the ordering of slot variables from v1.
for var in var_list:
self.add_slot(var, 'm')
for var in var_list:
self.add_slot(var, 'v')
for var in var_list:
self.add_slot(var, 'p')
def _prepare_local(self, var_device, var_dtype, apply_state):
super(AdamP, self)._prepare_local(var_device, var_dtype, apply_state)
local_step = math_ops.cast(self.iterations + 1, var_dtype)
beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
beta_1_power = math_ops.pow(beta_1_t, local_step)
beta_2_power = math_ops.pow(beta_2_t, local_step)
lr = apply_state[(var_device, var_dtype)]['lr_t']
bias_correction1 = 1 - beta_1_power
bias_correction2 = 1 - beta_2_power
delta = array_ops.identity(self._get_hyper('delta', var_dtype))
wd_ratio = array_ops.identity(self._get_hyper('wd_ratio', var_dtype))
apply_state[(var_device, var_dtype)].update(
dict(
lr=lr,
epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype),
weight_decay=ops.convert_to_tensor_v2(self.weight_decay, var_dtype),
beta_1_t=beta_1_t,
beta_1_power=beta_1_power,
one_minus_beta_1_t=1 - beta_1_t,
beta_2_t=beta_2_t,
beta_2_power=beta_2_power,
one_minus_beta_2_t=1 - beta_2_t,
bias_correction1=bias_correction1,
bias_correction2=bias_correction2,
delta=delta,
wd_ratio=wd_ratio))
[docs] def set_weights(self, weights):
params = self.weights
# If the weights are generated by Keras V1 optimizer, it includes vhats
# optimizer has 2x + 1 variables. Filter vhats out for compatibility.
num_vars = int((len(params) - 1) / 2)
if len(weights) == 3 * num_vars + 1:
weights = weights[:len(params)]
super(AdamP, self).set_weights(weights)
def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype))
or self._fallback_apply_state(var_device, var_dtype))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, 'm')
m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
m_t = state_ops.assign(m, m * coefficients['beta_1_t'] + m_scaled_g_values, use_locking=self._use_locking)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, 'v')
v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
v_t = state_ops.assign(v, v * coefficients['beta_2_t'] + v_scaled_g_values, use_locking=self._use_locking)
denorm = (math_ops.sqrt(v_t) / math_ops.sqrt(coefficients['bias_correction2'])) + coefficients['epsilon']
step_size = coefficients['lr'] / coefficients['bias_correction1']
if self.nesterov:
perturb = (coefficients['beta_1_t'] * m_t + coefficients['one_minus_beta_1_t'] * grad) / denorm
else:
perturb = m_t / denorm
# Projection
wd_ratio = 1
if len(var.shape) > 1 and grad.shape[0] is not None:
perturb, wd_ratio = self._projection(var, grad, perturb, coefficients['delta'], coefficients['wd_ratio'], coefficients['epsilon'])
# Weight decay
if self.weight_decay > 0:
var = state_ops.assign(var, var * (1 - coefficients['lr'] * coefficients['weight_decay'] * wd_ratio), use_locking=self._use_locking)
var_update = state_ops.assign_sub(var, step_size * perturb, use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype))
or self._fallback_apply_state(var_device, var_dtype))
"""
Adam
"""
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, 'm')
m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
m_t = state_ops.assign(m, m * coefficients['beta_1_t'],
use_locking=self._use_locking)
with ops.control_dependencies([m_t]):
m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, 'v')
v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
v_t = state_ops.assign(v, v * coefficients['beta_2_t'],
use_locking=self._use_locking)
with ops.control_dependencies([v_t]):
v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
denorm = (math_ops.sqrt(v_t) / math_ops.sqrt(coefficients['bias_correction2'])) + coefficients['epsilon']
step_size = coefficients['lr'] / coefficients['bias_correction1']
if self.nesterov:
p_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
perturb = m_t * coefficients['beta_1_t']
perturb = self._resource_scatter_add(perturb, indices, p_scaled_g_values) / denorm
else:
perturb = m_t / denorm
# Projection
wd_ratio = 1
if len(var.shape) > 1 and grad.shape[0] is not None:
perturb, wd_ratio = self._projection(var, grad, perturb, coefficients['delta'], coefficients['wd_ratio'], coefficients['epsilon'])
# Weight decay
if self.weight_decay > 0:
var = state_ops.assign(var, var * (1 - coefficients['lr'] * coefficients['weight_decay'] * wd_ratio), use_locking=self._use_locking)
var_update = state_ops.assign_sub(var, step_size * perturb, use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
def _channel_view(self, x):
return array_ops.reshape(x, shape=[x.shape[0], -1])
def _layer_view(self, x):
return array_ops.reshape(x, shape=[1, -1])
def _cosine_similarity(self, x, y, eps, view_func):
x = view_func(x)
y = view_func(y)
x_norm = math_ops.euclidean_norm(x, axis=-1) + eps
y_norm = math_ops.euclidean_norm(y, axis=-1) + eps
dot = math_ops.reduce_sum(x * y, axis=-1)
return math_ops.abs(dot) / x_norm / y_norm
def _projection(self, var, grad, perturb, delta, wd_ratio, eps):
# channel_view
cosine_sim = self._cosine_similarity(grad, var, eps, self._channel_view)
cosine_max = math_ops.reduce_max(cosine_sim)
compare_val = delta / math_ops.sqrt(math_ops.cast(self._channel_view(var).shape[-1], dtype=delta.dtype))
perturb, wd = control_flow_ops.cond(pred=cosine_max < compare_val,
true_fn=lambda : self.channel_true_fn(var, perturb, wd_ratio, eps),
false_fn=lambda : self.channel_false_fn(var, grad, perturb, delta, wd_ratio, eps))
return perturb, wd
[docs] def channel_true_fn(self, var, perturb, wd_ratio, eps):
expand_size = [-1] + [1] * (len(var.shape) - 1)
var_n = var / (array_ops.reshape(math_ops.euclidean_norm(self._channel_view(var), axis=-1), shape=expand_size) + eps)
perturb -= var_n * array_ops.reshape(math_ops.reduce_sum(self._channel_view(var_n * perturb), axis=-1), shape=expand_size)
wd = wd_ratio
return perturb, wd
[docs] def channel_false_fn(self, var, grad, perturb, delta, wd_ratio, eps):
cosine_sim = self._cosine_similarity(grad, var, eps, self._layer_view)
cosine_max = math_ops.reduce_max(cosine_sim)
compare_val = delta / math_ops.sqrt(math_ops.cast(self._layer_view(var).shape[-1], dtype=delta.dtype))
perturb, wd = control_flow_ops.cond(cosine_max < compare_val,
true_fn=lambda : self.layer_true_fn(var, perturb, wd_ratio, eps),
false_fn=lambda : self.identity_fn(perturb))
return perturb, wd
[docs] def layer_true_fn(self, var, perturb, wd_ratio, eps):
expand_size = [-1] + [1] * (len(var.shape) - 1)
var_n = var / (array_ops.reshape(math_ops.euclidean_norm(self._layer_view(var), axis=-1), shape=expand_size) + eps)
perturb -= var_n * array_ops.reshape(math_ops.reduce_sum(self._layer_view(var_n * perturb), axis=-1), shape=expand_size)
wd = wd_ratio
return perturb, wd
[docs] def identity_fn(self, perturb):
wd = 1.0
return perturb, wd
[docs] def get_config(self):
config = super(AdamP, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
'beta_1': self._serialize_hyperparameter('beta_1'),
'beta_2': self._serialize_hyperparameter('beta_2'),
'delta': self._serialize_hyperparameter('delta'),
'wd_ratio': self._serialize_hyperparameter('wd_ratio'),
'epsilon': self.epsilon,
'weight_decay': self.weight_decay,
'nesterov': self.nesterov
})
return config
optimizer_custom_objects = {
'AdamP': AdamP,
}