-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
47 lines (41 loc) · 2.24 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import streamlit as st
import numpy as np
model = np.load('weight.npz')
x_mean = model['x_mean']
x_std = model['x_std']
theta = model['theta']
@st.cache_resource
#Define the prediction function
def predict(carat, cut, color, clarity, depth, table, x, y, z, x_mean, x_std, theta):
# Mapping for cut
cut_mapping = {'Fair': 0, 'Good': 1, 'Very Good': 2, 'Premium': 3, 'Ideal': 4}
# Mapping for color
color_mapping = {'J': 0, 'I': 1, 'H': 2, 'G': 3, 'F': 4, 'E': 5, 'D': 6}
# Mapping for clarity
clarity_mapping = {'I1': 0, 'SI2': 1, 'SI1': 2, 'VS2': 3, 'VS1': 4, 'VVS2': 5, 'VVS1': 6, 'IF': 7}
# Transform the categorical variables to numerical values
cut = cut_mapping.get(cut, 0)
color = color_mapping.get(color, 0)
clarity = clarity_mapping.get(clarity, 0)
input = np.array([[carat, cut, color, clarity, depth, table, x, y, z]], dtype='float')
input = (input - x_mean)/x_std
b = np.array([[1.0]])
input = np.concatenate((b, input), axis=1)
prediction = input.dot(theta)
return prediction
# st.snow()
#Add title for my app
st.title('💎DIAMOND PRICE PREDICTION 💎')
st.header('Vui lòng nhập các đặc trưng của viên kim cương bạn muốn mua:')
carat = st.number_input('Carat Weight:', min_value=0.1, max_value=10.0, value=1.0)
cut = st.selectbox('Cut Rating:', ['Fair', 'Good', 'Very Good', 'Premium', 'Ideal'])
color = st.selectbox('Color Rating:', ['J', 'I', 'H', 'G', 'F', 'E', 'D'])
clarity = st.selectbox('Clarity Rating:', ['I1', 'SI2', 'SI1', 'VS2', 'VS1', 'VVS2', 'VVS1', 'IF'])
depth = st.number_input('Diamond Depth Percentage:', min_value=0.1, max_value=100.0, value=1.0)
table = st.number_input('Diamond Table Percentage:', min_value=0.1, max_value=100.0, value=1.0)
x = st.number_input('Diamond Length (X) in mm:', min_value=0.1, max_value=100.0, value=1.0)
y = st.number_input('Diamond Width (Y) in mm:', min_value=0.1, max_value=100.0, value=1.0)
z = st.number_input('Diamond Height (Z) in mm:', min_value=0.1, max_value=100.0, value=1.0)
if st.button('Predict Price'):
out = predict(carat, cut, color, clarity, depth, table, x, y, z, x_mean, x_std, theta)
st.success(f'Giá dự đoán của viên kim cương là: ${out[0,0]:.2f} USD')