Skip to content

Commit

Permalink
Merge pull request #7 from ablaom/docstring-tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
boborbt authored Apr 9, 2024
2 parents bcafc93 + b0d6408 commit 4afd14a
Showing 1 changed file with 90 additions and 41 deletions.
131 changes: 90 additions & 41 deletions src/PartitionedLS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const MMI = MLJModelInterface
"""
$(TYPEDEF)
The PartLSFitResult struct represents the solution of the partitioned least squares problem.
The PartLSFitResult struct represents the solution of the partitioned least squares problem.
It contains the values of the α and β variables, the intercept t and the partition matrix P.
## Fields
Expand Down Expand Up @@ -95,7 +95,7 @@ parameter η controls the strength of the regularization.
## Main idea
K new rows are added to the data matrix X, row ``k \\in \\{1 \\dots K\\}`` is a vector of zeros except for
the components that corresponds to features belonging to the k-th partition, which is set to sqrt(η).
the components that corresponds to features belonging to the k-th partition, which is set to sqrt(η).
The target vector y is extended with K zeros.
The point of this change is that when the objective function is evaluated as ``math \\|Xw - y\\|^2``, the new part of
Expand All @@ -106,7 +106,7 @@ function regularizeProblem(X, y, P, η)
if η == 0
return X, y
end

Xn = X
yn = y
for k in 1:size(P, 2)
Expand Down Expand Up @@ -140,9 +140,9 @@ Make predictions for the datataset `X` using the PartialLS model `model`.
## Arguments
- `model`: a [PartLSFitResult](@ref)
- `X`: the data containing the examples for which the predictions are sought
## Return
the predictions of the given model on examples in X.
the predictions of the given model on examples in X.
"""
function predict(model::PartLSFitResult, X::Array{Float64,2})
(; α, β, t, P) = model
Expand All @@ -157,46 +157,64 @@ include("PartitionedLSBnB.jl")
"""
PartLS
A model type for fitting a partitioned least squares model to data.
A model type for fitting a partitioned least squares model to data. Both an MLJ and native
interfacew are provided.
# MLJ Interface
From MLJ, the type can be imported using
PartLS = @load PartLS pkg=PartitionedLS
Construct an instance with default hyper-parameters using the syntax model = FooRegressor(). Provide keyword arguments to override hyper-parameter defaults, as in FooRegressor(P=...).
Construct an instance with default hyper-parameters using the syntax `model =
PartLS()`. Provide keyword arguments to override hyper-parameter defaults, as in
`model = PartLS(P=...)`.
# Training data
## Training data
In MLJ or MLJBase, bind an instance `model` to data with
mach = machine(model, X, y)
where
- `X`: any matrix with element scitype `Float64,2`
- `X`: any matrix with element type `Float64`, or any table with columns of type `Float64`
Train the machine using `fit!(mach)`.
# Hyper-parameters
## Hyper-parameters
- `Optimizer`: the optimization algorithm to use. It can be `Opt`, `Alt` or `BnB` (names
exported by `PartitionedLS.jl`).
- `P`: the partition matrix. It is a binary matrix where each row corresponds to a
partition and each column corresponds to a feature. The element `P_{k, i} = 1` if
feature `i` belongs to partition `k`.
- `Optimizer`: the optimization algorithm to use. It can be `Opt`, `Alt` or `BnB`.
- `P`: the partition matrix. It is a binary matrix where each row corresponds to a partition and each column
corresponds to a feature. The element `P_{k, i} = 1` if feature `i` belongs to partition `k`.
- `η`: the regularization parameter. It controls the strength of the regularization.
- `ϵ`: the tolerance parameter. It is used to determine when the Alt optimization algorithm has converged. Only used by the `Alt` algorithm.
- `T`: the maximum number of iterations. It is used to determine when to stop the Alt optimization algorithm has converged. Only used by the `Alt` algorithm.
- `rng`: the random number generator to use.
- `ϵ`: the tolerance parameter. It is used to determine when the Alt optimization
algorithm has converged. Only used by the `Alt` algorithm.
- `T`: the maximum number of iterations. It is used to determine when to stop the Alt
optimization algorithm has converged. Only used by the `Alt` algorithm.
- `rng`: the random number generator to use.
- If `nothing`, the global random number generator `rand` is used.
- If an integer, the global number generator `rand` is used after seeding it with the given integer.
- If an integer, the global number generator `rand` is used after seeding it with the
given integer.
- If an object of type `AbstractRNG`, the given random number generator is used.
# Operations
## Operations
- `predict(mach, Xnew)`: return the predictions of the model on new data `Xnew`
# Fitted parameters
## Fitted parameters
The fields of `fitted_params(mach)` are:
Expand All @@ -207,31 +225,62 @@ The fields of `fitted_params(mach)` are:
- `P`: the partition matrix. It is a binary matrix where each row corresponds to a partition and each column
corresponds to a feature. The element `P_{k, i} = 1` if feature `i` belongs to partition `k`.
# Examples
## Examples
```julia
PartLS = @load FooRegressor pkg=PartLS
PartLS = @load PartLS pkg=PartitionedLS
X = [[1. 2. 3.];
[3. 3. 4.];
[8. 1. 3.];
[5. 3. 1.]]
y = [1.;
1.;
2.;
3.]
P = [[1 0];
[1 0];
[0 1]]
model = PartLS(P=P)
mach = machine(model, X, y) |> fit!
# predictions on the training set:
predict(mach, X)
```
# Native Interface
```
using PartitionedLS
X = [[1. 2. 3.];
[3. 3. 4.];
[8. 1. 3.];
X = [[1. 2. 3.];
[3. 3. 4.];
[8. 1. 3.];
[5. 3. 1.]]
y = [1.;
1.;
2.;
y = [1.;
1.;
2.;
3.]
P = [[1 0];
[1 0];
P = [[1 0];
[1 0];
[0 1]]
# fit using the optimal algorithm
# fit using the optimal algorithm
result = fit(Opt, X, y, P, η = 0.0)
y_hat = predict(result.model, X)
```
For other `fit` keyword options, refer to the "Hyper-parameters" section for the MLJ
interface.
"""
MMI.@mlj_model mutable struct PartLS <: MMI.Deterministic
Optimizer::Union{Type{Opt},Type{Alt},Type{BnB}} = Opt
Expand Down Expand Up @@ -303,7 +352,7 @@ MMI.metadata_model(PartLS,
input_scitype = Union{MMI.Table{AbstractVector{MMI.Continuous}}, AbstractMatrix{MMI.Continuous}}, # what input data is supported?
target_scitype = AbstractVector{MMI.Continuous}, # for a supervised model, what target?
supports_weights = false, # does the model support sample weights?
load_path = "PartitionedLS.PartLS"
load_path = "PartitionedLS.PartLS"
)
end

