@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
"""Performs a backpropagation."""
(data, ) = ctx.saved_tensors
data = data.double()
grad = torch.where(data >= 0.0, 1.0, torch.exp(data) * (data + 1))
return grad_output * grad
def forward(ctx, data: torch.Tensor) -> torch.Tensor:
"""Performs a forward pass."""
ctx.save_for_backward(data)
return torch.where(data <= 0.0, data * torch.exp(data), data)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
"""Performs a backpropagation."""
(data, ) = ctx.saved_tensors
data = data.double()
grad = torch.where(data >= 0.0, 1.0, torch.exp(data) * (data + 1))
return grad_output * grad
The exponential is computed 2 times (forward, backward) over the whole range (positive and negative).
Is think is could be computed only 1 time, and only for the negative value, for efficiency