Skip to content

mclements/mercury-ad

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mercury-ad: Mercury module for automatic differentiation

Introduction

A Mercury module for automatic differentiation that includes forward and backward differentiation.

The module adapts the approach presented in https://github.com/qobi/AD-Rosetta-Stone/. Interestingly, most of the functional approaches for backward differentiation described in https://github.com/qobi/AD-Rosetta-Stone/ uses references to update a tape in-place, whereas this library implements a more pure functional approach for the fan-out and reverse phases. [For the technical details, I added an integer tag to each tape and then collect the sensitivity values to extract the gradients.]

The module also includes some simple optimisers and implements one of the extended examples (mlp_R) given in https://github.com/qobi/AD-Rosetta-Stone/.

The general type is ad_number, which includes base(float) for numerical values. Full documentation is given below.

Test examples

Some test code is here:

:- module test_ad.

:- interface.
:- import_module io.
:- pred main(io::di, io::uo) is det.

:- implementation.
:- use_module math.
:- import_module ad.
:- import_module ad.v.
:- import_module list.
:- import_module float.

main(!IO) :-
    derivative_F(func(X) = exp(base(2.0)*X), base(1.0), GradF),
    print_line("Expected: ", !IO), print_line(base(math.exp(2.0)*2.0), !IO),
    print_line(GradF, !IO),
    gradient_R(func(List) = Y :-
		   (if List=[A,B] then Y=exp(base(2.0)*A)+B else Y = base(0.0)),
		   from_list([1.0,3.0]), Grad0),
    print_line("Expected: ", !IO), print_line([base(math.exp(2.0)*2.0),base(1.0)], !IO),
    print_line(Grad0, !IO),
    gradient_R(func(List) = Y :-
		   (if List=[A,B] then Y=B+A*A*A else Y = base(0.0)),
		   [base(1.1),base(2.3)], Grad),
    print_line("Expected: ", !IO), print_line([base(3.0*1.1*1.1),base(1.0)], !IO),
    print_line(Grad, !IO),
    gradient_R(func(List) = Y :-
		   (if List=[A,B] then Y=exp(B+A*A*A) else Y = base(0.0)),
		   [base(1.1),base(2.3)], Grad2),
    print_line("Expected: ", !IO),
    print_line([base(math.exp(2.3+1.1*1.1*1.1)*(3.0*1.1*1.1)),
		base(math.exp(2.3+1.1*1.1*1.1))], !IO),
    print_line(Grad2, !IO),
    gradient_F(func(List) = Y :-
		   (if List=[A,B] then Y=exp(B+A*A*A) else Y = base(0.0)),
		   [base(1.1),base(2.3)], Grad3),
    print_line("Expected: ", !IO),
    print_line([base(math.exp(2.3+1.1*1.1*1.1)*(3.0*1.1*1.1)),
		base(math.exp(2.3+1.1*1.1*1.1))], !IO),
    print_line(Grad3, !IO),
    multivariate_argmin_F(func(AB) = Y :-
			      if AB = [A,B]
				      then Y = A*A+(B-base(1.0))*(B-base(1.0))
								  else Y=base(0.0),
			  [base(1.0),base(2.0)],Y4),
    print_line("Expected: ", !IO),
    print_line([base(0.0),base(1.0)], !IO),
    print_line(Y4,!IO),
    multivariate_argmin_R(func(AB) = Y :-
			      if AB = [A,B]
				      then Y = A*A+(B-base(1.0))*(B-base(1.0))
								  else Y=base(0.0),
			  [base(1.0),base(2.0)],Y5),
    print_line("Expected: ", !IO),
    print_line([base(0.0),base(1.0)], !IO),
    print_line(Y5,!IO),
    multivariate_argmin_R(rosenbrock,
			  [base(-3.0),base(4.0)],Y6),
    print_line("Expected: ", !IO),
    print_line([base(1.0),base(1.0)], !IO),
    print_line(Y6,!IO),
    multivariate_argmin_F(rosenbrock,
			  [base(-3.0),base(4.0)],Y7),
    print_line("Expected: ", !IO),
    print_line([base(1.0),base(1.0)], !IO),
    print_line(Y7,!IO),
    gradient_F(func(List) = Y :-
		   (if List=[A,B] then Y=atan2(A,B) else Y = base(0.0)),
		   [base(0.5),base(0.6)], Y8),
    print_line("Expected: ", !IO),
    Y8_Y = base(fdiff(func(Yi)=math.atan2(Yi,0.6),0.5)),
    Y8_X = base(fdiff(func(Xi)=math.atan2(0.5,Xi),0.6)),
    print_line([Y8_Y, Y8_X], !IO),
    print_line(Y8,!IO).