Expand Down Expand Up @@ -333,7 +382,7 @@ Train the machine using `fit!(mach)`.
- `η`: the regularization parameter. It controls the strength of the regularization.
- `ϵ`: the tolerance parameter. It is used to determine when the Alt optimization algorithm has converged. Only used by the `Alt` algorithm.
- `T`: the maximum number of iterations. It is used to determine when to stop the Alt optimization algorithm has converged. Only used by the `Alt` algorithm.
- `rng`: the random number generator to use.
- `rng`: the random number generator to use.
- If `nothing`, the global random number generator `rand` is used.
- If an integer, the global number generator `rand` is used after seeding it with the given integer.
- If an object of type `AbstractRNG`, the given random number generator is used.
Expand All @@ -360,22 +409,22 @@ The fields of `fitted_params(mach)` are:
PartLS = @load FooRegressor pkg=PartLS
X = [[1. 2. 3.];
[3. 3. 4.];
[8. 1. 3.];
X = [[1. 2. 3.];
[3. 3. 4.];
[8. 1. 3.];
[5. 3. 1.]]
y = [1.;
1.;
2.;
y = [1.;
1.;
2.;
3.]
P = [[1 0];
[1 0];
P = [[1 0];
[1 0];
[0 1]]
# fit using the optimal algorithm
# fit using the optimal algorithm
result = fit(Opt, X, y, P, η = 0.0)
y_hat = predict(result.model, X)
```
Expand Down

0 comments on commit 4afd14a

Please sign in to comment.