Skip to content

Commit

Permalink
add mpc operator add, move mean_normalize to ml.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kaih70 committed Sep 16, 2020
1 parent 57c82ab commit 63ae7e6
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 123 deletions.
21 changes: 14 additions & 7 deletions core/paddlefl_mpc/mpc_protocol/aby3_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,17 +319,24 @@ class Aby3OperatorsImpl : public MpcOperators {
auto a_tuple = from_tensor(in);
auto a_ = std::get<0>(a_tuple).get();

auto b_tuple = from_tensor<BoolTensor>(pos_info);
auto b_ = std::get<0>(b_tuple).get();

auto out_tuple = from_tensor(out);
auto out_ = std::get<0>(out_tuple).get();

if (pos_info) {
auto b_tuple = from_tensor<BoolTensor>(pos_info);
auto b_ = std::get<0>(b_tuple).get();
a_->max_pooling(out_, b_);
}

void max(const Tensor* in, Tensor* out) override {

auto a_tuple = from_tensor(in);
auto a_ = std::get<0>(a_tuple).get();

auto out_tuple = from_tensor(out);
auto out_ = std::get<0>(out_tuple).get();

a_->max_pooling(out_, b_);
} else {
a_->max_pooling(out_, nullptr);
}
a_->max_pooling(out_, nullptr);
}

void inverse_square_root(const Tensor* in, Tensor* out) override {
Expand Down
4 changes: 4 additions & 0 deletions core/paddlefl_mpc/mpc_protocol/mpc_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class MpcOperators {
// for filter in other shape, reshape input first
virtual void max_pooling(const Tensor* in, Tensor* out, Tensor* pos_info) {}

// column wise max
// in shape [n, ...], out shape [1, ...]
virtual void max(const Tensor* in, Tensor* out) {}

virtual void inverse_square_root(const Tensor* in, Tensor* out) = 0;

virtual void predicts_to_indices(const Tensor* in,
Expand Down
4 changes: 2 additions & 2 deletions core/paddlefl_mpc/operators/mpc_mean_normalize_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ class MpcMeanNormalizationKernel : public MpcOpKernel<T> {
->mpc_operators()->neg(min, &neg_min);

mpc::MpcInstance::mpc_instance()->mpc_protocol()
->mpc_operators()->max_pooling(&neg_min, &neg_min_global, nullptr);
->mpc_operators()->max(&neg_min, &neg_min_global);

mpc::MpcInstance::mpc_instance()->mpc_protocol()
->mpc_operators()->max_pooling(max, &max_global, nullptr);
->mpc_operators()->max(max, &max_global);

range->mutable_data<T>(
framework::make_ddim({share_num, 1, feat_num}), context.GetPlace(), 0);
Expand Down
3 changes: 0 additions & 3 deletions python/paddle_fl/mpc/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
from .rnn import *
from . import metric_op
from .metric_op import *
from . import data_preprocessing
from .data_preprocessing import *

__all__ = []
__all__ += basic.__all__
Expand All @@ -48,4 +46,3 @@
__all__ += compare.__all__
__all__ += conv.__all__
__all__ += metric_op.__all__
__all__ += data_preprocessing.__all__
107 changes: 0 additions & 107 deletions python/paddle_fl/mpc/layers/data_preprocessing.py

This file was deleted.

98 changes: 94 additions & 4 deletions python/paddle_fl/mpc/layers/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
'pool2d',
'batch_norm',
'reshape',
'mean_normalize',
]


Expand Down Expand Up @@ -612,7 +613,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):

helper = MpcLayerHelper("reshape2", **locals())
_helper = LayerHelper("reshape2", **locals())

def get_new_shape_tensor(list_shape):
new_shape_tensor = []
for dim in list_shape:
Expand All @@ -625,7 +626,7 @@ def get_new_shape_tensor(list_shape):
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_shape_tensor.append(temp_out)
return new_shape_tensor

def get_attr_shape(list_shape):
unk_dim_idx = -1
attrs_shape = []
Expand Down Expand Up @@ -662,13 +663,13 @@ def get_attr_shape(list_shape):
assert len(shape) > 0, ("The size of 'shape' in reshape can't be zero, "
"but received %s." % len(shape))
attrs["shape"] = get_attr_shape(shape)

if utils._contain_var(shape):
inputs['ShapeTensor'] = get_new_shape_tensor(shape)
elif isinstance(actual_shape, Variable):
actual_shape.stop_gradient = True
inputs["Shape"] = actual_shape

out = x if inplace else helper.create_mpc_variable_for_type_inference(
dtype=x.dtype)
x_shape = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
Expand All @@ -680,3 +681,92 @@ def get_attr_shape(list_shape):
"XShape": x_shape})

