Date: 2026-05-23
The notebook source code for this article is available on GitHub.
Transfer learning is the practice of adapting a pre-trained model for a new but related task. An example of transfer learning is low-rank adaptation (LoRA), a common technique in LLM fine-tuning.
In this notebook experiment, let’s load ResNet-18 with MindCV for fine-tuning on the CIFAR-100 dataset. ResNet-18 is a modern CNN pre-trained on ImageNet which makes it suitable for our task as both datasets represent image classification tasks. Furthermore, ImageNet contains far more samples and categories than CIFAR-100, meaning the pre-trained filters from ImageNet should translate well to CIFAR-100 with minor adjustments.
For the best experience following this notebook experiment, it’s recommended to go through the first 8 chapters of the D2L textbook or have equivalent experience with machine learning.
This notebook experiment was performed on the OrangePi AIpro (20T) development board. The software used in this notebook are listed below.
51b0636da1e38a44a52edf76cf314d1c7c18883a (2025-07-24)!cat requirements.txt
absl-py==2.4.0
attrs==26.1.0
cloudpickle==3.1.2
decorator==5.2.1
git+https://github.com/mindspore-lab/mindcv.git@51b0636da1e38a44a52edf76cf314d1c7c18883a
jupyterlab==4.5.7
jupyterlab-git==0.53.0
jupyter-resource-usage==1.2.1
loguru==0.7.3
matplotlib==3.10.9
mindspore==2.9.0
ml-dtypes==0.5.4
msguard==0.0.8
openpyxl==3.1.5
opentelemetry-exporter-otlp-proto-grpc==1.33.1
opentelemetry-exporter-otlp-proto-http==1.33.1
pandas~=2.2
plotly>=5.11.0
pydantic==2.13.4
sympy==1.14.0
tornado==6.5.5
%pip install -r requirements.txt
Collecting git+https://github.com/mindspore-lab/mindcv.git@51b0636da1e38a44a52edf76cf314d1c7c18883a (from -r requirements.txt (line 5))
Cloning https://github.com/mindspore-lab/mindcv.git (to revision 51b0636da1e38a44a52edf76cf314d1c7c18883a) to /tmp/pip-req-build-athjd9sc
Running command git clone --filter=blob:none --quiet https://github.com/mindspore-lab/mindcv.git /tmp/pip-req-build-athjd9sc
Running command git rev-parse -q --verify 'sha^51b0636da1e38a44a52edf76cf314d1c7c18883a'
Running command git fetch -q https://github.com/mindspore-lab/mindcv.git 51b0636da1e38a44a52edf76cf314d1c7c18883a
Resolved https://github.com/mindspore-lab/mindcv.git to commit 51b0636da1e38a44a52edf76cf314d1c7c18883a
Installing build dependencies ... [?25ldone
[?25h Getting requirements to build wheel ... [?25ldone
[?25h Preparing metadata (pyproject.toml) ... [?25ldone
[?25hRequirement already satisfied: absl-py==2.4.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 1)) (2.4.0)
Requirement already satisfied: attrs==26.1.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 2)) (26.1.0)
Requirement already satisfied: cloudpickle==3.1.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 3)) (3.1.2)
Requirement already satisfied: decorator==5.2.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 4)) (5.2.1)
Requirement already satisfied: jupyterlab==4.5.7 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 6)) (4.5.7)
Requirement already satisfied: jupyterlab-git==0.53.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 7)) (0.53.0)
Requirement already satisfied: jupyter-resource-usage==1.2.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 8)) (1.2.1)
Requirement already satisfied: loguru==0.7.3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 9)) (0.7.3)
Requirement already satisfied: matplotlib==3.10.9 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 10)) (3.10.9)
Requirement already satisfied: mindspore==2.9.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 11)) (2.9.0)
Requirement already satisfied: ml-dtypes==0.5.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 12)) (0.5.4)
Requirement already satisfied: msguard==0.0.8 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 13)) (0.0.8)
Requirement already satisfied: openpyxl==3.1.5 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 14)) (3.1.5)
Requirement already satisfied: opentelemetry-exporter-otlp-proto-grpc==1.33.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 15)) (1.33.1)
Requirement already satisfied: opentelemetry-exporter-otlp-proto-http==1.33.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 16)) (1.33.1)
Requirement already satisfied: pandas~=2.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 17)) (2.3.3)
Requirement already satisfied: plotly>=5.11.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 18)) (6.7.0)
Requirement already satisfied: pydantic==2.13.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 19)) (2.13.4)
Requirement already satisfied: sympy==1.14.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 20)) (1.14.0)
Requirement already satisfied: tornado==6.5.5 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 21)) (6.5.5)
Requirement already satisfied: async-lru>=1.0.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.7->-r requirements.txt (line 6)) (2.3.0)
Requirement already satisfied: httpx<1,>=0.25.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.28.1)
Requirement already satisfied: ipykernel!=6.30.0,>=6.5.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.7->-r requirements.txt (line 6)) (7.2.0)
Requirement already satisfied: jinja2>=3.0.3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.7->-r requirements.txt (line 6)) (3.1.6)
Requirement already satisfied: jupyter-core in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.7->-r requirements.txt (line 6)) (5.9.1)
Requirement already satisfied: jupyter-lsp>=2.0.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.7->-r requirements.txt (line 6)) (2.3.1)
Requirement already satisfied: jupyter-server<3,>=2.4.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.7->-r requirements.txt (line 6)) (2.18.2)
Requirement already satisfied: jupyterlab-server<3,>=2.28.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.7->-r requirements.txt (line 6)) (2.28.0)
Requirement already satisfied: notebook-shim>=0.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.2.4)
Requirement already satisfied: packaging in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.7->-r requirements.txt (line 6)) (26.2)
Requirement already satisfied: setuptools>=41.1.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.7->-r requirements.txt (line 6)) (82.0.1)
Requirement already satisfied: traitlets in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.7->-r requirements.txt (line 6)) (5.15.0)
Requirement already satisfied: jupyterlab-git-core>=0.52.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab-git-core[nbdime]>=0.52.0->jupyterlab-git==0.53.0->-r requirements.txt (line 7)) (0.53.0)
Requirement already satisfied: prometheus-client in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-resource-usage==1.2.1->-r requirements.txt (line 8)) (0.25.0)
Requirement already satisfied: psutil>=5.6 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-resource-usage==1.2.1->-r requirements.txt (line 8)) (7.2.2)
Requirement already satisfied: pyzmq>=19 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-resource-usage==1.2.1->-r requirements.txt (line 8)) (27.1.0)
Requirement already satisfied: contourpy>=1.0.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.9->-r requirements.txt (line 10)) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.9->-r requirements.txt (line 10)) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.9->-r requirements.txt (line 10)) (4.63.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.9->-r requirements.txt (line 10)) (1.5.0)
Requirement already satisfied: numpy>=1.23 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.9->-r requirements.txt (line 10)) (1.26.4)
Requirement already satisfied: pillow>=8 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.9->-r requirements.txt (line 10)) (12.2.0)
Requirement already satisfied: pyparsing>=3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.9->-r requirements.txt (line 10)) (3.3.2)
Requirement already satisfied: python-dateutil>=2.7 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.9->-r requirements.txt (line 10)) (2.9.0.post0)
Requirement already satisfied: protobuf>=3.13.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindspore==2.9.0->-r requirements.txt (line 11)) (5.29.6)
Requirement already satisfied: asttokens>=2.0.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindspore==2.9.0->-r requirements.txt (line 11)) (3.0.1)
Requirement already satisfied: scipy>=1.5.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindspore==2.9.0->-r requirements.txt (line 11)) (1.17.1)
Requirement already satisfied: astunparse>=1.6.3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindspore==2.9.0->-r requirements.txt (line 11)) (1.6.3)
Requirement already satisfied: safetensors>=0.4.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindspore==2.9.0->-r requirements.txt (line 11)) (0.7.0)
Requirement already satisfied: dill>=0.3.7 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindspore==2.9.0->-r requirements.txt (line 11)) (0.4.1)
Requirement already satisfied: et-xmlfile in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from openpyxl==3.1.5->-r requirements.txt (line 14)) (2.0.0)
Requirement already satisfied: deprecated>=1.2.6 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from opentelemetry-exporter-otlp-proto-grpc==1.33.1->-r requirements.txt (line 15)) (1.3.1)
Requirement already satisfied: googleapis-common-protos~=1.52 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from opentelemetry-exporter-otlp-proto-grpc==1.33.1->-r requirements.txt (line 15)) (1.75.0)
Requirement already satisfied: grpcio<2.0.0,>=1.63.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from opentelemetry-exporter-otlp-proto-grpc==1.33.1->-r requirements.txt (line 15)) (1.80.0)
Requirement already satisfied: opentelemetry-api~=1.15 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from opentelemetry-exporter-otlp-proto-grpc==1.33.1->-r requirements.txt (line 15)) (1.33.1)
Requirement already satisfied: opentelemetry-exporter-otlp-proto-common==1.33.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from opentelemetry-exporter-otlp-proto-grpc==1.33.1->-r requirements.txt (line 15)) (1.33.1)
Requirement already satisfied: opentelemetry-proto==1.33.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from opentelemetry-exporter-otlp-proto-grpc==1.33.1->-r requirements.txt (line 15)) (1.33.1)
Requirement already satisfied: opentelemetry-sdk~=1.33.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from opentelemetry-exporter-otlp-proto-grpc==1.33.1->-r requirements.txt (line 15)) (1.33.1)
Requirement already satisfied: requests~=2.7 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from opentelemetry-exporter-otlp-proto-http==1.33.1->-r requirements.txt (line 16)) (2.34.2)
Requirement already satisfied: annotated-types>=0.6.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from pydantic==2.13.4->-r requirements.txt (line 19)) (0.7.0)
Requirement already satisfied: pydantic-core==2.46.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from pydantic==2.13.4->-r requirements.txt (line 19)) (2.46.4)
Requirement already satisfied: typing-extensions>=4.14.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from pydantic==2.13.4->-r requirements.txt (line 19)) (4.15.0)
Requirement already satisfied: typing-inspection>=0.4.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from pydantic==2.13.4->-r requirements.txt (line 19)) (0.4.2)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from sympy==1.14.0->-r requirements.txt (line 20)) (1.3.0)
Requirement already satisfied: PyYAML>=5.3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindcv==0.5.0->-r requirements.txt (line 5)) (6.0.3)
Requirement already satisfied: tqdm in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindcv==0.5.0->-r requirements.txt (line 5)) (4.67.3)
Requirement already satisfied: pytz>=2020.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from pandas~=2.2->-r requirements.txt (line 17)) (2026.2)
Requirement already satisfied: tzdata>=2022.7 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from pandas~=2.2->-r requirements.txt (line 17)) (2026.2)
Requirement already satisfied: narwhals>=1.15.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from plotly>=5.11.0->-r requirements.txt (line 18)) (2.21.2)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from astunparse>=1.6.3->mindspore==2.9.0->-r requirements.txt (line 11)) (0.47.0)
Requirement already satisfied: six<2.0,>=1.6.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from astunparse>=1.6.3->mindspore==2.9.0->-r requirements.txt (line 11)) (1.17.0)
Requirement already satisfied: wrapt<3,>=1.10 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from deprecated>=1.2.6->opentelemetry-exporter-otlp-proto-grpc==1.33.1->-r requirements.txt (line 15)) (2.2.0)
Requirement already satisfied: anyio in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from httpx<1,>=0.25.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (4.13.0)
Requirement already satisfied: certifi in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from httpx<1,>=0.25.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (2026.5.20)
Requirement already satisfied: httpcore==1.* in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from httpx<1,>=0.25.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (1.0.9)
Requirement already satisfied: idna in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from httpx<1,>=0.25.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (3.16)
Requirement already satisfied: h11>=0.16 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from httpcore==1.*->httpx<1,>=0.25.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.16.0)
Requirement already satisfied: comm>=0.1.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.2.3)
Requirement already satisfied: debugpy>=1.6.5 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (1.8.20)
Requirement already satisfied: ipython>=7.23.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (9.13.0)
Requirement already satisfied: jupyter-client>=8.8.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (8.8.0)
Requirement already satisfied: matplotlib-inline>=0.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.2.2)
Requirement already satisfied: nest-asyncio>=1.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (1.6.0)
Requirement already satisfied: MarkupSafe>=2.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jinja2>=3.0.3->jupyterlab==4.5.7->-r requirements.txt (line 6)) (3.0.3)
Requirement already satisfied: platformdirs>=2.5 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-core->jupyterlab==4.5.7->-r requirements.txt (line 6)) (4.9.6)
Requirement already satisfied: argon2-cffi>=21.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (25.1.0)
Requirement already satisfied: jupyter-events>=0.11.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.12.1)
Requirement already satisfied: jupyter-server-terminals>=0.4.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.5.4)
Requirement already satisfied: nbconvert>=6.4.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (7.17.1)
Requirement already satisfied: nbformat>=5.3.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (5.10.4)
Requirement already satisfied: send2trash>=1.8.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (2.1.0)
Requirement already satisfied: terminado>=0.8.3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.18.1)
Requirement already satisfied: websocket-client>=1.7 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (1.9.0)
Requirement already satisfied: pexpect in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab-git-core>=0.52.0->jupyterlab-git-core[nbdime]>=0.52.0->jupyterlab-git==0.53.0->-r requirements.txt (line 7)) (4.9.0)
Requirement already satisfied: nbdime~=4.0.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab-git-core[nbdime]>=0.52.0->jupyterlab-git==0.53.0->-r requirements.txt (line 7)) (4.0.4)
Requirement already satisfied: babel>=2.10 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (2.18.0)
Requirement already satisfied: json5>=0.9.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.14.0)
Requirement already satisfied: jsonschema>=4.18.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (4.26.0)
Requirement already satisfied: importlib-metadata<8.7.0,>=6.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from opentelemetry-api~=1.15->opentelemetry-exporter-otlp-proto-grpc==1.33.1->-r requirements.txt (line 15)) (8.6.1)
Requirement already satisfied: opentelemetry-semantic-conventions==0.54b1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from opentelemetry-sdk~=1.33.1->opentelemetry-exporter-otlp-proto-grpc==1.33.1->-r requirements.txt (line 15)) (0.54b1)
Requirement already satisfied: charset_normalizer<4,>=2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from requests~=2.7->opentelemetry-exporter-otlp-proto-http==1.33.1->-r requirements.txt (line 16)) (3.4.7)
Requirement already satisfied: urllib3<3,>=1.26 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from requests~=2.7->opentelemetry-exporter-otlp-proto-http==1.33.1->-r requirements.txt (line 16)) (2.7.0)
Requirement already satisfied: argon2-cffi-bindings in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from argon2-cffi>=21.1->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (25.1.0)
Requirement already satisfied: zipp>=3.20 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from importlib-metadata<8.7.0,>=6.0->opentelemetry-api~=1.15->opentelemetry-exporter-otlp-proto-grpc==1.33.1->-r requirements.txt (line 15)) (4.1.0)
Requirement already satisfied: ipython-pygments-lexers>=1.0.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (1.1.1)
Requirement already satisfied: jedi>=0.18.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.20.0)
Requirement already satisfied: prompt_toolkit<3.1.0,>=3.0.41 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (3.0.52)
Requirement already satisfied: pygments>=2.14.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (2.20.0)
Requirement already satisfied: stack_data>=0.6.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.6.3)
Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (2025.9.1)
Requirement already satisfied: referencing>=0.28.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.37.0)
Requirement already satisfied: rpds-py>=0.25.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.30.0)
Requirement already satisfied: python-json-logger>=2.0.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (4.1.0)
Requirement already satisfied: rfc3339-validator in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.1.4)
Requirement already satisfied: rfc3986-validator>=0.1.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.1.1)
Requirement already satisfied: beautifulsoup4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (4.14.3)
Requirement already satisfied: bleach!=5.0.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from bleach[css]!=5.0.0->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (6.3.0)
Requirement already satisfied: defusedxml in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.7.1)
Requirement already satisfied: jupyterlab-pygments in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.3.0)
Requirement already satisfied: mistune<4,>=2.0.3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (3.2.1)
Requirement already satisfied: nbclient>=0.5.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.10.4)
Requirement already satisfied: pandocfilters>=1.4.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (1.5.1)
Requirement already satisfied: colorama in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbdime~=4.0.1->jupyterlab-git-core[nbdime]>=0.52.0->jupyterlab-git==0.53.0->-r requirements.txt (line 7)) (0.4.6)
Requirement already satisfied: gitpython!=2.1.4,!=2.1.5,!=2.1.6 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbdime~=4.0.1->jupyterlab-git-core[nbdime]>=0.52.0->jupyterlab-git==0.53.0->-r requirements.txt (line 7)) (3.1.50)
Requirement already satisfied: fastjsonschema>=2.15 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbformat>=5.3.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (2.21.2)
Requirement already satisfied: ptyprocess>=0.5 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from pexpect->jupyterlab-git-core>=0.52.0->jupyterlab-git-core[nbdime]>=0.52.0->jupyterlab-git==0.53.0->-r requirements.txt (line 7)) (0.7.0)
Requirement already satisfied: webencodings in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from bleach!=5.0.0->bleach[css]!=5.0.0->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.5.1)
Requirement already satisfied: tinycss2<1.5,>=1.1.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from bleach[css]!=5.0.0->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (1.4.0)
Requirement already satisfied: gitdb<5,>=4.0.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from gitpython!=2.1.4,!=2.1.5,!=2.1.6->nbdime~=4.0.1->jupyterlab-git-core[nbdime]>=0.52.0->jupyterlab-git==0.53.0->-r requirements.txt (line 7)) (4.0.12)
Requirement already satisfied: parso<0.9.0,>=0.8.6 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jedi>=0.18.2->ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.8.7)
Requirement already satisfied: fqdn in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (1.5.1)
Requirement already satisfied: isoduration in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (20.11.0)
Requirement already satisfied: jsonpointer>1.13 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (3.1.1)
Requirement already satisfied: rfc3987-syntax>=1.1.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (1.1.0)
Requirement already satisfied: uri-template in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (1.3.0)
Requirement already satisfied: webcolors>=24.6.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (25.10.0)
Requirement already satisfied: wcwidth in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from prompt_toolkit<3.1.0,>=3.0.41->ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.7.0)
Requirement already satisfied: executing>=1.2.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from stack_data>=0.6.0->ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (2.2.1)
Requirement already satisfied: pure-eval in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from stack_data>=0.6.0->ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (0.2.3)
Requirement already satisfied: cffi>=1.0.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from argon2-cffi-bindings->argon2-cffi>=21.1->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (2.0.0)
Requirement already satisfied: soupsieve>=1.6.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from beautifulsoup4->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (2.8.3)
Requirement already satisfied: pycparser in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi>=21.1->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (3.0)
Requirement already satisfied: smmap<6,>=3.0.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from gitdb<5,>=4.0.1->gitpython!=2.1.4,!=2.1.5,!=2.1.6->nbdime~=4.0.1->jupyterlab-git-core[nbdime]>=0.52.0->jupyterlab-git==0.53.0->-r requirements.txt (line 7)) (5.0.3)
Requirement already satisfied: lark>=1.2.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from rfc3987-syntax>=1.1.0->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (1.3.1)
Requirement already satisfied: arrow>=0.15.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.7->-r requirements.txt (line 6)) (1.4.0)
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m26.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
import mindspore
mindspore.set_device(device_target='Ascend', device_id=0)
mindspore.run_check()
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
setattr(self, word, getattr(machar, word).flat[0])
/home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
return self._float_to_str(self.smallest_subnormal)
/home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
setattr(self, word, getattr(machar, word).flat[0])
/home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
return self._float_to_str(self.smallest_subnormal)
MindSpore version: 2.9.0
The result of multiplication calculation is correct, MindSpore has been installed on platform [Ascend] successfully!
The mindcv.create_dataset allows us to load the CIFAR-100 dataset with a single line of code. Let’s take just the first 4 samples and visualize them with Matplotlib.
import mindcv
visual_ds = mindcv.create_dataset('cifar100', split='train', download=True)
visual_ds = visual_ds.batch(batch_size=4)
visual_ds
<mindspore.dataset.engine.datasets.BatchDataset at 0xe7ff49ef0f50>
Our dataset contains 3 columns.
image: the image containing the object we wish to classifycoarse_label: a coarse label indicating the broad category that our object belongs to, e.g. insects, flowers, or people. There are 20 coarse labels in CIFAR-100fine_label: a fine label indicating the precise category that our objects belongs to, e.g. bee, orchid or girl. There are 100 fine labels in CIFAR-100Since CIFAR-100 includes both coarse and fine labels with a hierarchical relationship between the two, i.e. we can’t have a flower which is a girl, it’s a canonical example of hierarchical classification.
visual_ds.get_col_names()
['image', 'coarse_label', 'fine_label']
Let’s inspect the shapes of our sample features and labels.
X_samples, y0_samples, y1_samples = next(visual_ds.create_tuple_iterator())
X_samples.shape, y0_samples.shape, y1_samples.shape
((4, 32, 32, 3), (4,), (4,))
Below, we define mappings from the raw integer labels to their human-readable equivalents. This mapping is taken from the CIFAR-100 dataset on Hugging Face.
y0_labels = [
'aquatic_mammals',
'fish',
'flowers',
'food_containers',
'fruit_and_vegetables',
'household_electrical_devices',
'household_furniture',
'insects',
'large_carnivores',
'large_man-made_outdoor_things',
'large_natural_outdoor_scenes',
'large_omnivores_and_herbivores',
'medium_mammals',
'non-insect_invertebrates',
'people',
'reptiles',
'small_mammals',
'trees',
'vehicles_1',
'vehicles_2'
]
len(y0_labels)
20
y1_labels = [
'apple',
'aquarium_fish',
'baby',
'bear',
'beaver',
'bed',
'bee',
'beetle',
'bicycle',
'bottle',
'bowl',
'boy',
'bridge',
'bus',
'butterfly',
'camel',
'can',
'castle',
'caterpillar',
'cattle',
'chair',
'chimpanzee',
'clock',
'cloud',
'cockroach',
'couch',
'cra',
'crocodile',
'cup',
'dinosaur',
'dolphin',
'elephant',
'flatfish',
'forest',
'fox',
'girl',
'hamster',
'house',
'kangaroo',
'keyboard',
'lamp',
'lawn_mower',
'leopard',
'lion',
'lizard',
'lobster',
'man',
'maple_tree',
'motorcycle',
'mountain',
'mouse',
'mushroom',
'oak_tree',
'orange',
'orchid',
'otter',
'palm_tree',
'pear',
'pickup_truck',
'pine_tree',
'plain',
'plate',
'poppy',
'porcupine',
'possum',
'rabbit',
'raccoon',
'ray',
'road',
'rocket',
'rose',
'sea',
'seal',
'shark',
'shrew',
'skunk',
'skyscraper',
'snail',
'snake',
'spider',
'squirrel',
'streetcar',
'sunflower',
'sweet_pepper',
'table',
'tank',
'telephone',
'television',
'tiger',
'tractor',
'train',
'trout',
'tulip',
'turtle',
'wardrobe',
'whale',
'willow_tree',
'wolf',
'woman',
'worm'
]
len(y1_labels)
100
Now we visualize our 4 samples with Matplotlib. The main reason we chose 4 samples instead of 10 as we have done before is because the names for some of these categories are rather long and it’s difficult to get Matplotlib to get them to display properly without overlapping text and images.
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 2)
for axis_idx, axis in enumerate(axes.flatten()):
axis.set_title(f'{y0_labels[y0_samples[axis_idx]]}\n{y1_labels[y1_samples[axis_idx]]}')
axis.imshow(X_samples[axis_idx])
axis.axis('off')
fig.tight_layout()
fig.show()

