diff --git a/d2x/auth/sf/auth_url.py b/d2x/auth/sf/auth_url.py index 9013460..7c2db4f 100644 --- a/d2x/auth/sf/auth_url.py +++ b/d2x/auth/sf/auth_url.py @@ -12,6 +12,7 @@ from rich.panel import Panel from rich.progress import Progress, SpinnerColumn, TextColumn from rich.table import Table +from simple_salesforce import Salesforce # Local imports from d2x.models.sf.auth import ( @@ -23,7 +24,7 @@ SfdxAuthUrlModel, ) from d2x.ux.gh.actions import summary as gha_summary, output as gha_output -from d2x.models.sf.org import SalesforceOrgInfo +from d2x.models.sf.org import SalesforceOrgInfo, ScratchOrg from d2x.base.types import CLIOptions from d2x.api.gh import ( set_environment_variable, @@ -62,7 +63,7 @@ def exchange_token(org_info: SalesforceOrgInfo, cli_options: CLIOptions): # Create debug info debug_info = TokenExchangeDebug( - url=f"https://{org_info.full_domain}{token_url_path}", + url=f"https://{org_info.org.full_domain}{token_url_path}", method="POST", headers=headers, request=token_request, @@ -71,8 +72,8 @@ def exchange_token(org_info: SalesforceOrgInfo, cli_options: CLIOptions): console.print(debug_info.to_table()) # Make request - progress.add_task(f"Connecting to {org_info.full_domain}...", total=None) - conn = http.client.HTTPSConnection(org_info.full_domain) + progress.add_task(f"Connecting to {org_info.org.full_domain}...", total=None) + conn = http.client.HTTPSConnection(org_info.org.full_domain) task = progress.add_task("Exchanging tokens...", total=None) conn.request("POST", token_url_path, body, headers) @@ -116,7 +117,7 @@ def exchange_token(org_info: SalesforceOrgInfo, cli_options: CLIOptions): # Display success success_panel = Panel( - f"[green]Successfully authenticated to {org_info.full_domain}\n" + f"[green]Successfully authenticated to {org_info.org.full_domain}\n" f"[blue]Token Details:[/]\n" f" Issued At: {token_response.issued_at.strftime('%Y-%m-%d %H:%M:%S')}\n" f" Expires At: {token_response.expires_at.strftime('%Y-%m-%d %H:%M:%S')}\n" @@ -148,9 +149,56 @@ def exchange_token(org_info: SalesforceOrgInfo, cli_options: CLIOptions): raise +def create_scratch_org(org_info: ScratchOrg, cli_options: CLIOptions): + """Create a scratch org using simple-salesforce""" + console = cli_options.console + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + transient=True, + ) as progress: + try: + progress.add_task("Creating scratch org...", total=None) + + # Authenticate to Salesforce using simple-salesforce + sf = Salesforce( + username=org_info.auth_info.client_id, + password=org_info.auth_info.client_secret.get_secret_value(), + security_token=org_info.auth_info.refresh_token, + domain=org_info.full_domain, + ) + + # Create scratch org + result = sf.restful("sobjects/ScratchOrg", method="POST", json=org_info.dict()) + + # Display success + success_panel = Panel( + f"[green]Successfully created scratch org\n" + f"[blue]Org Details:[/]\n" + f" Org ID: {result['id']}\n" + f" Status: {result['status']}\n" + f" Expiration Date: {result['expirationDate']}", + title="[green]Scratch Org Creation Success", + border_style="green", + ) + console.print(success_panel) + + return result + + except Exception as e: + error_panel = Panel( + f"[red]Error: {str(e)}", + title="[red]Scratch Org Creation Failed", + border_style="red", + ) + console.print(error_panel) + raise + + def get_full_domain(org_info: SalesforceOrgInfo) -> str: """Construct the full domain from SalesforceOrgInfo.""" - return org_info.full_domain.rstrip("/") + return org_info.org.full_domain.rstrip("/") def main(cli_options: CLIOptions): diff --git a/d2x/cli/main.py b/d2x/cli/main.py index 2e34dae..3cf7365 100644 --- a/d2x/cli/main.py +++ b/d2x/cli/main.py @@ -7,6 +7,8 @@ from typing import Optional from importlib.metadata import version, PackageNotFoundError from d2x.env.gh import set_environment_variable, get_environment_variable, set_environment_secret, get_environment_secret +from d2x.auth.sf.auth_url import create_scratch_org, exchange_token +from d2x.models.sf.org import SalesforceOrgInfo, ScratchOrg # Disable rich_click's syntax highlighting click.SHOW_ARGUMENTS = False @@ -80,6 +82,68 @@ def url(output_format: OutputFormatType, debug: bool): raise +@sf.group() +def org(): + """Salesforce org commands""" + pass + + +@org.command() +@common_options +def create(output_format: OutputFormatType, debug: bool): + """Create a Salesforce scratch org""" + cli_options = CLIOptions(output_format=output_format, debug=debug) + try: + # Assuming org_info is obtained from somewhere, e.g., environment variable or input + org_info = ScratchOrg( + org_type="scratch", + domain_type="my", + full_domain="example.my.salesforce.com", + auth_info=AuthInfo( + client_id="your_client_id", + client_secret="your_client_secret", + refresh_token="your_refresh_token", + instance_url="https://example.my.salesforce.com", + ), + ) + create_scratch_org(org_info, cli_options) + except: + if debug: + type, value, tb = sys.exc_info() + pdb.post_mortem(tb) + else: + raise + + +@org.command() +@common_options +def exchange(output_format: OutputFormatType, debug: bool): + """Exchange auth code for an OAuth grant""" + cli_options = CLIOptions(output_format=output_format, debug=debug) + try: + # Assuming org_info is obtained from somewhere, e.g., environment variable or input + org_info = SalesforceOrgInfo( + auth_info=AuthInfo( + client_id="your_client_id", + client_secret="your_client_secret", + refresh_token="your_refresh_token", + instance_url="https://example.my.salesforce.com", + ), + org=ScratchOrg( + org_type="scratch", + domain_type="my", + full_domain="example.my.salesforce.com", + ), + ) + exchange_token(org_info, cli_options) + except: + if debug: + type, value, tb = sys.exc_info() + pdb.post_mortem(tb) + else: + raise + + @d2x_cli.group() def env(): """Environment commands""" diff --git a/d2x/models/sf/org.py b/d2x/models/sf/org.py index 8553717..8ca274a 100644 --- a/d2x/models/sf/org.py +++ b/d2x/models/sf/org.py @@ -1,5 +1,5 @@ -from typing import Optional, Literal -from pydantic import Field +from typing import Optional, Literal, Union +from pydantic import Field, BaseModel from d2x.base.models import CommonBaseModel from d2x.models.sf.auth import AuthInfo, DomainType, OrgType @@ -23,12 +23,9 @@ PodType = Literal["cs", "db", None] -class SalesforceOrgInfo(CommonBaseModel): - """Structured information about a Salesforce org.""" +class BaseSalesforceOrg(BaseModel): + """Base model for Salesforce orgs with a pydantic discriminator.""" - auth_info: AuthInfo = Field( - ..., description="Authentication information for the Salesforce org." - ) org_type: OrgType = Field(..., description="Type of the Salesforce org.") domain_type: DomainType = Field( ..., description="Type of domain for the Salesforce org." @@ -44,10 +41,79 @@ class SalesforceOrgInfo(CommonBaseModel): pod_number: Optional[str] = Field(None, description="Pod number if applicable.") pod_type: Optional[PodType] = Field(None, description="Pod type if applicable.") + class Config: + use_enum_values = True + discriminator = "org_type" + + +class ProductionOrg(BaseSalesforceOrg): + """Model for Production Salesforce orgs.""" + + org_type: Literal[OrgType.PRODUCTION] = Field( + default=OrgType.PRODUCTION, description="Type of the Salesforce org." + ) + instance_url: str = Field(..., description="Instance URL of the production org.") + created_date: str = Field(..., description="Creation date of the production org.") + last_modified_date: str = Field(..., description="Last modified date of the production org.") + status: str = Field(..., description="Status of the production org.") + + +class TrialOrg(BaseSalesforceOrg): + """Model for Trial Salesforce orgs.""" + + org_type: Literal[OrgType.DEMO] = Field( + default=OrgType.DEMO, description="Type of the Salesforce org." + ) + expiration_date: Optional[str] = Field( + None, description="Expiration date of the trial org." + ) + instance_url: str = Field(..., description="Instance URL of the trial org.") + created_date: str = Field(..., description="Creation date of the trial org.") + last_modified_date: str = Field(..., description="Last modified date of the trial org.") + status: str = Field(..., description="Status of the trial org.") + + +class SandboxOrg(BaseSalesforceOrg): + """Model for Sandbox Salesforce orgs.""" + + org_type: Literal[OrgType.SANDBOX] = Field( + default=OrgType.SANDBOX, description="Type of the Salesforce org." + ) + instance_url: str = Field(..., description="Instance URL of the sandbox org.") + created_date: str = Field(..., description="Creation date of the sandbox org.") + last_modified_date: str = Field(..., description="Last modified date of the sandbox org.") + status: str = Field(..., description="Status of the sandbox org.") + + +class ScratchOrg(BaseSalesforceOrg): + """Model for Scratch Salesforce orgs.""" + + org_type: Literal[OrgType.SCRATCH] = Field( + default=OrgType.SCRATCH, description="Type of the Salesforce org." + ) + expiration_date: Optional[str] = Field( + None, description="Expiration date of the scratch org." + ) + instance_url: str = Field(..., description="Instance URL of the scratch org.") + created_date: str = Field(..., description="Creation date of the scratch org.") + last_modified_date: str = Field(..., description="Last modified date of the scratch org.") + status: str = Field(..., description="Status of the scratch org.") + + +class SalesforceOrgInfo(CommonBaseModel): + """Structured information about a Salesforce org.""" + + auth_info: AuthInfo = Field( + ..., description="Authentication information for the Salesforce org." + ) + org: Union[ProductionOrg, TrialOrg, SandboxOrg, ScratchOrg] = Field( + ..., description="Salesforce org details." + ) + @property def is_classic_pod(self) -> bool: """Determine if the pod is a classic pod.""" - return self.pod_type in ["cs", "db"] + return self.org.pod_type in ["cs", "db"] @property def is_hyperforce(self) -> bool: @@ -57,4 +123,4 @@ def is_hyperforce(self) -> bool: @property def is_sandbox(self) -> bool: """Determine if the org is a sandbox.""" - return self.org_type == OrgType.SANDBOX + return self.org.org_type == OrgType.SANDBOX diff --git a/tests/test_auth_url.py b/tests/test_auth_url.py index a925713..4e6eceb 100644 --- a/tests/test_auth_url.py +++ b/tests/test_auth_url.py @@ -2,8 +2,8 @@ import unittest from unittest.mock import patch, MagicMock from pydantic import SecretStr -from d2x.auth.sf.auth_url import exchange_token -from d2x.models.sf.org import SalesforceOrgInfo +from d2x.auth.sf.auth_url import exchange_token, create_scratch_org +from d2x.models.sf.org import SalesforceOrgInfo, ScratchOrg from d2x.base.types import CLIOptions from d2x.models.sf.auth import AuthInfo @@ -20,9 +20,15 @@ def test_exchange_token_success(self, mock_https_connection, mock_set_env_var): refresh_token="test_refresh_token", instance_url="https://test.salesforce.com", ), - org_type="production", - domain_type="pod", - full_domain="test.salesforce.com", + org=ScratchOrg( + org_type="scratch", + domain_type="my", + full_domain="test.salesforce.com", + instance_url="https://test.salesforce.com", + created_date="2021-01-01", + last_modified_date="2021-01-02", + status="Active", + ), ) # Mock the CLIOptions @@ -69,9 +75,15 @@ def test_exchange_token_failure(self, mock_https_connection, mock_set_env_var): refresh_token="test_refresh_token", instance_url="https://test.salesforce.com", ), - org_type="production", - domain_type="pod", - full_domain="test.salesforce.com", + org=ScratchOrg( + org_type="scratch", + domain_type="my", + full_domain="test.salesforce.com", + instance_url="https://test.salesforce.com", + created_date="2021-01-01", + last_modified_date="2021-01-02", + status="Active", + ), ) # Mock the CLIOptions @@ -93,5 +105,77 @@ def test_exchange_token_failure(self, mock_https_connection, mock_set_env_var): exchange_token(org_info, cli_options) +class TestCreateScratchOrg(unittest.TestCase): + @patch("d2x.auth.sf.auth_url.Salesforce") + def test_create_scratch_org_success(self, mock_salesforce): + # Mock the ScratchOrg + org_info = ScratchOrg( + org_type="scratch", + domain_type="my", + full_domain="test.salesforce.com", + instance_url="https://test.salesforce.com", + created_date="2021-01-01", + last_modified_date="2021-01-02", + status="Active", + auth_info=AuthInfo( + client_id="test_client_id", + client_secret=SecretStr("test_client_secret"), + refresh_token="test_refresh_token", + instance_url="https://test.salesforce.com", + ), + ) + + # Mock the CLIOptions + cli_options = CLIOptions(output_format="text", debug=False) + + # Mock the Salesforce instance and response + mock_sf_instance = MagicMock() + mock_salesforce.return_value = mock_sf_instance + mock_sf_instance.restful.return_value = { + "id": "test_org_id", + "status": "Active", + "expirationDate": "2021-01-10", + } + + # Call the function + result = create_scratch_org(org_info, cli_options) + + # Assertions + self.assertEqual(result["id"], "test_org_id") + self.assertEqual(result["status"], "Active") + self.assertEqual(result["expirationDate"], "2021-01-10") + + @patch("d2x.auth.sf.auth_url.Salesforce") + def test_create_scratch_org_failure(self, mock_salesforce): + # Mock the ScratchOrg + org_info = ScratchOrg( + org_type="scratch", + domain_type="my", + full_domain="test.salesforce.com", + instance_url="https://test.salesforce.com", + created_date="2021-01-01", + last_modified_date="2021-01-02", + status="Active", + auth_info=AuthInfo( + client_id="test_client_id", + client_secret=SecretStr("test_client_secret"), + refresh_token="test_refresh_token", + instance_url="https://test.salesforce.com", + ), + ) + + # Mock the CLIOptions + cli_options = CLIOptions(output_format="text", debug=False) + + # Mock the Salesforce instance and response + mock_sf_instance = MagicMock() + mock_salesforce.return_value = mock_sf_instance + mock_sf_instance.restful.side_effect = Exception("Failed to create scratch org") + + # Call the function and assert exception + with self.assertRaises(Exception): + create_scratch_org(org_info, cli_options) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..6a6ff2d --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,29 @@ +import unittest +from unittest.mock import patch, MagicMock +from click.testing import CliRunner +from d2x.cli.main import d2x_cli +from d2x.models.sf.org import ScratchOrg +from d2x.models.sf.auth import AuthInfo +from d2x.base.types import CLIOptions + + +class TestCLICommands(unittest.TestCase): + @patch("d2x.cli.main.create_scratch_org") + def test_create_scratch_org_command(self, mock_create_scratch_org): + runner = CliRunner() + result = runner.invoke(d2x_cli, ["sf", "org", "create"]) + + self.assertEqual(result.exit_code, 0) + mock_create_scratch_org.assert_called_once() + + @patch("d2x.cli.main.exchange_token") + def test_exchange_token_command(self, mock_exchange_token): + runner = CliRunner() + result = runner.invoke(d2x_cli, ["sf", "org", "exchange"]) + + self.assertEqual(result.exit_code, 0) + mock_exchange_token.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_org_models.py b/tests/test_org_models.py new file mode 100644 index 0000000..9b05a7f --- /dev/null +++ b/tests/test_org_models.py @@ -0,0 +1,101 @@ +import unittest +from pydantic import ValidationError +from d2x.models.sf.org import ( + BaseSalesforceOrg, + ProductionOrg, + TrialOrg, + SandboxOrg, + ScratchOrg, + OrgType, + DomainType, +) + + +class TestBaseSalesforceOrg(unittest.TestCase): + def test_production_org(self): + org = ProductionOrg( + org_type=OrgType.PRODUCTION, + domain_type=DomainType.POD, + full_domain="example.salesforce.com", + instance_url="https://example.salesforce.com", + created_date="2021-01-01", + last_modified_date="2021-01-02", + status="Active", + ) + self.assertEqual(org.org_type, OrgType.PRODUCTION) + self.assertEqual(org.domain_type, DomainType.POD) + self.assertEqual(org.full_domain, "example.salesforce.com") + self.assertEqual(org.instance_url, "https://example.salesforce.com") + self.assertEqual(org.created_date, "2021-01-01") + self.assertEqual(org.last_modified_date, "2021-01-02") + self.assertEqual(org.status, "Active") + + def test_trial_org(self): + org = TrialOrg( + org_type=OrgType.DEMO, + domain_type=DomainType.POD, + full_domain="example.salesforce.com", + instance_url="https://example.salesforce.com", + created_date="2021-01-01", + last_modified_date="2021-01-02", + status="Active", + ) + self.assertEqual(org.org_type, OrgType.DEMO) + self.assertEqual(org.domain_type, DomainType.POD) + self.assertEqual(org.full_domain, "example.salesforce.com") + self.assertEqual(org.instance_url, "https://example.salesforce.com") + self.assertEqual(org.created_date, "2021-01-01") + self.assertEqual(org.last_modified_date, "2021-01-02") + self.assertEqual(org.status, "Active") + + def test_sandbox_org(self): + org = SandboxOrg( + org_type=OrgType.SANDBOX, + domain_type=DomainType.POD, + full_domain="example.salesforce.com", + instance_url="https://example.salesforce.com", + created_date="2021-01-01", + last_modified_date="2021-01-02", + status="Active", + ) + self.assertEqual(org.org_type, OrgType.SANDBOX) + self.assertEqual(org.domain_type, DomainType.POD) + self.assertEqual(org.full_domain, "example.salesforce.com") + self.assertEqual(org.instance_url, "https://example.salesforce.com") + self.assertEqual(org.created_date, "2021-01-01") + self.assertEqual(org.last_modified_date, "2021-01-02") + self.assertEqual(org.status, "Active") + + def test_scratch_org(self): + org = ScratchOrg( + org_type=OrgType.SCRATCH, + domain_type=DomainType.POD, + full_domain="example.salesforce.com", + instance_url="https://example.salesforce.com", + created_date="2021-01-01", + last_modified_date="2021-01-02", + status="Active", + ) + self.assertEqual(org.org_type, OrgType.SCRATCH) + self.assertEqual(org.domain_type, DomainType.POD) + self.assertEqual(org.full_domain, "example.salesforce.com") + self.assertEqual(org.instance_url, "https://example.salesforce.com") + self.assertEqual(org.created_date, "2021-01-01") + self.assertEqual(org.last_modified_date, "2021-01-02") + self.assertEqual(org.status, "Active") + + def test_invalid_org_type(self): + with self.assertRaises(ValidationError): + ProductionOrg( + org_type="invalid_org_type", + domain_type=DomainType.POD, + full_domain="example.salesforce.com", + instance_url="https://example.salesforce.com", + created_date="2021-01-01", + last_modified_date="2021-01-02", + status="Active", + ) + + +if __name__ == "__main__": + unittest.main()