:- func rosenbrock(v_ad_number) = ad_number.
rosenbrock(In) = Result :-
    In = [X,Y] ->
    A = base(1.0),
    B = base(100.0),
    Result = (A-X)*(A-X)+B*(Y-X*X)*(Y-X*X)
    ;
    Result = base(0.0).

Running the code and getting output:

mmc --make test_ad libad && ./test_ad
Expected: 
base(14.7781121978613)
base(14.7781121978613)
Expected: 
[base(14.7781121978613), base(1.0)]
[base(14.7781121978613), base(1.0)]
Expected: 
[base(3.630000000000001), base(1.0)]
[base(3.630000000000001), base(1.0)]
Expected: 
[base(137.0344903162743), base(37.75054829649429)]
[base(137.0344903162743), base(37.75054829649429)]
Expected: 
[base(137.0344903162743), base(37.75054829649429)]
[base(137.0344903162743), base(37.75054829649429)]
Expected: 
[base(0.0), base(1.0)]
[base(-4.2936212903462384e-09), base(0.9999999957063787)]
Expected: 
[base(0.0), base(1.0)]
[base(-4.2936212903462384e-09), base(0.9999999957063787)]
Expected: 
[base(1.0), base(1.0)]
[base(0.9999914554400818), base(0.9999828755368568)]
Expected: 
[base(1.0), base(1.0)]
[base(0.9999914554355536), base(0.9999828755391909)]
Expected: 
[base(0.9836065574087004), base(-0.8196721312081489)]
[base(0.9836065573770492), base(-0.819672131147541)]

Documentation

%--------------------------------------------------%
% Copyright (C) 2023 Mark Clements.
% This file is distributed under the terms specified in LICENSE.
%--------------------------------------------------%
%
% File: ad.m.
% Authors: mclements
% Stability: low.
%
% This module defines backward and forward automatic
% differentiation
%
%--------------------------------------------------%

:- module ad.
:- interface.
:- import_module list.
:- import_module float.

    %% main representation type
:- type ad_number --->
   dual_number(int,       % epsilon (used for order of derivative)
	       ad_number, % value
	       ad_number) % derivative
   ;
   tape(int,              % variable order (new)
	int,              % epsilon (used for order of derivative)
	ad_number,        % value
	list(ad_number),  % factors
	list(ad_number),  % tape
	int,              % fanout 
	ad_number)        % sensitivity
   ;
   base(float).

    %% vector of ad_numbers
:- type v_ad_number == list(ad_number).
    %% matrix of ad_numbers
:- type m_ad_number == list(list(ad_number)).
    %% vector of floats
:- type v_float == list(float).
    %% matrix of floats
:- type m_float == list(list(float)).

    %% make_dual(Tag, Value, Derivative) constructs a dual_number
:- func make_dual_number(int,ad_number,ad_number) = ad_number.
    %% make_dual(Tag, Epsilon, Value, Factors, Tapes) constructs a tape
:- func make_tape(int, int, ad_number, v_ad_number,
		  v_ad_number) = ad_number.

%% defined functions and predicates for differentiation
:- func (ad_number::in) + (ad_number::in) = (ad_number::out) is det.
:- func (ad_number::in) - (ad_number::in) = (ad_number::out) is det.
:- func (ad_number::in) * (ad_number::in) = (ad_number::out) is det.
:- func (ad_number::in) / (ad_number::in) = (ad_number::out) is det.
:- func pow(ad_number, ad_number) = ad_number.
:- pred (ad_number::in) < (ad_number::in) is semidet.
:- pred (ad_number::in) =< (ad_number::in) is semidet.
:- pred (ad_number::in) > (ad_number::in) is semidet.
:- pred (ad_number::in) >= (ad_number::in) is semidet.
:- pred (ad_number::in) == (ad_number::in) is semidet. % equality
:- func exp(ad_number) = ad_number is det.
:- func ln(ad_number) = ad_number is det.
:- func log2(ad_number) = ad_number is det.
:- func log10(ad_number) = ad_number is det.
:- func log(ad_number,ad_number) = ad_number is det.
:- func sqrt(ad_number) = ad_number is det.
:- func sin(ad_number) = ad_number is det.
:- func cos(ad_number) = ad_number is det.
:- func tan(ad_number) = ad_number is det.
:- func asin(ad_number) = ad_number is det.
:- func acos(ad_number) = ad_number is det.
:- func atan(ad_number) = ad_number is det.
:- func atan2(ad_number,ad_number) = ad_number is det.
:- func sinh(ad_number) = ad_number is det.
:- func cosh(ad_number) = ad_number is det.
:- func tanh(ad_number) = ad_number is det.
:- func abs(ad_number) = ad_number is det.
%% TODO: add further functions and operators

    %% derivative_F(F,Theta,Derivative,!Epsilon) takes a function F and initial values Theta,
    %% and returns the Derivarive, with input and output for Epsilon (accounting on the derivatives).
    %% Uses forward differentiation.
