It's a little bit of a long one, as I wanted to share my learning process, discovery steps and changes I've made along the way as they were happening. For a straight up solution, please see the code at the end.
The second version of the code was my starting point. Over the course of a few days, I went from making simple tuning changes and just rerunning the code, to mostly grasping what each part of the code does (not necessarily knowing 100% how it works). As I barely knew the two libraries, I had to do a lot of googling and reading documentation to understand the basics, with an occasional ChatGPT usage here and there. Chat was especially helpful in explaining certain usage mechanics via simple examples, but more on that later.
The first thing I did, was to implement the idea from my last comment at the time, made under the 2nd answer here - for the non-overlapping cases, when a weight reaches close to zero during solving, we can recompute the machine/product compatibility matrix, to let it separate the saturated machine into it's own group. This will let it arrive at it's own, lower average, than the rest of its group would normally arrive at. For this, I've added a second loop on top of the first one, put the machine group calculation logic at the beginning of the outer loop, and then added an extra exit condition to the inner loop:
zero_weight_threshold = 1e-3 while True: # Build a graph where machines/nodes are connected (edges) if they share a product G = nx.Graph() G.add_nodes_from(range(num_machines)) for n1 in range(num_machines): for n2 in range(n1 + 1, num_machines): if product_comp[n1].mul(product_comp[n2]).any(): G.add_edge(n1, n2) # Extract connected components (groups of machines) machine_groups = [np.sort(list(group)) for group in nx.connected_components(G)] while True: # solve here using the calculated machine_groups if ((weights < zero_weight_threshold) & machine_product_compatibility.bool()).any(): # Recompute the machine product compatibility using the updated weights machine_product_compatibility = (weights >= zero_weight_threshold).float() break
This was the first breakthrough, as it removed the imbalance caused by one of the machines "dragging down" the average for the whole group, and allowed the loss function to continue decreasing to <1e-14 ranges. I had to up the learning rate so that it wouldn't take forever to detect the weight reaching close to zero, but that seemed to introduce trouble in later iterations. To help this, I did two things. First off, I've noticed that the weights are really taking their time reaching the extremes of 0 and 1, and managed to track it down to the logits_to_weights function. The sigmoid function that was used to constrain the weights to (0, 1) range was a logical choice, but something that'd converge quicker was clearly needed. Fortunately, the topic of so called "activation functions" is quite known to me, due to me learning about neural networks in the past, so I've replaced it with an appropriately scaled tanh call. Keeping the single-machine groups and weight constraints of always summing up to 1 in mind, to help single-machine weights converge quicker, I've also added a small overscale, which also seemed to help reduce iterations needed:
overscale: float = 1.01 def logits_to_weights(logits): return (logits.tanh() / 2 + 0.5) * overscale
I've also experimented with a dynamic learning rate. Out of all available options in Torch, this one seemed the most suited:
optimiser = torch.optim.Adam([logits], lr=0.1) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, factor=0.95, patience=70) while True: # ... optimiser.zero_grad() loss.backward() optimiser.step() scheduler.step(loss.item()) # calling the scheduler with current loss value here # ...
The values of factor and patience, as well as the learning rate itself, can be tuned to possibly achieve quicker convergence times/iterations. The ranges I've experimented with were about 0.01-0.5 for the learning rate, 0.5-0.99 for the factor, and 10-100 for the patience. 0.95 and 70 for factor and patience seemed to be the sweet spot, that had the best average performance over the given examples. The learning rate can usually start at 0.3-0.4 to get a fast start, but too high values would result in loss oscillations. Values of 0.1-0.2 or lower seemed to be the most stable across the given examples and tests I've performed.
This solution seemed ready and was actually working perfectly on all training examples, so I tried using it on some bigger, custom problems: 11x9 and 17x6 of machines/products respectively. The second problem I chose was actually one of the harder ones I had, with highest average utilization reaching around 102% - I wanted to see what the 2nd loss term of over-utilization does to the code. The solution I got was correct for the first problem, but incorrect for the second problem - the utilization percentages ended up slightly off the average. This one boggled me for some time, and after many tests and observing the results, I've reached the conclusion that the 2nd over-utilization loss term was actually not really needed at all! It turns out that the 3rd loss term of evenly distributing the utilization is enough by itself, and the 2nd term was actually competing value-wise with the 3rd term, leading to the solution not converging as expected. Removing the over-utilization term entirely allowed me to get the exact solution I've expected to see, and everything worked from here!
I was basically ready to post the solution at this point, but I guess by sheer curiosity of research, I've stumbled upon posts talking about "parametrization". Digging deeper into this topic revealed that it's apparently a great way to specify a "values that always sum up to 1" constraint, and since that's exactly what I was already doing in this code (for the weights), it has peaked my interest. This is where ChatGPT proved to be the most useful, by showing me simple examples and explaining to me how to specify such a constraint.
This was the second breakthrough. It turns out that the logits_to_weights function can be rewritten as such:
from torch.nn.functional import softmax def logits_to_weights(logits): weights = logits.masked_fill(machine_product_compatibility == 0, -np.inf) return softmax(weights, dim=0)
It turns out, the -inf value lets the softmax function ignore the incompatible places in the compatibility matrix, letting the remaining compatible values sum up to exactly 1.0, every time, in every case. This removes the 1st weights sum term from the loss function, leaving only the uniformity one in and simplifying everything. The learning rate also had to be reduced to ~0.1, as the convergence was so quick, higher values often tended to overshoot. This simple change has improved the convergence speed and reduced the iterations required ~10 fold, while simultenously removing the possibility of the weights sum being not quite right and introducing inaccuracy in the results. Fantastic!
Some other, non-significant changes include:
machine_product_compatibility is now a bool matrix, as it better fits it's definition and purpose. All affected places in the code have been adjusted. - Simplified the
num_products and num_machines definitions, so that they're inferred from the compatibility matrix size. machine_capacities calculation now matches closer to the way it was calculated in my excel (86400 / cycle_time). - Added
zero_weight_threshold, that controls when to treat a weight as "zeroed out" during solving (for compatibility matrix recalculation). logits now starts as a copy of the compatibility matrix instead of being random, as this translates well to the distribution of the starting weights. - There's now a single, separate variable that tracks the amount of iterations during solving, even if the inner loop breaks for the outer loop to recalculate the compatibility matrix.
- Added
learning_rate and groups to the loss graph, to show when the groups are recalculated during solving, as well as show how the learning rate changes over time. - The
masked_fill of -inf values was optimized to be called only when the compatibility matrix changes (that's why the respective call in logits_to_weights is now commented out. - The plots and prints at the very end have been adjusted, to remove information that's obsolete or no longer present.
The final solution code:
from collections import defaultdict import torch import numpy as np from torch import nn import networkx as nx from matplotlib import pyplot as plt from torch.nn.functional import softmax torch.set_printoptions(linewidth=180) np.set_printoptions(linewidth=180, suppress=True) # Overlapping example machine_product_compatibility = torch.tensor([ # Pa Pb Pc Pd [1, 0, 1, 0], # M1: Pa, Pc [0, 1, 0, 1], # M2: Pb, Pd [0, 1, 1, 0], # M3: Pb, Pc [1, 0, 0, 1], # M4: Pa, Pd [1, 1, 0, 0], # M5: Pa, Pb [0, 0, 1, 1], # M6: Pc, Pd ]).bool() # Non-overlapping, ex1 # machine_product_compatibility = torch.tensor([ # # Pa Pb Pc Pd # [1, 0, 0, 0], # M1 # [1, 0, 1, 0], # M2 # [0, 0, 1, 0], # M3 # [0, 0, 0, 1], # M4 # [0, 1, 0, 1], # M5 # [0, 1, 0, 0], # M6 # ]).bool() # Non-overlapping, ex2 # machine_product_compatibility = torch.tensor([ # # Pa Pb Pc Pd # [1, 0, 0, 0], # M1 # [1, 1, 0, 0], # M2 # [0, 1, 0, 0], # M3 # [0, 0, 0, 1], # M4 # [0, 0, 1, 1], # M5 # [0, 0, 1, 0], # M6 # ]).bool() # Pa Pb Pc Pd product_totals = torch.tensor([6720, 2880, 3840, 5760]).float() / 10 num_machines = machine_product_compatibility.size(dim=0) # M1 M2 M3 M4 M5 M6 cycle_times = torch.tensor([36, 36, 36, 36, 18, 18]).float() # Products per day per machine T_workday = 24 * 60 * 60 machine_capacities = torch.full((num_machines,), T_workday).float().div(cycle_times) def logits_to_weights(logits): # weights = logits.masked_fill(machine_product_compatibility == 0, float('-inf')) return softmax(logits, dim=0) np.random.seed(0) torch.manual_seed(0) i = 0 total_epochs = 10_000 # stop after this many epochs loss_threshold = 1e-19 # stop once loss is lower than this value zero_weight_threshold = 1e-3 # decides when to zero-out the weight, due to group splitting results = defaultdict(list) logits = nn.Parameter( machine_product_compatibility.float() .masked_fill(machine_product_compatibility == 0, -np.inf) ) optimiser = torch.optim.Adam([logits], lr=0.1) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, factor=0.95, patience=70) while True: # Build a graph where machines/nodes are connected (edges) if they share a product G = nx.Graph() G.add_nodes_from(range(num_machines)) for n1 in range(num_machines): for n2 in range(n1 + 1, num_machines): if (machine_product_compatibility[n1] & machine_product_compatibility[n2]).any(): G.add_edge(n1, n2) # Extract connected components (groups of machines) machine_groups = [np.sort(list(group)) for group in nx.connected_components(G)] print(f"Num. of machine groups is {len(machine_groups)}:", machine_groups) if i > 0: # skip first iteration results["compatibility_changed"].append(i) while True: # Each weight is bound (0, 1), and only available for compatible pairs weights = logits_to_weights(logits) # Current utilisation per machine utilisation_per_mach = ( product_totals.mul(weights) # Ratio of each product passing through the machine .sum(dim=1) # Total amount of all products passing through the machine .div(machine_capacities) # Express as fraction of each machine's capacity ) # Encourage uniformity, per machine group loss = torch.stack([ # Deviation of machines from group average (utilisation_per_mach[mach_ixs] - utilisation_per_mach[mach_ixs].mean()) .square().mean() # Scalar MSE for mach_ixs in machine_groups # Repeat for each group of machine ]).mean() # Optimisation step optimiser.zero_grad() loss.backward() optimiser.step() scheduler.step(loss.item()) # === Record metrics === results["learning_rate"].append(optimiser.param_groups[0]['lr']) results["loss"].append(loss.item()) i += 1 # Stop on minimal loss or total epochs reached # Loop to recalulate the compatibility matrix using current weights if loss < loss_threshold: break elif i >= total_epochs: break elif ((weights < zero_weight_threshold) & machine_product_compatibility).any(): # Recompute the machine product compatibility using the updated weights machine_product_compatibility = (weights >= zero_weight_threshold) with torch.no_grad(): logits.masked_fill_(machine_product_compatibility == 0, -np.inf) break if loss < loss_threshold: print("Stopping at near-zero loss:", loss.item()) break elif i >= total_epochs: print("Stopping at max epochs:", i) break # === Plot === f, ax = plt.subplots(figsize=(5, 2.8)) for name, loss_ in results.items(): if name in ("compatibility_changed"): continue ax.plot(loss_, label=name) if results["compatibility_changed"]: for xgroup in results["compatibility_changed"]: vl = ax.axvline(xgroup, color="black", ymin=0.8) vl.set_label("compatibility_changed") ax.legend(framealpha=0) ax.spines[["top", "right"]].set_visible(False) ax.semilogy(True) ax.grid(axis='y', which="major", linestyle=":") ax.set( title="Loss convergence", xlabel="iteration", ylabel="total loss" ) # Visualise machine relationships graph pos = nx.spring_layout(G, seed=1) f, ax = plt.subplots(figsize=(5, 3), layout="tight") nx.draw_networkx_nodes(G, pos, node_color="lightblue", node_size=400, ax=ax) nx.draw_networkx_edges(G, pos, width=1.5, edge_color="gray", ax=ax) nx.draw_networkx_labels(G, pos, font_size=10, font_weight="bold", ax=ax) ax.set_title("Machine relationships via shared products", fontsize=12) ax.spines[:].set_visible(False) # Machine groups and their averages f, ax = plt.subplots(figsize=(5, 3), layout="tight") for mach_ixs in machine_groups: ax.plot(mach_ixs, utilisation_per_mach[mach_ixs].detach(), linestyle="none", marker='d') ax.plot( mach_ixs, [utilisation_per_mach[mach_ixs].detach().mean(),] * len(mach_ixs), color="black", ) ax.set( xlabel="machine index", ylabel="utilisation", title="Machine utilisation and group averages" ) ax.spines[["top", "right"]].set_visible(False) # Product weights (compatible) for each machine print( "\nPercentage product weights for compatible machines (rows) and products (columns):\n", weights.detach().numpy().round(7), sep='' ) # Machine time per product print( "\n\nPercent utilisation per machine:\n", utilisation_per_mach.detach().numpy().round(5), sep='' ) print("Solution found after", i, "iterations") plt.show() ```