From cd1b8317fe02ffb1c5a165e7a8b26b4761e1851c Mon Sep 17 00:00:00 2001 From: Senthur Ayyappan Date: Sun, 3 Nov 2024 19:44:39 -0500 Subject: [PATCH 1/4] Fix static type checker errors Fixes #7 Fix static type checker errors and add type annotations across multiple files. * **onshape_api/connect.py** - Add type annotations to all functions and methods. - Fix type errors in `Client` class methods. - Ensure all imports are correctly typed. * **onshape_api/data/preprocess.py** - Add type annotations to all functions. - Fix type errors in `extract_ids` and `get_assembly_df` functions. * **onshape_api/graph.py** - Add type annotations to all functions. - Fix type inconsistencies in `create_graph` and `get_robot_link` functions. * **onshape_api/log.py** - Add type annotations to all methods. - Fix type inconsistencies in `Logger` class methods. * **onshape_api/models/assembly.py** - Add type annotations to all classes and methods. - Fix type issues in `PartInstance` and `AssemblyInstance` classes. * **onshape_api/models/document.py** - Add type annotations to all classes and methods. - Fix type issues in `Document` and `DocumentMetaData` classes. * **onshape_api/models/element.py** - Add type annotations to all classes and methods. - Fix type issues in `Element` class. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/imsenthur/onshape-api/issues/7?shareId=XXXX-XXXX-XXXX-XXXX). --- onshape_api/connect.py | 75 +++++++++++++++++++++++----------- onshape_api/data/preprocess.py | 9 +--- onshape_api/graph.py | 40 +++++++++--------- onshape_api/log.py | 32 +++++++-------- onshape_api/models/assembly.py | 20 ++++----- onshape_api/models/document.py | 8 ++-- onshape_api/models/element.py | 9 ++-- onshape_api/models/mass.py | 6 +-- onshape_api/models/variable.py | 12 +++--- 9 files changed, 118 insertions(+), 93 deletions(-) diff --git a/onshape_api/connect.py b/onshape_api/connect.py index 47bdcc7..751bd78 100644 --- a/onshape_api/connect.py +++ b/onshape_api/connect.py @@ -6,7 +6,7 @@ import secrets import string from enum import Enum -from typing import BinaryIO +from typing import Any, BinaryIO, Optional, Union from urllib.parse import parse_qs, urlencode, urlparse import requests @@ -32,7 +32,7 @@ class HTTP(str, Enum): DELETE = "delete" -def load_env_variables(env): +def load_env_variables(env: str) -> tuple[str, str]: """ Load environment variables from the specified .env file. @@ -61,7 +61,7 @@ def load_env_variables(env): return access_key, secret_key -def make_nonce(): +def make_nonce() -> str: """ Generate a unique ID for the request, 25 chars in length @@ -85,7 +85,7 @@ class Client: - logging (bool, default=True): Turn logging on or off """ - def __init__(self, env="./.env", log_file="./onshape.log", log_level=1): + def __init__(self, env: str = "./.env", log_file: str = "./onshape.log", log_level: int = 1) -> None: """ Instantiates an instance of the Onshape class. Reads credentials from a .env file. @@ -96,13 +96,13 @@ def __init__(self, env="./.env", log_file="./onshape.log", log_level=1): - env (str, default='./.env'): Environment file location """ - self._url = BASE_URL + self._url: str = BASE_URL self._access_key, self._secret_key = load_env_variables(env) LOGGER.set_file_name(log_file) LOGGER.set_stream_level(LOG_LEVEL[log_level]) LOGGER.info(f"Onshape API initialized with env file: {env}") - def get_document(self, did): + def get_document(self, did: str) -> DocumentMetaData: """ Get details for a specified document. @@ -116,7 +116,7 @@ def get_document(self, did): return DocumentMetaData.model_validate(_request_json) - def get_elements(self, did, wtype, wid): + def get_elements(self, did: str, wtype: str, wid: str) -> dict[str, Element]: """ Get list of elements in a document. @@ -138,7 +138,7 @@ def get_elements(self, did, wtype, wid): return {element["name"]: Element.model_validate(element) for element in _elements_json} - def get_features_from_partstudio(self, did, wid, eid): + def get_features_from_partstudio(self, did: str, wid: str, eid: str) -> requests.Response: """ Gets the feature list for specified document / workspace / part studio. @@ -156,7 +156,7 @@ def get_features_from_partstudio(self, did, wid, eid): "/api/partstudios/d/" + did + "/w/" + wid + "/e/" + eid + "/features", ) - def get_features_from_assembly(self, did, wtype, wid, eid): + def get_features_from_assembly(self, did: str, wtype: str, wid: str, eid: str) -> dict[str, Any]: """ Gets the feature list for specified document / workspace / part studio. @@ -173,7 +173,7 @@ def get_features_from_assembly(self, did, wtype, wid, eid): "get", "/api/assemblies/d/" + did + "/" + wtype + "/" + wid + "/e/" + eid + "/features" ).json() - def get_variables(self, did, wid, eid): + def get_variables(self, did: str, wid: str, eid: str) -> dict[str, Variable]: """ Get list of variables in a variable studio. @@ -194,7 +194,7 @@ def get_variables(self, did, wid, eid): return {variable["name"]: Variable.model_validate(variable) for variable in _variables_json[0]["variables"]} - def set_variables(self, did, wid, eid, variables): + def set_variables(self, did: str, wid: str, eid: str, variables: dict[str, Variable]) -> requests.Response: """ Set variables in a variable studio. @@ -219,7 +219,7 @@ def set_variables(self, did, wid, eid, variables): body=payload, ) - def create_assembly(self, did, wid, name="My Assembly"): + def create_assembly(self, did: str, wid: str, name: str = "My Assembly") -> requests.Response: """ Creates a new assembly element in the specified document / workspace. @@ -236,7 +236,9 @@ def create_assembly(self, did, wid, name="My Assembly"): return self.request(HTTP.POST, "/api/assemblies/d/" + did + "/w/" + wid, body=payload) - def get_assembly(self, did, wtype, wid, eid, configuration="default"): + def get_assembly( + self, did: str, wtype: str, wid: str, eid: str, configuration: str = "default" + ) -> tuple[Assembly, dict[str, Any]]: _request_path = "/api/assemblies/d/" + did + "/" + wtype + "/" + wid + "/e/" + eid _assembly_json = self.request( HTTP.GET, @@ -251,10 +253,10 @@ def get_assembly(self, did, wtype, wid, eid, configuration="default"): return Assembly.model_validate(_assembly_json), _assembly_json - def get_parts(self, did, wid, eid): + def get_parts(self, did: str, wid: str, eid: str) -> None: pass - def download_stl(self, did, wid, eid, partID, buffer: BinaryIO): + def download_stl(self, did: str, wid: str, eid: str, partID: str, buffer: BinaryIO) -> None: """ Exports STL export from a part studio and saves it to a file. @@ -288,7 +290,7 @@ def download_stl(self, did, wid, eid, partID, buffer: BinaryIO): else: LOGGER.info(f"Failed to download STL file: {response.status_code} - {response.text}") - def get_mass_properties(self, did, wid, eid, partID): + def get_mass_properties(self, did: str, wid: str, eid: str, partID: str) -> MassModel: """ Get mass properties for a part in a part studio. @@ -306,7 +308,16 @@ def get_mass_properties(self, did, wid, eid, partID): return MassModel.model_validate(_resonse_json["bodies"][partID]) - def request(self, method, path, query=None, headers=None, body=None, base_url=None, log_response=True): + def request( + self, + method: Union[HTTP, str], + path: str, + query: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, str]] = None, + body: Optional[dict[str, Any]] = None, + base_url: Optional[str] = None, + log_response: bool = True, + ) -> requests.Response: """ Issues a request to Onshape @@ -347,10 +358,12 @@ def request(self, method, path, query=None, headers=None, body=None, base_url=No return res - def _build_url(self, base_url, path, query): + def _build_url(self, base_url: str, path: str, query: dict[str, Any]) -> str: return base_url + path + "?" + urlencode(query) - def _send_request(self, method, url, headers, body): + def _send_request( + self, method: Union[HTTP, str], url: str, headers: dict[str, str], body: Optional[dict[str, Any]] + ) -> requests.Response: return requests.request( method, url, @@ -361,7 +374,9 @@ def _send_request(self, method, url, headers, body): timeout=10, # Specify an appropriate timeout value in seconds ) - def _handle_redirect(self, res, method, headers, log_response=True): + def _handle_redirect( + self, res: requests.Response, method: Union[HTTP, str], headers: dict[str, str], log_response: bool = True + ) -> requests.Response: location = urlparse(res.headers["Location"]) querystring = parse_qs(location.query) @@ -374,13 +389,21 @@ def _handle_redirect(self, res, method, headers, log_response=True): method, location.path, query=new_query, headers=headers, base_url=new_base_url, log_response=log_response ) - def _log_response(self, res): + def _log_response(self, res: requests.Response) -> None: if not 200 <= res.status_code <= 206: LOGGER.debug(f"Request failed, details: {res.text}") else: LOGGER.debug(f"Request succeeded, details: {res.text}") - def _make_auth(self, method, date, nonce, path, query=None, ctype="application/json"): + def _make_auth( + self, + method: Union[HTTP, str], + date: str, + nonce: str, + path: str, + query: Optional[dict[str, Any]] = None, + ctype: str = "application/json", + ) -> str: """ Create the request signature to authenticate @@ -412,7 +435,13 @@ def _make_auth(self, method, date, nonce, path, query=None, ctype="application/j return auth - def _make_headers(self, method, path, query=None, headers=None): + def _make_headers( + self, + method: Union[HTTP, str], + path: str, + query: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, str]] = None, + ) -> dict[str, str]: """ Creates a headers object to sign the request diff --git a/onshape_api/data/preprocess.py b/onshape_api/data/preprocess.py index a39fda4..2c4a6b9 100644 --- a/onshape_api/data/preprocess.py +++ b/onshape_api/data/preprocess.py @@ -1,27 +1,22 @@ import json import os import re - import pandas as pd - import onshape_api as osa AUTOMATE_ASSEMBLYID_PATTERN = r"(?P\w{24})_(?P\w{24})_(?P\w{24})" - -def extract_ids(assembly_id): +def extract_ids(assembly_id: str) -> dict[str, str | None]: match = re.match(AUTOMATE_ASSEMBLYID_PATTERN, assembly_id) if match: return match.groupdict() else: return {"documentId": None, "documentMicroversion": None, "elementId": None} - -def get_assembly_df(automate_assembly_df): +def get_assembly_df(automate_assembly_df: pd.DataFrame) -> pd.DataFrame: assembly_df = automate_assembly_df["assemblyId"].apply(extract_ids).apply(pd.Series) return assembly_df - if __name__ == "__main__": client = osa.Client() diff --git a/onshape_api/graph.py b/onshape_api/graph.py index 18d2fff..f1a15b8 100644 --- a/onshape_api/graph.py +++ b/onshape_api/graph.py @@ -1,7 +1,7 @@ import io import os import random -from typing import Optional, Union +from typing import Optional, Union, Dict, Tuple, List import matplotlib.pyplot as plt import networkx as nx @@ -39,7 +39,7 @@ CURRENT_DIR = os.getcwd() -def generate_names(max_length: int) -> list[str]: +def generate_names(max_length: int) -> List[str]: words_file_path = os.path.join(SCRIPT_DIR, "words.txt") with open(words_file_path) as file: @@ -51,12 +51,12 @@ def generate_names(max_length: int) -> list[str]: return random.sample(words, max_length) -def show_graph(graph: nx.Graph): +def show_graph(graph: nx.Graph) -> None: nx.draw_circular(graph, with_labels=True) plt.show() -def convert_to_digraph(graph: nx.Graph) -> nx.DiGraph: +def convert_to_digraph(graph: nx.Graph) -> Tuple[nx.DiGraph, str]: _centrality = nx.closeness_centrality(graph) _root_node = max(_centrality, key=_centrality.get) _graph = nx.bfs_tree(graph, _root_node) @@ -64,12 +64,12 @@ def convert_to_digraph(graph: nx.Graph) -> nx.DiGraph: def create_graph( - occurences: dict[str, Occurrence], - instances: dict[str, Instance], - parts: dict[str, Part], - mates: dict[str, MateFeatureData], + occurences: Dict[str, Occurrence], + instances: Dict[str, Instance], + parts: Dict[str, Part], + mates: Dict[str, MateFeatureData], directed: bool = True, -): +) -> Union[nx.Graph, Tuple[nx.DiGraph, str]]: graph = nx.Graph() for occurence in occurences: @@ -90,14 +90,14 @@ def create_graph( LOGGER.warning(f"Mate {mate} not found") if directed: - graph = convert_to_digraph(graph) + graph, root_node = convert_to_digraph(graph) LOGGER.info(f"Graph created with {len(graph.nodes)} nodes and {len(graph.edges)} edges") return graph -def download_stl_mesh(did, wid, eid, partID, client: Client, transform: np.ndarray, file_name: str) -> str: +def download_stl_mesh(did: str, wid: str, eid: str, partID: str, client: Client, transform: np.ndarray, file_name: str) -> str: try: with io.BytesIO() as buffer: LOGGER.info(f"Downloading mesh for {file_name}...") @@ -128,7 +128,7 @@ def get_robot_link( workspaceId: str, client: Client, mate: Optional[Union[MateFeatureData, None]] = None, -): +) -> Tuple[Link, np.matrix]: LOGGER.info(f"Creating robot link for {name}") if mate is None: @@ -190,7 +190,7 @@ def get_robot_joint( child: str, mate: MateFeatureData, stl_to_parent_tf: np.matrix, -): +) -> Union[RevoluteJoint, FixedJoint]: LOGGER.info(f"Creating robot joint from {parent} to {child}") parent_to_mate_tf = mate.matedEntities[1].matedCS.part_to_mate_tf @@ -228,20 +228,20 @@ def get_robot_joint( def get_urdf_components( graph: Union[nx.Graph, nx.DiGraph], workspaceId: str, - parts: dict[str, Part], - mass_properties: dict[str, MassModel], - mates: dict[str, MateFeatureData], + parts: Dict[str, Part], + mass_properties: Dict[str, MassModel], + mates: Dict[str, MateFeatureData], client: Client, -): +) -> Tuple[List[Link], List[Union[RevoluteJoint, FixedJoint]]]: if not isinstance(graph, nx.DiGraph): graph, root_node = convert_to_digraph(graph) - joints = [] - links = [] + joints: List[Union[RevoluteJoint, FixedJoint]] = [] + links: List[Link] = [] _readable_names = generate_names(len(graph.nodes)) _readable_names_mapping = dict(zip(graph.nodes, _readable_names)) - _stl_to_link_tf_mapping = {} + _stl_to_link_tf_mapping: Dict[str, np.matrix] = {} LOGGER.info(f"Processing root node: {_readable_names_mapping[root_node]}") diff --git a/onshape_api/log.py b/onshape_api/log.py index 8303db5..98b2a54 100644 --- a/onshape_api/log.py +++ b/onshape_api/log.py @@ -47,7 +47,7 @@ class LogLevel(Enum): class Logger(logging.Logger): _instance = None - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> "Logger": if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance @@ -60,7 +60,7 @@ def __init__( stream_level: LogLevel = LogLevel.INFO, file_max_bytes: int = 0, file_backup_count: int = 5, - file_name: Union[str, None] = None, + file_name: Optional[str] = None, buffer_size: int = 1000, ) -> None: if not hasattr(self, "_initialized"): @@ -76,7 +76,7 @@ def __init__( self._file_path: str = "" self._csv_path: str = "" self._file: Optional[Any] = None - self._writer = None + self._writer: Optional[csv.writer] = None self._is_logging = False self._header_written = False @@ -121,16 +121,16 @@ def _setup_file_handler(self) -> None: self._file_handler.setFormatter(fmt=self._std_formatter) self.addHandler(hdlr=self._file_handler) - def _ensure_file_handler(self): + def _ensure_file_handler(self) -> None: if not hasattr(self, "_file_handler"): self._setup_file_handler() - def track_variable(self, var_func: Callable[[], Any], name: str): + def track_variable(self, var_func: Callable[[], Any], name: str) -> None: var_id = id(var_func) self._tracked_vars[var_id] = var_func self._var_names[var_id] = name - def untrack_variable(self, var_func: Callable[[], Any]): + def untrack_variable(self, var_func: Callable[[], Any]) -> None: var_id = id(var_func) self._tracked_vars.pop(var_id, None) self._var_names.pop(var_id, None) @@ -138,7 +138,7 @@ def untrack_variable(self, var_func: Callable[[], Any]): def __repr__(self) -> str: return f"Logger(file_path={self._file_path})" - def set_file_name(self, file_name: Union[str, None]) -> None: + def set_file_name(self, file_name: Optional[str]) -> None: self._user_file_name = file_name self._file_path = "" self._csv_path = "" @@ -177,7 +177,7 @@ def update(self) -> None: if len(self._buffer) >= self._buffer_size: self.flush_buffer() - def flush_buffer(self): + def flush_buffer(self) -> None: if not self._buffer: return @@ -214,11 +214,11 @@ def _generate_file_paths(self) -> None: def __enter__(self) -> "Logger": return self - def __exit__(self, exc_type, exc_val, exc_tb) -> None: + def __exit__(self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[Any]) -> None: self.flush_buffer() self.close() - def reset(self): + def reset(self) -> None: self._buffer.clear() self._tracked_vars.clear() self._var_names.clear() @@ -233,27 +233,27 @@ def close(self) -> None: self._file = None self._writer = None - def debug(self, msg, *args, **kwargs): + def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: self._ensure_file_handler() super().debug(msg, *args, **kwargs) - def info(self, msg, *args, **kwargs): + def info(self, msg: str, *args: Any, **kwargs: Any) -> None: self._ensure_file_handler() super().info(msg, *args, **kwargs) - def warning(self, msg, *args, **kwargs): + def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: self._ensure_file_handler() super().warning(msg, *args, **kwargs) - def error(self, msg, *args, **kwargs): + def error(self, msg: str, *args: Any, **kwargs: Any) -> None: self._ensure_file_handler() super().error(msg, *args, **kwargs) - def critical(self, msg, *args, **kwargs): + def critical(self, msg: str, *args: Any, **kwargs: Any) -> None: self._ensure_file_handler() super().critical(msg, *args, **kwargs) - def log(self, level, msg, *args, **kwargs): + def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: self._ensure_file_handler() super().log(level, msg, *args, **kwargs) diff --git a/onshape_api/models/assembly.py b/onshape_api/models/assembly.py index 78c1372..b6e41e6 100644 --- a/onshape_api/models/assembly.py +++ b/onshape_api/models/assembly.py @@ -49,7 +49,7 @@ class Occurrence(BaseModel): path: list[str] @field_validator("transform") - def check_transform(cls, v): + def check_transform(cls, v: list[float]) -> list[float]: if len(v) != 16: raise ValueError("Transform must have 16 values") @@ -75,14 +75,14 @@ class IDBase(BaseModel): documentMicroversion: str @field_validator("documentId", "elementId", "documentMicroversion") - def check_ids(cls, v): + def check_ids(cls, v: str) -> str: if len(v) != 24: raise ValueError("DocumentId must have 24 characters") return v @property - def uid(self): + def uid(self) -> str: return generate_uid([self.documentId, self.documentMicroversion, self.elementId, self.fullConfiguration]) @@ -107,7 +107,7 @@ class Part(IDBase): bodyType: str @property - def uid(self): + def uid(self) -> str: return generate_uid([ self.documentId, self.documentMicroversion, @@ -144,14 +144,14 @@ class PartInstance(IDBase): partId: str @field_validator("type") - def check_type(cls, v): + def check_type(cls, v: InstanceType) -> InstanceType: if v != InstanceType.PART: raise ValueError("Type must be Part") return v @property - def uid(self): + def uid(self) -> str: return generate_uid([ self.documentId, self.documentMicroversion, @@ -184,7 +184,7 @@ class AssemblyInstance(IDBase): suppressed: bool @field_validator("type") - def check_type(cls, v): + def check_type(cls, v: InstanceType) -> InstanceType: if v != InstanceType.ASSEMBLY: raise ValueError("Type must be Assembly") @@ -211,7 +211,7 @@ class MatedCS(BaseModel): origin: list[float] @field_validator("xAxis", "yAxis", "zAxis", "origin") - def check_vectors(cls, v): + def check_vectors(cls, v: list[float]) -> list[float]: if len(v) != 3: raise ValueError("Vectors must have 3 values") @@ -332,7 +332,7 @@ class MateFeature(BaseModel): # return v @field_validator("featureType") - def check_featureType(cls, v): + def check_featureType(cls, v: str) -> str: if v != AssemblyFeatureType.MATE: raise ValueError("FeatureType must be Mate") @@ -349,7 +349,7 @@ class SubAssembly(IDBase): features: list[MateFeature] @property - def uid(self): + def uid(self) -> str: return generate_uid([self.documentId, self.documentMicroversion, self.elementId, self.fullConfiguration]) diff --git a/onshape_api/models/document.py b/onshape_api/models/document.py index 0598039..c5d3890 100644 --- a/onshape_api/models/document.py +++ b/onshape_api/models/document.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Union, cast +from typing import Union, cast, Tuple import regex as re from pydantic import BaseModel, field_validator @@ -16,7 +16,7 @@ class WORKSPACE_TYPE(str, Enum): DOCUMENT_PATTERN = r"https://cad.onshape.com/documents/([\w\d]+)/(w|v|m)/([\w\d]+)/e/([\w\d]+)" -def parse_url(url: str) -> str: +def parse_url(url: str) -> Tuple[str, WORKSPACE_TYPE, str, str]: pattern = re.match( DOCUMENT_PATTERN, url, @@ -36,7 +36,7 @@ def parse_url(url: str) -> str: class Document(BaseModel): url: Union[str, None] did: str - wtype: str + wtype: WORKSPACE_TYPE wid: str eid: str @@ -49,7 +49,7 @@ def check_ids(cls, value: str) -> str: return value @field_validator("wtype") - def check_wtype(cls, value: str) -> str: + def check_wtype(cls, value: WORKSPACE_TYPE) -> WORKSPACE_TYPE: if not value: raise ValueError("Workspace type cannot be empty, please check the URL") diff --git a/onshape_api/models/element.py b/onshape_api/models/element.py index 3101617..644f698 100644 --- a/onshape_api/models/element.py +++ b/onshape_api/models/element.py @@ -23,6 +23,7 @@ """ from enum import Enum +from typing import Any from pydantic import BaseModel, field_validator @@ -39,15 +40,15 @@ class ELEMENT_TYPE(str, Enum): class Element(BaseModel): id: str name: str - elementType: str + elementType: ELEMENT_TYPE microversionId: str @field_validator("elementType") - def validate_type(cls, value: str) -> str: + def validate_type(cls, value: str) -> ELEMENT_TYPE: if value not in ELEMENT_TYPE.__members__.values(): raise ValueError(f"Invalid element type: {value}") - return value + return ELEMENT_TYPE(value) @field_validator("id") def validate_id(cls, value: str) -> str: @@ -65,7 +66,7 @@ def validate_mid(cls, value: str) -> str: if __name__ == "__main__": - element_json = { + element_json: dict[str, Any] = { "name": "wheelAndFork", "id": "0b0c209535554345432581fe", "type": "Part Studio", diff --git a/onshape_api/models/mass.py b/onshape_api/models/mass.py index bf0d444..34d43c8 100644 --- a/onshape_api/models/mass.py +++ b/onshape_api/models/mass.py @@ -53,11 +53,11 @@ class MassModel(BaseModel): mass: list[float] centroid: list[float] inertia: list[float] - principalInertia: list[float, float, float] + principalInertia: list[float] principalAxes: list[PrincipalAxis] @field_validator("principalAxes") - def check_principal_axes(cls, v): + def check_principal_axes(cls, v: list[PrincipalAxis]) -> list[PrincipalAxis]: if len(v) != 3: raise ValueError("Principal axes must have 3 elements") return v @@ -92,7 +92,7 @@ def inertia_wrt(self, reference: np.matrix) -> np.matrix: def center_of_mass_wrt(self, reference: np.matrix) -> np.ndarray: if reference.shape != (4, 4): - raise ValueError("Reference frame must be a 3x3 matrix") + raise ValueError("Reference frame must be a 4x4 matrix") com = np.matrix([*list(self.center_of_mass), 1.0]) com_wrt = (reference * com.T)[:3] diff --git a/onshape_api/models/variable.py b/onshape_api/models/variable.py index 7ec57bd..a4fe6b1 100644 --- a/onshape_api/models/variable.py +++ b/onshape_api/models/variable.py @@ -10,7 +10,7 @@ """ from enum import Enum -from typing import Union +from typing import Union, Optional from pydantic import BaseModel, field_validator @@ -36,11 +36,11 @@ class Variable(BaseModel): } """ - type: str + type: VARIABLE_TYPE name: str - value: Union[str, None] = None - description: str = None - expression: str = None + value: Optional[str] = None + description: Optional[str] = None + expression: Optional[str] = None @field_validator("name") def validate_name(cls, value: str) -> str: @@ -50,7 +50,7 @@ def validate_name(cls, value: str) -> str: return value @field_validator("type") - def validate_type(cls, value: str) -> str: + def validate_type(cls, value: VARIABLE_TYPE) -> VARIABLE_TYPE: if value not in VARIABLE_TYPE.__members__.values(): raise ValueError(f"Invalid variable type: {value}") From 54d47013669b2ca73b1e54151c7fc69967484e4c Mon Sep 17 00:00:00 2001 From: imsenthur Date: Mon, 4 Nov 2024 15:07:04 -0500 Subject: [PATCH 2/4] More ruff fixes. --- onshape_api/data/preprocess.py | 5 +++++ onshape_api/graph.py | 36 ++++++++++++++++++---------------- onshape_api/log.py | 6 ++++-- onshape_api/models/document.py | 4 ++-- onshape_api/models/variable.py | 2 +- 5 files changed, 31 insertions(+), 22 deletions(-) diff --git a/onshape_api/data/preprocess.py b/onshape_api/data/preprocess.py index 2c4a6b9..d1c6826 100644 --- a/onshape_api/data/preprocess.py +++ b/onshape_api/data/preprocess.py @@ -1,11 +1,14 @@ import json import os import re + import pandas as pd + import onshape_api as osa AUTOMATE_ASSEMBLYID_PATTERN = r"(?P\w{24})_(?P\w{24})_(?P\w{24})" + def extract_ids(assembly_id: str) -> dict[str, str | None]: match = re.match(AUTOMATE_ASSEMBLYID_PATTERN, assembly_id) if match: @@ -13,10 +16,12 @@ def extract_ids(assembly_id: str) -> dict[str, str | None]: else: return {"documentId": None, "documentMicroversion": None, "elementId": None} + def get_assembly_df(automate_assembly_df: pd.DataFrame) -> pd.DataFrame: assembly_df = automate_assembly_df["assemblyId"].apply(extract_ids).apply(pd.Series) return assembly_df + if __name__ == "__main__": client = osa.Client() diff --git a/onshape_api/graph.py b/onshape_api/graph.py index f1a15b8..bef01c4 100644 --- a/onshape_api/graph.py +++ b/onshape_api/graph.py @@ -1,7 +1,7 @@ import io import os import random -from typing import Optional, Union, Dict, Tuple, List +from typing import Optional, Union import matplotlib.pyplot as plt import networkx as nx @@ -39,7 +39,7 @@ CURRENT_DIR = os.getcwd() -def generate_names(max_length: int) -> List[str]: +def generate_names(max_length: int) -> list[str]: words_file_path = os.path.join(SCRIPT_DIR, "words.txt") with open(words_file_path) as file: @@ -56,7 +56,7 @@ def show_graph(graph: nx.Graph) -> None: plt.show() -def convert_to_digraph(graph: nx.Graph) -> Tuple[nx.DiGraph, str]: +def convert_to_digraph(graph: nx.Graph) -> tuple[nx.DiGraph, str]: _centrality = nx.closeness_centrality(graph) _root_node = max(_centrality, key=_centrality.get) _graph = nx.bfs_tree(graph, _root_node) @@ -64,12 +64,12 @@ def convert_to_digraph(graph: nx.Graph) -> Tuple[nx.DiGraph, str]: def create_graph( - occurences: Dict[str, Occurrence], - instances: Dict[str, Instance], - parts: Dict[str, Part], - mates: Dict[str, MateFeatureData], + occurences: dict[str, Occurrence], + instances: dict[str, Instance], + parts: dict[str, Part], + mates: dict[str, MateFeatureData], directed: bool = True, -) -> Union[nx.Graph, Tuple[nx.DiGraph, str]]: +) -> Union[nx.Graph, tuple[nx.DiGraph, str]]: graph = nx.Graph() for occurence in occurences: @@ -97,7 +97,9 @@ def create_graph( return graph -def download_stl_mesh(did: str, wid: str, eid: str, partID: str, client: Client, transform: np.ndarray, file_name: str) -> str: +def download_stl_mesh( + did: str, wid: str, eid: str, partID: str, client: Client, transform: np.ndarray, file_name: str +) -> str: try: with io.BytesIO() as buffer: LOGGER.info(f"Downloading mesh for {file_name}...") @@ -128,7 +130,7 @@ def get_robot_link( workspaceId: str, client: Client, mate: Optional[Union[MateFeatureData, None]] = None, -) -> Tuple[Link, np.matrix]: +) -> tuple[Link, np.matrix]: LOGGER.info(f"Creating robot link for {name}") if mate is None: @@ -228,20 +230,20 @@ def get_robot_joint( def get_urdf_components( graph: Union[nx.Graph, nx.DiGraph], workspaceId: str, - parts: Dict[str, Part], - mass_properties: Dict[str, MassModel], - mates: Dict[str, MateFeatureData], + parts: dict[str, Part], + mass_properties: dict[str, MassModel], + mates: dict[str, MateFeatureData], client: Client, -) -> Tuple[List[Link], List[Union[RevoluteJoint, FixedJoint]]]: +) -> tuple[list[Link], list[Union[RevoluteJoint, FixedJoint]]]: if not isinstance(graph, nx.DiGraph): graph, root_node = convert_to_digraph(graph) - joints: List[Union[RevoluteJoint, FixedJoint]] = [] - links: List[Link] = [] + joints: list[Union[RevoluteJoint, FixedJoint]] = [] + links: list[Link] = [] _readable_names = generate_names(len(graph.nodes)) _readable_names_mapping = dict(zip(graph.nodes, _readable_names)) - _stl_to_link_tf_mapping: Dict[str, np.matrix] = {} + _stl_to_link_tf_mapping: dict[str, np.matrix] = {} LOGGER.info(f"Processing root node: {_readable_names_mapping[root_node]}") diff --git a/onshape_api/log.py b/onshape_api/log.py index 98b2a54..75e8a2a 100644 --- a/onshape_api/log.py +++ b/onshape_api/log.py @@ -31,7 +31,7 @@ from datetime import datetime from enum import Enum from logging.handlers import RotatingFileHandler -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional __all__ = ["LOGGER", "LOG_LEVEL", "Logger"] @@ -214,7 +214,9 @@ def _generate_file_paths(self) -> None: def __enter__(self) -> "Logger": return self - def __exit__(self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[Any]) -> None: + def __exit__( + self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[Any] + ) -> None: self.flush_buffer() self.close() diff --git a/onshape_api/models/document.py b/onshape_api/models/document.py index c5d3890..fc0b144 100644 --- a/onshape_api/models/document.py +++ b/onshape_api/models/document.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Union, cast, Tuple +from typing import Union, cast import regex as re from pydantic import BaseModel, field_validator @@ -16,7 +16,7 @@ class WORKSPACE_TYPE(str, Enum): DOCUMENT_PATTERN = r"https://cad.onshape.com/documents/([\w\d]+)/(w|v|m)/([\w\d]+)/e/([\w\d]+)" -def parse_url(url: str) -> Tuple[str, WORKSPACE_TYPE, str, str]: +def parse_url(url: str) -> tuple[str, WORKSPACE_TYPE, str, str]: pattern = re.match( DOCUMENT_PATTERN, url, diff --git a/onshape_api/models/variable.py b/onshape_api/models/variable.py index a4fe6b1..382daf0 100644 --- a/onshape_api/models/variable.py +++ b/onshape_api/models/variable.py @@ -10,7 +10,7 @@ """ from enum import Enum -from typing import Union, Optional +from typing import Optional from pydantic import BaseModel, field_validator From bf51f8a6cafa7e4b7c138489abbd8a67f798edd8 Mon Sep 17 00:00:00 2001 From: imsenthur Date: Mon, 4 Nov 2024 15:41:50 -0500 Subject: [PATCH 3/4] Enforcing CRLF and minor ruff fixes. --- examples/bike/main.py | 4 ++-- onshape_api/graph.py | 23 ++++++++++++++++------- onshape_api/models/link.py | 2 +- onshape_api/models/mass.py | 4 ++-- pyproject.toml | 1 + 5 files changed, 22 insertions(+), 12 deletions(-) diff --git a/examples/bike/main.py b/examples/bike/main.py index aa1fb4c..9d63c58 100644 --- a/examples/bike/main.py +++ b/examples/bike/main.py @@ -21,10 +21,10 @@ variables["wheelDiameter"].expression = "300 mm" variables["wheelThickness"].expression = "71 mm" -variables["forkAngle"].expression = "30 deg" +variables["forkAngle"].expression = "20 deg" client.set_variables(doc.did, doc.wid, elements["variables"].id, variables) -assembly = client.get_assembly(doc.did, doc.wtype, doc.wid, elements["assembly"].id) +assembly, _ = client.get_assembly(doc.did, doc.wtype, doc.wid, elements["assembly"].id) occurences = get_occurences(assembly) instances = get_instances(assembly) diff --git a/onshape_api/graph.py b/onshape_api/graph.py index bef01c4..0078d8c 100644 --- a/onshape_api/graph.py +++ b/onshape_api/graph.py @@ -7,6 +7,7 @@ import networkx as nx import numpy as np import stl +from networkx import DiGraph, Graph from onshape_api.connect import Client from onshape_api.log import LOGGER @@ -51,12 +52,16 @@ def generate_names(max_length: int) -> list[str]: return random.sample(words, max_length) -def show_graph(graph: nx.Graph) -> None: +def get_random_color() -> tuple[float, float, float, float]: + return random.SystemRandom().choice(list(COLORS)).value + + +def show_graph(graph: Graph) -> None: nx.draw_circular(graph, with_labels=True) plt.show() -def convert_to_digraph(graph: nx.Graph) -> tuple[nx.DiGraph, str]: +def convert_to_digraph(graph: Graph) -> tuple[DiGraph, str]: _centrality = nx.closeness_centrality(graph) _root_node = max(_centrality, key=_centrality.get) _graph = nx.bfs_tree(graph, _root_node) @@ -69,8 +74,8 @@ def create_graph( parts: dict[str, Part], mates: dict[str, MateFeatureData], directed: bool = True, -) -> Union[nx.Graph, tuple[nx.DiGraph, str]]: - graph = nx.Graph() +) -> Union[Graph, tuple[DiGraph, str]]: + graph: Graph = Graph() for occurence in occurences: if instances[occurence].type == InstanceType.PART: @@ -161,7 +166,7 @@ def get_robot_link( visual=VisualLink( origin=_origin, geometry=MeshGeometry(_mesh_path), - material=Material.from_color(name=f"{name}_material", color=random.SystemRandom().choice(list(COLORS))), + material=Material.from_color(name=f"{name}_material", color=get_random_color()), ), inertial=InertialLink( origin=Origin( @@ -228,14 +233,14 @@ def get_robot_joint( def get_urdf_components( - graph: Union[nx.Graph, nx.DiGraph], + graph: Union[Graph, DiGraph], workspaceId: str, parts: dict[str, Part], mass_properties: dict[str, MassModel], mates: dict[str, MateFeatureData], client: Client, ) -> tuple[list[Link], list[Union[RevoluteJoint, FixedJoint]]]: - if not isinstance(graph, nx.DiGraph): + if not isinstance(graph, DiGraph): graph, root_node = convert_to_digraph(graph) joints: list[Union[RevoluteJoint, FixedJoint]] = [] @@ -287,3 +292,7 @@ def get_urdf_components( links.append(_link) return links, joints + + +if __name__ == "__main__": + print(get_random_color()) diff --git a/onshape_api/models/link.py b/onshape_api/models/link.py index 533c26d..f0c9317 100644 --- a/onshape_api/models/link.py +++ b/onshape_api/models/link.py @@ -102,7 +102,7 @@ def to_xml(self, root: ET.Element | None = None) -> ET.Element: return material @classmethod - def from_color(cls, name: str, color: COLORS) -> "Material": + def from_color(cls, name: str, color: tuple[float, float, float, float]) -> "Material": return cls(name, color) diff --git a/onshape_api/models/mass.py b/onshape_api/models/mass.py index 34d43c8..2967800 100644 --- a/onshape_api/models/mass.py +++ b/onshape_api/models/mass.py @@ -90,13 +90,13 @@ def inertia_wrt(self, reference: np.matrix) -> np.matrix: return reference @ self.inertia_matrix @ reference.T - def center_of_mass_wrt(self, reference: np.matrix) -> np.ndarray: + def center_of_mass_wrt(self, reference: np.matrix) -> tuple[float, float, float]: if reference.shape != (4, 4): raise ValueError("Reference frame must be a 4x4 matrix") com = np.matrix([*list(self.center_of_mass), 1.0]) com_wrt = (reference * com.T)[:3] - return np.array([com_wrt[0, 0], com_wrt[1, 0], com_wrt[2, 0]]) + return (com_wrt[0, 0], com_wrt[1, 0], com_wrt[2, 0]) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index a46d904..bf15566 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ deptry = "^0.16.2" mypy = "^1.5.1" pre-commit = "^3.4.0" tox = "^4.11.1" +types-networkx = "^3.3.0.20241020" [tool.poetry.group.docs.dependencies] mkdocs = "^1.4.2" From e8e7b5cf5ec149ca4b6aa26873dcc2b242e102e2 Mon Sep 17 00:00:00 2001 From: imsenthur Date: Mon, 4 Nov 2024 16:01:33 -0500 Subject: [PATCH 4/4] From 60+ ruff errors to 45. --- onshape_api/graph.py | 3 ++- onshape_api/mesh.py | 2 +- pyproject.toml | 2 ++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/onshape_api/graph.py b/onshape_api/graph.py index 0078d8c..0b966f1 100644 --- a/onshape_api/graph.py +++ b/onshape_api/graph.py @@ -53,7 +53,8 @@ def generate_names(max_length: int) -> list[str]: def get_random_color() -> tuple[float, float, float, float]: - return random.SystemRandom().choice(list(COLORS)).value + _color: COLORS = random.SystemRandom().choice(list(COLORS)) # nosec + return tuple(_color.value) def show_graph(graph: Graph) -> None: diff --git a/onshape_api/mesh.py b/onshape_api/mesh.py index 0e05189..bf971b1 100644 --- a/onshape_api/mesh.py +++ b/onshape_api/mesh.py @@ -53,4 +53,4 @@ def transform_inertia_matrix(inertia_matrix: np.matrix, rotation: np.matrix) -> Returns: - np.matrix: Transformed inertia matrix """ - return rotation @ inertia_matrix @ rotation.T + return np.matrix(rotation @ inertia_matrix @ rotation.T) diff --git a/pyproject.toml b/pyproject.toml index bf15566..6d7a760 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ mypy = "^1.5.1" pre-commit = "^3.4.0" tox = "^4.11.1" types-networkx = "^3.3.0.20241020" +types-requests = "^2.32.0.20241016" +pandas-stubs = "^2.2.3.241009" [tool.poetry.group.docs.dependencies] mkdocs = "^1.4.2"