Skip to content

Commit

Permalink
docstring fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Jan 10, 2025
1 parent f1d1188 commit e174077
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 42 deletions.
4 changes: 4 additions & 0 deletions docs/src/api/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Linear Models

```@docs
StandardRidge
```

## Gaussian Regression

Currently, v0.10, is unavailable.
Expand Down
43 changes: 1 addition & 42 deletions src/train/linear_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,6 @@ The equations for ridge regression are as follows:
there's usually no need to tweak this
- `reg`: regularization coefficient. Default is set to 0.0 (linear regression).
# Examples
```jldoctest
julia> ridge_reg = StandardRidge()
StandardRidge(0.0)
julia> ol = train(ridge_reg, rand(Float32, 10, 10), rand(Float32, 10, 10))
OutputLayer successfully trained with output size: 10
julia> ol.output_matrix #visualize output matrix
10×10 Matrix{Float32}:
0.456574 -0.0407612 0.121963 … 0.859327 -0.127494 0.0572494
0.133216 -0.0337922 0.0185378 0.24077 0.0297829 0.31512
0.379672 -1.24541 -0.444314 1.02269 -0.0446086 0.482282
1.18455 -0.517971 -0.133498 0.84473 0.31575 0.205857
-0.119345 0.563294 0.747992 0.0102919 1.509 -0.328005
-0.0716812 0.0976365 0.628654 … -0.516041 2.4309 -0.113402
0.0153872 -0.52334 0.0526867 0.729326 2.98958 1.32703
0.154027 0.6013 1.05548 -0.0840203 0.991182 -0.328555
1.11007 -0.0371736 -0.0529418 0.186796 -1.21815 0.204838
0.282996 -0.263799 0.132079 0.875417 0.497951 0.273423
julia> ridge_reg = StandardRidge(0.001) #passing a value
StandardRidge(0.001)
julia> ol = train(ridge_reg, rand(Float16, 10, 10), rand(Float16, 10, 10))
OutputLayer successfully trained with output size: 10
julia> ol.output_matrix
10×10 Matrix{Float16}:
-1.251 3.074 -1.566 -0.10297 … 0.3823 1.341 -1.77 -0.445
0.11017 -2.027 0.8975 0.872 -0.643 0.02615 1.083 0.615
0.2634 3.514 -1.168 -1.532 1.486 0.1255 -1.795 -0.06555
0.964 0.9463 -0.006855 -0.519 0.0743 -0.181 -0.433 0.06793
-0.389 1.887 -0.702 -0.8906 0.221 1.303 -1.318 0.2634
-0.1337 -0.4453 -0.06866 0.557 … -0.322 0.247 0.2554 0.5933
-0.6724 0.906 -0.547 0.697 -0.2664 0.809 -0.6836 0.2358
0.8843 -3.664 1.615 1.417 -0.6094 -0.59 1.975 0.4785
1.266 -0.933 0.0664 -0.4497 -0.0759 -0.03897 1.117 0.3152
0.6353 1.327 -0.6978 -1.053 0.8037 0.6577 -0.7246 0.07336
```
"""
struct StandardRidge
Expand Down Expand Up @@ -84,8 +44,7 @@ function train(sr::StandardRidge, states::AbstractArray, target_data::AbstractAr
))
end

T = eltype(states)
output_layer = Matrix(((states * states' + T(sr.reg) * I) \
output_layer = Matrix(((states * states' + sr.reg * I) \
(states * target_data'))')
return OutputLayer(sr, output_layer, size(target_data, 1), target_data[:, end])
end

0 comments on commit e174077

Please sign in to comment.