PLA感知学习算法

Wesley13
• 阅读 580
  1 #include <vector>
  2 #include<iomanip>
  3 #include <string>
  4 #include<stdio.h>
  5 #include<string.h> 
  6 #include <fstream>
  7 #include <iostream>
  8 #include<set>
  9 #include<algorithm>
 10 #include<cstdio>
 11 #include<iomanip>
 12 #include<map>
 13 #include<cmath>
 14 #define col 41
 15 #define row 7000
 16 
 17 using namespace std;
 18 double label[8010][80];           //训练集 
 19 double test_label[8010][80];      //测试集
 20 double valition_label[8010][80];  //验证集 
 21 string s[8010]; 
 22 string ss[8010]; 
 23 string s2[8010]; 
 24 
 25 
 26 //logitic函数,将负无穷到正无穷 转化 -1到 1 
 27  
 28 double logistic(double n){
 29     
 30    return 1/(1+exp(-1.0*n));
 31     
 32 }
 33 
 34 double geterror(int n){
 35     return 0.001;
 36 }
 37 
 38 double cut_t(string s, int t){
 39     string str = s;
 40     int r = 0;
 41     double count = 0.0;
 42     bool flag = true;
 43     double flag1 = 1.0;
 44     double sum = 0.0;
 45     for(int i=0;i<str.length();i++){
 46         
 47         if(r==t && str[i] == '-'){
 48             flag1 = -1;
 49             continue;
 50         }
 51         if(str[i]==','){
 52             r++;
 53             continue;
 54         }
 55         if(r==t){
 56             if(flag == false){
 57                 count ++;
 58             }
 59             if(str[i] == '.'){
 60                 flag = false;
 61             }
 62             else {
 63                 sum = sum + (str[i] - '0') * 1.0;
 64                 sum = sum * 10;
 65             }
 66 
 67         }
 68     }
 69   
 70     for(int i=0;i<=count;i++){
 71          sum = sum/10;
 72     }
 73     return sum*flag1;
 74 }
 75 
 76 
 77 
 78 int main()
 79 {
 80 
 81 
 82 /*************************************************读文件***********************************************************/ 
 83   
 84     fstream myfile("C:\\AI_data\\lab5\\train.csv");
 85     fstream valition("C:\\AI_data\\lab5\\valition.csv");
 86     fstream test("C:\\AI_data\\lab5\\test.csv");
 87     
 88     
 89     int num=0;
 90     string temp;
 91     if(!myfile.is_open())
 92     {
 93         cout << "1未成功打开文件" << endl;
 94     }
 95     while(getline(myfile,temp))   //读入文本中的词 
 96     {
 97         s[num] = temp;
 98         num++;
 99     }        
100     
101    
102     int num1 = 0;
103     string temp1;
104     if(!test.is_open())
105     {
106         cout << "2未成功打开文件" << endl;
107     }
108     while(getline(test,temp1))   //读入文本中的词 
109     {
110         ss[num1] = temp1;
111         num1++;
112     }
113        
114        
115     int num2 = 0;
116     string temp2;
117     if(!valition.is_open())
118     {
119         cout << "3未成功打开文件" << endl;
120     }
121     while(getline(valition,temp2))   //读入文本中的词 
122     {
123         s2[num2] = temp2;
124         num2++;
125     }
126     
127    
128 /***********************************************处理文本********************************************************************/ 
129     for(int i=0; i<num; i++){
130         
131         int len = s[i].length(); 
132         string str = s[i];
133         char t[8000]="";
134         for(int j=0;j<col;j++){
135             label[i][0] = 1.0;                //需要在每一个样例前面加上一个 1 
136             label[i][j+1] = cut_t(s[i],j);
137         }
138         // for(int j=0;j<=col;j++)  cout<<label[i][j]<<" ";
139         //cout<<endl;
140         for(int w=0;w<len;w++){
141              t[w] = str[w]; 
142         }
143         const char *d = " , \n" ;
144         char* p = strtok(t,d); 
145         while(p)
146         {
147             p=strtok(NULL,d);
148         }
149     }
150     
151     for(int i=0; i<num1; i++){
152         
153         int len1 = ss[i].length(); 
154         string str1 = ss[i];
155         char tt[8000]="";
156         //cout<<ss[i]<<endl;
157         for(int j=0;j<col-1;j++){
158             test_label[i][0] = 1.0;                //需要在每一个样例前面加上一个 1 
159             test_label[i][j+1] = cut_t(ss[i],j);
160             //cout<<test_label[i][j]<<" ";
161         }
162          //for(int j=0;j<67;j++)  cout<<label[i][j]<<endl;
163         //cout<<endl;
164         for(int w=0;w<len1;w++){
165              tt[w] = str1[w]; 
166         }
167         const char *d = " , \n" ;
168         char* p = strtok(tt,d); 
169         while(p)
170         {
171             p=strtok(NULL,d);
172         }
173     }
174     
175     for(int i=0; i<num2; i++){
176         
177         int len2 = s2[i].length(); 
178         string str2 = s2[i];
179         char t2[8000]="";
180         for(int j=0;j<col;j++){
181             valition_label[i][0] = 1.0;                //需要在每一个样例前面加上一个 1 
182             valition_label[i][j+1] = cut_t(s2[i],j);
183         }
184         // for(int j=0;j<=col;j++)  cout<<label[i][j]<<" ";
185         //cout<<endl;
186         for(int w=0;w<len2;w++){
187              t2[w] = str2[w]; 
188         }
189         const char *d2 = " , \n" ;
190         char* p2 = strtok(t2,d2); 
191         while(p2)
192         {
193             p2=strtok(NULL,d2);
194         }
195     }
196 /***************************************************************** PLA算法执行  ************************************************/  
197  
198 
199  
200  
201     double w[col];                   //初始的 w[] 数组 
202     double new_w[col];
203     double zhishu[row];
204     for(int j=0;j<col;j++){
205         w[j] = 1.0; 
206     }
207     for(int ui=0;ui<row;ui++){
208         zhishu[ui] = 0.0;
209     }
210     
211     int a = 6000;                    //由于不能全部划分,所以设立一个最大次数 
212     double error = 0.5;
213     while(a--){
214         //double error = geterror(a);  //调整步长 
215                 
216         for(int j=0;j<col;j++){    //初始化数组,用来更新w[]数组 
217             new_w[j] = 0.0;
218         } 
219         for(int i=0;i<num;i++){    // 遍历所有样本进行一轮迭代 
220             for(int j=0;j<col;j++){
221                zhishu[i] += label[i][j]*w[j]; // 对每一个导数进行存储 
222             }                     
223             //进行logistic变换 
224             zhishu[i] = logistic(zhishu[i]) - label[i][col];  
225         }
226         bool flag = true;    //判断是否收敛 
227         for(int jt=0;jt<col;jt++){
228             for(int it=0;it<num;it++){  //更新 w[] 
229                  new_w[jt] +=  label[it][jt]*zhishu[it];
230             }
231             new_w[jt] = new_w[jt]*error;   
232             if(new_w[jt] != 0) flag = false; 
233             w[jt] = w[jt] - new_w[jt];  //为下一次迭代 w[] 
234         }
235         if(flag){  //如果收敛 
236              cout<<a<<endl;
237              cout<<"完美收敛,提前结束"<<endl; 
238              break;
239         }
240         
241     }
242     
243     //统计各个指标 
244     double TP = 0.0;
245     double FN = 0.0;
246     double TN = 0.0;
247     double FP = 0.0;
248     double Acc = 0.0;
249     double Rec = 0.0;
250     double Pre = 0.0;
251     double F1 = 0.0;
252      
253  
254     for(int i=0;i<num2;i++){
255        int flag1 = 1;
256        double sum2 = 0.0;
257        for(int j=0;j<col;j++){
258             sum2 += valition_label[i][j]*w[j];
259        }
260        if(logistic(sum2) < 0.5) flag1 = 0;
261        else flag1 = 1;
262        
263        
264        if(flag1 == 1 && valition_label[i][col] == 1) TP++;
265        else if(flag1 == 0 && valition_label[i][col] == 1) FN++;
266        else if(flag1 == 0 && valition_label[i][col] == 0) TN++;
267        else FP++;
268     }
269     cout<<"TP = "<<TP<<endl;
270     cout<<"TN = "<<TN<<endl;
271     cout<<"FN = "<<FN<<endl;
272     cout<<"FP = "<<FP<<endl;
273     Acc = (TP+TN)/(TP+TN+FP+FN);
274     Rec = TP/(TP+FN);
275     Pre = TP/(TP+FP);
276     F1 = 2*Pre*Rec / (Pre+Rec);
277 
278     cout<<"Acc = "<<Acc<<endl;
279     cout<<"Rec = "<<Rec<<endl;
280     cout<<"Pre = "<<Pre<<endl;
281     cout<<"F1 = "<<F1<<endl;
282 
283 
284 
285     for(int k=0;k<num1;k++){
286         
287        int flag3 = 0;
288        double sum3 = 0.0;
289        for(int j=0;j<col;j++){
290             sum3 += test_label[k][j]*w[j];
291        }
292 
293        if(logistic(sum3) <  0.5 ) flag3 = 0;
294        else flag3 = 1;
295        cout<<flag3<<endl; 
296     }
297     
298     test.close(); 
299     myfile.close();
300     return 0;
301 }

