Skip to content

Commit

Permalink
talk: use tcolorbox + minted
Browse files Browse the repository at this point in the history
Drop minted hack to fix box alignment. Was broken with recent TeX stacks
(texlive 2023.20231207).
  • Loading branch information
elcorto committed Feb 4, 2024
1 parent bccd624 commit a46eb46
Showing 1 changed file with 58 additions and 52 deletions.
110 changes: 58 additions & 52 deletions talk/main.tex
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
\usepackage{bm}
\usepackage{xspace}
\usepackage{amsmath}
\usepackage{minted}
\usepackage{tcolorbox}
\usepackage[
backend=biber,
maxbibnames=2,
Expand All @@ -26,30 +26,35 @@
% minted
%----------------------------------------------------------------------------------%

% https://tex.stackexchange.com/a/173854
%
%% fix the minted@colorbg environment
\makeatletter
\renewenvironment{minted@colorbg}[1]
{\def\minted@bgcol{#1}%
\noindent
\begin{lrbox}{\minted@bgbox}
\begin{minipage}{\linewidth-2\fboxsep}}
{\end{minipage}%
\end{lrbox}%
\setlength{\topsep}{\bigskipamount}% set the vertical space
\trivlist\item\relax % ensure going to a new line
\colorbox{\minted@bgcol}{\usebox{\minted@bgbox}}%
\endtrivlist % close the trivlist
}
\makeatother

\definecolor{mintedbg}{rgb}{0.95,0.95,0.95}
\setminted{
bgcolor=mintedbg,
fontsize=\small,

\tcbuselibrary{minted}

\tcbset{%
% must use \tcbuselibrary{minted}
%%listing engine=minted,
minted options={fontsize=\small},
minted style=gruvbox-light,
listing only,
boxsep=0pt,
leftrule=0pt,
rightrule=0pt,
toprule=0pt,
bottomrule=0pt,
right=2pt,
top=2pt,
bottom=2pt,
}

\newtcbinputlisting[]{\mintedfromfile}[2]{%
minted language=#1,
listing file={#2},
left=20pt,
}

\newtcblisting{mintedcode}[1]{%
minted language=#1,
left=0pt,
}

%----------------------------------------------------------------------------------%
% beamer style and layout
Expand Down Expand Up @@ -178,36 +183,36 @@
f(x) = \ln(\sin^2(x))
\end{equation*}
Code:
\begin{minted}{python}
\begin{mintedcode}{python}
f = lambda x: np.log(np.power(np.sin(x), 2))
\end{minted}
\end{mintedcode}
Derivative:
\begin{equation*}
\td{f}{x} = \red{\td{c}{b}}\,\green{\td{b}{a}}\,\blue{\td{a}{x}}
= \red{\frac{1}{\sin^2(x)}}\,\green{2\,\sin(x)}\,\blue{\cos(x)}
= 2\,\frac{\cos(x)}{\sin(x)} = 2\,\cot(x)
\end{equation*}
Code:
\begin{minted}{python}
\begin{mintedcode}{python}
fprime = lambda x: 2 * np.cos(x) / np.sin(x)
\end{minted}
\end{mintedcode}
\end{frame}


\begin{frame}
\frametitle{AD teaser: arbitrary code}
\only<1>{
\inputminted{python}{code/jax_ad_teaser_func.py}
\mintedfromfile{python}{code/jax_ad_teaser_func.py}
\vfill
}
\only<2->{
\uncover<2->{
\inputminted{python}{code/jax_ad_teaser_func_w_imports.py}
\mintedfromfile{python}{code/jax_ad_teaser_func_w_imports.py}
}
}
\vspace{-0.3cm}
\uncover<3->{
\inputminted{python}{code/jax_ad_teaser_grad_usage.py}
\mintedfromfile{python}{code/jax_ad_teaser_grad_usage.py}
}
\end{frame}

Expand All @@ -221,9 +226,9 @@
\ve f(\ve x) = \ve c(\ve b(\ve a(\ve x)))
\end{equation*}
For $n=m$:
\begin{minted}{python}
\begin{mintedcode}{python}
f = lambda x: np.log(np.power(np.sin(x), 2))
\end{minted}
\end{mintedcode}
\vfill
\uncover<2->{
Jacobian:
Expand Down Expand Up @@ -313,26 +318,26 @@
\begin{frame}
\frametitle{Custom JVPs in \jax}
\only<1>{
\inputminted{python}{code/jax_defjvp_mysin_no_decorator.py}
\mintedfromfile{python}{code/jax_defjvp_mysin_no_decorator.py}
}
\only<2->{
\uncover<2->{
\inputminted{python}{code/jax_defjvp_mysin.py}
\mintedfromfile{python}{code/jax_defjvp_mysin.py}
}
}
\vspace{-0.3cm}
\uncover<3->{%
\inputminted{python}{code/jax_defjvp_with_jac.py}
\mintedfromfile{python}{code/jax_defjvp_with_jac.py}
}
\vspace{-0.3cm}
\uncover<4->{%
\inputminted{python}{code/jax_defjvp.py}
\mintedfromfile{python}{code/jax_defjvp.py}
}
\end{frame}

\begin{frame}
\frametitle{Default JVPs for \numpy primitives in \jax}
\inputminted{python}{code/jax_lax_sin.py}
\mintedfromfile{python}{code/jax_lax_sin.py}
\end{frame}

\begin{frame}
Expand Down Expand Up @@ -382,8 +387,8 @@

\begin{frame}
\frametitle{Reverse mode VJPs in \pytorch}
\url{https://github.com/pytorch/pytorch/blob/master/tools/autograd/derivatives.yaml}
\inputminted{yaml}{code/pytorch_vjp.yaml}
\url{https://github.com/pytorch/pytorch/blob/main/tools/autograd/derivatives.yaml}
\mintedfromfile{yaml}{code/pytorch_vjp.yaml}
Scripts to generate C++ code (\co{libtorch.so}) and Python bindings.
\end{frame}

Expand All @@ -397,23 +402,24 @@
\begin{equation*}
f(\ve x) = c(\ve b(\ve a(\ve x)))
\end{equation*}
\inputminted{python}{code/scalar_field.py}
\mintedfromfile{python}{code/scalar_field.py}
}
\end{frame}

\begin{frame}[fragile]
%%\begin{frame}[fragile]
\begin{frame}
\frametitle{\pytorch: Reverse mode example}
\begin{equation*}
f: \mathbb R^n\ra \mathbb R, \quad f(\ve x) = c(\ve b(\ve a(\ve x)))
\end{equation*}
\inputminted[xleftmargin=.5cm]{python}{code/pytorch_fwd_rev_1.py}
\mintedfromfile{python}{code/pytorch_fwd_rev_1.py}
\uncover<2->{
Forward pass, tracing, observe \texttt{grad\textunderscore fn}
\inputminted[xleftmargin=.5cm]{python}{code/pytorch_fwd_rev_2.py}
\mintedfromfile{python}{code/pytorch_fwd_rev_2.py}
}
\uncover<3->{
Backward pass
\inputminted[xleftmargin=.5cm]{python}{code/pytorch_fwd_rev_3.py}
\mintedfromfile{python}{code/pytorch_fwd_rev_3.py}
}
\end{frame}

Expand All @@ -430,26 +436,26 @@
\end{align*}
\uncover<2->{
Default: initialize backward pass with $\pdi{f}{f}=1$
\inputminted[xleftmargin=.5cm]{python}{code/pytorch_rev_detail_1.py}
\mintedfromfile{python}{code/pytorch_rev_detail_1.py}
}
\uncover<3->{
VJP: extract one row of Jacobian $\pdi{\ve b}{\ve a}$
\inputminted[xleftmargin=.5cm]{python}{code/pytorch_rev_detail_2.py}
\mintedfromfile{python}{code/pytorch_rev_detail_2.py}
}
\end{frame}

\begin{frame}[fragile]
\frametitle{Higher order fun}
\inputminted{python}{code/higher_order_1.py}
\mintedfromfile{python}{code/higher_order_1.py}
\vspace{-0.3cm}
\uncover<2->{
\inputminted{python}{code/higher_order_2.py}
\mintedfromfile{python}{code/higher_order_2.py}
}
\uncover<3->{
$\ma J:=\partial f; \ma H = \partial^2 f = \partial(\partial f) = \partial(\nabla f)$ for $f:\mathbb R^n\ra\mathbb R$
}
\uncover<4->{
\inputminted{python}{code/higher_order_hessian.py}
\mintedfromfile{python}{code/higher_order_hessian.py}
}
\end{frame}

Expand All @@ -459,15 +465,15 @@
\begin{itemize}
\item restricted to \numpy/\pytorch-based code or need to wrap target package
\item in-place ops (\co{a[i] *= 3}) can break AD, \pytorch: Tensor versioning system
\begin{minted}{python}
\begin{mintedcode}{python}
RuntimeError: one of the variables needed for gradient
computation has been modified by an inplace operation:
[torch.FloatTensor []], which is output 0 of
SelectBackward, is at version 55; expected version 51
instead. Hint: enable anomaly detection to find the
operation that failed to compute its gradient, with
torch.autograd.set_detect_anomaly(True).
\end{minted}
\end{mintedcode}
\item numerical accuracy a.k.a. when do you want \verb|custom_jvp|
\begin{itemize}
\item AD derivatives can generate numerically unstable code
Expand All @@ -485,7 +491,7 @@

\begin{frame}
\frametitle{\texttt{jaxpr}}
\inputminted{python}{code/jaxpr.py}
\mintedfromfile{python}{code/jaxpr.py}
\end{frame}

\end{document}

0 comments on commit a46eb46

Please sign in to comment.