00001 #ifndef __IcaNet1__
00002 #define __IcaNet1__
00003
00028 #include "RealVector.hpp"
00029
00030 class IcaNet1
00031 {
00032 private:
00033
00034 bool combined;
00035 bool superGaussian;
00036
00037 double learnRate;
00038
00039 RealVector* w;
00040 double wNorm2;
00041 RealVector* dw;
00042
00043 protected:
00044
00045
00046 virtual double gMin(double u){ return (u - u*u*u); }
00047 virtual double gPlus(double u)
00048 {
00049 wNorm2= w->vT_v();
00050 double wNorm4= wNorm2*wNorm2;
00051 return (-u*wNorm4 + u*u*u);
00052 }
00053
00054 virtual void learn(RealVector* sample);
00055
00056 public:
00057
00058
00059 IcaNet1(double rate, bool superType)
00060 {learnRate=rate; superGaussian= false; w=0; dw=0; combined= false;}
00061 IcaNet1(double rate, bool superType, bool combineFlag)
00062 {learnRate=rate; superGaussian= false; w=0; dw=0; combined= combineFlag; }
00063 virtual ~IcaNet1()
00064 {
00065 if( w != 0 )
00066 {
00067 delete w;
00068 delete dw;
00069 }
00070 }
00071
00072 virtual RealVector* getW(){return w;}
00073
00074
00075 virtual void init(int size);
00076 virtual void learning(RealVector* sample, int nmax=1);
00077
00078
00079 virtual double forward(RealVector* src);
00080
00081 virtual void output()
00082 {
00083 printf("Single ICA Net");
00084 }
00085 };
00086
00087 #endif