{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Mixture of experts"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example we are going to approximate a piece wise continuous function using an expert mixture of metamodels.\n",
"\n",
"The metamodels will be represented by the family of $f_k \\forall \\in [1, N]$:\n",
"\n",
"$$f(\\underline{x}) = f_1(\\underline{x}) \\quad \\forall \\underline{z} \\in Class\\, 1$$\n",
" \n",
"$$\\dots$$\n",
"\n",
"$$f(\\underline{x}) = f_k(\\underline{x}) \\quad \\forall \\underline{z} \\in Class\\, k$$\n",
"\n",
"$$\\dots$$\n",
"\n",
"$$f(\\underline{x}) = f_N(\\underline{x}) \\quad \\forall \\underline{z} \\in Class\\, N$$\n",
" \n",
"where the N classes are defined by the classifier.\n",
"\n",
"Using the supervised mode the classifier partitions the input and output space at once:\n",
"\n",
"$$ z =(\\underline{x}, f( \\underline{x})) $$\n",
"\n",
"The classifier is MixtureClassifier based on a MixtureDistribution defined as:\n",
"$$ p(\\underline{x}) = \\sum_{i=1}^N w_ip_i(\\underline{x})$$\n",
"\n",
"\n",
"The rule to assign a point to a class is defined as follows: $\\underline{x}$ is assigned to the class $j=argmax_j \\log w_kp_k(\\underline{z})$.\n",
"\n",
"The grade of $\\underline{x}$ with respect to the class $k$ is $\\log w_kp_k(\\underline{x})$.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import print_function\n",
"import openturns as ot\n",
"from matplotlib import pyplot as plt\n",
"from openturns.viewer import View\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"dimension = 1\n",
"\n",
"# Define the piecewise model we want to rebuild\n",
"def piecewise(X):\n",
" # if x < 0.0:\n",
" # f = (x+0.75)**2-0.75**2\n",
" # else:\n",
" # f = 2.0-x**2\n",
" xarray = np.array(X, copy=False)\n",
" return np.piecewise(xarray, [xarray < 0, xarray >= 0], [lambda x: x*(x+1.5), lambda x: 2.0 - x*x])\n",
"f = ot.PythonFunction(1, 1, func_sample=piecewise)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Build a metamodel over each segment\n",
"degree = 5\n",
"samplingSize = 100\n",
"enumerateFunction = ot.LinearEnumerateFunction(dimension)\n",
"productBasis = ot.OrthogonalProductPolynomialFactory([ot.LegendreFactory()] * dimension, enumerateFunction)\n",
"adaptiveStrategy = ot.FixedStrategy(productBasis, enumerateFunction.getStrataCumulatedCardinal(degree))\n",
"projectionStrategy = ot.LeastSquaresStrategy(ot.MonteCarloExperiment(samplingSize))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
"class=Graph name=v0 as a function of x0 implementation=class=GraphImplementation name=v0 as a function of x0 title=v0 as a function of x0 xTitle=x0 yTitle=v0 axes=ON grid=ON legendposition= legendFontSize=1 drawables=[class=Drawable name=Unnamed implementation=class=Curve name=Unnamed derived from class=DrawableImplementation name=Unnamed legend= data=class=Sample name=Unnamed implementation=class=SampleImplementation name=Unnamed size=129 dimension=2 data=[[-1,-0.5],[-0.992188,-0.503845],[-0.984375,-0.507568],[-0.976563,-0.511169],[-0.96875,-0.514648],[-0.960938,-0.518005],[-0.953125,-0.52124],[-0.945313,-0.524353],[-0.9375,-0.527344],[-0.929688,-0.530212],[-0.921875,-0.532959],[-0.914063,-0.535583],[-0.90625,-0.538086],[-0.898438,-0.540466],[-0.890625,-0.542725],[-0.882813,-0.544861],[-0.875,-0.546875],[-0.867188,-0.548767],[-0.859375,-0.550537],[-0.851563,-0.552185],[-0.84375,-0.553711],[-0.835938,-0.555115],[-0.828125,-0.556396],[-0.820313,-0.557556],[-0.8125,-0.558594],[-0.804688,-0.559509],[-0.796875,-0.560303],[-0.789063,-0.560974],[-0.78125,-0.561523],[-0.773438,-0.561951],[-0.765625,-0.562256],[-0.757813,-0.562439],[-0.75,-0.5625],[-0.742188,-0.562439],[-0.734375,-0.562256],[-0.726563,-0.561951],[-0.71875,-0.561523],[-0.710938,-0.560974],[-0.703125,-0.560303],[-0.695313,-0.559509],[-0.6875,-0.558594],[-0.679688,-0.557556],[-0.671875,-0.556397],[-0.664063,-0.555115],[-0.65625,-0.553711],[-0.648438,-0.552185],[-0.640625,-0.550537],[-0.632813,-0.548767],[-0.625,-0.546875],[-0.617188,-0.544861],[-0.609375,-0.542725],[-0.601563,-0.540466],[-0.59375,-0.538086],[-0.585938,-0.535584],[-0.578125,-0.532959],[-0.570313,-0.530213],[-0.5625,-0.527344],[-0.554688,-0.524353],[-0.546875,-0.52124],[-0.539063,-0.518006],[-0.53125,-0.514649],[-0.523438,-0.51117],[-0.515625,-0.507569],[-0.507813,-0.503845],[-0.500001,-0.5],[-0.492188,-0.496033],[-0.484376,-0.491944],[-0.476563,-0.487732],[-0.468751,-0.483399],[-0.460938,-0.478943],[-0.453126,-0.474366],[-0.445313,-0.469666],[-0.437501,-0.464844],[-0.429688,-0.4599],[-0.421876,-0.454834],[-0.414063,-0.449646],[-0.406251,-0.444336],[-0.398438,-0.438904],[-0.390626,-0.43335],[-0.382813,-0.427674],[-0.375001,-0.421875],[-0.367188,-0.415955],[-0.359376,-0.409913],[-0.351563,-0.403748],[-0.343751,-0.397461],[-0.335938,-0.391053],[-0.328126,-0.384522],[-0.320313,-0.377869],[-0.312501,-0.371094],[-0.304688,-0.364197],[-0.296876,-0.357178],[-0.289063,-0.350037],[-0.281251,-0.342774],[-0.273438,-0.335389],[-0.265626,-0.327882],[-0.257813,-0.320252],[-0.250001,-0.312501],[-0.242188,-0.304627],[-0.234376,-0.296632],[-0.226563,-0.288514],[-0.218751,-0.280274],[-0.210938,-0.271912],[-0.203126,-0.263429],[-0.195313,-0.254823],[-0.187501,-0.246095],[-0.179688,-0.237245],[-0.171876,-0.228272],[-0.164063,-0.219178],[-0.156251,-0.209962],[-0.148438,-0.200624],[-0.140626,-0.191163],[-0.132813,-0.181581],[-0.125001,-0.171876],[-0.117188,-0.162049],[-0.109376,-0.152101],[-0.101563,-0.14203],[-0.0937509,-0.131837],[-0.0859384,-0.121522],[-0.0781259,-0.111085],[-0.0703134,-0.100526],[-0.0625009,-0.089845],[-0.0546884,-0.0790418],[-0.046876,-0.0681166],[-0.0390635,-0.0570692],[-0.031251,-0.0458998],[-0.0234385,-0.0346084],[-0.015626,-0.0231948],[-0.00781349,-0.0116592],[-1e-06,-1.5e-06]] color=blue fillStyle=solid lineStyle=solid pointStyle=none lineWidth=1]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Segment 1: (-1.0; 0.0)\n",
"d1 = ot.Uniform(-1.0, 0.0)\n",
"fc1 = ot.FunctionalChaosAlgorithm(f, d1, adaptiveStrategy, projectionStrategy)\n",
"fc1.run()\n",
"mm1 = fc1.getResult().getMetaModel()\n",
"mm1.draw(-1.0, -1e-6)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
"class=Graph name=v0 as a function of x0 implementation=class=GraphImplementation name=v0 as a function of x0 title=v0 as a function of x0 xTitle=x0 yTitle=v0 axes=ON grid=ON legendposition= legendFontSize=1 drawables=[class=Drawable name=Unnamed implementation=class=Curve name=Unnamed derived from class=DrawableImplementation name=Unnamed legend= data=class=Sample name=Unnamed implementation=class=SampleImplementation name=Unnamed size=129 dimension=2 data=[[1e-06,2],[0.00781349,1.99994],[0.015626,1.99976],[0.0234385,1.99945],[0.031251,1.99902],[0.0390635,1.99847],[0.046876,1.9978],[0.0546884,1.99701],[0.0625009,1.99609],[0.0703134,1.99506],[0.0781259,1.9939],[0.0859384,1.99261],[0.0937509,1.99121],[0.101563,1.98968],[0.109376,1.98804],[0.117188,1.98627],[0.125001,1.98437],[0.132813,1.98236],[0.140626,1.98022],[0.148438,1.97797],[0.156251,1.97559],[0.164063,1.97308],[0.171876,1.97046],[0.179688,1.96771],[0.187501,1.96484],[0.195313,1.96185],[0.203126,1.95874],[0.210938,1.95551],[0.218751,1.95215],[0.226563,1.94867],[0.234376,1.94507],[0.242188,1.94134],[0.250001,1.9375],[0.257813,1.93353],[0.265626,1.92944],[0.273438,1.92523],[0.281251,1.9209],[0.289063,1.91644],[0.296876,1.91186],[0.304688,1.90717],[0.312501,1.90234],[0.320313,1.8974],[0.328126,1.89233],[0.335938,1.88715],[0.343751,1.88184],[0.351563,1.8764],[0.359376,1.87085],[0.367188,1.86517],[0.375001,1.85937],[0.382813,1.85345],[0.390626,1.84741],[0.398438,1.84125],[0.406251,1.83496],[0.414063,1.82855],[0.421876,1.82202],[0.429688,1.81537],[0.437501,1.80859],[0.445313,1.8017],[0.453126,1.79468],[0.460938,1.78754],[0.468751,1.78027],[0.476563,1.77289],[0.484376,1.76538],[0.492188,1.75775],[0.5,1.75],[0.507813,1.74213],[0.515625,1.73413],[0.523438,1.72601],[0.53125,1.71777],[0.539063,1.70941],[0.546875,1.70093],[0.554688,1.69232],[0.5625,1.68359],[0.570313,1.67474],[0.578125,1.66577],[0.585938,1.65668],[0.59375,1.64746],[0.601563,1.63812],[0.609375,1.62866],[0.617188,1.61908],[0.625,1.60937],[0.632813,1.59955],[0.640625,1.5896],[0.648438,1.57953],[0.65625,1.56934],[0.664063,1.55902],[0.671875,1.54858],[0.679688,1.53802],[0.6875,1.52734],[0.695313,1.51654],[0.703125,1.50561],[0.710938,1.49457],[0.71875,1.4834],[0.726563,1.47211],[0.734375,1.46069],[0.742188,1.44916],[0.75,1.4375],[0.757813,1.42572],[0.765625,1.41382],[0.773438,1.40179],[0.78125,1.38965],[0.789063,1.37738],[0.796875,1.36499],[0.804688,1.35248],[0.8125,1.33984],[0.820313,1.32709],[0.828125,1.31421],[0.835938,1.30121],[0.84375,1.28809],[0.851563,1.27484],[0.859375,1.26147],[0.867188,1.24799],[0.875,1.23437],[0.882813,1.22064],[0.890625,1.20679],[0.898438,1.19281],[0.90625,1.17871],[0.914063,1.16449],[0.921875,1.15015],[0.929688,1.13568],[0.9375,1.12109],[0.945313,1.10638],[0.953125,1.09155],[0.960938,1.0766],[0.96875,1.06152],[0.976563,1.04633],[0.984375,1.03101],[0.992188,1.01556],[1,1]] color=blue fillStyle=solid lineStyle=solid pointStyle=none lineWidth=1]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Segment 2: (0.0, 1.0)\n",
"d2 = ot.Uniform(0.0, 1.0)\n",
"fc2 = ot.FunctionalChaosAlgorithm(f, d2, adaptiveStrategy, projectionStrategy)\n",
"fc2.run()\n",
"mm2 = fc2.getResult().getMetaModel()\n",
"mm2.draw(1e-6,1.0)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Define the mixture\n",
"R = ot.CorrelationMatrix(2)\n",
"d1 = ot.Normal([-1.0, -1.0], [1.0]*2, R)# segment 1\n",
"d2 = ot.Normal([1.0, 1.0], [1.0]*2, R)# segment 2\n",
"weights = [1.0]*2\n",
"atoms = [d1, d2]\n",
"mixture = ot.Mixture(atoms, weights)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Create the classifier based on the mixture\n",
"classifier = ot.MixtureClassifier(mixture)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Create local experts using the metamodels\n",
"experts = ot.Basis([mm1, mm2])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Create a mixture of experts\n",
"evaluation = ot.ExpertMixture(experts, classifier)\n",
"moe = ot.Function(evaluation)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
"class=Graph name=v0 as a function of x0 implementation=class=GraphImplementation name=v0 as a function of x0 title=v0 as a function of x0 xTitle=x0 yTitle=v0 axes=ON grid=ON legendposition= legendFontSize=1 drawables=[class=Drawable name=Unnamed implementation=class=Curve name=Unnamed derived from class=DrawableImplementation name=Unnamed legend= data=class=Sample name=Unnamed implementation=class=SampleImplementation name=Unnamed size=129 dimension=2 data=[[-1,-0.5],[-0.984375,-0.507568],[-0.96875,-0.514648],[-0.953125,-0.52124],[-0.9375,-0.527344],[-0.921875,-0.532959],[-0.90625,-0.538086],[-0.890625,-0.542725],[-0.875,-0.546875],[-0.859375,-0.550537],[-0.84375,-0.553711],[-0.828125,-0.556396],[-0.8125,-0.558594],[-0.796875,-0.560303],[-0.78125,-0.561523],[-0.765625,-0.562256],[-0.75,-0.5625],[-0.734375,-0.562256],[-0.71875,-0.561523],[-0.703125,-0.560303],[-0.6875,-0.558594],[-0.671875,-0.556396],[-0.65625,-0.553711],[-0.640625,-0.550537],[-0.625,-0.546875],[-0.609375,-0.542725],[-0.59375,-0.538086],[-0.578125,-0.532959],[-0.5625,-0.527344],[-0.546875,-0.52124],[-0.53125,-0.514648],[-0.515625,-0.507568],[-0.5,-0.5],[-0.484375,-0.491943],[-0.46875,-0.483398],[-0.453125,-0.474365],[-0.4375,-0.464844],[-0.421875,-0.454834],[-0.40625,-0.444336],[-0.390625,-0.43335],[-0.375,-0.421875],[-0.359375,-0.409912],[-0.34375,-0.397461],[-0.328125,-0.384521],[-0.3125,-0.371094],[-0.296875,-0.357178],[-0.28125,-0.342773],[-0.265625,-0.327881],[-0.25,-0.3125],[-0.234375,-0.296631],[-0.21875,-0.280273],[-0.203125,-0.263428],[-0.1875,-0.246094],[-0.171875,-0.228271],[-0.15625,-0.209961],[-0.140625,-0.191162],[-0.125,-0.171875],[-0.109375,-0.1521],[-0.09375,-0.131836],[-0.078125,-0.111084],[-0.0625,-0.0898438],[-0.046875,-0.0681152],[-0.03125,-0.0458984],[-0.015625,-0.0231934],[0,-1.26806e-17],[0.015625,1.99976],[0.03125,1.99902],[0.046875,1.9978],[0.0625,1.99609],[0.078125,1.9939],[0.09375,1.99121],[0.109375,1.98804],[0.125,1.98438],[0.140625,1.98022],[0.15625,1.97559],[0.171875,1.97046],[0.1875,1.96484],[0.203125,1.95874],[0.21875,1.95215],[0.234375,1.94507],[0.25,1.9375],[0.265625,1.92944],[0.28125,1.9209],[0.296875,1.91187],[0.3125,1.90234],[0.328125,1.89233],[0.34375,1.88184],[0.359375,1.87085],[0.375,1.85937],[0.390625,1.84741],[0.40625,1.83496],[0.421875,1.82202],[0.4375,1.80859],[0.453125,1.79468],[0.46875,1.78027],[0.484375,1.76538],[0.5,1.75],[0.515625,1.73413],[0.53125,1.71777],[0.546875,1.70093],[0.5625,1.68359],[0.578125,1.66577],[0.59375,1.64746],[0.609375,1.62866],[0.625,1.60937],[0.640625,1.5896],[0.65625,1.56934],[0.671875,1.54858],[0.6875,1.52734],[0.703125,1.50562],[0.71875,1.4834],[0.734375,1.46069],[0.75,1.4375],[0.765625,1.41382],[0.78125,1.38965],[0.796875,1.36499],[0.8125,1.33984],[0.828125,1.31421],[0.84375,1.28809],[0.859375,1.26147],[0.875,1.23437],[0.890625,1.20679],[0.90625,1.17871],[0.921875,1.15015],[0.9375,1.12109],[0.953125,1.09155],[0.96875,1.06152],[0.984375,1.03101],[1,1]] color=blue fillStyle=solid lineStyle=solid pointStyle=none lineWidth=1]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Draw the mixture of experts\n",
"moe.draw(-1.0, 1.0)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}