Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix static type checker errors #8

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/bike/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

variables["wheelDiameter"].expression = "300 mm"
variables["wheelThickness"].expression = "71 mm"
variables["forkAngle"].expression = "30 deg"
variables["forkAngle"].expression = "20 deg"

client.set_variables(doc.did, doc.wid, elements["variables"].id, variables)
assembly = client.get_assembly(doc.did, doc.wtype, doc.wid, elements["assembly"].id)
assembly, _ = client.get_assembly(doc.did, doc.wtype, doc.wid, elements["assembly"].id)

occurences = get_occurences(assembly)
instances = get_instances(assembly)
Expand Down
75 changes: 52 additions & 23 deletions onshape_api/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import secrets
import string
from enum import Enum
from typing import BinaryIO
from typing import Any, BinaryIO, Optional, Union
from urllib.parse import parse_qs, urlencode, urlparse

import requests
Expand All @@ -32,7 +32,7 @@ class HTTP(str, Enum):
DELETE = "delete"


def load_env_variables(env):
def load_env_variables(env: str) -> tuple[str, str]:
"""
Load environment variables from the specified .env file.

Expand Down Expand Up @@ -61,7 +61,7 @@ def load_env_variables(env):
return access_key, secret_key


def make_nonce():
def make_nonce() -> str:
"""
Generate a unique ID for the request, 25 chars in length

Expand All @@ -85,7 +85,7 @@ class Client:
- logging (bool, default=True): Turn logging on or off
"""

def __init__(self, env="./.env", log_file="./onshape.log", log_level=1):
def __init__(self, env: str = "./.env", log_file: str = "./onshape.log", log_level: int = 1) -> None:
"""
Instantiates an instance of the Onshape class. Reads credentials from a .env file.

Expand All @@ -96,13 +96,13 @@ def __init__(self, env="./.env", log_file="./onshape.log", log_level=1):
- env (str, default='./.env'): Environment file location
"""

self._url = BASE_URL
self._url: str = BASE_URL
self._access_key, self._secret_key = load_env_variables(env)
LOGGER.set_file_name(log_file)
LOGGER.set_stream_level(LOG_LEVEL[log_level])
LOGGER.info(f"Onshape API initialized with env file: {env}")

def get_document(self, did):
def get_document(self, did: str) -> DocumentMetaData:
"""
Get details for a specified document.

Expand All @@ -116,7 +116,7 @@ def get_document(self, did):

return DocumentMetaData.model_validate(_request_json)

def get_elements(self, did, wtype, wid):
def get_elements(self, did: str, wtype: str, wid: str) -> dict[str, Element]:
"""
Get list of elements in a document.

Expand All @@ -138,7 +138,7 @@ def get_elements(self, did, wtype, wid):

return {element["name"]: Element.model_validate(element) for element in _elements_json}

def get_features_from_partstudio(self, did, wid, eid):
def get_features_from_partstudio(self, did: str, wid: str, eid: str) -> requests.Response:
"""
Gets the feature list for specified document / workspace / part studio.

Expand All @@ -156,7 +156,7 @@ def get_features_from_partstudio(self, did, wid, eid):
"/api/partstudios/d/" + did + "/w/" + wid + "/e/" + eid + "/features",
)

def get_features_from_assembly(self, did, wtype, wid, eid):
def get_features_from_assembly(self, did: str, wtype: str, wid: str, eid: str) -> dict[str, Any]:
"""
Gets the feature list for specified document / workspace / part studio.

Expand All @@ -173,7 +173,7 @@ def get_features_from_assembly(self, did, wtype, wid, eid):
"get", "/api/assemblies/d/" + did + "/" + wtype + "/" + wid + "/e/" + eid + "/features"
).json()

def get_variables(self, did, wid, eid):
def get_variables(self, did: str, wid: str, eid: str) -> dict[str, Variable]:
"""
Get list of variables in a variable studio.

Expand All @@ -194,7 +194,7 @@ def get_variables(self, did, wid, eid):

return {variable["name"]: Variable.model_validate(variable) for variable in _variables_json[0]["variables"]}

def set_variables(self, did, wid, eid, variables):
def set_variables(self, did: str, wid: str, eid: str, variables: dict[str, Variable]) -> requests.Response:
"""
Set variables in a variable studio.

Expand All @@ -219,7 +219,7 @@ def set_variables(self, did, wid, eid, variables):
body=payload,
)

def create_assembly(self, did, wid, name="My Assembly"):
def create_assembly(self, did: str, wid: str, name: str = "My Assembly") -> requests.Response:
"""
Creates a new assembly element in the specified document / workspace.

Expand All @@ -236,7 +236,9 @@ def create_assembly(self, did, wid, name="My Assembly"):

