Date: 2026-04-04
The notebook source code for this article is available on: GitCode, GitHub
MindSpore is an open source deep learning framework lead by Huawei and the MindSpore community. It is optimized for performing model training and inference tasks on Huawei’s Ascend series processors (NPUs), though it supports NVIDIA GPUs and CPU-only environments as well. It provides a compelling alternative to PyTorch, the leading open source deep learning framework under the Linux Foundation. Furthermore, it provides utility classes and functions compatible with a subset of PyTorch to ease the migration of training and inference pipelines from PyTorch to MindSpore, which is out of scope for this article.
In this lab, we will use MindSpore for training a simple linear regression model consisting of a single fully connected layer to predict house prices in California. The dataset we will be using is the California housing dataset which we’ll fetch using scikit-learn’s sklearn.datasets.fetch_california_housing function.
The first 3 chapters of Dive into Deep Learning, also known as D2L. It covers the background knowledge on introductory probability, statistics and linear algebra, as well as data pre-processing, transformation methods and linear regression techniques required to understand the steps performed in this lab.
The instructions in this lab were tested on the OrangePi AIpro (20T) development board. It should work in a CPU-only environment as well with minimal configuration - simply specify the NOTEBOOK_USE_CPU=1 environment variable. While it may work in NVIDIA GPU environments with few modifications by specifying the NOTEBOOK_USE_GPU=1 environment variable, this has not been explicitly tested.
The software versions listed below which can be found in the provided requirements.txt as well.
%pip install -r requirements.txt
Requirement already satisfied: absl-py==2.4.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 1)) (2.4.0)
Requirement already satisfied: attrs==25.4.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 2)) (25.4.0)
Requirement already satisfied: cloudpickle==3.1.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 3)) (3.1.2)
Requirement already satisfied: decorator==5.2.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 4)) (5.2.1)
Requirement already satisfied: jupyterlab==4.5.6 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 5)) (4.5.6)
Requirement already satisfied: jupyterlab-git==0.52.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 6)) (0.52.0)
Requirement already satisfied: jupyter-resource-usage==1.2.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 7)) (1.2.0)
Requirement already satisfied: matplotlib==3.10.8 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 8)) (3.10.8)
Requirement already satisfied: mindspore==2.8.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 9)) (2.8.0)
Requirement already satisfied: ml-dtypes==0.5.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 10)) (0.5.4)
Requirement already satisfied: nbmake==1.5.5 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 11)) (1.5.5)
Requirement already satisfied: nbstripout==0.9.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 12)) (0.9.1)
Requirement already satisfied: pytest==9.0.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 13)) (9.0.2)
Requirement already satisfied: scikit-learn==1.8.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 14)) (1.8.0)
Requirement already satisfied: sympy==1.14.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 15)) (1.14.0)
Requirement already satisfied: tornado==6.5.5 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from -r requirements.txt (line 16)) (6.5.5)
Requirement already satisfied: async-lru>=1.0.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.6->-r requirements.txt (line 5)) (2.3.0)
Requirement already satisfied: httpx<1,>=0.25.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.28.1)
Requirement already satisfied: ipykernel!=6.30.0,>=6.5.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.6->-r requirements.txt (line 5)) (7.2.0)
Requirement already satisfied: jinja2>=3.0.3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.6->-r requirements.txt (line 5)) (3.1.6)
Requirement already satisfied: jupyter-core in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.6->-r requirements.txt (line 5)) (5.9.1)
Requirement already satisfied: jupyter-lsp>=2.0.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.6->-r requirements.txt (line 5)) (2.3.1)
Requirement already satisfied: jupyter-server<3,>=2.4.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.6->-r requirements.txt (line 5)) (2.17.0)
Requirement already satisfied: jupyterlab-server<3,>=2.28.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.6->-r requirements.txt (line 5)) (2.28.0)
Requirement already satisfied: notebook-shim>=0.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.2.4)
Requirement already satisfied: packaging in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.6->-r requirements.txt (line 5)) (26.0)
Requirement already satisfied: setuptools>=41.1.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.6->-r requirements.txt (line 5)) (82.0.1)
Requirement already satisfied: traitlets in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab==4.5.6->-r requirements.txt (line 5)) (5.14.3)
Requirement already satisfied: nbdime~=4.0.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab-git==0.52.0->-r requirements.txt (line 6)) (4.0.4)
Requirement already satisfied: nbformat in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab-git==0.52.0->-r requirements.txt (line 6)) (5.10.4)
Requirement already satisfied: pexpect in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab-git==0.52.0->-r requirements.txt (line 6)) (4.9.0)
Requirement already satisfied: prometheus-client in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-resource-usage==1.2.0->-r requirements.txt (line 7)) (0.24.1)
Requirement already satisfied: psutil>=5.6 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-resource-usage==1.2.0->-r requirements.txt (line 7)) (7.2.2)
Requirement already satisfied: pyzmq>=19 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-resource-usage==1.2.0->-r requirements.txt (line 7)) (27.1.0)
Requirement already satisfied: contourpy>=1.0.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.8->-r requirements.txt (line 8)) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.8->-r requirements.txt (line 8)) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.8->-r requirements.txt (line 8)) (4.62.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.8->-r requirements.txt (line 8)) (1.5.0)
Requirement already satisfied: numpy>=1.23 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.8->-r requirements.txt (line 8)) (1.26.4)
Requirement already satisfied: pillow>=8 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.8->-r requirements.txt (line 8)) (12.2.0)
Requirement already satisfied: pyparsing>=3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.8->-r requirements.txt (line 8)) (3.3.2)
Requirement already satisfied: python-dateutil>=2.7 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from matplotlib==3.10.8->-r requirements.txt (line 8)) (2.9.0.post0)
Requirement already satisfied: protobuf>=3.13.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindspore==2.8.0->-r requirements.txt (line 9)) (7.34.1)
Requirement already satisfied: asttokens>=2.0.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindspore==2.8.0->-r requirements.txt (line 9)) (3.0.1)
Requirement already satisfied: scipy>=1.5.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindspore==2.8.0->-r requirements.txt (line 9)) (1.17.1)
Requirement already satisfied: astunparse>=1.6.3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindspore==2.8.0->-r requirements.txt (line 9)) (1.6.3)
Requirement already satisfied: safetensors>=0.4.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindspore==2.8.0->-r requirements.txt (line 9)) (0.7.0)
Requirement already satisfied: dill>=0.3.7 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from mindspore==2.8.0->-r requirements.txt (line 9)) (0.4.1)
Requirement already satisfied: nbclient>=0.6.6 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbmake==1.5.5->-r requirements.txt (line 11)) (0.10.4)
Requirement already satisfied: pygments>=2.7.3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbmake==1.5.5->-r requirements.txt (line 11)) (2.20.0)
Requirement already satisfied: iniconfig>=1.0.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from pytest==9.0.2->-r requirements.txt (line 13)) (2.3.0)
Requirement already satisfied: pluggy<2,>=1.5 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from pytest==9.0.2->-r requirements.txt (line 13)) (1.6.0)
Requirement already satisfied: joblib>=1.3.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from scikit-learn==1.8.0->-r requirements.txt (line 14)) (1.5.3)
Requirement already satisfied: threadpoolctl>=3.2.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from scikit-learn==1.8.0->-r requirements.txt (line 14)) (3.6.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from sympy==1.14.0->-r requirements.txt (line 15)) (1.3.0)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from astunparse>=1.6.3->mindspore==2.8.0->-r requirements.txt (line 9)) (0.46.3)
Requirement already satisfied: six<2.0,>=1.6.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from astunparse>=1.6.3->mindspore==2.8.0->-r requirements.txt (line 9)) (1.17.0)
Requirement already satisfied: anyio in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from httpx<1,>=0.25.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (4.13.0)
Requirement already satisfied: certifi in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from httpx<1,>=0.25.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (2026.2.25)
Requirement already satisfied: httpcore==1.* in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from httpx<1,>=0.25.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (1.0.9)
Requirement already satisfied: idna in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from httpx<1,>=0.25.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (3.11)
Requirement already satisfied: h11>=0.16 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from httpcore==1.*->httpx<1,>=0.25.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.16.0)
Requirement already satisfied: comm>=0.1.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.2.3)
Requirement already satisfied: debugpy>=1.6.5 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (1.8.20)
Requirement already satisfied: ipython>=7.23.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (9.12.0)
Requirement already satisfied: jupyter-client>=8.8.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (8.8.0)
Requirement already satisfied: matplotlib-inline>=0.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.2.1)
Requirement already satisfied: nest-asyncio>=1.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (1.6.0)
Requirement already satisfied: MarkupSafe>=2.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jinja2>=3.0.3->jupyterlab==4.5.6->-r requirements.txt (line 5)) (3.0.3)
Requirement already satisfied: platformdirs>=2.5 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-core->jupyterlab==4.5.6->-r requirements.txt (line 5)) (4.9.4)
Requirement already satisfied: argon2-cffi>=21.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (25.1.0)
Requirement already satisfied: jupyter-events>=0.11.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.12.0)
Requirement already satisfied: jupyter-server-terminals>=0.4.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.5.4)
Requirement already satisfied: nbconvert>=6.4.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (7.17.0)
Requirement already satisfied: send2trash>=1.8.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (2.1.0)
Requirement already satisfied: terminado>=0.8.3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.18.1)
Requirement already satisfied: websocket-client>=1.7 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (1.9.0)
Requirement already satisfied: babel>=2.10 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (2.18.0)
Requirement already satisfied: json5>=0.9.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.14.0)
Requirement already satisfied: jsonschema>=4.18.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (4.26.0)
Requirement already satisfied: requests>=2.31 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (2.33.1)
Requirement already satisfied: colorama in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbdime~=4.0.1->jupyterlab-git==0.52.0->-r requirements.txt (line 6)) (0.4.6)
Requirement already satisfied: gitpython!=2.1.4,!=2.1.5,!=2.1.6 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbdime~=4.0.1->jupyterlab-git==0.52.0->-r requirements.txt (line 6)) (3.1.46)
Requirement already satisfied: fastjsonschema>=2.15 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbformat->jupyterlab-git==0.52.0->-r requirements.txt (line 6)) (2.21.2)
Requirement already satisfied: ptyprocess>=0.5 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from pexpect->jupyterlab-git==0.52.0->-r requirements.txt (line 6)) (0.7.0)
Requirement already satisfied: typing_extensions>=4.5 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from anyio->httpx<1,>=0.25.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (4.15.0)
Requirement already satisfied: argon2-cffi-bindings in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from argon2-cffi>=21.1->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (25.1.0)
Requirement already satisfied: gitdb<5,>=4.0.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from gitpython!=2.1.4,!=2.1.5,!=2.1.6->nbdime~=4.0.1->jupyterlab-git==0.52.0->-r requirements.txt (line 6)) (4.0.12)
Requirement already satisfied: ipython-pygments-lexers>=1.0.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (1.1.1)
Requirement already satisfied: jedi>=0.18.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.19.2)
Requirement already satisfied: prompt_toolkit<3.1.0,>=3.0.41 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (3.0.52)
Requirement already satisfied: stack_data>=0.6.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.6.3)
Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (2025.9.1)
Requirement already satisfied: referencing>=0.28.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.37.0)
Requirement already satisfied: rpds-py>=0.25.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.30.0)
Requirement already satisfied: python-json-logger>=2.0.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (4.1.0)
Requirement already satisfied: pyyaml>=5.3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (6.0.3)
Requirement already satisfied: rfc3339-validator in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.1.4)
Requirement already satisfied: rfc3986-validator>=0.1.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.1.1)
Requirement already satisfied: beautifulsoup4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (4.14.3)
Requirement already satisfied: bleach!=5.0.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from bleach[css]!=5.0.0->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (6.3.0)
Requirement already satisfied: defusedxml in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.7.1)
Requirement already satisfied: jupyterlab-pygments in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.3.0)
Requirement already satisfied: mistune<4,>=2.0.3 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (3.2.0)
Requirement already satisfied: pandocfilters>=1.4.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (1.5.1)
Requirement already satisfied: charset_normalizer<4,>=2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from requests>=2.31->jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (3.4.7)
Requirement already satisfied: urllib3<3,>=1.26 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from requests>=2.31->jupyterlab-server<3,>=2.28.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (2.6.3)
Requirement already satisfied: webencodings in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from bleach!=5.0.0->bleach[css]!=5.0.0->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.5.1)
Requirement already satisfied: tinycss2<1.5,>=1.1.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from bleach[css]!=5.0.0->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (1.4.0)
Requirement already satisfied: smmap<6,>=3.0.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from gitdb<5,>=4.0.1->gitpython!=2.1.4,!=2.1.5,!=2.1.6->nbdime~=4.0.1->jupyterlab-git==0.52.0->-r requirements.txt (line 6)) (5.0.3)
Requirement already satisfied: parso<0.9.0,>=0.8.4 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jedi>=0.18.2->ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.8.6)
Requirement already satisfied: fqdn in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (1.5.1)
Requirement already satisfied: isoduration in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (20.11.0)
Requirement already satisfied: jsonpointer>1.13 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (3.1.1)
Requirement already satisfied: rfc3987-syntax>=1.1.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (1.1.0)
Requirement already satisfied: uri-template in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (1.3.0)
Requirement already satisfied: webcolors>=24.6.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (25.10.0)
Requirement already satisfied: wcwidth in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from prompt_toolkit<3.1.0,>=3.0.41->ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.6.0)
Requirement already satisfied: executing>=1.2.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from stack_data>=0.6.0->ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (2.2.1)
Requirement already satisfied: pure-eval in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from stack_data>=0.6.0->ipython>=7.23.1->ipykernel!=6.30.0,>=6.5.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (0.2.3)
Requirement already satisfied: cffi>=1.0.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from argon2-cffi-bindings->argon2-cffi>=21.1->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (2.0.0)
Requirement already satisfied: soupsieve>=1.6.1 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from beautifulsoup4->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (2.8.3)
Requirement already satisfied: pycparser in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi>=21.1->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (3.0)
Requirement already satisfied: lark>=1.2.2 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from rfc3987-syntax>=1.1.0->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (1.3.1)
Requirement already satisfied: arrow>=0.15.0 in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (1.4.0)
Requirement already satisfied: tzdata in /home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.11.0->jupyter-server<3,>=2.4.0->jupyterlab==4.5.6->-r requirements.txt (line 5)) (2026.1)
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
Let’s initialize MindSpore based on our hardware platform and verify that it is correctly installed.
import os
NOTEBOOK_USE_ASCEND = os.getenv('NOTEBOOK_USE_ASCEND', '0')
NOTEBOOK_USE_GPU = os.getenv('NOTEBOOK_USE_GPU', '0')
NOTEBOOK_USE_CPU = os.getenv('NOTEBOOK_USE_CPU', '0')
NOTEBOOK_CI_MODE = os.getenv('NOTEBOOK_CI_MODE', '0')
import mindspore
platform = 'Ascend'
if NOTEBOOK_USE_ASCEND == '1':
platform = 'Ascend'
elif NOTEBOOK_USE_GPU == '1':
platform = 'GPU'
elif NOTEBOOK_USE_CPU == '1' or NOTEBOOK_CI_MODE == '1':
platform = 'CPU'
else:
platform = 'Ascend'
mindspore.set_device(device_target=platform)
mindspore.run_check()
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
setattr(self, word, getattr(machar, word).flat[0])
/home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
return self._float_to_str(self.smallest_subnormal)
/home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
setattr(self, word, getattr(machar, word).flat[0])
/home/HwHiAiUser/.pyenv/versions/3.12.13/envs/orangepiaipro-20t/lib/python3.12/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
return self._float_to_str(self.smallest_subnormal)
MindSpore version: 2.8.0
The result of multiplication calculation is correct, MindSpore has been installed on platform [Ascend] successfully!
Warnings can be safely ignored as long as no [CRITICAL] or [ERROR] messages appear, in which case check out the Ascend forum. If you see the below output, it means MindSpore 2.8.0 is correctly installed.
MindSpore version: 2.8.0
The result of multiplication calculation is correct, MindSpore has been installed on platform [Ascend] successfully!
Let’s fetch the California housing dataset with scikit-learn’s fetch_california_housing. We’ll convert the target to have shape (20640, 1), to avoid issues with our linear regression algorithm due to broadcasting.
from sklearn.datasets import fetch_california_housing
housing = fetch_california_housing()
X = housing.data
y = housing.target.reshape(-1, 1)
X.shape, y.shape
((20640, 8), (20640, 1))
Split the dataset into training and validation sets. We’ll use an $80:20$ ratio which is common in many machine learning applications.
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=False)
X_train.shape, X_test.shape, y_train.shape, y_test.shape
((16512, 8), (4128, 8), (16512, 1), (4128, 1))
Let’s inspect the first sample in our training set. Notice that all features and labels are purely numeric, floating point values - a perfect example of using linear regression.
X_train[0], y_train[0]
(array([ 8.3252 , 41. , 6.98412698, 1.02380952,
322. , 2.55555556, 37.88 , -122.23 ]),
array([4.526]))
Verify that our data is clean with no NaN values.
import numpy as np
np.any(np.isnan(X).ravel()), np.any(np.isnan(y).ravel())
(False, False)
Great - none of our samples contain NaNs! Let’s compute the mean and standard deviation of the labels in our training set.
y_train.mean(), y_train.std()
(2.02067031310562, 1.1352013072688294)
The scikit-learn documentation on real-world datasets states that the house prices are given as multiples of USD\$100,000 so the label values range from $0.15$ through $5.0$. Nevertheless, it’s a good idea to transform the labels through the 2-step process described below to make the data more Gaussian and centered, which should help our linear regression model produce more accurate results.
StandardScaler to the results in (1) to normalize the distribution to have zero mean $\mu = 0$ and unit variance $\sigma^2 = 1$We’ll achieve this through a custom scaler class HousingPriceScaler which inherits the following base classes in scikit-learn.
sklearn.base.BaseEstimator: base class for all estimators in scikit-learnsklearn.base.TransformerMixin: base class for data transformation in scikit-learnfrom sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import StandardScaler
class HousingPriceScaler(BaseEstimator, TransformerMixin):
def __init__(self):
self.scaler = StandardScaler()
def fit(self, y):
y_log1p = np.log1p(y)
print(f'After log1p, before scaling: mean = {y_log1p.mean():.4f}, stddev = {y_log1p.std():.4f}')
self.scaler.fit(y_log1p)
return self
def transform(self, y):
y_log1p = np.log1p(y)
y_scaled = self.scaler.transform(y_log1p)
return y_scaled
def inverse_transform(self, y_scaled):
y_log1p = self.scaler.inverse_transform(y_scaled)
y = np.expm1(y_log1p)
return y
scaler = HousingPriceScaler()
scaler
HousingPriceScaler()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
Before scaling the labels, let’s scale the features as well. Unlike the labels which should be transformed via the 2-step process described above for optimal linear regression behavior, it suffices to normalize our features directly with sklearn.preprocessing.StandardScaler. This makes our scaled features have mean $\mu = 0$ and standard deviation $\sigma = 1$.
features_scaler = StandardScaler()
X_train_scaled = features_scaler.fit_transform(X_train)
print(f'Before scaling: mean = {X_train.mean():.4f}, stddev = {X_train.std():.4f}')
print(f'After scaling: mean = {X_train_scaled.mean():.4f}, stddev = {X_train_scaled.std():.4f}')
X_test_scaled = features_scaler.transform(X_test)
Before scaling: mean = 174.5399, stddev = 631.1331
After scaling: mean = -0.0000, stddev = 1.0000
Now transform the labels in our training set with our custom scaler and observe the transformed labels are also centered with mean $\mu = 0$ and standard deviation $\sigma = 1$.
y_train_scaled = scaler.fit_transform(y_train)
print(f'Before scaling: mean = {y_train.mean():.4f}, stddev = {y_train.std():.4f}')
print(f'After scaling: mean = {y_train_scaled.mean():.4f}, stddev = {y_train_scaled.std():.4f}')
y_test_scaled = scaler.transform(y_test)
After log1p, before scaling: mean = 1.0418, stddev = 0.3506
Before scaling: mean = 2.0207, stddev = 1.1352
After scaling: mean = 0.0000, stddev = 1.0000
With our training labels appropriately scaled, it’s time to define and train our neural network!
Let’s define a simple neural network with exactly 1 fully connected layer. The fully connected layer has 8 input channels and 1 output channel which corresponds to the number of features and labels in our dataset respectively.
With MindSpore, we do this by defining our own class inheriting from mindspore.nn.Cell and implement the construct method. Our fully connected layer is given by mindspore.nn.Dense.
Let’s also define our class to accept the following optional parameters.
lr: the learning rate of our model. Defaults to 0.01 if not specifiedwd: the weight decay factor $\lambda$. Defaults to 0.0 if not specifiedNote that we must use 16-bit floating point values defined as float16 in mindspore.dtype. This is because matrix-vector and matrix-matrix multiplication is defined only for 16-bit floats with the CANN kernels library for the Ascend 310B1 NPU chip included with the OrangePi AIpro (20T) development board. This will reduce the precision of our model which is necessary and sufficient for our use case.
import mindspore.nn as nn
from mindspore import dtype as mstype
class MyModel(nn.Cell):
def __init__(self, lr=0.01, wd=0.0):
super().__init__()
self.lr = lr
self.wd = wd
self.dense1 = nn.Dense(8, 1, dtype=mstype.float16)
def construct(self, X):
y_hat = self.dense1(X)
return y_hat
model = MyModel(lr=0.03, wd=0.0)
model
MyModel(
(dense1): Dense(input_channels=8, output_channels=1, has_bias=True)
)
Use the mean squared error (MSE) for our loss function, predefined in MindSpore with mindspore.nn.MSELoss.
loss_fn = nn.MSELoss()
loss_fn
MSELoss()
Use minibatch stochaistic gradient descent (SGD) for our optimizer, predefined in MindSpore with mindspore.nn.SGD. Pass in the learning rate and weight decay factor from our model via the learning_rate and weight_decay arguments, respectively.
optimizer = nn.SGD(params=model.trainable_params(),
learning_rate=model.lr,
weight_decay=model.wd)
optimizer
SGD()
Now define our forward and gradient functions. mindspore.value_and_grad accepts the following parameters.
fn: our forward function. It takes the features X and labels y as arguments and returns a tuple of (loss, prediction)grad_position: set it to None and differentiate based on the model weights insteadweights: pass in the model weights available through our optimizer as optimizer.parametershas_aux: set to True so it returns the gradient based on just the first argument lossdef forward_fn(X, y):
y_hat = model(X)
loss = loss_fn(y_hat, y)
return loss, y_hat
grad_fn = mindspore.value_and_grad(fn=forward_fn,
grad_position=None,
weights=optimizer.parameters,
has_aux=True)
Define our training step and training epoch. Each step we take a batch of a certain fixed size, usually a power of 2, e.g. $2^9=512$. 1 full pass on our training data is known as an epoch.
def train_step(X, y):
(loss, _), grads = grad_fn(X, y)
optimizer(grads)
return loss
def train_epoch(dataset, epoch=0):
model.set_train()
print(f'Epoch {epoch} start')
batch_total = dataset.get_dataset_size()
training_losses = []
for batch_idx, (X_batch, y_batch) in enumerate(dataset.create_tuple_iterator()):
loss = train_step(X_batch, y_batch)
print(f'Training loss (scaled): {loss.asnumpy():.4f} [{batch_idx}/{batch_total}]')
training_losses.append(loss)
print(f'Epoch {epoch} end')
return training_losses
At the end of each epoch, let’s also validate our model against the validation set. Recall that our model was trained on features normalized by StandardScaler, plus labels that have undergone the 2-step transformation described in an earlier section. It will return predictions which are similarly transformed. We can undo these transformations with the inverse_transform method of our scaler class, which gives us the predicted house prices in multiples of USD\$100,000.
def validate_epoch(epoch=0):
model.set_train(False)
y_hat_scaled = model(mindspore.Tensor(X_test_scaled.astype(np.float16))).asnumpy()
y_hat = scaler.inverse_transform(y_hat_scaled)
validation_loss_scaled = loss_fn(mindspore.Tensor(y_hat_scaled), mindspore.Tensor(y_test_scaled))
validation_loss = loss_fn(mindspore.Tensor(y_hat), mindspore.Tensor(y_test))
print(f'Validation loss at epoch {epoch} (scaled): {validation_loss_scaled.asnumpy():.4f}')
print(f'Validation loss at epoch {epoch}: {validation_loss.asnumpy():.4f}')
return validation_loss_scaled
Convert our dataset to mindspore.dataset.NumpySlicesDataset and feed it to our model in batches of $2^9=512$ samples per batch. We saw some of its included methods in the train_epoch function.
get_dataset_size: get the total number of batches in the datasetcreate_tuple_iterator: convert the dataset into an iterator of tuples in the form (batch_idx, (X_batch, y_batch))import mindspore.dataset as ds
import mindspore.ops as ops
train_ds = ds.NumpySlicesDataset(data=(X_train_scaled.astype(np.float16), y_train_scaled.astype(np.float16)),
column_names=['features', 'labels'])
train_ds = train_ds.batch(batch_size=512)
epochs = 15
training_losses_scaled = []
validation_losses_scaled = []
for epoch_idx in range(epochs):
training_losses_scaled.extend(train_epoch(train_ds, epoch=epoch_idx))
validation_losses_scaled.append(validate_epoch(epoch=epoch_idx))
training_losses_scaled = ops.cast(ops.stack(training_losses_scaled, axis=0), dtype=mstype.float64)
validation_losses_scaled = ops.stack(validation_losses_scaled, axis=0)
training_losses_scaled.shape, validation_losses_scaled.shape
Epoch 0 start
/usr/local/Ascend/cann-8.5.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:161: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/classifier/transdata/transdata_classifier.py:223: SyntaxWarning: invalid escape sequence '\B'
Return BN\BH SCH Result
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:146: SyntaxWarning: invalid escape sequence '\c'
2. In forward, tiling would not split c1 and c0, find c1\c0 based on t2.
/usr/local/Ascend/cann-8.5.0/python/site-packages/tbe/dsl/unify_schedule/vector/transdata/common/graph/transdata_graph_info.py:172: SyntaxWarning: invalid escape sequence '\c'
1. Forward: tiling would not split c1\c0\h0, find c1\c0\h1\h0 based on t2
/usr/local/Ascend/cann-8.5.0/opp/built-in/op_impl/ai_core/tbe/impl/ops_legacy/dynamic/gelu_grad_v2.py:97: SyntaxWarning: invalid escape sequence '\h'
gelu_grad_erf = erfc(-\hat{x}) / 2 + (1 /sqrt(Pi)) * (\hat{x}) * exp(-\hat{x}^2)
/usr/local/Ascend/cann-8.5.0/opp/built-in/op_impl/ai_core/tbe/impl/ops_legacy/dynamic/gelu_grad_v2.py:157: SyntaxWarning: invalid escape sequence '\h'
gelu_grad_erf = erfc(-\hat{x}) / 2 + (1 /sqrt(Pi)) * (\hat{x}) * exp(-\hat{x}^2)
/usr/local/Ascend/cann-8.5.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:161: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-8.5.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:161: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-8.5.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:161: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-8.5.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:161: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-8.5.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:161: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-8.5.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:161: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-8.5.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:161: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
/usr/local/Ascend/cann-8.5.0/python/site-packages/asc_op_compile_base/asc_op_compiler/ascendc_compile_gen_code.py:161: SyntaxWarning: invalid escape sequence '\w'
match = re.search(f'{option}=(\w+)', ' '.join(compile_options))
.Training loss (scaled): 1.3799 [0/33]
Training loss (scaled): 1.2891 [1/33]
Training loss (scaled): 1.1133 [2/33]
Training loss (scaled): 1.0420 [3/33]
Training loss (scaled): 0.9453 [4/33]
Training loss (scaled): 1.2617 [5/33]
Training loss (scaled): 0.8091 [6/33]
Training loss (scaled): 0.7905 [7/33]
Training loss (scaled): 0.7300 [8/33]
Training loss (scaled): 0.6133 [9/33]
Training loss (scaled): 0.6831 [10/33]
Training loss (scaled): 0.6797 [11/33]
Training loss (scaled): 0.7168 [12/33]
Training loss (scaled): 0.6226 [13/33]
Training loss (scaled): 0.7051 [14/33]
Training loss (scaled): 0.5557 [15/33]
Training loss (scaled): 0.5259 [16/33]
Training loss (scaled): 0.5215 [17/33]
Training loss (scaled): 0.4946 [18/33]
Training loss (scaled): 0.4644 [19/33]
Training loss (scaled): 0.4788 [20/33]
Training loss (scaled): 0.4482 [21/33]
Training loss (scaled): 0.4346 [22/33]
Training loss (scaled): 0.4951 [23/33]
Training loss (scaled): 0.5820 [24/33]
Training loss (scaled): 0.4407 [25/33]
Training loss (scaled): 0.6157 [26/33]
Training loss (scaled): 0.4985 [27/33]
Training loss (scaled): 1.1260 [28/33]
Training loss (scaled): 0.5576 [29/33]
Training loss (scaled): 0.5366 [30/33]
Training loss (scaled): 0.5557 [31/33]
Training loss (scaled): 0.4836 [32/33]
Epoch 0 end
Validation loss at epoch 0 (scaled): 0.5919
Validation loss at epoch 0: 284.3527
Epoch 1 start
Training loss (scaled): 0.4980 [0/33]
Training loss (scaled): 0.5049 [1/33]
Training loss (scaled): 0.4485 [2/33]
Training loss (scaled): 0.4717 [3/33]
Training loss (scaled): 0.4900 [4/33]
Training loss (scaled): 0.4839 [5/33]
Training loss (scaled): 0.5483 [6/33]
Training loss (scaled): 0.4541 [7/33]
Training loss (scaled): 0.5259 [8/33]
Training loss (scaled): 0.4878 [9/33]
Training loss (scaled): 0.4846 [10/33]
Training loss (scaled): 0.5684 [11/33]
Training loss (scaled): 0.4893 [12/33]
Training loss (scaled): 0.4773 [13/33]
Training loss (scaled): 0.4653 [14/33]
Training loss (scaled): 0.4905 [15/33]
Training loss (scaled): 0.4165 [16/33]
Training loss (scaled): 0.4250 [17/33]
Training loss (scaled): 0.4495 [18/33]
Training loss (scaled): 0.4648 [19/33]
Training loss (scaled): 0.5757 [20/33]
Training loss (scaled): 0.4775 [21/33]
Training loss (scaled): 0.4695 [22/33]
Training loss (scaled): 0.5386 [23/33]
Training loss (scaled): 0.4153 [24/33]
Training loss (scaled): 0.5107 [25/33]
Training loss (scaled): 0.5142 [26/33]
Training loss (scaled): 0.4727 [27/33]
Training loss (scaled): 0.4907 [28/33]
Training loss (scaled): 0.4236 [29/33]
Training loss (scaled): 0.4270 [30/33]
Training loss (scaled): 0.4553 [31/33]
Training loss (scaled): 0.3872 [32/33]
Epoch 1 end
Validation loss at epoch 1 (scaled): 0.4860
Validation loss at epoch 1: 0.7928
Epoch 2 start
Training loss (scaled): 0.4849 [0/33]
Training loss (scaled): 0.4739 [1/33]
Training loss (scaled): 0.4854 [2/33]
Training loss (scaled): 0.4985 [3/33]
Training loss (scaled): 0.4985 [4/33]
Training loss (scaled): 0.4639 [5/33]
Training loss (scaled): 0.4258 [6/33]
Training loss (scaled): 0.5562 [7/33]
Training loss (scaled): 0.4263 [8/33]
Training loss (scaled): 0.3901 [9/33]
Training loss (scaled): 0.4709 [10/33]
Training loss (scaled): 0.4087 [11/33]
Training loss (scaled): 0.5078 [12/33]
Training loss (scaled): 0.4644 [13/33]
Training loss (scaled): 0.4473 [14/33]
Training loss (scaled): 0.4390 [15/33]
Training loss (scaled): 0.4197 [16/33]
Training loss (scaled): 0.4197 [17/33]
Training loss (scaled): 0.5298 [18/33]
Training loss (scaled): 0.4229 [19/33]
Training loss (scaled): 0.4641 [20/33]
Training loss (scaled): 0.4233 [21/33]
Training loss (scaled): 0.4404 [22/33]
Training loss (scaled): 0.5151 [23/33]
Training loss (scaled): 0.5049 [24/33]
Training loss (scaled): 0.4377 [25/33]
Training loss (scaled): 0.5542 [26/33]
Training loss (scaled): 0.4241 [27/33]
Training loss (scaled): 0.4631 [28/33]
Training loss (scaled): 0.4189 [29/33]
Training loss (scaled): 0.3826 [30/33]
Training loss (scaled): 0.4807 [31/33]
Training loss (scaled): 0.4712 [32/33]
Epoch 2 end
Validation loss at epoch 2 (scaled): 0.4533
Validation loss at epoch 2: 0.8286
Epoch 3 start
Training loss (scaled): 0.5015 [0/33]
Training loss (scaled): 0.5015 [1/33]
Training loss (scaled): 0.4878 [2/33]
Training loss (scaled): 0.4609 [3/33]
Training loss (scaled): 0.4365 [4/33]
Training loss (scaled): 0.4392 [5/33]
Training loss (scaled): 0.4238 [6/33]
Training loss (scaled): 0.4951 [7/33]
Training loss (scaled): 0.4319 [8/33]
Training loss (scaled): 0.4109 [9/33]
Training loss (scaled): 0.5068 [10/33]
Training loss (scaled): 0.4204 [11/33]
Training loss (scaled): 0.4448 [12/33]
Training loss (scaled): 0.4456 [13/33]
Training loss (scaled): 0.4688 [14/33]
Training loss (scaled): 0.4763 [15/33]
Training loss (scaled): 0.4114 [16/33]
Training loss (scaled): 0.4299 [17/33]
Training loss (scaled): 0.4363 [18/33]
Training loss (scaled): 0.3933 [19/33]
Training loss (scaled): 0.4175 [20/33]
Training loss (scaled): 0.4856 [21/33]
Training loss (scaled): 0.5078 [22/33]
Training loss (scaled): 0.4187 [23/33]
Training loss (scaled): 0.4204 [24/33]
Training loss (scaled): 0.5415 [25/33]
Training loss (scaled): 0.3647 [26/33]
Training loss (scaled): 0.4138 [27/33]
Training loss (scaled): 0.4194 [28/33]
Training loss (scaled): 0.4302 [29/33]
Training loss (scaled): 0.4780 [30/33]
Training loss (scaled): 0.3928 [31/33]
Training loss (scaled): 0.4355 [32/33]
Epoch 3 end
Validation loss at epoch 3 (scaled): 0.4293
Validation loss at epoch 3: 0.7472
Epoch 4 start
Training loss (scaled): 0.4875 [0/33]
Training loss (scaled): 0.4434 [1/33]
Training loss (scaled): 0.3662 [2/33]
Training loss (scaled): 0.4177 [3/33]
Training loss (scaled): 0.4675 [4/33]
Training loss (scaled): 0.4407 [5/33]
Training loss (scaled): 0.4023 [6/33]
Training loss (scaled): 0.4033 [7/33]
Training loss (scaled): 0.5107 [8/33]
Training loss (scaled): 0.4307 [9/33]
Training loss (scaled): 0.4255 [10/33]
Training loss (scaled): 0.4106 [11/33]
Training loss (scaled): 0.4382 [12/33]
Training loss (scaled): 0.4121 [13/33]
Training loss (scaled): 0.4995 [14/33]
Training loss (scaled): 0.4846 [15/33]
Training loss (scaled): 0.4290 [16/33]
Training loss (scaled): 0.4783 [17/33]
Training loss (scaled): 0.4761 [18/33]
Training loss (scaled): 0.4272 [19/33]
Training loss (scaled): 0.4092 [20/33]
Training loss (scaled): 0.4395 [21/33]
Training loss (scaled): 0.3760 [22/33]
Training loss (scaled): 0.4736 [23/33]
Training loss (scaled): 0.4136 [24/33]
Training loss (scaled): 0.4019 [25/33]
Training loss (scaled): 0.5293 [26/33]
Training loss (scaled): 0.4460 [27/33]
Training loss (scaled): 0.4272 [28/33]
Training loss (scaled): 0.4282 [29/33]
Training loss (scaled): 0.4248 [30/33]
Training loss (scaled): 0.4141 [31/33]
Training loss (scaled): 0.4194 [32/33]
Epoch 4 end
Validation loss at epoch 4 (scaled): 0.4155
Validation loss at epoch 4: 0.7660
Epoch 5 start
Training loss (scaled): 0.4307 [0/33]
Training loss (scaled): 0.4094 [1/33]
Training loss (scaled): 0.4468 [2/33]
Training loss (scaled): 0.4265 [3/33]
Training loss (scaled): 0.4500 [4/33]
Training loss (scaled): 0.4451 [5/33]
Training loss (scaled): 0.4629 [6/33]
Training loss (scaled): 0.4011 [7/33]
Training loss (scaled): 0.3989 [8/33]
Training loss (scaled): 0.4148 [9/33]
Training loss (scaled): 0.4294 [10/33]
Training loss (scaled): 0.3813 [11/33]
Training loss (scaled): 0.4500 [12/33]
Training loss (scaled): 0.4382 [13/33]
Training loss (scaled): 0.4919 [14/33]
Training loss (scaled): 0.4407 [15/33]
Training loss (scaled): 0.4353 [16/33]
Training loss (scaled): 0.3940 [17/33]
Training loss (scaled): 0.4619 [18/33]
Training loss (scaled): 0.3779 [19/33]
Training loss (scaled): 0.4795 [20/33]
Training loss (scaled): 0.4648 [21/33]
Training loss (scaled): 0.3726 [22/33]
Training loss (scaled): 0.4307 [23/33]
Training loss (scaled): 0.4753 [24/33]
Training loss (scaled): 0.4597 [25/33]
Training loss (scaled): 0.3872 [26/33]
Training loss (scaled): 0.4868 [27/33]
Training loss (scaled): 0.3879 [28/33]
Training loss (scaled): 0.3955 [29/33]
Training loss (scaled): 0.4275 [30/33]
Training loss (scaled): 0.4331 [31/33]
Training loss (scaled): 0.3843 [32/33]
Epoch 5 end
Validation loss at epoch 5 (scaled): 0.3980
Validation loss at epoch 5: 0.7297
Epoch 6 start
Training loss (scaled): 0.3804 [0/33]
Training loss (scaled): 0.4395 [1/33]
Training loss (scaled): 0.3594 [2/33]
Training loss (scaled): 0.4094 [3/33]
Training loss (scaled): 0.4695 [4/33]
Training loss (scaled): 0.4031 [5/33]
Training loss (scaled): 0.4036 [6/33]
Training loss (scaled): 0.4141 [7/33]
Training loss (scaled): 0.4409 [8/33]
Training loss (scaled): 0.4438 [9/33]
Training loss (scaled): 0.4038 [10/33]
Training loss (scaled): 0.4348 [11/33]
Training loss (scaled): 0.4473 [12/33]
Training loss (scaled): 0.4678 [13/33]
Training loss (scaled): 0.4485 [14/33]
Training loss (scaled): 0.4634 [15/33]
Training loss (scaled): 0.4153 [16/33]
Training loss (scaled): 0.3677 [17/33]
Training loss (scaled): 0.4082 [18/33]
Training loss (scaled): 0.4609 [19/33]
Training loss (scaled): 0.4353 [20/33]
Training loss (scaled): 0.4473 [21/33]
Training loss (scaled): 0.3916 [22/33]
Training loss (scaled): 0.3882 [23/33]
Training loss (scaled): 0.4001 [24/33]
Training loss (scaled): 0.4312 [25/33]
Training loss (scaled): 0.4265 [26/33]
Training loss (scaled): 0.4602 [27/33]
Training loss (scaled): 0.4067 [28/33]
Training loss (scaled): 0.5522 [29/33]
Training loss (scaled): 0.3811 [30/33]
Training loss (scaled): 0.4260 [31/33]
Training loss (scaled): 0.3186 [32/33]
Epoch 6 end
Validation loss at epoch 6 (scaled): 0.3879
Validation loss at epoch 6: 0.7155
Epoch 7 start
Training loss (scaled): 0.3860 [0/33]
Training loss (scaled): 0.3933 [1/33]
Training loss (scaled): 0.3569 [2/33]
Training loss (scaled): 0.4600 [3/33]
Training loss (scaled): 0.4119 [4/33]
Training loss (scaled): 0.5010 [5/33]
Training loss (scaled): 0.3506 [6/33]
Training loss (scaled): 0.3987 [7/33]
Training loss (scaled): 0.4194 [8/33]
Training loss (scaled): 0.4277 [9/33]
Training loss (scaled): 0.4182 [10/33]
Training loss (scaled): 0.3855 [11/33]
Training loss (scaled): 0.4700 [12/33]
Training loss (scaled): 0.3882 [13/33]
Training loss (scaled): 0.4424 [14/33]
Training loss (scaled): 0.4597 [15/33]
Training loss (scaled): 0.3625 [16/33]
Training loss (scaled): 0.4492 [17/33]
Training loss (scaled): 0.3894 [18/33]
Training loss (scaled): 0.4302 [19/33]
Training loss (scaled): 0.6040 [20/33]
Training loss (scaled): 0.4192 [21/33]
Training loss (scaled): 0.4119 [22/33]
Training loss (scaled): 0.3787 [23/33]
Training loss (scaled): 0.4592 [24/33]
Training loss (scaled): 0.3889 [25/33]
Training loss (scaled): 0.4431 [26/33]
Training loss (scaled): 0.4304 [27/33]
Training loss (scaled): 0.4224 [28/33]
Training loss (scaled): 0.4387 [29/33]
Training loss (scaled): 0.3936 [30/33]
Training loss (scaled): 0.4045 [31/33]
Training loss (scaled): 0.4368 [32/33]
Epoch 7 end
Validation loss at epoch 7 (scaled): 0.3768
Validation loss at epoch 7: 0.7110
Epoch 8 start
Training loss (scaled): 0.4438 [0/33]
Training loss (scaled): 0.3782 [1/33]
Training loss (scaled): 0.4084 [2/33]
Training loss (scaled): 0.3774 [3/33]
Training loss (scaled): 0.4077 [4/33]
Training loss (scaled): 0.4534 [5/33]
Training loss (scaled): 0.3889 [6/33]
Training loss (scaled): 0.4089 [7/33]
Training loss (scaled): 0.3799 [8/33]
Training loss (scaled): 0.4436 [9/33]
Training loss (scaled): 0.4214 [10/33]
Training loss (scaled): 0.4421 [11/33]
Training loss (scaled): 0.4214 [12/33]
Training loss (scaled): 0.3994 [13/33]
Training loss (scaled): 0.4077 [14/33]
Training loss (scaled): 0.4004 [15/33]
Training loss (scaled): 0.4548 [16/33]
Training loss (scaled): 0.4453 [17/33]
Training loss (scaled): 0.4199 [18/33]
Training loss (scaled): 0.4426 [19/33]
Training loss (scaled): 0.3887 [20/33]
Training loss (scaled): 0.4299 [21/33]
Training loss (scaled): 0.4333 [22/33]
Training loss (scaled): 0.4502 [23/33]
Training loss (scaled): 0.4641 [24/33]
Training loss (scaled): 0.3931 [25/33]
Training loss (scaled): 0.4053 [26/33]
Training loss (scaled): 0.4231 [27/33]
Training loss (scaled): 0.4282 [28/33]
Training loss (scaled): 0.4646 [29/33]
Training loss (scaled): 0.3545 [30/33]
Training loss (scaled): 0.3992 [31/33]
Training loss (scaled): 0.4082 [32/33]
Epoch 8 end
Validation loss at epoch 8 (scaled): 0.3708
Validation loss at epoch 8: 0.6989
Epoch 9 start
Training loss (scaled): 0.3257 [0/33]
Training loss (scaled): 0.3926 [1/33]
Training loss (scaled): 0.4119 [2/33]
Training loss (scaled): 0.5244 [3/33]
Training loss (scaled): 0.4785 [4/33]
Training loss (scaled): 0.4290 [5/33]
Training loss (scaled): 0.4595 [6/33]
Training loss (scaled): 0.4036 [7/33]
Training loss (scaled): 0.4187 [8/33]
Training loss (scaled): 0.4075 [9/33]
Training loss (scaled): 0.3384 [10/33]
Training loss (scaled): 0.4338 [11/33]
Training loss (scaled): 0.4077 [12/33]
Training loss (scaled): 0.4485 [13/33]
Training loss (scaled): 0.4392 [14/33]
Training loss (scaled): 0.3640 [15/33]
Training loss (scaled): 0.3997 [16/33]
Training loss (scaled): 0.4136 [17/33]
Training loss (scaled): 0.3755 [18/33]
Training loss (scaled): 0.4504 [19/33]
Training loss (scaled): 0.4307 [20/33]
Training loss (scaled): 0.4619 [21/33]
Training loss (scaled): 0.3486 [22/33]
Training loss (scaled): 0.4304 [23/33]
Training loss (scaled): 0.3750 [24/33]
Training loss (scaled): 0.4080 [25/33]
Training loss (scaled): 0.4058 [26/33]
Training loss (scaled): 0.4558 [27/33]
Training loss (scaled): 0.4185 [28/33]
Training loss (scaled): 0.4119 [29/33]
Training loss (scaled): 0.4121 [30/33]
Training loss (scaled): 0.3945 [31/33]
Training loss (scaled): 0.4443 [32/33]
Epoch 9 end
Validation loss at epoch 9 (scaled): 0.3671
Validation loss at epoch 9: 0.7032
Epoch 10 start
Training loss (scaled): 0.4160 [0/33]
Training loss (scaled): 0.3657 [1/33]
Training loss (scaled): 0.4250 [2/33]
Training loss (scaled): 0.3606 [3/33]
Training loss (scaled): 0.4211 [4/33]
Training loss (scaled): 0.4133 [5/33]
Training loss (scaled): 0.4048 [6/33]
Training loss (scaled): 0.5000 [7/33]
Training loss (scaled): 0.3999 [8/33]
Training loss (scaled): 0.4199 [9/33]
Training loss (scaled): 0.4463 [10/33]
Training loss (scaled): 0.3999 [11/33]
Training loss (scaled): 0.3787 [12/33]
Training loss (scaled): 0.3945 [13/33]
Training loss (scaled): 0.4580 [14/33]
Training loss (scaled): 0.4402 [15/33]
Training loss (scaled): 0.4546 [16/33]
Training loss (scaled): 0.3735 [17/33]
Training loss (scaled): 0.4141 [18/33]
Training loss (scaled): 0.4097 [19/33]
Training loss (scaled): 0.3789 [20/33]
Training loss (scaled): 0.3848 [21/33]
Training loss (scaled): 0.3628 [22/33]
Training loss (scaled): 0.3931 [23/33]
Training loss (scaled): 0.4067 [24/33]
Training loss (scaled): 0.4187 [25/33]
Training loss (scaled): 0.4050 [26/33]
Training loss (scaled): 0.3977 [27/33]
Training loss (scaled): 0.3567 [28/33]
Training loss (scaled): 0.4460 [29/33]
Training loss (scaled): 0.4719 [30/33]
Training loss (scaled): 0.4302 [31/33]
Training loss (scaled): 0.4878 [32/33]
Epoch 10 end
Validation loss at epoch 10 (scaled): 0.3616
Validation loss at epoch 10: 0.7103
Epoch 11 start
Training loss (scaled): 0.4387 [0/33]
Training loss (scaled): 0.4246 [1/33]
Training loss (scaled): 0.4741 [2/33]
Training loss (scaled): 0.4214 [3/33]
Training loss (scaled): 0.3855 [4/33]
Training loss (scaled): 0.4331 [5/33]
Training loss (scaled): 0.3752 [6/33]
Training loss (scaled): 0.3652 [7/33]
Training loss (scaled): 0.4309 [8/33]
Training loss (scaled): 0.3799 [9/33]
Training loss (scaled): 0.4414 [10/33]
Training loss (scaled): 0.4143 [11/33]
Training loss (scaled): 0.3821 [12/33]
Training loss (scaled): 0.3503 [13/33]
Training loss (scaled): 0.4426 [14/33]
Training loss (scaled): 0.3931 [15/33]
Training loss (scaled): 0.4011 [16/33]
Training loss (scaled): 0.4319 [17/33]
Training loss (scaled): 0.4209 [18/33]
Training loss (scaled): 0.4080 [19/33]
Training loss (scaled): 0.5190 [20/33]
Training loss (scaled): 0.4082 [21/33]
Training loss (scaled): 0.4209 [22/33]
Training loss (scaled): 0.3794 [23/33]
Training loss (scaled): 0.3772 [24/33]
Training loss (scaled): 0.3721 [25/33]
Training loss (scaled): 0.4241 [26/33]
Training loss (scaled): 0.4021 [27/33]
Training loss (scaled): 0.3865 [28/33]
Training loss (scaled): 0.4167 [29/33]
Training loss (scaled): 0.4834 [30/33]
Training loss (scaled): 0.3765 [31/33]
Training loss (scaled): 0.3296 [32/33]
Epoch 11 end
Validation loss at epoch 11 (scaled): 0.3568
Validation loss at epoch 11: 0.7126
Epoch 12 start
Training loss (scaled): 0.3965 [0/33]
Training loss (scaled): 0.3430 [1/33]
Training loss (scaled): 0.4348 [2/33]
Training loss (scaled): 0.3718 [3/33]
Training loss (scaled): 0.4041 [4/33]
Training loss (scaled): 0.4045 [5/33]
Training loss (scaled): 0.4658 [6/33]
Training loss (scaled): 0.4331 [7/33]
Training loss (scaled): 0.3931 [8/33]
Training loss (scaled): 0.4214 [9/33]
Training loss (scaled): 0.4358 [10/33]
Training loss (scaled): 0.4421 [11/33]
Training loss (scaled): 0.4126 [12/33]
Training loss (scaled): 0.4233 [13/33]
Training loss (scaled): 0.3572 [14/33]
Training loss (scaled): 0.3760 [15/33]
Training loss (scaled): 0.3989 [16/33]
Training loss (scaled): 0.3936 [17/33]
Training loss (scaled): 0.4368 [18/33]
Training loss (scaled): 0.4128 [19/33]
Training loss (scaled): 0.4697 [20/33]
Training loss (scaled): 0.4104 [21/33]
Training loss (scaled): 0.4207 [22/33]
Training loss (scaled): 0.3813 [23/33]
Training loss (scaled): 0.3743 [24/33]
Training loss (scaled): 0.4106 [25/33]
Training loss (scaled): 0.4231 [26/33]
Training loss (scaled): 0.4358 [27/33]
Training loss (scaled): 0.4773 [28/33]
Training loss (scaled): 0.4075 [29/33]
Training loss (scaled): 0.3992 [30/33]
Training loss (scaled): 0.3733 [31/33]
Training loss (scaled): 0.3323 [32/33]
Epoch 12 end
Validation loss at epoch 12 (scaled): 0.3552
Validation loss at epoch 12: 0.7013
Epoch 13 start
Training loss (scaled): 0.4077 [0/33]
Training loss (scaled): 0.3345 [1/33]
Training loss (scaled): 0.4136 [2/33]
Training loss (scaled): 0.3936 [3/33]
Training loss (scaled): 0.4302 [4/33]
Training loss (scaled): 0.3457 [5/33]
Training loss (scaled): 0.4199 [6/33]
Training loss (scaled): 0.4307 [7/33]
Training loss (scaled): 0.4023 [8/33]
Training loss (scaled): 0.3308 [9/33]
Training loss (scaled): 0.3813 [10/33]
Training loss (scaled): 0.3892 [11/33]
Training loss (scaled): 0.3499 [12/33]
Training loss (scaled): 0.4666 [13/33]
Training loss (scaled): 0.4937 [14/33]
Training loss (scaled): 0.4275 [15/33]
Training loss (scaled): 0.4087 [16/33]
Training loss (scaled): 0.4822 [17/33]
Training loss (scaled): 0.4170 [18/33]
Training loss (scaled): 0.3948 [19/33]
Training loss (scaled): 0.4583 [20/33]
Training loss (scaled): 0.4229 [21/33]
Training loss (scaled): 0.3940 [22/33]
Training loss (scaled): 0.3770 [23/33]
Training loss (scaled): 0.4866 [24/33]
Training loss (scaled): 0.3835 [25/33]
Training loss (scaled): 0.4238 [26/33]
Training loss (scaled): 0.3699 [27/33]
Training loss (scaled): 0.3833 [28/33]
Training loss (scaled): 0.4583 [29/33]
Training loss (scaled): 0.4092 [30/33]
Training loss (scaled): 0.3901 [31/33]
Training loss (scaled): 0.3960 [32/33]
Epoch 13 end
Validation loss at epoch 13 (scaled): 0.3552
Validation loss at epoch 13: 0.7009
Epoch 14 start
Training loss (scaled): 0.3989 [0/33]
Training loss (scaled): 0.4541 [1/33]
Training loss (scaled): 0.4031 [2/33]
Training loss (scaled): 0.3794 [3/33]
Training loss (scaled): 0.4490 [4/33]
Training loss (scaled): 0.4451 [5/33]
Training loss (scaled): 0.4978 [6/33]
Training loss (scaled): 0.3982 [7/33]
Training loss (scaled): 0.4253 [8/33]
Training loss (scaled): 0.3896 [9/33]
Training loss (scaled): 0.4092 [10/33]
Training loss (scaled): 0.3547 [11/33]
Training loss (scaled): 0.4739 [12/33]
Training loss (scaled): 0.4065 [13/33]
Training loss (scaled): 0.4004 [14/33]
Training loss (scaled): 0.3972 [15/33]
Training loss (scaled): 0.3875 [16/33]
Training loss (scaled): 0.3767 [17/33]
Training loss (scaled): 0.4216 [18/33]
Training loss (scaled): 0.3386 [19/33]
Training loss (scaled): 0.3823 [20/33]
Training loss (scaled): 0.4116 [21/33]
Training loss (scaled): 0.3882 [22/33]
Training loss (scaled): 0.3708 [23/33]
Training loss (scaled): 0.4268 [24/33]
Training loss (scaled): 0.3677 [25/33]
Training loss (scaled): 0.4009 [26/33]
Training loss (scaled): 0.4165 [27/33]
Training loss (scaled): 0.4329 [28/33]
Training loss (scaled): 0.3398 [29/33]
Training loss (scaled): 0.4089 [30/33]
Training loss (scaled): 0.4509 [31/33]
Training loss (scaled): 0.4446 [32/33]
Epoch 14 end
Validation loss at epoch 14 (scaled): 0.3620
Validation loss at epoch 14: 0.6809
((495,), (15,))
Plot the scaled training vs. validation loss. Our model performs remarkably well on both seen and unseen inputs!
import matplotlib.pyplot as plt
batch_total = training_losses_scaled.size
epoch_total = validation_losses_scaled.size
batches_per_epoch = batch_total // epoch_total
batch_x = np.arange(batch_total)
epoch_x = np.arange(epoch_total) * batches_per_epoch + batches_per_epoch
training_losses_y = training_losses_scaled.asnumpy()
validation_losses_y = validation_losses_scaled.asnumpy()
plt.figure(figsize=(8, 5))
plt.plot(batch_x, training_losses_y, label='Training loss (scaled)', linestyle='-')
plt.plot(epoch_x, validation_losses_y, label='Validation loss (scaled)', linestyle='--')
plt.title('Training vs. validation loss (scaled)')
plt.xlabel('Batch number')
plt.ylabel('Loss (scaled)')
plt.yscale('log')
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend()
plt.show()
.

