CNN框架初探

CNN

很久没更新博客了,前段时间面试上了字节,然后考了很多考试,所以一直没时间做自己的东西,现在终于放暑假了,更新下上个月的一些小成果吧。
之前一直在寻找java的cnn框架,没找到,只找到了一个很老的深度学习框架,简单用了下,拟合个比较简单的东西还是比较方便的,但是貌似没有卷积池化这样的操作,所以不是特别适合做计算机视觉。
其实实际上java确实不是特别适合做这方面的东西,也可能是我对java的GPU操作不太熟悉,还没涉及到,但我又想快点写出来一个cnn,所以暂时只用到了CPU,在之后有机会的话呢我看看如果有对应的java操作GPU的文档和相对的分布式开发的话呢,我会认真更新下这个框架的。
目前我只支持了:

  1. 卷积层
  2. 池化层
  3. relu层
  4. dropout层
  5. 全连接层
  6. softmax和SVM

tanh和sigmod还没写,使用感觉上感觉没有relu好用,就暂时没写上去。

神经网络

这是我先前写的一个框架,我当时写的时候没用到矩阵,而是真的用的是神经网络的思路,真的是神经元,线,层,网络作为java bean,然后生成,不详细介绍如何使用了,算是一个废稿。

package Main;

import java.io.FileNotFoundException;
import java.io.IOException;

import bean.Layer;
import bean.Net;
import utils.NetSave;

public class Start {
	public static String FileName = "D:/dat.save";
	
	public static void main(String[] args) {
		try {
			creat();
		} catch (IOException e) {
			e.printStackTrace();
		}
		System.out.println("------------导入神经网络------------");
		try {
			input();
		} catch (ClassNotFoundException | IOException e) {
			e.printStackTrace();
		}
		
	}
	public static void input() throws ClassNotFoundException, IOException {
		Net n = NetSave.input(FileName);
		double[] a2 = n.test(new double[] {1.0,1.0});
		for(double c : a2)System.out.println("预期结果:0\t 输出结果:"+c);
		a2 = n.test(new double[] {0.0,1.0});
		for(double c : a2)System.out.println("预期结果:1\t 输出结果:"+c);
		a2 = n.test(new double[] {1.0,0.0});
		for(double c : a2)System.out.println("预期结果:1\t 输出结果:"+c);
		a2 = n.test(new double[] {0.0,0.0});
		for(double c : a2)System.out.println("预期结果:0\t 输出结果:"+c);
	}
	public static void creat() throws FileNotFoundException, IOException {
		Net n = new Net();//建立一个神经网络
		n.AddLayer(new Layer(2));//输入层
		n.AddLayer(new Layer(10));//隐藏层
		n.AddLayer(new Layer(20));//隐藏层
		n.AddLayer(new Layer(1));//输出层
		n.connect();//连接神经网络
		//没训练前的结果
		double[] a1 = n.test(new double[] {1.0,0.0});
		for(double c : a1)System.out.println("没训练前的结果:"+c);
		//训练输入数据
		double[] one = new double[]{0.0,0.0};
		double[] two = new double[]{1.0,0.0};
		double[] three = new double[]{0.0,1.0};
		double[] four = new double[]{1.0,1.0};
		//训练输出结果
		double[] res1 = new double[]{0.0};
		double[] res2 = new double[]{1.0};
		System.out.println("开始误差:"+n.train(two, res2));
		long startTime = System.currentTimeMillis();
		//开始训练
		for(int i = 0;i < 1000000;i++) {
			n.train(one,res1);
			n.train(two, res2);
			n.train(three, res2);
			n.train(four, res1);
		}
		System.out.println("结束误差:"+n.train(two, res2));
		System.out.println("训练所用时间:"+(double)(((System.currentTimeMillis()-startTime)/100)%60)/10+"s");
		 //测试训练结果
		double[] a2 = n.test(new double[] {1.0,1.0});
		for(double c : a2)System.out.println("预期结果:0\t 输出结果:"+c);
		
		a2 = n.test(new double[] {0.0,1.0});
		for(double c : a2)System.out.println("预期结果:1\t 输出结果:"+c);
		
		a2 = n.test(new double[] {1.0,0.0});
		for(double c : a2)System.out.println("预期结果:1\t 输出结果:"+c);
		
		a2 = n.test(new double[] {0.0,0.0});
		for(double c : a2)System.out.println("预期结果:0\t 输出结果:"+c);
		
		//保存
		NetSave.save(FileName, n);
	}
}

我测试的主要是对XOR这个模型的拟合,支持这个网络的导入导出,其实写的还是可以的,但是大网络肯定是有点困难的,这里给出输出的内容:

没训练前的结果:0.7155445521857534
开始误差:0.04045745089560178
结束误差:9.991615427844533E-6
训练所用时间:1.5s
预期结果:0	 输出结果:0.004821836413339121
预期结果:1	 输出结果:0.9959640239707258
预期结果:1	 输出结果:0.9955297659791846
预期结果:0	 输出结果:0.003557716622898065
------------导入神经网络------------
预期结果:0	 输出结果:0.004821836413339121
预期结果:1	 输出结果:0.9959640239707258
预期结果:1	 输出结果:0.9955297659791846
预期结果:0	 输出结果:0.003557716622898065

可以看到
一开始的误差是0.7,到最后能够到0.0000009,这个已经算是很好的了,用了1.5s,也算是比较快的了,而且到最后输出的结果和预期很接近了。
但是由于不太适合做一些大的网路,所以我就没有继续写这个框架了,转而编写了下面的这个卷积神经网络。

卷积神经网络

写完上面的那个简单的网络之后我就想着快点写一下这个卷积神经网络的框架,但是下手比较难,我就去学习了相关的课程之后写了一个C++版本的,后来才转到了java版本,沿用了caffe的思路,两天时间不到就把这个小框架写出来了,对于简单的网络来说拟合起来还是比较简单的,但是对于大型的网络,速度确实很欠缺而且震荡还是有一定的偏大,可能是我的模型不是很好的缘故,暂时先看我们最经典入门的手写数字集合的识别是如何实现的。
首先我们需要准备需要训练的模型以及数据。
数据集的话呢我使用的是mnist的手写训练集


