Skip to content

Commit

Permalink
Initial attempt at adding discriminated unions.
Browse files Browse the repository at this point in the history
  • Loading branch information
senthurayyappan committed Nov 8, 2024
1 parent f23c306 commit 558ecd8
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 31 deletions.
2 changes: 1 addition & 1 deletion onshape_api/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def get_assembly(self, did, wtype, wid, eid, configuration="default"):
query={
"includeMateFeatures": "true",
"includeMateConnectors": "true",
"includeNonSolids": "true",
"includeNonSolids": "false",
"configuration": configuration,
},
).json()
Expand Down
18 changes: 14 additions & 4 deletions onshape_api/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import pandas as pd

import onshape_api as osa
from onshape_api.connect import Client
from onshape_api.models import Assembly

AUTOMATE_ASSEMBLYID_PATTERN = r"(?P<documentId>\w{24})_(?P<documentMicroversion>\w{24})_(?P<elementId>\w{24})"

Expand All @@ -22,9 +23,7 @@ def get_assembly_df(automate_assembly_df):
return assembly_df


if __name__ == "__main__":
client = osa.Client()

def save_all_jsons(client: Client):
if not os.path.exists("assemblies.parquet"):
automate_assembly_df = pd.read_parquet("automate_assemblies.parquet", engine="pyarrow")
assembly_df = get_assembly_df(automate_assembly_df)
Expand Down Expand Up @@ -56,3 +55,14 @@ def get_assembly_df(automate_assembly_df):
print(f"Assembly JSON saved to {json_file_path}")
except Exception as e:
print(f"An error occurred for row {index}: {e}")


if __name__ == "__main__":
client = Client()
# save_all_jsons(client)

json_file_path = "mate_relations.json"
with open(json_file_path) as json_file:
assembly_json = json.load(json_file)

assembly = Assembly.model_validate(assembly_json)
178 changes: 152 additions & 26 deletions onshape_api/models/assembly.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from enum import Enum
from typing import Union
from typing import Literal, Union

import numpy as np
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, Field, field_validator

from onshape_api.models.document import Document
from onshape_api.models.mass import MassProperties
from onshape_api.utilities.helpers import generate_uid


class InstanceType(str, Enum):
class INSTANCE_TYPE(str, Enum):
PART = "Part"
ASSEMBLY = "Assembly"


class MATETYPE(str, Enum):
class MATE_TYPE(str, Enum):
SLIDER = "SLIDER"
CYLINDRICAL = "CYLINDRICAL"
REVOLUTE = "REVOLUTE"
Expand All @@ -25,8 +25,18 @@ class MATETYPE(str, Enum):
PARALLEL = "PARALLEL"


class AssemblyFeatureType(str, Enum):
class RELATION_TYPE(str, Enum):
LINEAR = "LINEAR"
GEAR = "GEAR"
SCREW = "SCREW"
RACK_AND_PINION = "RACK_AND_PINION"


class ASSEMBLY_FEATURE_TYPE(str, Enum):
MATE = "mate"
MATERELATION = "mateRelation"
MATEGROUP = "mateGroup"
MATECONNECTOR = "mateConnector"


class Occurrence(BaseModel):
Expand Down Expand Up @@ -140,15 +150,15 @@ class PartInstance(IDBase):
"""

isStandardContent: bool
type: InstanceType
type: INSTANCE_TYPE
id: str
name: str
suppressed: bool
partId: str

@field_validator("type")
def check_type(cls, v: InstanceType) -> InstanceType:
if v != InstanceType.PART:
def check_type(cls, v: INSTANCE_TYPE) -> INSTANCE_TYPE:
if v != INSTANCE_TYPE.PART:
raise ValueError("Type must be Part")

return v
Expand Down Expand Up @@ -182,13 +192,13 @@ class AssemblyInstance(IDBase):
"""

id: str
type: InstanceType
type: INSTANCE_TYPE
name: str
suppressed: bool

@field_validator("type")
def check_type(cls, v: InstanceType) -> InstanceType:
if v != InstanceType.ASSEMBLY:
def check_type(cls, v: INSTANCE_TYPE) -> INSTANCE_TYPE:
if v != INSTANCE_TYPE.ASSEMBLY:
raise ValueError("Type must be Assembly")