return self.request(HTTP.POST, "/api/assemblies/d/" + did + "/w/" + wid, body=payload)

def get_assembly(self, did, wtype, wid, eid, configuration="default"):
def get_assembly(
self, did: str, wtype: str, wid: str, eid: str, configuration: str = "default"
) -> tuple[Assembly, dict[str, Any]]:
_request_path = "/api/assemblies/d/" + did + "/" + wtype + "/" + wid + "/e/" + eid
_assembly_json = self.request(
HTTP.GET,
Expand All @@ -251,10 +253,10 @@ def get_assembly(self, did, wtype, wid, eid, configuration="default"):

return Assembly.model_validate(_assembly_json), _assembly_json

def get_parts(self, did, wid, eid):
def get_parts(self, did: str, wid: str, eid: str) -> None:
pass

def download_stl(self, did, wid, eid, partID, buffer: BinaryIO):
def download_stl(self, did: str, wid: str, eid: str, partID: str, buffer: BinaryIO) -> None:
"""
Exports STL export from a part studio and saves it to a file.

Expand Down Expand Up @@ -288,7 +290,7 @@ def download_stl(self, did, wid, eid, partID, buffer: BinaryIO):
else:
LOGGER.info(f"Failed to download STL file: {response.status_code} - {response.text}")

def get_mass_properties(self, did, wid, eid, partID):
def get_mass_properties(self, did: str, wid: str, eid: str, partID: str) -> MassModel:
"""
Get mass properties for a part in a part studio.

Expand All @@ -306,7 +308,16 @@ def get_mass_properties(self, did, wid, eid, partID):

return MassModel.model_validate(_resonse_json["bodies"][partID])

def request(self, method, path, query=None, headers=None, body=None, base_url=None, log_response=True):
def request(
self,
method: Union[HTTP, str],
path: str,
query: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
body: Optional[dict[str, Any]] = None,
base_url: Optional[str] = None,
log_response: bool = True,
) -> requests.Response:
"""
Issues a request to Onshape

Expand Down Expand Up @@ -347,10 +358,12 @@ def request(self, method, path, query=None, headers=None, body=None, base_url=No

return res

def _build_url(self, base_url, path, query):
def _build_url(self, base_url: str, path: str, query: dict[str, Any]) -> str:
return base_url + path + "?" + urlencode(query)

def _send_request(self, method, url, headers, body):
def _send_request(
self, method: Union[HTTP, str], url: str, headers: dict[str, str], body: Optional[dict[str, Any]]
) -> requests.Response:
return requests.request(
method,
url,
Expand All @@ -361,7 +374,9 @@ def _send_request(self, method, url, headers, body):
timeout=10, # Specify an appropriate timeout value in seconds
)

def _handle_redirect(self, res, method, headers, log_response=True):
def _handle_redirect(
self, res: requests.Response, method: Union[HTTP, str], headers: dict[str, str], log_response: bool = True
) -> requests.Response:
location = urlparse(res.headers["Location"])
querystring = parse_qs(location.query)

Expand All @@ -374,13 +389,21 @@ def _handle_redirect(self, res, method, headers, log_response=True):
method, location.path, query=new_query, headers=headers, base_url=new_base_url, log_response=log_response
)

def _log_response(self, res):
def _log_response(self, res: requests.Response) -> None:
if not 200 <= res.status_code <= 206:
LOGGER.debug(f"Request failed, details: {res.text}")
else:
LOGGER.debug(f"Request succeeded, details: {res.text}")

def _make_auth(self, method, date, nonce, path, query=None, ctype="application/json"):
def _make_auth(
self,
method: Union[HTTP, str],
date: str,
nonce: str,
path: str,
query: Optional[dict[str, Any]] = None,
ctype: str = "application/json",
) -> str:
"""
Create the request signature to authenticate

Expand Down Expand Up @@ -412,7 +435,13 @@ def _make_auth(self, method, date, nonce, path, query=None, ctype="application/j

return auth

def _make_headers(self, method, path, query=None, headers=None):
def _make_headers(
self,
method: Union[HTTP, str],
path: str,
query: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
) -> dict[str, str]:
"""
Creates a headers object to sign the request

Expand Down
4 changes: 2 additions & 2 deletions onshape_api/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
AUTOMATE_ASSEMBLYID_PATTERN = r"(?P<documentId>\w{24})_(?P<documentMicroversion>\w{24})_(?P<elementId>\w{24})"


def extract_ids(assembly_id):
def extract_ids(assembly_id: str) -> dict[str, str | None]:
match = re.match(AUTOMATE_ASSEMBLYID_PATTERN, assembly_id)
if match:
return match.groupdict()
else:
return {"documentId": None, "documentMicroversion": None, "elementId": None}


def get_assembly_df(automate_assembly_df):
def get_assembly_df(automate_assembly_df: pd.DataFrame) -> pd.DataFrame:
assembly_df = automate_assembly_df["assemblyId"].apply(extract_ids).apply(pd.Series)
return assembly_df

Expand Down
42 changes: 27 additions & 15 deletions onshape_api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import networkx as nx
import numpy as np
import stl
from networkx import DiGraph, Graph

from onshape_api.connect import Client
from onshape_api.log import LOGGER
Expand Down Expand Up @@ -51,12 +52,17 @@ def generate_names(max_length: int) -> list[str]:
return random.sample(words, max_length)


def show_graph(graph: nx.Graph):
def get_random_color() -> tuple[float, float, float, float]:
_color: COLORS = random.SystemRandom().choice(list(COLORS)) # nosec
return tuple(_color.value)


def show_graph(graph: Graph) -> None:
nx.draw_circular(graph, with_labels=True)
plt.show()


def convert_to_digraph(graph: nx.Graph) -> nx.DiGraph:
def convert_to_digraph(graph: Graph) -> tuple[DiGraph, str]:
_centrality = nx.closeness_centrality(graph)
_root_node = max(_centrality, key=_centrality.get)
_graph = nx.bfs_tree(graph, _root_node)
Expand All @@ -69,8 +75,8 @@ def create_graph(
parts: dict[str, Part],
mates: dict[str, MateFeatureData],
directed: bool = True,
):
graph = nx.Graph()
) -> Union[Graph, tuple[DiGraph, str]]:
graph: Graph = Graph()

for occurence in occurences:
if instances[occurence].type == InstanceType.PART:
Expand All @@ -90,14 +96,16 @@ def create_graph(
LOGGER.warning(f"Mate {mate} not found")

if directed:
graph = convert_to_digraph(graph)
graph, root_node = convert_to_digraph(graph)

LOGGER.info(f"Graph created with {len(graph.nodes)} nodes and {len(graph.edges)} edges")

return graph


def download_stl_mesh(did, wid, eid, partID, client: Client, transform: np.ndarray, file_name: str) -> str:
def download_stl_mesh(
did: str, wid: str, eid: str, partID: str, client: Client, transform: np.ndarray, file_name: str
) -> str:
try:
with io.BytesIO() as buffer:
LOGGER.info(f"Downloading mesh for {file_name}...")
Expand Down Expand Up @@ -128,7 +136,7 @@ def get_robot_link(
workspaceId: str,
client: Client,
mate: Optional[Union[MateFeatureData, None]] = None,
):
) -> tuple[Link, np.matrix]:
LOGGER.info(f"Creating robot link for {name}")

if mate is None:
Expand Down Expand Up @@ -159,7 +167,7 @@ def get_robot_link(
visual=VisualLink(
origin=_origin,
geometry=MeshGeometry(_mesh_path),
material=Material.from_color(name=f"{name}_material", color=random.SystemRandom().choice(list(COLORS))),
material=Material.from_color(name=f"{name}_material", color=get_random_color()),
),
inertial=InertialLink(
origin=Origin(
Expand Down Expand Up @@ -190,7 +198,7 @@ def get_robot_joint(
child: str,
mate: MateFeatureData,
stl_to_parent_tf: np.matrix,
):
) -> Union[RevoluteJoint, FixedJoint]:
LOGGER.info(f"Creating robot joint from {parent} to {child}")

parent_to_mate_tf = mate.matedEntities[1].matedCS.part_to_mate_tf
Expand Down Expand Up @@ -226,22 +234,22 @@ def get_robot_joint(


def get_urdf_components(
graph: Union[nx.Graph, nx.DiGraph],
graph: Union[Graph, DiGraph],
workspaceId: str,
parts: dict[str, Part],
mass_properties: dict[str, MassModel],
mates: dict[str, MateFeatureData],
client: Client,
):
if not isinstance(graph, nx.DiGraph):
) -> tuple[list[Link], list[Union[RevoluteJoint, FixedJoint]]]:
if not isinstance(graph, DiGraph):
graph, root_node = convert_to_digraph(graph)

joints = []
links = []
joints: list[Union[RevoluteJoint, FixedJoint]] = []
links: list[Link] = []

_readable_names = generate_names(len(graph.nodes))
_readable_names_mapping = dict(zip(graph.nodes, _readable_names))
_stl_to_link_tf_mapping = {}
_stl_to_link_tf_mapping: dict[str, np.matrix] = {}

LOGGER.info(f"Processing root node: {_readable_names_mapping[root_node]}")

Expand Down Expand Up @@ -285,3 +293,7 @@ def get_urdf_components(
links.append(_link)

return links, joints


if __name__ == "__main__":
print(get_random_color())
Loading
Loading