A Mercury module for automatic differentiation that includes forward and backward differentiation.
The module adapts the approach presented in https://github.com/qobi/AD-Rosetta-Stone/. Interestingly, most of the functional approaches for backward differentiation described in https://github.com/qobi/AD-Rosetta-Stone/ uses references to update a tape in-place, whereas this library implements a more pure functional approach for the fan-out and reverse phases. [For the technical details, I added an integer tag to each tape and then collect the sensitivity values to extract the gradients.]
The module also includes some simple optimisers and implements one of the extended examples (mlp_R) given in https://github.com/qobi/AD-Rosetta-Stone/.
The general type is ad_number
, which includes base(float)
for numerical values. Full documentation is given below.
Some test code is here:
:- module test_ad. :- interface. :- import_module io. :- pred main(io::di, io::uo) is det. :- implementation. :- use_module math. :- import_module ad. :- import_module ad.v. :- import_module list. :- import_module float. main(!IO) :- derivative_F(func(X) = exp(base(2.0)*X), base(1.0), GradF), print_line("Expected: ", !IO), print_line(base(math.exp(2.0)*2.0), !IO), print_line(GradF, !IO), gradient_R(func(List) = Y :- (if List=[A,B] then Y=exp(base(2.0)*A)+B else Y = base(0.0)), from_list([1.0,3.0]), Grad0), print_line("Expected: ", !IO), print_line([base(math.exp(2.0)*2.0),base(1.0)], !IO), print_line(Grad0, !IO), gradient_R(func(List) = Y :- (if List=[A,B] then Y=B+A*A*A else Y = base(0.0)), [base(1.1),base(2.3)], Grad), print_line("Expected: ", !IO), print_line([base(3.0*1.1*1.1),base(1.0)], !IO), print_line(Grad, !IO), gradient_R(func(List) = Y :- (if List=[A,B] then Y=exp(B+A*A*A) else Y = base(0.0)), [base(1.1),base(2.3)], Grad2), print_line("Expected: ", !IO), print_line([base(math.exp(2.3+1.1*1.1*1.1)*(3.0*1.1*1.1)), base(math.exp(2.3+1.1*1.1*1.1))], !IO), print_line(Grad2, !IO), gradient_F(func(List) = Y :- (if List=[A,B] then Y=exp(B+A*A*A) else Y = base(0.0)), [base(1.1),base(2.3)], Grad3), print_line("Expected: ", !IO), print_line([base(math.exp(2.3+1.1*1.1*1.1)*(3.0*1.1*1.1)), base(math.exp(2.3+1.1*1.1*1.1))], !IO), print_line(Grad3, !IO), multivariate_argmin_F(func(AB) = Y :- if AB = [A,B] then Y = A*A+(B-base(1.0))*(B-base(1.0)) else Y=base(0.0), [base(1.0),base(2.0)],Y4), print_line("Expected: ", !IO), print_line([base(0.0),base(1.0)], !IO), print_line(Y4,!IO), multivariate_argmin_R(func(AB) = Y :- if AB = [A,B] then Y = A*A+(B-base(1.0))*(B-base(1.0)) else Y=base(0.0), [base(1.0),base(2.0)],Y5), print_line("Expected: ", !IO), print_line([base(0.0),base(1.0)], !IO), print_line(Y5,!IO), multivariate_argmin_R(rosenbrock, [base(-3.0),base(4.0)],Y6), print_line("Expected: ", !IO), print_line([base(1.0),base(1.0)], !IO), print_line(Y6,!IO), multivariate_argmin_F(rosenbrock, [base(-3.0),base(4.0)],Y7), print_line("Expected: ", !IO), print_line([base(1.0),base(1.0)], !IO), print_line(Y7,!IO), gradient_F(func(List) = Y :- (if List=[A,B] then Y=atan2(A,B) else Y = base(0.0)), [base(0.5),base(0.6)], Y8), print_line("Expected: ", !IO), Y8_Y = base(fdiff(func(Yi)=math.atan2(Yi,0.6),0.5)), Y8_X = base(fdiff(func(Xi)=math.atan2(0.5,Xi),0.6)), print_line([Y8_Y, Y8_X], !IO), print_line(Y8,!IO). :- func rosenbrock(v_ad_number) = ad_number. rosenbrock(In) = Result :- In = [X,Y] -> A = base(1.0), B = base(100.0), Result = (A-X)*(A-X)+B*(Y-X*X)*(Y-X*X) ; Result = base(0.0).
Running the code and getting output:
mmc --make test_ad libad && ./test_ad
Expected: base(14.7781121978613) base(14.7781121978613) Expected: [base(14.7781121978613), base(1.0)] [base(14.7781121978613), base(1.0)] Expected: [base(3.630000000000001), base(1.0)] [base(3.630000000000001), base(1.0)] Expected: [base(137.0344903162743), base(37.75054829649429)] [base(137.0344903162743), base(37.75054829649429)] Expected: [base(137.0344903162743), base(37.75054829649429)] [base(137.0344903162743), base(37.75054829649429)] Expected: [base(0.0), base(1.0)] [base(-4.2936212903462384e-09), base(0.9999999957063787)] Expected: [base(0.0), base(1.0)] [base(-4.2936212903462384e-09), base(0.9999999957063787)] Expected: [base(1.0), base(1.0)] [base(0.9999914554400818), base(0.9999828755368568)] Expected: [base(1.0), base(1.0)] [base(0.9999914554355536), base(0.9999828755391909)] Expected: [base(0.9836065574087004), base(-0.8196721312081489)] [base(0.9836065573770492), base(-0.819672131147541)]
%--------------------------------------------------% % Copyright (C) 2023 Mark Clements. % This file is distributed under the terms specified in LICENSE. %--------------------------------------------------% % % File: ad.m. % Authors: mclements % Stability: low. % % This module defines backward and forward automatic % differentiation % %--------------------------------------------------% :- module ad. :- interface. :- import_module list. :- import_module float. %% main representation type :- type ad_number ---> dual_number(int, % epsilon (used for order of derivative) ad_number, % value ad_number) % derivative ; tape(int, % variable order (new) int, % epsilon (used for order of derivative) ad_number, % value list(ad_number), % factors list(ad_number), % tape int, % fanout ad_number) % sensitivity ; base(float). %% vector of ad_numbers :- type v_ad_number == list(ad_number). %% matrix of ad_numbers :- type m_ad_number == list(list(ad_number)). %% vector of floats :- type v_float == list(float). %% matrix of floats :- type m_float == list(list(float)). %% make_dual(Tag, Value, Derivative) constructs a dual_number :- func make_dual_number(int,ad_number,ad_number) = ad_number. %% make_dual(Tag, Epsilon, Value, Factors, Tapes) constructs a tape :- func make_tape(int, int, ad_number, v_ad_number, v_ad_number) = ad_number. %% defined functions and predicates for differentiation :- func (ad_number::in) + (ad_number::in) = (ad_number::out) is det. :- func (ad_number::in) - (ad_number::in) = (ad_number::out) is det. :- func (ad_number::in) * (ad_number::in) = (ad_number::out) is det. :- func (ad_number::in) / (ad_number::in) = (ad_number::out) is det. :- func pow(ad_number, ad_number) = ad_number. :- pred (ad_number::in) < (ad_number::in) is semidet. :- pred (ad_number::in) =< (ad_number::in) is semidet. :- pred (ad_number::in) > (ad_number::in) is semidet. :- pred (ad_number::in) >= (ad_number::in) is semidet. :- pred (ad_number::in) == (ad_number::in) is semidet. % equality :- func exp(ad_number) = ad_number is det. :- func ln(ad_number) = ad_number is det. :- func log2(ad_number) = ad_number is det. :- func log10(ad_number) = ad_number is det. :- func log(ad_number,ad_number) = ad_number is det. :- func sqrt(ad_number) = ad_number is det. :- func sin(ad_number) = ad_number is det. :- func cos(ad_number) = ad_number is det. :- func tan(ad_number) = ad_number is det. :- func asin(ad_number) = ad_number is det. :- func acos(ad_number) = ad_number is det. :- func atan(ad_number) = ad_number is det. :- func atan2(ad_number,ad_number) = ad_number is det. :- func sinh(ad_number) = ad_number is det. :- func cosh(ad_number) = ad_number is det. :- func tanh(ad_number) = ad_number is det. :- func abs(ad_number) = ad_number is det. %% TODO: add further functions and operators %% derivative_F(F,Theta,Derivative,!Epsilon) takes a function F and initial values Theta, %% and returns the Derivarive, with input and output for Epsilon (accounting on the derivatives). %% Uses forward differentiation. :- pred derivative_F((func(ad_number) = ad_number)::in, ad_number::in, ad_number::out, int::in, int::out) is det. %% derivative_F(F,Theta,Derivative) takes a function F and initial values Theta, %% and returns the Derivative, assuming the default derivative count. %% Uses forward differentiation. :- pred derivative_F((func(ad_number) = ad_number)::in, ad_number::in, ad_number::out) is det. %% gradient_F(F,Theta,Gradient,!Epsilon) takes a function F and initial values Theta, %% and returns the Gradient, with input and output for Epsilon (accounting on the derivatives) %% Uses forward differentiation. :- pred gradient_F((func(v_ad_number) = ad_number)::in, v_ad_number::in, v_ad_number::out) is det. %% gradient_F(F,Theta,Gradient) takes a function F and initial values Theta, %% and returns the Gradient, assuming the default derivative count. %% Uses forward differentiation. :- pred gradient_F((func(v_ad_number) = ad_number)::in, v_ad_number::in, v_ad_number::out, int::in, int::out) is det. %% gradient_F(F,Theta,Gradient) takes a function F and initial values Theta, %% and returns the Gradient, assuming the default derivative count. %% Uses backward differentiation. :- pred gradient_R((func(v_ad_number) = ad_number)::in, v_ad_number::in, v_ad_number::out, int::in, int::out) is det. %% gradient_R(F,Theta,Gradient) takes a function F and initial values Theta, %% and returns the Gradient, assuming the default derivative count. %% Uses backward differentiation. :- pred gradient_R((func(v_ad_number) = ad_number)::in, v_ad_number::in, v_ad_number::out) is det. %% gradient_ascent_F(F,Theta,Iterations,Eta,{Final,Objective,Derivatives}) %% takes a function F, initial values Theta, number of Iterations and change Epsilon, %% a calculates the *maximum*, returning the Final parameters, the Objective and the Derivatives. %% Uses forward differentiation. :- pred gradient_ascent_F((func(v_ad_number) = ad_number)::in, v_ad_number::in, int::in, float::in, {v_ad_number, ad_number, v_ad_number}::out) is det. %% gradient_ascent_R(F,Theta,Iterations,Eta,{Final,Objective,Derivatives}) %% takes a function F, initial values Theta, number of Iterations and change Epsilon, %% a calculates the *maximum*, returning the Final parameters, the Objective and the Derivatives. %% Uses backward differentiation. :- pred gradient_ascent_R((func(v_ad_number) = ad_number)::in, v_ad_number::in, int::in, float::in, {v_ad_number, ad_number, v_ad_number}::out) is det. %% multivariate_argmin_F(F,Theta,Final}) %% takes a function F and initial values Theta %% and calculates the Final values for the *minimum*. %% Uses forward differentiation. :- pred multivariate_argmin_F((func(v_ad_number) = ad_number)::in, v_ad_number::in, v_ad_number::out) is det. %% multivariate_argmin_F(F,Theta,Final}) %% takes a function F and initial values Theta %% and calculates the Final values for the *minimum*. %% Uses backward differentiation. :- pred multivariate_argmin_R((func(v_ad_number) = ad_number)::in, v_ad_number::in, v_ad_number::out) is det. %% multivariate_argmax_F(F,Theta,Final}) %% takes a function F and initial values Theta %% and calculates the Final values for the *maximum*. %% Uses forward differentiation. :- pred multivariate_argmax_F((func(v_ad_number) = ad_number)::in, v_ad_number::in, v_ad_number::out) is det. %% multivariate_argmax_R(F,Theta,Final}) %% takes a function F and initial values Theta %% and calculates the Final values for the *maximum*. %% Uses backward differentiation. :- pred multivariate_argmax_R((func(v_ad_number) = ad_number)::in, v_ad_number::in, v_ad_number::out) is det. %% multivariate_max_F(F,Theta,Value}) %% takes a function F and initial values Theta %% and calculates the *maximum* Value. %% Uses forward differentiation. :- pred multivariate_max_F((func(v_ad_number) = ad_number)::in, v_ad_number::in, ad_number::out) is det. %% multivariate_max_R(F,Theta,Value}) %% takes a function F and initial values Theta %% and calculates the *maximum* Value. %% Uses backward differentiation. :- pred multivariate_max_R((func(v_ad_number) = ad_number)::in, v_ad_number::in, ad_number::out) is det. %% Some common utilities %% sqr(X) = X*X :- func sqr(ad_number) = ad_number. %% map_n(F,N) = list.map(F, 1..N). :- func map_n(func(int) = ad_number, int) = v_ad_number. %% vplus(X,Y) = X + Y :- func vplus(v_ad_number, v_ad_number) = v_ad_number. %% vminus(X,Y) = X - Y :- func vminus(v_ad_number, v_ad_number) = v_ad_number. %% ktimesv(K,V) = K*V :- func ktimesv(ad_number, v_ad_number) = v_ad_number. %% magnitude_squared(V) = sum_i(V[i]*V[i]) :- func magnitude_squared(v_ad_number) = ad_number. %% magnitude(V) = sqrt(sum_i(V[i]*V[i])) :- func magnitude(v_ad_number) = ad_number. %% distance_squared(X,Y) = magnitude_sqrt(X-Y) :- func distance_squared(v_ad_number,v_ad_number) = ad_number. %% distance(X,Y) = magnitude(X-Y) :- func distance(v_ad_number,v_ad_number) = ad_number. %% fdiff(F,X,Eps) gives the first derivative for F at X using a %% symmetric finite difference using Eps :- func fdiff(func(float)=float,float,float) = float. %% fdiff(F,X) = fdiff(F,X,1.0e-5) :- func fdiff(func(float)=float,float) = float. %% submodule for operations and functions on v_ad_number :- module ad.v. :- interface. %% Addition :- func (v_ad_number::in) + (v_ad_number::in) = (v_ad_number::out) is det. %% Subtraction :- func (v_ad_number::in) - (v_ad_number::in) = (v_ad_number::out) is det. %% multiplication by a scalar :- func (ad_number::in) * (v_ad_number::in) = (v_ad_number::out) is det. %% convert from a vector of floats :- func from_list(v_float) = v_ad_number. %% convert of a vector of floats :- func to_list(v_ad_number) = v_float is det. :- end_module ad.v. %% submodule for operations and functions on m_ad_number :- module ad.m. :- interface. %% Addition :- func (m_ad_number::in) + (m_ad_number::in) = (m_ad_number::out) is det. %% Subtraction :- func (m_ad_number::in) - (m_ad_number::in) = (m_ad_number::out) is det. %% convert from a matrix of floats :- func from_lists(m_float) = m_ad_number. %% convert of a matrix of floats :- func to_lists(m_ad_number) = m_float is det. :- end_module ad.m. %% fanout(Tape) is the fanout operation for backward differentiation :- func determine_fanout(ad_number) = ad_number. %% reverse_phase(Sensitivity,Tape) is the reverse pahse for backward differentiation :- func reverse_phase(ad_number, ad_number) = ad_number. %% extract_gradients(Tape) extracts the gradients as a vector :- func extract_gradients(ad_number) = v_ad_number. %% to_float(Ad_number) return a float representation :- func to_float(ad_number) = float.