Let’s load both the training and validation sets this time and pass them through our standard data pre-processing pipeline.
The inputs are transformed via our standard image-processing pipeline.
Let’s one-hot encode our labels as well. This applies to both the coarse and fine labels. We’ll split our training and validation sets into batches of $2 ^ 7 = 128$ samples for fine-tuning.
train_ds = mindcv.create_dataset('cifar100', split='train', download=True)
test_ds = mindcv.create_dataset('cifar100', split='test', download=True)
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
from mindspore import dtype as mstype
def transform_ds(dataset):
image_transforms = [
vision.Resize(size=(64, 64)),
vision.Rescale(rescale=1/255, shift=0),
vision.HWC2CHW()
]
coarse_label_transforms = [
transforms.OneHot(num_classes=20),
transforms.TypeCast(data_type=mstype.float32)
]
fine_label_transforms = [
transforms.OneHot(num_classes=100),
transforms.TypeCast(data_type=mstype.float32)
]
dataset = dataset.map(operations=image_transforms, input_columns='image')
dataset = dataset.map(operations=coarse_label_transforms, input_columns='coarse_label')
dataset = dataset.map(operations=fine_label_transforms, input_columns='fine_label')
dataset = dataset.batch(batch_size=128, drop_remainder=False)
return dataset
train_ds, test_ds = transform_ds(dataset=train_ds), transform_ds(dataset=test_ds)
train_ds, test_ds
(<mindspore.dataset.engine.datasets.BatchDataset at 0xe7feec15c230>,
<mindspore.dataset.engine.datasets.BatchDataset at 0xe7feec15d7f0>)
Although it appears undocumented, we can use list_models from MindCV with a wildcard name resnet* to list the available pre-trained ResNet models. We’ll use ResNet-18 so the fine-tuning process completes on our development board within a reasonable timeframe, e.g. 2 hours.
resnet_pretrained = mindcv.list_models('resnet*', pretrained=True)
resnet_pretrained
['resnet101',
'resnet152',
'resnet18',
'resnet34',
'resnet50',
'resnetv2_101',
'resnetv2_50']
Load the pretrained model with mindcv.create_model.
resnet18 = mindcv.create_model('resnet18', pretrained=True)
resnet18
ResNet(
(conv1): Conv2d(input_channels=3, output_channels=64, kernel_size=(7, 7), stride=(2, 2), pad_mode=pad, padding=3, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec36a240>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=bn1.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=bn1.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=bn1.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=bn1.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(max_pool): MaxPool2d(kernel_size=3, stride=2, pad_mode=SAME)
(layer1): SequentialCell(
(0): BasicBlock(
(conv1): Conv2d(input_channels=64, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec369b20>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer1.0.bn1.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer1.0.bn1.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer1.0.bn1.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer1.0.bn1.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=64, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff2117b9e0>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer1.0.bn2.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer1.0.bn2.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer1.0.bn2.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer1.0.bn2.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
)
(1): BasicBlock(
(conv1): Conv2d(input_channels=64, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff211b1490>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer1.1.bn1.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer1.1.bn1.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer1.1.bn1.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer1.1.bn1.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=64, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec15cfb0>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer1.1.bn2.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer1.1.bn2.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer1.1.bn2.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer1.1.bn2.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
)
)
(layer2): SequentialCell(
(0): BasicBlock(
(conv1): Conv2d(input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff49e71550>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer2.0.bn1.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer2.0.bn1.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer2.0.bn1.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer2.0.bn1.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=128, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff211b1d30>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer2.0.bn2.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer2.0.bn2.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer2.0.bn2.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer2.0.bn2.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
(down_sample): SequentialCell(
(0): Conv2d(input_channels=64, output_channels=128, kernel_size=(1, 1), stride=(2, 2), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec0d3920>, bias_init=None, format=NCHW)
(1): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer2.0.down_sample.1.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer2.0.down_sample.1.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer2.0.down_sample.1.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer2.0.down_sample.1.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
)
)
(1): BasicBlock(
(conv1): Conv2d(input_channels=128, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec156de0>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer2.1.bn1.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer2.1.bn1.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer2.1.bn1.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer2.1.bn1.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=128, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff0c14f1d0>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer2.1.bn2.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer2.1.bn2.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer2.1.bn2.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer2.1.bn2.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
)
)
(layer3): SequentialCell(
(0): BasicBlock(
(conv1): Conv2d(input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec15d190>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer3.0.bn1.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer3.0.bn1.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer3.0.bn1.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer3.0.bn1.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=256, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec126750>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer3.0.bn2.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer3.0.bn2.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer3.0.bn2.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer3.0.bn2.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
(down_sample): SequentialCell(
(0): Conv2d(input_channels=128, output_channels=256, kernel_size=(1, 1), stride=(2, 2), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff211b1820>, bias_init=None, format=NCHW)
(1): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer3.0.down_sample.1.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer3.0.down_sample.1.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer3.0.down_sample.1.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer3.0.down_sample.1.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
)
)
(1): BasicBlock(
(conv1): Conv2d(input_channels=256, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec36b980>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer3.1.bn1.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer3.1.bn1.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer3.1.bn1.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer3.1.bn1.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=256, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff21118320>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer3.1.bn2.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer3.1.bn2.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer3.1.bn2.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer3.1.bn2.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
)
)
(layer4): SequentialCell(
(0): BasicBlock(
(conv1): Conv2d(input_channels=256, output_channels=512, kernel_size=(3, 3), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec0f9400>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer4.0.bn1.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer4.0.bn1.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer4.0.bn1.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer4.0.bn1.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=512, output_channels=512, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec4c33e0>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer4.0.bn2.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer4.0.bn2.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer4.0.bn2.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer4.0.bn2.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
(down_sample): SequentialCell(
(0): Conv2d(input_channels=256, output_channels=512, kernel_size=(1, 1), stride=(2, 2), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec126330>, bias_init=None, format=NCHW)
(1): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer4.0.down_sample.1.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer4.0.down_sample.1.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer4.0.down_sample.1.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer4.0.down_sample.1.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
)
)
(1): BasicBlock(
(conv1): Conv2d(input_channels=512, output_channels=512, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec2e5850>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer4.1.bn1.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer4.1.bn1.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer4.1.bn1.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer4.1.bn1.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=512, output_channels=512, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec5b6a80>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=layer4.1.bn2.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=layer4.1.bn2.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=layer4.1.bn2.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=layer4.1.bn2.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
)
)
(pool): GlobalAvgPooling()
(classifier): Dense(input_channels=512, output_channels=1000, has_bias=True)
)
For our hierarchical classification task, let’s reuse the stem and body from ResNet-18 and replace the original head consisting of a single classifier with 2 parallel FC layers - 1 for predicting the coarse labels and 1 for predicting the fine labels.
We’ll then concatenate the $20 + 100$ logits from the coarse + fine predictions into a single tensor and return the result.
import mindspore.nn as nn
class MultiHeadResNet18(nn.Cell):
def __init__(self, resnet18):
super().__init__()
self.backbone = nn.SequentialCell([
resnet18.conv1,
resnet18.bn1,
resnet18.relu,
resnet18.max_pool,
resnet18.layer1,
resnet18.layer2,
resnet18.layer3,
resnet18.layer4,
resnet18.pool,
nn.Flatten()
])
self.coarse_classifier = nn.Dense(512, 20)
self.fine_classifier = nn.Dense(512, 100)
def construct(self, X):
y_hat = self.backbone(X)
coarse_logits = self.coarse_classifier(y_hat)
fine_logits = self.fine_classifier(y_hat)
return ops.concat((coarse_logits, fine_logits), axis=-1)
multihead_resnet18 = MultiHeadResNet18(resnet18=resnet18)
multihead_resnet18
MultiHeadResNet18(
(backbone): SequentialCell(
(0): Conv2d(input_channels=3, output_channels=64, kernel_size=(7, 7), stride=(2, 2), pad_mode=pad, padding=3, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec36a240>, bias_init=None, format=NCHW)
(1): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.1.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.1.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.1.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.1.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
(2): ReLU()
(3): MaxPool2d(kernel_size=3, stride=2, pad_mode=SAME)
(4): SequentialCell(
(0): BasicBlock(
(conv1): Conv2d(input_channels=64, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec369b20>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.4.0.bn1.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.4.0.bn1.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.4.0.bn1.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.4.0.bn1.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=64, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff2117b9e0>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.4.0.bn2.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.4.0.bn2.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.4.0.bn2.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.4.0.bn2.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
)
(1): BasicBlock(
(conv1): Conv2d(input_channels=64, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff211b1490>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.4.1.bn1.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.4.1.bn1.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.4.1.bn1.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.4.1.bn1.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=64, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec15cfb0>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.4.1.bn2.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.4.1.bn2.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.4.1.bn2.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.4.1.bn2.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
)
)
(5): SequentialCell(
(0): BasicBlock(
(conv1): Conv2d(input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff49e71550>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.5.0.bn1.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.5.0.bn1.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.5.0.bn1.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.5.0.bn1.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=128, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff211b1d30>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.5.0.bn2.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.5.0.bn2.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.5.0.bn2.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.5.0.bn2.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
(down_sample): SequentialCell(
(0): Conv2d(input_channels=64, output_channels=128, kernel_size=(1, 1), stride=(2, 2), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec0d3920>, bias_init=None, format=NCHW)
(1): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.5.0.down_sample.1.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.5.0.down_sample.1.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.5.0.down_sample.1.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.5.0.down_sample.1.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
)
)
(1): BasicBlock(
(conv1): Conv2d(input_channels=128, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec156de0>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.5.1.bn1.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.5.1.bn1.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.5.1.bn1.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.5.1.bn1.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=128, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff0c14f1d0>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.5.1.bn2.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.5.1.bn2.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.5.1.bn2.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.5.1.bn2.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
)
)
(6): SequentialCell(
(0): BasicBlock(
(conv1): Conv2d(input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec15d190>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.6.0.bn1.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.6.0.bn1.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.6.0.bn1.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.6.0.bn1.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=256, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec126750>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.6.0.bn2.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.6.0.bn2.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.6.0.bn2.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.6.0.bn2.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
(down_sample): SequentialCell(
(0): Conv2d(input_channels=128, output_channels=256, kernel_size=(1, 1), stride=(2, 2), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff211b1820>, bias_init=None, format=NCHW)
(1): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.6.0.down_sample.1.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.6.0.down_sample.1.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.6.0.down_sample.1.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.6.0.down_sample.1.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
)
)
(1): BasicBlock(
(conv1): Conv2d(input_channels=256, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec36b980>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.6.1.bn1.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.6.1.bn1.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.6.1.bn1.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.6.1.bn1.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=256, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff21118320>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.6.1.bn2.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.6.1.bn2.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.6.1.bn2.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.6.1.bn2.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
)
)
(7): SequentialCell(
(0): BasicBlock(
(conv1): Conv2d(input_channels=256, output_channels=512, kernel_size=(3, 3), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec0f9400>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.7.0.bn1.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.7.0.bn1.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.7.0.bn1.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.7.0.bn1.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=512, output_channels=512, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec4c33e0>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.7.0.bn2.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.7.0.bn2.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.7.0.bn2.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.7.0.bn2.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
(down_sample): SequentialCell(
(0): Conv2d(input_channels=256, output_channels=512, kernel_size=(1, 1), stride=(2, 2), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec126330>, bias_init=None, format=NCHW)
(1): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.7.0.down_sample.1.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.7.0.down_sample.1.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.7.0.down_sample.1.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.7.0.down_sample.1.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
)
)
(1): BasicBlock(
(conv1): Conv2d(input_channels=512, output_channels=512, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec2e5850>, bias_init=None, format=NCHW)
(bn1): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.7.1.bn1.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.7.1.bn1.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.7.1.bn1.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.7.1.bn1.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
(relu): ReLU()
(conv2): Conv2d(input_channels=512, output_channels=512, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec5b6a80>, bias_init=None, format=NCHW)
(bn2): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.7.1.bn2.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.7.1.bn2.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.7.1.bn2.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.7.1.bn2.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
)
)
(8): GlobalAvgPooling()
(9): Flatten()
)
(coarse_classifier): Dense(input_channels=512, output_channels=20, has_bias=True)
(fine_classifier): Dense(input_channels=512, output_channels=100, has_bias=True)
)
Let’s wrap our model with mindspore.amp.auto_mixed_precision to handle the type casting between FP32 and FP16 automatically. This saves us the effort of doing it manually.
import mindspore.amp as amp
multihead_resnet18_amp = amp.auto_mixed_precision(network=multihead_resnet18, amp_level='O2')
multihead_resnet18_amp
_OutputTo32(
(_backbone): MultiHeadResNet18(
(backbone): SequentialCell(
(0): Conv2d(input_channels=3, output_channels=64, kernel_size=(7, 7), stride=(2, 2), pad_mode=pad, padding=3, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec36a240>, bias_init=None, format=NCHW)
(1): _OutputTo16(
(_backbone): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.1.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.1.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.1.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.1.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
)
(2): ReLU()
(3): MaxPool2d(kernel_size=3, stride=2, pad_mode=SAME)
(4): SequentialCell(
(0): BasicBlock(
(conv1): Conv2d(input_channels=64, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec369b20>, bias_init=None, format=NCHW)
(bn1): _OutputTo16(
(_backbone): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.4.0.bn1.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.4.0.bn1.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.4.0.bn1.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.4.0.bn1.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
)
(relu): ReLU()
(conv2): Conv2d(input_channels=64, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff2117b9e0>, bias_init=None, format=NCHW)
(bn2): _OutputTo16(
(_backbone): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.4.0.bn2.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.4.0.bn2.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.4.0.bn2.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.4.0.bn2.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
)
)
(1): BasicBlock(
(conv1): Conv2d(input_channels=64, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff211b1490>, bias_init=None, format=NCHW)
(bn1): _OutputTo16(
(_backbone): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.4.1.bn1.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.4.1.bn1.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.4.1.bn1.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.4.1.bn1.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
)
(relu): ReLU()
(conv2): Conv2d(input_channels=64, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec15cfb0>, bias_init=None, format=NCHW)
(bn2): _OutputTo16(
(_backbone): BatchNorm2d(num_features=64, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.4.1.bn2.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.4.1.bn2.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.4.1.bn2.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.4.1.bn2.moving_variance, shape=(64,), dtype=Float32, requires_grad=False))
)
)
)
(5): SequentialCell(
(0): BasicBlock(
(conv1): Conv2d(input_channels=64, output_channels=128, kernel_size=(3, 3), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff49e71550>, bias_init=None, format=NCHW)
(bn1): _OutputTo16(
(_backbone): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.5.0.bn1.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.5.0.bn1.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.5.0.bn1.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.5.0.bn1.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
)
(relu): ReLU()
(conv2): Conv2d(input_channels=128, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff211b1d30>, bias_init=None, format=NCHW)
(bn2): _OutputTo16(
(_backbone): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.5.0.bn2.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.5.0.bn2.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.5.0.bn2.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.5.0.bn2.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
)
(down_sample): SequentialCell(
(0): Conv2d(input_channels=64, output_channels=128, kernel_size=(1, 1), stride=(2, 2), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec0d3920>, bias_init=None, format=NCHW)
(1): _OutputTo16(
(_backbone): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.5.0.down_sample.1.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.5.0.down_sample.1.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.5.0.down_sample.1.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.5.0.down_sample.1.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
)
)
)
(1): BasicBlock(
(conv1): Conv2d(input_channels=128, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec156de0>, bias_init=None, format=NCHW)
(bn1): _OutputTo16(
(_backbone): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.5.1.bn1.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.5.1.bn1.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.5.1.bn1.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.5.1.bn1.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
)
(relu): ReLU()
(conv2): Conv2d(input_channels=128, output_channels=128, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff0c14f1d0>, bias_init=None, format=NCHW)
(bn2): _OutputTo16(
(_backbone): BatchNorm2d(num_features=128, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.5.1.bn2.gamma, shape=(128,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.5.1.bn2.beta, shape=(128,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.5.1.bn2.moving_mean, shape=(128,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.5.1.bn2.moving_variance, shape=(128,), dtype=Float32, requires_grad=False))
)
)
)
(6): SequentialCell(
(0): BasicBlock(
(conv1): Conv2d(input_channels=128, output_channels=256, kernel_size=(3, 3), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec15d190>, bias_init=None, format=NCHW)
(bn1): _OutputTo16(
(_backbone): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.6.0.bn1.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.6.0.bn1.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.6.0.bn1.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.6.0.bn1.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
)
(relu): ReLU()
(conv2): Conv2d(input_channels=256, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec126750>, bias_init=None, format=NCHW)
(bn2): _OutputTo16(
(_backbone): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.6.0.bn2.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.6.0.bn2.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.6.0.bn2.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.6.0.bn2.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
)
(down_sample): SequentialCell(
(0): Conv2d(input_channels=128, output_channels=256, kernel_size=(1, 1), stride=(2, 2), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff211b1820>, bias_init=None, format=NCHW)
(1): _OutputTo16(
(_backbone): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.6.0.down_sample.1.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.6.0.down_sample.1.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.6.0.down_sample.1.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.6.0.down_sample.1.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
)
)
)
(1): BasicBlock(
(conv1): Conv2d(input_channels=256, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec36b980>, bias_init=None, format=NCHW)
(bn1): _OutputTo16(
(_backbone): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.6.1.bn1.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.6.1.bn1.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.6.1.bn1.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.6.1.bn1.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
)
(relu): ReLU()
(conv2): Conv2d(input_channels=256, output_channels=256, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7ff21118320>, bias_init=None, format=NCHW)
(bn2): _OutputTo16(
(_backbone): BatchNorm2d(num_features=256, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.6.1.bn2.gamma, shape=(256,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.6.1.bn2.beta, shape=(256,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.6.1.bn2.moving_mean, shape=(256,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.6.1.bn2.moving_variance, shape=(256,), dtype=Float32, requires_grad=False))
)
)
)
(7): SequentialCell(
(0): BasicBlock(
(conv1): Conv2d(input_channels=256, output_channels=512, kernel_size=(3, 3), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec0f9400>, bias_init=None, format=NCHW)
(bn1): _OutputTo16(
(_backbone): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.7.0.bn1.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.7.0.bn1.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.7.0.bn1.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.7.0.bn1.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
)
(relu): ReLU()
(conv2): Conv2d(input_channels=512, output_channels=512, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec4c33e0>, bias_init=None, format=NCHW)
(bn2): _OutputTo16(
(_backbone): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.7.0.bn2.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.7.0.bn2.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.7.0.bn2.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.7.0.bn2.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
)
(down_sample): SequentialCell(
(0): Conv2d(input_channels=256, output_channels=512, kernel_size=(1, 1), stride=(2, 2), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec126330>, bias_init=None, format=NCHW)
(1): _OutputTo16(
(_backbone): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.7.0.down_sample.1.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.7.0.down_sample.1.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.7.0.down_sample.1.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.7.0.down_sample.1.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
)
)
)
(1): BasicBlock(
(conv1): Conv2d(input_channels=512, output_channels=512, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec2e5850>, bias_init=None, format=NCHW)
(bn1): _OutputTo16(
(_backbone): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.7.1.bn1.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.7.1.bn1.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.7.1.bn1.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.7.1.bn1.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
)
(relu): ReLU()
(conv2): Conv2d(input_channels=512, output_channels=512, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xe7feec5b6a80>, bias_init=None, format=NCHW)
(bn2): _OutputTo16(
(_backbone): BatchNorm2d(num_features=512, eps=1e-05, momentum=0.9, gamma=Parameter (name=backbone.7.1.bn2.gamma, shape=(512,), dtype=Float32, requires_grad=True), beta=Parameter (name=backbone.7.1.bn2.beta, shape=(512,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=backbone.7.1.bn2.moving_mean, shape=(512,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=backbone.7.1.bn2.moving_variance, shape=(512,), dtype=Float32, requires_grad=False))
)
)
)
(8): GlobalAvgPooling()
(9): Flatten()
)
(coarse_classifier): Dense(input_channels=512, output_channels=20, has_bias=True)
(fine_classifier): Dense(input_channels=512, output_channels=100, has_bias=True)
)
)
Create a dummy RGB “image” of $64 \times 64$ pixels and inspect the output shape from our network. Since we have 20 raw logits for the coarse labels and 100 for the fine labels, we should output a total of $20 + 100 = 120$ logits per image.
import mindspore.ops as ops
X_rgb_64x64 = ops.randn(1, 3, 64, 64)
y_hat = multihead_resnet18_amp(X_rgb_64x64)
y_hat.shape
/usr/local/Ascend/cann-9.0.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:179: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-9.0.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-9.0.0/opp/built-in/op_impl/ai_core/tbe/impl/ops_legacy/dynamic/gelu_grad_v2.py:97: SyntaxWarning: invalid escape sequence '\h'
gelu_grad_erf = erfc(-\hat{x}) / 2 + (1 /sqrt(Pi)) * (\hat{x}) * exp(-\hat{x}^2)
/usr/local/Ascend/cann-9.0.0/opp/built-in/op_impl/ai_core/tbe/impl/ops_legacy/dynamic/gelu_grad_v2.py:157: SyntaxWarning: invalid escape sequence '\h'
gelu_grad_erf = erfc(-\hat{x}) / 2 + (1 /sqrt(Pi)) * (\hat{x}) * exp(-\hat{x}^2)
/usr/local/Ascend/cann-9.0.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:179: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-9.0.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:179: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-9.0.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:179: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-9.0.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:179: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-9.0.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:179: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-9.0.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:179: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-9.0.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:179: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-9.0.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:179: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
(1, 120)
Let’s define our loss function. Since there is a hierarchical relationship between the coarse and fine labels, we should mirror this relationship in our loss function.
The idea is that we should penalize an incorrect prediction of the coarse label more severely compared to an incorrect prediction of the fine label. It might be easy for our model to mistake an apple for a pear but at least it shouldn’t mistake a person for a tree!
For our purposes in this notebook experiment, let’s set $\lambda = 0.5$ to strike a balance between penalizing incorrect fine label predictions too heavily and ensuring that our model learns the fine labels quick enough within a limited number of epochs. We’ll fine-tune our model over 10 epochs as it takes just under 2 hours to do so with our OrangePi AIpro (20T) development board which is a marginally reasonable timeframe.
class HierarchicalLoss(nn.Cell):
def __init__(self, lambd=0.5, reduction='mean'):
super().__init__()
self.lambd = lambd
self.reduction = reduction
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logits, labels):
coarse_logits, fine_logits = logits[:, :20], logits[:, 20:]
coarse_labels, fine_labels = labels
coarse_loss = self.ce(coarse_logits, coarse_labels)
fine_loss = self.ce(fine_logits, fine_labels)
return coarse_loss + self.lambd * fine_loss
loss_fn = HierarchicalLoss(lambd=0.5, reduction='mean')
loss_fn
HierarchicalLoss(
(ce): SoftmaxCrossEntropyWithLogits()
)
Let’s use minibatch SGD as our optimizer with a learning rate of 1e-3 for faster convergence within 10 epochs. For a fine-tuning task like this one, 1e-3 is already considered aggressive as we risk destroying the pre-trained image filters. Usually, a learning rate of 1e-4 will suffice.
optimizer = nn.SGD(params=multihead_resnet18_amp.trainable_params(), learning_rate=1e-3)
optimizer
SGD()
Define our forward and gradient functions as usual.
def forward(X, y):
y_hat = multihead_resnet18_amp(X)
loss = loss_fn(y_hat, y)
return loss, y_hat
grad_fn = mindspore.value_and_grad(fn=forward, grad_position=None, weights=optimizer.parameters, has_aux=True)
Now define our training logic per batch and per epoch. This is mostly similar to what we’ve seen before except our dataset now returns 2 sets of labels (coarse and fine) instead of 1 which we must accommodate for.
def train_batch(X_batch, y_batch):
(loss, _), grads = grad_fn(X_batch, y_batch)
optimizer(grads)
return loss
def train_epoch(epoch=0):
print(f'Epoch {epoch} start')
batch_count = train_ds.get_dataset_size()
training_losses = []
multihead_resnet18_amp.set_train()
for batch_idx, (X_batch, y0_batch, y1_batch) in enumerate(train_ds.create_tuple_iterator()):
loss = train_batch(X_batch, (y0_batch, y1_batch))
if batch_idx % 20 == 0:
print(f'Training loss: {loss.asnumpy():.4f} [{batch_idx}/{batch_count}]')
training_losses.append(loss)
print(f'Epoch {epoch} end')
return training_losses
Define the validation logic as well. Again, the only real difference is accommodating both the coarse and fine labels.
def validate_epoch(epoch=0):
validation_losses = []
multihead_resnet18_amp.set_train(False)
for X_batch, y0_batch, y1_batch in test_ds.create_tuple_iterator():
batch_size = X_batch.shape[0]
y_hat = multihead_resnet18_amp(X_batch)
validation_loss = (batch_size, loss_fn(y_hat, (y0_batch, y1_batch)).item())
validation_losses.append(validation_loss)
val_samples_total = sum(batch_size for batch_size, _ in validation_losses)
val_loss = sum(batch_size * batch_loss for batch_size, batch_loss in validation_losses) / val_samples_total
print(f'Validation loss after epoch {epoch}: {val_loss:.4f}')
return val_loss
Fine-tune our model over 10 epochs. Be patient - this will take just under 2 hours to complete on our development board.
training_losses = []
validation_losses = []
epochs = 10
print(f'Fine-tuning our model over {epochs} epochs ...')
for epoch in range(epochs):
training_losses.extend(train_epoch(epoch=epoch))
validation_losses.append(validate_epoch(epoch=epoch))
training_losses, validation_losses = ops.stack(training_losses, axis=0), mindspore.Tensor(validation_losses)
training_losses.shape, validation_losses.shape
Fine-tuning our model over 10 epochs ...
Epoch 0 start
..Training loss: 5.6848 [0/391]
Training loss: 5.4499 [20/391]
Training loss: 5.3580 [40/391]
Training loss: 5.3982 [60/391]
Training loss: 5.3181 [80/391]
Training loss: 5.3137 [100/391]
Training loss: 5.1766 [120/391]
Training loss: 5.1904 [140/391]
Training loss: 5.0723 [160/391]
Training loss: 5.0245 [180/391]
Training loss: 4.9998 [200/391]
Training loss: 4.7842 [220/391]
Training loss: 4.9190 [240/391]
Training loss: 4.9942 [260/391]
Training loss: 4.8681 [280/391]
Training loss: 4.7217 [300/391]
Training loss: 4.6198 [320/391]
Training loss: 4.7625 [340/391]
Training loss: 4.4851 [360/391]
Training loss: 4.7031 [380/391]
path string is NULLpath string is NULL.Epoch 0 end
Validation loss after epoch 0: 4.5388
Epoch 1 start
Training loss: 4.4457 [0/391]
Training loss: 4.5591 [20/391]
Training loss: 4.3686 [40/391]
Training loss: 4.4903 [60/391]
Training loss: 4.4057 [80/391]
Training loss: 4.2155 [100/391]
Training loss: 4.1198 [120/391]
Training loss: 4.1873 [140/391]
Training loss: 4.2746 [160/391]
Training loss: 4.1506 [180/391]
Training loss: 4.1203 [200/391]
Training loss: 4.1334 [220/391]
Training loss: 3.8344 [240/391]
Training loss: 3.9543 [260/391]
Training loss: 3.8445 [280/391]
Training loss: 4.1660 [300/391]
Training loss: 3.8658 [320/391]
Training loss: 3.9542 [340/391]
Training loss: 3.7998 [360/391]
Training loss: 3.7627 [380/391]
Epoch 1 end
Validation loss after epoch 1: 3.7990
Epoch 2 start
Training loss: 3.8752 [0/391]
Training loss: 3.6399 [20/391]
Training loss: 3.6589 [40/391]
Training loss: 3.5796 [60/391]
Training loss: 3.6930 [80/391]
Training loss: 3.6689 [100/391]
Training loss: 3.5086 [120/391]
Training loss: 3.5895 [140/391]
Training loss: 3.3900 [160/391]
Training loss: 3.4246 [180/391]
Training loss: 3.4555 [200/391]
Training loss: 3.2600 [220/391]
Training loss: 3.6037 [240/391]
Training loss: 3.5174 [260/391]
Training loss: 3.3731 [280/391]
Training loss: 3.4723 [300/391]
Training loss: 3.3939 [320/391]
Training loss: 3.5425 [340/391]
Training loss: 3.1965 [360/391]
Training loss: 3.3080 [380/391]
Epoch 2 end
Validation loss after epoch 2: 3.3056
Epoch 3 start
Training loss: 3.1990 [0/391]
Training loss: 3.2452 [20/391]
Training loss: 3.1243 [40/391]
Training loss: 3.3510 [60/391]
Training loss: 3.2063 [80/391]
Training loss: 3.0302 [100/391]
Training loss: 3.1367 [120/391]
Training loss: 3.1910 [140/391]
Training loss: 3.1967 [160/391]
Training loss: 2.9561 [180/391]
Training loss: 3.0482 [200/391]
Training loss: 3.2196 [220/391]
Training loss: 2.9519 [240/391]
Training loss: 3.2386 [260/391]
Training loss: 2.8335 [280/391]
Training loss: 3.0779 [300/391]
Training loss: 2.8997 [320/391]
Training loss: 3.1181 [340/391]
Training loss: 3.0453 [360/391]
Training loss: 2.8809 [380/391]
Epoch 3 end
Validation loss after epoch 3: 2.9454
Epoch 4 start
Training loss: 2.8008 [0/391]
Training loss: 2.8289 [20/391]
Training loss: 2.7191 [40/391]
Training loss: 2.8720 [60/391]
Training loss: 2.7193 [80/391]
Training loss: 2.7098 [100/391]
Training loss: 2.6996 [120/391]
Training loss: 2.8685 [140/391]
Training loss: 2.6722 [160/391]
Training loss: 2.7407 [180/391]
Training loss: 3.0916 [200/391]
Training loss: 2.8873 [220/391]
Training loss: 2.7637 [240/391]
Training loss: 2.7557 [260/391]
Training loss: 2.6343 [280/391]
Training loss: 2.8405 [300/391]
Training loss: 2.6399 [320/391]
Training loss: 2.8021 [340/391]
Training loss: 2.5994 [360/391]
Training loss: 2.4809 [380/391]
Epoch 4 end
Validation loss after epoch 4: 2.7025
Epoch 5 start
Training loss: 2.7227 [0/391]
Training loss: 2.7959 [20/391]
Training loss: 2.5210 [40/391]
Training loss: 2.6694 [60/391]
Training loss: 2.6951 [80/391]
Training loss: 2.5896 [100/391]
Training loss: 2.4578 [120/391]
Training loss: 2.5537 [140/391]
Training loss: 2.7304 [160/391]
Training loss: 2.5739 [180/391]
Training loss: 2.5543 [200/391]
Training loss: 2.5630 [220/391]
Training loss: 2.2733 [240/391]
Training loss: 2.3827 [260/391]
Training loss: 2.2933 [280/391]
Training loss: 2.4656 [300/391]
Training loss: 2.4673 [320/391]
Training loss: 2.4325 [340/391]
Training loss: 2.3518 [360/391]
Training loss: 2.3556 [380/391]
Epoch 5 end
Validation loss after epoch 5: 2.4836
Epoch 6 start
Training loss: 2.6754 [0/391]
Training loss: 2.2903 [20/391]
Training loss: 2.3814 [40/391]
Training loss: 2.3325 [60/391]
Training loss: 2.3894 [80/391]
Training loss: 2.1230 [100/391]
Training loss: 2.3667 [120/391]
Training loss: 2.5547 [140/391]
Training loss: 2.4975 [160/391]
Training loss: 2.3667 [180/391]
Training loss: 2.3595 [200/391]
Training loss: 2.3678 [220/391]
Training loss: 2.3864 [240/391]
Training loss: 2.1519 [260/391]
Training loss: 2.5530 [280/391]
Training loss: 2.3646 [300/391]
Training loss: 2.2571 [320/391]
Training loss: 2.0247 [340/391]
Training loss: 2.2899 [360/391]
Training loss: 2.4016 [380/391]
Epoch 6 end
Validation loss after epoch 6: 2.3266
Epoch 7 start
Training loss: 2.0084 [0/391]
Training loss: 2.3109 [20/391]
Training loss: 2.3336 [40/391]
Training loss: 2.2305 [60/391]
Training loss: 2.1838 [80/391]
Training loss: 2.2407 [100/391]
Training loss: 2.1323 [120/391]
Training loss: 2.1440 [140/391]
Training loss: 2.2685 [160/391]
Training loss: 2.1336 [180/391]
Training loss: 2.1021 [200/391]
Training loss: 2.1731 [220/391]
Training loss: 2.1500 [240/391]
Training loss: 2.1780 [260/391]
Training loss: 2.0409 [280/391]
Training loss: 2.0791 [300/391]
Training loss: 2.1021 [320/391]
Training loss: 2.0038 [340/391]
Training loss: 2.1276 [360/391]
Training loss: 2.0658 [380/391]
Epoch 7 end
Validation loss after epoch 7: 2.2192
Epoch 8 start
Training loss: 2.1134 [0/391]
Training loss: 2.1615 [20/391]
Training loss: 2.0193 [40/391]
Training loss: 2.3047 [60/391]
Training loss: 2.1270 [80/391]
Training loss: 2.2316 [100/391]
Training loss: 2.2435 [120/391]
Training loss: 1.9793 [140/391]
Training loss: 1.9127 [160/391]
Training loss: 2.2763 [180/391]
Training loss: 2.1120 [200/391]
Training loss: 2.0741 [220/391]
Training loss: 2.1669 [240/391]
Training loss: 2.0715 [260/391]
Training loss: 2.1798 [280/391]
Training loss: 2.0331 [300/391]
Training loss: 2.0233 [320/391]
Training loss: 2.1274 [340/391]
Training loss: 1.9132 [360/391]
Training loss: 2.1143 [380/391]
Epoch 8 end
Validation loss after epoch 8: 2.0878
Epoch 9 start
Training loss: 2.0792 [0/391]
Training loss: 1.9131 [20/391]
Training loss: 2.0230 [40/391]
Training loss: 2.1467 [60/391]
Training loss: 1.7917 [80/391]
Training loss: 2.1788 [100/391]
Training loss: 1.7646 [120/391]
Training loss: 1.7108 [140/391]
Training loss: 1.9769 [160/391]
Training loss: 2.1211 [180/391]
Training loss: 1.9208 [200/391]
Training loss: 1.9581 [220/391]
Training loss: 2.2359 [240/391]
Training loss: 1.7331 [260/391]
Training loss: 1.9252 [280/391]
Training loss: 1.9974 [300/391]
Training loss: 1.9489 [320/391]
Training loss: 2.2466 [340/391]
Training loss: 1.7049 [360/391]
Training loss: 1.8673 [380/391]
Epoch 9 end
Validation loss after epoch 9: 2.0185
((3910,), (10,))
Let’s plot the training vs. validation loss with Matplotlib.
import numpy as np
batches_per_epoch = 391
batches_x = np.arange(batches_per_epoch * epochs)
epochs_x = np.arange(epochs) * batches_per_epoch + batches_per_epoch
training_losses_y = training_losses.asnumpy()
validation_losses_y = validation_losses.asnumpy()
plt.figure(figsize=(8, 5))
plt.plot(batches_x, training_losses_y, label='Training loss', linestyle='-')
plt.plot(epochs_x, validation_losses_y, label='Validation loss', linestyle='--')
plt.title('Training vs. validation losses')
plt.xlabel('Batch number')
plt.ylabel('Loss')
plt.yscale('log')
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend()
plt.show()
.
/usr/local/Ascend/cann-9.0.0/opp/built-in/op_impl/ai_core/tbe/impl/ops_legacy/dynamic/gelu_grad_v2.py:97: SyntaxWarning: invalid escape sequence '\h'
gelu_grad_erf = erfc(-\hat{x}) / 2 + (1 /sqrt(Pi)) * (\hat{x}) * exp(-\hat{x}^2)
/usr/local/Ascend/cann-9.0.0/opp/built-in/op_impl/ai_core/tbe/impl/ops_legacy/dynamic/gelu_grad_v2.py:157: SyntaxWarning: invalid escape sequence '\h'
gelu_grad_erf = erfc(-\hat{x}) / 2 + (1 /sqrt(Pi)) * (\hat{x}) * exp(-\hat{x}^2)
/usr/local/Ascend/cann-9.0.0/opp/built-in/op_impl/ai_core/tbe/impl/ops_legacy/dynamic/gelu_grad_v2.py:97: SyntaxWarning: invalid escape sequence '\h'
gelu_grad_erf = erfc(-\hat{x}) / 2 + (1 /sqrt(Pi)) * (\hat{x}) * exp(-\hat{x}^2)
/usr/local/Ascend/cann-9.0.0/opp/built-in/op_impl/ai_core/tbe/impl/ops_legacy/dynamic/gelu_grad_v2.py:157: SyntaxWarning: invalid escape sequence '\h'
gelu_grad_erf = erfc(-\hat{x}) / 2 + (1 /sqrt(Pi)) * (\hat{x}) * exp(-\hat{x}^2)
..
/usr/local/Ascend/cann-9.0.0/opp/built-in/op_impl/ai_core/tbe/impl/ops_legacy/dynamic/gelu_grad_v2.py:97: SyntaxWarning: invalid escape sequence '\h'
gelu_grad_erf = erfc(-\hat{x}) / 2 + (1 /sqrt(Pi)) * (\hat{x}) * exp(-\hat{x}^2)
/usr/local/Ascend/cann-9.0.0/opp/built-in/op_impl/ai_core/tbe/impl/ops_legacy/dynamic/gelu_grad_v2.py:157: SyntaxWarning: invalid escape sequence '\h'
gelu_grad_erf = erfc(-\hat{x}) / 2 + (1 /sqrt(Pi)) * (\hat{x}) * exp(-\hat{x}^2)
/usr/local/Ascend/cann-9.0.0/opp/built-in/op_impl/ai_core/tbe/impl/ops_legacy/dynamic/gelu_grad_v2.py:97: SyntaxWarning: invalid escape sequence '\h'
gelu_grad_erf = erfc(-\hat{x}) / 2 + (1 /sqrt(Pi)) * (\hat{x}) * exp(-\hat{x}^2)
/usr/local/Ascend/cann-9.0.0/opp/built-in/op_impl/ai_core/tbe/impl/ops_legacy/dynamic/gelu_grad_v2.py:157: SyntaxWarning: invalid escape sequence '\h'
gelu_grad_erf = erfc(-\hat{x}) / 2 + (1 /sqrt(Pi)) * (\hat{x}) * exp(-\hat{x}^2)
...

Finally, let’s evaluate the accuracy of our model with 3 metrics.
multihead_resnet18_amp.set_train(False)
coarse_accs, fine_accs, true_accs = [], [], []
for X_batch, y0_batch, y1_batch in test_ds.create_tuple_iterator():
batch_size = X_batch.shape[0]
y_hat = multihead_resnet18_amp(X_batch)
y0_hat, y1_hat = y_hat[:, :20], y_hat[:, 20:]
y0_hat, y1_hat = ops.argmax(y0_hat, dim=1), ops.argmax(y1_hat, dim=1)
y0_batch, y1_batch = ops.argmax(y0_batch, dim=1), ops.argmax(y1_batch, dim=1)
y0_correct, y1_correct = (y0_hat == y0_batch).sum().item(), (y1_hat == y1_batch).sum().item()
y0_acc, y1_acc = y0_correct / batch_size, y1_correct / batch_size
coarse_accs.append((batch_size, y0_acc))
fine_accs.append((batch_size, y1_acc))
y_hat, y_batch = ops.stack((y0_hat, y1_hat), axis=1), ops.stack((y0_batch, y1_batch), axis=1)
y_correct = ops.amin((y_hat == y_batch).astype(dtype=mstype.int32), axis=1).sum().item()
y_acc = y_correct / batch_size
true_accs.append((batch_size, y_acc))
val_samples_total = sum(batch_size for batch_size, _ in true_accs)
coarse_acc = sum(batch_size * y0_acc for batch_size, y0_acc in coarse_accs) / val_samples_total
fine_acc = sum(batch_size * y1_acc for batch_size, y1_acc in fine_accs) / val_samples_total
true_acc = sum(batch_size * y_acc for batch_size, y_acc in true_accs) / val_samples_total
print(f'Coarse accuracy: {coarse_acc:.4f}')
print(f'Fine accuracy: {fine_acc:.4f}')
print(f'True accuracy: {true_acc:.4f}')
.Coarse accuracy: 0.6987
Fine accuracy: 0.4645
True accuracy: 0.4251
The coarse accuracy is close to $70\%$ and the hierarchical accuracy is over $40\%$ - not bad!
We saw in this notebook experiment how to load ResNet-18 with pre-trained weights using MindCV and use it for transfer learning by fine-tuning the pre-trained model on the CIFAR-100 dataset using MindSpore.
While fine-tuning a modern CNN such as ResNet-18 is interesting in its own right, it’s nevertheless just the tip of the iceberg in the field of transfer learning and we have yet to explore practical techniques of transfer learning relevant to modern deep learning. As we continue in our deep learning journey, we’ll likely cover LLM parameter-efficient fine-tuning (PEFT) techniques such as LoRA in a future notebook experiment. PEFT enables modern ML practitioners to quickly adapt an existing LLM to a specific domain by freezing the original model weights and attaching a small set of new weights for training, greatly reducing the time, effort and resources required for LLM fine-tuning.
I hope you enjoyed following through this notebook experiment as much as I did authoring it and stay tuned for updates ;-)