We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
common/layersのSoftmaxWithLossの動作についてですが、現状のものでは、
forward(x, t)の引数のx, tが例えば、 x = np.array([[1.0, 1.5, 2.0], [1.2, 1.5, 1.7]]) t = np.array([[0, 0, 1], [0, 1, 0]]) または、 x = np.array([[1.0, 1.5, 2.0], [1.2, 1.5, 1.7]]) t = np.array([2, 1]) などのバッチ形式の時は問題ないのですが、
x = np.array([1.0, 1.5, 2.0]) t = np.array([0, 0, 1]) または、 x = np.array([1.0, 1.5, 2.0]) t = np.array(2) などのベクトル形式の場合、エラーになってしまいます。
一応、クラスSoftmaxWithLossを下記のように変更し、
66 class SoftmaxWithLoss: 67 def init(self): 68 self.params, self.grads = [], [] 69 self.y = None # softmaxの出力 70 self.t = None # 教師ラベル 71 72 def forward(self, x, t): 73 self.t = t 74 self.y = softmax(x) 75 76 # 教師ラベルがone-hotベクトルの場合、正解のインデックスに変換 77 if self.y.ndim == 1: # add 78 self.t = self.t.reshape(1, self.t.size) # add 79 self.y = self.y.reshape(1, self.y.size) # add 80 81 if self.t.size == self.y.size: 82 self.t = self.t.argmax(axis=1) 83 84 loss = cross_entropy_error(self.y, self.t) 85 return loss 86 87 def backward(self, dout=1): 88 #batch_size = self.t.shape[0 # delete] 89 batch_size = self.y.shape[0] # modify 90 #print('here1') 91 #print(batch_size) 92 dx = self.y.copy() 93 dx[np.arange(batch_size), self.t] -= 1 94 dx *= dout 95 dx = dx / batch_size 96 97 return dx 98
common.functionsの関数cross_entropy_error()を下記の様に変更すると、
25 def cross_entropy_error(y, t): 26 #print('here2') 27 #if y.ndim == 1: # delete 28 # t = t.reshape(1, t.size) # delete 29 # y = y.reshape(1, y.size) # delete 30 31 # 教師データがone-hot-vectorの場合、正解ラベルのインデックスに変換 32 #if t.size == y.size: # delet 33 # t = t.argmax(axis=1) # delet 34 35 batch_size = y.shape[0] 36 37 return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size 38
ベクトル形式でも問題なく動作してくれる様なのですが、現状のものは引数のxとtに関して これ以外の入力形式を想定して書かれているのでしょうか?その場合、どういった入力形式を 想定しているのか教えて下さい。
The text was updated successfully, but these errors were encountered:
No branches or pull requests
common/layersのSoftmaxWithLossの動作についてですが、現状のものでは、
forward(x, t)の引数のx, tが例えば、
x = np.array([[1.0, 1.5, 2.0], [1.2, 1.5, 1.7]])
t = np.array([[0, 0, 1], [0, 1, 0]])
または、
x = np.array([[1.0, 1.5, 2.0], [1.2, 1.5, 1.7]])
t = np.array([2, 1])
などのバッチ形式の時は問題ないのですが、
x = np.array([1.0, 1.5, 2.0])
t = np.array([0, 0, 1])
または、
x = np.array([1.0, 1.5, 2.0])
t = np.array(2)
などのベクトル形式の場合、エラーになってしまいます。
一応、クラスSoftmaxWithLossを下記のように変更し、
66 class SoftmaxWithLoss:
67 def init(self):
68 self.params, self.grads = [], []
69 self.y = None # softmaxの出力
70 self.t = None # 教師ラベル
71
72 def forward(self, x, t):
73 self.t = t
74 self.y = softmax(x)
75
76 # 教師ラベルがone-hotベクトルの場合、正解のインデックスに変換
77 if self.y.ndim == 1: # add
78 self.t = self.t.reshape(1, self.t.size) # add
79 self.y = self.y.reshape(1, self.y.size) # add
80
81 if self.t.size == self.y.size:
82 self.t = self.t.argmax(axis=1)
83
84 loss = cross_entropy_error(self.y, self.t)
85 return loss
86
87 def backward(self, dout=1):
88 #batch_size = self.t.shape[0 # delete]
89 batch_size = self.y.shape[0] # modify
90 #print('here1')
91 #print(batch_size)
92 dx = self.y.copy()
93 dx[np.arange(batch_size), self.t] -= 1
94 dx *= dout
95 dx = dx / batch_size
96
97 return dx
98
common.functionsの関数cross_entropy_error()を下記の様に変更すると、
25 def cross_entropy_error(y, t):
26 #print('here2')
27 #if y.ndim == 1: # delete
28 # t = t.reshape(1, t.size) # delete
29 # y = y.reshape(1, y.size) # delete
30
31 # 教師データがone-hot-vectorの場合、正解ラベルのインデックスに変換
32 #if t.size == y.size: # delet
33 # t = t.argmax(axis=1) # delet
34
35 batch_size = y.shape[0]
36
37 return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size
38
ベクトル形式でも問題なく動作してくれる様なのですが、現状のものは引数のxとtに関して
これ以外の入力形式を想定して書かれているのでしょうか?その場合、どういった入力形式を
想定しているのか教えて下さい。
The text was updated successfully, but these errors were encountered: