2015年8月1日土曜日

C言語でニューラルネット(2次元平面での分離)

前回に引き続き,ニューラルネットワークです.

前回は,ニューラルネットによる関数近似を行いましたが,今回はちょっとパターン認識寄り.
というのも,2次元平面上のある点が,AとBどちらに分類されるのかという問題なためです.

たとえば,以下の様な状況の場合.

2D

これを2種類に分類したいとします.
理想的な場合,このように分類できると想像できると思います.

2D2

先ほどの4点を与えたとき,大体の人はこのように線を引いて分けると思います.
このような1本の線では分離ができない問題のことを,線形分離不可能な問題と言ったりします.

線形分離可能な例はこんな感じ.

2D3

1本の直線で,分類できています.

線形分離可能な問題は,単純パーセプトロンなどでも解けますが,線形分離不可能の場合には,隠れ層を持つようなニューラルネットワークの構成である必要があります.

このような隠れ層を持つニューラルネットワークの学習方法のひとつが,誤差逆伝播法(Back Propagation)です.

というわけで,2次元での分類を行っていきます.

まず,2次元の入力(x軸,y軸)に対して出力(●,×)があるので,入力次元は2,出力次元は1です.

data.datrho.dat
001
010
100
110

まずは,線形分離可能なパターンから試してみます.
ここで,data.datの左側がx軸,右側がy軸で,rho.datの0が●,1が×を表しています.
よって,このデータは先ほどの直線で分離できるデータを数字で表したものとなっています.

これをニューラルネットワークに入力として与えると,以下のようになります.

2D4

このように,確りと分離できていることがわかります.
プログラム中では,positive.datに0.5以上(厳密には超過)の出力群を,negative.datに0.5以下の出力群をファイル出力しています.

次に,線形分離不可能の場合.

data.datrho.dat
001
010
100
111

これは最初の図に対応します.
このときの出力は,以下のようになります.

2D5

左下と右上で,×になっています.
ここからわかるように,非線形の分離問題(EX-OR問題)に対しても分離を行えていることがわかります.

プログラムリストは,前回の物から少し変えてあります.(主にTest関数内)
以下リスト.

#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#include<time.h>

#define Input_size 4
#define Input_dim 2
#define Hidden_dim 10
#define alpha 10
#define eta 0.8
#define test_size 10
#define th 0.5

#define INPUT_NAME "data.dat"
#define RHO_NAME "rho.dat"

double a[Input_size][Input_dim] = { 0 }, rho[Input_size] = { 0 };
double v[Hidden_dim][Input_dim + 1] = { 0 }, w[Hidden_dim + 1] = { 0 };
double I[Input_dim + 1] = { 0 }, H[Hidden_dim + 1] = { 0 };
double E[Input_size] = { 0 };
double O = 0;
double O_out[Input_size][Input_dim + 1] = { 0 }, test_out[(test_size+1)*(test_size+1)][Input_dim + 1];
double dk, dj[Hidden_dim] = { 0 };

void Loading();
void Init();
void Display();
void Test();
void Save();

int main(void)
{
    int i, j, k, t;
    double temp, E_max = 0;

    Loading();
    Init();

    Display();

    //main loop
    k = 0;
    while (1){
        k++;

        for (t = 0; t < Input_size; t++){
            for (i = 0; i < Input_dim; i++){
                I[i + 1] = a[t][i];
            }

            for (i = 0; i < Hidden_dim; i++){
                temp = 0;
                for (j = 0; j < Input_dim + 1; j++){
                    temp += I[j] * v[i][j];
                }
                H[i + 1] = 1 / (1 + exp(-alpha*temp));
            }

            temp = 0;
            for (i = 0; i < Hidden_dim + 1; i++){
                temp += H[i] * w[i];
            }

            O = 1 / (1 + exp(-alpha*temp));

            E[t] = 1 / 2.0*pow((rho[t] - O), 2);

            dk = -(rho[t] - O)*O*(1 - O);

            for (i = 0; i < Hidden_dim; i++){
                dj[i] = dk*w[i + 1] * H[i + 1] * (1 - H[i + 1]);
            }

            for (i = 0; i < Hidden_dim + 1; i++){
                w[i] -= eta*dk*H[i];
            }

            for (i = 0; i < Hidden_dim; i++){
                for (j = 0; j < Input_dim + 1; j++){
                    v[i][j] = v[i][j] - eta*dj[i] * I[j];
                }
            }

            for (i = 0; i < Input_dim + 1; i++){
                if (i == Input_dim){
                    O_out[t][i] = O;
                }
                else{
                    O_out[t][i] = a[t][i];
                }
            }
        }
        /*
        E_max=0;
        for (i=0;i<Input_size;i++){
        if(fabs(E[i])>E_max){
        E_max=fabs(E[i]);
        }
        }

        if(E_max<1e-5){
        printf("\nk=%d\nE_max=%f\n",k,E_max);
        for (i=0;i<Input_size;i++){
        printf("%f\n",E[i]);
        }
        printf("\nO_out=\n");
        for (i=0;i<Input_size;i++){
        for (j=0;j<Input_dim+1;j++){
        printf("%f\t",O_out[i][j]);
        }
        printf("\n");
        }
        break;
        }
        */
        if (k >= 1e6){
            printf("\nk=%d\n", k);
            for (i = 0; i < Input_size; i++){
                printf("%f\n", E[i]);
            }
            printf("\nO_out=\n");
            for (i = 0; i < Input_size; i++){
                for (j = 0; j < Input_dim + 1; j++){
                    printf("%f\t", O_out[i][j]);
                }
                printf("\n");
            }
            break;
        }
    }

    Test();
    Save();
    return 0;
}

