diff --git a/stack/mtp_base.py b/mtp_device/mtp_base.py similarity index 98% rename from stack/mtp_base.py rename to mtp_device/mtp_base.py index 2f323f9..810dded 100644 --- a/stack/mtp_base.py +++ b/mtp_device/mtp_base.py @@ -21,7 +21,7 @@ def __init__(self, info): self.info = info def get_handles(self): - return [self] + return [self.get_uid()] def get_uid(self): return self.uid diff --git a/mtp_device/mtp_device.py b/mtp_device/mtp_device.py new file mode 100644 index 0000000..ebaa822 --- /dev/null +++ b/mtp_device/mtp_device.py @@ -0,0 +1,282 @@ +from mtp_base import MtpObjectContainer, MtpEntityInfoInterface +from mtp_proto import MU16, MU32, MStr, MArray, OperationDataCodes, ResponseCodes, mtp_data +from struct import unpack + + +class MtpProtocolException(Exception): + + def __init__(self, response, msg=None): + super(MtpProtocolException, self).__init__(msg) + self.response = response + + +operations = {} + + +def operation(opcode, num_params=None, session_required=True): + ''' + Decorator for an API operation function + + :param opcode: operation code + :param num_params: number of parameter the operation expects (default: None) + :param session_required: is the operation requires a session (default: True) + ''' + + def decorator(func): + + def wrapper(self, request, response): + try: + if num_params is not None: + if request.num_params() < num_params: + raise MtpProtocolException(ResponseCodes.PARAMETER_NOT_SUPPORTED) + if session_required and (self.session_id is None): + raise MtpProtocolException(ResponseCodes.SESSION_NOT_OPEN) + res = func(self, request) + except MtpProtocolException as ex: + response.status = ex.response + res = None + return res + + if opcode in operations: + raise Exception('operation %#x already defined', opcode) + operations[opcode] = wrapper + return wrapper + + return decorator + + +class MtpRequest(object): + + def __init__(self, tid, code, params): + self.tid = tid + self.code = code + self.params = params + + def num_params(self): + return len(self.params) + + def get_param(self, idx): + if idx < len(self.params): + return self.params[idx] + return None + + @classmethod + def from_buff(cls, data): + if len(data) % 4 != 0: + raise Exception('request length (%#x) is not a multiple of four' % (len(data))) + if len(data) < 0xc: + raise Exception('request too short') + length, ctype, opcode, tid = unpack('