:- pred derivative_F((func(ad_number) = ad_number)::in, ad_number::in, ad_number::out,
		     int::in, int::out) is det.
    %% derivative_F(F,Theta,Derivative) takes a function F and initial values Theta,
    %% and returns the Derivative, assuming the default derivative count.
    %% Uses forward differentiation.
:- pred derivative_F((func(ad_number) = ad_number)::in, ad_number::in, ad_number::out) is det.

    %% gradient_F(F,Theta,Gradient,!Epsilon) takes a function F and initial values Theta,
    %% and returns the Gradient, with input and output for Epsilon (accounting on the derivatives)
    %% Uses forward differentiation.
:- pred gradient_F((func(v_ad_number) = ad_number)::in,
		   v_ad_number::in, v_ad_number::out) is det.
    %% gradient_F(F,Theta,Gradient) takes a function F and initial values Theta,
    %% and returns the Gradient, assuming the default derivative count.
    %% Uses forward differentiation.
:- pred gradient_F((func(v_ad_number) = ad_number)::in,
		   v_ad_number::in, v_ad_number::out,
		  int::in, int::out) is det.

    %% gradient_F(F,Theta,Gradient) takes a function F and initial values Theta,
    %% and returns the Gradient, assuming the default derivative count.
    %% Uses backward differentiation.
:- pred gradient_R((func(v_ad_number) = ad_number)::in,
		   v_ad_number::in, v_ad_number::out,
		   int::in, int::out) is det.
    %% gradient_R(F,Theta,Gradient) takes a function F and initial values Theta,
    %% and returns the Gradient, assuming the default derivative count.
    %% Uses backward differentiation.
:- pred gradient_R((func(v_ad_number) = ad_number)::in,
		   v_ad_number::in, v_ad_number::out) is det.

    %% gradient_ascent_F(F,Theta,Iterations,Eta,{Final,Objective,Derivatives})
    %% takes a function F, initial values Theta, number of Iterations and change Epsilon,
    %% a calculates the *maximum*, returning the Final parameters, the Objective and the Derivatives.
    %% Uses forward differentiation.
:- pred gradient_ascent_F((func(v_ad_number) = ad_number)::in,
			   v_ad_number::in,
			   int::in,
			   float::in,
			   {v_ad_number, ad_number, v_ad_number}::out) is det.
    %% gradient_ascent_R(F,Theta,Iterations,Eta,{Final,Objective,Derivatives})
    %% takes a function F, initial values Theta, number of Iterations and change Epsilon,
    %% a calculates the *maximum*, returning the Final parameters, the Objective and the Derivatives.
    %% Uses backward differentiation.
:- pred gradient_ascent_R((func(v_ad_number) = ad_number)::in,
			   v_ad_number::in,
			   int::in,
			   float::in,
			   {v_ad_number, ad_number, v_ad_number}::out) is det.

    %% multivariate_argmin_F(F,Theta,Final})
    %% takes a function F and initial values Theta
    %% and calculates the Final values for the *minimum*.
    %% Uses forward differentiation.
:- pred multivariate_argmin_F((func(v_ad_number) = ad_number)::in,
			      v_ad_number::in,
			      v_ad_number::out) is det.
    %% multivariate_argmin_F(F,Theta,Final})
    %% takes a function F and initial values Theta
    %% and calculates the Final values for the *minimum*.
    %% Uses backward differentiation.
:- pred multivariate_argmin_R((func(v_ad_number) = ad_number)::in,
			      v_ad_number::in,
			      v_ad_number::out) is det.

    %% multivariate_argmax_F(F,Theta,Final})
    %% takes a function F and initial values Theta
    %% and calculates the Final values for the *maximum*.
    %% Uses forward differentiation.