然后我们准备一个模型填充我们的数据,首先需要明确一点就是我们需要填充的是一个四个维度的矩阵,四个维度分别是:

  1. 训练数据的个数
  2. 训练数据的通道
  3. 训练数据的高
  4. 训练数据的宽
    第一个是我们有多少个样本,第二个是我们样本的通道数(有rgb三通道和灰度图像这样子的),宽高就是图片的宽高了,填充方法如下:
 public static Blob getImages(String fileName) {
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
            byte[] bytes = new byte[4];
            bin.read(bytes, 0, 4);
            if (!"00000803".equals(bytesToHex(bytes))) {                        // 读取魔数
                throw new RuntimeException("Please select the correct file!");
            } else {
                bin.read(bytes, 0, 4);
                int number = Integer.parseInt(bytesToHex(bytes), 16);           // 读取样本总数
                bin.read(bytes, 0, 4);
                int xPixel = Integer.parseInt(bytesToHex(bytes), 16);           // 读取每行所含像素点数
                bin.read(bytes, 0, 4);
                int yPixel = Integer.parseInt(bytesToHex(bytes), 16);           // 读取每列所含像素点数
                Blob res = new Blob(number, 1,yPixel, xPixel, Cube.FILLZEZORS);
                //Blob res = new Blob(number, 3,yPixel, xPixel, Cube.FILLZEZORS);
                for (int i = 0; i < number; i++) {
                	for(int y = 0;y<yPixel;y++) {
                		for(int x = 0;x<xPixel;x++) {
                			res.Get(i).set(0, y, x, bin.read() / 255.0);
                			//res.Get(i).set(1, y, x, 1);
                			//res.Get(i).set(2, y, x, 0);
                		}
                	}
                }
                return res;
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * get labels of `train` or `test`
     *
     * @param fileName the file of 'train' or 'test' about label
     * @return
     */
    public static Blob getLabels(String fileName) {
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
            byte[] bytes = new byte[4];
            bin.read(bytes, 0, 4);
            if (!"00000801".equals(bytesToHex(bytes))) {
                throw new RuntimeException("Please select the correct file!");
            } else {
                bin.read(bytes, 0, 4);
                int number = Integer.parseInt(bytesToHex(bytes), 16);
                Blob res = new Blob(number, 10,1, 1, Cube.FILLZEZORS);
                for (int i = 0; i < number; i++) {
                    res.Get(i).set(bin.read(), 0, 0, 1);
                }
                return res;
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

分别需要填充到输入和输出。
输入的话呢是60000 * 1 * 28 * 28,这里需要注意的是灰度值我除了255.0
输出的话呢用one-hot,因为10种可能(0-9),所以就是60000 * 10 * 1 * 1
是个通道对应着softmax的是个输入输出。
准备完了这些之后我们需要准备我们训练的模型:

{ 
	"train":{
        "learning rate" : 0.001,
        "lr decay": 0.9999,
        "optimizer": "rmsprop",
        "momentum coefficient": 0.95,
        "rmsprop decay": 0.99,
        "reg coefficient": 0.2,
        "num epochs": 1,
        "batch size": 50,
        "lr update": true,
        "snapshot": false,
        "snapshot interval":20,
        "fine tune": false,
        "pre train model": "./test.modle"
	},
	"net":[
		{
			"name":"conv1",
			"type":"Conv",
			"kernel num":3,
			"kernel height":5,
			"kernel width":5,
			"pad":2,
			"stride":1,			
			"conv weight init": "msra"
		},
		{
			"name":"pool1",
			"type":"Pool", 
			"kernel height":3,
			"kernel width":3,
			"stride":2
		},
		{
			"name":"relu1",
			"type":"Relu" 
		},
		{
			"name":"drop1",
			"type":"Dropout",
			"drop rate": 0.5
		},
		{
			"name":"fc2",
			"type":"Fc",   
			"kernel num":10,
			"fc weight init":"msra"
		},
		{
			"name":"softmax",
			"type":"Softmax"	
		}
	]      
}

快照啥的我还没实现,主要是有点懒。。。

  • 学习率:0.001
  • 学习率衰减0.9999,给个接近一的数就好了,不用衰减的很大
  • 参数更新使用rmsprop,当然我还提供了momentum和sgd
  • 训练样本1次
  • 每个batch size有50个样本
  • 下面就是各种层了

有了模型之后我们就需要加载这个模型:

	NetParam param = new NetParam();
	try {
		param = JsonTools.GetCnnParam("myModel_cnnfirst.json");
	} catch (Exception e) {
		System.out.println("CNN Param read Error!");
	}

我提供了个静态方法专门用来导入。
填充数据:

Blob train_images = Mnist.getImages("mnist/train/train-images.idx3-ubyte");
Blob train_labels = Mnist.getLabels("mnist/train/train-labels.idx1-ubyte");

初始化网络:

Net net = null;
try {
	net = new Net(param,train_images,train_labels);
} catch (Exception e) {
	e.printStackTrace();
	System.out.println("Net init error!");
}

写的比较草,所以只要了训练集。
在训练之前我们手写几个数字:


然后我们直接输出他的可能性:

System.out.println(GetString(SoftmaxLayer.softmax_get_(net.test( Graph.getBlob("C:\\Users\\a\\Desktop\\test\\0.bmp")))[0]));
System.out.println(GetString(SoftmaxLayer.softmax_get_(net.test( Graph.getBlob("C:\\Users\\a\\Desktop\\test\\1.bmp")))[0]));
System.out.println(GetString(SoftmaxLayer.softmax_get_(net.test( Graph.getBlob("C:\\Users\\a\\Desktop\\test\\2.bmp")))[0]));
System.out.println(GetString(SoftmaxLayer.softmax_get_(net.test( Graph.getBlob("C:\\Users\\a\\Desktop\\test\\3.bmp")))[0]));
System.out.println(GetString(SoftmaxLayer.softmax_get_(net.test( Graph.getBlob("C:\\Users\\a\\Desktop\\test\\4.bmp")))[0]));
System.out.println(GetString(SoftmaxLayer.softmax_get_(net.test( Graph.getBlob("C:\\Users\\a\\Desktop\\test\\5.bmp")))[0]));
System.out.println(GetString(SoftmaxLayer.softmax_get_(net.test( Graph.getBlob("C:\\Users\\a\\Desktop\\test\\6.bmp")))[0]));
System.out.println(GetString(SoftmaxLayer.softmax_get_(net.test( Graph.getBlob("C:\\Users\\a\\Desktop\\test\\7.bmp")))[0]));
System.out.println(GetString(SoftmaxLayer.softmax_get_(net.test( Graph.getBlob("C:\\Users\\a\\Desktop\\test\\8.bmp")))[0]));
System.out.println(GetString(SoftmaxLayer.softmax_get_(net.test( Graph.getBlob("C:\\Users\\a\\Desktop\\test\\9.bmp")))[0]));

Graph.getBlob得到我们单个图片的输入,使用net.test得到需要测试的输入数据的输出值,然后传递给softmax进行概率计算,返回的是一个double类型的数组(就是每个输出的概率值),然后调用自己的GetString得到字符串输出。

public static String GetString(double[] g) {
	int index = 0;
	double max = g[0];
	for(int i = 1;i < g.length;i++) {
		if(g[i] > max) {
			max = g[i];
			index = i;
		}
	}
	String res = "预测最大可能的类别:"+index+"\t可能性:"+max;
	return res;
}

有了这些之后我们就可以调用

net.train();//训练

训练就开始了,训练期间会输出损失值,正确率和进度
训练完毕之后就可以再预测一下,看下结果,我这里简单贴一下输出的内容:

预测最大可能的类别:1	可能性:0.15088678841313677
预测最大可能的类别:7	可能性:0.13096024361576306
预测最大可能的类别:9	可能性:0.16541533979038608
预测最大可能的类别:9	可能性:0.15770358849700633
预测最大可能的类别:1	可能性:0.1580772891202943
预测最大可能的类别:9	可能性:0.1644280197622801
预测最大可能的类别:9	可能性:0.16054313587429986
预测最大可能的类别:9	可能性:0.1315615988557002
预测最大可能的类别:9	可能性:0.15567194300064927
预测最大可能的类别:9	可能性:0.16451444471281498
当前训练次数:0	損失值:2.3914618339620244	准确率14.000000000000002%	当前进度:0.0%
当前训练次数:1	損失值:2.300239870224877	准确率12.0%	当前进度:0.08333333333333334%
当前训练次数:2	損失值:2.2317275920680752	准确率24.0%	当前进度:0.16666666666666669%
当前训练次数:3	損失值:2.242225818126596	准确率20.0%	当前进度:0.25%
当前训练次数:4	損失值:2.171027707628599	准确率20.0%	当前进度:0.33333333333333337%
当前训练次数:5	損失值:2.1067916567748552	准确率22.0%	当前进度:0.4166666666666667%
当前训练次数:6	損失值:2.0932036593437715	准确率34.0%	当前进度:0.5%
当前训练次数:7	損失值:2.0302915435219466	准确率32.0%	当前进度:0.5833333333333334%
当前训练次数:8	損失值:1.8635343803725808	准确率50.0%	当前进度:0.6666666666666667%
当前训练次数:9	損失值:1.9585453836591684	准确率32.0%	当前进度:0.75%
当前训练次数:10	損失值:1.978877139528592	准确率22.0%	当前进度:0.8333333333333334%
当前训练次数:11	損失值:1.9902384241645896	准确率24.0%	当前进度:0.9166666666666666%
当前训练次数:12	損失值:2.0420873144096543	准确率26.0%	当前进度:1.0%
当前训练次数:13	損失值:1.6897064679384752	准确率54.0%	当前进度:1.0833333333333335%
当前训练次数:14	損失值:1.6926356764617654	准确率60.0%	当前进度:1.1666666666666667%
当前训练次数:15	損失值:1.7189233098386958	准确率46.0%	当前进度:1.25%
当前训练次数:16	損失值:1.480252896290082	准确率60.0%	当前进度:1.3333333333333335%
当前训练次数:17	損失值:1.5389264293606228	准确率54.0%	当前进度:1.4166666666666665%
当前训练次数:18	損失值:1.626967065007884	准确率52.0%	当前进度:1.5%
......中间内容省略......
当前训练次数:1195	損失值:0.19112057434706473	准确率96.0%	当前进度:99.58333333333333%
当前训练次数:1196	損失值:0.05352951050307908	准确率100.0%	当前进度:99.66666666666667%
当前训练次数:1197	損失值:0.0803565317731044	准确率100.0%	当前进度:99.75%
当前训练次数:1198	損失值:0.5872361406379709	准确率84.0%	当前进度:99.83333333333333%
当前训练次数:1199	損失值:0.10712553062810029	准确率100.0%	当前进度:99.91666666666667%
预测最大可能的类别:0	可能性:0.9926752540760908
预测最大可能的类别:0	可能性:0.9967641949574358
预测最大可能的类别:1	可能性:0.9918003449031905
预测最大可能的类别:2	可能性:0.9924374734570615
预测最大可能的类别:3	可能性:0.9849900242379964
预测最大可能的类别:4	可能性:0.9893327948873324
预测最大可能的类别:5	可能性:0.9545521357289892
预测最大可能的类别:6	可能性:0.9961029307767508
预测最大可能的类别:7	可能性:0.9998860249502625
预测最大可能的类别:8	可能性:0.9756162928779206
预测最大可能的类别:9	可能性:0.9679179856928872

可以看到我们一开始的预测是错误的,并且概率都在0.1左右,因为十分类,概率都是初始化都是0.1的,并且我们一开始的损失是2.3,原因是 -log(0.1) 就是2.3,随着我们的训练(大约半分钟吧),最终我们对预测的结果都是非常准确的,叨叨了0.999左右,这才是训练了一次训练集,训练个两三次差不多对这个模型就能拟合的差不多。
但是经过我对一些比较复杂的图片的识别测试发现,有的时候损失值一直不怎么下降,这也是我现在还没发出来的原因,我准备找找原因之后,先把CPU这个版本发出来,主要发亮部分:

  1. java代码框架
  2. 工具

这个工具主要由两部分组成,我前段时间编写的那个Cheetah脚本语言作为数据输入,然后驱动后面的CNN框架进行训练,前面用一个语言进行输入是为了提高工具的灵活性,比如可以从网络中进行数据获取之类的。

2 Likes