Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add method to set default kind of new variables in VarStore #893

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions src/nn/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,26 +101,24 @@ pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
};

/// Creates a new float tensor with the specified shape, device, and initialization.
pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError> {
pub fn f_init(i: Init, dims: &[i64], device: Device, kind: Kind) -> Result<Tensor, TchError> {
match i {
Init::Const(cst) => {
// Optimize the case for which a single C++ code can be done.
if cst == 0. {
Tensor::f_zeros(dims, (Kind::Float, device))
Tensor::f_zeros(dims, (kind, device))
} else if (cst - 1.).abs() <= f64::EPSILON {
Tensor::f_ones(dims, (Kind::Float, device))
Tensor::f_ones(dims, (kind, device))
} else {
Tensor::f_ones(dims, (Kind::Float, device)).map(|t| t * cst)
Tensor::f_ones(dims, (kind, device)).map(|t| t * cst)
}
}
Init::Uniform { lo, up } => {
Tensor::f_zeros(dims, (Kind::Float, device))?.f_uniform_(lo, up)
}
Init::Uniform { lo, up } => Tensor::f_zeros(dims, (kind, device))?.f_uniform_(lo, up),
Init::Randn { mean, stdev } => {
if mean == 0. && (stdev - 1.).abs() <= f64::EPSILON {
Tensor::f_randn(dims, (Kind::Float, device))
Tensor::f_randn(dims, (kind, device))
} else {
Tensor::f_randn(dims, (Kind::Float, device)).map(|t| t * stdev + mean)
Tensor::f_randn(dims, (kind, device)).map(|t| t * stdev + mean)
}
}
Init::Kaiming { dist, fan, non_linearity } => {
Expand All @@ -130,10 +128,10 @@ pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError>
match dist {
NormalOrUniform::Uniform => {
let bound = 3f64.sqrt() * std;
Tensor::f_zeros(dims, (Kind::Float, device))?.f_uniform_(-bound, bound)
Tensor::f_zeros(dims, (kind, device))?.f_uniform_(-bound, bound)
}
NormalOrUniform::Normal => {
let randn = Tensor::f_randn(dims, (Kind::Float, device))?;
let randn = Tensor::f_randn(dims, (kind, device))?;
Ok(randn * std)
}
}
Expand All @@ -148,7 +146,7 @@ pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError>
let cols: i64 = dims.iter().skip(1).product();

let mut flattened =
Tensor::f_empty([rows, cols], (Kind::Float, device))?.f_normal_(0.0, 1.0)?;
Tensor::f_empty([rows, cols], (kind, device))?.f_normal_(0.0, 1.0)?;
let flattened = if rows < cols { flattened.f_t_()? } else { flattened };

let (mut q, r) = Tensor::f_linalg_qr(&flattened, "reduced")?;
Expand All @@ -166,7 +164,7 @@ pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError>

/// Creates a new float tensor with the specified shape, device, and initialization.
pub fn init(i: Init, dims: &[i64], device: Device) -> Tensor {
f_init(i, dims, device).unwrap()
f_init(i, dims, device, Kind::Float).unwrap()
}

impl Init {
Expand Down Expand Up @@ -197,7 +195,9 @@ impl Init {
tensor.copy_(&(tensor.randn_like() * stdev + mean));
}
Init::Orthogonal { gain } => {
let q = f_init(Init::Orthogonal { gain }, &tensor.size(), tensor.device()).unwrap();
let q =
f_init(Init::Orthogonal { gain }, &tensor.size(), tensor.device(), Kind::Float)
.unwrap();
crate::no_grad(|| tensor.view_as(&q).copy_(&q));
}
}
Expand Down
19 changes: 16 additions & 3 deletions src/nn/var_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct Variables {
pub struct VarStore {
pub variables_: Arc<Mutex<Variables>>,
device: Device,
kind: Kind,
}

/// A variable store with an associated path for variables naming.
Expand All @@ -57,7 +58,7 @@ impl VarStore {
pub fn new(device: Device) -> VarStore {
let variables =
Variables { named_variables: HashMap::new(), trainable_variables: Vec::new() };
VarStore { variables_: Arc::new(Mutex::new(variables)), device }
VarStore { variables_: Arc::new(Mutex::new(variables)), device, kind: Kind::Float }
}

pub fn merge(var_stores: Vec<(VarStore, Option<&str>)>) -> Result<VarStore, TchError> {
Expand Down Expand Up @@ -110,6 +111,11 @@ impl VarStore {
self.device
}

/// Gets the default kind of new variables
pub fn kind(&self) -> Kind {
self.kind
}

/// Returns the number of tensors currently stored on this var-store.
pub fn len(&self) -> usize {
let variables = self.variables_.lock().unwrap();
Expand Down Expand Up @@ -322,13 +328,15 @@ impl VarStore {
}
}

/// Casts all variables in a var store to the target kind .
/// Casts all variables in a var store to the target kind and sets the default kind
/// for new variables.
///
/// For floating-point conversion, methods `half`, `bfloat16`, `float` and `double`
/// should be preferred as they ensure only float-like variables will be converted
/// to the target type.
pub fn set_kind(&mut self, kind: Kind) {
self.root().set_kind(kind);
self.kind = kind;
}

/// Casts all float-like variable of a var store to half-precision (Half kind).
Expand Down Expand Up @@ -410,6 +418,11 @@ impl<'a> Path<'a> {
self.var_store.device
}

/// Gets the default kind of new variables
pub fn kind(&self) -> Kind {
self.var_store.kind
}

pub fn path(&self, name: &str) -> String {
if name.chars().any(|x| x == SEP) {
panic!("variable name cannot contain {SEP} {name}");
Expand Down Expand Up @@ -551,7 +564,7 @@ impl<'a> Path<'a> {
/// The variable uses a float tensor initialized as per the
/// related argument.
pub fn f_var(&self, name: &str, dims: &[i64], init: Init) -> Result<Tensor, TchError> {
let v = super::f_init(init, dims, self.device())?;
let v = super::f_init(init, dims, self.device(), self.kind())?;
Ok(self.add(name, v, true))
}

Expand Down
3 changes: 2 additions & 1 deletion tests/var_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ fn init_test() {
"{}",
"ortho_norm initialization failed {ortho_norm}"
);
let ortho_shape_fail = tch::nn::f_init(Init::Orthogonal { gain: 1.0 }, &[10], Device::Cpu);
let ortho_shape_fail =
tch::nn::f_init(Init::Orthogonal { gain: 1.0 }, &[10], Device::Cpu, tch::Kind::Float);
assert!(ortho_shape_fail.is_err());
let kaiming_u = vs.root().var("kaiming_u", &[20, 100], nn::init::DEFAULT_KAIMING_UNIFORM);
assert!(f64::abs(f64_from(&kaiming_u.mean(Kind::Float))) < 5e-3);
Expand Down
Loading