ここ
http://mars.elcom.nitech.ac.jp/java-cai/neuro/me …
を参考にニューラルネットワークのBP学習のプログラムを作成しているのですが、
学習データについて疑問があります。
学習に使用する学習データをプログラムでは乱数を生成して作っている
のですが、学習データの生成に際して、データの分散や平均の大きさを
考えるべきなのでしょうか?
また考えるべきだとしたら、どのように評価したらよいのでしょうか?
No.1ベストアンサー
- 回答日時:
データの分散や平均の大きさを考えると、統計ですね。
参考まで。
/* Back-Propagation Program ver1.01 */
#include<stdio.h>
#include<math.h>
#include<stdlib.h>
#include<time.h>
#define INPUT 2
#define HIDDEN 4
#define OUTPUT 1
#define PATTERN 4
#define PR 100
#define MAX_T 10000
#define eta 2.4
#define eps 1.0e-4
#define alpha 0.8
#define beta 0.8
#define W0 0.5
double xi[INPUT+1],v[HIDDEN+1],o[OUTPUT],zeta[OUTPUT];
double w1[HIDDEN][INPUT+1],w2[OUTPUT][HIDDEN+1];
double d_w1[HIDDEN][INPUT+1],d_w2[OUTPUT][HIDDEN+1];
double pre_dw1[HIDDEN][INPUT+1],pre_dw2[OUTPUT][HIDDEN+1];
double data[PATTERN][INPUT],t_data[PATTERN][OUTPUT];
void load_data(char *filename);
void back_propagation();
void w_init();
double ranran();
void dw_init();
void xi_set(long int t, int p);
void forward(long int t);
void backward();
double calc_error();
void modify_w();
void w_print();
double sigmoid(double u);
main(int argc, char *argv[])
{
load_data(*++argv);
back_propagation();
w_print();
}
void load_data(char *filename)
{
int p,k,i;
double value;
FILE *fp;
fp = fopen(filename,"r");
if( fp == NULL ) {
fprintf(stderr,"File Open Error!\n");
exit(0);
}
for( p=0 ; p < PATTERN ; p++ ){
for( k=0 ; k < INPUT ; k++ ){
fscanf(fp," %lf",&value);
data[p][k] = value;
}
for( i=0 ; i < OUTPUT ; i++ ){
fscanf(fp," %lf",&value);
t_data[p][i] = value;
}
}
fclose(fp);
printf("Input Desired\n");
for( p=0 ; p<PATTERN ; p++ ){
printf("{");
for( k=0 ; k<INPUT ; k++ )
printf(" %.0lf,",data[p][k]);
printf("} -> {");
for( i=0 ; i<OUTPUT ; i++)
printf("%.0lf,",t_data[p][i]);
printf("}\n");
}
putchar('\n');
}
void back_propagation()
{
long int t;
int p;
double E,Esum;
w_init();
for( t=0 ; t < MAX_T ; t++ ){
dw_init();
for( p=0, Esum=0 ; p < PATTERN ; p++ ){
xi_set(t,p);
forward(t);
backward();
Esum += calc_error();
}
modify_w();
E = Esum / (OUTPUT * PATTERN);
if( t%PR == 0 )
printf("%ld %e\n",t,E);
if( E < eps )
break;
}
printf("\nTime = %ld",t);
if( t == MAX_T )
printf(" (MAX) You must retry!");
putchar('\n');
for( p=0 ; p < PATTERN ; p++ ){
xi_set(0,p);
forward(0);
}
printf("E = %e\n",E);
}
void w_init()
{
int i,j,k;
long time_t;
time_t = time(NULL);
//srand48(time_t);
srand(time_t);
for( j=0 ; j < HIDDEN ; j++ )
for( k=0 ; k <INPUT+1 ; k++ ){
w1[j][k] = ranran();
d_w1[j][k] = 0.0;
}
for( i=0 ; i < OUTPUT ; i++ )
for( j=0 ; j < HIDDEN+1 ; j++ ){
w2[i][j] = ranran();
d_w2[i][j] = 0.0;
}
}
double ranran()
{
double r;
//r = drand48();
r = rand();
r = r * 2*W0 - W0;
return r;
}
void dw_init()
{
int i,j,k;
for( j=0 ; j < HIDDEN ; j++)
for( k=0 ; k < INPUT+1 ; k++ ){
pre_dw1[j][k] = d_w1[j][k];
d_w1[j][k] = 0.0;
}
for( i=0 ; i <OUTPUT ; i++ )
for( j=0 ; j < HIDDEN+1 ; j++ ){
pre_dw2[i][j] = d_w2[i][j];
d_w2[i][j] = 0.0;
}
}
void xi_set(long int t, int p)
{
int i,k;
if( t%PR == 0 ) printf("Input ");
for( k=0 ; k < INPUT ; k++ ){
xi[k] = data[p][k];
if( t%PR == 0 ) printf(" %.0lf ",xi[k]);
}
xi[INPUT] = 1.0;
if( t%PR == 0 ) putchar('(');
for( i=0 ; i < OUTPUT ; i++ ){
zeta[i] = t_data[p][i];
if( t%PR == 0 ) printf(" %.0lf ",zeta[i]);
}
if( t%PR == 0 ) printf(")\n");
}
void forward(long int t)
{
int i,j,k;
double sum;
for( j=0 ; j < HIDDEN ; j++ ){
for( k=0, sum=0 ; k < INPUT+1 ; k++ )
sum += xi[k] * w1[j][k];
v[j] = sigmoid(sum);
}
if( t%PR == 0 ) printf("Output ");
v[HIDDEN] = 1.0;
for( i=0 ; i < OUTPUT ; i++ ){
for( j=0, sum=0 ; j < HIDDEN+1 ; j++ )
sum += v[j] * w2[i][j];
o[i] = sigmoid(sum);
if( t%PR == 0 ) printf(" %.4lf",o[i]);
}
if(t %PR == 0 ) putchar('\n');
}
void backward()
{
int i,j,k;
double delta2[OUTPUT],delta1[HIDDEN+1],sum;
for( i=0 ; i < OUTPUT ; i++ )
delta2[i] = beta * o[i] * (1-o[i]) * (zeta[i]-o[i]);
for( j=0 ; j < HIDDEN ; j++){
for( i=0, sum=0 ; i < OUTPUT ; i++ )
sum += w2[i][j] * delta2[i];
delta1[j] = beta * v[j] * (1-v[j]) * sum;
}
for( i=0 ; i < OUTPUT ; i++ )
for( j=0 ; j < HIDDEN+1 ; j++)
d_w2[i][j] += delta2[i] * v[j];
for( j=0 ; j < HIDDEN ; j++ )
for( k=0 ; k < INPUT+1 ; k++ )
d_w1[j][k] += delta1[j] * xi[k];
}
double calc_error()
{
double E=0;
int i;
for( i=0 ; i < OUTPUT ; i++ )
E += (zeta[i]-o[i]) * (zeta[i]-o[i]);
return E;
}
void modify_w()
{
int i,j,k;
for( i=0 ; i < OUTPUT ; i++ )
for( j=0 ; j < HIDDEN+1 ; j++ ){
d_w2[i][j] = eta * d_w2[i][j] + alpha * pre_dw2[i][j];
w2[i][j] = w2[i][j] + d_w2[i][j];
}
for( j=0 ; j < HIDDEN ; j++)
for( k=0 ; k < INPUT+1 ; k++ ){
d_w1[j][k] = eta * d_w1[j][k] + alpha * pre_dw1[j][k];
w1[j][k] = w1[j][k] + d_w1[j][k];
}
}
void w_print()
{
int i,j,k;
printf("Weight\n");
for(j =0 ; j < HIDDEN ; j++ ){
printf("w1[%d]={",j);
for( k=0 ; k < INPUT ; k++ ){
if( k != 0 )
putchar(',');
printf("%.6lf",w1[j][k]);
}
printf("} theta1[%d]=%.6lf\n",j,w1[j][k]);
}
for( i=0 ; i < OUTPUT ; i++ ){
printf("w2[%d]={",i);
for( j=0 ; j < HIDDEN ;j ++ ){
if( j != 0 )
putchar(',');
printf("%.6lf",w2[i][j]);
}
printf("} theta2[%d]=%.6lf\n",i,w2[i][j]);
}
}
double sigmoid(double u)
{
return 1.0 / (1.0+exp(-beta*u));
}
この回答への補足
ファイルから読み込むデータはデータ作成の仕方によって良いデータ、悪いデータ
というのができてしまうと思うのですが、良いデータ、悪いデータの見極め方はないでしょうか?
No.2
- 回答日時:
このプログラムを例にすると、エラー値が出ます。
ゼロに近ければ成功です。
入力
x1 x2xor
000
011
101
110
出力
デフォールトでは500回に1回学習の途中結果が次のように表示されます。
Input 1 1 ( 0 ) <- パターン0の入力と教師信号
Output 0.0127 <- その入力を入れたとき実際の出力
Input 1 0 ( 1 ) <- パターン1の入力と教師信号
Output 0.9863 <- その入力を入れたとき実際の出力
Input 0 1 ( 1 ) <- パターン2の入力と教師信号
Output 0.9856 <- その入力を入れたとき実際の出力
Input 0 0 ( 0 ) <- パターン3の入力と教師信号
Output 0.0134 <- その入力を入れたとき実際の出力
500 9.205321e-05 <- 現在の学習回数とエラー値
学習が終了すると次のように表示されます。
Time = 2277 <- かかった学習の回数
Input 1 1 ( 0 ) <- パターン0の入力と教師信号
Output 0.0041 <- その入力を入れたときの出力
Input 1 0 ( 1 ) <- 以下各パターンについて同じ
Output 0.9958
Input 0 1 ( 1 )
Output 0.9949
Input 0 0 ( 0 )
Output 0.0044
E = 9.998749e-06 <- 最終的なエラー値(ゼロに近い値なはず)
Weight <- 学習終了後の重みと閾値
w1[0]={3.914727,-3.845814} theta1[0]=-2.108753
w1[1]={3.554649,-3.407200} theta1[1]=1.706878
w2[0]={5.865372,-5.808394} theta2[0]=2.825594
お探しのQ&Aが見つからない時は、教えて!gooで質問しましょう!
似たような質問が見つかりました
- IT・エンジニアリング 大規模言語モデルは今後どのように進化していくでしょうか? 1 2023/07/20 19:17
- その他(プログラミング・Web制作) python コードについて(初学者です) 3 2023/07/20 14:44
- 宇宙科学・天文学・天気 AIが答えた方程式 1 2023/02/20 00:12
- AJAX Pythonを無料(安価)で学ぶ方法ってありますか? 4 2023/08/11 17:23
- その他(コンピューター・テクノロジー) ChatGPTの学習データのサイズは800GBらしいです 受付中 1 2023/03/23 21:15
- Excel(エクセル) マクロか関数で処理したいのですが、教えて頂けませんか。 8 2022/10/31 15:18
- 物理学 大学で物理学を学んでいる人、大学で物理を学んでいた人へ質問です。 私は現在大学1年で、物理を学んでい 2 2022/10/03 20:00
- 大学受験 推薦入試について教えていただきたいことがあります。 私は、この春高校三年生になります。進路について考 1 2022/04/05 02:04
- 統計学 統計学の質問【帰無仮説】 高校の新学習指導要領では、統計的仮説検定の基本的な考え方が必修単元となった 5 2023/05/23 21:00
- 介護福祉士・ケアマネージャー・社会福祉士 介護関係者の方に真面目な相談です!! 大変悩んでます…。。 介護専門学生ですが訪問介護に3日実習いき 1 2022/05/23 16:56
おすすめ情報
デイリーランキングこのカテゴリの人気デイリーQ&Aランキング
-
マイナスからプラスへ転じた時...
-
エクセルで可視セルにのみ値貼...
-
2÷3などの余りについて
-
Aの値からBの値を除するとは??
-
20'(角度)の計算がわかりま...
-
値差の%計算方法について
-
EXCELの分散分析表のP-値が....
-
Excelで1つしかない値だけを抽...
-
10%引いた元の数字を出すには?
-
ある商品のロス率を5%見込み、...
-
yはxに比例し、x=-2のとき、y=-...
-
パーセントの出し方を教えて下さい
-
信頼区間の1.96や1.65ってどこ...
-
変数とパラメータとは違うもの...
-
折れ線グラフの下の面積の求め方
-
楕円の外周の計算方法
-
エクセル 3つの値の中からデー...
-
A÷B=X と (A/4)÷(B/4)=X でX...
-
勾配曲線とは何ですか?
-
「n進法から10進法への変換」
マンスリーランキングこのカテゴリの人気マンスリーQ&Aランキング
-
マイナスからプラスへ転じた時...
-
2÷3などの余りについて
-
信頼区間の1.96や1.65ってどこ...
-
Excelで1つしかない値だけを抽...
-
変数とパラメータとは違うもの...
-
「Aに対するBの割合」と「Aに対...
-
10C7 =10.9.8.7.6.5.4/7.6.5.4...
-
0 <= ある値Aのある値B乗 <= あ...
-
20'(角度)の計算がわかりま...
-
ある商品のロス率を5%見込み、...
-
中学数学 代表値について
-
Aの値からBの値を除するとは??
-
教えてください。数学Bの二項分...
-
エクセルで可視セルにのみ値貼...
-
a^2の√=a が成り立たない場合
-
比と比の値について。 a:b=a/b ...
-
EXCELの分散分析表のP-値が....
-
値差の%計算方法について
-
10%引いた元の数字を出すには?
-
パーセントの出し方を教えて下さい
おすすめ情報