Skip to content

Commit

Permalink
[Issue #745] add ML-based auto-scaling policy (#818)
Browse files Browse the repository at this point in the history
  • Loading branch information
slzo authored Jan 3, 2025
1 parent c625127 commit f6396dc
Show file tree
Hide file tree
Showing 11 changed files with 525 additions and 5 deletions.
12 changes: 12 additions & 0 deletions pixels-common/src/main/resources/pixels.properties
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,18 @@ vm.key.name=pixels
# use `,` to add multiple security groups, e.g. `pixels-sg1,pixels-sg2`
vm.security.groups=pixels-sg1

# choose the suit python path
# if you use global python instead of python-venv, you can easily modify it to python
python.env.path = /home/ubuntu/dev/venvpixels/bin/python
# path of code used to forecast the resource usage
forecast.code.path = /home/ubuntu/dev/pixels/pixels-daemon/src/main/java/io/pixelsdb/pixels/daemon/scaling/policy/helper/forecast.py
# path of historyData
pixels.historyData.dir=/home/ubuntu/opt/pixels/historyData/
# split cputime (ms)
cpuspl = [10000,60000,300000,600000]
# split mem (G)
memspl = [1,8,16,32,64]

###### experimental settings ######
# the rate of free memory in jvm
experimental.gc.threshold=0.3
Expand Down
13 changes: 13 additions & 0 deletions pixels-daemon/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@
<optional>true</optional>
</dependency>

<!-- trino-jdbc -->
<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-jdbc</artifactId>
</dependency>
<!-- grpc -->
<dependency>
<groupId>io.grpc</groupId>
Expand Down Expand Up @@ -189,6 +194,14 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>11</source>
<target>11</target>
</configuration>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright 2024 PixelsDB.
*
* This file is part of Pixels.
*
* Pixels is free software: you can redistribute it and/or modify
* it under the terms of the Affero GNU General Public License as
* published by the Free Software Foundation, either version 3 of
* the License, or (at your option) any later version.
*
* Pixels is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Affero GNU General Public License for more details.
*
* You should have received a copy of the Affero GNU General Public
* License along with Pixels. If not, see
* <https://www.gnu.org/licenses/>.
*/
package io.pixelsdb.pixels.daemon.scaling.policy;

import java.io.IOException;
import java.io.InputStreamReader;

import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.pixelsdb.pixels.common.utils.ConfigFactory;
import io.pixelsdb.pixels.daemon.TransProto;
import io.pixelsdb.pixels.daemon.TransServiceGrpc;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.BufferedReader;
import java.util.Arrays;

public class AIASPolicy extends Policy
{
private static final Logger log = LogManager.getLogger(BasicPolicy.class);

private void scaling(double[] cpuTimes, double[] memUsages)
{
/* the unit of cpuTimes is 5mins = the forecast interval
spl = [10*1000, 60*1000, 5*60*1000, 10*60*1000] #10s 1min 5min 10min
memspl = [G, 8*G, 16*G, 32*G, 64*G, 128*G] #1G 8G 16G 32G 64G 128G */
int WorkerNum1 = (int)Math.ceil((Arrays.stream(cpuTimes).sum()-cpuTimes[cpuTimes.length-1])/8);
int WorkerNum2 = (int)Math.ceil((Arrays.stream(memUsages).sum()-memUsages[memUsages.length-1])/32);
int WorkerNum = Math.max(WorkerNum1, WorkerNum2);
scalingManager.expendTo(WorkerNum);
log.info("INFO: expend to "+ WorkerNum1 + " or " + WorkerNum2);
}

public boolean transDump(long timestamp){
String host = ConfigFactory.Instance().getProperty("trans.server.host");
int port = Integer.parseInt(ConfigFactory.Instance().getProperty("trans.server.port"));
ManagedChannel channel = ManagedChannelBuilder.forAddress(host,port)
.usePlaintext().build();
TransServiceGrpc.TransServiceBlockingStub stub = TransServiceGrpc.newBlockingStub(channel);
TransProto.DumpTransRequest request = TransProto.DumpTransRequest.newBuilder().setTimestamp(timestamp).build();
TransProto.DumpTransResponse response = stub.dumpTrans(request);
channel.shutdownNow();
return true;
}

@Override
public void doAutoScaling()
{
try
{
ConfigFactory config = ConfigFactory.Instance();
long timestamp = System.currentTimeMillis();
timestamp /= 7*24*60*60*1000;
transDump(timestamp);
// forecast according to history data
String[] cmd = {config.getProperty("python.env.path"),
config.getProperty("forecast.code.path"),
config.getProperty("pixels.historyData.dir")+ timestamp + ".csv",
config.getProperty("pixels.historyData.dir")+ (timestamp-1) + ".csv",
config.getProperty("cpuspl"),
config.getProperty("memspl")
};
Process proc = Runtime.getRuntime().exec(cmd);
proc.waitFor();
BufferedReader br1 = new BufferedReader(new InputStreamReader(proc.getInputStream()));
BufferedReader br2 = new BufferedReader(new InputStreamReader(proc.getErrorStream()));

String line;
while ((line = br2.readLine()) != null) {
log.error(line);
}
line = br1.readLine();
double[] cpuTimes = Arrays.stream(line.split(" "))
.mapToDouble(Double::parseDouble)
.toArray();
line = br1.readLine();
double[] memUsages = Arrays.stream(line.split(" "))
.mapToDouble(Double::parseDouble)
.toArray();
scaling(cpuTimes, memUsages);
} catch (IOException e)
{
e.printStackTrace();
} catch (InterruptedException e)
{
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,47 @@
*/
package io.pixelsdb.pixels.daemon.scaling.policy;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class PidPolicy extends Policy
{
private static final Logger log = LogManager.getLogger(BasicPolicy.class);
// add a control to make scaling-choice
class pidcontrollor
{
private float kp, ki, kd;
private int queryConcurrencyTarget, integral, prerror;
pidcontrollor(){
this.kp = 0.8F;
this.ki = 0.05F;
this.kd = 0.05F;
this.queryConcurrencyTarget = 1;
this.integral = 0;
this.prerror = 0;
}
public float calcuator(int queryConcurrency){
int error = queryConcurrency - queryConcurrencyTarget;
integral += error;
int derivative = error - prerror;
float re = kp*error + ki*integral + kd*derivative;
prerror = error;
return re;
}
public void argsuit(){ // finetuing the arg:kp ki kd
return;
}
}

private pidcontrollor ctl = new pidcontrollor();
@Override
public void doAutoScaling()
{
int queryConcurrency = metricsQueue.getLast();
System.out.println("Receive metrics:" + metricsQueue);
System.out.println("TODO: pid policy");
float sf = ctl.calcuator(queryConcurrency);
log.info("INFO: expand " + sf +" vm");
scalingManager.multiplyInstance(sf);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
import sys
import copy
import duckdb
import datetime
import pandas as pd
from statsforecast import StatsForecast
from statsforecast.models import AutoARIMA

# the split function
#1.split the query to small/mid/big query by the cputime
#2.split the query to different timestrap every 5 mins
def dataprocess(df, startime, delta, startime2, delta2, cpuspl, memspl):
# G = 1024**3
# memspl = [G, 8*G, 16*G, 32*G, 64*G, 128*G] #1G 8G 16G 32G 64G 128G
mems = [{startime2:[]} for i in range(len(memspl)+1)]

# cpuspl = [10*1000, 60*1000, 5*60*1000, 10*60*1000] #10s 1min 5min 10min
cpuTimes = [{startime:[]} for i in range(len(cpuspl)+1)]

for i in df.itertuples():
if (startime+delta) <= i.createdTime:
startime += delta
while (startime+delta) < i.createdTime:
for j in range(len(cpuspl)+1):
cpuTimes[j][startime] = []
startime += delta
flag = 0
for j in range(len(cpuspl)):
cpuTimes[j][startime] = []
if (not flag and i.cpuTimeTotal <= cpuspl[j]):
cpuTimes[j][startime].append(i.cpuTimeTotal)
flag = 1
cpuTimes[len(cpuspl)][startime] = []
if not flag:
cpuTimes[len(cpuspl)][startime].append(i.cpuTimeTotal)
else:
flag = 0
for j in range(len(cpuspl)):
if i.cpuTimeTotal <= cpuspl[j]:
cpuTimes[j][startime].append(i.cpuTimeTotal)
flag = 1
break
if not flag:
cpuTimes[len(cpuspl)][startime].append(i.cpuTimeTotal)
#----------------------------------------------------------#
if (startime2+delta2) <= i.createdTime:
startime2 += delta2
while (startime2+delta2) < i.createdTime:
for j in range(len(memspl)+1):
mems[j][startime2] = []
startime2 += delta2
flag = 0
for j in range(len(memspl)):
mems[j][startime2] = []
if (not flag and i.memoryUsed <= memspl[j]):
mems[j][startime2].append(i.memoryUsed)
flag = 1
mems[len(memspl)][startime2] = []
if not flag:
mems[len(memspl)][startime2].append(i.memoryUsed)
else:
flag = 0
for j in range(len(memspl)):
if i.memoryUsed <= memspl[j]:
mems[j][startime2].append(i.memoryUsed)
flag = 1
break
if not flag:
mems[len(memspl)][startime2].append(i.memoryUsed)

return cpuTimes,mems

# use autoARIMA forecast the cputime usage next 5min
def cpuTimeForecast(historyDatas, curtime):
scale = 5*60*1000 #5min
cpuTimes = []
for historyData in historyDatas:
ds = []
y = []
for i in historyData:
ds.append(i)
y.append(sum(historyData[i])/scale)
train_data = pd.DataFrame({'unique_id': [1]*len(historyData), 'ds': ds, 'y': y})
sf = StatsForecast(models=[AutoARIMA()], freq='5min')
sf.fit(train_data)
H = int((curtime-ds[-1])/(datetime.timedelta(minutes=15))) + 2
cpuTime = (sf.predict(h=H, level=[90]))['AutoARIMA-hi-90'].iloc[-1]
cpuTimes.append(cpuTime)
return cpuTimes

def memUasgeForecast(historyDatas, curtime):
scale = 1024**3 #1G
memUsages = []
for historyData in historyDatas:
ds = []
y = []
for i in historyData:
ds.append(i)
y.append(sum(historyData[i])/scale)
train_data = pd.DataFrame({'unique_id': [1]*len(historyData), 'ds': ds, 'y': y})
sf = StatsForecast(models=[AutoARIMA()], freq='5min')
sf.fit(train_data)
H = int((curtime-ds[-1])/(datetime.timedelta(minutes=15))) + 2
memUsage = (sf.predict(h=H, level=[90]))['AutoARIMA-hi-90'].iloc[-1]
# memUsage = [i/300 for i in memUsage]
memUsages.append(memUsage)
return memUsages



def main():
logfile = '\'' + sys.argv[1] + '\''
if os.path.exists(sys.argv[2]):
logfile = logfile + ',' + '\'' + sys.argv[2] + '\''
cpuspl = [int(i) for i in sys.argv[3][1:-1].split(',')]
memspl = [int(i)*(1024**3) for i in sys.argv[4][1:-1].split(',')]
curtime = datetime.datetime.now().replace(microsecond=0)
startime = curtime - datetime.timedelta(days=7)
que = " select createdTime, cpuTimeTotal, memoryUsed from read_csv_auto([" \
+ logfile \
+ "]) where createdTime between '" \
+ str(startime) \
+ "' and '" \
+ str(curtime) \
+ "' order by createdTime"
df = duckdb.query(que).df()

delta = datetime.timedelta(minutes=5)
cpuTimeData,memUsageData = dataprocess(df, startime, delta, startime, delta, cpuspl, memspl)

cpuTimes = cpuTimeForecast(cpuTimeData, curtime)
print(*cpuTimes)
memUsages = memUasgeForecast(memUsageData, curtime)
print(*memUsages)

main()
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.pixelsdb.pixels.daemon.scaling.MetricsQueue;
import io.pixelsdb.pixels.daemon.scaling.policy.BasicPolicy;
import io.pixelsdb.pixels.daemon.scaling.policy.PidPolicy;
import io.pixelsdb.pixels.daemon.scaling.policy.AIASPolicy;
import io.pixelsdb.pixels.daemon.scaling.policy.Policy;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -47,6 +48,9 @@ public PolicyManager()
case "pid":
policy = new PidPolicy();
break;
case "AIAS":
policy = new AIASPolicy();
break;
default:
policy = new BasicPolicy();
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,24 @@ public void multiplyInstance(float percent)
reduceSome(-count);
}
}

public void expendTo(int target)
{
int count = 0;
for (String id : instanceMap.keySet())
{
if (instanceMap.get(id) == InstanceState.RUNNING)
{
count++;
}
}
count = target - count;
if (count >= 0)
{
expandSome(count);
} else
{
reduceSome(-count);
}
}
}
Loading

0 comments on commit f6396dc

Please sign in to comment.