上面是原始的PLA实现,下面是PLA基于口袋算法的优化:

  1 #include <vector>
  2 #include<iomanip>
  3 #include <string>
  4 #include<stdio.h>
  5 #include<string.h> 
  6 #include <fstream>
  7 #include <iostream>
  8 #include<set>
  9 #include<algorithm>
 10 #include<cstdio>
 11 #include<iomanip>
 12 #include<map>
 13 #include<cmath>
 14 using namespace std;
 15 double label[4010][80];
 16 string s[4010];
 17 
 18 double cut_t(string s, int t){
 19     string str = s;
 20     int r = 0;
 21     double count = 0.0;
 22     bool flag = true;
 23     double flag1 = 1.0;
 24     double sum = 0.0;
 25     for(int i=0;i<str.length();i++){
 26         
 27         if(r==t && str[i] == '-'){
 28             flag1 = -1;
 29             continue;
 30         }
 31         if(str[i]==','){
 32             r++;
 33             continue;
 34         }
 35         if(r==t){
 36             if(flag == false){
 37                 count ++;
 38             }
 39             if(str[i] == '.'){
 40                 flag = false;
 41             }
 42             else {
 43                 sum = sum + (str[i] - '0') * 1.0;
 44                 sum = sum * 10;
 45             }
 46 
 47         }
 48     }
 49   
 50     for(int i=0;i<=count;i++){
 51          sum = sum/10;
 52     }
 53     return sum*flag1;
 54 }
 55 
 56 
 57 
 58 int main()
 59 {
 60 
 61 
 62 /*************************************************读文件***********************************************************/ 
 63   
 64     fstream myfile("F:\\AI_data\\lab3\\train.txt");
 65     int num=0;
 66     string temp;
 67     if (!myfile.is_open())
 68     {
 69         cout << "未成功打开文件" << endl;
 70     }
 71     while(getline(myfile,temp))   //读入文本中的词 
 72     {
 73         s[num] = temp;
 74         num++;
 75     }        
 76     
 77    
 78 /***********************************************处理文本********************************************************************/ 
 79     for(int i=0; i<num; i++){
 80         
 81         int len = s[i].length(); 
 82         string str = s[i];
 83         char t[8000]="";
 84         for(int j=0;j<66;j++){
 85             label[i][0] = 1.0;                //需要在每一个样例前面加上一个 1 
 86             label[i][j+1] = cut_t(s[i],j);
 87         }
 88          //for(int j=0;j<67;j++)  cout<<label[i][j]<<endl;
 89         // cout<<endl;
 90         for(int w=0;w<len;w++){
 91              t[w] = str[w]; 
 92         }
 93         const char *d = " , \n" ;
 94         char* p = strtok(t,d); 
 95         while(p)
 96         {
 97             p=strtok(NULL,d);
 98         }
 99     }
100   
101 /***************************************************************** PLA算法执行  ************************************************/  
102  
103     double w[66];                   //初始的 w[] 数组
104     double change_w[66]; 
105     //double w[7];
106     double store[4010];
107     double sum = 0.0;
108     for(int j=0;j<66;j++){
109         w[j] = 1.0;
110         change_w[j] = 1.0;
111     }
112     int a = 2000;
113     
114      
115     while(a--){                             //规定迭代次数 
116         
117         bool flag2 = true;
118         long double counter_right1 = 0;     //两次的正确的数目统计 
119         long double counter_right2 = 0;
120         int dex = 0;
121          
122         for(int i=0;i<num;i++){             //遍历所有数据 
123             
124             sum = 0.0;
125             
126             for(int j=0;j<66;j++){          //进行计算 
127                 
128                 sum += label[i][j]*w[j];
129                 //cout<< i << "   "<<j<<endl; 
130             }
131             //cout<<"sum= "<<sum<<endl;
132             int flag = 0;
133             
134             if(sum > 0.0){                 //对结果的符号进行判断 
135                 flag = 1;
136             }
137             else{
138                 flag = -1;
139             }
140         
141             //cout<<flag << "   "<<label[i][66]<<endl; 
142             if(flag != label[i][66] ){    //判断结果是否是正确的,不正确需要考虑这个w[] 
143                 if(flag2){
144                     for(int k=0;k<66;k++){
145                        change_w[k] = w[k] + label[i][k]*label[i][66];
146                        dex = i;
147                        //cout<<w[k]<<endl;
148                     }
149                 }
150                 flag2 = false;           //一次只考虑第一个不正确的 w[] 
151             }
152             else counter_right1++;       //记录第一个 w[] 的正确率 
153         } 
154         
155         
156         for(int i=0;i<num;i++){         //遍历所有数据 
157             
158             sum = 0.0;
159             
160             for(int j=0;j<66;j++){       //用第二个w[]进行迭代 
161                 
162                 sum += label[i][j]*change_w[j];
163                 //cout<< i << "   "<<j<<endl; 
164             }
165             //cout<<"sum= "<<sum<<endl;
166             int flag = 0;
167              
168             if(sum > 0.0){              //算出结果的符号 
169                 flag = 1;
170             }
171             else{
172                 flag = -1;
173             }
174                                        //记录正确率 
175             //cout<<flag << "   "<<label[i][66]<<endl; 
176             if(flag == label[i][66] )  counter_right2++;
177         }
178         
179         
180         //两个w[]数组正确率比较 ,第一个正确率高则返回原来的w[],否则w[] 替换为更新后的,进入下一轮迭代 
181         if(counter_right1 > counter_right2){
182             for(int j=0;j<66;j++){
183                 w[j] = change_w[j] - label[dex][j]*label[dex][66];   
184             }
185         }
186         else{
187             for(int j=0;j<66;j++){
188                 w[j] = change_w[j] ;
189             } 
190         }         
191         
192     }
193     
194     
195     double TP = 0.0;
196     double FN = 0.0;
197     double TN = 0.0;
198     double FP = 0.0;
199     double Acc = 0.0;
200     double Rec = 0.0;
201     double Pre = 0.0;
202     double F1 = 0.0;
203      
204     for(int i=0;i<num;i++){
205        int flag1 = 1;
206        double sum1 = 0.0;
207        for(int j=0;j<66;j++){
208             sum1 += label[i][j]*w[j];
209             //cout<<" w[] = "<<w[j]<<endl; 
210        }
211        cout<<sum1<<endl;
212        
213        if(sum1 >= 0 ) flag1 = 1;
214        else flag1 = -1;
215        
216        cout<<flag1<<" ";
217        cout<<label[i][66]<<endl;
218        if(flag1 == 1 && label[i][66] == 1) TP++;
219        else if(flag1 == -1 && label[i][66] == 1) FN++;
220        else if(flag1 == -1 && label[i][66] == -1) TN++;
221        else if(flag1 == 1 && label[i][66] == -1)FP++;
222        
223     }
224     cout<<TP<<endl;
225     cout<<TN<<endl;
226     cout<<FN<<endl;
227     cout<<TN<<endl;
228     Acc = (TP+TN)/(TP+TN+FP+FN);
229     Rec = TP/(TP+FN);
230     Pre = TP/(TP+FP);
231     F1 = 2*Pre*Rec / (Pre+Rec);
232 
233     cout<<"Acc = "<<Acc<<endl;
234     cout<<"Rec = "<<Rec<<endl;
235     cout<<"Pre = "<<Pre<<endl;
236     cout<<"F1 = "<<F1<<endl;
237     for(int k=0;k<66;k++){
238         cout<< w[k] <<endl;
239     }
240      
241     myfile.close();
242     return 0;
243 }
点赞
收藏
评论区
推荐文章
blmius blmius
2年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
Easter79 Easter79
2年前
swap空间的增减方法
(1)增大swap空间去激活swap交换区:swapoff v /dev/vg00/lvswap扩展交换lv:lvextend L 10G /dev/vg00/lvswap重新生成swap交换区:mkswap /dev/vg00/lvswap激活新生成的交换区:swapon v /dev/vg00/lvswap
Jacquelyn38 Jacquelyn38
2年前
2020年前端实用代码段,为你的工作保驾护航
有空的时候,自己总结了几个代码段,在开发中也经常使用,谢谢。1、使用解构获取json数据let jsonData  id: 1,status: "OK",data: 'a', 'b';let  id, status, data: number   jsonData;console.log(id, status, number )
皕杰报表之UUID
​在我们用皕杰报表工具设计填报报表时,如何在新增行里自动增加id呢?能新增整数排序id吗?目前可以在新增行里自动增加id,但只能用uuid函数增加UUID编码,不能新增整数排序id。uuid函数说明:获取一个UUID,可以在填报表中用来创建数据ID语法:uuid()或uuid(sep)参数说明:sep布尔值,生成的uuid中是否包含分隔符'',缺省为
Wesley13 Wesley13
2年前
Java获得今日零时零分零秒的时间(Date型)
publicDatezeroTime()throwsParseException{    DatetimenewDate();    SimpleDateFormatsimpnewSimpleDateFormat("yyyyMMdd00:00:00");    SimpleDateFormatsimp2newS
Stella981 Stella981
2年前
KVM调整cpu和内存
一.修改kvm虚拟机的配置1、virsheditcentos7找到“memory”和“vcpu”标签,将<namecentos7</name<uuid2220a6d1a36a4fbb8523e078b3dfe795</uuid
Wesley13 Wesley13
2年前
mysql设置时区
mysql设置时区mysql\_query("SETtime\_zone'8:00'")ordie('时区设置失败,请联系管理员!');中国在东8区所以加8方法二:selectcount(user\_id)asdevice,CONVERT\_TZ(FROM\_UNIXTIME(reg\_time),'08:00','0
Wesley13 Wesley13
2年前
00:Java简单了解
浅谈Java之概述Java是SUN(StanfordUniversityNetwork),斯坦福大学网络公司)1995年推出的一门高级编程语言。Java是一种面向Internet的编程语言。随着Java技术在web方面的不断成熟,已经成为Web应用程序的首选开发语言。Java是简单易学,完全面向对象,安全可靠,与平台无关的编程语言。
Stella981 Stella981
2年前
Django中Admin中的一些参数配置
设置在列表中显示的字段,id为django模型默认的主键list_display('id','name','sex','profession','email','qq','phone','status','create_time')设置在列表可编辑字段list_editable
Wesley13 Wesley13
2年前
MySQL部分从库上面因为大量的临时表tmp_table造成慢查询
背景描述Time:20190124T00:08:14.70572408:00User@Host:@Id:Schema:sentrymetaLast_errno:0Killed:0Query_time:0.315758Lock_
Python进阶者 Python进阶者
3个月前
Excel中这日期老是出来00:00:00,怎么用Pandas把这个去除
大家好,我是皮皮。一、前言前几天在Python白银交流群【上海新年人】问了一个Pandas数据筛选的问题。问题如下:这日期老是出来00:00:00,怎么把这个去除。二、实现过程后来【论草莓如何成为冻干莓】给了一个思路和代码如下:pd.toexcel之前把这