void Loading(){
    FILE *fp;
    int i, j;

    //Input a
    if (fopen_s(&fp, INPUT_NAME, "r") != 0){
        printf("INPUT FILE open error!\n");
        exit(-1);
    }

    for (i = 0; i < Input_size; i++){
        for (j = 0; j < Input_dim; j++){
            fscanf_s(fp, "%lf ", &a[i][j]);
        }
    }
    fclose(fp);

    //Input rho 
    if (fopen_s(&fp, RHO_NAME, "r") != 0){
        printf("RHO FILE open error!\n");
        exit(-1);
    }

    for (i = 0; i < Input_size; i++){
        fscanf_s(fp, "%lf ", &rho[i]);
    }
    fclose(fp);
}

void Init(){
    int i, j;

    I[0] = 1;
    H[0] = 1;

    srand((unsigned int)time(NULL));

    for (i = 0; i < Hidden_dim; i++){
        for (j = 0; j < Input_dim + 1; j++){
            v[i][j] = 0.1*rand() / (double)RAND_MAX;
        }
    }

    for (i = 0; i < Hidden_dim + 1; i++){
        w[i] = 0.1*rand() / (double)RAND_MAX;
    }
}

void Display(){
    int i, j;

    printf("a=\n");
    for (i = 0; i < Input_size; i++){
        for (j = 0; j < Input_dim; j++){
            printf("%f", a[i][j]);
        }
        printf("\n");
    }

    printf("\nrho=\n");
    for (i = 0; i < Input_size; i++){
        printf("%f\n", rho[i]);
    }

    printf("\nv=\n");
    for (i = 0; i < Hidden_dim; i++){
        for (j = 0; j < Input_dim + 1; j++){
            printf("%f\t", v[i][j]);
        }
        printf("\n");
    }
}

void Test(){
    int i, j, p, q;
    double temp;

    for (p = 0; p < test_size + 1; p++){
        for (q = 0; q < test_size + 1; q++){
            I[1] = (double)p/test_size;
            I[2] = (double)q/test_size;

            for (i = 0; i < Hidden_dim; i++){
                temp = 0;
                for (j = 0; j < Input_dim + 1; j++){
                    temp += I[j] * v[i][j];
                }
                H[i + 1] = 1 / (1 + exp(-alpha*temp));
            }

            temp = 0;
            for (i = 0; i < Hidden_dim + 1; i++){
                temp += H[i] * w[i];
            }

            O = 1 / (1 + exp(-alpha*temp));

            for (i = 0; i < Input_dim + 1; i++){
                if (i == Input_dim){
                    test_out[p*(test_size+1)+q][i] = O;
                }
                else{
                    test_out[p*(test_size+1)+q][0] = (double)p/test_size;
                    test_out[p*(test_size+1)+q][1] = (double)q/test_size;
                }
            }
        }
    }
}

void Save(){
    int i, j;
    FILE *fp;
    FILE *p, *n;

    if (fopen_s(&fp, "./O_out.dat", "w") != 0){
        printf("O_out.dat open Error!\n");
        exit(-1);
    }

    for (i = 0; i < Input_size; i++){
        for (j = 0; j < Input_dim; j++){
            fprintf(fp, "%f\t%f\n", a[i][j], rho[i]);
        }
    }

    fclose(fp);

    if (fopen_s(&p, "./positive.dat", "w") != 0){
        exit(-1);
    }
    if (fopen_s(&n, "./negative.dat", "w") != 0){
        exit(-1);
    }

    for (i = 0; i < (test_size + 1)*(test_size + 1); i++){
        for (j = 0; j < Input_dim + 1; j++){
            if (test_out[i][Input_dim] > th){
                fprintf(p, "%f\t", test_out[i][j]);
            }
            else {
                fprintf(n, "%f\t", test_out[i][j]);
            }
        }
        if (test_out[i][Input_dim] > th){
            fprintf(p, "\n");
        }
        else {
            fprintf(n, "\n");
        }
    }
    fclose(fp);
}

0 件のコメント:

コメントを投稿