I had some time to read up on this and decided to share.
In "serious" ray-tracers that grown-ups use, cones and other similar shapes such as cylinders are usually represented as either tesselated surfaces or as quadric shapes. There exists pure cone-only intersection code that may or may not be more optimized, but implementing the more general quadric approach will give so much more value-for-effort for the beginner that it is worth implementing it first. Thus, in this answer I will focus on the quadric solution.
In geometry quadric shapes are
any surface that can be defined by an algebraic equation of second degree
. On the normal form this equation looks like this:
Ax^2 + By^2 + Cz^2 +2Dxy + 2Exz + 2Fyz + 2Gx + 2Hy + 2Iz + J = 0
By selecting different values for A through J, and by carefully applying axis aligned clipping planes in the right places, we can define what kind of surface we want. Here are some example shapes (from here):

Finally to the code. Since my ray-tracer was based on the ubiquitous instructional ray-tracer by Leonard McMillan I found a plethora of student extensions more or less based on this code. One in particular by a guy called Charlie had an implementation of Ray-Quadric intersection and shading code. I have shamelessly copied the important parts here (full code here), and will comment on it as soon as I get a chance to actually test it in my own application.
class Quadric extends Renderable { private static final float twoPI = (float)(Math.PI * 2.0); float A, B, C, D, E, F, G, H, I, J; float A2, B2, C2; float miny, maxy; Matrix3D OStoWS, WStoOS; float theta; public Quadric(Surface s, Matrix3D o2w, Matrix3D w2o, float zmin, float zmax, float thetamax, float a, float b, float c, float d, float e, float f, float g, float h, float i, float j ) { surface = s; A = a; B = b; C = c; D = d; E = e; F = f; G = g; H = h; I = i; J = j; A2 = 2.0f * A; B2 = 2.0f * B; C2 = 2.0f * C; miny = -zmin - Constants.EPSILON; maxy = zmax + Constants.EPSILON; OStoWS = new Matrix3D(o2w); WStoOS = new Matrix3D(w2o); theta = (float)(thetamax * Math.PI / 180.0f - Math.PI); } Vector3D d = new Vector3D(); Vector3D o = new Vector3D(); float a2_1; float a, b, c, tOS, tWS, discrim; float dMagOS; public boolean intersect(Ray ray) { WStoOS.transform(ray.origin, o); WStoOS.transformNormal(ray.direction, d); dMagOS = 1.0f / d.length(); d.normalize(); boolean flag = false; a = A * d.x * d.x + B * d.y * d.y + C * d.z * d.z; b = A2 * o.x * d.x + B2 * o.y * d.y + C2 * o.z * d.z; c = A * o.x * o.x + B * o.y * o.y + C * o.z * o.z + J; if (a == 0.0f) { tOS = - c / b; flag = true; // remember this } else { discrim = b * b - 4.0f * a * c; // compute discriminant if (discrim < 0.0f) return false; // no intersection discrim = (float)Math.sqrt(discrim); // store sqrt value a2_1 = 1.0f / (2.0f * a); // store 1 / (2a) tOS = (-b - discrim) * a2_1; // near intersection if (tOS < 0.0f) { // near intersection too close tOS = (-b + discrim) * a2_1; // use far intersection flag = true; // remember this } } tWS = tOS * dMagOS; // need to scale intersection tOS to get tWS if ((tWS > ray.tWS) || (tWS < 0)) return false; // trivial reject ray.hitOS.addScaled(o, d, tOS); // get hit point in object space if (ray.hitOS.y < miny || ray.hitOS.y > maxy) // outside clip box? { if (!flag) { // alright, so the near intersection is outside the clipping box, // but that doesn't mean the far intersection isn't inside tOS = (-b + discrim) * a2_1; // compute the far intersection tWS = tOS * dMagOS; // need to scale intersection t if ((tWS > ray.tWS) || (tWS < 0)) return false; // trivial reject ray.hitOS.addScaled(o, d, tOS); if (ray.hitOS.y < miny || ray.hitOS.y > maxy) return false; } else { // both near and far intersection flunked clip box return false; } } if (theta != twoPI) { // now clip against the accept angle float a = (float)Math.atan2(-ray.hitOS.z, -ray.hitOS.x); if (a > theta) { if (!flag) { // alright, so the near intersection is outside the accept angle, // but that doesn't mean the far intersection isn't inside tOS = (-b + discrim) * a2_1; // compute the far intersection tWS = tOS * dMagOS; // need to scale intersection t if ((tWS > ray.tWS) || (tWS < 0)) return false; // trivial reject ray.hitOS.addScaled(o, d, tOS); // update intersection point a = (float)Math.atan2(-ray.hitOS.z, -ray.hitOS.x); if (a > theta) return false; // we haven't actually tested this yet at this point if (ray.hitOS.y < miny || ray.hitOS.y > maxy) return false; } else { // both near and far intersection flunked accept angle return false; } } } ray.tWS = tWS; ray.object = this; return true; } public void computeNormal(Ray ray) { Vector3D nOS = new Vector3D( A2 * ray.hitOS.x, B2 * ray.hitOS.y, C2 * ray.hitOS.z); OStoWS.transformNormal(nOS, ray.n); ray.n.normalize(); }