```#! /usr/bin/python
# -*- coding: utf-8 -*-
"""AMSGrad Implementation based on the paper: "On the Convergence of Adam and Beyond" (ICLR 2018)
"""

from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import (control_flow_ops, math_ops, resource_variable_ops, state_ops, variable_scope)
from tensorflow.python.training import optimizer

"""Implementation of the AMSGrad optimization algorithm.

See: `On the Convergence of Adam and Beyond - [Reddi et al., 2018] <https://openreview.net/pdf?id=ryQu7f-RZ>`__.

Parameters
----------
learning_rate: float
A Tensor or a floating point value.  The learning rate.
beta1: float
A float value or a constant float tensor.
The exponential decay rate for the 1st moment estimates.
beta2: float
A float value or a constant float tensor.
The exponential decay rate for the 2nd moment estimates.
epsilon: float
A small constant for numerical stability.
This epsilon is "epsilon hat" in the Kingma and Ba paper
(in the formula just before Section 2.1), not the epsilon in Algorithm 1 of the paper.
use_locking: bool
If True use locks for update operations.
name: str
Optional name for the operations created when applying gradients.
"""

def __init__(self, learning_rate=0.01, beta1=0.9, beta2=0.99, epsilon=1e-8, use_locking=False, name="AMSGrad"):
self._lr = learning_rate
self._beta1 = beta1
self._beta2 = beta2
self._epsilon = epsilon

self._lr_t = None
self._beta1_t = None
self._beta2_t = None
self._epsilon_t = None

self._beta1_power = None
self._beta2_power = None

def _create_slots(self, var_list):
first_var = min(var_list, key=lambda x: x.name)

create_new = self._beta1_power is None
if not create_new and context.in_graph_mode():
create_new = (self._beta1_power.graph is not first_var.graph)

if create_new:
with ops.colocate_with(first_var):
self._beta1_power = variable_scope.variable(self._beta1, name="beta1_power", trainable=False)
self._beta2_power = variable_scope.variable(self._beta2, name="beta2_power", trainable=False)
# Create slots for the first and second moments.
for v in var_list:
self._zeros_slot(v, "m", self._name)
self._zeros_slot(v, "v", self._name)
self._zeros_slot(v, "vhat", self._name)

def _prepare(self):
self._lr_t = ops.convert_to_tensor(self._lr)
self._beta1_t = ops.convert_to_tensor(self._beta1)
self._beta2_t = ops.convert_to_tensor(self._beta2)
self._epsilon_t = ops.convert_to_tensor(self._epsilon)

beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)

lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))

# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad * (1 - beta1_t)
m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking)

# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking)

vhat = self.get_slot(var, "vhat")
vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat))
v_sqrt = math_ops.sqrt(vhat_t)

var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t])

var = var.handle

lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))

# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m").handle
m_scaled_g_values = grad * (1 - beta1_t)
m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking)

# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v").handle
v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking)

vhat = self.get_slot(var, "vhat").handle
vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat))
v_sqrt = math_ops.sqrt(vhat_t)

var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t])

beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)

lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))

# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad * (1 - beta1_t)
m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
with ops.control_dependencies([m_t]):

# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
with ops.control_dependencies([v_t]):

vhat = self.get_slot(var, "vhat")
vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat))
v_sqrt = math_ops.sqrt(vhat_t)
var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t])

return self._apply_sparse_shared(
var,
lambda x, i, v: state_ops.
x, i, v, use_locking=self._use_locking
)
)

return x.value()