return helper.append_activation(out)


def mean_normalize(f_min, f_max, f_mean, sample_num):
'''
Mean normalization is a method used to normalize the range of independent
variables or features of data.
Refer to:
https://en.wikipedia.org/wiki/Feature_scaling#Mean_normalization
Args:
f_min (Variable): A 2-D tensor with shape [P, N], where P is the party
num and N is the feature num. Each row contains the
local min feature val of N features.
f_max (Variable): A 2-D tensor with shape [P, N], where P is the party
num and N is the feature num. Each row contains the
local max feature val of N features.
f_mean (Variable): A 2-D tensor with shape [P, N], where P is the party
num and N is the feature num. Each row contains the
local min feature val of N features.
sample_num (Variable): A 1-D tensor with shape [P], where P is the
party num. Each element contains sample num
of party_i.
Returns:
f_range (Variable): A 1-D tensor with shape [N], where N is the
feature num. Each element contains global
range of feature_i.
f_mean_out (Variable): A 1-D tensor with shape [N], where N is the
feature num. Each element contains global
range of feature_i.
Examples:
.. code-block:: python
import paddle_fl.mpc as pfl_mpc
pfl_mpc.init("aby3", role, "localhost", redis_server, redis_port)
# 2 for share, 4 for 4 party, 100 for feat_num
input_size = [2, 4, 100]
mi = pfl_mpc.data(name='mi', shape=input_size, dtype='int64')
ma = pfl_mpc.data(name='ma', shape=input_size, dtype='int64')
me = pfl_mpc.data(name='me', shape=input_size, dtype='int64')
sn = pfl_mpc.data(name='sn', shape=input_size[:-1], dtype='int64')
out0, out1 = pfl_mpc.layers.mean_normalize(f_min=mi, f_max=ma,
f_mean=me, sample_num=sn)
exe = fluid.Executor(place=fluid.CPUPlace())
# feed encrypted data
f_range, f_mean = exe.run(feed={'mi': f_min, 'ma': f_max,
'me': f_mean, 'sn': sample_num}, fetch_list=[out0, out1])
'''
helper = MpcLayerHelper("mean_normalize", **locals())

# dtype = helper.input_dtype()
dtype = 'int64'

check_dtype(dtype, 'f_min', ['int64'], 'mean_normalize')
check_dtype(dtype, 'f_max', ['int64'], 'mean_normalize')
check_dtype(dtype, 'f_mean', ['int64'], 'mean_normalize')
check_dtype(dtype, 'sample_num', ['int64'], 'mean_normalize')

f_range = helper.create_mpc_variable_for_type_inference(dtype=f_min.dtype)
f_mean_out= helper.create_mpc_variable_for_type_inference(dtype=f_min.dtype)

# to avoid circular dependencies
from .math import reduce_sum

total_num = reduce_sum(sample_num)

op_type = 'mean_normalize'

helper.append_op(
type='mpc_' + op_type,
inputs={
"Min": f_min,
"Max": f_max,
"Mean": f_mean,
"SampleNum": sample_num,
"TotalNum": total_num,
},
outputs={
"Range": f_range,
"MeanOut": f_mean_out,
},
)

return f_range, f_mean_out

0 comments on commit 63ae7e6

Please sign in to comment.