-
Notifications
You must be signed in to change notification settings - Fork 0
/
支持向量机SVM.py
54 lines (41 loc) · 1.24 KB
/
支持向量机SVM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.svm import LinearSVC
import time
start = time.time()
#沿用感知机模型里面的数据
Dict = {'x1':[1,4,3,2,5,6,6,7],'x2':[4,8,6,3,5,6,13,14],'label':[0,0,0,1,1,1,0,0]}
Data_table = pd.DataFrame(Dict)
plt.scatter(Data_table['x1'],Data_table['x2'],c=Data_table['label'])
svm_model = LinearSVC()
svm_model.fit(Data_table[['x1','x2']],Data_table['label'])
w_i = svm_model.coef_[0]
b = svm_model.intercept_
score = svm_model.score(Data_table[['x1','x2']],Data_table['label'])
print(score)
print(w_i)
print(b)
x = np.linspace(0,7,100)
y = (w_i[0]*x+b)/(-w_i[1])
x1 = np.linspace(0,7,100)
y1 = (w_i[0]*x+b+1)/(-w_i[1])
x2 = np.linspace(0,7,100)
y2 = (w_i[0]*x+b-1)/(-w_i[1])
plt.plot(x1,y1,linestyle='--',color = 'black')
plt.plot(x2,y2,linestyle='--',color = 'black')
plt.plot(x,y)
plt.title('Support Vector Machine Using Sklearn')
plt.xlabel('X')
plt.ylabel('Y')
text1 = r'$w_1=0.6511,w_2=-0.5302$'
text2 = r'$b=0.5674$'
text3 = r'$y_{i}(W^{T}X+b)\geq 1$'
plt.text(0,13,text1,fontsize=17)
plt.text(0,11,text2,fontsize=17)
plt.text(0,9,text3,fontsize = 17)
plt.grid()
plt.show()
end = time.time()
print('程序运行时间为: %s Seconds' % (end - start))
time = end-start