:- pred multivariate_argmax_F((func(v_ad_number) = ad_number)::in,
			      v_ad_number::in,
			      v_ad_number::out) is det.
    %% multivariate_argmax_R(F,Theta,Final})
    %% takes a function F and initial values Theta
    %% and calculates the Final values for the *maximum*.
    %% Uses backward differentiation.
:- pred multivariate_argmax_R((func(v_ad_number) = ad_number)::in,
			      v_ad_number::in,
			      v_ad_number::out) is det.

    %% multivariate_max_F(F,Theta,Value})
    %% takes a function F and initial values Theta
    %% and calculates the *maximum* Value.
    %% Uses forward differentiation.
:- pred multivariate_max_F((func(v_ad_number) = ad_number)::in,
			   v_ad_number::in,
			   ad_number::out) is det.
    %% multivariate_max_R(F,Theta,Value})
    %% takes a function F and initial values Theta
    %% and calculates the *maximum* Value.
    %% Uses backward differentiation.
:- pred multivariate_max_R((func(v_ad_number) = ad_number)::in,
			   v_ad_number::in,
			   ad_number::out) is det.

%% Some common utilities
    %% sqr(X) = X*X
:- func sqr(ad_number) = ad_number.
    %% map_n(F,N) = list.map(F, 1..N).
:- func map_n(func(int) = ad_number, int) = v_ad_number.
    %% vplus(X,Y) = X + Y
:- func vplus(v_ad_number, v_ad_number) = v_ad_number.
    %% vminus(X,Y) = X - Y
:- func vminus(v_ad_number, v_ad_number) = v_ad_number.
    %% ktimesv(K,V) = K*V
:- func ktimesv(ad_number, v_ad_number) = v_ad_number.
    %% magnitude_squared(V) = sum_i(V[i]*V[i])
:- func magnitude_squared(v_ad_number) = ad_number.
    %% magnitude(V) = sqrt(sum_i(V[i]*V[i]))
:- func magnitude(v_ad_number) = ad_number.
    %% distance_squared(X,Y) = magnitude_sqrt(X-Y)
:- func distance_squared(v_ad_number,v_ad_number) = ad_number.
    %% distance(X,Y) = magnitude(X-Y)
:- func distance(v_ad_number,v_ad_number) = ad_number.
    %% fdiff(F,X,Eps) gives the first derivative for F at X using a
    %% symmetric finite difference using Eps
:- func fdiff(func(float)=float,float,float) = float.
    %% fdiff(F,X) = fdiff(F,X,1.0e-5)
:- func fdiff(func(float)=float,float) = float.

%% submodule for operations and functions on v_ad_number
:- module ad.v.
:- interface.
    %% Addition
:- func (v_ad_number::in) + (v_ad_number::in) = (v_ad_number::out) is det.
    %% Subtraction
:- func (v_ad_number::in) - (v_ad_number::in) = (v_ad_number::out) is det.
    %% multiplication by a scalar
:- func (ad_number::in) * (v_ad_number::in) = (v_ad_number::out) is det.
    %% convert from a vector of floats
:- func from_list(v_float) = v_ad_number.
    %% convert of a vector of floats
:- func to_list(v_ad_number) = v_float is det.
:- end_module ad.v.

%% submodule for operations and functions on m_ad_number
:- module ad.m.
:- interface.
    %% Addition
:- func (m_ad_number::in) + (m_ad_number::in) = (m_ad_number::out) is det.
    %% Subtraction
:- func (m_ad_number::in) - (m_ad_number::in) = (m_ad_number::out) is det.
    %% convert from a matrix of floats
:- func from_lists(m_float) = m_ad_number.
    %% convert of a matrix of floats
:- func to_lists(m_ad_number) = m_float is det.
:- end_module ad.m.

    %% fanout(Tape) is the fanout operation for backward differentiation 
:- func determine_fanout(ad_number) = ad_number.
    %% reverse_phase(Sensitivity,Tape) is the reverse pahse for backward differentiation
:- func reverse_phase(ad_number, ad_number) = ad_number.
    %% extract_gradients(Tape) extracts the gradients as a vector
:- func extract_gradients(ad_number) = v_ad_number.
    %% to_float(Ad_number) return a float representation
:- func to_float(ad_number) = float.

About

Mercury library for automatic differentiation

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published