Skip to content

Commit

Permalink
fix: add IOException for TaskManager interface #1598
Browse files Browse the repository at this point in the history
  • Loading branch information
bamthomas committed Nov 20, 2024
1 parent 16fe540 commit 4771cd5
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,15 @@ public TaskResource(final DatashareTaskFactory taskFactory, final TaskManager ta
parameters = {@Parameter(name = "filter", description = "pattern contained in the task name", in = ParameterIn.QUERY)})
@ApiResponse(responseCode = "200", description = "returns the list of tasks", useReturnTypeSchema = true)
@Get("/all")
public List<Task<?>> tasks(Context context) {
public List<Task<?>> tasks(Context context) throws IOException {
Pattern pattern = Pattern.compile(StringUtils.isEmpty(context.get("filter")) ? ".*": String.format(".*%s.*", context.get("filter")));
return taskManager.getTasks((User) context.currentUser(), pattern);
}

@Operation(description = "Gets one task with its id.")
@ApiResponse(responseCode = "200", description = "returns the task from its id", useReturnTypeSchema = true)
@Get("/:id")
public Task<?> getTask(@Parameter(name = "id", description = "task id", in = ParameterIn.PATH) String id) {
public Task<?> getTask(@Parameter(name = "id", description = "task id", in = ParameterIn.PATH) String id) throws IOException {
return notFoundIfNull(taskManager.getTask(id));
}

Expand Down Expand Up @@ -223,15 +223,15 @@ public TaskResponse scanFile(@Parameter(name = "filePath", description = "path o
@Operation(description = "Cleans all DONE tasks.")
@ApiResponse(responseCode = "200", description = "returns 200 and the list of removed tasks", useReturnTypeSchema = true)
@Post("/clean")
public List<Task<?>> cleanDoneTasks() {
public List<Task<?>> cleanDoneTasks() throws IOException {
return taskManager.clearDoneTasks();
}

@Operation(description = "Cleans a specific task.")
@ApiResponse(responseCode = "200", description = "returns 200 if the task is removed")
@ApiResponse(responseCode = "403", description = "returns 403 if the task is still in RUNNING state")
@Delete("/clean/:taskName:")
public Payload cleanTask(@Parameter(name = "taskName", description = "name of the task to delete", in = ParameterIn.PATH) final String taskId, Context context) {
public Payload cleanTask(@Parameter(name = "taskName", description = "name of the task to delete", in = ParameterIn.PATH) final String taskId, Context context) throws IOException {
Task<?> task = forbiddenIfNotSameUser(context, notFoundIfNull(taskManager.getTask(taskId)));
if (task.getState() == Task.State.RUNNING) {
return forbidden();
Expand All @@ -251,7 +251,7 @@ public Payload cleanTaskPreflight(final String taskName) {
@Operation(description = "Cancels the task with the given name.")
@ApiResponse(responseCode = "200", description = "returns 200 with the cancellation status (true/false)", useReturnTypeSchema = true)
@Put("/stop/:taskId:")
public boolean stopTask(@Parameter(name = "taskName", description = "name of the task to cancel", in = ParameterIn.PATH) final String taskId) {
public boolean stopTask(@Parameter(name = "taskName", description = "name of the task to cancel", in = ParameterIn.PATH) final String taskId) throws IOException {
return taskManager.stopTask(notFoundIfNull(taskManager.getTask(taskId)).id);
}

Expand All @@ -266,7 +266,7 @@ public Payload stopTaskPreflight(final String taskName) {
"If the status is false, it means that the thread has not been stopped.")
@ApiResponse(responseCode = "200", description = "returns 200 and the tasks stop result map", useReturnTypeSchema = true)
@Put("/stopAll")
public Map<String, Boolean> stopAllTasks(final Context context) {
public Map<String, Boolean> stopAllTasks(final Context context) throws IOException {
return taskManager.stopAllTasks((User) context.currentUser());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.io.Closeable;
import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
Expand All @@ -25,15 +24,15 @@
public interface TaskManager extends Closeable {
Logger logger = LoggerFactory.getLogger(TaskManager.class);

boolean stopTask(String taskId);
<V> Task<V> clearTask(String taskId);
boolean shutdownAndAwaitTermination(int timeout, TimeUnit timeUnit) throws InterruptedException;
<V> Task<V> getTask(String taskId);
List<Task<?>> getTasks();
List<Task<?>> getTasks(User user, Pattern pattern);
List<Task<?>> clearDoneTasks();
void clear();
boolean save(Task<?> task);
boolean stopTask(String taskId) throws IOException;
<V> Task<V> clearTask(String taskId) throws IOException;
boolean shutdownAndAwaitTermination(int timeout, TimeUnit timeUnit) throws InterruptedException, IOException;
<V> Task<V> getTask(String taskId) throws IOException;
List<Task<?>> getTasks() throws IOException;
List<Task<?>> getTasks(User user, Pattern pattern) throws IOException;
List<Task<?>> clearDoneTasks() throws IOException;
void clear() throws IOException;
boolean save(Task<?> task) throws IOException;
void enqueue(Task<?> task) throws IOException;

static List<Task<?>> getTasks(Stream<Task<?>> stream, User user, Pattern pattern) {
Expand All @@ -43,11 +42,18 @@ static List<Task<?>> getTasks(Stream<Task<?>> stream, User user, Pattern pattern
collect(toList());
}

default Map<String, Boolean> stopAllTasks(User user) {
default Map<String, Boolean> stopAllTasks(User user) throws IOException {
return getTasks().stream().
filter(t -> user.equals(t.getUser())).
filter(t -> t.getState() == Task.State.RUNNING || t.getState() == Task.State.QUEUED).collect(
toMap(t -> t.id, t -> stopTask(t.id)));
toMap(t -> t.id, t -> {
try {
return stopTask(t.id);
} catch (IOException e) {
logger.error("cannot stop task {}", t.id, e);
return false;
}
}));
}


Expand Down Expand Up @@ -88,7 +94,7 @@ default <V> String startTask(Task<V> taskView) throws IOException {
return saved ? taskView.id: null;
}

default <V extends Serializable> Task<V> setResult(ResultEvent<V> e) {
default <V extends Serializable> Task<V> setResult(ResultEvent<V> e) throws IOException {
Task<V> taskView = getTask(e.taskId);
if (taskView != null) {
logger.info("result event for {}", e.taskId);
Expand All @@ -100,7 +106,7 @@ default <V extends Serializable> Task<V> setResult(ResultEvent<V> e) {
return taskView;
}

default <V extends Serializable> Task<V> setError(ErrorEvent e) {
default <V extends Serializable> Task<V> setError(ErrorEvent e) throws IOException {
Task<V> taskView = getTask(e.taskId);
if (taskView != null) {
logger.info("error event for {}", e.taskId);
Expand All @@ -112,8 +118,8 @@ default <V extends Serializable> Task<V> setError(ErrorEvent e) {
return taskView;
}

default Task<?> setCanceled(CancelledEvent e) {
Task<?> taskView = getTask(e.taskId);
default <V> Task<V> setCanceled(CancelledEvent e) throws IOException {
Task<V> taskView = getTask(e.taskId);
if (taskView != null) {
logger.info("canceled event for {}", e.taskId);
taskView.cancel();
Expand All @@ -131,35 +137,45 @@ default Task<?> setCanceled(CancelledEvent e) {
return taskView;
}

default Task<?> setProgress(ProgressEvent e) {
default <V> Task<V> setProgress(ProgressEvent e) throws IOException {
logger.debug("progress event for {}", e.taskId);
Task<?> taskView = getTask(e.taskId);
Task<V> taskView = getTask(e.taskId);
if (taskView != null) {
taskView.setProgress(e.progress);
save(taskView);
}
return taskView;
}

default <V extends Serializable> Task<?> handleAck(TaskEvent e) {
if (e instanceof CancelledEvent) {
return setCanceled((CancelledEvent) e);
}
if (e instanceof ResultEvent) {
return setResult(((ResultEvent<V>) e));
}
if (e instanceof ErrorEvent) {
return setError((ErrorEvent) e);
}
if (e instanceof ProgressEvent) {
return setProgress((ProgressEvent)e);
default <V extends Serializable> Task<V> handleAck(TaskEvent e) {
try {
if (e instanceof CancelledEvent ce) {
return setCanceled(ce);
}
if (e instanceof ResultEvent) {
return setResult((ResultEvent<V>) e);
}
if (e instanceof ErrorEvent ee) {
return setError(ee);
}
if (e instanceof ProgressEvent pe) {
return setProgress(pe);
}
logger.warn("received event not handled {}", e);
return null;
} catch (IOException ioe) {
throw new TaskEventHandlingException(ioe);
}
logger.warn("received event not handled {}", e);
return null;
}

default void foo(String taskId, StateLatch stateLatch) {
// for tests
default void setLatch(String taskId, StateLatch stateLatch) throws IOException {
getTask(taskId).setLatch(stateLatch);
}

class TaskEventHandlingException extends RuntimeException {
public TaskEventHandlingException(Exception cause) {
super(cause);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,9 @@ public <V> Task<V> clearTask(String taskId) {
}

@Override
public boolean shutdownAndAwaitTermination(int timeout, TimeUnit timeUnit) throws InterruptedException {
try {
amqp.publish(AmqpQueue.WORKER_EVENT, new ShutdownEvent());
return true;
} catch (IOException e) {
throw new RuntimeException(e);
}
public boolean shutdownAndAwaitTermination(int timeout, TimeUnit timeUnit) throws InterruptedException, IOException {
amqp.publish(AmqpQueue.WORKER_EVENT, new ShutdownEvent());
return true;
}

public boolean save(Task<?> task) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public TaskManagerRedis(RedissonClient redissonClient, String taskMapName, Routi
this.tasks = new RedissonMap<>(new TaskViewCodec(), commandSyncService, taskMapName, redissonClient, null, null);
this.eventTopic = redissonClient.getTopic(EVENT_CHANNEL_NAME);
this.eventCallback = eventCallback;
addEventListener(this::handleEvent);
eventTopic.addListener(TaskEvent.class, (channelString, message) -> handleEvent(message));
}

@Override
Expand Down Expand Up @@ -105,11 +105,7 @@ public boolean stopTask(String taskId) {
}

public void handleEvent(TaskEvent e) {
ofNullable(TaskManager.super.handleAck(e)).ifPresent(t -> ofNullable(eventCallback).ifPresent(Runnable::run));
}

public void addEventListener(Consumer<TaskEvent> callback) {
eventTopic.addListener(TaskEvent.class, (channelString, message) -> callback.accept(message));
ofNullable(TaskManager.super.handleAck(e)).flatMap(t -> ofNullable(eventCallback)).ifPresent(Runnable::run);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ public TaskInspector(TaskManager taskManager) {
this.taskManager = taskManager;
}

public boolean awaitToBeStarted(String taskId, int timeoutMs) throws InterruptedException {
public boolean awaitToBeStarted(String taskId, int timeoutMs) throws Exception {
return this.awaitStatus(taskId, Task.State.RUNNING, timeoutMs, TimeUnit.MILLISECONDS);
}

public boolean awaitStatus(String taskId, Task.State state, long timeout, TimeUnit unit) throws InterruptedException {
public boolean awaitStatus(String taskId, Task.State state, long timeout, TimeUnit unit) throws Exception {
StateLatch stateLatch = new StateLatch();
taskManager.foo(taskId, stateLatch);
taskManager.setLatch(taskId, stateLatch);
return stateLatch.await(state, timeout, unit);
}
}

0 comments on commit 4771cd5

Please sign in to comment.