Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rawls attempts to fetch "rawls" workspaces from WSM in addition to "MC" #2825

Merged
merged 6 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,6 @@ class AggregatedWorkspaceService(workspaceManagerDAO: WorkspaceManagerDAO) exten
}
}

/**
* Optimized version of [[fetchAggregatedWorkspace]]
*
* If the provided workspace is not of type "MC", returns the provided "rawls" workspace with no WSM information, as
* it can be assumed to be GCP and does not need to call out to WSM for this.
*/
def optimizedFetchAggregatedWorkspace(workspace: Workspace, ctx: RawlsRequestContext): AggregatedWorkspace =
workspace.workspaceType match {
case WorkspaceType.RawlsWorkspace =>
AggregatedWorkspace(workspace, Some(workspace.googleProjectId), azureCloudContext = None, policies = List.empty)
case WorkspaceType.McWorkspace =>
fetchAggregatedWorkspace(workspace, ctx)
}

private def aggregateMCWorkspaceWithWSMInfo(workspace: Workspace,
wsmInfo: WorkspaceDescription
): AggregatedWorkspace =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,24 @@ class WorkspaceService(protected val ctx: RawlsRequestContext,
getV2WorkspaceContextAndPermissions(workspaceName, SamWorkspaceActions.read, Option(attrSpecs)) flatMap {
workspaceContext =>
dataSource.inTransaction { dataAccess =>
val wsmContext = wsmService.optimizedFetchAggregatedWorkspace(workspaceContext, ctx)
// some GCP workspaces, like those with linked snapshots, have a stub WSM workspace
ashanhol marked this conversation as resolved.
Show resolved Hide resolved
val wsmContext =
try
wsmService.fetchAggregatedWorkspace(workspaceContext, ctx)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need a test for this specific call or does that exist somewhere else already? (I'd assume so but wanted to check)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im still trying to figure that out

catch {
case e: AggregateWorkspaceNotFoundException =>
// return workspace with no WSM information for gcp workspace
if (workspaceContext.workspaceType == WorkspaceType.RawlsWorkspace) {
AggregatedWorkspace(workspaceContext,
Some(workspaceContext.googleProjectId),
azureCloudContext = None,
policies = List.empty
)
} else {
// bubble up an MC workspace exception
throw e
ashanhol marked this conversation as resolved.
Show resolved Hide resolved
}
}
Comment on lines +315 to +331
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is correct and in good shape, and I would not block the PR on changes here. However, the use of try/catch is Java-idiomatic, where the use of Try/Success/Failure is Scala-idiomatic. To be a good Scala citizen, you'd have something like this:

Suggested change
val wsmContext =
try
wsmService.fetchAggregatedWorkspace(workspaceContext, ctx)
catch {
case e: AggregateWorkspaceNotFoundException =>
// return workspace with no WSM information for gcp workspace
if (workspaceContext.workspaceType == WorkspaceType.RawlsWorkspace) {
AggregatedWorkspace(workspaceContext,
Some(workspaceContext.googleProjectId),
azureCloudContext = None,
policies = List.empty
)
} else {
// bubble up an MC workspace exception
throw e
}
}
val wsmContext =
Try(wsmService.fetchAggregatedWorkspace(workspaceContext, ctx)) match {
case Success(found) => found
case Failure(notFound: AggregateWorkspaceNotFoundException)
if workspaceContext.workspaceType == WorkspaceType.RawlsWorkspace =>
// return workspace with no WSM information for gcp workspace
AggregatedWorkspace(workspaceContext,
Some(workspaceContext.googleProjectId),
azureCloudContext = None,
policies = List.empty
)
case Failure(x) => throw x // bubble up an MC workspace exception
}


// maximum access level is required to calculate canCompute and canShare. Therefore, if any of
// accessLevel, canCompute, canShare is specified, we have to get it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,24 +217,6 @@ class AggregatedWorkspaceServiceSpec extends AnyFlatSpec with MockitoTestUtils {
verify(wsmDao).getWorkspace(ArgumentMatchers.eq(legacyRawlsWorkspace.workspaceIdAsUUID), any[RawlsRequestContext])
}

behavior of "optimizedFetchAggregatedWorkspace"

it should "not reach out to WSM for legacy GCP Rawls workspaces" in {
val wsmDao = mock[WorkspaceManagerDAO]
when(wsmDao.getWorkspace(any[UUID], any[RawlsRequestContext])).thenReturn(
new WorkspaceDescription().stage(WorkspaceStageModel.RAWLS_WORKSPACE)
)
val svc = new AggregatedWorkspaceService(wsmDao)

val aggregatedWorkspace = svc.optimizedFetchAggregatedWorkspace(legacyRawlsWorkspace, defaultRequestContext)

aggregatedWorkspace.baseWorkspace shouldBe legacyRawlsWorkspace
aggregatedWorkspace.getCloudPlatform shouldBe Some(WorkspaceCloudPlatform.Gcp)
aggregatedWorkspace.azureCloudContext shouldBe None
aggregatedWorkspace.googleProjectId shouldBe Some(legacyRawlsWorkspace.googleProjectId)
verify(wsmDao, never()).getWorkspace(any[UUID], any[RawlsRequestContext])
}

behavior of "fetchAggregatedWorkspaces"

it should "match an MC workspace with the correct WSM information" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3042,6 +3042,28 @@ class WorkspaceServiceSpec
)
}

private def createGcpWorkspaceStub(services: TestApiService,
workspaceName: String,
policies: List[WsmPolicyInput] = List(),
workspaceService: WorkspaceService
): Workspace = {
val workspaceRequest = WorkspaceRequest(
testData.testProject1Name.value,
workspaceName,
Map.empty
)
when(services.workspaceManagerDAO.getWorkspace(any[UUID], any[RawlsRequestContext])).thenReturn(
new WorkspaceDescription()
.stage(WorkspaceStageModel.RAWLS_WORKSPACE)
.policies(policies.asJava)
)

Await.result(
services.mcWorkspaceService.createMultiCloudOrRawlsWorkspace(workspaceRequest, workspaceService),
Duration.Inf
)
}

it should "get the details of an Azure workspace" in withTestDataServices { services =>
val managedAppCoordinates = AzureManagedAppCoordinates(UUID.randomUUID(), UUID.randomUUID(), "fake_mrg_id")
val workspace = createAzureWorkspace(services, managedAppCoordinates)
Expand Down Expand Up @@ -3148,6 +3170,47 @@ class WorkspaceServiceSpec
additionalData.tail.head.getOrElse("pair2Key", "fail") shouldEqual "pair2Val"
}

it should "return the policies of a GCP workspace" in withTestDataServices { services =>
val workspaceName = s"rawls-test-workspace-${UUID.randomUUID().toString}"
val workspaceRequest = WorkspaceRequest(
testData.testProject1Name.value,
workspaceName,
Map.empty
)
val wsmPolicyInput = new WsmPolicyInput()
.name("test_name")
.namespace("test_namespace")
.additionalData(
List(
new WsmPolicyPair().value("pair1Val").key("pair1Key"),
new WsmPolicyPair().value("pair2Val").key("pair2Key")
).asJava
)
val workspace = createGcpWorkspaceStub(services, workspaceName, List(wsmPolicyInput), services.workspaceService)
val readWorkspace = Await.result(services.workspaceService.getWorkspace(
WorkspaceName(workspace.namespace, workspace.name),
WorkspaceFieldSpecs()
),
Duration.Inf
)

val response = readWorkspace.convertTo[WorkspaceResponse]

response.workspace.name shouldBe workspaceName
response.azureContext shouldEqual None
response.workspace.cloudPlatform shouldBe Some(WorkspaceCloudPlatform.Gcp)
response.policies should not be empty
val policies: List[WorkspacePolicy] = response.policies.get
policies should not be empty
val policy: WorkspacePolicy = policies.head
policy.name shouldBe wsmPolicyInput.getName
policy.namespace shouldBe wsmPolicyInput.getNamespace
val additionalData = policy.additionalData
additionalData.length shouldEqual 2
additionalData.head.getOrElse("pair1Key", "fail") shouldEqual "pair1Val"
additionalData.tail.head.getOrElse("pair2Key", "fail") shouldEqual "pair2Val"
Comment on lines +3197 to +3211
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also won't block the PR on any changes here; this code is fine. You could streamline a lot of the assertions by defining the expected result and checking for that, though:

Suggested change
val response = readWorkspace.convertTo[WorkspaceResponse]
response.workspace.name shouldBe workspaceName
response.azureContext shouldEqual None
response.workspace.cloudPlatform shouldBe Some(WorkspaceCloudPlatform.Gcp)
response.policies should not be empty
val policies: List[WorkspacePolicy] = response.policies.get
policies should not be empty
val policy: WorkspacePolicy = policies.head
policy.name shouldBe wsmPolicyInput.getName
policy.namespace shouldBe wsmPolicyInput.getNamespace
val additionalData = policy.additionalData
additionalData.length shouldEqual 2
additionalData.head.getOrElse("pair1Key", "fail") shouldEqual "pair1Val"
additionalData.tail.head.getOrElse("pair2Key", "fail") shouldEqual "pair2Val"
val response = readWorkspace.convertTo[WorkspaceResponse]
response.workspace.name shouldBe workspaceName
response.azureContext shouldEqual None
response.workspace.cloudPlatform shouldBe Some(WorkspaceCloudPlatform.Gcp)
val expectedPolicy: WorkspacePolicy = new WorkspacePolicy("test_name",
"test_namespace",
List(
Map("pair1Key" -> "pair1Val"),
Map("pair2Key" -> "pair2Val")
)
)
response.policies should contain(List(expectedPolicy))

}

it should "return correct canCompute permission for Azure workspaces" in withTestDataServices { services =>
val managedAppCoordinates = AzureManagedAppCoordinates(UUID.randomUUID(), UUID.randomUUID(), "fake_mrg_id")
val workspace = createAzureWorkspace(services, managedAppCoordinates)
Expand Down
Loading