I have a matrix A of dimension 1000x70000. my loss function includes A and I want to find optimal value of A using gradient descent where the constraint is that the rows of A remain in probability simplex (i.e. every row sums up to 1). I have initialised A as given below
A=np.random.dirichlet(np.ones(70000),1000) A=torch.tensor(A,requires_grad=True) and my training loop looks like as given below
for epoch in range(500): y_pred=forward(X) y=model(torch.mm(A.float(),X)) l=loss(y,y_pred) l.backward() A.grad.data=-A.grad.data optimizer.step() optimizer.zero_grad() if epoch%2==0: print("Loss",l,"\n")