Support a training pipeline via onnx-mlir
<aside>
👉
</aside>
Goals:
Given a generic onnx
model/graph as an input, you are asked to transform that onnx graph into a training-based onnx
Rrepresentation that is then ingested inonnx-mlir
(MLIR). You should describe your approach in detail, document your transformations/conversions, and include sanity checks for testing the validity of your conversions. The Final IR needs to be explicit such that all operations related to an entire training process (forwards propagation + loss function computation + backwards propagation with optimizer updates) are part of the IR via annotation (foward_conv
, backward_conv
, etc.) or have explicit calls in terms of numerical operations (reduce
, add
, multiply
, dot
)
Inputs: Onnx Graphs
Outputs: Your custom Training IR + Documentation + Outputs
Requirements
- The graph output should have support for forward and backwards prop as explained previously. Additionally, you need to consider how to incorporate the loss function as part of the graph lowering of the framework.
- All transformations/conversions must generalize and be explicitly defined. This means when you design a pass you should consider edge cases that may come up in other architectures. If you can't handle all such edge cases, mention which ones you would need to extend the pass for.
- Make this work first with static shapes, then make it generic in terms of batch size. Bonus points are awarded if you can make it batch dimension agnostic.
- Make sure your solution can generalize for a variety of loss functions (MSE, CrossEntropy) and optimizers (Adam, Adagrad, RMS-Prop, SGD, etc) as opposed to just SGD.
Acceptance Criteria
High-Level Criteria
- How easy it is to read/understand your conversion to Onnx Training IR + Final IR + Documentation.
- If end-to-end conversion works on the networks we provide you
- If end-to-end conversion works on networks we have kept hidden from you
- If you can handle static shapes versus non static shapes in terms of the IR.
Implementation Criteria
- The Final IR needs to be traversable in the sense that every operation in the compute graph can be accessed from a programming environment with a clear example of doing so in your submission.
- Clear entry points of the graph should be indicated (what is a
input
, what is a weight
a, what is an output
)and what are intermediate tensors (activations
). Gradients should be annotated as grad_{weight_node_name}
.