logo资料库

径向基网络.doc

第1页 / 共12页
第2页 / 共12页
第3页 / 共12页
第4页 / 共12页
第5页 / 共12页
第6页 / 共12页
第7页 / 共12页
第8页 / 共12页
资料共12页,剩余部分请下载后查看
一 径向基网络结构和原理
二 径向基网络的学习算法
三 径向基网络具体实现方法
四 实验仿真
实验结果
五 附件:
附件1
附件2 :实验数据源
一 径向基网络结构和原理 径向基网络的拓扑结构图如下所示,其网络有三层构成,第一层是输入 层,第二层是隐含层,第三层是输出层。采用径向基函数(常用高斯函数) 作为基函数,将输入向量空间转换到隐含层空间,实现对原问题的线性可 分。 图 1-径向基网络结构 径向基网络核心是隐含层采用了径向基函数,它计算的是输入向量和基 函数中心之间的欧式距离,而不是输入向量与权值的内积。基函数一般采 用的是高斯函数。 径向基函数其实就是沿径向对称的变量函数,通常定义为空间中任意一 点 x 到某一中心 c 的之间欧式距离的单调函数。当 x 远离 c 时取得很小的 值,当 x 越靠近中心 c 时取得值越大,在中心 c 出取得最大的值。现已证 明,在一定的条件下,径向基φ(||x-c||)可以逼近几乎所有的函数。
图 2-高斯函数 径向基函数最早是用来解决插值处理问题,具体的是在每个样本上面放 一个高斯函数,函数的中心处在样本点上,如下图所示,然后假设真实的 拟合训练数据的是蓝色的。对于新的数据 x1,f(x1)的值等于 b 点的坐标加 上 c 点的坐标,b 点的纵坐标的是第一个高斯函数的值乘以一个大点的权值, c 的纵坐标是第二个高斯函数的值乘以小点的权值得到,其他样本点的值都 是 0.所以 x1 点的函数值由附近的 b 和 c 确定,拓展到任意新的 x,红色高 斯函数乘以一个权值后在对应 x 的地方加起来就可以拟合真实的函数曲线。 二 径向基网络的学习算法 根据上面的网络的结构,可以知径向基网络需要训练的有三个参数①隐 含层中基函数的中心②隐含层中基函数的标准差③隐含层与输出层间的权 值 常见的学习算法有随机选取固定中心、自组织选取中心和有监督选取中 心随机选取中心只需要训练隐含层和输出层之间的权值,其他参数都是固 定的。 自组织选取中心需要训练的是基函数的中心和隐含层和输出层的权值, 基函数中心的选取常采用聚类的方法。 有监督选取中心上面的 3 个参数都是通过监督学习获得的。 三 径向基网络具体实现方法
采用监督学习算法对网络所有的参数(径向基函数的中心、方差和隐含 层到输出层的权值)进行训练,主要是对代价函数(均方误差)进行梯度 下降,然后修正每个参数。具体如下: (1)随机初始化径向基函数的中心、方差和隐含层到输出层的权值。 (2)通过梯度下降来对网络中的三种参数都进行监督训练优化。代价 函数是网络输出和期望输出的均方误差: 具体核心步骤: (1) 定义代价函数 (2) 修改基函数中心 (3) 修改标准差 (4) 修改权值 四 实验仿真 实验结果 中心数 30 30 30 35 120 30 学习率 0.025 0.025 0.025 0.025 0.025 0.1 迭代次数 1500 2000 3000 3000 3000 3000 准确度 80%-85.7143% 74.2857%-88.5741% 74.2857%-94.2857% 80%-85% 80%-82.8571% 80%-88.5714%
五 附件: 附件 1 learnRBF.m 文件 close all;clear;clc; % train data num trainNumSamples=115; % test data num testNumSamples=35; % input data demensions inputSize=4; % output data demensions outputSize=3; %% --------------------------------------------------------------- %% ----------------step 0: load data and initial input data------- display('step 0 :load data and initial input data...'); %assign random values in the range [-1, +1] % the learning process % load the training set data and test set data load data.dat % normalize the input data to rang [-1 +1] datanew=data(:,3:6); maxv=max(max(datanew));
minv=min(min(datanew)); datanorm=2*((datanew-minv)/(maxv-minv)-0.5); %train_x=zeros(inputSize,trainNumSamples); % 4 115 %train_y=zeros(outputSize,trainNumSamples); % 3 115 outputData=zeros(150,3); for i=1:150 if(data(i,2)==0) outputData(i,:)=[1 0 0]; elseif(data(i,2)==1) outputData(i,:)=[0 1 0]; outputData(i,:)=[0 0 1]; else end end datanew=[datanorm outputData]; idx=randperm(150); idx=idx(1:trainNumSamples); train=datanew(idx,:); test=datanew; test(idx,:)=[]; train_x=train(:,1:4)'; train_y=train(:,5:7)'; test_x=test(:,1:4)'; test_y=test(:,5:7)'; %% --------------------------------------------------------------- %% ---------------------------step 1:initial rbf data------------- numSamples=size(train_x,2); rbf.inputSize=size(train_x,1); % num of Radial Basis function rbf.hiddenSize=30; rbf.outputSize=size(train_y,1); % learning rate rbf.alpha=0.025; % center of RBF for i=1:rbf.hiddenSize % randomly pick up some samples to initialize center index=randi([1,numSamples]); rbf.center(:,i)=train_x(:,index); end % delta of RBF
rbf.delta =rand(1,rbf.hiddenSize); % weight of RBF r =1.0; % random number betweeen[-r,r] rbf.weight=rand(rbf.outputSize,rbf.hiddenSize)*2*r-r; %% --------------------------------------------------------------- %% --------------------------step 2:start training---------------- display('step 2:start training...'); maxIter=3000; preCost=0; for i=1:maxIter fprintf(1,'Iteration %d,',i); rbf=trainRBF(rbf,train_x,train_y); fprintf(1,'the cost is %d\n ',rbf.cost); curCost= rbf.cost; if abs(curCost-preCost)<1e-8 disp('Reached Iteration termination condition and Termination now !'); break; end preCost=curCost; end %% --------------------------------------------------------------- %% -----------------------------step 3:start testing-------------- display('step 3:start testing...'); output=zeros(outputSize,testNumSamples); Green=zeros(rbf.hiddenSize,1); for i=1:size(test_x,2) Green(j,1)=green(test_x(:,i),rbf.center(:,j),rbf.delta(j)); for j=1:rbf.hiddenSize end output(:,i)=rbf.weight*Green; end end Otest=zeros(3,35); for i=1:size(output,2) Otest(:,i)=output(:,i)>[0.8;0.8;0.8]; desire=zeros(testNumSamples,1); test=zeros(testNumSamples,1); for i=1:testNumSamples if (test_y(1,i)==1) desire(i)=0;
elseif (test_y(2,i)==1) desire(i)=1; desire(i)=2; else end end for i=1:testNumSamples if (Otest(1,i)==1) elseif (Otest(2,i)==1) test(i)=0; test(i)=1; elseif (Otest(3,i)==1) test(i)=2; test(i)=3; else end end % calculate the accuracy of test sets Accuracy=0; for i=1:testNumSamples if (test(i)==desire(i)) Accuracy=Accuracy+1; end end Accuracy=100*Accuracy/testNumSamples; t=['准确度: ' num2str(Accuracy) '%']; disp(t); % plot the output and desired output i=1:testNumSamples; figure;plot(i,desire,'b-',i,test,'r-'); title('有监督选取中心'); legend('期望输出',t); trainRBF.m 文件 function [rbf]=trainRBF(rbf,train_x,train_y) %% ----------------------step1:calculate gradient--------------- numSamples = size(train_x,2); Green=zeros(rbf.hiddenSize,1); output=zeros(rbf.outputSize,1); delta_weight=zeros(rbf.outputSize,rbf.hiddenSize); delta_center=zeros(rbf.inputSize,rbf.hiddenSize); delta_delta=zeros(1,rbf.hiddenSize);
rbf.cost=0; for i=1:numSamples % Feed forward for j=1:rbf.hiddenSize end output=rbf.weight*Green; Green(j,1)=green(train_x(:,i),rbf.center(:,j),rbf.delta(j)); % Back propagation delta3=-(train_y(:,i)-output); rbf.cost=rbf.cost+sum(delta3.^2); delta_weight=delta_weight+delta3*Green'; delta2=rbf.weight'*delta3.*Green; for j=1:rbf.hiddenSize delta_center(:,j)=delta_center(:,j)+delta2(j) .* (train_x(:, i) - rbf.center(:, j)) ./ rbf.delta(j)^2; delta_delta(j) = delta_delta(j)+ delta2(j) * sum((train_x(:, i) - rbf.center(:, j)).^2) ./ rbf.delta(j)^3; end end %% --------------------------step 2 : update parameters--------- rbf.cost=0.5*rbf.cost./numSamples; rbf.weight=rbf.weight-rbf.alpha.*delta_weight./numSamples; rbf.center=rbf.center-rbf.alpha.*delta_center./numSamples; rbf.delta=rbf.delta-rbf.alpha.*delta_delta./numSamples; end green.m文件 function greenValue=green(x,c,delta) greenValue=exp(-1.0*sum((x-c).^2)/(2*delta^2)); end 附件 2 :实验数据源 说明: 第 1 列是行号无实际意义,第 2 列表示蝴蝶花的种类,第 3 列到第 6 列表示花的属性值。 % data.dat 1 0 2 14 33 50 2 1 24 56 31 67 3 1 23 51 31 69 4 0 2 10 36 46
分享到:
收藏