Let’s take a sample from the validation set and use our model to predict the house price. We’ll see how it compares to the actual price and how much it differs.
Recall that:
StandardScalerStandardScalerX_test_scaled_sample = X_test_scaled[0]
y_test_scaled_sample = y_test_scaled[0].reshape(-1, 1)
model.set_train(False)
y_hat_scaled_sample = model(mindspore.Tensor(X_test_scaled_sample.astype(np.float16))).asnumpy().reshape(-1, 1)
y_hat_sample = scaler.inverse_transform(y_hat_scaled_sample).item()
y_hat_sample_usd = y_hat_sample * 100_000
y_test_sample = scaler.inverse_transform(y_test_scaled_sample).item()
y_test_sample_usd = y_test_sample * 100_000
print(f'Predicted house price (USD$): {y_hat_sample_usd:>10.2f}')
print(f'Actual house price (USD$): {y_test_sample_usd:>13.2f}')
print(f'Percentage error: {(y_hat_sample_usd - y_test_sample_usd) / y_test_sample_usd * 100:>22.2f}%')
Predicted house price (USD$): 129003.91
Actual house price (USD$): 165600.00
Percentage error: -22.10%
Now that our linear regression model is trained to predict house prices in California, let’s save and export it so others can download and import our model for inference. The most common exported model formats supported by MindSpore include but not limited to the below.
mindspore.export function with file_format='MINDIR' to export models in this formatmindspore.export function with file_format='ONNX' to export models in this formatLet’s export our model in MindIR format and specify the following arguments.
net: our trained model*inputs: sample inputs that our model takes. Since each sample in our dataset has 8 features, our sample tensor has shape (1, 8)file_name: the name of our exported model, minus the file extensionfile_format: the format of our exported model. The appropriate file extension is automatically appended, e.g. .mindir# Export our model in MindIR format to `california-housing-linear-simple.mindir`
mindspore.export(model,
mindspore.Tensor(X_test_scaled_sample.astype(np.float16).reshape(-1, 8)),
file_name='california-housing-linear-simple',
file_format='MINDIR')
# Verify the file exists
os.path.isfile('california-housing-linear-simple.mindir')
True
To export our model in ONNX format, the file_format='ONNX' option is deprecated so we will use mindspore.onnx.export instead with the following parameters.
net: our trained model*inputs: sample inputs to our model, similar with mindspore.exportfile_name: filename of our exported ONNX model. The .onnx file extension must be specified explicitlyinput_names: list of names for our model inputs, e.g. ['features']output_names: list of names for our model outputs, e.g. ['labels']# Export our model in ONNX format to `california-housing-linear-simple.onnx`
mindspore.onnx.export(model,
mindspore.Tensor(X_test_scaled_sample.astype(np.float16).reshape(-1, 8)),
file_name='california-housing-linear-simple.onnx',
input_names=['features'],
output_names=['labels'])
# Verify the file exists
os.path.isfile('california-housing-linear-simple.onnx')
True
With our model exported in both formats, we can upload them to object storage or a dedicated hub for sharing machine learning models such as Hugging Face. This is outside the scope of this article.
We saw in this interactive notebook article how to train a simple linear regression model in MindSpore to predict house prices in California. The key points are summarized below.
This experiment has only scratched the surface of what’s possible with MindSpore and the Ascend ecosystem. I hope you enjoyed this article and stay tuned for updates! ;-)