return v
Expand Down Expand Up @@ -281,11 +291,15 @@ class MateFeatureData(BaseModel):
"""

matedEntities: list[MatedEntity]
mateType: MATETYPE
mateType: MATE_TYPE
name: str


class MateFeature(BaseModel):
class BaseAssemblyFeature(BaseModel):
featureType: str


class MateFeature(BaseAssemblyFeature):
"""
Feature model
{
Expand Down Expand Up @@ -324,22 +338,132 @@ class MateFeature(BaseModel):

id: str
suppressed: bool
featureType: str
featureType: Literal["MateFeature"] = "MateFeature"
featureData: MateFeatureData

# @field_validator("id")
# def check_id(cls, v):
# if len(v) != 17:
# raise ValueError("Id must have 17 characters")

# return v
class MateRelationMate(BaseModel):
"""
{
"featureId": "S4/TgCRmQt1nIHHp",
"occurrence": []
},
"""

@field_validator("featureType")
def check_featureType(cls, v: str) -> str:
if v != AssemblyFeatureType.MATE:
raise ValueError("FeatureType must be Mate")
featureId: str
occurrence: list[str]

return v

class MateRelationFeatureData(BaseModel):
"""
{
"relationType": "GEAR",
"mates": [
{
"featureId": "S4/TgCRmQt1nIHHp",
"occurrence": []
},
{
"featureId": "QwaoOeXYPifsN7CP",
"occurrence": []
}
],
"reverseDirection": false,
"relationRatio": 1,
"name": "Gear 1"
}
"""

relationType: RELATION_TYPE
mates: list[MateRelationMate]
reverseDirection: bool
relationRatio: float
name: str


class MateRelationFeature(BaseAssemblyFeature):
"""
{
"id": "amcpeia1Lm2LN2He",
"suppressed": false,
"featureType": "mateRelation",
"featureData":
{
"relationType": "GEAR",
"mates": [
{
"featureId": "S4/TgCRmQt1nIHHp",
"occurrence": []
},
{
"featureId": "QwaoOeXYPifsN7CP",
"occurrence": []
}
],
"reverseDirection": false,
"relationRatio": 1,
"name": "Gear 1"
}
},
"""

id: str
suppressed: bool
featureType: Literal["MateRelationFeature"] = "MateRelationFeature"
featureData: MateRelationFeatureData


class MateGroupFeatureOccurrence(BaseModel):
occurrence: list[str]


class MateGroupFeatureData(BaseModel):
occurrences: list[MateGroupFeatureOccurrence]
name: str


class MateGroupFeature(BaseAssemblyFeature):
id: str
suppressed: bool
featureType: Literal["MateGroupFeature"] = "MateGroupFeature"
featureData: MateGroupFeatureData


class MateConnectorFeatureData(BaseModel):
mateConnectorCS: MatedCS
occurence: list[str]
name: str


class MateConnectorFeature(BaseAssemblyFeature):
"""
{
"id": "MftzXroqpwJJDurRm",
"suppressed": false,
"featureType": "mateConnector",
"featureData": {
"mateConnectorCS": {
"xAxis": [],
"yAxis": [],
"zAxis": [],
"origin": []
},
"occurrence": [
"MplKLzV/4d+nqmD18"
],
"name": "Mate connector 1"
}
},
"""

id: str
suppressed: bool
featureType: Literal["MateConnectorFeature"] = "MateConnectorFeature"
featureData: MateConnectorFeatureData


class Pattern(BaseModel):
pass


class SubAssembly(IDBase):
Expand All @@ -348,8 +472,10 @@ class SubAssembly(IDBase):
"""

instances: list[Union[PartInstance, AssemblyInstance]]
patterns: list[dict]
features: list[MateFeature]
patterns: list[Pattern]
features: list[Union[MateFeature, MateRelationFeature, MateGroupFeature, MateConnectorFeature]] = Field(
..., discriminator="featureType"
)

@property
def uid(self) -> str:
Expand Down

0 comments on commit 558ecd8

Please sign in to comment.